Skip to content

Commit

Permalink
adding processing table functions (#71)
Browse files Browse the repository at this point in the history
  • Loading branch information
svittoz authored Jun 27, 2024
1 parent 5539046 commit 77f8598
Show file tree
Hide file tree
Showing 3 changed files with 322 additions and 0 deletions.
4 changes: 4 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# Changelog

## Unreleased

### Added
- Functions tag_table_with_age, tag_table_period_length, tag_table_by_type

### Fixed
- Quartiles computed from plot_concepts_set does not depend on value selection anymore

Expand Down
187 changes: 187 additions & 0 deletions eds_scikit/utils/process_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
from datetime import timedelta
from typing import Dict, List, Union

import numpy as np
from loguru import logger

from eds_scikit.utils.checks import check_columns
from eds_scikit.utils.typing import DataFrame


def tag_table_by_type(
table: DataFrame,
type_groups: Union[str, Dict],
source_col: str,
target_col: str,
filter_table: bool = False,
):
"""Add tag column to table based on their value (ex : condition_occurrence -> "DIABETIC", "NOT DIABETIC)
Parameters
----------
table : DataFrame
Table (must contain columns source_col, target_col)
type_groups : Union[str, Dict]
Regex or Dict of regex to define tags and associated regex.
source_col : str
Column on which the tagging is applied.
target_col : str
Label column name
remove_other : bool
If True, remove untagged columns
Returns
-------
DataFrame
Input dataframe with tag column `target_col`
Output
-------
| person_id | condition_source_value | DIABETIC_CONDITION |
|:---------------------------:|-------------------------:|:---------------------:|
| 001 | E100 | DIABETES_TYPE_I |
| 002 | E101 | DIABETES_TYPE_I |
| 003 | E110 | DIABETES_TYPE_II |
| 004 | E113 | DIABETES_TYPE_II |
| 005 | A001 | OTHER |
"""
if isinstance(type_groups, str):
type_groups = {type_groups: type_groups}
table[target_col] = "OTHER"

for type_name, type_value in type_groups.items():

table.loc[
table[source_col]
.astype(str)
.str.contains(
type_value,
case=False,
regex=True,
na=False,
),
target_col,
] = type_name

logger.debug(
"The following {} : {} have been tagged on table.",
target_col,
type_groups,
)

table = table[table[target_col] != "OTHER"] if filter_table else table

return table


def tag_table_period_length(
table: DataFrame,
length_of_stays: List[float],
start_date_col: str = "visit_start_datetime",
end_date_col: str = "visit_end_datetime",
) -> DataFrame:
"""Tag table by length of stays (can be applied to visit_occurrence table)
Example : length_of_stays = [7, 14]
Output
-------
| person_id | visit_start_datetime | visit_end_datetime | length_of_stay |
|:---------------------------:|-------------------------:|:---------------------:|:---------------------:|
| 001 | 2020-04-01 | 2020-04-12 | "7 days - 14 days" |
| 002 | 2020-04-01 | 2020-04-03 | "<= 7 days " |
| 003 | 2020-04-01 | 2020-04-09 | ">= 7 days " |
Parameters
----------
table : DataFrame
length_of_stays : List[float]
Example : [7 , 14]
start_date_col : str, optional
by default "visit_start_datetime"
end_date_col : str, optional
by default "visit_end_datetime"
Returns
-------
DataFrame
"""
table = table.assign(
length=(table[end_date_col] - table[start_date_col])
/ np.timedelta64(timedelta(days=1))
)

# Incomplete stays
table = table.assign(length_of_stay="Not specified")
table["length_of_stay"] = table.length_of_stay.mask(
table[end_date_col].isna(),
"Incomplete stay",
)

# Complete stays
min_duration = length_of_stays[0]
max_duration = length_of_stays[-1]
table["length_of_stay"] = table["length_of_stay"].mask(
(table["length"] <= min_duration),
"<= {} days".format(min_duration),
)
table["length_of_stay"] = table["length_of_stay"].mask(
(table["length"] >= max_duration),
">= {} days".format(max_duration),
)
for min_length, max_length in zip(length_of_stays[:-1], length_of_stays[1:]):
table["length_of_stay"] = table["length_of_stay"].mask(
(table["length"] >= min_length) & (table["length"] < max_length),
"{} days - {} days".format(min_length, max_length),
)
table = table.drop(columns="length")

return table


def tag_table_with_age(
table: DataFrame, date_col: str, person: DataFrame, age_ranges: List[int] = None
):
"""Tag table with person age
Parameters
----------
table : DataFrame
must contain person_id and date_col
date_column: str
date column from table on which to compute age
person : DataFrame
must contain person_id
age_ranges : List[int]
if None, simply compute age.
example : None, [18], [18, 60]
Returns
-------
DataFrame
"""
check_columns(df=person, required_columns=["person_id", "birth_datetime"])
check_columns(df=table, required_columns=[date_col, "person_id"])

table = table.merge(person[["person_id", "birth_datetime"]], on="person_id")

table["age"] = (table[date_col] - table["birth_datetime"]) / (
np.timedelta64(timedelta(days=1)) * 356
)
table["age"] = table["age"].astype(int)

table["age_range"] = "Not specified"
if age_ranges:
age_ranges.sort()
table.loc[table.age <= age_ranges[0], "age_range"] = f"age <= {age_ranges[0]}"

for age_min, age_max in zip(age_ranges[:-1], age_ranges[1:]):
in_range = (table.age > age_min) & (table.age <= age_max)
table.loc[in_range, "age_range"] = f"{age_min} < age <= {age_max}"

table.loc[table.age > age_ranges[-1], "age_range"] = f"age > {age_ranges[-1]}"

return table
131 changes: 131 additions & 0 deletions tests/test_process_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import pandas as pd
import pytest

from eds_scikit.utils import framework
from eds_scikit.utils.process_table import (
tag_table_by_type,
tag_table_period_length,
tag_table_with_age,
)

# Generate random data for the first dataframe
num_rows = 1000
table = {
"condition_source_value": ["E100", "E101", "E110", "A001", "B002"],
"visit_start_datetime": [
"2021-05-16",
"2018-08-16",
"2023-03-14",
"2023-05-09",
"2022-07-17",
],
"visit_end_datetime": [
"2021-05-26",
"2018-09-16",
"2023-03-15",
"2023-10-10",
"2022-07-18",
],
"person_id": [0, 1, 2, 3, 4],
}

table = pd.DataFrame(table)
table["visit_start_datetime"] = pd.to_datetime(table["visit_start_datetime"])
table["visit_end_datetime"] = pd.to_datetime(table["visit_end_datetime"])

# Generate random data for the second dataframe
person = {
"person_id": [0, 1, 2, 3, 4],
"birth_datetime": [
"2000-03-29",
"1990-04-08",
"1975-09-28",
"1970-04-28",
"1975-10-03",
],
}
person["birth_datetime"] = pd.to_datetime(person["birth_datetime"])

person = pd.DataFrame(person)


@pytest.mark.parametrize("module", ["pandas", "koalas"])
def test_tag_table_with_age(module):

person_fr = framework.to(module, person)
table_fr = framework.to(module, table)

table_with_age = tag_table_with_age(
table_fr, "visit_start_datetime", person_fr, age_ranges=[24, 30, 40]
)
table_with_age = framework.to("pandas", table_with_age)
assert (
table_with_age["age_range"]
== pd.Series(
["age <= 24", "24 < age <= 30", "age > 40", "age > 40", "age > 40"],
name="age_range",
)
).all()

table_with_age = tag_table_with_age(
table_fr, "visit_start_datetime", person_fr, age_ranges=None
)
table_with_age = framework.to("pandas", table_with_age)
assert (table_with_age["age"] == pd.Series([21, 29, 48, 54, 48], name="age")).all()


@pytest.mark.parametrize("module", ["pandas", "koalas"])
def test_table_by_type(module):

table_fr = framework.to(module, table)

table_by_type = tag_table_by_type(
table_fr,
type_groups={"DIABETES_TYPE_I": r"^E10", "DIABETES_TYPE_II": r"^E11"},
source_col="condition_source_value",
target_col="tag",
)
table_by_type = framework.to("pandas", table_by_type)
assert (
table_by_type["tag"]
== pd.Series(
[
"DIABETES_TYPE_I",
"DIABETES_TYPE_I",
"DIABETES_TYPE_II",
"OTHER",
"OTHER",
],
name="tag",
)
).all()
table_by_type = tag_table_by_type(
table_fr,
type_groups={"DIABETES_TYPE_I": r"^E10", "DIABETES_TYPE_II": r"^E11"},
source_col="condition_source_value",
target_col="tag",
filter_table=True,
)
table_by_type = framework.to("pandas", table_by_type)
assert (
table_by_type["tag"]
== pd.Series(
["DIABETES_TYPE_I", "DIABETES_TYPE_I", "DIABETES_TYPE_II"], name="tag"
)
).all()


@pytest.mark.parametrize("module", ["pandas", "koalas"])
def test_tag_table_period_length(module):

table_fr = framework.to(module, table)

table_period_length = tag_table_period_length(table_fr, length_of_stays=[7, 14])
table_period_length = framework.to("pandas", table_period_length)
assert (
table_period_length["length_of_stay"]
== pd.Series(
["7 days - 14 days", ">= 14 days", "<= 7 days", ">= 14 days", "<= 7 days"],
name="tag",
)
).all()

0 comments on commit 77f8598

Please sign in to comment.