Skip to content

Commit

Permalink
FIX data displayed in waterfall plot (#32)
Browse files Browse the repository at this point in the history
  • Loading branch information
joaopfonseca committed Mar 3, 2024
1 parent f286725 commit d9a0aac
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 19 deletions.
24 changes: 19 additions & 5 deletions examples/plot_basic_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import matplotlib.pyplot as plt
from sklearn.utils import check_random_state
from sharp import ShaRP
from sharp.utils import scores_to_ordering

# Set up some envrionment variables
RNG_SEED = 42
Expand All @@ -31,6 +32,7 @@ def score_function(X):
[rng.normal(size=(N_SAMPLES, 1)), rng.binomial(1, 0.5, size=(N_SAMPLES, 1))], axis=1
)
y = score_function(X)
rank = scores_to_ordering(y)


######################################################################################
Expand Down Expand Up @@ -65,17 +67,29 @@ def score_function(X):
######################################################################################
# We can also turn these into visualizations:

plt.style.use("seaborn-v0_8-whitegrid")

# Visualization of feature contributions
print("Sample 2 feature values:", X[2])
print("Sample 3 feature values:", X[3])
fig, axes = plt.subplots(1, 2)
fig, axes = plt.subplots(1, 2, figsize=(13.5, 4.5), layout="constrained")

# Bar plot comparing two points
xai.plot.bar(pair_scores, ax=axes[0])
axes[0].set_title("Pairwise comparison (Sample 2 vs 3)")
xai.plot.bar(pair_scores, ax=axes[0], color="#ff0051")
axes[0].set_title(
f"Pairwise comparison - Sample 2 (rank {rank[2]}) vs 3 (rank {rank[3]})",
fontsize=12,
y=-0.2,
)
axes[0].set_xlabel("")
axes[0].set_ylabel("Contribution to rank", fontsize=12)
axes[0].tick_params(axis="both", which="major", labelsize=12)

# Waterfall explaining rank for sample 2
axes[1] = xai.plot.waterfall(individual_scores)
axes[1].suptitle("Rank explanation for Sample 9")
axes[1] = xai.plot.waterfall(
individual_scores, feature_values=X[9], mean_target_value=rank.mean()
)
ax = axes[1].gca()
ax.set_title("Rank explanation for Sample 9", fontsize=12, y=-0.2)

plt.show()
2 changes: 1 addition & 1 deletion sharp/utils/_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def check_feature_names(X):
feature_names = _get_feature_names(X)

if feature_names is None:
feature_names = np.indices([X.shape[1]]).squeeze()
feature_names = np.array([f"Feature {i}" for i in range(X.shape[1])])

return feature_names

Expand Down
8 changes: 4 additions & 4 deletions sharp/visualization/_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def bar(self, scores, ax=None, **kwargs):

return ax

def waterfall(self, scores, mean_shapley_value=0):
def waterfall(self, contributions, feature_values=None, mean_target_value=0):
"""
TODO: refactor waterfall plot code.
"""
Expand All @@ -38,11 +38,11 @@ def waterfall(self, scores, mean_shapley_value=0):
rank_dict = {
"upper_bounds": None,
"lower_bounds": None,
"features": None, # pd.Series(feature_names),
"features": feature_values, # pd.Series(feature_names),
"data": None, # pd.Series(ind_values, index=feature_names),
"base_values": mean_shapley_value,
"base_values": mean_target_value,
"feature_names": feature_names,
"values": pd.Series(scores, index=feature_names),
"values": pd.Series(contributions, index=feature_names),
}
return _waterfall(rank_dict, max_display=10)

Expand Down
24 changes: 15 additions & 9 deletions sharp/visualization/_waterfall.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,19 @@ def _waterfall(shap_values, max_display=10, show=False): # noqa
plt.ioff()

base_values = float(shap_values["base_values"])
features = shap_values["values"]
features = (
np.array(shap_values["features"])
if shap_values["features"] is not None
else np.array(shap_values["values"])
)
feature_names = shap_values["feature_names"]
# lower_bounds = shap_values["lower_bounds"]
# upper_bounds = shap_values["upper_bounds"]
values = shap_values["values"]

# init variables we use for tracking the plot locations
num_features = min(max_display, len(values))
row_height = 0.5
# row_height = 0.5
rng = range(num_features - 1, -1, -1)
order = np.argsort(-np.abs(values))
pos_lefts = []
Expand All @@ -55,7 +59,7 @@ def _waterfall(shap_values, max_display=10, show=False): # noqa
yticklabels = ["" for _ in range(num_features + 1)]

# size the plot based on how many features we are plotting
plt.gcf().set_size_inches(8, num_features * row_height + 1.5)
# plt.gcf().set_size_inches(8, num_features * row_height + 1.5)

# see how many individual (vs. grouped at the end) features we are plotting
if num_features == len(values):
Expand All @@ -66,7 +70,7 @@ def _waterfall(shap_values, max_display=10, show=False): # noqa
# compute the locations of the individual features and plot the dashed connecting
# lines
for i in range(num_individual):
sval = values[order[i]]
sval = values.iloc[order.iloc[i]]
loc -= sval
if sval >= 0:
pos_inds.append(rng[i])
Expand All @@ -92,17 +96,19 @@ def _waterfall(shap_values, max_display=10, show=False): # noqa
zorder=-1,
)
if features is None:
yticklabels[rng[i]] = feature_names[order[i]]
yticklabels[rng[i]] = feature_names[order.iloc[i]]
else:
if np.issubdtype(type(features[order[i]]), np.number):
if np.issubdtype(type(features[order.iloc[i]]), np.number):
yticklabels[rng[i]] = (
format_value(float(features[order[i]]), "%0.03f")
format_value(float(features[order.iloc[i]]), "%0.03f")
+ " = "
+ feature_names[order[i]]
+ feature_names[order.iloc[i]]
)
else:
yticklabels[rng[i]] = (
str(features[order[i]]) + " = " + str(feature_names[order[i]])
str(features[order.iloc[i]])
+ " = "
+ str(feature_names[order.iloc[i]])
)

# add a last grouped feature to represent the impact of all the features we didn't
Expand Down

0 comments on commit d9a0aac

Please sign in to comment.