mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fixes #102428 Also improves hook registration type hints: ```python from typing import Any, Dict, Tuple from torch import nn from torch.optim import Adam, Adagrad, Optimizer linear = nn.Linear(2,2) optimizer = Adam(linear.parameters(), lr=0.001) def pre_hook_fn_return_none(optimizer: Adam, inputs: Tuple[Any, ...], kwargs: Dict[str, Any]) -> None: return None def pre_hook_fn_return_modified( optimizer: Optimizer, inputs: Tuple[Any, ...], kwargs: Dict[str, Any] ) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: return inputs, kwargs def hook_fn(optimizer: Optimizer, inputs: Tuple[Any, ...], kwargs: Dict[str, Any]) -> None: return None def hook_fn_other_optimizer(optimizer: Adagrad, inputs: Tuple[Any, ...], kwargs: Dict[str, Any]) -> None: return None optimizer.register_step_post_hook(hook_fn) # OK optimizer.register_step_pre_hook(pre_hook_fn_return_none) # OK optimizer.register_step_pre_hook(pre_hook_fn_return_modified) # OK optimizer.register_step_post_hook(hook_fn_other_optimizer) # Parameter 1: type "Adam" cannot be assigned to type "Adagrad" ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/102593 Approved by: https://github.com/janeyx99, https://github.com/malfet
13 lines
269 B
Python
13 lines
269 B
Python
from typing import Tuple
|
|
|
|
from .optimizer import Optimizer, params_t
|
|
|
|
class SparseAdam(Optimizer):
|
|
def __init__(
|
|
self,
|
|
params: params_t,
|
|
lr: float = ...,
|
|
betas: Tuple[float, float] = ...,
|
|
eps: float = ...,
|
|
) -> None: ...
|