Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add great_tables renderer #2846

Merged
merged 2 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
326 changes: 318 additions & 8 deletions plugins/flytekit-pandera/flytekitplugins/pandera/pandas_renderer.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,335 @@
from typing import TYPE_CHECKING
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional

from flytekit import lazy_module

if TYPE_CHECKING:
import great_tables as gt
import pandas

import pandera
else:
gt = lazy_module("great_tables")
pandas = lazy_module("pandas")
pandera = lazy_module("pandera")


@dataclass
class PandasReport:
summary: pandas.DataFrame
data_preview: pandas.DataFrame
schema_error_df: Optional[pandas.DataFrame] = None
data_error_df: Optional[pandas.DataFrame] = None


SCHEMA_ERROR_KEY = "SCHEMA"
DATA_ERROR_KEY = "DATA"
SCHEMA_ERROR_COLUMNS = ["schema", "column", "error_code", "check", "failure_case", "error"]
DATA_ERROR_COLUMNS = ["schema", "column", "error_code", "check", "index", "failure_case", "error"]
DATA_ERROR_DISPLAY_ORDER = ["column", "error_code", "percent_valid", "check", "failure_cases", "error"]

DATA_PREVIEW_HEAD = 5
FAILURE_CASE_LIMIT = 10
ERROR_COLUMN_MAX_WIDTH = 200


class PandasReportRenderer:
def __init__(self, title: str = "Pandera Error Report"):
self._title = title

def to_html(self, error: pandera.errors.SchemaErrors) -> str:
error.failure_cases.groupby(["schema_context"])
html = (
error.failure_cases.set_index(["schema_context", "column", "check"])
.drop(["check_number"], axis="columns")[["index", "failure_case"]]
.to_html()
def _create_success_report(self, data: "pandas.DataFrame", schema: pandera.DataFrameSchema):
summary = pandas.DataFrame(
[
{"Metadata": "Schema Name", "Value": schema.name},
{"Metadata": "Shape", "Value": f"{data.shape[0]} rows x {data.shape[1]} columns"},
{"Metadata": "Total schema errors", "Value": 0},
{"Metadata": "Total data errors", "Value": 0},
{"Metadata": "Schema Object", "Value": f"```\n{schema.__repr__()}\n```"},
]
)

return PandasReport(
summary=summary,
data_preview=data.head(DATA_PREVIEW_HEAD),
schema_error_df=None,
data_error_df=None,
)

@staticmethod
def _reshape_long_failure_cases(long_failure_cases: "pandas.DataFrame"):
return (
long_failure_cases.pivot(
index=["schema_context", "check", "index"], columns="column", values="failure_case"
)
.apply(lambda s: s.to_dict(), axis="columns")
.rename("failure_case")
.reset_index(["index", "check"])
.reset_index(drop=True)[["check", "index", "failure_case"]]
)

def _prepare_data_error_df(self, data: "pandas.DataFrame", data_errors: dict, failure_cases: "pandas.DataFrame"):
def num_failure_cases(series):
return len(series)

def _failure_cases(series):
series = series.astype(str)
out = ", ".join(str(x) for x in series.iloc[:FAILURE_CASE_LIMIT])
if len(series) > FAILURE_CASE_LIMIT:
out += f" ... (+{len(series) - FAILURE_CASE_LIMIT} more)"
return out

data_errors = pandas.concat(pandas.DataFrame(v).assign(error_code=k) for k, v in data_errors.items())
# this is a bit of a hack to get the failure cases into the same format as the data errors
long_failure_case_selector = (failure_cases["schema_context"] == "DataFrameSchema") & (
failure_cases["column"].notna()
)
return html
long_failure_cases = failure_cases[long_failure_case_selector]

data_error_df = [data_errors.merge(failure_cases, how="inner", on=["column", "check"])]
if long_failure_cases.shape[0] > 0:
reshaped_failure_cases = self._reshape_long_failure_cases(long_failure_cases)
long_data_errors = data_errors.assign(
column=data_errors.column.where(~(data_errors.column == data_errors.schema), "NA")
)
data_error_df.append(
long_data_errors.merge(reshaped_failure_cases, how="inner", on=["check"]).assign(column="NA")
)

data_error_df = pandas.concat(data_error_df)
out_df = (
data_error_df[DATA_ERROR_COLUMNS]
.groupby(["column", "error_code", "check", "error"])
.failure_case.agg([num_failure_cases, _failure_cases])
.reset_index()
.rename(columns={"_failure_cases": "failure_cases"})
.assign(percent_valid=lambda df: 1 - (df["num_failure_cases"] / data.shape[0]))
)
return out_df

def _create_error_report(
self,
data: "pandas.DataFrame",
schema: pandera.DataFrameSchema,
error: "pandera.errors.SchemaErrors",
):
failure_cases = error.failure_cases
error_dict = error.args[0]

schema_errors = error_dict.get(SCHEMA_ERROR_KEY)
data_errors = error_dict.get(DATA_ERROR_KEY)

if schema_errors is None:
schema_error_df = None
total_schema_errors = 0
else:
schema_error_df = (
pandas.concat(pandas.DataFrame(v).assign(error_code=k) for k, v in schema_errors.items())
.merge(failure_cases, how="left", on=["column", "check"])[SCHEMA_ERROR_COLUMNS]
.drop(["schema"], axis="columns")
)
total_schema_errors = schema_error_df.shape[0]

if data_errors is None:
data_error_df = None
total_data_errors = 0
else:
data_error_df = self._prepare_data_error_df(data, data_errors, failure_cases)
total_data_errors = data_error_df.shape[0]

summary = pandas.DataFrame(
[
{"Metadata": "Schema Name", "Value": schema.name},
{"Metadata": "Data Shape", "Value": f"{data.shape[0]} rows x {data.shape[1]} columns"},
{"Metadata": "Total schema errors", "Value": total_schema_errors},
{"Metadata": "Total data errors", "Value": total_data_errors},
{"Metadata": "Schema Object", "Value": f"```\n{error.schema.__repr__()}\n```"},
]
)

return PandasReport(
summary=summary,
data_preview=data.head(DATA_PREVIEW_HEAD),
schema_error_df=schema_error_df,
data_error_df=data_error_df,
)

def _format_summary_df(self, df: "pandas.DataFrame") -> str:
return (
gt.GT(df)
.tab_header(
title=gt.md("**Summary**"),
subtitle="A high-level overview of the schema errors found in the DataFrame.",
)
.cols_width(
cases={
"Metadata": "20%",
"Value": "80%",
}
)
.fmt_markdown(["Value"])
.tab_stub(rowname_col="Metadata")
.tab_stubhead(label="Metadata")
.tab_style(style=gt.style.text(align="left"), locations=gt.loc.header())
.tab_style(style=gt.style.fill(color="#f2fae2"), locations=gt.loc.header())
.tab_style(
style=gt.style.text(weight="bold"),
locations=[gt.loc.column_labels(), gt.loc.stubhead(), gt.loc.stub()],
)
.as_raw_html()
)

def _format_data_preview_df(self, df: "pandas.DataFrame") -> str:
return (
gt.GT(df)
.tab_header(
title=gt.md("**Data Preview**"),
subtitle=f"A preview of the first {min(DATA_PREVIEW_HEAD, df.shape[0])} rows of the data.",
)
.tab_style(style=gt.style.text(align="left"), locations=gt.loc.header())
.tab_style(style=gt.style.fill(color="#f2fae2"), locations=gt.loc.header())
.tab_style(style=gt.style.text(weight="bold"), locations=gt.loc.column_labels())
.tab_style(style=gt.style.text(align="left"), locations=[gt.loc.body(), gt.loc.column_labels()])
.as_raw_html()
)

@staticmethod
def _format_error(x: str) -> str:
if len(x) > ERROR_COLUMN_MAX_WIDTH:
x = f"{x[:ERROR_COLUMN_MAX_WIDTH]}..."
return f"```\n{x}\n```"

def _format_schema_error_df(self, df: "pandas.DataFrame") -> str:
df = df.assign(
error=lambda df: df["error"].map(self._format_error),
error_code=lambda df: df["error_code"].map(lambda x: f"`{x}`"),
check=lambda df: df["check"].map(lambda x: f"`{x}`"),
)
return (
gt.GT(df)
.tab_header(
title=gt.md("**Schema-level Errors**"),
subtitle="Schema-level metadata errors, e.g. column names, dtypes.",
)
.fmt_markdown(["error_code", "check", "error"])
.tab_style(style=gt.style.text(align="left"), locations=gt.loc.header())
.tab_style(style=gt.style.fill(color="#f2fae2"), locations=gt.loc.header())
.tab_style(style=gt.style.text(weight="bold"), locations=gt.loc.column_labels())
.as_raw_html()
)

def _format_data_error_df(self, df: "pandas.DataFrame") -> str:
df = df.assign(
error=lambda df: df["error"].map(self._format_error),
error_code=lambda df: df["error_code"].map(lambda x: f"`{x}`"),
check=lambda df: df["check"].map(lambda x: f"`{x}`"),
)[DATA_ERROR_DISPLAY_ORDER]

return (
gt.GT(df)
.tab_header(
title=gt.md("**Data-level Errors**"),
subtitle="Data-level value errors, e.g. null values, out-of-range values.",
)
.fmt_markdown(["error_code", "check", "error"])
.fmt_percent("percent_valid", decimals=2)
.data_color(columns=["percent_valid"], palette="RdYlGn", domain=[0, 1], alpha=0.2)
.tab_stub(groupname_col="column", rowname_col="error_code")
.tab_stubhead(label="column")
.tab_style(
style=gt.style.text(align="left"), locations=[gt.loc.header(), gt.loc.column_labels(), gt.loc.body()]
)
.tab_style(style=gt.style.text(align="center"), locations=gt.loc.body(columns="percent_valid"))
.tab_style(style=gt.style.fill(color="#f2fae2"), locations=gt.loc.header())
.tab_style(
style=gt.style.text(weight="bold"),
locations=[gt.loc.column_labels(), gt.loc.stubhead(), gt.loc.row_groups()],
)
.tab_style(
style=gt.style.fill(color="#f4f4f4"),
locations=[gt.loc.row_groups(), gt.loc.stub()],
)
.as_raw_html()
)

def to_html(
self,
data: "pandas.DataFrame",
schema: pandera.DataFrameSchema,
error: Optional[pandera.errors.SchemaErrors] = None,
) -> str:
error_segments = ""
if error is None:
report_dfs = self._create_success_report(data, schema)
top_message = "✅ Data validation succeeded."
else:
report_dfs = self._create_error_report(data, schema, error)
top_message = "❌ Data validation failed."

error_segments = ""
if report_dfs.schema_error_df is not None:
error_segments += f"""
<br>

{self._format_schema_error_df(report_dfs.schema_error_df)}
"""

if report_dfs.data_error_df is not None:
error_segments += f"""
<br>

{self._format_data_error_df(report_dfs.data_error_df)}
"""

return f"""
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Pandera Report</title>
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" integrity="sha384-HSMxcRTRxnN+Bdg0JdbxYKrThecOKuH5zCYotlSAcp1+c8xmyTe9GYg1l9a69psu" crossorigin="anonymous">
<style>
.pandera-report {{
min-width: 60%;
margin: 0 auto;
padding: 20px;
}}

.report-title h1 {{
display: inline;
line-height: 60px;
padding-left: 10px;
}}

.report-title img {{
display: inline;
width: 60px;
float: left;
}}

table.gt_table {{
width: 100% !important;
margin-left: 0 !important;
}}
</style>
</head>
<body>
<div class="pandera-report">
<div class="report-title">
<img src="https://raw.githubusercontent.com/pandera-dev/pandera/main/docs/source/_static/pandera-logo.png" alt="Pandera Logo">
<h1>Pandera Report</h1>
</div>

<h3>{top_message}</h3>

{self._format_summary_df(report_dfs.summary)}

<br>

{self._format_data_preview_df(report_dfs.data_preview)}

{error_segments}
</div>
</body>
</html>
"""
Original file line number Diff line number Diff line change
Expand Up @@ -88,15 +88,19 @@ def to_literal(
try:
val = schema.validate(python_val, lazy=True)
except (pandera.errors.SchemaError, pandera.errors.SchemaErrors) as exc:
html = renderer.to_html(python_val, schema, exc)
val = python_val
if config.on_error == "raise":
# render the deck before raising the error
raise exc
elif config.on_error == "warn":
logger.warning(str(exc))
html = renderer.to_html(exc)
Deck(renderer._title, html)
else:
raise ValueError(f"Invalid on_error value: {config.on_error}")
val = python_val
else:
html = renderer.to_html(val, schema)
finally:
Deck(renderer._title, html)

lv = self._sd_transformer.to_literal(ctx, val, pandas.DataFrame, expected)

Expand All @@ -123,16 +127,18 @@ def to_python_value(
try:
val = schema.validate(df, lazy=True)
except (pandera.errors.SchemaError, pandera.errors.SchemaErrors) as exc:
html = renderer.to_html(df, schema, exc)
val = df
if config.on_error == "raise":
raise exc
elif config.on_error == "warn":
print(ctx.execution_state)
logger.warning(str(exc))
html = renderer.to_html(exc)
Deck(renderer._title, html)
else:
raise ValueError(f"Invalid on_error value: {config.on_error}")
val = df
else:
html = renderer.to_html(val, schema)
finally:
Deck(renderer._title, html)

return val

Expand Down
Loading