better ordering support in QueryView

This commit is contained in:
Mike McLean 2024-02-19 07:59:48 -05:00 committed by Tomas Kopecek
parent 92298fb943
commit 374f4e3793

View file

@ -617,13 +617,15 @@ class QueryProcessor(object):
group: a column or alias name to use in the 'GROUP BY' clause
(controlled by enable_group)
- enable_group: if True, opts.group will be enabled
- order_map: (optional) a name:expression map of allowed orders. Otherwise any column or alias
is allowed
"""
iterchunksize = 1000
def __init__(self, columns=None, aliases=None, tables=None,
joins=None, clauses=None, values=None, transform=None,
opts=None, enable_group=False):
opts=None, enable_group=False, order_map=None):
self.columns = columns
self.aliases = aliases
if columns and aliases:
@ -656,6 +658,7 @@ class QueryProcessor(object):
self.opts = opts
else:
self.opts = {}
self.order_map = order_map
self.enable_group = enable_group
self.logger = logging.getLogger('koji.db')
@ -754,14 +757,19 @@ SELECT %(col_str)s
else:
direction = ''
# Check if we're ordering by alias first
orderCol = self.colsByAlias.get(order)
if orderCol:
pass
elif order in self.columns:
orderCol = order
if self.order_map is not None:
# order should only be a key in the map
expr = self.order_map.get(order)
if not expr:
raise koji.ParameterError(f'Invalid order term: {order}')
else:
raise Exception('Invalid order: ' + order)
order_exprs.append(orderCol + direction)
expr = self.colsByAlias.get(order)
if not expr:
if order in self.columns:
expr = order
else:
raise Exception('Invalid order: ' + order)
order_exprs.append(expr + direction)
return 'ORDER BY ' + ', '.join(order_exprs)
else:
return ''
@ -881,18 +889,37 @@ class QueryView:
default_fields = ()
def __init__(self, clauses=None, fields=None, opts=None):
self.clauses = clauses
self.fields = fields
self.opts = opts
self._query = None
@property
def query(self):
if self._query is not None:
return self._query
else:
return self.get_query()
def get_query(self):
self.extra_joins = []
self.values = {}
self.order_map = {}
self.check_opts()
tables = list(self.tables) # copy
fields = self.get_fields(fields)
fields, aliases = zip(*fields.items())
clauses = self.get_clauses(clauses)
fields = self.get_fields(self.fields)
columns, aliases = zip(*fields.items())
clauses = self.get_clauses()
joins = self.get_joins()
self.query = QueryProcessor(
columns=fields, aliases=aliases,
self._query = QueryProcessor(
columns=columns, aliases=aliases,
tables=tables, joins=joins,
clauses=clauses, values=self.values,
opts=opts)
opts=self.opts, order_map=self.order_map)
return self._query
def get_fields(self, fields):
fields = fields or self.default_fields
@ -901,6 +928,19 @@ class QueryView:
return {self.map_field(f): f for f in fields}
def check_opts(self):
# some options may trigger joins
if self.opts is None:
return
if 'order' in self.opts:
for key in self.opts['order'].split(','):
if key.startswith('-'):
key = key[1:]
self.order_map[key] = self.map_field(key)
if 'group' in self.opts:
for key in self.opts['group'].split(','):
self.map_field(key)
def map_field(self, field):
f_info = self.fieldmap.get(field)
if f_info is None:
@ -912,10 +952,10 @@ class QueryView:
# duplicates removed later
return fullname
def get_clauses(self, clauses):
def get_clauses(self):
# for now, just a very simple implementation
result = []
clauses = clauses or []
clauses = self.clauses or []
for n, clause in enumerate(clauses):
# TODO checks check checks
if len(clause) == 2: