From d9a0aac700875284b908e63848677adcefa62262 Mon Sep 17 00:00:00 2001 From: joaopfonseca Date: Sat, 2 Mar 2024 01:06:13 +0000 Subject: [PATCH] FIX data displayed in waterfall plot (#32) --- examples/plot_basic_usage.py | 24 +++++++++++++++++++----- sharp/utils/_checks.py | 2 +- sharp/visualization/_visualization.py | 8 ++++---- sharp/visualization/_waterfall.py | 24 +++++++++++++++--------- 4 files changed, 39 insertions(+), 19 deletions(-) diff --git a/examples/plot_basic_usage.py b/examples/plot_basic_usage.py index cc5a777..3437ac1 100644 --- a/examples/plot_basic_usage.py +++ b/examples/plot_basic_usage.py @@ -12,6 +12,7 @@ import matplotlib.pyplot as plt from sklearn.utils import check_random_state from sharp import ShaRP +from sharp.utils import scores_to_ordering # Set up some envrionment variables RNG_SEED = 42 @@ -31,6 +32,7 @@ def score_function(X): [rng.normal(size=(N_SAMPLES, 1)), rng.binomial(1, 0.5, size=(N_SAMPLES, 1))], axis=1 ) y = score_function(X) +rank = scores_to_ordering(y) ###################################################################################### @@ -65,17 +67,29 @@ def score_function(X): ###################################################################################### # We can also turn these into visualizations: +plt.style.use("seaborn-v0_8-whitegrid") + # Visualization of feature contributions print("Sample 2 feature values:", X[2]) print("Sample 3 feature values:", X[3]) -fig, axes = plt.subplots(1, 2) +fig, axes = plt.subplots(1, 2, figsize=(13.5, 4.5), layout="constrained") # Bar plot comparing two points -xai.plot.bar(pair_scores, ax=axes[0]) -axes[0].set_title("Pairwise comparison (Sample 2 vs 3)") +xai.plot.bar(pair_scores, ax=axes[0], color="#ff0051") +axes[0].set_title( + f"Pairwise comparison - Sample 2 (rank {rank[2]}) vs 3 (rank {rank[3]})", + fontsize=12, + y=-0.2, +) +axes[0].set_xlabel("") +axes[0].set_ylabel("Contribution to rank", fontsize=12) +axes[0].tick_params(axis="both", which="major", labelsize=12) # Waterfall explaining rank for sample 2 -axes[1] = xai.plot.waterfall(individual_scores) -axes[1].suptitle("Rank explanation for Sample 9") +axes[1] = xai.plot.waterfall( + individual_scores, feature_values=X[9], mean_target_value=rank.mean() +) +ax = axes[1].gca() +ax.set_title("Rank explanation for Sample 9", fontsize=12, y=-0.2) plt.show() diff --git a/sharp/utils/_checks.py b/sharp/utils/_checks.py index da9de4a..822b0c3 100644 --- a/sharp/utils/_checks.py +++ b/sharp/utils/_checks.py @@ -12,7 +12,7 @@ def check_feature_names(X): feature_names = _get_feature_names(X) if feature_names is None: - feature_names = np.indices([X.shape[1]]).squeeze() + feature_names = np.array([f"Feature {i}" for i in range(X.shape[1])]) return feature_names diff --git a/sharp/visualization/_visualization.py b/sharp/visualization/_visualization.py index 9fc94bc..b3ce9d5 100644 --- a/sharp/visualization/_visualization.py +++ b/sharp/visualization/_visualization.py @@ -28,7 +28,7 @@ def bar(self, scores, ax=None, **kwargs): return ax - def waterfall(self, scores, mean_shapley_value=0): + def waterfall(self, contributions, feature_values=None, mean_target_value=0): """ TODO: refactor waterfall plot code. """ @@ -38,11 +38,11 @@ def waterfall(self, scores, mean_shapley_value=0): rank_dict = { "upper_bounds": None, "lower_bounds": None, - "features": None, # pd.Series(feature_names), + "features": feature_values, # pd.Series(feature_names), "data": None, # pd.Series(ind_values, index=feature_names), - "base_values": mean_shapley_value, + "base_values": mean_target_value, "feature_names": feature_names, - "values": pd.Series(scores, index=feature_names), + "values": pd.Series(contributions, index=feature_names), } return _waterfall(rank_dict, max_display=10) diff --git a/sharp/visualization/_waterfall.py b/sharp/visualization/_waterfall.py index d7ae69a..676f65b 100644 --- a/sharp/visualization/_waterfall.py +++ b/sharp/visualization/_waterfall.py @@ -30,7 +30,11 @@ def _waterfall(shap_values, max_display=10, show=False): # noqa plt.ioff() base_values = float(shap_values["base_values"]) - features = shap_values["values"] + features = ( + np.array(shap_values["features"]) + if shap_values["features"] is not None + else np.array(shap_values["values"]) + ) feature_names = shap_values["feature_names"] # lower_bounds = shap_values["lower_bounds"] # upper_bounds = shap_values["upper_bounds"] @@ -38,7 +42,7 @@ def _waterfall(shap_values, max_display=10, show=False): # noqa # init variables we use for tracking the plot locations num_features = min(max_display, len(values)) - row_height = 0.5 + # row_height = 0.5 rng = range(num_features - 1, -1, -1) order = np.argsort(-np.abs(values)) pos_lefts = [] @@ -55,7 +59,7 @@ def _waterfall(shap_values, max_display=10, show=False): # noqa yticklabels = ["" for _ in range(num_features + 1)] # size the plot based on how many features we are plotting - plt.gcf().set_size_inches(8, num_features * row_height + 1.5) + # plt.gcf().set_size_inches(8, num_features * row_height + 1.5) # see how many individual (vs. grouped at the end) features we are plotting if num_features == len(values): @@ -66,7 +70,7 @@ def _waterfall(shap_values, max_display=10, show=False): # noqa # compute the locations of the individual features and plot the dashed connecting # lines for i in range(num_individual): - sval = values[order[i]] + sval = values.iloc[order.iloc[i]] loc -= sval if sval >= 0: pos_inds.append(rng[i]) @@ -92,17 +96,19 @@ def _waterfall(shap_values, max_display=10, show=False): # noqa zorder=-1, ) if features is None: - yticklabels[rng[i]] = feature_names[order[i]] + yticklabels[rng[i]] = feature_names[order.iloc[i]] else: - if np.issubdtype(type(features[order[i]]), np.number): + if np.issubdtype(type(features[order.iloc[i]]), np.number): yticklabels[rng[i]] = ( - format_value(float(features[order[i]]), "%0.03f") + format_value(float(features[order.iloc[i]]), "%0.03f") + " = " - + feature_names[order[i]] + + feature_names[order.iloc[i]] ) else: yticklabels[rng[i]] = ( - str(features[order[i]]) + " = " + str(feature_names[order[i]]) + str(features[order.iloc[i]]) + + " = " + + str(feature_names[order.iloc[i]]) ) # add a last grouped feature to represent the impact of all the features we didn't