jsoncomm: transparently handle huge messages via fds
The existing jsoncomm is a work of beautiy. For very big arguments however the used `SOCK_SEQPACKET` hits the limitations of the kernel network buffer size (see also [0]). This lead to various workarounds in #824,#1331,#1836 where parts of the request are encoded as part of the json method call and parts are done via a side-channel via fd-passing. This commit changes the code so that the fd channel is automatically and transparently created and the workarounds are removed. A test is added that ensures that very big messages can be passed. [0] https://github.com/osbuild/osbuild/pull/1833
This commit is contained in:
parent
d67fa48c17
commit
0abdfb9041
6 changed files with 130 additions and 87 deletions
|
|
@ -17,17 +17,15 @@ osbuild is the path. The input options are just passed to the
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
import contextlib
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import tempfile
|
|
||||||
from typing import Any, Dict, Optional, Tuple
|
from typing import Any, Dict, Optional, Tuple
|
||||||
|
|
||||||
from osbuild import host
|
from osbuild import host
|
||||||
from osbuild.util.types import PathLike
|
from osbuild.util.types import PathLike
|
||||||
|
|
||||||
from .objectstore import ObjectStore, StoreClient, StoreServer
|
from .objectstore import StoreClient, StoreServer
|
||||||
|
|
||||||
|
|
||||||
class Input:
|
class Input:
|
||||||
|
|
@ -67,7 +65,7 @@ class InputManager:
|
||||||
self.root = root
|
self.root = root
|
||||||
self.inputs: Dict[str, Input] = {}
|
self.inputs: Dict[str, Input] = {}
|
||||||
|
|
||||||
def map(self, ip: Input, store: ObjectStore) -> Tuple[str, Dict]:
|
def map(self, ip: Input) -> Tuple[str, Dict]:
|
||||||
|
|
||||||
target = os.path.join(self.root, ip.name)
|
target = os.path.join(self.root, ip.name)
|
||||||
os.makedirs(target)
|
os.makedirs(target)
|
||||||
|
|
@ -88,12 +86,8 @@ class InputManager:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
with make_args_and_reply_files(store.tmp, args) as (fd_args, fd_reply):
|
client = self.service_manager.start(f"input/{ip.name}", ip.info.path)
|
||||||
fds = [fd_args, fd_reply]
|
reply = client.call("map", args)
|
||||||
client = self.service_manager.start(f"input/{ip.name}", ip.info.path)
|
|
||||||
_, _ = client.call_with_fds("map", {}, fds)
|
|
||||||
with os.fdopen(os.dup(fd_reply)) as f:
|
|
||||||
reply = json.loads(f.read())
|
|
||||||
|
|
||||||
path = reply["path"]
|
path = reply["path"]
|
||||||
|
|
||||||
|
|
@ -107,15 +101,6 @@ class InputManager:
|
||||||
return reply
|
return reply
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
|
||||||
def make_args_and_reply_files(tmp, args):
|
|
||||||
with tempfile.TemporaryFile("w+", dir=tmp, encoding="utf-8") as f_args, \
|
|
||||||
tempfile.TemporaryFile("w+", dir=tmp, encoding="utf-8") as f_reply:
|
|
||||||
json.dump(args, f_args)
|
|
||||||
f_args.seek(0)
|
|
||||||
yield f_args.fileno(), f_reply.fileno()
|
|
||||||
|
|
||||||
|
|
||||||
class InputService(host.Service):
|
class InputService(host.Service):
|
||||||
"""Input host service"""
|
"""Input host service"""
|
||||||
|
|
||||||
|
|
@ -129,21 +114,14 @@ class InputService(host.Service):
|
||||||
def stop(self):
|
def stop(self):
|
||||||
self.unmap()
|
self.unmap()
|
||||||
|
|
||||||
def dispatch(self, method: str, _, fds):
|
def dispatch(self, method: str, args, fds):
|
||||||
if method == "map":
|
if method == "map":
|
||||||
# map() sends fd[0] to read the arguments from and fd[1] to
|
|
||||||
# write the reply back. This avoids running into EMSGSIZE
|
|
||||||
with os.fdopen(fds.steal(0)) as f:
|
|
||||||
args = json.load(f)
|
|
||||||
store = StoreClient(connect_to=args["api"]["store"])
|
store = StoreClient(connect_to=args["api"]["store"])
|
||||||
r = self.map(store,
|
r = self.map(store,
|
||||||
args["origin"],
|
args["origin"],
|
||||||
args["refs"],
|
args["refs"],
|
||||||
args["target"],
|
args["target"],
|
||||||
args["options"])
|
args["options"])
|
||||||
with os.fdopen(fds.steal(1), "w") as f:
|
return r, None
|
||||||
f.write(json.dumps(r))
|
|
||||||
f.seek(0)
|
|
||||||
return "{}", None
|
|
||||||
|
|
||||||
raise host.ProtocolError("Unknown method")
|
raise host.ProtocolError("Unknown method")
|
||||||
|
|
|
||||||
|
|
@ -210,7 +210,7 @@ class Stage:
|
||||||
|
|
||||||
ipmgr = InputManager(mgr, storeapi, inputs_tmpdir)
|
ipmgr = InputManager(mgr, storeapi, inputs_tmpdir)
|
||||||
for key, ip in self.inputs.items():
|
for key, ip in self.inputs.items():
|
||||||
data = ipmgr.map(ip, store)
|
data = ipmgr.map(ip)
|
||||||
inputs[key] = data
|
inputs[key] = data
|
||||||
|
|
||||||
devmgr = DeviceManager(mgr, build_root.dev, tree)
|
devmgr = DeviceManager(mgr, build_root.dev, tree)
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,4 @@
|
||||||
import abc
|
import abc
|
||||||
import contextlib
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
|
@ -29,6 +28,7 @@ class Source:
|
||||||
cache = os.path.join(store.store, "sources")
|
cache = os.path.join(store.store, "sources")
|
||||||
|
|
||||||
args = {
|
args = {
|
||||||
|
"items": self.items,
|
||||||
"options": self.options,
|
"options": self.options,
|
||||||
"cache": cache,
|
"cache": cache,
|
||||||
"output": None,
|
"output": None,
|
||||||
|
|
@ -36,20 +36,10 @@ class Source:
|
||||||
}
|
}
|
||||||
|
|
||||||
client = mgr.start(f"source/{source}", self.info.path)
|
client = mgr.start(f"source/{source}", self.info.path)
|
||||||
|
reply = client.call("download", args)
|
||||||
with self.make_items_file(store.tmp) as fd:
|
|
||||||
fds = [fd]
|
|
||||||
reply = client.call_with_fds("download", args, fds)
|
|
||||||
|
|
||||||
return reply
|
return reply
|
||||||
|
|
||||||
@contextlib.contextmanager
|
|
||||||
def make_items_file(self, tmp):
|
|
||||||
with tempfile.TemporaryFile("w+", dir=tmp, encoding="utf-8") as f:
|
|
||||||
json.dump(self.items, f)
|
|
||||||
f.seek(0)
|
|
||||||
yield f.fileno()
|
|
||||||
|
|
||||||
# "name", "id", "stages", "results" is only here to make it looks like a
|
# "name", "id", "stages", "results" is only here to make it looks like a
|
||||||
# pipeline for the monitor. This should be revisited at some point
|
# pipeline for the monitor. This should be revisited at some point
|
||||||
# and maybe the monitor should get first-class support for
|
# and maybe the monitor should get first-class support for
|
||||||
|
|
@ -103,12 +93,6 @@ class SourceService(host.Service):
|
||||||
"""Returns True if the item to download is in cache. """
|
"""Returns True if the item to download is in cache. """
|
||||||
return os.path.isfile(f"{self.cache}/{checksum}")
|
return os.path.isfile(f"{self.cache}/{checksum}")
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def load_items(fds):
|
|
||||||
with os.fdopen(fds.steal(0)) as f:
|
|
||||||
items = json.load(f)
|
|
||||||
return items
|
|
||||||
|
|
||||||
def setup(self, args):
|
def setup(self, args):
|
||||||
self.cache = os.path.join(args["cache"], self.content_type)
|
self.cache = os.path.join(args["cache"], self.content_type)
|
||||||
os.makedirs(self.cache, exist_ok=True)
|
os.makedirs(self.cache, exist_ok=True)
|
||||||
|
|
@ -118,7 +102,7 @@ class SourceService(host.Service):
|
||||||
if method == "download":
|
if method == "download":
|
||||||
self.setup(args)
|
self.setup(args)
|
||||||
with tempfile.TemporaryDirectory(prefix=".unverified-", dir=self.cache) as self.tmpdir:
|
with tempfile.TemporaryDirectory(prefix=".unverified-", dir=self.cache) as self.tmpdir:
|
||||||
self.fetch_all(SourceService.load_items(fds))
|
self.fetch_all(args["items"])
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
raise host.ProtocolError("Unknown method")
|
raise host.ProtocolError("Unknown method")
|
||||||
|
|
|
||||||
|
|
@ -5,18 +5,31 @@ serialization. It uses unix-domain-datagram-sockets and provides a simple
|
||||||
unicast message transmission.
|
unicast message transmission.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
import array
|
import array
|
||||||
import contextlib
|
import contextlib
|
||||||
import errno
|
import errno
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import socket
|
import socket
|
||||||
from typing import Any, Optional
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
from .types import PathLike
|
from .types import PathLike
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def memfd(name):
|
||||||
|
fd = os.memfd_create(name, 0)
|
||||||
|
try:
|
||||||
|
yield fd
|
||||||
|
finally:
|
||||||
|
os.close(fd)
|
||||||
|
|
||||||
|
|
||||||
|
# this marker is used when the arguments are passed via a filedescriptor
|
||||||
|
# because they exceed the allowed size for a network package
|
||||||
|
ARGS_VIA_FD_MARKER = b"<args-via-fd>"
|
||||||
|
|
||||||
|
|
||||||
class FdSet:
|
class FdSet:
|
||||||
"""File-Descriptor Set
|
"""File-Descriptor Set
|
||||||
|
|
||||||
|
|
@ -92,6 +105,19 @@ class FdSet:
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
def wmem_max() -> int:
|
||||||
|
""" Return the kernels maximum send socket buffer size in bytes
|
||||||
|
|
||||||
|
When /proc is not mounted return a conservative estimate (64kb).
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with open("/proc/sys/net/core/wmem_max", encoding="utf8") as wmem_file:
|
||||||
|
return int(wmem_file.read().strip())
|
||||||
|
except FileNotFoundError:
|
||||||
|
# conservative estimate for systems that have no /proc mounted
|
||||||
|
return 64_000
|
||||||
|
|
||||||
|
|
||||||
class Socket(contextlib.AbstractContextManager):
|
class Socket(contextlib.AbstractContextManager):
|
||||||
"""Communication Socket
|
"""Communication Socket
|
||||||
|
|
||||||
|
|
@ -353,7 +379,17 @@ class Socket(contextlib.AbstractContextManager):
|
||||||
if level == socket.SOL_SOCKET and ty == socket.SCM_RIGHTS:
|
if level == socket.SOL_SOCKET and ty == socket.SCM_RIGHTS:
|
||||||
assert len(data) % fds.itemsize == 0
|
assert len(data) % fds.itemsize == 0
|
||||||
fds.frombytes(data)
|
fds.frombytes(data)
|
||||||
fdset = FdSet(rawfds=fds)
|
# Next we need to check if the serialzed data comes via an FD
|
||||||
|
# or via the message. FDs are used if the data size is big to
|
||||||
|
# avoid running into errno.EMSGSIZE
|
||||||
|
if msg[0] == ARGS_VIA_FD_MARKER:
|
||||||
|
fd_payload = fds[0]
|
||||||
|
fdset = FdSet(rawfds=fds[1:])
|
||||||
|
with os.fdopen(fd_payload) as f:
|
||||||
|
serialized = f.read()
|
||||||
|
else:
|
||||||
|
fdset = FdSet(rawfds=fds)
|
||||||
|
serialized = msg[0]
|
||||||
|
|
||||||
# Check the returned message flags. If the message was truncated, we
|
# Check the returned message flags. If the message was truncated, we
|
||||||
# have to discard it. This shouldn't happen, but there is no harm in
|
# have to discard it. This shouldn't happen, but there is no harm in
|
||||||
|
|
@ -364,13 +400,38 @@ class Socket(contextlib.AbstractContextManager):
|
||||||
raise BufferError
|
raise BufferError
|
||||||
|
|
||||||
try:
|
try:
|
||||||
payload = json.loads(msg[0])
|
payload = json.loads(serialized)
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
raise BufferError from e
|
raise BufferError from e
|
||||||
|
|
||||||
return (payload, fdset, msg[3])
|
return (payload, fdset, msg[3])
|
||||||
|
|
||||||
def send(self, payload: object, *, fds: Optional[list] = None):
|
def _send_via_fd(self, serialized: bytes, fds: List[int]):
|
||||||
|
assert self._socket is not None
|
||||||
|
with memfd("jsoncomm/payload") as fd_payload:
|
||||||
|
os.write(fd_payload, serialized)
|
||||||
|
os.lseek(fd_payload, 0, 0)
|
||||||
|
cmsg = []
|
||||||
|
cmsg.append((socket.SOL_SOCKET, socket.SCM_RIGHTS, array.array("i", [fd_payload] + fds)))
|
||||||
|
n = self._socket.sendmsg([ARGS_VIA_FD_MARKER], cmsg, 0)
|
||||||
|
assert n == len(ARGS_VIA_FD_MARKER)
|
||||||
|
|
||||||
|
def _send_via_sendmsg(self, serialized: bytes, fds: List[int]):
|
||||||
|
assert self._socket is not None
|
||||||
|
cmsg = []
|
||||||
|
if fds:
|
||||||
|
cmsg.append((socket.SOL_SOCKET, socket.SCM_RIGHTS, array.array("i", fds)))
|
||||||
|
try:
|
||||||
|
self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, len(serialized))
|
||||||
|
n = self._socket.sendmsg([serialized], cmsg, 0)
|
||||||
|
except OSError as exc:
|
||||||
|
if exc.errno == errno.EMSGSIZE:
|
||||||
|
raise BufferError(
|
||||||
|
f"jsoncomm message size {len(serialized)} is too big") from exc
|
||||||
|
raise exc
|
||||||
|
assert n == len(serialized)
|
||||||
|
|
||||||
|
def send(self, payload: object, *, fds: Optional[list] = None) -> None:
|
||||||
"""Send Message
|
"""Send Message
|
||||||
|
|
||||||
Send a new message via this socket. This operation is synchronous. The
|
Send a new message via this socket. This operation is synchronous. The
|
||||||
|
|
@ -399,19 +460,14 @@ class Socket(contextlib.AbstractContextManager):
|
||||||
if not self._socket:
|
if not self._socket:
|
||||||
raise RuntimeError("Tried to send without socket.")
|
raise RuntimeError("Tried to send without socket.")
|
||||||
|
|
||||||
serialized = json.dumps(payload).encode()
|
if not fds:
|
||||||
cmsg = []
|
fds = []
|
||||||
if fds:
|
|
||||||
cmsg.append((socket.SOL_SOCKET, socket.SCM_RIGHTS, array.array("i", fds)))
|
|
||||||
|
|
||||||
try:
|
serialized = json.dumps(payload).encode()
|
||||||
n = self._socket.sendmsg([serialized], cmsg, 0)
|
if len(serialized) > wmem_max():
|
||||||
except OSError as exc:
|
self._send_via_fd(serialized, fds)
|
||||||
if exc.errno == errno.EMSGSIZE:
|
else:
|
||||||
raise BufferError(
|
self._send_via_sendmsg(serialized, fds)
|
||||||
f"jsoncomm message size {len(serialized)} is too big") from exc
|
|
||||||
raise exc
|
|
||||||
assert n == len(serialized)
|
|
||||||
|
|
||||||
def send_and_recv(self, payload: object, *, fds: Optional[list] = None):
|
def send_and_recv(self, payload: object, *, fds: Optional[list] = None):
|
||||||
"""Send a message and wait for a reply
|
"""Send a message and wait for a reply
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,7 @@
|
||||||
import json
|
|
||||||
import os
|
import os
|
||||||
from unittest.mock import patch
|
from unittest.mock import call, patch
|
||||||
|
|
||||||
from osbuild import inputs
|
from osbuild import inputs
|
||||||
from osbuild.util.jsoncomm import FdSet
|
|
||||||
|
|
||||||
|
|
||||||
class FakeInputService(inputs.InputService):
|
class FakeInputService(inputs.InputService):
|
||||||
|
|
@ -11,7 +9,7 @@ class FakeInputService(inputs.InputService):
|
||||||
# do not call "super().__init__()" here to make it testable
|
# do not call "super().__init__()" here to make it testable
|
||||||
self.map_calls = []
|
self.map_calls = []
|
||||||
|
|
||||||
def map(self, _store, origin, refs, target, options):
|
def map(self, store, origin, refs, target, options):
|
||||||
self.map_calls.append([origin, refs, target, options])
|
self.map_calls.append([origin, refs, target, options])
|
||||||
return "complex", 2, "reply"
|
return "complex", 2, "reply"
|
||||||
|
|
||||||
|
|
@ -20,8 +18,6 @@ def test_inputs_dispatches_map(tmp_path):
|
||||||
store_api_path = tmp_path / "api-store"
|
store_api_path = tmp_path / "api-store"
|
||||||
store_api_path.write_text("")
|
store_api_path.write_text("")
|
||||||
|
|
||||||
args_path = tmp_path / "args"
|
|
||||||
reply_path = tmp_path / "reply"
|
|
||||||
args = {
|
args = {
|
||||||
"api": {
|
"api": {
|
||||||
"store": os.fspath(store_api_path),
|
"store": os.fspath(store_api_path),
|
||||||
|
|
@ -31,17 +27,14 @@ def test_inputs_dispatches_map(tmp_path):
|
||||||
"target": "some-target",
|
"target": "some-target",
|
||||||
"options": "some-options",
|
"options": "some-options",
|
||||||
}
|
}
|
||||||
args_path.write_text(json.dumps(args))
|
|
||||||
reply_path.write_text("")
|
|
||||||
|
|
||||||
with args_path.open() as f_args, reply_path.open("w") as f_reply:
|
fake_service = FakeInputService(args="some")
|
||||||
fd_args, fd_reply = os.dup(f_args.fileno()), os.dup(f_reply.fileno())
|
with patch.object(inputs, "StoreClient") as mocked_store_client_klass:
|
||||||
fds = FdSet.from_list([fd_args, fd_reply])
|
r = fake_service.dispatch("map", args, None)
|
||||||
fake_service = FakeInputService(args="some")
|
assert mocked_store_client_klass.call_args_list == [
|
||||||
with patch.object(inputs, "StoreClient"):
|
call(connect_to=os.fspath(store_api_path)),
|
||||||
r = fake_service.dispatch("map", None, fds)
|
]
|
||||||
assert r == ('{}', None)
|
assert fake_service.map_calls == [
|
||||||
assert fake_service.map_calls == [
|
["some-origin", "some-refs", "some-target", "some-options"],
|
||||||
["some-origin", "some-refs", "some-target", "some-options"],
|
]
|
||||||
]
|
assert r == (("complex", 2, "reply"), None)
|
||||||
assert reply_path.read_text() == '["complex", 2, "reply"]'
|
|
||||||
|
|
|
||||||
|
|
@ -2,13 +2,17 @@
|
||||||
# Tests for the 'osbuild.util.jsoncomm' module.
|
# Tests for the 'osbuild.util.jsoncomm' module.
|
||||||
#
|
#
|
||||||
|
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import errno
|
import errno
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import pathlib
|
import pathlib
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
from concurrent import futures
|
from concurrent import futures
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
@ -220,11 +224,39 @@ class TestUtilJsonComm(unittest.TestCase):
|
||||||
pong, _, _ = a.recv()
|
pong, _, _ = a.recv()
|
||||||
self.assertEqual(ping, pong)
|
self.assertEqual(ping, pong)
|
||||||
|
|
||||||
def test_send_and_recv_tons_of_data_still_errors(self):
|
def test_sendmsg_errors_with_size_on_EMSGSIZE(self):
|
||||||
a, _ = jsoncomm.Socket.new_pair()
|
a, _ = jsoncomm.Socket.new_pair()
|
||||||
|
|
||||||
ping = {"data": "1" * 1_000_000}
|
serialized = json.dumps({"data": "1" * 1_000_000}).encode()
|
||||||
with pytest.raises(BufferError) as exc:
|
with pytest.raises(BufferError) as exc:
|
||||||
a.send(ping)
|
a._send_via_sendmsg(serialized, [])
|
||||||
assert str(exc.value) == "jsoncomm message size 1000012 is too big"
|
assert str(exc.value) == "jsoncomm message size 1000012 is too big"
|
||||||
assert exc.value.__cause__.errno == errno.EMSGSIZE
|
assert exc.value.__cause__.errno == errno.EMSGSIZE
|
||||||
|
|
||||||
|
def test_send_and_recv_tons_of_data_is_fine(self):
|
||||||
|
a, b = jsoncomm.Socket.new_pair()
|
||||||
|
|
||||||
|
ping = {"data": "tons" * 1_000_000}
|
||||||
|
a.send(ping)
|
||||||
|
pong, _, _ = b.send_and_recv(ping)
|
||||||
|
self.assertEqual(ping, pong)
|
||||||
|
pong, _, _ = a.recv()
|
||||||
|
self.assertEqual(ping, pong)
|
||||||
|
|
||||||
|
def test_send_small_data_via_sendmsg(self):
|
||||||
|
a, _ = jsoncomm.Socket.new_pair()
|
||||||
|
with patch.object(a, "_send_via_fd") as mock_send_via_fd, \
|
||||||
|
patch.object(a, "_send_via_sendmsg") as mock_send_via_sendmsg:
|
||||||
|
ping = {"data": "little"}
|
||||||
|
a.send(ping)
|
||||||
|
assert mock_send_via_fd.call_count == 0
|
||||||
|
assert mock_send_via_sendmsg.call_count == 1
|
||||||
|
|
||||||
|
def test_send_huge_data_via_fd(self):
|
||||||
|
a, _ = jsoncomm.Socket.new_pair()
|
||||||
|
with patch.object(a, "_send_via_fd") as mock_send_via_fd, \
|
||||||
|
patch.object(a, "_send_via_sendmsg") as mock_send_via_sendmsg:
|
||||||
|
ping = {"data": "tons" * 1_000_000}
|
||||||
|
a.send(ping)
|
||||||
|
assert mock_send_via_fd.call_count == 1
|
||||||
|
assert mock_send_via_sendmsg.call_count == 0
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue