Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(pre-commit): [pre-commit.ci] autoupdate #228

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ ci:
autofix_commit_msg: "fix: [pre-commit.ci] auto fixes [...]"
autoupdate_commit_msg: "chore(pre-commit): [pre-commit.ci] autoupdate"
autoupdate_schedule: monthly
default_stages: [commit, push, manual]
default_stages: [pre-commit, pre-push, manual]
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
rev: v5.0.0
hooks:
- id: check-symlinks
- id: destroyed-symlinks
Expand All @@ -26,24 +26,24 @@ repos:
- id: debug-statements
- id: double-quote-string-fixer
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v18.1.8
rev: v20.1.0
hooks:
- id: clang-format
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.5.0
rev: v0.11.4
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
- repo: https://github.com/PyCQA/isort
rev: 5.13.2
rev: 6.0.1
hooks:
- id: isort
- repo: https://github.com/psf/black
rev: 24.4.2
rev: 25.1.0
hooks:
- id: black-jupyter
- repo: https://github.com/asottile/pyupgrade
rev: v3.16.0
rev: v3.19.1
hooks:
- id: pyupgrade
args: [--py38-plus] # sync with requires-python
Expand All @@ -52,7 +52,7 @@ repos:
^examples/
)
- repo: https://github.com/pycqa/flake8
rev: 7.1.0
rev: 7.2.0
hooks:
- id: flake8
additional_dependencies:
Expand All @@ -68,7 +68,7 @@ repos:
^docs/source/conf.py$
)
- repo: https://github.com/codespell-project/codespell
rev: v2.3.0
rev: v2.4.1
hooks:
- id: codespell
additional_dependencies: [".[toml]"]
Expand Down
4 changes: 2 additions & 2 deletions torchopt/distributed/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,12 +318,12 @@ def remote_async_call(
futures.append(fut)

future = cast(
Future[List[T]],
'Future[List[T]]',
torch.futures.collect_all(futures).then(lambda fut: [f.wait() for f in fut.wait()]),
)
if reducer is not None:
return cast(
Future[U],
'Future[U]',
future.then(lambda fut: reducer(fut.wait())),
)
return future
Expand Down
2 changes: 1 addition & 1 deletion torchopt/nn/stateless.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def reparametrize(
module: nn.Module,
named_tensors: dict[str, torch.Tensor] | Iterable[tuple[str, torch.Tensor]],
allow_missing: bool = False,
) -> Generator[nn.Module, None, None]:
) -> Generator[nn.Module]:
"""Reparameterize the module parameters and/or buffers."""
if not isinstance(named_tensors, dict):
named_tensors = dict(named_tensors)
Expand Down
18 changes: 9 additions & 9 deletions torchopt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,13 @@ def fn_(obj: Any) -> None:
obj.detach_().requires_grad_(requires_grad)

if isinstance(target, ModuleState):
true_target = cast(TensorTree, (target.params, target.buffers))
true_target = cast('TensorTree', (target.params, target.buffers))
elif isinstance(target, nn.Module):
true_target = cast(TensorTree, tuple(target.parameters()))
true_target = cast('TensorTree', tuple(target.parameters()))
elif isinstance(target, MetaOptimizer):
true_target = cast(TensorTree, target.state_dict())
true_target = cast('TensorTree', target.state_dict())
else:
true_target = cast(TensorTree, target) # tree of tensors
true_target = cast('TensorTree', target) # tree of tensors

pytree.tree_map_(fn_, true_target)

Expand Down Expand Up @@ -325,7 +325,7 @@ def recover_state_dict(
from torchopt.optim.meta.base import MetaOptimizer

if isinstance(target, nn.Module):
params, buffers, *_ = state = cast(ModuleState, state)
params, buffers, *_ = state = cast('ModuleState', state)
params_containers, buffers_containers = extract_module_containers(target, with_buffers=True)

if state.detach_buffers:
Expand All @@ -343,7 +343,7 @@ def clone_detach_(t: torch.Tensor) -> torch.Tensor:
):
tgt.update(src)
elif isinstance(target, MetaOptimizer):
state = cast(Sequence[OptState], state)
state = cast('Sequence[OptState]', state)
target.load_state_dict(state)
else:
raise TypeError(f'Unexpected class of {target}')
Expand Down Expand Up @@ -422,9 +422,9 @@ def module_clone( # noqa: C901

if isinstance(target, (nn.Module, MetaOptimizer)):
if isinstance(target, nn.Module):
containers = cast(TensorTree, extract_module_containers(target, with_buffers=True))
containers = cast('TensorTree', extract_module_containers(target, with_buffers=True))
else:
containers = cast(TensorTree, target.state_dict())
containers = cast('TensorTree', target.state_dict())
tensors = pytree.tree_leaves(containers)
memo = {id(t): t for t in tensors}
cloned = copy.deepcopy(target, memo=memo)
Expand Down Expand Up @@ -476,7 +476,7 @@ def clone_detach_(t: torch.Tensor) -> torch.Tensor:
else:
replicate = clone_detach_

return pytree.tree_map(replicate, cast(TensorTree, target))
return pytree.tree_map(replicate, cast('TensorTree', target))


@overload
Expand Down
2 changes: 1 addition & 1 deletion torchopt/visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def make_dot( # noqa: C901
elif isinstance(param, Generator):
param_map.update({v: k for k, v in param})
else:
param_map.update({v: k for k, v in cast(Mapping, param).items()})
param_map.update({v: k for k, v in cast('Mapping', param).items()})

node_attr = {
'style': 'filled',
Expand Down
Loading
Loading