-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #43 from CortexFoundation/wlt
Wlt
- Loading branch information
Showing
5 changed files
with
406 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,393 @@ | ||
#!/usr/bin/env python3 | ||
# -*- coding: utf-8 -*- | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
from __future__ import print_function | ||
import os | ||
import sys | ||
|
||
curr_path = os.path.abspath(os.path.dirname(__file__)) | ||
sys.path.append(os.path.join(curr_path, "../python")) | ||
import mxnet as mx | ||
import random | ||
import argparse | ||
import cv2 | ||
import time | ||
import traceback | ||
|
||
try: | ||
import multiprocessing | ||
except ImportError: | ||
multiprocessing = None | ||
|
||
def list_image(root, recursive, exts): | ||
"""Traverses the root of directory that contains images and | ||
generates image list iterator. | ||
Parameters | ||
---------- | ||
root: string | ||
recursive: bool | ||
exts: string | ||
Returns | ||
------- | ||
image iterator that contains all the image under the specified path | ||
""" | ||
|
||
i = 0 | ||
if recursive: | ||
cat = {} | ||
for path, dirs, files in os.walk(root, followlinks=True): | ||
dirs.sort() | ||
files.sort() | ||
for fname in files: | ||
fpath = os.path.join(path, fname) | ||
suffix = os.path.splitext(fname)[1].lower() | ||
if os.path.isfile(fpath) and (suffix in exts): | ||
if path not in cat: | ||
cat[path] = len(cat) | ||
yield (i, os.path.relpath(fpath, root), cat[path]) | ||
i += 1 | ||
for k, v in sorted(cat.items(), key=lambda x: x[1]): | ||
print(os.path.relpath(k, root), v) | ||
else: | ||
for fname in sorted(os.listdir(root)): | ||
fpath = os.path.join(root, fname) | ||
suffix = os.path.splitext(fname)[1].lower() | ||
if os.path.isfile(fpath) and (suffix in exts): | ||
yield (i, os.path.relpath(fpath, root), 0) | ||
i += 1 | ||
|
||
def write_list(path_out, image_list): | ||
"""Hepler function to write image list into the file. | ||
The format is as below, | ||
integer_image_index \t float_label_index \t path_to_image | ||
Note that the blank between number and tab is only used for readability. | ||
Parameters | ||
---------- | ||
path_out: string | ||
image_list: list | ||
""" | ||
with open(path_out, 'w') as fout: | ||
for i, item in enumerate(image_list): | ||
line = '%d\t' % item[0] | ||
for j in item[2:]: | ||
line += '%f\t' % j | ||
line += '%s\n' % item[1] | ||
fout.write(line) | ||
|
||
def make_list(args): | ||
"""Generates .lst file. | ||
Parameters | ||
---------- | ||
args: object that contains all the arguments | ||
""" | ||
image_list = list_image(args.root, args.recursive, args.exts) | ||
image_list = list(image_list) | ||
if args.shuffle is True: | ||
random.seed(100) | ||
random.shuffle(image_list) | ||
N = len(image_list) | ||
chunk_size = (N + args.chunks - 1) // args.chunks | ||
for i in range(args.chunks): | ||
chunk = image_list[i * chunk_size:(i + 1) * chunk_size] | ||
if args.chunks > 1: | ||
str_chunk = '_%d' % i | ||
else: | ||
str_chunk = '' | ||
sep = int(chunk_size * args.train_ratio) | ||
sep_test = int(chunk_size * args.test_ratio) | ||
if args.train_ratio == 1.0: | ||
write_list(args.prefix + str_chunk + '.lst', chunk) | ||
else: | ||
if args.test_ratio: | ||
write_list(args.prefix + str_chunk + '_test.lst', chunk[:sep_test]) | ||
if args.train_ratio + args.test_ratio < 1.0: | ||
write_list(args.prefix + str_chunk + '_val.lst', chunk[sep_test + sep:]) | ||
write_list(args.prefix + str_chunk + '_train.lst', chunk[sep_test:sep_test + sep]) | ||
|
||
def read_list(path_in): | ||
"""Reads the .lst file and generates corresponding iterator. | ||
Parameters | ||
---------- | ||
path_in: string | ||
Returns | ||
------- | ||
item iterator that contains information in .lst file | ||
""" | ||
with open(path_in) as fin: | ||
while True: | ||
line = fin.readline() | ||
if not line: | ||
break | ||
line = [i.strip() for i in line.strip().split('\t')] | ||
line_len = len(line) | ||
# check the data format of .lst file | ||
if line_len < 3: | ||
print('lst should have at least has three parts, but only has %s parts for %s' % (line_len, line)) | ||
continue | ||
try: | ||
item = [int(line[0])] + [line[-1]] + [float(i) for i in line[1:-1]] | ||
except Exception as e: | ||
print('Parsing lst met error for %s, detail: %s' % (line, e)) | ||
continue | ||
yield item | ||
|
||
def image_encode(args, i, item, q_out): | ||
"""Reads, preprocesses, packs the image and put it back in output queue. | ||
Parameters | ||
---------- | ||
args: object | ||
i: int | ||
item: list | ||
q_out: queue | ||
""" | ||
fullpath = os.path.join(args.root, item[1]) | ||
|
||
if len(item) > 3 and args.pack_label: | ||
header = mx.recordio.IRHeader(0, item[2:], item[0], 0) | ||
else: | ||
header = mx.recordio.IRHeader(0, item[2], item[0], 0) | ||
|
||
if args.pass_through: | ||
try: | ||
with open(fullpath, 'rb') as fin: | ||
img = fin.read() | ||
s = mx.recordio.pack(header, img) | ||
q_out.put((i, s, item)) | ||
except Exception as e: | ||
traceback.print_exc() | ||
print('pack_img error:', item[1], e) | ||
q_out.put((i, None, item)) | ||
return | ||
|
||
try: | ||
img = cv2.imread(fullpath, args.color) | ||
except: | ||
traceback.print_exc() | ||
print('imread error trying to load file: %s ' % fullpath) | ||
q_out.put((i, None, item)) | ||
return | ||
if img is None: | ||
print('imread read blank (None) image for file: %s' % fullpath) | ||
q_out.put((i, None, item)) | ||
return | ||
if args.center_crop: | ||
if img.shape[0] > img.shape[1]: | ||
margin = (img.shape[0] - img.shape[1]) // 2 | ||
img = img[margin:margin + img.shape[1], :] | ||
else: | ||
margin = (img.shape[1] - img.shape[0]) // 2 | ||
img = img[:, margin:margin + img.shape[0]] | ||
if args.resize: | ||
if img.shape[0] > img.shape[1]: | ||
newsize = (args.resize, img.shape[0] * args.resize // img.shape[1]) | ||
else: | ||
newsize = (img.shape[1] * args.resize // img.shape[0], args.resize) | ||
img = cv2.resize(img, newsize) | ||
|
||
try: | ||
s = mx.recordio.pack_img(header, img, quality=args.quality, img_fmt=args.encoding) | ||
q_out.put((i, s, item)) | ||
except Exception as e: | ||
traceback.print_exc() | ||
print('pack_img error on file: %s' % fullpath, e) | ||
q_out.put((i, None, item)) | ||
return | ||
|
||
def read_worker(args, q_in, q_out): | ||
"""Function that will be spawned to fetch the image | ||
from the input queue and put it back to output queue. | ||
Parameters | ||
---------- | ||
args: object | ||
q_in: queue | ||
q_out: queue | ||
""" | ||
while True: | ||
deq = q_in.get() | ||
if deq is None: | ||
break | ||
i, item = deq | ||
image_encode(args, i, item, q_out) | ||
|
||
def write_worker(q_out, fname, working_dir): | ||
"""Function that will be spawned to fetch processed image | ||
from the output queue and write to the .rec file. | ||
Parameters | ||
---------- | ||
q_out: queue | ||
fname: string | ||
working_dir: string | ||
""" | ||
pre_time = time.time() | ||
count = 0 | ||
fname = os.path.basename(fname) | ||
fname_rec = os.path.splitext(fname)[0] + '.rec' | ||
fname_idx = os.path.splitext(fname)[0] + '.idx' | ||
record = mx.recordio.MXIndexedRecordIO(os.path.join(working_dir, fname_idx), | ||
os.path.join(working_dir, fname_rec), 'w') | ||
buf = {} | ||
more = True | ||
while more: | ||
deq = q_out.get() | ||
if deq is not None: | ||
i, s, item = deq | ||
buf[i] = (s, item) | ||
else: | ||
more = False | ||
while count in buf: | ||
s, item = buf[count] | ||
del buf[count] | ||
if s is not None: | ||
record.write_idx(item[0], s) | ||
|
||
if count % 1000 == 0: | ||
cur_time = time.time() | ||
print('time:', cur_time - pre_time, ' count:', count) | ||
pre_time = cur_time | ||
count += 1 | ||
|
||
def parse_args(): | ||
"""Defines all arguments. | ||
Returns | ||
------- | ||
args object that contains all the params | ||
""" | ||
parser = argparse.ArgumentParser( | ||
formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||
description='Create an image list or \ | ||
make a record database by reading from an image list') | ||
parser.add_argument('prefix', help='prefix of input/output lst and rec files.') | ||
parser.add_argument('root', help='path to folder containing images.') | ||
|
||
cgroup = parser.add_argument_group('Options for creating image lists') | ||
cgroup.add_argument('--list', action='store_true', | ||
help='If this is set im2rec will create image list(s) by traversing root folder\ | ||
and output to <prefix>.lst.\ | ||
Otherwise im2rec will read <prefix>.lst and create a database at <prefix>.rec') | ||
cgroup.add_argument('--exts', nargs='+', default=['.jpeg', '.jpg', '.png'], | ||
help='list of acceptable image extensions.') | ||
cgroup.add_argument('--chunks', type=int, default=1, help='number of chunks.') | ||
cgroup.add_argument('--train-ratio', type=float, default=1.0, | ||
help='Ratio of images to use for training.') | ||
cgroup.add_argument('--test-ratio', type=float, default=0, | ||
help='Ratio of images to use for testing.') | ||
cgroup.add_argument('--recursive', action='store_true', | ||
help='If true recursively walk through subdirs and assign an unique label\ | ||
to images in each folder. Otherwise only include images in the root folder\ | ||
and give them label 0.') | ||
cgroup.add_argument('--no-shuffle', dest='shuffle', action='store_false', | ||
help='If this is passed, \ | ||
im2rec will not randomize the image order in <prefix>.lst') | ||
rgroup = parser.add_argument_group('Options for creating database') | ||
rgroup.add_argument('--pass-through', action='store_true', | ||
help='whether to skip transformation and save image as is') | ||
rgroup.add_argument('--resize', type=int, default=0, | ||
help='resize the shorter edge of image to the newsize, original images will\ | ||
be packed by default.') | ||
rgroup.add_argument('--center-crop', action='store_true', | ||
help='specify whether to crop the center image to make it rectangular.') | ||
rgroup.add_argument('--quality', type=int, default=95, | ||
help='JPEG quality for encoding, 1-100; or PNG compression for encoding, 1-9') | ||
rgroup.add_argument('--num-thread', type=int, default=1, | ||
help='number of thread to use for encoding. order of images will be different\ | ||
from the input list if >1. the input list will be modified to match the\ | ||
resulting order.') | ||
rgroup.add_argument('--color', type=int, default=1, choices=[-1, 0, 1], | ||
help='specify the color mode of the loaded image.\ | ||
1: Loads a color image. Any transparency of image will be neglected. It is the default flag.\ | ||
0: Loads image in grayscale mode.\ | ||
-1:Loads image as such including alpha channel.') | ||
rgroup.add_argument('--encoding', type=str, default='.jpg', choices=['.jpg', '.png'], | ||
help='specify the encoding of the images.') | ||
rgroup.add_argument('--pack-label', action='store_true', | ||
help='Whether to also pack multi dimensional label in the record file') | ||
args = parser.parse_args() | ||
args.prefix = os.path.abspath(args.prefix) | ||
args.root = os.path.abspath(args.root) | ||
return args | ||
|
||
if __name__ == '__main__': | ||
args = parse_args() | ||
# if the '--list' is used, it generates .lst file | ||
if args.list: | ||
make_list(args) | ||
# otherwise read .lst file to generates .rec file | ||
else: | ||
if os.path.isdir(args.prefix): | ||
working_dir = args.prefix | ||
else: | ||
working_dir = os.path.dirname(args.prefix) | ||
files = [os.path.join(working_dir, fname) for fname in os.listdir(working_dir) | ||
if os.path.isfile(os.path.join(working_dir, fname))] | ||
count = 0 | ||
for fname in files: | ||
if fname.startswith(args.prefix) and fname.endswith('.lst'): | ||
print('Creating .rec file from', fname, 'in', working_dir) | ||
count += 1 | ||
image_list = read_list(fname) | ||
# -- write_record -- # | ||
if args.num_thread > 1 and multiprocessing is not None: | ||
q_in = [multiprocessing.Queue(1024) for i in range(args.num_thread)] | ||
q_out = multiprocessing.Queue(1024) | ||
# define the process | ||
read_process = [multiprocessing.Process(target=read_worker, args=(args, q_in[i], q_out)) \ | ||
for i in range(args.num_thread)] | ||
# process images with num_thread process | ||
for p in read_process: | ||
p.start() | ||
# only use one process to write .rec to avoid race-condtion | ||
write_process = multiprocessing.Process(target=write_worker, args=(q_out, fname, working_dir)) | ||
write_process.start() | ||
# put the image list into input queue | ||
for i, item in enumerate(image_list): | ||
q_in[i % len(q_in)].put((i, item)) | ||
for q in q_in: | ||
q.put(None) | ||
for p in read_process: | ||
p.join() | ||
|
||
q_out.put(None) | ||
write_process.join() | ||
else: | ||
print('multiprocessing not available, fall back to single threaded encoding') | ||
try: | ||
import Queue as queue | ||
except ImportError: | ||
import queue | ||
q_out = queue.Queue() | ||
fname = os.path.basename(fname) | ||
fname_rec = os.path.splitext(fname)[0] + '.rec' | ||
fname_idx = os.path.splitext(fname)[0] + '.idx' | ||
record = mx.recordio.MXIndexedRecordIO(os.path.join(working_dir, fname_idx), | ||
os.path.join(working_dir, fname_rec), 'w') | ||
cnt = 0 | ||
pre_time = time.time() | ||
for i, item in enumerate(image_list): | ||
image_encode(args, i, item, q_out) | ||
if q_out.empty(): | ||
continue | ||
_, s, _ = q_out.get() | ||
record.write_idx(item[0], s) | ||
if cnt % 1000 == 0: | ||
cur_time = time.time() | ||
print('time:', cur_time - pre_time, ' count:', cnt) | ||
pre_time = cur_time | ||
cnt += 1 | ||
if not count: | ||
print('Did not find and list file with prefix %s'%args.prefix) |
Oops, something went wrong.