-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
adding processing table functions (#71)
- Loading branch information
Showing
3 changed files
with
322 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,187 @@ | ||
from datetime import timedelta | ||
from typing import Dict, List, Union | ||
|
||
import numpy as np | ||
from loguru import logger | ||
|
||
from eds_scikit.utils.checks import check_columns | ||
from eds_scikit.utils.typing import DataFrame | ||
|
||
|
||
def tag_table_by_type( | ||
table: DataFrame, | ||
type_groups: Union[str, Dict], | ||
source_col: str, | ||
target_col: str, | ||
filter_table: bool = False, | ||
): | ||
"""Add tag column to table based on their value (ex : condition_occurrence -> "DIABETIC", "NOT DIABETIC) | ||
Parameters | ||
---------- | ||
table : DataFrame | ||
Table (must contain columns source_col, target_col) | ||
type_groups : Union[str, Dict] | ||
Regex or Dict of regex to define tags and associated regex. | ||
source_col : str | ||
Column on which the tagging is applied. | ||
target_col : str | ||
Label column name | ||
remove_other : bool | ||
If True, remove untagged columns | ||
Returns | ||
------- | ||
DataFrame | ||
Input dataframe with tag column `target_col` | ||
Output | ||
------- | ||
| person_id | condition_source_value | DIABETIC_CONDITION | | ||
|:---------------------------:|-------------------------:|:---------------------:| | ||
| 001 | E100 | DIABETES_TYPE_I | | ||
| 002 | E101 | DIABETES_TYPE_I | | ||
| 003 | E110 | DIABETES_TYPE_II | | ||
| 004 | E113 | DIABETES_TYPE_II | | ||
| 005 | A001 | OTHER | | ||
""" | ||
if isinstance(type_groups, str): | ||
type_groups = {type_groups: type_groups} | ||
table[target_col] = "OTHER" | ||
|
||
for type_name, type_value in type_groups.items(): | ||
|
||
table.loc[ | ||
table[source_col] | ||
.astype(str) | ||
.str.contains( | ||
type_value, | ||
case=False, | ||
regex=True, | ||
na=False, | ||
), | ||
target_col, | ||
] = type_name | ||
|
||
logger.debug( | ||
"The following {} : {} have been tagged on table.", | ||
target_col, | ||
type_groups, | ||
) | ||
|
||
table = table[table[target_col] != "OTHER"] if filter_table else table | ||
|
||
return table | ||
|
||
|
||
def tag_table_period_length( | ||
table: DataFrame, | ||
length_of_stays: List[float], | ||
start_date_col: str = "visit_start_datetime", | ||
end_date_col: str = "visit_end_datetime", | ||
) -> DataFrame: | ||
"""Tag table by length of stays (can be applied to visit_occurrence table) | ||
Example : length_of_stays = [7, 14] | ||
Output | ||
------- | ||
| person_id | visit_start_datetime | visit_end_datetime | length_of_stay | | ||
|:---------------------------:|-------------------------:|:---------------------:|:---------------------:| | ||
| 001 | 2020-04-01 | 2020-04-12 | "7 days - 14 days" | | ||
| 002 | 2020-04-01 | 2020-04-03 | "<= 7 days " | | ||
| 003 | 2020-04-01 | 2020-04-09 | ">= 7 days " | | ||
Parameters | ||
---------- | ||
table : DataFrame | ||
length_of_stays : List[float] | ||
Example : [7 , 14] | ||
start_date_col : str, optional | ||
by default "visit_start_datetime" | ||
end_date_col : str, optional | ||
by default "visit_end_datetime" | ||
Returns | ||
------- | ||
DataFrame | ||
""" | ||
table = table.assign( | ||
length=(table[end_date_col] - table[start_date_col]) | ||
/ np.timedelta64(timedelta(days=1)) | ||
) | ||
|
||
# Incomplete stays | ||
table = table.assign(length_of_stay="Not specified") | ||
table["length_of_stay"] = table.length_of_stay.mask( | ||
table[end_date_col].isna(), | ||
"Incomplete stay", | ||
) | ||
|
||
# Complete stays | ||
min_duration = length_of_stays[0] | ||
max_duration = length_of_stays[-1] | ||
table["length_of_stay"] = table["length_of_stay"].mask( | ||
(table["length"] <= min_duration), | ||
"<= {} days".format(min_duration), | ||
) | ||
table["length_of_stay"] = table["length_of_stay"].mask( | ||
(table["length"] >= max_duration), | ||
">= {} days".format(max_duration), | ||
) | ||
for min_length, max_length in zip(length_of_stays[:-1], length_of_stays[1:]): | ||
table["length_of_stay"] = table["length_of_stay"].mask( | ||
(table["length"] >= min_length) & (table["length"] < max_length), | ||
"{} days - {} days".format(min_length, max_length), | ||
) | ||
table = table.drop(columns="length") | ||
|
||
return table | ||
|
||
|
||
def tag_table_with_age( | ||
table: DataFrame, date_col: str, person: DataFrame, age_ranges: List[int] = None | ||
): | ||
"""Tag table with person age | ||
Parameters | ||
---------- | ||
table : DataFrame | ||
must contain person_id and date_col | ||
date_column: str | ||
date column from table on which to compute age | ||
person : DataFrame | ||
must contain person_id | ||
age_ranges : List[int] | ||
if None, simply compute age. | ||
example : None, [18], [18, 60] | ||
Returns | ||
------- | ||
DataFrame | ||
""" | ||
check_columns(df=person, required_columns=["person_id", "birth_datetime"]) | ||
check_columns(df=table, required_columns=[date_col, "person_id"]) | ||
|
||
table = table.merge(person[["person_id", "birth_datetime"]], on="person_id") | ||
|
||
table["age"] = (table[date_col] - table["birth_datetime"]) / ( | ||
np.timedelta64(timedelta(days=1)) * 356 | ||
) | ||
table["age"] = table["age"].astype(int) | ||
|
||
table["age_range"] = "Not specified" | ||
if age_ranges: | ||
age_ranges.sort() | ||
table.loc[table.age <= age_ranges[0], "age_range"] = f"age <= {age_ranges[0]}" | ||
|
||
for age_min, age_max in zip(age_ranges[:-1], age_ranges[1:]): | ||
in_range = (table.age > age_min) & (table.age <= age_max) | ||
table.loc[in_range, "age_range"] = f"{age_min} < age <= {age_max}" | ||
|
||
table.loc[table.age > age_ranges[-1], "age_range"] = f"age > {age_ranges[-1]}" | ||
|
||
return table |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
import pandas as pd | ||
import pytest | ||
|
||
from eds_scikit.utils import framework | ||
from eds_scikit.utils.process_table import ( | ||
tag_table_by_type, | ||
tag_table_period_length, | ||
tag_table_with_age, | ||
) | ||
|
||
# Generate random data for the first dataframe | ||
num_rows = 1000 | ||
table = { | ||
"condition_source_value": ["E100", "E101", "E110", "A001", "B002"], | ||
"visit_start_datetime": [ | ||
"2021-05-16", | ||
"2018-08-16", | ||
"2023-03-14", | ||
"2023-05-09", | ||
"2022-07-17", | ||
], | ||
"visit_end_datetime": [ | ||
"2021-05-26", | ||
"2018-09-16", | ||
"2023-03-15", | ||
"2023-10-10", | ||
"2022-07-18", | ||
], | ||
"person_id": [0, 1, 2, 3, 4], | ||
} | ||
|
||
table = pd.DataFrame(table) | ||
table["visit_start_datetime"] = pd.to_datetime(table["visit_start_datetime"]) | ||
table["visit_end_datetime"] = pd.to_datetime(table["visit_end_datetime"]) | ||
|
||
# Generate random data for the second dataframe | ||
person = { | ||
"person_id": [0, 1, 2, 3, 4], | ||
"birth_datetime": [ | ||
"2000-03-29", | ||
"1990-04-08", | ||
"1975-09-28", | ||
"1970-04-28", | ||
"1975-10-03", | ||
], | ||
} | ||
person["birth_datetime"] = pd.to_datetime(person["birth_datetime"]) | ||
|
||
person = pd.DataFrame(person) | ||
|
||
|
||
@pytest.mark.parametrize("module", ["pandas", "koalas"]) | ||
def test_tag_table_with_age(module): | ||
|
||
person_fr = framework.to(module, person) | ||
table_fr = framework.to(module, table) | ||
|
||
table_with_age = tag_table_with_age( | ||
table_fr, "visit_start_datetime", person_fr, age_ranges=[24, 30, 40] | ||
) | ||
table_with_age = framework.to("pandas", table_with_age) | ||
assert ( | ||
table_with_age["age_range"] | ||
== pd.Series( | ||
["age <= 24", "24 < age <= 30", "age > 40", "age > 40", "age > 40"], | ||
name="age_range", | ||
) | ||
).all() | ||
|
||
table_with_age = tag_table_with_age( | ||
table_fr, "visit_start_datetime", person_fr, age_ranges=None | ||
) | ||
table_with_age = framework.to("pandas", table_with_age) | ||
assert (table_with_age["age"] == pd.Series([21, 29, 48, 54, 48], name="age")).all() | ||
|
||
|
||
@pytest.mark.parametrize("module", ["pandas", "koalas"]) | ||
def test_table_by_type(module): | ||
|
||
table_fr = framework.to(module, table) | ||
|
||
table_by_type = tag_table_by_type( | ||
table_fr, | ||
type_groups={"DIABETES_TYPE_I": r"^E10", "DIABETES_TYPE_II": r"^E11"}, | ||
source_col="condition_source_value", | ||
target_col="tag", | ||
) | ||
table_by_type = framework.to("pandas", table_by_type) | ||
assert ( | ||
table_by_type["tag"] | ||
== pd.Series( | ||
[ | ||
"DIABETES_TYPE_I", | ||
"DIABETES_TYPE_I", | ||
"DIABETES_TYPE_II", | ||
"OTHER", | ||
"OTHER", | ||
], | ||
name="tag", | ||
) | ||
).all() | ||
table_by_type = tag_table_by_type( | ||
table_fr, | ||
type_groups={"DIABETES_TYPE_I": r"^E10", "DIABETES_TYPE_II": r"^E11"}, | ||
source_col="condition_source_value", | ||
target_col="tag", | ||
filter_table=True, | ||
) | ||
table_by_type = framework.to("pandas", table_by_type) | ||
assert ( | ||
table_by_type["tag"] | ||
== pd.Series( | ||
["DIABETES_TYPE_I", "DIABETES_TYPE_I", "DIABETES_TYPE_II"], name="tag" | ||
) | ||
).all() | ||
|
||
|
||
@pytest.mark.parametrize("module", ["pandas", "koalas"]) | ||
def test_tag_table_period_length(module): | ||
|
||
table_fr = framework.to(module, table) | ||
|
||
table_period_length = tag_table_period_length(table_fr, length_of_stays=[7, 14]) | ||
table_period_length = framework.to("pandas", table_period_length) | ||
assert ( | ||
table_period_length["length_of_stay"] | ||
== pd.Series( | ||
["7 days - 14 days", ">= 14 days", "<= 7 days", ">= 14 days", "<= 7 days"], | ||
name="tag", | ||
) | ||
).all() |