PR#4224: Match longest extension first for archivetype

Merges #4224
https://pagure.io/koji/pull-request/4224

Fixes: #4291
https://pagure.io/koji/issue/4291
Match longest extension first for archivetype
This commit is contained in:
Tomas Kopecek 2025-04-29 16:22:13 +02:00
commit fd1c383909
2 changed files with 26 additions and 10 deletions

View file

@ -7807,15 +7807,18 @@ def get_archive_type(filename=None, type_name=None, type_id=None, strict=False):
else:
raise koji.GenericError('one of filename, type_name, or type_id must be specified')
parts = filename.split('.')
# otherwise match the filename
query = QueryProcessor(
tables=['archivetypes'],
columns=['id', 'name', 'description', 'extensions', 'compression_type'],
clauses=['extensions ~* %(pattern)s'],
clauses=[r"%(ext)s IN (SELECT lower(s)"
r" FROM unnest(regexp_split_to_array(extensions, '\s+')) AS s)"],
)
for start in range(len(parts) - 1, -1, -1):
# match longest extension first. e.g. .tar.gz before .gz
parts = filename.lower().split('.')
for start in range(len(parts)):
ext = '.'.join(parts[start:])
query.values['pattern'] = r'(\s|^)%s(\s|$)' % ext
query.values['ext'] = ext
results = query.execute()
if len(results) == 1:
@ -7825,7 +7828,7 @@ def get_archive_type(filename=None, type_name=None, type_id=None, strict=False):
raise koji.GenericError('multiple matches for file extension: %s' % ext)
# otherwise
if strict:
raise koji.GenericError('unsupported file extension: %s' % ext)
raise koji.GenericError('unsupported file extension: %s' % filename)
else:
return None

View file

@ -61,7 +61,8 @@ class TestGetArchiveType(DBQueryTestCase):
archive_info = [{'id': 1, 'name': 'archive-type-1', 'extensions': 'ext'},
{'id': 2, 'name': 'archive-type-2', 'extensions': 'ext'}]
filename = 'test-filename.ext'
self.qp_execute_return_value = archive_info
self.qp_execute_side_effect = [[], archive_info]
# no matches for full name, multiple matches for .ext
with self.assertRaises(koji.GenericError) as ex:
kojihub.get_archive_type(filename=filename)
self.assertEqual("multiple matches for file extension: ext", str(ex.exception))
@ -70,7 +71,10 @@ class TestGetArchiveType(DBQueryTestCase):
query = self.queries[0]
self.assertEqual(query.tables, ['archivetypes'])
self.assertEqual(query.joins, None)
self.assertEqual(query.clauses, ['extensions ~* %(pattern)s'])
_clauses = [
"%(ext)s IN (SELECT lower(s) FROM "
"unnest(regexp_split_to_array(extensions, '\\s+')) AS s)"]
self.assertEqual(query.clauses, _clauses)
self.assertEqual(query.columns,
['compression_type', 'description', 'extensions', 'id', 'name'])
get_archive_type_by_name.assert_not_called()
@ -91,7 +95,10 @@ class TestGetArchiveType(DBQueryTestCase):
query = self.queries[0]
self.assertEqual(query.tables, ['archivetypes'])
self.assertEqual(query.joins, None)
self.assertEqual(query.clauses, ['extensions ~* %(pattern)s'])
_clauses = [
"%(ext)s IN (SELECT lower(s) FROM "
"unnest(regexp_split_to_array(extensions, '\\s+')) AS s)"]
self.assertEqual(query.clauses, _clauses)
self.assertEqual(query.columns,
['compression_type', 'description', 'extensions', 'id', 'name'])
get_archive_type_by_name.assert_not_called()
@ -111,7 +118,10 @@ class TestGetArchiveType(DBQueryTestCase):
query = self.queries[0]
self.assertEqual(query.tables, ['archivetypes'])
self.assertEqual(query.joins, None)
self.assertEqual(query.clauses, ['extensions ~* %(pattern)s'])
_clauses = [
"%(ext)s IN (SELECT lower(s) FROM "
"unnest(regexp_split_to_array(extensions, '\\s+')) AS s)"]
self.assertEqual(query.clauses, _clauses)
self.assertEqual(query.columns,
['compression_type', 'description', 'extensions', 'id', 'name'])
get_archive_type_by_name.assert_not_called()
@ -129,7 +139,10 @@ class TestGetArchiveType(DBQueryTestCase):
query = self.queries[0]
self.assertEqual(query.tables, ['archivetypes'])
self.assertEqual(query.joins, None)
self.assertEqual(query.clauses, ['extensions ~* %(pattern)s'])
_clauses = [
"%(ext)s IN (SELECT lower(s) FROM "
"unnest(regexp_split_to_array(extensions, '\\s+')) AS s)"]
self.assertEqual(query.clauses, _clauses)
self.assertEqual(query.columns,
['compression_type', 'description', 'extensions', 'id', 'name'])
get_archive_type_by_name.assert_not_called()