Skip to content

Commit

Permalink
Small viz improvements (#21)
Browse files Browse the repository at this point in the history
* Update config viz

* Syntax replace not by is None

* Fix test randomness
  • Loading branch information
Aremaki authored Jul 28, 2023
1 parent 2babbc5 commit 0ca0b9c
Show file tree
Hide file tree
Showing 50 changed files with 547 additions and 495 deletions.
2 changes: 1 addition & 1 deletion edsteva/io/synthetic/note.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def _generate_note_step(
):
t_end = visit_care_site[date_col].max()
t0 = generator.integers(t0_visit, t_end)
c_before = generator.uniform(0, 0.2)
c_before = generator.uniform(0, 0.01)
c_after = generator.uniform(0.8, 1)
note_before_t0_visit = (
visit_care_site[visit_care_site[date_col] <= t0_visit][[id_visit_col, date_col]]
Expand Down
6 changes: 5 additions & 1 deletion edsteva/io/synthetic/synthetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def __post_init__(self):

def generate(self):
if self.seed:
np.random.seed(self.seed)
self.generator = np.random.default_rng(self.seed)
else:
self.generator = np.random.default_rng()
Expand Down Expand Up @@ -630,7 +631,10 @@ def _generate_measurement(
]
measurement[self.id_visit_col] = (
visit_care_site[self.id_visit_col]
.sample(n=measurement.shape[0], replace=True)
.sample(
n=measurement.shape[0],
replace=True,
)
.reset_index(drop=True)
)
measurement["value_as_number"] = [None] * missing_value + list(
Expand Down
22 changes: 21 additions & 1 deletion edsteva/models/rectangle_function/viz_configs/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def get_t_1_selection(predictor: DataFrame):


normalized_probe_line = dict(
legend_title="Mean",
encode=dict(
strokeDash=alt.StrokeDash(
"legend_predictor",
Expand All @@ -81,7 +82,7 @@ def get_t_1_selection(predictor: DataFrame):
orient="top",
),
)
)
),
)
probe_line = dict(
encode=dict(
Expand Down Expand Up @@ -155,6 +156,25 @@ def get_t_1_selection(predictor: DataFrame):
filters=[dict(filter=alt.datum.t_0 == alt.datum.max_t0)],
)

error_line = dict(
legend_title="Standard deviation",
mark_errorband=dict(
extent="stdev",
),
encode=dict(
stroke=alt.Stroke(
"legend_error_band",
title="Error band",
legend=alt.Legend(
symbolType="square",
orient="top",
labelFontSize=12,
labelFontStyle="bold",
),
),
),
)

horizontal_min_c0 = dict(
x=alt.X(
"min(c_0):Q",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from edsteva.utils.typing import DataFrame

from .defaults import (
error_line,
get_c_0_min_selection,
get_error_max_selection,
get_t_0_selection,
Expand Down Expand Up @@ -30,6 +31,7 @@ def get_normalized_probe_dashboard_config(self, predictor: DataFrame):
c_0_min_filter,
error_max_filter,
],
error_line=error_line,
probe_line=normalized_probe_line,
model_line=normalized_model_line,
extra_horizontal_bar_charts=[horizontal_min_c0, horizontal_max_error],
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from edsteva.utils.typing import DataFrame

from .defaults import (
error_line,
get_c_0_min_selection,
get_error_max_selection,
get_t_0_selection,
Expand Down Expand Up @@ -28,6 +29,7 @@ def get_normalized_probe_plot_config(self, predictor: DataFrame):
c_0_min_filter,
error_max_filter,
],
probe_line=normalized_probe_line,
error_line=error_line,
model_line=normalized_model_line,
probe_line=normalized_probe_line,
)
24 changes: 23 additions & 1 deletion edsteva/models/step_function/viz_configs/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def get_t_0_selection(predictor: DataFrame):


normalized_probe_line = dict(
legend_title="Mean",
encode=dict(
strokeDash=alt.StrokeDash(
"legend_predictor",
Expand All @@ -66,8 +67,9 @@ def get_t_0_selection(predictor: DataFrame):
orient="top",
),
)
)
),
)

probe_line = dict(
encode=dict(
strokeDash=alt.StrokeDash(
Expand All @@ -83,6 +85,7 @@ def get_t_0_selection(predictor: DataFrame):
)
)
)

normalized_model_line = dict(
mark_line=dict(
color="black",
Expand Down Expand Up @@ -140,6 +143,25 @@ def get_t_0_selection(predictor: DataFrame):
filters=[dict(filter=alt.datum.t_0 == alt.datum.max_t0)],
)

error_line = dict(
legend_title="Standard deviation",
mark_errorband=dict(
extent="stdev",
),
encode=dict(
stroke=alt.Stroke(
"legend_error_band",
title="Error band",
legend=alt.Legend(
symbolType="square",
orient="top",
labelFontSize=12,
labelFontStyle="bold",
),
),
),
)

horizontal_min_c0 = dict(
x=alt.X(
"min(c_0):Q",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from edsteva.utils.typing import DataFrame

from .defaults import (
error_line,
get_c_0_min_selection,
get_error_max_selection,
get_t_0_selection,
Expand All @@ -18,6 +19,7 @@ def get_normalized_probe_dashboard_config(self, predictor: DataFrame):
return dict(
estimates_selections=[c_0_min_selection, t_0_selection, error_max_selection],
estimates_filters=[c_0_min_filter, t_0_min_filter, error_max_filter],
error_line=error_line,
probe_line=normalized_probe_line,
model_line=normalized_model_line,
extra_horizontal_bar_charts=[horizontal_min_c0, horizontal_max_error],
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from edsteva.utils.typing import DataFrame

from .defaults import (
error_line,
get_c_0_min_selection,
get_error_max_selection,
get_t_0_selection,
Expand All @@ -16,6 +17,7 @@ def get_normalized_probe_plot_config(self, predictor: DataFrame):
return dict(
estimates_selections=[c_0_min_selection, t_0_selection, error_max_selection],
estimates_filters=[c_0_min_filter, t_0_min_filter, error_max_filter],
error_line=error_line,
probe_line=normalized_probe_line,
model_line=normalized_model_line,
)
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,6 @@ def get_horizontal_bar_charts(standard_terminologies: List[str]):
)

normalized_main_chart = dict(
legend_title="Mean",
encode=dict(
x=alt.X(
"normalized_date:Q",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from .defaults import (
chart_style,
error_line,
get_horizontal_bar_charts,
normalized_main_chart,
normalized_time_line,
Expand All @@ -16,7 +15,6 @@ def get_normalized_probe_dashboard_config(self):
chart_style=chart_style,
main_chart=normalized_main_chart,
time_line=normalized_time_line,
error_line=error_line,
vertical_bar_charts=vertical_bar_charts,
horizontal_bar_charts=horizontal_bar_charts,
)
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from .defaults import chart_style, error_line, normalized_main_chart
from .defaults import chart_style, normalized_main_chart


def get_normalized_probe_plot_config(self):
return dict(
chart_style=chart_style,
main_chart=normalized_main_chart,
error_line=error_line,
)
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def get_horizontal_bar_charts(standard_terminologies: List[str]):
),
],
calculates=[
dict(completeness=alt.datum.sum_measurement / alt.datum.max_measurement),
dict(c=alt.datum.sum_measurement / alt.datum.max_measurement),
],
encode=dict(
x=alt.X(
Expand All @@ -113,7 +113,7 @@ def get_horizontal_bar_charts(standard_terminologies: List[str]):
axis=alt.Axis(tickCount="month", labelAngle=0, grid=True),
),
y=alt.Y(
"completeness:Q",
"c:Q",
title="Completeness predictor c(t)",
axis=alt.Axis(grid=True),
),
Expand All @@ -125,7 +125,7 @@ def get_horizontal_bar_charts(standard_terminologies: List[str]):
tooltip=[
alt.Tooltip("value:N", title="Index"),
alt.Tooltip("yearmonth(date):T", title="Date"),
alt.Tooltip("completeness:Q", title="c(t)", format=".2f"),
alt.Tooltip("c:Q", title="c(t)", format=".2f"),
],
),
properties=dict(
Expand All @@ -135,7 +135,6 @@ def get_horizontal_bar_charts(standard_terminologies: List[str]):
)

normalized_main_chart = dict(
legend_title="Mean",
encode=dict(
x=alt.X(
"normalized_date:Q",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from .defaults import (
chart_style,
error_line,
get_horizontal_bar_charts,
normalized_main_chart,
normalized_time_line,
Expand All @@ -16,7 +15,6 @@ def get_normalized_probe_dashboard_config(self):
chart_style=chart_style,
main_chart=normalized_main_chart,
time_line=normalized_time_line,
error_line=error_line,
vertical_bar_charts=vertical_bar_charts,
horizontal_bar_charts=horizontal_bar_charts,
)
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from .defaults import chart_style, error_line, normalized_main_chart
from .defaults import chart_style, normalized_main_chart


def get_normalized_probe_plot_config(self):
return dict(
chart_style=chart_style,
main_chart=normalized_main_chart,
error_line=error_line,
)
7 changes: 3 additions & 4 deletions edsteva/probes/biology/viz_configs/per_visit/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def get_horizontal_bar_charts(standard_terminologies: List[str]):
),
],
calculates=[
dict(completeness=alt.datum.sum_visit_with_measurement / alt.datum.sum_visit),
dict(c=alt.datum.sum_visit_with_measurement / alt.datum.sum_visit),
],
encode=dict(
x=alt.X(
Expand All @@ -113,7 +113,7 @@ def get_horizontal_bar_charts(standard_terminologies: List[str]):
axis=alt.Axis(tickCount="month", labelAngle=0, grid=True),
),
y=alt.Y(
"completeness:Q",
"c:Q",
title="Completeness predictor c(t)",
axis=alt.Axis(grid=True),
),
Expand All @@ -125,7 +125,7 @@ def get_horizontal_bar_charts(standard_terminologies: List[str]):
tooltip=[
alt.Tooltip("value:N", title="Index"),
alt.Tooltip("yearmonth(date):T", title="Date"),
alt.Tooltip("completeness:Q", title="c(t)", format=".2f"),
alt.Tooltip("c:Q", title="c(t)", format=".2f"),
],
),
properties=dict(
Expand All @@ -135,7 +135,6 @@ def get_horizontal_bar_charts(standard_terminologies: List[str]):
)

normalized_main_chart = dict(
legend_title="Mean",
encode=dict(
x=alt.X(
"normalized_date:Q",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from .defaults import (
chart_style,
error_line,
get_horizontal_bar_charts,
normalized_main_chart,
normalized_time_line,
Expand All @@ -16,7 +15,6 @@ def get_normalized_probe_dashboard_config(self):
chart_style=chart_style,
main_chart=normalized_main_chart,
time_line=normalized_time_line,
error_line=error_line,
vertical_bar_charts=vertical_bar_charts,
horizontal_bar_charts=horizontal_bar_charts,
)
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from .defaults import chart_style, error_line, normalized_main_chart
from .defaults import chart_style, normalized_main_chart


def get_normalized_probe_plot_config(self):
return dict(
chart_style=chart_style,
main_chart=normalized_main_chart,
error_line=error_line,
)
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@
)

normalized_main_chart = dict(
legend_title="Mean",
encode=dict(
x=alt.X(
"normalized_date:Q",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from .defaults import (
chart_style,
error_line,
horizontal_bar_charts,
normalized_main_chart,
normalized_time_line,
Expand All @@ -13,7 +12,6 @@ def get_normalized_probe_dashboard_config(self):
chart_style=chart_style,
main_chart=normalized_main_chart,
time_line=normalized_time_line,
error_line=error_line,
vertical_bar_charts=vertical_bar_charts,
horizontal_bar_charts=horizontal_bar_charts,
)
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from .defaults import chart_style, error_line, normalized_main_chart
from .defaults import chart_style, normalized_main_chart


def get_normalized_probe_plot_config(self):
return dict(
chart_style=chart_style,
main_chart=normalized_main_chart,
error_line=error_line,
)
Loading

0 comments on commit 0ca0b9c

Please sign in to comment.