Skip to content

Commit

Permalink
Implement GraphRegularization.save().
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 458086350
  • Loading branch information
csferng authored and tensorflow-copybara committed Jun 29, 2022
1 parent df810dd commit 8a3317d
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import print_function

import collections
import os

from absl.testing import parameterized
import neural_structured_learning.configs as configs
Expand Down Expand Up @@ -64,10 +65,12 @@ def build_linear_keras_functional_model(input_shape,
def build_linear_keras_subclassed_model(input_shape, weights, dynamic=False):
del input_shape

class LinearModel(tf.keras.Model):
class CustomLinearModel(tf.keras.Model):

def __init__(self):
super(LinearModel, self).__init__(dynamic=dynamic)
def __init__(self, weights, name=None, dynamic=False):
super(CustomLinearModel, self).__init__(name=name, dynamic=dynamic)
self.init_weights = weights
self.init_dynamic = dynamic
self.dense = tf.keras.layers.Dense(
weights.shape[-1],
use_bias=False,
Expand All @@ -77,7 +80,14 @@ def __init__(self):
def call(self, inputs):
return self.dense(inputs['feature'])

return LinearModel()
def get_config(self):
return {
'name': self.name,
'weights': self.init_weights,
'dynamic': self.init_dynamic
}

return CustomLinearModel(weights, dynamic=dynamic)


def build_linear_keras_dynamic_model(input_shape, weights):
Expand Down Expand Up @@ -728,6 +738,42 @@ def test_perturb_on_batch_pgd(self, model_fn):
self.assertAllClose(x_adv, adv_inputs['feature'])
self.assertAllClose(y0, adv_inputs['label'])

def _test_adv_model_save(self, model_fn):
"""Template for testing model saving and loading."""
w, x0, y0, lr, adv_config, _ = self._set_up_linear_regression()
model = model_fn(input_shape=(2,), weights=w)
adv_model = adversarial_regularization.AdversarialRegularization(
model, label_keys=['label'], adv_config=adv_config)
adv_model.compile(optimizer=tf.keras.optimizers.SGD(lr), loss=['MAE'])

# Run the model before saving it. This is necessary for subclassed models.
inputs = {'feature': x0, 'label': y0}
adv_model.evaluate(inputs, steps=1)

saved_model_dir = os.path.join(self.get_temp_dir(), 'saved_model')
adv_model.save(saved_model_dir)

loaded_model = tf.keras.models.load_model(saved_model_dir)
self.assertEqual(
len(loaded_model.trainable_weights), len(adv_model.trainable_weights))
for w_loaded, w_adv in zip(loaded_model.trainable_weights,
adv_model.trainable_weights):
self.assertAllClose(
tf.keras.backend.get_value(w_loaded),
tf.keras.backend.get_value(w_adv))

@parameterized.named_parameters([
('sequential', build_linear_keras_sequential_model),
('functional', build_linear_keras_functional_model),
])
def test_adv_model_save(self, model_fn):
self._test_adv_model_save(model_fn)

# Saving subclassed models are only supported in TF v2.
@test_util.run_v2_only
def test_adv_model_save_subclassed(self):
self._test_adv_model_save(build_linear_keras_subclassed_model)


if __name__ == '__main__':
tf.test.main()
6 changes: 6 additions & 0 deletions neural_structured_learning/keras/graph_regularization.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,9 @@ def call(self, inputs, training=False, **kwargs):
self.add_loss(scaled_graph_loss)

return base_output

def save(self, *args, **kwargs):
"""Saves the base model. See base class for details of the interface."""
# Graph regularization doesn't introduce new model variables, so saving the
# base model can capture all variables in the model.
self.base_model.save(*args, **kwargs)
61 changes: 56 additions & 5 deletions neural_structured_learning/keras/graph_regularization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@
from __future__ import division
from __future__ import print_function

import os

from absl.testing import parameterized
import neural_structured_learning.configs as configs
from neural_structured_learning.keras import graph_regularization

import numpy as np
import tensorflow as tf

Expand Down Expand Up @@ -88,10 +89,12 @@ def build_linear_functional_model(input_shape, weights, num_output=1):
def build_linear_subclass_model(input_shape, weights, num_output=1):
del input_shape

class LinearModel(tf.keras.Model):
class CustomLinearModel(tf.keras.Model):

def __init__(self):
super(LinearModel, self).__init__()
def __init__(self, weights, num_output, name=None):
super(CustomLinearModel, self).__init__(name=name)
self.init_weights = weights
self.num_output = num_output
self.dense = tf.keras.layers.Dense(
num_output,
use_bias=False,
Expand All @@ -101,7 +104,14 @@ def __init__(self):
def call(self, inputs):
return self.dense(inputs[FEATURE_NAME])

return LinearModel()
def get_config(self):
return {
'name': self.name,
'weights': self.init_weights,
'num_output': self.num_output
}

return CustomLinearModel(weights, num_output)


def make_dataset(example_proto, input_shape, training, max_neighbors):
Expand Down Expand Up @@ -481,6 +491,47 @@ def test_graph_reg_model_evaluate(self, model_fn):
weight=w,
distributed_strategy=None)

def _test_graph_reg_model_save(self, model_fn):
"""Template for testing model saving and loading."""
w = np.array([[4.0], [-3.0]])
base_model = model_fn((2,), w)
graph_reg_config = configs.make_graph_reg_config(
max_neighbors=1, multiplier=1)
graph_reg_model = graph_regularization.GraphRegularization(
base_model, graph_reg_config)
graph_reg_model.compile(
optimizer=tf.keras.optimizers.SGD(LEARNING_RATE),
loss='MSE',
metrics=['accuracy'])

# Run the model before saving it. This is necessary for subclassed models.
inputs = {FEATURE_NAME: tf.constant([[5.0, 3.0]])}
graph_reg_model.predict(inputs, steps=1, batch_size=1)
saved_model_dir = os.path.join(self.get_temp_dir(), 'saved_model')
graph_reg_model.save(saved_model_dir)

loaded_model = tf.keras.models.load_model(saved_model_dir)
self.assertEqual(
len(loaded_model.trainable_weights),
len(graph_reg_model.trainable_weights))
for w_loaded, w_graph_reg in zip(loaded_model.trainable_weights,
graph_reg_model.trainable_weights):
self.assertAllClose(
tf.keras.backend.get_value(w_loaded),
tf.keras.backend.get_value(w_graph_reg))

@parameterized.named_parameters([
('_sequential', build_linear_sequential_model),
('_functional', build_linear_functional_model),
])
def test_graph_reg_model_save(self, model_fn):
self._test_graph_reg_model_save(model_fn)

# Saving subclassed models are only supported in TF v2.
@test_util.run_v2_only
def test_graph_reg_model_save_subclass(self):
self._test_graph_reg_model_save(build_linear_subclass_model)


if __name__ == '__main__':
tf.test.main()

0 comments on commit 8a3317d

Please sign in to comment.