diff --git a/osbuild/api.py b/osbuild/api.py index 7b47571b..f7e8aa06 100644 --- a/osbuild/api.py +++ b/osbuild/api.py @@ -69,13 +69,24 @@ class BaseAPI(abc.ABC): def _dispatch(self, sock: jsoncomm.Socket): """Called when data is available on the socket""" msg, fds, addr = sock.recv() + if msg is None: + # Peer closed the connection + self.event_loop.remove_reader(sock) + return self._message(msg, fds, sock, addr) fds.close() + def _accept(self, server): + client = server.accept() + if client: + self.event_loop.add_reader(client, self._dispatch, client) + def _run_event_loop(self): with jsoncomm.Socket.new_server(self.socket_address) as server: + server.blocking = False + server.listen() self.barrier.wait() - self.event_loop.add_reader(server, self._dispatch, server) + self.event_loop.add_reader(server, self._accept, server) asyncio.set_event_loop(self.event_loop) self.event_loop.run_forever() self.event_loop.remove_reader(server) diff --git a/osbuild/util/jsoncomm.py b/osbuild/util/jsoncomm.py index 59dbc33f..06fb5bdf 100644 --- a/osbuild/util/jsoncomm.py +++ b/osbuild/util/jsoncomm.py @@ -96,8 +96,8 @@ class Socket(contextlib.AbstractContextManager): """Communication Socket This socket object represents a communication channel. It allows sending - and receiving JSON-encoded messages. It uses unix-domain-datagram sockets - as underlying transport. + and receiving JSON-encoded messages. It uses unix-domain sequenced-packet + sockets as underlying transport. """ _socket = None @@ -194,7 +194,7 @@ class Socket(contextlib.AbstractContextManager): sock = None try: - sock = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM) + sock = socket.socket(socket.AF_UNIX, socket.SOCK_SEQPACKET) # Trigger an auto-bind. If you do not do this, you might end up with # an unbound unix socket, which cannot receive messages. @@ -217,7 +217,8 @@ class Socket(contextlib.AbstractContextManager): def new_server(cls, bind_to: PathLike): """Create Server - Create a new listener socket. + Create a new listener socket. Returned socket is in non-blocking + mode by default. See `blocking` property. Parameters ---------- @@ -240,9 +241,10 @@ class Socket(contextlib.AbstractContextManager): # creation and open. But then your entire socket creation is racy # as well. We do not guarantee atomicity, so you better make sure # you do not rely on it. - sock = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM) + sock = socket.socket(socket.AF_UNIX, socket.SOCK_SEQPACKET) sock.bind(os.fspath(bind_to)) unlink = os.open(os.path.join(".", path[0]), os.O_CLOEXEC | os.O_PATH) + sock.setblocking(False) except: if unlink is not None: os.close(unlink) @@ -264,9 +266,12 @@ class Socket(contextlib.AbstractContextManager): A tuple consisting of the deserialized message payload, the auxiliary file-descriptor set, and the socket-address of the sender is returned. + + In case the peer closed the connection, A tuple of `None` values is + returned. """ - # On `SOCK_DGRAM`, packets might be arbitrarily sized. There is no + # On `SOCK_SEQPACKET`, packets might be arbitrarily sized. There is no # hard-coded upper limit, since it is only restricted by the size of # the kernel write buffer on sockets (which itself can be modified via # sysctl). The only real maximum is probably something like 2^31-1, @@ -279,6 +284,9 @@ class Socket(contextlib.AbstractContextManager): size = 4096 while True: peek = self._socket.recvmsg(size, 0, socket.MSG_PEEK) + if not peek[0]: + # Connection was closed + return None, None, None if not (peek[2] & socket.MSG_TRUNC): break size *= 2 diff --git a/test/mod/test_util_jsoncomm.py b/test/mod/test_util_jsoncomm.py index 96e58184..72a192b0 100644 --- a/test/mod/test_util_jsoncomm.py +++ b/test/mod/test_util_jsoncomm.py @@ -7,20 +7,31 @@ import os import pathlib import tempfile import unittest +from concurrent import futures from osbuild.util import jsoncomm class TestUtilJsonComm(unittest.TestCase): def setUp(self): + # Prepare a bi-directional connection between a `client` + # and `server`; nb: the nomenclature is a bit unusual in + # the sense that the serving socket is called `listener` self.dir = tempfile.TemporaryDirectory() self.address = pathlib.Path(self.dir.name, "listener") - self.server = jsoncomm.Socket.new_server(self.address) - self.client = jsoncomm.Socket.new_client(self.address) + self.listener = jsoncomm.Socket.new_server(self.address) + self.listener.blocking = True # We want `accept` to block + self.listener.listen() + + with futures.ThreadPoolExecutor() as executor: + future = executor.submit(self.listener.accept) + self.client = jsoncomm.Socket.new_client(self.address) + self.server = future.result() def tearDown(self): self.client.close() self.server.close() + self.listener.close() self.dir.cleanup() def test_fdset(self):