From b8d07168f7cacf7bde3a20ec534c81d6c70d68dc Mon Sep 17 00:00:00 2001 From: svittoz Date: Fri, 14 Jun 2024 10:28:00 +0000 Subject: [PATCH] Fix merge_visits --- changelog.md | 5 +++ eds_scikit/period/stays.py | 67 +++++++++++++++++++++++++++---------- tests/test_visit_merging.py | 15 +++++---- 3 files changed, 63 insertions(+), 24 deletions(-) diff --git a/changelog.md b/changelog.md index 712b4e57..fd7b69f0 100644 --- a/changelog.md +++ b/changelog.md @@ -1,5 +1,10 @@ # Changelog +## Unreleased + +### Fixed +- Fix merge_visits sort_values.groupby.first + ## v0.1.8 (2024-06-13) ### Fixed diff --git a/eds_scikit/period/stays.py b/eds_scikit/period/stays.py index 13f76033..3e5b05ea 100644 --- a/eds_scikit/period/stays.py +++ b/eds_scikit/period/stays.py @@ -73,10 +73,10 @@ def cleaning( @concept_checker(concepts=["STAY_ID", "CONTIGUOUS_STAY_ID"]) def merge_visits( vo: DataFrame, + open_stay_end_datetime: Optional[datetime], remove_deleted_visits: bool = True, long_stay_threshold: timedelta = timedelta(days=365), long_stay_filtering: Optional[str] = "all", - open_stay_end_datetime: Optional[datetime] = None, max_timedelta: timedelta = timedelta(days=2), merge_different_hospitals: bool = False, merge_different_source_values: Union[bool, List[str]] = ["hospitalisés", "urgence"], @@ -108,6 +108,11 @@ def merge_visits( - care_site_id (if ``merge_different_hospitals == True``) - visit_source_value (if ``merge_different_source_values != False``) - row_status_source_value (if ``remove_deleted_visits= True``) + open_stay_end_datetime: Optional[datetime] + Datetime to use in order to fill the `visit_end_datetime` of open visits. This is necessary in + order to compute stay duration and to filter long stays. + You might provide the extraction date of your data or datetime.now() + (be aware it will produce undeterministic outputs). remove_deleted_visits: bool Wether to remove deleted visits from the merging procedure. Deleted visits are extracted via the `row_status_source_value` column @@ -126,10 +131,6 @@ def merge_visits( Long stays are determined by the ``long_stay_threshold`` value. long_stay_threshold : timedelta Minimum visit duration value to consider a visit as candidate for "long visits filtering" - open_stay_end_datetime: Optional[datetime] - Datetime to use in order to fill the `visit_end_datetime` of open visits. This is necessary in - order to compute stay duration and to filter long stays. If not provided `datetime.now()` will be used. - You might provide the extraction date of your data here. max_timedelta : timedelta Maximum time difference between the end of a visit and the start of another to consider them as belonging to the same stay. This duration is internally converted in seconds before @@ -292,20 +293,50 @@ def get_first( ) # Getting the corresponding first visit - first_visit = ( - merged.sort_values( - by=[flag_name, "visit_start_datetime_1"], ascending=[False, False] - ) - .groupby("visit_occurrence_id_2") - .first()["visit_occurrence_id_1"] - .reset_index() - .rename( - columns={ - "visit_occurrence_id_1": f"{concept_prefix}STAY_ID", - "visit_occurrence_id_2": "visit_occurrence_id", - } - ) + # Replacement for : + # first_visit = merged.sort_values(by=[flag_name, "visit_start_datetime_1"], + # ascending=[False, False]) + # .groupby(visit_occurrence_id_2).first()["visit_occurrence_id_1"] + # which is not deterministic in Koalas + + flagged = ( + merged[merged[flag_name]] + .groupby("visit_occurrence_id_2", as_index=False)[ + ["visit_start_datetime_1"] + ] + .max() + ) + flagged = merged[merged[flag_name]].merge( + flagged, on=["visit_occurrence_id_2", "visit_start_datetime_1"], how="right" + ) + flagged["flagged"] = True + unflagged = ( + merged[~merged[flag_name]] + .groupby("visit_occurrence_id_2", as_index=False)[ + ["visit_start_datetime_1"] + ] + .max() + ) + unflagged = merged[~merged[flag_name]].merge( + unflagged, + on=["visit_occurrence_id_2", "visit_start_datetime_1"], + how="right", + ) + unflagged = unflagged.merge( + flagged[["visit_occurrence_id_2", "flagged"]], + on="visit_occurrence_id_2", + how="left", + ) + unflagged = unflagged[unflagged.flagged.isna()] + first_visit = fw.concat((flagged, unflagged), axis=0) + + first_visit = first_visit.rename( + columns={ + "visit_occurrence_id_1": f"{concept_prefix}STAY_ID", + "visit_occurrence_id_2": "visit_occurrence_id", + } ) + first_visit = first_visit[["visit_occurrence_id", f"{concept_prefix}STAY_ID"]] return merged, first_visit diff --git a/tests/test_visit_merging.py b/tests/test_visit_merging.py index 0251d568..e5693046 100644 --- a/tests/test_visit_merging.py +++ b/tests/test_visit_merging.py @@ -1,3 +1,5 @@ +from datetime import datetime + import pandas as pd import pytest @@ -43,7 +45,10 @@ ] -@pytest.mark.parametrize("module", ["pandas", "koalas"]) +@pytest.mark.parametrize( + "module", + ["pandas", "koalas"], +) @pytest.mark.parametrize( "params, results", [(params, results) for params, results in zip(all_params, all_results)], @@ -53,9 +58,7 @@ def test_visit_merging(module, params, results): results = framework.to(module, results) vo = framework.to(module, ds.visit_occurrence) - merged = merge_visits(vo, **params) + merged = merge_visits(vo, datetime(2023, 1, 1), **params) merged = framework.pandas(merged) - - assert_equal_no_order( - merged[["visit_occurrence_id", "STAY_ID", "CONTIGUOUS_STAY_ID"]], results - ) + merged = merged[results.columns] + assert_equal_no_order(merged, results, check_dtype=False)