diff --git a/osbuild/host.py b/osbuild/host.py index d2a406a3..5d278afd 100644 --- a/osbuild/host.py +++ b/osbuild/host.py @@ -48,7 +48,7 @@ import sys import threading import traceback from collections import OrderedDict -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Callable from osbuild.util.jsoncomm import FdSet, Socket @@ -131,6 +131,16 @@ class ServiceProtocol: # method call, which can also be `None` return data + @staticmethod + def encode_signal(sig: Any): + msg = { + "type": "signal", + "data": { + "reply": sig + } + } + return msg + @staticmethod def encode_exception(value, tb): backtrace = "".join(traceback.format_tb(tb)) @@ -293,6 +303,10 @@ class Service(abc.ABC): return msg, fds + def emit_signal(self, data: Any, fds: Optional[list] = None): + self._check_fds(fds) + self.sock.send(self.protocol.encode_signal(data), fds=fds) + @staticmethod def _close_all(fds: Optional[List[int]]): if not fds: @@ -336,7 +350,8 @@ class ServiceClient: def call_with_fds(self, method: str, args: Optional[Any] = None, - fds: Optional[List] = None) -> Tuple[Any, FdSet]: + fds: Optional[List] = None, + on_signal: Callable[[Any, FdSet], None] = None) -> Tuple[Any, FdSet]: """ Remotely call a method and return the result, including file descriptors. @@ -344,16 +359,20 @@ class ServiceClient: msg = self.protocol.encode_method(method, args) - ret, fds, _ = self.sock.send_and_recv(msg, fds=fds) + self.sock.send(msg, fds=fds) - kind, data = self.protocol.decode_message(ret) - - if kind == "reply": - ret = self.protocol.decode_reply(data) - return ret, fds - if kind == "exception": - error = self.protocol.decode_exception(data) - raise error + while True: + ret, fds, _ = self.sock.recv() + kind, data = self.protocol.decode_message(ret) + if kind == "signal": + ret = self.protocol.decode_reply(data) + on_signal(ret, fds) + if kind == "reply": + ret = self.protocol.decode_reply(data) + return ret, fds + if kind == "exception": + error = self.protocol.decode_exception(data) + raise error raise ProtocolError(f"unknown message type: {kind}") diff --git a/test/mod/test_host.py b/test/mod/test_host.py index d2157a2b..56ea1807 100755 --- a/test/mod/test_host.py +++ b/test/mod/test_host.py @@ -61,6 +61,18 @@ class ServiceTest(host.Service): continue raise raise ValueError(f"fd '{fd}' was not closed") + elif method == "signal_me_3_times": + self.emit_signal(0) + self.emit_signal(1) + self.emit_signal(2) + elif method == "signal_me_on_fd": + with tempfile.TemporaryFile("w+") as f: + with os.fdopen(fds.steal(0)) as d: + f.write(d.read()) + f.seek(0) + fds = [os.dup(f.fileno())] + self.register_fds(fds) + self.emit_signal("that should do it", fds) else: raise host.ProtocolError("unknown method:", method) @@ -135,6 +147,42 @@ def test_exception(): client.call("exception") +def test_signals(): + with host.ServiceManager() as mgr: + exec_callback = 0 + + def check_value(item, _fds): + nonlocal exec_callback + assert item == exec_callback + exec_callback += 1 + client = mgr.start("test_signal_me_3_times", __file__) + client.call_with_fds("signal_me_3_times", on_signal=check_value) + assert exec_callback == 3 + + +def test_signals_on_separate_fd(): + with host.ServiceManager() as mgr: + + data = "osbuild\n" + exec_callback = False + + def check_value(item, fds): + nonlocal exec_callback + exec_callback = True + assert item == "that should do it" + with os.fdopen(fds.steal(0)) as d: + assert data == d.read() + + client = mgr.start("test_signal_me_on_fd", __file__) + + with tempfile.TemporaryFile("w+") as f: + f.write(data) + f.seek(0) + + client.call_with_fds("signal_me_on_fd", fds=[f.fileno()], on_signal=check_value) + assert exec_callback + + def main(): service = ServiceTest.from_args(sys.argv[1:]) service.main()