diff --git a/eds_scikit/period/stays.py b/eds_scikit/period/stays.py index 13f76033..5d3532e4 100644 --- a/eds_scikit/period/stays.py +++ b/eds_scikit/period/stays.py @@ -7,7 +7,9 @@ from eds_scikit.utils.datetime_helpers import substract_datetime from eds_scikit.utils.framework import get_framework from eds_scikit.utils.typing import DataFrame +from eds_scikit.utils.sort_first_koalas import sort_values_first_koalas +import pandas as pd def cleaning( vo, @@ -69,7 +71,6 @@ def cleaning( return vo[~mask], vo[mask] - @concept_checker(concepts=["STAY_ID", "CONTIGUOUS_STAY_ID"]) def merge_visits( vo: DataFrame, @@ -290,28 +291,23 @@ def get_first( right_index=True, how="inner", ) - - # Getting the corresponding first visit - first_visit = ( - merged.sort_values( - by=[flag_name, "visit_start_datetime_1"], ascending=[False, False] - ) - .groupby("visit_occurrence_id_2") - .first()["visit_occurrence_id_1"] - .reset_index() - .rename( - columns={ - "visit_occurrence_id_1": f"{concept_prefix}STAY_ID", - "visit_occurrence_id_2": "visit_occurrence_id", - } - ) - ) - + + first_visit = sort_values_first_koalas(merged, + by_cols=["visit_occurrence_id_2"], + cols=[flag_name, "visit_start_datetime_1"], + ascending=False + ).rename( + columns={ + "visit_occurrence_id_1": f"{concept_prefix}STAY_ID", + "visit_occurrence_id_2": "visit_occurrence_id", + })[[f"{concept_prefix}STAY_ID", "visit_occurrence_id"]] + + return merged, first_visit - merged, first_contiguous_visit = get_first(merged, contiguous_only=True) + merged, first_contiguous_visit = get_first(merged, contiguous_only=True) merged, first_visit = get_first(merged, contiguous_only=False) - + # Concatenating merge visits with previously discarded ones results = fw.concat( [ diff --git a/eds_scikit/utils/sort_first_koalas.py b/eds_scikit/utils/sort_first_koalas.py new file mode 100644 index 00000000..22355982 --- /dev/null +++ b/eds_scikit/utils/sort_first_koalas.py @@ -0,0 +1,6 @@ +def sort_values_first_koalas(dataframe, by_cols, cols, ascending=True): + for col in cols: + dataframe_min_max = dataframe.groupby(by_cols, as_index=False)[col] + dataframe_min_max = dataframe_min_max.min() if ascending else dataframe_min_max.max() + dataframe = dataframe.merge(dataframe_min_max, on=[*by_cols, col], how="right") + return dataframe diff --git a/tests/test_sort_first_koalas.py b/tests/test_sort_first_koalas.py new file mode 100644 index 00000000..bdb4cdef --- /dev/null +++ b/tests/test_sort_first_koalas.py @@ -0,0 +1,37 @@ +import pandas as pd +from eds_scikit.utils.sort_first_koalas import sort_values_first_koalas +from numpy import array +import pytest +from eds_scikit.utils import framework +from eds_scikit.utils.test_utils import assert_equal_no_order + +data = { + 'A': array(['X', 'Y', 'X', 'Y', 'Y', 'Z', 'X', 'Z', 'X', 'X', 'X', 'Z', 'Y', 'Z', 'Z', 'X', 'Y', 'Y', 'Y', 'Y'], dtype='