diff --git a/neural_structured_learning/keras/adversarial_regularization_test.py b/neural_structured_learning/keras/adversarial_regularization_test.py index 2deebb6..6f6f13e 100644 --- a/neural_structured_learning/keras/adversarial_regularization_test.py +++ b/neural_structured_learning/keras/adversarial_regularization_test.py @@ -18,6 +18,7 @@ from __future__ import print_function import collections +import os from absl.testing import parameterized import neural_structured_learning.configs as configs @@ -64,10 +65,12 @@ def build_linear_keras_functional_model(input_shape, def build_linear_keras_subclassed_model(input_shape, weights, dynamic=False): del input_shape - class LinearModel(tf.keras.Model): + class CustomLinearModel(tf.keras.Model): - def __init__(self): - super(LinearModel, self).__init__(dynamic=dynamic) + def __init__(self, weights, name=None, dynamic=False): + super(CustomLinearModel, self).__init__(name=name, dynamic=dynamic) + self.init_weights = weights + self.init_dynamic = dynamic self.dense = tf.keras.layers.Dense( weights.shape[-1], use_bias=False, @@ -77,7 +80,14 @@ def __init__(self): def call(self, inputs): return self.dense(inputs['feature']) - return LinearModel() + def get_config(self): + return { + 'name': self.name, + 'weights': self.init_weights, + 'dynamic': self.init_dynamic + } + + return CustomLinearModel(weights, dynamic=dynamic) def build_linear_keras_dynamic_model(input_shape, weights): @@ -728,6 +738,42 @@ def test_perturb_on_batch_pgd(self, model_fn): self.assertAllClose(x_adv, adv_inputs['feature']) self.assertAllClose(y0, adv_inputs['label']) + def _test_adv_model_save(self, model_fn): + """Template for testing model saving and loading.""" + w, x0, y0, lr, adv_config, _ = self._set_up_linear_regression() + model = model_fn(input_shape=(2,), weights=w) + adv_model = adversarial_regularization.AdversarialRegularization( + model, label_keys=['label'], adv_config=adv_config) + adv_model.compile(optimizer=tf.keras.optimizers.SGD(lr), loss=['MAE']) + + # Run the model before saving it. This is necessary for subclassed models. + inputs = {'feature': x0, 'label': y0} + adv_model.evaluate(inputs, steps=1) + + saved_model_dir = os.path.join(self.get_temp_dir(), 'saved_model') + adv_model.save(saved_model_dir) + + loaded_model = tf.keras.models.load_model(saved_model_dir) + self.assertEqual( + len(loaded_model.trainable_weights), len(adv_model.trainable_weights)) + for w_loaded, w_adv in zip(loaded_model.trainable_weights, + adv_model.trainable_weights): + self.assertAllClose( + tf.keras.backend.get_value(w_loaded), + tf.keras.backend.get_value(w_adv)) + + @parameterized.named_parameters([ + ('sequential', build_linear_keras_sequential_model), + ('functional', build_linear_keras_functional_model), + ]) + def test_adv_model_save(self, model_fn): + self._test_adv_model_save(model_fn) + + # Saving subclassed models are only supported in TF v2. + @test_util.run_v2_only + def test_adv_model_save_subclassed(self): + self._test_adv_model_save(build_linear_keras_subclassed_model) + if __name__ == '__main__': tf.test.main() diff --git a/neural_structured_learning/keras/graph_regularization.py b/neural_structured_learning/keras/graph_regularization.py index 136b831..c62bd69 100644 --- a/neural_structured_learning/keras/graph_regularization.py +++ b/neural_structured_learning/keras/graph_regularization.py @@ -139,3 +139,9 @@ def call(self, inputs, training=False, **kwargs): self.add_loss(scaled_graph_loss) return base_output + + def save(self, *args, **kwargs): + """Saves the base model. See base class for details of the interface.""" + # Graph regularization doesn't introduce new model variables, so saving the + # base model can capture all variables in the model. + self.base_model.save(*args, **kwargs) diff --git a/neural_structured_learning/keras/graph_regularization_test.py b/neural_structured_learning/keras/graph_regularization_test.py index 7ec4248..c52c41c 100644 --- a/neural_structured_learning/keras/graph_regularization_test.py +++ b/neural_structured_learning/keras/graph_regularization_test.py @@ -17,10 +17,11 @@ from __future__ import division from __future__ import print_function +import os + from absl.testing import parameterized import neural_structured_learning.configs as configs from neural_structured_learning.keras import graph_regularization - import numpy as np import tensorflow as tf @@ -88,10 +89,12 @@ def build_linear_functional_model(input_shape, weights, num_output=1): def build_linear_subclass_model(input_shape, weights, num_output=1): del input_shape - class LinearModel(tf.keras.Model): + class CustomLinearModel(tf.keras.Model): - def __init__(self): - super(LinearModel, self).__init__() + def __init__(self, weights, num_output, name=None): + super(CustomLinearModel, self).__init__(name=name) + self.init_weights = weights + self.num_output = num_output self.dense = tf.keras.layers.Dense( num_output, use_bias=False, @@ -101,7 +104,14 @@ def __init__(self): def call(self, inputs): return self.dense(inputs[FEATURE_NAME]) - return LinearModel() + def get_config(self): + return { + 'name': self.name, + 'weights': self.init_weights, + 'num_output': self.num_output + } + + return CustomLinearModel(weights, num_output) def make_dataset(example_proto, input_shape, training, max_neighbors): @@ -481,6 +491,47 @@ def test_graph_reg_model_evaluate(self, model_fn): weight=w, distributed_strategy=None) + def _test_graph_reg_model_save(self, model_fn): + """Template for testing model saving and loading.""" + w = np.array([[4.0], [-3.0]]) + base_model = model_fn((2,), w) + graph_reg_config = configs.make_graph_reg_config( + max_neighbors=1, multiplier=1) + graph_reg_model = graph_regularization.GraphRegularization( + base_model, graph_reg_config) + graph_reg_model.compile( + optimizer=tf.keras.optimizers.SGD(LEARNING_RATE), + loss='MSE', + metrics=['accuracy']) + + # Run the model before saving it. This is necessary for subclassed models. + inputs = {FEATURE_NAME: tf.constant([[5.0, 3.0]])} + graph_reg_model.predict(inputs, steps=1, batch_size=1) + saved_model_dir = os.path.join(self.get_temp_dir(), 'saved_model') + graph_reg_model.save(saved_model_dir) + + loaded_model = tf.keras.models.load_model(saved_model_dir) + self.assertEqual( + len(loaded_model.trainable_weights), + len(graph_reg_model.trainable_weights)) + for w_loaded, w_graph_reg in zip(loaded_model.trainable_weights, + graph_reg_model.trainable_weights): + self.assertAllClose( + tf.keras.backend.get_value(w_loaded), + tf.keras.backend.get_value(w_graph_reg)) + + @parameterized.named_parameters([ + ('_sequential', build_linear_sequential_model), + ('_functional', build_linear_functional_model), + ]) + def test_graph_reg_model_save(self, model_fn): + self._test_graph_reg_model_save(model_fn) + + # Saving subclassed models are only supported in TF v2. + @test_util.run_v2_only + def test_graph_reg_model_save_subclass(self): + self._test_graph_reg_model_save(build_linear_subclass_model) + if __name__ == '__main__': tf.test.main()