better ordering support in QueryView
This commit is contained in:
parent
92298fb943
commit
374f4e3793
1 changed files with 56 additions and 16 deletions
|
|
@ -617,13 +617,15 @@ class QueryProcessor(object):
|
|||
group: a column or alias name to use in the 'GROUP BY' clause
|
||||
(controlled by enable_group)
|
||||
- enable_group: if True, opts.group will be enabled
|
||||
- order_map: (optional) a name:expression map of allowed orders. Otherwise any column or alias
|
||||
is allowed
|
||||
"""
|
||||
|
||||
iterchunksize = 1000
|
||||
|
||||
def __init__(self, columns=None, aliases=None, tables=None,
|
||||
joins=None, clauses=None, values=None, transform=None,
|
||||
opts=None, enable_group=False):
|
||||
opts=None, enable_group=False, order_map=None):
|
||||
self.columns = columns
|
||||
self.aliases = aliases
|
||||
if columns and aliases:
|
||||
|
|
@ -656,6 +658,7 @@ class QueryProcessor(object):
|
|||
self.opts = opts
|
||||
else:
|
||||
self.opts = {}
|
||||
self.order_map = order_map
|
||||
self.enable_group = enable_group
|
||||
self.logger = logging.getLogger('koji.db')
|
||||
|
||||
|
|
@ -754,14 +757,19 @@ SELECT %(col_str)s
|
|||
else:
|
||||
direction = ''
|
||||
# Check if we're ordering by alias first
|
||||
orderCol = self.colsByAlias.get(order)
|
||||
if orderCol:
|
||||
pass
|
||||
elif order in self.columns:
|
||||
orderCol = order
|
||||
if self.order_map is not None:
|
||||
# order should only be a key in the map
|
||||
expr = self.order_map.get(order)
|
||||
if not expr:
|
||||
raise koji.ParameterError(f'Invalid order term: {order}')
|
||||
else:
|
||||
raise Exception('Invalid order: ' + order)
|
||||
order_exprs.append(orderCol + direction)
|
||||
expr = self.colsByAlias.get(order)
|
||||
if not expr:
|
||||
if order in self.columns:
|
||||
expr = order
|
||||
else:
|
||||
raise Exception('Invalid order: ' + order)
|
||||
order_exprs.append(expr + direction)
|
||||
return 'ORDER BY ' + ', '.join(order_exprs)
|
||||
else:
|
||||
return ''
|
||||
|
|
@ -881,18 +889,37 @@ class QueryView:
|
|||
default_fields = ()
|
||||
|
||||
def __init__(self, clauses=None, fields=None, opts=None):
|
||||
self.clauses = clauses
|
||||
self.fields = fields
|
||||
self.opts = opts
|
||||
self._query = None
|
||||
|
||||
@property
|
||||
def query(self):
|
||||
if self._query is not None:
|
||||
return self._query
|
||||
else:
|
||||
return self.get_query()
|
||||
|
||||
def get_query(self):
|
||||
self.extra_joins = []
|
||||
self.values = {}
|
||||
self.order_map = {}
|
||||
|
||||
self.check_opts()
|
||||
|
||||
tables = list(self.tables) # copy
|
||||
fields = self.get_fields(fields)
|
||||
fields, aliases = zip(*fields.items())
|
||||
clauses = self.get_clauses(clauses)
|
||||
fields = self.get_fields(self.fields)
|
||||
columns, aliases = zip(*fields.items())
|
||||
clauses = self.get_clauses()
|
||||
joins = self.get_joins()
|
||||
self.query = QueryProcessor(
|
||||
columns=fields, aliases=aliases,
|
||||
self._query = QueryProcessor(
|
||||
columns=columns, aliases=aliases,
|
||||
tables=tables, joins=joins,
|
||||
clauses=clauses, values=self.values,
|
||||
opts=opts)
|
||||
opts=self.opts, order_map=self.order_map)
|
||||
|
||||
return self._query
|
||||
|
||||
def get_fields(self, fields):
|
||||
fields = fields or self.default_fields
|
||||
|
|
@ -901,6 +928,19 @@ class QueryView:
|
|||
|
||||
return {self.map_field(f): f for f in fields}
|
||||
|
||||
def check_opts(self):
|
||||
# some options may trigger joins
|
||||
if self.opts is None:
|
||||
return
|
||||
if 'order' in self.opts:
|
||||
for key in self.opts['order'].split(','):
|
||||
if key.startswith('-'):
|
||||
key = key[1:]
|
||||
self.order_map[key] = self.map_field(key)
|
||||
if 'group' in self.opts:
|
||||
for key in self.opts['group'].split(','):
|
||||
self.map_field(key)
|
||||
|
||||
def map_field(self, field):
|
||||
f_info = self.fieldmap.get(field)
|
||||
if f_info is None:
|
||||
|
|
@ -912,10 +952,10 @@ class QueryView:
|
|||
# duplicates removed later
|
||||
return fullname
|
||||
|
||||
def get_clauses(self, clauses):
|
||||
def get_clauses(self):
|
||||
# for now, just a very simple implementation
|
||||
result = []
|
||||
clauses = clauses or []
|
||||
clauses = self.clauses or []
|
||||
for n, clause in enumerate(clauses):
|
||||
# TODO checks check checks
|
||||
if len(clause) == 2:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue