From 610d1c45d5af5cc38297805a80d0f1ca2e7b4c32 Mon Sep 17 00:00:00 2001 From: Christian Kellner Date: Fri, 7 May 2021 17:59:16 +0000 Subject: [PATCH] util/jsoncomm: ability to create socket from fd Add a new constructor method that allows creating a `Socket` from an existing file-descriptor of a socket. This might be need when the socket was passed to a child process. Add a simple test for the new constructor method. --- osbuild/util/jsoncomm.py | 21 +++++++++++++++++++++ test/mod/test_util_jsoncomm.py | 17 +++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/osbuild/util/jsoncomm.py b/osbuild/util/jsoncomm.py index e06e3147..6330f16f 100644 --- a/osbuild/util/jsoncomm.py +++ b/osbuild/util/jsoncomm.py @@ -272,6 +272,27 @@ class Socket(contextlib.AbstractContextManager): return cls(a, None), cls(b, None) + @classmethod + def new_from_fd(cls, fd: int, *, blocking=True, close_fd=True): + """Create a socket for an existing file descriptor + + Duplicate the file descriptor and return a `Socket` for it. + The blocking mode can be set via `blocking`. If `close_fd` + is True (the default) `fd` will be closed. + + Parameters + ---------- + fd + The file descriptor to use. + blocking + The blocking mode for the socket pair. + """ + sock = socket.fromfd(fd, socket.AF_UNIX, socket.SOCK_SEQPACKET) + sock.setblocking(blocking) + if close_fd: + os.close(fd) + return cls(sock, None) + def fileno(self) -> int: assert self._socket is not None return self._socket.fileno() diff --git a/test/mod/test_util_jsoncomm.py b/test/mod/test_util_jsoncomm.py index 04cd071b..dadaefbd 100644 --- a/test/mod/test_util_jsoncomm.py +++ b/test/mod/test_util_jsoncomm.py @@ -186,3 +186,20 @@ class TestUtilJsonComm(unittest.TestCase): a.send(ping) pong, _, _ = b.recv() self.assertEqual(ping, pong) + + def test_from_fd(self): + # + # Test creating a Socket from an existing file descriptor + a, x = jsoncomm.Socket.new_pair() + fd = os.dup(x.fileno()) + + b = jsoncomm.Socket.new_from_fd(fd) + + # x should be closed and thus raise "Bad file descriptor" + with self.assertRaises(OSError): + os.write(fd, b"test") + + ping = {"osbuild": "yes"} + a.send(ping) + pong, _, _ = b.recv() + self.assertEqual(ping, pong)