# python library # db utilities for koji # Copyright (c) 2005-2014 Red Hat, Inc. # # Koji is free software; you can redistribute it and/or # modify it under the terms of the GNU Lesser General Public # License as published by the Free Software Foundation; # version 2.1 of the License. # # This software is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU # Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public # License along with this software; if not, write to the Free Software # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA # # Authors: # Mike McLean from __future__ import absolute_import import datetime import logging import koji import os # import psycopg2.extensions # # don't convert timestamp fields to DateTime objects # del psycopg2.extensions.string_types[1114] # del psycopg2.extensions.string_types[1184] # del psycopg2.extensions.string_types[1082] # del psycopg2.extensions.string_types[1083] # del psycopg2.extensions.string_types[1266] import re import sys import time import traceback import psycopg2 from dateutil.tz import tzutc import koji.context context = koji.context.context POSITIONAL_RE = re.compile(r'%[a-z]') NAMED_RE = re.compile(r'%\(([^\)]+)\)[a-z]') ## Globals ## _DBopts = None # A persistent connection to the database. # A new connection will be created whenever # Apache forks a new worker, and that connection # will be used to service all requests handled # by that worker. # This probably doesn't need to be a ThreadLocal # since Apache is not using threading, # but play it safe anyway. _DBconn = koji.context.ThreadLocal() logger = logging.getLogger('koji.db') class DBWrapper: def __init__(self, cnx): self.cnx = cnx def __getattr__(self, key): if not self.cnx: raise Exception('connection is closed') return getattr(self.cnx, key) def cursor(self, *args, **kw): if not self.cnx: raise Exception('connection is closed') return CursorWrapper(self.cnx.cursor(*args, **kw)) def close(self): # Rollback any uncommitted changes and clear the connection so # this DBWrapper is no longer usable after close() if not self.cnx: raise Exception('connection is closed') self.cnx.cursor().execute('ROLLBACK') # We do this rather than cnx.rollback to avoid opening a new transaction # If our connection gets recycled cnx.rollback will be called then. self.cnx = None class CursorWrapper: def __init__(self, cursor): self.cursor = cursor self.logger = logging.getLogger('koji.db') def __getattr__(self, key): return getattr(self.cursor, key) def _timed_call(self, method, args, kwargs): start = time.time() ret = getattr(self.cursor, method)(*args, **kwargs) self.logger.debug("%s operation completed in %.4f seconds", method, time.time() - start) return ret def fetchone(self, *args, **kwargs): return self._timed_call('fetchone', args, kwargs) def fetchall(self, *args, **kwargs): return self._timed_call('fetchall', args, kwargs) def quote(self, operation, parameters): if hasattr(self.cursor, "mogrify"): quote = self.cursor.mogrify else: def quote(a, b): return a % b try: sql = quote(operation, parameters) if isinstance(sql, bytes): try: sql = koji.util.decode_bytes(sql) except Exception: pass return sql except Exception: self.logger.exception( 'Unable to quote query:\n%s\nParameters: %s', operation, parameters) return "INVALID QUERY" def preformat(self, sql, params): """psycopg2 requires all variable placeholders to use the string (%s) datatype, regardless of the actual type of the data. Format the sql string to be compliant. It also requires IN parameters to be in tuple rather than list format.""" sql = POSITIONAL_RE.sub(r'%s', sql) sql = NAMED_RE.sub(r'%(\1)s', sql) if isinstance(params, dict): for name, value in params.items(): if isinstance(value, list): params[name] = tuple(value) else: if isinstance(params, tuple): params = list(params) for i, item in enumerate(params): if isinstance(item, list): params[i] = tuple(item) return sql, params def execute(self, operation, parameters=(), log_errors=True): debug = self.logger.isEnabledFor(logging.DEBUG) operation, parameters = self.preformat(operation, parameters) if debug: self.logger.debug(self.quote(operation, parameters)) start = time.time() try: ret = self.cursor.execute(operation, parameters) except Exception: if log_errors: self.logger.error('Query failed. Query was: %s', self.quote(operation, parameters)) raise if debug: self.logger.debug("Execute operation completed in %.4f seconds", time.time() - start) return ret ## Functions ## def provideDBopts(**opts): global _DBopts if _DBopts is None: _DBopts = dict([i for i in opts.items() if i[1] is not None]) def setDBopts(**opts): global _DBopts _DBopts = opts def getDBopts(): return _DBopts def connect(): logger = logging.getLogger('koji.db') global _DBconn if hasattr(_DBconn, 'conn'): # Make sure the previous transaction has been # closed. This is safe to call multiple times. conn = _DBconn.conn try: # Under normal circumstances, the last use of this connection # will have issued a raw ROLLBACK to close the transaction. To # avoid 'no transaction in progress' warnings (depending on postgres # configuration) we open a new one here. # Should there somehow be a transaction in progress, a second # BEGIN will be a harmless no-op, though there may be a warning. conn.cursor().execute('BEGIN') conn.rollback() return DBWrapper(conn) except psycopg2.Error: del _DBconn.conn # create a fresh connection opts = _DBopts if opts is None: opts = {} try: if 'dsn' in opts: conn = psycopg2.connect(dsn=opts['dsn']) else: conn = psycopg2.connect(**opts) conn.set_client_encoding('UTF8') except Exception: logger.error(''.join(traceback.format_exception(*sys.exc_info()))) raise # XXX test # return conn _DBconn.conn = conn return DBWrapper(conn) def _dml(operation, values, log_errors=True): """Run an insert, update, or delete. Return number of rows affected If log is False, errors will not be logged. It makes sense only for queries which are expected to fail (LOCK NOWAIT) """ c = context.cnx.cursor() c.execute(operation, values, log_errors=log_errors) ret = c.rowcount logger.debug("Operation affected %s row(s)", ret) c.close() context.commit_pending = True return ret def _fetchMulti(query, values): """Run the query and return all rows""" c = context.cnx.cursor() c.execute(query, values) results = c.fetchall() c.close() return results def _fetchSingle(query, values, strict=False): """Run the query and return a single row If strict is true, raise an error if the query returns more or less than one row.""" results = _fetchMulti(query, values) numRows = len(results) if numRows == 0: if strict: raise koji.GenericError('query returned no rows') else: return None elif strict and numRows > 1: raise koji.GenericError('multiple rows returned for a single row query') else: return results[0] def _singleValue(query, values=None, strict=True): """Perform a query that returns a single value. Note that unless strict is True a return value of None could mean either a single NULL value or zero rows returned.""" if values is None: values = {} row = _fetchSingle(query, values, strict) if row: if strict and len(row) > 1: raise koji.GenericError('multiple fields returned for a single value query') return row[0] else: # don't need to check strict here, since that was already handled by _singleRow() return None def _multiRow(query, values, fields): """Return all rows from "query". Named query parameters can be specified using the "values" map. Results will be returned as a list of maps. Each map in the list will have a key for each element in the "fields" list. If there are no results, an empty list will be returned.""" return [dict(zip(fields, row)) for row in _fetchMulti(query, values)] def _singleRow(query, values, fields, strict=False): """Return a single row from "query". Named parameters can be specified using the "values" map. The result will be returned as as map. The map will have a key for each element in the "fields" list. If more than one row is returned and "strict" is true, a GenericError will be raised. If no rows are returned, and "strict" is True, a GenericError will be raised. Otherwise None will be returned.""" row = _fetchSingle(query, values, strict) if row: return dict(zip(fields, row)) else: # strict enforced by _fetchSingle return None def convert_timestamp(ts): """Convert a numeric timestamp to a string suitable for a datetimetz field""" return datetime.datetime.fromtimestamp(ts, tzutc()).isoformat(' ') def get_event(): """Get an event id for this transaction We cache the result in context, so subsequent calls in the same transaction will get the same event. This cache is cleared between the individual calls in a multicall. See: https://pagure.io/koji/pull-request/74 """ if hasattr(context, 'event_id'): return context.event_id event_id = _singleValue("SELECT get_event()") context.event_id = event_id return event_id def nextval(sequence): """Get the next value for the given sequence""" data = {'sequence': sequence} return _singleValue("SELECT nextval(%(sequence)s)", data, strict=True) def currval(sequence): """Get the current value for the given sequence""" data = {'sequence': sequence} return _singleValue("SELECT currval(%(sequence)s)", data, strict=True) def db_lock(name, wait=True): """Obtain lock for name The named lock must exist in the locks table :param string name: the lock name :param bool wait: whether to wait for the lock (default: True) :return: True if locked, False otherwise This function is implemented using db row locks and the locks table """ # attempt to lock the row data = {"name": name} if wait: query = "SELECT name FROM locks WHERE name=%(name)s FOR UPDATE" else: # using SKIP LOCKED rather than NOWAIT to avoid error messages query = "SELECT name FROM locks WHERE name=%(name)s FOR UPDATE SKIP LOCKED" rows = _fetchMulti(query, data) if rows: # we have the lock return True if not wait: # in the no-wait case, this could mean either that the row is already locked, or that # the lock does not exist, so we check query = "SELECT name FROM locks WHERE name=%(name)s" rows = _fetchMulti(query, data) if rows: # the lock exists, but we did not acquire it return False # otherwise, the lock does not exist raise koji.LockError(f"Lock not defined: {name}") class Savepoint(object): def __init__(self, name): self.name = name _dml("SAVEPOINT %s" % name, {}) def rollback(self): _dml("ROLLBACK TO SAVEPOINT %s" % self.name, {}) class InsertProcessor(object): """Build an insert statement table - the table to insert into data - a dictionary of data to insert (keys = row names) rawdata - data to insert specified as sql expressions rather than python values does not support query inserts of "DEFAULT VALUES" """ def __init__(self, table, data=None, rawdata=None): self.table = table self.data = {} if data: self.data.update(data) self.rawdata = {} if rawdata: self.rawdata.update(rawdata) def __str__(self): if not self.data and not self.rawdata: return "-- incomplete update: no assigns" parts = ['INSERT INTO %s ' % self.table] columns = sorted(list(self.data.keys()) + list(self.rawdata.keys())) parts.append("(%s) " % ', '.join(columns)) values = [] for key in columns: if key in self.data: values.append("%%(%s)s" % key) else: values.append("(%s)" % self.rawdata[key]) parts.append("VALUES (%s)" % ', '.join(values)) return ''.join(parts) def __repr__(self): return "" % vars(self) def set(self, **kwargs): """Set data via keyword args""" self.data.update(kwargs) def rawset(self, **kwargs): """Set rawdata via keyword args""" self.rawdata.update(kwargs) def make_create(self, event_id=None, user_id=None): if event_id is None: event_id = get_event() if user_id is None: context.session.assertLogin() user_id = context.session.user_id self.data['create_event'] = event_id self.data['creator_id'] = user_id def dup_check(self): """Check to see if the insert duplicates an existing row""" if self.rawdata: logger.warning("Can't perform duplicate check") return None data = self.data.copy() if 'create_event' in self.data: # versioned table data['active'] = True del data['create_event'] del data['creator_id'] clauses = ["%s = %%(%s)s" % (k, k) for k in data] query = QueryProcessor(columns=list(data.keys()), tables=[self.table], clauses=clauses, values=data) if query.execute(): return True return False def execute(self): return _dml(str(self), self.data) class UpsertProcessor(InsertProcessor): """Build a basic upsert statement table - the table to insert into data - a dictionary of data to insert (keys = row names) rawdata - data to insert specified as sql expressions rather than python values keys - the rows that are the unique keys skip_dup - if set to true, do nothing on conflict """ def __init__(self, table, data=None, rawdata=None, keys=None, skip_dup=False): super(UpsertProcessor, self).__init__(table, data=data, rawdata=rawdata) self.keys = keys self.skip_dup = skip_dup if not keys and not skip_dup: raise ValueError('either keys or skip_dup must be set') def __repr__(self): return "" % vars(self) def __str__(self): insert = super(UpsertProcessor, self).__str__() parts = [insert] if self.skip_dup: parts.append(' ON CONFLICT DO NOTHING') else: parts.append(f' ON CONFLICT ({",".join(self.keys)}) DO UPDATE SET ') # filter out conflict keys from data data = {k: self.data[k] for k in self.data if k not in self.keys} rawdata = {k: self.rawdata[k] for k in self.rawdata if k not in self.keys} assigns = [f"{key} = %({key})s" for key in data] assigns.extend([f"{key} = ({rawdata[key]})" for key in self.rawdata]) parts.append(', '.join(sorted(assigns))) return ''.join(parts) class UpdateProcessor(object): """Build an update statement table - the table to insert into data - a dictionary of data to insert (keys = row names) rawdata - data to insert specified as sql expressions rather than python values clauses - a list of where clauses which will be ANDed together values - dict of values used in clauses does not support the FROM clause """ def __init__(self, table, data=None, rawdata=None, clauses=None, values=None): self.table = table self.data = {} if data: self.data.update(data) self.rawdata = {} if rawdata: self.rawdata.update(rawdata) self.clauses = [] if clauses: self.clauses.extend(clauses) self.values = {} if values: self.values.update(values) def __str__(self): if not self.data and not self.rawdata: return "-- incomplete update: no assigns" parts = ['UPDATE %s SET ' % self.table] assigns = ["%s = %%(data.%s)s" % (key, key) for key in self.data] assigns.extend(["%s = (%s)" % (key, self.rawdata[key]) for key in self.rawdata]) parts.append(', '.join(sorted(assigns))) if self.clauses: parts.append('\nWHERE ') parts.append(' AND '.join(["( %s )" % c for c in sorted(self.clauses)])) return ''.join(parts) def __repr__(self): return "" % vars(self) def get_values(self): """Returns unified values dict, including data""" ret = {} ret.update(self.values) for key in self.data: ret["data." + key] = self.data[key] return ret def set(self, **kwargs): """Set data via keyword args""" self.data.update(kwargs) def rawset(self, **kwargs): """Set rawdata via keyword args""" self.rawdata.update(kwargs) def make_revoke(self, event_id=None, user_id=None): """Add standard revoke options to the update""" if event_id is None: event_id = get_event() if user_id is None: context.session.assertLogin() user_id = context.session.user_id self.data['revoke_event'] = event_id self.data['revoker_id'] = user_id self.rawdata['active'] = 'NULL' self.clauses.append('active = TRUE') def execute(self): return _dml(str(self), self.get_values()) class DeleteProcessor(object): """Build an delete statement table - the table to delete clauses - a list of where clauses which will be ANDed together values - dict of values used in clauses """ def __init__(self, table, clauses=None, values=None): self.table = table self.clauses = [] if clauses: self.clauses.extend(clauses) self.values = {} if values: self.values.update(values) def __str__(self): parts = ['DELETE FROM %s ' % self.table] if self.clauses: parts.append('\nWHERE ') parts.append(' AND '.join(["( %s )" % c for c in sorted(self.clauses)])) return ''.join(parts) def __repr__(self): return "" % vars(self) def get_values(self): """Returns unified values dict, including data""" ret = {} ret.update(self.values) return ret def execute(self): return _dml(str(self), self.get_values()) class QueryProcessor(object): """ Build a query from its components. - columns, aliases, tables: lists of the column names to retrieve, the tables to retrieve them from, and the key names to use when returning values as a map, respectively - joins: a list of joins in the form 'table1 ON table1.col1 = table2.col2', 'JOIN' will be prepended automatically; if extended join syntax (LEFT, OUTER, etc.) is required, it can be specified, and 'JOIN' will not be prepended - clauses: a list of where clauses in the form 'table1.col1 OPER table2.col2-or-variable'; each clause will be surrounded by parentheses and all will be AND'ed together - values: the map that will be used to replace any substitution expressions in the query - transform: a function that will be called on each row (not compatible with countOnly or singleValue) - opts: a map of query options; currently supported options are: countOnly: if True, return an integer indicating how many results would have been returned, rather than the actual query results order: a column or alias name to use in the 'ORDER BY' clause offset: an integer to use in the 'OFFSET' clause limit: an integer to use in the 'LIMIT' clause asList: if True, return results as a list of lists, where each list contains the column values in query order, rather than the usual list of maps rowlock: if True, use "FOR UPDATE" to lock the queried rows group: a column or alias name to use in the 'GROUP BY' clause (controlled by enable_group) - enable_group: if True, opts.group will be enabled - order_map: (optional) a name:expression map of allowed orders. Otherwise any column or alias is allowed """ iterchunksize = 1000 def __init__(self, columns=None, aliases=None, tables=None, joins=None, clauses=None, values=None, transform=None, opts=None, enable_group=False, order_map=None): self.columns = columns self.aliases = aliases if columns and aliases: if len(columns) != len(aliases): raise Exception('column and alias lists must be the same length') # reorder alias_table = sorted(zip(aliases, columns)) self.aliases = [x[0] for x in alias_table] self.columns = [x[1] for x in alias_table] self.colsByAlias = dict(alias_table) else: self.colsByAlias = {} if columns: self.columns = sorted(columns) if aliases: self.aliases = sorted(aliases) self.tables = tables self.joins = joins if clauses: self.clauses = sorted(clauses) else: self.clauses = clauses self.cursors = 0 if values: self.values = values else: self.values = {} self.transform = transform if opts: self.opts = opts else: self.opts = {} self.order_map = order_map self.enable_group = enable_group self.logger = logging.getLogger('koji.db') def countOnly(self, count): self.opts['countOnly'] = count def __str__(self): query = \ """ SELECT %(col_str)s FROM %(table_str)s %(join_str)s %(clause_str)s %(group_str)s %(order_str)s %(offset_str)s %(limit_str)s """ if self.opts.get('countOnly'): if self.opts.get('offset') \ or self.opts.get('limit') \ or (self.enable_group and self.opts.get('group')): # If we're counting with an offset and/or limit, we need # to wrap the offset/limited query and then count the results, # rather than trying to offset/limit the single row returned # by count(*). Because we're wrapping the query, we don't care # about the column values. col_str = '1' else: col_str = 'count(*)' else: col_str = self._seqtostr(self.columns) table_str = self._seqtostr(self.tables, sort=True) join_str = self._joinstr() clause_str = self._seqtostr(self.clauses, sep=')\n AND (') if clause_str: clause_str = ' WHERE (' + clause_str + ')' if self.enable_group: group_str = self._group() else: group_str = '' order_str = self._order() offset_str = self._optstr('offset') limit_str = self._optstr('limit') query = query % locals() if self.opts.get('countOnly') and \ (self.opts.get('offset') or self.opts.get('limit') or (self.enable_group and self.opts.get('group'))): query = 'SELECT count(*)\nFROM (' + query + ') numrows' if self.opts.get('rowlock'): query += '\n FOR UPDATE' return query def __repr__(self): return '' % \ (self.columns, self.aliases, self.tables, self.joins, self.clauses, self.values, self.opts) def _seqtostr(self, seq, sep=', ', sort=False): if seq: if sort: seq = sorted(seq) return sep.join(seq) else: return '' def _joinstr(self): if not self.joins: return '' result = '' for join in self.joins: if result: result += '\n' if re.search(r'\bjoin\b', join, re.IGNORECASE): # The join clause already contains the word 'join', # so don't prepend 'JOIN' to it result += ' ' + join else: result += ' JOIN ' + join return result def _order(self): # Don't bother sorting if we're just counting if self.opts.get('countOnly'): return '' order_opt = self.opts.get('order') if order_opt: order_exprs = [] for order in order_opt.split(','): if order.startswith('-'): order = order[1:] direction = ' DESC' else: direction = '' # Check if we're ordering by alias first if self.order_map is not None: # order should only be a key in the map expr = self.order_map.get(order) if not expr: raise koji.ParameterError(f'Invalid order term: {order}') else: expr = self.colsByAlias.get(order) if not expr: if order in self.columns: expr = order else: raise Exception('Invalid order: ' + order) order_exprs.append(expr + direction) return 'ORDER BY ' + ', '.join(order_exprs) else: return '' def _group(self): group_opt = self.opts.get('group') if group_opt: group_exprs = [] for group in group_opt.split(','): if group: group_exprs.append(group) return 'GROUP BY ' + ', '.join(group_exprs) else: return '' def _optstr(self, optname): optval = self.opts.get(optname) if optval: return '%s %i' % (optname.upper(), optval) else: return '' def singleValue(self, strict=True): # self.transform not applied here return _singleValue(str(self), self.values, strict=strict) def execute(self): query = str(self) if self.opts.get('countOnly'): return _singleValue(query, self.values, strict=True) elif self.opts.get('asList'): if self.transform is None: return _fetchMulti(query, self.values) else: # if we're transforming, generate the dicts so the transform can modify fields = self.aliases or self.columns data = _multiRow(query, self.values, fields) data = [self.transform(row) for row in data] # and then convert back to lists data = [[row[f] for f in fields] for row in data] return data else: data = _multiRow(query, self.values, (self.aliases or self.columns)) if self.transform is not None: data = [self.transform(row) for row in data] return data def iterate(self): if self.opts.get('countOnly'): return self.execute() elif self.opts.get('limit') and self.opts['limit'] < self.iterchunksize: return self.execute() else: fields = self.aliases or self.columns fields = list(fields) cname = "qp_cursor_%s_%i_%i" % (id(self), os.getpid(), self.cursors) self.cursors += 1 self.logger.debug('Setting up query iterator. cname=%r', cname) return self._iterate(cname, str(self), self.values.copy(), fields, self.iterchunksize, self.opts.get('asList')) def _iterate(self, cname, query, values, fields, chunksize, as_list=False): # We pass all this data into the generator so that the iterator works # from the snapshot when it was generated. Otherwise reuse of the processor # for similar queries could have unpredictable results. query = "DECLARE %s NO SCROLL CURSOR FOR %s" % (cname, query) c = context.cnx.cursor() c.execute(query, values) c.close() try: query = "FETCH %i FROM %s" % (chunksize, cname) while True: if as_list: if self.transform is None: buf = _fetchMulti(query, {}) else: # if we're transforming, generate the dicts so the transform can modify buf = _multiRow(query, self.values, fields) buf = [self.transform(row) for row in buf] # and then convert back to lists buf = [[row[f] for f in fields] for row in buf] else: buf = _multiRow(query, {}, fields) if self.transform is not None: buf = [self.transform(row) for row in buf] if not buf: break for row in buf: yield row finally: c = context.cnx.cursor() c.execute("CLOSE %s" % cname) c.close() def executeOne(self, strict=False): results = self.execute() if isinstance(results, list): if len(results) > 0: if strict and len(results) > 1: raise koji.GenericError('multiple rows returned for a single row query') return results[0] elif strict: raise koji.GenericError('query returned no rows') else: return None return results class QueryView: # abstract base class # subclasses should provide... tables = [] joins = [] joinmap = {} fieldmap = {} default_fields = () def __init__(self, clauses=None, fields=None, opts=None): self.clauses = clauses self.fields = fields self.opts = opts self._query = None @property def query(self): if self._query is not None: return self._query else: return self.get_query() def get_query(self): self._implicit_joins = [] self._values = {} self._order_map = {} self.check_opts() tables = list(self.tables) # copy clauses = self.get_clauses() # get_fields needs to be after clauses because it might consider other implicit joins fields = self.get_fields(self.fields) aliases, columns = zip(*fields.items()) joins = self.get_joins() self._query = QueryProcessor( columns=columns, aliases=aliases, tables=tables, joins=joins, clauses=clauses, values=self._values, opts=self.opts, order_map=self._order_map) return self._query def get_fields(self, fields): fields = fields or self.default_fields or ['*'] if isinstance(fields, str): fields = [fields] # handle special field names flist = [] for field in fields: if field == '*': # all fields that don't require additional joins for f in self.fieldmap: joinkey = self.fieldmap[f][1] if joinkey is None or joinkey in self._implicit_joins: flist.append(f) elif field == '**': # all fields flist.extend(self.fieldmap) else: flist.append(field) return {f: self.map_field(f) for f in set(flist)} def check_opts(self): # some options may trigger joins if self.opts is None: return if 'order' in self.opts: for key in self.opts['order'].split(','): if key.startswith('-'): key = key[1:] self._order_map[key] = self.map_field(key) if 'group' in self.opts: for key in self.opts['group'].split(','): self.map_field(key) def map_field(self, field): f_info = self.fieldmap.get(field) if f_info is None: raise koji.ParameterError(f'Invalid field for query {field}') fullname, joinkey = f_info fullname = fullname or field if joinkey: self._implicit_joins.append(joinkey) # duplicates removed later return fullname def get_clauses(self): # for now, just a very simple implementation result = [] clauses = self.clauses or [] for n, clause in enumerate(clauses): # TODO checks check checks if len(clause) == 2: # implicit operator field, value = clause if isinstance(value, (list, tuple)): op = 'IN' else: op = '=' elif len(clause) == 3: field, op, value = clause op = op.upper() if op not in ('IN', '=', '!=', '>', '<', '>=', '<=', 'IS', 'IS NOT', '@>', '<@'): raise koji.ParameterError(f'Invalid operator: {op}') else: raise koji.ParameterError(f'Invalid clause: {clause}') fullname = self.map_field(field) key = f'v_{field}_{n}' self._values[key] = value result.append(f'{fullname} {op} %({key})s') return result def get_joins(self): joins = list(self.joins) seen = set() # note we preserve the order that implicit joins were added for joinkey in self._implicit_joins: if joinkey in seen: continue seen.add(joinkey) joins.append(self.joinmap[joinkey]) return joins def execute(self): return self.query.execute() def executeOne(self, strict=False): return self.query.executeOne(strict=strict) def iterate(self): return self.query.iterate() def singleValue(self, strict=True): return self.query.singleValue(strict=strict) class BulkInsertProcessor(object): def __init__(self, table, data=None, columns=None, strict=True, batch=1000): """Do bulk inserts - it has some limitations compared to InsertProcessor (no rawset, dup_check). set() is replaced with add_record() to avoid confusion table - name of the table data - list of dict per record columns - list/set of names of used columns - makes sense mainly with strict=True strict - if True, all records must contain values for all columns. if False, missing values will be inserted as NULLs batch - batch size for inserts (one statement per batch) """ self.table = table self.data = [] if columns is None: self.columns = set() else: self.columns = set(columns) if data is not None: self.data = data for row in data: self.columns |= set(row.keys()) self.strict = strict self.batch = batch def __str__(self): if not self.data: return "-- incomplete insert: no data" query, params = self._get_insert(self.data) return query def _get_insert(self, data): """ Generate one insert statement for the given data :param list data: list of rows (dict format) to insert :returns: (query, params) """ if not data: # should not happen raise ValueError('no data for insert') parts = ['INSERT INTO %s ' % self.table] columns = sorted(self.columns) parts.append("(%s) " % ', '.join(columns)) prepared_data = {} values = [] i = 0 for row in data: row_values = [] for key in columns: if key in row: row_key = '%s%d' % (key, i) row_values.append("%%(%s)s" % row_key) prepared_data[row_key] = row[key] elif self.strict: raise koji.GenericError("Missing value %s in BulkInsert" % key) else: row_values.append("NULL") values.append("(%s)" % ', '.join(row_values)) i += 1 parts.append("VALUES %s" % ', '.join(values)) return ''.join(parts), prepared_data def __repr__(self): return "" % vars(self) def add_record(self, **kwargs): """Set whole record via keyword args""" if not kwargs: raise koji.GenericError("Missing values in BulkInsert.add_record") self.data.append(kwargs) self.columns |= set(kwargs.keys()) def execute(self): if not self.batch: self._one_insert(self.data) else: for i in range(0, len(self.data), self.batch): data = self.data[i:i + self.batch] self._one_insert(data) def _one_insert(self, data): query, params = self._get_insert(data) _dml(query, params) def _applyQueryOpts(results, queryOpts): """ Apply queryOpts to results in the same way QueryProcessor would. results is a list of maps. queryOpts is a map which may contain the following fields: countOnly order offset limit Note: - asList is supported by QueryProcessor but not by this method. We don't know the original query order, and so don't have a way to return a useful list. asList should be handled by the caller. - group is supported by QueryProcessor but not by this method as well. """ if queryOpts is None: queryOpts = {} if queryOpts.get('order'): order = queryOpts['order'] reverse = False if order.startswith('-'): order = order[1:] reverse = True results.sort(key=lambda o: o[order], reverse=reverse) if queryOpts.get('offset'): results = results[queryOpts['offset']:] if queryOpts.get('limit'): results = results[:queryOpts['limit']] if queryOpts.get('countOnly'): return len(results) else: return results class BulkUpdateProcessor(object): """Build a bulk update statement using a from clause table - the table to insert into data - list of dictionaries of update data (keys = row names) match_keys - the fields that are used to match The row data is provided as a list of dictionaries. Each entry must contain the same keys. The match_keys value indicate which keys are used to select the rows to update. The remaining keys are the actual updates. I.e. if you have data = [{'a':1, 'b':2}] with match_keys=['a'], this will set b=2 for rows where a=1 """ def __init__(self, table, data=None, match_keys=None): self.table = table self.data = data or [] if match_keys is None: self.match_keys = [] else: self.match_keys = list(match_keys) self._values = {} def __str__(self): return self.get_sql() def get_sql(self): if not self.data or not self.match_keys: return "-- incomplete bulk update" set_keys, all_keys = self.get_keys() match_keys = list(self.match_keys) match_keys.sort() utable = f'__kojibulk_{self.table}' utable.replace('.', '_') # in case schema qualified assigns = [f'{key} = {utable}.{key}' for key in set_keys] values = {} # values for lookup fdata = [] # data for VALUES clause for n, row in enumerate(self.data): # each row is a dictionary with all keys parts = [] for key in all_keys: v_key = f'val_{key}_{n}' values[v_key] = row[key] parts.append(f'%({v_key})s') fdata.append('(%s)' % ', '.join(parts)) clauses = [f'{self.table}.{key} = {utable}.{key}' for key in match_keys] parts = [ 'UPDATE %s SET %s\n' % (self.table, ', '.join(assigns)), 'FROM (VALUES %s)\nAS %s (%s)\n' % ( ', '.join(fdata), utable, ', '.join(all_keys)), 'WHERE (%s)' % ' AND '.join(clauses), ] self._values = values return ''.join(parts) def get_keys(self): if not self.data: raise ValueError('no update data') all_keys = list(self.data[0].keys()) for key in all_keys: if not isinstance(key, str): raise TypeError('update data must use string keys') all_keys.sort() set_keys = [k for k in all_keys if k not in self.match_keys] set_keys.sort() # also check that data is sane required = set(all_keys) for row in self.data: if set(row.keys()) != required: raise ValueError('mismatched update keys') return set_keys, all_keys def __repr__(self): return "" % vars(self) def execute(self): sql = self.get_sql() # sets self._values return _dml(sql, self._values) # the end