debian-forge/osbuild/api.py
Christian Kellner 0c7284572e osbuild: auto-generate socket addresses for APIs
Rely on the ability of `BaseAPI` to auto-generate socket addresses
when no one was provided. The `BuildRoot` does not rely on the
sockets being created in the `BuildRoot.api` directory anymore and
will instead bind-mount each individual socket address to the well
known location via the `BaseAPI.endpoint` identifier.
Convert all API providers to take the `socket_address` as an
optional keyword argument.
2020-07-27 12:50:38 +01:00

174 lines
5.2 KiB
Python

import abc
import asyncio
import io
import json
import os
import sys
import tempfile
import threading
from typing import Optional
from .util.types import PathLike
from .util import jsoncomm
__all__ = [
"API"
]
class BaseAPI(abc.ABC):
"""Base class for all API providers
This base class provides the basic scaffolding for setting
up API endpoints, normally to be used for bi-directional
communication from and to the sandbox. It is to be used as
a context manger. The communication channel will only be
established on entering the context and will be shut down
when the context is left.
The `_dispatch` method needs to be implemented by child
classes, and is called for incoming messages.
Optionally, the `_cleanup` method can be implemented, to
clean up resources after the context is left and the
communication channel shut down.
"""
def __init__(self, socket_address: Optional[PathLike] = None):
self.socket_address = socket_address
self.barrier = threading.Barrier(2)
self.event_loop = None
self.thread = None
self._socketdir = None
@property
@classmethod
@abc.abstractmethod
def endpoint(cls):
"""The name of the API endpoint"""
@abc.abstractmethod
def _dispatch(self, server):
"""Called for incoming messages on the socket"""
def _cleanup(self):
"""Called after the event loop is shut down"""
@classmethod
def _make_socket_dir(cls):
"""Called to create the temporary socket dir"""
return tempfile.TemporaryDirectory(prefix="api-", dir="/run/osbuild")
def _run_event_loop(self):
with jsoncomm.Socket.new_server(self.socket_address) as server:
self.barrier.wait()
self.event_loop.add_reader(server, self._dispatch, server)
asyncio.set_event_loop(self.event_loop)
self.event_loop.run_forever()
self.event_loop.remove_reader(server)
def __enter__(self):
# We are not re-entrant, so complain if re-entered.
assert self.event_loop is None
if not self.socket_address:
self._socketdir = self._make_socket_dir()
address = os.path.join(self._socketdir.name, self.endpoint)
self.socket_address = address
self.event_loop = asyncio.new_event_loop()
self.thread = threading.Thread(target=self._run_event_loop)
self.barrier.reset()
self.thread.start()
self.barrier.wait()
return self
def __exit__(self, *args):
self.event_loop.call_soon_threadsafe(self.event_loop.stop)
self.thread.join()
self.event_loop.close()
# Give deriving classes a chance to clean themselves up
self._cleanup()
self.thread = None
self.event_loop = None
if self._socketdir:
self._socketdir.cleanup()
self._socketdir = None
self.socket_address = None
class API(BaseAPI):
"""The main OSBuild API"""
endpoint = "osbuild"
def __init__(self, args, monitor, *, socket_address=None):
super().__init__(socket_address)
self.input = args
self._output_data = io.StringIO()
self._output_pipe = None
self.monitor = monitor
@property
def output(self):
return self._output_data.getvalue()
def _prepare_input(self):
with tempfile.TemporaryFile() as fd:
fd.write(json.dumps(self.input).encode('utf-8'))
# re-open the file to get a read-only file descriptor
return open(f"/proc/self/fd/{fd.fileno()}", "r")
def _prepare_output(self):
r, w = os.pipe()
self._output_pipe = r
self._output_data.truncate(0)
self._output_data.seek(0)
self.event_loop.add_reader(r, self._output_ready)
return os.fdopen(w)
def _output_ready(self):
raw = os.read(self._output_pipe, 4096)
data = raw.decode("utf-8")
self._output_data.write(data)
self.monitor.log(data)
def _setup_stdio(self, server, addr):
with self._prepare_input() as stdin, \
self._prepare_output() as stdout:
msg = {}
fds = []
fds.append(stdin.fileno())
msg['stdin'] = 0
fds.append(stdout.fileno())
msg['stdout'] = 1
fds.append(stdout.fileno())
msg['stderr'] = 2
server.send(msg, fds=fds, destination=addr)
def _dispatch(self, server):
msg, _, addr = server.recv()
if msg["method"] == 'setup-stdio':
self._setup_stdio(server, addr)
def _cleanup(self):
if self._output_pipe:
os.close(self._output_pipe)
self._output_pipe = None
def setup_stdio(path="/run/osbuild/api/osbuild"):
"""Replace standard i/o with the ones provided by the API"""
with jsoncomm.Socket.new_client(path) as client:
req = {"method": "setup-stdio"}
client.send(req)
msg, fds, _ = client.recv()
for sio in ["stdin", "stdout", "stderr"]:
target = getattr(sys, sio)
source = fds[msg[sio]]
os.dup2(source, target.fileno())
fds.close()