From 1ce517e595ccd71aa323508f403402aa0ed2cc77 Mon Sep 17 00:00:00 2001 From: Christian Kellner Date: Wed, 22 Jul 2020 15:24:33 +0200 Subject: [PATCH] api: extract base class Split out the part of `api.API` that is responsible for providing the server infrastructure for the API; i.e. setting up the server and the corresponding context manager and asynchronous event handling. This leaves `API` itself which just the implementation of the high level protocol and makes the API-server part re-usable. NB: pylint, for some reason, confuses `API` and `BaseAPI`, like in `test_monitor`. Annotate that accordingly. --- osbuild/api.py | 94 ++++++++++++++++++++++++++++++---------- test/mod/test_monitor.py | 2 +- 2 files changed, 73 insertions(+), 23 deletions(-) diff --git a/osbuild/api.py b/osbuild/api.py index 4d77d2b2..be29e8e9 100644 --- a/osbuild/api.py +++ b/osbuild/api.py @@ -1,3 +1,4 @@ +import abc import asyncio import io import json @@ -7,16 +8,81 @@ import threading from .util import jsoncomm -class API: - def __init__(self, socket_address, args, monitor): +__all__ = [ + "API" +] + + +class BaseAPI(abc.ABC): + """Base class for all API providers + + This base class provides the basic scaffolding for setting + up API endpoints, normally to be used for bi-directional + communication from and to the sandbox. It is to be used as + a context manger. The communication channel will only be + established on entering the context and will be shut down + when the context is left. + + The `_dispatch` method needs to be implemented by child + classes, and is called for incoming messages. + Optionally, the `_cleanup` method can be implemented, to + clean up resources after the context is left and the + communication channel shut down. + """ + def __init__(self, socket_address): self.socket_address = socket_address + self.barrier = threading.Barrier(2) + self.event_loop = None + self.thread = None + + @abc.abstractmethod + def _dispatch(self, server): + """Called for incoming messages on the socket""" + + def _cleanup(self): + """Called after the event loop is shut down""" + + def _run_event_loop(self): + with jsoncomm.Socket.new_server(self.socket_address) as server: + self.barrier.wait() + self.event_loop.add_reader(server, self._dispatch, server) + asyncio.set_event_loop(self.event_loop) + self.event_loop.run_forever() + self.event_loop.remove_reader(server) + + def __enter__(self): + # We are not re-entrant, so complain if re-entered. + assert self.event_loop is None + + self.event_loop = asyncio.new_event_loop() + self.thread = threading.Thread(target=self._run_event_loop) + + self.barrier.reset() + self.thread.start() + self.barrier.wait() + + return self + + def __exit__(self, *args): + self.event_loop.call_soon_threadsafe(self.event_loop.stop) + self.thread.join() + self.event_loop.close() + + # Give deriving classes a chance to clean themselves up + self._cleanup() + + self.thread = None + self.event_loop = None + + +class API(BaseAPI): + """The main OSBuild API""" + def __init__(self, socket_address, args, monitor): + super().__init__(socket_address) self.input = args self._output_data = io.StringIO() self._output_pipe = None self.monitor = monitor - self.event_loop = asyncio.new_event_loop() - self.thread = threading.Thread(target=self._run_event_loop) - self.barrier = threading.Barrier(2) @property def output(self): @@ -61,23 +127,7 @@ class API: if msg["method"] == 'setup-stdio': self._setup_stdio(server, addr) - def _run_event_loop(self): - with jsoncomm.Socket.new_server(self.socket_address) as server: - self.barrier.wait() - self.event_loop.add_reader(server, self._dispatch, server) - asyncio.set_event_loop(self.event_loop) - self.event_loop.run_forever() - self.event_loop.remove_reader(server) - - def __enter__(self): - self.thread.start() - self.barrier.wait() - return self - - def __exit__(self, *args): - self.event_loop.call_soon_threadsafe(self.event_loop.stop) - self.thread.join() - self.event_loop.close() + def _cleanup(self): if self._output_pipe: os.close(self._output_pipe) self._output_pipe = None diff --git a/test/mod/test_monitor.py b/test/mod/test_monitor.py index 36a2ddd3..f13934ec 100644 --- a/test/mod/test_monitor.py +++ b/test/mod/test_monitor.py @@ -93,7 +93,7 @@ class TestMonitor(unittest.TestCase): p.start() p.join() self.assertEqual(p.exitcode, 0) - output = api.output + output = api.output # pylint: disable=no-member assert output self.assertEqual(json.dumps(args), output)