api: extract base class

Split out the part of `api.API` that is responsible for providing
the server infrastructure for the API; i.e. setting up the server
and the corresponding context manager and asynchronous event
handling. This leaves `API` itself which just the implementation
of the high level protocol and makes the API-server part re-usable.

NB: pylint, for some reason, confuses `API` and `BaseAPI`, like in
`test_monitor`. Annotate that accordingly.
This commit is contained in:
Christian Kellner 2020-07-22 15:24:33 +02:00 committed by Tom Gundersen
parent 2423bf12f0
commit 1ce517e595
2 changed files with 73 additions and 23 deletions

View file

@ -1,3 +1,4 @@
import abc
import asyncio
import io
import json
@ -7,16 +8,81 @@ import threading
from .util import jsoncomm
class API:
def __init__(self, socket_address, args, monitor):
__all__ = [
"API"
]
class BaseAPI(abc.ABC):
"""Base class for all API providers
This base class provides the basic scaffolding for setting
up API endpoints, normally to be used for bi-directional
communication from and to the sandbox. It is to be used as
a context manger. The communication channel will only be
established on entering the context and will be shut down
when the context is left.
The `_dispatch` method needs to be implemented by child
classes, and is called for incoming messages.
Optionally, the `_cleanup` method can be implemented, to
clean up resources after the context is left and the
communication channel shut down.
"""
def __init__(self, socket_address):
self.socket_address = socket_address
self.barrier = threading.Barrier(2)
self.event_loop = None
self.thread = None
@abc.abstractmethod
def _dispatch(self, server):
"""Called for incoming messages on the socket"""
def _cleanup(self):
"""Called after the event loop is shut down"""
def _run_event_loop(self):
with jsoncomm.Socket.new_server(self.socket_address) as server:
self.barrier.wait()
self.event_loop.add_reader(server, self._dispatch, server)
asyncio.set_event_loop(self.event_loop)
self.event_loop.run_forever()
self.event_loop.remove_reader(server)
def __enter__(self):
# We are not re-entrant, so complain if re-entered.
assert self.event_loop is None
self.event_loop = asyncio.new_event_loop()
self.thread = threading.Thread(target=self._run_event_loop)
self.barrier.reset()
self.thread.start()
self.barrier.wait()
return self
def __exit__(self, *args):
self.event_loop.call_soon_threadsafe(self.event_loop.stop)
self.thread.join()
self.event_loop.close()
# Give deriving classes a chance to clean themselves up
self._cleanup()
self.thread = None
self.event_loop = None
class API(BaseAPI):
"""The main OSBuild API"""
def __init__(self, socket_address, args, monitor):
super().__init__(socket_address)
self.input = args
self._output_data = io.StringIO()
self._output_pipe = None
self.monitor = monitor
self.event_loop = asyncio.new_event_loop()
self.thread = threading.Thread(target=self._run_event_loop)
self.barrier = threading.Barrier(2)
@property
def output(self):
@ -61,23 +127,7 @@ class API:
if msg["method"] == 'setup-stdio':
self._setup_stdio(server, addr)
def _run_event_loop(self):
with jsoncomm.Socket.new_server(self.socket_address) as server:
self.barrier.wait()
self.event_loop.add_reader(server, self._dispatch, server)
asyncio.set_event_loop(self.event_loop)
self.event_loop.run_forever()
self.event_loop.remove_reader(server)
def __enter__(self):
self.thread.start()
self.barrier.wait()
return self
def __exit__(self, *args):
self.event_loop.call_soon_threadsafe(self.event_loop.stop)
self.thread.join()
self.event_loop.close()
def _cleanup(self):
if self._output_pipe:
os.close(self._output_pipe)
self._output_pipe = None

View file

@ -93,7 +93,7 @@ class TestMonitor(unittest.TestCase):
p.start()
p.join()
self.assertEqual(p.exitcode, 0)
output = api.output
output = api.output # pylint: disable=no-member
assert output
self.assertEqual(json.dumps(args), output)