Replace _multiRow, _singleRow, _singleValue with QP

Fixes: https://pagure.io/koji/issue/3581
This commit is contained in:
Jana Cupova 2022-11-15 10:14:29 +01:00 committed by Tomas Kopecek
parent 769ac0e178
commit a8dd469091
7 changed files with 138 additions and 92 deletions

View file

@ -75,7 +75,7 @@ from koji.util import (
multi_fnmatch,
safer_move,
)
from koji.db import (
from koji.db import ( # noqa: F401
BulkInsertProcessor,
DeleteProcessor,
InsertProcessor,
@ -85,11 +85,12 @@ from koji.db import (
_applyQueryOpts,
_dml,
_fetchSingle,
_multiRow,
_singleRow,
_multiRow, # needed for backward compatibility, will removed in Koji 1.36
_singleRow, # needed for backward compatibility, will removed in Koji 1.36
_singleValue,
get_event,
nextval,
currval,
)
@ -736,7 +737,7 @@ def make_task(method, arglist, **opts):
idata['host_id'] = opts['assign']
insert = InsertProcessor('task', data=idata)
insert.execute()
task_id = _singleValue("SELECT currval('task_id_seq')", strict=True)
task_id = currval('task_id_seq')
opts['id'] = task_id
koji.plugin.run_callbacks(
'postTaskStateChange', attribute='state', old=None, new='FREE', info=opts)
@ -2285,15 +2286,15 @@ def get_tag_groups(tag, event=None, inherit=True, incl_pkgs=True, incl_reqs=True
evcondition = eventCondition(event)
# First get the list of groups
fields = ('name', 'group_id', 'tag_id', 'blocked', 'exported', 'display_name',
'is_default', 'uservisible', 'description', 'langonly', 'biarchonly',)
q = """
SELECT %s FROM group_config JOIN groups ON group_id = id
WHERE %s AND tag_id = %%(tagid)s
""" % (",".join(fields), evcondition)
columns = ['name', 'group_id', 'tag_id', 'blocked', 'exported', 'display_name',
'is_default', 'uservisible', 'description', 'langonly', 'biarchonly']
groups = {}
for tagid in taglist:
for group in _multiRow(q, locals(), fields):
query = QueryProcessor(tables=['group_config'], columns=columns,
joins=['groups ON group_id = id'],
clauses=[evcondition, 'tag_id = %(tagid)s'],
values={'tagid': tagid})
for group in query.execute():
grp_id = group['group_id']
# we only take the first entry for group as we go through inheritance
groups.setdefault(grp_id, group)
@ -2301,13 +2302,12 @@ def get_tag_groups(tag, event=None, inherit=True, incl_pkgs=True, incl_reqs=True
if incl_pkgs:
for group in groups.values():
group['packagelist'] = {}
fields = ('group_id', 'tag_id', 'package', 'blocked', 'type', 'basearchonly', 'requires')
q = """
SELECT %s FROM group_package_listing
WHERE %s AND tag_id = %%(tagid)s
""" % (",".join(fields), evcondition)
columns = ['group_id', 'tag_id', 'package', 'blocked', 'type', 'basearchonly', 'requires']
for tagid in taglist:
for grp_pkg in _multiRow(q, locals(), fields):
query = QueryProcessor(tables=['group_package_listing'], columns=columns,
clauses=[evcondition, 'tag_id = %(tagid)s'],
values={'tagid': tagid})
for grp_pkg in query.execute():
grp_id = grp_pkg['group_id']
if grp_id not in groups:
# tag does not have this group
@ -2323,12 +2323,13 @@ def get_tag_groups(tag, event=None, inherit=True, incl_pkgs=True, incl_reqs=True
# and now the group reqs
for group in groups.values():
group['grouplist'] = {}
fields = ('group_id', 'tag_id', 'req_id', 'blocked', 'type', 'is_metapkg', 'name')
q = """SELECT %s FROM group_req_listing JOIN groups on req_id = id
WHERE %s AND tag_id = %%(tagid)s
""" % (",".join(fields), evcondition)
columns = ['group_id', 'tag_id', 'req_id', 'blocked', 'type', 'is_metapkg', 'name']
for tagid in taglist:
for grp_req in _multiRow(q, locals(), fields):
query = QueryProcessor(tables=['group_req_listing'], columns=columns,
joins=['groups on req_id = id'],
clauses=[evcondition, 'tag_id = %(tagid)s'],
values={'tagid': tagid})
for grp_req in query.execute():
grp_id = grp_req['group_id']
if grp_id not in groups:
# tag does not have this group
@ -2992,7 +2993,7 @@ def repo_info(repo_id, strict=False):
:returns: dict (id, state, create_event, creation_time, tag_id, tag_name,
dist)
"""
fields = (
fields = [
('repo.id', 'id'),
('repo.state', 'state'),
('repo.task_id', 'task_id'),
@ -3002,12 +3003,12 @@ def repo_info(repo_id, strict=False):
('repo.tag_id', 'tag_id'),
('tag.name', 'tag_name'),
('repo.dist', 'dist'),
)
q = """SELECT %s FROM repo
JOIN tag ON tag_id=tag.id
JOIN events ON repo.create_event = events.id
WHERE repo.id = %%(repo_id)s""" % ','.join([f[0] for f in fields])
return _singleRow(q, locals(), [f[1] for f in fields], strict=strict)
]
columns, aliases = zip(*fields)
joins = ['tag ON tag_id=tag.id', 'events ON repo.create_event = events.id']
query = QueryProcessor(tables=['repo'], columns=columns, aliases=aliases, joins=joins,
clauses=['repo.id = %(repo_id)s'], values={'repo_id': repo_id})
return query.executeOne(strict=strict)
def repo_ready(repo_id):
@ -7414,9 +7415,9 @@ def add_archive_type(name, description, extensions, compression_type=None):
for ext in extensions.split(' '):
if not ext.replace('.', '').isalnum():
raise koji.GenericError(f'No such {ext} file extension')
select = r"""SELECT id FROM archivetypes
WHERE extensions ~* E'(\\s|^)%s(\\s|$)'""" % ext
results = _multiRow(select, {}, ('id',))
query = QueryProcessor(tables=['archivetypes'], columns=['id'],
clauses=[f"extensions ~* E'(\\s|^){ext}(\\s|$)'"], values={})
results = query.execute()
if len(results) > 0:
raise koji.GenericError(f'file extension {ext} already exists')
insert = InsertProcessor('archivetypes', data=data)
@ -8337,10 +8338,16 @@ def build_references(build_id, limit=None, lazy=False):
ret = {}
# find tags
q = """SELECT tag_id, tag.name FROM tag_listing JOIN tag on tag_id = tag.id
WHERE build_id = %(build_id)i AND active = TRUE"""
ret['tags'] = _multiRow(q, locals(), ('id', 'name'))
fields = {
'tag_id': 'tag_id',
'tag.name': 'name',
}
columns, aliases = zip(*fields.items())
query = QueryProcessor(tables=['tag_listing'], columns=columns, aliases=aliases,
joins=['tag on tag_id = tag.id'],
clauses=['build_id = %(build_id)i', 'active = TRUE'],
values={'build_id': build_id})
ret['tags'] = query.execute()
if lazy and ret['tags']:
return ret
@ -8404,20 +8411,29 @@ def build_references(build_id, limit=None, lazy=False):
return ret
# find archives whose buildroots we were in
fields = ('id', 'type_id', 'type_name', 'build_id', 'filename')
fields = {
'archiveinfo.id': 'id',
'archiveinfo.type_id': 'type_id',
'archivetypes.name': 'type_name',
'archiveinfo.build_id': 'build_id',
'archiveinfo.filename': 'filename',
}
columns, aliases = zip(*fields.items())
idx = {}
q = """SELECT archiveinfo.id, archiveinfo.type_id, archivetypes.name, archiveinfo.build_id,
archiveinfo.filename
FROM buildroot_archives
JOIN archiveinfo ON archiveinfo.buildroot_id = buildroot_archives.buildroot_id
JOIN build ON archiveinfo.build_id = build.id
JOIN archivetypes ON archivetypes.id = archiveinfo.type_id
WHERE buildroot_archives.archive_id = %(archive_id)i
AND build.state = %(st_complete)i"""
opts = {}
if limit is not None:
q += "\nLIMIT %(limit)i"
opts = {'limit': limit}
for (archive_id,) in build_archive_ids:
for row in _multiRow(q, locals(), fields):
query = QueryProcessor(tables=['buildroot_archives'], columns=columns, aliases=aliases,
joins=['archiveinfo ON archiveinfo.buildroot_id = '
'buildroot_archives.buildroot_id',
'build ON archiveinfo.build_id = build.id',
'archivetypes ON archivetypes.id = archiveinfo.type_id'],
clauses=['buildroot_archives.archive_id = %(archive_id)i',
'build.state = %(st_complete)i'],
values={'archive_id': archive_id, 'st_complete': st_complete},
opts=opts)
for row in query.execute():
idx.setdefault(row['id'], row)
if limit is not None and len(idx) > limit:
break
@ -10351,10 +10367,14 @@ class RootExports(object):
If no event with the given id exists, an error will be raised.
"""
fields = ('id', 'ts')
values = {'id': id}
q = """SELECT id, date_part('epoch', time) FROM events WHERE id = %(id)i"""
return _singleRow(q, values, fields, strict=True)
fields = {
'id': 'id',
"date_part('epoch', time)": 'ts',
}
columns, aliases = zip(*fields.items())
query = QueryProcessor(tables=['events'], columns=columns, aliases=aliases,
clauses=['id = %(id)i'], values={'id': id})
return query.executeOne(strict=True)
def getLastEvent(self, before=None):
"""
@ -10373,18 +10393,24 @@ class RootExports(object):
When trying to find information about a specific event, the getEvent() method
should be used.
"""
fields = ('id', 'ts')
fields = {
'id': 'id',
"date_part('epoch', time)": 'ts',
}
columns, aliases = zip(*fields.items())
values = {}
q = """SELECT id, date_part('epoch', time) FROM events"""
clauses = []
if before is not None:
if not isinstance(before, NUMERIC_TYPES):
raise koji.GenericError('Invalid type for before: %s' % type(before))
# use the repr() conversion because it retains more precision than the
# string conversion
q += """ WHERE date_part('epoch', time) < %(before)r"""
clauses = ["date_part('epoch', time) < %(before)r"]
values['before'] = before
q += """ ORDER BY id DESC LIMIT 1"""
return _singleRow(q, values, fields, strict=True)
opts = {'order': '-id', 'limit': 1}
query = QueryProcessor(tables=['events'], columns=columns, aliases=aliases,
clauses=clauses, values=values, opts=opts)
return query.executeOne(strict=True)
evalPolicy = staticmethod(eval_policy)
@ -11904,19 +11930,17 @@ class RootExports(object):
st_complete = koji.BUILD_STATES['COMPLETE']
# we need to filter out builds without tasks (imports) as they'll reduce
# time average. CG imported builds often contain *_koji_task_id instead.
query = """SELECT date_part('epoch', avg(build.completion_time - events.time))
FROM build
JOIN events ON build.create_event = events.id
WHERE build.pkg_id = %(packageID)i
AND build.state = %(st_complete)i
AND (
build.task_id IS NOT NULL OR
build.extra LIKE '%%' || 'koji_task_id' || '%%'
)"""
clauses = ['build.pkg_id = %(packageID)i', 'build.state = %(st_complete)i',
"build.task_id IS NOT NULL OR build.extra LIKE '%' || 'koji_task_id' || '%'"]
if age is not None:
query += " AND build.completion_time > NOW() - '%s months'::interval" % int(age)
return _singleValue(query, locals())
clauses.append(f"build.completion_time > NOW() - '{int(age)} months'::interval")
query = QueryProcessor(tables=['build'],
columns=["date_part('epoch', "
"avg(build.completion_time - events.time))"],
joins=['events ON build.create_event = events.id'],
clauses=clauses,
values={'packageID': packageID, 'st_complete': st_complete})
return query.singleValue()
packageListAdd = staticmethod(pkglist_add)
packageListRemove = staticmethod(pkglist_remove)
@ -13095,14 +13119,13 @@ class RootExports(object):
The timestamp represents the last time the host with the given
ID contacted the hub. Returns None if the host has never contacted
the hub."""
query = """SELECT update_time FROM sessions
JOIN host ON sessions.user_id = host.user_id
WHERE host.id = %(hostID)i
ORDER BY update_time DESC
LIMIT 1
"""
date = _singleValue(query, locals(), strict=False)
the hub."""
opts = {'order': '-update_time', 'limit': 1}
query = QueryProcessor(tables=['sessions'], columns=['update_time'],
joins=['host ON sessions.user_id = host.user_id'],
clauses=['host.id = %(hostID)i'], values={'hostID': hostID},
opts=opts)
date = query.singleValue(strict=False)
if ts and date is not None:
return date.timestamp()
else:

View file

@ -315,6 +315,12 @@ def nextval(sequence):
return _singleValue("SELECT nextval(%(sequence)s)", data, strict=True)
def currval(sequence):
"""Get the current value for the given sequence"""
data = {'sequence': sequence}
return _singleValue("SELECT currval(%(sequence)s)", data, strict=True)
class Savepoint(object):
def __init__(self, name):

View file

@ -6,10 +6,10 @@ import koji
import kojihub
IP = kojihub.InsertProcessor
QP = kojihub.QueryProcessor
class TestAddArchiveType(unittest.TestCase):
def setUp(self):
self.context = mock.patch('kojihub.context').start()
@ -23,7 +23,10 @@ class TestAddArchiveType(unittest.TestCase):
self.insert_execute = mock.MagicMock()
self.verify_name_internal = mock.patch('kojihub.verify_name_internal').start()
self.get_archive_type = mock.patch('kojihub.get_archive_type').start()
self._multiRow = mock.patch('kojihub._multiRow').start()
self.QueryProcessor = mock.patch('kojihub.QueryProcessor',
side_effect=self.getQuery).start()
self.queries = []
self.query_execute = mock.MagicMock()
def tearDown(self):
mock.patch.stopall()
@ -34,7 +37,14 @@ class TestAddArchiveType(unittest.TestCase):
self.inserts.append(insert)
return insert
def getQuery(self, *args, **kwargs):
query = QP(*args, **kwargs)
query.execute = self.query_execute
self.queries.append(query)
return query
def test_add_archive_type_valid_empty_compression_type(self):
self.query_execute.side_effect = [[]]
self.verify_name_internal.return_value = None
self.get_archive_type.return_value = None
kojihub.add_archive_type('deb', 'Debian package', 'deb')
@ -50,6 +60,7 @@ class TestAddArchiveType(unittest.TestCase):
self.context.session.assertPerm.assert_called_with('admin')
def test_add_archive_type_valid_with_compression_type(self):
self.query_execute.side_effect = [[]]
self.verify_name_internal.return_value = None
self.get_archive_type.return_value = None
kojihub.add_archive_type('jar', 'Jar package', 'jar', 'zip')

View file

@ -5,14 +5,24 @@ import sys
import kojihub
QP = kojihub.QueryProcessor
class TestGetLastHostUpdate(unittest.TestCase):
def getQuery(self, *args, **kwargs):
query = QP(*args, **kwargs)
query.singleValue = self.query_singleValue
self.queries.append(query)
return query
def setUp(self):
self.exports = kojihub.RootExports()
self.QueryProcessor = mock.patch('kojihub.QueryProcessor',
side_effect=self.getQuery).start()
self.queries = []
self.query_singleValue = mock.MagicMock()
@mock.patch('kojihub._singleValue')
def test_valid_ts(self, _singleValue):
def test_valid_ts(self):
expected = 1615875554.862938
if sys.version_info[1] <= 6:
dt = datetime.datetime.strptime(
@ -20,19 +30,17 @@ class TestGetLastHostUpdate(unittest.TestCase):
else:
dt = datetime.datetime.strptime(
"2021-03-16T06:19:14.862938+00:00", "%Y-%m-%dT%H:%M:%S.%f%z")
_singleValue.return_value = dt
self.query_singleValue.return_value = dt
rv = self.exports.getLastHostUpdate(1, ts=True)
self.assertEqual(rv, expected)
@mock.patch('kojihub._singleValue')
def test_valid_datetime(self, _singleValue):
def test_valid_datetime(self):
if sys.version_info[1] <= 6:
dt = datetime.datetime.strptime(
"2021-03-16T06:19:14.862938+0000", "%Y-%m-%dT%H:%M:%S.%f%z")
else:
dt = datetime.datetime.strptime(
"2021-03-16T06:19:14.862938+00:00", "%Y-%m-%dT%H:%M:%S.%f%z")
expected = "2021-03-16T06:19:14.862938+00:00"
_singleValue.return_value = dt
self.query_singleValue.return_value = dt
rv = self.exports.getLastHostUpdate(1)
self.assertEqual(rv, dt)

View file

@ -23,7 +23,6 @@ class TestNewBuild(unittest.TestCase):
self.get_build = mock.patch('kojihub.get_build').start()
self.recycle_build = mock.patch('kojihub.recycle_build').start()
self.context = mock.patch('kojihub.context').start()
self._singleValue = mock.patch('kojihub._singleValue').start()
def tearDown(self):
mock.patch.stopall()

View file

@ -27,6 +27,7 @@ class TestRepoFunctions(unittest.TestCase):
self._dml = mock.patch('kojihub._dml').start()
self.exports = kojihub.RootExports()
self.get_tag = mock.patch('kojihub.get_tag').start()
self.query_executeOne = mock.MagicMock()
def tearDown(self):
mock.patch.stopall()
@ -34,6 +35,7 @@ class TestRepoFunctions(unittest.TestCase):
def getQuery(self, *args, **kwargs):
query = QP(*args, **kwargs)
query.execute = mock.MagicMock()
query.executeOne = self.query_executeOne
self.queries.append(query)
return query
@ -78,8 +80,7 @@ class TestRepoFunctions(unittest.TestCase):
if 'dist = %(dist)s' not in update.clauses:
raise Exception('Missing dist condition')
@mock.patch('kojihub._singleRow')
def test_repo_info(self, _singleRow):
def test_repo_info(self):
repo_row = {'id': 10,
'state': 0,
'task_id': 15,
@ -90,7 +91,7 @@ class TestRepoFunctions(unittest.TestCase):
'tag_id': 3,
'tag_name': 'test-tag',
'dist': False}
_singleRow.return_value = repo_row
self.query_executeOne.return_value = repo_row
rv = kojihub.repo_info(3)
self.assertEqual(rv, repo_row)

View file

@ -259,8 +259,7 @@ class TestGrouplist(unittest.TestCase):
self.assertEqual(u.values['user_id'], uid)
self.assertEqual(u.values['group_id'], gid)
@mock.patch('kojihub._multiRow')
def test_get_group_members(self, _multiRow):
def test_get_group_members(self):
group, gid = 'test_group', 1
# no permission
@ -300,4 +299,3 @@ class TestGrouplist(unittest.TestCase):
self.assertEqual(len(self.queries), 1)
self.assertEqual(len(self.inserts), 0)
self.assertEqual(len(self.updates), 0)
_multiRow.assert_not_called()