Skip to content

Commit

Permalink
Merge pull request #86 from ziao-guo/main
Browse files Browse the repository at this point in the history
capture ReadError in download
  • Loading branch information
ziao-guo authored Nov 13, 2023
2 parents 56b5f10 + dc2dc7e commit dbfdad9
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 10 deletions.
7 changes: 6 additions & 1 deletion pygmtools/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,12 @@ def download(self, url=None, name=None, retries=5):
os.remove(filename)
return self.download(url, name, retries - 1)

file_names = tar.getnames()
try:
file_names = tar.getnames()
except tarfile.ReadError as err:
print('Warning: Content error. Retrying...\n', err)
os.remove(filename)
return self.download(url, name, retries - 1)
print('Unzipping files...')
sleep(0.5)
for file_name in tqdm(file_names):
Expand Down
26 changes: 17 additions & 9 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
from random import choice
import os

import platform
os_name = platform.system()

# Test dataset download and preprocess, and data fetch and evaluation
def _test_benchmark(name, sets, problem, filter, **ds_dict):
benchmark = pygm.benchmark.Benchmark(name=name, sets=sets, problem=problem, filter=filter, **ds_dict)
Expand Down Expand Up @@ -57,19 +60,14 @@ def _test_get_data(benchmark, num):

# Entry function
def test_dataset_and_benchmark():
dataset_name_list = ['PascalVOC', 'WillowObject', 'SPair71k', 'IMC_PT_SparseGM', 'CUB2011']
dataset_name_list = ['WillowObject', 'PascalVOC', 'SPair71k', 'IMC_PT_SparseGM', 'CUB2011']

if os_name == 'Darwin':
dataset_name_list = ['WillowObject', 'SPair71k', 'IMC_PT_SparseGM', 'CUB2011']
problem_type_list = ['2GM', 'MGM']
set_list = ['train', 'test']
filter_list = ['intersection', 'inclusion', 'unfiltered']
dict_list = []
voc_cfg_dict = dict()
voc_cfg_dict['KPT_ANNO_DIR'] = dataset_cfg.PascalVOC.KPT_ANNO_DIR
voc_cfg_dict['ROOT_DIR'] = dataset_cfg.PascalVOC.ROOT_DIR
voc_cfg_dict['SET_SPLIT'] = dataset_cfg.PascalVOC.SET_SPLIT
voc_cfg_dict['CLASSES'] = dataset_cfg.PascalVOC.CLASSES
voc_cfg_dict['CACHE_PATH'] = dataset_cfg.CACHE_PATH
voc_cfg_dict['URL'] = 'https://drive.google.com/u/0/uc?id=1TLBN4dnf_THmN3kNINMO0DpHhcBDqzcI&export=download'
dict_list.append(voc_cfg_dict)

willow_cfg_dict = dict()
willow_cfg_dict['CLASSES'] = dataset_cfg.WillowObject.CLASSES
Expand All @@ -82,6 +80,16 @@ def test_dataset_and_benchmark():
willow_cfg_dict['URL'] = 'https://drive.google.com/u/0/uc?export=download&confirm=Z-AR&id=18AvGwkuhnih5bFDjfJK5NYM16LvDfwW_'
dict_list.append(willow_cfg_dict)

if os_name != 'Darwin':
voc_cfg_dict = dict()
voc_cfg_dict['KPT_ANNO_DIR'] = dataset_cfg.PascalVOC.KPT_ANNO_DIR
voc_cfg_dict['ROOT_DIR'] = dataset_cfg.PascalVOC.ROOT_DIR
voc_cfg_dict['SET_SPLIT'] = dataset_cfg.PascalVOC.SET_SPLIT
voc_cfg_dict['CLASSES'] = dataset_cfg.PascalVOC.CLASSES
voc_cfg_dict['CACHE_PATH'] = dataset_cfg.CACHE_PATH
voc_cfg_dict['URL'] = 'https://huggingface.co/datasets/ziaoguo/small_VOC/resolve/main/small_voc.tar?download=true'
dict_list.append(voc_cfg_dict)

spair_cfg_dict = dict()
spair_cfg_dict['TRAIN_DIFF_PARAMS'] = {'mirror': 0}
spair_cfg_dict['EVAL_DIFF_PARAMS'] = dataset_cfg.SPair.EVAL_DIFF_PARAMS
Expand Down

0 comments on commit dbfdad9

Please sign in to comment.