Skip to content

Commit

Permalink
pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
svittoz committed Apr 17, 2024
1 parent 46ace4c commit a57b167
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 29 deletions.
29 changes: 16 additions & 13 deletions eds_scikit/period/stays.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
from eds_scikit.utils.checks import MissingConceptError, algo_checker, concept_checker
from eds_scikit.utils.datetime_helpers import substract_datetime
from eds_scikit.utils.framework import get_framework
from eds_scikit.utils.typing import DataFrame
from eds_scikit.utils.sort_first_koalas import sort_values_first_koalas
from eds_scikit.utils.typing import DataFrame

import pandas as pd

def cleaning(
vo,
Expand Down Expand Up @@ -71,6 +70,7 @@ def cleaning(

return vo[~mask], vo[mask]


@concept_checker(concepts=["STAY_ID", "CONTIGUOUS_STAY_ID"])
def merge_visits(
vo: DataFrame,
Expand Down Expand Up @@ -291,23 +291,26 @@ def get_first(
right_index=True,
how="inner",
)

first_visit = sort_values_first_koalas(merged,
by_cols=["visit_occurrence_id_2"],
cols=[flag_name, "visit_start_datetime_1"],
ascending=False
).rename(

first_visit = sort_values_first_koalas(
merged,
by_cols=["visit_occurrence_id_2"],
cols=[flag_name, "visit_start_datetime_1"],
ascending=False,
).rename(
columns={
"visit_occurrence_id_1": f"{concept_prefix}STAY_ID",
"visit_occurrence_id_2": "visit_occurrence_id",
})[[f"{concept_prefix}STAY_ID", "visit_occurrence_id"]]


}
)[
[f"{concept_prefix}STAY_ID", "visit_occurrence_id"]
]

return merged, first_visit

merged, first_contiguous_visit = get_first(merged, contiguous_only=True)
merged, first_contiguous_visit = get_first(merged, contiguous_only=True)
merged, first_visit = get_first(merged, contiguous_only=False)

# Concatenating merge visits with previously discarded ones
results = fw.concat(
[
Expand Down
4 changes: 3 additions & 1 deletion eds_scikit/utils/sort_first_koalas.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
def sort_values_first_koalas(dataframe, by_cols, cols, ascending=True):
for col in cols:
dataframe_min_max = dataframe.groupby(by_cols, as_index=False)[col]
dataframe_min_max = dataframe_min_max.min() if ascending else dataframe_min_max.max()
dataframe_min_max = (
dataframe_min_max.min() if ascending else dataframe_min_max.max()
)
dataframe = dataframe.merge(dataframe_min_max, on=[*by_cols, col], how="right")
return dataframe
61 changes: 46 additions & 15 deletions tests/test_sort_first_koalas.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,68 @@
import pandas as pd
from eds_scikit.utils.sort_first_koalas import sort_values_first_koalas
from numpy import array
import pytest
from numpy import array

from eds_scikit.utils import framework
from eds_scikit.utils.sort_first_koalas import sort_values_first_koalas
from eds_scikit.utils.test_utils import assert_equal_no_order

data = {
'A': array(['X', 'Y', 'X', 'Y', 'Y', 'Z', 'X', 'Z', 'X', 'X', 'X', 'Z', 'Y', 'Z', 'Z', 'X', 'Y', 'Y', 'Y', 'Y'], dtype='<U1'),
'B': array([9, 5, 4, 1, 4, 6, 1, 3, 4, 9, 2, 4, 4, 4, 8, 1, 2, 1, 5, 8]),
'C': array([4, 3, 8, 3, 1, 1, 5, 6, 6, 7, 9, 5, 2, 5, 9, 2, 2, 8, 4, 7]),
'D': array([8, 3, 1, 4, 6, 5, 5, 7, 5, 5, 4, 5, 5, 9, 5, 4, 8, 6, 6, 1]),
'E': array([2, 6, 4, 1, 6, 1, 2, 3, 5, 3, 1, 4, 3, 1, 8, 6, 1, 3, 8, 3])
"A": array(
[
"X",
"Y",
"X",
"Y",
"Y",
"Z",
"X",
"Z",
"X",
"X",
"X",
"Z",
"Y",
"Z",
"Z",
"X",
"Y",
"Y",
"Y",
"Y",
],
dtype="<U1",
),
"B": array([9, 5, 4, 1, 4, 6, 1, 3, 4, 9, 2, 4, 4, 4, 8, 1, 2, 1, 5, 8]),
"C": array([4, 3, 8, 3, 1, 1, 5, 6, 6, 7, 9, 5, 2, 5, 9, 2, 2, 8, 4, 7]),
"D": array([8, 3, 1, 4, 6, 5, 5, 7, 5, 5, 4, 5, 5, 9, 5, 4, 8, 6, 6, 1]),
"E": array([2, 6, 4, 1, 6, 1, 2, 3, 5, 3, 1, 4, 3, 1, 8, 6, 1, 3, 8, 3]),
}

all_inputs = [
pd.DataFrame(data)
]
all_inputs = [pd.DataFrame(data)]

all_params = dict(ascending=[True, False], cols=["A", ["A", "B"]], by_cols=[["B", "C", "D"], ["C", "D"]])
all_params = dict(
ascending=[True, False],
cols=["A", ["A", "B"]],
by_cols=[["B", "C", "D"], ["C", "D"]],
)
all_params = pd.DataFrame(all_params).to_dict("records")


@pytest.mark.parametrize("module", ["pandas", "koalas"])
@pytest.mark.parametrize(
"params, inputs",
[(params, inputs) for params, inputs in zip(all_params, all_inputs)],
)
def test_visit_merging(module, params, inputs):
print(params)
expected_results = inputs.sort_values(params["cols"]).groupby(params["by_cols"]).first().reset_index()
expected_results = (
inputs.sort_values(params["cols"])
.groupby(params["by_cols"])
.first()
.reset_index()
)
inputs = framework.to(module, inputs)
results = sort_values_first_koalas(inputs, **params)
results = framework.pandas(results)

assert_equal_no_order(
results, expected_results
)
assert_equal_no_order(results, expected_results)

0 comments on commit a57b167

Please sign in to comment.