-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dev #110
Conversation
…data for a label, or b) there are no labels for a shard.
Handles edge cases in task caching
… for xgboost models.
…th predicted labels and ground truth labels. Additionally fixes bug in EvalCallback
Meds eval schema compliance
…or a specific task in the static and time-series tabularization stages.
…els can be produced for xgboost when inputting task specific tabularized data.
…is all that is needed.
…re incorrect, meaning tabularization what performed on the wrong indices.
savez_compressed
Task specific tabularization
Warning Rate limit exceeded@Oufattole has exceeded the limit for the number of commits or files that can be reviewed per hour. Please wait 13 minutes and 49 seconds before requesting another review. ⌛ How to resolve this issue?After the wait time has elapsed, a review can be triggered using the We recommend that you space out your commits to avoid hitting the rate limit. 🚦 How do rate limits work?CodeRabbit enforces hourly rate limits for each developer per organization. Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout. Please see our FAQ for further information. 📒 Files selected for processing (4)
WalkthroughThis pull request introduces comprehensive enhancements to the MEDS-Tab library's tabularization workflows. The changes span multiple files, focusing on improving flexibility in data processing, adding optional label handling, and enhancing documentation. Key modifications include updating the README with clearer tabularization instructions, introducing optional label directory processing across various scripts, and refining model prediction and evaluation capabilities. The changes aim to provide more granular control over data tabularization and improve the overall usability of the library. Changes
Possibly related PRs
Suggested reviewers
Poem
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 8
🧹 Nitpick comments (14)
src/MEDS_tabular_automl/scripts/tabularize_time_series.py (1)
76-79
: Simplify nested if statements.Combine the nested if statements for better readability.
- if cfg.input_label_dir: - if not Path(cfg.input_label_dir).is_dir(): - raise ValueError(f"input_label_dir: {cfg.input_label_dir} is not a directory.") + if cfg.input_label_dir and not Path(cfg.input_label_dir).is_dir(): + raise ValueError(f"input_label_dir: {cfg.input_label_dir} is not a directory.")🧰 Tools
🪛 Ruff (0.8.2)
77-78: Use a single
if
statement instead of nestedif
statementsCombine
if
statements usingand
(SIM102)
src/MEDS_tabular_automl/scripts/tabularize_static.py (1)
94-96
: Simplify nested if statements.Combine the nested if statements for better readability.
- if cfg.input_label_dir: - if not Path(cfg.input_label_dir).is_dir(): - raise ValueError(f"input_label_dir: {cfg.input_label_dir} is not a directory.") + if cfg.input_label_dir and not Path(cfg.input_label_dir).is_dir(): + raise ValueError(f"input_label_dir: {cfg.input_label_dir} is not a directory.")🧰 Tools
🪛 Ruff (0.8.2)
94-95: Use a single
if
statement instead of nestedif
statementsCombine
if
statements usingand
(SIM102)
src/MEDS_tabular_automl/generate_summarized_reps.py (1)
161-176
: Consider optimizing the label-based window computation.The current implementation performs multiple data transformations that could be optimized:
- Converting to lazy DataFrame and back to eager multiple times
- Redundant window computation when labels are not provided
Consider caching the window computation results and reducing eager/lazy conversions.
src/MEDS_tabular_automl/scripts/cache_task.py (1)
112-122
: Consider optimizing the matrix generation logic.The current implementation creates a CSR matrix and then converts it back to COO format. This could be optimized by:
- Staying in CSR format if possible
- Using direct indexing on COO format
- csr: sp.csr_array = sp.csr_array(matrix) - valid_ids = label_df.select(pl.col("event_id")).collect().to_series().to_numpy() - csr = csr[valid_ids, :] - indices_with_no_past_data = valid_ids == -1 - if indices_with_no_past_data.any().item(): - csr[indices_with_no_past_data] = 0 - csr.eliminate_zeros() - return sp.coo_array(csr) + valid_ids = label_df.select(pl.col("event_id")).collect().to_series().to_numpy() + if isinstance(matrix, sp.coo_array): + # Direct indexing on COO format + mask = np.isin(matrix.row, valid_ids) + row_map = {old: new for new, old in enumerate(valid_ids)} + new_rows = np.array([row_map[r] for r in matrix.row[mask]]) + result = sp.coo_array( + (matrix.data[mask], (new_rows, matrix.col[mask])), + shape=(len(valid_ids), matrix.shape[1]) + ) + # Handle no past data + indices_with_no_past_data = valid_ids == -1 + if indices_with_no_past_data.any().item(): + result = result.tocsr() + result[indices_with_no_past_data] = 0 + result.eliminate_zeros() + result = result.tocoo() + return result + else: + # Fallback to current implementation for non-COO matrices + csr = sp.csr_array(matrix) + csr = csr[valid_ids, :] + indices_with_no_past_data = valid_ids == -1 + if indices_with_no_past_data.any().item(): + csr[indices_with_no_past_data] = 0 + csr.eliminate_zeros() + return sp.coo_array(csr)src/MEDS_tabular_automl/xgboost_model.py (1)
187-189
: Improve error message for label mismatch.The error message could be more descriptive to help with debugging.
- mismatched_labels = predictions_df["boolean_value"] == labels["label"] - raise ValueError(f"Label mismatch: {sum(mismatched_labels)} incorrect predictions") + mismatched_count = (~(predictions_df["boolean_value"] == labels["label"])).sum() + total_count = len(predictions_df) + raise ValueError( + f"Label mismatch: {mismatched_count}/{total_count} predictions do not match " + f"the original labels. This might indicate data corruption or processing errors." + )src/MEDS_tabular_automl/evaluation_callback.py (2)
11-19
: Consider enhancing the mock implementation.The MockModelLauncher could be improved to better simulate real model behavior and aid testing.
class MockModelLauncher: # pragma: no cover + def __init__(self): + self.model_loaded = False + self.is_built = False + def load_model(self, model_path): - pass + self.model_loaded = True def _build(self): - pass + if not self.model_loaded: + raise ValueError("Model must be loaded before building") + self.is_built = True def predict(self, split): + if not self.is_built: + raise ValueError("Model must be built before prediction") + valid_splits = {"train", "tuning", "held_out"} + if split not in valid_splits: + raise ValueError(f"Invalid split: {split}") return pl.DataFrame({"predictions": [0.1, 0.2]})
116-163
: Consider adding retry logic for model operations.The store_predictions method could benefit from retry logic for model operations to handle transient failures.
+from functools import wraps +import time + +def retry_on_error(max_retries=3, delay=1): + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + for attempt in range(max_retries): + try: + return func(*args, **kwargs) + except Exception as e: + if attempt == max_retries - 1: + raise + logger.warning(f"Attempt {attempt + 1} failed: {str(e)}") + time.sleep(delay) + return None + return wrapper + return decorator + def store_predictions(self, best_trial_dir, splits): config = Path(best_trial_dir) / "config.log" xgboost_fp = Path(best_trial_dir) / "xgboost.json" if not xgboost_fp.exists(): logger.warning("Prediction parquets not stored, we only support storing them for xgboost models.") return cfg = OmegaConf.load(config) model_launcher = instantiate(cfg.model_launcher) - model_launcher.load_model(xgboost_fp) - model_launcher._build() + + @retry_on_error() + def load_and_build(): + model_launcher.load_model(xgboost_fp) + model_launcher._build() + + load_and_build() for split in splits: - pred_df = model_launcher.predict(split) - pred_df.write_parquet(Path(best_trial_dir) / f"{split}_predictions.parquet") + @retry_on_error() + def predict_and_save(): + pred_df = model_launcher.predict(split) + pred_df.write_parquet(Path(best_trial_dir) / f"{split}_predictions.parquet") + + predict_and_save()tests/test_tabularize_task.py (2)
145-232
: Consider parameterizing the test data.The test could be more maintainable by moving the test data to separate fixture files.
+import pytest +from typing import Dict + +@pytest.fixture +def test_data() -> Dict[str, str]: + return { + "train/0": MEDS_TRAIN_0, + "train/1": MEDS_TRAIN_1, + "held_out/0": MEDS_HELD_OUT_0, + "tuning/0": MEDS_TUNING_0, + } + +@pytest.fixture +def temp_median() -> float: + return 99.8 + -def test_tabularize(tmp_path): +def test_tabularize(tmp_path, test_data, temp_median):
290-326
: Add more comprehensive assertion messages.The test assertions could benefit from more descriptive error messages.
- assert len(output_files) == 1 + assert len(output_files) == 1, f"Expected exactly one output file, but found {len(output_files)}" log_dir = Path(cfg.path.sweep_results_dir) log_files = list(log_dir.glob("**/*.log")) - assert len(log_files) == 2 + assert len(log_files) == 2, ( + f"Expected exactly two log files (config and performance), but found {len(log_files)}: " + f"{[f.name for f in log_files]}" + )tests/test_integration.py (1)
311-315
: LGTM! Consider reducing code duplication.The assertions correctly verify the existence of critical model output files. However, these assertions are duplicated in two locations.
Consider extracting the assertions into a helper function to avoid duplication:
+def assert_model_outputs(time_output_dir): + assert (time_output_dir / "best_trial/held_out_predictions.parquet").exists() + assert (time_output_dir / "best_trial/tuning_predictions.parquet").exists() + assert (time_output_dir / "sweep_results_summary.parquet").exists() + # Replace both assertion blocks with: - time_output_dir = next(output_model_dir.iterdir()) - assert (time_output_dir / "best_trial/held_out_predictions.parquet").exists() - assert (time_output_dir / "best_trial/tuning_predictions.parquet").exists() - assert (time_output_dir / "sweep_results_summary.parquet").exists() + assert_model_outputs(next(output_model_dir.iterdir()))Also applies to: 349-353
src/MEDS_tabular_automl/tabular_dataset.py (1)
137-139
: LGTM! Consider adding docstring updates.The changes improve robustness by handling missing event IDs and making label column selection more flexible.
Consider updating the method's docstring to reflect these changes:
def _load_ids_and_labels( self, load_ids: bool = True, load_labels: bool = True ) -> tuple[Mapping[int, list], Mapping[int, list]]: - """Loads valid event ids and labels for each shard. + """Loads valid event ids and labels for each shard. + + If the "event_id" column is missing, it will be generated using row indices. + For labels, it first checks for "boolean_value" column, then falls back to "label".Also applies to: 144-145
tests/test_tabularize.py (1)
14-15
: LGTM! Consider enhancing prediction verification.The new test section properly verifies model prediction functionality.
Consider adding more assertions to verify the prediction output structure:
predictions_df = xgboost_model.predict("held_out") assert isinstance(predictions_df, pl.DataFrame) + # Verify required columns exist + assert "subject_id" in predictions_df.columns + assert "prediction_time" in predictions_df.columns + assert "prediction" in predictions_df.columns + # Verify predictions are within expected range for binary classification + assert predictions_df["prediction"].is_between(0, 1).all()Also applies to: 372-384
MIMICIV_TUTORIAL/task_tabularize_meds.sh (1)
34-41
: Consider using named arguments for better maintainability.While the script works correctly, using positional arguments can be error-prone. Consider using a more robust argument parsing approach.
Example using
getopts
orargparse
:-# Assign arguments to variables -MIMICIV_MEDS_DIR="$1" -OUTPUT_TABULARIZATION_DIR="$2" -TASKS="$3" -TASKS_DIR="$4" -OUTPUT_MODEL_DIR="$5" -N_PARALLEL_WORKERS="$6" +while getopts "m:o:t:d:r:w:" opt; do + case $opt in + m) MIMICIV_MEDS_DIR="$OPTARG" ;; + o) OUTPUT_TABULARIZATION_DIR="$OPTARG" ;; + t) TASKS="$OPTARG" ;; + d) TASKS_DIR="$OPTARG" ;; + r) OUTPUT_MODEL_DIR="$OPTARG" ;; + w) N_PARALLEL_WORKERS="$OPTARG" ;; + \?) echo "Invalid option -$OPTARG" >&2; exit 1 ;; + esac +doneMIMICIV_TUTORIAL/README.MD (1)
102-102
: Fix grammar in the storage-intensive description.Address the grammar issues identified by static analysis.
-This is an incredibly storage and memory-intensive operation +This is an extremely storage- and memory-intensive operation🧰 Tools
🪛 LanguageTool
[grammar] ~102-~102: When “all-time” is used as a modifier, it is usually spelled with a hyphen.
Context: ...ation of every unique subject_id across all time points. This is an incredibly storage a...(ALL_TIME_HYPHEN)
[grammar] ~102-~102: You used an adverb (‘incredibly’) instead of an adjective, or a noun (‘storage’) instead of another adjective.
Context: ...t_id across all time points. This is an incredibly storage and memory-intensive operation that cre...(A_RB_NN)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (17)
MIMICIV_TUTORIAL/README.MD
(1 hunks)MIMICIV_TUTORIAL/task_tabularize_meds.sh
(1 hunks)pyproject.toml
(2 hunks)src/MEDS_tabular_automl/configs/launch_model.yaml
(1 hunks)src/MEDS_tabular_automl/configs/tabularization.yaml
(1 hunks)src/MEDS_tabular_automl/evaluation_callback.py
(2 hunks)src/MEDS_tabular_automl/generate_static_features.py
(3 hunks)src/MEDS_tabular_automl/generate_summarized_reps.py
(6 hunks)src/MEDS_tabular_automl/scripts/cache_task.py
(2 hunks)src/MEDS_tabular_automl/scripts/tabularize_static.py
(3 hunks)src/MEDS_tabular_automl/scripts/tabularize_time_series.py
(3 hunks)src/MEDS_tabular_automl/tabular_dataset.py
(1 hunks)src/MEDS_tabular_automl/utils.py
(1 hunks)src/MEDS_tabular_automl/xgboost_model.py
(3 hunks)tests/test_integration.py
(2 hunks)tests/test_tabularize.py
(2 hunks)tests/test_tabularize_task.py
(1 hunks)
✅ Files skipped from review due to trivial changes (1)
- src/MEDS_tabular_automl/utils.py
🧰 Additional context used
🪛 Ruff (0.8.2)
src/MEDS_tabular_automl/scripts/tabularize_static.py
94-95: Use a single if
statement instead of nested if
statements
Combine if
statements using and
(SIM102)
163-163: Function definition does not bind loop variable label_df
(B023)
src/MEDS_tabular_automl/scripts/tabularize_time_series.py
77-78: Use a single if
statement instead of nested if
statements
Combine if
statements using and
(SIM102)
122-122: Function definition does not bind loop variable label_df
(B023)
src/MEDS_tabular_automl/xgboost_model.py
168-168: Use key in dict
instead of key in dict.keys()
Remove .keys()
(SIM118)
🪛 LanguageTool
MIMICIV_TUTORIAL/README.MD
[grammar] ~102-~102: When “all-time” is used as a modifier, it is usually spelled with a hyphen.
Context: ...ation of every unique subject_id across all time points. This is an incredibly storage a...
(ALL_TIME_HYPHEN)
[grammar] ~102-~102: You used an adverb (‘incredibly’) instead of an adjective, or a noun (‘storage’) instead of another adjective.
Context: ...t_id across all time points. This is an incredibly storage and memory-intensive operation that cre...
(A_RB_NN)
⏰ Context from checks skipped due to timeout of 90000ms (3)
- GitHub Check: run_tests_ubuntu (3.12)
- GitHub Check: code-quality
- GitHub Check: run_tests_ubuntu (3.11)
🔇 Additional comments (12)
src/MEDS_tabular_automl/scripts/tabularize_time_series.py (1)
99-103
: LGTM! Robust label handling.The code correctly handles both cases where labels are present or absent.
src/MEDS_tabular_automl/scripts/tabularize_static.py (1)
146-150
: LGTM! Robust label handling.The code correctly handles both cases where labels are present or absent.
src/MEDS_tabular_automl/generate_summarized_reps.py (2)
131-132
: LGTM! Well-documented parameter addition.The new
label_df
parameter is properly typed and documented.Also applies to: 146-146
188-189
: LGTM! Consistent parameter propagation.The
label_df
parameter is correctly propagated through the function chain.Also applies to: 199-199, 226-226
src/MEDS_tabular_automl/generate_static_features.py (1)
169-169
: LGTM! Well-documented parameter addition.The new
label_df
parameter is properly typed and documented.Also applies to: 181-181
src/MEDS_tabular_automl/scripts/cache_task.py (2)
61-109
: Great job on the comprehensive docstring with examples!The docstring effectively demonstrates the function's behavior with various test cases, including edge cases like empty labels and events with no history.
207-215
: Good handling of null event IDs.The code properly handles cases where labels have no prior patient data by:
- Logging a warning message
- Filling null event IDs with -1
src/MEDS_tabular_automl/configs/tabularization.yaml (1)
12-14
: LGTM! Clear and well-documented configuration change.The new
input_label_dir
parameter is properly documented and maintains backward compatibility with its null default.src/MEDS_tabular_automl/configs/launch_model.yaml (1)
25-28
: LGTM! Well-structured prediction split configuration.The addition of explicit prediction splits for
held_out
andtuning
sets follows ML best practices and aligns with the PR's objective to improve prediction output handling.pyproject.toml (2)
37-37
: LGTM! Enhanced test coverage configuration support.The update to
pytest-cov[toml]
enables TOML-based configuration for coverage reporting.
55-61
: LGTM! Well-defined coverage exclusions.The coverage exclusions are appropriately configured to ignore standard patterns like
pragma: nocover
,NotImplementedError
, and main module checks.MIMICIV_TUTORIAL/README.MD (1)
85-99
: LGTM! Clear documentation of the recommended approach.The documentation clearly explains the selective task tabularization approach and its benefits.
@@ -109,6 +119,7 @@ def compute_fn(shard_df): | |||
sparse_matrix, | |||
window_size, | |||
agg, | |||
label_df, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix potential loop variable binding issue.
The label_df
variable is used in a closure which could lead to unexpected behavior due to late binding.
- label_df,
+ label_df.clone() if label_df is not None else None,
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
label_df, | |
label_df.clone() if label_df is not None else None, |
🧰 Tools
🪛 Ruff (0.8.2)
122-122: Function definition does not bind loop variable label_df
(B023)
@@ -150,6 +160,7 @@ def compute_fn(shard_df): | |||
agg=agg, | |||
feature_columns=feature_columns, | |||
shard_df=shard_df, | |||
label_df=label_df, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix potential loop variable binding issue.
The label_df
variable is used in a closure which could lead to unexpected behavior due to late binding.
- label_df=label_df,
+ label_df=label_df.clone() if label_df is not None else None,
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
label_df=label_df, | |
label_df=label_df.clone() if label_df is not None else None, |
🧰 Tools
🪛 Ruff (0.8.2)
163-163: Function definition does not bind loop variable label_df
(B023)
if label_df is not None: | ||
event_df = label_df.rename({"prediction_time": "time"}) | ||
else: | ||
event_df = get_unique_time_events_df(get_events_df(shard_df, feature_columns)) | ||
|
||
matrix = get_sparse_static_rep(static_features, static_measurements.lazy(), event_df, feature_columns) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codebase verification
Add safety checks for label-based event generation
The verification reveals that the direct usage of label DataFrame bypasses critical safety checks that were present in get_unique_time_events_df
:
- No verification of time column's non-null values
- No guarantee of unique time values
- No validation of proper sorting by subject_id and time
These checks are crucial as they ensure alignment with the ts_matrix generated later in the pipeline. Consider adding equivalent checks when using the label DataFrame.
🔗 Analysis chain
Verify the impact of label-based event generation.
The change from using get_unique_time_events_df
to direct label DataFrame usage might affect the ordering and uniqueness of events.
Run the following script to verify the event generation behavior:
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Check for potential differences in event generation between label-based and standard approaches
# Search for usages of get_unique_time_events_df to understand the expected behavior
rg -A 5 "get_unique_time_events_df"
# Search for potential ordering requirements in the codebase
ast-grep --pattern 'sort(by="time")'
Length of output: 6316
Script:
#!/bin/bash
# Check the implementation details and requirements
# View the implementation of get_unique_time_events_df
rg -A 10 "def get_unique_time_events_df" src/MEDS_tabular_automl/utils.py
# Check get_sparse_static_rep implementation for any ordering requirements
rg -A 10 "def get_sparse_static_rep" src/MEDS_tabular_automl/generate_static_features.py
# Look for any sorting operations on event_df
rg -B 2 -A 2 "event_df.*sort" src/MEDS_tabular_automl/
Length of output: 1257
Script:
#!/bin/bash
# Find complete implementations and operations
# Find get_unique_time_events_df implementation
ast-grep --pattern 'def get_unique_time_events_df($_) -> $_:
$$$'
# Find get_sparse_static_rep implementation
ast-grep --pattern 'def get_sparse_static_rep($_) -> $_:
$$$'
# Look for distinct/unique operations
rg -A 3 "\.distinct\(\)|\.unique\(\)" src/MEDS_tabular_automl/utils.py
# Look for sort operations
rg -A 3 "\.sort\(" src/MEDS_tabular_automl/utils.py
Length of output: 2186
def load_model(self, xgboost_json_fp: Path): | ||
self.model = xgb.Booster() | ||
self.model.load_model(str(xgboost_json_fp)) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Consider adding error handling for model loading.
The model loading code should handle potential file access and format errors.
def load_model(self, xgboost_json_fp: Path):
+ if not xgboost_json_fp.exists():
+ raise FileNotFoundError(f"Model file not found: {xgboost_json_fp}")
self.model = xgb.Booster()
- self.model.load_model(str(xgboost_json_fp))
+ try:
+ self.model.load_model(str(xgboost_json_fp))
+ except Exception as e:
+ raise ValueError(f"Failed to load model from {xgboost_json_fp}: {str(e)}")
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def load_model(self, xgboost_json_fp: Path): | |
self.model = xgb.Booster() | |
self.model.load_model(str(xgboost_json_fp)) | |
def load_model(self, xgboost_json_fp: Path): | |
if not xgboost_json_fp.exists(): | |
raise FileNotFoundError(f"Model file not found: {xgboost_json_fp}") | |
self.model = xgb.Booster() | |
try: | |
self.model.load_model(str(xgboost_json_fp)) | |
except Exception as e: | |
raise ValueError(f"Failed to load model from {xgboost_json_fp}: {str(e)}") |
def _predict(self, split="held_out") -> tuple[np.ndarray, np.ndarray]: | ||
"""Helper Function that retrieves model predictions and labels.""" | ||
if split == "tuning": | ||
y_pred = self.model.predict(self.dtuning) | ||
y_true = self.dtuning.get_label() | ||
elif split == "held_out": | ||
y_pred = self.model.predict(self.dheld_out) | ||
y_true = self.dheld_out.get_label() | ||
elif split == "train": | ||
y_pred = self.model.predict(self.dtrain) | ||
y_true = self.dtrain.get_label() | ||
else: | ||
raise ValueError(f"Invalid split for evaluation: {split}") | ||
return y_true, y_pred | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Consider adding input validation for the split parameter.
The split parameter should be validated against a predefined set of valid values.
def _predict(self, split="held_out") -> tuple[np.ndarray, np.ndarray]:
"""Helper Function that retrieves model predictions and labels."""
+ valid_splits = {"tuning", "held_out", "train"}
+ if split not in valid_splits:
+ raise ValueError(f"Invalid split '{split}'. Must be one of {valid_splits}")
+ if self.model is None:
+ raise ValueError("Model not loaded. Call load_model() first.")
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def _predict(self, split="held_out") -> tuple[np.ndarray, np.ndarray]: | |
"""Helper Function that retrieves model predictions and labels.""" | |
if split == "tuning": | |
y_pred = self.model.predict(self.dtuning) | |
y_true = self.dtuning.get_label() | |
elif split == "held_out": | |
y_pred = self.model.predict(self.dheld_out) | |
y_true = self.dheld_out.get_label() | |
elif split == "train": | |
y_pred = self.model.predict(self.dtrain) | |
y_true = self.dtrain.get_label() | |
else: | |
raise ValueError(f"Invalid split for evaluation: {split}") | |
return y_true, y_pred | |
def _predict(self, split="held_out") -> tuple[np.ndarray, np.ndarray]: | |
"""Helper Function that retrieves model predictions and labels.""" | |
valid_splits = {"tuning", "held_out", "train"} | |
if split not in valid_splits: | |
raise ValueError(f"Invalid split '{split}'. Must be one of {valid_splits}") | |
if self.model is None: | |
raise ValueError("Model not loaded. Call load_model() first.") | |
if split == "tuning": | |
y_pred = self.model.predict(self.dtuning) | |
y_true = self.dtuning.get_label() | |
elif split == "held_out": | |
y_pred = self.model.predict(self.dheld_out) | |
y_true = self.dheld_out.get_label() | |
elif split == "train": | |
y_pred = self.model.predict(self.dtrain) | |
y_true = self.dtrain.get_label() | |
else: | |
raise ValueError(f"Invalid split for evaluation: {split}") | |
return y_true, y_pred |
tabularize_static.main(cfg) | ||
|
||
output_dir = Path(cfg.output_dir) / "tabularize" | ||
|
||
output_files = list(output_dir.glob("**/static/**/*.npz")) | ||
actual_files = [get_shard_prefix(output_dir, each) + ".npz" for each in output_files] | ||
assert set(actual_files) == set(EXPECTED_STATIC_FILES) | ||
|
||
# Validate output matrices have the same number of columns as labels | ||
for f in output_files: | ||
static_matrix = load_matrix(f) | ||
assert static_matrix.shape[0] > 0, "Static Data Tabular Dataframe Should not be Empty!" | ||
|
||
# Get the expected number of columns from feature names | ||
expected_num_cols = len(get_feature_names(f"static/{f.stem}", feature_columns)) | ||
assert static_matrix.shape[1] == expected_num_cols, ( | ||
f"Static Data Tabular Dataframe Should have {expected_num_cols} " | ||
f"Columns but has {static_matrix.shape[1]}!" | ||
) | ||
|
||
label_df = pl.read_parquet( | ||
Path(cfg.input_label_dir) / f.relative_to(output_dir).parent.parent.parent.with_suffix(".parquet") | ||
) | ||
assert static_matrix.shape[0] == label_df.height, ( | ||
f"Static Data Tabular Dataframe Should have {label_df.height} " | ||
f"Rows but has {static_matrix.shape[0]}!" | ||
) | ||
|
||
tabularize_time_series.main(cfg) | ||
|
||
# confirm summary files exist: | ||
output_files = list_subdir_files(str(output_dir.resolve()), "npz") | ||
actual_files = [ | ||
get_shard_prefix(output_dir, each) + ".npz" for each in output_files if "none/static" not in str(each) | ||
] | ||
assert len(actual_files) > 0 | ||
|
||
# Validate output matrices have the same number of columns as labels | ||
for f in output_files: | ||
static_matrix = load_matrix(f) | ||
assert static_matrix.shape[0] > 0, "Static Data Tabular Dataframe Should not be Empty!" | ||
|
||
# Get the expected number of columns from feature names | ||
expected_num_cols = len(get_feature_names(f"{f.parent.stem}/{f.stem}", feature_columns)) | ||
assert static_matrix.shape[1] == expected_num_cols, ( | ||
f"Static Data Tabular Dataframe Should have {expected_num_cols} " | ||
f"Columns but has {static_matrix.shape[1]}!" | ||
) | ||
|
||
label_df = pl.read_parquet( | ||
Path(cfg.input_label_dir) / f.relative_to(output_dir).parent.parent.parent.with_suffix(".parquet") | ||
) | ||
assert static_matrix.shape[0] == label_df.height, ( | ||
f"Static Data Tabular Dataframe Should have {label_df.height} " | ||
f"Rows but has {static_matrix.shape[0]}!" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Add cleanup for temporary files.
The test should clean up temporary files even if assertions fail.
+import contextlib
+import shutil
+
+@contextlib.contextmanager
+def temporary_files(*paths):
+ try:
+ yield
+ finally:
+ for path in paths:
+ if isinstance(path, (str, Path)):
+ path = Path(path)
+ if path.exists():
+ if path.is_dir():
+ shutil.rmtree(path)
+ else:
+ path.unlink()
+
- tabularize_static.main(cfg)
+ with temporary_files(output_dir, input_label_dir):
+ tabularize_static.main(cfg)
Committable suggestion skipped: line range outside the PR's diff.
worker="range(0,$N_PARALLEL_WORKERS)" \ | ||
hydra/launcher=joblib \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Validate N_PARALLEL_WORKERS value.
Add validation to ensure N_PARALLEL_WORKERS is a positive integer and doesn't exceed system resources.
+# Validate N_PARALLEL_WORKERS
+if ! [[ "$N_PARALLEL_WORKERS" =~ ^[0-9]+$ ]] || [ "$N_PARALLEL_WORKERS" -lt 1 ]; then
+ echo "Error: N_PARALLEL_WORKERS must be a positive integer"
+ exit 1
+fi
+
+# Get available CPU cores
+AVAILABLE_CORES=$(nproc)
+if [ "$N_PARALLEL_WORKERS" -gt "$AVAILABLE_CORES" ]; then
+ echo "Warning: N_PARALLEL_WORKERS ($N_PARALLEL_WORKERS) exceeds available CPU cores ($AVAILABLE_CORES)"
+fi
+
worker="range(0,$N_PARALLEL_WORKERS)" \
hydra/launcher=joblib \
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
worker="range(0,$N_PARALLEL_WORKERS)" \ | |
hydra/launcher=joblib \ | |
# Validate N_PARALLEL_WORKERS | |
if ! [[ "$N_PARALLEL_WORKERS" =~ ^[0-9]+$ ]] || [ "$N_PARALLEL_WORKERS" -lt 1 ]; then | |
echo "Error: N_PARALLEL_WORKERS must be a positive integer" | |
exit 1 | |
fi | |
# Get available CPU cores | |
AVAILABLE_CORES=$(nproc) | |
if [ "$N_PARALLEL_WORKERS" -gt "$AVAILABLE_CORES" ]; then | |
echo "Warning: N_PARALLEL_WORKERS ($N_PARALLEL_WORKERS) exceeds available CPU cores ($AVAILABLE_CORES)" | |
fi | |
worker="range(0,$N_PARALLEL_WORKERS)" \ | |
hydra/launcher=joblib \ |
for TASK in "${TASK_ARRAY[@]}" | ||
do | ||
mkdir -p "${OUTPUT_TABULARIZATION_DIR}/${TASK}" | ||
rsync -r "${OUTPUT_TABULARIZATION_DIR}/metadata/" "${OUTPUT_TABULARIZATION_DIR}/${TASK}/metadata" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Add directory existence validation.
The script should validate that required directories exist before proceeding.
+# Validate input directories
+for dir in "$MIMICIV_MEDS_DIR" "$TASKS_DIR"; do
+ if [ ! -d "$dir" ]; then
+ echo "Error: Directory $dir does not exist."
+ exit 1
+ fi
+done
+
for TASK in "${TASK_ARRAY[@]}"
do
mkdir -p "${OUTPUT_TABULARIZATION_DIR}/${TASK}"
rsync -r "${OUTPUT_TABULARIZATION_DIR}/metadata/" "${OUTPUT_TABULARIZATION_DIR}/${TASK}/metadata"
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
for TASK in "${TASK_ARRAY[@]}" | |
do | |
mkdir -p "${OUTPUT_TABULARIZATION_DIR}/${TASK}" | |
rsync -r "${OUTPUT_TABULARIZATION_DIR}/metadata/" "${OUTPUT_TABULARIZATION_DIR}/${TASK}/metadata" | |
# Validate input directories | |
for dir in "$MIMICIV_MEDS_DIR" "$TASKS_DIR"; do | |
if [ ! -d "$dir" ]; then | |
echo "Error: Directory $dir does not exist." | |
exit 1 | |
fi | |
done | |
for TASK in "${TASK_ARRAY[@]}" | |
do | |
mkdir -p "${OUTPUT_TABULARIZATION_DIR}/${TASK}" | |
rsync -r "${OUTPUT_TABULARIZATION_DIR}/metadata/" "${OUTPUT_TABULARIZATION_DIR}/${TASK}/metadata" |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #110 +/- ##
==========================================
+ Coverage 81.96% 84.22% +2.26%
==========================================
Files 20 20
Lines 1253 1325 +72
==========================================
+ Hits 1027 1116 +89
+ Misses 226 209 -17 ☔ View full report in Codecov by Sentry. |
Adds meds eval label schema adherence tests for generated prediction …
Summary by CodeRabbit
Release Notes
New Features
Improvements
Testing
Documentation
These changes enhance the flexibility and robustness of the MEDS tabular automation toolkit, providing more precise control over data processing and model evaluation.