PEP585: Missed conversions (#145342)

Differential Revision: [D68785969](https://our.internmc.facebook.com/intern/diff/D68785969)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145342
Approved by: https://github.com/bobrenjc93
This commit is contained in:
Aaron Orenstein 2025-01-25 11:54:14 -08:00 committed by PyTorch MergeBot
parent 8696e59ae2
commit 7178b827d7
11 changed files with 144 additions and 143 deletions

View File

@ -20,7 +20,7 @@ import types
import typing import typing
import warnings import warnings
import weakref import weakref
from typing import ( # noqa: F401 # (Dict, List, Tuple) imported by torch.jit.annotations from typing import ( # noqa: UP035, F401 # (Dict, List, Tuple) imported by torch.jit.annotations
Any, Any,
Callable, Callable,
Dict, Dict,
@ -1125,7 +1125,8 @@ def _get_overloaded_methods(method, mod_class):
def is_tuple(ann) -> bool: def is_tuple(ann) -> bool:
if ann is Tuple: # Check for typing.Tuple missing args (but `tuple` is fine)
if ann is typing.Tuple: # noqa: UP006
raise_error_container_parameter_missing("Tuple") raise_error_container_parameter_missing("Tuple")
# For some reason Python 3.7 violates the Type[A, B].__origin__ == Type rule # For some reason Python 3.7 violates the Type[A, B].__origin__ == Type rule
@ -1133,35 +1134,31 @@ def is_tuple(ann) -> bool:
return False return False
ann_origin = get_origin(ann) ann_origin = get_origin(ann)
if IS_PY39_PLUS and ann.__module__ == "builtins" and ann_origin is tuple: return ann.__module__ in ("builtins", "typing") and ann_origin is tuple
return True
return ann.__module__ == "typing" and (ann_origin is Tuple or ann_origin is tuple)
def is_list(ann) -> bool: def is_list(ann) -> bool:
if ann is List: # Check for typing.List missing args (but `list` is fine)
if ann is typing.List: # noqa: UP006
raise_error_container_parameter_missing("List") raise_error_container_parameter_missing("List")
if not hasattr(ann, "__module__"): if not hasattr(ann, "__module__"):
return False return False
ann_origin = get_origin(ann) ann_origin = get_origin(ann)
if IS_PY39_PLUS and ann.__module__ == "builtins" and ann_origin is list: return ann.__module__ in ("builtins", "typing") and ann_origin is list
return True
return ann.__module__ == "typing" and (ann_origin is List or ann_origin is list)
def is_dict(ann) -> bool: def is_dict(ann) -> bool:
if ann is Dict: # Check for typing.Dict missing args (but `dict` is fine)
if ann is typing.Dict: # noqa: UP006
raise_error_container_parameter_missing("Dict") raise_error_container_parameter_missing("Dict")
if not hasattr(ann, "__module__"): if not hasattr(ann, "__module__"):
return False return False
ann_origin = get_origin(ann) ann_origin = get_origin(ann)
if IS_PY39_PLUS and ann.__module__ == "builtins" and ann_origin is dict: return ann.__module__ in ("builtins", "typing") and ann_origin is dict
return True
return ann.__module__ == "typing" and (ann_origin is Dict or ann_origin is dict)
def is_union(ann): def is_union(ann):
@ -1371,11 +1368,11 @@ def raise_error_container_parameter_missing(target_type) -> None:
def check_args_exist(target_type) -> None: def check_args_exist(target_type) -> None:
if target_type is List or target_type is list: if target_type is typing.List or target_type is list: # noqa: UP006
raise_error_container_parameter_missing("List") raise_error_container_parameter_missing("List")
elif target_type is Tuple or target_type is tuple: elif target_type is typing.Tuple or target_type is tuple: # noqa: UP006
raise_error_container_parameter_missing("Tuple") raise_error_container_parameter_missing("Tuple")
elif target_type is Dict or target_type is dict: elif target_type is typing.Dict or target_type is dict: # noqa: UP006
raise_error_container_parameter_missing("Dict") raise_error_container_parameter_missing("Dict")
elif target_type is None or target_type is Optional: elif target_type is None or target_type is Optional:
raise_error_container_parameter_missing("Optional") raise_error_container_parameter_missing("Optional")
@ -1399,7 +1396,7 @@ def container_checker(obj, target_type) -> bool:
check_args_exist(target_type) check_args_exist(target_type)
if origin_type is None: if origin_type is None:
return False return False
elif origin_type is list or origin_type is List: elif origin_type is list or origin_type is typing.List: # noqa: UP006
check_empty_containers(obj) check_empty_containers(obj)
if not isinstance(obj, list): if not isinstance(obj, list):
return False return False
@ -1413,7 +1410,7 @@ def container_checker(obj, target_type) -> bool:
elif not isinstance(el, arg_type): elif not isinstance(el, arg_type):
return False return False
return True return True
elif origin_type is Dict or origin_type is dict: elif origin_type is typing.Dict or origin_type is dict: # noqa: UP006
check_empty_containers(obj) check_empty_containers(obj)
if not isinstance(obj, dict): if not isinstance(obj, dict):
return False return False
@ -1430,7 +1427,7 @@ def container_checker(obj, target_type) -> bool:
elif not isinstance(val, val_type): elif not isinstance(val, val_type):
return False return False
return True return True
elif origin_type is Tuple or origin_type is tuple: elif origin_type is typing.Tuple or origin_type is tuple: # noqa: UP006
check_empty_containers(obj) check_empty_containers(obj)
if not isinstance(obj, tuple): if not isinstance(obj, tuple):
return False return False

View File

@ -209,8 +209,8 @@ def derived_types(
def derived_seq_types(typ: Union[type, typing._SpecialForm]): def derived_seq_types(typ: Union[type, typing._SpecialForm]):
return ( return (
typing.Sequence[typ], # type: ignore[valid-type] typing.Sequence[typ], # type: ignore[valid-type] # noqa: UP006
typing.List[typ], # type: ignore[valid-type] typing.List[typ], # type: ignore[valid-type] # noqa: UP006
GenericAlias(collections.abc.Sequence, (typ,)), GenericAlias(collections.abc.Sequence, (typ,)),
GenericAlias(list, (typ,)), GenericAlias(list, (typ,)),
) )
@ -252,7 +252,7 @@ def get_supported_param_types():
SUPPORTED_RETURN_TYPES = { SUPPORTED_RETURN_TYPES = {
Tensor: "Tensor", Tensor: "Tensor",
typing.List[Tensor]: "Tensor[]", typing.List[Tensor]: "Tensor[]", # noqa: UP006
list[Tensor]: "Tensor[]", list[Tensor]: "Tensor[]",
int: "SymInt", int: "SymInt",
float: "float", float: "float",
@ -306,7 +306,7 @@ def tuple_to_list(tuple_type: type[tuple]) -> type[list]:
# Account for different python versions, e.g. python 3.8 would give () # Account for different python versions, e.g. python 3.8 would give ()
# but python 3.12 would give None. # but python 3.12 would give None.
if ( if (
tuple_type is typing.Tuple tuple_type is typing.Tuple # noqa: UP006
or tuple_type is tuple or tuple_type is tuple
or type_args == () or type_args == ()
or type_args is None or type_args is None

View File

@ -2,7 +2,7 @@
import contextlib import contextlib
import sys import sys
import warnings import warnings
from typing import Any, cast, List, Optional, TYPE_CHECKING, Union from typing import Any, cast, Optional, TYPE_CHECKING, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -787,7 +787,7 @@ def _resolve_group_name(group: RANK_TYPES, tag: str = "") -> str:
FutureWarning, FutureWarning,
stacklevel=3, stacklevel=3,
) )
return c10d._resolve_group_name_by_ranks_and_tag(cast(List[int], group), tag) return c10d._resolve_group_name_by_ranks_and_tag(cast(list[int], group), tag)
else: else:
raise ValueError(f"Unsupported group type: {type(group)}, {group}") raise ValueError(f"Unsupported group type: {type(group)}, {group}")

View File

@ -3,7 +3,7 @@
import abc import abc
import io import io
from collections.abc import Sequence from collections.abc import Sequence
from typing import cast, IO, Optional, Type from typing import cast, IO, Optional
# introduced as collections.abc.Buffer in Python 3.12 # introduced as collections.abc.Buffer in Python 3.12
from typing_extensions import Buffer from typing_extensions import Buffer
@ -189,11 +189,11 @@ class ZStandard(StreamTransformExtension):
class ExtensionRegistry: class ExtensionRegistry:
def __init__(self) -> None: def __init__(self) -> None:
# Populate default registry contents # Populate default registry contents
self.extensions: dict[str, Type[Extension]] = { self.extensions: dict[str, type[Extension]] = {
cls.registry_name(): cls for cls in (ZStandard,) cls.registry_name(): cls for cls in (ZStandard,)
} }
def register(self, cls: Type[Extension]) -> None: def register(self, cls: type[Extension]) -> None:
self.extensions[cls.registry_name()] = cls self.extensions[cls.registry_name()] = cls
def from_descriptor_list(self, descriptors: Sequence[str]) -> Sequence[Extension]: def from_descriptor_list(self, descriptors: Sequence[str]) -> Sequence[Extension]:

View File

@ -1,6 +1,6 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
import math import math
from typing import List, Optional, Union from typing import Optional, Union
from typing_extensions import deprecated from typing_extensions import deprecated
import torch import torch
@ -788,8 +788,8 @@ class _ConvTransposeNd(_ConvNd):
f"or {num_non_spatial_dims + num_spatial_dims} elements (got {len(output_size)})" f"or {num_non_spatial_dims + num_spatial_dims} elements (got {len(output_size)})"
) )
min_sizes = torch.jit.annotate(List[int], []) min_sizes = torch.jit.annotate(list[int], [])
max_sizes = torch.jit.annotate(List[int], []) max_sizes = torch.jit.annotate(list[int], [])
for d in range(num_spatial_dims): for d in range(num_spatial_dims):
dim_size = ( dim_size = (
(input.size(d + num_non_spatial_dims) - 1) * stride[d] (input.size(d + num_non_spatial_dims) - 1) * stride[d]
@ -811,7 +811,7 @@ class _ConvTransposeNd(_ConvNd):
f"from {min_sizes} to {max_sizes} (for an input of {input.size()[2:]})" f"from {min_sizes} to {max_sizes} (for an input of {input.size()[2:]})"
) )
res = torch.jit.annotate(List[int], []) res = torch.jit.annotate(list[int], [])
for d in range(num_spatial_dims): for d in range(num_spatial_dims):
res.append(output_size[d] - min_sizes[d]) res.append(output_size[d] - min_sizes[d])

View File

@ -1,7 +1,8 @@
# mypy: allow-untyped-decorators # mypy: allow-untyped-decorators
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
import functools import functools
from typing import cast, Dict, Iterable, List, Optional, Tuple, Union import typing
from typing import cast, Optional, Union
from typing_extensions import deprecated from typing_extensions import deprecated
import torch import torch
@ -20,7 +21,10 @@ __all__ = [
] ]
_tensor_or_tensors = Union[torch.Tensor, Iterable[torch.Tensor]] _tensor_or_tensors = Union[
torch.Tensor,
typing.Iterable[torch.Tensor], # noqa: UP006 - needed until XLA's patch is updated
]
def _no_grad(func): def _no_grad(func):
@ -73,13 +77,13 @@ def _get_total_norm(
if len(tensors) == 0: if len(tensors) == 0:
return torch.tensor(0.0) return torch.tensor(0.0)
first_device = tensors[0].device first_device = tensors[0].device
grouped_tensors: Dict[ grouped_tensors: dict[
Tuple[torch.device, torch.dtype], Tuple[List[List[Tensor]], List[int]] tuple[torch.device, torch.dtype], tuple[list[list[Tensor]], list[int]]
] = _group_tensors_by_device_and_dtype( ] = _group_tensors_by_device_and_dtype(
[tensors] # type: ignore[list-item] [tensors] # type: ignore[list-item]
) # type: ignore[assignment] ) # type: ignore[assignment]
norms: List[Tensor] = [] norms: list[Tensor] = []
for (device, _), ([device_tensors], _) in grouped_tensors.items(): for (device, _), ([device_tensors], _) in grouped_tensors.items():
if (foreach is None and _has_foreach_support(device_tensors, device)) or ( if (foreach is None and _has_foreach_support(device_tensors, device)) or (
foreach and _device_has_foreach_support(device) foreach and _device_has_foreach_support(device)
@ -146,8 +150,8 @@ def _clip_grads_with_norm_(
max_norm = float(max_norm) max_norm = float(max_norm)
if len(grads) == 0: if len(grads) == 0:
return return
grouped_grads: Dict[ grouped_grads: dict[
Tuple[torch.device, torch.dtype], Tuple[List[List[Tensor]], List[int]] tuple[torch.device, torch.dtype], tuple[list[list[Tensor]], list[int]]
] = _group_tensors_by_device_and_dtype( ] = _group_tensors_by_device_and_dtype(
[grads] [grads]
) # type: ignore[assignment] ) # type: ignore[assignment]
@ -269,10 +273,10 @@ def clip_grad_value_(
for (device, _), ([grads], _) in grouped_grads.items(): for (device, _), ([grads], _) in grouped_grads.items():
if ( if (
foreach is None foreach is None
and _has_foreach_support(cast(List[Tensor], grads), device=device) and _has_foreach_support(cast(list[Tensor], grads), device=device)
) or (foreach and _device_has_foreach_support(device)): ) or (foreach and _device_has_foreach_support(device)):
torch._foreach_clamp_min_(cast(List[Tensor], grads), -clip_value) torch._foreach_clamp_min_(cast(list[Tensor], grads), -clip_value)
torch._foreach_clamp_max_(cast(List[Tensor], grads), clip_value) torch._foreach_clamp_max_(cast(list[Tensor], grads), clip_value)
elif foreach: elif foreach:
raise RuntimeError( raise RuntimeError(
f"foreach=True was passed, but can't use the foreach API on {device.type} tensors" f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"

View File

@ -1,5 +1,5 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
from typing import cast, List, Optional, Tuple, Union from typing import cast, Optional, Union
import torch import torch
from torch import Tensor from torch import Tensor
@ -30,7 +30,7 @@ class Adamax(Optimizer):
self, self,
params: ParamsT, params: ParamsT,
lr: Union[float, Tensor] = 2e-3, lr: Union[float, Tensor] = 2e-3,
betas: Tuple[float, float] = (0.9, 0.999), betas: tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8, eps: float = 1e-8,
weight_decay: float = 0, weight_decay: float = 0,
foreach: Optional[bool] = None, foreach: Optional[bool] = None,
@ -134,11 +134,11 @@ class Adamax(Optimizer):
loss = closure() loss = closure()
for group in self.param_groups: for group in self.param_groups:
params_with_grad: List[Tensor] = [] params_with_grad: list[Tensor] = []
grads: List[Tensor] = [] grads: list[Tensor] = []
exp_avgs: List[Tensor] = [] exp_avgs: list[Tensor] = []
exp_infs: List[Tensor] = [] exp_infs: list[Tensor] = []
state_steps: List[Tensor] = [] state_steps: list[Tensor] = []
beta1, beta2 = group["betas"] beta1, beta2 = group["betas"]
eps = group["eps"] eps = group["eps"]
@ -223,11 +223,11 @@ Adamax.__doc__ = (
def _single_tensor_adamax( def _single_tensor_adamax(
params: List[Tensor], params: list[Tensor],
grads: List[Tensor], grads: list[Tensor],
exp_avgs: List[Tensor], exp_avgs: list[Tensor],
exp_infs: List[Tensor], exp_infs: list[Tensor],
state_steps: List[Tensor], state_steps: list[Tensor],
*, *,
eps: float, eps: float,
beta1: float, beta1: float,
@ -297,11 +297,11 @@ def _single_tensor_adamax(
def _multi_tensor_adamax( def _multi_tensor_adamax(
params: List[Tensor], params: list[Tensor],
grads: List[Tensor], grads: list[Tensor],
exp_avgs: List[Tensor], exp_avgs: list[Tensor],
exp_infs: List[Tensor], exp_infs: list[Tensor],
state_steps: List[Tensor], state_steps: list[Tensor],
*, *,
eps: float, eps: float,
beta1: float, beta1: float,
@ -339,11 +339,11 @@ def _multi_tensor_adamax(
grouped_exp_infs_, grouped_exp_infs_,
grouped_state_steps_, grouped_state_steps_,
), _ in grouped_tensors.values(): ), _ in grouped_tensors.values():
grouped_params = cast(List[Tensor], grouped_params_) grouped_params = cast(list[Tensor], grouped_params_)
grouped_grads = cast(List[Tensor], grouped_grads_) grouped_grads = cast(list[Tensor], grouped_grads_)
grouped_exp_avgs = cast(List[Tensor], grouped_exp_avgs_) grouped_exp_avgs = cast(list[Tensor], grouped_exp_avgs_)
grouped_exp_infs = cast(List[Tensor], grouped_exp_infs_) grouped_exp_infs = cast(list[Tensor], grouped_exp_infs_)
grouped_state_steps = cast(List[Tensor], grouped_state_steps_) grouped_state_steps = cast(list[Tensor], grouped_state_steps_)
if has_complex: if has_complex:
_view_as_real( _view_as_real(
@ -389,7 +389,7 @@ def _multi_tensor_adamax(
torch._foreach_add_(grouped_grads, eps) torch._foreach_add_(grouped_grads, eps)
torch._foreach_maximum_(grouped_exp_infs, grouped_grads) torch._foreach_maximum_(grouped_exp_infs, grouped_grads)
bias_corrections: Union[Tuple[Tensor, ...], List[Tensor]] bias_corrections: Union[tuple[Tensor, ...], list[Tensor]]
if capturable: if capturable:
bias_corrections = torch._foreach_pow(beta1, grouped_state_steps) bias_corrections = torch._foreach_pow(beta1, grouped_state_steps)
# foreach_sub doesn't allow a scalar as the first arg # foreach_sub doesn't allow a scalar as the first arg
@ -410,11 +410,11 @@ def _multi_tensor_adamax(
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adamax) @_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adamax)
def adamax( def adamax(
params: List[Tensor], params: list[Tensor],
grads: List[Tensor], grads: list[Tensor],
exp_avgs: List[Tensor], exp_avgs: list[Tensor],
exp_infs: List[Tensor], exp_infs: list[Tensor],
state_steps: List[Tensor], state_steps: list[Tensor],
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
foreach: Optional[bool] = None, foreach: Optional[bool] = None,

View File

@ -1,6 +1,6 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
r"""Implementation for the NAdam algorithm.""" r"""Implementation for the NAdam algorithm."""
from typing import cast, List, Optional, Tuple, Union from typing import cast, Optional, Union
import torch import torch
from torch import Tensor from torch import Tensor
@ -32,7 +32,7 @@ class NAdam(Optimizer): # noqa: D101
self, self,
params: ParamsT, params: ParamsT,
lr: Union[float, Tensor] = 2e-3, lr: Union[float, Tensor] = 2e-3,
betas: Tuple[float, float] = (0.9, 0.999), betas: tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8, eps: float = 1e-8,
weight_decay: float = 0, weight_decay: float = 0,
momentum_decay: float = 4e-3, momentum_decay: float = 4e-3,
@ -167,13 +167,13 @@ class NAdam(Optimizer): # noqa: D101
loss = closure() loss = closure()
for group in self.param_groups: for group in self.param_groups:
params_with_grad: List[Tensor] = [] params_with_grad: list[Tensor] = []
grads: List[Tensor] = [] grads: list[Tensor] = []
exp_avgs: List[Tensor] = [] exp_avgs: list[Tensor] = []
exp_avg_sqs: List[Tensor] = [] exp_avg_sqs: list[Tensor] = []
mu_products: List[Tensor] = [] mu_products: list[Tensor] = []
state_steps: List[Tensor] = [] state_steps: list[Tensor] = []
beta1, beta2 = cast(Tuple[float, float], group["betas"]) beta1, beta2 = cast(tuple[float, float], group["betas"])
has_complex = self._init_group( has_complex = self._init_group(
group, group,
@ -277,12 +277,12 @@ NAdam.__doc__ = (
def _single_tensor_nadam( def _single_tensor_nadam(
params: List[Tensor], params: list[Tensor],
grads: List[Tensor], grads: list[Tensor],
exp_avgs: List[Tensor], exp_avgs: list[Tensor],
exp_avg_sqs: List[Tensor], exp_avg_sqs: list[Tensor],
mu_products: List[Tensor], mu_products: list[Tensor],
state_steps: List[Tensor], state_steps: list[Tensor],
*, *,
beta1: float, beta1: float,
beta2: float, beta2: float,
@ -371,12 +371,12 @@ def _single_tensor_nadam(
def _multi_tensor_nadam( def _multi_tensor_nadam(
params: List[Tensor], params: list[Tensor],
grads: List[Tensor], grads: list[Tensor],
exp_avgs: List[Tensor], exp_avgs: list[Tensor],
exp_avg_sqs: List[Tensor], exp_avg_sqs: list[Tensor],
mu_products: List[Tensor], mu_products: list[Tensor],
state_steps: List[Tensor], state_steps: list[Tensor],
*, *,
beta1: float, beta1: float,
beta2: float, beta2: float,
@ -417,12 +417,12 @@ def _multi_tensor_nadam(
grouped_mu_products_, grouped_mu_products_,
grouped_state_steps_, grouped_state_steps_,
), _ in grouped_tensors.values(): ), _ in grouped_tensors.values():
grouped_params = cast(List[Tensor], grouped_params_) grouped_params = cast(list[Tensor], grouped_params_)
grouped_grads = cast(List[Tensor], grouped_grads_) grouped_grads = cast(list[Tensor], grouped_grads_)
grouped_exp_avgs = cast(List[Tensor], grouped_exp_avgs_) grouped_exp_avgs = cast(list[Tensor], grouped_exp_avgs_)
grouped_exp_avg_sqs = cast(List[Tensor], grouped_exp_avg_sqs_) grouped_exp_avg_sqs = cast(list[Tensor], grouped_exp_avg_sqs_)
grouped_mu_products = cast(List[Tensor], grouped_mu_products_) grouped_mu_products = cast(list[Tensor], grouped_mu_products_)
grouped_state_steps = cast(List[Tensor], grouped_state_steps_) grouped_state_steps = cast(list[Tensor], grouped_state_steps_)
# handle complex # handle complex
if has_complex: if has_complex:
@ -469,9 +469,9 @@ def _multi_tensor_nadam(
exp_avg_sq_sqrt = torch._foreach_sqrt(grouped_exp_avg_sqs) exp_avg_sq_sqrt = torch._foreach_sqrt(grouped_exp_avg_sqs)
bias_correction_sqrt: Union[Tuple[Tensor, ...], List[Tensor]] bias_correction_sqrt: Union[tuple[Tensor, ...], list[Tensor]]
mus: Union[Tuple[Tensor, ...], List[Tensor]] mus: Union[tuple[Tensor, ...], list[Tensor]]
mu_nexts: Union[Tuple[Tensor, ...], List[Tensor]] mu_nexts: Union[tuple[Tensor, ...], list[Tensor]]
if capturable: if capturable:
# mus will be beta1 * (1 - 0.5 * 0.96 ** (step * momentum_decay)) # mus will be beta1 * (1 - 0.5 * 0.96 ** (step * momentum_decay))
exponent = torch._foreach_mul(grouped_state_steps, momentum_decay) exponent = torch._foreach_mul(grouped_state_steps, momentum_decay)
@ -579,12 +579,12 @@ def _multi_tensor_nadam(
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_nadam) @_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_nadam)
def nadam( def nadam(
params: List[Tensor], params: list[Tensor],
grads: List[Tensor], grads: list[Tensor],
exp_avgs: List[Tensor], exp_avgs: list[Tensor],
exp_avg_sqs: List[Tensor], exp_avg_sqs: list[Tensor],
mu_products: List[Tensor], mu_products: list[Tensor],
state_steps: List[Tensor], state_steps: list[Tensor],
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
decoupled_weight_decay: bool = False, decoupled_weight_decay: bool = False,

View File

@ -1,6 +1,6 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
r"""Implementation for the RAdam algorithm.""" r"""Implementation for the RAdam algorithm."""
from typing import cast, List, Optional, Tuple, Union from typing import cast, Optional, Union
import torch import torch
from torch import Tensor from torch import Tensor
@ -31,7 +31,7 @@ class RAdam(Optimizer): # noqa: D101
self, self,
params: ParamsT, params: ParamsT,
lr: Union[float, Tensor] = 1e-3, lr: Union[float, Tensor] = 1e-3,
betas: Tuple[float, float] = (0.9, 0.999), betas: tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8, eps: float = 1e-8,
weight_decay: float = 0, weight_decay: float = 0,
decoupled_weight_decay: bool = False, decoupled_weight_decay: bool = False,
@ -138,12 +138,12 @@ class RAdam(Optimizer): # noqa: D101
loss = closure() loss = closure()
for group in self.param_groups: for group in self.param_groups:
params_with_grad: List[Tensor] = [] params_with_grad: list[Tensor] = []
grads: List[Tensor] = [] grads: list[Tensor] = []
exp_avgs: List[Tensor] = [] exp_avgs: list[Tensor] = []
exp_avg_sqs: List[Tensor] = [] exp_avg_sqs: list[Tensor] = []
state_steps: List[Tensor] = [] state_steps: list[Tensor] = []
beta1, beta2 = cast(Tuple[float, float], group["betas"]) beta1, beta2 = cast(tuple[float, float], group["betas"])
has_complex = self._init_group( has_complex = self._init_group(
group, params_with_grad, grads, exp_avgs, exp_avg_sqs, state_steps group, params_with_grad, grads, exp_avgs, exp_avg_sqs, state_steps
@ -252,11 +252,11 @@ RAdam.__doc__ = (
def _single_tensor_radam( def _single_tensor_radam(
params: List[Tensor], params: list[Tensor],
grads: List[Tensor], grads: list[Tensor],
exp_avgs: List[Tensor], exp_avgs: list[Tensor],
exp_avg_sqs: List[Tensor], exp_avg_sqs: list[Tensor],
state_steps: List[Tensor], state_steps: list[Tensor],
*, *,
beta1: float, beta1: float,
beta2: float, beta2: float,
@ -351,11 +351,11 @@ def _single_tensor_radam(
def _multi_tensor_radam( def _multi_tensor_radam(
params: List[Tensor], params: list[Tensor],
grads: List[Tensor], grads: list[Tensor],
exp_avgs: List[Tensor], exp_avgs: list[Tensor],
exp_avg_sqs: List[Tensor], exp_avg_sqs: list[Tensor],
state_steps: List[Tensor], state_steps: list[Tensor],
*, *,
beta1: float, beta1: float,
beta2: float, beta2: float,
@ -394,11 +394,11 @@ def _multi_tensor_radam(
grouped_exp_avg_sqs_, grouped_exp_avg_sqs_,
grouped_state_steps_, grouped_state_steps_,
), _ in grouped_tensors.values(): ), _ in grouped_tensors.values():
grouped_params = cast(List[Tensor], grouped_params_) grouped_params = cast(list[Tensor], grouped_params_)
grouped_grads = cast(List[Tensor], grouped_grads_) grouped_grads = cast(list[Tensor], grouped_grads_)
grouped_exp_avgs = cast(List[Tensor], grouped_exp_avgs_) grouped_exp_avgs = cast(list[Tensor], grouped_exp_avgs_)
grouped_exp_avg_sqs = cast(List[Tensor], grouped_exp_avg_sqs_) grouped_exp_avg_sqs = cast(list[Tensor], grouped_exp_avg_sqs_)
grouped_state_steps = cast(List[Tensor], grouped_state_steps_) grouped_state_steps = cast(list[Tensor], grouped_state_steps_)
# Update steps # Update steps
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
@ -422,9 +422,9 @@ def _multi_tensor_radam(
# maximum length of the approximated SMA # maximum length of the approximated SMA
rho_inf = 2 / (1 - beta2) - 1 rho_inf = 2 / (1 - beta2) - 1
# compute the length of the approximated SMA # compute the length of the approximated SMA
bias_correction1: Union[Tuple[Tensor, ...], List[Tensor]] bias_correction1: Union[tuple[Tensor, ...], list[Tensor]]
bias_correction2: Union[Tuple[Tensor, ...], List[Tensor]] bias_correction2: Union[tuple[Tensor, ...], list[Tensor]]
rho_t_list: Union[Tuple[Tensor, ...], List[Tensor]] rho_t_list: Union[tuple[Tensor, ...], list[Tensor]]
if capturable: if capturable:
bias_correction1 = torch._foreach_pow(beta2, grouped_state_steps) bias_correction1 = torch._foreach_pow(beta2, grouped_state_steps)
torch._foreach_neg_(bias_correction1) torch._foreach_neg_(bias_correction1)
@ -547,11 +547,11 @@ def _multi_tensor_radam(
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_radam) @_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_radam)
def radam( def radam(
params: List[Tensor], params: list[Tensor],
grads: List[Tensor], grads: list[Tensor],
exp_avgs: List[Tensor], exp_avgs: list[Tensor],
exp_avg_sqs: List[Tensor], exp_avg_sqs: list[Tensor],
state_steps: List[Tensor], state_steps: list[Tensor],
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
decoupled_weight_decay: bool = False, decoupled_weight_decay: bool = False,

View File

@ -1,15 +1,16 @@
from collections.abc import Sequence
from pathlib import Path from pathlib import Path
from re import match as _match from re import match as _match
from typing import List, Optional, Sequence, Set, Union from typing import Optional, Union
def read_file(fname: Union[Path, str]) -> List[str]: def read_file(fname: Union[Path, str]) -> list[str]:
with open(fname, encoding="utf-8") as f: with open(fname, encoding="utf-8") as f:
return f.readlines() return f.readlines()
def _embed_headers( def _embed_headers(
content: List[str], include_dirs: List[Path], processed_files: Set[str] content: list[str], include_dirs: list[Path], processed_files: set[str]
) -> str: ) -> str:
for line_idx, cur_line in enumerate(content): for line_idx, cur_line in enumerate(content):
m = _match('^\\s*#include\\s*[<"]([^>"]+)[>"]', cur_line) m = _match('^\\s*#include\\s*[<"]([^>"]+)[>"]', cur_line)

View File

@ -32,7 +32,6 @@ from typing import (
Callable, Callable,
cast, cast,
Generic, Generic,
List,
Optional, Optional,
overload, overload,
Protocol, Protocol,
@ -747,7 +746,7 @@ class TreeSpec:
return self.num_nodes == 1 and self.num_leaves == 1 return self.num_nodes == 1 and self.num_leaves == 1
def flatten_up_to(self, tree: PyTree) -> list[PyTree]: def flatten_up_to(self, tree: PyTree) -> list[PyTree]:
def helper(treespec: TreeSpec, tree: PyTree, subtrees: List[PyTree]) -> None: def helper(treespec: TreeSpec, tree: PyTree, subtrees: list[PyTree]) -> None:
if treespec.is_leaf(): if treespec.is_leaf():
subtrees.append(tree) subtrees.append(tree)
return return
@ -881,7 +880,7 @@ def tree_flatten(
to reconstruct the pytree. to reconstruct the pytree.
""" """
def helper(node: PyTree, leaves: List[Any]) -> TreeSpec: def helper(node: PyTree, leaves: list[Any]) -> TreeSpec:
if _is_leaf(node, is_leaf=is_leaf): if _is_leaf(node, is_leaf=is_leaf):
leaves.append(node) leaves.append(node)
return _LEAF_SPEC return _LEAF_SPEC