From 6186b6c5e83d4c8cb667605d3306cd171215f77b Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Wed, 8 Nov 2023 13:19:42 +0800 Subject: [PATCH] fix: fix optree compatibility --- torchopt/alias/utils.py | 26 +++++++++----- torchopt/transform/add_decayed_weights.py | 12 ++++--- torchopt/transform/scale_by_adadelta.py | 18 ++++------ torchopt/transform/scale_by_adam.py | 18 ++++------ torchopt/transform/scale_by_adamax.py | 18 +++++----- torchopt/transform/scale_by_rms.py | 12 ++++--- torchopt/transform/scale_by_rss.py | 22 +++++------- torchopt/transform/scale_by_schedule.py | 16 +++++---- torchopt/transform/scale_by_stddev.py | 12 ++++--- torchopt/transform/trace.py | 42 +++++++++++++++-------- 10 files changed, 110 insertions(+), 86 deletions(-) diff --git a/torchopt/alias/utils.py b/torchopt/alias/utils.py index 5c8dc97a..a984a889 100644 --- a/torchopt/alias/utils.py +++ b/torchopt/alias/utils.py @@ -108,19 +108,23 @@ def update_fn( if inplace: - def f(g: torch.Tensor, p: torch.Tensor) -> torch.Tensor: + def f(p: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g if g.requires_grad: return g.add_(p, alpha=weight_decay) return g.add_(p.data, alpha=weight_decay) - updates = tree_map_(f, updates, params) + tree_map_(f, params, updates) else: - def f(g: torch.Tensor, p: torch.Tensor) -> torch.Tensor: + def f(p: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g return g.add(p, alpha=weight_decay) - updates = tree_map(f, updates, params) + updates = tree_map(f, params, updates) return updates, state @@ -139,7 +143,7 @@ def update_fn( def f(g: torch.Tensor) -> torch.Tensor: return g.neg_() - updates = tree_map_(f, updates) + tree_map_(f, updates) else: @@ -166,19 +170,23 @@ def update_fn( if inplace: - def f(g: torch.Tensor, p: torch.Tensor) -> torch.Tensor: + def f(p: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g if g.requires_grad: return g.neg_().add_(p, alpha=weight_decay) return g.neg_().add_(p.data, alpha=weight_decay) - updates = tree_map_(f, updates, params) + tree_map_(f, params, updates) else: - def f(g: torch.Tensor, p: torch.Tensor) -> torch.Tensor: + def f(p: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g return g.neg().add_(p, alpha=weight_decay) - updates = tree_map(f, updates, params) + updates = tree_map(f, params, updates) return updates, state diff --git a/torchopt/transform/add_decayed_weights.py b/torchopt/transform/add_decayed_weights.py index 04d564d7..6643e6c7 100644 --- a/torchopt/transform/add_decayed_weights.py +++ b/torchopt/transform/add_decayed_weights.py @@ -226,19 +226,23 @@ def update_fn( if inplace: - def f(g: torch.Tensor, p: torch.Tensor) -> torch.Tensor: + def f(p: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g if g.requires_grad: return g.add_(p, alpha=weight_decay) return g.add_(p.data, alpha=weight_decay) - updates = tree_map_(f, updates, params) + tree_map_(f, params, updates) else: - def f(g: torch.Tensor, p: torch.Tensor) -> torch.Tensor: + def f(p: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g return g.add(p, alpha=weight_decay) - updates = tree_map(f, updates, params) + updates = tree_map(f, params, updates) return updates, state diff --git a/torchopt/transform/scale_by_adadelta.py b/torchopt/transform/scale_by_adadelta.py index fb5431a3..d644cf55 100644 --- a/torchopt/transform/scale_by_adadelta.py +++ b/torchopt/transform/scale_by_adadelta.py @@ -129,23 +129,19 @@ def update_fn( if inplace: - def f( - g: torch.Tensor, # pylint: disable=unused-argument - m: torch.Tensor, - v: torch.Tensor, - ) -> torch.Tensor: + def f(m: torch.Tensor, v: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g return g.mul_(v.add(eps).div_(m.add(eps)).sqrt_()) else: - def f( - g: torch.Tensor, # pylint: disable=unused-argument - m: torch.Tensor, - v: torch.Tensor, - ) -> torch.Tensor: + def f(m: torch.Tensor, v: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g return g.mul(v.add(eps).div_(m.add(eps)).sqrt_()) - updates = tree_map(f, updates, mu, state.nu) + updates = tree_map(f, mu, state.nu, updates) nu = update_moment.impl( # type: ignore[attr-defined] updates, diff --git a/torchopt/transform/scale_by_adam.py b/torchopt/transform/scale_by_adam.py index 5bf84eba..cc9ea146 100644 --- a/torchopt/transform/scale_by_adam.py +++ b/torchopt/transform/scale_by_adam.py @@ -201,23 +201,19 @@ def update_fn( if inplace: - def f( - g: torch.Tensor, # pylint: disable=unused-argument - m: torch.Tensor, - v: torch.Tensor, - ) -> torch.Tensor: + def f(m: torch.Tensor, v: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g return m.div_(v.add_(eps_root).sqrt_().add(eps)) else: - def f( - g: torch.Tensor, # pylint: disable=unused-argument - m: torch.Tensor, - v: torch.Tensor, - ) -> torch.Tensor: + def f(m: torch.Tensor, v: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g return m.div(v.add(eps_root).sqrt_().add(eps)) - updates = tree_map(f, updates, mu_hat, nu_hat) + updates = tree_map(f, mu_hat, nu_hat, updates) return updates, ScaleByAdamState(mu=mu, nu=nu, count=count_inc) return GradientTransformation(init_fn, update_fn) diff --git a/torchopt/transform/scale_by_adamax.py b/torchopt/transform/scale_by_adamax.py index 504e82cd..54ea861b 100644 --- a/torchopt/transform/scale_by_adamax.py +++ b/torchopt/transform/scale_by_adamax.py @@ -137,23 +137,21 @@ def update_fn( already_flattened=already_flattened, ) - def update_nu( - g: torch.Tensor, - n: torch.Tensor, - ) -> torch.Tensor: + def update_nu(n: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return None return torch.max(n.mul(b2), g.abs().add_(eps)) - nu = tree_map(update_nu, updates, state.nu) + nu = tree_map(update_nu, state.nu, updates) one_minus_b1_pow_t = 1 - b1**state.t - def f( - n: torch.Tensor, - m: torch.Tensor, - ) -> torch.Tensor: + def f(m: torch.Tensor, n: torch.Tensor | None) -> torch.Tensor: + if n is None: + return m return m.div(n).div_(one_minus_b1_pow_t) - updates = tree_map(f, nu, mu) + updates = tree_map(f, mu, nu) return updates, ScaleByAdamaxState(mu=mu, nu=nu, t=state.t + 1) diff --git a/torchopt/transform/scale_by_rms.py b/torchopt/transform/scale_by_rms.py index ac2fef16..23b385e7 100644 --- a/torchopt/transform/scale_by_rms.py +++ b/torchopt/transform/scale_by_rms.py @@ -136,17 +136,21 @@ def update_fn( if inplace: - def f(g: torch.Tensor, n: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name + def f(n: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g return g.div_(n.sqrt().add_(eps)) - updates = tree_map_(f, updates, nu) + tree_map_(f, nu, updates) else: - def f(g: torch.Tensor, n: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name + def f(n: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g return g.div(n.sqrt().add(eps)) - updates = tree_map(f, updates, nu) + updates = tree_map(f, nu, updates) return updates, ScaleByRmsState(nu=nu) diff --git a/torchopt/transform/scale_by_rss.py b/torchopt/transform/scale_by_rss.py index 68021e5e..f4000eea 100644 --- a/torchopt/transform/scale_by_rss.py +++ b/torchopt/transform/scale_by_rss.py @@ -128,23 +128,19 @@ def update_fn( if inplace: - def f(g: torch.Tensor, sos: torch.Tensor) -> torch.Tensor: - return torch.where( - sos > 0.0, - g.div_(sos.sqrt().add_(eps)), - 0.0, - ) + def f(sos: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g + return torch.where(sos > 0.0, g.div_(sos.sqrt().add_(eps)), 0.0) else: - def f(g: torch.Tensor, sos: torch.Tensor) -> torch.Tensor: - return torch.where( - sos > 0.0, - g.div(sos.sqrt().add(eps)), - 0.0, - ) + def f(sos: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g + return torch.where(sos > 0.0, g.div(sos.sqrt().add(eps)), 0.0) - updates = tree_map(f, updates, sum_of_squares) + updates = tree_map(f, sum_of_squares, updates) return updates, ScaleByRssState(sum_of_squares=sum_of_squares) return GradientTransformation(init_fn, update_fn) diff --git a/torchopt/transform/scale_by_schedule.py b/torchopt/transform/scale_by_schedule.py index f27fb7e8..749b1853 100644 --- a/torchopt/transform/scale_by_schedule.py +++ b/torchopt/transform/scale_by_schedule.py @@ -96,20 +96,24 @@ def update_fn( inplace: bool = True, ) -> tuple[Updates, OptState]: if inplace: - - def f(g: torch.Tensor, c: Numeric) -> torch.Tensor: # pylint: disable=invalid-name + # pylint: disable-next=invalid-name + def f(c: Numeric, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g step_size = step_size_fn(c) return g.mul_(step_size) - updates = tree_map_(f, updates, state.count) + tree_map_(f, state.count, updates) else: - - def f(g: torch.Tensor, c: Numeric) -> torch.Tensor: # pylint: disable=invalid-name + # pylint: disable-next=invalid-name + def f(c: Numeric, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g step_size = step_size_fn(c) return g.mul(step_size) - updates = tree_map(f, updates, state.count) + updates = tree_map(f, state.count, updates) return ( updates, diff --git a/torchopt/transform/scale_by_stddev.py b/torchopt/transform/scale_by_stddev.py index bbbfb384..30e799b6 100644 --- a/torchopt/transform/scale_by_stddev.py +++ b/torchopt/transform/scale_by_stddev.py @@ -148,17 +148,21 @@ def update_fn( if inplace: - def f(g: torch.Tensor, m: torch.Tensor, n: torch.Tensor) -> torch.Tensor: + def f(m: torch.Tensor, n: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g return g.div_(n.addcmul(m, m, value=-1.0).sqrt_().add(eps)) - updates = tree_map_(f, updates, mu, nu) + tree_map_(f, mu, nu, updates) else: - def f(g: torch.Tensor, m: torch.Tensor, n: torch.Tensor) -> torch.Tensor: + def f(m: torch.Tensor, n: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g return g.div(n.addcmul(m, m, value=-1.0).sqrt_().add(eps)) - updates = tree_map(f, updates, mu, nu) + updates = tree_map(f, mu, nu, updates) return updates, ScaleByRStdDevState(mu=mu, nu=nu) diff --git a/torchopt/transform/trace.py b/torchopt/transform/trace.py index 7a1e1971..5f82c067 100644 --- a/torchopt/transform/trace.py +++ b/torchopt/transform/trace.py @@ -148,52 +148,66 @@ def update_fn( if nesterov: if inplace: - def f1(g: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + def f1(t: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g if first_call: return t.add_(g) return t.mul_(momentum).add_(g) - def f2(g: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + def f2(t: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g return g.add_(t, alpha=momentum) - new_trace = tree_map(f1, updates, state.trace) - updates = tree_map_(f2, updates, new_trace) + new_trace = tree_map(f1, state.trace, updates) + tree_map_(f2, new_trace, updates) else: - def f1(g: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + def f1(t: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g if first_call: return t.add(g) return t.mul(momentum).add_(g) - def f2(g: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + def f2(t: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g return g.add(t, alpha=momentum) - new_trace = tree_map(f1, updates, state.trace) - updates = tree_map(f2, updates, new_trace) + new_trace = tree_map(f1, state.trace, updates) + updates = tree_map(f2, new_trace, updates) else: if inplace: - def f(g: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + def f(t: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g if first_call: return t.add_(g) return t.mul_(momentum).add_(g, alpha=1.0 - dampening) - def copy_(g: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + def copy_to_(t: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g return g.copy_(t) - new_trace = tree_map(f, updates, state.trace) - updates = tree_map_(copy_, updates, new_trace) + new_trace = tree_map(f, state.trace, updates) + tree_map_(copy_to_, new_trace, updates) else: - def f(g: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + def f(t: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g if first_call: return t.add(g) return t.mul(momentum).add_(g, alpha=1.0 - dampening) - new_trace = tree_map(f, updates, state.trace) + new_trace = tree_map(f, state.trace, updates) updates = tree_map(torch.clone, new_trace) first_call = False