From 0e13e6554c496bc99b85890a80c0a591391d3506 Mon Sep 17 00:00:00 2001 From: Samuel Scully Date: Fri, 26 Jul 2024 20:38:42 +0100 Subject: [PATCH] Copy batch.py code to CLI and adjust to make it work --- opencage/command_line.py | 195 ++++++++++++++++++++++++- test/fixtures/forward.csv | 3 + test/{test_cli.py => test_cli_args.py} | 22 +-- test/test_cli_run.py | 24 +++ 4 files changed, 229 insertions(+), 15 deletions(-) create mode 100644 test/fixtures/forward.csv rename test/{test_cli.py => test_cli_args.py} (84%) create mode 100644 test/test_cli_run.py diff --git a/opencage/command_line.py b/opencage/command_line.py index 3f25e28..fbce379 100644 --- a/opencage/command_line.py +++ b/opencage/command_line.py @@ -1,22 +1,204 @@ import argparse import sys import re +import csv +import re +import ssl +import asyncio +import traceback +import backoff +import certifi +from tqdm import tqdm +from opencage.geocoder import OpenCageGeocode, OpenCageGeocodeError + + +def main(args=sys.argv[1:]): + options = parse_args(args) + assert sys.version_info >= (3, 7), "Script requires Python 3.7 or newer" + + BatchGeocoder.quiet = options.quiet + BatchGeocoder.max_time = options.timeout + BatchGeocoder.max_tries = options.retries + + BatchGeocoder(options)() + + +class BatchGeocoder(): + max_time = 1 + max_tries = 10 + quiet = False + + def on_backoff(details): + if not BatchGeocoder.quiet: + sys.stderr.write("Backing off {wait:0.1f} seconds afters {tries} tries " + "calling function {target} with args {args} and kwargs " + "{kwargs}\n".format(**details)) + + def __init__(self, options): + self.options = options + self.progress = options.no_progress or tqdm(total=0, position=0, desc="Addresses geocoded", dynamic_ncols=True) + self.queue = asyncio.Queue(maxsize=options.limit) + self.sslcontext = ssl.create_default_context(cafile=certifi.where()) + self.outfile = csv.writer(options.output) + + def __call__(self): + asyncio.run(self.geocode()) + self.options.output.close() + + async def geocode(self): + with self.options.input as infile: + csv_reader = csv.reader(infile, strict=True, skipinitialspace=True) + + for row in csv_reader: + if len(row) == 0: + raise Exception(f"Empty line in input file at line number {csv_reader.line_num}, aborting") + + if self.options.reverse: + work_item = {'id': row[0], 'address': f"{row[1]},{row[2]}"} + else: + work_item = {'id': row[0], 'address': row[1]} + + await self.queue.put(work_item) + if self.queue.full(): + break + + self.log(f"{self.queue.qsize()} work_items in queue") + + if self.progress: + self.progress.total = self.queue.qsize() + self.progress.refresh() + + ## 2. Create tasks workers. That is coroutines, each taks take work_items + ## from the queue until it's empty. Tasks run in parallel + ## + ## https://docs.python.org/3/library/asyncio-task.html#creating-tasks + ## https://docs.python.org/3/library/asyncio-task.html#coroutine + ## + self.log(f"Creating {self.options.workers} task workers...") + tasks = [] + for i in range(self.options.workers): + task = asyncio.create_task(self.work(i)) + tasks.append(task) -def main(): - print(parse_args(sys.argv[1:])) + ## 3. Now workers do the geocoding + ## + self.log("Waiting for workers to finish processing queue...") + await self.queue.join() + + ## 4. Cleanup + ## + for task in tasks: + task.cancel() + + if self.progress: + self.progress.close() + + self.log("All done.\n") + + async def work(self, number): + self.log(f"Worker {number} starts...") + + while True: + item = await self.queue.get() + address_id = item['id'] + address = item['address'] + await self.geocode_one_address(address, address_id) + + if self.progress: + self.progress.update(1) + + self.queue.task_done() + + @backoff.on_exception(backoff.expo, + asyncio.TimeoutError, + max_time=max_time, + max_tries=max_tries, + on_backoff=on_backoff) + async def geocode_one_address(self, address, address_id): + async with OpenCageGeocode(self.options.api_key, domain=self.options.api_domain, sslcontext=self.sslcontext) as geocoder: + geocoding_results = None + + try: + if self.options.reverse: + # Reverse: + # coordinates -> address, e.g. '40.78,-73.97' => '101, West 91st Street, New York' + lon_lat = address.split(',') + geocoding_results = await geocoder.reverse_geocode_async(lon_lat[0], lon_lat[1], no_annotations=1) + else: + # Forward: + # address -> coordinates + # note: you may also want to set other optional parameters like + # countrycode, language, etc + # see the full list: https://opencagedata.com/api#forward-opt + geocoding_results = await geocoder.geocode_async(address, no_annotations=1) + except OpenCageGeocodeError as exc: + sys.stderr.write(str(exc) + "\n") + except Exception as exc: + traceback.print_exception(exc, file=sys.stderr) + + try: + if geocoding_results is not None and len(geocoding_results): + geocoding_result = geocoding_results[0] + else: + geocoding_result = None + + await self.write_one_geocoding_result(geocoding_result, address, address_id) + except Exception as exc: + traceback.print_exception(exc, file=sys.stderr) + + async def write_one_geocoding_result(self, geocoding_result, address, address_id): + if geocoding_result is not None: + row = [ + address_id, + geocoding_result['geometry']['lat'], + geocoding_result['geometry']['lng'], + # Any of the components might be empty: + geocoding_result['components'].get('_type', ''), + geocoding_result['components'].get('country', ''), + geocoding_result['components'].get('county', ''), + geocoding_result['components'].get('city', ''), + geocoding_result['components'].get('postcode', ''), + geocoding_result['components'].get('road', ''), + geocoding_result['components'].get('house_number', ''), + geocoding_result['confidence'], + geocoding_result['formatted'] + ] + else: + self.log(f"not found, writing empty result: {address}") + row = [ + address_id, + 0, # not to be confused with https://en.wikipedia.org/wiki/Null_Island + 0, + '', + '', + '', + '', + '', + '', + '', + -1, # confidence values are 1-10 (lowest to highest), use -1 for unknown + '' + ] + + self.outfile.writerow(row) + + def log(self, message): + if not self.options.quiet: + sys.stderr.write(f"{message}\n") def parse_args(args): parser = argparse.ArgumentParser(description="opencage") - parser.add_argument("--apikey", required=True, type=api_key_type, help="your OpenCage API key") - parser.add_argument("--input", required=True, type=argparse.FileType('r'), help="input file name (one query per line)") - parser.add_argument("--output", required=True, type=argparse.FileType('x'), help="output file name") + parser.add_argument("--api-key", required=True, type=api_key_type, help="your OpenCage API key") + parser.add_argument("--input", required=True, type=argparse.FileType('r', encoding='utf-8'), help="input file name (one query per line)") + parser.add_argument("--output", required=True, type=argparse.FileType('x', encoding='utf-8'), help="output file name") group = parser.add_mutually_exclusive_group(required=True) group.add_argument("--forward", action="store_true", help="use forward geocoding") group.add_argument("--reverse", action="store_true", help="use reverse geocoding") + parser.add_argument("--limit", type=int, default=0, help="number of lines to read from the input file") parser.add_argument("--has-headers", action="store_true", help="if the first row should be treated as a header row") parser.add_argument("--input-columns", type=comma_separated_type(int), default="1", help="comma separated list of integers") parser.add_argument("--add-columns", type=comma_separated_type(str), default="_type,country,county,city,postcode,road,house_number,confidence,formatted", help="comma separated list of output columns") @@ -31,6 +213,7 @@ def parse_args(args): return parser.parse_args(args) + def api_key_type(apikey): pattern = re.compile(r"^(oc_gc_)?[0-9a-f]{32}$") @@ -39,6 +222,7 @@ def api_key_type(apikey): return apikey + def ranged_type(value_type, min_value, max_value): def range_checker(arg: str): try: @@ -52,6 +236,7 @@ def range_checker(arg: str): # Return function handle to checking function return range_checker + def comma_separated_type(value_type): def comma_separated(arg: str): return [value_type(x) for x in arg.split(',')] diff --git a/test/fixtures/forward.csv b/test/fixtures/forward.csv new file mode 100644 index 0000000..84fcab1 --- /dev/null +++ b/test/fixtures/forward.csv @@ -0,0 +1,3 @@ +Rathausmarkt 1, 20095 Hamburg, Germany +10 Downing Street, London, SW1A 2AA +C/ de Mallorca, 401, 08013 Barcelona, Spain \ No newline at end of file diff --git a/test/test_cli.py b/test/test_cli_args.py similarity index 84% rename from test/test_cli.py rename to test/test_cli_args.py index be1335a..e2d5e11 100644 --- a/test/test_cli.py +++ b/test/test_cli_args.py @@ -1,8 +1,7 @@ import pathlib import pytest -import httpretty -from opencage.command_line import main, parse_args +from opencage.command_line import parse_args @pytest.fixture(autouse=True) def around(): @@ -19,14 +18,14 @@ def assert_parse_args_error(args, message, capfd): def test_required_arguments(capfd): assert_parse_args_error( [], - 'the following arguments are required: --apikey, --input, --output', + 'the following arguments are required: --api-key, --input, --output', capfd ) def test_invalid_api_key(capfd): assert_parse_args_error( [ - "--apikey", "invalid", + "--api-key", "invalid", "--input", "test/fixtures/input.txt", "--output", "test/fixtures/output.txt", "--forward" @@ -38,7 +37,7 @@ def test_invalid_api_key(capfd): def test_existing_output_file(capfd): assert_parse_args_error( [ - "--apikey", "oc_gc_12345678901234567890123456789012", + "--api-key", "oc_gc_12345678901234567890123456789012", "--input", "test/fixtures/input.txt", "--output", "test/fixtures/input.txt", "--forward" @@ -50,7 +49,7 @@ def test_existing_output_file(capfd): def test_requires_forward_or_reverse(capfd): assert_parse_args_error( [ - "--apikey", "oc_gc_12345678901234567890123456789012", + "--api-key", "oc_gc_12345678901234567890123456789012", "--input", "test/fixtures/input.txt", "--output", "test/fixtures/output.txt", ], @@ -61,7 +60,7 @@ def test_requires_forward_or_reverse(capfd): def test_argument_range(capfd): assert_parse_args_error( [ - "--apikey", "oc_gc_12345678901234567890123456789012", + "--api-key", "oc_gc_12345678901234567890123456789012", "--input", "test/fixtures/input.txt", "--output", "test/fixtures/output.txt", "--forward", @@ -73,13 +72,14 @@ def test_argument_range(capfd): def test_full_argument_list(): args = parse_args([ - "--apikey", "oc_gc_12345678901234567890123456789012", + "--api-key", "oc_gc_12345678901234567890123456789012", "--input", "test/fixtures/input.txt", "--output", "test/fixtures/output.txt", "--reverse", "--has-header", "--input-columns", "1,2", "--add-columns", "city,postcode", + "--limit", "4", "--workers", "3", "--timeout", "2", "--retries", "1", @@ -90,7 +90,7 @@ def test_full_argument_list(): "--quiet" ]) - assert args.apikey == "oc_gc_12345678901234567890123456789012" + assert args.api_key == "oc_gc_12345678901234567890123456789012" assert args.input.name == "test/fixtures/input.txt" assert args.output.name == "test/fixtures/output.txt" assert args.reverse is True @@ -98,6 +98,7 @@ def test_full_argument_list(): assert args.has_headers is True assert args.input_columns == [1, 2] assert args.add_columns == ["city", "postcode"] + assert args.limit == 4 assert args.workers == 3 assert args.timeout == 2 assert args.retries == 1 @@ -109,7 +110,7 @@ def test_full_argument_list(): def test_defaults(): args = parse_args([ - "--apikey", "12345678901234567890123456789012", + "--api-key", "12345678901234567890123456789012", "--input", "test/fixtures/input.txt", "--output", "test/fixtures/output.txt", "--reverse" @@ -117,6 +118,7 @@ def test_defaults(): assert args.reverse is True assert args.forward is False + assert args.limit == 0 assert args.has_headers is False assert args.input_columns == [1] assert args.add_columns == ["_type", "country", "county", "city", "postcode", "road", "house_number", "confidence", "formatted"] diff --git a/test/test_cli_run.py b/test/test_cli_run.py new file mode 100644 index 0000000..3f9cd62 --- /dev/null +++ b/test/test_cli_run.py @@ -0,0 +1,24 @@ +import pathlib +import pytest + +from opencage.command_line import main + +@pytest.fixture(autouse=True) +def around(): + yield + pathlib.Path("test/fixtures/output.txt").unlink(True) + +def test_main(): + main([ + "--api-key", "5d3c52eae14b4d4a8372c818e076de82", + "--input", "test/fixtures/forward.csv", + "--output", "test/fixtures/output.txt", + "--forward", + ]) + + assert pathlib.Path("test/fixtures/output.txt").exists() + + with open("test/fixtures/output.txt", "r") as f: + lines = f.readlines() + assert len(lines) == 3 + assert lines[0].strip() == "Rathausmarkt 1,53.5439,10.0133,postcode,Germany,,,20095,,,7,\"20095, Germany\""