Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training reset #40

Merged
merged 4 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,4 +224,7 @@ def on_preference_changed(

logging.info("(Re-)launching Application..")

app.run(host="0.0.0.0", port="8050", threaded=True, debug="--debug" in args)
try:
app.run(host="0.0.0.0", port="8050", threaded=True, debug="--debug" in args)
except Exception as e:
logging.error(e)
171 changes: 63 additions & 108 deletions app/pages/1-training.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,6 @@
dcc.Store(id="training-page-storage", storage_type="session"),
dcc.Store(id="training-log-storage", storage_type="session"),
dcc.Store(id="training-log-hist-storage", storage_type="session"),
dcc.Interval(
id="interval-component",
interval=1 * 1000, # in milliseconds
n_intervals=0,
),
html.Div(
[
html.Div(
Expand Down Expand Up @@ -162,7 +157,6 @@
dbc.Button(
"Start Training",
id="training-start-button",
disabled="true",
),
],
),
Expand Down Expand Up @@ -228,38 +222,27 @@
)


@callback(
[
Output("training-page-storage", "data", allow_duplicate=True),
Output("training-start-button", "disabled", allow_duplicate=True),
],
Input("main-storage", "modified_timestamp"),
State("main-storage", "data"),
State("training-page-storage", "data"),
prevent_initial_call=True,
)
def update_page_data(_, main_data, page_data):
if main_data["circuit_type"] is None or main_data["circuit_type"] == "No_Ansatz":
return page_data, True

return page_data, False


@callback(
[
Output("training-page-storage", "data"),
Output("training-log-storage", "data", allow_duplicate=True),
Output("training-log-hist-storage", "data"),
Output("training-start-button", "children", allow_duplicate=True),
],
[
Input("main-storage", "modified_timestamp"),
Input("training-bit-flip-prob-slider", "value"),
Input("training-phase-flip-prob-slider", "value"),
Input("training-amplitude-damping-prob-slider", "value"),
Input("training-phase-damping-prob-slider", "value"),
Input("training-depolarization-prob-slider", "value"),
Input("training-steps-numeric-input", "value"),
Input("training-start-button", "n_clicks"),
],
State("training-start-button", "children"),
prevent_initial_call="initial_duplicate",
)
def on_preference_changed(bf, pf, ad, pd, dp, steps):
def on_preference_changed(_, bf, pf, ad, pd, dp, steps, n, state):

# Give a default data dict with 0 clicks if there's no data.
# page_data = dict(bf=bf, pf=pf, ad=ad, pd=pd, dp=dp, steps=steps)
Expand All @@ -272,10 +255,65 @@ def on_preference_changed(bf, pf, ad, pd, dp, steps):
"Depolarization": dp,
},
"steps": steps,
"running": state != "Reset Training",
}
page_log = {"loss": [], "params": [], "ent_cap": []}
page_log_hist = {"x": [], "y": [], "z": []}

return page_data, page_log_hist
if state == "Reset Training":
return [page_data, page_log, page_log_hist, "Start Training"]
else:
return [page_data, page_log, page_log_hist, "Reset Training"]


@callback(
Output("training-metric-figure", "figure"),
Input("training-log-storage", "modified_timestamp"),
[
State("training-log-storage", "data"),
State("training-page-storage", "data"),
],
prevent_initial_call=True,
)
def update_loss(n, page_log_training, data):
fig_expval = go.Figure()
if (
page_log_training is not None
and len(page_log_training["loss"]) > 0
and data is not None
):
fig_expval.add_scatter(y=page_log_training["loss"])

fig_expval.update_layout(
title="Loss",
template="simple_white",
xaxis_title="Step",
yaxis_title="Loss",
xaxis_range=[0, data["steps"] if data is not None else DEFAULT_N_STEPS],
autosize=False,
)

return fig_expval


@callback(
Output("training-log-storage", "data", allow_duplicate=True),
Input("training-log-storage", "modified_timestamp"),
[
State("training-log-storage", "data"),
State("training-page-storage", "data"),
],
prevent_initial_call=True,
)
def pong(_, page_log_training, page_data):
if (
page_log_training is None
or page_data is None
or len(page_log_training["loss"]) > page_data["steps"]
or not page_data["running"]
):
raise PreventUpdate()
return page_log_training


@callback(
Expand Down Expand Up @@ -422,89 +460,6 @@ def update_ent_cap(n, page_log_training, data):
return fig_ent_cap


@callback(
Output("training-metric-figure", "figure"),
Input("training-log-storage", "modified_timestamp"),
[
State("training-log-storage", "data"),
State("training-page-storage", "data"),
],
prevent_initial_call=True,
)
def update_loss(n, page_log_training, data):
fig_expval = go.Figure()
if (
page_log_training is not None
and len(page_log_training["loss"]) > 0
and data is not None
):
fig_expval.add_scatter(y=page_log_training["loss"])

fig_expval.update_layout(
title="Loss",
template="simple_white",
xaxis_title="Step",
yaxis_title="Loss",
xaxis_range=[0, data["steps"] if data is not None else DEFAULT_N_STEPS],
autosize=False,
)

return fig_expval


@callback(
[
Output("training-log-storage", "data", allow_duplicate=True),
Output("training-start-button", "disabled", allow_duplicate=True),
],
Input("training-start-button", "n_clicks"),
prevent_initial_call=True,
)
def trigger_training(_):
page_log = {"loss": [], "params": [], "ent_cap": []}

return [page_log, True]


@callback(
Output("training-start-button", "disabled", allow_duplicate=True),
Input("training-log-storage", "modified_timestamp"),
[
State("training-log-storage", "data"),
State("training-page-storage", "data"),
],
prevent_initial_call=True,
)
def stop_training(_, page_log_training, page_data):
if (
page_log_training is not None
and page_data is not None
and len(page_log_training["loss"]) <= page_data["steps"]
):
raise PreventUpdate()

return False


@callback(
Output("training-log-storage", "data", allow_duplicate=True),
Input("training-log-storage", "modified_timestamp"),
[
State("training-log-storage", "data"),
State("training-page-storage", "data"),
],
prevent_initial_call=True,
)
def pong(_, page_log_training, page_data):
if (
page_log_training is None
or page_data is None
or len(page_log_training["loss"]) > page_data["steps"]
):
raise PreventUpdate()
return page_log_training


@callback(
Output("training-log-storage", "data"),
[
Expand Down