diff --git a/notebooks/per_token_plot.ipynb b/notebooks/per_token_plot.ipynb index ae00c68c..198057c7 100644 --- a/notebooks/per_token_plot.ipynb +++ b/notebooks/per_token_plot.ipynb @@ -2,13 +2,13 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "6825ba6543ba45438b18cf0949576020", + "model_id": "696575431f65420e9dc22c3b3476bfbb", "version_major": 2, "version_minor": 0 }, @@ -16,7 +16,7 @@ "VBox(children=(Dropdown(description='Token Category:', options=('nouns', 'verbs', 'prepositions', 'adjectives'…" ] }, - "execution_count": 1, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -47,7 +47,32 @@ " performance_data[model][cat] = (-means, err_low, err_hi)\n", "\n", "\n", - "visualize_per_token_category(performance_data)" + "visualize_per_token_category(performance_data, log_scale=True, bg_color='LightGrey', line_color=\"Red\", marker_color='Orange', bar_color='Green')" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "cb3af5248a4a40118c36a527c927289d", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "VBox(children=(Dropdown(description='Token Category:', options=('nouns', 'verbs', 'prepositions', 'adjectives'…" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "visualize_per_token_category(performance_data, log_scale=False)" ] } ], diff --git a/src/delphi/eval/vis_per_token_model.py b/src/delphi/eval/vis_per_token_model.py index 70856902..a8e269fe 100644 --- a/src/delphi/eval/vis_per_token_model.py +++ b/src/delphi/eval/vis_per_token_model.py @@ -1,10 +1,11 @@ import ipywidgets import numpy as np import plotly.graph_objects as go -from beartype.typing import Dict -def visualize_per_token_category(input: Dict[str, Dict[str, tuple]]) -> ipywidgets.VBox: +def visualize_per_token_category( + input: dict[str, dict[str, tuple]], log_scale=False, **kwargs: str +) -> ipywidgets.VBox: model_names = list(input.keys()) categories = list(input[model_names[0]].keys()) category = categories[0] @@ -24,10 +25,21 @@ def get_plot_values(category: str) -> tuple[np.ndarray, np.ndarray, np.ndarray]: symmetric=False, array=err_hi, arrayminus=err_low, - color="purple", + color=kwargs.get("bar_color", "purple"), ), + marker=dict( + color=kwargs.get("marker_color", "SkyBlue"), + size=15, + line=dict(color=kwargs.get("line_color", "MediumPurple"), width=2), + ), + ), + layout=go.Layout( + yaxis=dict( + title="Loss", + type="log" if log_scale else "linear", + ), + plot_bgcolor=kwargs.get("bg_color", "AliceBlue"), ), - layout=go.Layout(yaxis=dict(title="Loss")), ) selected_category = ipywidgets.Dropdown( @@ -38,12 +50,11 @@ def get_plot_values(category: str) -> tuple[np.ndarray, np.ndarray, np.ndarray]: ) def response(change): - if selected_category.value: - means, err_lo, err_hi = get_plot_values(selected_category.value) - with g.batch_update(): - g.data[0].y = means - g.data[0].error_y["array"] = err_hi - g.data[0].error_y["arrayminus"] = err_lo + means, err_lo, err_hi = get_plot_values(selected_category.value) + with g.batch_update(): + g.data[0].y = means + g.data[0].error_y["array"] = err_hi + g.data[0].error_y["arrayminus"] = err_lo selected_category.observe(response, names="value")