debian-koji/kojihub/db.py
2024-07-12 10:55:44 -04:00

1248 lines
42 KiB
Python

# python library
# db utilities for koji
# Copyright (c) 2005-2014 Red Hat, Inc.
#
# Koji is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation;
# version 2.1 of the License.
#
# This software is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with this software; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
#
# Authors:
# Mike McLean <mikem@redhat.com>
from __future__ import absolute_import
import datetime
import logging
import koji
import os
# import psycopg2.extensions
# # don't convert timestamp fields to DateTime objects
# del psycopg2.extensions.string_types[1114]
# del psycopg2.extensions.string_types[1184]
# del psycopg2.extensions.string_types[1082]
# del psycopg2.extensions.string_types[1083]
# del psycopg2.extensions.string_types[1266]
import re
import sys
import time
import traceback
import psycopg2
from dateutil.tz import tzutc
import koji.context
context = koji.context.context
POSITIONAL_RE = re.compile(r'%[a-z]')
NAMED_RE = re.compile(r'%\(([^\)]+)\)[a-z]')
## Globals ##
_DBopts = None
# A persistent connection to the database.
# A new connection will be created whenever
# Apache forks a new worker, and that connection
# will be used to service all requests handled
# by that worker.
# This probably doesn't need to be a ThreadLocal
# since Apache is not using threading,
# but play it safe anyway.
_DBconn = koji.context.ThreadLocal()
logger = logging.getLogger('koji.db')
class DBWrapper:
def __init__(self, cnx):
self.cnx = cnx
def __getattr__(self, key):
if not self.cnx:
raise Exception('connection is closed')
return getattr(self.cnx, key)
def cursor(self, *args, **kw):
if not self.cnx:
raise Exception('connection is closed')
return CursorWrapper(self.cnx.cursor(*args, **kw))
def close(self):
# Rollback any uncommitted changes and clear the connection so
# this DBWrapper is no longer usable after close()
if not self.cnx:
raise Exception('connection is closed')
self.cnx.cursor().execute('ROLLBACK')
# We do this rather than cnx.rollback to avoid opening a new transaction
# If our connection gets recycled cnx.rollback will be called then.
self.cnx = None
class CursorWrapper:
def __init__(self, cursor):
self.cursor = cursor
self.logger = logging.getLogger('koji.db')
def __getattr__(self, key):
return getattr(self.cursor, key)
def _timed_call(self, method, args, kwargs):
start = time.time()
ret = getattr(self.cursor, method)(*args, **kwargs)
self.logger.debug("%s operation completed in %.4f seconds", method, time.time() - start)
return ret
def fetchone(self, *args, **kwargs):
return self._timed_call('fetchone', args, kwargs)
def fetchall(self, *args, **kwargs):
return self._timed_call('fetchall', args, kwargs)
def quote(self, operation, parameters):
if hasattr(self.cursor, "mogrify"):
quote = self.cursor.mogrify
else:
def quote(a, b):
return a % b
try:
sql = quote(operation, parameters)
if isinstance(sql, bytes):
try:
sql = koji.util.decode_bytes(sql)
except Exception:
pass
return sql
except Exception:
self.logger.exception(
'Unable to quote query:\n%s\nParameters: %s', operation, parameters)
return "INVALID QUERY"
def preformat(self, sql, params):
"""psycopg2 requires all variable placeholders to use the string (%s) datatype,
regardless of the actual type of the data. Format the sql string to be compliant.
It also requires IN parameters to be in tuple rather than list format."""
sql = POSITIONAL_RE.sub(r'%s', sql)
sql = NAMED_RE.sub(r'%(\1)s', sql)
if isinstance(params, dict):
for name, value in params.items():
if isinstance(value, list):
params[name] = tuple(value)
else:
if isinstance(params, tuple):
params = list(params)
for i, item in enumerate(params):
if isinstance(item, list):
params[i] = tuple(item)
return sql, params
def execute(self, operation, parameters=(), log_errors=True):
debug = self.logger.isEnabledFor(logging.DEBUG)
operation, parameters = self.preformat(operation, parameters)
if debug:
self.logger.debug(self.quote(operation, parameters))
start = time.time()
try:
ret = self.cursor.execute(operation, parameters)
except Exception:
if log_errors:
self.logger.error('Query failed. Query was: %s', self.quote(operation, parameters))
raise
if debug:
self.logger.debug("Execute operation completed in %.4f seconds", time.time() - start)
return ret
## Functions ##
def provideDBopts(**opts):
global _DBopts
if _DBopts is None:
_DBopts = dict([i for i in opts.items() if i[1] is not None])
def setDBopts(**opts):
global _DBopts
_DBopts = opts
def getDBopts():
return _DBopts
def connect():
logger = logging.getLogger('koji.db')
global _DBconn
if hasattr(_DBconn, 'conn'):
# Make sure the previous transaction has been
# closed. This is safe to call multiple times.
conn = _DBconn.conn
try:
# Under normal circumstances, the last use of this connection
# will have issued a raw ROLLBACK to close the transaction. To
# avoid 'no transaction in progress' warnings (depending on postgres
# configuration) we open a new one here.
# Should there somehow be a transaction in progress, a second
# BEGIN will be a harmless no-op, though there may be a warning.
conn.cursor().execute('BEGIN')
conn.rollback()
return DBWrapper(conn)
except psycopg2.Error:
del _DBconn.conn
# create a fresh connection
opts = _DBopts
if opts is None:
opts = {}
try:
if 'dsn' in opts:
conn = psycopg2.connect(dsn=opts['dsn'])
else:
conn = psycopg2.connect(**opts)
conn.set_client_encoding('UTF8')
except Exception:
logger.error(''.join(traceback.format_exception(*sys.exc_info())))
raise
# XXX test
# return conn
_DBconn.conn = conn
return DBWrapper(conn)
def _dml(operation, values, log_errors=True):
"""Run an insert, update, or delete. Return number of rows affected
If log is False, errors will not be logged. It makes sense only for
queries which are expected to fail (LOCK NOWAIT)
"""
c = context.cnx.cursor()
c.execute(operation, values, log_errors=log_errors)
ret = c.rowcount
logger.debug("Operation affected %s row(s)", ret)
c.close()
context.commit_pending = True
return ret
def _fetchMulti(query, values):
"""Run the query and return all rows"""
c = context.cnx.cursor()
c.execute(query, values)
results = c.fetchall()
c.close()
return results
def _fetchSingle(query, values, strict=False):
"""Run the query and return a single row
If strict is true, raise an error if the query returns more or less than
one row."""
results = _fetchMulti(query, values)
numRows = len(results)
if numRows == 0:
if strict:
raise koji.GenericError('query returned no rows')
else:
return None
elif strict and numRows > 1:
raise koji.GenericError('multiple rows returned for a single row query')
else:
return results[0]
def _singleValue(query, values=None, strict=True):
"""Perform a query that returns a single value.
Note that unless strict is True a return value of None could mean either
a single NULL value or zero rows returned."""
if values is None:
values = {}
row = _fetchSingle(query, values, strict)
if row:
if strict and len(row) > 1:
raise koji.GenericError('multiple fields returned for a single value query')
return row[0]
else:
# don't need to check strict here, since that was already handled by _singleRow()
return None
def _multiRow(query, values, fields):
"""Return all rows from "query". Named query parameters
can be specified using the "values" map. Results will be returned
as a list of maps. Each map in the list will have a key for each
element in the "fields" list. If there are no results, an empty
list will be returned."""
return [dict(zip(fields, row)) for row in _fetchMulti(query, values)]
def _singleRow(query, values, fields, strict=False):
"""Return a single row from "query". Named parameters can be
specified using the "values" map. The result will be returned as
as map. The map will have a key for each element in the "fields"
list. If more than one row is returned and "strict" is true, a
GenericError will be raised. If no rows are returned, and "strict"
is True, a GenericError will be raised. Otherwise None will be
returned."""
row = _fetchSingle(query, values, strict)
if row:
return dict(zip(fields, row))
else:
# strict enforced by _fetchSingle
return None
def convert_timestamp(ts):
"""Convert a numeric timestamp to a string suitable for a datetimetz field"""
return datetime.datetime.fromtimestamp(ts, tzutc()).isoformat(' ')
def get_event():
"""Get an event id for this transaction
We cache the result in context, so subsequent calls in the same transaction will
get the same event.
This cache is cleared between the individual calls in a multicall.
See: https://pagure.io/koji/pull-request/74
"""
if hasattr(context, 'event_id'):
return context.event_id
event_id = _singleValue("SELECT get_event()")
context.event_id = event_id
return event_id
def nextval(sequence):
"""Get the next value for the given sequence"""
data = {'sequence': sequence}
return _singleValue("SELECT nextval(%(sequence)s)", data, strict=True)
def currval(sequence):
"""Get the current value for the given sequence"""
data = {'sequence': sequence}
return _singleValue("SELECT currval(%(sequence)s)", data, strict=True)
def db_lock(name, wait=True):
"""Obtain lock for name
The named lock must exist in the locks table
:param string name: the lock name
:param bool wait: whether to wait for the lock (default: True)
:return: True if locked, False otherwise
This function is implemented using db row locks and the locks table
"""
# attempt to lock the row
data = {"name": name}
if wait:
query = "SELECT name FROM locks WHERE name=%(name)s FOR UPDATE"
else:
# using SKIP LOCKED rather than NOWAIT to avoid error messages
query = "SELECT name FROM locks WHERE name=%(name)s FOR UPDATE SKIP LOCKED"
rows = _fetchMulti(query, data)
if rows:
# we have the lock
return True
if not wait:
# in the no-wait case, this could mean either that the row is already locked, or that
# the lock does not exist, so we check
query = "SELECT name FROM locks WHERE name=%(name)s"
rows = _fetchMulti(query, data)
if rows:
# the lock exists, but we did not acquire it
return False
# otherwise, the lock does not exist
raise koji.LockError(f"Lock not defined: {name}")
class Savepoint(object):
def __init__(self, name):
self.name = name
_dml("SAVEPOINT %s" % name, {})
def rollback(self):
_dml("ROLLBACK TO SAVEPOINT %s" % self.name, {})
class InsertProcessor(object):
"""Build an insert statement
table - the table to insert into
data - a dictionary of data to insert (keys = row names)
rawdata - data to insert specified as sql expressions rather than python values
does not support query inserts of "DEFAULT VALUES"
"""
def __init__(self, table, data=None, rawdata=None):
self.table = table
self.data = {}
if data:
self.data.update(data)
self.rawdata = {}
if rawdata:
self.rawdata.update(rawdata)
def __str__(self):
if not self.data and not self.rawdata:
return "-- incomplete update: no assigns"
parts = ['INSERT INTO %s ' % self.table]
columns = sorted(list(self.data.keys()) + list(self.rawdata.keys()))
parts.append("(%s) " % ', '.join(columns))
values = []
for key in columns:
if key in self.data:
values.append("%%(%s)s" % key)
else:
values.append("(%s)" % self.rawdata[key])
parts.append("VALUES (%s)" % ', '.join(values))
return ''.join(parts)
def __repr__(self):
return "<InsertProcessor: %r>" % vars(self)
def set(self, **kwargs):
"""Set data via keyword args"""
self.data.update(kwargs)
def rawset(self, **kwargs):
"""Set rawdata via keyword args"""
self.rawdata.update(kwargs)
def make_create(self, event_id=None, user_id=None):
if event_id is None:
event_id = get_event()
if user_id is None:
context.session.assertLogin()
user_id = context.session.user_id
self.data['create_event'] = event_id
self.data['creator_id'] = user_id
def dup_check(self):
"""Check to see if the insert duplicates an existing row"""
if self.rawdata:
logger.warning("Can't perform duplicate check")
return None
data = self.data.copy()
if 'create_event' in self.data:
# versioned table
data['active'] = True
del data['create_event']
del data['creator_id']
clauses = ["%s = %%(%s)s" % (k, k) for k in data]
query = QueryProcessor(columns=list(data.keys()), tables=[self.table],
clauses=clauses, values=data)
if query.execute():
return True
return False
def execute(self):
return _dml(str(self), self.data)
class UpsertProcessor(InsertProcessor):
"""Build a basic upsert statement
table - the table to insert into
data - a dictionary of data to insert (keys = row names)
rawdata - data to insert specified as sql expressions rather than python values
keys - the rows that are the unique keys
skip_dup - if set to true, do nothing on conflict
"""
def __init__(self, table, data=None, rawdata=None, keys=None, skip_dup=False):
super(UpsertProcessor, self).__init__(table, data=data, rawdata=rawdata)
self.keys = keys
self.skip_dup = skip_dup
if not keys and not skip_dup:
raise ValueError('either keys or skip_dup must be set')
def __repr__(self):
return "<UpsertProcessor: %r>" % vars(self)
def __str__(self):
insert = super(UpsertProcessor, self).__str__()
parts = [insert]
if self.skip_dup:
parts.append(' ON CONFLICT DO NOTHING')
else:
parts.append(f' ON CONFLICT ({",".join(self.keys)}) DO UPDATE SET ')
# filter out conflict keys from data
data = {k: self.data[k] for k in self.data if k not in self.keys}
rawdata = {k: self.rawdata[k] for k in self.rawdata if k not in self.keys}
assigns = [f"{key} = %({key})s" for key in data]
assigns.extend([f"{key} = ({rawdata[key]})" for key in self.rawdata])
parts.append(', '.join(sorted(assigns)))
return ''.join(parts)
class UpdateProcessor(object):
"""Build an update statement
table - the table to insert into
data - a dictionary of data to insert (keys = row names)
rawdata - data to insert specified as sql expressions rather than python values
clauses - a list of where clauses which will be ANDed together
values - dict of values used in clauses
does not support the FROM clause
"""
def __init__(self, table, data=None, rawdata=None, clauses=None, values=None):
self.table = table
self.data = {}
if data:
self.data.update(data)
self.rawdata = {}
if rawdata:
self.rawdata.update(rawdata)
self.clauses = []
if clauses:
self.clauses.extend(clauses)
self.values = {}
if values:
self.values.update(values)
def __str__(self):
if not self.data and not self.rawdata:
return "-- incomplete update: no assigns"
parts = ['UPDATE %s SET ' % self.table]
assigns = ["%s = %%(data.%s)s" % (key, key) for key in self.data]
assigns.extend(["%s = (%s)" % (key, self.rawdata[key]) for key in self.rawdata])
parts.append(', '.join(sorted(assigns)))
if self.clauses:
parts.append('\nWHERE ')
parts.append(' AND '.join(["( %s )" % c for c in sorted(self.clauses)]))
return ''.join(parts)
def __repr__(self):
return "<UpdateProcessor: %r>" % vars(self)
def get_values(self):
"""Returns unified values dict, including data"""
ret = {}
ret.update(self.values)
for key in self.data:
ret["data." + key] = self.data[key]
return ret
def set(self, **kwargs):
"""Set data via keyword args"""
self.data.update(kwargs)
def rawset(self, **kwargs):
"""Set rawdata via keyword args"""
self.rawdata.update(kwargs)
def make_revoke(self, event_id=None, user_id=None):
"""Add standard revoke options to the update"""
if event_id is None:
event_id = get_event()
if user_id is None:
context.session.assertLogin()
user_id = context.session.user_id
self.data['revoke_event'] = event_id
self.data['revoker_id'] = user_id
self.rawdata['active'] = 'NULL'
self.clauses.append('active = TRUE')
def execute(self):
return _dml(str(self), self.get_values())
class DeleteProcessor(object):
"""Build an delete statement
table - the table to delete
clauses - a list of where clauses which will be ANDed together
values - dict of values used in clauses
"""
def __init__(self, table, clauses=None, values=None):
self.table = table
self.clauses = []
if clauses:
self.clauses.extend(clauses)
self.values = {}
if values:
self.values.update(values)
def __str__(self):
parts = ['DELETE FROM %s ' % self.table]
if self.clauses:
parts.append('\nWHERE ')
parts.append(' AND '.join(["( %s )" % c for c in sorted(self.clauses)]))
return ''.join(parts)
def __repr__(self):
return "<DeleteProcessor: %r>" % vars(self)
def get_values(self):
"""Returns unified values dict, including data"""
ret = {}
ret.update(self.values)
return ret
def execute(self):
return _dml(str(self), self.get_values())
class QueryProcessor(object):
"""
Build a query from its components.
- columns, aliases, tables: lists of the column names to retrieve,
the tables to retrieve them from, and the key names to use when
returning values as a map, respectively
- joins: a list of joins in the form 'table1 ON table1.col1 = table2.col2', 'JOIN' will be
prepended automatically; if extended join syntax (LEFT, OUTER, etc.) is required,
it can be specified, and 'JOIN' will not be prepended
- clauses: a list of where clauses in the form 'table1.col1 OPER table2.col2-or-variable';
each clause will be surrounded by parentheses and all will be AND'ed together
- values: the map that will be used to replace any substitution expressions in the query
- transform: a function that will be called on each row (not compatible with
countOnly or singleValue)
- opts: a map of query options; currently supported options are:
countOnly: if True, return an integer indicating how many results would have been
returned, rather than the actual query results
order: a column or alias name to use in the 'ORDER BY' clause
offset: an integer to use in the 'OFFSET' clause
limit: an integer to use in the 'LIMIT' clause
asList: if True, return results as a list of lists, where each list contains the
column values in query order, rather than the usual list of maps
rowlock: if True, use "FOR UPDATE" to lock the queried rows
group: a column or alias name to use in the 'GROUP BY' clause
(controlled by enable_group)
- enable_group: if True, opts.group will be enabled
- order_map: (optional) a name:expression map of allowed orders. Otherwise any column or alias
is allowed
"""
iterchunksize = 1000
def __init__(self, columns=None, aliases=None, tables=None,
joins=None, clauses=None, values=None, transform=None,
opts=None, enable_group=False, order_map=None):
self.columns = columns
self.aliases = aliases
if columns and aliases:
if len(columns) != len(aliases):
raise Exception('column and alias lists must be the same length')
# reorder
alias_table = sorted(zip(aliases, columns))
self.aliases = [x[0] for x in alias_table]
self.columns = [x[1] for x in alias_table]
self.colsByAlias = dict(alias_table)
else:
self.colsByAlias = {}
if columns:
self.columns = sorted(columns)
if aliases:
self.aliases = sorted(aliases)
self.tables = tables
self.joins = joins
if clauses:
self.clauses = sorted(clauses)
else:
self.clauses = clauses
self.cursors = 0
if values:
self.values = values
else:
self.values = {}
self.transform = transform
if opts:
self.opts = opts
else:
self.opts = {}
self.order_map = order_map
self.enable_group = enable_group
self.logger = logging.getLogger('koji.db')
def countOnly(self, count):
self.opts['countOnly'] = count
def __str__(self):
query = \
"""
SELECT %(col_str)s
FROM %(table_str)s
%(join_str)s
%(clause_str)s
%(group_str)s
%(order_str)s
%(offset_str)s
%(limit_str)s
"""
if self.opts.get('countOnly'):
if self.opts.get('offset') \
or self.opts.get('limit') \
or (self.enable_group and self.opts.get('group')):
# If we're counting with an offset and/or limit, we need
# to wrap the offset/limited query and then count the results,
# rather than trying to offset/limit the single row returned
# by count(*). Because we're wrapping the query, we don't care
# about the column values.
col_str = '1'
else:
col_str = 'count(*)'
else:
col_str = self._seqtostr(self.columns)
table_str = self._seqtostr(self.tables, sort=True)
join_str = self._joinstr()
clause_str = self._seqtostr(self.clauses, sep=')\n AND (')
if clause_str:
clause_str = ' WHERE (' + clause_str + ')'
if self.enable_group:
group_str = self._group()
else:
group_str = ''
order_str = self._order()
offset_str = self._optstr('offset')
limit_str = self._optstr('limit')
query = query % locals()
if self.opts.get('countOnly') and \
(self.opts.get('offset') or
self.opts.get('limit') or
(self.enable_group and self.opts.get('group'))):
query = 'SELECT count(*)\nFROM (' + query + ') numrows'
if self.opts.get('rowlock'):
query += '\n FOR UPDATE'
return query
def __repr__(self):
return '<QueryProcessor: ' \
'columns=%r, aliases=%r, tables=%r, joins=%r, clauses=%r, values=%r, opts=%r>' % \
(self.columns, self.aliases, self.tables, self.joins, self.clauses, self.values,
self.opts)
def _seqtostr(self, seq, sep=', ', sort=False):
if seq:
if sort:
seq = sorted(seq)
return sep.join(seq)
else:
return ''
def _joinstr(self):
if not self.joins:
return ''
result = ''
for join in self.joins:
if result:
result += '\n'
if re.search(r'\bjoin\b', join, re.IGNORECASE):
# The join clause already contains the word 'join',
# so don't prepend 'JOIN' to it
result += ' ' + join
else:
result += ' JOIN ' + join
return result
def _order(self):
# Don't bother sorting if we're just counting
if self.opts.get('countOnly'):
return ''
order_opt = self.opts.get('order')
if order_opt:
order_exprs = []
for order in order_opt.split(','):
if order.startswith('-'):
order = order[1:]
direction = ' DESC'
else:
direction = ''
# Check if we're ordering by alias first
if self.order_map is not None:
# order should only be a key in the map
expr = self.order_map.get(order)
if not expr:
raise koji.ParameterError(f'Invalid order term: {order}')
else:
expr = self.colsByAlias.get(order)
if not expr:
if order in self.columns:
expr = order
else:
raise Exception('Invalid order: ' + order)
order_exprs.append(expr + direction)
return 'ORDER BY ' + ', '.join(order_exprs)
else:
return ''
def _group(self):
group_opt = self.opts.get('group')
if group_opt:
group_exprs = []
for group in group_opt.split(','):
if group:
group_exprs.append(group)
return 'GROUP BY ' + ', '.join(group_exprs)
else:
return ''
def _optstr(self, optname):
optval = self.opts.get(optname)
if optval:
return '%s %i' % (optname.upper(), optval)
else:
return ''
def singleValue(self, strict=True):
# self.transform not applied here
return _singleValue(str(self), self.values, strict=strict)
def execute(self):
query = str(self)
if self.opts.get('countOnly'):
return _singleValue(query, self.values, strict=True)
elif self.opts.get('asList'):
if self.transform is None:
return _fetchMulti(query, self.values)
else:
# if we're transforming, generate the dicts so the transform can modify
fields = self.aliases or self.columns
data = _multiRow(query, self.values, fields)
data = [self.transform(row) for row in data]
# and then convert back to lists
data = [[row[f] for f in fields] for row in data]
return data
else:
data = _multiRow(query, self.values, (self.aliases or self.columns))
if self.transform is not None:
data = [self.transform(row) for row in data]
return data
def iterate(self):
if self.opts.get('countOnly'):
return self.execute()
elif self.opts.get('limit') and self.opts['limit'] < self.iterchunksize:
return self.execute()
else:
fields = self.aliases or self.columns
fields = list(fields)
cname = "qp_cursor_%s_%i_%i" % (id(self), os.getpid(), self.cursors)
self.cursors += 1
self.logger.debug('Setting up query iterator. cname=%r', cname)
return self._iterate(cname, str(self), self.values.copy(), fields,
self.iterchunksize, self.opts.get('asList'))
def _iterate(self, cname, query, values, fields, chunksize, as_list=False):
# We pass all this data into the generator so that the iterator works
# from the snapshot when it was generated. Otherwise reuse of the processor
# for similar queries could have unpredictable results.
query = "DECLARE %s NO SCROLL CURSOR FOR %s" % (cname, query)
c = context.cnx.cursor()
c.execute(query, values)
c.close()
try:
query = "FETCH %i FROM %s" % (chunksize, cname)
while True:
if as_list:
if self.transform is None:
buf = _fetchMulti(query, {})
else:
# if we're transforming, generate the dicts so the transform can modify
buf = _multiRow(query, self.values, fields)
buf = [self.transform(row) for row in buf]
# and then convert back to lists
buf = [[row[f] for f in fields] for row in buf]
else:
buf = _multiRow(query, {}, fields)
if self.transform is not None:
buf = [self.transform(row) for row in buf]
if not buf:
break
for row in buf:
yield row
finally:
c = context.cnx.cursor()
c.execute("CLOSE %s" % cname)
c.close()
def executeOne(self, strict=False):
results = self.execute()
if isinstance(results, list):
if len(results) > 0:
if strict and len(results) > 1:
raise koji.GenericError('multiple rows returned for a single row query')
return results[0]
elif strict:
raise koji.GenericError('query returned no rows')
else:
return None
return results
class QueryView:
# abstract base class
# subclasses should provide...
tables = []
joins = []
joinmap = {}
fieldmap = {}
default_fields = ()
def __init__(self, clauses=None, fields=None, opts=None):
self.clauses = clauses
self.fields = fields
self.opts = opts
self._query = None
@property
def query(self):
if self._query is not None:
return self._query
else:
return self.get_query()
def get_query(self):
self._implicit_joins = []
self._values = {}
self._order_map = {}
self.check_opts()
tables = list(self.tables) # copy
clauses = self.get_clauses()
# get_fields needs to be after clauses because it might consider other implicit joins
fields = self.get_fields(self.fields)
aliases, columns = zip(*fields.items())
joins = self.get_joins()
self._query = QueryProcessor(
columns=columns, aliases=aliases,
tables=tables, joins=joins,
clauses=clauses, values=self._values,
opts=self.opts, order_map=self._order_map)
return self._query
def get_fields(self, fields):
fields = fields or self.default_fields or ['*']
if isinstance(fields, str):
fields = [fields]
# handle special field names
flist = []
for field in fields:
if field == '*':
# all fields that don't require additional joins
for f in self.fieldmap:
joinkey = self.fieldmap[f][1]
if joinkey is None or joinkey in self._implicit_joins:
flist.append(f)
elif field == '**':
# all fields
flist.extend(self.fieldmap)
else:
flist.append(field)
return {f: self.map_field(f) for f in set(flist)}
def check_opts(self):
# some options may trigger joins
if self.opts is None:
return
if 'order' in self.opts:
for key in self.opts['order'].split(','):
if key.startswith('-'):
key = key[1:]
self._order_map[key] = self.map_field(key)
if 'group' in self.opts:
for key in self.opts['group'].split(','):
self.map_field(key)
def map_field(self, field):
f_info = self.fieldmap.get(field)
if f_info is None:
raise koji.ParameterError(f'Invalid field for query {field}')
fullname, joinkey = f_info
fullname = fullname or field
if joinkey:
self._implicit_joins.append(joinkey)
# duplicates removed later
return fullname
def get_clauses(self):
# for now, just a very simple implementation
result = []
clauses = self.clauses or []
for n, clause in enumerate(clauses):
# TODO checks check checks
if len(clause) == 2:
# implicit operator
field, value = clause
if isinstance(value, (list, tuple)):
op = 'IN'
else:
op = '='
elif len(clause) == 3:
field, op, value = clause
op = op.upper()
if op not in ('IN', '=', '!=', '>', '<', '>=', '<=', 'IS', 'IS NOT', '@>', '<@'):
raise koji.ParameterError(f'Invalid operator: {op}')
else:
raise koji.ParameterError(f'Invalid clause: {clause}')
fullname = self.map_field(field)
key = f'v_{field}_{n}'
self._values[key] = value
result.append(f'{fullname} {op} %({key})s')
return result
def get_joins(self):
joins = list(self.joins)
seen = set()
# note we preserve the order that implicit joins were added
for joinkey in self._implicit_joins:
if joinkey in seen:
continue
seen.add(joinkey)
joins.append(self.joinmap[joinkey])
return joins
def execute(self):
return self.query.execute()
def executeOne(self, strict=False):
return self.query.executeOne(strict=strict)
def iterate(self):
return self.query.iterate()
def singleValue(self, strict=True):
return self.query.singleValue(strict=strict)
class BulkInsertProcessor(object):
def __init__(self, table, data=None, columns=None, strict=True, batch=1000):
"""Do bulk inserts - it has some limitations compared to
InsertProcessor (no rawset, dup_check).
set() is replaced with add_record() to avoid confusion
table - name of the table
data - list of dict per record
columns - list/set of names of used columns - makes sense
mainly with strict=True
strict - if True, all records must contain values for all columns.
if False, missing values will be inserted as NULLs
batch - batch size for inserts (one statement per batch)
"""
self.table = table
self.data = []
if columns is None:
self.columns = set()
else:
self.columns = set(columns)
if data is not None:
self.data = data
for row in data:
self.columns |= set(row.keys())
self.strict = strict
self.batch = batch
def __str__(self):
if not self.data:
return "-- incomplete insert: no data"
query, params = self._get_insert(self.data)
return query
def _get_insert(self, data):
"""
Generate one insert statement for the given data
:param list data: list of rows (dict format) to insert
:returns: (query, params)
"""
if not data:
# should not happen
raise ValueError('no data for insert')
parts = ['INSERT INTO %s ' % self.table]
columns = sorted(self.columns)
parts.append("(%s) " % ', '.join(columns))
prepared_data = {}
values = []
i = 0
for row in data:
row_values = []
for key in columns:
if key in row:
row_key = '%s%d' % (key, i)
row_values.append("%%(%s)s" % row_key)
prepared_data[row_key] = row[key]
elif self.strict:
raise koji.GenericError("Missing value %s in BulkInsert" % key)
else:
row_values.append("NULL")
values.append("(%s)" % ', '.join(row_values))
i += 1
parts.append("VALUES %s" % ', '.join(values))
return ''.join(parts), prepared_data
def __repr__(self):
return "<BulkInsertProcessor: %r>" % vars(self)
def add_record(self, **kwargs):
"""Set whole record via keyword args"""
if not kwargs:
raise koji.GenericError("Missing values in BulkInsert.add_record")
self.data.append(kwargs)
self.columns |= set(kwargs.keys())
def execute(self):
if not self.batch:
self._one_insert(self.data)
else:
for i in range(0, len(self.data), self.batch):
data = self.data[i:i + self.batch]
self._one_insert(data)
def _one_insert(self, data):
query, params = self._get_insert(data)
_dml(query, params)
def _applyQueryOpts(results, queryOpts):
"""
Apply queryOpts to results in the same way QueryProcessor would.
results is a list of maps.
queryOpts is a map which may contain the following fields:
countOnly
order
offset
limit
Note:
- asList is supported by QueryProcessor but not by this method.
We don't know the original query order, and so don't have a way to
return a useful list. asList should be handled by the caller.
- group is supported by QueryProcessor but not by this method as well.
"""
if queryOpts is None:
queryOpts = {}
if queryOpts.get('order'):
order = queryOpts['order']
reverse = False
if order.startswith('-'):
order = order[1:]
reverse = True
results.sort(key=lambda o: o[order], reverse=reverse)
if queryOpts.get('offset'):
results = results[queryOpts['offset']:]
if queryOpts.get('limit'):
results = results[:queryOpts['limit']]
if queryOpts.get('countOnly'):
return len(results)
else:
return results
class BulkUpdateProcessor(object):
"""Build a bulk update statement using a from clause
table - the table to insert into
data - list of dictionaries of update data (keys = row names)
match_keys - the fields that are used to match
The row data is provided as a list of dictionaries. Each entry
must contain the same keys.
The match_keys value indicate which keys are used to select the
rows to update. The remaining keys are the actual updates.
I.e. if you have data = [{'a':1, 'b':2}] with match_keys=['a'],
this will set b=2 for rows where a=1
"""
def __init__(self, table, data=None, match_keys=None):
self.table = table
self.data = data or []
if match_keys is None:
self.match_keys = []
else:
self.match_keys = list(match_keys)
self._values = {}
def __str__(self):
return self.get_sql()
def get_sql(self):
if not self.data or not self.match_keys:
return "-- incomplete bulk update"
set_keys, all_keys = self.get_keys()
match_keys = list(self.match_keys)
match_keys.sort()
utable = f'__kojibulk_{self.table}'
utable.replace('.', '_') # in case schema qualified
assigns = [f'{key} = {utable}.{key}' for key in set_keys]
values = {} # values for lookup
fdata = [] # data for VALUES clause
for n, row in enumerate(self.data):
# each row is a dictionary with all keys
parts = []
for key in all_keys:
v_key = f'val_{key}_{n}'
values[v_key] = row[key]
parts.append(f'%({v_key})s')
fdata.append('(%s)' % ', '.join(parts))
clauses = [f'{self.table}.{key} = {utable}.{key}' for key in match_keys]
parts = [
'UPDATE %s SET %s\n' % (self.table, ', '.join(assigns)),
'FROM (VALUES %s)\nAS %s (%s)\n' % (
', '.join(fdata), utable, ', '.join(all_keys)),
'WHERE (%s)' % ' AND '.join(clauses),
]
self._values = values
return ''.join(parts)
def get_keys(self):
if not self.data:
raise ValueError('no update data')
all_keys = list(self.data[0].keys())
for key in all_keys:
if not isinstance(key, str):
raise TypeError('update data must use string keys')
all_keys.sort()
set_keys = [k for k in all_keys if k not in self.match_keys]
set_keys.sort()
# also check that data is sane
required = set(all_keys)
for row in self.data:
if set(row.keys()) != required:
raise ValueError('mismatched update keys')
return set_keys, all_keys
def __repr__(self):
return "<BulkUpdateProcessor: %r>" % vars(self)
def execute(self):
sql = self.get_sql() # sets self._values
return _dml(sql, self._values)
# the end