PR#3269: return 400 codes when client fails to send a full request

Merges #3269
https://pagure.io/koji/pull-request/3269

Fixes: #3268
https://pagure.io/koji/issue/3268
Some network errors cannot be retried
This commit is contained in:
Tomas Kopecek 2022-03-21 14:34:37 +01:00
commit f5fe7a550e
3 changed files with 58 additions and 15 deletions

View file

@ -62,6 +62,7 @@ from koji.tasks import parse_task_params
import koji.xmlrpcplus
from koji.context import context
from koji.daemon import SCM
from koji.server import BadRequest, RequestTimeout
from koji.util import (
base64encode,
decode_bytes,
@ -15361,7 +15362,17 @@ def handle_upload(environ):
os.ftruncate(fd, offset)
os.lseek(fd, offset, 0)
while True:
chunk = inf.read(65536)
try:
chunk = inf.read(65536)
except OSError as e:
str_e = str(e)
logger.error(f"Error reading upload. Offset {offset}+{size}, path {fn}")
if 'timeout' in str_e:
logger.exception("Timed out reading input stream")
raise RequestTimeout(str_e)
else:
logger.exception("Error reading input stream")
raise BadRequest(str_e)
if not chunk:
break
size += len(chunk)

View file

@ -38,6 +38,7 @@ import koji.policy
import koji.util
from koji.context import context
# import xmlrpclib functions from koji to use tweaked Marshaller
from koji.server import ServerError, BadRequest, RequestTimeout
from koji.xmlrpcplus import ExtendedMarshaller, Fault, dumps, getparser
@ -213,7 +214,16 @@ class ModXMLRPCRequestHandler(object):
rlen = 0
maxlen = opts.get('MaxRequestLength', None)
while True:
chunk = stream.read(8192)
try:
chunk = stream.read(8192)
except OSError as e:
str_e = str(e)
if 'timeout' in str_e:
self.logger.exception("Timed out reading input stream")
raise RequestTimeout(str_e)
else:
self.logger.exception("Error reading input stream")
raise BadRequest(str_e)
if not chunk:
break
rlen += len(chunk)
@ -255,6 +265,9 @@ class ModXMLRPCRequestHandler(object):
# wrap response in a singleton tuple
response = (response,)
response = dumps(response, methodresponse=1, marshaller=Marshaller)
except ServerError:
raise
# these are handled higher up
except Fault as fault:
self.traceback = True
response = dumps(fault, marshaller=Marshaller)
@ -382,6 +395,18 @@ def offline_reply(start_response, msg=None):
return [response]
def error_reply(start_response, status, response, extra_headers=None):
response = response.encode()
headers = [
('Content-Length', str(len(response))),
('Content-Type', "text/plain"),
]
if extra_headers:
headers.extend(extra_headers)
start_response(status, headers)
return [response]
def load_config(environ):
"""Load configuration options
@ -744,18 +769,12 @@ def application(environ, start_response):
firstcall = False
# XMLRPC uses POST only. Reject anything else
if environ['REQUEST_METHOD'] != 'POST':
headers = [
extra_headers = [
('Allow', 'POST'),
]
start_response('405 Method Not Allowed', headers)
response = "Method Not Allowed\n" \
"This is an XML-RPC server. Only POST requests are accepted."
response = response.encode()
headers = [
('Content-Length', str(len(response))),
('Content-Type', "text/plain"),
]
return [response]
"This is an XML-RPC server. Only POST requests are accepted.\n"
return error_reply(start_response, '405 Method Not Allowed', response, extra_headers)
if opts.get('ServerOffline'):
return offline_reply(start_response, msg=opts.get("OfflineMessage", None))
# XXX check request length
@ -776,10 +795,15 @@ def application(environ, start_response):
except Exception:
return offline_reply(start_response, msg="database outage")
h = ModXMLRPCRequestHandler(registry)
if environ.get('CONTENT_TYPE') == 'application/octet-stream':
response = h._wrap_handler(h.handle_upload, environ)
else:
response = h._wrap_handler(h.handle_rpc, environ)
try:
if environ.get('CONTENT_TYPE') == 'application/octet-stream':
response = h._wrap_handler(h.handle_upload, environ)
else:
response = h._wrap_handler(h.handle_rpc, environ)
except BadRequest as e:
return error_reply(start_response, '400 Bad Request', str(e) + '\n')
except RequestTimeout as e:
return error_reply(start_response, '408 Request Timeout', str(e) + '\n')
response = response.encode()
headers = [
('Content-Length', str(len(response))),

View file

@ -26,3 +26,11 @@ class ServerError(Exception):
class ServerRedirect(ServerError):
"""Used to handle redirects"""
class BadRequest(ServerError):
"""Used to trigger an http 400 error"""
class RequestTimeout(ServerError):
"""Used to trigger an http 408 error"""