jsoncomm: transparently handle huge messages via fds

The existing jsoncomm is a work of beautiy. For very big arguments
however the used `SOCK_SEQPACKET` hits the limitations of the
kernel network buffer size (see also [0]). This lead to various
workarounds in #824,#1331,#1836 where parts of the request are
encoded as part of the json method call and parts are done via
a side-channel via fd-passing.

This commit changes the code so that the fd channel is automatically
and transparently created and the workarounds are removed. A test
is added that ensures that very big messages can be passed.

[0] https://github.com/osbuild/osbuild/pull/1833
This commit is contained in:
Michael Vogt 2024-08-07 16:49:30 +02:00 committed by Achilleas Koutsou
parent d67fa48c17
commit 0abdfb9041
6 changed files with 130 additions and 87 deletions

View file

@ -5,18 +5,31 @@ serialization. It uses unix-domain-datagram-sockets and provides a simple
unicast message transmission.
"""
import array
import contextlib
import errno
import json
import os
import socket
from typing import Any, Optional
from typing import Any, List, Optional
from .types import PathLike
@contextlib.contextmanager
def memfd(name):
fd = os.memfd_create(name, 0)
try:
yield fd
finally:
os.close(fd)
# this marker is used when the arguments are passed via a filedescriptor
# because they exceed the allowed size for a network package
ARGS_VIA_FD_MARKER = b"<args-via-fd>"
class FdSet:
"""File-Descriptor Set
@ -92,6 +105,19 @@ class FdSet:
return v
def wmem_max() -> int:
""" Return the kernels maximum send socket buffer size in bytes
When /proc is not mounted return a conservative estimate (64kb).
"""
try:
with open("/proc/sys/net/core/wmem_max", encoding="utf8") as wmem_file:
return int(wmem_file.read().strip())
except FileNotFoundError:
# conservative estimate for systems that have no /proc mounted
return 64_000
class Socket(contextlib.AbstractContextManager):
"""Communication Socket
@ -353,7 +379,17 @@ class Socket(contextlib.AbstractContextManager):
if level == socket.SOL_SOCKET and ty == socket.SCM_RIGHTS:
assert len(data) % fds.itemsize == 0
fds.frombytes(data)
fdset = FdSet(rawfds=fds)
# Next we need to check if the serialzed data comes via an FD
# or via the message. FDs are used if the data size is big to
# avoid running into errno.EMSGSIZE
if msg[0] == ARGS_VIA_FD_MARKER:
fd_payload = fds[0]
fdset = FdSet(rawfds=fds[1:])
with os.fdopen(fd_payload) as f:
serialized = f.read()
else:
fdset = FdSet(rawfds=fds)
serialized = msg[0]
# Check the returned message flags. If the message was truncated, we
# have to discard it. This shouldn't happen, but there is no harm in
@ -364,13 +400,38 @@ class Socket(contextlib.AbstractContextManager):
raise BufferError
try:
payload = json.loads(msg[0])
payload = json.loads(serialized)
except json.JSONDecodeError as e:
raise BufferError from e
return (payload, fdset, msg[3])
def send(self, payload: object, *, fds: Optional[list] = None):
def _send_via_fd(self, serialized: bytes, fds: List[int]):
assert self._socket is not None
with memfd("jsoncomm/payload") as fd_payload:
os.write(fd_payload, serialized)
os.lseek(fd_payload, 0, 0)
cmsg = []
cmsg.append((socket.SOL_SOCKET, socket.SCM_RIGHTS, array.array("i", [fd_payload] + fds)))
n = self._socket.sendmsg([ARGS_VIA_FD_MARKER], cmsg, 0)
assert n == len(ARGS_VIA_FD_MARKER)
def _send_via_sendmsg(self, serialized: bytes, fds: List[int]):
assert self._socket is not None
cmsg = []
if fds:
cmsg.append((socket.SOL_SOCKET, socket.SCM_RIGHTS, array.array("i", fds)))
try:
self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, len(serialized))
n = self._socket.sendmsg([serialized], cmsg, 0)
except OSError as exc:
if exc.errno == errno.EMSGSIZE:
raise BufferError(
f"jsoncomm message size {len(serialized)} is too big") from exc
raise exc
assert n == len(serialized)
def send(self, payload: object, *, fds: Optional[list] = None) -> None:
"""Send Message
Send a new message via this socket. This operation is synchronous. The
@ -399,19 +460,14 @@ class Socket(contextlib.AbstractContextManager):
if not self._socket:
raise RuntimeError("Tried to send without socket.")
serialized = json.dumps(payload).encode()
cmsg = []
if fds:
cmsg.append((socket.SOL_SOCKET, socket.SCM_RIGHTS, array.array("i", fds)))
if not fds:
fds = []
try:
n = self._socket.sendmsg([serialized], cmsg, 0)
except OSError as exc:
if exc.errno == errno.EMSGSIZE:
raise BufferError(
f"jsoncomm message size {len(serialized)} is too big") from exc
raise exc
assert n == len(serialized)
serialized = json.dumps(payload).encode()
if len(serialized) > wmem_max():
self._send_via_fd(serialized, fds)
else:
self._send_via_sendmsg(serialized, fds)
def send_and_recv(self, payload: object, *, fds: Optional[list] = None):
"""Send a message and wait for a reply