diff --git a/osbuild/inputs.py b/osbuild/inputs.py index 24c580e6..8ad682e9 100644 --- a/osbuild/inputs.py +++ b/osbuild/inputs.py @@ -17,17 +17,15 @@ osbuild is the path. The input options are just passed to the """ import abc -import contextlib import hashlib import json import os -import tempfile from typing import Any, Dict, Optional, Tuple from osbuild import host from osbuild.util.types import PathLike -from .objectstore import ObjectStore, StoreClient, StoreServer +from .objectstore import StoreClient, StoreServer class Input: @@ -67,7 +65,7 @@ class InputManager: self.root = root self.inputs: Dict[str, Input] = {} - def map(self, ip: Input, store: ObjectStore) -> Tuple[str, Dict]: + def map(self, ip: Input) -> Tuple[str, Dict]: target = os.path.join(self.root, ip.name) os.makedirs(target) @@ -88,12 +86,8 @@ class InputManager: } } - with make_args_and_reply_files(store.tmp, args) as (fd_args, fd_reply): - fds = [fd_args, fd_reply] - client = self.service_manager.start(f"input/{ip.name}", ip.info.path) - _, _ = client.call_with_fds("map", {}, fds) - with os.fdopen(os.dup(fd_reply)) as f: - reply = json.loads(f.read()) + client = self.service_manager.start(f"input/{ip.name}", ip.info.path) + reply = client.call("map", args) path = reply["path"] @@ -107,15 +101,6 @@ class InputManager: return reply -@contextlib.contextmanager -def make_args_and_reply_files(tmp, args): - with tempfile.TemporaryFile("w+", dir=tmp, encoding="utf-8") as f_args, \ - tempfile.TemporaryFile("w+", dir=tmp, encoding="utf-8") as f_reply: - json.dump(args, f_args) - f_args.seek(0) - yield f_args.fileno(), f_reply.fileno() - - class InputService(host.Service): """Input host service""" @@ -129,21 +114,14 @@ class InputService(host.Service): def stop(self): self.unmap() - def dispatch(self, method: str, _, fds): + def dispatch(self, method: str, args, fds): if method == "map": - # map() sends fd[0] to read the arguments from and fd[1] to - # write the reply back. This avoids running into EMSGSIZE - with os.fdopen(fds.steal(0)) as f: - args = json.load(f) store = StoreClient(connect_to=args["api"]["store"]) r = self.map(store, args["origin"], args["refs"], args["target"], args["options"]) - with os.fdopen(fds.steal(1), "w") as f: - f.write(json.dumps(r)) - f.seek(0) - return "{}", None + return r, None raise host.ProtocolError("Unknown method") diff --git a/osbuild/pipeline.py b/osbuild/pipeline.py index 437a9001..84013740 100644 --- a/osbuild/pipeline.py +++ b/osbuild/pipeline.py @@ -210,7 +210,7 @@ class Stage: ipmgr = InputManager(mgr, storeapi, inputs_tmpdir) for key, ip in self.inputs.items(): - data = ipmgr.map(ip, store) + data = ipmgr.map(ip) inputs[key] = data devmgr = DeviceManager(mgr, build_root.dev, tree) diff --git a/osbuild/sources.py b/osbuild/sources.py index b09597e9..029dc6b8 100644 --- a/osbuild/sources.py +++ b/osbuild/sources.py @@ -1,5 +1,4 @@ import abc -import contextlib import hashlib import json import os @@ -29,6 +28,7 @@ class Source: cache = os.path.join(store.store, "sources") args = { + "items": self.items, "options": self.options, "cache": cache, "output": None, @@ -36,20 +36,10 @@ class Source: } client = mgr.start(f"source/{source}", self.info.path) - - with self.make_items_file(store.tmp) as fd: - fds = [fd] - reply = client.call_with_fds("download", args, fds) + reply = client.call("download", args) return reply - @contextlib.contextmanager - def make_items_file(self, tmp): - with tempfile.TemporaryFile("w+", dir=tmp, encoding="utf-8") as f: - json.dump(self.items, f) - f.seek(0) - yield f.fileno() - # "name", "id", "stages", "results" is only here to make it looks like a # pipeline for the monitor. This should be revisited at some point # and maybe the monitor should get first-class support for @@ -103,12 +93,6 @@ class SourceService(host.Service): """Returns True if the item to download is in cache. """ return os.path.isfile(f"{self.cache}/{checksum}") - @staticmethod - def load_items(fds): - with os.fdopen(fds.steal(0)) as f: - items = json.load(f) - return items - def setup(self, args): self.cache = os.path.join(args["cache"], self.content_type) os.makedirs(self.cache, exist_ok=True) @@ -118,7 +102,7 @@ class SourceService(host.Service): if method == "download": self.setup(args) with tempfile.TemporaryDirectory(prefix=".unverified-", dir=self.cache) as self.tmpdir: - self.fetch_all(SourceService.load_items(fds)) + self.fetch_all(args["items"]) return None, None raise host.ProtocolError("Unknown method") diff --git a/osbuild/util/jsoncomm.py b/osbuild/util/jsoncomm.py index 8a0cbb7c..7180b24f 100644 --- a/osbuild/util/jsoncomm.py +++ b/osbuild/util/jsoncomm.py @@ -5,18 +5,31 @@ serialization. It uses unix-domain-datagram-sockets and provides a simple unicast message transmission. """ - import array import contextlib import errno import json import os import socket -from typing import Any, Optional +from typing import Any, List, Optional from .types import PathLike +@contextlib.contextmanager +def memfd(name): + fd = os.memfd_create(name, 0) + try: + yield fd + finally: + os.close(fd) + + +# this marker is used when the arguments are passed via a filedescriptor +# because they exceed the allowed size for a network package +ARGS_VIA_FD_MARKER = b"" + + class FdSet: """File-Descriptor Set @@ -92,6 +105,19 @@ class FdSet: return v +def wmem_max() -> int: + """ Return the kernels maximum send socket buffer size in bytes + + When /proc is not mounted return a conservative estimate (64kb). + """ + try: + with open("/proc/sys/net/core/wmem_max", encoding="utf8") as wmem_file: + return int(wmem_file.read().strip()) + except FileNotFoundError: + # conservative estimate for systems that have no /proc mounted + return 64_000 + + class Socket(contextlib.AbstractContextManager): """Communication Socket @@ -353,7 +379,17 @@ class Socket(contextlib.AbstractContextManager): if level == socket.SOL_SOCKET and ty == socket.SCM_RIGHTS: assert len(data) % fds.itemsize == 0 fds.frombytes(data) - fdset = FdSet(rawfds=fds) + # Next we need to check if the serialzed data comes via an FD + # or via the message. FDs are used if the data size is big to + # avoid running into errno.EMSGSIZE + if msg[0] == ARGS_VIA_FD_MARKER: + fd_payload = fds[0] + fdset = FdSet(rawfds=fds[1:]) + with os.fdopen(fd_payload) as f: + serialized = f.read() + else: + fdset = FdSet(rawfds=fds) + serialized = msg[0] # Check the returned message flags. If the message was truncated, we # have to discard it. This shouldn't happen, but there is no harm in @@ -364,13 +400,38 @@ class Socket(contextlib.AbstractContextManager): raise BufferError try: - payload = json.loads(msg[0]) + payload = json.loads(serialized) except json.JSONDecodeError as e: raise BufferError from e return (payload, fdset, msg[3]) - def send(self, payload: object, *, fds: Optional[list] = None): + def _send_via_fd(self, serialized: bytes, fds: List[int]): + assert self._socket is not None + with memfd("jsoncomm/payload") as fd_payload: + os.write(fd_payload, serialized) + os.lseek(fd_payload, 0, 0) + cmsg = [] + cmsg.append((socket.SOL_SOCKET, socket.SCM_RIGHTS, array.array("i", [fd_payload] + fds))) + n = self._socket.sendmsg([ARGS_VIA_FD_MARKER], cmsg, 0) + assert n == len(ARGS_VIA_FD_MARKER) + + def _send_via_sendmsg(self, serialized: bytes, fds: List[int]): + assert self._socket is not None + cmsg = [] + if fds: + cmsg.append((socket.SOL_SOCKET, socket.SCM_RIGHTS, array.array("i", fds))) + try: + self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, len(serialized)) + n = self._socket.sendmsg([serialized], cmsg, 0) + except OSError as exc: + if exc.errno == errno.EMSGSIZE: + raise BufferError( + f"jsoncomm message size {len(serialized)} is too big") from exc + raise exc + assert n == len(serialized) + + def send(self, payload: object, *, fds: Optional[list] = None) -> None: """Send Message Send a new message via this socket. This operation is synchronous. The @@ -399,19 +460,14 @@ class Socket(contextlib.AbstractContextManager): if not self._socket: raise RuntimeError("Tried to send without socket.") - serialized = json.dumps(payload).encode() - cmsg = [] - if fds: - cmsg.append((socket.SOL_SOCKET, socket.SCM_RIGHTS, array.array("i", fds))) + if not fds: + fds = [] - try: - n = self._socket.sendmsg([serialized], cmsg, 0) - except OSError as exc: - if exc.errno == errno.EMSGSIZE: - raise BufferError( - f"jsoncomm message size {len(serialized)} is too big") from exc - raise exc - assert n == len(serialized) + serialized = json.dumps(payload).encode() + if len(serialized) > wmem_max(): + self._send_via_fd(serialized, fds) + else: + self._send_via_sendmsg(serialized, fds) def send_and_recv(self, payload: object, *, fds: Optional[list] = None): """Send a message and wait for a reply diff --git a/test/mod/test_inputs.py b/test/mod/test_inputs.py index 915d1313..fc2dc291 100644 --- a/test/mod/test_inputs.py +++ b/test/mod/test_inputs.py @@ -1,9 +1,7 @@ -import json import os -from unittest.mock import patch +from unittest.mock import call, patch from osbuild import inputs -from osbuild.util.jsoncomm import FdSet class FakeInputService(inputs.InputService): @@ -11,7 +9,7 @@ class FakeInputService(inputs.InputService): # do not call "super().__init__()" here to make it testable self.map_calls = [] - def map(self, _store, origin, refs, target, options): + def map(self, store, origin, refs, target, options): self.map_calls.append([origin, refs, target, options]) return "complex", 2, "reply" @@ -20,8 +18,6 @@ def test_inputs_dispatches_map(tmp_path): store_api_path = tmp_path / "api-store" store_api_path.write_text("") - args_path = tmp_path / "args" - reply_path = tmp_path / "reply" args = { "api": { "store": os.fspath(store_api_path), @@ -31,17 +27,14 @@ def test_inputs_dispatches_map(tmp_path): "target": "some-target", "options": "some-options", } - args_path.write_text(json.dumps(args)) - reply_path.write_text("") - with args_path.open() as f_args, reply_path.open("w") as f_reply: - fd_args, fd_reply = os.dup(f_args.fileno()), os.dup(f_reply.fileno()) - fds = FdSet.from_list([fd_args, fd_reply]) - fake_service = FakeInputService(args="some") - with patch.object(inputs, "StoreClient"): - r = fake_service.dispatch("map", None, fds) - assert r == ('{}', None) - assert fake_service.map_calls == [ - ["some-origin", "some-refs", "some-target", "some-options"], - ] - assert reply_path.read_text() == '["complex", 2, "reply"]' + fake_service = FakeInputService(args="some") + with patch.object(inputs, "StoreClient") as mocked_store_client_klass: + r = fake_service.dispatch("map", args, None) + assert mocked_store_client_klass.call_args_list == [ + call(connect_to=os.fspath(store_api_path)), + ] + assert fake_service.map_calls == [ + ["some-origin", "some-refs", "some-target", "some-options"], + ] + assert r == (("complex", 2, "reply"), None) diff --git a/test/mod/test_util_jsoncomm.py b/test/mod/test_util_jsoncomm.py index 93d62a13..53c406d3 100644 --- a/test/mod/test_util_jsoncomm.py +++ b/test/mod/test_util_jsoncomm.py @@ -2,13 +2,17 @@ # Tests for the 'osbuild.util.jsoncomm' module. # +# pylint: disable=protected-access + import asyncio import errno +import json import os import pathlib import tempfile import unittest from concurrent import futures +from unittest.mock import patch import pytest @@ -220,11 +224,39 @@ class TestUtilJsonComm(unittest.TestCase): pong, _, _ = a.recv() self.assertEqual(ping, pong) - def test_send_and_recv_tons_of_data_still_errors(self): + def test_sendmsg_errors_with_size_on_EMSGSIZE(self): a, _ = jsoncomm.Socket.new_pair() - ping = {"data": "1" * 1_000_000} + serialized = json.dumps({"data": "1" * 1_000_000}).encode() with pytest.raises(BufferError) as exc: - a.send(ping) + a._send_via_sendmsg(serialized, []) assert str(exc.value) == "jsoncomm message size 1000012 is too big" assert exc.value.__cause__.errno == errno.EMSGSIZE + + def test_send_and_recv_tons_of_data_is_fine(self): + a, b = jsoncomm.Socket.new_pair() + + ping = {"data": "tons" * 1_000_000} + a.send(ping) + pong, _, _ = b.send_and_recv(ping) + self.assertEqual(ping, pong) + pong, _, _ = a.recv() + self.assertEqual(ping, pong) + + def test_send_small_data_via_sendmsg(self): + a, _ = jsoncomm.Socket.new_pair() + with patch.object(a, "_send_via_fd") as mock_send_via_fd, \ + patch.object(a, "_send_via_sendmsg") as mock_send_via_sendmsg: + ping = {"data": "little"} + a.send(ping) + assert mock_send_via_fd.call_count == 0 + assert mock_send_via_sendmsg.call_count == 1 + + def test_send_huge_data_via_fd(self): + a, _ = jsoncomm.Socket.new_pair() + with patch.object(a, "_send_via_fd") as mock_send_via_fd, \ + patch.object(a, "_send_via_sendmsg") as mock_send_via_sendmsg: + ping = {"data": "tons" * 1_000_000} + a.send(ping) + assert mock_send_via_fd.call_count == 1 + assert mock_send_via_sendmsg.call_count == 0