Skip to content

Commit

Permalink
fix: fix optree compatibility for multi-tree-map with None values (
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan authored Nov 9, 2023
1 parent 86b167c commit 93cc7ec
Show file tree
Hide file tree
Showing 24 changed files with 126 additions and 124 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ jobs:
submodules: "recursive"
fetch-depth: 1

- name: Set up Python 3.8
- name: Set up Python 3.9
uses: actions/setup-python@v4
with:
python-version: "3.8"
python-version: "3.9"
update-environment: true

- name: Setup CUDA Toolkit
Expand Down
12 changes: 6 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ ci:
default_stages: [commit, push, manual]
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v4.5.0
hooks:
- id: check-symlinks
- id: destroyed-symlinks
Expand All @@ -26,11 +26,11 @@ repos:
- id: debug-statements
- id: double-quote-string-fixer
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v16.0.6
rev: v17.0.4
hooks:
- id: clang-format
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.0.287
rev: v0.1.5
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
Expand All @@ -39,11 +39,11 @@ repos:
hooks:
- id: isort
- repo: https://github.com/psf/black
rev: 23.7.0
rev: 23.11.0
hooks:
- id: black-jupyter
- repo: https://github.com/asottile/pyupgrade
rev: v3.10.1
rev: v3.15.0
hooks:
- id: pyupgrade
args: [--py38-plus] # sync with requires-python
Expand All @@ -68,7 +68,7 @@ repos:
^docs/source/conf.py$
)
- repo: https://github.com/codespell-project/codespell
rev: v2.2.5
rev: v2.2.6
hooks:
- id: codespell
additional_dependencies: [".[toml]"]
Expand Down
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

-
- Set minimal C++ standard to C++17 by [@XuehaiPan](https://github.com/XuehaiPan) in [#195](https://github.com/metaopt/torchopt/pull/195).

### Fixed

-
- Fix `optree` compatibility for multi-tree-map with `None` values by [@XuehaiPan](https://github.com/XuehaiPan) in [#195](https://github.com/metaopt/torchopt/pull/195).

### Removed

Expand Down
4 changes: 2 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@ cmake_minimum_required(VERSION 3.11) # for FetchContent
project(torchopt LANGUAGES CXX)

include(FetchContent)
set(PYBIND11_VERSION v2.10.3)
set(PYBIND11_VERSION v2.11.1)

if(NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE Release)
endif()

set(CMAKE_CXX_STANDARD 14)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)

find_package(Threads REQUIRED) # -pthread
Expand Down
1 change: 0 additions & 1 deletion conda-recipe.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ dependencies:
- hunspell-en
- myst-nb
- ipykernel
- pandoc
- docutils

# Testing
Expand Down
1 change: 0 additions & 1 deletion docs/conda-recipe.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -67,5 +67,4 @@ dependencies:
- hunspell-en
- myst-nb
- ipykernel
- pandoc
- docutils
10 changes: 5 additions & 5 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,17 @@ torch >= 1.13

--requirement ../requirements.txt

sphinx >= 5.2.1
sphinx >= 5.2.1, < 7.0.0a0
sphinxcontrib-bibtex >= 2.4
sphinx-autodoc-typehints >= 1.20
myst-nb >= 0.15

sphinx-autoapi
sphinx-autobuild
sphinx-copybutton
sphinx-rtd-theme
sphinxcontrib-katex
sphinxcontrib-bibtex
sphinx-autodoc-typehints >= 1.19.2
IPython
ipykernel
pandoc
myst-nb
docutils
matplotlib
1 change: 1 addition & 0 deletions torchopt/alias/sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
__all__ = ['sgd']


# pylint: disable-next=too-many-arguments
def sgd(
lr: ScalarOrSchedule,
momentum: float = 0.0,
Expand Down
26 changes: 15 additions & 11 deletions torchopt/alias/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,19 +108,21 @@ def update_fn(

if inplace:

def f(g: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
def f(p: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None:
if g is None:
return g
if g.requires_grad:
return g.add_(p, alpha=weight_decay)
return g.add_(p.data, alpha=weight_decay)

updates = tree_map_(f, updates, params)
tree_map_(f, params, updates)

else:

def f(g: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
return g.add(p, alpha=weight_decay)
def f(p: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None:
return g.add(p, alpha=weight_decay) if g is not None else g

updates = tree_map(f, updates, params)
updates = tree_map(f, params, updates)

return updates, state

Expand All @@ -139,7 +141,7 @@ def update_fn(
def f(g: torch.Tensor) -> torch.Tensor:
return g.neg_()

updates = tree_map_(f, updates)
tree_map_(f, updates)

else:

Expand All @@ -166,19 +168,21 @@ def update_fn(

if inplace:

def f(g: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
def f(p: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None:
if g is None:
return g
if g.requires_grad:
return g.neg_().add_(p, alpha=weight_decay)
return g.neg_().add_(p.data, alpha=weight_decay)

updates = tree_map_(f, updates, params)
tree_map_(f, params, updates)

else:

def f(g: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
return g.neg().add_(p, alpha=weight_decay)
def f(p: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None:
return g.neg().add_(p, alpha=weight_decay) if g is not None else g

updates = tree_map(f, updates, params)
updates = tree_map(f, params, updates)

return updates, state

Expand Down
2 changes: 2 additions & 0 deletions torchopt/distributed/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ def sum_reducer(results: Iterable[torch.Tensor]) -> torch.Tensor:
return torch.sum(torch.stack(tuple(results), dim=0), dim=0)


# pylint: disable-next=too-many-arguments
def remote_async_call(
func: Callable[..., T],
*,
Expand Down Expand Up @@ -328,6 +329,7 @@ def remote_async_call(
return future


# pylint: disable-next=too-many-arguments
def remote_sync_call(
func: Callable[..., T],
*,
Expand Down
4 changes: 3 additions & 1 deletion torchopt/linalg/cg.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def _identity(x: TensorTree) -> TensorTree:
return x


# pylint: disable-next=too-many-locals
# pylint: disable-next=too-many-arguments,too-many-locals
def _cg_solve(
A: Callable[[TensorTree], TensorTree],
b: TensorTree,
Expand Down Expand Up @@ -102,6 +102,7 @@ def body_fn(
return x_final


# pylint: disable-next=too-many-arguments
def _isolve(
_isolve_solve: Callable,
A: TensorTree | Callable[[TensorTree], TensorTree],
Expand Down Expand Up @@ -134,6 +135,7 @@ def _isolve(
return isolve_solve(A, b)


# pylint: disable-next=too-many-arguments
def cg(
A: TensorTree | Callable[[TensorTree], TensorTree],
b: TensorTree,
Expand Down
2 changes: 2 additions & 0 deletions torchopt/linalg/ns.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,14 @@ def _ns_inv(A: torch.Tensor, maxiter: int, alpha: float | None = None) -> torch.
# A^{-1} = a [I - (I - a A)]^{-1} = a [I + (I - a A) + (I - a A)^2 + (I - a A)^3 + ...]
M = I - alpha * A
for rank in range(maxiter):
# pylint: disable-next=not-callable
inv_A_hat = inv_A_hat + torch.linalg.matrix_power(M, rank)
inv_A_hat = alpha * inv_A_hat
else:
# A^{-1} = [I - (I - A)]^{-1} = I + (I - A) + (I - A)^2 + (I - A)^3 + ...
M = I - A
for rank in range(maxiter):
# pylint: disable-next=not-callable
inv_A_hat = inv_A_hat + torch.linalg.matrix_power(M, rank)
return inv_A_hat

Expand Down
4 changes: 2 additions & 2 deletions torchopt/nn/stateless.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ def recursive_setattr(path: str, value: torch.Tensor) -> torch.Tensor:
mod._parameters[attr] = value # type: ignore[assignment]
elif hasattr(mod, '_buffers') and attr in mod._buffers:
mod._buffers[attr] = value
elif hasattr(mod, '_meta_parameters') and attr in mod._meta_parameters: # type: ignore[operator]
mod._meta_parameters[attr] = value # type: ignore[operator,index]
elif hasattr(mod, '_meta_parameters') and attr in mod._meta_parameters:
mod._meta_parameters[attr] = value
else:
setattr(mod, attr, value)
# pylint: enable=protected-access
Expand Down
12 changes: 7 additions & 5 deletions torchopt/transform/add_decayed_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,19 +226,21 @@ def update_fn(

if inplace:

def f(g: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
def f(p: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None:
if g is None:
return g
if g.requires_grad:
return g.add_(p, alpha=weight_decay)
return g.add_(p.data, alpha=weight_decay)

updates = tree_map_(f, updates, params)
tree_map_(f, params, updates)

else:

def f(g: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
return g.add(p, alpha=weight_decay)
def f(p: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None:
return g.add(p, alpha=weight_decay) if g is not None else g

updates = tree_map(f, updates, params)
updates = tree_map(f, params, updates)

return updates, state

Expand Down
18 changes: 5 additions & 13 deletions torchopt/transform/scale_by_adadelta.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,23 +129,15 @@ def update_fn(

if inplace:

def f(
g: torch.Tensor, # pylint: disable=unused-argument
m: torch.Tensor,
v: torch.Tensor,
) -> torch.Tensor:
return g.mul_(v.add(eps).div_(m.add(eps)).sqrt_())
def f(m: torch.Tensor, v: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None:
return g.mul_(v.add(eps).div_(m.add(eps)).sqrt_()) if g is not None else g

else:

def f(
g: torch.Tensor, # pylint: disable=unused-argument
m: torch.Tensor,
v: torch.Tensor,
) -> torch.Tensor:
return g.mul(v.add(eps).div_(m.add(eps)).sqrt_())
def f(m: torch.Tensor, v: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None:
return g.mul(v.add(eps).div_(m.add(eps)).sqrt_()) if g is not None else g

updates = tree_map(f, updates, mu, state.nu)
updates = tree_map(f, mu, state.nu, updates)

nu = update_moment.impl( # type: ignore[attr-defined]
updates,
Expand Down
20 changes: 7 additions & 13 deletions torchopt/transform/scale_by_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def _scale_by_adam_flat(
)


# pylint: disable-next=too-many-arguments
def _scale_by_adam(
b1: float = 0.9,
b2: float = 0.999,
Expand Down Expand Up @@ -200,23 +201,15 @@ def update_fn(

if inplace:

def f(
g: torch.Tensor, # pylint: disable=unused-argument
m: torch.Tensor,
v: torch.Tensor,
) -> torch.Tensor:
return m.div_(v.add_(eps_root).sqrt_().add(eps))
def f(m: torch.Tensor, v: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None:
return m.div_(v.add_(eps_root).sqrt_().add(eps)) if g is not None else g

else:

def f(
g: torch.Tensor, # pylint: disable=unused-argument
m: torch.Tensor,
v: torch.Tensor,
) -> torch.Tensor:
return m.div(v.add(eps_root).sqrt_().add(eps))
def f(m: torch.Tensor, v: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None:
return m.div(v.add(eps_root).sqrt_().add(eps)) if g is not None else g

updates = tree_map(f, updates, mu_hat, nu_hat)
updates = tree_map(f, mu_hat, nu_hat, updates)
return updates, ScaleByAdamState(mu=mu, nu=nu, count=count_inc)

return GradientTransformation(init_fn, update_fn)
Expand Down Expand Up @@ -283,6 +276,7 @@ def _scale_by_accelerated_adam_flat(
)


# pylint: disable-next=too-many-arguments
def _scale_by_accelerated_adam(
b1: float = 0.9,
b2: float = 0.999,
Expand Down
18 changes: 6 additions & 12 deletions torchopt/transform/scale_by_adamax.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,23 +137,17 @@ def update_fn(
already_flattened=already_flattened,
)

def update_nu(
g: torch.Tensor,
n: torch.Tensor,
) -> torch.Tensor:
return torch.max(n.mul(b2), g.abs().add_(eps))
def update_nu(n: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None:
return torch.max(n.mul(b2), g.abs().add_(eps)) if g is not None else g

nu = tree_map(update_nu, updates, state.nu)
nu = tree_map(update_nu, state.nu, updates)

one_minus_b1_pow_t = 1 - b1**state.t

def f(
n: torch.Tensor,
m: torch.Tensor,
) -> torch.Tensor:
return m.div(n).div_(one_minus_b1_pow_t)
def f(m: torch.Tensor, n: torch.Tensor | None) -> torch.Tensor:
return m.div(n).div_(one_minus_b1_pow_t) if n is not None else m

updates = tree_map(f, nu, mu)
updates = tree_map(f, mu, nu)

return updates, ScaleByAdamaxState(mu=mu, nu=nu, t=state.t + 1)

Expand Down
Loading

0 comments on commit 93cc7ec

Please sign in to comment.