Skip to content

Commit

Permalink
chore: Fix some docstring types and apply some refactorings suggested…
Browse files Browse the repository at this point in the history
… by Sourcery
  • Loading branch information
edgarrmondragon committed Nov 17, 2023
1 parent 80e31a2 commit d0b9586
Show file tree
Hide file tree
Showing 17 changed files with 48 additions and 83 deletions.
8 changes: 2 additions & 6 deletions singer_sdk/_singerlib/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,7 @@ def __missing__(self, breadcrumb: Breadcrumb) -> bool:
Returns:
True if the breadcrumb is selected, False otherwise.
"""
if len(breadcrumb) >= 2: # noqa: PLR2004
parent = breadcrumb[:-2]
return self[parent]

return True
return self[breadcrumb[:-2]] if len(breadcrumb) >= 2 else True # noqa: PLR2004


@dataclass
Expand Down Expand Up @@ -71,7 +67,7 @@ def from_dict(cls: type[Metadata], value: dict[str, t.Any]) -> Metadata:
)

def to_dict(self) -> dict[str, t.Any]:
"""Convert metadata to a JSON-encodeable dictionary.
"""Convert metadata to a JSON-encodable dictionary.
Returns:
Metadata object.
Expand Down
16 changes: 5 additions & 11 deletions singer_sdk/authenticators.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,9 +462,7 @@ def client_id(self) -> str | None:
Returns:
Optional client secret from stream config if it has been set.
"""
if self.config:
return self.config.get("client_id")
return None
return self.config.get("client_id") if self.config else None

@property
def client_secret(self) -> str | None:
Expand All @@ -473,9 +471,7 @@ def client_secret(self) -> str | None:
Returns:
Optional client secret from stream config if it has been set.
"""
if self.config:
return self.config.get("client_secret")
return None
return self.config.get("client_secret") if self.config else None

def is_token_valid(self) -> bool:
"""Check if token is valid.
Expand All @@ -487,9 +483,7 @@ def is_token_valid(self) -> bool:
return False
if not self.expires_in:
return True
if self.expires_in > (utc_now() - self.last_refreshed).total_seconds():
return True
return False
return self.expires_in > (utc_now() - self.last_refreshed).total_seconds()

# Authentication and refresh
def update_access_token(self) -> None:
Expand Down Expand Up @@ -520,7 +514,7 @@ def update_access_token(self) -> None:
self.expires_in = int(expiration) if expiration else None
if self.expires_in is None:
self.logger.debug(
"No expires_in receied in OAuth response and no "
"No expires_in received in OAuth response and no "
"default_expiration set. Token will be treated as if it never "
"expires.",
)
Expand Down Expand Up @@ -566,7 +560,7 @@ def oauth_request_body(self) -> dict:

@property
def oauth_request_payload(self) -> dict:
"""Return request paytload for OAuth request.
"""Return request payload for OAuth request.
Returns:
Payload object for OAuth.
Expand Down
5 changes: 0 additions & 5 deletions singer_sdk/connectors/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -902,9 +902,6 @@ def merge_sql_types(
if issubclass(
generic_type,
(sqlalchemy.types.String, sqlalchemy.types.Unicode),
) or issubclass(
generic_type,
(sqlalchemy.types.String, sqlalchemy.types.Unicode),
):
# If length None or 0 then is varchar max ?
if (
Expand All @@ -913,8 +910,6 @@ def merge_sql_types(
or (cur_len and (opt_len >= cur_len))
):
return opt
# If best conversion class is equal to current type
# return the best conversion class
elif str(opt) == str(current_type):
return opt

Expand Down
12 changes: 7 additions & 5 deletions singer_sdk/helpers/_conformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@ def snakecase(string: str) -> str:
"""
string = re.sub(r"[\-\.\s]", "_", string)
string = (
string[0].lower()
+ re.sub(
r"[A-Z]",
lambda matched: "_" + str(matched.group(0).lower()),
string[1:],
(
string[0].lower()
+ re.sub(
r"[A-Z]",
lambda matched: f"_{matched.group(0).lower()!s}",
string[1:],
)
)
if string
else string
Expand Down
7 changes: 2 additions & 5 deletions singer_sdk/helpers/_flattening.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,12 +465,9 @@ def _should_jsondump_value(key: str, value: t.Any, flattened_schema=None) -> boo
if isinstance(value, (dict, list)):
return True

if (
return bool(
flattened_schema
and key in flattened_schema
and "type" in flattened_schema[key]
and set(flattened_schema[key]["type"]) == {"null", "object", "array"}
):
return True

return False
)
15 changes: 4 additions & 11 deletions singer_sdk/helpers/_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
STARTING_MARKER = "starting_replication_value"


def get_state_if_exists( # noqa: PLR0911
def get_state_if_exists(
tap_state: dict,
tap_stream_id: str,
state_partition_context: dict | None = None,
Expand Down Expand Up @@ -47,9 +47,7 @@ def get_state_if_exists( # noqa: PLR0911

stream_state = tap_state["bookmarks"][tap_stream_id]
if not state_partition_context:
if key:
return stream_state.get(key, None)
return stream_state
return stream_state.get(key, None) if key else stream_state
if "partitions" not in stream_state:
return None # No partitions defined

Expand All @@ -59,9 +57,7 @@ def get_state_if_exists( # noqa: PLR0911
)
if matched_partition is None:
return None # Partition definition not present
if key:
return matched_partition.get(key, None)
return matched_partition
return matched_partition.get(key, None) if key else matched_partition


def get_state_partitions_list(tap_state: dict, tap_stream_id: str) -> list[dict] | None:
Expand All @@ -84,10 +80,7 @@ def _find_in_partitions_list(
f"{{state_partition_context}}.\nMatching state values were: {found!s}"
)
raise ValueError(msg)
if found:
return t.cast(dict, found[0])

return None
return t.cast(dict, found[0]) if found else None


def _create_in_partitions_list(
Expand Down
7 changes: 4 additions & 3 deletions singer_sdk/helpers/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def is_secret_type(type_dict: dict) -> bool:
"""Return True if JSON Schema type definition appears to be a secret.
Will return true if either `writeOnly` or `secret` are true on this type
or any of the type's subproperties.
or any of the type's sub-properties.
Args:
type_dict: The JSON Schema type to check.
Expand All @@ -96,7 +96,7 @@ def is_secret_type(type_dict: dict) -> bool:
return True

if "properties" in type_dict:
# Recursively check subproperties and return True if any child is secret.
# Recursively check sub-properties and return True if any child is secret.
return any(
is_secret_type(child_type_dict)
for child_type_dict in type_dict["properties"].values()
Expand Down Expand Up @@ -388,6 +388,7 @@ def conform_record_data_types(
return rec


# TODO: This is in dire need of refactoring. It's a mess.
def _conform_record_data_types( # noqa: PLR0912
input_object: dict[str, t.Any],
schema: dict,
Expand All @@ -405,7 +406,7 @@ def _conform_record_data_types( # noqa: PLR0912
input_object: A single record
schema: JSON schema the given input_object is expected to meet
level: Specifies how recursive the conformance process should be
parent: '.' seperated path to this element from the object root (for logging)
parent: '.' separated path to this element from the object root (for logging)
"""
output_object: dict[str, t.Any] = {}
unmapped_properties: list[str] = []
Expand Down
2 changes: 1 addition & 1 deletion singer_sdk/io_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def _assert_line_requires(line_dict: dict, requires: set[str]) -> None:
if not requires.issubset(line_dict):
missing = requires - set(line_dict)
msg = f"Line is missing required {', '.join(missing)} key(s): {line_dict}"
raise Exception(msg)
raise Exception(msg) # TODO: Raise a more specific exception

def deserialize_json(self, line: str) -> dict:
"""Deserialize a line of json.
Expand Down
7 changes: 2 additions & 5 deletions singer_sdk/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,10 +273,7 @@ def transform(self, record: dict) -> dict | None:
The transformed record.
"""
transformed_record = self._transform_fn(record)
if not transformed_record:
return None

return super().transform(transformed_record)
return super().transform(transformed_record) if transformed_record else None

def get_filter_result(self, record: dict) -> bool:
"""Return True to include or False to exclude.
Expand All @@ -291,7 +288,7 @@ def get_filter_result(self, record: dict) -> bool:

@property
def functions(self) -> dict[str, t.Callable]:
"""Get availabale transformation functions.
"""Get available transformation functions.
Returns:
Functions which should be available for expression evaluation.
Expand Down
7 changes: 3 additions & 4 deletions singer_sdk/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,10 +263,9 @@ def __exit__(
exc_tb: The exception traceback.
"""
if Tag.STATUS not in self.tags:
if exc_type is None:
self.tags[Tag.STATUS] = Status.SUCCEEDED
else:
self.tags[Tag.STATUS] = Status.FAILED
self.tags[Tag.STATUS] = (
Status.SUCCEEDED if exc_type is None else Status.FAILED
)
log(self.logger, Point("timer", self.metric, self.elapsed(), self.tags))

def elapsed(self) -> float:
Expand Down
2 changes: 1 addition & 1 deletion singer_sdk/sinks/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def _add_sdc_metadata_to_record(
tz=datetime.timezone.utc,
).isoformat()
record["_sdc_batched_at"] = (
context.get("batch_start_time", None)
context.get("batch_start_time")
or datetime.datetime.now(tz=datetime.timezone.utc)
).isoformat()
record["_sdc_deleted_at"] = record.get("_sdc_deleted_at")
Expand Down
8 changes: 1 addition & 7 deletions singer_sdk/sinks/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,7 @@ def schema_name(self) -> str | None:
if default_target_schema:
return default_target_schema

if len(parts) in {2, 3}:
# Stream name is a two-part or three-part identifier.
# Use the second-to-last part as the schema name.
return self.conform_name(parts[-2], "schema")

# Schema name not detected.
return None
return self.conform_name(parts[-2], "schema") if len(parts) in {2, 3} else None

@property
def database_name(self) -> str | None:
Expand Down
4 changes: 2 additions & 2 deletions singer_sdk/tap_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,10 @@ def streams(self) -> dict[str, Stream]:
Returns:
A mapping of names to streams, using discovery or a provided catalog.
"""
input_catalog = self.input_catalog

if self._streams is None:
self._streams = {}
input_catalog = self.input_catalog

for stream in self.load_streams():
if input_catalog is not None:
stream.apply_catalog(input_catalog)
Expand Down
2 changes: 1 addition & 1 deletion singer_sdk/target_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ def _process_schema_message(self, message_dict: dict) -> None:

stream_name = message_dict["stream"]
schema = message_dict["schema"]
key_properties = message_dict.get("key_properties", None)
key_properties = message_dict.get("key_properties")
do_registration = False
if stream_name not in self.mapper.stream_maps:
do_registration = True
Expand Down
1 change: 1 addition & 0 deletions singer_sdk/testing/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def runner(self) -> TapTestRunner | TargetTestRunner:

return TapTestClass

# TODO: Refactor this. It's too long and nested.
def _annotate_test_class( # noqa: C901
self,
empty_test_class: type[BaseTestClass],
Expand Down
5 changes: 1 addition & 4 deletions singer_sdk/testing/legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,7 @@ def _get_tap_catalog(
# Test discovery
tap.run_discovery()
catalog_dict = tap.catalog_dict
if select_all:
return _select_all(catalog_dict)

return catalog_dict
return _select_all(catalog_dict) if select_all else catalog_dict


def _select_all(catalog_dict: dict) -> dict:
Expand Down
23 changes: 11 additions & 12 deletions singer_sdk/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,12 +965,14 @@ def to_jsonschema_type(
msg = "Expected `str` or a SQLAlchemy `TypeEngine` object or type."
raise ValueError(msg)

# Look for the type name within the known SQL type names:
for sqltype, jsonschema_type in sqltype_lookup.items():
if sqltype.lower() in type_name.lower():
return jsonschema_type

return sqltype_lookup["string"] # safe failover to str
return next(
(
jsonschema_type
for sqltype, jsonschema_type in sqltype_lookup.items()
if sqltype.lower() in type_name.lower()
),
sqltype_lookup["string"], # safe failover to str
)


def _jsonschema_type_check(jsonschema_type: dict, type_check: tuple[str]) -> bool:
Expand All @@ -981,7 +983,7 @@ def _jsonschema_type_check(jsonschema_type: dict, type_check: tuple[str]) -> boo
type_check: A tuple of type strings to look for.
Returns:
True if the schema suports the type.
True if the schema supports the type.
"""
if "type" in jsonschema_type:
if isinstance(jsonschema_type["type"], (list, tuple)):
Expand All @@ -991,12 +993,9 @@ def _jsonschema_type_check(jsonschema_type: dict, type_check: tuple[str]) -> boo
elif jsonschema_type.get("type") in type_check:
return True

if any(
return any(
_jsonschema_type_check(t, type_check) for t in jsonschema_type.get("anyOf", ())
):
return True

return False
)


def to_sql_type( # noqa: PLR0911, C901
Expand Down

0 comments on commit d0b9586

Please sign in to comment.