simplify extension match

this avoids errors if the ext value contains special characters
This commit is contained in:
Mike McLean 2024-10-10 11:13:14 -04:00 committed by Tomas Kopecek
parent 6afde19a8a
commit 1b3990cec9
2 changed files with 6 additions and 6 deletions

View file

@ -7811,13 +7811,13 @@ def get_archive_type(filename=None, type_name=None, type_id=None, strict=False):
query = QueryProcessor( query = QueryProcessor(
tables=['archivetypes'], tables=['archivetypes'],
columns=['id', 'name', 'description', 'extensions', 'compression_type'], columns=['id', 'name', 'description', 'extensions', 'compression_type'],
clauses=['extensions ~* %(pattern)s'], clauses=[r"%(ext)s = ANY(regexp_split_to_array(extensions, '\s+'))"],
) )
# match longest extension first. e.g. .tar.gz before .gz # match longest extension first. e.g. .tar.gz before .gz
parts = filename.split('.') parts = filename.split('.')
for start in range(len(parts)): for start in range(len(parts)):
ext = '.'.join(parts[start:]) ext = '.'.join(parts[start:])
query.values['pattern'] = r'(\s|^)%s(\s|$)' % ext query.values['ext'] = ext
results = query.execute() results = query.execute()
if len(results) == 1: if len(results) == 1:

View file

@ -71,7 +71,7 @@ class TestGetArchiveType(DBQueryTestCase):
query = self.queries[0] query = self.queries[0]
self.assertEqual(query.tables, ['archivetypes']) self.assertEqual(query.tables, ['archivetypes'])
self.assertEqual(query.joins, None) self.assertEqual(query.joins, None)
self.assertEqual(query.clauses, ['extensions ~* %(pattern)s']) self.assertEqual(query.clauses, ["%(ext)s = ANY(regexp_split_to_array(extensions, '\\s+'))"])
self.assertEqual(query.columns, self.assertEqual(query.columns,
['compression_type', 'description', 'extensions', 'id', 'name']) ['compression_type', 'description', 'extensions', 'id', 'name'])
get_archive_type_by_name.assert_not_called() get_archive_type_by_name.assert_not_called()
@ -92,7 +92,7 @@ class TestGetArchiveType(DBQueryTestCase):
query = self.queries[0] query = self.queries[0]
self.assertEqual(query.tables, ['archivetypes']) self.assertEqual(query.tables, ['archivetypes'])
self.assertEqual(query.joins, None) self.assertEqual(query.joins, None)
self.assertEqual(query.clauses, ['extensions ~* %(pattern)s']) self.assertEqual(query.clauses, ["%(ext)s = ANY(regexp_split_to_array(extensions, '\\s+'))"])
self.assertEqual(query.columns, self.assertEqual(query.columns,
['compression_type', 'description', 'extensions', 'id', 'name']) ['compression_type', 'description', 'extensions', 'id', 'name'])
get_archive_type_by_name.assert_not_called() get_archive_type_by_name.assert_not_called()
@ -112,7 +112,7 @@ class TestGetArchiveType(DBQueryTestCase):
query = self.queries[0] query = self.queries[0]
self.assertEqual(query.tables, ['archivetypes']) self.assertEqual(query.tables, ['archivetypes'])
self.assertEqual(query.joins, None) self.assertEqual(query.joins, None)
self.assertEqual(query.clauses, ['extensions ~* %(pattern)s']) self.assertEqual(query.clauses, ["%(ext)s = ANY(regexp_split_to_array(extensions, '\\s+'))"])
self.assertEqual(query.columns, self.assertEqual(query.columns,
['compression_type', 'description', 'extensions', 'id', 'name']) ['compression_type', 'description', 'extensions', 'id', 'name'])
get_archive_type_by_name.assert_not_called() get_archive_type_by_name.assert_not_called()
@ -130,7 +130,7 @@ class TestGetArchiveType(DBQueryTestCase):
query = self.queries[0] query = self.queries[0]
self.assertEqual(query.tables, ['archivetypes']) self.assertEqual(query.tables, ['archivetypes'])
self.assertEqual(query.joins, None) self.assertEqual(query.joins, None)
self.assertEqual(query.clauses, ['extensions ~* %(pattern)s']) self.assertEqual(query.clauses, ["%(ext)s = ANY(regexp_split_to_array(extensions, '\\s+'))"])
self.assertEqual(query.columns, self.assertEqual(query.columns,
['compression_type', 'description', 'extensions', 'id', 'name']) ['compression_type', 'description', 'extensions', 'id', 'name'])
get_archive_type_by_name.assert_not_called() get_archive_type_by_name.assert_not_called()