api: extract base class
Split out the part of `api.API` that is responsible for providing the server infrastructure for the API; i.e. setting up the server and the corresponding context manager and asynchronous event handling. This leaves `API` itself which just the implementation of the high level protocol and makes the API-server part re-usable. NB: pylint, for some reason, confuses `API` and `BaseAPI`, like in `test_monitor`. Annotate that accordingly.
This commit is contained in:
parent
2423bf12f0
commit
1ce517e595
2 changed files with 73 additions and 23 deletions
|
|
@ -1,3 +1,4 @@
|
|||
import abc
|
||||
import asyncio
|
||||
import io
|
||||
import json
|
||||
|
|
@ -7,16 +8,81 @@ import threading
|
|||
from .util import jsoncomm
|
||||
|
||||
|
||||
class API:
|
||||
def __init__(self, socket_address, args, monitor):
|
||||
__all__ = [
|
||||
"API"
|
||||
]
|
||||
|
||||
|
||||
class BaseAPI(abc.ABC):
|
||||
"""Base class for all API providers
|
||||
|
||||
This base class provides the basic scaffolding for setting
|
||||
up API endpoints, normally to be used for bi-directional
|
||||
communication from and to the sandbox. It is to be used as
|
||||
a context manger. The communication channel will only be
|
||||
established on entering the context and will be shut down
|
||||
when the context is left.
|
||||
|
||||
The `_dispatch` method needs to be implemented by child
|
||||
classes, and is called for incoming messages.
|
||||
Optionally, the `_cleanup` method can be implemented, to
|
||||
clean up resources after the context is left and the
|
||||
communication channel shut down.
|
||||
"""
|
||||
def __init__(self, socket_address):
|
||||
self.socket_address = socket_address
|
||||
self.barrier = threading.Barrier(2)
|
||||
self.event_loop = None
|
||||
self.thread = None
|
||||
|
||||
@abc.abstractmethod
|
||||
def _dispatch(self, server):
|
||||
"""Called for incoming messages on the socket"""
|
||||
|
||||
def _cleanup(self):
|
||||
"""Called after the event loop is shut down"""
|
||||
|
||||
def _run_event_loop(self):
|
||||
with jsoncomm.Socket.new_server(self.socket_address) as server:
|
||||
self.barrier.wait()
|
||||
self.event_loop.add_reader(server, self._dispatch, server)
|
||||
asyncio.set_event_loop(self.event_loop)
|
||||
self.event_loop.run_forever()
|
||||
self.event_loop.remove_reader(server)
|
||||
|
||||
def __enter__(self):
|
||||
# We are not re-entrant, so complain if re-entered.
|
||||
assert self.event_loop is None
|
||||
|
||||
self.event_loop = asyncio.new_event_loop()
|
||||
self.thread = threading.Thread(target=self._run_event_loop)
|
||||
|
||||
self.barrier.reset()
|
||||
self.thread.start()
|
||||
self.barrier.wait()
|
||||
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.event_loop.call_soon_threadsafe(self.event_loop.stop)
|
||||
self.thread.join()
|
||||
self.event_loop.close()
|
||||
|
||||
# Give deriving classes a chance to clean themselves up
|
||||
self._cleanup()
|
||||
|
||||
self.thread = None
|
||||
self.event_loop = None
|
||||
|
||||
|
||||
class API(BaseAPI):
|
||||
"""The main OSBuild API"""
|
||||
def __init__(self, socket_address, args, monitor):
|
||||
super().__init__(socket_address)
|
||||
self.input = args
|
||||
self._output_data = io.StringIO()
|
||||
self._output_pipe = None
|
||||
self.monitor = monitor
|
||||
self.event_loop = asyncio.new_event_loop()
|
||||
self.thread = threading.Thread(target=self._run_event_loop)
|
||||
self.barrier = threading.Barrier(2)
|
||||
|
||||
@property
|
||||
def output(self):
|
||||
|
|
@ -61,23 +127,7 @@ class API:
|
|||
if msg["method"] == 'setup-stdio':
|
||||
self._setup_stdio(server, addr)
|
||||
|
||||
def _run_event_loop(self):
|
||||
with jsoncomm.Socket.new_server(self.socket_address) as server:
|
||||
self.barrier.wait()
|
||||
self.event_loop.add_reader(server, self._dispatch, server)
|
||||
asyncio.set_event_loop(self.event_loop)
|
||||
self.event_loop.run_forever()
|
||||
self.event_loop.remove_reader(server)
|
||||
|
||||
def __enter__(self):
|
||||
self.thread.start()
|
||||
self.barrier.wait()
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.event_loop.call_soon_threadsafe(self.event_loop.stop)
|
||||
self.thread.join()
|
||||
self.event_loop.close()
|
||||
def _cleanup(self):
|
||||
if self._output_pipe:
|
||||
os.close(self._output_pipe)
|
||||
self._output_pipe = None
|
||||
|
|
|
|||
|
|
@ -93,7 +93,7 @@ class TestMonitor(unittest.TestCase):
|
|||
p.start()
|
||||
p.join()
|
||||
self.assertEqual(p.exitcode, 0)
|
||||
output = api.output
|
||||
output = api.output # pylint: disable=no-member
|
||||
assert output
|
||||
|
||||
self.assertEqual(json.dumps(args), output)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue