Skip to content

Commit

Permalink
TF 2.3.1 compatibility
Browse files Browse the repository at this point in the history
 - Fixed 'L1' object has no attribute 'l2' (and vice versa for non-`l1_l2` objects)
 - Moved testing to TF2.3.1
  • Loading branch information
OverLordGoldDragon authored Oct 26, 2020
1 parent a9c664a commit 635b5a0
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 13 deletions.
4 changes: 2 additions & 2 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ env:
- TF_VERSION="1.14.0" KERAS_VERSION="2.2.5" TF_KERAS="1"
- TF_VERSION="2.2.0" KERAS_VERSION="2.3.0" TF_EAGER="1"
- TF_VERSION="2.2.0" KERAS_VERSION="2.3.0"
- TF_VERSION="2.3.0" KERAS_VERSION="2.3.0" TF_KERAS="1" TF_EAGER="1"
- TF_VERSION="2.3.0" KERAS_VERSION="2.3.0" TF_KERAS="1"
- TF_VERSION="2.3.1" KERAS_VERSION="2.3.0" TF_KERAS="1" TF_EAGER="1"
- TF_VERSION="2.3.1" KERAS_VERSION="2.3.0" TF_KERAS="1"

notifications:
email: false
Expand Down
2 changes: 1 addition & 1 deletion keras_adamw/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,4 @@
from .utils import get_weight_decays, fill_dict_in_order
from .utils import reset_seeds, K_eval

__version__ = '1.37'
__version__ = '1.38'
20 changes: 12 additions & 8 deletions keras_adamw/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,11 +164,8 @@ def _get_layer_penalties(layer, zero_penalties=False):
for weight_name in ['kernel', 'bias']:
_lambda = getattr(layer, weight_name + '_regularizer', None)
if _lambda is not None:
l1l2 = (float(_lambda.l1), float(_lambda.l2))
l1l2 = _get_and_maybe_zero_penalties(_lambda, zero_penalties)
penalties.append([getattr(layer, weight_name).name, l1l2])
if zero_penalties:
_lambda.l1 = np.array(0., dtype=_lambda.l1.dtype)
_lambda.l2 = np.array(0., dtype=_lambda.l2.dtype)
return penalties


Expand All @@ -190,14 +187,21 @@ def _cell_penalties(rnn_cell, zero_penalties=False):
_lambda = getattr(cell, weight_type + '_regularizer', None)
if _lambda is not None:
weight_name = cell.weights[weight_idx].name
l1l2 = (float(_lambda.l1), float(_lambda.l2))
l1l2 = _get_and_maybe_zero_penalties(_lambda, zero_penalties)
penalties.append([weight_name, l1l2])
if zero_penalties:
_lambda.l1 = np.array(0., dtype=_lambda.l1.dtype)
_lambda.l2 = np.array(0., dtype=_lambda.l2.dtype)
return penalties


def _get_and_maybe_zero_penalties(_lambda, zero_penalties):
if zero_penalties:
if hasattr(_lambda, 'l1'):
_lambda.l1 = np.array(0., dtype=_lambda.l1.dtype)
if hasattr(_lambda, 'l2'):
_lambda.l2 = np.array(0., dtype=_lambda.l2.dtype)
return (float(getattr(_lambda, 'l1', 0.)),
float(getattr(_lambda, 'l2', 0.)))


def fill_dict_in_order(_dict, values_list):
for idx, key in enumerate(_dict.keys()):
_dict[key] = values_list[idx]
Expand Down
4 changes: 2 additions & 2 deletions tests/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ def test_misc(): # tests of non-main features to improve coverage
embed_input_dim = 5

# arbitrarily select SGDW for coverage testing
l1_reg = 1e-4 if optimizer_name == 'SGDW' else 0
l2_reg = 1e-4 if optimizer_name != 'SGDW' else 0
l1_reg = 1e-4 if optimizer_name == 'SGDW' else None
l2_reg = 1e-4 if optimizer_name != 'SGDW' else None
if optimizer_name == 'SGDW':
optimizer_kw.update(dict(zero_penalties=False,
weight_decays={},
Expand Down

0 comments on commit 635b5a0

Please sign in to comment.