diff --git a/opencage/batch.py b/opencage/batch.py index 5e33533..7af9e0f 100644 --- a/opencage/batch.py +++ b/opencage/batch.py @@ -3,17 +3,22 @@ import asyncio import traceback import threading -import backoff -import certifi import random -import re -from tqdm import tqdm -from urllib.parse import urlencode from contextlib import suppress +from urllib.parse import urlencode +from tqdm import tqdm +import certifi +import backoff from opencage.geocoder import OpenCageGeocode, OpenCageGeocodeError, _query_for_reverse_geocoding class OpenCageBatchGeocoder(): + + """ Called from command_line.py + init() receives the parsed command line parameters + geocode() receive an input and output CSV reader/writer and loops over the data + """ + def __init__(self, options): self.options = options self.sslcontext = ssl.create_default_context(cafile=certifi.where()) @@ -22,7 +27,7 @@ def __init__(self, options): def __call__(self, *args, **kwargs): asyncio.run(self.geocode(*args, **kwargs)) - async def geocode(self, input, output): + async def geocode(self, csv_input, csv_output): if not self.options.dry_run: test = await self.test_request() if test['error']: @@ -33,13 +38,13 @@ async def geocode(self, input, output): self.options.workers = 1 if self.options.headers: - header_columns = next(input, None) + header_columns = next(csv_input, None) if header_columns is None: return queue = asyncio.Queue(maxsize=self.options.limit) - read_warnings = await self.read_input(input, queue) + read_warnings = await self.read_input(csv_input, queue) if self.options.dry_run: if not read_warnings: @@ -47,14 +52,14 @@ async def geocode(self, input, output): return if self.options.headers: - output.writerow(header_columns + self.options.add_columns) + csv_output.writerow(header_columns + self.options.add_columns) progress_bar = not (self.options.no_progress or self.options.quiet) and \ tqdm(total=queue.qsize(), position=0, desc="Addresses geocoded", dynamic_ncols=True) tasks = [] for _ in range(self.options.workers): - task = asyncio.create_task(self.worker(output, queue, progress_bar)) + task = asyncio.create_task(self.worker(csv_output, queue, progress_bar)) tasks.append(task) # This starts the workers and waits until all are finished @@ -80,9 +85,9 @@ async def test_request(self): except Exception as exc: return { 'error': exc } - async def read_input(self, input, queue): + async def read_input(self, csv_input, queue): any_warnings = False - for index, row in enumerate(input): + for index, row in enumerate(csv_input): line_number = index + 1 if len(row) == 0: @@ -138,12 +143,12 @@ async def read_one_line(self, row, row_id): return { 'row_id': row_id, 'address': ','.join(address), 'original_columns': row, 'warnings': warnings } - async def worker(self, output, queue, progress): + async def worker(self, csv_output, queue, progress): while True: item = await queue.get() try: - await self.geocode_one_address(output, item['row_id'], item['address'], item['original_columns']) + await self.geocode_one_address(csv_output, item['row_id'], item['address'], item['original_columns']) if progress: progress.update(1) @@ -152,7 +157,7 @@ async def worker(self, output, queue, progress): finally: queue.task_done() - async def geocode_one_address(self, output, row_id, address, original_columns): + async def geocode_one_address(self, csv_output, row_id, address, original_columns): def on_backoff(details): if not self.options.quiet: sys.stderr.write("Backing off {wait:0.1f} seconds afters {tries} tries " @@ -195,13 +200,13 @@ async def _geocode_one_address(): 'response': geocoding_result }) - await self.write_one_geocoding_result(output, row_id, address, geocoding_result, original_columns) + await self.write_one_geocoding_result(csv_output, row_id, geocoding_result, original_columns) except Exception as exc: traceback.print_exception(exc, file=sys.stderr) await _geocode_one_address() - async def write_one_geocoding_result(self, output, row_id, address, geocoding_result, original_columns = []): + async def write_one_geocoding_result(self, csv_output, row_id, geocoding_result, original_columns): row = original_columns for column in self.options.add_columns: @@ -227,11 +232,10 @@ async def write_one_geocoding_result(self, output, row_id, address, geocoding_re if self.options.verbose: self.log(f"Writing row {row_id}") - output.writerow(row) + csv_output.writerow(row) self.write_counter = self.write_counter + 1 def log(self, message): if not self.options.quiet: sys.stderr.write(f"{message}\n") - diff --git a/opencage/command_line.py b/opencage/command_line.py index 6086404..5bda93f 100644 --- a/opencage/command_line.py +++ b/opencage/command_line.py @@ -15,13 +15,12 @@ def main(args=sys.argv[1:]): geocoder = OpenCageBatchGeocoder(options) - with options.input as input: - output_io = io.StringIO() if options.dry_run else open(options.output, 'x', encoding='utf-8') - reader = csv.reader(input, strict=True, skipinitialspace=True) - writer = csv.writer(output_io) + with options.input as input_filename: + with (io.StringIO() if options.dry_run else open(options.output, 'x', encoding='utf-8')) as output_io: + reader = csv.reader(input_filename, strict=True, skipinitialspace=True) + writer = csv.writer(output_io) - geocoder(input=reader, output=writer) - output_io.close() + geocoder(csv_input=reader, csv_output=writer) def parse_args(args): if len(args) == 0: @@ -54,7 +53,7 @@ def parse_args(args): sys.exit(1) if 0 in options.input_columns: - print(f"Error: A column 0 in --input-columns does not exist. The lowest possible number is 1.", file=sys.stderr) + print("Error: A column 0 in --input-columns does not exist. The lowest possible number is 1.", file=sys.stderr) sys.exit(1) return options @@ -94,8 +93,8 @@ def ranged_type(value_type, min_value, max_value): def range_checker(arg: str): try: f = value_type(arg) - except ValueError: - raise argparse.ArgumentTypeError(f'must be a valid {value_type}') + except ValueError as exc: + raise argparse.ArgumentTypeError(f'must be a valid {value_type}') from exc if f < min_value or f > max_value: raise argparse.ArgumentTypeError(f'must be within [{min_value}, {max_value}]') return f @@ -120,5 +119,5 @@ def comma_separated_dict_type(arg): try: return dict([x.split('=') for x in arg.split(',')]) - except ValueError: - raise argparse.ArgumentTypeError("must be a valid comma separated list of key=value pairs") + except ValueError as exc: + raise argparse.ArgumentTypeError("must be a valid comma separated list of key=value pairs") from exc diff --git a/opencage/geocoder.py b/opencage/geocoder.py index d3b650b..4622ce5 100644 --- a/opencage/geocoder.py +++ b/opencage/geocoder.py @@ -186,8 +186,8 @@ def geocode(self, query, **kwargs): if raw_response: return response - else: - return floatify_latlng(response['results']) + + return floatify_latlng(response['results']) async def geocode_async(self, query, **kwargs): """ @@ -219,8 +219,8 @@ async def geocode_async(self, query, **kwargs): if raw_response: return response - else: - return floatify_latlng(response['results']) + + return floatify_latlng(response['results']) def reverse_geocode(self, lat, lng, **kwargs): """ @@ -285,18 +285,13 @@ def _opencage_request(self, params): return response_json def _opencage_headers(self, client): - if client == 'requests': - client_version = requests.__version__ - elif client == 'aiohttp': + client_version = requests.__version__ + if client == 'aiohttp': client_version = aiohttp.__version__ + py_version = '.'.join(str(x) for x in sys.version_info[0:3]) return { - 'User-Agent': 'opencage-python/%s Python/%s %s/%s' % ( - __version__, - '.'.join(str(x) for x in sys.version_info[0:3]), - client, - client_version - ) + 'User-Agent': f"opencage-python/{__version__} Python/{py_version} {client}/{client_version}" } async def _opencage_async_request(self, params):