diff --git a/osbuild/formats/v2.py b/osbuild/formats/v2.py index b179ac4f..ec660c0b 100644 --- a/osbuild/formats/v2.py +++ b/osbuild/formats/v2.py @@ -184,7 +184,7 @@ def load_device(name: str, description: Dict, index: Index, stage: Stage): stage.add_device(name, info, options) -def load_input(name: str, description: Dict, index: Index, stage: Stage, manifest: Manifest): +def load_input(name: str, description: Dict, index: Index, stage: Stage, manifest: Manifest, source_refs: set): input_type = description["type"] origin = description["origin"] options = description.get("options", {}) @@ -205,6 +205,10 @@ def load_input(name: str, description: Dict, index: Index, stage: Stage, manifes target = resolve_ref(r, manifest) resolved[target] = desc refs = resolved + elif origin == "org.osbuild.source": + unknown_refs = set(refs.keys()) - source_refs + if unknown_refs: + raise ValueError(f"Unknown source reference(s) {unknown_refs}") for r, desc in refs.items(): ip.add_reference(r, desc) @@ -226,7 +230,7 @@ def load_mount(name: str, description: Dict, index: Index, stage: Stage): stage.add_mount(name, info, device, target, options) -def load_stage(description: Dict, index: Index, pipeline: Pipeline, manifest: Manifest): +def load_stage(description: Dict, index: Index, pipeline: Pipeline, manifest: Manifest, source_refs): stage_type = description["type"] opts = description.get("options", {}) info = index.get_module_info("Stage", stage_type) @@ -239,7 +243,7 @@ def load_stage(description: Dict, index: Index, pipeline: Pipeline, manifest: Ma ips = description.get("inputs", {}) for name, desc in ips.items(): - load_input(name, desc, index, stage, manifest) + load_input(name, desc, index, stage, manifest, source_refs) mounts = description.get("mounts", {}) for name, desc in mounts.items(): @@ -248,7 +252,7 @@ def load_stage(description: Dict, index: Index, pipeline: Pipeline, manifest: Ma return stage -def load_pipeline(description: Dict, index: Index, manifest: Manifest): +def load_pipeline(description: Dict, index: Index, manifest: Manifest, source_refs: set): name = description["name"] build = description.get("build") runner = description.get("runner") @@ -260,7 +264,7 @@ def load_pipeline(description: Dict, index: Index, manifest: Manifest): pl = manifest.add_pipeline(name, runner, build) for desc in description.get("stages", []): - load_stage(desc, index, pl, manifest) + load_stage(desc, index, pl, manifest, source_refs) def load(description: Dict, index: Index) -> Manifest: @@ -270,6 +274,7 @@ def load(description: Dict, index: Index) -> Manifest: pipelines = description.get("pipelines", []) manifest = Manifest() + source_refs = set() # load the sources for name, desc in sources.items(): @@ -277,9 +282,10 @@ def load(description: Dict, index: Index) -> Manifest: items = desc.get("items", {}) options = desc.get("options", {}) manifest.add_source(info, items, options) + source_refs.update(items.keys()) for desc in pipelines: - load_pipeline(desc, index, manifest) + load_pipeline(desc, index, manifest, source_refs) # The "runner" property in the manifest format is the # runner to the run the pipeline with. In osbuild the diff --git a/test/mod/test_fmt_v2.py b/test/mod/test_fmt_v2.py index cf862830..6e5c8aaf 100644 --- a/test/mod/test_fmt_v2.py +++ b/test/mod/test_fmt_v2.py @@ -72,6 +72,37 @@ BASIC_PIPELINE = { ] } +BAD_SHA = "sha256:15a654d32efaa75b5df3e2481939d0393fe1746696cc858ca094ccf8b76073cd" + +BAD_REF_PIPELINE = { + "version": "2", + "sources": { + "org.osbuild.curl": { + "items": { + "sha256:c540ca8c5e21ba5f063286c94a088af2aac0b15bc40df6fd562d40154c10f4a1": "", + } + } + }, + "pipelines": [ + { + "name": "build", + "stages": [ + { + "type": "org.osbuild.rpm", + "inputs": { + "packages": { + "type": "org.osbuild.files", + "origin": "org.osbuild.source", + "references": { + BAD_SHA: {} + } + } + } + } + ] + } + ] +} class TestFormatV1(unittest.TestCase): def setUp(self): @@ -147,3 +178,17 @@ class TestFormatV1(unittest.TestCase): res = fmt.validate(desc, self.index) self.assert_validation(res) + + def test_load_bad_ref_manifest(self): + desc = BAD_REF_PIPELINE + + info = self.index.detect_format_info(desc) + self.assertIsNotNone(info) + fmt = info.module + self.assertIsNotNone(fmt) + + with self.assertRaises(ValueError) as ex: + fmt.load(desc, self.index) + + self.assertTrue(str(ex.exception).find(BAD_SHA) > -1, + "The unknown source reference is not included in the exception")