Skip to content

Commit

Permalink
Fix merge_visits
Browse files Browse the repository at this point in the history
  • Loading branch information
svittoz committed Jun 14, 2024
1 parent 9cd7f1a commit 10ba3db
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 25 deletions.
5 changes: 5 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# Changelog

## Unreleased

### Fixed
- Fix merge_visits sort_values.groupby.first

## v0.1.8 (2024-06-13)

### Fixed
Expand Down
37 changes: 18 additions & 19 deletions eds_scikit/period/stays.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from eds_scikit.utils.checks import MissingConceptError, algo_checker, concept_checker
from eds_scikit.utils.datetime_helpers import substract_datetime
from eds_scikit.utils.framework import get_framework
from eds_scikit.utils.sort_values_first import sort_values_first
from eds_scikit.utils.typing import DataFrame


Expand Down Expand Up @@ -73,10 +74,10 @@ def cleaning(
@concept_checker(concepts=["STAY_ID", "CONTIGUOUS_STAY_ID"])
def merge_visits(
vo: DataFrame,
open_stay_end_datetime: Optional[datetime],
remove_deleted_visits: bool = True,
long_stay_threshold: timedelta = timedelta(days=365),
long_stay_filtering: Optional[str] = "all",
open_stay_end_datetime: Optional[datetime] = None,
max_timedelta: timedelta = timedelta(days=2),
merge_different_hospitals: bool = False,
merge_different_source_values: Union[bool, List[str]] = ["hospitalisés", "urgence"],
Expand Down Expand Up @@ -108,6 +109,11 @@ def merge_visits(
- care_site_id (if ``merge_different_hospitals == True``)
- visit_source_value (if ``merge_different_source_values != False``)
- row_status_source_value (if ``remove_deleted_visits= True``)
open_stay_end_datetime: Optional[datetime]
Datetime to use in order to fill the `visit_end_datetime` of open visits. This is necessary in
order to compute stay duration and to filter long stays.
You might provide the extraction date of your data or datetime.now()
(be aware it will produce undeterministic outputs).
remove_deleted_visits: bool
Wether to remove deleted visits from the merging procedure.
Deleted visits are extracted via the `row_status_source_value` column
Expand All @@ -126,10 +132,6 @@ def merge_visits(
Long stays are determined by the ``long_stay_threshold`` value.
long_stay_threshold : timedelta
Minimum visit duration value to consider a visit as candidate for "long visits filtering"
open_stay_end_datetime: Optional[datetime]
Datetime to use in order to fill the `visit_end_datetime` of open visits. This is necessary in
order to compute stay duration and to filter long stays. If not provided `datetime.now()` will be used.
You might provide the extraction date of your data here.
max_timedelta : timedelta
Maximum time difference between the end of a visit and the start of another to consider
them as belonging to the same stay. This duration is internally converted in seconds before
Expand Down Expand Up @@ -291,21 +293,18 @@ def get_first(
how="inner",
)

# Getting the corresponding first visit
first_visit = (
merged.sort_values(
by=[flag_name, "visit_start_datetime_1"], ascending=[False, False]
)
.groupby("visit_occurrence_id_2")
.first()["visit_occurrence_id_1"]
.reset_index()
.rename(
columns={
"visit_occurrence_id_1": f"{concept_prefix}STAY_ID",
"visit_occurrence_id_2": "visit_occurrence_id",
}
)
first_visit = sort_values_first(
merged,
by_cols=["visit_occurrence_id_2"],
cols=[flag_name, "visit_start_datetime_1", "visit_occurrence_id_1"],
)
first_visit = first_visit.rename(
columns={
"visit_occurrence_id_1": f"{concept_prefix}STAY_ID",
"visit_occurrence_id_2": "visit_occurrence_id",
}
)
first_visit = first_visit[["visit_occurrence_id", f"{concept_prefix}STAY_ID"]]

return merged, first_visit

Expand Down
30 changes: 30 additions & 0 deletions eds_scikit/utils/sort_values_first.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from typing import List

from eds_scikit.utils.typing import DataFrame


def sort_values_first(
df: DataFrame, by_cols: List[str], cols: List[str], ascending: bool = False
):
"""
Replace dataframe.sort_value(cols).groupby(by_cols).first()
Parameters
----------
df : DataFrame
by_cols : List[str]
columns to groupby
cols : List[str]
columns to sort
ascending : bool
"""

return (
df.groupby(by_cols)
.apply(
lambda group: group.sort_values(
by=cols, ascending=[ascending for i in cols]
).head(1)
)
.reset_index(drop=True)
)
52 changes: 52 additions & 0 deletions tests/test_sort_values_first.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import numpy as np
import pandas as pd
import pytest
from databricks import koalas as ks

from eds_scikit.utils import framework
from eds_scikit.utils.sort_values_first import sort_values_first
from eds_scikit.utils.test_utils import assert_equal_no_order

# Create a DataFrame
np.random.seed(0)
size = 10000
data = {
"A": np.random.choice(["X", "Y", "Z"], size),
"B": np.random.randint(1, 5, size),
"C": np.random.randint(1, 5, size),
"D": np.random.randint(1, 5, size),
"E": np.random.randint(1, 5, size),
}

inputs = pd.DataFrame(data)
inputs.loc[0, "B"] = 0
inputs.loc[0, "C"] = 4


@pytest.mark.parametrize(
"module",
["pandas", "koalas"],
)
def test_sort_values_first(module):

inputs_fr = framework.to(module, inputs)

computed = framework.pandas(
sort_values_first(inputs_fr, ["A"], ["B", "C"], ascending=True)
)
expected = (
inputs.sort_values(["B", "C"], ascending=True)
.groupby("A", as_index=False)
.first()
)
assert_equal_no_order(computed, expected)

computed = framework.pandas(
sort_values_first(inputs_fr, ["A"], ["B", "C"], ascending=False)
)
expected = (
inputs.sort_values(["B", "C"], ascending=False)
.groupby("A", as_index=False)
.first()
)
assert_equal_no_order(computed, expected)
15 changes: 9 additions & 6 deletions tests/test_visit_merging.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from datetime import datetime

import pandas as pd
import pytest

Expand Down Expand Up @@ -43,7 +45,10 @@
]


@pytest.mark.parametrize("module", ["pandas", "koalas"])
@pytest.mark.parametrize(
"module",
["pandas", "koalas"],
)
@pytest.mark.parametrize(
"params, results",
[(params, results) for params, results in zip(all_params, all_results)],
Expand All @@ -53,9 +58,7 @@ def test_visit_merging(module, params, results):
results = framework.to(module, results)

vo = framework.to(module, ds.visit_occurrence)
merged = merge_visits(vo, **params)
merged = merge_visits(vo, datetime(2023, 1, 1), **params)
merged = framework.pandas(merged)

assert_equal_no_order(
merged[["visit_occurrence_id", "STAY_ID", "CONTIGUOUS_STAY_ID"]], results
)
merged = merged[results.columns]
assert_equal_no_order(merged, results, check_dtype=False)

0 comments on commit 10ba3db

Please sign in to comment.