From 374f4e37935e07ce7748558474811cdb3df2b83f Mon Sep 17 00:00:00 2001 From: Mike McLean Date: Mon, 19 Feb 2024 07:59:48 -0500 Subject: [PATCH] better ordering support in QueryView --- kojihub/db.py | 72 +++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 56 insertions(+), 16 deletions(-) diff --git a/kojihub/db.py b/kojihub/db.py index a7cd1cf5..40c121df 100644 --- a/kojihub/db.py +++ b/kojihub/db.py @@ -617,13 +617,15 @@ class QueryProcessor(object): group: a column or alias name to use in the 'GROUP BY' clause (controlled by enable_group) - enable_group: if True, opts.group will be enabled + - order_map: (optional) a name:expression map of allowed orders. Otherwise any column or alias + is allowed """ iterchunksize = 1000 def __init__(self, columns=None, aliases=None, tables=None, joins=None, clauses=None, values=None, transform=None, - opts=None, enable_group=False): + opts=None, enable_group=False, order_map=None): self.columns = columns self.aliases = aliases if columns and aliases: @@ -656,6 +658,7 @@ class QueryProcessor(object): self.opts = opts else: self.opts = {} + self.order_map = order_map self.enable_group = enable_group self.logger = logging.getLogger('koji.db') @@ -754,14 +757,19 @@ SELECT %(col_str)s else: direction = '' # Check if we're ordering by alias first - orderCol = self.colsByAlias.get(order) - if orderCol: - pass - elif order in self.columns: - orderCol = order + if self.order_map is not None: + # order should only be a key in the map + expr = self.order_map.get(order) + if not expr: + raise koji.ParameterError(f'Invalid order term: {order}') else: - raise Exception('Invalid order: ' + order) - order_exprs.append(orderCol + direction) + expr = self.colsByAlias.get(order) + if not expr: + if order in self.columns: + expr = order + else: + raise Exception('Invalid order: ' + order) + order_exprs.append(expr + direction) return 'ORDER BY ' + ', '.join(order_exprs) else: return '' @@ -881,18 +889,37 @@ class QueryView: default_fields = () def __init__(self, clauses=None, fields=None, opts=None): + self.clauses = clauses + self.fields = fields + self.opts = opts + self._query = None + + @property + def query(self): + if self._query is not None: + return self._query + else: + return self.get_query() + + def get_query(self): self.extra_joins = [] self.values = {} + self.order_map = {} + + self.check_opts() + tables = list(self.tables) # copy - fields = self.get_fields(fields) - fields, aliases = zip(*fields.items()) - clauses = self.get_clauses(clauses) + fields = self.get_fields(self.fields) + columns, aliases = zip(*fields.items()) + clauses = self.get_clauses() joins = self.get_joins() - self.query = QueryProcessor( - columns=fields, aliases=aliases, + self._query = QueryProcessor( + columns=columns, aliases=aliases, tables=tables, joins=joins, clauses=clauses, values=self.values, - opts=opts) + opts=self.opts, order_map=self.order_map) + + return self._query def get_fields(self, fields): fields = fields or self.default_fields @@ -901,6 +928,19 @@ class QueryView: return {self.map_field(f): f for f in fields} + def check_opts(self): + # some options may trigger joins + if self.opts is None: + return + if 'order' in self.opts: + for key in self.opts['order'].split(','): + if key.startswith('-'): + key = key[1:] + self.order_map[key] = self.map_field(key) + if 'group' in self.opts: + for key in self.opts['group'].split(','): + self.map_field(key) + def map_field(self, field): f_info = self.fieldmap.get(field) if f_info is None: @@ -912,10 +952,10 @@ class QueryView: # duplicates removed later return fullname - def get_clauses(self, clauses): + def get_clauses(self): # for now, just a very simple implementation result = [] - clauses = clauses or [] + clauses = self.clauses or [] for n, clause in enumerate(clauses): # TODO checks check checks if len(clause) == 2: