Skip to content

Latest commit

 

History

History
74 lines (50 loc) · 3.49 KB

README.md

File metadata and controls

74 lines (50 loc) · 3.49 KB

Tensorflow/Keras implementation of RISE

Original paper by Petsiuk et al.: https://arxiv.org/abs/1806.07421

This repo was inspired by these great XAI repositories:


Basic usage is as simple as:

from RISE_tf.explain_image import RISE

explainer = RISE()
explanation = explainer.explain(image,model)

There are more options including some novel experimental perturbations you can try.

This is an adaptation of

"Rise: Randomized input sampling for explanation of black-box models. Petsiuk, V., Das, A., & Saenko, K. (2018)." with some additions and adaptations.

A notebook will be included to show some simple usages. The notebook uses VGG19 from tf.keras.application, but it should be easily adaptable to other models. RISE is a model agnostic technique, meaning it can be used as long as you have a function receiving an image input and outputting predictions.

RISE generates explanation heatmaps from the weighted average of perturbation masks generated by a Monte-Carlo approach. The weight of each mask in the average is related to how much it affects the prediction probabilities.


A note on image preprocessing and explanations

The explain() function receives an image and a model.

The image to be explained is expected to be in the input_shape used by the model, without any color normalization or futher preprocessing applied yet. This makes it easier to be consistent among different models that may have wildly different preprocessing routines.

Ideally the any color normalization/preprocessing is included within the model class/function. Tensorflow 2.0 allows this easily with preprocessing layers ( https://www.tensorflow.org/guide/keras/preprocessing_layers).

However many users have separate preprocessing routines that are applied before the image is sent to the model for classification. A quick fix to make this work with this implementation is to define a model class that receives your model and adds the preprocessing step to the predict() method:

class MyModel():
    def __init__(self,model):
        self.model = model
        self.input_shape = model.input_shape
        self.output_shape = model.output_shape
        
    def predict(self, batch_images):
        
        batch_images = your_preprocessing(batch_images) 
        
        return self.model.predict(batch_images)


model_with_preprocessing = MyModel(your_model)

You can customize this class as much as necessary. The explain() function only needs to be able to call the predict() method and input_shape attribute.


Adaptations/changes from the original repo/paper

Some of the main adaptations come in the form of writing a RISE class that uses a tf 2.0/ keras model for predictions. Some small changes in how the batches are passed to the predictor as well. The original work and repo use only black pixel perturbations. This repo includes some other options (Gaussian blur, b&w vs colored noise), but these new options are there just as an experimental test. The source code should also make it easy to implement new types of perturbations if desired. The default however is the black pixel perturbations used in the original repo and paper.

This repo is a work in progress.