Skip to content

Commit

Permalink
feat: Allow SQL tap developers to auto-skip certain stream names from…
Browse files Browse the repository at this point in the history
… discovery
  • Loading branch information
edgarrmondragon committed Dec 3, 2024
1 parent da883d9 commit 8b3528e
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 9 deletions.
33 changes: 25 additions & 8 deletions singer_sdk/connectors/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -881,6 +881,8 @@ def discover_catalog_entry(
schema_name: str,
table_name: str,
is_view: bool, # noqa: FBT001
*,
reflect_indices: bool = True,
) -> CatalogEntry:
"""Create `CatalogEntry` object for the given table or a view.
Expand All @@ -890,6 +892,7 @@ def discover_catalog_entry(
schema_name: Schema name to inspect
table_name: Name of the table or a view
is_view: Flag whether this object is a view, returned by `get_object_names`
reflect_indices: Whether to reflect indices
Returns:
`CatalogEntry` object for the given table or a view
Expand All @@ -905,11 +908,12 @@ def discover_catalog_entry(

# An element of the columns list is ``None`` if it's an expression and is
# returned in the ``expressions`` list of the reflected index.
possible_primary_keys.extend(
index_def["column_names"] # type: ignore[misc]
for index_def in inspected.get_indexes(table_name, schema=schema_name)
if index_def.get("unique", False)
)
if reflect_indices:
possible_primary_keys.extend(
index_def["column_names"] # type: ignore[misc]
for index_def in inspected.get_indexes(table_name, schema=schema_name)
if index_def.get("unique", False)
)

key_properties = next(iter(possible_primary_keys), None)

Expand Down Expand Up @@ -960,16 +964,29 @@ def discover_catalog_entry(
replication_key=None, # Must be defined by user
)

def discover_catalog_entries(self) -> list[dict]:
def discover_catalog_entries(
self,
*,
skip_schemas: t.Sequence[str] = (),
reflect_indices: bool = True,
) -> list[dict]:
"""Return a list of catalog entries from discovery.
Args:
skip_schemas: A list of schema names to skip.
reflect_indices: Whether to reflect indices to detect potential primary
keys.
Returns:
The discovered catalog entries as a list.
"""
result: list[dict] = []
engine = self._engine
inspected = sa.inspect(engine)
for schema_name in self.get_schema_names(engine, inspected):
if schema_name in skip_schemas:
continue

# Iterate through each table and view
for table_name, is_view in self.get_object_names(
engine,
Expand All @@ -982,6 +999,7 @@ def discover_catalog_entries(self) -> list[dict]:
schema_name,
table_name,
is_view,
reflect_indices=reflect_indices,
)
result.append(catalog_entry.to_dict())

Expand Down Expand Up @@ -1217,8 +1235,7 @@ def prepare_schema(self, schema_name: str) -> None:
Args:
schema_name: The target schema name.
"""
schema_exists = self.schema_exists(schema_name)
if not schema_exists:
if not self.schema_exists(schema_name):
self.create_schema(schema_name)

def prepare_table(
Expand Down
7 changes: 6 additions & 1 deletion singer_sdk/tap_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,9 @@ class SQLTap(Tap):
querying a database's system tables).
"""

skip_schemas: t.Sequence[str] = []
"""Hard-coded list of stream names to skip when discovering the catalog."""

_tap_connector: SQLConnector | None = None

def __init__(self, *args: t.Any, **kwargs: t.Any) -> None:
Expand Down Expand Up @@ -700,7 +703,9 @@ def catalog_dict(self) -> dict:
connector = self.tap_connector

result: dict[str, list[dict]] = {"streams": []}
result["streams"].extend(connector.discover_catalog_entries())
result["streams"].extend(
connector.discover_catalog_entries(skip_schemas=self.skip_schemas),
)

self._catalog_dict = result
return self._catalog_dict
Expand Down
26 changes: 26 additions & 0 deletions tests/core/test_connector_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,32 @@ def test_adapt_column_type(self, connector: DuckDBConnector):
assert result.keys() == ["id", "name"]
assert result.cursor.description[1][1] == "STRING"

@pytest.mark.parametrize(
"skip_schemas,expected_streams",
[
([], 1),
(["memory.my_schema"], 0),
],
)
def test_discover_catalog_entries_skip_schemas(
self,
connector: DuckDBConnector,
skip_schemas: list[str],
expected_streams: int,
):
with connector._engine.connect() as conn, conn.begin():
conn.execute(sa.text("CREATE SCHEMA my_schema"))
conn.execute(
sa.text(
"CREATE TABLE my_schema.test_table (id INTEGER PRIMARY KEY, name STRING)", # noqa: E501
)
)
entries = connector.discover_catalog_entries(
skip_schemas=skip_schemas,
reflect_indices=False,
)
assert len(entries) == expected_streams


def test_adapter_without_json_serde():
registry.register(
Expand Down

0 comments on commit 8b3528e

Please sign in to comment.