-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathprepare_mask_data.py
127 lines (107 loc) · 5.15 KB
/
prepare_mask_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# Copyright (C) 2022 ByteDance Inc.
# All rights reserved.
# Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
# The software is made available under Creative Commons BY-NC-SA 4.0 license
# by ByteDance Inc. You can use, redistribute, and adapt it
# for non-commercial purposes, as long as you (a) give appropriate credit
# by citing our paper, (b) indicate any changes that you've made,
# and (c) distribute any derivative works under the same license.
# THE AUTHORS DISCLAIM ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING ALL
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR ANY PARTICULAR PURPOSE.
# IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL
# DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
# WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING
# OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
# summary of changes (GC-GAN):
# 25/05/2024: add prepare_conditional for gaze-labeled data
import os
from glob import glob
import argparse
from io import BytesIO
import multiprocessing
from functools import partial
from PIL import Image
import lmdb
from tqdm import tqdm
from torchvision import datasets
from torchvision.transforms import functional as trans_fn
import numpy as np
def resize_and_convert(img, size, format, resample):
img = trans_fn.resize(img, size, resample)
img = trans_fn.center_crop(img, size)
buffer = BytesIO()
img.save(buffer, format=format, quality=100)
val = buffer.getvalue()
return val
def resize_worker(img_file, size, use_rgb, format, resample):
i, file = img_file
img = Image.open(file)
if use_rgb:
img = img.convert("RGB")
img = resize_and_convert(img, size, format, resample)
print(img_file)
return i, img, file
def find_images(path):
if os.path.isfile(path):
with open(path, "r") as f:
files = [line.strip() for line in f.readlines()]
else:
files = list()
IMAGE_EXTENSIONS = {'jpg', 'png', 'jpeg', 'webp'}
IMAGE_EXTENSIONS = IMAGE_EXTENSIONS.union({f.upper() for f in IMAGE_EXTENSIONS})
for ext in IMAGE_EXTENSIONS:
files += glob(f'{path}/**/*.{ext}', recursive=True)
files = sorted(files)
return files
def prepare(env, files, n_worker, size, prefix, use_rgb, format, resample):
resize_fn = partial(resize_worker, size=size, use_rgb=use_rgb, format=format, resample=resample)
total = 0
with env.begin(write=True) as txn:
with multiprocessing.Pool(n_worker) as pool:
for i, img, img_filename in tqdm(pool.imap_unordered(resize_fn, enumerate(files))):
txn.put(f"{prefix}-{str(i).zfill(7)}".encode("utf-8"), img)
tmp = f"{prefix}-{str(i).zfill(7)}".encode("utf-8"), img
total += 1
txn.put(f"{prefix}-length".encode("utf-8"), str(total).encode("utf-8"))
def prepare_conditional(env, files, n_worker, size, prefix, use_rgb, format, resample):
resize_fn = partial(resize_worker, size=size, use_rgb=use_rgb, format=format, resample=resample)
total = 0
with env.begin(write=True) as txn:
with multiprocessing.Pool(n_worker) as pool:
labels = {}
for i, img, img_filename in tqdm(pool.imap_unordered(resize_fn, enumerate(files))):
label_data = os.path.basename(img_filename).split('_')
label = label_data[4] + '_' + label_data[5]
labels[i] = label
txn.put(f"{prefix}-{str(i).zfill(7)}-{label}".encode("utf-8"), img)
#tmp = f"{prefix}-{str(i).zfill(7)}"
total += 1
txn.put(f"{prefix}-length".encode("utf-8"), str(total).encode("utf-8"))
np.save(env.path()+"/labels_gaze", labels)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Preprocess images for model training")
parser.add_argument("--out", type=str, help="filename of the result lmdb dataset")
parser.add_argument(
"--size",
type=int,
default=256,
help="resolutions of images for the dataset",
)
parser.add_argument(
"--n_worker",
type=int,
default=8,
help="number of workers for preparing dataset",
)
parser.add_argument("image_path", type=str, help="path to the image files")
parser.add_argument("label_path", type=str, help="path to the label files")
args = parser.parse_args()
images = find_images(args.image_path)
labels = find_images(args.label_path)
get_key = lambda fpath: os.path.splitext(os.path.basename(fpath))[0] # Identify by basename
label_dict = {get_key(label):label for label in labels}
labels = [label_dict[get_key(image)] for image in images]
print(f"Number of images: {len(images)}")
with lmdb.open(args.out, map_size=1024 ** 4, readahead=False) as env:
prepare_conditional(env, images, args.n_worker, args.size, 'image', use_rgb=True, format='png', resample=Image.LANCZOS)
prepare_conditional(env, labels, args.n_worker, args.size, 'label', use_rgb=False, format='png', resample=Image.NEAREST)