Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle explicit segmentation keys #1566

Merged
merged 12 commits into from
Sep 25, 2024
89 changes: 89 additions & 0 deletions python/tests/api/logger/test_segments.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,60 @@ def test_single_column_segment() -> None:
assert cardinality == 1.0


def test_single_column_and_manual_segment() -> None:
input_rows = 100
segment_column = "col3"
number_of_segments = 5
d = {
"col1": [i for i in range(input_rows)],
"col2": [i * i * 1.1 for i in range(input_rows)],
segment_column: [f"x{str(i%number_of_segments)}" for i in range(input_rows)],
}

df = pd.DataFrame(data=d)
test_segments = segment_on_column("col3")
results: SegmentedResultSet = why.log(
df, schema=DatasetSchema(segments=test_segments), segment_key_values={"zzz": "foo", "ver": 1}
)
assert results.count == number_of_segments
partitions = results.partitions
assert len(partitions) == 1
partition = partitions[0]
segments = results.segments_in_partition(partition)
assert len(segments) == number_of_segments

first_segment = next(iter(segments))
assert first_segment.key == ("x0", "1", "foo")
first_segment_profile = results.profile(first_segment)
assert first_segment_profile is not None
assert first_segment_profile._columns["col1"]._schema.dtype == np.int64
assert first_segment_profile._columns["col2"]._schema.dtype == np.float64
assert first_segment_profile._columns["col3"]._schema.dtype.name == "object"
segment_cardinality: CardinalityMetric = (
first_segment_profile.view().get_column(segment_column).get_metric("cardinality")
)
cardinality = segment_cardinality.estimate
assert cardinality is not None
assert cardinality == 1.0


def test_throw_on_duplicate_keys() -> None:
input_rows = 100
segment_column = "col3"
number_of_segments = 5
d = {
"col1": [i for i in range(input_rows)],
"col2": [i * i * 1.1 for i in range(input_rows)],
segment_column: [f"x{str(i%number_of_segments)}" for i in range(input_rows)],
}

df = pd.DataFrame(data=d)
test_segments = segment_on_column("col3")

with pytest.raises(ValueError):
why.log(df, schema=DatasetSchema(segments=test_segments), segment_key_values={segment_column: "foo"})


def test_single_column_segment_with_trace_id() -> None:
input_rows = 100
segment_column = "col3"
Expand Down Expand Up @@ -312,6 +366,41 @@ def test_multi_column_segment() -> None:
assert count == 1


def test_multicolumn_and_manual_segment() -> None:
input_rows = 100
d = {
"col1": [i for i in range(input_rows)],
"col2": [i * i * 1.1 for i in range(input_rows)],
"col3": [f"x{str(i%5)}" for i in range(input_rows)],
}

df = pd.DataFrame(data=d)
segmentation_partition = SegmentationPartition(
name="col1,col3", mapper=ColumnMapperFunction(col_names=["col1", "col3"])
)
test_segments = {segmentation_partition.name: segmentation_partition}
results: SegmentedResultSet = why.log(
df, schema=DatasetSchema(segments=test_segments), segment_key_values={"ver": 42, "zzz": "bar"}
)
segments = results.segments()
last_segment = segments[-1]

# Note this segment is not useful as there is only one datapoint per segment, we have 100 rows and
# 100 segments. The segment value is a tuple of strings identifying this segment.
assert last_segment.key == ("99", "x4", "42", "bar")

last_segment_profile = results.profile(last_segment)

assert last_segment_profile._columns["col1"]._schema.dtype == np.int64
assert last_segment_profile._columns["col2"]._schema.dtype == np.float64
assert last_segment_profile._columns["col3"]._schema.dtype.name == "object"

segment_distribution: DistributionMetric = last_segment_profile.view().get_column("col1").get_metric("distribution")
count = segment_distribution.n
assert count is not None
assert count == 1


def test_multi_column_segment_serialization_roundtrip_v0(tmp_path: Any) -> None:
input_rows = 35
d = {
Expand Down
43 changes: 43 additions & 0 deletions python/tests/api/writer/test_whylabs_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,49 @@ def test_whylabs_writer_segmented(zipped: bool):
assert deserialized_view.get_columns().keys() == data.keys()


@pytest.mark.load
def test_whylabs_writer_explicit_segmented():
ORG_ID = _get_org()
MODEL_ID = os.environ.get("WHYLABS_DEFAULT_DATASET_ID")
why.init(reinit=True, force_local=True)
schema = DatasetSchema(segments=segment_on_column("col1"))
data = {"col1": [1, 2, 1, 3, 2, 2], "col2": ["foo", "bar", "wat", "foo", "baz", "wat"]}
df = pd.DataFrame(data)
trace_id = str(uuid4())
profile = why.log(df, schema=schema, trace_id=trace_id, segment_key_values={"version": "1.0.0"})

assert profile.count == 3
partitions = profile.partitions
assert len(partitions) == 1
partition = partitions[0]
segments = profile.segments_in_partition(partition)
assert len(segments) == 3

first_segment = next(iter(segments))
assert first_segment.key == ("1", "1.0.0")

writer = WhyLabsWriter()
success, status = writer.write(profile)
assert success
time.sleep(SLEEP_TIME) # platform needs time to become aware of the profile
dataset_api = DatasetProfileApi(writer._api_client)
response: ProfileTracesResponse = dataset_api.get_profile_traces(
org_id=ORG_ID,
dataset_id=MODEL_ID,
trace_id=trace_id,
)
assert len(response.get("traces")) == 3
for trace in response.get("traces"):
download_url = trace.get("download_url")
headers = {"Content-Type": "application/octet-stream"}
downloaded_profile = writer._s3_pool.request(
"GET", download_url, headers=headers, timeout=writer._timeout_seconds
)
deserialized_view = DatasetProfileView.deserialize(downloaded_profile.data)
assert deserialized_view._metadata["whylogs.tag.version"] == "1.0.0"
assert deserialized_view.get_columns().keys() == data.keys()


@pytest.mark.load
@pytest.mark.parametrize(
"segmented,zipped",
Expand Down
6 changes: 1 addition & 5 deletions python/whylogs/api/logger/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,17 +113,13 @@ def log(

# If segments are defined use segment_processing to return a SegmentedResultSet
if active_schema and active_schema.segments:
if segment_key_values:
raise ValueError(
f"using explicit `segment_key_values` {segment_key_values} is not compatible "
f"with segmentation also defined in the DatasetSchema: {active_schema.segments}"
)
segmented_results: SegmentedResultSet = segment_processing(
schema=active_schema,
obj=obj,
pandas=pandas,
row=row,
segment_cache=self._segment_cache,
segment_key_values=segment_key_values,
)
# Update the existing segmented_results metadata with the trace_id and other keys if not present
_populate_common_profile_metadata(segmented_results.metadata, trace_id=trace_id, tags=tags)
Expand Down
26 changes: 21 additions & 5 deletions python/whylogs/api/logger/segment_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,15 @@ def _process_segment(
segments[segment_key] = profile


def _get_segment_from_group_key(group_key, partition_id) -> Tuple[str, ...]:
def _get_segment_from_group_key(group_key, partition_id, explicit_keys: Tuple[str, ...] = ()) -> Tuple[str, ...]:
if isinstance(group_key, str):
segment_tuple_key: Tuple[str, ...] = (group_key,)
elif isinstance(group_key, (List, Iterable, Iterator)):
segment_tuple_key = tuple(str(k) for k in group_key)
else:
segment_tuple_key = (str(group_key),)

return Segment(segment_tuple_key, partition_id)
return Segment(segment_tuple_key + explicit_keys, partition_id)


def _is_nan(x):
Expand All @@ -65,7 +65,11 @@ def _process_simple_partition(
pandas: Optional[pd.DataFrame] = None,
row: Optional[Mapping[str, Any]] = None,
segment_cache: Optional[SegmentCache] = None,
segment_key_values: Optional[Dict[str, str]] = None,
):
explicit_keys = (
tuple(str(segment_key_values[k]) for k in sorted(segment_key_values.keys())) if segment_key_values else tuple()
)
if pandas is not None:
# simple means we can segment on column values
grouped_data = pandas.groupby(columns)
Expand All @@ -81,11 +85,11 @@ def _process_simple_partition(
pandas_segment = pandas[mask]
else:
pandas_segment = grouped_data.get_group(group)
segment_key = _get_segment_from_group_key(group, partition_id)
segment_key = _get_segment_from_group_key(group, partition_id, explicit_keys)
_process_segment(pandas_segment, segment_key, segments, schema, segment_cache)
elif row:
# TODO: consider if we need to combine with the column names
segment_key = Segment(tuple(str(row[element]) for element in columns), partition_id)
segment_key = Segment(tuple(str(row[element]) for element in columns) + explicit_keys, partition_id)
_process_segment(row, segment_key, segments, schema, segment_cache)


Expand Down Expand Up @@ -129,6 +133,7 @@ def _log_segment(
pandas: Optional[pd.DataFrame] = None,
row: Optional[Mapping[str, Any]] = None,
segment_cache: Optional[SegmentCache] = None,
segment_key_values: Optional[Dict[str, str]] = None,
) -> Dict[Segment, Any]:
segments: Dict[Segment, Any] = {}
pandas, row = _pandas_or_dict(obj, pandas, row)
Expand All @@ -137,7 +142,13 @@ def _log_segment(
if partition.simple:
columns = partition.mapper.col_names if partition.mapper else None
if columns:
_process_simple_partition(partition.id, schema, segments, columns, pandas, row, segment_cache)
_process_simple_partition(
richard-rogers marked this conversation as resolved.
Show resolved Hide resolved
partition.id, schema, segments, columns, pandas, row, segment_cache, segment_key_values
)
else:
logger.error(
"Segmented DatasetSchema defines no segments; use an unsegmented DatasetSchema or specify columns to segment on."
)
else:
raise NotImplementedError("custom mapped segments not yet implemented")
return segments
Expand All @@ -149,6 +160,7 @@ def segment_processing(
pandas: Optional[pd.DataFrame] = None,
row: Optional[Dict[str, Any]] = None,
segment_cache: Optional[SegmentCache] = None,
segment_key_values: Optional[Dict[str, str]] = None,
) -> SegmentedResultSet:
number_of_partitions = len(schema.segments)
logger.info(f"The specified schema defines segments with {number_of_partitions} partitions.")
Expand All @@ -160,6 +172,9 @@ def segment_processing(

for partition_name in schema.segments:
segment_partition = schema.segments[partition_name]
if segment_partition.mapper and segment_key_values:
segment_partition.mapper.set_explicit_names(segment_key_values.keys())

logger.info(f"Processing partition with name({partition_name})")
logger.debug(f"{partition_name}: is simple ({segment_partition.simple}), id ({segment_partition.id})")
if segment_partition.filter:
Expand All @@ -176,6 +191,7 @@ def segment_processing(
pandas=pandas,
row=row,
segment_cache=segment_cache,
segment_key_values=segment_key_values,
)
segmented_profiles[segment_partition.id] = partition_segments
segment_partitions.append(segment_partition)
Expand Down
11 changes: 11 additions & 0 deletions python/whylogs/core/segmentation_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,17 @@ def __post_init__(self):
column_string = ",".join(sorted(self.col_names))
segment_hash = hashlib.sha512(bytes(column_string + mapper_string, encoding="utf8"))
self.id = segment_hash.hexdigest()
self.explicit_names = list()

def set_explicit_names(self, key_names: List[str] = []) -> None:
if self.col_names:
for name in key_names:
if name in self.col_names:
raise ValueError(
f"Cannot have segmentation key {name} as both a column name and explicit segment key"
)

self.explicit_names = sorted(key_names)


@dataclass
Expand Down
3 changes: 2 additions & 1 deletion python/whylogs/migration/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,9 @@ def _generate_segment_tags_metadata(

segment_tags = []
col_names = partition.mapper.col_names
explicit_names = partition.mapper.explicit_names

for index, column_name in enumerate(col_names):
for index, column_name in enumerate(col_names + explicit_names):
segment_tags.append(SegmentTag(key=_TAG_PREFIX + column_name, value=segment.key[index]))
else:
raise NotImplementedError(
Expand Down
Loading