From 1b3990cec98d699b72bbe59b1e3daa57fa30a56f Mon Sep 17 00:00:00 2001 From: Mike McLean Date: Thu, 10 Oct 2024 11:13:14 -0400 Subject: [PATCH] simplify extension match this avoids errors if the ext value contains special characters --- kojihub/kojihub.py | 4 ++-- tests/test_hub/test_get_archive_type.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/kojihub/kojihub.py b/kojihub/kojihub.py index ce61cebd..96ab17e3 100644 --- a/kojihub/kojihub.py +++ b/kojihub/kojihub.py @@ -7811,13 +7811,13 @@ def get_archive_type(filename=None, type_name=None, type_id=None, strict=False): query = QueryProcessor( tables=['archivetypes'], columns=['id', 'name', 'description', 'extensions', 'compression_type'], - clauses=['extensions ~* %(pattern)s'], + clauses=[r"%(ext)s = ANY(regexp_split_to_array(extensions, '\s+'))"], ) # match longest extension first. e.g. .tar.gz before .gz parts = filename.split('.') for start in range(len(parts)): ext = '.'.join(parts[start:]) - query.values['pattern'] = r'(\s|^)%s(\s|$)' % ext + query.values['ext'] = ext results = query.execute() if len(results) == 1: diff --git a/tests/test_hub/test_get_archive_type.py b/tests/test_hub/test_get_archive_type.py index 5185003f..6480bc22 100644 --- a/tests/test_hub/test_get_archive_type.py +++ b/tests/test_hub/test_get_archive_type.py @@ -71,7 +71,7 @@ class TestGetArchiveType(DBQueryTestCase): query = self.queries[0] self.assertEqual(query.tables, ['archivetypes']) self.assertEqual(query.joins, None) - self.assertEqual(query.clauses, ['extensions ~* %(pattern)s']) + self.assertEqual(query.clauses, ["%(ext)s = ANY(regexp_split_to_array(extensions, '\\s+'))"]) self.assertEqual(query.columns, ['compression_type', 'description', 'extensions', 'id', 'name']) get_archive_type_by_name.assert_not_called() @@ -92,7 +92,7 @@ class TestGetArchiveType(DBQueryTestCase): query = self.queries[0] self.assertEqual(query.tables, ['archivetypes']) self.assertEqual(query.joins, None) - self.assertEqual(query.clauses, ['extensions ~* %(pattern)s']) + self.assertEqual(query.clauses, ["%(ext)s = ANY(regexp_split_to_array(extensions, '\\s+'))"]) self.assertEqual(query.columns, ['compression_type', 'description', 'extensions', 'id', 'name']) get_archive_type_by_name.assert_not_called() @@ -112,7 +112,7 @@ class TestGetArchiveType(DBQueryTestCase): query = self.queries[0] self.assertEqual(query.tables, ['archivetypes']) self.assertEqual(query.joins, None) - self.assertEqual(query.clauses, ['extensions ~* %(pattern)s']) + self.assertEqual(query.clauses, ["%(ext)s = ANY(regexp_split_to_array(extensions, '\\s+'))"]) self.assertEqual(query.columns, ['compression_type', 'description', 'extensions', 'id', 'name']) get_archive_type_by_name.assert_not_called() @@ -130,7 +130,7 @@ class TestGetArchiveType(DBQueryTestCase): query = self.queries[0] self.assertEqual(query.tables, ['archivetypes']) self.assertEqual(query.joins, None) - self.assertEqual(query.clauses, ['extensions ~* %(pattern)s']) + self.assertEqual(query.clauses, ["%(ext)s = ANY(regexp_split_to_array(extensions, '\\s+'))"]) self.assertEqual(query.columns, ['compression_type', 'description', 'extensions', 'id', 'name']) get_archive_type_by_name.assert_not_called()