osbuild: fix optional-types

Optional types were provided in places but were not always correct. Add
mypy checking and fix those that fail(ed).
This commit is contained in:
Simon de Vlieger 2022-07-06 10:54:37 +02:00 committed by Christian Kellner
parent 6e66c69608
commit 3fd864e5a9
29 changed files with 209 additions and 111 deletions

View file

@ -75,7 +75,8 @@ class BaseAPI(abc.ABC):
msg, fds, _ = sock.recv()
if msg is None:
# Peer closed the connection
self.event_loop.remove_reader(sock)
if self.event_loop:
self.event_loop.remove_reader(sock)
return
self._message(msg, fds, sock)
fds.close()

View file

@ -16,8 +16,9 @@ import subprocess
import tempfile
import time
from typing import Optional
from typing import Optional, Set
from osbuild.api import BaseAPI
from osbuild.util import linux
@ -57,7 +58,7 @@ class ProcOverrides:
def __init__(self, path) -> None:
self.path = path
self.overrides = set()
self.overrides: Set["str"] = set()
@property
def cmdline(self) -> str:
@ -167,7 +168,7 @@ class BuildRoot(contextlib.AbstractContextManager):
self._exitstack.close()
self._exitstack = None
def register_api(self, api: "BaseAPI"):
def register_api(self, api: BaseAPI):
"""Register an API endpoint.
The context of the API endpoint will be bound to the context of

View file

@ -17,7 +17,7 @@ import hashlib
import json
import os
from typing import Dict, Optional
from typing import Dict, Optional, Any
from osbuild import host
@ -52,11 +52,11 @@ class DeviceManager:
Uses a `host.ServiceManager` to open `Device` instances.
"""
def __init__(self, mgr: host.ServiceManager, devpath: str, tree: str) -> Dict:
def __init__(self, mgr: host.ServiceManager, devpath: str, tree: str) -> None:
self.service_manager = mgr
self.devpath = devpath
self.tree = tree
self.devices = {}
self.devices: Dict[str, Dict[str, Any]] = {}
def device_relpath(self, dev: Optional[Device]) -> Optional[str]:
if dev is None:

View file

@ -7,7 +7,7 @@ the created tree into an artefact. The pipeline can have any
number of nested build pipelines. A sources section is used
to fetch resources.
"""
from typing import Dict
from typing import Dict, Any
from osbuild.meta import Index, ValidationResult
from ..pipeline import BuildResult, Manifest, Pipeline, detect_host_runner
@ -15,9 +15,9 @@ from ..pipeline import BuildResult, Manifest, Pipeline, detect_host_runner
VERSION = "1"
def describe(manifest: Manifest, *, with_id=False) -> Dict:
def describe(manifest: Manifest, *, with_id=False) -> Dict[str, Any]:
"""Create the manifest description for the pipeline"""
def describe_stage(stage):
def describe_stage(stage) -> Dict[str, Any]:
description = {"name": stage.name}
if stage.options:
description["options"] = stage.options
@ -25,8 +25,8 @@ def describe(manifest: Manifest, *, with_id=False) -> Dict:
description["id"] = stage.id
return description
def describe_pipeline(pipeline: Pipeline) -> Dict:
description = {}
def describe_pipeline(pipeline: Pipeline) -> Dict[str, Any]:
description: Dict[str, Any] = {}
if pipeline.build:
build = manifest[pipeline.build]
description["build"] = {

View file

@ -2,7 +2,7 @@
Second, and current, version of the manifest description
"""
from typing import Dict
from typing import Dict, Any
from osbuild.meta import Index, ModuleInfo, ValidationResult
from ..inputs import Input
from ..pipeline import Manifest, Pipeline, Stage, detect_host_runner
@ -120,7 +120,7 @@ def describe(manifest: Manifest, *, with_id=False) -> Dict:
return desc
def describe_pipeline(p: Pipeline):
desc = {
desc: Dict[str, Any] = {
"name": p.name
}
@ -158,7 +158,7 @@ def describe(manifest: Manifest, *, with_id=False) -> Dict:
for source in manifest.sources
}
description = {
description: Dict[str, Any] = {
"version": VERSION,
"pipelines": pipelines
}
@ -384,6 +384,8 @@ def load(description: Dict, index: Index) -> Manifest:
def output(manifest: Manifest, res: Dict) -> Dict:
"""Convert a result into the v2 format"""
result: Dict[str, Any] = {}
if not res["success"]:
last = list(res.keys())[-1]
failed = res[last]["stages"][-1]
@ -412,7 +414,7 @@ def output(manifest: Manifest, res: Dict) -> Dict:
# gather all the metadata
for p in manifest.pipelines.values():
data = {}
data: Dict[str, Any] = {}
r = res.get(p.id, {})
for stage in r.get("stages", []):
md = stage.metadata

View file

@ -48,7 +48,7 @@ import sys
import threading
import traceback
from collections import OrderedDict
from typing import Any, Dict, List, Optional, Tuple, Callable
from typing import Any, Dict, List, Optional, Tuple, Callable, Iterable, Union
from osbuild.util.jsoncomm import FdSet, Socket
@ -92,7 +92,7 @@ class ServiceProtocol:
return t, d
@staticmethod
def encode_method(name: str, arguments: List):
def encode_method(name: str, arguments: Union[List[str], Dict[str, Any]]):
msg = {
"type": "method",
"data": {
@ -349,14 +349,21 @@ class ServiceClient:
return ret
def call_with_fds(self, method: str,
args: Optional[Any] = None,
fds: Optional[List] = None,
on_signal: Callable[[Any, FdSet], None] = None) -> Tuple[Any, FdSet]:
args: Optional[Union[List[str], Dict[str, Any]]] = None,
fds: Optional[List[int]] = None,
on_signal: Callable[[Any, Optional[Iterable[int]]], None] = None
) -> Tuple[Any, Optional[Iterable[int]]]:
"""
Remotely call a method and return the result, including file
descriptors.
"""
if args is None:
args = []
if fds is None:
fds = []
msg = self.protocol.encode_method(method, args)
self.sock.send(msg, fds=fds)
@ -366,7 +373,9 @@ class ServiceClient:
kind, data = self.protocol.decode_message(ret)
if kind == "signal":
ret = self.protocol.decode_reply(data)
on_signal(ret, fds)
if on_signal:
on_signal(ret, fds)
if kind == "reply":
ret = self.protocol.decode_reply(data)
return ret, fds
@ -471,6 +480,9 @@ class ServiceManager:
self.services[uid] = service
ours = None
if proc.stdout is None:
raise RuntimeError("No stdout.")
stdout = io.TextIOWrapper(proc.stdout,
encoding="utf-8",
line_buffering=True)

View file

@ -21,7 +21,7 @@ import hashlib
import json
import os
from typing import Dict, Optional, Tuple
from typing import Dict, Optional, Tuple, Any
from osbuild import host
from osbuild.util.types import PathLike
@ -37,7 +37,7 @@ class Input:
self.name = name
self.info = info
self.origin = origin
self.refs = {}
self.refs: Dict[str, Dict[str, Any]] = {}
self.options = options or {}
self.id = self.calc_id()
@ -60,11 +60,11 @@ class Input:
class InputManager:
def __init__(self, mgr: host.ServiceManager, storeapi: StoreServer, root: PathLike) -> Dict:
def __init__(self, mgr: host.ServiceManager, storeapi: StoreServer, root: PathLike) -> None:
self.service_manager = mgr
self.storeapi = storeapi
self.root = root
self.inputs = {}
self.inputs: Dict[str, Input] = {}
def map(self, ip: Input) -> Tuple[str, Dict]:

View file

@ -29,7 +29,7 @@ import pkgutil
import json
import sys
from collections import deque
from typing import Dict, Iterable, List, Optional
from typing import Dict, Sequence, List, Optional, Union, Set, Deque, Any, Tuple
import jsonschema
@ -50,7 +50,7 @@ class ValidationError:
def __init__(self, message: str):
self.message = message
self.path = deque()
self.path: Deque[Union[int, str]] = deque()
@classmethod
def from_exception(cls, ex):
@ -88,7 +88,7 @@ class ValidationError:
"path": list(self.path)
}
def rebase(self, path: Iterable[str]):
def rebase(self, path: Sequence[str]):
"""Prepend the `path` to `self.path`"""
rev = reversed(path)
self.path.extendleft(rev)
@ -96,7 +96,7 @@ class ValidationError:
def __hash__(self):
return hash((self.id, self.message))
def __eq__(self, other: "ValidationError"):
def __eq__(self, other: object):
if not isinstance(other, ValidationError):
raise ValueError("Need ValidationError")
@ -119,7 +119,7 @@ class ValidationResult:
def __init__(self, origin: Optional[str]):
self.origin = origin
self.errors = set()
self.errors: Set[ValidationError] = set()
def fail(self, msg: str) -> ValidationError:
"""Add a new `ValidationError` with `msg` as message"""
@ -218,7 +218,7 @@ class Schema:
def __init__(self, schema: str, name: Optional[str] = None):
self.data = schema
self.name = name
self._validator = None
self._validator: Optional[jsonschema.Draft4Validator] = None
def check(self) -> ValidationResult:
"""Validate the `schema` data itself"""
@ -258,9 +258,13 @@ class Schema:
with 'missing schema information' as the reason.
"""
res = self.check()
if not res:
return res
if not self._validator:
raise RuntimeError("Trying to validate without validator.")
for error in self._validator.iter_errors(target):
res += ValidationError.from_exception(error)
@ -426,7 +430,7 @@ class ModuleInfo:
tree = ast.parse(data, name)
docstring = ast.get_docstring(tree)
doclist = docstring.split("\n")
doclist = docstring.split("\n") if docstring else []
assigns = filter_type(tree.body, ast.Assign)
values = {
@ -489,15 +493,19 @@ class Index:
def __init__(self, path: str):
self.path = path
self._module_info = {}
self._format_info = {}
self._schemata = {}
self._module_info: Dict[Tuple[str, Any], Any] = {}
self._format_info: Dict[Tuple[str, Any], Any] = {}
self._schemata: Dict[Tuple[str, Any, str], Schema] = {}
@staticmethod
def list_formats() -> List[str]:
"""List all known formats for manifest descriptions"""
base = "osbuild.formats"
spec = importlib.util.find_spec(base)
if not spec:
raise RuntimeError(f"Could not find spec for {base!r}")
locations = spec.submodule_search_locations
modinfo = [
mod for mod in pkgutil.walk_packages(locations)
@ -555,9 +563,10 @@ class Index:
that case the actual schema data for `Schema` will be
`None` and any validation will fail.
"""
schema = self._schemata.get((klass, name, version))
if schema is not None:
return schema
cached_schema: Optional[Schema] = self._schemata.get((klass, name, version))
if cached_schema is not None:
return cached_schema
if klass == "Manifest":
path = f"{self.path}/schemas/osbuild{version}.json"

View file

@ -13,7 +13,7 @@ import json
import os
import subprocess
from typing import Dict
from typing import Dict, List
from osbuild import host
from osbuild.devices import DeviceManager
@ -55,7 +55,7 @@ class MountManager:
def __init__(self, devices: DeviceManager, root: str) -> None:
self.devices = devices
self.root = root
self.mounts = {}
self.mounts: Dict[str, Dict[str, Mount]] = {}
def mount(self, mount: Mount) -> Dict:
@ -77,7 +77,7 @@ class MountManager:
path = client.call("mount", args)
if not path:
res = {}
res: Dict[str, Mount] = {}
self.mounts[mount.name] = res
return res
@ -123,7 +123,7 @@ class FileSystemMountService(MountService):
self.check = False
@abc.abstractmethod
def translate_options(self, options: Dict):
def translate_options(self, options: Dict) -> List:
return []
def mount(self, args: Dict):
@ -134,14 +134,15 @@ class FileSystemMountService(MountService):
options = args["options"]
mountpoint = os.path.join(root, target.lstrip("/"))
args = self.translate_options(options)
options = self.translate_options(options)
os.makedirs(mountpoint, exist_ok=True)
self.mountpoint = mountpoint
subprocess.run(
["mount"] +
args + [
options + [
"--source", source,
"--target", mountpoint
],

View file

@ -3,7 +3,7 @@ import os
import subprocess
import tempfile
import uuid
from typing import Optional
from typing import Optional, Iterator, Set
from osbuild.util.types import PathLike
from osbuild.util import jsoncomm, rmrf
@ -55,10 +55,10 @@ class Object:
self._init = True
self._readers = 0
self._writer = False
self._base = None
self._base: Optional[str] = None
self._workdir = None
self._tree = None
self.id = None
self.id: Optional[str] = None
self.store = store
self.reset()
@ -85,7 +85,7 @@ class Object:
self.id = base_id
@property
def _path(self) -> str:
def _path(self) -> Optional[str]:
if self._base and not self._init:
path = self.store.resolve_ref(self._base)
else:
@ -93,7 +93,7 @@ class Object:
return path
@contextlib.contextmanager
def write(self) -> str:
def write(self) -> Iterator[str]:
"""Return a path that can be written to"""
self._check_writable()
self._check_readers()
@ -110,13 +110,13 @@ class Object:
self._writer = False
@contextlib.contextmanager
def read(self) -> str:
def read(self) -> Iterator[PathLike]:
with self.tempdir("reader") as target:
with self.read_at(target) as path:
yield path
@contextlib.contextmanager
def read_at(self, target: PathLike, path: str = "/") -> str:
def read_at(self, target: PathLike, path: str = "/") -> Iterator[PathLike]:
"""Read the object or a part of it at given location
Map the tree or a part of it specified via `path` at the
@ -125,6 +125,9 @@ class Object:
self._check_writable()
self._check_writer()
if self._path is None:
raise RuntimeError("read_at with no path.")
path = os.path.join(self._path, path.lstrip("/"))
mount(path, target)
@ -260,7 +263,7 @@ class ObjectStore(contextlib.AbstractContextManager):
os.makedirs(self.objects, exist_ok=True)
os.makedirs(self.refs, exist_ok=True)
os.makedirs(self.tmp, exist_ok=True)
self._objs = set()
self._objs: Set[Object] = set()
def _get_floating(self, object_id: str) -> Optional[Object]:
"""Internal: get a non-committed object"""
@ -349,7 +352,13 @@ class ObjectStore(contextlib.AbstractContextManager):
with self.tempdir() as tmp:
link = f"{tmp}/link"
os.symlink(f"../objects/{object_name}", link)
os.replace(link, self.resolve_ref(object_id))
ref = self.resolve_ref(object_id)
if not ref:
raise RuntimeError("commit with unresolvable ref")
os.replace(link, ref)
# the reference that is pointing to `object_name` is now the base
# of `obj`. It is not actively initialized but any subsequent calls

View file

@ -15,6 +15,7 @@ from .inputs import Input, InputManager
from .mounts import Mount, MountManager
from .sources import Source
from .util import osrelease
from .objectstore import ObjectStore
DEFAULT_CAPABILITIES = {
@ -244,7 +245,7 @@ class Pipeline:
self.name = name
self.build = build
self.runner = runner
self.stages = []
self.stages: List[Stage] = []
self.assembler = None
self.source_epoch = source_epoch
@ -370,7 +371,9 @@ class Manifest:
self.pipelines = collections.OrderedDict()
self.sources: List[Source] = []
def add_pipeline(self, name: str, runner: str, build: str, source_epoch: Optional[int] = None) -> Pipeline:
def add_pipeline(
self, name: str, runner: Optional[str], build: Optional[str] = None, source_epoch: Optional[int] = None
) -> Pipeline:
pipeline = Pipeline(name, runner, build, source_epoch)
if name in self.pipelines:
raise ValueError(f"Name {name} already exists")
@ -387,7 +390,7 @@ class Manifest:
for source in self.sources:
source.download(mgr, store, libdir)
def depsolve(self, store, targets: Iterable[str]) -> List[str]:
def depsolve(self, store: ObjectStore, targets: Iterable[str]) -> List[str]:
"""Return the list of pipelines that need to be built
Given a list of target pipelines, return the names
@ -404,6 +407,9 @@ class Manifest:
while check:
pl = check.pop() # get the last(!) item
if not pl:
raise RuntimeError("Could not find pipeline.")
if store.contains(pl.id):
continue

View file

@ -75,10 +75,11 @@ class SourceService(host.Service):
return checksum, desc
def download(self, items: Dict) -> None:
items = filter(lambda i: not self.exists(i[0], i[1]), items.items()) # discards items already in cache
items = map(lambda i: self.transform(i[0], i[1]), items) # prepare each item to be downloaded
filtered = filter(lambda i: not self.exists(i[0], i[1]), items.items()) # discards items already in cache
transformed = map(lambda i: self.transform(i[0], i[1]), filtered) # prepare each item to be downloaded
with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
for _ in executor.map(self.fetch_one, *zip(*items)):
for _ in executor.map(self.fetch_one, *zip(*transformed)):
pass
@property
@ -102,6 +103,7 @@ class SourceService(host.Service):
if method == "download":
self.setup(args)
with tempfile.TemporaryDirectory(prefix=".unverified-", dir=self.cache) as self.tmpdir:
return self.download(SourceService.load_items(fds)), None
self.download(SourceService.load_items(fds))
return None, None
raise host.ProtocolError("Unknown method")

View file

@ -127,13 +127,20 @@ class Socket(contextlib.AbstractContextManager):
@blocking.setter
def blocking(self, value: bool):
"""Set the blocking mode of the socket."""
self._socket.setblocking(value)
if self._socket:
self._socket.setblocking(value)
else:
raise RuntimeError("Tried to set blocking mode without socket.")
def accept(self) -> Optional["Socket"]:
"""Accept a new connection on the socket.
See python's `socket.accept` for more information.
"""
if not self._socket:
raise RuntimeError("Tried to accept without socket.")
# Since, in the kernel, for AF_UNIX, new connection requests,
# i.e. clients connecting, are directly put on the receive
# queue of the listener socket, accept here *should* always
@ -151,6 +158,9 @@ class Socket(contextlib.AbstractContextManager):
See python's `socket.listen` for details.
"""
if not self._socket:
raise RuntimeError("Tried to listen without socket.")
# `Socket.listen` accepts an `int` or no argument, but not `None`
args = [backlog] if backlog is not None else []
self._socket.listen(*args)
@ -386,6 +396,9 @@ class Socket(contextlib.AbstractContextManager):
If the payload cannot be serialized, a type error is raised.
"""
if not self._socket:
raise RuntimeError("Tried to send without socket.")
serialized = json.dumps(payload).encode()
cmsg = []
if fds:

View file

@ -141,25 +141,25 @@ class LibCap:
get_bound = lib.cap_get_bound
get_bound.argtypes = (self.cap_value_t,)
get_bound.restype = ctypes.c_int
get_bound.errcheck = self._check_result
get_bound.errcheck = self._check_result # type: ignore
self._get_bound = get_bound
from_name = lib.cap_from_name
from_name.argtypes = (ctypes.c_char_p, ctypes.POINTER(self.cap_value_t),)
from_name.restype = ctypes.c_int
from_name.errcheck = self._check_result
from_name.errcheck = self._check_result # type: ignore
self._from_name = from_name
to_name = lib.cap_to_name
to_name.argtypes = (ctypes.c_int,)
to_name.restype = ctypes.POINTER(ctypes.c_char)
to_name.errcheck = self._check_result
to_name.errcheck = self._check_result # type: ignore
self._to_name = to_name
free = lib.cap_free
free.argtypes = (ctypes.c_void_p,)
free.restype = ctypes.c_int
free.errcheck = self._check_result
free.errcheck = self._check_result # type: ignore
self._free = free
@staticmethod
@ -210,6 +210,10 @@ class LibCap:
"""Translate from the capability's integer value to the its symbolic name"""
raw = self._to_name(value)
val = ctypes.cast(raw, ctypes.c_char_p).value
if val is None:
raise RuntimeError("Failed to cast.")
res = str(val, encoding="utf-8")
self._free(raw)
return res.upper()

View file

@ -16,6 +16,8 @@ import shlex
import shutil
import subprocess
from typing import Dict, Any
import mako.template
@ -44,7 +46,7 @@ def rglob(pathname, *, fatal=False):
class Script:
# all built-in commands in a name to method map
commands = {}
commands: Dict[str, Any] = {}
# helper decorator to register builtin methods
class command:

View file

@ -28,7 +28,7 @@ import struct
import sys
from collections import OrderedDict
from typing import BinaryIO, Dict, Union
from typing import BinaryIO, Dict, Union, List
PathLike = Union[str, bytes, os.PathLike]
@ -118,7 +118,7 @@ class Header:
@property
@classmethod
@abc.abstractmethod
def struct(cls) -> struct.Struct:
def struct(cls) -> Union[struct.Struct, CStruct]:
"""Definition of the underlying struct data"""
def __init__(self, data):
@ -146,6 +146,10 @@ class Header:
def __str__(self) -> str:
msg = f"{self.__class__.__name__}:"
if not isinstance(self.struct, CStruct):
raise RuntimeError("No field support on Struct")
for f in self.struct.fields:
msg += f"\n\t{f.name}: {self[f.name]}"
return msg
@ -402,8 +406,7 @@ class Metadata:
@classmethod
def decode(cls, data: bytes) -> "Metadata":
data = data.decode("utf-8")
name, md = Metadata.decode_data(data)
name, md = Metadata.decode_data(data.decode("utf8"))
return cls(name, md)
def encode(self) -> bytes:
@ -535,8 +538,7 @@ class Disk:
self.lbl_hdr = None
self.pv_hdr = None
self.ma_headers = []
self.metadata = None
self.ma_headers: List[MDAHeader] = []
try:
self._init_headers()
@ -568,7 +570,7 @@ class Disk:
self.metadata = md
@classmethod
def open(cls, path: PathLike, *, read_only=False) -> None:
def open(cls, path: PathLike, *, read_only: bool = False) -> "Disk":
mode = "rb"
if not read_only:
mode += "+"

View file

@ -7,7 +7,7 @@ import sys
import tempfile
import typing
from typing import List
from typing import List, Any
from .types import PathLike
@ -116,6 +116,9 @@ def rev_parse(repo: PathLike, ref: str) -> str:
repo = os.fspath(repo)
if isinstance(repo, bytes):
repo = repo.decode("utf8")
r = subprocess.run(["ostree", "rev-parse", ref, f"--repo={repo}"],
encoding="utf-8",
stdout=subprocess.PIPE,
@ -134,6 +137,9 @@ def show(repo: PathLike, checksum: str) -> str:
repo = os.fspath(repo)
if isinstance(repo, bytes):
repo = repo.decode("utf8")
r = subprocess.run(["ostree", "show", f"--repo={repo}", checksum],
encoding="utf-8",
stdout=subprocess.PIPE,
@ -216,7 +222,7 @@ class SubIdsDB:
"""
def __init__(self) -> None:
self.db = collections.OrderedDict()
self.db: collections.OrderedDict[str, Any] = collections.OrderedDict()
def read(self, fp) -> int:
idx = 0

View file

@ -1,10 +1,7 @@
"""Path handling utility functions"""
import os.path
import os
from .types import PathLike
def in_tree(path: PathLike, tree: PathLike, must_exist=False) -> bool:
def in_tree(path: str, tree: str, must_exist: bool = False) -> bool:
"""Return whether the canonical location of 'path' is under 'tree'.
If 'must_exist' is True, the file must also exist for the check to succeed.
"""

View file

@ -2,10 +2,5 @@
# Define some useful typing abbreviations
#
import os
from typing import Union
#: Represents a file system path. See also `os.fspath`.
PathLike = Union[str, bytes, os.PathLike]
PathLike = str