diff --git a/osbuild/util/jsoncomm.py b/osbuild/util/jsoncomm.py new file mode 100644 index 00000000..1591642b --- /dev/null +++ b/osbuild/util/jsoncomm.py @@ -0,0 +1,306 @@ +"""JSON Communication + +This module implements a client/server communication method based on JSON +serialization. It uses unix-domain-datagram-sockets and provides a simple +unicast message transmission. +""" + + +import array +import contextlib +import errno +import json +import os +import socket +from typing import Any +from typing import Optional + + +class FdSet(): + """File-Descriptor Set + + This object wraps an array of file-descriptors. Unlike a normal integer + array, this object owns the file-descriptors and therefore closes them once + the object is released. + + File-descriptor sets are initialized once. From then one, the only allowed + operation is to query it for information, or steal file-descriptors from + it. If you close a set, all remaining file-descriptors are closed and + removed from the set. It will then be an empty set. + """ + + _fds = array.array("i") + + def __init__(self, *, rawfds): + for i in rawfds: + if not isinstance(i, int) or i < 0: + raise ValueError() + + self._fds = rawfds + + def __del__(self): + self.close() + + def close(self): + """Close All Entries + + This closes all stored file-descriptors and clears the set. Once this + returns, the set will be empty. It is safe to call this multiple times. + Note that a set is automatically closed when it is garbage collected. + """ + + for i in self._fds: + if i >= 0: + os.close(i) + + self._fds = array.array("i") + + @classmethod + def from_list(cls, l: list): + """Create new Set from List + + This creates a new file-descriptor set initialized to the same entries + as in the given list. This consumes the file-descriptors. The caller + must not assume ownership anymore. + """ + + fds = array.array("i") + fds.fromlist(l) + return cls(rawfds=fds) + + def __len__(self): + return len(self._fds) + + def __getitem__(self, key: Any): + if self._fds[key] < 0: + raise IndexError + return self._fds[key] + + def steal(self, key: Any): + """Steal Entry + + Retrieve the entry at the given position, but drop it from the internal + file-descriptor set. The caller will now own the file-descriptor and it + can no longer be accessed through the set. + + Note that this does not reshuffle the set. All indices stay constant. + """ + + v = self[key] + self._fds[key] = -1 + return v + + +class Socket(contextlib.AbstractContextManager): + """Communication Socket + + This socket object represents a communication channel. It allows sending + and receiving JSON-encoded messages. It uses unix-domain-datagram sockets + as underlying transport. + """ + + _socket = None + _unlink = None + + def __init__(self, sock, unlink): + self._socket = sock + self._unlink = unlink + + def __del__(self): + self.close() + + def __exit__(self, exc_type, exc_value, exc_tb): + self.close() + return False + + def close(self): + """Close Socket + + Close the socket and all underlying resources. This can be called + multiple times. + """ + + # close the socket if it is set + if self._socket is not None: + self._socket.close() + self._socket = None + + # unlink the file-system entry, if pinned + if self._unlink is not None: + try: + os.unlink(self._unlink[1], dir_fd=self._unlink[0]) + except OSError as e: + if e.errno != errno.ENOENT: + raise + + os.close(self._unlink[0]) + self._unlink = None + + @classmethod + def new_client(cls, connect_to: Optional[str] = None): + """Create Client + + Create a new client socket. + + Parameters + ---------- + connect_to + If not `None`, the client will use the specified address as the + default destination for all send operations. + """ + + sock = None + + try: + sock = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM) + + # Trigger an auto-bind. If you do not do this, you might end up with + # an unbound unix socket, which cannot receive messages. + # Alternatively, you can also set `SO_PASSCRED`, but this has + # side-effects. + sock.bind("") + + # Connect the socket. This has no effect other than specifying the + # default destination for send operations. + if connect_to is not None: + sock.connect(connect_to) + except: + if sock is not None: + sock.close() + raise + + return cls(sock, None) + + @classmethod + def new_server(cls, bind_to: str): + """Create Server + + Create a new listener socket. + + Parameters + ---------- + bind_to + The socket-address to listen on for incoming client requests. + """ + + sock = None + unlink = None + path = os.path.split(bind_to) + + try: + # We bind the socket and then open a directory-fd on the target + # socket. This allows us to properly unlink the socket when the + # server is closed. Note that sockets are never automatically + # cleaned up on linux, nor can you bind to existing sockets. + # We use a dirfd to guarantee this works even when you change + # your mount points in-between. + # Yeah, this is racy when mount-points change between the socket + # creation and open. But then your entire socket creation is racy + # as well. We do not guarantee atomicity, so you better make sure + # you do not rely on it. + sock = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM) + sock.bind(bind_to) + unlink = os.open(os.path.join(".", path[0]), os.O_CLOEXEC | os.O_PATH) + except: + if unlink is not None: + os.close(unlink) + if sock is not None: + sock.close() + raise + + return cls(sock, (unlink, path[1])) + + def fileno(self) -> int: + assert self._socket is not None + return self._socket.fileno() + + def recv(self): + """Receive a Message + + This receives the next pending message from the socket. This operation + is synchronous. + + A tuple consisting of the deserialized message payload, the auxiliary + file-descriptor set, and the socket-address of the sender is returned. + """ + + # On `SOCK_DGRAM`, packets might be arbitrarily sized. There is no + # hard-coded upper limit, since it is only restricted by the size of + # the kernel write buffer on sockets (which itself can be modified via + # sysctl). The only real maximum is probably something like 2^31-1, + # since that is the maximum of that sysctl datatype. + # Anyway, `MSG_TRUNC+MSG_PEEK` usually allows us to easily peek at the + # incoming buffer. Unfortunately, the python `recvmsg()` wrapper + # discards the return code and we cannot use that. Instead, we simply + # loop until we know the size. This is slightly awkward, but seems fine + # as long as you do not put this into a hot-path. + size = 4096 + while True: + peek = self._socket.recvmsg(size, 0, socket.MSG_PEEK) + if not (peek[2] & socket.MSG_TRUNC): + break + size *= 2 + + # Fetch a packet from the socket. On linux, the maximum SCM_RIGHTS array + # size is hard-coded to 253. This allows us to size the ancillary buffer + # big enough to receive any possible message. + fds = array.array("i") + msg = self._socket.recvmsg(size, socket.CMSG_LEN(253 * fds.itemsize)) + + # First thing we do is always to fetch the CMSG FDs into an FdSet. This + # guarantees that we do not leak FDs in case the message handling fails + # for other reasons. + for level, ty, data in msg[1]: + if level == socket.SOL_SOCKET and ty == socket.SCM_RIGHTS: + assert len(data) % fds.itemsize == 0 + fds.frombytes(data) + fdset = FdSet(rawfds=fds) + + # Check the returned message flags. If the message was truncated, we + # have to discard it. This shouldn't happen, but there is no harm in + # handling it. However, `CTRUNC` can happen, since it is also triggered + # when LSMs reject FD transmission. Treat it the same as a parser error. + flags = msg[2] + if flags & (socket.MSG_TRUNC | socket.MSG_CTRUNC): + raise BufferError + + try: + payload = json.loads(msg[0]) + except json.JSONDecodeError: + raise BufferError + + return (payload, fdset, msg[3]) + + def send(self, payload: object, *, destination: Optional[str] = None, fds: list = []): + """Send Message + + Send a new message via this socket. This operation is synchronous. The + maximum message size depends on the configured send-buffer on the + socket. An `OSError` with `EMSGSIZE` is raised when it is exceeded. + + Parameters + ---------- + payload + A python object to serialize as JSON and send via this socket. See + `json.dump()` for details about the serialization involved. + destination + The destination to send to. If `None`, the default destination is + used (if none is set, this will raise an `OSError`). + fds + A list of file-descriptors to send with the message. + + Raises + ------ + OSError + If the socket cannot be written, a matching `OSError` is raised. + TypeError + If the payload cannot be serialized, a type error is raised. + """ + + serialized = json.dumps(payload).encode() + cmsg = [] + if len(fds) > 0: + cmsg.append((socket.SOL_SOCKET, socket.SCM_RIGHTS, array.array("i", fds))) + + n = self._socket.sendmsg([serialized], cmsg, 0, destination) + assert n == len(serialized) diff --git a/test/test_util_jsoncomm.py b/test/test_util_jsoncomm.py new file mode 100644 index 00000000..310f26dd --- /dev/null +++ b/test/test_util_jsoncomm.py @@ -0,0 +1,159 @@ +# +# Tests for the 'osbuild.util.jsoncomm' module. +# + + +import asyncio +import os +import tempfile +import unittest +from osbuild.util import jsoncomm + + +class TestUtilJsonComm(unittest.TestCase): + def setUp(self): + self.dir = tempfile.TemporaryDirectory() + self.address = os.path.join(self.dir.name, "listener") + self.server = jsoncomm.Socket.new_server(self.address) + self.client = jsoncomm.Socket.new_client(self.address) + + def tearDown(self): + self.client.close() + self.server.close() + self.dir.cleanup() + + def test_fdset(self): + # + # Test the FdSet implementation. Create a simple FD array and verify + # that the FdSet correctly indexes them. Furthermore, verify that a + # close actually closes the Fds so a following FdSet will get the same + # FD numbers assigned. + # + + v1 = [os.dup(0), os.dup(0), os.dup(0), os.dup(0)] + s = jsoncomm.FdSet.from_list(v1) + assert len(s) == 4 + for i in range(4): + assert s[i] == v1[i] + with self.assertRaises(IndexError): + _ = s[128] + s.close() + + v2 = [os.dup(0), os.dup(0), os.dup(0), os.dup(0)] + assert v1 == v2 + s = jsoncomm.FdSet.from_list(v2) + s.close() + + def test_fdset_init(self): + # + # Test FdSet initializations. This includes common edge-cases like empty + # initializers, invalid array values, or invalid types. + # + + s = jsoncomm.FdSet.from_list([]) + s.close() + + with self.assertRaises(ValueError): + v1 = [-1] + s = jsoncomm.FdSet.from_list(v1) + + with self.assertRaises(ValueError): + v1 = ["foobar"] + s = jsoncomm.FdSet(rawfds=v1) + + def test_ping_pong(self): + # + # Test sending messages through the client/server connection. + # + + data = {"key": "value"} + self.client.send(data) + msg = self.server.recv() + assert msg[0] == data + assert len(msg[1]) == 0 + + self.server.send(data, destination=msg[2]) + msg = self.client.recv() + assert msg[0] == data + assert len(msg[1]) == 0 + + def test_scm_rights(self): + # + # Test FD transmission. Create a file, send a file-descriptor through + # the communication channel, and then verify that the file-contents + # can be read. + # + + with tempfile.TemporaryFile() as f1: + f1.write(b"foobar") + f1.seek(0) + + self.client.send({}, fds=[f1.fileno()]) + + msg = self.server.recv() + assert msg[0] == {} + assert len(msg[1]) == 1 + with os.fdopen(msg[1].steal(0)) as f2: + assert f2.read() == "foobar" + + def test_listener_cleanup(self): + # + # Verify that only a single server can listen on a specified address. + # Then make sure closing a server will correctly unlink its socket. + # + + addr = os.path.join(self.dir.name, "foobar") + srv1 = jsoncomm.Socket.new_server(addr) + with self.assertRaises(OSError): + srv2 = jsoncomm.Socket.new_server(addr) + srv1.close() + srv2 = jsoncomm.Socket.new_server(addr) + srv2.close() + + def test_contextlib(self): + # + # Verify the context-manager of sockets. Make sure they correctly close + # the socket, and they correctly propagate exceptions. + # + + assert self.client.fileno() >= 0 + with self.client as client: + assert client == self.client + assert client.fileno() >= 0 + with self.assertRaises(AssertionError): + self.client.fileno() + + assert self.server.fileno() >= 0 + with self.assertRaises(SystemError): + with self.server as server: + assert server.fileno() >= 0 + raise SystemError + raise AssertionError + with self.assertRaises(AssertionError): + self.server.fileno() + + def test_asyncio(self): + # + # Test integration with asyncio-eventloops. Use a trivial echo server + # and test a simple ping/pong roundtrip. + # + + loop = asyncio.new_event_loop() + + def echo(socket): + msg = socket.recv() + socket.send(msg[0], destination=msg[2]) + loop.stop() + + self.client.send({}) + + loop.add_reader(self.server, echo, self.server) + loop.run_forever() + loop.close() + + msg = self.client.recv() + assert msg[0] == {} + + +if __name__ == "__main__": + unittest.main()