diff --git a/.travis.yml b/.travis.yml
index ad6f7d80..71a10429 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -6,7 +6,8 @@ python:
- 3.4
- 3.5
install:
-- pip install requests mock nose nose-cov python-coveralls aliyun-python-sdk-sts
+- pip install requests nose nose-cov python-coveralls aliyun-python-sdk-sts
+- pip install --upgrade mock
script:
- nosetests unittests/ --with-cov
- if [ -n "$OSS_TEST_ACCESS_KEY_ID" ]; then nosetests tests/ --with-cov; fi
diff --git a/CHANGELOG.rst b/CHANGELOG.rst
index 58175629..09984a77 100644
--- a/CHANGELOG.rst
+++ b/CHANGELOG.rst
@@ -6,9 +6,11 @@ Python SDK的版本号遵循 `Semantic Versioning `_ 规则
Version 2.1.0
-------------
-- 增加:可以通过 `defaults.connection_pool_size` 来设置连接池的最大连接数。
-- 增加:可以通过 `resumable_upload` 函数的 `num_threads` 参数指定并发的线程数,来进行并发上传。
-- 修复:修复一些文档的Bug
+- 增加:可以通过 `oss2.defaults.connection_pool_size` 来设置连接池的最大连接数。
+- 增加:可以通过 `oss2.resumable_upload` 函数的 `num_threads` 参数指定并发的线程数,来进行并发上传。
+- 增加:提供断点下载函数 `oss2.resumable_download` 。
+- 修复:保存断点信息的文件名应该由“规则化”的本地文件名生成;当断点信息文件格式不是json时,删除断点信息文件。
+- 修复:修复一些文档的Bug。
Version 2.0.6
-------------
diff --git a/README.rst b/README.rst
index ca03d96b..e38fb4cb 100644
--- a/README.rst
+++ b/README.rst
@@ -108,12 +108,11 @@ Python 2.6,2.7,3.3,3.4,3.5
$ export OSS_TEST_STS_ARN=
-然后通过以下方式之一运行测试:
+然后通过以下方式运行测试:
.. code-block:: bash
- $ python -m unittest discover tests # 如果Python版本 >= 2.7
- $ nosetests # 如果安装了nose
+ $ nosetests # 请先安装nose
更多使用
--------
diff --git a/doc/easy.rst b/doc/easy.rst
index b4bd938d..40740af2 100644
--- a/doc/easy.rst
+++ b/doc/easy.rst
@@ -16,11 +16,11 @@
.. autoclass:: oss2.PartIterator
-断点续传上传
-~~~~~~~~~~~~
+断点续传(上传、下载)
+~~~~~~~~~~~~~~~~~~~
.. autofunction:: oss2.resumable_upload
-
+.. autofunction:: oss2.resumable_download
FileObject适配器
~~~~~~~~~~~~~~~~~~
diff --git a/oss2/__init__.py b/oss2/__init__.py
index 51509423..45b7e35e 100644
--- a/oss2/__init__.py
+++ b/oss2/__init__.py
@@ -11,7 +11,9 @@
MultipartUploadIterator, ObjectUploadIterator, PartIterator)
-from .resumable import resumable_upload, ResumableStore, determine_part_size
+from .resumable import resumable_upload, resumable_download, ResumableStore, ResumableDownloadStore, determine_part_size
+from .resumable import make_upload_store, make_download_store
+
from .compat import to_bytes, to_string, to_unicode, urlparse, urlquote, urlunquote
diff --git a/oss2/defaults.py b/oss2/defaults.py
index aafa2dc1..07271bc6 100644
--- a/oss2/defaults.py
+++ b/oss2/defaults.py
@@ -33,4 +33,14 @@ def get(value, default_value):
#: 每个Session连接池大小
-connection_pool_size = 10
\ No newline at end of file
+connection_pool_size = 10
+
+
+#: 对于断点下载,如果OSS文件大小大于该值就进行并行下载(multiget)
+multiget_threshold = 100 * 1024 * 1024
+
+#: 并行下载(multiget)缺省线程数
+multiget_num_threads = 4
+
+#: 并行下载(multiget)的缺省分片大小
+multiget_part_size = 10 * 1024 * 1024
\ No newline at end of file
diff --git a/oss2/exceptions.py b/oss2/exceptions.py
index b522fff7..243e9764 100644
--- a/oss2/exceptions.py
+++ b/oss2/exceptions.py
@@ -44,7 +44,9 @@ def __init__(self, status, headers, body, details):
self.message = self.details.get('Message', '')
def __str__(self):
- return str(self.details)
+ error = {'status': self.status,
+ 'details': self.details}
+ return str(error)
class ClientError(OssError):
@@ -52,7 +54,9 @@ def __init__(self, message):
OssError.__init__(self, OSS_CLIENT_ERROR_STATUS, {}, 'ClientError: ' + message, {})
def __str__(self):
- return self.body
+ error = {'status': self.status,
+ 'details': self.body}
+ return str(error)
class RequestError(OssError):
@@ -61,7 +65,9 @@ def __init__(self, e):
self.exception = e
def __str__(self):
- return self.body
+ error = {'status': self.status,
+ 'details': self.body}
+ return str(error)
class ServerError(OssError):
@@ -147,6 +153,11 @@ class ObjectNotAppendable(Conflict):
code = 'ObjectNotAppendable'
+class PreconditionFailed(ServerError):
+ status = 412
+ code = 'PreconditionFailed'
+
+
class NotModified(ServerError):
status = 304
code = ''
diff --git a/oss2/resumable.py b/oss2/resumable.py
index de4be357..adc934d6 100644
--- a/oss2/resumable.py
+++ b/oss2/resumable.py
@@ -22,6 +22,9 @@
import logging
import functools
import threading
+import shutil
+import random
+import string
_MAX_PART_COUNT = 10000
@@ -37,7 +40,11 @@ def resumable_upload(bucket, key, filename,
num_threads=None):
"""断点上传本地文件。
- 缺省条件下,该函数会在用户HOME目录下保存断点续传的信息。当待上传的本地文件没有发生变化,
+ 实现中采用分片上传方式上传本地文件,缺省的并发数是 `oss2.defaults.multipart_num_threads` ,并且在
+ 本地磁盘保存已经上传的分片信息。如果因为某种原因上传被中断,下次上传同样的文件,即源文件和目标文件路径都
+ 一样,就只会上传缺失的分片。
+
+ 缺省条件下,该函数会在用户 `HOME` 目录下保存断点续传的信息。当待上传的本地文件没有发生变化,
且目标文件名没有变化时,会根据本地保存的信息,从断点开始上传。
:param bucket: :class:`Bucket ` 对象
@@ -48,7 +55,7 @@ def resumable_upload(bucket, key, filename,
:param multipart_threshold: 文件长度大于该值时,则用分片上传。
:param part_size: 指定分片上传的每个分片的大小。如不指定,则自动计算。
:param progress_callback: 上传进度回调函数。参见 :ref:`progress_callback` 。
- :param num_threads: 并发上传的线程数,如不指定则使用 `oss2.defaults.multipart_threshold` 。
+ :param num_threads: 并发上传的线程数,如不指定则使用 `oss2.defaults.multipart_num_threads` 。
"""
size = os.path.getsize(filename)
multipart_threshold = defaults.get(multipart_threshold, defaults.multipart_threshold)
@@ -67,9 +74,65 @@ def resumable_upload(bucket, key, filename,
progress_callback=progress_callback)
+def resumable_download(bucket, key, filename,
+ multiget_threshold=None,
+ part_size=None,
+ progress_callback=None,
+ num_threads=None,
+ store=None):
+ """断点下载。
+
+ 实现的方法是:
+ #. 在本地创建一个临时文件,文件名由原始文件名加上一个随机的后缀组成;
+ #. 通过指定请求的 `Range` 头按照范围并发读取OSS文件,并写入到临时文件里对应的位置;
+ #. 全部完成之后,把临时文件重名为目标文件 (即 `filename` )
+
+ 在上述过程中,断点信息,即已经完成的范围,会保存在磁盘上。因为某种原因下载中断,后续如果下载
+ 同样的文件,也就是源文件和目标文件一样,就会先读取断点信息,然后只下载缺失的部分。
+
+ 缺省设置下,断点信息保存在 `HOME` 目录的一个子目录下。可以通过 `store` 参数更改保存位置。
+
+ 使用该函数应注意如下细节:
+ #. 对同样的源文件、目标文件,避免多个程序(线程)同时调用该函数。因为断点信息会在磁盘上互相覆盖,或临时文件名会冲突。
+ #. 避免使用太小的范围(分片),即 `part_size` 不宜过小,建议大于或等于 `oss2.defaults.multiget_part_size` 。
+ #. 如果目标文件已经存在,那么该函数会覆盖此文件。
+
+
+ :param bucket: :class:`Bucket ` 对象。
+ :param str key: 待下载的远程文件名。
+ :param str filename: 本地的目标文件名。
+ :param int multiget_threshold: 文件长度大于该值时,则使用断点下载。
+ :param int part_size: 指定期望的分片大小,即每个请求获得的字节数,实际的分片大小可能有所不同。
+ :param progress_callback: 下载进度回调函数。参见 :ref:`progress_callback` 。
+ :param num_threads: 并发下载的线程数,如不指定则使用 `oss2.defaults.multiget_num_threads` 。
+
+ :param store: 用来保存断点信息的持久存储,可以指定断点信息所在的目录。
+ :type store: `ResumableDownloadStore`
+
+ :raises: 如果OSS文件不存在,则抛出 :class:`NotFound ` ;也有可能抛出其他因下载文件而产生的异常。
+ """
+
+ multiget_threshold = defaults.get(multiget_threshold, defaults.multiget_threshold)
+
+ result = bucket.head_object(key)
+ if result.content_length >= multiget_threshold:
+ downloader = _ResumableDownloader(bucket, key, filename, _ObjectInfo.make(result),
+ part_size=part_size,
+ progress_callback=progress_callback,
+ num_threads=num_threads,
+ store=store)
+ downloader.download()
+ else:
+ bucket.get_object_to_file(key, filename,
+ progress_callback=progress_callback)
+
+
+_MAX_MULTIGET_PART_COUNT = 100
+
+
def determine_part_size(total_size,
preferred_size=None):
- """确定分片大小。
+ """确定分片上传是分片的大小。
:param int total_size: 总共需要上传的长度
:param int preferred_size: 用户期望的分片大小。如果不指定则采用defaults.part_size
@@ -79,19 +142,231 @@ def determine_part_size(total_size,
if not preferred_size:
preferred_size = defaults.part_size
+ return _determine_part_size_internal(total_size, preferred_size, _MAX_PART_COUNT)
+
+
+def _determine_part_size_internal(total_size, preferred_size, max_count):
if total_size < preferred_size:
return total_size
- if preferred_size * _MAX_PART_COUNT < total_size:
- if total_size % _MAX_PART_COUNT:
- return total_size // _MAX_PART_COUNT + 1
+ if preferred_size * max_count < total_size:
+ if total_size % max_count:
+ return total_size // max_count + 1
else:
- return total_size // _MAX_PART_COUNT
+ return total_size // max_count
else:
return preferred_size
-class _ResumableUploader(object):
+def _split_to_parts(total_size, part_size):
+ parts = []
+ num_parts = utils.how_many(total_size, part_size)
+
+ for i in range(num_parts):
+ if i == num_parts - 1:
+ start = i * part_size
+ end = total_size
+ else:
+ start = i * part_size
+ end = part_size + start
+
+ parts.append(_PartToProcess(i + 1, start, end))
+
+ return parts
+
+
+class _ResumableOperation(object):
+ def __init__(self, bucket, key, filename, size, store,
+ progress_callback=None):
+ self.bucket = bucket
+ self.key = key
+ self.filename = filename
+ self.size = size
+
+ self._abspath = os.path.abspath(filename)
+
+ self.__store = store
+ self.__record_key = self.__store.make_store_key(bucket.bucket_name, key, self._abspath)
+ logging.info('key is {0}'.format(self.__record_key))
+
+ # protect self.__progress_callback
+ self.__plock = threading.Lock()
+ self.__progress_callback = progress_callback
+
+ def _del_record(self):
+ self.__store.delete(self.__record_key)
+
+ def _put_record(self, record):
+ self.__store.put(self.__record_key, record)
+
+ def _get_record(self):
+ return self.__store.get(self.__record_key)
+
+ def _report_progress(self, consumed_size):
+ if self.__progress_callback:
+ with self.__plock:
+ self.__progress_callback(consumed_size, self.size)
+
+
+class _ObjectInfo(object):
+ def __init__(self):
+ self.size = None
+ self.etag = None
+ self.mtime = None
+
+ @staticmethod
+ def make(head_object_result):
+ objectInfo = _ObjectInfo()
+ objectInfo.size = head_object_result.content_length
+ objectInfo.etag = head_object_result.etag
+ objectInfo.mtime = head_object_result.last_modified
+
+ return objectInfo
+
+
+class _ResumableDownloader(_ResumableOperation):
+ def __init__(self, bucket, key, filename, objectInfo,
+ part_size=None,
+ store=None,
+ progress_callback=None,
+ num_threads=None):
+ super(_ResumableDownloader, self).__init__(bucket, key, filename, objectInfo.size,
+ store or ResumableDownloadStore(),
+ progress_callback=progress_callback)
+ self.objectInfo = objectInfo
+
+ self.__part_size = defaults.get(part_size, defaults.multiget_part_size)
+ self.__part_size = _determine_part_size_internal(self.size, self.__part_size, _MAX_MULTIGET_PART_COUNT)
+
+ self.__tmp_file = None
+ self.__num_threads = defaults.get(num_threads, defaults.multiget_num_threads)
+ self.__finished_parts = None
+ self.__finished_size = None
+
+ # protect record
+ self.__lock = threading.Lock()
+ self.__record = None
+
+ def download(self):
+ self.__load_record()
+
+ parts_to_download = self.__get_parts_to_download()
+
+ # create tmp file if it is does not exist
+ open(self.__tmp_file, 'a').close()
+
+ q = TaskQueue(functools.partial(self.__producer, parts_to_download=parts_to_download),
+ [self.__consumer] * self.__num_threads)
+ q.run()
+
+ utils.force_rename(self.__tmp_file, self.filename)
+
+ self._report_progress(self.size)
+ self._del_record()
+
+ def __producer(self, q, parts_to_download=None):
+ for part in parts_to_download:
+ q.put(part)
+
+ def __consumer(self, q):
+ while q.ok():
+ part = q.get()
+ if part is None:
+ break
+
+ self.__download_part(part)
+
+ def __download_part(self, part):
+ self._report_progress(self.__finished_size)
+
+ with open(self.__tmp_file, 'rb+') as f:
+ f.seek(part.start, os.SEEK_SET)
+
+ headers = {'If-Match': self.objectInfo.etag,
+ 'If-Unmodified-Since': utils.http_date(self.objectInfo.mtime)}
+ result = self.bucket.get_object(self.key, byte_range=(part.start, part.end - 1), headers=headers)
+ shutil.copyfileobj(result, f)
+
+ self.__finish_part(part)
+
+ def __load_record(self):
+ record = self._get_record()
+
+ if record and not self.is_record_sane(record):
+ self._del_record()
+ record = None
+
+ if record and self.__is_remote_changed(record):
+ utils.silently_remove(self.filename + record['tmp_suffix'])
+ self._del_record()
+ record = None
+
+ if not record:
+ record = {'mtime': self.objectInfo.mtime, 'etag': self.objectInfo.etag, 'size': self.objectInfo.size,
+ 'bucket': self.bucket.bucket_name, 'key': self.key, 'part_size': self.__part_size,
+ 'tmp_suffix': self.__gen_tmp_suffix(), 'abspath': self._abspath,
+ 'parts': []}
+ self._put_record(record)
+
+ self.__tmp_file = self.filename + record['tmp_suffix']
+ self.__part_size = record['part_size']
+ self.__finished_parts = list(_PartToProcess(p['part_number'], p['start'], p['end']) for p in record['parts'])
+ self.__finished_size = sum(p.size for p in self.__finished_parts)
+ self.__record = record
+
+ def __get_parts_to_download(self):
+ assert self.__record
+
+ all_set = set(_split_to_parts(self.size, self.__part_size))
+ finished_set = set(self.__finished_parts)
+
+ return sorted(list(all_set - finished_set), key=lambda p: p.part_number)
+
+ @staticmethod
+ def is_record_sane(record):
+ try:
+ for key in ('etag', 'tmp_suffix', 'abspath', 'bucket', 'key'):
+ if not isinstance(record[key], str):
+ logging.info('{0} is not a string: {1}, but {2}'.format(key, record[key], record[key].__class__))
+ return False
+
+ for key in ('part_size', 'size', 'mtime'):
+ if not isinstance(record[key], int):
+ logging.info('{0} is not an integer: {1}, but {2}'.format(key, record[key], record[key].__class__))
+ return False
+
+ for key in ('parts'):
+ if not isinstance(record['parts'], list):
+ logging.info('{0} is not a list: {1}, but {2}'.format(key, record[key], record[key].__class__))
+ return False
+ except KeyError as e:
+ logging.info('Key not found: {0}'.format(e.args))
+ return False
+
+ return True
+
+ def __is_remote_changed(self, record):
+ return (record['mtime'] != self.objectInfo.mtime or
+ record['size'] != self.objectInfo.size or
+ record['etag'] != self.objectInfo.etag)
+
+ def __finish_part(self, part):
+ logging.debug('finishing part: part_number={0}, start={1}, end={2}'.format(part.part_number, part.start, part.end))
+
+ with self.__lock:
+ self.__finished_parts.append(part)
+ self.__finished_size += part.size
+
+ self.__record['parts'].append({'part_number': part.part_number,
+ 'start': part.start,
+ 'end': part.end})
+ self._put_record(self.__record)
+
+ def __gen_tmp_suffix(self):
+ return '.tmp-' + ''.join(random.choice(string.ascii_lowercase) for i in range(12))
+
+
+class _ResumableUploader(_ResumableOperation):
"""以断点续传方式上传文件。
:param bucket: :class:`Bucket ` 对象
@@ -110,50 +385,39 @@ def __init__(self, bucket, key, filename, size,
part_size=None,
progress_callback=None,
num_threads=None):
- self.bucket = bucket
- self.key = key
- self.filename = filename
- self.size = size
+ super(_ResumableUploader, self).__init__(bucket, key, filename, size,
+ store or ResumableStore(),
+ progress_callback=progress_callback)
- self.__store = store or ResumableStore()
self.__headers = headers
self.__part_size = defaults.get(part_size, defaults.part_size)
- self.__abspath = os.path.abspath(filename)
self.__mtime = os.path.getmtime(filename)
- # protect self.__progress_callback
- self.__plock = threading.Lock()
- self.__progress_callback = progress_callback
-
self.__num_threads = defaults.get(num_threads, defaults.multipart_num_threads)
- self.__store_key = self.__store.make_store_key(bucket.bucket_name, key, self.__abspath)
self.__upload_id = None
# protect below fields
self.__lock = threading.Lock()
self.__record = None
- self.__size_uploaded = 0
- self.__parts = None
+ self.__finished_size = 0
+ self.__finished_parts = None
def upload(self):
self.__load_record()
- parts_uploaded = self.__recorded_parts()
- parts_to_upload, self.__parts = self.__get_parts_to_upload(parts_uploaded)
+ parts_to_upload = self.__get_parts_to_upload(self.__finished_parts)
parts_to_upload = sorted(parts_to_upload, key=lambda p: p.part_number)
- self.__size_uploaded = sum(p.size for p in self.__parts)
-
q = TaskQueue(functools.partial(self.__producer, parts_to_upload=parts_to_upload),
[self.__consumer] * self.__num_threads)
q.run()
- self.__report_progress(self.size)
+ self._report_progress(self.size)
- self.bucket.complete_multipart_upload(self.key, self.__upload_id, self.__parts)
- self.__store_delete()
+ self.bucket.complete_multipart_upload(self.key, self.__upload_id, self.__finished_parts)
+ self._del_record()
def __producer(self, q, parts_to_upload=None):
for part in parts_to_upload:
@@ -169,7 +433,7 @@ def __consumer(self, q):
def __upload_part(self, part):
with open(to_unicode(self.filename), 'rb') as f:
- self.__report_progress(self.__size_uploaded)
+ self._report_progress(self.__finished_size)
f.seek(part.start, os.SEEK_SET)
result = self.bucket.upload_part(self.key, self.__upload_id, part.part_number,
@@ -179,63 +443,49 @@ def __upload_part(self, part):
def __finish_part(self, part_info):
with self.__lock:
- self.__parts.append(part_info)
- self.__size_uploaded += part_info.size
+ self.__finished_parts.append(part_info)
+ self.__finished_size += part_info.size
self.__record['parts'].append({'part_number': part_info.part_number, 'etag': part_info.etag})
- self.__store_put(self.__record)
-
- def __report_progress(self, uploaded_size):
- if self.__progress_callback:
- with self.__plock:
- self.__progress_callback(uploaded_size, self.size)
-
- def __store_get(self):
- return self.__store.get(self.__store_key)
-
- def __store_put(self, record):
- return self.__store.put(self.__store_key, record)
-
- def __store_delete(self):
- return self.__store.delete(self.__store_key)
+ self._put_record(self.__record)
def __load_record(self):
- record = self.__store_get()
+ record = self._get_record()
if record and not _is_record_sane(record):
- self.__store_delete()
+ self._del_record()
record = None
if record and self.__file_changed(record):
logging.debug('{0} was changed, clear the record.'.format(self.filename))
- self.__store_delete()
+ self._del_record()
record = None
if record and not self.__upload_exists(record['upload_id']):
logging.debug('{0} upload not exist, clear the record.'.format(record['upload_id']))
- self.__store_delete()
+ self._del_record()
record = None
- if record:
- self.__record = record
- else:
+ if not record:
part_size = determine_part_size(self.size, self.__part_size)
upload_id = self.bucket.init_multipart_upload(self.key, headers=self.__headers).upload_id
record = {'upload_id': upload_id, 'mtime': self.__mtime, 'size': self.size, 'parts': [],
- 'abspath': self.__abspath, 'key': self.key,
+ 'abspath': self._abspath, 'bucket': self.bucket.bucket_name, 'key': self.key,
'part_size': part_size}
logging.debug('put new record upload_id={0} part_size={1}'.format(upload_id, part_size))
- self.__store_put(record)
+ self._put_record(record)
- self.__part_size = record['part_size']
- self.__upload_id = record['upload_id']
self.__record = record
+ self.__part_size = self.__record['part_size']
+ self.__upload_id = self.__record['upload_id']
+ self.__finished_parts = self.__get_finished_parts()
+ self.__finished_size = sum(p.size for p in self.__finished_parts)
- def __recorded_parts(self):
+ def __get_finished_parts(self):
last_part_number = utils.how_many(self.size, self.__part_size)
- parts_uploaded = []
+ parts = []
for p in self.__record['parts']:
part_info = PartInfo(int(p['part_number']), p['etag'])
@@ -244,9 +494,9 @@ def __recorded_parts(self):
else:
part_info.size = self.__part_size
- parts_uploaded.append(part_info)
+ parts.append(part_info)
- return parts_uploaded
+ return parts
def __upload_exists(self, upload_id):
try:
@@ -260,51 +510,25 @@ def __file_changed(self, record):
return record['mtime'] != self.__mtime or record['size'] != self.size
def __get_parts_to_upload(self, parts_uploaded):
- num_parts = utils.how_many(self.size, self.__part_size)
- uploaded_map = {}
- to_upload_map = {}
-
- for uploaded in parts_uploaded:
- uploaded_map[uploaded.part_number] = uploaded
-
- for i in range(num_parts):
- if i == num_parts - 1:
- start = i * self.__part_size
- end = self.size
- else:
- start = i * self.__part_size
- end = self.__part_size + start
-
- to_upload_map[i + 1] = _PartToUpload(i + 1, start, end)
-
+ all_parts = _split_to_parts(self.size, self.__part_size)
if not parts_uploaded:
- return to_upload_map.values(), []
+ return all_parts
- kept_parts = []
+ all_parts_map = dict((p.part_number, p) for p in all_parts)
for uploaded in parts_uploaded:
- if uploaded.part_number in to_upload_map:
- del to_upload_map[uploaded.part_number]
- kept_parts.append(uploaded)
+ if uploaded.part_number in all_parts_map:
+ del all_parts_map[uploaded.part_number]
- return to_upload_map.values(), kept_parts
+ return all_parts_map.values()
_UPLOAD_TEMP_DIR = '.py-oss-upload'
+_DOWNLOAD_TEMP_DIR = '.py-oss-download'
-class ResumableStore(object):
- """操作续传信息的类。
-
- 每次上传的信息会保存在root/dir/下面的某个文件里。
-
- :param str root: 父目录,缺省为HOME
- :param str dir: 自目录,缺省为_UPLOAD_TEMP_DIR
- """
- def __init__(self, root=None, dir=None):
- root = root or os.path.expanduser('~')
- dir = dir or _UPLOAD_TEMP_DIR
-
+class _ResumableStoreBase(object):
+ def __init__(self, root, dir):
self.dir = os.path.join(root, dir)
if os.path.isdir(self.dir):
@@ -312,11 +536,6 @@ def __init__(self, root=None, dir=None):
utils.makedir_p(self.dir)
- @staticmethod
- def make_store_key(bucket_name, key, filename):
- oss_pathname = 'oss://{0}/{1}'.format(bucket_name, key)
- return utils.md5_string(oss_pathname) + '-' + utils.md5_string(filename)
-
def get(self, key):
pathname = self.__path(key)
@@ -327,8 +546,15 @@ def get(self, key):
# json.load()返回的总是unicode,对于Python2,我们将其转换
# 为str。
- with open(to_unicode(pathname), 'r') as f:
- return stringify(json.load(f))
+
+ try:
+ with open(to_unicode(pathname), 'r') as f:
+ content = json.load(f)
+ except ValueError:
+ os.remove(pathname)
+ return None
+ else:
+ return stringify(content)
def put(self, key, value):
pathname = self.__path(key)
@@ -348,8 +574,54 @@ def __path(self, key):
return os.path.join(self.dir, key)
-def make_upload_store():
- return ResumableStore(dir=_UPLOAD_TEMP_DIR)
+def _normalize_path(path):
+ return os.path.normpath(os.path.normcase(path))
+
+
+class ResumableStore(_ResumableStoreBase):
+ """保存断点上传断点信息的类。
+
+ 每次上传的信息会保存在 `root/dir/` 下面的某个文件里。
+
+ :param str root: 父目录,缺省为HOME
+ :param str dir: 子目录,缺省为 `_UPLOAD_TEMP_DIR`
+ """
+ def __init__(self, root=None, dir=None):
+ super(ResumableStore, self).__init__(root or os.path.expanduser('~'), dir or _UPLOAD_TEMP_DIR)
+
+ @staticmethod
+ def make_store_key(bucket_name, key, filename):
+ filepath = _normalize_path(filename)
+
+ oss_pathname = 'oss://{0}/{1}'.format(bucket_name, key)
+ return utils.md5_string(oss_pathname) + '-' + utils.md5_string(filepath)
+
+
+class ResumableDownloadStore(_ResumableStoreBase):
+ """保存断点下载断点信息的类。
+
+ 每次下载的断点信息会保存在 `root/dir/` 下面的某个文件里。
+
+ :param str root: 父目录,缺省为HOME
+ :param str dir: 子目录,缺省为 `_DOWNLOAD_TEMP_DIR`
+ """
+ def __init__(self, root=None, dir=None):
+ super(ResumableDownloadStore, self).__init__(root or os.path.expanduser('~'), dir or _DOWNLOAD_TEMP_DIR)
+
+ @staticmethod
+ def make_store_key(bucket_name, key, filename):
+ filepath = _normalize_path(filename)
+
+ oss_pathname = 'oss://{0}/{1}'.format(bucket_name, key)
+ return utils.md5_string(oss_pathname) + '-' + utils.md5_string(filepath) + '-download'
+
+
+def make_upload_store(root=None, dir=None):
+ return ResumableStore(root=root, dir=dir)
+
+
+def make_download_store(root=None, dir=None):
+ return ResumableDownloadStore(root=root, dir=dir)
def _rebuild_record(filename, store, bucket, key, upload_id, part_size=None):
@@ -399,7 +671,7 @@ def _is_record_sane(record):
return True
-class _PartToUpload(object):
+class _PartToProcess(object):
def __init__(self, part_number, start, end):
self.part_number = part_number
self.start = start
@@ -408,3 +680,12 @@ def __init__(self, part_number, start, end):
@property
def size(self):
return self.end - self.start
+
+ def __hash__(self):
+ return hash(self.__key())
+
+ def __eq__(self, other):
+ return self.__key() == other.__key()
+
+ def __key(self):
+ return (self.part_number, self.start, self.end)
\ No newline at end of file
diff --git a/oss2/utils.py b/oss2/utils.py
index a3dad71e..b5fc7695 100644
--- a/oss2/utils.py
+++ b/oss2/utils.py
@@ -347,3 +347,22 @@ def makedir_p(dirpath):
if e.errno != errno.EEXIST:
raise
+
+def silently_remove(filename):
+ """删除文件,如果文件不存在也不报错。"""
+ try:
+ os.remove(filename)
+ except OSError as e:
+ if e.errno != errno.ENOENT:
+ raise
+
+
+def force_rename(src, dst):
+ try:
+ os.rename(src, dst)
+ except OSError as e:
+ if e.errno == errno.EEXIST:
+ silently_remove(dst)
+ os.rename(src, dst)
+ else:
+ raise
diff --git a/tests/common.py b/tests/common.py
index 48089510..eda6b877 100644
--- a/tests/common.py
+++ b/tests/common.py
@@ -4,6 +4,7 @@
import unittest
import time
import tempfile
+import errno
import oss2
@@ -38,6 +39,11 @@ def delete_keys(bucket, key_list):
bucket.batch_delete_objects(g)
+class NonlocalObject(object):
+ def __init__(self, value):
+ self.var = value
+
+
def wait_meta_sync():
if os.environ.get('TRAVIS'):
time.sleep(15)
@@ -53,11 +59,18 @@ def __init__(self, *args, **kwargs):
self.default_connect_timeout = oss2.defaults.connect_timeout
self.default_multipart_num_threads = oss2.defaults.multipart_threshold
+ self.default_multiget_threshold = 1024 * 1024
+ self.default_multiget_part_size = 100 * 1024
+
def setUp(self):
oss2.defaults.connect_timeout = self.default_connect_timeout
oss2.defaults.multipart_threshold = self.default_multipart_num_threads
oss2.defaults.multipart_num_threads = random.randint(1, 5)
+ oss2.defaults.multiget_threshold = self.default_multiget_threshold
+ oss2.defaults.multiget_part_size = self.default_multiget_part_size
+ oss2.defaults.multiget_num_threads = random.randint(1, 5)
+
self.bucket = oss2.Bucket(oss2.Auth(OSS_ID, OSS_SECRET), OSS_ENDPOINT, OSS_BUCKET)
self.bucket.create_bucket()
self.key_list = []
@@ -65,7 +78,8 @@ def setUp(self):
def tearDown(self):
for temp_file in self.temp_files:
- os.remove(temp_file)
+ oss2.utils.silently_remove(temp_file)
+
delete_keys(self.bucket, self.key_list)
def random_key(self, suffix=''):
@@ -74,6 +88,12 @@ def random_key(self, suffix=''):
return key
+ def random_filename(self):
+ filename = random_string(16)
+ self.temp_files.append(filename)
+
+ return filename
+
def _prepare_temp_file(self, content):
fd, pathname = tempfile.mkstemp(suffix='test-upload')
@@ -91,3 +111,9 @@ def retry_assert(self, func):
time.sleep(i+2)
self.assertTrue(False)
+
+ def assertFileContent(self, filename, content):
+ with open(filename, 'rb') as f:
+ read = f.read()
+ self.assertEqual(len(read), len(content))
+ self.assertEqual(read, content)
diff --git a/tests/test_download.py b/tests/test_download.py
new file mode 100644
index 00000000..0cdeddaf
--- /dev/null
+++ b/tests/test_download.py
@@ -0,0 +1,496 @@
+# -*- coding: utf-8 -*-
+
+import unittest
+import oss2
+import os
+import sys
+import time
+import copy
+import tempfile
+
+from mock import patch
+from functools import partial
+
+from common import *
+
+
+def modify_one(store, store_key, r, key=None, value=None):
+ r[key] = value
+ store.put(store_key, r)
+
+
+class TestDownload(OssTestCase):
+ def __prepare(self, file_size):
+ content = random_bytes(file_size)
+ key = self.random_key()
+ filename = self.random_filename()
+
+ self.bucket.put_object(key, content)
+
+ return key, filename, content
+
+ def __record(self, key, filename, store=None):
+ store = store or oss2.resumable.make_download_store()
+ store_key = store.make_store_key(self.bucket.bucket_name, key, os.path.abspath(filename))
+ return store.get(store_key)
+
+ def __test_normal(self, file_size):
+ key, filename, content = self.__prepare(file_size)
+ oss2.resumable_download(self.bucket, key, filename)
+
+ self.assertFileContent(filename, content)
+
+ def test_small(self):
+ oss2.defaults.multiget_threshold = 1024 * 1024
+
+ self.__test_normal(1023)
+
+ def test_large_single_threaded(self):
+ oss2.defaults.multiget_threshold = 1024 * 1024
+ oss2.defaults.multiget_part_size = 100 * 1024 + 1
+ oss2.defaults.multiget_num_threads = 1
+
+ self.__test_normal(2 * 1024 * 1024 + 1)
+
+ def test_large_multi_threaded(self):
+ """多线程,线程数少于分片数"""
+
+ oss2.defaults.multiget_threshold = 1024 * 1024
+ oss2.defaults.multiget_part_size = 100 * 1024
+ oss2.defaults.multiget_num_threads = 7
+
+ self.__test_normal(2 * 1024 * 1024)
+
+ def test_large_many_threads(self):
+ """线程数多余分片数"""
+
+ oss2.defaults.multiget_threshold = 1024 * 1024
+ oss2.defaults.multiget_part_size = 100 * 1024
+ oss2.defaults.multiget_num_threads = 10
+
+ self.__test_normal(512 * 1024 - 1)
+
+ def __test_resume(self, file_size, failed_parts, modify_func_record=None):
+ total = NonlocalObject(0)
+
+ orig_download_part = oss2.resumable._ResumableDownloader._ResumableDownloader__download_part
+
+ def mock_download_part(self, part, failed_parts=None):
+ if part.part_number in failed_parts:
+ raise RuntimeError("Fail download_part for part: {0}".format(part.part_number))
+ else:
+ total.var += 1
+ orig_download_part(self, part)
+
+ key, filename, content = self.__prepare(file_size)
+
+ with patch.object(oss2.resumable._ResumableDownloader, '_ResumableDownloader__download_part',
+ side_effect=partial(mock_download_part, failed_parts=failed_parts),
+ autospec=True):
+ self.assertRaises(RuntimeError, oss2.resumable_download, self.bucket, key, filename)
+
+ store = oss2.resumable.make_download_store()
+ store_key = store.make_store_key(self.bucket.bucket_name, key, os.path.abspath(filename))
+ record = store.get(store_key)
+
+ tmp_file = filename + record['tmp_suffix']
+ self.assertTrue(os.path.exists(tmp_file))
+ self.assertTrue(not os.path.exists(filename))
+
+ with patch.object(oss2.resumable._ResumableDownloader, '_ResumableDownloader__download_part',
+ side_effect=partial(mock_download_part, failed_parts=[]),
+ autospec=True):
+ oss2.resumable_download(self.bucket, key, filename)
+
+ self.assertEqual(total.var, oss2.utils.how_many(file_size, oss2.defaults.multiget_part_size))
+ self.assertTrue(not os.path.exists(tmp_file))
+ self.assertFileContent(filename, content)
+
+ def test_resume_hole_start(self):
+ """第一个part失败"""
+
+ oss2.defaults.multiget_threshold = 1
+ oss2.defaults.multiget_part_size = 500
+ oss2.defaults.multiget_num_threads = 3
+
+ self.__test_resume(500 * 10 + 16, [1])
+
+ def test_resume_hole_end(self):
+ """最后一个part失败"""
+
+ oss2.defaults.multiget_threshold = 1
+ oss2.defaults.multiget_part_size = 500
+ oss2.defaults.multiget_num_threads = 2
+
+ self.__test_resume(500 * 10 + 16, [11])
+
+ def test_resume_hole_mid(self):
+ """中间part失败"""
+
+ oss2.defaults.multiget_threshold = 1
+ oss2.defaults.multiget_part_size = 500
+ oss2.defaults.multiget_num_threads = 3
+
+ self.__test_resume(500 * 10 + 16, [3])
+
+ def test_resume_rename_failed(self):
+ size = 500 * 10
+ part_size = 499
+
+ oss2.defaults.multiget_threshold = 1
+ oss2.defaults.multiget_part_size = part_size
+ oss2.defaults.multiget_num_threads = 3
+
+ key, filename, content = self.__prepare(size)
+
+ with patch.object(os, 'rename', side_effect=RuntimeError(), autospec=True):
+ self.assertRaises(RuntimeError, oss2.resumable_download, self.bucket, key, filename)
+
+ r = self.__record(key, filename)
+
+ # assert record fields are valid
+ head_object_result = self.bucket.head_object(key)
+
+ self.assertEqual(r['size'], size)
+ self.assertEqual(r['mtime'], head_object_result.last_modified)
+ self.assertEqual(r['etag'], head_object_result.etag)
+
+ self.assertEqual(r['bucket'], self.bucket.bucket_name)
+ self.assertEqual(r['key'], key)
+ self.assertEqual(r['part_size'], part_size)
+
+ self.assertTrue(os.path.exists(filename + r['tmp_suffix']))
+ self.assertFileContent(filename + r['tmp_suffix'], content)
+
+ self.assertTrue(not os.path.exists(filename))
+
+ self.assertEqual(r['abspath'], os.path.abspath(filename))
+
+ self.assertEqual(len(r['parts']), oss2.utils.how_many(size, part_size))
+
+ parts = sorted(r['parts'], key=lambda p: p['part_number'])
+ for i, p in enumerate(parts):
+ self.assertEqual(p['part_number'], i+1)
+ self.assertEqual(p['start'], part_size * i)
+ self.assertEqual(p['end'], min(part_size*(i+1), size))
+
+ with patch.object(oss2.resumable._ResumableDownloader, '_ResumableDownloader__download_part',
+ side_effect=RuntimeError(),
+ autospec=True):
+ oss2.resumable_download(self.bucket, key, filename)
+
+ self.assertTrue(not os.path.exists(filename + r['tmp_suffix']))
+ self.assertFileContent(filename, content)
+ self.assertEqual(self.__record(key, filename), None)
+
+ def __test_insane_record(self, file_size, modify_record_func, old_tmp_exists=True):
+ orig_rename = os.rename
+
+ obj = NonlocalObject({})
+
+ key, filename, content = self.__prepare(file_size)
+
+ def mock_rename(src, dst):
+ obj.var = self.__record(key, filename)
+ orig_rename(src, dst)
+
+ with patch.object(os, 'rename', side_effect=RuntimeError(), autospec=True):
+ self.assertRaises(RuntimeError, oss2.resumable_download, self.bucket, key, filename)
+
+ store = oss2.resumable.make_download_store()
+ store_key = store.make_store_key(self.bucket.bucket_name, key, os.path.abspath(filename))
+ r = store.get(store_key)
+
+ modify_record_func(store, store_key, copy.deepcopy(r))
+
+ with patch.object(os, 'rename', side_effect=mock_rename, autospec=True):
+ oss2.resumable_download(self.bucket, key, filename)
+
+ new_r = obj.var
+
+ self.assertTrue(new_r['tmp_suffix'] != r['tmp_suffix'])
+
+ self.assertEqual(new_r['size'], r['size'])
+ self.assertEqual(new_r['mtime'], r['mtime'])
+ self.assertEqual(new_r['etag'], r['etag'])
+ self.assertEqual(new_r['part_size'], r['part_size'])
+
+ self.assertEqual(os.path.exists(filename + r['tmp_suffix']), old_tmp_exists)
+ self.assertTrue(not os.path.exists(filename + new_r['tmp_suffix']))
+
+ oss2.utils.silently_remove(filename + r['tmp_suffix'])
+
+ def test_insane_record_modify(self):
+ oss2.defaults.multiget_threshold = 1
+ oss2.defaults.multiget_part_size = 128
+ oss2.defaults.multiget_num_threads = 3
+
+ self.__test_insane_record(400, partial(modify_one, key='size', value='123'))
+ self.__test_insane_record(400, partial(modify_one, key='mtime', value='123'))
+ self.__test_insane_record(400, partial(modify_one, key='etag', value=123))
+
+ self.__test_insane_record(400, partial(modify_one, key='part_size', value={}))
+ self.__test_insane_record(400, partial(modify_one, key='tmp_suffix', value={1:2}))
+ self.__test_insane_record(400, partial(modify_one, key='parts', value={1:2}))
+
+ self.__test_insane_record(400, partial(modify_one, key='abspath', value=123))
+ self.__test_insane_record(400, partial(modify_one, key='bucket', value=123))
+ self.__test_insane_record(400, partial(modify_one, key='key', value=1.2))
+
+ def test_insane_record_missing(self):
+ oss2.defaults.multiget_threshold = 1
+ oss2.defaults.multiget_part_size = 128
+ oss2.defaults.multiget_num_threads = 3
+
+ def missing_one(store, store_key, r, key=None):
+ del r[key]
+ store.put(store_key, r)
+
+ self.__test_insane_record(400, partial(missing_one, key='key'))
+ self.__test_insane_record(400, partial(missing_one, key='mtime'))
+ self.__test_insane_record(400, partial(missing_one, key='parts'))
+
+ def test_insane_record_deleted(self):
+ oss2.defaults.multiget_threshold = 1
+ oss2.defaults.multiget_part_size = 128
+ oss2.defaults.multiget_num_threads = 3
+
+ def delete_record(store, store_key, r):
+ store.delete(store_key)
+
+ self.__test_insane_record(400, delete_record)
+
+ def test_insane_record_not_json(self):
+ oss2.defaults.multiget_threshold = 1
+ oss2.defaults.multiget_part_size = 128
+ oss2.defaults.multiget_num_threads = 3
+
+ def corrupt_record(store, store_key, r):
+ pathname = store._ResumableStoreBase__path(store_key)
+ with open(oss2.to_unicode(pathname), 'w') as f:
+ f.write('hello}')
+
+ self.__test_insane_record(400, corrupt_record)
+
+ def test_remote_changed_before_start(self):
+ """在开始下载之前,OSS上的文件就已经被修改了"""
+ oss2.defaults.multiget_threshold = 1
+
+ # reuse __test_insane_record to simulate
+ self.__test_insane_record(400, partial(modify_one, key='etag', value='BABEF00D123456789'), old_tmp_exists=False)
+ self.__test_insane_record(400, partial(modify_one, key='size', value=1024), old_tmp_exists=False)
+ self.__test_insane_record(400, partial(modify_one, key='mtime', value=1024), old_tmp_exists=False)
+
+ def test_remote_changed_during_upload(self):
+ oss2.defaults.multiget_threshold = 1
+ oss2.defaults.multiget_part_size = 100
+ oss2.defaults.multiget_num_threads = 2
+
+ orig_download_part = oss2.resumable._ResumableDownloader._ResumableDownloader__download_part
+ orig_rename = os.rename
+
+ file_size = 1000
+ key, filename, content = self.__prepare(file_size)
+
+ old_context = {}
+ new_context = {}
+
+ def mock_download_part(downloader, part, part_number=None):
+ if part.part_number == part_number:
+ r = self.__record(key, filename)
+
+ old_context['tmp_suffix'] = r['tmp_suffix']
+ old_context['etag'] = r['etag']
+ old_context['content'] = random_bytes(file_size)
+
+ self.bucket.put_object(key, old_context['content'])
+
+ orig_download_part(downloader, part)
+
+ def mock_rename(src, dst):
+ r = self.__record(key, filename)
+
+ new_context['tmp_suffix'] = r['tmp_suffix']
+ new_context['etag'] = r['etag']
+
+ orig_rename(src, dst)
+
+ with patch.object(oss2.resumable._ResumableDownloader, '_ResumableDownloader__download_part',
+ side_effect=partial(mock_download_part, part_number=5),
+ autospec=True):
+ self.assertRaises(oss2.exceptions.PreconditionFailed, oss2.resumable_download, self.bucket, key, filename)
+
+ with patch.object(os, 'rename', side_effect=mock_rename):
+ oss2.resumable_download(self.bucket, key, filename)
+
+ self.assertTrue(new_context['tmp_suffix'] != old_context['tmp_suffix'])
+ self.assertTrue(new_context['etag'] != old_context['etag'])
+
+ def test_two_downloaders(self):
+ """两个downloader同时跑,但是store的目录不一样。"""
+
+ oss2.defaults.multiget_threshold = 1
+ oss2.defaults.multiget_part_size = 100
+ oss2.defaults.multiget_num_threads = 2
+
+ store1 = oss2.make_download_store()
+ store2 = oss2.make_download_store(dir='.another-py-oss-download')
+
+ file_size = 1000
+ key, filename, content = self.__prepare(file_size)
+
+ context1a = {}
+ context1b = {}
+ context2 = {}
+
+ def mock_rename(src, dst, ctx=None, store=None):
+ r = self.__record(key, filename, store=store)
+
+ ctx['tmp_suffix'] = r['tmp_suffix']
+ ctx['etag'] = r['etag']
+ ctx['mtime'] = r['mtime']
+
+ raise RuntimeError('intentional')
+
+ with patch.object(os, 'rename', side_effect=partial(mock_rename, ctx=context1a, store=store1), autospect=True):
+ self.assertRaises(RuntimeError, oss2.resumable_download, self.bucket, key, filename, store=store1)
+
+ with patch.object(os, 'rename', side_effect=partial(mock_rename, ctx=context1b, store=store1), autospect=True):
+ self.assertRaises(RuntimeError, oss2.resumable_download, self.bucket, key, filename, store=store1)
+
+ with patch.object(os, 'rename', side_effect=partial(mock_rename, ctx=context2, store=store2), autospect=True):
+ self.assertRaises(RuntimeError, oss2.resumable_download, self.bucket, key, filename, store=store2)
+
+ self.assertEqual(context1a['tmp_suffix'], context1b['tmp_suffix'])
+ self.assertEqual(context1a['etag'], context1b['etag'])
+ self.assertEqual(context1a['mtime'], context1b['mtime'])
+
+ self.assertNotEqual(context1a['tmp_suffix'], context2['tmp_suffix'])
+ self.assertEqual(context1a['etag'], context2['etag'])
+ self.assertEqual(context1a['mtime'], context2['mtime'])
+
+ self.assertTrue(os.path.exists(filename + context1a['tmp_suffix']))
+ self.assertTrue(os.path.exists(filename + context2['tmp_suffix']))
+
+ oss2.resumable_download(self.bucket, key, filename, store=store1)
+ self.assertTrue(not os.path.exists(filename + context1a['tmp_suffix']))
+ self.assertTrue(os.path.exists(filename + context2['tmp_suffix']))
+
+ oss2.resumable_download(self.bucket, key, filename, store=store2)
+ self.assertTrue(not os.path.exists(filename + context2['tmp_suffix']))
+
+ def test_progress(self):
+ oss2.defaults.multiget_threshold = 1
+ oss2.defaults.multiget_part_size = 100
+ oss2.defaults.multiget_num_threads = 1
+
+ stats = {'previous': -1, 'called':0}
+
+ def progress_callback(bytes_consumed, total_bytes):
+ self.assertTrue(bytes_consumed <= total_bytes)
+ self.assertTrue(bytes_consumed > stats['previous'])
+
+ stats['previous'] = bytes_consumed
+ stats['called'] += 1
+
+ file_size = 100 * 5 + 1
+ key, filename, content = self.__prepare(file_size)
+
+ oss2.resumable_download(self.bucket, key, filename, progress_callback=progress_callback)
+
+ self.assertEqual(stats['previous'], file_size)
+ self.assertEqual(stats['called'], oss2.utils.how_many(file_size, oss2.defaults.multiget_part_size) + 1)
+
+ def test_parameters(self):
+ oss2.defaults.multiget_threshold = 1
+ oss2.defaults.multiget_part_size = 100
+ oss2.defaults.multiget_num_threads = 5
+
+ context = {}
+
+ orig_download = oss2.resumable._ResumableDownloader.download
+
+ def mock_download(downloader):
+ context['part_size'] = downloader._ResumableDownloader__part_size
+ context['num_threads'] = downloader._ResumableDownloader__num_threads
+
+ raise RuntimeError()
+
+ file_size = 123 * 3 + 1
+ key, filename, content = self.__prepare(file_size)
+
+ with patch.object(oss2.resumable._ResumableDownloader, 'download',
+ side_effect=mock_download, autospec=True):
+ self.assertRaises(RuntimeError, oss2.resumable_download, self.bucket, key, filename,
+ part_size=123, num_threads=3)
+
+ self.assertEqual(context['part_size'], 123)
+ self.assertEqual(context['num_threads'], 3)
+
+ def test_relpath_and_abspath(self):
+ """测试绝对、相对路径"""
+ # testing steps:
+ # 1. first use abspath, and fail one part
+ # 2. then use relpath to continue
+
+ if os.name == 'nt':
+ os.chdir('C:\\')
+
+ oss2.defaults.multiget_threshold = 1
+ oss2.defaults.multiget_part_size = 100
+ oss2.defaults.multiget_num_threads = 5
+
+ fd, abspath = tempfile.mkstemp()
+ os.close(fd)
+
+ relpath = os.path.relpath(abspath)
+
+ self.assertNotEqual(abspath, relpath)
+
+ file_size = 1000
+ key = self.random_key()
+ content = random_bytes(file_size)
+
+ self.bucket.put_object(key, content)
+
+ orig_download_part = oss2.resumable._ResumableDownloader._ResumableDownloader__download_part
+ orig_rename = os.rename
+
+ context1 = {}
+ context2 = {}
+
+ def mock_download_part(downloader, part, part_number=None):
+ if part.part_number == part_number:
+ r = self.__record(key, abspath)
+
+ context1['abspath'] = r['abspath']
+ context1['tmp_suffix'] = r['tmp_suffix']
+
+ raise RuntimeError("Fail download_part for part: {0}".format(part_number))
+ else:
+ orig_download_part(downloader, part)
+
+ def mock_rename(src, dst):
+ r = self.__record(key, relpath)
+
+ context2['abspath'] = r['abspath']
+ context2['tmp_suffix'] = r['tmp_suffix']
+
+ orig_rename(src, dst)
+
+ with patch.object(oss2.resumable._ResumableDownloader, '_ResumableDownloader__download_part',
+ side_effect=partial(mock_download_part, part_number=3),
+ autospec=True):
+ self.assertRaises(RuntimeError, oss2.resumable_download, self.bucket, key, abspath)
+
+ with patch.object(os, 'rename', side_effect=mock_rename):
+ oss2.resumable_download(self.bucket, key, relpath)
+
+ self.assertEqual(context1['abspath'], context2['abspath'])
+ self.assertEqual(context1['tmp_suffix'], context2['tmp_suffix'])
+
+ oss2.utils.silently_remove(abspath)
+
+
+if __name__ == '__main__':
+ unittest.main()
\ No newline at end of file
diff --git a/tests/test_upload.py b/tests/test_upload.py
index f37997aa..644f4232 100644
--- a/tests/test_upload.py
+++ b/tests/test_upload.py
@@ -8,6 +8,8 @@
from common import *
+from mock import patch
+
class TestUpload(OssTestCase):
def test_upload_small(self):
@@ -53,26 +55,30 @@ def test_concurrency(self):
self.assertEqual(result.headers['x-oss-object-type'], 'Multipart')
def test_progress(self):
- stats = {'previous': -1}
+ stats = {'previous': -1, 'ncalled': 0}
def progress_callback(bytes_consumed, total_bytes):
self.assertTrue(bytes_consumed <= total_bytes)
- self.assertTrue(bytes_consumed >= stats['previous'])
+ self.assertTrue(bytes_consumed > stats['previous'])
stats['previous'] = bytes_consumed
+ stats['ncalled'] += 1
key = random_string(16)
content = random_bytes(5 * 100 * 1024 + 100)
pathname = self._prepare_temp_file(content)
+ part_size = 100 * 1024
oss2.resumable_upload(self.bucket, key, pathname,
multipart_threshold=200 * 1024,
- part_size=100 * 1024,
- progress_callback=progress_callback)
+ part_size=part_size,
+ progress_callback=progress_callback,
+ num_threads=1)
self.assertEqual(stats['previous'], len(content))
+ self.assertEqual(stats['ncalled'], oss2.utils.how_many(len(content), part_size) + 1)
- stats = {'previous': -1}
+ stats = {'previous': -1, 'ncalled': 0}
oss2.resumable_upload(self.bucket, key, pathname,
multipart_threshold=len(content) + 100,
progress_callback=progress_callback)
@@ -113,9 +119,6 @@ def __test_resume(self, content_size, uploaded_parts, expected_unfinished=0):
def test_resume_empty(self):
self.__test_resume(250 * 1024, [])
- def test_resume_empty(self):
- self.__test_resume(250 * 1024, [])
-
def test_resume_continuous(self):
self.__test_resume(500 * 1024, [1, 2])
@@ -141,13 +144,10 @@ def upload_part(self, key, upload_id, part_number, data):
pathname = self._prepare_temp_file(content)
- from unittest.mock import patch
with patch.object(oss2.Bucket, 'upload_part', side_effect=upload_part, autospec=True) as mock_upload_part:
- try:
- oss2.resumable_upload(self.bucket, key, pathname, multipart_threshold=0,
- part_size=100 * 1024)
- except RuntimeError:
- pass
+ self.assertRaises(RuntimeError, oss2.resumable_upload, self.bucket, key, pathname,
+ multipart_threshold=0,
+ part_size=100 * 1024)
if modify_record_func:
modify_record_func(oss2.resumable.make_upload_store(), self.bucket.bucket_name, key, pathname)
@@ -156,48 +156,45 @@ def upload_part(self, key, upload_id, part_number, data):
self.assertEqual(len(list(oss2.ObjectUploadIterator(self.bucket, key))), expected_unfinished)
- if sys.version_info >= (3, 3):
- def test_interrupt_empty(self):
- self.__test_interrupt(310 * 1024, 1)
+ def test_interrupt_empty(self):
+ self.__test_interrupt(310 * 1024, 1)
- def test_interrupt_mid(self):
- self.__test_interrupt(510 * 1024, 3)
+ def test_interrupt_mid(self):
+ self.__test_interrupt(510 * 1024, 3)
- def test_interrupt_last(self):
- self.__test_interrupt(500 * 1024 - 1, 5)
+ def test_interrupt_last(self):
+ self.__test_interrupt(500 * 1024 - 1, 5)
- def test_record_bad_size(self):
- self.__test_interrupt(500 * 1024, 3,
- modify_record_func=self.__make_corrupt_record('size', 'hello'),
- expected_unfinished=1)
+ def test_record_bad_size(self):
+ self.__test_interrupt(500 * 1024, 3,
+ modify_record_func=self.__make_corrupt_record('size', 'hello'),
+ expected_unfinished=1)
- def test_record_no_such_upload_id(self):
- self.__test_interrupt(500 * 1024, 3,
- modify_record_func=self.__make_corrupt_record('upload_id', 'ABCD1234'),
- expected_unfinished=1)
+ def test_record_no_such_upload_id(self):
+ self.__test_interrupt(500 * 1024, 3,
+ modify_record_func=self.__make_corrupt_record('upload_id', 'ABCD1234'),
+ expected_unfinished=1)
- def test_file_changed_mtime(self):
- def change_mtime(store, bucket_name, key, pathname):
- time.sleep(2)
- os.utime(pathname, (time.time(), time.time()))
+ def test_file_changed_mtime(self):
+ def change_mtime(store, bucket_name, key, pathname):
+ time.sleep(2)
+ os.utime(pathname, (time.time(), time.time()))
- self.__test_interrupt(500 * 1024, 3,
- modify_record_func=change_mtime,
- expected_unfinished=1)
+ self.__test_interrupt(500 * 1024, 3,
+ modify_record_func=change_mtime,
+ expected_unfinished=1)
- def test_file_changed_size(self):
- def change_size(store, bucket_name, key, pathname):
- mtime = os.path.getmtime(pathname)
+ def test_file_changed_size(self):
+ def change_size(store, bucket_name, key, pathname):
+ mtime = os.path.getmtime(pathname)
- with open(pathname, 'w') as f:
- f.write('hello world')
+ with open(pathname, 'w') as f:
+ f.write('hello world')
- os.utime(pathname, (mtime, mtime))
- self.__test_interrupt(500 * 1024, 3,
- modify_record_func=change_size,
- expected_unfinished=1)
- else:
- print('skip error injection cases for Python version < 3.3')
+ os.utime(pathname, (mtime, mtime))
+ self.__test_interrupt(500 * 1024, 3,
+ modify_record_func=change_size,
+ expected_unfinished=1)
def __make_corrupt_record(self, name, value):
def corrupt_record(store, bucket_name, key, pathname):