Skip to content

Commit

Permalink
fix: fix optree compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan committed Nov 8, 2023
1 parent ce9e83b commit 6186b6c
Show file tree
Hide file tree
Showing 10 changed files with 110 additions and 86 deletions.
26 changes: 17 additions & 9 deletions torchopt/alias/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,19 +108,23 @@ def update_fn(

if inplace:

def f(g: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
def f(p: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None:
if g is None:
return g
if g.requires_grad:
return g.add_(p, alpha=weight_decay)
return g.add_(p.data, alpha=weight_decay)

updates = tree_map_(f, updates, params)
tree_map_(f, params, updates)

else:

def f(g: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
def f(p: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None:
if g is None:
return g
return g.add(p, alpha=weight_decay)

updates = tree_map(f, updates, params)
updates = tree_map(f, params, updates)

return updates, state

Expand All @@ -139,7 +143,7 @@ def update_fn(
def f(g: torch.Tensor) -> torch.Tensor:
return g.neg_()

updates = tree_map_(f, updates)
tree_map_(f, updates)

else:

Expand All @@ -166,19 +170,23 @@ def update_fn(

if inplace:

def f(g: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
def f(p: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None:
if g is None:
return g
if g.requires_grad:
return g.neg_().add_(p, alpha=weight_decay)
return g.neg_().add_(p.data, alpha=weight_decay)

updates = tree_map_(f, updates, params)
tree_map_(f, params, updates)

else:

def f(g: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
def f(p: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None:
if g is None:
return g
return g.neg().add_(p, alpha=weight_decay)

updates = tree_map(f, updates, params)
updates = tree_map(f, params, updates)

return updates, state

Expand Down
12 changes: 8 additions & 4 deletions torchopt/transform/add_decayed_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,19 +226,23 @@ def update_fn(

if inplace:

def f(g: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
def f(p: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None:
if g is None:
return g
if g.requires_grad:
return g.add_(p, alpha=weight_decay)
return g.add_(p.data, alpha=weight_decay)

updates = tree_map_(f, updates, params)
tree_map_(f, params, updates)

else:

def f(g: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
def f(p: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None:
if g is None:
return g
return g.add(p, alpha=weight_decay)

updates = tree_map(f, updates, params)
updates = tree_map(f, params, updates)

return updates, state

Expand Down
18 changes: 7 additions & 11 deletions torchopt/transform/scale_by_adadelta.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,23 +129,19 @@ def update_fn(

if inplace:

def f(
g: torch.Tensor, # pylint: disable=unused-argument
m: torch.Tensor,
v: torch.Tensor,
) -> torch.Tensor:
def f(m: torch.Tensor, v: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None:
if g is None:
return g
return g.mul_(v.add(eps).div_(m.add(eps)).sqrt_())

else:

def f(
g: torch.Tensor, # pylint: disable=unused-argument
m: torch.Tensor,
v: torch.Tensor,
) -> torch.Tensor:
def f(m: torch.Tensor, v: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None:
if g is None:
return g
return g.mul(v.add(eps).div_(m.add(eps)).sqrt_())

updates = tree_map(f, updates, mu, state.nu)
updates = tree_map(f, mu, state.nu, updates)

nu = update_moment.impl( # type: ignore[attr-defined]
updates,
Expand Down
18 changes: 7 additions & 11 deletions torchopt/transform/scale_by_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,23 +201,19 @@ def update_fn(

if inplace:

def f(
g: torch.Tensor, # pylint: disable=unused-argument
m: torch.Tensor,
v: torch.Tensor,
) -> torch.Tensor:
def f(m: torch.Tensor, v: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None:
if g is None:
return g
return m.div_(v.add_(eps_root).sqrt_().add(eps))

else:

def f(
g: torch.Tensor, # pylint: disable=unused-argument
m: torch.Tensor,
v: torch.Tensor,
) -> torch.Tensor:
def f(m: torch.Tensor, v: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None:
if g is None:
return g
return m.div(v.add(eps_root).sqrt_().add(eps))

updates = tree_map(f, updates, mu_hat, nu_hat)
updates = tree_map(f, mu_hat, nu_hat, updates)
return updates, ScaleByAdamState(mu=mu, nu=nu, count=count_inc)

return GradientTransformation(init_fn, update_fn)
Expand Down
18 changes: 8 additions & 10 deletions torchopt/transform/scale_by_adamax.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,23 +137,21 @@ def update_fn(
already_flattened=already_flattened,
)

def update_nu(
g: torch.Tensor,
n: torch.Tensor,
) -> torch.Tensor:
def update_nu(n: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None:
if g is None:
return None
return torch.max(n.mul(b2), g.abs().add_(eps))

nu = tree_map(update_nu, updates, state.nu)
nu = tree_map(update_nu, state.nu, updates)

one_minus_b1_pow_t = 1 - b1**state.t

def f(
n: torch.Tensor,
m: torch.Tensor,
) -> torch.Tensor:
def f(m: torch.Tensor, n: torch.Tensor | None) -> torch.Tensor:
if n is None:
return m
return m.div(n).div_(one_minus_b1_pow_t)

updates = tree_map(f, nu, mu)
updates = tree_map(f, mu, nu)

return updates, ScaleByAdamaxState(mu=mu, nu=nu, t=state.t + 1)

Expand Down
12 changes: 8 additions & 4 deletions torchopt/transform/scale_by_rms.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,17 +136,21 @@ def update_fn(

if inplace:

def f(g: torch.Tensor, n: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name
def f(n: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None:
if g is None:
return g
return g.div_(n.sqrt().add_(eps))

updates = tree_map_(f, updates, nu)
tree_map_(f, nu, updates)

else:

def f(g: torch.Tensor, n: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name
def f(n: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None:
if g is None:
return g
return g.div(n.sqrt().add(eps))

updates = tree_map(f, updates, nu)
updates = tree_map(f, nu, updates)

return updates, ScaleByRmsState(nu=nu)

Expand Down
22 changes: 9 additions & 13 deletions torchopt/transform/scale_by_rss.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,23 +128,19 @@ def update_fn(

if inplace:

def f(g: torch.Tensor, sos: torch.Tensor) -> torch.Tensor:
return torch.where(
sos > 0.0,
g.div_(sos.sqrt().add_(eps)),
0.0,
)
def f(sos: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None:
if g is None:
return g
return torch.where(sos > 0.0, g.div_(sos.sqrt().add_(eps)), 0.0)

else:

def f(g: torch.Tensor, sos: torch.Tensor) -> torch.Tensor:
return torch.where(
sos > 0.0,
g.div(sos.sqrt().add(eps)),
0.0,
)
def f(sos: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None:
if g is None:
return g
return torch.where(sos > 0.0, g.div(sos.sqrt().add(eps)), 0.0)

updates = tree_map(f, updates, sum_of_squares)
updates = tree_map(f, sum_of_squares, updates)
return updates, ScaleByRssState(sum_of_squares=sum_of_squares)

return GradientTransformation(init_fn, update_fn)
Expand Down
16 changes: 10 additions & 6 deletions torchopt/transform/scale_by_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,20 +96,24 @@ def update_fn(
inplace: bool = True,
) -> tuple[Updates, OptState]:
if inplace:

def f(g: torch.Tensor, c: Numeric) -> torch.Tensor: # pylint: disable=invalid-name
# pylint: disable-next=invalid-name
def f(c: Numeric, g: torch.Tensor | None) -> torch.Tensor | None:
if g is None:
return g
step_size = step_size_fn(c)
return g.mul_(step_size)

updates = tree_map_(f, updates, state.count)
tree_map_(f, state.count, updates)

else:

def f(g: torch.Tensor, c: Numeric) -> torch.Tensor: # pylint: disable=invalid-name
# pylint: disable-next=invalid-name
def f(c: Numeric, g: torch.Tensor | None) -> torch.Tensor | None:
if g is None:
return g
step_size = step_size_fn(c)
return g.mul(step_size)

updates = tree_map(f, updates, state.count)
updates = tree_map(f, state.count, updates)

return (
updates,
Expand Down
12 changes: 8 additions & 4 deletions torchopt/transform/scale_by_stddev.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,17 +148,21 @@ def update_fn(

if inplace:

def f(g: torch.Tensor, m: torch.Tensor, n: torch.Tensor) -> torch.Tensor:
def f(m: torch.Tensor, n: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None:
if g is None:
return g
return g.div_(n.addcmul(m, m, value=-1.0).sqrt_().add(eps))

updates = tree_map_(f, updates, mu, nu)
tree_map_(f, mu, nu, updates)

else:

def f(g: torch.Tensor, m: torch.Tensor, n: torch.Tensor) -> torch.Tensor:
def f(m: torch.Tensor, n: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None:
if g is None:
return g
return g.div(n.addcmul(m, m, value=-1.0).sqrt_().add(eps))

updates = tree_map(f, updates, mu, nu)
updates = tree_map(f, mu, nu, updates)

return updates, ScaleByRStdDevState(mu=mu, nu=nu)

Expand Down
42 changes: 28 additions & 14 deletions torchopt/transform/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,52 +148,66 @@ def update_fn(
if nesterov:
if inplace:

def f1(g: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
def f1(t: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None:
if g is None:
return g
if first_call:
return t.add_(g)
return t.mul_(momentum).add_(g)

def f2(g: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
def f2(t: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None:
if g is None:
return g

Check warning on line 160 in torchopt/transform/trace.py

View check run for this annotation

Codecov / codecov/patch

torchopt/transform/trace.py#L160

Added line #L160 was not covered by tests
return g.add_(t, alpha=momentum)

new_trace = tree_map(f1, updates, state.trace)
updates = tree_map_(f2, updates, new_trace)
new_trace = tree_map(f1, state.trace, updates)
tree_map_(f2, new_trace, updates)

else:

def f1(g: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
def f1(t: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None:
if g is None:
return g
if first_call:
return t.add(g)
return t.mul(momentum).add_(g)

def f2(g: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
def f2(t: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None:
if g is None:
return g

Check warning on line 177 in torchopt/transform/trace.py

View check run for this annotation

Codecov / codecov/patch

torchopt/transform/trace.py#L177

Added line #L177 was not covered by tests
return g.add(t, alpha=momentum)

new_trace = tree_map(f1, updates, state.trace)
updates = tree_map(f2, updates, new_trace)
new_trace = tree_map(f1, state.trace, updates)
updates = tree_map(f2, new_trace, updates)

else:
if inplace:

def f(g: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
def f(t: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None:
if g is None:
return g
if first_call:
return t.add_(g)
return t.mul_(momentum).add_(g, alpha=1.0 - dampening)

def copy_(g: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
def copy_to_(t: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None:
if g is None:
return g

Check warning on line 195 in torchopt/transform/trace.py

View check run for this annotation

Codecov / codecov/patch

torchopt/transform/trace.py#L195

Added line #L195 was not covered by tests
return g.copy_(t)

new_trace = tree_map(f, updates, state.trace)
updates = tree_map_(copy_, updates, new_trace)
new_trace = tree_map(f, state.trace, updates)
tree_map_(copy_to_, new_trace, updates)

else:

def f(g: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
def f(t: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None:
if g is None:
return g
if first_call:
return t.add(g)
return t.mul(momentum).add_(g, alpha=1.0 - dampening)

new_trace = tree_map(f, updates, state.trace)
new_trace = tree_map(f, state.trace, updates)
updates = tree_map(torch.clone, new_trace)

first_call = False
Expand Down

0 comments on commit 6186b6c

Please sign in to comment.