jsoncomm: transparently handle huge messages via fds
The existing jsoncomm is a work of beautiy. For very big arguments however the used `SOCK_SEQPACKET` hits the limitations of the kernel network buffer size (see also [0]). This lead to various workarounds in #824,#1331,#1836 where parts of the request are encoded as part of the json method call and parts are done via a side-channel via fd-passing. This commit changes the code so that the fd channel is automatically and transparently created and the workarounds are removed. A test is added that ensures that very big messages can be passed. [0] https://github.com/osbuild/osbuild/pull/1833
This commit is contained in:
parent
d67fa48c17
commit
0abdfb9041
6 changed files with 130 additions and 87 deletions
|
|
@ -17,17 +17,15 @@ osbuild is the path. The input options are just passed to the
|
|||
"""
|
||||
|
||||
import abc
|
||||
import contextlib
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
from osbuild import host
|
||||
from osbuild.util.types import PathLike
|
||||
|
||||
from .objectstore import ObjectStore, StoreClient, StoreServer
|
||||
from .objectstore import StoreClient, StoreServer
|
||||
|
||||
|
||||
class Input:
|
||||
|
|
@ -67,7 +65,7 @@ class InputManager:
|
|||
self.root = root
|
||||
self.inputs: Dict[str, Input] = {}
|
||||
|
||||
def map(self, ip: Input, store: ObjectStore) -> Tuple[str, Dict]:
|
||||
def map(self, ip: Input) -> Tuple[str, Dict]:
|
||||
|
||||
target = os.path.join(self.root, ip.name)
|
||||
os.makedirs(target)
|
||||
|
|
@ -88,12 +86,8 @@ class InputManager:
|
|||
}
|
||||
}
|
||||
|
||||
with make_args_and_reply_files(store.tmp, args) as (fd_args, fd_reply):
|
||||
fds = [fd_args, fd_reply]
|
||||
client = self.service_manager.start(f"input/{ip.name}", ip.info.path)
|
||||
_, _ = client.call_with_fds("map", {}, fds)
|
||||
with os.fdopen(os.dup(fd_reply)) as f:
|
||||
reply = json.loads(f.read())
|
||||
reply = client.call("map", args)
|
||||
|
||||
path = reply["path"]
|
||||
|
||||
|
|
@ -107,15 +101,6 @@ class InputManager:
|
|||
return reply
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def make_args_and_reply_files(tmp, args):
|
||||
with tempfile.TemporaryFile("w+", dir=tmp, encoding="utf-8") as f_args, \
|
||||
tempfile.TemporaryFile("w+", dir=tmp, encoding="utf-8") as f_reply:
|
||||
json.dump(args, f_args)
|
||||
f_args.seek(0)
|
||||
yield f_args.fileno(), f_reply.fileno()
|
||||
|
||||
|
||||
class InputService(host.Service):
|
||||
"""Input host service"""
|
||||
|
||||
|
|
@ -129,21 +114,14 @@ class InputService(host.Service):
|
|||
def stop(self):
|
||||
self.unmap()
|
||||
|
||||
def dispatch(self, method: str, _, fds):
|
||||
def dispatch(self, method: str, args, fds):
|
||||
if method == "map":
|
||||
# map() sends fd[0] to read the arguments from and fd[1] to
|
||||
# write the reply back. This avoids running into EMSGSIZE
|
||||
with os.fdopen(fds.steal(0)) as f:
|
||||
args = json.load(f)
|
||||
store = StoreClient(connect_to=args["api"]["store"])
|
||||
r = self.map(store,
|
||||
args["origin"],
|
||||
args["refs"],
|
||||
args["target"],
|
||||
args["options"])
|
||||
with os.fdopen(fds.steal(1), "w") as f:
|
||||
f.write(json.dumps(r))
|
||||
f.seek(0)
|
||||
return "{}", None
|
||||
return r, None
|
||||
|
||||
raise host.ProtocolError("Unknown method")
|
||||
|
|
|
|||
|
|
@ -210,7 +210,7 @@ class Stage:
|
|||
|
||||
ipmgr = InputManager(mgr, storeapi, inputs_tmpdir)
|
||||
for key, ip in self.inputs.items():
|
||||
data = ipmgr.map(ip, store)
|
||||
data = ipmgr.map(ip)
|
||||
inputs[key] = data
|
||||
|
||||
devmgr = DeviceManager(mgr, build_root.dev, tree)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import abc
|
||||
import contextlib
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
|
|
@ -29,6 +28,7 @@ class Source:
|
|||
cache = os.path.join(store.store, "sources")
|
||||
|
||||
args = {
|
||||
"items": self.items,
|
||||
"options": self.options,
|
||||
"cache": cache,
|
||||
"output": None,
|
||||
|
|
@ -36,20 +36,10 @@ class Source:
|
|||
}
|
||||
|
||||
client = mgr.start(f"source/{source}", self.info.path)
|
||||
|
||||
with self.make_items_file(store.tmp) as fd:
|
||||
fds = [fd]
|
||||
reply = client.call_with_fds("download", args, fds)
|
||||
reply = client.call("download", args)
|
||||
|
||||
return reply
|
||||
|
||||
@contextlib.contextmanager
|
||||
def make_items_file(self, tmp):
|
||||
with tempfile.TemporaryFile("w+", dir=tmp, encoding="utf-8") as f:
|
||||
json.dump(self.items, f)
|
||||
f.seek(0)
|
||||
yield f.fileno()
|
||||
|
||||
# "name", "id", "stages", "results" is only here to make it looks like a
|
||||
# pipeline for the monitor. This should be revisited at some point
|
||||
# and maybe the monitor should get first-class support for
|
||||
|
|
@ -103,12 +93,6 @@ class SourceService(host.Service):
|
|||
"""Returns True if the item to download is in cache. """
|
||||
return os.path.isfile(f"{self.cache}/{checksum}")
|
||||
|
||||
@staticmethod
|
||||
def load_items(fds):
|
||||
with os.fdopen(fds.steal(0)) as f:
|
||||
items = json.load(f)
|
||||
return items
|
||||
|
||||
def setup(self, args):
|
||||
self.cache = os.path.join(args["cache"], self.content_type)
|
||||
os.makedirs(self.cache, exist_ok=True)
|
||||
|
|
@ -118,7 +102,7 @@ class SourceService(host.Service):
|
|||
if method == "download":
|
||||
self.setup(args)
|
||||
with tempfile.TemporaryDirectory(prefix=".unverified-", dir=self.cache) as self.tmpdir:
|
||||
self.fetch_all(SourceService.load_items(fds))
|
||||
self.fetch_all(args["items"])
|
||||
return None, None
|
||||
|
||||
raise host.ProtocolError("Unknown method")
|
||||
|
|
|
|||
|
|
@ -5,18 +5,31 @@ serialization. It uses unix-domain-datagram-sockets and provides a simple
|
|||
unicast message transmission.
|
||||
"""
|
||||
|
||||
|
||||
import array
|
||||
import contextlib
|
||||
import errno
|
||||
import json
|
||||
import os
|
||||
import socket
|
||||
from typing import Any, Optional
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from .types import PathLike
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def memfd(name):
|
||||
fd = os.memfd_create(name, 0)
|
||||
try:
|
||||
yield fd
|
||||
finally:
|
||||
os.close(fd)
|
||||
|
||||
|
||||
# this marker is used when the arguments are passed via a filedescriptor
|
||||
# because they exceed the allowed size for a network package
|
||||
ARGS_VIA_FD_MARKER = b"<args-via-fd>"
|
||||
|
||||
|
||||
class FdSet:
|
||||
"""File-Descriptor Set
|
||||
|
||||
|
|
@ -92,6 +105,19 @@ class FdSet:
|
|||
return v
|
||||
|
||||
|
||||
def wmem_max() -> int:
|
||||
""" Return the kernels maximum send socket buffer size in bytes
|
||||
|
||||
When /proc is not mounted return a conservative estimate (64kb).
|
||||
"""
|
||||
try:
|
||||
with open("/proc/sys/net/core/wmem_max", encoding="utf8") as wmem_file:
|
||||
return int(wmem_file.read().strip())
|
||||
except FileNotFoundError:
|
||||
# conservative estimate for systems that have no /proc mounted
|
||||
return 64_000
|
||||
|
||||
|
||||
class Socket(contextlib.AbstractContextManager):
|
||||
"""Communication Socket
|
||||
|
||||
|
|
@ -353,7 +379,17 @@ class Socket(contextlib.AbstractContextManager):
|
|||
if level == socket.SOL_SOCKET and ty == socket.SCM_RIGHTS:
|
||||
assert len(data) % fds.itemsize == 0
|
||||
fds.frombytes(data)
|
||||
# Next we need to check if the serialzed data comes via an FD
|
||||
# or via the message. FDs are used if the data size is big to
|
||||
# avoid running into errno.EMSGSIZE
|
||||
if msg[0] == ARGS_VIA_FD_MARKER:
|
||||
fd_payload = fds[0]
|
||||
fdset = FdSet(rawfds=fds[1:])
|
||||
with os.fdopen(fd_payload) as f:
|
||||
serialized = f.read()
|
||||
else:
|
||||
fdset = FdSet(rawfds=fds)
|
||||
serialized = msg[0]
|
||||
|
||||
# Check the returned message flags. If the message was truncated, we
|
||||
# have to discard it. This shouldn't happen, but there is no harm in
|
||||
|
|
@ -364,13 +400,38 @@ class Socket(contextlib.AbstractContextManager):
|
|||
raise BufferError
|
||||
|
||||
try:
|
||||
payload = json.loads(msg[0])
|
||||
payload = json.loads(serialized)
|
||||
except json.JSONDecodeError as e:
|
||||
raise BufferError from e
|
||||
|
||||
return (payload, fdset, msg[3])
|
||||
|
||||
def send(self, payload: object, *, fds: Optional[list] = None):
|
||||
def _send_via_fd(self, serialized: bytes, fds: List[int]):
|
||||
assert self._socket is not None
|
||||
with memfd("jsoncomm/payload") as fd_payload:
|
||||
os.write(fd_payload, serialized)
|
||||
os.lseek(fd_payload, 0, 0)
|
||||
cmsg = []
|
||||
cmsg.append((socket.SOL_SOCKET, socket.SCM_RIGHTS, array.array("i", [fd_payload] + fds)))
|
||||
n = self._socket.sendmsg([ARGS_VIA_FD_MARKER], cmsg, 0)
|
||||
assert n == len(ARGS_VIA_FD_MARKER)
|
||||
|
||||
def _send_via_sendmsg(self, serialized: bytes, fds: List[int]):
|
||||
assert self._socket is not None
|
||||
cmsg = []
|
||||
if fds:
|
||||
cmsg.append((socket.SOL_SOCKET, socket.SCM_RIGHTS, array.array("i", fds)))
|
||||
try:
|
||||
self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, len(serialized))
|
||||
n = self._socket.sendmsg([serialized], cmsg, 0)
|
||||
except OSError as exc:
|
||||
if exc.errno == errno.EMSGSIZE:
|
||||
raise BufferError(
|
||||
f"jsoncomm message size {len(serialized)} is too big") from exc
|
||||
raise exc
|
||||
assert n == len(serialized)
|
||||
|
||||
def send(self, payload: object, *, fds: Optional[list] = None) -> None:
|
||||
"""Send Message
|
||||
|
||||
Send a new message via this socket. This operation is synchronous. The
|
||||
|
|
@ -399,19 +460,14 @@ class Socket(contextlib.AbstractContextManager):
|
|||
if not self._socket:
|
||||
raise RuntimeError("Tried to send without socket.")
|
||||
|
||||
serialized = json.dumps(payload).encode()
|
||||
cmsg = []
|
||||
if fds:
|
||||
cmsg.append((socket.SOL_SOCKET, socket.SCM_RIGHTS, array.array("i", fds)))
|
||||
if not fds:
|
||||
fds = []
|
||||
|
||||
try:
|
||||
n = self._socket.sendmsg([serialized], cmsg, 0)
|
||||
except OSError as exc:
|
||||
if exc.errno == errno.EMSGSIZE:
|
||||
raise BufferError(
|
||||
f"jsoncomm message size {len(serialized)} is too big") from exc
|
||||
raise exc
|
||||
assert n == len(serialized)
|
||||
serialized = json.dumps(payload).encode()
|
||||
if len(serialized) > wmem_max():
|
||||
self._send_via_fd(serialized, fds)
|
||||
else:
|
||||
self._send_via_sendmsg(serialized, fds)
|
||||
|
||||
def send_and_recv(self, payload: object, *, fds: Optional[list] = None):
|
||||
"""Send a message and wait for a reply
|
||||
|
|
|
|||
|
|
@ -1,9 +1,7 @@
|
|||
import json
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import call, patch
|
||||
|
||||
from osbuild import inputs
|
||||
from osbuild.util.jsoncomm import FdSet
|
||||
|
||||
|
||||
class FakeInputService(inputs.InputService):
|
||||
|
|
@ -11,7 +9,7 @@ class FakeInputService(inputs.InputService):
|
|||
# do not call "super().__init__()" here to make it testable
|
||||
self.map_calls = []
|
||||
|
||||
def map(self, _store, origin, refs, target, options):
|
||||
def map(self, store, origin, refs, target, options):
|
||||
self.map_calls.append([origin, refs, target, options])
|
||||
return "complex", 2, "reply"
|
||||
|
||||
|
|
@ -20,8 +18,6 @@ def test_inputs_dispatches_map(tmp_path):
|
|||
store_api_path = tmp_path / "api-store"
|
||||
store_api_path.write_text("")
|
||||
|
||||
args_path = tmp_path / "args"
|
||||
reply_path = tmp_path / "reply"
|
||||
args = {
|
||||
"api": {
|
||||
"store": os.fspath(store_api_path),
|
||||
|
|
@ -31,17 +27,14 @@ def test_inputs_dispatches_map(tmp_path):
|
|||
"target": "some-target",
|
||||
"options": "some-options",
|
||||
}
|
||||
args_path.write_text(json.dumps(args))
|
||||
reply_path.write_text("")
|
||||
|
||||
with args_path.open() as f_args, reply_path.open("w") as f_reply:
|
||||
fd_args, fd_reply = os.dup(f_args.fileno()), os.dup(f_reply.fileno())
|
||||
fds = FdSet.from_list([fd_args, fd_reply])
|
||||
fake_service = FakeInputService(args="some")
|
||||
with patch.object(inputs, "StoreClient"):
|
||||
r = fake_service.dispatch("map", None, fds)
|
||||
assert r == ('{}', None)
|
||||
with patch.object(inputs, "StoreClient") as mocked_store_client_klass:
|
||||
r = fake_service.dispatch("map", args, None)
|
||||
assert mocked_store_client_klass.call_args_list == [
|
||||
call(connect_to=os.fspath(store_api_path)),
|
||||
]
|
||||
assert fake_service.map_calls == [
|
||||
["some-origin", "some-refs", "some-target", "some-options"],
|
||||
]
|
||||
assert reply_path.read_text() == '["complex", 2, "reply"]'
|
||||
assert r == (("complex", 2, "reply"), None)
|
||||
|
|
|
|||
|
|
@ -2,13 +2,17 @@
|
|||
# Tests for the 'osbuild.util.jsoncomm' module.
|
||||
#
|
||||
|
||||
# pylint: disable=protected-access
|
||||
|
||||
import asyncio
|
||||
import errno
|
||||
import json
|
||||
import os
|
||||
import pathlib
|
||||
import tempfile
|
||||
import unittest
|
||||
from concurrent import futures
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
|
@ -220,11 +224,39 @@ class TestUtilJsonComm(unittest.TestCase):
|
|||
pong, _, _ = a.recv()
|
||||
self.assertEqual(ping, pong)
|
||||
|
||||
def test_send_and_recv_tons_of_data_still_errors(self):
|
||||
def test_sendmsg_errors_with_size_on_EMSGSIZE(self):
|
||||
a, _ = jsoncomm.Socket.new_pair()
|
||||
|
||||
ping = {"data": "1" * 1_000_000}
|
||||
serialized = json.dumps({"data": "1" * 1_000_000}).encode()
|
||||
with pytest.raises(BufferError) as exc:
|
||||
a.send(ping)
|
||||
a._send_via_sendmsg(serialized, [])
|
||||
assert str(exc.value) == "jsoncomm message size 1000012 is too big"
|
||||
assert exc.value.__cause__.errno == errno.EMSGSIZE
|
||||
|
||||
def test_send_and_recv_tons_of_data_is_fine(self):
|
||||
a, b = jsoncomm.Socket.new_pair()
|
||||
|
||||
ping = {"data": "tons" * 1_000_000}
|
||||
a.send(ping)
|
||||
pong, _, _ = b.send_and_recv(ping)
|
||||
self.assertEqual(ping, pong)
|
||||
pong, _, _ = a.recv()
|
||||
self.assertEqual(ping, pong)
|
||||
|
||||
def test_send_small_data_via_sendmsg(self):
|
||||
a, _ = jsoncomm.Socket.new_pair()
|
||||
with patch.object(a, "_send_via_fd") as mock_send_via_fd, \
|
||||
patch.object(a, "_send_via_sendmsg") as mock_send_via_sendmsg:
|
||||
ping = {"data": "little"}
|
||||
a.send(ping)
|
||||
assert mock_send_via_fd.call_count == 0
|
||||
assert mock_send_via_sendmsg.call_count == 1
|
||||
|
||||
def test_send_huge_data_via_fd(self):
|
||||
a, _ = jsoncomm.Socket.new_pair()
|
||||
with patch.object(a, "_send_via_fd") as mock_send_via_fd, \
|
||||
patch.object(a, "_send_via_sendmsg") as mock_send_via_sendmsg:
|
||||
ping = {"data": "tons" * 1_000_000}
|
||||
a.send(ping)
|
||||
assert mock_send_via_fd.call_count == 1
|
||||
assert mock_send_via_sendmsg.call_count == 0
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue