[5/N] Apply Ruff fixes and pyupgrade to Python 3.9 (#144205)

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144205
Approved by: https://github.com/albanD
This commit is contained in:
cyy 2025-01-15 04:00:47 +00:00 committed by PyTorch MergeBot
parent db787181b5
commit d87aad6877
46 changed files with 398 additions and 444 deletions

View File

@ -9,7 +9,8 @@ half, float, double and bfloat16) and complex :class:`Tensor` types (cfloat, cdo
"""
import warnings
from typing import cast, List, Optional, Sequence, Tuple, Union
from collections.abc import Sequence
from typing import cast, List, Optional, Tuple, Union
import torch
from torch import _vmap_internals
@ -60,7 +61,7 @@ def _calculate_shape(
output: Union[torch.Tensor, graph.GradientEdge],
grad: torch.Tensor,
is_grads_batched: bool,
) -> Tuple[_ShapeorNestedShape, _ShapeorNestedShape]:
) -> tuple[_ShapeorNestedShape, _ShapeorNestedShape]:
# is_same_size ensures that both tensors are either nested or non nested
# circular import
from torch.nested._internal.nested_tensor import NestedTensor
@ -89,8 +90,8 @@ def _make_grads(
outputs: Union[Sequence[torch.Tensor], Sequence[graph.GradientEdge]],
grads: Sequence[_OptionalTensor],
is_grads_batched: bool,
) -> Tuple[_OptionalTensor, ...]:
new_grads: List[_OptionalTensor] = []
) -> tuple[_OptionalTensor, ...]:
new_grads: list[_OptionalTensor] = []
for out, grad in zip(outputs, grads):
out = cast(Union[torch.Tensor, graph.GradientEdge], out)
out_size = None
@ -231,7 +232,7 @@ def _make_grads(
def _tensor_or_tensors_to_tuple(
tensors: Optional[_TensorOrTensors], length: int
) -> Tuple[_OptionalTensor, ...]:
) -> tuple[_OptionalTensor, ...]:
if tensors is None:
return (None,) * length
if isinstance(tensors, torch.Tensor):
@ -370,7 +371,7 @@ def grad(
allow_unused: Optional[bool] = None,
is_grads_batched: bool = False,
materialize_grads: bool = False,
) -> Tuple[torch.Tensor, ...]:
) -> tuple[torch.Tensor, ...]:
r"""Compute and return the sum of gradients of outputs with respect to the inputs.
``grad_outputs`` should be a sequence of length matching ``output``

View File

@ -4,7 +4,7 @@ import inspect
import itertools
import warnings
from collections import OrderedDict
from typing import Any, List, Optional, Tuple
from typing import Any, Optional
from typing_extensions import deprecated
import torch
@ -389,7 +389,7 @@ class _SingleLevelFunction(
)
@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> Any:
def setup_context(ctx: Any, inputs: tuple[Any, ...], output: Any) -> Any:
r"""There are two ways to define the forward pass of an autograd.Function.
Either:
@ -722,7 +722,7 @@ def _unflatten(input, proto):
# unflatten a list or tuple input into a nested list/tuple structure
# specified by proto
def unflatten_helper(input, proto):
res: List[Optional[torch.Tensor]] = []
res: list[Optional[torch.Tensor]] = []
if hasattr(proto, "_jit_wrap"):
return proto._jit_wrap(input)
if not isinstance(proto, (list, tuple)):

View File

@ -1,5 +1,4 @@
# mypy: allow-untyped-defs
from typing import List, Tuple
import torch
from torch._vmap_internals import _vmap
@ -181,8 +180,8 @@ def _autograd_grad(
assert isinstance(grad_outputs, tuple)
assert len(outputs) == len(grad_outputs)
new_outputs: Tuple[torch.Tensor, ...] = ()
new_grad_outputs: Tuple[torch.Tensor, ...] = ()
new_outputs: tuple[torch.Tensor, ...] = ()
new_grad_outputs: tuple[torch.Tensor, ...] = ()
for out, grad_out in zip(outputs, grad_outputs):
if out is not None and out.requires_grad:
new_outputs += (out,)
@ -211,7 +210,7 @@ def _fill_in_zeros(grads, refs, strict, create_graph, stage):
if stage not in ["back", "back_trick", "double_back", "double_back_trick"]:
raise RuntimeError(f"Invalid stage argument '{stage}' to _fill_in_zeros")
res: Tuple[torch.Tensor, ...] = ()
res: tuple[torch.Tensor, ...] = ()
for i, grads_i in enumerate(grads):
if grads_i is None:
if strict:
@ -470,8 +469,8 @@ def jvp(func, inputs, v=None, create_graph=False, strict=False):
def _construct_standard_basis_for(
tensors: Tuple[torch.Tensor, ...], tensor_numels: Tuple[int, ...]
) -> Tuple[torch.Tensor, ...]:
tensors: tuple[torch.Tensor, ...], tensor_numels: tuple[int, ...]
) -> tuple[torch.Tensor, ...]:
# This function:
# - constructs a N=sum(tensor_numels) standard basis. i.e. an NxN identity matrix.
# - Splits the identity matrix into chunks with each chunk size determined by `tensor_numels`.
@ -780,11 +779,11 @@ def jacobian(
jacobian_output_input, (is_outputs_tuple, is_inputs_tuple)
)
jacobian: Tuple[torch.Tensor, ...] = ()
jacobian: tuple[torch.Tensor, ...] = ()
for i, out in enumerate(outputs):
# mypy complains that expression and variable have different types due to the empty list
jac_i: Tuple[List[torch.Tensor]] = tuple([] for _ in range(len(inputs))) # type: ignore[assignment]
jac_i: tuple[list[torch.Tensor]] = tuple([] for _ in range(len(inputs))) # type: ignore[assignment]
for j in range(out.nelement()):
vj = _autograd_grad(
(out.reshape(-1)[j],),

View File

@ -2,8 +2,9 @@
import collections
import functools
import warnings
from collections.abc import Iterable
from itertools import product
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
from typing import Callable, Optional, Union
from typing_extensions import deprecated
import torch
@ -48,14 +49,14 @@ def _is_float_or_complex_tensor(obj):
def _allocate_jacobians_with_inputs(
input_tensors: Tuple, numel_output
) -> Tuple[torch.Tensor, ...]:
input_tensors: tuple, numel_output
) -> tuple[torch.Tensor, ...]:
# Makes zero-filled tensors from inputs. If `numel_output` is not None, for
# each tensor in `input_tensors`, returns a new zero-filled tensor with height
# of `t.numel` and width of `numel_output`. Otherwise, for each tensor, returns
# a 1-d tensor with size `(t.numel,)`. Each new tensor will be strided and have
# the same dtype and device as those of the corresponding input.
out: List[torch.Tensor] = [
out: list[torch.Tensor] = [
t.new_zeros((t.numel(), numel_output), layout=torch.strided)
for t in input_tensors
if _is_float_or_complex_tensor(t) and t.requires_grad
@ -64,14 +65,14 @@ def _allocate_jacobians_with_inputs(
def _allocate_jacobians_with_outputs(
output_tensors: Tuple, numel_input, dtype=None, device=None
) -> Tuple[torch.Tensor, ...]:
output_tensors: tuple, numel_input, dtype=None, device=None
) -> tuple[torch.Tensor, ...]:
# Makes zero-filled tensors from outputs. If `dim` is not None, for each tensor
# in `output_tensors`, returns a new zero-filled tensor with height of `dim` and
# width of `t.numel`. Otherwise, for each tensor, returns a 1-d tensor with size
# (t.numel,).
options = {"dtype": dtype, "device": device, "layout": torch.strided}
out: List[torch.Tensor] = [
out: list[torch.Tensor] = [
t.new_zeros((numel_input, t.numel()), **options)
for t in output_tensors
if _is_float_or_complex_tensor(t)
@ -258,7 +259,7 @@ def _iter_tensor(x_tensor):
def _get_numerical_jacobian(
fn, inputs, outputs=None, target=None, eps=1e-3, is_forward_ad=False
) -> List[Tuple[torch.Tensor, ...]]:
) -> list[tuple[torch.Tensor, ...]]:
"""Compute the numerical Jacobian of `fn(inputs)` with respect to `target`.
If not specified, targets are the input. Returns M * N Jacobians where N is the
@ -281,7 +282,7 @@ def _get_numerical_jacobian(
Note that `target` may not even be part of `input` to `fn`, so please be
**very careful** in this to not clone `target`.
"""
jacobians: List[Tuple[torch.Tensor, ...]] = []
jacobians: list[tuple[torch.Tensor, ...]] = []
if outputs is None:
outputs = _as_tuple(fn(*_as_tuple(inputs)))
if not is_forward_ad and any(o.is_complex() for o in outputs):
@ -386,13 +387,13 @@ def _compute_numerical_gradient(fn, entry, v, norm_v, nbhd_checks_fn):
def _compute_numerical_jvps_wrt_specific_input(
jvp_fn, delta, input_is_complex, is_forward_ad=False
) -> List[torch.Tensor]:
) -> list[torch.Tensor]:
# Computing the jacobian only works for real delta
# For details on the algorithm used here, refer:
# Section 3.5.3 https://arxiv.org/pdf/1701.00392.pdf
# s = fn(z) where z = x for real valued input
# and z = x + yj for complex valued input
jvps: List[torch.Tensor] = []
jvps: list[torch.Tensor] = []
ds_dx_tup = jvp_fn(delta[0] if isinstance(delta, tuple) else delta)
if input_is_complex: # C -> R
@ -412,8 +413,8 @@ def _compute_numerical_jvps_wrt_specific_input(
def _combine_jacobian_cols(
jacobians_cols: Dict[int, List[torch.Tensor]], outputs, input, numel
) -> Tuple[torch.Tensor, ...]:
jacobians_cols: dict[int, list[torch.Tensor]], outputs, input, numel
) -> tuple[torch.Tensor, ...]:
# jacobian_cols maps column_idx -> output_idx -> single column of jacobian Tensor
# we return a list that maps output_idx -> full jacobian Tensor
jacobians = _allocate_jacobians_with_outputs(
@ -467,13 +468,13 @@ def _check_outputs_same_dtype_and_shape(output1, output2, eps, idx=None) -> None
def get_numerical_jacobian_wrt_specific_input(
fn, input_idx, inputs, outputs, eps, input=None, is_forward_ad=False
) -> Tuple[torch.Tensor, ...]:
) -> tuple[torch.Tensor, ...]:
# Computes the numerical jacobians wrt to a single input. Returns N jacobian
# tensors, where N is the number of outputs. We use a dictionary for
# jacobian_cols because indices aren't necessarily consecutive for sparse inputs
# When we perturb only a single element of the input tensor at a time, the jvp
# is equivalent to a single col of the Jacobian matrix of fn.
jacobian_cols: Dict[int, List[torch.Tensor]] = {}
jacobian_cols: dict[int, list[torch.Tensor]] = {}
input = inputs[input_idx] if input is None else input
assert input.requires_grad
for x, idx, d_idx in _iter_tensor(input):
@ -493,7 +494,7 @@ def get_numerical_jacobian_wrt_specific_input(
def _get_analytical_jacobian_forward_ad(
fn, inputs, outputs, *, check_grad_dtypes=False, all_u=None
) -> Tuple[Tuple[torch.Tensor, ...], ...]:
) -> tuple[tuple[torch.Tensor, ...], ...]:
"""Compute the analytical Jacobian using forward mode AD of `fn(inputs)` using forward mode AD with respect to `target`.
Return N * M Jacobians where N is the number of tensors in target that require grad and
@ -659,7 +660,7 @@ def _mul_tensor_or_tuple(u, k):
def _get_numerical_jvp_wrt_specific_input(
fn, input_idx, inputs, u, eps, is_forward_ad=False
) -> List[torch.Tensor]:
) -> list[torch.Tensor]:
input = inputs[input_idx]
input_to_perturb = _get_input_to_perturb(input)
wrapped_fn = _with_prepare_inputs(fn, inputs, input_idx, input_to_perturb, True)
@ -676,7 +677,7 @@ def _get_numerical_vJu(
fn, inputs, inp_indices, func_out, all_u, all_v, eps, is_forward_ad
):
# Note that all_v can also be None, in that case, this function only computes Ju.
reduced_jacobians: List[List[torch.Tensor]] = []
reduced_jacobians: list[list[torch.Tensor]] = []
for inp_idx, u in zip(inp_indices, all_u):
all_Ju = _get_numerical_jvp_wrt_specific_input(
fn, inp_idx, inputs, u, eps, is_forward_ad
@ -692,7 +693,7 @@ def _get_numerical_vJu(
# TODO: handle the other Ju
pass
if all_v is not None:
jacobian_scalars: List[torch.Tensor] = []
jacobian_scalars: list[torch.Tensor] = []
for v, Ju in zip(all_v, filtered_Ju):
jacobian_scalars.append(_dot_with_type_promotion(v, Ju))
reduced_jacobians.append(jacobian_scalars)
@ -712,7 +713,7 @@ def _check_jacobians_equal(j1, j2, atol):
def _stack_and_check_tensors(
list_of_list_of_tensors, inputs, numel_outputs
) -> Tuple[Tuple[torch.Tensor, ...], bool, bool]:
) -> tuple[tuple[torch.Tensor, ...], bool, bool]:
# For the ith tensor in the inner list checks whether it has the same size and
# dtype as the ith differentiable input.
out_jacobians = _allocate_jacobians_with_inputs(inputs, numel_outputs)
@ -757,7 +758,7 @@ If the test
def _check_analytical_jacobian_attributes(
inputs, output, nondet_tol, check_grad_dtypes, fast_mode=False, v=None
) -> Tuple[torch.Tensor, ...]:
) -> tuple[torch.Tensor, ...]:
# This is used by both fast and slow mode:
# - For slow mode, vjps[i][j] is the jth row of the Jacobian wrt the ith
# input.
@ -802,12 +803,12 @@ def _check_analytical_jacobian_attributes(
def _get_analytical_vJu_backward_mode(
inputs, outputs, nondet_tol, check_grad_dtypes, all_v, all_u
):
reduced_jacobians: List[List[torch.Tensor]] = []
reduced_jacobians: list[list[torch.Tensor]] = []
for output, v in zip(outputs, all_v):
all_vJ = _check_analytical_jacobian_attributes(
inputs, output, nondet_tol, check_grad_dtypes, fast_mode=True, v=v
)
jacobian_scalars: List[torch.Tensor] = []
jacobian_scalars: list[torch.Tensor] = []
for vJ, u in zip(all_vJ, all_u):
# Why do we need squeeze here? vJ is a 2-d tensor so that we can reuse
# the error checking logic from slow mode
@ -878,7 +879,7 @@ def _get_analytical_jacobian(inputs, outputs, input_idx, output_idx):
def _compute_analytical_jacobian_rows(
vjp_fn, sample_output
) -> List[List[Optional[torch.Tensor]]]:
) -> list[list[Optional[torch.Tensor]]]:
# Computes Jacobian row-by-row by projecting `vjp_fn` = v^T J on standard basis
# vectors: vjp_fn(e) = e^T J is a corresponding row of the Jacobian.
# NB: this function does not assume vjp_fn(v) to return tensors with the same
@ -889,7 +890,7 @@ def _compute_analytical_jacobian_rows(
)
flat_grad_out = grad_out_base.view(-1)
# jacobians_rows[i][j] is the Jacobian jth row for the ith input
jacobians_rows: List[List[Optional[torch.Tensor]]] = []
jacobians_rows: list[list[Optional[torch.Tensor]]] = []
for j in range(flat_grad_out.numel()):
flat_grad_out.zero_()
flat_grad_out[j] = 1.0 # projection for jth row of Jacobian
@ -905,9 +906,9 @@ def _compute_analytical_jacobian_rows(
def _get_analytical_vjps_wrt_specific_output(
vjp_fn, sample_output, v
) -> List[List[Optional[torch.Tensor]]]:
) -> list[list[Optional[torch.Tensor]]]:
grad_inputs = vjp_fn(v.reshape(sample_output.shape))
vjps: List[List[Optional[torch.Tensor]]] = [
vjps: list[list[Optional[torch.Tensor]]] = [
[vjp.clone() if isinstance(vjp, torch.Tensor) else None] for vjp in grad_inputs
]
return vjps
@ -1174,7 +1175,7 @@ def _test_batched_grad(input, output, output_idx) -> bool:
def _test_backward_mul_by_grad_output(outputs, inputs, masked) -> bool:
# Tests that backward is multiplied by grad_output
diff_input_list: List[torch.Tensor] = list(_iter_tensors(inputs, True))
diff_input_list: list[torch.Tensor] = list(_iter_tensors(inputs, True))
if not diff_input_list:
raise GradcheckError("no Tensors requiring grad found in input")
grads_input = torch.autograd.grad(
@ -1294,7 +1295,7 @@ def _test_undefined_forward_mode(func, outputs, inputs):
def _test_undefined_backward_mode(func, outputs, inputs) -> bool:
diff_input_list: List[torch.Tensor] = list(_iter_tensors(inputs, True))
diff_input_list: list[torch.Tensor] = list(_iter_tensors(inputs, True))
if not diff_input_list:
raise GradcheckError("no Tensors requiring grad found in input")

View File

@ -4,23 +4,15 @@ import functools
import logging
import threading
from collections import defaultdict, deque
from collections.abc import Generator, Iterable, Iterator, MutableMapping, Sequence
from typing import (
Any,
Callable,
cast,
Deque,
Dict,
Generator,
Iterable,
Iterator,
List,
Literal,
MutableMapping,
NamedTuple,
Optional,
Sequence,
Set,
Tuple,
TYPE_CHECKING,
Union,
)
@ -71,7 +63,7 @@ class Node(abc.ABC):
@property
@abc.abstractmethod
def next_functions(self) -> Tuple[Tuple[Optional["Node"], int], ...]:
def next_functions(self) -> tuple[tuple[Optional["Node"], int], ...]:
raise NotImplementedError
@abc.abstractmethod
@ -81,7 +73,7 @@ class Node(abc.ABC):
@property
@abc.abstractmethod
def _input_metadata(self) -> List[Any]:
def _input_metadata(self) -> list[Any]:
raise NotImplementedError
@abc.abstractmethod
@ -367,7 +359,7 @@ class save_on_cpu(saved_tensors_hooks):
def __init__(self, pin_memory: bool = False, device_type: str = "cuda") -> None:
device_module = getattr(torch, device_type, torch.cuda)
def pack_to_cpu(tensor: torch.Tensor) -> Tuple[torch.device, torch.Tensor]:
def pack_to_cpu(tensor: torch.Tensor) -> tuple[torch.device, torch.Tensor]:
if not pin_memory:
return (tensor.device, tensor.cpu())
packed = torch.empty(
@ -379,7 +371,7 @@ class save_on_cpu(saved_tensors_hooks):
packed.copy_(tensor)
return (tensor.device, packed)
def unpack_from_cpu(packed: Tuple[torch.device, torch.Tensor]) -> torch.Tensor:
def unpack_from_cpu(packed: tuple[torch.device, torch.Tensor]) -> torch.Tensor:
device, tensor = packed
return tensor.to(device, non_blocking=pin_memory)
@ -423,19 +415,19 @@ def disable_saved_tensors_hooks(error_message: str) -> Generator[None, None, Non
class _MultiHandle(RemovableHandle):
handles: Tuple[RemovableHandle, ...]
handles: tuple[RemovableHandle, ...]
def __init__(self, handles: Tuple[RemovableHandle, ...]) -> None:
def __init__(self, handles: tuple[RemovableHandle, ...]) -> None:
self.handles = handles
def remove(self) -> None:
for handle in self.handles:
handle.remove()
def __getstate__(self) -> Tuple[RemovableHandle, ...]:
def __getstate__(self) -> tuple[RemovableHandle, ...]:
return self.handles
def __setstate__(self, state: Tuple[RemovableHandle, ...]) -> None:
def __setstate__(self, state: tuple[RemovableHandle, ...]) -> None:
self.handles = state
@ -502,9 +494,9 @@ def register_multi_grad_hook(
raise ValueError(f"Expects mode to be one of {supported_modes} but got {mode}")
if mode == "all":
count: Dict[int, int] = {}
count: dict[int, int] = {}
nb_calls = None
buffer: Dict[int, List[Optional[torch.Tensor]]] = {}
buffer: dict[int, list[Optional[torch.Tensor]]] = {}
grad_fns = list(map(_get_grad_fn_or_grad_acc, tensors))
len_tensors = len(tensors)
@ -544,7 +536,7 @@ def register_multi_grad_hook(
)
elif mode == "any":
fn = cast(Callable[[torch.Tensor], None], fn)
ran_hook: Dict[int, bool] = defaultdict(bool)
ran_hook: dict[int, bool] = defaultdict(bool)
@functools.wraps(fn)
def wrapped_fn(grad: torch.Tensor) -> None:
@ -582,8 +574,8 @@ def register_multi_grad_hook(
_allow_mutation_on_saved_tensors_enabled: bool = False
_TID: TypeAlias = Tuple[int, int, int]
_SID: TypeAlias = Tuple[int, int]
_TID: TypeAlias = tuple[int, int, int]
_SID: TypeAlias = tuple[int, int]
def _get_tid(tensor: torch.Tensor) -> _TID:
@ -665,8 +657,8 @@ class _CloneArgBeforeMutateMode(TorchDispatchMode):
self,
func: "OpOverload",
types: Iterable[type],
args: Tuple[Any, ...] = (),
kwargs: Optional[Dict[Any, Any]] = None,
args: tuple[Any, ...] = (),
kwargs: Optional[dict[Any, Any]] = None,
) -> Any:
kwargs = kwargs or {}
@ -704,7 +696,7 @@ class _AllowMutationOnSavedContext:
self.cloned: MutableMapping[_Handle, torch.Tensor] = WeakKeyDictionary()
self.original: MutableMapping[_Handle, torch.Tensor] = WeakKeyDictionary()
self.tid_to_weakhandle: MutableMapping[_TID, _Handle] = WeakValueDictionary()
self.sid_to_tid: Dict[_SID, Set[_TID]] = defaultdict(set)
self.sid_to_tid: dict[_SID, set[_TID]] = defaultdict(set)
def clear(self) -> None:
self.cloned.clear()
@ -768,10 +760,10 @@ def _register_logging_hooks_on_whole_graph(
) -> Callable[[], None]:
grad_fns = list(map(_get_grad_fn_or_grad_acc, t_outputs))
def iter_graph(roots: List[Node]) -> Iterator[Node]:
def iter_graph(roots: list[Node]) -> Iterator[Node]:
if not roots:
return
seen: Set[Node] = set()
seen: set[Node] = set()
q: Deque[Node] = deque()
for node in roots:
if node is not None:
@ -815,7 +807,7 @@ def _engine_run_backward(
t_outputs: Sequence[Union[torch.Tensor, GradientEdge]],
*args: Any,
**kwargs: Any,
) -> Tuple[torch.Tensor, ...]:
) -> tuple[torch.Tensor, ...]:
attach_logging_hooks = log.getEffectiveLevel() <= logging.DEBUG
if attach_logging_hooks:
unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs)

View File

@ -1,9 +1,10 @@
# mypy: allow-untyped-defs
import uuid
from collections import defaultdict
from collections.abc import Iterable
from dataclasses import dataclass
from time import perf_counter_ns
from typing import Any, Dict, Iterable, List, Optional
from typing import Any, Optional
from warnings import warn
import torch
@ -557,7 +558,7 @@ class profile:
# frontend_function_events contains the events in aten or torch frontend level,
# whose correlation id is 0
frontend_function_events = []
device_corr_map: Dict[int, List[FunctionEvent]] = {}
device_corr_map: dict[int, list[FunctionEvent]] = {}
max_evt_id = 0
for kineto_event in result.events():
if _filter_name(kineto_event.name()):
@ -1141,7 +1142,7 @@ class KinetoStepTracker:
"""
_current_step = 0
_step_dict: Dict[str, int] = defaultdict(int)
_step_dict: dict[str, int] = defaultdict(int)
@classmethod
def init_step_count(cls, requester: str):

View File

@ -4,7 +4,7 @@ import itertools
import math
from collections import defaultdict, namedtuple
from operator import attrgetter
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Optional
from typing_extensions import deprecated
import torch
@ -114,7 +114,7 @@ class EventList(list):
thread_events,
key=lambda event: [event.time_range.start, -event.time_range.end],
)
current_events: List[FunctionEvent] = []
current_events: list[FunctionEvent] = []
for event in thread_events_:
while len(current_events) > 0:
parent = current_events[-1]
@ -310,9 +310,9 @@ class EventList(list):
An EventList containing FunctionEventAvg objects.
"""
assert self._tree_built
stats: Dict[Tuple[str, ...], FunctionEventAvg] = defaultdict(FunctionEventAvg)
stats: dict[tuple[str, ...], FunctionEventAvg] = defaultdict(FunctionEventAvg)
def get_key(event, group_by_input_shapes, group_by_stack_n) -> Tuple[str, ...]:
def get_key(event, group_by_input_shapes, group_by_stack_n) -> tuple[str, ...]:
key = [
str(event.key),
str(event.node_id),
@ -476,14 +476,14 @@ class FunctionEvent(FormattedTimesMixin):
self.time_range: Interval = Interval(start_us, end_us)
self.thread: int = thread
self.fwd_thread: Optional[int] = fwd_thread
self.kernels: List[Kernel] = []
self.kernels: list[Kernel] = []
self.count: int = 1
self.cpu_children: List[FunctionEvent] = []
self.cpu_children: list[FunctionEvent] = []
self.cpu_parent: Optional[FunctionEvent] = None
self.input_shapes: Tuple[int, ...] = input_shapes
self.concrete_inputs: List[Any] = concrete_inputs
self.kwinputs: Dict[str, Any] = kwinputs
self.stack: List = stack
self.input_shapes: tuple[int, ...] = input_shapes
self.concrete_inputs: list[Any] = concrete_inputs
self.kwinputs: dict[str, Any] = kwinputs
self.stack: list = stack
self.scope: int = scope
self.use_device: Optional[str] = use_device
self.cpu_memory_usage: int = cpu_memory_usage
@ -656,14 +656,14 @@ class FunctionEventAvg(FormattedTimesMixin):
self.device_time_total: int = 0
self.self_cpu_time_total: int = 0
self.self_device_time_total: int = 0
self.input_shapes: Optional[List[List[int]]] = None
self.stack: Optional[List] = None
self.input_shapes: Optional[list[list[int]]] = None
self.stack: Optional[list] = None
self.scope: Optional[int] = None
self.cpu_memory_usage: int = 0
self.device_memory_usage: int = 0
self.self_cpu_memory_usage: int = 0
self.self_device_memory_usage: int = 0
self.cpu_children: Optional[List[FunctionEvent]] = None
self.cpu_children: Optional[list[FunctionEvent]] = None
self.cpu_parent: Optional[FunctionEvent] = None
self.device_type: DeviceType = DeviceType.CPU
self.is_legacy: bool = False
@ -734,8 +734,8 @@ class MemRecordsAcc:
def __init__(self, mem_records):
self._mem_records = mem_records
self._start_nses: List[int] = []
self._indices: List[int] = []
self._start_nses: list[int] = []
self._indices: list[int] = []
if len(mem_records) > 0:
tmp = sorted([(r[0].start_ns(), i) for i, r in enumerate(mem_records)])
self._start_nses, self._indices = zip(*tmp) # type: ignore[assignment]

View File

@ -1,7 +1,8 @@
# mypy: allow-untyped-defs
""" This module contains functions and classes that alter the behavior of torch.nn.functional.scaled_dot_product_attention """
import contextlib
from typing import Iterable, List, Union
from collections.abc import Iterable
from typing import List, Union
from warnings import warn
import torch.backends.cuda
@ -13,7 +14,7 @@ from torch.backends.cuda import (
)
__all__: List[str] = ["SDPBackend", "sdpa_kernel", "WARN_FOR_UNFUSED_KERNELS"]
__all__: list[str] = ["SDPBackend", "sdpa_kernel", "WARN_FOR_UNFUSED_KERNELS"]
# Note: [SDPA warnings]
# TODO: Consider using this for sdpa regardless of subclasses
@ -73,7 +74,7 @@ def _backend_from_string(name: str):
def _cur_sdpa_kernel_backends():
backends: List[SDPBackend] = []
backends: list[SDPBackend] = []
for name, val in _backend_names.items():
if getattr(torch.backends.cuda, f"{name}_sdp_enabled")():
backends.append(getattr(SDPBackend, val))
@ -88,7 +89,7 @@ def _sdpa_kernel(backends: Iterable[SDPBackend]):
@contextlib.contextmanager
def sdpa_kernel(
backends: Union[List[SDPBackend], SDPBackend], set_priority: bool = False
backends: Union[list[SDPBackend], SDPBackend], set_priority: bool = False
):
r"""
Context manager to select which backend to use for scaled dot product attention.

View File

@ -1,12 +1,12 @@
# mypy: allow-untyped-defs
"""Defines utilities for interacting with scaled_dot_product_attention"""
import math
from typing import List, Optional, Union
from typing import Optional, Union
import torch
__all__: List[str] = []
__all__: list[str] = []
def _input_requires_grad(*tensors: torch.Tensor) -> bool:

View File

@ -9,7 +9,7 @@ import math
import operator
import warnings
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Optional, Union
import torch
from torch import Tensor
@ -75,9 +75,9 @@ def _get_mod_type(fn: Callable) -> _ModificationType:
# Need to define it here so that Dynamo doesn't skip it
def _vmap_for_bhqkv(
fn: Callable,
prefix: Tuple[Optional[int], ...],
suffix: Tuple[Optional[int], ...] = (),
out_dims: Union[int, List[Optional[int]]] = 0,
prefix: tuple[Optional[int], ...],
suffix: tuple[Optional[int], ...] = (),
out_dims: Union[int, list[Optional[int]]] = 0,
group_dim: bool = False,
):
"""Used to vmap both score_mods and mask_mods over 4-dimensional/5-dimension inputs.
@ -98,7 +98,7 @@ def _vmap_for_bhqkv(
callable: The vmapped function.
"""
# We vamp a function 4 times, broadcasting the [b, h, q_idx, kv_idx] dimensions
dimensions: List[Tuple[None | int, None | int, None | int, None | int]] = []
dimensions: list[tuple[None | int, None | int, None | int, None | int]] = []
dimensions = [
(None, None, None, 0),
(None, None, 0, None),
@ -173,7 +173,7 @@ def _ordered_to_dense(num_blocks_in_row: Tensor, col_indices: Tensor):
return out
def _dense_to_ordered(dense_mask) -> Tuple[Tensor, Tensor]:
def _dense_to_ordered(dense_mask) -> tuple[Tensor, Tensor]:
dense_mask = dense_mask.to(dtype=torch.int32)
num_blocks_in_row = dense_mask.sum(dim=-1)
col_indices = torch.argsort(dense_mask, dim=-1, descending=True, stable=True)
@ -262,7 +262,7 @@ class BlockMask:
the backwards pass. These are autogenerated from 2.
"""
seq_lengths: Tuple[int, int]
seq_lengths: tuple[int, int]
kv_num_blocks: Tensor
kv_indices: Tensor
full_kv_num_blocks: Optional[Tensor]
@ -271,12 +271,12 @@ class BlockMask:
q_indices: Optional[Tensor]
full_q_num_blocks: Optional[Tensor]
full_q_indices: Optional[Tensor]
BLOCK_SIZE: Tuple[int, int]
BLOCK_SIZE: tuple[int, int]
mask_mod: _mask_mod_signature
def __init__(
self,
seq_lengths: Tuple[int, int],
seq_lengths: tuple[int, int],
kv_num_blocks: Tensor,
kv_indices: Tensor,
full_kv_num_blocks: Optional[Tensor],
@ -285,7 +285,7 @@ class BlockMask:
q_indices: Optional[Tensor],
full_q_num_blocks: Optional[Tensor],
full_q_indices: Optional[Tensor],
BLOCK_SIZE: Tuple[int, int],
BLOCK_SIZE: tuple[int, int],
mask_mod: _mask_mod_signature,
):
if kv_indices.dim() < 2:
@ -320,9 +320,9 @@ class BlockMask:
kv_indices: Tensor,
full_kv_num_blocks: Optional[Tensor] = None,
full_kv_indices: Optional[Tensor] = None,
BLOCK_SIZE: Union[int, Tuple[int, int]] = _DEFAULT_SPARSE_BLOCK_SIZE,
BLOCK_SIZE: Union[int, tuple[int, int]] = _DEFAULT_SPARSE_BLOCK_SIZE,
mask_mod: Optional[_mask_mod_signature] = None,
seq_lengths: Optional[Tuple[int, int]] = None,
seq_lengths: Optional[tuple[int, int]] = None,
):
"""
Creates a BlockMask instance from key-value block information.
@ -332,7 +332,7 @@ class BlockMask:
kv_indices (Tensor): Indices of key-value blocks in each Q_BLOCK_SIZE row tile.
full_kv_num_blocks (Optional[Tensor]): Number of full kv_blocks in each Q_BLOCK_SIZE row tile.
full_kv_indices (Optional[Tensor]): Indices of full key-value blocks in each Q_BLOCK_SIZE row tile.
BLOCK_SIZE (Union[int, Tuple[int, int]]): Size of KV_BLOCK_SIZE x Q_BLOCK_SIZE tiles.
BLOCK_SIZE (Union[int, tuple[int, int]]): Size of KV_BLOCK_SIZE x Q_BLOCK_SIZE tiles.
mask_mod (Optional[Callable]): Function to modify the mask.
Returns:
@ -664,7 +664,7 @@ def _convert_mask_to_block_mask(
Q_BLOCK_SIZE=_DEFAULT_SPARSE_BLOCK_SIZE,
KV_BLOCK_SIZE=_DEFAULT_SPARSE_BLOCK_SIZE,
separate_full_blocks: bool = False,
) -> Tuple[Tensor, Optional[Tensor]]:
) -> tuple[Tensor, Optional[Tensor]]:
assert mask.dtype == torch.bool
mask = _broadcast_to_dim(mask, 4)
@ -748,9 +748,9 @@ def _convert_block_mask_to_mask(
def _create_sparse_block_from_block_mask(
block_mask: Tuple[Tensor, Optional[Tensor]],
block_mask: tuple[Tensor, Optional[Tensor]],
mask_mod: Optional[Callable],
seq_lengths: Tuple[int, int],
seq_lengths: tuple[int, int],
Q_BLOCK_SIZE: int = _DEFAULT_SPARSE_BLOCK_SIZE,
KV_BLOCK_SIZE: int = _DEFAULT_SPARSE_BLOCK_SIZE,
) -> BlockMask:
@ -758,7 +758,7 @@ def _create_sparse_block_from_block_mask(
partial_bm = _dense_to_ordered(partial_blocks)
if full_blocks is not None:
full_bm: Tuple[Optional[Tensor], Optional[Tensor]] = _dense_to_ordered(
full_bm: tuple[Optional[Tensor], Optional[Tensor]] = _dense_to_ordered(
full_blocks
)
else:
@ -829,7 +829,7 @@ def create_block_mask(
Q_LEN: int,
KV_LEN: int,
device: str = "cuda",
BLOCK_SIZE: Union[int, Tuple[int, int]] = _DEFAULT_SPARSE_BLOCK_SIZE,
BLOCK_SIZE: Union[int, tuple[int, int]] = _DEFAULT_SPARSE_BLOCK_SIZE,
_compile=False,
) -> BlockMask:
r"""This function creates a block mask tuple from a mask_mod function.
@ -845,7 +845,7 @@ def create_block_mask(
Q_LEN (int): Sequence length of query.
KV_LEN (int): Sequence length of key/value.
device (str): Device to run the mask creation on.
BLOCK_SIZE (int or Tuple[int, int]): Block size for the block mask. If a single int is provided it is used for both query and key/value.
BLOCK_SIZE (int or tuple[int, int]): Block size for the block mask. If a single int is provided it is used for both query and key/value.
Returns:
BlockMask: A BlockMask object that contains the block mask information.
@ -1002,7 +1002,7 @@ def create_nested_block_mask(
H: Optional[int],
q_nt: torch.Tensor,
kv_nt: Optional[torch.Tensor] = None,
BLOCK_SIZE: Union[int, Tuple[int, int]] = _DEFAULT_SPARSE_BLOCK_SIZE,
BLOCK_SIZE: Union[int, tuple[int, int]] = _DEFAULT_SPARSE_BLOCK_SIZE,
_compile=False,
) -> BlockMask:
r"""This function creates a nested tensor compatible block mask tuple from a mask_mod
@ -1024,7 +1024,7 @@ def create_nested_block_mask(
constructed to operate on a "stacked sequence" of length ``sum(S)`` for sequence
length ``S`` from the NJT. If this is None, ``q_nt`` is used to define the structure
for key / value as well. Default: None
BLOCK_SIZE (int or Tuple[int, int]): Block size for the block mask. If a single int is
BLOCK_SIZE (int or tuple[int, int]): Block size for the block mask. If a single int is
provided it is used for both query and key/value.
Returns:
@ -1167,8 +1167,8 @@ def flex_attention(
scale: Optional[float] = None,
enable_gqa: bool = False,
return_lse: bool = False,
kernel_options: Optional[Dict[str, Any]] = None,
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
kernel_options: Optional[dict[str, Any]] = None,
) -> Union[Tensor, tuple[Tensor, Tensor]]:
r"""This function implements scaled dot product attention with an arbitrary attention score modification function.
This function computes the scaled dot product attention between query, key, and value tensors with a user-defined

View File

@ -1,4 +1,4 @@
from typing import Optional, Tuple, TypeVar, Union
from typing import Optional, TypeVar, Union
from torch import Tensor
@ -9,13 +9,13 @@ from torch import Tensor
# broadcast to a tuple.
# Comes in several variants: A tuple of unknown size, and a fixed-size tuple for 1d, 2d, or 3d operations.
T = TypeVar("T")
_scalar_or_tuple_any_t = Union[T, Tuple[T, ...]]
_scalar_or_tuple_1_t = Union[T, Tuple[T]]
_scalar_or_tuple_2_t = Union[T, Tuple[T, T]]
_scalar_or_tuple_3_t = Union[T, Tuple[T, T, T]]
_scalar_or_tuple_4_t = Union[T, Tuple[T, T, T, T]]
_scalar_or_tuple_5_t = Union[T, Tuple[T, T, T, T, T]]
_scalar_or_tuple_6_t = Union[T, Tuple[T, T, T, T, T, T]]
_scalar_or_tuple_any_t = Union[T, tuple[T, ...]]
_scalar_or_tuple_1_t = Union[T, tuple[T]]
_scalar_or_tuple_2_t = Union[T, tuple[T, T]]
_scalar_or_tuple_3_t = Union[T, tuple[T, T, T]]
_scalar_or_tuple_4_t = Union[T, tuple[T, T, T, T]]
_scalar_or_tuple_5_t = Union[T, tuple[T, T, T, T, T]]
_scalar_or_tuple_6_t = Union[T, tuple[T, T, T, T, T, T]]
# For arguments which represent size parameters (eg, kernel size, padding)
_size_any_t = _scalar_or_tuple_any_t[int]

View File

@ -3,7 +3,7 @@
import importlib
import math
import warnings
from typing import Callable, List, Optional, Tuple, TYPE_CHECKING, Union
from typing import Callable, List, Optional, TYPE_CHECKING, Union
import torch
from torch import _VF, sym_int as _sym_int, Tensor
@ -440,7 +440,7 @@ def fractional_max_pool2d_with_indices(
output_ratio: Optional[BroadcastingList2[float]] = None,
return_indices: bool = False,
_random_samples: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]: # noqa: D400
) -> tuple[Tensor, Tensor]: # noqa: D400
r"""
fractional_max_pool2d(input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None)
@ -552,7 +552,7 @@ def fractional_max_pool3d_with_indices(
output_ratio: Optional[BroadcastingList3[float]] = None,
return_indices: bool = False,
_random_samples: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]: # noqa: D400
) -> tuple[Tensor, Tensor]: # noqa: D400
r"""
fractional_max_pool3d(input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None)
@ -669,7 +669,7 @@ def max_pool1d_with_indices(
dilation: BroadcastingList1[int] = 1,
ceil_mode: bool = False,
return_indices: bool = False,
) -> Tuple[Tensor, Tensor]: # noqa: D400
) -> tuple[Tensor, Tensor]: # noqa: D400
r"""
max_pool1d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False)
@ -759,7 +759,7 @@ def max_pool2d_with_indices(
dilation: BroadcastingList2[int] = 1,
ceil_mode: bool = False,
return_indices: bool = False,
) -> Tuple[Tensor, Tensor]: # noqa: D400
) -> tuple[Tensor, Tensor]: # noqa: D400
r"""
max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False)
@ -849,7 +849,7 @@ def max_pool3d_with_indices(
dilation: BroadcastingList3[int] = 1,
ceil_mode: bool = False,
return_indices: bool = False,
) -> Tuple[Tensor, Tensor]: # noqa: D400
) -> tuple[Tensor, Tensor]: # noqa: D400
r"""
max_pool3d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False)
@ -933,11 +933,11 @@ max_pool3d = boolean_dispatch(
def _unpool_output_size(
input: Tensor,
kernel_size: List[int],
stride: List[int],
padding: List[int],
output_size: Optional[List[int]],
) -> List[int]:
kernel_size: list[int],
stride: list[int],
padding: list[int],
output_size: Optional[list[int]],
) -> list[int]:
input_size = input.size()
default_size = torch.jit.annotate(List[int], [])
for d in range(len(kernel_size)):
@ -1187,7 +1187,7 @@ def adaptive_max_pool1d_with_indices(
input: Tensor,
output_size: BroadcastingList1[int],
return_indices: bool = False,
) -> Tuple[Tensor, Tensor]: # noqa: D400
) -> tuple[Tensor, Tensor]: # noqa: D400
r"""
adaptive_max_pool1d(input, output_size, return_indices=False)
@ -1242,7 +1242,7 @@ def adaptive_max_pool2d_with_indices(
input: Tensor,
output_size: BroadcastingList2[int],
return_indices: bool = False,
) -> Tuple[Tensor, Tensor]: # noqa: D400
) -> tuple[Tensor, Tensor]: # noqa: D400
r"""adaptive_max_pool2d(input, output_size, return_indices=False)
Applies a 2D adaptive max pooling over an input signal composed of
@ -1298,7 +1298,7 @@ def adaptive_max_pool3d_with_indices(
input: Tensor,
output_size: BroadcastingList3[int],
return_indices: bool = False,
) -> Tuple[Tensor, Tensor]: # noqa: D400
) -> tuple[Tensor, Tensor]: # noqa: D400
r"""
adaptive_max_pool3d(input, output_size, return_indices=False)
@ -2430,7 +2430,7 @@ def _no_grad_embedding_renorm_(
input: Tensor,
max_norm: float,
norm_type: float,
) -> Tuple[Tensor, Tensor]:
) -> tuple[Tensor, Tensor]:
torch.embedding_renorm_(weight.detach(), input, max_norm, norm_type)
@ -2769,7 +2769,7 @@ if embedding_bag.__doc__:
embedding_bag.__doc__ = embedding_bag.__doc__.format(**reproducibility_notes)
def _verify_batch_size(size: List[int]) -> None:
def _verify_batch_size(size: list[int]) -> None:
# XXX: JIT script does not support the reduce from functools, and mul op is a
# builtin, which cannot be used as a value to a func yet, so rewrite this size
# check to a simple equivalent for loop
@ -2832,7 +2832,7 @@ def batch_norm(
)
def _verify_spatial_size(size: List[int]) -> None:
def _verify_spatial_size(size: list[int]) -> None:
# Verify that there is > 1 spatial element for instance norm calculation.
size_prods = 1
for i in range(2, len(size)):
@ -2888,7 +2888,7 @@ def instance_norm(
def layer_norm(
input: Tensor,
normalized_shape: List[int],
normalized_shape: list[int],
weight: Optional[Tensor] = None,
bias: Optional[Tensor] = None,
eps: float = 1e-5,
@ -2914,7 +2914,7 @@ def layer_norm(
def rms_norm(
input: Tensor,
normalized_shape: List[int],
normalized_shape: list[int],
weight: Optional[Tensor] = None,
eps: Optional[float] = None,
) -> Tensor:
@ -4301,7 +4301,7 @@ def upsample( # noqa: F811
@_overload
def upsample( # noqa: F811
input: Tensor,
size: Optional[List[int]] = None,
size: Optional[list[int]] = None,
scale_factor: Optional[float] = None,
mode: str = "nearest",
align_corners: Optional[bool] = None,
@ -4402,7 +4402,7 @@ def _is_integer(x) -> bool:
def interpolate( # noqa: F811
input: Tensor,
size: Optional[int] = None,
scale_factor: Optional[List[float]] = None,
scale_factor: Optional[list[float]] = None,
mode: str = "nearest",
align_corners: Optional[bool] = None,
recompute_scale_factor: Optional[bool] = None,
@ -4414,8 +4414,8 @@ def interpolate( # noqa: F811
@_overload
def interpolate( # noqa: F811
input: Tensor,
size: Optional[List[int]] = None,
scale_factor: Optional[List[float]] = None,
size: Optional[list[int]] = None,
scale_factor: Optional[list[float]] = None,
mode: str = "nearest",
align_corners: Optional[bool] = None,
recompute_scale_factor: Optional[bool] = None,
@ -4440,7 +4440,7 @@ def interpolate( # noqa: F811
@_overload
def interpolate( # noqa: F811
input: Tensor,
size: Optional[List[int]] = None,
size: Optional[list[int]] = None,
scale_factor: Optional[float] = None,
mode: str = "nearest",
align_corners: Optional[bool] = None,
@ -4453,7 +4453,7 @@ def interpolate( # noqa: F811
def interpolate( # noqa: F811
input: Tensor,
size: Optional[int] = None,
scale_factor: Optional[List[float]] = None,
scale_factor: Optional[list[float]] = None,
mode: str = "nearest",
align_corners: Optional[bool] = None,
recompute_scale_factor: Optional[bool] = None,
@ -4744,7 +4744,7 @@ def upsample_nearest( # noqa: F811
@_overload
def upsample_nearest( # noqa: F811
input: Tensor,
size: Optional[List[int]] = None,
size: Optional[list[int]] = None,
scale_factor: Optional[float] = None,
) -> Tensor:
pass
@ -4794,7 +4794,7 @@ def upsample_bilinear( # noqa: F811
@_overload
def upsample_bilinear( # noqa: F811
input: Tensor,
size: Optional[List[int]] = None,
size: Optional[list[int]] = None,
scale_factor: Optional[float] = None,
) -> Tensor:
pass
@ -4804,7 +4804,7 @@ def upsample_bilinear( # noqa: F811
def upsample_bilinear( # noqa: F811
input: Tensor,
size: Optional[int] = None,
scale_factor: Optional[List[float]] = None,
scale_factor: Optional[list[float]] = None,
) -> Tensor:
pass
@ -4812,8 +4812,8 @@ def upsample_bilinear( # noqa: F811
@_overload
def upsample_bilinear( # noqa: F811
input: Tensor,
size: Optional[List[int]] = None,
scale_factor: Optional[List[float]] = None,
size: Optional[list[int]] = None,
scale_factor: Optional[list[float]] = None,
) -> Tensor:
pass
@ -5025,7 +5025,7 @@ def grid_sample(
def affine_grid(
theta: Tensor,
size: List[int],
size: list[int],
align_corners: Optional[bool] = None,
) -> Tensor:
r"""Generate 2D or 3D flow field (sampling grid), given a batch of affine matrices :attr:`theta`.
@ -5127,7 +5127,7 @@ def affine_grid(
def pad(
input: Tensor,
pad: List[int],
pad: list[int],
mode: str = "constant",
value: Optional[float] = None,
) -> Tensor:
@ -5490,7 +5490,7 @@ def normalize(
return torch.div(input, denom, out=out)
def assert_int_or_pair(arg: List[int], arg_name: str, message: str) -> None:
def assert_int_or_pair(arg: list[int], arg_name: str, message: str) -> None:
assert isinstance(arg, int) or len(arg) == 2, message.format(arg_name)
@ -5579,7 +5579,7 @@ def _in_projection_packed(
v: Tensor,
w: Tensor,
b: Optional[Tensor] = None,
) -> List[Tensor]:
) -> list[Tensor]:
r"""Perform the in-projection step of the attention operation, using packed weights.
Output is a triple containing projection tensors for query, key and value.
@ -5658,7 +5658,7 @@ def _in_projection(
b_q: Optional[Tensor] = None,
b_k: Optional[Tensor] = None,
b_v: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor, Tensor]:
) -> tuple[Tensor, Tensor, Tensor]:
r"""Perform the in-projection step of the attention operation.
This is simply a triple of linear projections,
@ -6019,7 +6019,7 @@ def multi_head_attention_forward(
static_v: Optional[Tensor] = None,
average_attn_weights: bool = True,
is_causal: bool = False,
) -> Tuple[Tensor, Optional[Tensor]]:
) -> tuple[Tensor, Optional[Tensor]]:
r"""Forward method for MultiHeadAttention.
.. note::

View File

@ -1,6 +1,6 @@
# mypy: allow-untyped-defs
import warnings
from typing import Optional, Tuple
from typing import Optional
import torch
import torch.nn.functional as F
@ -1146,7 +1146,7 @@ class MultiheadAttention(Module):
attn_mask: Optional[Tensor] = None,
average_attn_weights: bool = True,
is_causal: bool = False,
) -> Tuple[Tensor, Optional[Tensor]]:
) -> tuple[Tensor, Optional[Tensor]]:
r"""Compute attention outputs using query, key, and value embeddings.
Supports optional parameters for padding, masks and attention weights.
@ -1401,7 +1401,7 @@ class MultiheadAttention(Module):
attn_mask: Optional[Tensor],
key_padding_mask: Optional[Tensor],
query: Tensor,
) -> Tuple[Optional[Tensor], Optional[int]]:
) -> tuple[Optional[Tensor], Optional[int]]:
r"""Determine mask type and combine masks if necessary.
If only one mask is provided, that mask

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs
from collections import namedtuple
from typing import List, Sequence
from collections.abc import Sequence
import torch
import torch.nn.functional as F
@ -107,7 +107,7 @@ class AdaptiveLogSoftmaxWithLoss(Module):
in_features: int
n_classes: int
cutoffs: List[int]
cutoffs: list[int]
div_value: float
head_bias: bool
head: Linear

View File

@ -2,19 +2,9 @@
# mypy: allow-untyped-defs
import operator
from collections import abc as container_abcs, OrderedDict
from collections.abc import Iterable, Iterator, Mapping
from itertools import chain, islice
from typing import (
Any,
Dict,
Iterable,
Iterator,
Mapping,
Optional,
overload,
Tuple,
TypeVar,
Union,
)
from typing import Any, Optional, overload, TypeVar, Union
from typing_extensions import deprecated, Self
import torch
@ -107,7 +97,7 @@ class Sequential(Module):
]))
"""
_modules: Dict[str, Module] # type: ignore[assignment]
_modules: dict[str, Module] # type: ignore[assignment]
@overload
def __init__(self, *args: Module) -> None:
@ -302,7 +292,7 @@ class ModuleList(Module):
return x
"""
_modules: Dict[str, Module] # type: ignore[assignment]
_modules: dict[str, Module] # type: ignore[assignment]
def __init__(self, modules: Optional[Iterable[Module]] = None) -> None:
super().__init__()
@ -490,7 +480,7 @@ class ModuleDict(Module):
return x
"""
_modules: Dict[str, Module] # type: ignore[assignment]
_modules: dict[str, Module] # type: ignore[assignment]
def __init__(self, modules: Optional[Mapping[str, Module]] = None) -> None:
super().__init__()
@ -539,7 +529,7 @@ class ModuleDict(Module):
return self._modules.keys()
@_copy_to_script_wrapper
def items(self) -> Iterable[Tuple[str, Module]]:
def items(self) -> Iterable[tuple[str, Module]]:
r"""Return an iterable of the ModuleDict key/value pairs."""
return self._modules.items()
@ -771,7 +761,7 @@ class ParameterDict(Module):
def __init__(self, parameters: Any = None) -> None:
super().__init__()
self._keys: Dict[str, None] = {}
self._keys: dict[str, None] = {}
if parameters is not None:
self.update(parameters)
@ -855,7 +845,7 @@ class ParameterDict(Module):
del self[key]
return v
def popitem(self) -> Tuple[str, Any]:
def popitem(self) -> tuple[str, Any]:
"""Remove and return the last inserted `(key, parameter)` pair from the ParameterDict."""
k, _ = self._keys.popitem()
# We need the key in the _keys to be able to access/del
@ -888,7 +878,7 @@ class ParameterDict(Module):
r"""Return an iterable of the ParameterDict keys."""
return self._keys.keys()
def items(self) -> Iterable[Tuple[str, Any]]:
def items(self) -> Iterable[tuple[str, Any]]:
r"""Return an iterable of the ParameterDict key/value pairs."""
return ((k, self[k]) for k in self._keys)

View File

@ -1,6 +1,6 @@
# mypy: allow-untyped-defs
import math
from typing import List, Optional, Tuple, Union
from typing import List, Optional, Union
from typing_extensions import deprecated
import torch
@ -70,14 +70,14 @@ class _ConvNd(Module):
...
in_channels: int
_reversed_padding_repeated_twice: List[int]
_reversed_padding_repeated_twice: list[int]
out_channels: int
kernel_size: Tuple[int, ...]
stride: Tuple[int, ...]
padding: Union[str, Tuple[int, ...]]
dilation: Tuple[int, ...]
kernel_size: tuple[int, ...]
stride: tuple[int, ...]
padding: Union[str, tuple[int, ...]]
dilation: tuple[int, ...]
transposed: bool
output_padding: Tuple[int, ...]
output_padding: tuple[int, ...]
groups: int
padding_mode: str
weight: Tensor
@ -87,12 +87,12 @@ class _ConvNd(Module):
self,
in_channels: int,
out_channels: int,
kernel_size: Tuple[int, ...],
stride: Tuple[int, ...],
padding: Union[str, Tuple[int, ...]],
dilation: Tuple[int, ...],
kernel_size: tuple[int, ...],
stride: tuple[int, ...],
padding: Union[str, tuple[int, ...]],
dilation: tuple[int, ...],
transposed: bool,
output_padding: Tuple[int, ...],
output_padding: tuple[int, ...],
groups: int,
bias: bool,
padding_mode: str,
@ -768,13 +768,13 @@ class _ConvTransposeNd(_ConvNd):
def _output_padding(
self,
input: Tensor,
output_size: Optional[List[int]],
stride: List[int],
padding: List[int],
kernel_size: List[int],
output_size: Optional[list[int]],
stride: list[int],
padding: list[int],
kernel_size: list[int],
num_spatial_dims: int,
dilation: Optional[List[int]] = None,
) -> List[int]:
dilation: Optional[list[int]] = None,
) -> list[int]:
if output_size is None:
ret = _single(self.output_padding) # converting to list if was not already
else:
@ -952,7 +952,7 @@ class ConvTranspose1d(_ConvTransposeNd):
**factory_kwargs,
)
def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor:
def forward(self, input: Tensor, output_size: Optional[list[int]] = None) -> Tensor:
if self.padding_mode != "zeros":
raise ValueError(
"Only `zeros` padding mode is supported for ConvTranspose1d"
@ -1139,7 +1139,7 @@ class ConvTranspose2d(_ConvTransposeNd):
**factory_kwargs,
)
def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor:
def forward(self, input: Tensor, output_size: Optional[list[int]] = None) -> Tensor:
if self.padding_mode != "zeros":
raise ValueError(
"Only `zeros` padding mode is supported for ConvTranspose2d"
@ -1324,7 +1324,7 @@ class ConvTranspose3d(_ConvTransposeNd):
**factory_kwargs,
)
def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor:
def forward(self, input: Tensor, output_size: Optional[list[int]] = None) -> Tensor:
if self.padding_mode != "zeros":
raise ValueError(
"Only `zeros` padding mode is supported for ConvTranspose3d"
@ -1391,7 +1391,7 @@ class _LazyConvXdMixin(LazyModuleMixin):
transposed: bool
in_channels: int
out_channels: int
kernel_size: Tuple[int, ...]
kernel_size: tuple[int, ...]
weight: UninitializedParameter
bias: UninitializedParameter

View File

@ -1,5 +1,5 @@
# mypy: allow-untyped-defs
from typing import Tuple, Union
from typing import Union
from torch import Tensor
from torch.types import _size
@ -103,7 +103,7 @@ class Unflatten(Module):
torch.Size([2, 2, 5, 5])
"""
NamedShape = Tuple[Tuple[str, int]]
NamedShape = tuple[tuple[str, int]]
__constants__ = ["dim", "unflattened_size"]
dim: Union[int, str]

View File

@ -1,6 +1,6 @@
# mypy: allow-untyped-defs
import itertools
from typing import Any, Optional, Protocol, Type
from typing import Any, Optional, Protocol
import torch
from torch.nn.parameter import is_lazy
@ -180,7 +180,7 @@ class LazyModuleMixin:
# modules inheriting from this will change their __class__ to the specified
# one after they are fully initialized
cls_to_become: Optional[Type[Any]] = None
cls_to_become: Optional[type[Any]] = None
def __init__(self: _LazyProtocol, *args, **kwargs):
# Mypy doesnt like this super call in a mixin

View File

@ -6,20 +6,8 @@ import itertools
import warnings
import weakref
from collections import namedtuple, OrderedDict
from typing import (
Any,
Callable,
Dict,
Iterator,
List,
Mapping,
Optional,
overload,
Set,
Tuple,
TypeVar,
Union,
)
from collections.abc import Iterator, Mapping
from typing import Any, Callable, Optional, overload, TypeVar, Union
from typing_extensions import Self
import torch
@ -42,7 +30,7 @@ __all__ = [
"Module",
]
_grad_t = Union[Tuple[Tensor, ...], Tensor]
_grad_t = Union[tuple[Tensor, ...], Tensor]
# See https://mypy.readthedocs.io/en/latest/generics.html#generic-methods-and-generic-self for the use
# of `T` to annotate `self`. Many methods of `Module` return `self` and we want those return values to be
# the type of the subclass, not the looser type of `Module`.
@ -74,9 +62,9 @@ def _addindent(s_, numSpaces):
r"""This tracks hooks common to all modules that are executed immediately before
.registering the buffer/module/parameter"""
_global_buffer_registration_hooks: Dict[int, Callable] = OrderedDict()
_global_module_registration_hooks: Dict[int, Callable] = OrderedDict()
_global_parameter_registration_hooks: Dict[int, Callable] = OrderedDict()
_global_buffer_registration_hooks: dict[int, Callable] = OrderedDict()
_global_module_registration_hooks: dict[int, Callable] = OrderedDict()
_global_parameter_registration_hooks: dict[int, Callable] = OrderedDict()
class _WrappedHook:
@ -98,14 +86,14 @@ class _WrappedHook:
return self.hook(module, *args, **kwargs)
return self.hook(*args, **kwargs)
def __getstate__(self) -> Dict:
def __getstate__(self) -> dict:
result = {"hook": self.hook, "with_module": self.with_module}
if self.with_module:
result["module"] = self.module()
return result
def __setstate__(self, state: Dict):
def __setstate__(self, state: dict):
self.hook = state["hook"]
self.with_module = state["with_module"]
@ -120,13 +108,13 @@ class _WrappedHook:
r"""This tracks hooks common to all modules that are executed before/after
calling forward and backward. This is global state used for debugging/profiling
purposes"""
_global_backward_pre_hooks: Dict[int, Callable] = OrderedDict()
_global_backward_hooks: Dict[int, Callable] = OrderedDict()
_global_backward_pre_hooks: dict[int, Callable] = OrderedDict()
_global_backward_hooks: dict[int, Callable] = OrderedDict()
_global_is_full_backward_hook: Optional[bool] = None
_global_forward_pre_hooks: Dict[int, Callable] = OrderedDict()
_global_forward_hooks: Dict[int, Callable] = OrderedDict()
_global_forward_hooks_always_called: Dict[int, bool] = OrderedDict()
_global_forward_hooks_with_kwargs: Dict[int, bool] = OrderedDict()
_global_forward_pre_hooks: dict[int, Callable] = OrderedDict()
_global_forward_hooks: dict[int, Callable] = OrderedDict()
_global_forward_hooks_always_called: dict[int, bool] = OrderedDict()
_global_forward_hooks_with_kwargs: dict[int, bool] = OrderedDict()
_EXTRA_STATE_KEY_SUFFIX = "_extra_state"
@ -447,29 +435,29 @@ class Module:
the change."""
training: bool
_parameters: Dict[str, Optional[Parameter]]
_buffers: Dict[str, Optional[Tensor]]
_non_persistent_buffers_set: Set[str]
_backward_pre_hooks: Dict[int, Callable]
_backward_hooks: Dict[int, Callable]
_parameters: dict[str, Optional[Parameter]]
_buffers: dict[str, Optional[Tensor]]
_non_persistent_buffers_set: set[str]
_backward_pre_hooks: dict[int, Callable]
_backward_hooks: dict[int, Callable]
_is_full_backward_hook: Optional[bool]
_forward_hooks: Dict[int, Callable]
_forward_hooks: dict[int, Callable]
# Marks whether the corresponding _forward_hooks accept kwargs or not.
# As JIT does not support Set[int], this dict is used as a set, where all
# hooks represented in this dict accept kwargs.
_forward_hooks_with_kwargs: Dict[int, bool]
_forward_hooks_with_kwargs: dict[int, bool]
# forward hooks that should always be called even if an exception is raised
_forward_hooks_always_called: Dict[int, bool]
_forward_pre_hooks: Dict[int, Callable]
_forward_hooks_always_called: dict[int, bool]
_forward_pre_hooks: dict[int, Callable]
# Marks whether the corresponding _forward_hooks accept kwargs or not.
# As JIT does not support Set[int], this dict is used as a set, where all
# hooks represented in this dict accept kwargs.
_forward_pre_hooks_with_kwargs: Dict[int, bool]
_state_dict_hooks: Dict[int, Callable]
_load_state_dict_pre_hooks: Dict[int, Callable]
_state_dict_pre_hooks: Dict[int, Callable]
_load_state_dict_post_hooks: Dict[int, Callable]
_modules: Dict[str, Optional["Module"]]
_forward_pre_hooks_with_kwargs: dict[int, bool]
_state_dict_hooks: dict[int, Callable]
_load_state_dict_pre_hooks: dict[int, Callable]
_state_dict_pre_hooks: dict[int, Callable]
_load_state_dict_post_hooks: dict[int, Callable]
_modules: dict[str, Optional["Module"]]
call_super_init: bool = False
_compiled_call_impl: Optional[Callable] = None
@ -712,7 +700,7 @@ class Module:
if target == "":
return self
atoms: List[str] = target.split(".")
atoms: list[str] = target.split(".")
mod: torch.nn.Module = self
for item in atoms:
@ -769,7 +757,7 @@ class Module:
if target == "":
raise ValueError("Cannot set the submodule without a target name!")
atoms: List[str] = target.split(".")
atoms: list[str] = target.split(".")
name = atoms.pop(-1)
mod: torch.nn.Module = self
@ -1485,13 +1473,13 @@ class Module:
It returns two lists, one with the full backward hooks and one with the non-full
backward hooks.
"""
full_backward_hooks: List[Callable] = []
full_backward_hooks: list[Callable] = []
if _global_is_full_backward_hook is True:
full_backward_hooks += _global_backward_hooks.values()
if self._is_full_backward_hook is True:
full_backward_hooks += self._backward_hooks.values()
non_full_backward_hooks: List[Callable] = []
non_full_backward_hooks: list[Callable] = []
if _global_is_full_backward_hook is False:
non_full_backward_hooks += _global_backward_hooks.values()
if self._is_full_backward_hook is False:
@ -1500,7 +1488,7 @@ class Module:
return full_backward_hooks, non_full_backward_hooks
def _get_backward_pre_hooks(self):
backward_pre_hooks: List[Callable] = []
backward_pre_hooks: list[Callable] = []
backward_pre_hooks += _global_backward_pre_hooks.values()
backward_pre_hooks += self._backward_pre_hooks.values()
@ -1580,10 +1568,10 @@ class Module:
def register_forward_pre_hook(
self,
hook: Union[
Callable[[T, Tuple[Any, ...]], Optional[Any]],
Callable[[T, tuple[Any, ...]], Optional[Any]],
Callable[
[T, Tuple[Any, ...], Dict[str, Any]],
Optional[Tuple[Any, Dict[str, Any]]],
[T, tuple[Any, ...], dict[str, Any]],
Optional[tuple[Any, dict[str, Any]]],
],
],
*,
@ -1646,8 +1634,8 @@ class Module:
def register_forward_hook(
self,
hook: Union[
Callable[[T, Tuple[Any, ...], Any], Optional[Any]],
Callable[[T, Tuple[Any, ...], Dict[str, Any], Any], Optional[Any]],
Callable[[T, tuple[Any, ...], Any], Optional[Any]],
Callable[[T, tuple[Any, ...], dict[str, Any], Any], Optional[Any]],
],
*,
prepend: bool = False,
@ -2125,7 +2113,7 @@ class Module:
# The user can pass an optional arbitrary mappable object to `state_dict`, in which case `state_dict` returns
# back that same object. But if they pass nothing, an `OrderedDict` is created and returned.
T_destination = TypeVar("T_destination", bound=Dict[str, Any])
T_destination = TypeVar("T_destination", bound=dict[str, Any])
@overload
def state_dict(
@ -2134,7 +2122,7 @@ class Module:
...
@overload
def state_dict(self, *, prefix: str = ..., keep_vars: bool = ...) -> Dict[str, Any]:
def state_dict(self, *, prefix: str = ..., keep_vars: bool = ...) -> dict[str, Any]:
...
# TODO: Change `*args` to `*` and remove the corresponding warning in docs when BC allows.
@ -2514,9 +2502,9 @@ class Module:
f"Expected state_dict to be dict-like, got {type(state_dict)}."
)
missing_keys: List[str] = []
unexpected_keys: List[str] = []
error_msgs: List[str] = []
missing_keys: list[str] = []
unexpected_keys: list[str] = []
error_msgs: list[str] = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, "_metadata", None)
@ -2632,7 +2620,7 @@ class Module:
def named_parameters(
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
) -> Iterator[Tuple[str, Parameter]]:
) -> Iterator[tuple[str, Parameter]]:
r"""Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.
Args:
@ -2687,7 +2675,7 @@ class Module:
def named_buffers(
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
) -> Iterator[Tuple[str, Tensor]]:
) -> Iterator[tuple[str, Tensor]]:
r"""Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.
Args:
@ -2725,7 +2713,7 @@ class Module:
for _name, module in self.named_children():
yield module
def named_children(self) -> Iterator[Tuple[str, "Module"]]:
def named_children(self) -> Iterator[tuple[str, "Module"]]:
r"""Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.
Yields:
@ -2774,7 +2762,7 @@ class Module:
def named_modules(
self,
memo: Optional[Set["Module"]] = None,
memo: Optional[set["Module"]] = None,
prefix: str = "",
remove_duplicate: bool = True,
):

View File

@ -1,6 +1,6 @@
# mypy: allow-untyped-defs
import numbers
from typing import List, Optional, Tuple, Union
from typing import Optional, Union
import torch
from torch import Size, Tensor
@ -88,7 +88,7 @@ class CrossMapLRN2d(Module):
return "{size}, alpha={alpha}, beta={beta}, k={k}".format(**self.__dict__)
_shape_t = Union[int, List[int], Size]
_shape_t = Union[int, list[int], Size]
class LayerNorm(Module):
@ -170,7 +170,7 @@ class LayerNorm(Module):
"""
__constants__ = ["normalized_shape", "eps", "elementwise_affine"]
normalized_shape: Tuple[int, ...]
normalized_shape: tuple[int, ...]
eps: float
elementwise_affine: bool
@ -359,7 +359,7 @@ class RMSNorm(Module):
"""
__constants__ = ["normalized_shape", "eps", "elementwise_affine"]
normalized_shape: Tuple[int, ...]
normalized_shape: tuple[int, ...]
eps: Optional[float]
elementwise_affine: bool

View File

@ -1,5 +1,5 @@
# mypy: allow-untyped-defs
from typing import Sequence, Tuple
from collections.abc import Sequence
import torch.nn.functional as F
from torch import Tensor
@ -83,7 +83,7 @@ class CircularPad1d(_CircularPadNd):
[5., 6., 7., 4., 5., 6., 7., 4.]]])
"""
padding: Tuple[int, int]
padding: tuple[int, int]
def __init__(self, padding: _size_2_t) -> None:
super().__init__()
@ -142,7 +142,7 @@ class CircularPad2d(_CircularPadNd):
[8., 6., 7., 8., 6.]]]])
"""
padding: Tuple[int, int, int, int]
padding: tuple[int, int, int, int]
def __init__(self, padding: _size_4_t) -> None:
super().__init__()
@ -191,7 +191,7 @@ class CircularPad3d(_CircularPadNd):
>>> output = m(input)
"""
padding: Tuple[int, int, int, int, int, int]
padding: tuple[int, int, int, int, int, int]
def __init__(self, padding: _size_6_t) -> None:
super().__init__()
@ -262,7 +262,7 @@ class ConstantPad1d(_ConstantPadNd):
[ 3.5000, 3.5000, 3.5000, -3.6372, 0.1182, -1.8652, 3.5000]]])
"""
padding: Tuple[int, int]
padding: tuple[int, int]
def __init__(self, padding: _size_2_t, value: float):
super().__init__(value)
@ -313,7 +313,7 @@ class ConstantPad2d(_ConstantPadNd):
"""
__constants__ = ["padding", "value"]
padding: Tuple[int, int, int, int]
padding: tuple[int, int, int, int]
def __init__(self, padding: _size_4_t, value: float) -> None:
super().__init__(value)
@ -353,7 +353,7 @@ class ConstantPad3d(_ConstantPadNd):
>>> output = m(input)
"""
padding: Tuple[int, int, int, int, int, int]
padding: tuple[int, int, int, int, int, int]
def __init__(self, padding: _size_6_t, value: float) -> None:
super().__init__(value)
@ -405,7 +405,7 @@ class ReflectionPad1d(_ReflectionPadNd):
[7., 6., 5., 4., 5., 6., 7., 6.]]])
"""
padding: Tuple[int, int]
padding: tuple[int, int]
def __init__(self, padding: _size_2_t) -> None:
super().__init__()
@ -458,7 +458,7 @@ class ReflectionPad2d(_ReflectionPadNd):
[7., 6., 7., 8., 7.]]]])
"""
padding: Tuple[int, int, int, int]
padding: tuple[int, int, int, int]
def __init__(self, padding: _size_4_t) -> None:
super().__init__()
@ -512,7 +512,7 @@ class ReflectionPad3d(_ReflectionPadNd):
[1., 0., 1., 0.]]]]])
"""
padding: Tuple[int, int, int, int, int, int]
padding: tuple[int, int, int, int, int, int]
def __init__(self, padding: _size_6_t) -> None:
super().__init__()
@ -564,7 +564,7 @@ class ReplicationPad1d(_ReplicationPadNd):
[4., 4., 4., 4., 5., 6., 7., 7.]]])
"""
padding: Tuple[int, int]
padding: tuple[int, int]
def __init__(self, padding: _size_2_t) -> None:
super().__init__()
@ -616,7 +616,7 @@ class ReplicationPad2d(_ReplicationPadNd):
[6., 6., 7., 8., 8.]]]])
"""
padding: Tuple[int, int, int, int]
padding: tuple[int, int, int, int]
def __init__(self, padding: _size_4_t) -> None:
super().__init__()
@ -657,7 +657,7 @@ class ReplicationPad3d(_ReplicationPadNd):
>>> output = m(input)
"""
padding: Tuple[int, int, int, int, int, int]
padding: tuple[int, int, int, int, int, int]
def __init__(self, padding: _size_6_t) -> None:
super().__init__()
@ -708,7 +708,7 @@ class ZeroPad1d(ConstantPad1d):
[ 0.0000, 0.0000, 0.0000, -3.6372, 0.1182, -1.8652, 0.0000]]])
"""
padding: Tuple[int, int]
padding: tuple[int, int]
def __init__(self, padding: _size_2_t) -> None:
super().__init__(padding, 0.0)
@ -762,7 +762,7 @@ class ZeroPad2d(ConstantPad2d):
[ 0.0000, -0.9162, -0.5436, -0.6446, 0.0000]]]])
"""
padding: Tuple[int, int, int, int]
padding: tuple[int, int, int, int]
def __init__(self, padding: _size_4_t) -> None:
super().__init__(padding, 0.0)
@ -804,7 +804,7 @@ class ZeroPad3d(ConstantPad3d):
>>> output = m(input)
"""
padding: Tuple[int, int, int, int, int, int]
padding: tuple[int, int, int, int, int, int]
def __init__(self, padding: _size_6_t) -> None:
super().__init__(padding, 0.0)

View File

@ -1,4 +1,4 @@
from typing import List, Optional
from typing import Optional
import torch.nn.functional as F
from torch import Tensor
@ -384,7 +384,7 @@ class MaxUnpool1d(_MaxUnpoolNd):
self.padding = _single(padding)
def forward(
self, input: Tensor, indices: Tensor, output_size: Optional[List[int]] = None
self, input: Tensor, indices: Tensor, output_size: Optional[list[int]] = None
) -> Tensor:
return F.max_unpool1d(
input, indices, self.kernel_size, self.stride, self.padding, output_size
@ -479,7 +479,7 @@ class MaxUnpool2d(_MaxUnpoolNd):
self.padding = _pair(padding)
def forward(
self, input: Tensor, indices: Tensor, output_size: Optional[List[int]] = None
self, input: Tensor, indices: Tensor, output_size: Optional[list[int]] = None
) -> Tensor:
return F.max_unpool2d(
input, indices, self.kernel_size, self.stride, self.padding, output_size
@ -557,7 +557,7 @@ class MaxUnpool3d(_MaxUnpoolNd):
self.padding = _triple(padding)
def forward(
self, input: Tensor, indices: Tensor, output_size: Optional[List[int]] = None
self, input: Tensor, indices: Tensor, output_size: Optional[list[int]] = None
) -> Tensor:
return F.max_unpool3d(
input, indices, self.kernel_size, self.stride, self.padding, output_size

View File

@ -4,7 +4,7 @@ import math
import numbers
import warnings
import weakref
from typing import List, Optional, overload, Tuple
from typing import Optional, overload
from typing_extensions import deprecated
import torch
@ -106,7 +106,7 @@ class RNNBase(Module):
self.dropout = float(dropout)
self.bidirectional = bidirectional
self.proj_size = proj_size
self._flat_weight_refs: List[Optional[weakref.ReferenceType[Parameter]]] = []
self._flat_weight_refs: list[Optional[weakref.ReferenceType[Parameter]]] = []
num_directions = 2 if bidirectional else 1
if (
@ -172,7 +172,7 @@ class RNNBase(Module):
# Second bias vector included for CuDNN compatibility. Only one
# bias vector is needed in standard definition.
b_hh = Parameter(torch.empty(gate_size, **factory_kwargs))
layer_params: Tuple[Tensor, ...] = ()
layer_params: tuple[Tensor, ...] = ()
if self.proj_size == 0:
if bias:
layer_params = (w_ih, w_hh, b_ih, b_hh)
@ -317,7 +317,7 @@ class RNNBase(Module):
def get_expected_hidden_size(
self, input: Tensor, batch_sizes: Optional[Tensor]
) -> Tuple[int, int, int]:
) -> tuple[int, int, int]:
if batch_sizes is not None:
mini_batch = int(batch_sizes[0])
else:
@ -340,7 +340,7 @@ class RNNBase(Module):
def check_hidden_size(
self,
hx: Tensor,
expected_hidden_size: Tuple[int, int, int],
expected_hidden_size: tuple[int, int, int],
msg: str = "Expected hidden size {}, got {}",
) -> None:
if hx.size() != expected_hidden_size:
@ -451,7 +451,7 @@ class RNNBase(Module):
]
@property
def all_weights(self) -> List[List[Parameter]]:
def all_weights(self) -> list[list[Parameter]]:
return [
[getattr(self, weight) for weight in weights]
for weights in self._all_weights
@ -639,14 +639,14 @@ class RNN(RNNBase):
@torch._jit_internal._overload_method # noqa: F811
def forward(
self, input: Tensor, hx: Optional[Tensor] = None
) -> Tuple[Tensor, Tensor]:
) -> tuple[Tensor, Tensor]:
pass
@overload
@torch._jit_internal._overload_method # noqa: F811
def forward(
self, input: PackedSequence, hx: Optional[Tensor] = None
) -> Tuple[PackedSequence, Tensor]:
) -> tuple[PackedSequence, Tensor]:
pass
def forward(self, input, hx=None): # noqa: F811
@ -978,7 +978,7 @@ class LSTM(RNNBase):
def get_expected_cell_size(
self, input: Tensor, batch_sizes: Optional[Tensor]
) -> Tuple[int, int, int]:
) -> tuple[int, int, int]:
if batch_sizes is not None:
mini_batch = int(batch_sizes[0])
else:
@ -996,7 +996,7 @@ class LSTM(RNNBase):
def check_forward_args(
self,
input: Tensor,
hidden: Tuple[Tensor, Tensor], # type: ignore[override]
hidden: tuple[Tensor, Tensor], # type: ignore[override]
batch_sizes: Optional[Tensor],
):
self.check_input(input, batch_sizes)
@ -1014,9 +1014,9 @@ class LSTM(RNNBase):
# Same as above, see torch/nn/modules/module.py::_forward_unimplemented
def permute_hidden( # type: ignore[override]
self,
hx: Tuple[Tensor, Tensor],
hx: tuple[Tensor, Tensor],
permutation: Optional[Tensor],
) -> Tuple[Tensor, Tensor]:
) -> tuple[Tensor, Tensor]:
if permutation is None:
return hx
return _apply_permutation(hx[0], permutation), _apply_permutation(
@ -1027,16 +1027,16 @@ class LSTM(RNNBase):
@overload # type: ignore[override]
@torch._jit_internal._overload_method # noqa: F811
def forward(
self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: # noqa: F811
self, input: Tensor, hx: Optional[tuple[Tensor, Tensor]] = None
) -> tuple[Tensor, tuple[Tensor, Tensor]]: # noqa: F811
pass
# Same as above, see torch/nn/modules/module.py::_forward_unimplemented
@overload
@torch._jit_internal._overload_method # noqa: F811
def forward(
self, input: PackedSequence, hx: Optional[Tuple[Tensor, Tensor]] = None
) -> Tuple[PackedSequence, Tuple[Tensor, Tensor]]: # noqa: F811
self, input: PackedSequence, hx: Optional[tuple[Tensor, Tensor]] = None
) -> tuple[PackedSequence, tuple[Tensor, Tensor]]: # noqa: F811
pass
def forward(self, input, hx=None): # noqa: F811
@ -1319,14 +1319,14 @@ class GRU(RNNBase):
@torch._jit_internal._overload_method # noqa: F811
def forward(
self, input: Tensor, hx: Optional[Tensor] = None
) -> Tuple[Tensor, Tensor]: # noqa: F811
) -> tuple[Tensor, Tensor]: # noqa: F811
pass
@overload
@torch._jit_internal._overload_method # noqa: F811
def forward(
self, input: PackedSequence, hx: Optional[Tensor] = None
) -> Tuple[PackedSequence, Tensor]: # noqa: F811
) -> tuple[PackedSequence, Tensor]: # noqa: F811
pass
def forward(self, input, hx=None): # noqa: F811
@ -1679,8 +1679,8 @@ class LSTMCell(RNNCellBase):
super().__init__(input_size, hidden_size, bias, num_chunks=4, **factory_kwargs)
def forward(
self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None
) -> Tuple[Tensor, Tensor]:
self, input: Tensor, hx: Optional[tuple[Tensor, Tensor]] = None
) -> tuple[Tensor, Tensor]:
if input.dim() not in (1, 2):
raise ValueError(
f"LSTMCell: Expected input to be 1D or 2D, got {input.dim()}D instead"

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs
import collections
from itertools import repeat
from typing import Any, Dict, List
from typing import Any
__all__ = ["consume_prefix_in_state_dict_if_present"]
@ -32,7 +32,7 @@ def _reverse_repeat_tuple(t, n):
return tuple(x for x in reversed(t) for _ in range(n))
def _list_with_default(out_size: List[int], defaults: List[int]) -> List[int]:
def _list_with_default(out_size: list[int], defaults: list[int]) -> list[int]:
import torch
if isinstance(out_size, (int, torch.SymInt)):
@ -45,7 +45,7 @@ def _list_with_default(out_size: List[int], defaults: List[int]) -> List[int]:
def consume_prefix_in_state_dict_if_present(
state_dict: Dict[str, Any],
state_dict: dict[str, Any],
prefix: str,
) -> None:
r"""Strip the prefix in state_dict in place, if any.

View File

@ -1,5 +1,5 @@
import warnings
from typing import List, Optional
from typing import Optional
import torch
from torch._utils import _get_device_index
@ -116,7 +116,7 @@ class Scatter(Function):
# background streams used for copying
_streams: Optional[List[Optional[torch.Stream]]] = None
_streams: Optional[list[Optional[torch.Stream]]] = None
def _get_stream(device: torch.device):

View File

@ -1,6 +1,5 @@
# mypy: allow-untyped-defs
import warnings
from typing import List
import torch
from torch._utils import (
@ -137,7 +136,7 @@ def reduce_add_coalesced(inputs, destination=None, buffer_size=10485760):
"""
# TODO: When `len(inputs) == 1` and all inputs are on `destination`, just
# return `inputs`.
dense_tensors: List[List] = [[] for _ in inputs] # shape (num_gpus, num_tensors)
dense_tensors: list[list] = [[] for _ in inputs] # shape (num_gpus, num_tensors)
output = []
ref_order = []
# process sparse ones first since they may have different sizes on different gpus

View File

@ -1,8 +1,9 @@
# mypy: allow-untyped-defs
import operator
import warnings
from collections.abc import Sequence
from itertools import chain
from typing import Any, Dict, Generic, List, Optional, Sequence, Tuple, TypeVar, Union
from typing import Any, Generic, Optional, TypeVar, Union
import torch
from torch._utils import (
@ -195,20 +196,20 @@ class DataParallel(Module, Generic[T]):
def replicate(
self, module: T, device_ids: Sequence[Union[int, torch.device]]
) -> List[T]:
) -> list[T]:
return replicate(module, device_ids, not torch.is_grad_enabled())
def scatter(
self,
inputs: Tuple[Any, ...],
kwargs: Optional[Dict[str, Any]],
inputs: tuple[Any, ...],
kwargs: Optional[dict[str, Any]],
device_ids: Sequence[Union[int, torch.device]],
) -> Any:
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
def parallel_apply(
self, replicas: Sequence[T], inputs: Sequence[Any], kwargs: Any
) -> List[Any]:
) -> list[Any]:
return parallel_apply(
replicas, inputs, kwargs, self.device_ids[: len(replicas)]
)

View File

@ -12,7 +12,7 @@ from collections import defaultdict, deque
from contextlib import contextmanager
from dataclasses import dataclass, fields, is_dataclass
from enum import auto, Enum
from typing import Any, Callable, List, Optional, Tuple, Type, TYPE_CHECKING
from typing import Any, Callable, Optional, TYPE_CHECKING
import torch
import torch.distributed as dist
@ -815,7 +815,7 @@ class DistributedDataParallel(Module, Joinable):
# Initialize gradient buffers and register all reduce hook
self._delay_grad_buffer: Optional[torch.Tensor] = None
self._delay_grad_views: List[torch.Tensor] = []
self._delay_grad_views: list[torch.Tensor] = []
self._delay_all_reduce_all_params = False
if len(self._delay_all_reduce_params) != 0:
self._register_delay_all_reduce_hook(
@ -853,7 +853,7 @@ class DistributedDataParallel(Module, Joinable):
param_to_name_mapping,
static_graph,
)
self._comm_hooks: List[Tuple[Callable, object]] = []
self._comm_hooks: list[tuple[Callable, object]] = []
if self.mixed_precision is not None:
_setup_mixed_precision_params(self.mixed_precision, self.module)
@ -907,7 +907,7 @@ class DistributedDataParallel(Module, Joinable):
# Register the AccumulateGrad post hooks if optimize_ddp is
# True. The hooks will be deregistered if compiled_autograd is not
# enabled.
self._accum_grad_hooks: List[RemovableHandle] = []
self._accum_grad_hooks: list[RemovableHandle] = []
optimize_ddp = torch._dynamo.config._get_optimize_ddp_mode()
self._use_python_reducer = optimize_ddp in (
"python_reducer",
@ -1614,7 +1614,7 @@ class DistributedDataParallel(Module, Joinable):
treespec,
output_is_rref,
) = _tree_flatten_with_rref(output)
output_placeholders: List[Optional[torch.Tensor]] = [
output_placeholders: list[Optional[torch.Tensor]] = [
None for _ in range(len(output_tensor_list))
]
# Do not touch tensors that have no grad_fn, which can cause issues
@ -2046,7 +2046,7 @@ class DistributedDataParallel(Module, Joinable):
self.logger._set_comm_hook_name(str(comm_hook_type))
dist._register_builtin_comm_hook(self.reducer, comm_hook_type)
def _register_fused_optim(self, optim: Type, *args, optim_params=None, **kwargs):
def _register_fused_optim(self, optim: type, *args, optim_params=None, **kwargs):
r"""
Register an optimizer in DDP to optimize parameter immediately after its gradient reduction.

View File

@ -1,5 +1,6 @@
import threading
from typing import Any, cast, Dict, List, Optional, Sequence, Tuple, Union
from collections.abc import Sequence
from typing import Any, cast, Optional, Union
import torch
from torch._utils import ExceptionWrapper
@ -11,7 +12,7 @@ __all__ = ["get_a_var", "parallel_apply"]
def get_a_var(
obj: Union[torch.Tensor, List[Any], Tuple[Any, ...], Dict[Any, Any]],
obj: Union[torch.Tensor, list[Any], tuple[Any, ...], dict[Any, Any]],
) -> Optional[torch.Tensor]:
if isinstance(obj, torch.Tensor):
return obj
@ -30,9 +31,9 @@ def get_a_var(
def parallel_apply(
modules: Sequence[Module],
inputs: Sequence[Any],
kwargs_tup: Optional[Sequence[Dict[str, Any]]] = None,
kwargs_tup: Optional[Sequence[dict[str, Any]]] = None,
devices: Optional[Sequence[Optional[Union[int, torch.device]]]] = None,
) -> List[Any]:
) -> list[Any]:
r"""Apply each `module` in :attr:`modules` in parallel on each of :attr:`devices`.
Args:
@ -51,7 +52,7 @@ def parallel_apply(
if kwargs_tup is not None:
assert len(modules) == len(kwargs_tup)
else:
kwargs_tup = (cast(Dict[str, Any], {}),) * len(modules)
kwargs_tup = (cast(dict[str, Any], {}),) * len(modules)
if devices is not None:
assert len(modules) == len(devices)
else:
@ -69,7 +70,7 @@ def parallel_apply(
i: int,
module: Module,
input: Any,
kwargs: Dict[str, Any],
kwargs: dict[str, Any],
device: Optional[Union[int, torch.device]] = None,
stream: Optional[torch.cuda.Stream] = None,
) -> None:

View File

@ -1,16 +1,6 @@
from collections import OrderedDict
from typing import (
cast,
Dict,
Iterator,
List,
Optional,
Sequence,
Set,
TYPE_CHECKING,
TypeVar,
Union,
)
from collections.abc import Iterator, Sequence
from typing import cast, Optional, TYPE_CHECKING, TypeVar, Union
from typing_extensions import TypeIs
import torch
@ -59,7 +49,7 @@ def _is_jit_enabled() -> "EnabledProxy":
#
# currently a module cannot be replicated properly if the descendants of
# any ScriptModule contains python module (type 1 above)
def _replicatable_module(module: Module, memo: Optional[Set[Module]] = None) -> bool:
def _replicatable_module(module: Module, memo: Optional[set[Module]] = None) -> bool:
# module.modules() contains module itself as the first element
def descendant_modules(module: Module) -> Iterator[Module]:
gen = module.modules()
@ -94,7 +84,7 @@ def _broadcast_coalesced_reshape(
tensors: Sequence[torch.Tensor],
devices: Sequence[Union[int, torch.device]],
detach: bool = False,
) -> List[List[torch.Tensor]]:
) -> list[list[torch.Tensor]]:
from torch.nn.parallel._functions import Broadcast
if detach:
@ -118,7 +108,7 @@ def replicate(
network: T,
devices: Sequence[Union[int, torch.device]],
detach: bool = False,
) -> List[T]:
) -> list[T]:
if not _replicatable_module(network):
raise RuntimeError(
"Cannot replicate network where python modules are "
@ -136,8 +126,8 @@ def replicate(
param_copies = _broadcast_coalesced_reshape(params, devices, detach)
buffers = list(network.buffers())
buffers_rg: List[torch.Tensor] = []
buffers_not_rg: List[torch.Tensor] = []
buffers_rg: list[torch.Tensor] = []
buffers_not_rg: list[torch.Tensor] = []
for buf in buffers:
if buf.requires_grad and not detach:
buffers_rg.append(buf)
@ -153,8 +143,8 @@ def replicate(
)
modules = list(network.modules())
module_copies: List[List[Module]] = [[] for _ in devices]
module_indices: Dict[Module, int] = {}
module_copies: list[list[Module]] = [[] for _ in devices]
module_indices: dict[Module, int] = {}
for i, module in enumerate(modules):
module_indices[module] = i

View File

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs
from typing import Any, Dict, List, Optional, overload, Sequence, Tuple, TypeVar, Union
from collections.abc import Sequence
from typing import Any, Optional, overload, TypeVar, Union
from typing_extensions import deprecated
import torch
@ -34,7 +35,7 @@ def scatter(
inputs: torch.Tensor,
target_gpus: Sequence[Union[int, torch.device]],
dim: int = ...,
) -> Tuple[torch.Tensor, ...]:
) -> tuple[torch.Tensor, ...]:
...
@ -43,7 +44,7 @@ def scatter(
inputs: T,
target_gpus: Sequence[Union[int, torch.device]],
dim: int = ...,
) -> List[T]:
) -> list[T]:
...
@ -79,11 +80,11 @@ def scatter(inputs, target_gpus, dim=0):
def scatter_kwargs(
inputs: Tuple[Any, ...],
kwargs: Optional[Dict[str, Any]],
inputs: tuple[Any, ...],
kwargs: Optional[dict[str, Any]],
target_gpus: Sequence[Union[int, torch.device]],
dim: int = 0,
) -> Tuple[Tuple[Any, ...], Tuple[Dict[str, Any], ...]]:
) -> tuple[tuple[Any, ...], tuple[dict[str, Any], ...]]:
r"""Scatter with support for kwargs dictionary."""
scattered_inputs = scatter(inputs, target_gpus, dim) if inputs else []
scattered_kwargs = scatter(kwargs, target_gpus, dim) if kwargs else []

View File

@ -1,6 +1,6 @@
import importlib
import warnings
from typing import Callable, List
from typing import Callable
_MESSAGE_TEMPLATE = (
@ -9,7 +9,7 @@ _MESSAGE_TEMPLATE = (
def lazy_deprecated_import(
all: List[str],
all: list[str],
old_module: str,
new_module: str,
) -> Callable:

View File

@ -1,5 +1,5 @@
# mypy: allow-untyped-defs
from typing import List, Optional
from typing import Optional
import numpy as np
@ -68,7 +68,7 @@ def int_padding_for_string_padding(func, padding_style, dilation, kernel_size):
return dilation[i] if isinstance(dilation, tuple) else dilation
if padding_style == "same":
padding: List[int] = []
padding: list[int] = []
# F.pad needs the padding in reverse order from what conv expects
for i in range(conv_picker(func, 0, 1, 2), -1, -1):
padding += conv_padding_for_same(get_dilation(i), kernel_size[i])
@ -145,7 +145,7 @@ def conv_backward(func, ctx, grad_output):
kernel_size = [weight_shape[i] for i in range(2, conv_picker(func, 3, 4, 5))]
batch_size = ctx.batch_size
results: List[Optional[torch.Tensor]] = []
results: list[Optional[torch.Tensor]] = []
results.append(None) # for kwarg names
results.append(None) # for op reference

View File

@ -1,5 +1,5 @@
# mypy: allow-untyped-defs
from typing import List, Optional
from typing import Optional
import torch
import torch.nn.functional as F
@ -56,7 +56,7 @@ class EmbeddingPerSampleGrad(torch.autograd.Function):
1, index, grad_output.reshape(batch_size, -1, embedding_dim)
)
results: List[Optional[torch.Tensor]] = []
results: list[Optional[torch.Tensor]] = []
results.append(None) # for kwarg names
results.append(None) # for op reference

View File

@ -1,14 +1,14 @@
# mypy: allow-untyped-defs
import functools
from contextlib import contextmanager
from typing import Callable, Dict
from typing import Callable
import torch
from torch._decomp import decomposition_table
from torch.utils._pytree import tree_map_only
HANDLED_FUNCTIONS: Dict[Callable, torch.autograd.Function] = {}
HANDLED_FUNCTIONS: dict[Callable, torch.autograd.Function] = {}
aten = torch._ops.ops.aten
# __torch_function__ runs before the pydispatcher so we need to manually use the same

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs
import operator
from functools import reduce
from typing import List, Optional
from typing import Optional
import torch
import torch.nn.functional as F
@ -51,7 +51,7 @@ class GroupNormPerSampleGrad(torch.autograd.Function):
weight, bias, eps = ctx.weight, ctx.bias, ctx.eps
mean, rstd = ctx.mean, ctx.rstd
results: List[Optional[torch.Tensor]] = []
results: list[Optional[torch.Tensor]] = []
results.append(None) # for kwarg names
results.append(None) # for op reference

View File

@ -1,6 +1,6 @@
# mypy: allow-untyped-defs
from functools import partial
from typing import List, Optional
from typing import Optional
import torch
import torch.nn.functional as F
@ -40,7 +40,7 @@ class InstanceNormPerSampleGrad(torch.autograd.Function):
input, running_mean, running_var = ctx.input, ctx.running_mean, ctx.running_var
weight, bias, eps = ctx.weight, ctx.bias, ctx.eps
results: List[Optional[torch.Tensor]] = []
results: list[Optional[torch.Tensor]] = []
results.append(None) # for kwarg names
results.append(None) # for op reference
if input.requires_grad:

View File

@ -1,5 +1,5 @@
# mypy: allow-untyped-defs
from typing import List, Optional
from typing import Optional
import torch
import torch.nn.functional as F
@ -52,7 +52,7 @@ class LayerNormPerSampleGrad(torch.autograd.Function):
input, normalized_shape = ctx.args
mean, rstd = ctx.mean, ctx.rstd
results: List[Optional[torch.Tensor]] = []
results: list[Optional[torch.Tensor]] = []
results.append(None) # for kwarg names
results.append(None) # for op reference
if input.requires_grad:

View File

@ -1,5 +1,5 @@
# mypy: allow-untyped-defs
from typing import List, Optional
from typing import Optional
import torch
import torch.nn.functional as F
@ -38,7 +38,7 @@ class LinearPerSampleGrad(torch.autograd.Function):
def backward(ctx, grad_output):
input, weight = ctx.args
bias = ctx.kwargs["bias"]
results: List[Optional[torch.Tensor]] = []
results: list[Optional[torch.Tensor]] = []
results.append(None) # for kwarg_names
results.append(None) # for op reference

View File

@ -1,7 +1,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, Iterable, List, Tuple
from collections.abc import Iterable
import torch
@ -113,7 +113,7 @@ class NamedMemberAccessor:
def __init__(self, module: "torch.nn.Module") -> None:
self.module = module
self.memo: Dict[str, torch.nn.Module] = {}
self.memo: dict[str, torch.nn.Module] = {}
# Nested attribute access
@ -225,7 +225,7 @@ class NamedMemberAccessor:
# Batched operations
def get_tensors(self, names: Iterable[str]) -> List[torch.Tensor]:
def get_tensors(self, names: Iterable[str]) -> list[torch.Tensor]:
"""
Get the tensors specified by the given paths.
@ -252,7 +252,7 @@ class NamedMemberAccessor:
for name, value in zip(names, values):
self.set_tensor(name, value)
def set_tensors_dict(self, named_tensors: Dict[str, torch.Tensor]) -> None:
def set_tensors_dict(self, named_tensors: dict[str, torch.Tensor]) -> None:
"""
Set the attributes specified by the given paths to values.
@ -281,7 +281,7 @@ class NamedMemberAccessor:
names: Iterable[str],
values: Iterable[torch.Tensor],
allow_missing: bool = False,
) -> List[torch.Tensor]:
) -> list[torch.Tensor]:
"""
Swap the attributes specified by the given paths to values.
@ -301,8 +301,8 @@ class NamedMemberAccessor:
]
def swap_tensors_dict(
self, named_tensors: Dict[str, torch.Tensor], allow_missing: bool = False
) -> Tuple[Dict[str, torch.Tensor], List[str]]:
self, named_tensors: dict[str, torch.Tensor], allow_missing: bool = False
) -> tuple[dict[str, torch.Tensor], list[str]]:
"""
Swap the attributes specified by the given paths to values.
@ -332,7 +332,7 @@ class NamedMemberAccessor:
raise RuntimeError(f"Missing key(s): {', '.join(map(repr, missing_keys))}.")
return orig_named_tensors, missing_keys
def check_keys(self, keys: Iterable[str]) -> Tuple[List[str], List[str]]:
def check_keys(self, keys: Iterable[str]) -> tuple[list[str], list[str]]:
"""Check that the given keys are valid."""
keys = set(keys)
valid_keys = {name for name, _ in self.named_tensors(remove_duplicate=False)}
@ -345,21 +345,21 @@ class NamedMemberAccessor:
def named_parameters(
self,
remove_duplicate: bool = True,
) -> Iterable[Tuple[str, torch.Tensor]]:
) -> Iterable[tuple[str, torch.Tensor]]:
"""Iterate over all the parameters in the module."""
yield from self.module.named_parameters(remove_duplicate=remove_duplicate)
def named_buffers(
self,
remove_duplicate: bool = True,
) -> Iterable[Tuple[str, torch.Tensor]]:
) -> Iterable[tuple[str, torch.Tensor]]:
"""Iterate over all the buffers in the module."""
yield from self.module.named_buffers(remove_duplicate=remove_duplicate)
def named_tensors(
self,
remove_duplicate: bool = True,
) -> Iterable[Tuple[str, torch.Tensor]]:
) -> Iterable[tuple[str, torch.Tensor]]:
"""Iterate over all the tensors in the module."""
yield from self.module.named_parameters(remove_duplicate=remove_duplicate)
yield from self.module.named_buffers(remove_duplicate=remove_duplicate)
@ -367,6 +367,6 @@ class NamedMemberAccessor:
def named_modules(
self,
remove_duplicate: bool = True,
) -> Iterable[Tuple[str, "torch.nn.Module"]]:
) -> Iterable[tuple[str, "torch.nn.Module"]]:
"""Iterate over all the modules in the module."""
yield from self.module.named_modules(remove_duplicate=remove_duplicate)

View File

@ -1,4 +1,5 @@
from typing import Iterable, Optional
from collections.abc import Iterable
from typing import Optional
import torch

View File

@ -1,7 +1,7 @@
from __future__ import annotations
import copy
from typing import Optional, Tuple, TypeVar
from typing import TypeVar
import torch
@ -55,14 +55,14 @@ def fuse_conv_bn_eval(
def fuse_conv_bn_weights(
conv_w: torch.Tensor,
conv_b: Optional[torch.Tensor],
conv_b: torch.Tensor | None,
bn_rm: torch.Tensor,
bn_rv: torch.Tensor,
bn_eps: float,
bn_w: Optional[torch.Tensor],
bn_b: Optional[torch.Tensor],
bn_w: torch.Tensor | None,
bn_b: torch.Tensor | None,
transpose: bool = False,
) -> Tuple[torch.nn.Parameter, torch.nn.Parameter]:
) -> tuple[torch.nn.Parameter, torch.nn.Parameter]:
r"""Fuse convolutional module parameters and BatchNorm module parameters into new convolutional module parameters.
Args:
@ -155,13 +155,13 @@ def fuse_linear_bn_eval(
def fuse_linear_bn_weights(
linear_w: torch.Tensor,
linear_b: Optional[torch.Tensor],
linear_b: torch.Tensor | None,
bn_rm: torch.Tensor,
bn_rv: torch.Tensor,
bn_eps: float,
bn_w: torch.Tensor,
bn_b: torch.Tensor,
) -> Tuple[torch.nn.Parameter, torch.nn.Parameter]:
) -> tuple[torch.nn.Parameter, torch.nn.Parameter]:
r"""Fuse linear module parameters and BatchNorm module parameters into new linear module parameters.
Args:

View File

@ -2,9 +2,10 @@
# mypy: allow-untyped-defs
import collections
import copyreg
from collections.abc import Sequence
from contextlib import contextmanager
from copy import deepcopy
from typing import Dict, Optional, Sequence, Tuple, Union
from typing import Optional, Union
import torch
from torch import Tensor
@ -25,7 +26,7 @@ __all__ = [
]
_cache_enabled = 0
_cache: Dict[Tuple[int, str], Optional[Tensor]] = {}
_cache: dict[tuple[int, str], Optional[Tensor]] = {}
@contextmanager
@ -165,9 +166,7 @@ class ParametrizationList(ModuleList):
pass
# else, or if it throws, we assume that right_inverse is the identity
if not isinstance(new, Tensor) and not isinstance(
new, collections.abc.Sequence
):
if not isinstance(new, Tensor) and not isinstance(new, Sequence):
raise ValueError(
"'right_inverse' must return a Tensor or a Sequence of tensors (list, tuple...). "
f"Got {type(new).__name__}"

View File

@ -3,7 +3,6 @@ r"""Pruning methods."""
import numbers
from abc import ABC, abstractmethod
from collections.abc import Iterable
from typing import Tuple
import torch
@ -270,7 +269,7 @@ class PruningContainer(BasePruningMethod):
"""
def __init__(self, *args):
self._pruning_methods: Tuple[BasePruningMethod, ...] = ()
self._pruning_methods: tuple[BasePruningMethod, ...] = ()
if not isinstance(args, Iterable): # only 1 item
self._tensor_name = args._tensor_name
self.add_pruning_method(args)

View File

@ -1,16 +1,6 @@
import warnings
from collections.abc import Iterable
from typing import (
Any,
Callable,
List,
NamedTuple,
Optional,
overload,
Tuple,
TypeVar,
Union,
)
from typing import Any, Callable, NamedTuple, Optional, overload, TypeVar, Union
from typing_extensions import Self
import torch
@ -228,7 +218,7 @@ def _packed_sequence_init_args(
batch_sizes: Optional[Tensor] = None,
sorted_indices: Optional[Tensor] = None,
unsorted_indices: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
) -> tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
# NB: if unsorted_indices is provided, it should be the inverse permutation
# to sorted_indices. Don't assert it here because the PackedSequence ctor
# should only be used internally.
@ -279,7 +269,7 @@ def invert_permutation(permutation: Optional[Tensor]) -> Optional[Tensor]:
def pack_padded_sequence(
input: Tensor,
lengths: Union[Tensor, List[int]],
lengths: Union[Tensor, list[int]],
batch_first: bool = False,
enforce_sorted: bool = True,
) -> PackedSequence:
@ -347,7 +337,7 @@ def pad_packed_sequence(
batch_first: bool = False,
padding_value: float = 0.0,
total_length: Optional[int] = None,
) -> Tuple[Tensor, Tensor]:
) -> tuple[Tensor, Tensor]:
r"""Pad a packed batch of variable length sequences.
It is an inverse operation to :func:`pack_padded_sequence`.
@ -419,7 +409,7 @@ def pad_packed_sequence(
# NOTE: for JIT-compatibility, we need to be more restrictive here and use specific types instead of Iterable.
def pad_sequence(
sequences: Union[Tensor, List[Tensor]],
sequences: Union[Tensor, list[Tensor]],
batch_first: bool = False,
padding_value: float = 0.0,
padding_side: str = "right",
@ -487,7 +477,7 @@ def unpad_sequence(
padded_sequences: Tensor,
lengths: Tensor,
batch_first: bool = False,
) -> List[Tensor]:
) -> list[Tensor]:
r"""Unpad padded Tensor into a list of variable length Tensors.
``unpad_sequence`` unstacks padded Tensor into a list of variable length Tensors.
@ -533,7 +523,7 @@ def unpad_sequence(
def pack_sequence(
sequences: List[Tensor],
sequences: list[Tensor],
enforce_sorted: bool = True,
) -> PackedSequence:
r"""Packs a list of variable length Tensors.
@ -571,7 +561,7 @@ def pack_sequence(
)
def unpack_sequence(packed_sequences: PackedSequence) -> List[Tensor]:
def unpack_sequence(packed_sequences: PackedSequence) -> list[Tensor]:
r"""Unpack PackedSequence into a list of variable length Tensors.
``packed_sequences`` should be a PackedSequence object.

View File

@ -1,6 +1,6 @@
# mypy: allow-untyped-defs
import contextlib
from typing import Any, Dict, Optional, Set, Tuple, Union
from typing import Any, Optional, Union
from typing_extensions import deprecated
import torch
@ -13,8 +13,8 @@ __all__ = ["functional_call"]
def _untie_named_tensors_map(
module: "torch.nn.Module",
parameters_and_buffers: Dict[str, Tensor],
) -> Dict[str, Tensor]:
parameters_and_buffers: dict[str, Tensor],
) -> dict[str, Tensor]:
"""
Unties all tied tensors in the module to parameters_and_buffers.
@ -41,12 +41,12 @@ def _untie_named_tensors_map(
ValueError: if there are more than one user-given values for the same tied tensor.
"""
# A map of {name: tensor} for all tensors (including tied ones) in the module.
all_named_tensors: Dict[str, Tensor] = {}
all_named_tensors: dict[str, Tensor] = {}
all_named_tensors.update(module.named_parameters(remove_duplicate=False))
all_named_tensors.update(module.named_buffers(remove_duplicate=False))
# A map of {tensor: set(all_tied_names)} for all tensor names in the module.
tensor_to_tied_names_map: Dict[Tensor, Set[str]] = {}
tensor_to_tied_names_map: dict[Tensor, set[str]] = {}
for name, tensor in all_named_tensors.items():
if tensor not in tensor_to_tied_names_map:
tensor_to_tied_names_map[tensor] = set()
@ -54,7 +54,7 @@ def _untie_named_tensors_map(
# A map of {tied_name: set(all_tied_names)} for all tensor names in the module.
# If a name is not tied, it will not be in this map.
tied_names_map: Dict[str, Set[str]] = {}
tied_names_map: dict[str, set[str]] = {}
for tied_names in tensor_to_tied_names_map.values():
if len(tied_names) > 1:
for tied_name in tied_names:
@ -98,7 +98,7 @@ def _untie_named_tensors_map(
@contextlib.contextmanager
def _reparametrize_module(
module: "torch.nn.Module",
parameters_and_buffers: Dict[str, Tensor],
parameters_and_buffers: dict[str, Tensor],
tie_weights: bool = False,
strict: bool = False,
stack_weights: bool = False,
@ -132,7 +132,7 @@ def _reparametrize_module(
)
)
orig_parameters_and_buffers: Dict[str, Tensor] = {}
orig_parameters_and_buffers: dict[str, Tensor] = {}
try:
orig_parameters_and_buffers, _ = accessor.swap_tensors_dict(
untied_parameters_and_buffers, allow_missing=True
@ -167,9 +167,9 @@ def _reparametrize_module(
)
def functional_call(
module: "torch.nn.Module",
parameters_and_buffers: Dict[str, Tensor],
args: Optional[Union[Any, Tuple]] = None,
kwargs: Optional[Dict[str, Any]] = None,
parameters_and_buffers: dict[str, Tensor],
args: Optional[Union[Any, tuple]] = None,
kwargs: Optional[dict[str, Any]] = None,
*,
tie_weights: bool = True,
strict: bool = False,
@ -245,9 +245,9 @@ def functional_call(
def _functional_call(
module: "torch.nn.Module",
parameters_and_buffers: Dict[str, Tensor],
args: Optional[Union[Any, Tuple]] = None,
kwargs: Optional[Dict[str, Any]] = None,
parameters_and_buffers: dict[str, Tensor],
args: Optional[Union[Any, tuple]] = None,
kwargs: Optional[dict[str, Any]] = None,
*,
tie_weights: bool = True,
strict: bool = False,