Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enanchment in CF Generation #259

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
9 changes: 7 additions & 2 deletions dice_ml/explainer_interfaces/dice_KD.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,8 @@ def find_counterfactuals(self, data_df_copy, query_instance, query_instance_orig
# post-hoc operation on continuous features to enhance sparsity - only for public data
if posthoc_sparsity_param is not None and posthoc_sparsity_param > 0 and 'data_df' in self.data_interface.__dict__:
self.final_cfs_df_sparse = copy.deepcopy(self.final_cfs)
self.final_cfs_df_sparse = self.do_posthoc_sparsity_enhancement(self.final_cfs_df_sparse, query_instance,
self.final_cfs_df_sparse = self.do_posthoc_sparsity_enhancement(self.final_cfs_df_sparse,
query_instance,
posthoc_sparsity_param,
posthoc_sparsity_algorithm)
else:
Expand All @@ -260,10 +261,14 @@ def find_counterfactuals(self, data_df_copy, query_instance, query_instance_orig
if total_cfs_found < total_CFs:
self.elapsed = timeit.default_timer() - start_time
m, s = divmod(self.elapsed, 60)
print('Only %d (required %d) ' % (total_cfs_found, self.total_CFs),
print('Only %d (required %d) ' % (total_cfs_found, total_CFs),
'Diverse Counterfactuals found for the given configuation, perhaps ',
'change the query instance or the features to vary...' '; total time taken: %02d' % m,
'min %02d' % s, 'sec')
elif total_cfs_found == 0:
print(
'No Counterfactuals found for the given configuration, perhaps try with different parameters...',
'; total time taken: %02d' % m, 'min %02d' % s, 'sec')
else:
print('Diverse Counterfactuals found! total time taken: %02d' % m, 'min %02d' % s, 'sec')

Expand Down
28 changes: 16 additions & 12 deletions dice_ml/explainer_interfaces/dice_genetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,8 @@ def do_random_init(self, num_inits, features_to_vary, query_instance, desired_cl
def do_KD_init(self, features_to_vary, query_instance, cfs, desired_class, desired_range):
cfs = self.label_encode(cfs)
cfs = cfs.reset_index(drop=True)

self.cfs = np.zeros((self.population_size, self.data_interface.number_of_features))
for kx in range(self.population_size):
row = []
for kx in range(self.population_size*5):
if kx >= len(cfs):
break
one_init = np.zeros(self.data_interface.number_of_features)
Expand All @@ -143,16 +142,18 @@ def do_KD_init(self, features_to_vary, query_instance, cfs, desired_class, desir
one_init[jx] = query_instance[jx]
else:
one_init[jx] = np.random.choice(self.feature_range[feature])
self.cfs[kx] = one_init
t = tuple(one_init)
if t not in row:
row.append(t)
if len(row) == self.population_size:
break
kx += 1
self.cfs = np.array(row)

new_array = [tuple(row) for row in self.cfs]
uniques = np.unique(new_array, axis=0)

if len(uniques) != self.population_size:
if len(self.cfs) != self.population_size:
remaining_cfs = self.do_random_init(
self.population_size - len(uniques), features_to_vary, query_instance, desired_class, desired_range)
self.cfs = np.concatenate([uniques, remaining_cfs])
self.population_size - len(self.cfs), features_to_vary, query_instance, desired_class, desired_range)
self.cfs = np.concatenate([self.cfs, remaining_cfs])

def do_cf_initializations(self, total_CFs, initialization, algorithm, features_to_vary, desired_range,
desired_class,
Expand Down Expand Up @@ -466,8 +467,8 @@ def find_counterfactuals(self, query_instance, desired_range, desired_class,
if rest_members > 0:
new_generation_2 = np.zeros((rest_members, self.data_interface.number_of_features))
for new_gen_idx in range(rest_members):
parent1 = random.choice(population[:int(len(population) / 2)])
parent2 = random.choice(population[:int(len(population) / 2)])
parent1 = random.choice(population[:max(int(len(population) / 2), 1)])
parent2 = random.choice(population[:max(int(len(population) / 2), 1)])
child = self.mate(parent1, parent2, features_to_vary, query_instance)
new_generation_2[new_gen_idx] = child

Expand Down Expand Up @@ -514,6 +515,9 @@ def find_counterfactuals(self, query_instance, desired_range, desired_class,
if len(self.final_cfs) == self.total_CFs:
print('Diverse Counterfactuals found! total time taken: %02d' %
m, 'min %02d' % s, 'sec')
elif len(self.final_cfs) == 0:
print('No Counterfactuals found for the given configuration, perhaps try with different parameters...',
'; total time taken: %02d' % m, 'min %02d' % s, 'sec')
else:
print('Only %d (required %d) ' % (len(self.final_cfs), self.total_CFs),
'Diverse Counterfactuals found for the given configuation, perhaps ',
Expand Down
10 changes: 8 additions & 2 deletions dice_ml/explainer_interfaces/dice_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,17 @@ class of query_instance for binary classification.
cfs_df = None
candidate_cfs = pd.DataFrame(
np.repeat(query_instance.values, sample_size, axis=0), columns=query_instance.columns)
# Loop to change one feature at a time, then two features, and so on.
# Loop to change one feature at a time ##->(NOT TRUE), then two features, and so on.
for num_features_to_vary in range(1, len(self.features_to_vary)+1):
# commented lines allow more values to change as num_features_to_vary increases, instead of .at you should use .loc
# is deliberately left commented out to let you choose.
# is slower, but more complete and still faster than genetic/KDtree
# selected_features = np.random.choice(self.features_to_vary, (sample_size, num_features_to_vary), replace=True)
selected_features = np.random.choice(self.features_to_vary, (sample_size, 1), replace=True)
for k in range(sample_size):
candidate_cfs.at[k, selected_features[k][0]] = random_instances.at[k, selected_features[k][0]]
candidate_cfs.at[k, selected_features[k][0]] = random_instances._get_value(k, selected_features[k][0])
# If you only want to change one feature, you should use _get_value
# candidate_cfs.iloc[k][selected_features[k]]=random_instances.iloc[k][selected_features[k]]
giandos200 marked this conversation as resolved.
Show resolved Hide resolved
scores = self.predict_fn(candidate_cfs)
validity = self.decide_cf_validity(scores)
if sum(validity) > 0:
Expand Down
62 changes: 41 additions & 21 deletions dice_ml/explainer_interfaces/explainer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def generate_counterfactuals(self, query_instances, total_CFs,
desired_class="opposite", desired_range=None,
permitted_range=None, features_to_vary="all",
stopping_threshold=0.5, posthoc_sparsity_param=0.1,
posthoc_sparsity_algorithm="linear", verbose=False, **kwargs):
posthoc_sparsity_algorithm=None, verbose=False, **kwargs):
"""General method for generating counterfactuals.

:param query_instances: Input point(s) for which counterfactuals are to be generated.
Expand Down Expand Up @@ -81,11 +81,23 @@ def generate_counterfactuals(self, query_instances, total_CFs,
if total_CFs <= 0:
raise UserConfigValidationException(
"The number of counterfactuals generated per query instance (total_CFs) should be a positive integer.")
if total_CFs > 10:
if posthoc_sparsity_algorithm is None:
posthoc_sparsity_algorithm = 'binary'
elif total_CFs > 50 and posthoc_sparsity_algorithm == 'linear':
import warnings
warnings.warn(
"The number of counterfactuals (total_CFs={}) generated per query instance could take much time; "
"if too slow try to change the parameter 'posthoc_sparsity_algorithm' from 'linear' to "
"'binary' search!".format(total_CFs))
elif posthoc_sparsity_algorithm is None:
posthoc_sparsity_algorithm = 'linear'

cf_examples_arr = []
query_instances_list = []
if isinstance(query_instances, pd.DataFrame):
for ix in range(query_instances.shape[0]):
query_instances_list.append(query_instances[ix:(ix+1)])
query_instances_list.append(query_instances[ix:(ix + 1)])
elif isinstance(query_instances, Iterable):
query_instances_list = query_instances

Expand Down Expand Up @@ -179,11 +191,14 @@ def check_query_instance_validity(self, features_to_vary, permitted_range, query

if feature not in features_to_vary and permitted_range is not None:
if feature in permitted_range and feature in self.data_interface.continuous_feature_names:
if not permitted_range[feature][0] <= query_instance[feature].values[0] <= permitted_range[feature][1]:
raise ValueError("Feature:", feature, "is outside the permitted range and isn't allowed to vary.")
if not permitted_range[feature][0] <= query_instance[feature].values[0] <= permitted_range[feature][\
1]:
raise ValueError("Feature:", feature,
"is outside the permitted range and isn't allowed to vary.")
elif feature in permitted_range and feature in self.data_interface.categorical_feature_names:
if query_instance[feature].values[0] not in self.feature_range[feature]:
raise ValueError("Feature:", feature, "is outside the permitted range and isn't allowed to vary.")
raise ValueError("Feature:", feature,
"is outside the permitted range and isn't allowed to vary.")

def local_feature_importance(self, query_instances, cf_examples_list=None,
total_CFs=10,
Expand Down Expand Up @@ -429,12 +444,13 @@ def do_posthoc_sparsity_enhancement(self, final_cfs_sparse, query_instance, post
cfs_preds_sparse = []

for cf_ix in list(final_cfs_sparse.index):
current_pred = self.predict_fn_for_sparsity(final_cfs_sparse.loc[[cf_ix]][self.data_interface.feature_names])
current_pred = self.predict_fn_for_sparsity(
final_cfs_sparse.loc[[cf_ix]][self.data_interface.feature_names])
for feature in features_sorted:
# current_pred = self.predict_fn_for_sparsity(final_cfs_sparse.iat[[cf_ix]][self.data_interface.feature_names])
# feat_ix = self.data_interface.continuous_feature_names.index(feature)
diff = query_instance[feature].iat[0] - int(final_cfs_sparse.at[cf_ix, feature])
if(abs(diff) <= quantiles[feature]):
if (abs(diff) <= quantiles[feature]):
if posthoc_sparsity_algorithm == "linear":
final_cfs_sparse = self.do_linear_search(diff, decimal_prec, query_instance, cf_ix,
feature, final_cfs_sparse, current_pred)
Expand All @@ -455,13 +471,14 @@ def do_linear_search(self, diff, decimal_prec, query_instance, cf_ix, feature, f
query_instance greedily until the prediction class changes."""

old_diff = diff
change = (10**-decimal_prec[feature]) # the minimal possible change for a feature
change = (10 ** -decimal_prec[feature]) # the minimal possible change for a feature
current_pred = current_pred_orig
if self.model.model_type == ModelTypes.Classifier:
while((abs(diff) > 10e-4) and (np.sign(diff*old_diff) > 0) and self.is_cf_valid(current_pred)):
while ((abs(diff) > 10e-4) and (np.sign(diff * old_diff) > 0) and self.is_cf_valid(current_pred)):
old_val = int(final_cfs_sparse.at[cf_ix, feature])
final_cfs_sparse.at[cf_ix, feature] += np.sign(diff)*change
current_pred = self.predict_fn_for_sparsity(final_cfs_sparse.loc[[cf_ix]][self.data_interface.feature_names])
final_cfs_sparse.at[cf_ix, feature] += np.sign(diff) * change
current_pred = self.predict_fn_for_sparsity(
final_cfs_sparse.loc[[cf_ix]][self.data_interface.feature_names])
old_diff = diff

if not self.is_cf_valid(current_pred):
Expand Down Expand Up @@ -494,11 +511,12 @@ def do_binary_search(self, diff, decimal_prec, query_instance, cf_ix, feature, f
right = query_instance[feature].iat[0]

while left <= right:
current_val = left + ((right - left)/2)
current_val = left + ((right - left) / 2)
current_val = round(current_val, decimal_prec[feature])

final_cfs_sparse.at[cf_ix, feature] = current_val
current_pred = self.predict_fn_for_sparsity(final_cfs_sparse.loc[[cf_ix]][self.data_interface.feature_names])
current_pred = self.predict_fn_for_sparsity(
final_cfs_sparse.loc[[cf_ix]][self.data_interface.feature_names])

if current_val == right or current_val == left:
break
Expand All @@ -513,19 +531,20 @@ def do_binary_search(self, diff, decimal_prec, query_instance, cf_ix, feature, f
right = int(final_cfs_sparse.at[cf_ix, feature])

while right >= left:
current_val = right - ((right - left)/2)
current_val = right - ((right - left) / 2)
current_val = round(current_val, decimal_prec[feature])

final_cfs_sparse.at[cf_ix, feature] = current_val
current_pred = self.predict_fn_for_sparsity(final_cfs_sparse.loc[[cf_ix]][self.data_interface.feature_names])
current_pred = self.predict_fn_for_sparsity(
final_cfs_sparse.loc[[cf_ix]][self.data_interface.feature_names])

if current_val == right or current_val == left:
break

if self.is_cf_valid(current_pred):
right = current_val - (10**-decimal_prec[feature])
right = current_val - (10 ** -decimal_prec[feature])
else:
left = current_val + (10**-decimal_prec[feature])
left = current_val + (10 ** -decimal_prec[feature])

return final_cfs_sparse

Expand Down Expand Up @@ -567,7 +586,7 @@ def infer_target_cfs_class(self, desired_class_input, original_pred, num_output_
raise UserConfigValidationException("Desired class not present in training data!")
else:
raise UserConfigValidationException("The target class for {0} could not be identified".format(
desired_class_input))
desired_class_input))

def infer_target_cfs_range(self, desired_range_input):
target_range = None
Expand All @@ -586,7 +605,7 @@ def decide_cf_validity(self, model_outputs):
pred = model_outputs[i]
if self.model.model_type == ModelTypes.Classifier:
if self.num_output_nodes == 2: # binary
pred_1 = pred[self.num_output_nodes-1]
pred_1 = pred[self.num_output_nodes - 1]
validity[i] = 1 if \
((self.target_cf_class == 0 and pred_1 <= self.stopping_threshold) or
(self.target_cf_class == 1 and pred_1 >= self.stopping_threshold)) else 0
Expand Down Expand Up @@ -623,7 +642,7 @@ def is_cf_valid(self, model_score):
(target_cf_class == 1 and pred_1 >= self.stopping_threshold)) else False
return validity
if self.num_output_nodes == 2: # binary
pred_1 = model_score[self.num_output_nodes-1]
pred_1 = model_score[self.num_output_nodes - 1]
validity = True if \
((target_cf_class == 0 and pred_1 <= self.stopping_threshold) or
(target_cf_class == 1 and pred_1 >= self.stopping_threshold)) else False
Expand Down Expand Up @@ -699,7 +718,8 @@ def round_to_precision(self):
for ix, feature in enumerate(self.data_interface.continuous_feature_names):
self.final_cfs_df[feature] = self.final_cfs_df[feature].astype(float).round(precisions[ix])
if self.final_cfs_df_sparse is not None:
self.final_cfs_df_sparse[feature] = self.final_cfs_df_sparse[feature].astype(float).round(precisions[ix])
self.final_cfs_df_sparse[feature] = self.final_cfs_df_sparse[feature].astype(float).round(
precisions[ix])

def _check_any_counterfactuals_computed(self, cf_examples_arr):
"""Check if any counterfactuals were generated for any query point."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -329,4 +329,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}
1 change: 1 addition & 0 deletions requirements-linting.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
flake8==3.9.2
flake8-bugbear==21.11.29
flake8-blind-except==0.1.1
flake8-breakpoint
flake8-builtins==1.5.3
flake8-logging-format==0.6.0
flake8-nb==0.3.0
Expand Down