Skip to content

Commit

Permalink
Merge pull request #40 from cirKITers/training-reset
Browse files Browse the repository at this point in the history
Training reset
  • Loading branch information
majafranz authored Nov 12, 2024
2 parents 43111c4 + 2f8dcb0 commit 54a7d01
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 109 deletions.
5 changes: 4 additions & 1 deletion app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,4 +224,7 @@ def on_preference_changed(

logging.info("(Re-)launching Application..")

app.run(host="0.0.0.0", port="8050", threaded=True, debug="--debug" in args)
try:
app.run(host="0.0.0.0", port="8050", threaded=True, debug="--debug" in args)
except Exception as e:
logging.error(e)
171 changes: 63 additions & 108 deletions app/pages/1-training.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,6 @@
dcc.Store(id="training-page-storage", storage_type="session"),
dcc.Store(id="training-log-storage", storage_type="session"),
dcc.Store(id="training-log-hist-storage", storage_type="session"),
dcc.Interval(
id="interval-component",
interval=1 * 1000, # in milliseconds
n_intervals=0,
),
html.Div(
[
html.Div(
Expand Down Expand Up @@ -162,7 +157,6 @@
dbc.Button(
"Start Training",
id="training-start-button",
disabled="true",
),
],
),
Expand Down Expand Up @@ -228,38 +222,27 @@
)


@callback(
[
Output("training-page-storage", "data", allow_duplicate=True),
Output("training-start-button", "disabled", allow_duplicate=True),
],
Input("main-storage", "modified_timestamp"),
State("main-storage", "data"),
State("training-page-storage", "data"),
prevent_initial_call=True,
)
def update_page_data(_, main_data, page_data):
if main_data["circuit_type"] is None or main_data["circuit_type"] == "No_Ansatz":
return page_data, True

return page_data, False


@callback(
[
Output("training-page-storage", "data"),
Output("training-log-storage", "data", allow_duplicate=True),
Output("training-log-hist-storage", "data"),
Output("training-start-button", "children", allow_duplicate=True),
],
[
Input("main-storage", "modified_timestamp"),
Input("training-bit-flip-prob-slider", "value"),
Input("training-phase-flip-prob-slider", "value"),
Input("training-amplitude-damping-prob-slider", "value"),
Input("training-phase-damping-prob-slider", "value"),
Input("training-depolarization-prob-slider", "value"),
Input("training-steps-numeric-input", "value"),
Input("training-start-button", "n_clicks"),
],
State("training-start-button", "children"),
prevent_initial_call="initial_duplicate",
)
def on_preference_changed(bf, pf, ad, pd, dp, steps):
def on_preference_changed(_, bf, pf, ad, pd, dp, steps, n, state):

# Give a default data dict with 0 clicks if there's no data.
# page_data = dict(bf=bf, pf=pf, ad=ad, pd=pd, dp=dp, steps=steps)
Expand All @@ -272,10 +255,65 @@ def on_preference_changed(bf, pf, ad, pd, dp, steps):
"Depolarization": dp,
},
"steps": steps,
"running": state != "Reset Training",
}
page_log = {"loss": [], "params": [], "ent_cap": []}
page_log_hist = {"x": [], "y": [], "z": []}

return page_data, page_log_hist
if state == "Reset Training":
return [page_data, page_log, page_log_hist, "Start Training"]
else:
return [page_data, page_log, page_log_hist, "Reset Training"]


@callback(
Output("training-metric-figure", "figure"),
Input("training-log-storage", "modified_timestamp"),
[
State("training-log-storage", "data"),
State("training-page-storage", "data"),
],
prevent_initial_call=True,
)
def update_loss(n, page_log_training, data):
fig_expval = go.Figure()
if (
page_log_training is not None
and len(page_log_training["loss"]) > 0
and data is not None
):
fig_expval.add_scatter(y=page_log_training["loss"])

fig_expval.update_layout(
title="Loss",
template="simple_white",
xaxis_title="Step",
yaxis_title="Loss",
xaxis_range=[0, data["steps"] if data is not None else DEFAULT_N_STEPS],
autosize=False,
)

return fig_expval


@callback(
Output("training-log-storage", "data", allow_duplicate=True),
Input("training-log-storage", "modified_timestamp"),
[
State("training-log-storage", "data"),
State("training-page-storage", "data"),
],
prevent_initial_call=True,
)
def pong(_, page_log_training, page_data):
if (
page_log_training is None
or page_data is None
or len(page_log_training["loss"]) > page_data["steps"]
or not page_data["running"]
):
raise PreventUpdate()
return page_log_training


@callback(
Expand Down Expand Up @@ -422,89 +460,6 @@ def update_ent_cap(n, page_log_training, data):
return fig_ent_cap


@callback(
Output("training-metric-figure", "figure"),
Input("training-log-storage", "modified_timestamp"),
[
State("training-log-storage", "data"),
State("training-page-storage", "data"),
],
prevent_initial_call=True,
)
def update_loss(n, page_log_training, data):
fig_expval = go.Figure()
if (
page_log_training is not None
and len(page_log_training["loss"]) > 0
and data is not None
):
fig_expval.add_scatter(y=page_log_training["loss"])

fig_expval.update_layout(
title="Loss",
template="simple_white",
xaxis_title="Step",
yaxis_title="Loss",
xaxis_range=[0, data["steps"] if data is not None else DEFAULT_N_STEPS],
autosize=False,
)

return fig_expval


@callback(
[
Output("training-log-storage", "data", allow_duplicate=True),
Output("training-start-button", "disabled", allow_duplicate=True),
],
Input("training-start-button", "n_clicks"),
prevent_initial_call=True,
)
def trigger_training(_):
page_log = {"loss": [], "params": [], "ent_cap": []}

return [page_log, True]


@callback(
Output("training-start-button", "disabled", allow_duplicate=True),
Input("training-log-storage", "modified_timestamp"),
[
State("training-log-storage", "data"),
State("training-page-storage", "data"),
],
prevent_initial_call=True,
)
def stop_training(_, page_log_training, page_data):
if (
page_log_training is not None
and page_data is not None
and len(page_log_training["loss"]) <= page_data["steps"]
):
raise PreventUpdate()

return False


@callback(
Output("training-log-storage", "data", allow_duplicate=True),
Input("training-log-storage", "modified_timestamp"),
[
State("training-log-storage", "data"),
State("training-page-storage", "data"),
],
prevent_initial_call=True,
)
def pong(_, page_log_training, page_data):
if (
page_log_training is None
or page_data is None
or len(page_log_training["loss"]) > page_data["steps"]
):
raise PreventUpdate()
return page_log_training


@callback(
Output("training-log-storage", "data"),
[
Expand Down

0 comments on commit 54a7d01

Please sign in to comment.