Skip to content

Commit

Permalink
Merge pull request #111 from mmcdermott/meds-eval-schema
Browse files Browse the repository at this point in the history
Adds meds eval label schema adherence tests for generated prediction …
  • Loading branch information
Oufattole authored Jan 30, 2025
2 parents 21992c2 + 47a345f commit dfdd0eb
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 1 deletion.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ classifiers = [
dependencies = [
"polars>=1.6.0,<=1.17.1", "pyarrow", "loguru", "hydra-core==1.3.2", "numpy", "scipy<1.14.0", "pandas", "tqdm", "xgboost",
"scikit-learn", "hydra-optuna-sweeper", "hydra-joblib-launcher", "ml-mixins", "meds>=0.3.3", "meds-transforms>=0.0.7",
"meds-evaluation",
]

[tool.setuptools_scm]
Expand Down
4 changes: 3 additions & 1 deletion src/MEDS_tabular_automl/xgboost_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import scipy.sparse as sp
import xgboost as xgb
from loguru import logger
from meds_evaluation.schema import BINARY_CLASSIFICATION_SCHEMA_DICT
from omegaconf import DictConfig, OmegaConf
from sklearn.metrics import roc_auc_score

Expand Down Expand Up @@ -182,7 +183,8 @@ def predict(self, split="held_out") -> pl.DataFrame:
"predicted_boolean_value": y_pred.round(),
"predicted_boolean_probability": y_pred,
"event_id": labels["event_id"],
}
},
schema={**BINARY_CLASSIFICATION_SCHEMA_DICT, "event_id": pl.Int64},
)
if not (predictions_df["boolean_value"] == labels["label"]).all():
mismatched_labels = predictions_df["boolean_value"] == labels["label"]
Expand Down
7 changes: 7 additions & 0 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import polars as pl
from hydra import compose, initialize
from meds_evaluation.schema import validate_binary_classification_schema

from MEDS_tabular_automl.describe_codes import get_feature_columns
from MEDS_tabular_automl.file_name import list_subdir_files
Expand Down Expand Up @@ -313,6 +314,9 @@ def test_integration(tmp_path):
assert (time_output_dir / "best_trial/held_out_predictions.parquet").exists()
assert (time_output_dir / "best_trial/tuning_predictions.parquet").exists()
assert (time_output_dir / "sweep_results_summary.parquet").exists()
validate_binary_classification_schema(
pl.read_parquet(time_output_dir / "best_trial/held_out_predictions.parquet")
)
else:
assert len(glob.glob(str(output_model_dir / "*/sweep_results/**/*.pkl"))) == 2
assert len(glob.glob(str(output_model_dir / "*/best_trial/*.pkl"))) == 1
Expand Down Expand Up @@ -351,6 +355,9 @@ def test_integration(tmp_path):
assert (time_output_dir / "best_trial/held_out_predictions.parquet").exists()
assert (time_output_dir / "best_trial/tuning_predictions.parquet").exists()
assert (time_output_dir / "sweep_results_summary.parquet").exists()
validate_binary_classification_schema(
pl.read_parquet(time_output_dir / "best_trial/held_out_predictions.parquet")
)
else:
assert len(glob.glob(str(output_model_dir / "*/sweep_results/**/*.pkl"))) == 2
assert len(glob.glob(str(output_model_dir / "*/best_trial/*.pkl"))) == 1
Expand Down
3 changes: 3 additions & 0 deletions tests/test_tabularize_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,4 +323,7 @@ def test_tabularize(tmp_path):
xgboost_model.load_model(xgboost_json_fp)
xgboost_model._build()
predictions_df = xgboost_model.predict("held_out")
from meds_evaluation.schema import validate_binary_classification_schema

validate_binary_classification_schema(predictions_df)
assert isinstance(predictions_df, pl.DataFrame)

0 comments on commit dfdd0eb

Please sign in to comment.