diff --git a/sources/org.osbuild.curl b/sources/org.osbuild.curl index 86ce1eae..230bb84f 100755 --- a/sources/org.osbuild.curl +++ b/sources/org.osbuild.curl @@ -102,29 +102,30 @@ class CurlSource(sources.SourceService): super().__init__(*args, **kwargs) self.subscriptions = None - def transform(self, checksum, desc): - url = desc - if not isinstance(url, dict): - url = {"url": url} + def amend_secrets(self, checksum, desc_or_url): + if not isinstance(desc_or_url, dict): + desc = {"url": desc_or_url} + else: + desc = desc_or_url - # check if url needs rhsm secrets - if url.get("secrets", {}).get("name") == "org.osbuild.rhsm": + # check if desc needs rhsm secrets + if desc.get("secrets", {}).get("name") == "org.osbuild.rhsm": # rhsm secrets only need to be retrieved once and can then be reused if self.subscriptions is None: self.subscriptions = Subscriptions.from_host_system() - url["secrets"] = self.subscriptions.get_secrets(url.get("url")) - elif url.get("secrets", {}).get("name") == "org.osbuild.mtls": + desc["secrets"] = self.subscriptions.get_secrets(desc.get("desc")) + elif desc.get("secrets", {}).get("name") == "org.osbuild.mtls": key = os.getenv("OSBUILD_SOURCES_CURL_SSL_CLIENT_KEY") cert = os.getenv("OSBUILD_SOURCES_CURL_SSL_CLIENT_CERT") if not (key and cert): raise RuntimeError(f"mtls secrets required but key ({key}) or cert ({cert}) not defined") - url["secrets"] = { + desc["secrets"] = { 'ssl_ca_cert': os.getenv("OSBUILD_SOURCES_CURL_SSL_CA_CERT"), 'ssl_client_cert': cert, 'ssl_client_key': key, } - return checksum, url + return checksum, desc @staticmethod def _quote_url(url: str) -> str: @@ -135,10 +136,10 @@ class CurlSource(sources.SourceService): def fetch_all(self, items: Dict) -> None: filtered = filter(lambda i: not self.exists(i[0], i[1]), items.items()) # discards items already in cache - transformed = map(lambda i: self.transform(i[0], i[1]), filtered) # prepare each item to be downloaded + amended = map(lambda i: self.amend_secrets(i[0], i[1]), filtered) with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor: - for _ in executor.map(self.fetch_one, *zip(*transformed)): + for _ in executor.map(self.fetch_one, *zip(*amended)): pass def fetch_one(self, checksum, desc): diff --git a/sources/test/test_curl_source.py b/sources/test/test_curl_source.py index 8ac4c8fd..a65a4018 100644 --- a/sources/test/test_curl_source.py +++ b/sources/test/test_curl_source.py @@ -36,7 +36,7 @@ def test_curl_source_exists(sources_module): assert curl_source.exists(checksum, desc) -def test_curl_source_transform(sources_module): +def test_curl_source_amend_secrets(sources_module): sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) curl_source = sources_module.CurlSource.from_args(["--service-fd", str(sock.fileno())]) tmpdir = tempfile.TemporaryDirectory() @@ -58,13 +58,13 @@ def test_curl_source_transform(sources_module): cm.callback(cb) checksum = "sha256:1234567890123456789012345678901234567890909b14ffb032aa20fa23d9ad6" pathlib.Path(os.path.join(tmpdir.name, checksum)).touch() - new_desc = curl_source.transform(checksum, desc) + new_desc = curl_source.amend_secrets(checksum, desc) assert new_desc[1]["secrets"]["ssl_client_key"] == "key" assert new_desc[1]["secrets"]["ssl_client_cert"] == "cert" assert new_desc[1]["secrets"]["ssl_ca_cert"] is None -def test_curl_source_transform_fail(sources_module): +def test_curl_source_amend_secrets_fail(sources_module): sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) curl_source = sources_module.CurlSource.from_args(["--service-fd", str(sock.fileno())]) tmpdir = tempfile.TemporaryDirectory() @@ -78,5 +78,5 @@ def test_curl_source_transform_fail(sources_module): checksum = "sha256:1234567890123456789012345678901234567890909b14ffb032aa20fa23d9ad6" pathlib.Path(os.path.join(tmpdir.name, checksum)).touch() with pytest.raises(RuntimeError) as exc: - curl_source.transform(checksum, desc) + curl_source.amend_secrets(checksum, desc) assert "mtls secrets required" in str(exc)