Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add check images script #292

Merged
merged 8 commits into from
Nov 30, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions python/smqtk/bin/check_images.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""
Validate a list of images returning the filepaths and UUIDs of only the
valid images, or optionally, only the invalid images.
"""
import itertools
import logging
import os
import sys

from smqtk.utils.bin_utils import (
basic_cli_parser,
initialize_logging,
)
from smqtk.representation.data_element.file_element import DataFileElement
from smqtk.utils.image_utils import is_valid_element
from smqtk.utils import parallel


__author__ = '[email protected]'


def get_cli_parser():
parser = basic_cli_parser(__doc__, configuration_group=False)

parser.add_argument('-i', '--invert',
default=False, action='store_true',
help='Invert results, showing only invalid images.')

g_required = parser.add_argument_group("Required Arguments")
g_required.add_argument('-f', '--file-list',
type=str, default=None, metavar='PATH',
help='Path to a file that lists data file paths.')
return parser


def main():
# Print help and exit if no arguments were passed
if len(sys.argv) == 1:
get_cli_parser().print_help()
sys.exit(1)

args = get_cli_parser().parse_args()
llevel = logging.INFO if not args.verbose else logging.DEBUG
initialize_logging(logging.getLogger('smqtk'), llevel)
initialize_logging(logging.getLogger('__main__'), llevel)

log = logging.getLogger(__name__)
log.debug('Showing debug messages.')

if args.file_list is not None and not os.path.exists(args.file_list):
log.error('Invalid file list path: %s', args.file_list)
exit(103)

def check_image(image_path):
if not os.path.exists(image_path):
log.warn('Invalid image path given (does not exist): %s', image_path)
return (False, False)
else:
dfe = DataFileElement(image_path)
return (is_valid_element(dfe, check_image=True), dfe)

with open(args.file_list) as infile:
checked_images = parallel.parallel_map(check_image,
itertools.imap(str.strip, infile),
name='check-image-validity',
use_multiprocessing=True)

for (is_valid, dfe) in checked_images:
if dfe != False: # in the case of a non-existent file
if (is_valid and not args.invert) or (not is_valid and args.invert):
print('%s,%s' % (dfe._filepath, dfe.uuid()))


if __name__ == '__main__':
main()
39 changes: 11 additions & 28 deletions python/smqtk/bin/compute_many_descriptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,9 @@
"""
import collections
import csv
import io
import logging
import os

import PIL.Image

from smqtk.algorithms import get_descriptor_generator_impls
from smqtk.compute_functions import compute_many_descriptors
from smqtk.representation import (
Expand All @@ -25,6 +22,7 @@
report_progress,
basic_cli_parser,
)
from smqtk.utils.image_utils import is_valid_element
from smqtk.utils import plugin, parallel


Expand Down Expand Up @@ -95,39 +93,24 @@ def run_file_list(c, filelist_filepath, checkpoint_filepath, batch_size=None,
generator = plugin.from_plugin_config(c['descriptor_generator'],
get_descriptor_generator_impls())

def test_image_load(dfe):
try:
PIL.Image.open(io.BytesIO(dfe.get_bytes()))
return True
except IOError, ex:
# noinspection PyProtectedMember
log.warn("Failed to convert '%s' bytes into an image "
"(error: %s). Skipping",
dfe._filepath, str(ex))
return False

def is_valid_element(fp):
dfe = DataFileElement(fp)
ct = dfe.content_type()
if ct in generator.valid_content_types():
if not check_image or test_image_load(dfe):
def iter_valid_elements():
def is_valid(file_path):
dfe = DataFileElement(file_path)

if is_valid_element(dfe,
valid_content_types=generator.valid_content_types(),
check_image=check_image):
return dfe
else:
return None
else:
log.debug("Skipping file (invalid content) type for "
"descriptor generator (fp='%s', ct=%s)",
str(fp), ct)
return None
return False

def iter_valid_elements():
data_elements = collections.deque()
valid_files_filter = parallel.parallel_map(is_valid_element,
valid_files_filter = parallel.parallel_map(is_valid,
file_paths,
name="check-file-type",
use_multiprocessing=True)
for dfe in valid_files_filter:
if dfe is not None:
if dfe:
yield dfe
if data_set is not None:
data_elements.append(dfe)
Expand Down
110 changes: 110 additions & 0 deletions python/smqtk/tests/utils/test_image_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import mock
import os
import sys
import tempfile
import unittest

import nose.tools as ntools
from StringIO import StringIO

from smqtk.bin.check_images import main as check_images_main
from smqtk.representation.data_element.file_element import DataFileElement
from smqtk.tests import TEST_DATA_DIR
from smqtk.utils.image_utils import is_loadable_image, is_valid_element


class TestIsLoadableImage(unittest.TestCase):

def setUp(self):
self.good_image = DataFileElement(os.path.join(TEST_DATA_DIR,
'Lenna.png'))
self.non_image = DataFileElement(os.path.join(TEST_DATA_DIR,
'test_file.dat'))

@ntools.raises(AttributeError)
def test_non_data_element_raises_exception(self):
# should throw:
# AttributeError: 'bool' object has no attribute 'get_bytes'
is_loadable_image(False)


def test_unloadable_image_returns_false(self):
assert is_loadable_image(self.non_image) == False


def test_loadable_image_returns_true(self):
assert is_loadable_image(self.good_image) == True


class TestIsValidElement(unittest.TestCase):

def setUp(self):
self.good_image = DataFileElement(os.path.join(TEST_DATA_DIR,
'Lenna.png'))
self.non_image = DataFileElement(os.path.join(TEST_DATA_DIR,
'test_file.dat'))


def test_non_data_element(self):
assert is_valid_element(False) == False


def test_invalid_content_type(self):
assert is_valid_element(self.good_image, valid_content_types=[]) == False

def test_valid_content_type(self):
assert is_valid_element(self.good_image,
valid_content_types=['image/png']) == True


def test_invalid_image_returns_false(self):
assert is_valid_element(self.non_image, check_image=True) == False


class TestCheckImageCli(unittest.TestCase):

def check_images(self):
stdout, stderr = False, False
saved_stdout, saved_stderr = sys.stdout, sys.stderr

try:
out, err = StringIO(), StringIO()
sys.stdout, sys.stderr = out, err
check_images_main()
except SystemExit:
pass
finally:
stdout, stderr = out.getvalue().strip(), err.getvalue().strip()
sys.stdout, sys.stderr = saved_stdout, saved_stderr

return stdout, stderr


def test_base_case(self):
with mock.patch.object(sys, 'argv', ['']):
assert 'Validate a list of images returning the filepaths' in \
self.check_images()[0]


def test_check_images(self):
# Create test file with a valid, invalid, and non-existent image
_, filename = tempfile.mkstemp()

with open(filename, 'w') as outfile:
outfile.write(os.path.join(TEST_DATA_DIR, 'Lenna.png') + '\n')
outfile.write(os.path.join(TEST_DATA_DIR, 'test_file.dat') + '\n')
outfile.write(os.path.join(TEST_DATA_DIR, 'non-existent-file.jpeg'))

with mock.patch.object(sys, 'argv', ['', '--file-list', filename]):
out, err = self.check_images()

assert out == ','.join([os.path.join(TEST_DATA_DIR, 'Lenna.png'),
'3ee0d360dc12003c0d43e3579295b52b64906e85'])
assert 'non-existent-file.jpeg' not in out

with mock.patch.object(sys, 'argv', ['', '--file-list', filename, '--invert']):
out, err = self.check_images()

assert out == ','.join([os.path.join(TEST_DATA_DIR, 'test_file.dat'),
'da39a3ee5e6b4b0d3255bfef95601890afd80709'])
assert 'non-existent-file.jpeg' not in out
62 changes: 62 additions & 0 deletions python/smqtk/utils/image_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import io
import logging
import PIL.Image

from smqtk.representation.data_element.file_element import DataElement


def is_loadable_image(data_element):
"""
Determine if an image is able to be loaded by PIL.

:param data_element: A data element to check
:type data_element: DataElement

:return: Whether or not the image is loadable
:rtype: bool

"""
log = logging.getLogger(__name__)

try:
PIL.Image.open(io.BytesIO(data_element.get_bytes()))
return True
except IOError, ex:
# noinspection PyProtectedMember
log.debug("Failed to convert '%s' bytes into an image "
"(error: %s). Skipping", data_element, str(ex))
return False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Define a log = logging.getLogger(__name__) and reinstate the removed warning here.



def is_valid_element(data_element, valid_content_types=None, check_image=False):
"""
Determines if a given data element is valid.

:param data_element: Data element
:type data_element: str

:param valid_content_types: List of valid content types, or None to skip
content type checking.
:type valid_content_types: iterable | None

:param check_image: Whether or not to try loading the image with PIL. This
often catches issues that content type can't, such as corrupt images.
:type check_image: bool

:return: Whether or not the data element is valid
:rtype: bool

"""
log = logging.getLogger(__name__)

if (valid_content_types is not None and
data_element.content_type() not in valid_content_types):
log.debug("Skipping file (invalid content) type for "
"descriptor generator (data_element='%s', ct=%s)",
data_element, data_element.content_type())
return False

if check_image and not is_loadable_image(data_element):
return False

return isinstance(data_element, DataElement)
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,8 @@ def list_directory_files(dirpath, exclude_dirs=(), exclude_files=()):
'runApplication = smqtk.bin.runApplication:main',
'summarizePlugins = smqtk.bin.summarizePlugins:main',
'train_itq = smqtk.bin.train_itq:main',
'smqtk-nearest-neighbors = smqtk.bin.nearest_neighbors:main'
'smqtk-nearest-neighbors = smqtk.bin.nearest_neighbors:main',
'smqtk-check-images = smqtk.bin.check_images:main'
]
}
)