Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

♻️ Refactor webui to live render results #29

Merged
merged 1 commit into from
Nov 27, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 76 additions & 109 deletions aide/webui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,6 @@ def run(self):
input_col, results_col = st.columns([1, 3])
with input_col:
self.render_input_section(results_col)
with results_col:
self.render_results_section()

def render_sidebar(self):
"""
Expand Down Expand Up @@ -273,17 +271,46 @@ def run_aide(self, files, goal_text, eval_text, num_steps, results_col):
return None

experiment = self.initialize_experiment(input_dir, goal_text, eval_text)
placeholders = self.create_results_placeholders(results_col, experiment)

# Create separate placeholders for progress and config
progress_placeholder = results_col.empty()
config_placeholder = results_col.empty()
results_placeholder = results_col.empty()

for step in range(num_steps):
st.session_state.current_step = step + 1
progress = (step + 1) / num_steps
self.update_results_placeholders(placeholders, progress)

# Update progress
with progress_placeholder.container():
st.markdown(
f"### 🔥 Running Step {st.session_state.current_step}/{st.session_state.total_steps}"
)
st.progress(progress)

# Show config only for first step
if step == 0:
with config_placeholder.container():
st.markdown("### 📋 Configuration")
st.code(OmegaConf.to_yaml(experiment.cfg), language="yaml")

experiment.run(steps=1)

self.clear_run_state(placeholders)
# Show results
with results_placeholder.container():
self.render_live_results(experiment)

# Clear config after first step
if step == 0:
config_placeholder.empty()

return self.collect_results(experiment)
# Clear progress after all steps
progress_placeholder.empty()

# Update session state
st.session_state.is_running = False
st.session_state.results = self.collect_results(experiment)
return st.session_state.results

except Exception as e:
st.session_state.is_running = False
Expand Down Expand Up @@ -355,70 +382,6 @@ def initialize_experiment(input_dir, goal_text, eval_text):
experiment = Experiment(data_dir=str(input_dir), goal=goal_text, eval=eval_text)
return experiment

@staticmethod
def create_results_placeholders(results_col, experiment):
"""
Create placeholders in the results column for dynamic content.

Args:
results_col (st.delta_generator.DeltaGenerator): The results column.
experiment (Experiment): The Experiment object.

Returns:
dict: Dictionary of placeholders.
"""
with results_col:
status_placeholder = st.empty()
step_placeholder = st.empty()
config_title_placeholder = st.empty()
config_placeholder = st.empty()
progress_placeholder = st.empty()

step_placeholder.markdown(
f"### 🔥 Running Step {st.session_state.current_step}/{st.session_state.total_steps}"
)
config_title_placeholder.markdown("### 📋 Configuration")
config_placeholder.code(OmegaConf.to_yaml(experiment.cfg), language="yaml")
progress_placeholder.progress(0)

placeholders = {
"status": status_placeholder,
"step": step_placeholder,
"config_title": config_title_placeholder,
"config": config_placeholder,
"progress": progress_placeholder,
}
return placeholders

@staticmethod
def update_results_placeholders(placeholders, progress):
"""
Update the placeholders with the current progress.

Args:
placeholders (dict): Dictionary of placeholders.
progress (float): Current progress value.
"""
placeholders["step"].markdown(
f"### 🔥 Running Step {st.session_state.current_step}/{st.session_state.total_steps}"
)
placeholders["progress"].progress(progress)

@staticmethod
def clear_run_state(placeholders):
"""
Clear the running state and placeholders after the experiment.

Args:
placeholders (dict): Dictionary of placeholders.
"""
st.session_state.is_running = False
placeholders["status"].empty()
placeholders["step"].empty()
placeholders["config_title"].empty()
placeholders["config"].empty()
placeholders["progress"].empty()

@staticmethod
def collect_results(experiment):
"""
Expand Down Expand Up @@ -454,41 +417,6 @@ def collect_results(experiment):
}
return results

def render_results_section(self):
"""
Render the results section with tabs for different outputs.
"""
st.header("Results")
if st.session_state.get("results"):
results = st.session_state.results

tabs = st.tabs(
[
"Tree Visualization",
"Best Solution",
"Config",
"Journal",
"Validation Plot",
]
)

with tabs[0]:
self.render_tree_visualization(results)
with tabs[1]:
self.render_best_solution(results)
with tabs[2]:
self.render_config(results)
with tabs[3]:
self.render_journal(results)
with tabs[4]:
# Display best score before the plot
best_metric = self.get_best_metric(results)
if best_metric is not None:
st.metric("Best Validation Score", f"{best_metric:.4f}")
self.render_validation_plot(results)
else:
st.info("No results to display. Please run an experiment.")

@staticmethod
def render_tree_visualization(results):
"""
Expand Down Expand Up @@ -576,9 +504,13 @@ def get_best_metric(results):
return None

@staticmethod
def render_validation_plot(results):
def render_validation_plot(results, step):
"""
Render the validation score plot.

Args:
results (dict): The results dictionary
step (int): Current step number for unique key generation
"""
try:
journal_data = json.loads(results["journal"])
Expand Down Expand Up @@ -619,12 +551,47 @@ def render_validation_plot(results):
paper_bgcolor="rgba(0,0,0,0)",
)

st.plotly_chart(fig, use_container_width=True)
# Only keep the key for plotly_chart
st.plotly_chart(fig, use_container_width=True, key=f"plot_{step}")
else:
st.info("No validation metrics available to plot.")
st.info("No validation metrics available to plot")

except (json.JSONDecodeError, KeyError):
st.error("Could not parse validation metrics data.")
st.error("Could not parse validation metrics data")

def render_live_results(self, experiment):
"""
Render live results.

Args:
experiment (Experiment): The Experiment object
"""
results = self.collect_results(experiment)

# Create tabs for different result views
tabs = st.tabs(
[
"Tree Visualization",
"Best Solution",
"Config",
"Journal",
"Validation Plot",
]
)

with tabs[0]:
self.render_tree_visualization(results)
with tabs[1]:
self.render_best_solution(results)
with tabs[2]:
self.render_config(results)
with tabs[3]:
self.render_journal(results)
with tabs[4]:
best_metric = self.get_best_metric(results)
if best_metric is not None:
st.metric("Best Validation Score", f"{best_metric:.4f}")
self.render_validation_plot(results, step=st.session_state.current_step)


if __name__ == "__main__":
Expand Down