Skip to content

Commit

Permalink
Ensure prediction explanations are json-serializable (#3262)
Browse files Browse the repository at this point in the history
* Add extra encoding to prediction explanations
  • Loading branch information
eccabay authored Jan 20, 2022
1 parent 992b877 commit a8d0dce
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 6 deletions.
1 change: 1 addition & 0 deletions docs/source/release_notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
* Removed empty cell in text_input.ipynb :pr:`3234`
* Removed potential prediction explanations failure when pipelines predicted a class with probability 1 :pr:`3221`
* Dropped NaNs before partial dependence grid generation :pr:`3235`
* Allowed prediction explanations to be json-serializable :pr:`3262`
* Fixed bug where ``InvalidTargetDataCheck`` would not check time series regression targets :pr:`3251`
* Fixed bug in ``are_datasets_separated_by_gap_time_index`` :pr:`3256`
* Changes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ def _make_rows(
feature_value = "{:.2f}".format(feature_value)
else:
feature_value = str(feature_value)

feature_value = _make_json_serializable(feature_value)

row = [feature_name, feature_value, display_text]
if include_explainer_values:
explainer_value = explainer_values[feature_name][0]
Expand Down Expand Up @@ -117,6 +120,8 @@ def _make_json_serializable(value):
value = int(value)
else:
value = float(value)
elif isinstance(value, pd.Timestamp):
value = str(value)

return value

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1303,7 +1303,7 @@ def test_categories_aggregated_text(
"CUC",
"Mastercard",
24900,
pd.Timestamp("2019-01-01 00:12:26"),
str(pd.Timestamp("2019-01-01 00:12:26")),
}
assert explanation["drill_down"].keys() == {"currency", "provider", "datetime"}
assert (
Expand Down Expand Up @@ -1368,7 +1368,7 @@ def test_categories_aggregated_date_ohe(
"datetime",
}
assert set(explanation["feature_values"]) == {
pd.Timestamp("2019-01-01 00:12:26"),
str(pd.Timestamp("2019-01-01 00:12:26")),
"Mastercard",
"CUC",
24900,
Expand Down Expand Up @@ -1442,7 +1442,11 @@ def test_categories_aggregated_pca_dag(
assert all(
[
f in explanation["feature_values"]
for f in [pd.Timestamp("2019-01-01 00:12:26"), "Mastercard", "CUC"]
for f in [
str(pd.Timestamp("2019-01-01 00:12:26")),
"Mastercard",
"CUC",
]
]
)
assert explanation["drill_down"].keys() == {"currency", "provider", "datetime"}
Expand Down Expand Up @@ -1567,7 +1571,7 @@ def test_categories_aggregated_when_some_are_dropped(
"CUC",
"Mastercard",
24900,
pd.Timestamp("2019-01-01 00:12:26"),
str(pd.Timestamp("2019-01-01 00:12:26")),
}
assert explanation["drill_down"].keys() == {"currency", "provider", "datetime"}
assert (
Expand Down Expand Up @@ -2043,7 +2047,10 @@ def test_explain_predictions_report_shows_original_value_if_possible(
top_k_features=20,
algorithm=algorithm,
)
expected_feature_values = set(X.ww.iloc[0, :].tolist())
X_dt = X.copy()
X_dt.ww.init()
X_dt["datetime"] = X_dt["datetime"].astype(str)
expected_feature_values = set(X_dt.ww.iloc[0, :].tolist())
for explanation in report["explanations"][0]["explanations"]:
assert set(explanation["feature_names"]) == set(X.columns)
assert set(explanation["feature_values"]) == expected_feature_values
Expand Down Expand Up @@ -2106,11 +2113,14 @@ def test_explain_predictions_best_worst_report_shows_original_value_if_possible(
algorithm=algorithm,
)

X_dt = X.copy()
X_dt.ww.init()
X_dt["datetime"] = X_dt["datetime"].astype(str)
for index, explanation in enumerate(report["explanations"]):
for exp in explanation["explanations"]:
assert set(exp["feature_names"]) == set(X.columns)
assert set(exp["feature_values"]) == set(
X.ww.iloc[explanation["predicted_values"]["index_id"], :]
X_dt.ww.iloc[explanation["predicted_values"]["index_id"], :]
)

X_null = X.ww.copy()
Expand All @@ -2136,6 +2146,30 @@ def test_explain_predictions_best_worst_report_shows_original_value_if_possible(
assert np.isnan(feature_value)


@pytest.mark.parametrize("algorithm", algorithms)
def test_explain_predictions_best_worst_json(
algorithm, fraud_100, has_minimal_dependencies
):
if has_minimal_dependencies and algorithm == "lime":
pytest.skip("Skipping because lime is a non-core dependency")
pipeline = BinaryClassificationPipeline(
[
"Natural Language Featurizer",
"DateTime Featurizer",
"One Hot Encoder",
"Logistic Regression Classifier",
]
)
X, y = fraud_100
pipeline.fit(X, y)

report = explain_predictions_best_worst(
pipeline, X, y, algorithm=algorithm, output_format="dict"
)
json_output = json.dumps(report)
assert isinstance(json_output, str)


def test_explain_predictions_invalid_algorithm():
pipeline = MagicMock()
input_features = pd.DataFrame({"a": [5, 6, 1, 2, 3, 4, 5, 6, 7, 4]})
Expand Down

0 comments on commit a8d0dce

Please sign in to comment.