diff --git a/README.md b/README.md index 9707cbdf2..6dfa58024 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,7 @@ To quickly play around with the TruLens library, check out the following CoLab n * PyTorch: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1n77IGrPDO2XpeIVo_LQW0gY78enV-tY9?usp=sharing) * Tensorflow 2 / Keras: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1f-ETsdlppODJGQCdMXG-jmGmfyWyW2VD?usp=sharing) * NLP with PyTorch: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/18GcjsYMkRbxPDDS3J6BEbKnb7AY-1-Wa?usp=sharing) +* NLP with Tensorflow 2 / Keras: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1K09IvN7cMTkzsnb-uAeA0YQNfDU7Ibhs?usp=sharing) # Installation diff --git a/notebooks/nlp_demo_tf2.ipynb b/notebooks/nlp_demo_tf2.ipynb new file mode 100644 index 000000000..a5ead4053 --- /dev/null +++ b/notebooks/nlp_demo_tf2.ipynb @@ -0,0 +1,457 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "MltE4YvRByjg" + }, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "pwhMpDraXSk4" + }, + "outputs": [], + "source": [ + "import sys\n", + "sys.path.insert(0, \"..\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "xBf59EWqCC19", + "outputId": "e811afaf-fadc-48b4-c13a-6d178e8a313b" + }, + "outputs": [], + "source": [ + "# Install Trulens\n", + "!{sys.executable} -m pip install git+https://github.com/truera/trulens.git\n", + "!{sys.executable} -m pip install -U tensorflow-text tensorflow-hub tf-models-official gdown opencv-python" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "AVeydDoFByjj", + "outputId": "210796ee-cc1e-408a-e2b4-1a2f2c55b95d" + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "import tensorflow_text\n", + "import tensorflow as tf\n", + "import pandas as pd\n", + "from IPython.display import HTML as html_print\n", + "import plotly.express as px\n", + "\n", + "from trulens.nn.models import get_model_wrapper\n", + "from trulens.nn.attribution import InternalInfluence\n", + "from trulens.nn.slices import OutputCut, Slice, Cut\n", + "from trulens.nn.quantities import MaxClassQoI\n", + "from trulens.nn.distributions import LinearDoi\n", + "import gdown\n", + "\n", + "tf.get_logger().setLevel('ERROR')\n", + "print(tf.__version__)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "gdown.download(id=\"1-bVFx-qU_kD7gGqV2E8ucRrV0LKFxHzB\", output=\"resources.zip\", quiet=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "HbOWc0jBTuSz", + "outputId": "75cd4e61-c629-497c-b206-f7fceedfca12" + }, + "outputs": [], + "source": [ + "# Download notebook resources.\n", + "!mkdir -p resources\n", + "!unzip -o -d resources resources.zip\n", + "!rm resources.zip" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "f1R17GTA_iK9" + }, + "source": [ + "# Loading the model\n", + "The notebook resources include a model checkpoint. The model uses the Tensorflow Hub [Text Preprocessing layer](https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3) and [Small Bert layer](https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-512_A-8/2) followed by several convolutional and fully connected layers.\n", + "\n", + "The model has already been trained on a sentiment analysis task with the Covid-19 Tweets dataset.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "fKQrtVC-Byjk", + "outputId": "0e5a27e2-c731-45aa-e1df-422defe8f4a5" + }, + "outputs": [], + "source": [ + "model_name = 'classifierbert-cnn'\n", + "\n", + "model = tf.keras.models.load_model('./resources/' + model_name, compile=False)\n", + "model.summary()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8PjHCUd2AWCf" + }, + "source": [ + "## Model Vocabulary\n", + "We also load the vocabulary behind the model. This helps us translate our token IDs back into tokenized words." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "usLQZbmcoaNl" + }, + "outputs": [], + "source": [ + "vocab_file = f'./resources/{model_name}/assets/vocab.txt'\n", + "with open(vocab_file) as f:\n", + " vocab = f.read().splitlines()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "V8kSvuJyAlxg" + }, + "source": [ + "This model describes the sentiment of tweets into 5 classes: positive, extremely positive, negative, extremely negative, or neutral. Lets try it out on some examples." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Nv6eeKyYzad6" + }, + "outputs": [], + "source": [ + "sentences = [\n", + " \"Fill up the fridge with enough food, ready medical supplies, water, avoid crowd, be updated with news, dont panic, work from home if feasible, boost immune system by drinking vitamins and always wash hands. Stay safe and healthy! #COVID2019 #metroManilaCovid\",\n", + " \"Big thanks to all the retail, supermarket workers & nurses out there. This is mental and the subsequent panic buying and rise in cases shows just how important they are #Covid_19\",\n", + " \"I understand food being out of stock, but why toilet paper? what's up with that? #covid_19 #coronavirus\",\n", + " \"This Friday the 13th is a nightmare for supermarket employees. People are panic buying a day after Duterte announced an NCR lockdown. Carts are filled w/ all sorts of noodles. I guess these Metro Manila residents will be on pancit canton/bihon diet for a month #Covid_19 https://t.co/33Bw2ZKnds\",\n", + " \"Food, emergency supply stores struggle to meet demand #coronavirus #yzf https://t.co/wZ6yLBU2rl https://t.co/2Ef6Fy9u8y\"\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "fpwcN-32Byjl", + "outputId": "19af9c52-075a-47a9-f450-65a5260603ba" + }, + "outputs": [], + "source": [ + "classes = ['Extremely Negative', 'Negative', 'Neutral', 'Positive', 'Extremely Positive']\n", + "\n", + "predictions = model(tf.constant(sentences)).numpy()\n", + "for sentence, pred in zip(sentences, predictions):\n", + " print(f\"Predicted {classes[np.argmax(pred)]}: '{sentence}'\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "zvwxAUF-30s5" + }, + "source": [ + "# Model Wrapper\n", + "\n", + "As in the prior notebooks, we need to wrap the model with the appropriate Trulens functionality. As we are using a tf.keras model, it should be specified in the backend parameter." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "VGh0i0NAByjm", + "outputId": "c1222a92-6c78-46fc-9a21-fe32d26b4d23" + }, + "outputs": [], + "source": [ + "k_model = get_model_wrapper(model)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "lDplfgg54I8u" + }, + "source": [ + "# Attributions\n", + "\n", + "The model takes in text as input, which get tokenized in the `preprocessing` layer and translated into embeddings in the `BERT_encoder` layer. Since we cannot take the gradient with respect to the raw input text or tokenized text directly, we must use the embedding representation of our inputs.\n", + "\n", + "Below, we can inspect the available layers in our model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "yFmaRMzBj8N0", + "outputId": "4da74970-13b1-4316-8b8d-c35739bb3950" + }, + "outputs": [], + "source": [ + "[layer_name for layer_name in k_model._layers]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "nzCNLa7Y4ZKc" + }, + "source": [ + "## Parameters\n", + "\n", + "Above, `BERT_encoder/bert_encoder/word_embeddings` is the layer that produces a continuous representation of each input token so we will use that layer as the one defining the **distribution of interest**. While most neural NLP models contain a token embedding, the layer name will differ.\n", + "\n", + "The second thing to note is the form of model outputs. Specifically, outputs are structures which contain a 'logits' attribute that stores the model scores.\n", + "\n", + "Putting these things together, we instantiate `InternalInfluence` to attribute each embedding dimension to the maximum class (i.e. the predicted class)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "c5GSy4vWByjm" + }, + "outputs": [], + "source": [ + "embedding_layer_name = 'BERT_encoder/bert_encoder/word_embeddings'\n", + "\n", + "infl = InternalInfluence(\n", + " model=k_model,\n", + " cuts=Slice(Cut(embedding_layer_name, anchor='out'), OutputCut()),\n", + " qoi=MaxClassQoI(),\n", + " doi=LinearDoi(resolution=10, cut=Cut(embedding_layer_name, anchor='in'))\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "7VB_idn9CXDR" + }, + "source": [ + "We apply the preprocessing step to tokenize our input text. Using the model vocabulary, the token IDs (`sentence_encodings`) can be translated back into tokenized words (`tokens`)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "yoDj2m9WmVkQ" + }, + "outputs": [], + "source": [ + "# Define preprocessor\n", + "inp = model.input\n", + "preprocessing_layer = model.get_layer('preprocessing').get_output_at(-1)\n", + "pp_func = tf.keras.backend.function(inp, preprocessing_layer)\n", + "\n", + "sentence_encodings = pp_func(tf.constant(sentences))['input_word_ids']\n", + "tokens = [[vocab[i] for i in sentence] for sentence in sentence_encodings]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ChHEop0AC1S7" + }, + "source": [ + "Getting attributions uses the same call as model evaluation and returns a tensor. We can aggregate the attributions across the embedding dimension to get an approximate look at the influence of each token." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "BYkuVIY6552X" + }, + "outputs": [], + "source": [ + "attrs_internal = infl.attributions(np.array(sentences))\n", + "total_attrs = attrs_internal.sum(axis=2)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "OXqB8d5Q-eIL" + }, + "source": [ + "# Visualizing Influences\n", + "Here we display visualizations that describe the influence of each token on the final prediction. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "xEqsVKn3Byjq" + }, + "outputs": [], + "source": [ + "def rgb_str(r,g,b):\n", + " return \"rgb(%d,%d,%d)\" % (r,g,b)\n", + "\n", + "def cstr(s, color='black', background='white'):\n", + " return \"{}\".format(color, background, s)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 225 + }, + "id": "rrKTLEegByjs", + "outputId": "13f8ae4f-eafc-44c1-cc7e-eb5e80323fcb" + }, + "outputs": [], + "source": [ + "html=''\n", + "for sentence_idx in range(len(sentences)):\n", + " html += classes[np.argmax(predictions[sentence_idx])] + \": \"\n", + "\n", + " # Define the coloring for each token. Green=positive, Red=negative.\n", + " # Color intensity describes the magnitude of the influence in either direction. \n", + " max_imp = max(abs(total_attrs[sentence_idx]))\n", + " rgbs=[]\n", + " for imp in total_attrs[sentence_idx]:\n", + " normed_imp = int(imp/max_imp*256)\n", + " intensity = abs(normed_imp)\n", + " if normed_imp > 0: # green\n", + " rgbs.append(rgb_str(256-intensity, 256, 256-intensity))\n", + " else: # red\n", + " rgbs.append(rgb_str(256, 256-intensity, 256-intensity))\n", + "\n", + " for i, token in enumerate(tokens[sentence_idx]):\n", + " if token != \"[PAD]\":\n", + " html += cstr(token, 'black', rgbs[i]) + ' '\n", + " html += \"

\"\n", + "html_print(html)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000 + }, + "id": "l4T0BkZeByjs", + "outputId": "9f7181a0-60dc-415e-c526-28f46882a0de" + }, + "outputs": [], + "source": [ + "for sentence_idx in range(len(sentences)):\n", + " df = pd.DataFrame({'Tokens': tokens[sentence_idx],'Importance': total_attrs[sentence_idx]})\n", + " fig = px.bar(df, x='Tokens', y='Importance')\n", + " fig.update_layout(width=1200, height=300,)\n", + " fig.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "LN7epOxfByjs" + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "anaconda-cloud": {}, + "colab": { + "collapsed_sections": [], + "name": "nlp_demo_tf2.ipynb", + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3.9.12 ('demo3')", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.12" + }, + "vscode": { + "interpreter": { + "hash": "ce4e35a76a569399d57219f9877d3cff9bc99a439b1c8dd709c903be401418f7" + } + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tests/keras/unit/attribution_axioms_test.py b/tests/keras/unit/attribution_axioms_test.py index 9261e7886..713c051a1 100644 --- a/tests/keras/unit/attribution_axioms_test.py +++ b/tests/keras/unit/attribution_axioms_test.py @@ -54,6 +54,7 @@ def setUp(self): self.layer2 = 2 self.layer3 = 3 + class NestedAxiomsTest(AxiomsTestBase, TestCase): def setUp(self): diff --git a/tests/keras/unit/doi_test.py b/tests/keras/unit/doi_test.py index c76c72567..cf125dd48 100644 --- a/tests/keras/unit/doi_test.py +++ b/tests/keras/unit/doi_test.py @@ -31,6 +31,7 @@ def setUp(self): self.layer1 = 1 self.layer2 = 2 + class NestedDoiTest(DoiTestBase, TestCase): def setUp(self): @@ -39,7 +40,7 @@ def setUp(self): l0 = Input((1,)) l1 = Lambda(lambda input: self.l1_coeff * (input**self.l1_exp))(l0) nested_model = Model(l0, l1) - + l0 = Input((1,)) l1 = nested_model(l0) l2 = Lambda(lambda input: self.l2_coeff * (input**self.l2_exp))(l1) diff --git a/tests/keras/unit/keras_model_test.py b/tests/keras/unit/keras_model_test.py index 243007999..1bb274bf3 100644 --- a/tests/keras/unit/keras_model_test.py +++ b/tests/keras/unit/keras_model_test.py @@ -55,32 +55,34 @@ def test_wrong_keras_version(self): with self.assertRaises(ValueError): KerasModelWrapper(tf_keras_model) + class NestedModelWrapperTest(ModelWrapperTestBase, TestCase): - def setUp(self): - super(NestedModelWrapperTest, self).setUp() - n_x = Input((2,)) - n_y = Dense(2, activation='relu')(n_x) - nested_model = Model([n_x], [n_y]) - - x = Input((2,)) - z = nested_model(x) - z = Dense(2, activation='relu')(z) - y = Dense(1, name='logits')(z) - model = Model(x, y) - - self.model = KerasModelWrapper(model) - - self.model._model.set_weights( - [ - self.layer1_weights, self.internal_bias, self.layer2_weights, - self.internal_bias, self.layer3_weights, self.bias - ] - ) - - self.layer0 = 0 - self.layer1 = 1 - self.layer2 = 2 - self.out = 'logits' + + def setUp(self): + super(NestedModelWrapperTest, self).setUp() + n_x = Input((2,)) + n_y = Dense(2, activation='relu')(n_x) + nested_model = Model([n_x], [n_y]) + + x = Input((2,)) + z = nested_model(x) + z = Dense(2, activation='relu')(z) + y = Dense(1, name='logits')(z) + model = Model(x, y) + + self.model = KerasModelWrapper(model) + + self.model._model.set_weights( + [ + self.layer1_weights, self.internal_bias, self.layer2_weights, + self.internal_bias, self.layer3_weights, self.bias + ] + ) + + self.layer0 = 0 + self.layer1 = 1 + self.layer2 = 2 + self.out = 'logits' if __name__ == '__main__': diff --git a/tests/notebooks/requirements.txt b/tests/notebooks/requirements.txt index fc2c0997e..138bb53e8 100644 --- a/tests/notebooks/requirements.txt +++ b/tests/notebooks/requirements.txt @@ -1,10 +1,16 @@ jinja2==3.0.3 nbformat==5.0.8 nbconvert==6.0.7 +pandas==1.3.5 matplotlib scipy -tensorflow==2.6.4 +tensorflow-hub==0.12.0 +tensorflow-text==2.8.2 +tensorflow==2.8.2 +tf-models-official==2.8.0 torch==1.6.0 torchvision==0.7.0 ipywidgets -protobuf==3.20.* +plotly +gdown +protobuf==3.19.* diff --git a/tests/notebooks/requirements_latest.txt b/tests/notebooks/requirements_latest.txt index fef798429..c39ae935f 100644 --- a/tests/notebooks/requirements_latest.txt +++ b/tests/notebooks/requirements_latest.txt @@ -2,8 +2,15 @@ jinja2==3.0.3 nbformat==5.0.8 nbconvert==6.0.7 matplotlib +pandas +tensorflow-hub +tf-models-official +tensorflow-text scipy tensorflow torch torchvision ipywidgets +plotly +gdown +protobuf==3.20.* \ No newline at end of file diff --git a/trulens/nn/attribution.py b/trulens/nn/attribution.py index 91dfcd9a1..a8ec95c88 100644 --- a/trulens/nn/attribution.py +++ b/trulens/nn/attribution.py @@ -381,7 +381,7 @@ def _attributions(self, model_inputs: ModelInputs) -> AttributionResult: attribution_cut=None, # InputCut(), intervention=model_inputs )[0] - + doi_val = nested_map(doi_val, B.as_array) D = self.doi._wrap_public_call(doi_val, model_inputs=model_inputs) if self._return_doi: diff --git a/trulens/nn/backend/tf_backend/tf.py b/trulens/nn/backend/tf_backend/tf.py index 4455fb723..31092c9de 100644 --- a/trulens/nn/backend/tf_backend/tf.py +++ b/trulens/nn/backend/tf_backend/tf.py @@ -454,4 +454,7 @@ def is_tensor(x): ---------- x : backend.Tensor or other """ - return isinstance(x, Tensor) + try: + return isinstance(x, Tensor) or tf.keras.backend.is_keras_tensor(x) + except: + return False diff --git a/trulens/nn/models/keras_utils.py b/trulens/nn/models/keras_utils.py index 5bff4c688..ed65b939d 100644 --- a/trulens/nn/models/keras_utils.py +++ b/trulens/nn/models/keras_utils.py @@ -91,10 +91,15 @@ def path_from_tensor(a): for node in model._nodes_by_depth[depth]: # add layer output to layer_outputs layer = node.outbound_layer - layer_outputs[layer.name] = recurse_outputs(om_of_many(node.output_tensors), [layer.name]) + layer_outputs[ + layer.name + ] = recurse_outputs(om_of_many(node.output_tensors), [layer.name]) if node.inbound_layers: # Get input tensor paths for next layer from from prev layer's outputs - prev_layers = set(many_of_om(node.inbound_layers)) + if isinstance(node.inbound_layers, dict): + prev_layers = set(node.inbound_layers.values()) + else: + prev_layers = set(many_of_om(node.inbound_layers)) try: args = many_of_om(node.call_args) @@ -103,7 +108,7 @@ def path_from_tensor(a): # call_args, call_kwargs attributes don't exist in older Keras versions args = many_of_om(node.input_tensors) kwargs = {} - + arg_paths = [get_arg_path(prev_layers, arg) for arg in args] kwarg_paths = { key: get_arg_path(prev_layers, arg) @@ -252,7 +257,9 @@ def prop_through_layer(depth, dirty=False): """ if depth < 0: nodes = model._nodes_by_depth[0] - return om_of_many([layer_outputs[node.outbound_layer.name] for node in nodes]) + return om_of_many( + [layer_outputs[node.outbound_layer.name] for node in nodes] + ) nodes = model._nodes_by_depth[depth] if not dirty and all( @@ -261,7 +268,8 @@ def prop_through_layer(depth, dirty=False): for node in nodes): # no prior modifications, no nested models, and no replacements at this depth, continue on for node in nodes: - layer_outputs[node.outbound_layer.name] = node.outbound_layer.get_output_at(-1) + layer_outputs[node.outbound_layer.name + ] = om_of_many(node.output_tensors) return prop_through_layer(depth=depth - 1, dirty=dirty) diff --git a/trulens/nn/models/tensorflow_v2.py b/trulens/nn/models/tensorflow_v2.py index 429670749..67f69b973 100644 --- a/trulens/nn/models/tensorflow_v2.py +++ b/trulens/nn/models/tensorflow_v2.py @@ -13,7 +13,9 @@ from trulens.utils import tru_logger from trulens.utils.typing import DATA_CONTAINER_TYPE from trulens.utils.typing import Inputs +from trulens.utils.typing import many_of_om from trulens.utils.typing import ModelInputs +from trulens.utils.typing import nested_cast from trulens.utils.typing import nested_map from trulens.utils.typing import om_of_many from trulens.utils.typing import Outputs @@ -122,25 +124,31 @@ def _get_output_layer(self): for output in self._model.outputs: for layer in self._layers.values(): try: - if layer is output or layer.output is output: + if layer is output or self._get_layer_output(layer + ) is output: output_layers.append(layer) except: - # layer.output may not be instantiated when using model subclassing, - # but it is not a problem because self._model.outputs is only autoselected as output_layer.output - # when not subclassing. + # layer output may not be instantiated when using model subclassing, + # but it is not a problem because self._model.outputs is only autoselected as + # the output_layer output when not subclassing. continue return output_layers def _is_input_layer(self, layer): if (self._model.inputs is not None): - return any([inpt is layer.output for inpt in self._model.inputs]) + return any( + [ + inpt is self._get_layer_output(layer) + for inpt in self._model.inputs + ] + ) else: return False def _input_layer_index(self, layer): for i, inpt in enumerate(self._model.inputs): - if inpt is layer.output: + if inpt is self._get_layer_output(layer): return i return None @@ -183,10 +191,19 @@ def _fprop( ) for layer, x_i in zip(from_layers, intervention.args): + + def intervention_fn(x): + nonlocal x_i + x, x_i = many_of_om(x), many_of_om(x_i) + for i, (_x, _x_i) in enumerate(zip(x, x_i)): + if _x.dtype != _x_i.dtype: + x_i[i] = tf.cast(_x_i, _x.dtype) + return om_of_many(x_i) + if doi_cut.anchor == 'in': - layer.input_intervention = lambda _: x_i + layer.input_intervention = intervention_fn else: - layer.output_intervention = lambda _: x_i + layer.output_intervention = intervention_fn else: model_inputs = intervention