-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata_preprocessing.py
43 lines (33 loc) · 1.57 KB
/
data_preprocessing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import shutil
from tqdm import tqdm
from pathlib import Path
# import matplotlib.pyplot as plt
# %matplotlib inline
def transform_data(generated_dataset_path, new_dataset_path):
"""
Trandorm raw data generated by stylegan in pix2pix format
"""
train_A_path = new_dataset_path / 'train_A'
train_B_path = new_dataset_path / 'train_B'
train_C_path = new_dataset_path / 'train_C'
train_A_path.mkdir(parents=True, exist_ok=True)
train_B_path.mkdir(parents=True, exist_ok=True)
train_C_path.mkdir(parents=True, exist_ok=True)
all_paths = list(generated_dataset_path.iterdir())
source_paths = [p for p in all_paths if 'tr' not in str(p)]
neg_paths = [p for p in all_paths if 'tr' in str(p.name) and '-' in str(p.name)] # if tr word and - minus symbol in filename
pos_paths = [p for p in all_paths if 'tr' in str(p.name) and '-' not in str(p.name)]
len(source_paths), len(neg_paths), len(pos_paths)
for i, (source, neg, pos) in tqdm(enumerate(zip(sorted(source_paths, key=lambda x: x.name.split('.')[0] + '_'),
sorted(neg_paths), sorted(pos_paths)))):
shutil.copy(source, train_A_path / f'{i}.png')
shutil.copy(pos, train_B_path / f'{i}.png')
shutil.copy(neg, train_C_path / f'{i}.png')
return True
def main():
generated_dataset_path = Path('generated_ffhq_age/images')
new_dataset_path = Path('ffhq_smile_pix2pixHD')
# convert data in pix2poix format
transform_data(generated_dataset_path, new_dataset_path)
if __name__ == "__main__":
main()