From 635b5a015fdd160d84b2a2e684fb045da3e502bb Mon Sep 17 00:00:00 2001 From: OverLordGoldDragon <16495490+OverLordGoldDragon@users.noreply.github.com> Date: Mon, 26 Oct 2020 19:59:30 +0400 Subject: [PATCH] TF 2.3.1 compatibility - Fixed 'L1' object has no attribute 'l2' (and vice versa for non-`l1_l2` objects) - Moved testing to TF2.3.1 --- .travis.yml | 4 ++-- keras_adamw/__init__.py | 2 +- keras_adamw/utils.py | 20 ++++++++++++-------- tests/test_optimizers.py | 4 ++-- 4 files changed, 17 insertions(+), 13 deletions(-) diff --git a/.travis.yml b/.travis.yml index ec099f6..67544ca 100644 --- a/.travis.yml +++ b/.travis.yml @@ -12,8 +12,8 @@ env: - TF_VERSION="1.14.0" KERAS_VERSION="2.2.5" TF_KERAS="1" - TF_VERSION="2.2.0" KERAS_VERSION="2.3.0" TF_EAGER="1" - TF_VERSION="2.2.0" KERAS_VERSION="2.3.0" - - TF_VERSION="2.3.0" KERAS_VERSION="2.3.0" TF_KERAS="1" TF_EAGER="1" - - TF_VERSION="2.3.0" KERAS_VERSION="2.3.0" TF_KERAS="1" + - TF_VERSION="2.3.1" KERAS_VERSION="2.3.0" TF_KERAS="1" TF_EAGER="1" + - TF_VERSION="2.3.1" KERAS_VERSION="2.3.0" TF_KERAS="1" notifications: email: false diff --git a/keras_adamw/__init__.py b/keras_adamw/__init__.py index 211d577..a23df3e 100644 --- a/keras_adamw/__init__.py +++ b/keras_adamw/__init__.py @@ -28,4 +28,4 @@ from .utils import get_weight_decays, fill_dict_in_order from .utils import reset_seeds, K_eval -__version__ = '1.37' +__version__ = '1.38' diff --git a/keras_adamw/utils.py b/keras_adamw/utils.py index 4deacb2..05b60ea 100644 --- a/keras_adamw/utils.py +++ b/keras_adamw/utils.py @@ -164,11 +164,8 @@ def _get_layer_penalties(layer, zero_penalties=False): for weight_name in ['kernel', 'bias']: _lambda = getattr(layer, weight_name + '_regularizer', None) if _lambda is not None: - l1l2 = (float(_lambda.l1), float(_lambda.l2)) + l1l2 = _get_and_maybe_zero_penalties(_lambda, zero_penalties) penalties.append([getattr(layer, weight_name).name, l1l2]) - if zero_penalties: - _lambda.l1 = np.array(0., dtype=_lambda.l1.dtype) - _lambda.l2 = np.array(0., dtype=_lambda.l2.dtype) return penalties @@ -190,14 +187,21 @@ def _cell_penalties(rnn_cell, zero_penalties=False): _lambda = getattr(cell, weight_type + '_regularizer', None) if _lambda is not None: weight_name = cell.weights[weight_idx].name - l1l2 = (float(_lambda.l1), float(_lambda.l2)) + l1l2 = _get_and_maybe_zero_penalties(_lambda, zero_penalties) penalties.append([weight_name, l1l2]) - if zero_penalties: - _lambda.l1 = np.array(0., dtype=_lambda.l1.dtype) - _lambda.l2 = np.array(0., dtype=_lambda.l2.dtype) return penalties +def _get_and_maybe_zero_penalties(_lambda, zero_penalties): + if zero_penalties: + if hasattr(_lambda, 'l1'): + _lambda.l1 = np.array(0., dtype=_lambda.l1.dtype) + if hasattr(_lambda, 'l2'): + _lambda.l2 = np.array(0., dtype=_lambda.l2.dtype) + return (float(getattr(_lambda, 'l1', 0.)), + float(getattr(_lambda, 'l2', 0.))) + + def fill_dict_in_order(_dict, values_list): for idx, key in enumerate(_dict.keys()): _dict[key] = values_list[idx] diff --git a/tests/test_optimizers.py b/tests/test_optimizers.py index 5125966..17527fb 100644 --- a/tests/test_optimizers.py +++ b/tests/test_optimizers.py @@ -90,8 +90,8 @@ def test_misc(): # tests of non-main features to improve coverage embed_input_dim = 5 # arbitrarily select SGDW for coverage testing - l1_reg = 1e-4 if optimizer_name == 'SGDW' else 0 - l2_reg = 1e-4 if optimizer_name != 'SGDW' else 0 + l1_reg = 1e-4 if optimizer_name == 'SGDW' else None + l2_reg = 1e-4 if optimizer_name != 'SGDW' else None if optimizer_name == 'SGDW': optimizer_kw.update(dict(zero_penalties=False, weight_decays={},