diff --git a/MANIFEST.in b/MANIFEST.in index d0469c3..491f4fe 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -10,7 +10,6 @@ recursive-exclude * *.py[co] recursive-include docs *.rst conf.py Makefile make.bat -# package specific instructions -include mciutil/cli/mideu.yml +recursive-include *.yml include versioneer.py include mciutil/_version.py diff --git a/mciutil/cli/common.py b/mciutil/cli/common.py index 8be9c61..466306d 100644 --- a/mciutil/cli/common.py +++ b/mciutil/cli/common.py @@ -8,6 +8,7 @@ import sys import logging import csv +from pkg_resources import resource_filename LOGGER = logging.getLogger(__name__) @@ -86,12 +87,14 @@ def get_config_filename(config_filename): user_home_dir = os.path.expanduser("~") if os.path.isfile(current_dir + "/" + config_filename): - return current_dir + "/" + config_filename + config_filename = current_dir + "/" + config_filename elif os.path.isfile(user_home_dir + "/." + config_filename): - return user_home_dir + "/." + config_filename + config_filename = user_home_dir + "/." + config_filename else: - module_dir = os.path.dirname(os.path.abspath(__file__)) - return module_dir + "/" + config_filename + module_dir = resource_filename("mciutil", "cli") + config_filename = module_dir + "/" + config_filename + LOGGER.info("Using {0} config file".format(config_filename)) + return config_filename def add_to_csv(data_list, field_list, output_filename): @@ -103,20 +106,40 @@ def add_to_csv(data_list, field_list, output_filename): :param output_filename: filename for output CSV file :return: None """ + try: + instance_type = unicode + file_mode = "wb" + except NameError: + instance_type = str + file_mode = "w" + filtered_data_list = filter_data_list(data_list, field_list) - with open(output_filename, "w") as output_file: + with open(output_filename, file_mode) as output_file: writer = csv.DictWriter(output_file, fieldnames=field_list, extrasaction="ignore", lineterminator="\n") + # python 2.6 does not support writeheader() so skip if sys.version_info[0] == 2 and sys.version_info[1] == 6: pass else: writer.writeheader() - writer.writerows(filtered_data_list) + for item in filtered_data_list: + if file_mode == "w": + row = dict( + (k, v.decode('latin1') if not isinstance(v, instance_type) else v) + for k, v in item.items() + ) + else: + row = dict( + (k, v.encode('utf-8') if isinstance(v, instance_type) else v) + for k, v in item.items() + ) + writer.writerow(row) + LOGGER.info("%s records written", len(data_list)) @@ -145,7 +168,7 @@ def filter_dictionary(dictionary, field_list): return_dictionary = {} for item in dictionary: if item in field_list: - return_dictionary[item] = dictionary[item].decode() + return_dictionary[item] = dictionary[item] return return_dictionary diff --git a/setup.cfg b/setup.cfg index 3a09b53..5a4e536 100644 --- a/setup.cfg +++ b/setup.cfg @@ -5,5 +5,4 @@ universal = 1 VCS = git style = pep440 versionfile_source = mciutil/_version.py -versionfile_build = mciutil/_version.py tag_prefix = diff --git a/setup.py b/setup.py index 3dba435..ff50252 100755 --- a/setup.py +++ b/setup.py @@ -32,9 +32,10 @@ packages=[ 'mciutil', 'mciutil.cli' ], - package_dir={'mciutil': - 'mciutil'}, include_package_data=True, + package_data={ + 'mciutil.cli': ['*.yml'] + }, install_requires=requirements, entry_points={ 'console_scripts': [ diff --git a/tests/test_mideu.py b/tests/test_mideu.py index f9dd677..ee884e2 100644 --- a/tests/test_mideu.py +++ b/tests/test_mideu.py @@ -1,3 +1,5 @@ +# -*- coding: utf-8 -*- + from __future__ import absolute_import import os.path import sys @@ -143,7 +145,19 @@ class FilteredDictionaryTest(TestCase): def test_filter_dict(self): dict = {"a": b("123"), "b": b("456"), "c": b("789")} field_list = ["a", "c"] - expected_dict = {"a": "123", "c": "789"} + expected_dict = {"a": b("123"), "c": b("789")} + actual_dict = filter_dictionary(dict, field_list) + self.assertEqual(len(actual_dict), 2) + self.assertEqual(actual_dict, expected_dict) + + def test_filter_dict_with_latin1(self): + dict = { + "a": b("123\xc9"), + "b": b("456\xc9"), + "c": b("789\xc9") + } + field_list = ["a", "c"] + expected_dict = {"a": b("123\xc9"), "c": b("789\xc9")} actual_dict = filter_dictionary(dict, field_list) self.assertEqual(len(actual_dict), 2) self.assertEqual(actual_dict, expected_dict)