-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a basic example workflow. #16
Changes from 11 commits
b28726e
0e15c17
19ae1d9
8e9acab
3afe7bf
3811f40
9e0c0e4
791d18f
5ffef7d
5756038
c2df073
b2affac
26b439f
07ab03c
e974c1f
f66880d
beba3ee
ae6f1d4
ae744c6
94cf216
96c9bac
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
defaults: | ||
- _ACES_MD | ||
- _self_ | ||
- override hydra/hydra_logging: disabled | ||
|
||
cohort_predictions_dir: "${oc.env:MEDS_ROOT_DIR}/task_predictions" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
#!/bin/bash | ||
|
||
export MEDS_ROOT_DIR=$1 | ||
export MEDS_DATASET_NAME=$2 | ||
export MEDS_TASK_NAME=$3 | ||
|
||
shift 3 | ||
|
||
MEDS_DEV_REPO_DIR=$(python -c "from importlib.resources import files; print(files(\"MEDS_DEV\"))") | ||
export MEDS_DEV_REPO_DIR | ||
|
||
# TODO improve efficiency of prediction generator by using this | ||
# SHARDS=$(expand_shards "$MEDS_ROOT_DIR"/data) | ||
mmcdermott marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
python -m MEDS_DEV.helpers.generate_random_predictions --config-path="$MEDS_DEV_REPO_DIR"/configs \ | ||
--config-name="predictions" "hydra.searchpath=[pkg://aces.configs]" "$@" |
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,68 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
import os | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from importlib.resources import files | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from pathlib import Path | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
import hydra | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
import numpy as np | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
import polars as pl | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from omegaconf import DictConfig | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
SUBJECT_ID = "subject_id" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
PREDICTION_TIME = "prediction_time" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
BOOLEAN_VALUE_COLUMN = "boolean_value" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
PREDICTED_BOOLEAN_VALUE_COLUMN = "predicted_boolean_value" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
PREDICTED_BOOLEAN_PROBABILITY_COLUMN = "predicted_boolean_probability" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
CONFIG = files("MEDS_DEV").joinpath("configs/predictions.yaml") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
@hydra.main(version_base=None, config_path=str(CONFIG.parent.resolve()), config_name=CONFIG.stem) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def generate_random_predictions(cfg: DictConfig) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
cohort_dir = cfg.cohort_dir # cohort_dir: "${oc.env:MEDS_ROOT_DIR}/task_labels" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
cohort_name = cfg.cohort_name # cohort_name: ${task_name}; task_name: ${oc.env:MEDS_TASK_NAME} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
cohort_dir = Path(cohort_dir) / cohort_name | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
cohort_predictions_dir = ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
cfg.cohort_predictions_dir | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) # cohort_predictions_dir: "${oc.env:MEDS_ROOT_DIR}/task_predictions" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# TODO: use expand_shards helper from the script to access sharded dataframes directly | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
for split in cohort_dir.iterdir(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if split.is_dir() and split.name in {"train", "tuning", "held_out"}: # train | tuning | held_out | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
for file in split.iterdir(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if file.is_file(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
dataframe = pl.read_parquet(file) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
predictions = _generate_random_predictions(dataframe) # sharded dataframes | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+35
to
+36
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add error handling for file I/O operations. Reading parquet files without error handling may cause the script to crash if a file is missing or corrupted. Incorporate try-except blocks around file I/O operations to handle exceptions gracefully and provide informative error messages. Apply this diff to add error handling: for split in cohort_dir.iterdir():
if split.is_dir() and split.name in {"train", "tuning", "held_out"}:
for file in split.iterdir():
if file.is_file():
- dataframe = pl.read_parquet(file)
- predictions = _generate_random_predictions(dataframe)
+ try:
+ dataframe = pl.read_parquet(file)
+ predictions = _generate_random_predictions(dataframe)
+ except Exception as e:
+ print(f"Error processing {file}: {e}")
+ continue
# Rest of the code...
elif split.is_file():
- dataframe = pl.read_parquet(split)
- predictions = _generate_random_predictions(dataframe)
+ try:
+ dataframe = pl.read_parquet(split)
+ predictions = _generate_random_predictions(dataframe)
+ except Exception as e:
+ print(f"Error processing {split}: {e}")
+ return Also applies to: 44-45 |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# $MEDS_ROOT_DIR/task_predictions/$TASK_NAME/<split>/<file>.parquet | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
predictions_path = Path(cohort_predictions_dir) / cohort_name / split.name | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
os.makedirs(predictions_path, exist_ok=True) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
predictions.write_parquet(predictions_path / file.name) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
elif split.is_file(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
dataframe = pl.read_parquet(split) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
predictions = _generate_random_predictions(dataframe) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
predictions_path = Path(cohort_predictions_dir) / cohort_name | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
os.makedirs(predictions_path, exist_ok=True) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
predictions.write_parquet(predictions_path / split.name) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+31
to
+51
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Refactor duplicated code to improve maintainability. The code segments handling directory and file inputs contain duplicated logic, particularly in reading dataframes and writing predictions. Refactoring this section by extracting common functionality into helper functions will reduce code duplication and enhance readability. Apply this refactor to consolidate duplicated code: def generate_random_predictions(cfg: DictConfig) -> None:
# Existing code above...
+ def process_file(file, predictions_path):
+ dataframe = pl.read_parquet(file)
+ predictions = _generate_random_predictions(dataframe)
+ os.makedirs(predictions_path, exist_ok=True)
+ predictions.write_parquet(predictions_path / file.name)
+
for split in cohort_dir.iterdir():
if split.is_dir() and split.name in {"train", "tuning", "held_out"}:
for file in split.iterdir():
if file.is_file():
- dataframe = pl.read_parquet(file)
- predictions = _generate_random_predictions(dataframe)
-
- predictions_path = Path(cohort_predictions_dir) / cohort_name / split.name
- os.makedirs(predictions_path, exist_ok=True)
-
- predictions.write_parquet(predictions_path / file.name)
+ predictions_path = Path(cohort_predictions_dir) / cohort_name / split.name
+ process_file(file, predictions_path)
elif split.is_file():
- dataframe = pl.read_parquet(split)
- predictions = _generate_random_predictions(dataframe)
-
predictions_path = Path(cohort_predictions_dir) / cohort_name
- os.makedirs(predictions_path, exist_ok=True)
-
- predictions.write_parquet(predictions_path / split.name)
+ process_file(split, predictions_path) 📝 Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def _generate_random_predictions(dataframe: pl.DataFrame) -> pl.DataFrame: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"""Creates a new dataframe with the same subject_id and boolean_value columns as in the input dataframe, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
along with predictions.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
output = dataframe.select([SUBJECT_ID, PREDICTION_TIME, BOOLEAN_VALUE_COLUMN]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
probabilities = np.random.uniform(0, 1, len(dataframe)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# TODO: meds-evaluation currently cares about the order of columns and types, so the new columns have to | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# be inserted at the correct position and cast to the correct type | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
output.insert_column(3, pl.Series(PREDICTED_BOOLEAN_VALUE_COLUMN, probabilities.round()).cast(pl.Boolean)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Generate boolean predictions using a threshold comparison. Rounding uniform probabilities may not produce an unbiased random boolean outcome. Using a threshold comparison ensures a fair distribution of Apply this diff to update the boolean value generation: - output.insert_column(3, pl.Series(PREDICTED_BOOLEAN_VALUE_COLUMN, probabilities.round()).cast(pl.Boolean))
+ predicted_values = probabilities > 0.5
+ output.insert_column(3, pl.Series(PREDICTED_BOOLEAN_VALUE_COLUMN, predicted_values)) Alternatively, use - probabilities = rng.uniform(0, 1, len(dataframe))
+ predicted_values = rng.choice([True, False], size=len(dataframe)) 📝 Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
output.insert_column(4, pl.Series(PREDICTED_BOOLEAN_PROBABILITY_COLUMN, probabilities)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
mmcdermott marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return output | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if __name__ == "__main__": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
generate_random_predictions() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is this TODO for? I'm not sure I understand.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At the moment I'm reimplementing recursive search through the directory to find shards for which to generate the predictions, but I think this could be improved with the
expand_shards
helper you have implemented.