Skip to content

Commit

Permalink
Merge pull request #66 from MSDLLCpapers/plotting
Browse files Browse the repository at this point in the history
Minor changes to plotting - highlight the best or reference point
  • Loading branch information
xuyuting authored Sep 18, 2024
2 parents 805c9ab + d0f2c37 commit eff5652
Showing 1 changed file with 32 additions and 3 deletions.
35 changes: 32 additions & 3 deletions obsidian/plotting/plotly.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,18 +51,34 @@ def visualize_inputs(campaign: Campaign) -> Figure:
+ ['Correlation Matrix']
+ [X.columns[i] for i in range(cols, n_dim)]
)
# if campaign.optimizer is fitted, then X_best_f_idx is identified
if 'X_best_f_idx' in dir(campaign.optimizer):
marker_shapes = ['diamond' if rowInd in [campaign.optimizer.X_best_f_idx] else 'circle' for rowInd in range(campaign.X.shape[0])]
else:
marker_shapes = ['circle']*campaign.X.shape[0]

for i, param in enumerate(X.columns):
row_i = i // cols + 1
col_i = i % cols + 1
fig.add_trace(go.Scatter(x=X.index, y=X[param],
mode='markers', name=param,
marker=dict(color=color_list[i]),
marker=dict(color=color_list[i], symbol=marker_shapes),
showlegend=False),
row=row_i, col=col_i)
fig.update_xaxes(tickvals=np.around(np.linspace(0, campaign.m_exp, 5)),
row=row_i, col=col_i)

# Add note to explain the shape of markers
if hasattr(campaign.optimizer, 'X_best_f_idx'):
fig.add_annotation(
text="Note: The diamond markers denote samples that achieve the best sum of targets.",
showarrow=False,
xref="paper", yref="paper",
x=0,
y=-0.2,
font=dict(style="italic")
)

# Calculate the correlation matrix
X_u = campaign.X_space.unit_map(X)
corr_matrix = X_u.corr()
Expand Down Expand Up @@ -325,8 +341,9 @@ def factor_plot(optimizer: Optimizer,
Y_mu_ref = Y_pred_ref[y_name+('_t (pred)' if f_transform else ' (pred)')].values
fig.add_trace(go.Scatter(x=X_ref.iloc[:, feature_id].values, y=Y_mu_ref,
mode='markers',
marker=dict(symbol='diamond'),
line={'color': obsidian_colors.teal},
name='Ref'),
name='Reference'),
)
fig.update_xaxes(title_text=X_name)
fig.update_yaxes(title_text=y_name)
Expand Down Expand Up @@ -544,7 +561,19 @@ def optim_progress(campaign: Campaign,
marker=marker_dict,
customdata=campaign.data[X_names],
name='Data'))


# Highlight the best samples
if hasattr(campaign.optimizer, 'X_best_f_idx'):
fig.add_trace(go.Scatter(x=pd.Series(out_exp.iloc[campaign.optimizer.X_best_f_idx, 0]),
y=pd.Series(out_exp.iloc[campaign.optimizer.X_best_f_idx, 1]),
mode='markers',
marker=dict(symbol='diamond-open', size=14),
line={'color': 'black'},
legendgroup='marker_shape', showlegend=True,
name='Best')
)
fig.update_layout(showlegend=True)

template = ["<b>"+str(param.name)+"</b>: "+" %{customdata["+str(i)+"]"
+ (":.3G}"if isinstance(param, Param_Continuous) else "}") + "<br>"
for i, param in enumerate(campaign.X_space)]
Expand Down

0 comments on commit eff5652

Please sign in to comment.