Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix undeterministic merge_visits #70

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# Changelog

## Unreleased

### Fixed
- Fix merge_visits sort_values.groupby.first

## v0.1.8 (2024-06-13)

### Fixed
Expand Down
37 changes: 18 additions & 19 deletions eds_scikit/period/stays.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from eds_scikit.utils.checks import MissingConceptError, algo_checker, concept_checker
from eds_scikit.utils.datetime_helpers import substract_datetime
from eds_scikit.utils.framework import get_framework
from eds_scikit.utils.sort_values_first import sort_values_first
from eds_scikit.utils.typing import DataFrame


Expand Down Expand Up @@ -73,10 +74,10 @@ def cleaning(
@concept_checker(concepts=["STAY_ID", "CONTIGUOUS_STAY_ID"])
def merge_visits(
vo: DataFrame,
open_stay_end_datetime: Optional[datetime],
remove_deleted_visits: bool = True,
long_stay_threshold: timedelta = timedelta(days=365),
long_stay_filtering: Optional[str] = "all",
open_stay_end_datetime: Optional[datetime] = None,
max_timedelta: timedelta = timedelta(days=2),
merge_different_hospitals: bool = False,
merge_different_source_values: Union[bool, List[str]] = ["hospitalisés", "urgence"],
Expand Down Expand Up @@ -108,6 +109,11 @@ def merge_visits(
- care_site_id (if ``merge_different_hospitals == True``)
- visit_source_value (if ``merge_different_source_values != False``)
- row_status_source_value (if ``remove_deleted_visits= True``)
open_stay_end_datetime: Optional[datetime]
Datetime to use in order to fill the `visit_end_datetime` of open visits. This is necessary in
order to compute stay duration and to filter long stays.
You might provide the extraction date of your data or datetime.now()
(be aware it will produce undeterministic outputs).
remove_deleted_visits: bool
Wether to remove deleted visits from the merging procedure.
Deleted visits are extracted via the `row_status_source_value` column
Expand All @@ -126,10 +132,6 @@ def merge_visits(
Long stays are determined by the ``long_stay_threshold`` value.
long_stay_threshold : timedelta
Minimum visit duration value to consider a visit as candidate for "long visits filtering"
open_stay_end_datetime: Optional[datetime]
Datetime to use in order to fill the `visit_end_datetime` of open visits. This is necessary in
order to compute stay duration and to filter long stays. If not provided `datetime.now()` will be used.
You might provide the extraction date of your data here.
max_timedelta : timedelta
Maximum time difference between the end of a visit and the start of another to consider
them as belonging to the same stay. This duration is internally converted in seconds before
Expand Down Expand Up @@ -291,21 +293,18 @@ def get_first(
how="inner",
)

# Getting the corresponding first visit
first_visit = (
merged.sort_values(
by=[flag_name, "visit_start_datetime_1"], ascending=[False, False]
)
.groupby("visit_occurrence_id_2")
.first()["visit_occurrence_id_1"]
.reset_index()
.rename(
columns={
"visit_occurrence_id_1": f"{concept_prefix}STAY_ID",
"visit_occurrence_id_2": "visit_occurrence_id",
}
)
first_visit = sort_values_first(
merged,
by_cols=["visit_occurrence_id_2"],
cols=[flag_name, "visit_start_datetime_1", "visit_occurrence_id_1"],
)
first_visit = first_visit.rename(
columns={
"visit_occurrence_id_1": f"{concept_prefix}STAY_ID",
"visit_occurrence_id_2": "visit_occurrence_id",
}
)
first_visit = first_visit[["visit_occurrence_id", f"{concept_prefix}STAY_ID"]]

return merged, first_visit

Expand Down
30 changes: 30 additions & 0 deletions eds_scikit/utils/sort_values_first.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from typing import List

from eds_scikit.utils.typing import DataFrame


def sort_values_first(
df: DataFrame, by_cols: List[str], cols: List[str], ascending: bool = False
):
"""
Replace dataframe.sort_value(cols).groupby(by_cols).first()

Parameters
----------
df : DataFrame
by_cols : List[str]
columns to groupby
cols : List[str]
columns to sort
ascending : bool
"""

return (
df.groupby(by_cols)
.apply(
lambda group: group.sort_values(
by=cols, ascending=[ascending for i in cols]
).head(1)
)
.reset_index(drop=True)
)
51 changes: 51 additions & 0 deletions tests/test_sort_values_first.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import numpy as np
import pandas as pd
import pytest

from eds_scikit.utils import framework
from eds_scikit.utils.sort_values_first import sort_values_first
from eds_scikit.utils.test_utils import assert_equal_no_order

# Create a DataFrame
np.random.seed(0)
size = 10000
data = {
"A": np.random.choice(["X", "Y", "Z"], size),
"B": np.random.randint(1, 5, size),
"C": np.random.randint(1, 5, size),
"D": np.random.randint(1, 5, size),
"E": np.random.randint(1, 5, size),
}

inputs = pd.DataFrame(data)
inputs.loc[0, "B"] = 0
inputs.loc[0, "C"] = 4


@pytest.mark.parametrize(
"module",
["pandas", "koalas"],
)
def test_sort_values_first(module):

inputs_fr = framework.to(module, inputs)

computed = framework.pandas(
sort_values_first(inputs_fr, ["A"], ["B", "C"], ascending=True)
)
expected = (
inputs.sort_values(["B", "C"], ascending=True)
.groupby("A", as_index=False)
.first()
)
assert_equal_no_order(computed, expected)

computed = framework.pandas(
sort_values_first(inputs_fr, ["A"], ["B", "C"], ascending=False)
)
expected = (
inputs.sort_values(["B", "C"], ascending=False)
.groupby("A", as_index=False)
.first()
)
assert_equal_no_order(computed, expected)
15 changes: 9 additions & 6 deletions tests/test_visit_merging.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from datetime import datetime

import pandas as pd
import pytest

Expand Down Expand Up @@ -43,7 +45,10 @@
]


@pytest.mark.parametrize("module", ["pandas", "koalas"])
@pytest.mark.parametrize(
"module",
["pandas", "koalas"],
)
@pytest.mark.parametrize(
"params, results",
[(params, results) for params, results in zip(all_params, all_results)],
Expand All @@ -53,9 +58,7 @@ def test_visit_merging(module, params, results):
results = framework.to(module, results)

vo = framework.to(module, ds.visit_occurrence)
merged = merge_visits(vo, **params)
merged = merge_visits(vo, datetime(2023, 1, 1), **params)
merged = framework.pandas(merged)

assert_equal_no_order(
merged[["visit_occurrence_id", "STAY_ID", "CONTIGUOUS_STAY_ID"]], results
)
merged = merged[results.columns]
assert_equal_no_order(merged, results, check_dtype=False)
Loading