From a8dd46909157f1b559799db67d50431c3e54c6e7 Mon Sep 17 00:00:00 2001 From: Jana Cupova Date: Tue, 15 Nov 2022 10:14:29 +0100 Subject: [PATCH] Replace _multiRow, _singleRow, _singleValue with QP Fixes: https://pagure.io/koji/issue/3581 --- hub/kojihub.py | 175 +++++++++++--------- koji/db.py | 6 + tests/test_hub/test_add_archivetype.py | 15 +- tests/test_hub/test_get_last_host_update.py | 22 ++- tests/test_hub/test_new_build.py | 1 - tests/test_hub/test_repos.py | 7 +- tests/test_hub/test_user_groups.py | 4 +- 7 files changed, 138 insertions(+), 92 deletions(-) diff --git a/hub/kojihub.py b/hub/kojihub.py index d1276727..1561b92b 100644 --- a/hub/kojihub.py +++ b/hub/kojihub.py @@ -75,7 +75,7 @@ from koji.util import ( multi_fnmatch, safer_move, ) -from koji.db import ( +from koji.db import ( # noqa: F401 BulkInsertProcessor, DeleteProcessor, InsertProcessor, @@ -85,11 +85,12 @@ from koji.db import ( _applyQueryOpts, _dml, _fetchSingle, - _multiRow, - _singleRow, + _multiRow, # needed for backward compatibility, will removed in Koji 1.36 + _singleRow, # needed for backward compatibility, will removed in Koji 1.36 _singleValue, get_event, nextval, + currval, ) @@ -736,7 +737,7 @@ def make_task(method, arglist, **opts): idata['host_id'] = opts['assign'] insert = InsertProcessor('task', data=idata) insert.execute() - task_id = _singleValue("SELECT currval('task_id_seq')", strict=True) + task_id = currval('task_id_seq') opts['id'] = task_id koji.plugin.run_callbacks( 'postTaskStateChange', attribute='state', old=None, new='FREE', info=opts) @@ -2285,15 +2286,15 @@ def get_tag_groups(tag, event=None, inherit=True, incl_pkgs=True, incl_reqs=True evcondition = eventCondition(event) # First get the list of groups - fields = ('name', 'group_id', 'tag_id', 'blocked', 'exported', 'display_name', - 'is_default', 'uservisible', 'description', 'langonly', 'biarchonly',) - q = """ - SELECT %s FROM group_config JOIN groups ON group_id = id - WHERE %s AND tag_id = %%(tagid)s - """ % (",".join(fields), evcondition) + columns = ['name', 'group_id', 'tag_id', 'blocked', 'exported', 'display_name', + 'is_default', 'uservisible', 'description', 'langonly', 'biarchonly'] groups = {} for tagid in taglist: - for group in _multiRow(q, locals(), fields): + query = QueryProcessor(tables=['group_config'], columns=columns, + joins=['groups ON group_id = id'], + clauses=[evcondition, 'tag_id = %(tagid)s'], + values={'tagid': tagid}) + for group in query.execute(): grp_id = group['group_id'] # we only take the first entry for group as we go through inheritance groups.setdefault(grp_id, group) @@ -2301,13 +2302,12 @@ def get_tag_groups(tag, event=None, inherit=True, incl_pkgs=True, incl_reqs=True if incl_pkgs: for group in groups.values(): group['packagelist'] = {} - fields = ('group_id', 'tag_id', 'package', 'blocked', 'type', 'basearchonly', 'requires') - q = """ - SELECT %s FROM group_package_listing - WHERE %s AND tag_id = %%(tagid)s - """ % (",".join(fields), evcondition) + columns = ['group_id', 'tag_id', 'package', 'blocked', 'type', 'basearchonly', 'requires'] for tagid in taglist: - for grp_pkg in _multiRow(q, locals(), fields): + query = QueryProcessor(tables=['group_package_listing'], columns=columns, + clauses=[evcondition, 'tag_id = %(tagid)s'], + values={'tagid': tagid}) + for grp_pkg in query.execute(): grp_id = grp_pkg['group_id'] if grp_id not in groups: # tag does not have this group @@ -2323,12 +2323,13 @@ def get_tag_groups(tag, event=None, inherit=True, incl_pkgs=True, incl_reqs=True # and now the group reqs for group in groups.values(): group['grouplist'] = {} - fields = ('group_id', 'tag_id', 'req_id', 'blocked', 'type', 'is_metapkg', 'name') - q = """SELECT %s FROM group_req_listing JOIN groups on req_id = id - WHERE %s AND tag_id = %%(tagid)s - """ % (",".join(fields), evcondition) + columns = ['group_id', 'tag_id', 'req_id', 'blocked', 'type', 'is_metapkg', 'name'] for tagid in taglist: - for grp_req in _multiRow(q, locals(), fields): + query = QueryProcessor(tables=['group_req_listing'], columns=columns, + joins=['groups on req_id = id'], + clauses=[evcondition, 'tag_id = %(tagid)s'], + values={'tagid': tagid}) + for grp_req in query.execute(): grp_id = grp_req['group_id'] if grp_id not in groups: # tag does not have this group @@ -2992,7 +2993,7 @@ def repo_info(repo_id, strict=False): :returns: dict (id, state, create_event, creation_time, tag_id, tag_name, dist) """ - fields = ( + fields = [ ('repo.id', 'id'), ('repo.state', 'state'), ('repo.task_id', 'task_id'), @@ -3002,12 +3003,12 @@ def repo_info(repo_id, strict=False): ('repo.tag_id', 'tag_id'), ('tag.name', 'tag_name'), ('repo.dist', 'dist'), - ) - q = """SELECT %s FROM repo - JOIN tag ON tag_id=tag.id - JOIN events ON repo.create_event = events.id - WHERE repo.id = %%(repo_id)s""" % ','.join([f[0] for f in fields]) - return _singleRow(q, locals(), [f[1] for f in fields], strict=strict) + ] + columns, aliases = zip(*fields) + joins = ['tag ON tag_id=tag.id', 'events ON repo.create_event = events.id'] + query = QueryProcessor(tables=['repo'], columns=columns, aliases=aliases, joins=joins, + clauses=['repo.id = %(repo_id)s'], values={'repo_id': repo_id}) + return query.executeOne(strict=strict) def repo_ready(repo_id): @@ -7414,9 +7415,9 @@ def add_archive_type(name, description, extensions, compression_type=None): for ext in extensions.split(' '): if not ext.replace('.', '').isalnum(): raise koji.GenericError(f'No such {ext} file extension') - select = r"""SELECT id FROM archivetypes - WHERE extensions ~* E'(\\s|^)%s(\\s|$)'""" % ext - results = _multiRow(select, {}, ('id',)) + query = QueryProcessor(tables=['archivetypes'], columns=['id'], + clauses=[f"extensions ~* E'(\\s|^){ext}(\\s|$)'"], values={}) + results = query.execute() if len(results) > 0: raise koji.GenericError(f'file extension {ext} already exists') insert = InsertProcessor('archivetypes', data=data) @@ -8337,10 +8338,16 @@ def build_references(build_id, limit=None, lazy=False): ret = {} # find tags - q = """SELECT tag_id, tag.name FROM tag_listing JOIN tag on tag_id = tag.id - WHERE build_id = %(build_id)i AND active = TRUE""" - ret['tags'] = _multiRow(q, locals(), ('id', 'name')) - + fields = { + 'tag_id': 'tag_id', + 'tag.name': 'name', + } + columns, aliases = zip(*fields.items()) + query = QueryProcessor(tables=['tag_listing'], columns=columns, aliases=aliases, + joins=['tag on tag_id = tag.id'], + clauses=['build_id = %(build_id)i', 'active = TRUE'], + values={'build_id': build_id}) + ret['tags'] = query.execute() if lazy and ret['tags']: return ret @@ -8404,20 +8411,29 @@ def build_references(build_id, limit=None, lazy=False): return ret # find archives whose buildroots we were in - fields = ('id', 'type_id', 'type_name', 'build_id', 'filename') + fields = { + 'archiveinfo.id': 'id', + 'archiveinfo.type_id': 'type_id', + 'archivetypes.name': 'type_name', + 'archiveinfo.build_id': 'build_id', + 'archiveinfo.filename': 'filename', + } + columns, aliases = zip(*fields.items()) idx = {} - q = """SELECT archiveinfo.id, archiveinfo.type_id, archivetypes.name, archiveinfo.build_id, - archiveinfo.filename - FROM buildroot_archives - JOIN archiveinfo ON archiveinfo.buildroot_id = buildroot_archives.buildroot_id - JOIN build ON archiveinfo.build_id = build.id - JOIN archivetypes ON archivetypes.id = archiveinfo.type_id - WHERE buildroot_archives.archive_id = %(archive_id)i - AND build.state = %(st_complete)i""" + opts = {} if limit is not None: - q += "\nLIMIT %(limit)i" + opts = {'limit': limit} for (archive_id,) in build_archive_ids: - for row in _multiRow(q, locals(), fields): + query = QueryProcessor(tables=['buildroot_archives'], columns=columns, aliases=aliases, + joins=['archiveinfo ON archiveinfo.buildroot_id = ' + 'buildroot_archives.buildroot_id', + 'build ON archiveinfo.build_id = build.id', + 'archivetypes ON archivetypes.id = archiveinfo.type_id'], + clauses=['buildroot_archives.archive_id = %(archive_id)i', + 'build.state = %(st_complete)i'], + values={'archive_id': archive_id, 'st_complete': st_complete}, + opts=opts) + for row in query.execute(): idx.setdefault(row['id'], row) if limit is not None and len(idx) > limit: break @@ -10351,10 +10367,14 @@ class RootExports(object): If no event with the given id exists, an error will be raised. """ - fields = ('id', 'ts') - values = {'id': id} - q = """SELECT id, date_part('epoch', time) FROM events WHERE id = %(id)i""" - return _singleRow(q, values, fields, strict=True) + fields = { + 'id': 'id', + "date_part('epoch', time)": 'ts', + } + columns, aliases = zip(*fields.items()) + query = QueryProcessor(tables=['events'], columns=columns, aliases=aliases, + clauses=['id = %(id)i'], values={'id': id}) + return query.executeOne(strict=True) def getLastEvent(self, before=None): """ @@ -10373,18 +10393,24 @@ class RootExports(object): When trying to find information about a specific event, the getEvent() method should be used. """ - fields = ('id', 'ts') + fields = { + 'id': 'id', + "date_part('epoch', time)": 'ts', + } + columns, aliases = zip(*fields.items()) values = {} - q = """SELECT id, date_part('epoch', time) FROM events""" + clauses = [] if before is not None: if not isinstance(before, NUMERIC_TYPES): raise koji.GenericError('Invalid type for before: %s' % type(before)) # use the repr() conversion because it retains more precision than the # string conversion - q += """ WHERE date_part('epoch', time) < %(before)r""" + clauses = ["date_part('epoch', time) < %(before)r"] values['before'] = before - q += """ ORDER BY id DESC LIMIT 1""" - return _singleRow(q, values, fields, strict=True) + opts = {'order': '-id', 'limit': 1} + query = QueryProcessor(tables=['events'], columns=columns, aliases=aliases, + clauses=clauses, values=values, opts=opts) + return query.executeOne(strict=True) evalPolicy = staticmethod(eval_policy) @@ -11904,19 +11930,17 @@ class RootExports(object): st_complete = koji.BUILD_STATES['COMPLETE'] # we need to filter out builds without tasks (imports) as they'll reduce # time average. CG imported builds often contain *_koji_task_id instead. - query = """SELECT date_part('epoch', avg(build.completion_time - events.time)) - FROM build - JOIN events ON build.create_event = events.id - WHERE build.pkg_id = %(packageID)i - AND build.state = %(st_complete)i - AND ( - build.task_id IS NOT NULL OR - build.extra LIKE '%%' || 'koji_task_id' || '%%' - )""" + clauses = ['build.pkg_id = %(packageID)i', 'build.state = %(st_complete)i', + "build.task_id IS NOT NULL OR build.extra LIKE '%' || 'koji_task_id' || '%'"] if age is not None: - query += " AND build.completion_time > NOW() - '%s months'::interval" % int(age) - - return _singleValue(query, locals()) + clauses.append(f"build.completion_time > NOW() - '{int(age)} months'::interval") + query = QueryProcessor(tables=['build'], + columns=["date_part('epoch', " + "avg(build.completion_time - events.time))"], + joins=['events ON build.create_event = events.id'], + clauses=clauses, + values={'packageID': packageID, 'st_complete': st_complete}) + return query.singleValue() packageListAdd = staticmethod(pkglist_add) packageListRemove = staticmethod(pkglist_remove) @@ -13095,14 +13119,13 @@ class RootExports(object): The timestamp represents the last time the host with the given ID contacted the hub. Returns None if the host has never contacted - the hub.""" - query = """SELECT update_time FROM sessions - JOIN host ON sessions.user_id = host.user_id - WHERE host.id = %(hostID)i - ORDER BY update_time DESC - LIMIT 1 - """ - date = _singleValue(query, locals(), strict=False) + the hub.""" + opts = {'order': '-update_time', 'limit': 1} + query = QueryProcessor(tables=['sessions'], columns=['update_time'], + joins=['host ON sessions.user_id = host.user_id'], + clauses=['host.id = %(hostID)i'], values={'hostID': hostID}, + opts=opts) + date = query.singleValue(strict=False) if ts and date is not None: return date.timestamp() else: diff --git a/koji/db.py b/koji/db.py index 2e08a04e..478c30fa 100644 --- a/koji/db.py +++ b/koji/db.py @@ -315,6 +315,12 @@ def nextval(sequence): return _singleValue("SELECT nextval(%(sequence)s)", data, strict=True) +def currval(sequence): + """Get the current value for the given sequence""" + data = {'sequence': sequence} + return _singleValue("SELECT currval(%(sequence)s)", data, strict=True) + + class Savepoint(object): def __init__(self, name): diff --git a/tests/test_hub/test_add_archivetype.py b/tests/test_hub/test_add_archivetype.py index 41dd944a..e37e0d9c 100644 --- a/tests/test_hub/test_add_archivetype.py +++ b/tests/test_hub/test_add_archivetype.py @@ -6,10 +6,10 @@ import koji import kojihub IP = kojihub.InsertProcessor +QP = kojihub.QueryProcessor class TestAddArchiveType(unittest.TestCase): - def setUp(self): self.context = mock.patch('kojihub.context').start() @@ -23,7 +23,10 @@ class TestAddArchiveType(unittest.TestCase): self.insert_execute = mock.MagicMock() self.verify_name_internal = mock.patch('kojihub.verify_name_internal').start() self.get_archive_type = mock.patch('kojihub.get_archive_type').start() - self._multiRow = mock.patch('kojihub._multiRow').start() + self.QueryProcessor = mock.patch('kojihub.QueryProcessor', + side_effect=self.getQuery).start() + self.queries = [] + self.query_execute = mock.MagicMock() def tearDown(self): mock.patch.stopall() @@ -34,7 +37,14 @@ class TestAddArchiveType(unittest.TestCase): self.inserts.append(insert) return insert + def getQuery(self, *args, **kwargs): + query = QP(*args, **kwargs) + query.execute = self.query_execute + self.queries.append(query) + return query + def test_add_archive_type_valid_empty_compression_type(self): + self.query_execute.side_effect = [[]] self.verify_name_internal.return_value = None self.get_archive_type.return_value = None kojihub.add_archive_type('deb', 'Debian package', 'deb') @@ -50,6 +60,7 @@ class TestAddArchiveType(unittest.TestCase): self.context.session.assertPerm.assert_called_with('admin') def test_add_archive_type_valid_with_compression_type(self): + self.query_execute.side_effect = [[]] self.verify_name_internal.return_value = None self.get_archive_type.return_value = None kojihub.add_archive_type('jar', 'Jar package', 'jar', 'zip') diff --git a/tests/test_hub/test_get_last_host_update.py b/tests/test_hub/test_get_last_host_update.py index ee2a6b63..46827d16 100644 --- a/tests/test_hub/test_get_last_host_update.py +++ b/tests/test_hub/test_get_last_host_update.py @@ -5,14 +5,24 @@ import sys import kojihub +QP = kojihub.QueryProcessor + class TestGetLastHostUpdate(unittest.TestCase): + def getQuery(self, *args, **kwargs): + query = QP(*args, **kwargs) + query.singleValue = self.query_singleValue + self.queries.append(query) + return query def setUp(self): self.exports = kojihub.RootExports() + self.QueryProcessor = mock.patch('kojihub.QueryProcessor', + side_effect=self.getQuery).start() + self.queries = [] + self.query_singleValue = mock.MagicMock() - @mock.patch('kojihub._singleValue') - def test_valid_ts(self, _singleValue): + def test_valid_ts(self): expected = 1615875554.862938 if sys.version_info[1] <= 6: dt = datetime.datetime.strptime( @@ -20,19 +30,17 @@ class TestGetLastHostUpdate(unittest.TestCase): else: dt = datetime.datetime.strptime( "2021-03-16T06:19:14.862938+00:00", "%Y-%m-%dT%H:%M:%S.%f%z") - _singleValue.return_value = dt + self.query_singleValue.return_value = dt rv = self.exports.getLastHostUpdate(1, ts=True) self.assertEqual(rv, expected) - @mock.patch('kojihub._singleValue') - def test_valid_datetime(self, _singleValue): + def test_valid_datetime(self): if sys.version_info[1] <= 6: dt = datetime.datetime.strptime( "2021-03-16T06:19:14.862938+0000", "%Y-%m-%dT%H:%M:%S.%f%z") else: dt = datetime.datetime.strptime( "2021-03-16T06:19:14.862938+00:00", "%Y-%m-%dT%H:%M:%S.%f%z") - expected = "2021-03-16T06:19:14.862938+00:00" - _singleValue.return_value = dt + self.query_singleValue.return_value = dt rv = self.exports.getLastHostUpdate(1) self.assertEqual(rv, dt) diff --git a/tests/test_hub/test_new_build.py b/tests/test_hub/test_new_build.py index bfc8e1c6..3b215b2a 100644 --- a/tests/test_hub/test_new_build.py +++ b/tests/test_hub/test_new_build.py @@ -23,7 +23,6 @@ class TestNewBuild(unittest.TestCase): self.get_build = mock.patch('kojihub.get_build').start() self.recycle_build = mock.patch('kojihub.recycle_build').start() self.context = mock.patch('kojihub.context').start() - self._singleValue = mock.patch('kojihub._singleValue').start() def tearDown(self): mock.patch.stopall() diff --git a/tests/test_hub/test_repos.py b/tests/test_hub/test_repos.py index 2e2faf41..97f7fbd7 100644 --- a/tests/test_hub/test_repos.py +++ b/tests/test_hub/test_repos.py @@ -27,6 +27,7 @@ class TestRepoFunctions(unittest.TestCase): self._dml = mock.patch('kojihub._dml').start() self.exports = kojihub.RootExports() self.get_tag = mock.patch('kojihub.get_tag').start() + self.query_executeOne = mock.MagicMock() def tearDown(self): mock.patch.stopall() @@ -34,6 +35,7 @@ class TestRepoFunctions(unittest.TestCase): def getQuery(self, *args, **kwargs): query = QP(*args, **kwargs) query.execute = mock.MagicMock() + query.executeOne = self.query_executeOne self.queries.append(query) return query @@ -78,8 +80,7 @@ class TestRepoFunctions(unittest.TestCase): if 'dist = %(dist)s' not in update.clauses: raise Exception('Missing dist condition') - @mock.patch('kojihub._singleRow') - def test_repo_info(self, _singleRow): + def test_repo_info(self): repo_row = {'id': 10, 'state': 0, 'task_id': 15, @@ -90,7 +91,7 @@ class TestRepoFunctions(unittest.TestCase): 'tag_id': 3, 'tag_name': 'test-tag', 'dist': False} - _singleRow.return_value = repo_row + self.query_executeOne.return_value = repo_row rv = kojihub.repo_info(3) self.assertEqual(rv, repo_row) diff --git a/tests/test_hub/test_user_groups.py b/tests/test_hub/test_user_groups.py index 78c09bb8..2dd67d35 100644 --- a/tests/test_hub/test_user_groups.py +++ b/tests/test_hub/test_user_groups.py @@ -259,8 +259,7 @@ class TestGrouplist(unittest.TestCase): self.assertEqual(u.values['user_id'], uid) self.assertEqual(u.values['group_id'], gid) - @mock.patch('kojihub._multiRow') - def test_get_group_members(self, _multiRow): + def test_get_group_members(self): group, gid = 'test_group', 1 # no permission @@ -300,4 +299,3 @@ class TestGrouplist(unittest.TestCase): self.assertEqual(len(self.queries), 1) self.assertEqual(len(self.inserts), 0) self.assertEqual(len(self.updates), 0) - _multiRow.assert_not_called()