Skip to content

Commit

Permalink
Fixes and tidy
Browse files Browse the repository at this point in the history
  • Loading branch information
jbothma committed Jan 17, 2025
1 parent 3a3c6e7 commit 7a06b46
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 23 deletions.
10 changes: 5 additions & 5 deletions nomenklatura/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,8 @@ def match_command(
stream = path_entities(entities, Entity)
for proxy in match(enricher, resolver, stream):
write_entity(fh, proxy)
finally:
resolver.commit()
finally:
enricher.close()


Expand All @@ -214,8 +214,8 @@ def enrich_command(
stream = path_entities(entities, Entity)
for proxy in enrich(enricher, resolver, stream):
write_entity(fh, proxy)
finally:
resolver.commit()
finally:
enricher.close()


Expand Down Expand Up @@ -292,7 +292,7 @@ def statements_aggregate(
write_entity(outfh, entity)


@cli.command("load-resolver", help="Load file-based resolver info to database")
@cli.command("load-resolver", help="Load resolver decisions from file into database")
@click.argument("source", type=InPath)
def load_resolver(source: Path) -> None:
resolver = Resolver[Entity].make_default()
Expand All @@ -301,13 +301,13 @@ def load_resolver(source: Path) -> None:
resolver.commit()


@cli.command("dump-resolver", help="Load file-based resolver info to database")
@cli.command("dump-resolver", help="Dump resolver decisions from database to file")
@click.argument("target", type=OutPath)
def dump_resolver(target: Path) -> None:
resolver = Resolver[Entity].make_default()
resolver.begin()
resolver.save(target)
resolver.commit()
resolver.rollback()


@cli.command("bench", help="Benchmark a matching algorithm")
Expand Down
14 changes: 6 additions & 8 deletions nomenklatura/resolver/resolver.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,18 @@
import getpass
import logging
from functools import lru_cache
from sqlalchemy import MetaData, or_, alias, func
from sqlalchemy import Table, Column, Unicode, Float
from sqlalchemy.engine import Engine, Connection, Transaction
from sqlalchemy.sql.expression import select, delete, update
from threading import RLock
from typing import Dict, Generator, Optional, Set, Tuple
from urllib.parse import urlunparse
import getpass
import logging

from followthemoney.types import registry

from rigour.ids.wikidata import is_qid
from rigour.time import utc_now
from sqlalchemy import Column, Float, MetaData, Table, Unicode, alias, func, or_
from sqlalchemy.engine import Connection, Engine, Transaction
from sqlalchemy.sql.expression import delete, select, update

from nomenklatura.db import get_upsert_func, get_engine
from nomenklatura.db import get_engine, get_upsert_func
from nomenklatura.entity import CE
from nomenklatura.judgement import Judgement
from nomenklatura.resolver.edge import Edge
Expand Down
16 changes: 6 additions & 10 deletions tests/test_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,6 @@ def test_linker():


def test_resolver_store():
engine = get_engine()
metadata = get_metadata()

with NamedTemporaryFile("w") as fh:
path = Path(fh.name)
resolver = Resolver.make_default()
Expand All @@ -111,9 +108,9 @@ def test_resolver_store():
resolver.suggest("a1", "c1", 7.0)
resolver.save(path)

other = Resolver(
engine=engine, metadata=metadata, table_name="other", create=True
)
get_engine.cache_clear()
get_metadata.cache_clear()
other = Resolver(engine=get_engine(), metadata=get_metadata(), create=True)
other.begin()
other.load(path)
assert len(other.get_edges()) == len(resolver.get_edges())
Expand Down Expand Up @@ -143,9 +140,6 @@ def test_resolver_candidates():


def test_resolver_statements():
engine = get_engine()
metadata = get_metadata()

resolver = Resolver.make_default()
resolver.begin()
canon = resolver.decide("a1", "a2", Judgement.POSITIVE)
Expand All @@ -157,7 +151,9 @@ def test_resolver_statements():
assert stmt.value == "b2"
resolver.commit()

other = Resolver(engine=engine, metadata=metadata, table_name="other", create=True)
get_engine.cache_clear()
get_metadata.cache_clear()
other = Resolver(engine=get_engine(), metadata=get_metadata(), create=True)
other.begin()
stmt = other.apply_statement(stmt)
assert stmt.canonical_id == "a1"
Expand Down

0 comments on commit 7a06b46

Please sign in to comment.