simplify extension match
this avoids errors if the ext value contains special characters
This commit is contained in:
parent
6afde19a8a
commit
1b3990cec9
2 changed files with 6 additions and 6 deletions
|
|
@ -7811,13 +7811,13 @@ def get_archive_type(filename=None, type_name=None, type_id=None, strict=False):
|
||||||
query = QueryProcessor(
|
query = QueryProcessor(
|
||||||
tables=['archivetypes'],
|
tables=['archivetypes'],
|
||||||
columns=['id', 'name', 'description', 'extensions', 'compression_type'],
|
columns=['id', 'name', 'description', 'extensions', 'compression_type'],
|
||||||
clauses=['extensions ~* %(pattern)s'],
|
clauses=[r"%(ext)s = ANY(regexp_split_to_array(extensions, '\s+'))"],
|
||||||
)
|
)
|
||||||
# match longest extension first. e.g. .tar.gz before .gz
|
# match longest extension first. e.g. .tar.gz before .gz
|
||||||
parts = filename.split('.')
|
parts = filename.split('.')
|
||||||
for start in range(len(parts)):
|
for start in range(len(parts)):
|
||||||
ext = '.'.join(parts[start:])
|
ext = '.'.join(parts[start:])
|
||||||
query.values['pattern'] = r'(\s|^)%s(\s|$)' % ext
|
query.values['ext'] = ext
|
||||||
results = query.execute()
|
results = query.execute()
|
||||||
|
|
||||||
if len(results) == 1:
|
if len(results) == 1:
|
||||||
|
|
|
||||||
|
|
@ -71,7 +71,7 @@ class TestGetArchiveType(DBQueryTestCase):
|
||||||
query = self.queries[0]
|
query = self.queries[0]
|
||||||
self.assertEqual(query.tables, ['archivetypes'])
|
self.assertEqual(query.tables, ['archivetypes'])
|
||||||
self.assertEqual(query.joins, None)
|
self.assertEqual(query.joins, None)
|
||||||
self.assertEqual(query.clauses, ['extensions ~* %(pattern)s'])
|
self.assertEqual(query.clauses, ["%(ext)s = ANY(regexp_split_to_array(extensions, '\\s+'))"])
|
||||||
self.assertEqual(query.columns,
|
self.assertEqual(query.columns,
|
||||||
['compression_type', 'description', 'extensions', 'id', 'name'])
|
['compression_type', 'description', 'extensions', 'id', 'name'])
|
||||||
get_archive_type_by_name.assert_not_called()
|
get_archive_type_by_name.assert_not_called()
|
||||||
|
|
@ -92,7 +92,7 @@ class TestGetArchiveType(DBQueryTestCase):
|
||||||
query = self.queries[0]
|
query = self.queries[0]
|
||||||
self.assertEqual(query.tables, ['archivetypes'])
|
self.assertEqual(query.tables, ['archivetypes'])
|
||||||
self.assertEqual(query.joins, None)
|
self.assertEqual(query.joins, None)
|
||||||
self.assertEqual(query.clauses, ['extensions ~* %(pattern)s'])
|
self.assertEqual(query.clauses, ["%(ext)s = ANY(regexp_split_to_array(extensions, '\\s+'))"])
|
||||||
self.assertEqual(query.columns,
|
self.assertEqual(query.columns,
|
||||||
['compression_type', 'description', 'extensions', 'id', 'name'])
|
['compression_type', 'description', 'extensions', 'id', 'name'])
|
||||||
get_archive_type_by_name.assert_not_called()
|
get_archive_type_by_name.assert_not_called()
|
||||||
|
|
@ -112,7 +112,7 @@ class TestGetArchiveType(DBQueryTestCase):
|
||||||
query = self.queries[0]
|
query = self.queries[0]
|
||||||
self.assertEqual(query.tables, ['archivetypes'])
|
self.assertEqual(query.tables, ['archivetypes'])
|
||||||
self.assertEqual(query.joins, None)
|
self.assertEqual(query.joins, None)
|
||||||
self.assertEqual(query.clauses, ['extensions ~* %(pattern)s'])
|
self.assertEqual(query.clauses, ["%(ext)s = ANY(regexp_split_to_array(extensions, '\\s+'))"])
|
||||||
self.assertEqual(query.columns,
|
self.assertEqual(query.columns,
|
||||||
['compression_type', 'description', 'extensions', 'id', 'name'])
|
['compression_type', 'description', 'extensions', 'id', 'name'])
|
||||||
get_archive_type_by_name.assert_not_called()
|
get_archive_type_by_name.assert_not_called()
|
||||||
|
|
@ -130,7 +130,7 @@ class TestGetArchiveType(DBQueryTestCase):
|
||||||
query = self.queries[0]
|
query = self.queries[0]
|
||||||
self.assertEqual(query.tables, ['archivetypes'])
|
self.assertEqual(query.tables, ['archivetypes'])
|
||||||
self.assertEqual(query.joins, None)
|
self.assertEqual(query.joins, None)
|
||||||
self.assertEqual(query.clauses, ['extensions ~* %(pattern)s'])
|
self.assertEqual(query.clauses, ["%(ext)s = ANY(regexp_split_to_array(extensions, '\\s+'))"])
|
||||||
self.assertEqual(query.columns,
|
self.assertEqual(query.columns,
|
||||||
['compression_type', 'description', 'extensions', 'id', 'name'])
|
['compression_type', 'description', 'extensions', 'id', 'name'])
|
||||||
get_archive_type_by_name.assert_not_called()
|
get_archive_type_by_name.assert_not_called()
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue