Skip to content

Commit

Permalink
Copy batch.py code to CLI and adjust to make it work
Browse files Browse the repository at this point in the history
  • Loading branch information
sbscully committed Jul 26, 2024
1 parent 53bcfaf commit 0e13e65
Show file tree
Hide file tree
Showing 4 changed files with 229 additions and 15 deletions.
195 changes: 190 additions & 5 deletions opencage/command_line.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,204 @@
import argparse
import sys
import re
import csv
import re
import ssl
import asyncio
import traceback
import backoff
import certifi
from tqdm import tqdm
from opencage.geocoder import OpenCageGeocode, OpenCageGeocodeError


def main(args=sys.argv[1:]):
options = parse_args(args)
assert sys.version_info >= (3, 7), "Script requires Python 3.7 or newer"

BatchGeocoder.quiet = options.quiet
BatchGeocoder.max_time = options.timeout
BatchGeocoder.max_tries = options.retries

BatchGeocoder(options)()


class BatchGeocoder():
max_time = 1
max_tries = 10
quiet = False

def on_backoff(details):
if not BatchGeocoder.quiet:
sys.stderr.write("Backing off {wait:0.1f} seconds afters {tries} tries "
"calling function {target} with args {args} and kwargs "
"{kwargs}\n".format(**details))

def __init__(self, options):
self.options = options
self.progress = options.no_progress or tqdm(total=0, position=0, desc="Addresses geocoded", dynamic_ncols=True)
self.queue = asyncio.Queue(maxsize=options.limit)
self.sslcontext = ssl.create_default_context(cafile=certifi.where())
self.outfile = csv.writer(options.output)

def __call__(self):
asyncio.run(self.geocode())
self.options.output.close()

async def geocode(self):
with self.options.input as infile:
csv_reader = csv.reader(infile, strict=True, skipinitialspace=True)

for row in csv_reader:
if len(row) == 0:
raise Exception(f"Empty line in input file at line number {csv_reader.line_num}, aborting")

if self.options.reverse:
work_item = {'id': row[0], 'address': f"{row[1]},{row[2]}"}
else:
work_item = {'id': row[0], 'address': row[1]}

await self.queue.put(work_item)
if self.queue.full():
break

self.log(f"{self.queue.qsize()} work_items in queue")

if self.progress:
self.progress.total = self.queue.qsize()
self.progress.refresh()

## 2. Create tasks workers. That is coroutines, each taks take work_items
## from the queue until it's empty. Tasks run in parallel
##
## https://docs.python.org/3/library/asyncio-task.html#creating-tasks
## https://docs.python.org/3/library/asyncio-task.html#coroutine
##
self.log(f"Creating {self.options.workers} task workers...")
tasks = []
for i in range(self.options.workers):
task = asyncio.create_task(self.work(i))
tasks.append(task)

def main():
print(parse_args(sys.argv[1:]))
## 3. Now workers do the geocoding
##
self.log("Waiting for workers to finish processing queue...")
await self.queue.join()

## 4. Cleanup
##
for task in tasks:
task.cancel()

if self.progress:
self.progress.close()

self.log("All done.\n")

async def work(self, number):
self.log(f"Worker {number} starts...")

while True:
item = await self.queue.get()
address_id = item['id']
address = item['address']
await self.geocode_one_address(address, address_id)

if self.progress:
self.progress.update(1)

self.queue.task_done()

@backoff.on_exception(backoff.expo,
asyncio.TimeoutError,
max_time=max_time,
max_tries=max_tries,
on_backoff=on_backoff)
async def geocode_one_address(self, address, address_id):
async with OpenCageGeocode(self.options.api_key, domain=self.options.api_domain, sslcontext=self.sslcontext) as geocoder:
geocoding_results = None

try:
if self.options.reverse:
# Reverse:
# coordinates -> address, e.g. '40.78,-73.97' => '101, West 91st Street, New York'
lon_lat = address.split(',')
geocoding_results = await geocoder.reverse_geocode_async(lon_lat[0], lon_lat[1], no_annotations=1)
else:
# Forward:
# address -> coordinates
# note: you may also want to set other optional parameters like
# countrycode, language, etc
# see the full list: https://opencagedata.com/api#forward-opt
geocoding_results = await geocoder.geocode_async(address, no_annotations=1)
except OpenCageGeocodeError as exc:
sys.stderr.write(str(exc) + "\n")
except Exception as exc:
traceback.print_exception(exc, file=sys.stderr)

try:
if geocoding_results is not None and len(geocoding_results):
geocoding_result = geocoding_results[0]
else:
geocoding_result = None

await self.write_one_geocoding_result(geocoding_result, address, address_id)
except Exception as exc:
traceback.print_exception(exc, file=sys.stderr)

async def write_one_geocoding_result(self, geocoding_result, address, address_id):
if geocoding_result is not None:
row = [
address_id,
geocoding_result['geometry']['lat'],
geocoding_result['geometry']['lng'],
# Any of the components might be empty:
geocoding_result['components'].get('_type', ''),
geocoding_result['components'].get('country', ''),
geocoding_result['components'].get('county', ''),
geocoding_result['components'].get('city', ''),
geocoding_result['components'].get('postcode', ''),
geocoding_result['components'].get('road', ''),
geocoding_result['components'].get('house_number', ''),
geocoding_result['confidence'],
geocoding_result['formatted']
]
else:
self.log(f"not found, writing empty result: {address}")
row = [
address_id,
0, # not to be confused with https://en.wikipedia.org/wiki/Null_Island
0,
'',
'',
'',
'',
'',
'',
'',
-1, # confidence values are 1-10 (lowest to highest), use -1 for unknown
''
]

self.outfile.writerow(row)

def log(self, message):
if not self.options.quiet:
sys.stderr.write(f"{message}\n")


def parse_args(args):
parser = argparse.ArgumentParser(description="opencage")

parser.add_argument("--apikey", required=True, type=api_key_type, help="your OpenCage API key")
parser.add_argument("--input", required=True, type=argparse.FileType('r'), help="input file name (one query per line)")
parser.add_argument("--output", required=True, type=argparse.FileType('x'), help="output file name")
parser.add_argument("--api-key", required=True, type=api_key_type, help="your OpenCage API key")
parser.add_argument("--input", required=True, type=argparse.FileType('r', encoding='utf-8'), help="input file name (one query per line)")
parser.add_argument("--output", required=True, type=argparse.FileType('x', encoding='utf-8'), help="output file name")

group = parser.add_mutually_exclusive_group(required=True)
group.add_argument("--forward", action="store_true", help="use forward geocoding")
group.add_argument("--reverse", action="store_true", help="use reverse geocoding")

parser.add_argument("--limit", type=int, default=0, help="number of lines to read from the input file")
parser.add_argument("--has-headers", action="store_true", help="if the first row should be treated as a header row")
parser.add_argument("--input-columns", type=comma_separated_type(int), default="1", help="comma separated list of integers")
parser.add_argument("--add-columns", type=comma_separated_type(str), default="_type,country,county,city,postcode,road,house_number,confidence,formatted", help="comma separated list of output columns")
Expand All @@ -31,6 +213,7 @@ def parse_args(args):

return parser.parse_args(args)


def api_key_type(apikey):
pattern = re.compile(r"^(oc_gc_)?[0-9a-f]{32}$")

Expand All @@ -39,6 +222,7 @@ def api_key_type(apikey):

return apikey


def ranged_type(value_type, min_value, max_value):
def range_checker(arg: str):
try:
Expand All @@ -52,6 +236,7 @@ def range_checker(arg: str):
# Return function handle to checking function
return range_checker


def comma_separated_type(value_type):
def comma_separated(arg: str):
return [value_type(x) for x in arg.split(',')]
Expand Down
3 changes: 3 additions & 0 deletions test/fixtures/forward.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Rathausmarkt 1, 20095 Hamburg, Germany
10 Downing Street, London, SW1A 2AA
C/ de Mallorca, 401, 08013 Barcelona, Spain
22 changes: 12 additions & 10 deletions test/test_cli.py → test/test_cli_args.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import pathlib
import pytest
import httpretty

from opencage.command_line import main, parse_args
from opencage.command_line import parse_args

@pytest.fixture(autouse=True)
def around():
Expand All @@ -19,14 +18,14 @@ def assert_parse_args_error(args, message, capfd):
def test_required_arguments(capfd):
assert_parse_args_error(
[],
'the following arguments are required: --apikey, --input, --output',
'the following arguments are required: --api-key, --input, --output',
capfd
)

def test_invalid_api_key(capfd):
assert_parse_args_error(
[
"--apikey", "invalid",
"--api-key", "invalid",
"--input", "test/fixtures/input.txt",
"--output", "test/fixtures/output.txt",
"--forward"
Expand All @@ -38,7 +37,7 @@ def test_invalid_api_key(capfd):
def test_existing_output_file(capfd):
assert_parse_args_error(
[
"--apikey", "oc_gc_12345678901234567890123456789012",
"--api-key", "oc_gc_12345678901234567890123456789012",
"--input", "test/fixtures/input.txt",
"--output", "test/fixtures/input.txt",
"--forward"
Expand All @@ -50,7 +49,7 @@ def test_existing_output_file(capfd):
def test_requires_forward_or_reverse(capfd):
assert_parse_args_error(
[
"--apikey", "oc_gc_12345678901234567890123456789012",
"--api-key", "oc_gc_12345678901234567890123456789012",
"--input", "test/fixtures/input.txt",
"--output", "test/fixtures/output.txt",
],
Expand All @@ -61,7 +60,7 @@ def test_requires_forward_or_reverse(capfd):
def test_argument_range(capfd):
assert_parse_args_error(
[
"--apikey", "oc_gc_12345678901234567890123456789012",
"--api-key", "oc_gc_12345678901234567890123456789012",
"--input", "test/fixtures/input.txt",
"--output", "test/fixtures/output.txt",
"--forward",
Expand All @@ -73,13 +72,14 @@ def test_argument_range(capfd):

def test_full_argument_list():
args = parse_args([
"--apikey", "oc_gc_12345678901234567890123456789012",
"--api-key", "oc_gc_12345678901234567890123456789012",
"--input", "test/fixtures/input.txt",
"--output", "test/fixtures/output.txt",
"--reverse",
"--has-header",
"--input-columns", "1,2",
"--add-columns", "city,postcode",
"--limit", "4",
"--workers", "3",
"--timeout", "2",
"--retries", "1",
Expand All @@ -90,14 +90,15 @@ def test_full_argument_list():
"--quiet"
])

assert args.apikey == "oc_gc_12345678901234567890123456789012"
assert args.api_key == "oc_gc_12345678901234567890123456789012"
assert args.input.name == "test/fixtures/input.txt"
assert args.output.name == "test/fixtures/output.txt"
assert args.reverse is True
assert args.forward is False
assert args.has_headers is True
assert args.input_columns == [1, 2]
assert args.add_columns == ["city", "postcode"]
assert args.limit == 4
assert args.workers == 3
assert args.timeout == 2
assert args.retries == 1
Expand All @@ -109,14 +110,15 @@ def test_full_argument_list():

def test_defaults():
args = parse_args([
"--apikey", "12345678901234567890123456789012",
"--api-key", "12345678901234567890123456789012",
"--input", "test/fixtures/input.txt",
"--output", "test/fixtures/output.txt",
"--reverse"
])

assert args.reverse is True
assert args.forward is False
assert args.limit == 0
assert args.has_headers is False
assert args.input_columns == [1]
assert args.add_columns == ["_type", "country", "county", "city", "postcode", "road", "house_number", "confidence", "formatted"]
Expand Down
24 changes: 24 additions & 0 deletions test/test_cli_run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import pathlib
import pytest

from opencage.command_line import main

@pytest.fixture(autouse=True)
def around():
yield
pathlib.Path("test/fixtures/output.txt").unlink(True)

def test_main():
main([
"--api-key", "5d3c52eae14b4d4a8372c818e076de82",
"--input", "test/fixtures/forward.csv",
"--output", "test/fixtures/output.txt",
"--forward",
])

assert pathlib.Path("test/fixtures/output.txt").exists()

with open("test/fixtures/output.txt", "r") as f:
lines = f.readlines()
assert len(lines) == 3
assert lines[0].strip() == "Rathausmarkt 1,53.5439,10.0133,postcode,Germany,,,20095,,,7,\"20095, Germany\""

0 comments on commit 0e13e65

Please sign in to comment.