diff --git a/opencage/batch.py b/opencage/batch.py index b6bd9fb..98d9c70 100644 --- a/opencage/batch.py +++ b/opencage/batch.py @@ -60,10 +60,21 @@ async def read_one_line(self, row, line_number): if len(row) == 0: raise Exception(f"Empty line in input file at line number {line_number}, aborting") - if self.options.reverse: - return {'id': row[0], 'address': f"{row[1]},{row[2]}"} + if self.options.has_headers and line_number == 1: + return None + + if self.options.input_columns: + # input_columns option uses 1-based indexing + address = [row[i-1] for i in self.options.input_columns] + elif self.options.reverse: + address = [row[1], row[2]] else: - return {'id': row[0], 'address': row[1]} + address = [row[1]] + + if self.options.reverse and len(address) != 2: + self.log(f"Expected two comma-separated values for reverse geocoding, got {address}") + + return { 'id': row[0], 'address': ','.join(address) } async def worker(self, queue, progress): while True: @@ -94,8 +105,8 @@ async def _geocode_one_address(): try: if self.options.reverse: - lon_lat = address.split(',') - geocoding_results = await geocoder.reverse_geocode_async(lon_lat[0], lon_lat[1], **params) + lon, lat = address.split(',') + geocoding_results = await geocoder.reverse_geocode_async(lon, lat, **params) else: geocoding_results = await geocoder.geocode_async(address, **params) except OpenCageGeocodeError as exc: @@ -124,18 +135,16 @@ async def write_one_geocoding_result(self, result, address, id): id, result['geometry']['lat'], result['geometry']['lng'], - # Any of the components might be empty: - result['components'].get('_type', ''), - result['components'].get('country', ''), - result['components'].get('county', ''), - result['components'].get('city', ''), - result['components'].get('postcode', ''), - result['components'].get('road', ''), - result['components'].get('house_number', ''), - result['confidence'], - result['formatted'] ] + for column in self.options.add_columns: + if column in result: + row.append(result[column]) + elif column in result['components']: + row.append(result['components'][column]) + else: + row.append('') + self.outfile.writerow(row) def blank_row(self, id): diff --git a/opencage/command_line.py b/opencage/command_line.py index 3203ced..ea872fb 100644 --- a/opencage/command_line.py +++ b/opencage/command_line.py @@ -26,7 +26,7 @@ def parse_args(args): parser.add_argument("--limit", type=int, default=0, help="number of lines to read from the input file") parser.add_argument("--has-headers", action="store_true", help="if the first row should be treated as a header row") - parser.add_argument("--input-columns", type=comma_separated_type(int), default="1", help="comma separated list of integers") + parser.add_argument("--input-columns", type=comma_separated_type(int), default="", help="comma separated list of integers") parser.add_argument("--add-columns", type=comma_separated_type(str), default="_type,country,county,city,postcode,road,house_number,confidence,formatted", help="comma separated list of output columns") parser.add_argument("--workers", type=ranged_type(int, 1, 20), default=1, help="number of parallel workers") parser.add_argument("--timeout", type=ranged_type(int, 1, 60), default=1, help="timeout in seconds") @@ -65,6 +65,9 @@ def range_checker(arg: str): def comma_separated_type(value_type): def comma_separated(arg: str): + if not arg: + return [] + return [value_type(x) for x in arg.split(',')] return comma_separated diff --git a/test/fixtures/forward.csv b/test/fixtures/forward.csv index 84fcab1..5b29159 100644 --- a/test/fixtures/forward.csv +++ b/test/fixtures/forward.csv @@ -1,3 +1,3 @@ -Rathausmarkt 1, 20095 Hamburg, Germany -10 Downing Street, London, SW1A 2AA -C/ de Mallorca, 401, 08013 Barcelona, Spain \ No newline at end of file +1,Rathausmarkt 1, 20095 Hamburg, Germany +2,10 Downing Street, London, SW1A 2AA +3,C/ de Mallorca 401, 08013 Barcelona, Spain \ No newline at end of file diff --git a/test/test_cli_args.py b/test/test_cli_args.py index 33e798d..0f918ab 100644 --- a/test/test_cli_args.py +++ b/test/test_cli_args.py @@ -120,7 +120,7 @@ def test_defaults(): assert args.forward is False assert args.limit == 0 assert args.has_headers is False - assert args.input_columns == [1] + assert args.input_columns == [] assert args.add_columns == ["_type", "country", "county", "city", "postcode", "road", "house_number", "confidence", "formatted"] assert args.workers == 1 assert args.timeout == 1 diff --git a/test/test_cli_run.py b/test/test_cli_run.py index 990b40c..78b018d 100644 --- a/test/test_cli_run.py +++ b/test/test_cli_run.py @@ -14,6 +14,7 @@ def test_main(): "--input", "test/fixtures/forward.csv", "--output", "test/fixtures/output.txt", "--forward", + "--input-columns", "2,3,4", ]) assert pathlib.Path("test/fixtures/output.txt").exists() @@ -21,4 +22,4 @@ def test_main(): with open("test/fixtures/output.txt", "r") as f: lines = f.readlines() assert len(lines) == 3 - assert lines[0].strip() == "Rathausmarkt 1,51.9526622,7.6324709,social_facility,Germany,,Münster,48153,Friedrich-Ebert-Straße,7,9,\"Chance e.V., Friedrich-Ebert-Straße 7, 48153 Münster, Germany\"" + assert lines[0].strip() == '1,51.9526622,7.6324709,social_facility,Germany,,Münster,48153,Friedrich-Ebert-Straße,7,9,"Chance e.V., Friedrich-Ebert-Straße 7, 48153 Münster, Germany"'