osbuild: Add validation for source references

Validate source references while loading manifests so that a bad
reference would result in a meaningful error message instead of a
hard-to-understand Python exception.
This commit is contained in:
Diaa Sami 2021-07-16 17:07:29 +02:00 committed by Christian Kellner
parent 9e599fca17
commit 02ceb02d2a
2 changed files with 57 additions and 6 deletions

View file

@ -184,7 +184,7 @@ def load_device(name: str, description: Dict, index: Index, stage: Stage):
stage.add_device(name, info, options)
def load_input(name: str, description: Dict, index: Index, stage: Stage, manifest: Manifest):
def load_input(name: str, description: Dict, index: Index, stage: Stage, manifest: Manifest, source_refs: set):
input_type = description["type"]
origin = description["origin"]
options = description.get("options", {})
@ -205,6 +205,10 @@ def load_input(name: str, description: Dict, index: Index, stage: Stage, manifes
target = resolve_ref(r, manifest)
resolved[target] = desc
refs = resolved
elif origin == "org.osbuild.source":
unknown_refs = set(refs.keys()) - source_refs
if unknown_refs:
raise ValueError(f"Unknown source reference(s) {unknown_refs}")
for r, desc in refs.items():
ip.add_reference(r, desc)
@ -226,7 +230,7 @@ def load_mount(name: str, description: Dict, index: Index, stage: Stage):
stage.add_mount(name, info, device, target, options)
def load_stage(description: Dict, index: Index, pipeline: Pipeline, manifest: Manifest):
def load_stage(description: Dict, index: Index, pipeline: Pipeline, manifest: Manifest, source_refs):
stage_type = description["type"]
opts = description.get("options", {})
info = index.get_module_info("Stage", stage_type)
@ -239,7 +243,7 @@ def load_stage(description: Dict, index: Index, pipeline: Pipeline, manifest: Ma
ips = description.get("inputs", {})
for name, desc in ips.items():
load_input(name, desc, index, stage, manifest)
load_input(name, desc, index, stage, manifest, source_refs)
mounts = description.get("mounts", {})
for name, desc in mounts.items():
@ -248,7 +252,7 @@ def load_stage(description: Dict, index: Index, pipeline: Pipeline, manifest: Ma
return stage
def load_pipeline(description: Dict, index: Index, manifest: Manifest):
def load_pipeline(description: Dict, index: Index, manifest: Manifest, source_refs: set):
name = description["name"]
build = description.get("build")
runner = description.get("runner")
@ -260,7 +264,7 @@ def load_pipeline(description: Dict, index: Index, manifest: Manifest):
pl = manifest.add_pipeline(name, runner, build)
for desc in description.get("stages", []):
load_stage(desc, index, pl, manifest)
load_stage(desc, index, pl, manifest, source_refs)
def load(description: Dict, index: Index) -> Manifest:
@ -270,6 +274,7 @@ def load(description: Dict, index: Index) -> Manifest:
pipelines = description.get("pipelines", [])
manifest = Manifest()
source_refs = set()
# load the sources
for name, desc in sources.items():
@ -277,9 +282,10 @@ def load(description: Dict, index: Index) -> Manifest:
items = desc.get("items", {})
options = desc.get("options", {})
manifest.add_source(info, items, options)
source_refs.update(items.keys())
for desc in pipelines:
load_pipeline(desc, index, manifest)
load_pipeline(desc, index, manifest, source_refs)
# The "runner" property in the manifest format is the
# runner to the run the pipeline with. In osbuild the

View file

@ -72,6 +72,37 @@ BASIC_PIPELINE = {
]
}
BAD_SHA = "sha256:15a654d32efaa75b5df3e2481939d0393fe1746696cc858ca094ccf8b76073cd"
BAD_REF_PIPELINE = {
"version": "2",
"sources": {
"org.osbuild.curl": {
"items": {
"sha256:c540ca8c5e21ba5f063286c94a088af2aac0b15bc40df6fd562d40154c10f4a1": "",
}
}
},
"pipelines": [
{
"name": "build",
"stages": [
{
"type": "org.osbuild.rpm",
"inputs": {
"packages": {
"type": "org.osbuild.files",
"origin": "org.osbuild.source",
"references": {
BAD_SHA: {}
}
}
}
}
]
}
]
}
class TestFormatV1(unittest.TestCase):
def setUp(self):
@ -147,3 +178,17 @@ class TestFormatV1(unittest.TestCase):
res = fmt.validate(desc, self.index)
self.assert_validation(res)
def test_load_bad_ref_manifest(self):
desc = BAD_REF_PIPELINE
info = self.index.detect_format_info(desc)
self.assertIsNotNone(info)
fmt = info.module
self.assertIsNotNone(fmt)
with self.assertRaises(ValueError) as ex:
fmt.load(desc, self.index)
self.assertTrue(str(ex.exception).find(BAD_SHA) > -1,
"The unknown source reference is not included in the exception")