From ff3b42bc8d927e4dc28c9ef40f2e02ae275efdeb Mon Sep 17 00:00:00 2001 From: svittoz Date: Fri, 14 Jun 2024 10:28:00 +0000 Subject: [PATCH] Fix merge_visits --- changelog.md | 5 +++ eds_scikit/period/stays.py | 37 ++++++++++--------- eds_scikit/utils/sort_values_first.py | 30 ++++++++++++++++ tests/test_sort_values_first.py | 51 +++++++++++++++++++++++++++ tests/test_visit_merging.py | 15 ++++---- 5 files changed, 113 insertions(+), 25 deletions(-) create mode 100644 eds_scikit/utils/sort_values_first.py create mode 100644 tests/test_sort_values_first.py diff --git a/changelog.md b/changelog.md index 712b4e57..fd7b69f0 100644 --- a/changelog.md +++ b/changelog.md @@ -1,5 +1,10 @@ # Changelog +## Unreleased + +### Fixed +- Fix merge_visits sort_values.groupby.first + ## v0.1.8 (2024-06-13) ### Fixed diff --git a/eds_scikit/period/stays.py b/eds_scikit/period/stays.py index 13f76033..1ab79859 100644 --- a/eds_scikit/period/stays.py +++ b/eds_scikit/period/stays.py @@ -6,6 +6,7 @@ from eds_scikit.utils.checks import MissingConceptError, algo_checker, concept_checker from eds_scikit.utils.datetime_helpers import substract_datetime from eds_scikit.utils.framework import get_framework +from eds_scikit.utils.sort_values_first import sort_values_first from eds_scikit.utils.typing import DataFrame @@ -73,10 +74,10 @@ def cleaning( @concept_checker(concepts=["STAY_ID", "CONTIGUOUS_STAY_ID"]) def merge_visits( vo: DataFrame, + open_stay_end_datetime: Optional[datetime], remove_deleted_visits: bool = True, long_stay_threshold: timedelta = timedelta(days=365), long_stay_filtering: Optional[str] = "all", - open_stay_end_datetime: Optional[datetime] = None, max_timedelta: timedelta = timedelta(days=2), merge_different_hospitals: bool = False, merge_different_source_values: Union[bool, List[str]] = ["hospitalisés", "urgence"], @@ -108,6 +109,11 @@ def merge_visits( - care_site_id (if ``merge_different_hospitals == True``) - visit_source_value (if ``merge_different_source_values != False``) - row_status_source_value (if ``remove_deleted_visits= True``) + open_stay_end_datetime: Optional[datetime] + Datetime to use in order to fill the `visit_end_datetime` of open visits. This is necessary in + order to compute stay duration and to filter long stays. + You might provide the extraction date of your data or datetime.now() + (be aware it will produce undeterministic outputs). remove_deleted_visits: bool Wether to remove deleted visits from the merging procedure. Deleted visits are extracted via the `row_status_source_value` column @@ -126,10 +132,6 @@ def merge_visits( Long stays are determined by the ``long_stay_threshold`` value. long_stay_threshold : timedelta Minimum visit duration value to consider a visit as candidate for "long visits filtering" - open_stay_end_datetime: Optional[datetime] - Datetime to use in order to fill the `visit_end_datetime` of open visits. This is necessary in - order to compute stay duration and to filter long stays. If not provided `datetime.now()` will be used. - You might provide the extraction date of your data here. max_timedelta : timedelta Maximum time difference between the end of a visit and the start of another to consider them as belonging to the same stay. This duration is internally converted in seconds before @@ -291,21 +293,18 @@ def get_first( how="inner", ) - # Getting the corresponding first visit - first_visit = ( - merged.sort_values( - by=[flag_name, "visit_start_datetime_1"], ascending=[False, False] - ) - .groupby("visit_occurrence_id_2") - .first()["visit_occurrence_id_1"] - .reset_index() - .rename( - columns={ - "visit_occurrence_id_1": f"{concept_prefix}STAY_ID", - "visit_occurrence_id_2": "visit_occurrence_id", - } - ) + first_visit = sort_values_first( + merged, + by_cols=["visit_occurrence_id_2"], + cols=[flag_name, "visit_start_datetime_1", "visit_occurrence_id_1"], + ) + first_visit = first_visit.rename( + columns={ + "visit_occurrence_id_1": f"{concept_prefix}STAY_ID", + "visit_occurrence_id_2": "visit_occurrence_id", + } ) + first_visit = first_visit[["visit_occurrence_id", f"{concept_prefix}STAY_ID"]] return merged, first_visit diff --git a/eds_scikit/utils/sort_values_first.py b/eds_scikit/utils/sort_values_first.py new file mode 100644 index 00000000..a02b944d --- /dev/null +++ b/eds_scikit/utils/sort_values_first.py @@ -0,0 +1,30 @@ +from typing import List + +from eds_scikit.utils.typing import DataFrame + + +def sort_values_first( + df: DataFrame, by_cols: List[str], cols: List[str], ascending: bool = False +): + """ + Replace dataframe.sort_value(cols).groupby(by_cols).first() + + Parameters + ---------- + df : DataFrame + by_cols : List[str] + columns to groupby + cols : List[str] + columns to sort + ascending : bool + """ + + return ( + df.groupby(by_cols) + .apply( + lambda group: group.sort_values( + by=cols, ascending=[ascending for i in cols] + ).head(1) + ) + .reset_index(drop=True) + ) diff --git a/tests/test_sort_values_first.py b/tests/test_sort_values_first.py new file mode 100644 index 00000000..df41c189 --- /dev/null +++ b/tests/test_sort_values_first.py @@ -0,0 +1,51 @@ +import numpy as np +import pandas as pd +import pytest + +from eds_scikit.utils import framework +from eds_scikit.utils.sort_values_first import sort_values_first +from eds_scikit.utils.test_utils import assert_equal_no_order + +# Create a DataFrame +np.random.seed(0) +size = 10000 +data = { + "A": np.random.choice(["X", "Y", "Z"], size), + "B": np.random.randint(1, 5, size), + "C": np.random.randint(1, 5, size), + "D": np.random.randint(1, 5, size), + "E": np.random.randint(1, 5, size), +} + +inputs = pd.DataFrame(data) +inputs.loc[0, "B"] = 0 +inputs.loc[0, "C"] = 4 + + +@pytest.mark.parametrize( + "module", + ["pandas", "koalas"], +) +def test_sort_values_first(module): + + inputs_fr = framework.to(module, inputs) + + computed = framework.pandas( + sort_values_first(inputs_fr, ["A"], ["B", "C"], ascending=True) + ) + expected = ( + inputs.sort_values(["B", "C"], ascending=True) + .groupby("A", as_index=False) + .first() + ) + assert_equal_no_order(computed, expected) + + computed = framework.pandas( + sort_values_first(inputs_fr, ["A"], ["B", "C"], ascending=False) + ) + expected = ( + inputs.sort_values(["B", "C"], ascending=False) + .groupby("A", as_index=False) + .first() + ) + assert_equal_no_order(computed, expected) diff --git a/tests/test_visit_merging.py b/tests/test_visit_merging.py index 0251d568..e5693046 100644 --- a/tests/test_visit_merging.py +++ b/tests/test_visit_merging.py @@ -1,3 +1,5 @@ +from datetime import datetime + import pandas as pd import pytest @@ -43,7 +45,10 @@ ] -@pytest.mark.parametrize("module", ["pandas", "koalas"]) +@pytest.mark.parametrize( + "module", + ["pandas", "koalas"], +) @pytest.mark.parametrize( "params, results", [(params, results) for params, results in zip(all_params, all_results)], @@ -53,9 +58,7 @@ def test_visit_merging(module, params, results): results = framework.to(module, results) vo = framework.to(module, ds.visit_occurrence) - merged = merge_visits(vo, **params) + merged = merge_visits(vo, datetime(2023, 1, 1), **params) merged = framework.pandas(merged) - - assert_equal_no_order( - merged[["visit_occurrence_id", "STAY_ID", "CONTIGUOUS_STAY_ID"]], results - ) + merged = merged[results.columns] + assert_equal_no_order(merged, results, check_dtype=False)