diff --git a/cli/koji_cli/lib.py b/cli/koji_cli/lib.py index 04ef329b..e5821d8a 100644 --- a/cli/koji_cli/lib.py +++ b/cli/koji_cli/lib.py @@ -568,13 +568,18 @@ def download_file(url, relpath, quiet=False, noprogress=False, size=None, else: print("Downloading: %s" % relpath) + if not filesize: + response = requests.head(url, timeout=10) + if response.status_code == 200 and response.headers.get('Content-Length'): + filesize = int(response.headers['Content-Length']) + pos = 0 headers = {} if filesize: # append the file f = open(relpath, 'ab') pos = f.tell() - if pos: + if pos != 0: if filesize == pos: if not quiet: print("File %s already downloaded, skipping" % relpath) @@ -588,7 +593,7 @@ def download_file(url, relpath, quiet=False, noprogress=False, size=None, try: # closing needs to be used for requests < 2.18.0 - with closing(requests.get(url, headers=headers, stream=True)) as response: + with closing(koji.request_with_retry().get(url, headers=headers, stream=True)) as response: if response.status_code in (200, 416): # full content provided or reaching behind EOF # rewrite in such case f.close() diff --git a/tests/test_cli/test_download_file.py b/tests/test_cli/test_download_file.py index 894398f6..de3268f8 100644 --- a/tests/test_cli/test_download_file.py +++ b/tests/test_cli/test_download_file.py @@ -8,8 +8,10 @@ import requests_mock import requests import unittest + from koji_cli.lib import download_file, _download_progress + def mock_open(): """Return the right patch decorator for open""" if six.PY2: @@ -27,16 +29,16 @@ class TestDownloadFile(unittest.TestCase): self.stdout.truncate() self.stderr.seek(0) self.stderr.truncate() - self.requests_get.reset_mock() + self.request_with_retry.reset_mock() def setUp(self): self.tempdir = tempfile.mkdtemp() self.filename = self.tempdir + "/filename" self.stdout = mock.patch('sys.stdout', new_callable=six.StringIO).start() self.stderr = mock.patch('sys.stderr', new_callable=six.StringIO).start() - self.requests_get = mock.patch('requests.get', create=True, name='requests.get').start() - # will work when contextlib.closing will be removed in future - #self.requests_get = self.requests_get.return_value.__enter__ + self.request_with_retry = mock.patch('koji.request_with_retry').start() + self.get_mock = self.request_with_retry.return_value.get + self.head = mock.patch('requests.head').start() def tearDown(self): mock.patch.stopall() @@ -54,53 +56,63 @@ class TestDownloadFile(unittest.TestCase): else: self.assertEqual(cm.exception.args, (21, 'Is a directory')) + @mock.patch('os.unlink') @mock_open() - def test_handle_download_file(self, m_open): + def test_handle_download_file(self, m_open, os_unlink): self.reset_mock() + m_open.return_value.tell.return_value = 0 + rsp_head = self.head.return_value + rsp_head.status_code = 200 + rsp_head.headers = {'Content-Length': '5'} response = mock.MagicMock() - self.requests_get.return_value = response - response.headers.get.return_value = '5' # content-length + self.get_mock.return_value = response + response.headers.get.return_value = '5' # content-length response.iter_content.return_value = ['abcde'] rv = download_file("http://url", self.filename) actual = self.stdout.getvalue() - expected = 'Downloading: %s\n[====================================] 100%% 5.00 B\r\n' % self.filename + expected = 'Downloading: %s\n[====================================] 100%% 5.00 B / 5.00 B\r\n' % self.filename self.assertMultiLineEqual(actual, expected) - self.requests_get.assert_called_once() + self.get_mock.assert_called_once() m_open.assert_called_once() - response.headers.get.assert_called_once() + response.headers.get.assert_not_called() response.iter_content.assert_called_once() self.assertIsNone(rv) + @mock.patch('os.unlink') @mock_open() - def test_handle_download_file_undefined_length(self, m_open): + def test_handle_download_file_undefined_length(self, m_open, os_unlink): self.reset_mock() + m_open.return_value.tell.return_value = 0 + rsp_head = self.head.return_value + rsp_head.status_code = 200 + rsp_head.headers = {'Content-Length': str(65536 * 2)} response = mock.MagicMock() - self.requests_get.return_value = response - response.headers.get.return_value = None # content-length + self.get_mock.return_value = response + response.headers.get.return_value = None # content-length response.iter_content.return_value = ['a' * 65536, 'b' * 65536] rv = download_file("http://url", self.filename) actual = self.stdout.getvalue() - expected = 'Downloading: %s\n[ ] ???%% 64.00 KiB\r[ ] ???%% 128.00 KiB\r[====================================] 100%% 128.00 KiB\r\n' % self.filename + print(repr(actual)) + expected = 'Downloading: %s\n[================== ] 50%% 64.00 KiB / 128.00 KiB\r[====================================] 100%% 128.00 KiB / 128.00 KiB\r\n' % self.filename self.assertMultiLineEqual(actual, expected) - self.requests_get.assert_called_once() + self.get_mock.assert_called_once() m_open.assert_called_once() - response.headers.get.assert_called_once() + response.headers.get.assert_not_called() response.iter_content.assert_called_once() self.assertIsNone(rv) - def test_handle_download_file_with_size(self): rv = download_file("http://url", self.filename, size=10, num=8) actual = self.stdout.getvalue() expected = 'Downloading [8/10]: %s\n\n' % self.filename self.assertMultiLineEqual(actual, expected) - self.requests_get.assert_called_once() + self.get_mock.assert_called_once() self.assertIsNone(rv) def test_handle_download_file_quiet_noprogress(self): @@ -127,6 +139,7 @@ class TestDownloadFile(unittest.TestCase): - http vs https ''' + class TestDownloadProgress(unittest.TestCase): # Show long diffs in error output... maxDiff = None @@ -158,6 +171,7 @@ class TestDownloadFileError(unittest.TestCase): @requests_mock.Mocker() def test_handle_download_file_error_404(self, m): + m.head('http://url') m.get("http://url", text='Not Found\n', status_code=404) with self.assertRaises(requests.HTTPError): download_file("http://url", self.filename) @@ -168,6 +182,7 @@ class TestDownloadFileError(unittest.TestCase): @requests_mock.Mocker() def test_handle_download_file_error_500(self, m): + m.head('http://url') m.get("http://url", text='Internal Server Error\n', status_code=500) with self.assertRaises(requests.HTTPError): download_file("http://url", self.filename) @@ -176,5 +191,6 @@ class TestDownloadFileError(unittest.TestCase): except Exception: pass + if __name__ == '__main__': unittest.main()