jsoncomm: transparently handle huge messages via fds
The existing jsoncomm is a work of beautiy. For very big arguments however the used `SOCK_SEQPACKET` hits the limitations of the kernel network buffer size (see also [0]). This lead to various workarounds in #824,#1331,#1836 where parts of the request are encoded as part of the json method call and parts are done via a side-channel via fd-passing. This commit changes the code so that the fd channel is automatically and transparently created and the workarounds are removed. A test is added that ensures that very big messages can be passed. [0] https://github.com/osbuild/osbuild/pull/1833
This commit is contained in:
parent
d67fa48c17
commit
0abdfb9041
6 changed files with 130 additions and 87 deletions
|
|
@ -5,18 +5,31 @@ serialization. It uses unix-domain-datagram-sockets and provides a simple
|
|||
unicast message transmission.
|
||||
"""
|
||||
|
||||
|
||||
import array
|
||||
import contextlib
|
||||
import errno
|
||||
import json
|
||||
import os
|
||||
import socket
|
||||
from typing import Any, Optional
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from .types import PathLike
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def memfd(name):
|
||||
fd = os.memfd_create(name, 0)
|
||||
try:
|
||||
yield fd
|
||||
finally:
|
||||
os.close(fd)
|
||||
|
||||
|
||||
# this marker is used when the arguments are passed via a filedescriptor
|
||||
# because they exceed the allowed size for a network package
|
||||
ARGS_VIA_FD_MARKER = b"<args-via-fd>"
|
||||
|
||||
|
||||
class FdSet:
|
||||
"""File-Descriptor Set
|
||||
|
||||
|
|
@ -92,6 +105,19 @@ class FdSet:
|
|||
return v
|
||||
|
||||
|
||||
def wmem_max() -> int:
|
||||
""" Return the kernels maximum send socket buffer size in bytes
|
||||
|
||||
When /proc is not mounted return a conservative estimate (64kb).
|
||||
"""
|
||||
try:
|
||||
with open("/proc/sys/net/core/wmem_max", encoding="utf8") as wmem_file:
|
||||
return int(wmem_file.read().strip())
|
||||
except FileNotFoundError:
|
||||
# conservative estimate for systems that have no /proc mounted
|
||||
return 64_000
|
||||
|
||||
|
||||
class Socket(contextlib.AbstractContextManager):
|
||||
"""Communication Socket
|
||||
|
||||
|
|
@ -353,7 +379,17 @@ class Socket(contextlib.AbstractContextManager):
|
|||
if level == socket.SOL_SOCKET and ty == socket.SCM_RIGHTS:
|
||||
assert len(data) % fds.itemsize == 0
|
||||
fds.frombytes(data)
|
||||
fdset = FdSet(rawfds=fds)
|
||||
# Next we need to check if the serialzed data comes via an FD
|
||||
# or via the message. FDs are used if the data size is big to
|
||||
# avoid running into errno.EMSGSIZE
|
||||
if msg[0] == ARGS_VIA_FD_MARKER:
|
||||
fd_payload = fds[0]
|
||||
fdset = FdSet(rawfds=fds[1:])
|
||||
with os.fdopen(fd_payload) as f:
|
||||
serialized = f.read()
|
||||
else:
|
||||
fdset = FdSet(rawfds=fds)
|
||||
serialized = msg[0]
|
||||
|
||||
# Check the returned message flags. If the message was truncated, we
|
||||
# have to discard it. This shouldn't happen, but there is no harm in
|
||||
|
|
@ -364,13 +400,38 @@ class Socket(contextlib.AbstractContextManager):
|
|||
raise BufferError
|
||||
|
||||
try:
|
||||
payload = json.loads(msg[0])
|
||||
payload = json.loads(serialized)
|
||||
except json.JSONDecodeError as e:
|
||||
raise BufferError from e
|
||||
|
||||
return (payload, fdset, msg[3])
|
||||
|
||||
def send(self, payload: object, *, fds: Optional[list] = None):
|
||||
def _send_via_fd(self, serialized: bytes, fds: List[int]):
|
||||
assert self._socket is not None
|
||||
with memfd("jsoncomm/payload") as fd_payload:
|
||||
os.write(fd_payload, serialized)
|
||||
os.lseek(fd_payload, 0, 0)
|
||||
cmsg = []
|
||||
cmsg.append((socket.SOL_SOCKET, socket.SCM_RIGHTS, array.array("i", [fd_payload] + fds)))
|
||||
n = self._socket.sendmsg([ARGS_VIA_FD_MARKER], cmsg, 0)
|
||||
assert n == len(ARGS_VIA_FD_MARKER)
|
||||
|
||||
def _send_via_sendmsg(self, serialized: bytes, fds: List[int]):
|
||||
assert self._socket is not None
|
||||
cmsg = []
|
||||
if fds:
|
||||
cmsg.append((socket.SOL_SOCKET, socket.SCM_RIGHTS, array.array("i", fds)))
|
||||
try:
|
||||
self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, len(serialized))
|
||||
n = self._socket.sendmsg([serialized], cmsg, 0)
|
||||
except OSError as exc:
|
||||
if exc.errno == errno.EMSGSIZE:
|
||||
raise BufferError(
|
||||
f"jsoncomm message size {len(serialized)} is too big") from exc
|
||||
raise exc
|
||||
assert n == len(serialized)
|
||||
|
||||
def send(self, payload: object, *, fds: Optional[list] = None) -> None:
|
||||
"""Send Message
|
||||
|
||||
Send a new message via this socket. This operation is synchronous. The
|
||||
|
|
@ -399,19 +460,14 @@ class Socket(contextlib.AbstractContextManager):
|
|||
if not self._socket:
|
||||
raise RuntimeError("Tried to send without socket.")
|
||||
|
||||
serialized = json.dumps(payload).encode()
|
||||
cmsg = []
|
||||
if fds:
|
||||
cmsg.append((socket.SOL_SOCKET, socket.SCM_RIGHTS, array.array("i", fds)))
|
||||
if not fds:
|
||||
fds = []
|
||||
|
||||
try:
|
||||
n = self._socket.sendmsg([serialized], cmsg, 0)
|
||||
except OSError as exc:
|
||||
if exc.errno == errno.EMSGSIZE:
|
||||
raise BufferError(
|
||||
f"jsoncomm message size {len(serialized)} is too big") from exc
|
||||
raise exc
|
||||
assert n == len(serialized)
|
||||
serialized = json.dumps(payload).encode()
|
||||
if len(serialized) > wmem_max():
|
||||
self._send_via_fd(serialized, fds)
|
||||
else:
|
||||
self._send_via_sendmsg(serialized, fds)
|
||||
|
||||
def send_and_recv(self, payload: object, *, fds: Optional[list] = None):
|
||||
"""Send a message and wait for a reply
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue