testutil: add AtomicCounter() as a threadsafe counter

The existing code in the reqs counting is not really thread safe,
this commit fixes that.
This commit is contained in:
Michael Vogt 2024-04-04 17:18:57 +02:00 committed by Ondřej Budai
parent b90a5027dc
commit b9b296a7e5
4 changed files with 39 additions and 8 deletions

View file

@ -0,0 +1,29 @@
#!/usr/bin/python3
"""
thread/atomic related utilities
"""
import threading
class AtomicCounter:
""" A thread-safe counter """
def __init__(self, count: int = 0) -> None:
self._count = count
self._lock = threading.Lock()
def inc(self) -> None:
""" increase the count """
with self._lock:
self._count += 1
def dec(self) -> None:
""" decrease the count """
with self._lock:
self._count -= 1
@property
def count(self) -> int:
""" get the current count """
with self._lock:
return self._count

View file

@ -7,6 +7,8 @@ import http.server
import socket
import threading
from .atomic import AtomicCounter
def _get_free_port():
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
@ -23,13 +25,13 @@ class DirHTTPServer(http.server.ThreadingHTTPServer):
def __init__(self, *args, directory=None, simulate_failures=0, **kwargs):
super().__init__(*args, **kwargs)
self.directory = directory
self.simulate_failures = simulate_failures
self.reqs = 0
self.simulate_failures = AtomicCounter(simulate_failures)
self.reqs = AtomicCounter()
def finish_request(self, request, client_address):
self.reqs += 1 # racy on non GIL systems
if self.simulate_failures > 0:
self.simulate_failures -= 1 # racy on non GIL systems
self.reqs.inc()
if self.simulate_failures.count > 0:
self.simulate_failures.dec()
SilentHTTPRequestHandler(
request, client_address, self, directory="does-not-exists")
return

View file

@ -130,7 +130,7 @@ def test_curl_download_many_with_retry(tmp_path, sources_service):
sources_service.cache.mkdir()
sources_service.fetch_all(test_sources)
# we simulated N failures and we need to fetch K files
assert httpd.reqs == simulate_failures + len(test_sources)
assert httpd.reqs.count == simulate_failures + len(test_sources)
# double downloads happend in the expected format
for chksum in test_sources:
assert (sources_service.cache / chksum).exists()
@ -165,5 +165,5 @@ def test_curl_download_many_retries(tmp_path, sources_service):
with pytest.raises(RuntimeError) as exp:
sources_service.fetch_all(test_sources)
# curl will retry 10 times
assert httpd.reqs == 10 * len(test_sources)
assert httpd.reqs.count == 10 * len(test_sources)
assert "curl: error downloading http://localhost:" in str(exp.value)

View file

@ -28,7 +28,7 @@ setenv =
LINTABLES_EXCLUDES = "*.json,*.sh"
LINTABLES_EXCLUDES_RE = ".*\.json$,.*\.sh"
TYPEABLES = osbuild
TYPEABLES_STRICT = ./osbuild/main_cli.py ./osbuild/util/parsing.py
TYPEABLES_STRICT = ./osbuild/main_cli.py ./osbuild/util/parsing.py ./osbuild/testutil/atomic.py
passenv =
TEST_CATEGORY