diff --git a/osbuild/__init__.py b/osbuild/__init__.py index 3a31d842..2ae3229d 100644 --- a/osbuild/__init__.py +++ b/osbuild/__init__.py @@ -374,20 +374,15 @@ class Assembler: class Pipeline: - def __init__(self, base=None): + def __init__(self, build=None, base=None): self.base = base - self.build = None + self.build = build self.stages = [] self.assembler = None def get_id(self): return self.stages[-1].id if self.stages else self.base - def set_build(self, pipeline): - if self.stages: - raise ValueError("Must set build before stages.") - self.build = pipeline - def add_stage(self, name, options=None): build = self.build.get_id() if self.build else None stage = Stage(name, build, self.get_id(), options or {}) @@ -474,11 +469,12 @@ class Pipeline: def load(description): - pipeline = Pipeline(description.get("base")) - - b = description.get("build") - if b: - pipeline.set_build(load(b)) + build_description = description.get("build") + if build_description: + build = load(build_description) + else: + build = None + pipeline = Pipeline(build, description.get("base")) for s in description.get("stages", []): pipeline.add_stage(s["name"], s.get("options", {})) diff --git a/test/test_osbuild.py b/test/test_osbuild.py index 1142bb28..63d2024c 100644 --- a/test/test_osbuild.py +++ b/test/test_osbuild.py @@ -47,8 +47,7 @@ class TestDescriptions(unittest.TestCase): build = osbuild.Pipeline() build.add_stage("org.osbuild.test", { "one": 1 }) - pipeline = osbuild.Pipeline() - pipeline.set_build(build) + pipeline = osbuild.Pipeline(build) pipeline.add_stage("org.osbuild.test", { "one": 2 }) pipeline.set_assembler("org.osbuild.test")