jsoncomm: transparently handle huge messages via fds

The existing jsoncomm is a work of beautiy. For very big arguments
however the used `SOCK_SEQPACKET` hits the limitations of the
kernel network buffer size (see also [0]). This lead to various
workarounds in #824,#1331,#1836 where parts of the request are
encoded as part of the json method call and parts are done via
a side-channel via fd-passing.

This commit changes the code so that the fd channel is automatically
and transparently created and the workarounds are removed. A test
is added that ensures that very big messages can be passed.

[0] https://github.com/osbuild/osbuild/pull/1833
This commit is contained in:
Michael Vogt 2024-08-07 16:49:30 +02:00 committed by Achilleas Koutsou
parent d67fa48c17
commit 0abdfb9041
6 changed files with 130 additions and 87 deletions

View file

@ -17,17 +17,15 @@ osbuild is the path. The input options are just passed to the
""" """
import abc import abc
import contextlib
import hashlib import hashlib
import json import json
import os import os
import tempfile
from typing import Any, Dict, Optional, Tuple from typing import Any, Dict, Optional, Tuple
from osbuild import host from osbuild import host
from osbuild.util.types import PathLike from osbuild.util.types import PathLike
from .objectstore import ObjectStore, StoreClient, StoreServer from .objectstore import StoreClient, StoreServer
class Input: class Input:
@ -67,7 +65,7 @@ class InputManager:
self.root = root self.root = root
self.inputs: Dict[str, Input] = {} self.inputs: Dict[str, Input] = {}
def map(self, ip: Input, store: ObjectStore) -> Tuple[str, Dict]: def map(self, ip: Input) -> Tuple[str, Dict]:
target = os.path.join(self.root, ip.name) target = os.path.join(self.root, ip.name)
os.makedirs(target) os.makedirs(target)
@ -88,12 +86,8 @@ class InputManager:
} }
} }
with make_args_and_reply_files(store.tmp, args) as (fd_args, fd_reply): client = self.service_manager.start(f"input/{ip.name}", ip.info.path)
fds = [fd_args, fd_reply] reply = client.call("map", args)
client = self.service_manager.start(f"input/{ip.name}", ip.info.path)
_, _ = client.call_with_fds("map", {}, fds)
with os.fdopen(os.dup(fd_reply)) as f:
reply = json.loads(f.read())
path = reply["path"] path = reply["path"]
@ -107,15 +101,6 @@ class InputManager:
return reply return reply
@contextlib.contextmanager
def make_args_and_reply_files(tmp, args):
with tempfile.TemporaryFile("w+", dir=tmp, encoding="utf-8") as f_args, \
tempfile.TemporaryFile("w+", dir=tmp, encoding="utf-8") as f_reply:
json.dump(args, f_args)
f_args.seek(0)
yield f_args.fileno(), f_reply.fileno()
class InputService(host.Service): class InputService(host.Service):
"""Input host service""" """Input host service"""
@ -129,21 +114,14 @@ class InputService(host.Service):
def stop(self): def stop(self):
self.unmap() self.unmap()
def dispatch(self, method: str, _, fds): def dispatch(self, method: str, args, fds):
if method == "map": if method == "map":
# map() sends fd[0] to read the arguments from and fd[1] to
# write the reply back. This avoids running into EMSGSIZE
with os.fdopen(fds.steal(0)) as f:
args = json.load(f)
store = StoreClient(connect_to=args["api"]["store"]) store = StoreClient(connect_to=args["api"]["store"])
r = self.map(store, r = self.map(store,
args["origin"], args["origin"],
args["refs"], args["refs"],
args["target"], args["target"],
args["options"]) args["options"])
with os.fdopen(fds.steal(1), "w") as f: return r, None
f.write(json.dumps(r))
f.seek(0)
return "{}", None
raise host.ProtocolError("Unknown method") raise host.ProtocolError("Unknown method")

View file

@ -210,7 +210,7 @@ class Stage:
ipmgr = InputManager(mgr, storeapi, inputs_tmpdir) ipmgr = InputManager(mgr, storeapi, inputs_tmpdir)
for key, ip in self.inputs.items(): for key, ip in self.inputs.items():
data = ipmgr.map(ip, store) data = ipmgr.map(ip)
inputs[key] = data inputs[key] = data
devmgr = DeviceManager(mgr, build_root.dev, tree) devmgr = DeviceManager(mgr, build_root.dev, tree)

View file

@ -1,5 +1,4 @@
import abc import abc
import contextlib
import hashlib import hashlib
import json import json
import os import os
@ -29,6 +28,7 @@ class Source:
cache = os.path.join(store.store, "sources") cache = os.path.join(store.store, "sources")
args = { args = {
"items": self.items,
"options": self.options, "options": self.options,
"cache": cache, "cache": cache,
"output": None, "output": None,
@ -36,20 +36,10 @@ class Source:
} }
client = mgr.start(f"source/{source}", self.info.path) client = mgr.start(f"source/{source}", self.info.path)
reply = client.call("download", args)
with self.make_items_file(store.tmp) as fd:
fds = [fd]
reply = client.call_with_fds("download", args, fds)
return reply return reply
@contextlib.contextmanager
def make_items_file(self, tmp):
with tempfile.TemporaryFile("w+", dir=tmp, encoding="utf-8") as f:
json.dump(self.items, f)
f.seek(0)
yield f.fileno()
# "name", "id", "stages", "results" is only here to make it looks like a # "name", "id", "stages", "results" is only here to make it looks like a
# pipeline for the monitor. This should be revisited at some point # pipeline for the monitor. This should be revisited at some point
# and maybe the monitor should get first-class support for # and maybe the monitor should get first-class support for
@ -103,12 +93,6 @@ class SourceService(host.Service):
"""Returns True if the item to download is in cache. """ """Returns True if the item to download is in cache. """
return os.path.isfile(f"{self.cache}/{checksum}") return os.path.isfile(f"{self.cache}/{checksum}")
@staticmethod
def load_items(fds):
with os.fdopen(fds.steal(0)) as f:
items = json.load(f)
return items
def setup(self, args): def setup(self, args):
self.cache = os.path.join(args["cache"], self.content_type) self.cache = os.path.join(args["cache"], self.content_type)
os.makedirs(self.cache, exist_ok=True) os.makedirs(self.cache, exist_ok=True)
@ -118,7 +102,7 @@ class SourceService(host.Service):
if method == "download": if method == "download":
self.setup(args) self.setup(args)
with tempfile.TemporaryDirectory(prefix=".unverified-", dir=self.cache) as self.tmpdir: with tempfile.TemporaryDirectory(prefix=".unverified-", dir=self.cache) as self.tmpdir:
self.fetch_all(SourceService.load_items(fds)) self.fetch_all(args["items"])
return None, None return None, None
raise host.ProtocolError("Unknown method") raise host.ProtocolError("Unknown method")

View file

@ -5,18 +5,31 @@ serialization. It uses unix-domain-datagram-sockets and provides a simple
unicast message transmission. unicast message transmission.
""" """
import array import array
import contextlib import contextlib
import errno import errno
import json import json
import os import os
import socket import socket
from typing import Any, Optional from typing import Any, List, Optional
from .types import PathLike from .types import PathLike
@contextlib.contextmanager
def memfd(name):
fd = os.memfd_create(name, 0)
try:
yield fd
finally:
os.close(fd)
# this marker is used when the arguments are passed via a filedescriptor
# because they exceed the allowed size for a network package
ARGS_VIA_FD_MARKER = b"<args-via-fd>"
class FdSet: class FdSet:
"""File-Descriptor Set """File-Descriptor Set
@ -92,6 +105,19 @@ class FdSet:
return v return v
def wmem_max() -> int:
""" Return the kernels maximum send socket buffer size in bytes
When /proc is not mounted return a conservative estimate (64kb).
"""
try:
with open("/proc/sys/net/core/wmem_max", encoding="utf8") as wmem_file:
return int(wmem_file.read().strip())
except FileNotFoundError:
# conservative estimate for systems that have no /proc mounted
return 64_000
class Socket(contextlib.AbstractContextManager): class Socket(contextlib.AbstractContextManager):
"""Communication Socket """Communication Socket
@ -353,7 +379,17 @@ class Socket(contextlib.AbstractContextManager):
if level == socket.SOL_SOCKET and ty == socket.SCM_RIGHTS: if level == socket.SOL_SOCKET and ty == socket.SCM_RIGHTS:
assert len(data) % fds.itemsize == 0 assert len(data) % fds.itemsize == 0
fds.frombytes(data) fds.frombytes(data)
fdset = FdSet(rawfds=fds) # Next we need to check if the serialzed data comes via an FD
# or via the message. FDs are used if the data size is big to
# avoid running into errno.EMSGSIZE
if msg[0] == ARGS_VIA_FD_MARKER:
fd_payload = fds[0]
fdset = FdSet(rawfds=fds[1:])
with os.fdopen(fd_payload) as f:
serialized = f.read()
else:
fdset = FdSet(rawfds=fds)
serialized = msg[0]
# Check the returned message flags. If the message was truncated, we # Check the returned message flags. If the message was truncated, we
# have to discard it. This shouldn't happen, but there is no harm in # have to discard it. This shouldn't happen, but there is no harm in
@ -364,13 +400,38 @@ class Socket(contextlib.AbstractContextManager):
raise BufferError raise BufferError
try: try:
payload = json.loads(msg[0]) payload = json.loads(serialized)
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
raise BufferError from e raise BufferError from e
return (payload, fdset, msg[3]) return (payload, fdset, msg[3])
def send(self, payload: object, *, fds: Optional[list] = None): def _send_via_fd(self, serialized: bytes, fds: List[int]):
assert self._socket is not None
with memfd("jsoncomm/payload") as fd_payload:
os.write(fd_payload, serialized)
os.lseek(fd_payload, 0, 0)
cmsg = []
cmsg.append((socket.SOL_SOCKET, socket.SCM_RIGHTS, array.array("i", [fd_payload] + fds)))
n = self._socket.sendmsg([ARGS_VIA_FD_MARKER], cmsg, 0)
assert n == len(ARGS_VIA_FD_MARKER)
def _send_via_sendmsg(self, serialized: bytes, fds: List[int]):
assert self._socket is not None
cmsg = []
if fds:
cmsg.append((socket.SOL_SOCKET, socket.SCM_RIGHTS, array.array("i", fds)))
try:
self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, len(serialized))
n = self._socket.sendmsg([serialized], cmsg, 0)
except OSError as exc:
if exc.errno == errno.EMSGSIZE:
raise BufferError(
f"jsoncomm message size {len(serialized)} is too big") from exc
raise exc
assert n == len(serialized)
def send(self, payload: object, *, fds: Optional[list] = None) -> None:
"""Send Message """Send Message
Send a new message via this socket. This operation is synchronous. The Send a new message via this socket. This operation is synchronous. The
@ -399,19 +460,14 @@ class Socket(contextlib.AbstractContextManager):
if not self._socket: if not self._socket:
raise RuntimeError("Tried to send without socket.") raise RuntimeError("Tried to send without socket.")
serialized = json.dumps(payload).encode() if not fds:
cmsg = [] fds = []
if fds:
cmsg.append((socket.SOL_SOCKET, socket.SCM_RIGHTS, array.array("i", fds)))
try: serialized = json.dumps(payload).encode()
n = self._socket.sendmsg([serialized], cmsg, 0) if len(serialized) > wmem_max():
except OSError as exc: self._send_via_fd(serialized, fds)
if exc.errno == errno.EMSGSIZE: else:
raise BufferError( self._send_via_sendmsg(serialized, fds)
f"jsoncomm message size {len(serialized)} is too big") from exc
raise exc
assert n == len(serialized)
def send_and_recv(self, payload: object, *, fds: Optional[list] = None): def send_and_recv(self, payload: object, *, fds: Optional[list] = None):
"""Send a message and wait for a reply """Send a message and wait for a reply

View file

@ -1,9 +1,7 @@
import json
import os import os
from unittest.mock import patch from unittest.mock import call, patch
from osbuild import inputs from osbuild import inputs
from osbuild.util.jsoncomm import FdSet
class FakeInputService(inputs.InputService): class FakeInputService(inputs.InputService):
@ -11,7 +9,7 @@ class FakeInputService(inputs.InputService):
# do not call "super().__init__()" here to make it testable # do not call "super().__init__()" here to make it testable
self.map_calls = [] self.map_calls = []
def map(self, _store, origin, refs, target, options): def map(self, store, origin, refs, target, options):
self.map_calls.append([origin, refs, target, options]) self.map_calls.append([origin, refs, target, options])
return "complex", 2, "reply" return "complex", 2, "reply"
@ -20,8 +18,6 @@ def test_inputs_dispatches_map(tmp_path):
store_api_path = tmp_path / "api-store" store_api_path = tmp_path / "api-store"
store_api_path.write_text("") store_api_path.write_text("")
args_path = tmp_path / "args"
reply_path = tmp_path / "reply"
args = { args = {
"api": { "api": {
"store": os.fspath(store_api_path), "store": os.fspath(store_api_path),
@ -31,17 +27,14 @@ def test_inputs_dispatches_map(tmp_path):
"target": "some-target", "target": "some-target",
"options": "some-options", "options": "some-options",
} }
args_path.write_text(json.dumps(args))
reply_path.write_text("")
with args_path.open() as f_args, reply_path.open("w") as f_reply: fake_service = FakeInputService(args="some")
fd_args, fd_reply = os.dup(f_args.fileno()), os.dup(f_reply.fileno()) with patch.object(inputs, "StoreClient") as mocked_store_client_klass:
fds = FdSet.from_list([fd_args, fd_reply]) r = fake_service.dispatch("map", args, None)
fake_service = FakeInputService(args="some") assert mocked_store_client_klass.call_args_list == [
with patch.object(inputs, "StoreClient"): call(connect_to=os.fspath(store_api_path)),
r = fake_service.dispatch("map", None, fds) ]
assert r == ('{}', None) assert fake_service.map_calls == [
assert fake_service.map_calls == [ ["some-origin", "some-refs", "some-target", "some-options"],
["some-origin", "some-refs", "some-target", "some-options"], ]
] assert r == (("complex", 2, "reply"), None)
assert reply_path.read_text() == '["complex", 2, "reply"]'

View file

@ -2,13 +2,17 @@
# Tests for the 'osbuild.util.jsoncomm' module. # Tests for the 'osbuild.util.jsoncomm' module.
# #
# pylint: disable=protected-access
import asyncio import asyncio
import errno import errno
import json
import os import os
import pathlib import pathlib
import tempfile import tempfile
import unittest import unittest
from concurrent import futures from concurrent import futures
from unittest.mock import patch
import pytest import pytest
@ -220,11 +224,39 @@ class TestUtilJsonComm(unittest.TestCase):
pong, _, _ = a.recv() pong, _, _ = a.recv()
self.assertEqual(ping, pong) self.assertEqual(ping, pong)
def test_send_and_recv_tons_of_data_still_errors(self): def test_sendmsg_errors_with_size_on_EMSGSIZE(self):
a, _ = jsoncomm.Socket.new_pair() a, _ = jsoncomm.Socket.new_pair()
ping = {"data": "1" * 1_000_000} serialized = json.dumps({"data": "1" * 1_000_000}).encode()
with pytest.raises(BufferError) as exc: with pytest.raises(BufferError) as exc:
a.send(ping) a._send_via_sendmsg(serialized, [])
assert str(exc.value) == "jsoncomm message size 1000012 is too big" assert str(exc.value) == "jsoncomm message size 1000012 is too big"
assert exc.value.__cause__.errno == errno.EMSGSIZE assert exc.value.__cause__.errno == errno.EMSGSIZE
def test_send_and_recv_tons_of_data_is_fine(self):
a, b = jsoncomm.Socket.new_pair()
ping = {"data": "tons" * 1_000_000}
a.send(ping)
pong, _, _ = b.send_and_recv(ping)
self.assertEqual(ping, pong)
pong, _, _ = a.recv()
self.assertEqual(ping, pong)
def test_send_small_data_via_sendmsg(self):
a, _ = jsoncomm.Socket.new_pair()
with patch.object(a, "_send_via_fd") as mock_send_via_fd, \
patch.object(a, "_send_via_sendmsg") as mock_send_via_sendmsg:
ping = {"data": "little"}
a.send(ping)
assert mock_send_via_fd.call_count == 0
assert mock_send_via_sendmsg.call_count == 1
def test_send_huge_data_via_fd(self):
a, _ = jsoncomm.Socket.new_pair()
with patch.object(a, "_send_via_fd") as mock_send_via_fd, \
patch.object(a, "_send_via_sendmsg") as mock_send_via_sendmsg:
ping = {"data": "tons" * 1_000_000}
a.send(ping)
assert mock_send_via_fd.call_count == 1
assert mock_send_via_sendmsg.call_count == 0