jsoncomm: transparently handle huge messages via fds

The existing jsoncomm is a work of beautiy. For very big arguments
however the used `SOCK_SEQPACKET` hits the limitations of the
kernel network buffer size (see also [0]). This lead to various
workarounds in #824,#1331,#1836 where parts of the request are
encoded as part of the json method call and parts are done via
a side-channel via fd-passing.

This commit changes the code so that the fd channel is automatically
and transparently created and the workarounds are removed. A test
is added that ensures that very big messages can be passed.

[0] https://github.com/osbuild/osbuild/pull/1833
This commit is contained in:
Michael Vogt 2024-08-07 16:49:30 +02:00 committed by Achilleas Koutsou
parent d67fa48c17
commit 0abdfb9041
6 changed files with 130 additions and 87 deletions

View file

@ -17,17 +17,15 @@ osbuild is the path. The input options are just passed to the
"""
import abc
import contextlib
import hashlib
import json
import os
import tempfile
from typing import Any, Dict, Optional, Tuple
from osbuild import host
from osbuild.util.types import PathLike
from .objectstore import ObjectStore, StoreClient, StoreServer
from .objectstore import StoreClient, StoreServer
class Input:
@ -67,7 +65,7 @@ class InputManager:
self.root = root
self.inputs: Dict[str, Input] = {}
def map(self, ip: Input, store: ObjectStore) -> Tuple[str, Dict]:
def map(self, ip: Input) -> Tuple[str, Dict]:
target = os.path.join(self.root, ip.name)
os.makedirs(target)
@ -88,12 +86,8 @@ class InputManager:
}
}
with make_args_and_reply_files(store.tmp, args) as (fd_args, fd_reply):
fds = [fd_args, fd_reply]
client = self.service_manager.start(f"input/{ip.name}", ip.info.path)
_, _ = client.call_with_fds("map", {}, fds)
with os.fdopen(os.dup(fd_reply)) as f:
reply = json.loads(f.read())
reply = client.call("map", args)
path = reply["path"]
@ -107,15 +101,6 @@ class InputManager:
return reply
@contextlib.contextmanager
def make_args_and_reply_files(tmp, args):
with tempfile.TemporaryFile("w+", dir=tmp, encoding="utf-8") as f_args, \
tempfile.TemporaryFile("w+", dir=tmp, encoding="utf-8") as f_reply:
json.dump(args, f_args)
f_args.seek(0)
yield f_args.fileno(), f_reply.fileno()
class InputService(host.Service):
"""Input host service"""
@ -129,21 +114,14 @@ class InputService(host.Service):
def stop(self):
self.unmap()
def dispatch(self, method: str, _, fds):
def dispatch(self, method: str, args, fds):
if method == "map":
# map() sends fd[0] to read the arguments from and fd[1] to
# write the reply back. This avoids running into EMSGSIZE
with os.fdopen(fds.steal(0)) as f:
args = json.load(f)
store = StoreClient(connect_to=args["api"]["store"])
r = self.map(store,
args["origin"],
args["refs"],
args["target"],
args["options"])
with os.fdopen(fds.steal(1), "w") as f:
f.write(json.dumps(r))
f.seek(0)
return "{}", None
return r, None
raise host.ProtocolError("Unknown method")

View file

@ -210,7 +210,7 @@ class Stage:
ipmgr = InputManager(mgr, storeapi, inputs_tmpdir)
for key, ip in self.inputs.items():
data = ipmgr.map(ip, store)
data = ipmgr.map(ip)
inputs[key] = data
devmgr = DeviceManager(mgr, build_root.dev, tree)

View file

@ -1,5 +1,4 @@
import abc
import contextlib
import hashlib
import json
import os
@ -29,6 +28,7 @@ class Source:
cache = os.path.join(store.store, "sources")
args = {
"items": self.items,
"options": self.options,
"cache": cache,
"output": None,
@ -36,20 +36,10 @@ class Source:
}
client = mgr.start(f"source/{source}", self.info.path)
with self.make_items_file(store.tmp) as fd:
fds = [fd]
reply = client.call_with_fds("download", args, fds)
reply = client.call("download", args)
return reply
@contextlib.contextmanager
def make_items_file(self, tmp):
with tempfile.TemporaryFile("w+", dir=tmp, encoding="utf-8") as f:
json.dump(self.items, f)
f.seek(0)
yield f.fileno()
# "name", "id", "stages", "results" is only here to make it looks like a
# pipeline for the monitor. This should be revisited at some point
# and maybe the monitor should get first-class support for
@ -103,12 +93,6 @@ class SourceService(host.Service):
"""Returns True if the item to download is in cache. """
return os.path.isfile(f"{self.cache}/{checksum}")
@staticmethod
def load_items(fds):
with os.fdopen(fds.steal(0)) as f:
items = json.load(f)
return items
def setup(self, args):
self.cache = os.path.join(args["cache"], self.content_type)
os.makedirs(self.cache, exist_ok=True)
@ -118,7 +102,7 @@ class SourceService(host.Service):
if method == "download":
self.setup(args)
with tempfile.TemporaryDirectory(prefix=".unverified-", dir=self.cache) as self.tmpdir:
self.fetch_all(SourceService.load_items(fds))
self.fetch_all(args["items"])
return None, None
raise host.ProtocolError("Unknown method")

View file

@ -5,18 +5,31 @@ serialization. It uses unix-domain-datagram-sockets and provides a simple
unicast message transmission.
"""
import array
import contextlib
import errno
import json
import os
import socket
from typing import Any, Optional
from typing import Any, List, Optional
from .types import PathLike
@contextlib.contextmanager
def memfd(name):
fd = os.memfd_create(name, 0)
try:
yield fd
finally:
os.close(fd)
# this marker is used when the arguments are passed via a filedescriptor
# because they exceed the allowed size for a network package
ARGS_VIA_FD_MARKER = b"<args-via-fd>"
class FdSet:
"""File-Descriptor Set
@ -92,6 +105,19 @@ class FdSet:
return v
def wmem_max() -> int:
""" Return the kernels maximum send socket buffer size in bytes
When /proc is not mounted return a conservative estimate (64kb).
"""
try:
with open("/proc/sys/net/core/wmem_max", encoding="utf8") as wmem_file:
return int(wmem_file.read().strip())
except FileNotFoundError:
# conservative estimate for systems that have no /proc mounted
return 64_000
class Socket(contextlib.AbstractContextManager):
"""Communication Socket
@ -353,7 +379,17 @@ class Socket(contextlib.AbstractContextManager):
if level == socket.SOL_SOCKET and ty == socket.SCM_RIGHTS:
assert len(data) % fds.itemsize == 0
fds.frombytes(data)
# Next we need to check if the serialzed data comes via an FD
# or via the message. FDs are used if the data size is big to
# avoid running into errno.EMSGSIZE
if msg[0] == ARGS_VIA_FD_MARKER:
fd_payload = fds[0]
fdset = FdSet(rawfds=fds[1:])
with os.fdopen(fd_payload) as f:
serialized = f.read()
else:
fdset = FdSet(rawfds=fds)
serialized = msg[0]
# Check the returned message flags. If the message was truncated, we
# have to discard it. This shouldn't happen, but there is no harm in
@ -364,13 +400,38 @@ class Socket(contextlib.AbstractContextManager):
raise BufferError
try:
payload = json.loads(msg[0])
payload = json.loads(serialized)
except json.JSONDecodeError as e:
raise BufferError from e
return (payload, fdset, msg[3])
def send(self, payload: object, *, fds: Optional[list] = None):
def _send_via_fd(self, serialized: bytes, fds: List[int]):
assert self._socket is not None
with memfd("jsoncomm/payload") as fd_payload:
os.write(fd_payload, serialized)
os.lseek(fd_payload, 0, 0)
cmsg = []
cmsg.append((socket.SOL_SOCKET, socket.SCM_RIGHTS, array.array("i", [fd_payload] + fds)))
n = self._socket.sendmsg([ARGS_VIA_FD_MARKER], cmsg, 0)
assert n == len(ARGS_VIA_FD_MARKER)
def _send_via_sendmsg(self, serialized: bytes, fds: List[int]):
assert self._socket is not None
cmsg = []
if fds:
cmsg.append((socket.SOL_SOCKET, socket.SCM_RIGHTS, array.array("i", fds)))
try:
self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, len(serialized))
n = self._socket.sendmsg([serialized], cmsg, 0)
except OSError as exc:
if exc.errno == errno.EMSGSIZE:
raise BufferError(
f"jsoncomm message size {len(serialized)} is too big") from exc
raise exc
assert n == len(serialized)
def send(self, payload: object, *, fds: Optional[list] = None) -> None:
"""Send Message
Send a new message via this socket. This operation is synchronous. The
@ -399,19 +460,14 @@ class Socket(contextlib.AbstractContextManager):
if not self._socket:
raise RuntimeError("Tried to send without socket.")
serialized = json.dumps(payload).encode()
cmsg = []
if fds:
cmsg.append((socket.SOL_SOCKET, socket.SCM_RIGHTS, array.array("i", fds)))
if not fds:
fds = []
try:
n = self._socket.sendmsg([serialized], cmsg, 0)
except OSError as exc:
if exc.errno == errno.EMSGSIZE:
raise BufferError(
f"jsoncomm message size {len(serialized)} is too big") from exc
raise exc
assert n == len(serialized)
serialized = json.dumps(payload).encode()
if len(serialized) > wmem_max():
self._send_via_fd(serialized, fds)
else:
self._send_via_sendmsg(serialized, fds)
def send_and_recv(self, payload: object, *, fds: Optional[list] = None):
"""Send a message and wait for a reply

View file

@ -1,9 +1,7 @@
import json
import os
from unittest.mock import patch
from unittest.mock import call, patch
from osbuild import inputs
from osbuild.util.jsoncomm import FdSet
class FakeInputService(inputs.InputService):
@ -11,7 +9,7 @@ class FakeInputService(inputs.InputService):
# do not call "super().__init__()" here to make it testable
self.map_calls = []
def map(self, _store, origin, refs, target, options):
def map(self, store, origin, refs, target, options):
self.map_calls.append([origin, refs, target, options])
return "complex", 2, "reply"
@ -20,8 +18,6 @@ def test_inputs_dispatches_map(tmp_path):
store_api_path = tmp_path / "api-store"
store_api_path.write_text("")
args_path = tmp_path / "args"
reply_path = tmp_path / "reply"
args = {
"api": {
"store": os.fspath(store_api_path),
@ -31,17 +27,14 @@ def test_inputs_dispatches_map(tmp_path):
"target": "some-target",
"options": "some-options",
}
args_path.write_text(json.dumps(args))
reply_path.write_text("")
with args_path.open() as f_args, reply_path.open("w") as f_reply:
fd_args, fd_reply = os.dup(f_args.fileno()), os.dup(f_reply.fileno())
fds = FdSet.from_list([fd_args, fd_reply])
fake_service = FakeInputService(args="some")
with patch.object(inputs, "StoreClient"):
r = fake_service.dispatch("map", None, fds)
assert r == ('{}', None)
with patch.object(inputs, "StoreClient") as mocked_store_client_klass:
r = fake_service.dispatch("map", args, None)
assert mocked_store_client_klass.call_args_list == [
call(connect_to=os.fspath(store_api_path)),
]
assert fake_service.map_calls == [
["some-origin", "some-refs", "some-target", "some-options"],
]
assert reply_path.read_text() == '["complex", 2, "reply"]'
assert r == (("complex", 2, "reply"), None)

View file

@ -2,13 +2,17 @@
# Tests for the 'osbuild.util.jsoncomm' module.
#
# pylint: disable=protected-access
import asyncio
import errno
import json
import os
import pathlib
import tempfile
import unittest
from concurrent import futures
from unittest.mock import patch
import pytest
@ -220,11 +224,39 @@ class TestUtilJsonComm(unittest.TestCase):
pong, _, _ = a.recv()
self.assertEqual(ping, pong)
def test_send_and_recv_tons_of_data_still_errors(self):
def test_sendmsg_errors_with_size_on_EMSGSIZE(self):
a, _ = jsoncomm.Socket.new_pair()
ping = {"data": "1" * 1_000_000}
serialized = json.dumps({"data": "1" * 1_000_000}).encode()
with pytest.raises(BufferError) as exc:
a.send(ping)
a._send_via_sendmsg(serialized, [])
assert str(exc.value) == "jsoncomm message size 1000012 is too big"
assert exc.value.__cause__.errno == errno.EMSGSIZE
def test_send_and_recv_tons_of_data_is_fine(self):
a, b = jsoncomm.Socket.new_pair()
ping = {"data": "tons" * 1_000_000}
a.send(ping)
pong, _, _ = b.send_and_recv(ping)
self.assertEqual(ping, pong)
pong, _, _ = a.recv()
self.assertEqual(ping, pong)
def test_send_small_data_via_sendmsg(self):
a, _ = jsoncomm.Socket.new_pair()
with patch.object(a, "_send_via_fd") as mock_send_via_fd, \
patch.object(a, "_send_via_sendmsg") as mock_send_via_sendmsg:
ping = {"data": "little"}
a.send(ping)
assert mock_send_via_fd.call_count == 0
assert mock_send_via_sendmsg.call_count == 1
def test_send_huge_data_via_fd(self):
a, _ = jsoncomm.Socket.new_pair()
with patch.object(a, "_send_via_fd") as mock_send_via_fd, \
patch.object(a, "_send_via_sendmsg") as mock_send_via_sendmsg:
ping = {"data": "tons" * 1_000_000}
a.send(ping)
assert mock_send_via_fd.call_count == 1
assert mock_send_via_sendmsg.call_count == 0