diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 80ed2876..89c26c3c 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -29,10 +29,10 @@ jobs: submodules: "recursive" fetch-depth: 1 - - name: Set up Python 3.8 + - name: Set up Python 3.9 uses: actions/setup-python@v4 with: - python-version: "3.8" + python-version: "3.9" update-environment: true - name: Setup CUDA Toolkit diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4d2ef96d..975bb7b6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,7 +9,7 @@ ci: default_stages: [commit, push, manual] repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: v4.5.0 hooks: - id: check-symlinks - id: destroyed-symlinks @@ -26,11 +26,11 @@ repos: - id: debug-statements - id: double-quote-string-fixer - repo: https://github.com/pre-commit/mirrors-clang-format - rev: v16.0.6 + rev: v17.0.4 hooks: - id: clang-format - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.0.287 + rev: v0.1.5 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] @@ -39,11 +39,11 @@ repos: hooks: - id: isort - repo: https://github.com/psf/black - rev: 23.7.0 + rev: 23.11.0 hooks: - id: black-jupyter - repo: https://github.com/asottile/pyupgrade - rev: v3.10.1 + rev: v3.15.0 hooks: - id: pyupgrade args: [--py38-plus] # sync with requires-python @@ -68,7 +68,7 @@ repos: ^docs/source/conf.py$ ) - repo: https://github.com/codespell-project/codespell - rev: v2.2.5 + rev: v2.2.6 hooks: - id: codespell additional_dependencies: [".[toml]"] diff --git a/CHANGELOG.md b/CHANGELOG.md index 24e4eea4..797b6ca0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,11 +17,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed -- +- Set minimal C++ standard to C++17 by [@XuehaiPan](https://github.com/XuehaiPan) in [#195](https://github.com/metaopt/torchopt/pull/195). ### Fixed -- +- Fix `optree` compatibility for multi-tree-map with `None` values by [@XuehaiPan](https://github.com/XuehaiPan) in [#195](https://github.com/metaopt/torchopt/pull/195). ### Removed diff --git a/CMakeLists.txt b/CMakeLists.txt index 3b091a22..eca14815 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -17,13 +17,13 @@ cmake_minimum_required(VERSION 3.11) # for FetchContent project(torchopt LANGUAGES CXX) include(FetchContent) -set(PYBIND11_VERSION v2.10.3) +set(PYBIND11_VERSION v2.11.1) if(NOT CMAKE_BUILD_TYPE) set(CMAKE_BUILD_TYPE Release) endif() -set(CMAKE_CXX_STANDARD 14) +set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) find_package(Threads REQUIRED) # -pthread diff --git a/conda-recipe.yaml b/conda-recipe.yaml index 997f11c5..12b3a6d0 100644 --- a/conda-recipe.yaml +++ b/conda-recipe.yaml @@ -77,7 +77,6 @@ dependencies: - hunspell-en - myst-nb - ipykernel - - pandoc - docutils # Testing diff --git a/docs/conda-recipe.yaml b/docs/conda-recipe.yaml index 9a14af3f..30ec372e 100644 --- a/docs/conda-recipe.yaml +++ b/docs/conda-recipe.yaml @@ -67,5 +67,4 @@ dependencies: - hunspell-en - myst-nb - ipykernel - - pandoc - docutils diff --git a/docs/requirements.txt b/docs/requirements.txt index 655c64ff..82bf2d91 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -4,17 +4,17 @@ torch >= 1.13 --requirement ../requirements.txt -sphinx >= 5.2.1 +sphinx >= 5.2.1, < 7.0.0a0 +sphinxcontrib-bibtex >= 2.4 +sphinx-autodoc-typehints >= 1.20 +myst-nb >= 0.15 + sphinx-autoapi sphinx-autobuild sphinx-copybutton sphinx-rtd-theme sphinxcontrib-katex -sphinxcontrib-bibtex -sphinx-autodoc-typehints >= 1.19.2 IPython ipykernel -pandoc -myst-nb docutils matplotlib diff --git a/torchopt/alias/sgd.py b/torchopt/alias/sgd.py index 6fb3c6db..4c5b8317 100644 --- a/torchopt/alias/sgd.py +++ b/torchopt/alias/sgd.py @@ -44,6 +44,7 @@ __all__ = ['sgd'] +# pylint: disable-next=too-many-arguments def sgd( lr: ScalarOrSchedule, momentum: float = 0.0, diff --git a/torchopt/alias/utils.py b/torchopt/alias/utils.py index 5c8dc97a..2c2ec0e4 100644 --- a/torchopt/alias/utils.py +++ b/torchopt/alias/utils.py @@ -108,19 +108,21 @@ def update_fn( if inplace: - def f(g: torch.Tensor, p: torch.Tensor) -> torch.Tensor: + def f(p: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g if g.requires_grad: return g.add_(p, alpha=weight_decay) return g.add_(p.data, alpha=weight_decay) - updates = tree_map_(f, updates, params) + tree_map_(f, params, updates) else: - def f(g: torch.Tensor, p: torch.Tensor) -> torch.Tensor: - return g.add(p, alpha=weight_decay) + def f(p: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return g.add(p, alpha=weight_decay) if g is not None else g - updates = tree_map(f, updates, params) + updates = tree_map(f, params, updates) return updates, state @@ -139,7 +141,7 @@ def update_fn( def f(g: torch.Tensor) -> torch.Tensor: return g.neg_() - updates = tree_map_(f, updates) + tree_map_(f, updates) else: @@ -166,19 +168,21 @@ def update_fn( if inplace: - def f(g: torch.Tensor, p: torch.Tensor) -> torch.Tensor: + def f(p: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g if g.requires_grad: return g.neg_().add_(p, alpha=weight_decay) return g.neg_().add_(p.data, alpha=weight_decay) - updates = tree_map_(f, updates, params) + tree_map_(f, params, updates) else: - def f(g: torch.Tensor, p: torch.Tensor) -> torch.Tensor: - return g.neg().add_(p, alpha=weight_decay) + def f(p: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return g.neg().add_(p, alpha=weight_decay) if g is not None else g - updates = tree_map(f, updates, params) + updates = tree_map(f, params, updates) return updates, state diff --git a/torchopt/distributed/api.py b/torchopt/distributed/api.py index 3a6f0526..fb7461e4 100644 --- a/torchopt/distributed/api.py +++ b/torchopt/distributed/api.py @@ -271,6 +271,7 @@ def sum_reducer(results: Iterable[torch.Tensor]) -> torch.Tensor: return torch.sum(torch.stack(tuple(results), dim=0), dim=0) +# pylint: disable-next=too-many-arguments def remote_async_call( func: Callable[..., T], *, @@ -328,6 +329,7 @@ def remote_async_call( return future +# pylint: disable-next=too-many-arguments def remote_sync_call( func: Callable[..., T], *, diff --git a/torchopt/linalg/cg.py b/torchopt/linalg/cg.py index 9cd57cd8..42cb6bea 100644 --- a/torchopt/linalg/cg.py +++ b/torchopt/linalg/cg.py @@ -53,7 +53,7 @@ def _identity(x: TensorTree) -> TensorTree: return x -# pylint: disable-next=too-many-locals +# pylint: disable-next=too-many-arguments,too-many-locals def _cg_solve( A: Callable[[TensorTree], TensorTree], b: TensorTree, @@ -102,6 +102,7 @@ def body_fn( return x_final +# pylint: disable-next=too-many-arguments def _isolve( _isolve_solve: Callable, A: TensorTree | Callable[[TensorTree], TensorTree], @@ -134,6 +135,7 @@ def _isolve( return isolve_solve(A, b) +# pylint: disable-next=too-many-arguments def cg( A: TensorTree | Callable[[TensorTree], TensorTree], b: TensorTree, diff --git a/torchopt/linalg/ns.py b/torchopt/linalg/ns.py index 747ad3cf..ce49fe77 100644 --- a/torchopt/linalg/ns.py +++ b/torchopt/linalg/ns.py @@ -123,12 +123,14 @@ def _ns_inv(A: torch.Tensor, maxiter: int, alpha: float | None = None) -> torch. # A^{-1} = a [I - (I - a A)]^{-1} = a [I + (I - a A) + (I - a A)^2 + (I - a A)^3 + ...] M = I - alpha * A for rank in range(maxiter): + # pylint: disable-next=not-callable inv_A_hat = inv_A_hat + torch.linalg.matrix_power(M, rank) inv_A_hat = alpha * inv_A_hat else: # A^{-1} = [I - (I - A)]^{-1} = I + (I - A) + (I - A)^2 + (I - A)^3 + ... M = I - A for rank in range(maxiter): + # pylint: disable-next=not-callable inv_A_hat = inv_A_hat + torch.linalg.matrix_power(M, rank) return inv_A_hat diff --git a/torchopt/nn/stateless.py b/torchopt/nn/stateless.py index e547b5cb..8268ca3f 100644 --- a/torchopt/nn/stateless.py +++ b/torchopt/nn/stateless.py @@ -66,8 +66,8 @@ def recursive_setattr(path: str, value: torch.Tensor) -> torch.Tensor: mod._parameters[attr] = value # type: ignore[assignment] elif hasattr(mod, '_buffers') and attr in mod._buffers: mod._buffers[attr] = value - elif hasattr(mod, '_meta_parameters') and attr in mod._meta_parameters: # type: ignore[operator] - mod._meta_parameters[attr] = value # type: ignore[operator,index] + elif hasattr(mod, '_meta_parameters') and attr in mod._meta_parameters: + mod._meta_parameters[attr] = value else: setattr(mod, attr, value) # pylint: enable=protected-access diff --git a/torchopt/transform/add_decayed_weights.py b/torchopt/transform/add_decayed_weights.py index 04d564d7..39948694 100644 --- a/torchopt/transform/add_decayed_weights.py +++ b/torchopt/transform/add_decayed_weights.py @@ -226,19 +226,21 @@ def update_fn( if inplace: - def f(g: torch.Tensor, p: torch.Tensor) -> torch.Tensor: + def f(p: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g if g.requires_grad: return g.add_(p, alpha=weight_decay) return g.add_(p.data, alpha=weight_decay) - updates = tree_map_(f, updates, params) + tree_map_(f, params, updates) else: - def f(g: torch.Tensor, p: torch.Tensor) -> torch.Tensor: - return g.add(p, alpha=weight_decay) + def f(p: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return g.add(p, alpha=weight_decay) if g is not None else g - updates = tree_map(f, updates, params) + updates = tree_map(f, params, updates) return updates, state diff --git a/torchopt/transform/scale_by_adadelta.py b/torchopt/transform/scale_by_adadelta.py index fb5431a3..bbe40080 100644 --- a/torchopt/transform/scale_by_adadelta.py +++ b/torchopt/transform/scale_by_adadelta.py @@ -129,23 +129,15 @@ def update_fn( if inplace: - def f( - g: torch.Tensor, # pylint: disable=unused-argument - m: torch.Tensor, - v: torch.Tensor, - ) -> torch.Tensor: - return g.mul_(v.add(eps).div_(m.add(eps)).sqrt_()) + def f(m: torch.Tensor, v: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return g.mul_(v.add(eps).div_(m.add(eps)).sqrt_()) if g is not None else g else: - def f( - g: torch.Tensor, # pylint: disable=unused-argument - m: torch.Tensor, - v: torch.Tensor, - ) -> torch.Tensor: - return g.mul(v.add(eps).div_(m.add(eps)).sqrt_()) + def f(m: torch.Tensor, v: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return g.mul(v.add(eps).div_(m.add(eps)).sqrt_()) if g is not None else g - updates = tree_map(f, updates, mu, state.nu) + updates = tree_map(f, mu, state.nu, updates) nu = update_moment.impl( # type: ignore[attr-defined] updates, diff --git a/torchopt/transform/scale_by_adam.py b/torchopt/transform/scale_by_adam.py index c3c6254e..cc0ea3b6 100644 --- a/torchopt/transform/scale_by_adam.py +++ b/torchopt/transform/scale_by_adam.py @@ -132,6 +132,7 @@ def _scale_by_adam_flat( ) +# pylint: disable-next=too-many-arguments def _scale_by_adam( b1: float = 0.9, b2: float = 0.999, @@ -200,23 +201,15 @@ def update_fn( if inplace: - def f( - g: torch.Tensor, # pylint: disable=unused-argument - m: torch.Tensor, - v: torch.Tensor, - ) -> torch.Tensor: - return m.div_(v.add_(eps_root).sqrt_().add(eps)) + def f(m: torch.Tensor, v: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return m.div_(v.add_(eps_root).sqrt_().add(eps)) if g is not None else g else: - def f( - g: torch.Tensor, # pylint: disable=unused-argument - m: torch.Tensor, - v: torch.Tensor, - ) -> torch.Tensor: - return m.div(v.add(eps_root).sqrt_().add(eps)) + def f(m: torch.Tensor, v: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return m.div(v.add(eps_root).sqrt_().add(eps)) if g is not None else g - updates = tree_map(f, updates, mu_hat, nu_hat) + updates = tree_map(f, mu_hat, nu_hat, updates) return updates, ScaleByAdamState(mu=mu, nu=nu, count=count_inc) return GradientTransformation(init_fn, update_fn) @@ -283,6 +276,7 @@ def _scale_by_accelerated_adam_flat( ) +# pylint: disable-next=too-many-arguments def _scale_by_accelerated_adam( b1: float = 0.9, b2: float = 0.999, diff --git a/torchopt/transform/scale_by_adamax.py b/torchopt/transform/scale_by_adamax.py index 504e82cd..0a1c3ec9 100644 --- a/torchopt/transform/scale_by_adamax.py +++ b/torchopt/transform/scale_by_adamax.py @@ -137,23 +137,17 @@ def update_fn( already_flattened=already_flattened, ) - def update_nu( - g: torch.Tensor, - n: torch.Tensor, - ) -> torch.Tensor: - return torch.max(n.mul(b2), g.abs().add_(eps)) + def update_nu(n: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return torch.max(n.mul(b2), g.abs().add_(eps)) if g is not None else g - nu = tree_map(update_nu, updates, state.nu) + nu = tree_map(update_nu, state.nu, updates) one_minus_b1_pow_t = 1 - b1**state.t - def f( - n: torch.Tensor, - m: torch.Tensor, - ) -> torch.Tensor: - return m.div(n).div_(one_minus_b1_pow_t) + def f(m: torch.Tensor, n: torch.Tensor | None) -> torch.Tensor: + return m.div(n).div_(one_minus_b1_pow_t) if n is not None else m - updates = tree_map(f, nu, mu) + updates = tree_map(f, mu, nu) return updates, ScaleByAdamaxState(mu=mu, nu=nu, t=state.t + 1) diff --git a/torchopt/transform/scale_by_rms.py b/torchopt/transform/scale_by_rms.py index ac2fef16..084be839 100644 --- a/torchopt/transform/scale_by_rms.py +++ b/torchopt/transform/scale_by_rms.py @@ -136,17 +136,17 @@ def update_fn( if inplace: - def f(g: torch.Tensor, n: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name - return g.div_(n.sqrt().add_(eps)) + def f(n: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return g.div_(n.sqrt().add_(eps)) if g is not None else g - updates = tree_map_(f, updates, nu) + tree_map_(f, nu, updates) else: - def f(g: torch.Tensor, n: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name - return g.div(n.sqrt().add(eps)) + def f(n: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return g.div(n.sqrt().add(eps)) if g is not None else g - updates = tree_map(f, updates, nu) + updates = tree_map(f, nu, updates) return updates, ScaleByRmsState(nu=nu) diff --git a/torchopt/transform/scale_by_rss.py b/torchopt/transform/scale_by_rss.py index 68021e5e..b1f3d2a8 100644 --- a/torchopt/transform/scale_by_rss.py +++ b/torchopt/transform/scale_by_rss.py @@ -128,23 +128,21 @@ def update_fn( if inplace: - def f(g: torch.Tensor, sos: torch.Tensor) -> torch.Tensor: - return torch.where( - sos > 0.0, - g.div_(sos.sqrt().add_(eps)), - 0.0, + def f(sos: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return ( + torch.where(sos > 0.0, g.div_(sos.sqrt().add_(eps)), 0.0) + if g is not None + else g ) else: - def f(g: torch.Tensor, sos: torch.Tensor) -> torch.Tensor: - return torch.where( - sos > 0.0, - g.div(sos.sqrt().add(eps)), - 0.0, + def f(sos: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return ( + torch.where(sos > 0.0, g.div(sos.sqrt().add(eps)), 0.0) if g is not None else g ) - updates = tree_map(f, updates, sum_of_squares) + updates = tree_map(f, sum_of_squares, updates) return updates, ScaleByRssState(sum_of_squares=sum_of_squares) return GradientTransformation(init_fn, update_fn) diff --git a/torchopt/transform/scale_by_schedule.py b/torchopt/transform/scale_by_schedule.py index f27fb7e8..749b1853 100644 --- a/torchopt/transform/scale_by_schedule.py +++ b/torchopt/transform/scale_by_schedule.py @@ -96,20 +96,24 @@ def update_fn( inplace: bool = True, ) -> tuple[Updates, OptState]: if inplace: - - def f(g: torch.Tensor, c: Numeric) -> torch.Tensor: # pylint: disable=invalid-name + # pylint: disable-next=invalid-name + def f(c: Numeric, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g step_size = step_size_fn(c) return g.mul_(step_size) - updates = tree_map_(f, updates, state.count) + tree_map_(f, state.count, updates) else: - - def f(g: torch.Tensor, c: Numeric) -> torch.Tensor: # pylint: disable=invalid-name + # pylint: disable-next=invalid-name + def f(c: Numeric, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g step_size = step_size_fn(c) return g.mul(step_size) - updates = tree_map(f, updates, state.count) + updates = tree_map(f, state.count, updates) return ( updates, diff --git a/torchopt/transform/scale_by_stddev.py b/torchopt/transform/scale_by_stddev.py index bbbfb384..d9589c45 100644 --- a/torchopt/transform/scale_by_stddev.py +++ b/torchopt/transform/scale_by_stddev.py @@ -148,17 +148,17 @@ def update_fn( if inplace: - def f(g: torch.Tensor, m: torch.Tensor, n: torch.Tensor) -> torch.Tensor: - return g.div_(n.addcmul(m, m, value=-1.0).sqrt_().add(eps)) + def f(m: torch.Tensor, n: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return g.div_(n.addcmul(m, m, value=-1.0).sqrt_().add(eps)) if g is not None else g - updates = tree_map_(f, updates, mu, nu) + tree_map_(f, mu, nu, updates) else: - def f(g: torch.Tensor, m: torch.Tensor, n: torch.Tensor) -> torch.Tensor: - return g.div(n.addcmul(m, m, value=-1.0).sqrt_().add(eps)) + def f(m: torch.Tensor, n: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return g.div(n.addcmul(m, m, value=-1.0).sqrt_().add(eps)) if g is not None else g - updates = tree_map(f, updates, mu, nu) + updates = tree_map(f, mu, nu, updates) return updates, ScaleByRStdDevState(mu=mu, nu=nu) diff --git a/torchopt/transform/trace.py b/torchopt/transform/trace.py index 7a1e1971..d530a676 100644 --- a/torchopt/transform/trace.py +++ b/torchopt/transform/trace.py @@ -148,52 +148,60 @@ def update_fn( if nesterov: if inplace: - def f1(g: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + def f1(t: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g if first_call: return t.add_(g) return t.mul_(momentum).add_(g) - def f2(g: torch.Tensor, t: torch.Tensor) -> torch.Tensor: - return g.add_(t, alpha=momentum) + def f2(t: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return g.add_(t, alpha=momentum) if g is not None else g - new_trace = tree_map(f1, updates, state.trace) - updates = tree_map_(f2, updates, new_trace) + new_trace = tree_map(f1, state.trace, updates) + tree_map_(f2, new_trace, updates) else: - def f1(g: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + def f1(t: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g if first_call: return t.add(g) return t.mul(momentum).add_(g) - def f2(g: torch.Tensor, t: torch.Tensor) -> torch.Tensor: - return g.add(t, alpha=momentum) + def f2(t: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return g.add(t, alpha=momentum) if g is not None else g - new_trace = tree_map(f1, updates, state.trace) - updates = tree_map(f2, updates, new_trace) + new_trace = tree_map(f1, state.trace, updates) + updates = tree_map(f2, new_trace, updates) else: if inplace: - def f(g: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + def f(t: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g if first_call: return t.add_(g) return t.mul_(momentum).add_(g, alpha=1.0 - dampening) - def copy_(g: torch.Tensor, t: torch.Tensor) -> torch.Tensor: - return g.copy_(t) + def copy_to_(t: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + return g.copy_(t) if g is not None else g - new_trace = tree_map(f, updates, state.trace) - updates = tree_map_(copy_, updates, new_trace) + new_trace = tree_map(f, state.trace, updates) + tree_map_(copy_to_, new_trace, updates) else: - def f(g: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + def f(t: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None: + if g is None: + return g if first_call: return t.add(g) return t.mul(momentum).add_(g, alpha=1.0 - dampening) - new_trace = tree_map(f, updates, state.trace) + new_trace = tree_map(f, state.trace, updates) updates = tree_map(torch.clone, new_trace) first_call = False diff --git a/torchopt/transform/utils.py b/torchopt/transform/utils.py index 8c67fd7e..f1ed39da 100644 --- a/torchopt/transform/utils.py +++ b/torchopt/transform/utils.py @@ -160,6 +160,7 @@ def _update_moment_flat( ) +# pylint: disable-next=too-many-arguments def _update_moment( updates: Updates, moments: TensorTree, diff --git a/torchopt/utils.py b/torchopt/utils.py index 5414db80..ef771966 100644 --- a/torchopt/utils.py +++ b/torchopt/utils.py @@ -91,7 +91,7 @@ def fn_(obj: Any) -> None: @overload -def extract_state_dict( +def extract_state_dict( # pylint: disable=too-many-arguments target: nn.Module, *, by: CopyMode = 'reference', @@ -114,7 +114,7 @@ def extract_state_dict( ... -# pylint: disable-next=too-many-branches,too-many-locals +# pylint: disable-next=too-many-arguments,too-many-branches,too-many-locals def extract_state_dict( target: nn.Module | MetaOptimizer, *,