debian-forge/test/mod/test_host.py
Christian Kellner 662fe0feb9 test/host: checks for invalid fd handling
Check that we properly handle `dispatch` returning invalid fds and
that all fds are properly closed.
2021-09-24 12:14:04 +01:00

144 lines
3.7 KiB
Python
Executable file

#!/usr/bin/python3
#
# Runtime Tests for Host Services
#
import errno
import os
import sys
import tempfile
from typing import Any
import pytest
from osbuild import host
from osbuild.util.jsoncomm import FdSet
class ServiceTest(host.Service):
def __init__(self, args):
super().__init__(args)
self.fds = []
def register_fds(self, fds):
self.fds.extend(fds)
def dispatch(self, method: str, args: Any, fds: FdSet):
ret = None
if method == "exception":
raise ValueError("Remote Exception")
if method == "echo":
ret = args
elif method == "echo-fd":
ret = args
with tempfile.TemporaryFile("w+") as f:
with os.fdopen(fds.steal(0)) as d:
f.write(d.read())
f.seek(0)
fds = [os.dup(f.fileno())]
self.register_fds(fds)
elif method == "identify":
ret = self.id
elif method == "invalid-fd":
ret = []
with tempfile.TemporaryFile("w+") as f:
valid_fd = os.dup(f.fileno())
invalid_fd = valid_fd + 10
fds = [valid_fd, invalid_fd]
self.register_fds([valid_fd])
elif method == "check-fds-are-closed":
while self.fds:
fd = self.fds.pop()
try:
os.close(fd)
except OSError as e:
if e.errno == errno.EBADF:
print(f"fd '{fd}' was closed")
continue
raise
raise ValueError(f"fd '{fd}' was not closed")
else:
raise host.ProtocolError("unknown method:", method)
return ret, fds
def test_basic():
with host.ServiceManager() as mgr:
for i in range(3):
client = mgr.start(str(i), __file__)
args = ["an", "argument"]
res = client.call("echo", args)
assert args == res
remote_id = client.call("identify")
assert remote_id == str(i)
with pytest.raises(ValueError, match=f"{str(i)}"):
_ = mgr.start(str(i), __file__)
for i in range(3):
client = mgr.services[str(i)]
client.stop()
def test_pass_fd():
with host.ServiceManager() as mgr:
for i in range(3):
client = mgr.start(str(i), __file__)
args = ["an", "argument"]
data = "osbuild\n"
with tempfile.TemporaryFile("w+") as f:
f.write(data)
f.seek(0)
res, fds = client.call_with_fds("echo-fd", args, fds=[f.fileno()])
assert args == res
with os.fdopen(fds.steal(0)) as d:
assert data == d.read()
client.call_with_fds("check-fds-are-closed")
remote_id = client.call("identify")
assert remote_id == str(i)
with pytest.raises(ValueError, match=f"{str(i)}"):
_ = mgr.start(str(i), __file__)
for i in range(3):
client = mgr.services[str(i)]
client.stop()
def test_pass_fd_invalid():
with host.ServiceManager() as mgr:
client = mgr.start(str("test-invalid-fd"), __file__)
with pytest.raises(host.RemoteError):
client.call_with_fds("invalid-fd")
client.call_with_fds("check-fds-are-closed")
def test_exception():
with host.ServiceManager() as mgr:
client = mgr.start("exception", __file__)
with pytest.raises(host.RemoteError, match=r"Remote Exception"):
client.call("exception")
def main():
service = ServiceTest.from_args(sys.argv[1:])
service.main()
if __name__ == "__main__":
main()