166 lines
5.1 KiB
Python
166 lines
5.1 KiB
Python
from __future__ import absolute_import
|
|
try:
|
|
from unittest import mock
|
|
except ImportError:
|
|
import mock
|
|
import json
|
|
import koji
|
|
|
|
from koji.xmlrpcplus import Fault, DateTime
|
|
|
|
|
|
class BaseFakeClientSession(koji.ClientSession):
|
|
|
|
def __init__(self, *a, **kw):
|
|
super(BaseFakeClientSession, self).__init__(*a, **kw)
|
|
|
|
def multiCall(self, strict=False):
|
|
if not self.multicall:
|
|
raise Exception("not in multicall")
|
|
ret = []
|
|
self.multicall = False
|
|
calls = self._calls
|
|
self._calls = []
|
|
for call in calls:
|
|
method = call['methodName']
|
|
args, kwargs = koji.decode_args(*call['params'])
|
|
try:
|
|
result = self._callMethod(method, args, kwargs)
|
|
# multicall wraps non-fault results in a singleton
|
|
result = (result,)
|
|
ret.append(result)
|
|
except Fault as fault:
|
|
if strict:
|
|
raise
|
|
else:
|
|
ret.append({'faultCode': fault.faultCode,
|
|
'faultString': fault.faultString})
|
|
return ret
|
|
|
|
|
|
class FakeClientSession(BaseFakeClientSession):
|
|
|
|
def __init__(self, *a, **kw):
|
|
super(FakeClientSession, self).__init__(*a, **kw)
|
|
self._calldata = {}
|
|
self._offsets = {}
|
|
|
|
def load_calls(self, data):
|
|
"""Load call data
|
|
|
|
Data should be a list of dictionaries with keys:
|
|
- method
|
|
- args
|
|
- kwargs
|
|
- result (for successful calls)
|
|
- fault (for errors)
|
|
That represent call data, e.g. as generated by RecordingClientSession
|
|
"""
|
|
|
|
for call in data:
|
|
key = self._munge([call['method'], call['args'], call['kwargs']])
|
|
self._calldata.setdefault(key, []).append(call)
|
|
|
|
def load(self, filename):
|
|
# load from json file
|
|
with open(filename, 'rt') as fp:
|
|
data = json.load(fp, object_hook=decode_data)
|
|
self.load_calls(data)
|
|
|
|
def _callMethod(self, name, args, kwargs=None, retry=True):
|
|
if self.multicall:
|
|
return super(FakeClientSession, self)._callMethod(name, args,
|
|
kwargs, retry)
|
|
key = self._munge([name, args, kwargs])
|
|
# we may have a series of calls for each key
|
|
calls = self._calldata.get(key)
|
|
ofs = self._offsets.get(key, 0)
|
|
call = calls[ofs]
|
|
ofs += 1
|
|
if ofs < len(calls):
|
|
# don't go past the end
|
|
self._offsets[key] = ofs
|
|
if call:
|
|
if 'fault' in call:
|
|
fault = Fault(call['fault']['faultCode'],
|
|
call['fault']['faultString'])
|
|
raise koji.convertFault(fault)
|
|
else:
|
|
return call['result']
|
|
else:
|
|
return mock.MagicMock()
|
|
|
|
def _munge(self, data):
|
|
def callback(value):
|
|
if isinstance(value, list):
|
|
return tuple(value)
|
|
elif isinstance(value, dict):
|
|
keys = sorted(value.keys())
|
|
return tuple([(k, value[k]) for k in keys])
|
|
else:
|
|
return value
|
|
walker = koji.util.DataWalker(data, callback)
|
|
return walker.walk()
|
|
|
|
|
|
class RecordingClientSession(BaseFakeClientSession):
|
|
|
|
def __init__(self, *a, **kw):
|
|
super(RecordingClientSession, self).__init__(*a, **kw)
|
|
self._calldata = []
|
|
|
|
def get_calls(self):
|
|
return self._calldata
|
|
|
|
def dump(self, filename):
|
|
with open(filename, 'wt') as fp:
|
|
# json.dump(self._calldata, fp, indent=4, sort_keys=True)
|
|
json.dump(self._calldata, fp, indent=4, sort_keys=True, default=encode_data)
|
|
self._calldata = []
|
|
|
|
def _callMethod(self, name, args, kwargs=None, retry=True):
|
|
if self.multicall:
|
|
return super(RecordingClientSession, self)._callMethod(name, args,
|
|
kwargs, retry)
|
|
call = {
|
|
'method': name,
|
|
'args': args,
|
|
'kwargs': kwargs,
|
|
}
|
|
self._calldata.append(call)
|
|
try:
|
|
result = super(RecordingClientSession, self)._callMethod(name, args,
|
|
kwargs, retry)
|
|
call['result'] = result
|
|
return result
|
|
except Fault as fault:
|
|
err = {'faultCode': fault.faultCode,
|
|
'faultString': fault.faultString}
|
|
call['fault'] = err
|
|
raise
|
|
except koji.GenericError as e:
|
|
err = {'faultCode': e.faultCode,
|
|
'faultString': str(e)}
|
|
call['fault'] = err
|
|
raise
|
|
|
|
|
|
def encode_data(value):
|
|
"""Encode data for json"""
|
|
if isinstance(value, DateTime):
|
|
return {'__type': 'DateTime', 'value': value.value}
|
|
else:
|
|
raise TypeError('Unknown type for json encoding')
|
|
|
|
|
|
def decode_data(value):
|
|
"""Decode data encoded for json"""
|
|
if isinstance(value, dict):
|
|
_type = value.get('__type')
|
|
if _type == 'DateTime':
|
|
return DateTime(value['value'])
|
|
#else
|
|
return value
|
|
|
|
|
|
# the end
|