mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[5/N] Apply Ruff fixes and pyupgrade to Python 3.9 (#144205)
Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/144205 Approved by: https://github.com/albanD
This commit is contained in:
parent
db787181b5
commit
d87aad6877
|
|
@ -9,7 +9,8 @@ half, float, double and bfloat16) and complex :class:`Tensor` types (cfloat, cdo
|
|||
"""
|
||||
|
||||
import warnings
|
||||
from typing import cast, List, Optional, Sequence, Tuple, Union
|
||||
from collections.abc import Sequence
|
||||
from typing import cast, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import _vmap_internals
|
||||
|
|
@ -60,7 +61,7 @@ def _calculate_shape(
|
|||
output: Union[torch.Tensor, graph.GradientEdge],
|
||||
grad: torch.Tensor,
|
||||
is_grads_batched: bool,
|
||||
) -> Tuple[_ShapeorNestedShape, _ShapeorNestedShape]:
|
||||
) -> tuple[_ShapeorNestedShape, _ShapeorNestedShape]:
|
||||
# is_same_size ensures that both tensors are either nested or non nested
|
||||
# circular import
|
||||
from torch.nested._internal.nested_tensor import NestedTensor
|
||||
|
|
@ -89,8 +90,8 @@ def _make_grads(
|
|||
outputs: Union[Sequence[torch.Tensor], Sequence[graph.GradientEdge]],
|
||||
grads: Sequence[_OptionalTensor],
|
||||
is_grads_batched: bool,
|
||||
) -> Tuple[_OptionalTensor, ...]:
|
||||
new_grads: List[_OptionalTensor] = []
|
||||
) -> tuple[_OptionalTensor, ...]:
|
||||
new_grads: list[_OptionalTensor] = []
|
||||
for out, grad in zip(outputs, grads):
|
||||
out = cast(Union[torch.Tensor, graph.GradientEdge], out)
|
||||
out_size = None
|
||||
|
|
@ -231,7 +232,7 @@ def _make_grads(
|
|||
|
||||
def _tensor_or_tensors_to_tuple(
|
||||
tensors: Optional[_TensorOrTensors], length: int
|
||||
) -> Tuple[_OptionalTensor, ...]:
|
||||
) -> tuple[_OptionalTensor, ...]:
|
||||
if tensors is None:
|
||||
return (None,) * length
|
||||
if isinstance(tensors, torch.Tensor):
|
||||
|
|
@ -370,7 +371,7 @@ def grad(
|
|||
allow_unused: Optional[bool] = None,
|
||||
is_grads_batched: bool = False,
|
||||
materialize_grads: bool = False,
|
||||
) -> Tuple[torch.Tensor, ...]:
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
r"""Compute and return the sum of gradients of outputs with respect to the inputs.
|
||||
|
||||
``grad_outputs`` should be a sequence of length matching ``output``
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import inspect
|
|||
import itertools
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from typing import Any, List, Optional, Tuple
|
||||
from typing import Any, Optional
|
||||
from typing_extensions import deprecated
|
||||
|
||||
import torch
|
||||
|
|
@ -389,7 +389,7 @@ class _SingleLevelFunction(
|
|||
)
|
||||
|
||||
@staticmethod
|
||||
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> Any:
|
||||
def setup_context(ctx: Any, inputs: tuple[Any, ...], output: Any) -> Any:
|
||||
r"""There are two ways to define the forward pass of an autograd.Function.
|
||||
|
||||
Either:
|
||||
|
|
@ -722,7 +722,7 @@ def _unflatten(input, proto):
|
|||
# unflatten a list or tuple input into a nested list/tuple structure
|
||||
# specified by proto
|
||||
def unflatten_helper(input, proto):
|
||||
res: List[Optional[torch.Tensor]] = []
|
||||
res: list[Optional[torch.Tensor]] = []
|
||||
if hasattr(proto, "_jit_wrap"):
|
||||
return proto._jit_wrap(input)
|
||||
if not isinstance(proto, (list, tuple)):
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
# mypy: allow-untyped-defs
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
from torch._vmap_internals import _vmap
|
||||
|
|
@ -181,8 +180,8 @@ def _autograd_grad(
|
|||
assert isinstance(grad_outputs, tuple)
|
||||
assert len(outputs) == len(grad_outputs)
|
||||
|
||||
new_outputs: Tuple[torch.Tensor, ...] = ()
|
||||
new_grad_outputs: Tuple[torch.Tensor, ...] = ()
|
||||
new_outputs: tuple[torch.Tensor, ...] = ()
|
||||
new_grad_outputs: tuple[torch.Tensor, ...] = ()
|
||||
for out, grad_out in zip(outputs, grad_outputs):
|
||||
if out is not None and out.requires_grad:
|
||||
new_outputs += (out,)
|
||||
|
|
@ -211,7 +210,7 @@ def _fill_in_zeros(grads, refs, strict, create_graph, stage):
|
|||
if stage not in ["back", "back_trick", "double_back", "double_back_trick"]:
|
||||
raise RuntimeError(f"Invalid stage argument '{stage}' to _fill_in_zeros")
|
||||
|
||||
res: Tuple[torch.Tensor, ...] = ()
|
||||
res: tuple[torch.Tensor, ...] = ()
|
||||
for i, grads_i in enumerate(grads):
|
||||
if grads_i is None:
|
||||
if strict:
|
||||
|
|
@ -470,8 +469,8 @@ def jvp(func, inputs, v=None, create_graph=False, strict=False):
|
|||
|
||||
|
||||
def _construct_standard_basis_for(
|
||||
tensors: Tuple[torch.Tensor, ...], tensor_numels: Tuple[int, ...]
|
||||
) -> Tuple[torch.Tensor, ...]:
|
||||
tensors: tuple[torch.Tensor, ...], tensor_numels: tuple[int, ...]
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
# This function:
|
||||
# - constructs a N=sum(tensor_numels) standard basis. i.e. an NxN identity matrix.
|
||||
# - Splits the identity matrix into chunks with each chunk size determined by `tensor_numels`.
|
||||
|
|
@ -780,11 +779,11 @@ def jacobian(
|
|||
jacobian_output_input, (is_outputs_tuple, is_inputs_tuple)
|
||||
)
|
||||
|
||||
jacobian: Tuple[torch.Tensor, ...] = ()
|
||||
jacobian: tuple[torch.Tensor, ...] = ()
|
||||
|
||||
for i, out in enumerate(outputs):
|
||||
# mypy complains that expression and variable have different types due to the empty list
|
||||
jac_i: Tuple[List[torch.Tensor]] = tuple([] for _ in range(len(inputs))) # type: ignore[assignment]
|
||||
jac_i: tuple[list[torch.Tensor]] = tuple([] for _ in range(len(inputs))) # type: ignore[assignment]
|
||||
for j in range(out.nelement()):
|
||||
vj = _autograd_grad(
|
||||
(out.reshape(-1)[j],),
|
||||
|
|
|
|||
|
|
@ -2,8 +2,9 @@
|
|||
import collections
|
||||
import functools
|
||||
import warnings
|
||||
from collections.abc import Iterable
|
||||
from itertools import product
|
||||
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
|
||||
from typing import Callable, Optional, Union
|
||||
from typing_extensions import deprecated
|
||||
|
||||
import torch
|
||||
|
|
@ -48,14 +49,14 @@ def _is_float_or_complex_tensor(obj):
|
|||
|
||||
|
||||
def _allocate_jacobians_with_inputs(
|
||||
input_tensors: Tuple, numel_output
|
||||
) -> Tuple[torch.Tensor, ...]:
|
||||
input_tensors: tuple, numel_output
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
# Makes zero-filled tensors from inputs. If `numel_output` is not None, for
|
||||
# each tensor in `input_tensors`, returns a new zero-filled tensor with height
|
||||
# of `t.numel` and width of `numel_output`. Otherwise, for each tensor, returns
|
||||
# a 1-d tensor with size `(t.numel,)`. Each new tensor will be strided and have
|
||||
# the same dtype and device as those of the corresponding input.
|
||||
out: List[torch.Tensor] = [
|
||||
out: list[torch.Tensor] = [
|
||||
t.new_zeros((t.numel(), numel_output), layout=torch.strided)
|
||||
for t in input_tensors
|
||||
if _is_float_or_complex_tensor(t) and t.requires_grad
|
||||
|
|
@ -64,14 +65,14 @@ def _allocate_jacobians_with_inputs(
|
|||
|
||||
|
||||
def _allocate_jacobians_with_outputs(
|
||||
output_tensors: Tuple, numel_input, dtype=None, device=None
|
||||
) -> Tuple[torch.Tensor, ...]:
|
||||
output_tensors: tuple, numel_input, dtype=None, device=None
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
# Makes zero-filled tensors from outputs. If `dim` is not None, for each tensor
|
||||
# in `output_tensors`, returns a new zero-filled tensor with height of `dim` and
|
||||
# width of `t.numel`. Otherwise, for each tensor, returns a 1-d tensor with size
|
||||
# (t.numel,).
|
||||
options = {"dtype": dtype, "device": device, "layout": torch.strided}
|
||||
out: List[torch.Tensor] = [
|
||||
out: list[torch.Tensor] = [
|
||||
t.new_zeros((numel_input, t.numel()), **options)
|
||||
for t in output_tensors
|
||||
if _is_float_or_complex_tensor(t)
|
||||
|
|
@ -258,7 +259,7 @@ def _iter_tensor(x_tensor):
|
|||
|
||||
def _get_numerical_jacobian(
|
||||
fn, inputs, outputs=None, target=None, eps=1e-3, is_forward_ad=False
|
||||
) -> List[Tuple[torch.Tensor, ...]]:
|
||||
) -> list[tuple[torch.Tensor, ...]]:
|
||||
"""Compute the numerical Jacobian of `fn(inputs)` with respect to `target`.
|
||||
|
||||
If not specified, targets are the input. Returns M * N Jacobians where N is the
|
||||
|
|
@ -281,7 +282,7 @@ def _get_numerical_jacobian(
|
|||
Note that `target` may not even be part of `input` to `fn`, so please be
|
||||
**very careful** in this to not clone `target`.
|
||||
"""
|
||||
jacobians: List[Tuple[torch.Tensor, ...]] = []
|
||||
jacobians: list[tuple[torch.Tensor, ...]] = []
|
||||
if outputs is None:
|
||||
outputs = _as_tuple(fn(*_as_tuple(inputs)))
|
||||
if not is_forward_ad and any(o.is_complex() for o in outputs):
|
||||
|
|
@ -386,13 +387,13 @@ def _compute_numerical_gradient(fn, entry, v, norm_v, nbhd_checks_fn):
|
|||
|
||||
def _compute_numerical_jvps_wrt_specific_input(
|
||||
jvp_fn, delta, input_is_complex, is_forward_ad=False
|
||||
) -> List[torch.Tensor]:
|
||||
) -> list[torch.Tensor]:
|
||||
# Computing the jacobian only works for real delta
|
||||
# For details on the algorithm used here, refer:
|
||||
# Section 3.5.3 https://arxiv.org/pdf/1701.00392.pdf
|
||||
# s = fn(z) where z = x for real valued input
|
||||
# and z = x + yj for complex valued input
|
||||
jvps: List[torch.Tensor] = []
|
||||
jvps: list[torch.Tensor] = []
|
||||
ds_dx_tup = jvp_fn(delta[0] if isinstance(delta, tuple) else delta)
|
||||
|
||||
if input_is_complex: # C -> R
|
||||
|
|
@ -412,8 +413,8 @@ def _compute_numerical_jvps_wrt_specific_input(
|
|||
|
||||
|
||||
def _combine_jacobian_cols(
|
||||
jacobians_cols: Dict[int, List[torch.Tensor]], outputs, input, numel
|
||||
) -> Tuple[torch.Tensor, ...]:
|
||||
jacobians_cols: dict[int, list[torch.Tensor]], outputs, input, numel
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
# jacobian_cols maps column_idx -> output_idx -> single column of jacobian Tensor
|
||||
# we return a list that maps output_idx -> full jacobian Tensor
|
||||
jacobians = _allocate_jacobians_with_outputs(
|
||||
|
|
@ -467,13 +468,13 @@ def _check_outputs_same_dtype_and_shape(output1, output2, eps, idx=None) -> None
|
|||
|
||||
def get_numerical_jacobian_wrt_specific_input(
|
||||
fn, input_idx, inputs, outputs, eps, input=None, is_forward_ad=False
|
||||
) -> Tuple[torch.Tensor, ...]:
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
# Computes the numerical jacobians wrt to a single input. Returns N jacobian
|
||||
# tensors, where N is the number of outputs. We use a dictionary for
|
||||
# jacobian_cols because indices aren't necessarily consecutive for sparse inputs
|
||||
# When we perturb only a single element of the input tensor at a time, the jvp
|
||||
# is equivalent to a single col of the Jacobian matrix of fn.
|
||||
jacobian_cols: Dict[int, List[torch.Tensor]] = {}
|
||||
jacobian_cols: dict[int, list[torch.Tensor]] = {}
|
||||
input = inputs[input_idx] if input is None else input
|
||||
assert input.requires_grad
|
||||
for x, idx, d_idx in _iter_tensor(input):
|
||||
|
|
@ -493,7 +494,7 @@ def get_numerical_jacobian_wrt_specific_input(
|
|||
|
||||
def _get_analytical_jacobian_forward_ad(
|
||||
fn, inputs, outputs, *, check_grad_dtypes=False, all_u=None
|
||||
) -> Tuple[Tuple[torch.Tensor, ...], ...]:
|
||||
) -> tuple[tuple[torch.Tensor, ...], ...]:
|
||||
"""Compute the analytical Jacobian using forward mode AD of `fn(inputs)` using forward mode AD with respect to `target`.
|
||||
|
||||
Return N * M Jacobians where N is the number of tensors in target that require grad and
|
||||
|
|
@ -659,7 +660,7 @@ def _mul_tensor_or_tuple(u, k):
|
|||
|
||||
def _get_numerical_jvp_wrt_specific_input(
|
||||
fn, input_idx, inputs, u, eps, is_forward_ad=False
|
||||
) -> List[torch.Tensor]:
|
||||
) -> list[torch.Tensor]:
|
||||
input = inputs[input_idx]
|
||||
input_to_perturb = _get_input_to_perturb(input)
|
||||
wrapped_fn = _with_prepare_inputs(fn, inputs, input_idx, input_to_perturb, True)
|
||||
|
|
@ -676,7 +677,7 @@ def _get_numerical_vJu(
|
|||
fn, inputs, inp_indices, func_out, all_u, all_v, eps, is_forward_ad
|
||||
):
|
||||
# Note that all_v can also be None, in that case, this function only computes Ju.
|
||||
reduced_jacobians: List[List[torch.Tensor]] = []
|
||||
reduced_jacobians: list[list[torch.Tensor]] = []
|
||||
for inp_idx, u in zip(inp_indices, all_u):
|
||||
all_Ju = _get_numerical_jvp_wrt_specific_input(
|
||||
fn, inp_idx, inputs, u, eps, is_forward_ad
|
||||
|
|
@ -692,7 +693,7 @@ def _get_numerical_vJu(
|
|||
# TODO: handle the other Ju
|
||||
pass
|
||||
if all_v is not None:
|
||||
jacobian_scalars: List[torch.Tensor] = []
|
||||
jacobian_scalars: list[torch.Tensor] = []
|
||||
for v, Ju in zip(all_v, filtered_Ju):
|
||||
jacobian_scalars.append(_dot_with_type_promotion(v, Ju))
|
||||
reduced_jacobians.append(jacobian_scalars)
|
||||
|
|
@ -712,7 +713,7 @@ def _check_jacobians_equal(j1, j2, atol):
|
|||
|
||||
def _stack_and_check_tensors(
|
||||
list_of_list_of_tensors, inputs, numel_outputs
|
||||
) -> Tuple[Tuple[torch.Tensor, ...], bool, bool]:
|
||||
) -> tuple[tuple[torch.Tensor, ...], bool, bool]:
|
||||
# For the ith tensor in the inner list checks whether it has the same size and
|
||||
# dtype as the ith differentiable input.
|
||||
out_jacobians = _allocate_jacobians_with_inputs(inputs, numel_outputs)
|
||||
|
|
@ -757,7 +758,7 @@ If the test
|
|||
|
||||
def _check_analytical_jacobian_attributes(
|
||||
inputs, output, nondet_tol, check_grad_dtypes, fast_mode=False, v=None
|
||||
) -> Tuple[torch.Tensor, ...]:
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
# This is used by both fast and slow mode:
|
||||
# - For slow mode, vjps[i][j] is the jth row of the Jacobian wrt the ith
|
||||
# input.
|
||||
|
|
@ -802,12 +803,12 @@ def _check_analytical_jacobian_attributes(
|
|||
def _get_analytical_vJu_backward_mode(
|
||||
inputs, outputs, nondet_tol, check_grad_dtypes, all_v, all_u
|
||||
):
|
||||
reduced_jacobians: List[List[torch.Tensor]] = []
|
||||
reduced_jacobians: list[list[torch.Tensor]] = []
|
||||
for output, v in zip(outputs, all_v):
|
||||
all_vJ = _check_analytical_jacobian_attributes(
|
||||
inputs, output, nondet_tol, check_grad_dtypes, fast_mode=True, v=v
|
||||
)
|
||||
jacobian_scalars: List[torch.Tensor] = []
|
||||
jacobian_scalars: list[torch.Tensor] = []
|
||||
for vJ, u in zip(all_vJ, all_u):
|
||||
# Why do we need squeeze here? vJ is a 2-d tensor so that we can reuse
|
||||
# the error checking logic from slow mode
|
||||
|
|
@ -878,7 +879,7 @@ def _get_analytical_jacobian(inputs, outputs, input_idx, output_idx):
|
|||
|
||||
def _compute_analytical_jacobian_rows(
|
||||
vjp_fn, sample_output
|
||||
) -> List[List[Optional[torch.Tensor]]]:
|
||||
) -> list[list[Optional[torch.Tensor]]]:
|
||||
# Computes Jacobian row-by-row by projecting `vjp_fn` = v^T J on standard basis
|
||||
# vectors: vjp_fn(e) = e^T J is a corresponding row of the Jacobian.
|
||||
# NB: this function does not assume vjp_fn(v) to return tensors with the same
|
||||
|
|
@ -889,7 +890,7 @@ def _compute_analytical_jacobian_rows(
|
|||
)
|
||||
flat_grad_out = grad_out_base.view(-1)
|
||||
# jacobians_rows[i][j] is the Jacobian jth row for the ith input
|
||||
jacobians_rows: List[List[Optional[torch.Tensor]]] = []
|
||||
jacobians_rows: list[list[Optional[torch.Tensor]]] = []
|
||||
for j in range(flat_grad_out.numel()):
|
||||
flat_grad_out.zero_()
|
||||
flat_grad_out[j] = 1.0 # projection for jth row of Jacobian
|
||||
|
|
@ -905,9 +906,9 @@ def _compute_analytical_jacobian_rows(
|
|||
|
||||
def _get_analytical_vjps_wrt_specific_output(
|
||||
vjp_fn, sample_output, v
|
||||
) -> List[List[Optional[torch.Tensor]]]:
|
||||
) -> list[list[Optional[torch.Tensor]]]:
|
||||
grad_inputs = vjp_fn(v.reshape(sample_output.shape))
|
||||
vjps: List[List[Optional[torch.Tensor]]] = [
|
||||
vjps: list[list[Optional[torch.Tensor]]] = [
|
||||
[vjp.clone() if isinstance(vjp, torch.Tensor) else None] for vjp in grad_inputs
|
||||
]
|
||||
return vjps
|
||||
|
|
@ -1174,7 +1175,7 @@ def _test_batched_grad(input, output, output_idx) -> bool:
|
|||
|
||||
def _test_backward_mul_by_grad_output(outputs, inputs, masked) -> bool:
|
||||
# Tests that backward is multiplied by grad_output
|
||||
diff_input_list: List[torch.Tensor] = list(_iter_tensors(inputs, True))
|
||||
diff_input_list: list[torch.Tensor] = list(_iter_tensors(inputs, True))
|
||||
if not diff_input_list:
|
||||
raise GradcheckError("no Tensors requiring grad found in input")
|
||||
grads_input = torch.autograd.grad(
|
||||
|
|
@ -1294,7 +1295,7 @@ def _test_undefined_forward_mode(func, outputs, inputs):
|
|||
|
||||
|
||||
def _test_undefined_backward_mode(func, outputs, inputs) -> bool:
|
||||
diff_input_list: List[torch.Tensor] = list(_iter_tensors(inputs, True))
|
||||
diff_input_list: list[torch.Tensor] = list(_iter_tensors(inputs, True))
|
||||
if not diff_input_list:
|
||||
raise GradcheckError("no Tensors requiring grad found in input")
|
||||
|
||||
|
|
|
|||
|
|
@ -4,23 +4,15 @@ import functools
|
|||
import logging
|
||||
import threading
|
||||
from collections import defaultdict, deque
|
||||
from collections.abc import Generator, Iterable, Iterator, MutableMapping, Sequence
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
cast,
|
||||
Deque,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Literal,
|
||||
MutableMapping,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
TYPE_CHECKING,
|
||||
Union,
|
||||
)
|
||||
|
|
@ -71,7 +63,7 @@ class Node(abc.ABC):
|
|||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def next_functions(self) -> Tuple[Tuple[Optional["Node"], int], ...]:
|
||||
def next_functions(self) -> tuple[tuple[Optional["Node"], int], ...]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
|
|
@ -81,7 +73,7 @@ class Node(abc.ABC):
|
|||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def _input_metadata(self) -> List[Any]:
|
||||
def _input_metadata(self) -> list[Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
|
|
@ -367,7 +359,7 @@ class save_on_cpu(saved_tensors_hooks):
|
|||
def __init__(self, pin_memory: bool = False, device_type: str = "cuda") -> None:
|
||||
device_module = getattr(torch, device_type, torch.cuda)
|
||||
|
||||
def pack_to_cpu(tensor: torch.Tensor) -> Tuple[torch.device, torch.Tensor]:
|
||||
def pack_to_cpu(tensor: torch.Tensor) -> tuple[torch.device, torch.Tensor]:
|
||||
if not pin_memory:
|
||||
return (tensor.device, tensor.cpu())
|
||||
packed = torch.empty(
|
||||
|
|
@ -379,7 +371,7 @@ class save_on_cpu(saved_tensors_hooks):
|
|||
packed.copy_(tensor)
|
||||
return (tensor.device, packed)
|
||||
|
||||
def unpack_from_cpu(packed: Tuple[torch.device, torch.Tensor]) -> torch.Tensor:
|
||||
def unpack_from_cpu(packed: tuple[torch.device, torch.Tensor]) -> torch.Tensor:
|
||||
device, tensor = packed
|
||||
return tensor.to(device, non_blocking=pin_memory)
|
||||
|
||||
|
|
@ -423,19 +415,19 @@ def disable_saved_tensors_hooks(error_message: str) -> Generator[None, None, Non
|
|||
|
||||
|
||||
class _MultiHandle(RemovableHandle):
|
||||
handles: Tuple[RemovableHandle, ...]
|
||||
handles: tuple[RemovableHandle, ...]
|
||||
|
||||
def __init__(self, handles: Tuple[RemovableHandle, ...]) -> None:
|
||||
def __init__(self, handles: tuple[RemovableHandle, ...]) -> None:
|
||||
self.handles = handles
|
||||
|
||||
def remove(self) -> None:
|
||||
for handle in self.handles:
|
||||
handle.remove()
|
||||
|
||||
def __getstate__(self) -> Tuple[RemovableHandle, ...]:
|
||||
def __getstate__(self) -> tuple[RemovableHandle, ...]:
|
||||
return self.handles
|
||||
|
||||
def __setstate__(self, state: Tuple[RemovableHandle, ...]) -> None:
|
||||
def __setstate__(self, state: tuple[RemovableHandle, ...]) -> None:
|
||||
self.handles = state
|
||||
|
||||
|
||||
|
|
@ -502,9 +494,9 @@ def register_multi_grad_hook(
|
|||
raise ValueError(f"Expects mode to be one of {supported_modes} but got {mode}")
|
||||
|
||||
if mode == "all":
|
||||
count: Dict[int, int] = {}
|
||||
count: dict[int, int] = {}
|
||||
nb_calls = None
|
||||
buffer: Dict[int, List[Optional[torch.Tensor]]] = {}
|
||||
buffer: dict[int, list[Optional[torch.Tensor]]] = {}
|
||||
|
||||
grad_fns = list(map(_get_grad_fn_or_grad_acc, tensors))
|
||||
len_tensors = len(tensors)
|
||||
|
|
@ -544,7 +536,7 @@ def register_multi_grad_hook(
|
|||
)
|
||||
elif mode == "any":
|
||||
fn = cast(Callable[[torch.Tensor], None], fn)
|
||||
ran_hook: Dict[int, bool] = defaultdict(bool)
|
||||
ran_hook: dict[int, bool] = defaultdict(bool)
|
||||
|
||||
@functools.wraps(fn)
|
||||
def wrapped_fn(grad: torch.Tensor) -> None:
|
||||
|
|
@ -582,8 +574,8 @@ def register_multi_grad_hook(
|
|||
_allow_mutation_on_saved_tensors_enabled: bool = False
|
||||
|
||||
|
||||
_TID: TypeAlias = Tuple[int, int, int]
|
||||
_SID: TypeAlias = Tuple[int, int]
|
||||
_TID: TypeAlias = tuple[int, int, int]
|
||||
_SID: TypeAlias = tuple[int, int]
|
||||
|
||||
|
||||
def _get_tid(tensor: torch.Tensor) -> _TID:
|
||||
|
|
@ -665,8 +657,8 @@ class _CloneArgBeforeMutateMode(TorchDispatchMode):
|
|||
self,
|
||||
func: "OpOverload",
|
||||
types: Iterable[type],
|
||||
args: Tuple[Any, ...] = (),
|
||||
kwargs: Optional[Dict[Any, Any]] = None,
|
||||
args: tuple[Any, ...] = (),
|
||||
kwargs: Optional[dict[Any, Any]] = None,
|
||||
) -> Any:
|
||||
kwargs = kwargs or {}
|
||||
|
||||
|
|
@ -704,7 +696,7 @@ class _AllowMutationOnSavedContext:
|
|||
self.cloned: MutableMapping[_Handle, torch.Tensor] = WeakKeyDictionary()
|
||||
self.original: MutableMapping[_Handle, torch.Tensor] = WeakKeyDictionary()
|
||||
self.tid_to_weakhandle: MutableMapping[_TID, _Handle] = WeakValueDictionary()
|
||||
self.sid_to_tid: Dict[_SID, Set[_TID]] = defaultdict(set)
|
||||
self.sid_to_tid: dict[_SID, set[_TID]] = defaultdict(set)
|
||||
|
||||
def clear(self) -> None:
|
||||
self.cloned.clear()
|
||||
|
|
@ -768,10 +760,10 @@ def _register_logging_hooks_on_whole_graph(
|
|||
) -> Callable[[], None]:
|
||||
grad_fns = list(map(_get_grad_fn_or_grad_acc, t_outputs))
|
||||
|
||||
def iter_graph(roots: List[Node]) -> Iterator[Node]:
|
||||
def iter_graph(roots: list[Node]) -> Iterator[Node]:
|
||||
if not roots:
|
||||
return
|
||||
seen: Set[Node] = set()
|
||||
seen: set[Node] = set()
|
||||
q: Deque[Node] = deque()
|
||||
for node in roots:
|
||||
if node is not None:
|
||||
|
|
@ -815,7 +807,7 @@ def _engine_run_backward(
|
|||
t_outputs: Sequence[Union[torch.Tensor, GradientEdge]],
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Tuple[torch.Tensor, ...]:
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
attach_logging_hooks = log.getEffectiveLevel() <= logging.DEBUG
|
||||
if attach_logging_hooks:
|
||||
unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs)
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from time import perf_counter_ns
|
||||
from typing import Any, Dict, Iterable, List, Optional
|
||||
from typing import Any, Optional
|
||||
from warnings import warn
|
||||
|
||||
import torch
|
||||
|
|
@ -557,7 +558,7 @@ class profile:
|
|||
# frontend_function_events contains the events in aten or torch frontend level,
|
||||
# whose correlation id is 0
|
||||
frontend_function_events = []
|
||||
device_corr_map: Dict[int, List[FunctionEvent]] = {}
|
||||
device_corr_map: dict[int, list[FunctionEvent]] = {}
|
||||
max_evt_id = 0
|
||||
for kineto_event in result.events():
|
||||
if _filter_name(kineto_event.name()):
|
||||
|
|
@ -1141,7 +1142,7 @@ class KinetoStepTracker:
|
|||
"""
|
||||
|
||||
_current_step = 0
|
||||
_step_dict: Dict[str, int] = defaultdict(int)
|
||||
_step_dict: dict[str, int] = defaultdict(int)
|
||||
|
||||
@classmethod
|
||||
def init_step_count(cls, requester: str):
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import itertools
|
|||
import math
|
||||
from collections import defaultdict, namedtuple
|
||||
from operator import attrgetter
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Optional
|
||||
from typing_extensions import deprecated
|
||||
|
||||
import torch
|
||||
|
|
@ -114,7 +114,7 @@ class EventList(list):
|
|||
thread_events,
|
||||
key=lambda event: [event.time_range.start, -event.time_range.end],
|
||||
)
|
||||
current_events: List[FunctionEvent] = []
|
||||
current_events: list[FunctionEvent] = []
|
||||
for event in thread_events_:
|
||||
while len(current_events) > 0:
|
||||
parent = current_events[-1]
|
||||
|
|
@ -310,9 +310,9 @@ class EventList(list):
|
|||
An EventList containing FunctionEventAvg objects.
|
||||
"""
|
||||
assert self._tree_built
|
||||
stats: Dict[Tuple[str, ...], FunctionEventAvg] = defaultdict(FunctionEventAvg)
|
||||
stats: dict[tuple[str, ...], FunctionEventAvg] = defaultdict(FunctionEventAvg)
|
||||
|
||||
def get_key(event, group_by_input_shapes, group_by_stack_n) -> Tuple[str, ...]:
|
||||
def get_key(event, group_by_input_shapes, group_by_stack_n) -> tuple[str, ...]:
|
||||
key = [
|
||||
str(event.key),
|
||||
str(event.node_id),
|
||||
|
|
@ -476,14 +476,14 @@ class FunctionEvent(FormattedTimesMixin):
|
|||
self.time_range: Interval = Interval(start_us, end_us)
|
||||
self.thread: int = thread
|
||||
self.fwd_thread: Optional[int] = fwd_thread
|
||||
self.kernels: List[Kernel] = []
|
||||
self.kernels: list[Kernel] = []
|
||||
self.count: int = 1
|
||||
self.cpu_children: List[FunctionEvent] = []
|
||||
self.cpu_children: list[FunctionEvent] = []
|
||||
self.cpu_parent: Optional[FunctionEvent] = None
|
||||
self.input_shapes: Tuple[int, ...] = input_shapes
|
||||
self.concrete_inputs: List[Any] = concrete_inputs
|
||||
self.kwinputs: Dict[str, Any] = kwinputs
|
||||
self.stack: List = stack
|
||||
self.input_shapes: tuple[int, ...] = input_shapes
|
||||
self.concrete_inputs: list[Any] = concrete_inputs
|
||||
self.kwinputs: dict[str, Any] = kwinputs
|
||||
self.stack: list = stack
|
||||
self.scope: int = scope
|
||||
self.use_device: Optional[str] = use_device
|
||||
self.cpu_memory_usage: int = cpu_memory_usage
|
||||
|
|
@ -656,14 +656,14 @@ class FunctionEventAvg(FormattedTimesMixin):
|
|||
self.device_time_total: int = 0
|
||||
self.self_cpu_time_total: int = 0
|
||||
self.self_device_time_total: int = 0
|
||||
self.input_shapes: Optional[List[List[int]]] = None
|
||||
self.stack: Optional[List] = None
|
||||
self.input_shapes: Optional[list[list[int]]] = None
|
||||
self.stack: Optional[list] = None
|
||||
self.scope: Optional[int] = None
|
||||
self.cpu_memory_usage: int = 0
|
||||
self.device_memory_usage: int = 0
|
||||
self.self_cpu_memory_usage: int = 0
|
||||
self.self_device_memory_usage: int = 0
|
||||
self.cpu_children: Optional[List[FunctionEvent]] = None
|
||||
self.cpu_children: Optional[list[FunctionEvent]] = None
|
||||
self.cpu_parent: Optional[FunctionEvent] = None
|
||||
self.device_type: DeviceType = DeviceType.CPU
|
||||
self.is_legacy: bool = False
|
||||
|
|
@ -734,8 +734,8 @@ class MemRecordsAcc:
|
|||
|
||||
def __init__(self, mem_records):
|
||||
self._mem_records = mem_records
|
||||
self._start_nses: List[int] = []
|
||||
self._indices: List[int] = []
|
||||
self._start_nses: list[int] = []
|
||||
self._indices: list[int] = []
|
||||
if len(mem_records) > 0:
|
||||
tmp = sorted([(r[0].start_ns(), i) for i, r in enumerate(mem_records)])
|
||||
self._start_nses, self._indices = zip(*tmp) # type: ignore[assignment]
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
# mypy: allow-untyped-defs
|
||||
""" This module contains functions and classes that alter the behavior of torch.nn.functional.scaled_dot_product_attention """
|
||||
import contextlib
|
||||
from typing import Iterable, List, Union
|
||||
from collections.abc import Iterable
|
||||
from typing import List, Union
|
||||
from warnings import warn
|
||||
|
||||
import torch.backends.cuda
|
||||
|
|
@ -13,7 +14,7 @@ from torch.backends.cuda import (
|
|||
)
|
||||
|
||||
|
||||
__all__: List[str] = ["SDPBackend", "sdpa_kernel", "WARN_FOR_UNFUSED_KERNELS"]
|
||||
__all__: list[str] = ["SDPBackend", "sdpa_kernel", "WARN_FOR_UNFUSED_KERNELS"]
|
||||
|
||||
# Note: [SDPA warnings]
|
||||
# TODO: Consider using this for sdpa regardless of subclasses
|
||||
|
|
@ -73,7 +74,7 @@ def _backend_from_string(name: str):
|
|||
|
||||
|
||||
def _cur_sdpa_kernel_backends():
|
||||
backends: List[SDPBackend] = []
|
||||
backends: list[SDPBackend] = []
|
||||
for name, val in _backend_names.items():
|
||||
if getattr(torch.backends.cuda, f"{name}_sdp_enabled")():
|
||||
backends.append(getattr(SDPBackend, val))
|
||||
|
|
@ -88,7 +89,7 @@ def _sdpa_kernel(backends: Iterable[SDPBackend]):
|
|||
|
||||
@contextlib.contextmanager
|
||||
def sdpa_kernel(
|
||||
backends: Union[List[SDPBackend], SDPBackend], set_priority: bool = False
|
||||
backends: Union[list[SDPBackend], SDPBackend], set_priority: bool = False
|
||||
):
|
||||
r"""
|
||||
Context manager to select which backend to use for scaled dot product attention.
|
||||
|
|
|
|||
|
|
@ -1,12 +1,12 @@
|
|||
# mypy: allow-untyped-defs
|
||||
"""Defines utilities for interacting with scaled_dot_product_attention"""
|
||||
import math
|
||||
from typing import List, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
__all__: List[str] = []
|
||||
__all__: list[str] = []
|
||||
|
||||
|
||||
def _input_requires_grad(*tensors: torch.Tensor) -> bool:
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ import math
|
|||
import operator
|
||||
import warnings
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
|
@ -75,9 +75,9 @@ def _get_mod_type(fn: Callable) -> _ModificationType:
|
|||
# Need to define it here so that Dynamo doesn't skip it
|
||||
def _vmap_for_bhqkv(
|
||||
fn: Callable,
|
||||
prefix: Tuple[Optional[int], ...],
|
||||
suffix: Tuple[Optional[int], ...] = (),
|
||||
out_dims: Union[int, List[Optional[int]]] = 0,
|
||||
prefix: tuple[Optional[int], ...],
|
||||
suffix: tuple[Optional[int], ...] = (),
|
||||
out_dims: Union[int, list[Optional[int]]] = 0,
|
||||
group_dim: bool = False,
|
||||
):
|
||||
"""Used to vmap both score_mods and mask_mods over 4-dimensional/5-dimension inputs.
|
||||
|
|
@ -98,7 +98,7 @@ def _vmap_for_bhqkv(
|
|||
callable: The vmapped function.
|
||||
"""
|
||||
# We vamp a function 4 times, broadcasting the [b, h, q_idx, kv_idx] dimensions
|
||||
dimensions: List[Tuple[None | int, None | int, None | int, None | int]] = []
|
||||
dimensions: list[tuple[None | int, None | int, None | int, None | int]] = []
|
||||
dimensions = [
|
||||
(None, None, None, 0),
|
||||
(None, None, 0, None),
|
||||
|
|
@ -173,7 +173,7 @@ def _ordered_to_dense(num_blocks_in_row: Tensor, col_indices: Tensor):
|
|||
return out
|
||||
|
||||
|
||||
def _dense_to_ordered(dense_mask) -> Tuple[Tensor, Tensor]:
|
||||
def _dense_to_ordered(dense_mask) -> tuple[Tensor, Tensor]:
|
||||
dense_mask = dense_mask.to(dtype=torch.int32)
|
||||
num_blocks_in_row = dense_mask.sum(dim=-1)
|
||||
col_indices = torch.argsort(dense_mask, dim=-1, descending=True, stable=True)
|
||||
|
|
@ -262,7 +262,7 @@ class BlockMask:
|
|||
the backwards pass. These are autogenerated from 2.
|
||||
"""
|
||||
|
||||
seq_lengths: Tuple[int, int]
|
||||
seq_lengths: tuple[int, int]
|
||||
kv_num_blocks: Tensor
|
||||
kv_indices: Tensor
|
||||
full_kv_num_blocks: Optional[Tensor]
|
||||
|
|
@ -271,12 +271,12 @@ class BlockMask:
|
|||
q_indices: Optional[Tensor]
|
||||
full_q_num_blocks: Optional[Tensor]
|
||||
full_q_indices: Optional[Tensor]
|
||||
BLOCK_SIZE: Tuple[int, int]
|
||||
BLOCK_SIZE: tuple[int, int]
|
||||
mask_mod: _mask_mod_signature
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
seq_lengths: Tuple[int, int],
|
||||
seq_lengths: tuple[int, int],
|
||||
kv_num_blocks: Tensor,
|
||||
kv_indices: Tensor,
|
||||
full_kv_num_blocks: Optional[Tensor],
|
||||
|
|
@ -285,7 +285,7 @@ class BlockMask:
|
|||
q_indices: Optional[Tensor],
|
||||
full_q_num_blocks: Optional[Tensor],
|
||||
full_q_indices: Optional[Tensor],
|
||||
BLOCK_SIZE: Tuple[int, int],
|
||||
BLOCK_SIZE: tuple[int, int],
|
||||
mask_mod: _mask_mod_signature,
|
||||
):
|
||||
if kv_indices.dim() < 2:
|
||||
|
|
@ -320,9 +320,9 @@ class BlockMask:
|
|||
kv_indices: Tensor,
|
||||
full_kv_num_blocks: Optional[Tensor] = None,
|
||||
full_kv_indices: Optional[Tensor] = None,
|
||||
BLOCK_SIZE: Union[int, Tuple[int, int]] = _DEFAULT_SPARSE_BLOCK_SIZE,
|
||||
BLOCK_SIZE: Union[int, tuple[int, int]] = _DEFAULT_SPARSE_BLOCK_SIZE,
|
||||
mask_mod: Optional[_mask_mod_signature] = None,
|
||||
seq_lengths: Optional[Tuple[int, int]] = None,
|
||||
seq_lengths: Optional[tuple[int, int]] = None,
|
||||
):
|
||||
"""
|
||||
Creates a BlockMask instance from key-value block information.
|
||||
|
|
@ -332,7 +332,7 @@ class BlockMask:
|
|||
kv_indices (Tensor): Indices of key-value blocks in each Q_BLOCK_SIZE row tile.
|
||||
full_kv_num_blocks (Optional[Tensor]): Number of full kv_blocks in each Q_BLOCK_SIZE row tile.
|
||||
full_kv_indices (Optional[Tensor]): Indices of full key-value blocks in each Q_BLOCK_SIZE row tile.
|
||||
BLOCK_SIZE (Union[int, Tuple[int, int]]): Size of KV_BLOCK_SIZE x Q_BLOCK_SIZE tiles.
|
||||
BLOCK_SIZE (Union[int, tuple[int, int]]): Size of KV_BLOCK_SIZE x Q_BLOCK_SIZE tiles.
|
||||
mask_mod (Optional[Callable]): Function to modify the mask.
|
||||
|
||||
Returns:
|
||||
|
|
@ -664,7 +664,7 @@ def _convert_mask_to_block_mask(
|
|||
Q_BLOCK_SIZE=_DEFAULT_SPARSE_BLOCK_SIZE,
|
||||
KV_BLOCK_SIZE=_DEFAULT_SPARSE_BLOCK_SIZE,
|
||||
separate_full_blocks: bool = False,
|
||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
) -> tuple[Tensor, Optional[Tensor]]:
|
||||
assert mask.dtype == torch.bool
|
||||
mask = _broadcast_to_dim(mask, 4)
|
||||
|
||||
|
|
@ -748,9 +748,9 @@ def _convert_block_mask_to_mask(
|
|||
|
||||
|
||||
def _create_sparse_block_from_block_mask(
|
||||
block_mask: Tuple[Tensor, Optional[Tensor]],
|
||||
block_mask: tuple[Tensor, Optional[Tensor]],
|
||||
mask_mod: Optional[Callable],
|
||||
seq_lengths: Tuple[int, int],
|
||||
seq_lengths: tuple[int, int],
|
||||
Q_BLOCK_SIZE: int = _DEFAULT_SPARSE_BLOCK_SIZE,
|
||||
KV_BLOCK_SIZE: int = _DEFAULT_SPARSE_BLOCK_SIZE,
|
||||
) -> BlockMask:
|
||||
|
|
@ -758,7 +758,7 @@ def _create_sparse_block_from_block_mask(
|
|||
|
||||
partial_bm = _dense_to_ordered(partial_blocks)
|
||||
if full_blocks is not None:
|
||||
full_bm: Tuple[Optional[Tensor], Optional[Tensor]] = _dense_to_ordered(
|
||||
full_bm: tuple[Optional[Tensor], Optional[Tensor]] = _dense_to_ordered(
|
||||
full_blocks
|
||||
)
|
||||
else:
|
||||
|
|
@ -829,7 +829,7 @@ def create_block_mask(
|
|||
Q_LEN: int,
|
||||
KV_LEN: int,
|
||||
device: str = "cuda",
|
||||
BLOCK_SIZE: Union[int, Tuple[int, int]] = _DEFAULT_SPARSE_BLOCK_SIZE,
|
||||
BLOCK_SIZE: Union[int, tuple[int, int]] = _DEFAULT_SPARSE_BLOCK_SIZE,
|
||||
_compile=False,
|
||||
) -> BlockMask:
|
||||
r"""This function creates a block mask tuple from a mask_mod function.
|
||||
|
|
@ -845,7 +845,7 @@ def create_block_mask(
|
|||
Q_LEN (int): Sequence length of query.
|
||||
KV_LEN (int): Sequence length of key/value.
|
||||
device (str): Device to run the mask creation on.
|
||||
BLOCK_SIZE (int or Tuple[int, int]): Block size for the block mask. If a single int is provided it is used for both query and key/value.
|
||||
BLOCK_SIZE (int or tuple[int, int]): Block size for the block mask. If a single int is provided it is used for both query and key/value.
|
||||
|
||||
Returns:
|
||||
BlockMask: A BlockMask object that contains the block mask information.
|
||||
|
|
@ -1002,7 +1002,7 @@ def create_nested_block_mask(
|
|||
H: Optional[int],
|
||||
q_nt: torch.Tensor,
|
||||
kv_nt: Optional[torch.Tensor] = None,
|
||||
BLOCK_SIZE: Union[int, Tuple[int, int]] = _DEFAULT_SPARSE_BLOCK_SIZE,
|
||||
BLOCK_SIZE: Union[int, tuple[int, int]] = _DEFAULT_SPARSE_BLOCK_SIZE,
|
||||
_compile=False,
|
||||
) -> BlockMask:
|
||||
r"""This function creates a nested tensor compatible block mask tuple from a mask_mod
|
||||
|
|
@ -1024,7 +1024,7 @@ def create_nested_block_mask(
|
|||
constructed to operate on a "stacked sequence" of length ``sum(S)`` for sequence
|
||||
length ``S`` from the NJT. If this is None, ``q_nt`` is used to define the structure
|
||||
for key / value as well. Default: None
|
||||
BLOCK_SIZE (int or Tuple[int, int]): Block size for the block mask. If a single int is
|
||||
BLOCK_SIZE (int or tuple[int, int]): Block size for the block mask. If a single int is
|
||||
provided it is used for both query and key/value.
|
||||
|
||||
Returns:
|
||||
|
|
@ -1167,8 +1167,8 @@ def flex_attention(
|
|||
scale: Optional[float] = None,
|
||||
enable_gqa: bool = False,
|
||||
return_lse: bool = False,
|
||||
kernel_options: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
|
||||
kernel_options: Optional[dict[str, Any]] = None,
|
||||
) -> Union[Tensor, tuple[Tensor, Tensor]]:
|
||||
r"""This function implements scaled dot product attention with an arbitrary attention score modification function.
|
||||
|
||||
This function computes the scaled dot product attention between query, key, and value tensors with a user-defined
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Optional, Tuple, TypeVar, Union
|
||||
from typing import Optional, TypeVar, Union
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
|
|
@ -9,13 +9,13 @@ from torch import Tensor
|
|||
# broadcast to a tuple.
|
||||
# Comes in several variants: A tuple of unknown size, and a fixed-size tuple for 1d, 2d, or 3d operations.
|
||||
T = TypeVar("T")
|
||||
_scalar_or_tuple_any_t = Union[T, Tuple[T, ...]]
|
||||
_scalar_or_tuple_1_t = Union[T, Tuple[T]]
|
||||
_scalar_or_tuple_2_t = Union[T, Tuple[T, T]]
|
||||
_scalar_or_tuple_3_t = Union[T, Tuple[T, T, T]]
|
||||
_scalar_or_tuple_4_t = Union[T, Tuple[T, T, T, T]]
|
||||
_scalar_or_tuple_5_t = Union[T, Tuple[T, T, T, T, T]]
|
||||
_scalar_or_tuple_6_t = Union[T, Tuple[T, T, T, T, T, T]]
|
||||
_scalar_or_tuple_any_t = Union[T, tuple[T, ...]]
|
||||
_scalar_or_tuple_1_t = Union[T, tuple[T]]
|
||||
_scalar_or_tuple_2_t = Union[T, tuple[T, T]]
|
||||
_scalar_or_tuple_3_t = Union[T, tuple[T, T, T]]
|
||||
_scalar_or_tuple_4_t = Union[T, tuple[T, T, T, T]]
|
||||
_scalar_or_tuple_5_t = Union[T, tuple[T, T, T, T, T]]
|
||||
_scalar_or_tuple_6_t = Union[T, tuple[T, T, T, T, T, T]]
|
||||
|
||||
# For arguments which represent size parameters (eg, kernel size, padding)
|
||||
_size_any_t = _scalar_or_tuple_any_t[int]
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
import importlib
|
||||
import math
|
||||
import warnings
|
||||
from typing import Callable, List, Optional, Tuple, TYPE_CHECKING, Union
|
||||
from typing import Callable, List, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
from torch import _VF, sym_int as _sym_int, Tensor
|
||||
|
|
@ -440,7 +440,7 @@ def fractional_max_pool2d_with_indices(
|
|||
output_ratio: Optional[BroadcastingList2[float]] = None,
|
||||
return_indices: bool = False,
|
||||
_random_samples: Optional[Tensor] = None,
|
||||
) -> Tuple[Tensor, Tensor]: # noqa: D400
|
||||
) -> tuple[Tensor, Tensor]: # noqa: D400
|
||||
r"""
|
||||
fractional_max_pool2d(input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None)
|
||||
|
||||
|
|
@ -552,7 +552,7 @@ def fractional_max_pool3d_with_indices(
|
|||
output_ratio: Optional[BroadcastingList3[float]] = None,
|
||||
return_indices: bool = False,
|
||||
_random_samples: Optional[Tensor] = None,
|
||||
) -> Tuple[Tensor, Tensor]: # noqa: D400
|
||||
) -> tuple[Tensor, Tensor]: # noqa: D400
|
||||
r"""
|
||||
fractional_max_pool3d(input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None)
|
||||
|
||||
|
|
@ -669,7 +669,7 @@ def max_pool1d_with_indices(
|
|||
dilation: BroadcastingList1[int] = 1,
|
||||
ceil_mode: bool = False,
|
||||
return_indices: bool = False,
|
||||
) -> Tuple[Tensor, Tensor]: # noqa: D400
|
||||
) -> tuple[Tensor, Tensor]: # noqa: D400
|
||||
r"""
|
||||
max_pool1d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False)
|
||||
|
||||
|
|
@ -759,7 +759,7 @@ def max_pool2d_with_indices(
|
|||
dilation: BroadcastingList2[int] = 1,
|
||||
ceil_mode: bool = False,
|
||||
return_indices: bool = False,
|
||||
) -> Tuple[Tensor, Tensor]: # noqa: D400
|
||||
) -> tuple[Tensor, Tensor]: # noqa: D400
|
||||
r"""
|
||||
max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False)
|
||||
|
||||
|
|
@ -849,7 +849,7 @@ def max_pool3d_with_indices(
|
|||
dilation: BroadcastingList3[int] = 1,
|
||||
ceil_mode: bool = False,
|
||||
return_indices: bool = False,
|
||||
) -> Tuple[Tensor, Tensor]: # noqa: D400
|
||||
) -> tuple[Tensor, Tensor]: # noqa: D400
|
||||
r"""
|
||||
max_pool3d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False)
|
||||
|
||||
|
|
@ -933,11 +933,11 @@ max_pool3d = boolean_dispatch(
|
|||
|
||||
def _unpool_output_size(
|
||||
input: Tensor,
|
||||
kernel_size: List[int],
|
||||
stride: List[int],
|
||||
padding: List[int],
|
||||
output_size: Optional[List[int]],
|
||||
) -> List[int]:
|
||||
kernel_size: list[int],
|
||||
stride: list[int],
|
||||
padding: list[int],
|
||||
output_size: Optional[list[int]],
|
||||
) -> list[int]:
|
||||
input_size = input.size()
|
||||
default_size = torch.jit.annotate(List[int], [])
|
||||
for d in range(len(kernel_size)):
|
||||
|
|
@ -1187,7 +1187,7 @@ def adaptive_max_pool1d_with_indices(
|
|||
input: Tensor,
|
||||
output_size: BroadcastingList1[int],
|
||||
return_indices: bool = False,
|
||||
) -> Tuple[Tensor, Tensor]: # noqa: D400
|
||||
) -> tuple[Tensor, Tensor]: # noqa: D400
|
||||
r"""
|
||||
adaptive_max_pool1d(input, output_size, return_indices=False)
|
||||
|
||||
|
|
@ -1242,7 +1242,7 @@ def adaptive_max_pool2d_with_indices(
|
|||
input: Tensor,
|
||||
output_size: BroadcastingList2[int],
|
||||
return_indices: bool = False,
|
||||
) -> Tuple[Tensor, Tensor]: # noqa: D400
|
||||
) -> tuple[Tensor, Tensor]: # noqa: D400
|
||||
r"""adaptive_max_pool2d(input, output_size, return_indices=False)
|
||||
|
||||
Applies a 2D adaptive max pooling over an input signal composed of
|
||||
|
|
@ -1298,7 +1298,7 @@ def adaptive_max_pool3d_with_indices(
|
|||
input: Tensor,
|
||||
output_size: BroadcastingList3[int],
|
||||
return_indices: bool = False,
|
||||
) -> Tuple[Tensor, Tensor]: # noqa: D400
|
||||
) -> tuple[Tensor, Tensor]: # noqa: D400
|
||||
r"""
|
||||
adaptive_max_pool3d(input, output_size, return_indices=False)
|
||||
|
||||
|
|
@ -2430,7 +2430,7 @@ def _no_grad_embedding_renorm_(
|
|||
input: Tensor,
|
||||
max_norm: float,
|
||||
norm_type: float,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
torch.embedding_renorm_(weight.detach(), input, max_norm, norm_type)
|
||||
|
||||
|
||||
|
|
@ -2769,7 +2769,7 @@ if embedding_bag.__doc__:
|
|||
embedding_bag.__doc__ = embedding_bag.__doc__.format(**reproducibility_notes)
|
||||
|
||||
|
||||
def _verify_batch_size(size: List[int]) -> None:
|
||||
def _verify_batch_size(size: list[int]) -> None:
|
||||
# XXX: JIT script does not support the reduce from functools, and mul op is a
|
||||
# builtin, which cannot be used as a value to a func yet, so rewrite this size
|
||||
# check to a simple equivalent for loop
|
||||
|
|
@ -2832,7 +2832,7 @@ def batch_norm(
|
|||
)
|
||||
|
||||
|
||||
def _verify_spatial_size(size: List[int]) -> None:
|
||||
def _verify_spatial_size(size: list[int]) -> None:
|
||||
# Verify that there is > 1 spatial element for instance norm calculation.
|
||||
size_prods = 1
|
||||
for i in range(2, len(size)):
|
||||
|
|
@ -2888,7 +2888,7 @@ def instance_norm(
|
|||
|
||||
def layer_norm(
|
||||
input: Tensor,
|
||||
normalized_shape: List[int],
|
||||
normalized_shape: list[int],
|
||||
weight: Optional[Tensor] = None,
|
||||
bias: Optional[Tensor] = None,
|
||||
eps: float = 1e-5,
|
||||
|
|
@ -2914,7 +2914,7 @@ def layer_norm(
|
|||
|
||||
def rms_norm(
|
||||
input: Tensor,
|
||||
normalized_shape: List[int],
|
||||
normalized_shape: list[int],
|
||||
weight: Optional[Tensor] = None,
|
||||
eps: Optional[float] = None,
|
||||
) -> Tensor:
|
||||
|
|
@ -4301,7 +4301,7 @@ def upsample( # noqa: F811
|
|||
@_overload
|
||||
def upsample( # noqa: F811
|
||||
input: Tensor,
|
||||
size: Optional[List[int]] = None,
|
||||
size: Optional[list[int]] = None,
|
||||
scale_factor: Optional[float] = None,
|
||||
mode: str = "nearest",
|
||||
align_corners: Optional[bool] = None,
|
||||
|
|
@ -4402,7 +4402,7 @@ def _is_integer(x) -> bool:
|
|||
def interpolate( # noqa: F811
|
||||
input: Tensor,
|
||||
size: Optional[int] = None,
|
||||
scale_factor: Optional[List[float]] = None,
|
||||
scale_factor: Optional[list[float]] = None,
|
||||
mode: str = "nearest",
|
||||
align_corners: Optional[bool] = None,
|
||||
recompute_scale_factor: Optional[bool] = None,
|
||||
|
|
@ -4414,8 +4414,8 @@ def interpolate( # noqa: F811
|
|||
@_overload
|
||||
def interpolate( # noqa: F811
|
||||
input: Tensor,
|
||||
size: Optional[List[int]] = None,
|
||||
scale_factor: Optional[List[float]] = None,
|
||||
size: Optional[list[int]] = None,
|
||||
scale_factor: Optional[list[float]] = None,
|
||||
mode: str = "nearest",
|
||||
align_corners: Optional[bool] = None,
|
||||
recompute_scale_factor: Optional[bool] = None,
|
||||
|
|
@ -4440,7 +4440,7 @@ def interpolate( # noqa: F811
|
|||
@_overload
|
||||
def interpolate( # noqa: F811
|
||||
input: Tensor,
|
||||
size: Optional[List[int]] = None,
|
||||
size: Optional[list[int]] = None,
|
||||
scale_factor: Optional[float] = None,
|
||||
mode: str = "nearest",
|
||||
align_corners: Optional[bool] = None,
|
||||
|
|
@ -4453,7 +4453,7 @@ def interpolate( # noqa: F811
|
|||
def interpolate( # noqa: F811
|
||||
input: Tensor,
|
||||
size: Optional[int] = None,
|
||||
scale_factor: Optional[List[float]] = None,
|
||||
scale_factor: Optional[list[float]] = None,
|
||||
mode: str = "nearest",
|
||||
align_corners: Optional[bool] = None,
|
||||
recompute_scale_factor: Optional[bool] = None,
|
||||
|
|
@ -4744,7 +4744,7 @@ def upsample_nearest( # noqa: F811
|
|||
@_overload
|
||||
def upsample_nearest( # noqa: F811
|
||||
input: Tensor,
|
||||
size: Optional[List[int]] = None,
|
||||
size: Optional[list[int]] = None,
|
||||
scale_factor: Optional[float] = None,
|
||||
) -> Tensor:
|
||||
pass
|
||||
|
|
@ -4794,7 +4794,7 @@ def upsample_bilinear( # noqa: F811
|
|||
@_overload
|
||||
def upsample_bilinear( # noqa: F811
|
||||
input: Tensor,
|
||||
size: Optional[List[int]] = None,
|
||||
size: Optional[list[int]] = None,
|
||||
scale_factor: Optional[float] = None,
|
||||
) -> Tensor:
|
||||
pass
|
||||
|
|
@ -4804,7 +4804,7 @@ def upsample_bilinear( # noqa: F811
|
|||
def upsample_bilinear( # noqa: F811
|
||||
input: Tensor,
|
||||
size: Optional[int] = None,
|
||||
scale_factor: Optional[List[float]] = None,
|
||||
scale_factor: Optional[list[float]] = None,
|
||||
) -> Tensor:
|
||||
pass
|
||||
|
||||
|
|
@ -4812,8 +4812,8 @@ def upsample_bilinear( # noqa: F811
|
|||
@_overload
|
||||
def upsample_bilinear( # noqa: F811
|
||||
input: Tensor,
|
||||
size: Optional[List[int]] = None,
|
||||
scale_factor: Optional[List[float]] = None,
|
||||
size: Optional[list[int]] = None,
|
||||
scale_factor: Optional[list[float]] = None,
|
||||
) -> Tensor:
|
||||
pass
|
||||
|
||||
|
|
@ -5025,7 +5025,7 @@ def grid_sample(
|
|||
|
||||
def affine_grid(
|
||||
theta: Tensor,
|
||||
size: List[int],
|
||||
size: list[int],
|
||||
align_corners: Optional[bool] = None,
|
||||
) -> Tensor:
|
||||
r"""Generate 2D or 3D flow field (sampling grid), given a batch of affine matrices :attr:`theta`.
|
||||
|
|
@ -5127,7 +5127,7 @@ def affine_grid(
|
|||
|
||||
def pad(
|
||||
input: Tensor,
|
||||
pad: List[int],
|
||||
pad: list[int],
|
||||
mode: str = "constant",
|
||||
value: Optional[float] = None,
|
||||
) -> Tensor:
|
||||
|
|
@ -5490,7 +5490,7 @@ def normalize(
|
|||
return torch.div(input, denom, out=out)
|
||||
|
||||
|
||||
def assert_int_or_pair(arg: List[int], arg_name: str, message: str) -> None:
|
||||
def assert_int_or_pair(arg: list[int], arg_name: str, message: str) -> None:
|
||||
assert isinstance(arg, int) or len(arg) == 2, message.format(arg_name)
|
||||
|
||||
|
||||
|
|
@ -5579,7 +5579,7 @@ def _in_projection_packed(
|
|||
v: Tensor,
|
||||
w: Tensor,
|
||||
b: Optional[Tensor] = None,
|
||||
) -> List[Tensor]:
|
||||
) -> list[Tensor]:
|
||||
r"""Perform the in-projection step of the attention operation, using packed weights.
|
||||
|
||||
Output is a triple containing projection tensors for query, key and value.
|
||||
|
|
@ -5658,7 +5658,7 @@ def _in_projection(
|
|||
b_q: Optional[Tensor] = None,
|
||||
b_k: Optional[Tensor] = None,
|
||||
b_v: Optional[Tensor] = None,
|
||||
) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
) -> tuple[Tensor, Tensor, Tensor]:
|
||||
r"""Perform the in-projection step of the attention operation.
|
||||
|
||||
This is simply a triple of linear projections,
|
||||
|
|
@ -6019,7 +6019,7 @@ def multi_head_attention_forward(
|
|||
static_v: Optional[Tensor] = None,
|
||||
average_attn_weights: bool = True,
|
||||
is_causal: bool = False,
|
||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
) -> tuple[Tensor, Optional[Tensor]]:
|
||||
r"""Forward method for MultiHeadAttention.
|
||||
|
||||
.. note::
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import warnings
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
|
@ -1146,7 +1146,7 @@ class MultiheadAttention(Module):
|
|||
attn_mask: Optional[Tensor] = None,
|
||||
average_attn_weights: bool = True,
|
||||
is_causal: bool = False,
|
||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
) -> tuple[Tensor, Optional[Tensor]]:
|
||||
r"""Compute attention outputs using query, key, and value embeddings.
|
||||
|
||||
Supports optional parameters for padding, masks and attention weights.
|
||||
|
|
@ -1401,7 +1401,7 @@ class MultiheadAttention(Module):
|
|||
attn_mask: Optional[Tensor],
|
||||
key_padding_mask: Optional[Tensor],
|
||||
query: Tensor,
|
||||
) -> Tuple[Optional[Tensor], Optional[int]]:
|
||||
) -> tuple[Optional[Tensor], Optional[int]]:
|
||||
r"""Determine mask type and combine masks if necessary.
|
||||
|
||||
If only one mask is provided, that mask
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
# mypy: allow-untyped-defs
|
||||
|
||||
from collections import namedtuple
|
||||
from typing import List, Sequence
|
||||
from collections.abc import Sequence
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
|
@ -107,7 +107,7 @@ class AdaptiveLogSoftmaxWithLoss(Module):
|
|||
|
||||
in_features: int
|
||||
n_classes: int
|
||||
cutoffs: List[int]
|
||||
cutoffs: list[int]
|
||||
div_value: float
|
||||
head_bias: bool
|
||||
head: Linear
|
||||
|
|
|
|||
|
|
@ -2,19 +2,9 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import operator
|
||||
from collections import abc as container_abcs, OrderedDict
|
||||
from collections.abc import Iterable, Iterator, Mapping
|
||||
from itertools import chain, islice
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
Iterable,
|
||||
Iterator,
|
||||
Mapping,
|
||||
Optional,
|
||||
overload,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from typing import Any, Optional, overload, TypeVar, Union
|
||||
from typing_extensions import deprecated, Self
|
||||
|
||||
import torch
|
||||
|
|
@ -107,7 +97,7 @@ class Sequential(Module):
|
|||
]))
|
||||
"""
|
||||
|
||||
_modules: Dict[str, Module] # type: ignore[assignment]
|
||||
_modules: dict[str, Module] # type: ignore[assignment]
|
||||
|
||||
@overload
|
||||
def __init__(self, *args: Module) -> None:
|
||||
|
|
@ -302,7 +292,7 @@ class ModuleList(Module):
|
|||
return x
|
||||
"""
|
||||
|
||||
_modules: Dict[str, Module] # type: ignore[assignment]
|
||||
_modules: dict[str, Module] # type: ignore[assignment]
|
||||
|
||||
def __init__(self, modules: Optional[Iterable[Module]] = None) -> None:
|
||||
super().__init__()
|
||||
|
|
@ -490,7 +480,7 @@ class ModuleDict(Module):
|
|||
return x
|
||||
"""
|
||||
|
||||
_modules: Dict[str, Module] # type: ignore[assignment]
|
||||
_modules: dict[str, Module] # type: ignore[assignment]
|
||||
|
||||
def __init__(self, modules: Optional[Mapping[str, Module]] = None) -> None:
|
||||
super().__init__()
|
||||
|
|
@ -539,7 +529,7 @@ class ModuleDict(Module):
|
|||
return self._modules.keys()
|
||||
|
||||
@_copy_to_script_wrapper
|
||||
def items(self) -> Iterable[Tuple[str, Module]]:
|
||||
def items(self) -> Iterable[tuple[str, Module]]:
|
||||
r"""Return an iterable of the ModuleDict key/value pairs."""
|
||||
return self._modules.items()
|
||||
|
||||
|
|
@ -771,7 +761,7 @@ class ParameterDict(Module):
|
|||
|
||||
def __init__(self, parameters: Any = None) -> None:
|
||||
super().__init__()
|
||||
self._keys: Dict[str, None] = {}
|
||||
self._keys: dict[str, None] = {}
|
||||
if parameters is not None:
|
||||
self.update(parameters)
|
||||
|
||||
|
|
@ -855,7 +845,7 @@ class ParameterDict(Module):
|
|||
del self[key]
|
||||
return v
|
||||
|
||||
def popitem(self) -> Tuple[str, Any]:
|
||||
def popitem(self) -> tuple[str, Any]:
|
||||
"""Remove and return the last inserted `(key, parameter)` pair from the ParameterDict."""
|
||||
k, _ = self._keys.popitem()
|
||||
# We need the key in the _keys to be able to access/del
|
||||
|
|
@ -888,7 +878,7 @@ class ParameterDict(Module):
|
|||
r"""Return an iterable of the ParameterDict keys."""
|
||||
return self._keys.keys()
|
||||
|
||||
def items(self) -> Iterable[Tuple[str, Any]]:
|
||||
def items(self) -> Iterable[tuple[str, Any]]:
|
||||
r"""Return an iterable of the ParameterDict key/value pairs."""
|
||||
return ((k, self[k]) for k in self._keys)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import math
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import List, Optional, Union
|
||||
from typing_extensions import deprecated
|
||||
|
||||
import torch
|
||||
|
|
@ -70,14 +70,14 @@ class _ConvNd(Module):
|
|||
...
|
||||
|
||||
in_channels: int
|
||||
_reversed_padding_repeated_twice: List[int]
|
||||
_reversed_padding_repeated_twice: list[int]
|
||||
out_channels: int
|
||||
kernel_size: Tuple[int, ...]
|
||||
stride: Tuple[int, ...]
|
||||
padding: Union[str, Tuple[int, ...]]
|
||||
dilation: Tuple[int, ...]
|
||||
kernel_size: tuple[int, ...]
|
||||
stride: tuple[int, ...]
|
||||
padding: Union[str, tuple[int, ...]]
|
||||
dilation: tuple[int, ...]
|
||||
transposed: bool
|
||||
output_padding: Tuple[int, ...]
|
||||
output_padding: tuple[int, ...]
|
||||
groups: int
|
||||
padding_mode: str
|
||||
weight: Tensor
|
||||
|
|
@ -87,12 +87,12 @@ class _ConvNd(Module):
|
|||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Tuple[int, ...],
|
||||
stride: Tuple[int, ...],
|
||||
padding: Union[str, Tuple[int, ...]],
|
||||
dilation: Tuple[int, ...],
|
||||
kernel_size: tuple[int, ...],
|
||||
stride: tuple[int, ...],
|
||||
padding: Union[str, tuple[int, ...]],
|
||||
dilation: tuple[int, ...],
|
||||
transposed: bool,
|
||||
output_padding: Tuple[int, ...],
|
||||
output_padding: tuple[int, ...],
|
||||
groups: int,
|
||||
bias: bool,
|
||||
padding_mode: str,
|
||||
|
|
@ -768,13 +768,13 @@ class _ConvTransposeNd(_ConvNd):
|
|||
def _output_padding(
|
||||
self,
|
||||
input: Tensor,
|
||||
output_size: Optional[List[int]],
|
||||
stride: List[int],
|
||||
padding: List[int],
|
||||
kernel_size: List[int],
|
||||
output_size: Optional[list[int]],
|
||||
stride: list[int],
|
||||
padding: list[int],
|
||||
kernel_size: list[int],
|
||||
num_spatial_dims: int,
|
||||
dilation: Optional[List[int]] = None,
|
||||
) -> List[int]:
|
||||
dilation: Optional[list[int]] = None,
|
||||
) -> list[int]:
|
||||
if output_size is None:
|
||||
ret = _single(self.output_padding) # converting to list if was not already
|
||||
else:
|
||||
|
|
@ -952,7 +952,7 @@ class ConvTranspose1d(_ConvTransposeNd):
|
|||
**factory_kwargs,
|
||||
)
|
||||
|
||||
def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor:
|
||||
def forward(self, input: Tensor, output_size: Optional[list[int]] = None) -> Tensor:
|
||||
if self.padding_mode != "zeros":
|
||||
raise ValueError(
|
||||
"Only `zeros` padding mode is supported for ConvTranspose1d"
|
||||
|
|
@ -1139,7 +1139,7 @@ class ConvTranspose2d(_ConvTransposeNd):
|
|||
**factory_kwargs,
|
||||
)
|
||||
|
||||
def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor:
|
||||
def forward(self, input: Tensor, output_size: Optional[list[int]] = None) -> Tensor:
|
||||
if self.padding_mode != "zeros":
|
||||
raise ValueError(
|
||||
"Only `zeros` padding mode is supported for ConvTranspose2d"
|
||||
|
|
@ -1324,7 +1324,7 @@ class ConvTranspose3d(_ConvTransposeNd):
|
|||
**factory_kwargs,
|
||||
)
|
||||
|
||||
def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor:
|
||||
def forward(self, input: Tensor, output_size: Optional[list[int]] = None) -> Tensor:
|
||||
if self.padding_mode != "zeros":
|
||||
raise ValueError(
|
||||
"Only `zeros` padding mode is supported for ConvTranspose3d"
|
||||
|
|
@ -1391,7 +1391,7 @@ class _LazyConvXdMixin(LazyModuleMixin):
|
|||
transposed: bool
|
||||
in_channels: int
|
||||
out_channels: int
|
||||
kernel_size: Tuple[int, ...]
|
||||
kernel_size: tuple[int, ...]
|
||||
weight: UninitializedParameter
|
||||
bias: UninitializedParameter
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
# mypy: allow-untyped-defs
|
||||
from typing import Tuple, Union
|
||||
from typing import Union
|
||||
|
||||
from torch import Tensor
|
||||
from torch.types import _size
|
||||
|
|
@ -103,7 +103,7 @@ class Unflatten(Module):
|
|||
torch.Size([2, 2, 5, 5])
|
||||
"""
|
||||
|
||||
NamedShape = Tuple[Tuple[str, int]]
|
||||
NamedShape = tuple[tuple[str, int]]
|
||||
|
||||
__constants__ = ["dim", "unflattened_size"]
|
||||
dim: Union[int, str]
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import itertools
|
||||
from typing import Any, Optional, Protocol, Type
|
||||
from typing import Any, Optional, Protocol
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import is_lazy
|
||||
|
|
@ -180,7 +180,7 @@ class LazyModuleMixin:
|
|||
|
||||
# modules inheriting from this will change their __class__ to the specified
|
||||
# one after they are fully initialized
|
||||
cls_to_become: Optional[Type[Any]] = None
|
||||
cls_to_become: Optional[type[Any]] = None
|
||||
|
||||
def __init__(self: _LazyProtocol, *args, **kwargs):
|
||||
# Mypy doesnt like this super call in a mixin
|
||||
|
|
|
|||
|
|
@ -6,20 +6,8 @@ import itertools
|
|||
import warnings
|
||||
import weakref
|
||||
from collections import namedtuple, OrderedDict
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
overload,
|
||||
Set,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from collections.abc import Iterator, Mapping
|
||||
from typing import Any, Callable, Optional, overload, TypeVar, Union
|
||||
from typing_extensions import Self
|
||||
|
||||
import torch
|
||||
|
|
@ -42,7 +30,7 @@ __all__ = [
|
|||
"Module",
|
||||
]
|
||||
|
||||
_grad_t = Union[Tuple[Tensor, ...], Tensor]
|
||||
_grad_t = Union[tuple[Tensor, ...], Tensor]
|
||||
# See https://mypy.readthedocs.io/en/latest/generics.html#generic-methods-and-generic-self for the use
|
||||
# of `T` to annotate `self`. Many methods of `Module` return `self` and we want those return values to be
|
||||
# the type of the subclass, not the looser type of `Module`.
|
||||
|
|
@ -74,9 +62,9 @@ def _addindent(s_, numSpaces):
|
|||
|
||||
r"""This tracks hooks common to all modules that are executed immediately before
|
||||
.registering the buffer/module/parameter"""
|
||||
_global_buffer_registration_hooks: Dict[int, Callable] = OrderedDict()
|
||||
_global_module_registration_hooks: Dict[int, Callable] = OrderedDict()
|
||||
_global_parameter_registration_hooks: Dict[int, Callable] = OrderedDict()
|
||||
_global_buffer_registration_hooks: dict[int, Callable] = OrderedDict()
|
||||
_global_module_registration_hooks: dict[int, Callable] = OrderedDict()
|
||||
_global_parameter_registration_hooks: dict[int, Callable] = OrderedDict()
|
||||
|
||||
|
||||
class _WrappedHook:
|
||||
|
|
@ -98,14 +86,14 @@ class _WrappedHook:
|
|||
return self.hook(module, *args, **kwargs)
|
||||
return self.hook(*args, **kwargs)
|
||||
|
||||
def __getstate__(self) -> Dict:
|
||||
def __getstate__(self) -> dict:
|
||||
result = {"hook": self.hook, "with_module": self.with_module}
|
||||
if self.with_module:
|
||||
result["module"] = self.module()
|
||||
|
||||
return result
|
||||
|
||||
def __setstate__(self, state: Dict):
|
||||
def __setstate__(self, state: dict):
|
||||
self.hook = state["hook"]
|
||||
self.with_module = state["with_module"]
|
||||
|
||||
|
|
@ -120,13 +108,13 @@ class _WrappedHook:
|
|||
r"""This tracks hooks common to all modules that are executed before/after
|
||||
calling forward and backward. This is global state used for debugging/profiling
|
||||
purposes"""
|
||||
_global_backward_pre_hooks: Dict[int, Callable] = OrderedDict()
|
||||
_global_backward_hooks: Dict[int, Callable] = OrderedDict()
|
||||
_global_backward_pre_hooks: dict[int, Callable] = OrderedDict()
|
||||
_global_backward_hooks: dict[int, Callable] = OrderedDict()
|
||||
_global_is_full_backward_hook: Optional[bool] = None
|
||||
_global_forward_pre_hooks: Dict[int, Callable] = OrderedDict()
|
||||
_global_forward_hooks: Dict[int, Callable] = OrderedDict()
|
||||
_global_forward_hooks_always_called: Dict[int, bool] = OrderedDict()
|
||||
_global_forward_hooks_with_kwargs: Dict[int, bool] = OrderedDict()
|
||||
_global_forward_pre_hooks: dict[int, Callable] = OrderedDict()
|
||||
_global_forward_hooks: dict[int, Callable] = OrderedDict()
|
||||
_global_forward_hooks_always_called: dict[int, bool] = OrderedDict()
|
||||
_global_forward_hooks_with_kwargs: dict[int, bool] = OrderedDict()
|
||||
|
||||
_EXTRA_STATE_KEY_SUFFIX = "_extra_state"
|
||||
|
||||
|
|
@ -447,29 +435,29 @@ class Module:
|
|||
the change."""
|
||||
|
||||
training: bool
|
||||
_parameters: Dict[str, Optional[Parameter]]
|
||||
_buffers: Dict[str, Optional[Tensor]]
|
||||
_non_persistent_buffers_set: Set[str]
|
||||
_backward_pre_hooks: Dict[int, Callable]
|
||||
_backward_hooks: Dict[int, Callable]
|
||||
_parameters: dict[str, Optional[Parameter]]
|
||||
_buffers: dict[str, Optional[Tensor]]
|
||||
_non_persistent_buffers_set: set[str]
|
||||
_backward_pre_hooks: dict[int, Callable]
|
||||
_backward_hooks: dict[int, Callable]
|
||||
_is_full_backward_hook: Optional[bool]
|
||||
_forward_hooks: Dict[int, Callable]
|
||||
_forward_hooks: dict[int, Callable]
|
||||
# Marks whether the corresponding _forward_hooks accept kwargs or not.
|
||||
# As JIT does not support Set[int], this dict is used as a set, where all
|
||||
# hooks represented in this dict accept kwargs.
|
||||
_forward_hooks_with_kwargs: Dict[int, bool]
|
||||
_forward_hooks_with_kwargs: dict[int, bool]
|
||||
# forward hooks that should always be called even if an exception is raised
|
||||
_forward_hooks_always_called: Dict[int, bool]
|
||||
_forward_pre_hooks: Dict[int, Callable]
|
||||
_forward_hooks_always_called: dict[int, bool]
|
||||
_forward_pre_hooks: dict[int, Callable]
|
||||
# Marks whether the corresponding _forward_hooks accept kwargs or not.
|
||||
# As JIT does not support Set[int], this dict is used as a set, where all
|
||||
# hooks represented in this dict accept kwargs.
|
||||
_forward_pre_hooks_with_kwargs: Dict[int, bool]
|
||||
_state_dict_hooks: Dict[int, Callable]
|
||||
_load_state_dict_pre_hooks: Dict[int, Callable]
|
||||
_state_dict_pre_hooks: Dict[int, Callable]
|
||||
_load_state_dict_post_hooks: Dict[int, Callable]
|
||||
_modules: Dict[str, Optional["Module"]]
|
||||
_forward_pre_hooks_with_kwargs: dict[int, bool]
|
||||
_state_dict_hooks: dict[int, Callable]
|
||||
_load_state_dict_pre_hooks: dict[int, Callable]
|
||||
_state_dict_pre_hooks: dict[int, Callable]
|
||||
_load_state_dict_post_hooks: dict[int, Callable]
|
||||
_modules: dict[str, Optional["Module"]]
|
||||
call_super_init: bool = False
|
||||
_compiled_call_impl: Optional[Callable] = None
|
||||
|
||||
|
|
@ -712,7 +700,7 @@ class Module:
|
|||
if target == "":
|
||||
return self
|
||||
|
||||
atoms: List[str] = target.split(".")
|
||||
atoms: list[str] = target.split(".")
|
||||
mod: torch.nn.Module = self
|
||||
|
||||
for item in atoms:
|
||||
|
|
@ -769,7 +757,7 @@ class Module:
|
|||
if target == "":
|
||||
raise ValueError("Cannot set the submodule without a target name!")
|
||||
|
||||
atoms: List[str] = target.split(".")
|
||||
atoms: list[str] = target.split(".")
|
||||
name = atoms.pop(-1)
|
||||
mod: torch.nn.Module = self
|
||||
|
||||
|
|
@ -1485,13 +1473,13 @@ class Module:
|
|||
It returns two lists, one with the full backward hooks and one with the non-full
|
||||
backward hooks.
|
||||
"""
|
||||
full_backward_hooks: List[Callable] = []
|
||||
full_backward_hooks: list[Callable] = []
|
||||
if _global_is_full_backward_hook is True:
|
||||
full_backward_hooks += _global_backward_hooks.values()
|
||||
if self._is_full_backward_hook is True:
|
||||
full_backward_hooks += self._backward_hooks.values()
|
||||
|
||||
non_full_backward_hooks: List[Callable] = []
|
||||
non_full_backward_hooks: list[Callable] = []
|
||||
if _global_is_full_backward_hook is False:
|
||||
non_full_backward_hooks += _global_backward_hooks.values()
|
||||
if self._is_full_backward_hook is False:
|
||||
|
|
@ -1500,7 +1488,7 @@ class Module:
|
|||
return full_backward_hooks, non_full_backward_hooks
|
||||
|
||||
def _get_backward_pre_hooks(self):
|
||||
backward_pre_hooks: List[Callable] = []
|
||||
backward_pre_hooks: list[Callable] = []
|
||||
backward_pre_hooks += _global_backward_pre_hooks.values()
|
||||
backward_pre_hooks += self._backward_pre_hooks.values()
|
||||
|
||||
|
|
@ -1580,10 +1568,10 @@ class Module:
|
|||
def register_forward_pre_hook(
|
||||
self,
|
||||
hook: Union[
|
||||
Callable[[T, Tuple[Any, ...]], Optional[Any]],
|
||||
Callable[[T, tuple[Any, ...]], Optional[Any]],
|
||||
Callable[
|
||||
[T, Tuple[Any, ...], Dict[str, Any]],
|
||||
Optional[Tuple[Any, Dict[str, Any]]],
|
||||
[T, tuple[Any, ...], dict[str, Any]],
|
||||
Optional[tuple[Any, dict[str, Any]]],
|
||||
],
|
||||
],
|
||||
*,
|
||||
|
|
@ -1646,8 +1634,8 @@ class Module:
|
|||
def register_forward_hook(
|
||||
self,
|
||||
hook: Union[
|
||||
Callable[[T, Tuple[Any, ...], Any], Optional[Any]],
|
||||
Callable[[T, Tuple[Any, ...], Dict[str, Any], Any], Optional[Any]],
|
||||
Callable[[T, tuple[Any, ...], Any], Optional[Any]],
|
||||
Callable[[T, tuple[Any, ...], dict[str, Any], Any], Optional[Any]],
|
||||
],
|
||||
*,
|
||||
prepend: bool = False,
|
||||
|
|
@ -2125,7 +2113,7 @@ class Module:
|
|||
|
||||
# The user can pass an optional arbitrary mappable object to `state_dict`, in which case `state_dict` returns
|
||||
# back that same object. But if they pass nothing, an `OrderedDict` is created and returned.
|
||||
T_destination = TypeVar("T_destination", bound=Dict[str, Any])
|
||||
T_destination = TypeVar("T_destination", bound=dict[str, Any])
|
||||
|
||||
@overload
|
||||
def state_dict(
|
||||
|
|
@ -2134,7 +2122,7 @@ class Module:
|
|||
...
|
||||
|
||||
@overload
|
||||
def state_dict(self, *, prefix: str = ..., keep_vars: bool = ...) -> Dict[str, Any]:
|
||||
def state_dict(self, *, prefix: str = ..., keep_vars: bool = ...) -> dict[str, Any]:
|
||||
...
|
||||
|
||||
# TODO: Change `*args` to `*` and remove the corresponding warning in docs when BC allows.
|
||||
|
|
@ -2514,9 +2502,9 @@ class Module:
|
|||
f"Expected state_dict to be dict-like, got {type(state_dict)}."
|
||||
)
|
||||
|
||||
missing_keys: List[str] = []
|
||||
unexpected_keys: List[str] = []
|
||||
error_msgs: List[str] = []
|
||||
missing_keys: list[str] = []
|
||||
unexpected_keys: list[str] = []
|
||||
error_msgs: list[str] = []
|
||||
|
||||
# copy state_dict so _load_from_state_dict can modify it
|
||||
metadata = getattr(state_dict, "_metadata", None)
|
||||
|
|
@ -2632,7 +2620,7 @@ class Module:
|
|||
|
||||
def named_parameters(
|
||||
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
|
||||
) -> Iterator[Tuple[str, Parameter]]:
|
||||
) -> Iterator[tuple[str, Parameter]]:
|
||||
r"""Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.
|
||||
|
||||
Args:
|
||||
|
|
@ -2687,7 +2675,7 @@ class Module:
|
|||
|
||||
def named_buffers(
|
||||
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
|
||||
) -> Iterator[Tuple[str, Tensor]]:
|
||||
) -> Iterator[tuple[str, Tensor]]:
|
||||
r"""Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.
|
||||
|
||||
Args:
|
||||
|
|
@ -2725,7 +2713,7 @@ class Module:
|
|||
for _name, module in self.named_children():
|
||||
yield module
|
||||
|
||||
def named_children(self) -> Iterator[Tuple[str, "Module"]]:
|
||||
def named_children(self) -> Iterator[tuple[str, "Module"]]:
|
||||
r"""Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.
|
||||
|
||||
Yields:
|
||||
|
|
@ -2774,7 +2762,7 @@ class Module:
|
|||
|
||||
def named_modules(
|
||||
self,
|
||||
memo: Optional[Set["Module"]] = None,
|
||||
memo: Optional[set["Module"]] = None,
|
||||
prefix: str = "",
|
||||
remove_duplicate: bool = True,
|
||||
):
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import numbers
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import Size, Tensor
|
||||
|
|
@ -88,7 +88,7 @@ class CrossMapLRN2d(Module):
|
|||
return "{size}, alpha={alpha}, beta={beta}, k={k}".format(**self.__dict__)
|
||||
|
||||
|
||||
_shape_t = Union[int, List[int], Size]
|
||||
_shape_t = Union[int, list[int], Size]
|
||||
|
||||
|
||||
class LayerNorm(Module):
|
||||
|
|
@ -170,7 +170,7 @@ class LayerNorm(Module):
|
|||
"""
|
||||
|
||||
__constants__ = ["normalized_shape", "eps", "elementwise_affine"]
|
||||
normalized_shape: Tuple[int, ...]
|
||||
normalized_shape: tuple[int, ...]
|
||||
eps: float
|
||||
elementwise_affine: bool
|
||||
|
||||
|
|
@ -359,7 +359,7 @@ class RMSNorm(Module):
|
|||
|
||||
"""
|
||||
__constants__ = ["normalized_shape", "eps", "elementwise_affine"]
|
||||
normalized_shape: Tuple[int, ...]
|
||||
normalized_shape: tuple[int, ...]
|
||||
eps: Optional[float]
|
||||
elementwise_affine: bool
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
# mypy: allow-untyped-defs
|
||||
from typing import Sequence, Tuple
|
||||
from collections.abc import Sequence
|
||||
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
|
@ -83,7 +83,7 @@ class CircularPad1d(_CircularPadNd):
|
|||
[5., 6., 7., 4., 5., 6., 7., 4.]]])
|
||||
"""
|
||||
|
||||
padding: Tuple[int, int]
|
||||
padding: tuple[int, int]
|
||||
|
||||
def __init__(self, padding: _size_2_t) -> None:
|
||||
super().__init__()
|
||||
|
|
@ -142,7 +142,7 @@ class CircularPad2d(_CircularPadNd):
|
|||
[8., 6., 7., 8., 6.]]]])
|
||||
"""
|
||||
|
||||
padding: Tuple[int, int, int, int]
|
||||
padding: tuple[int, int, int, int]
|
||||
|
||||
def __init__(self, padding: _size_4_t) -> None:
|
||||
super().__init__()
|
||||
|
|
@ -191,7 +191,7 @@ class CircularPad3d(_CircularPadNd):
|
|||
>>> output = m(input)
|
||||
"""
|
||||
|
||||
padding: Tuple[int, int, int, int, int, int]
|
||||
padding: tuple[int, int, int, int, int, int]
|
||||
|
||||
def __init__(self, padding: _size_6_t) -> None:
|
||||
super().__init__()
|
||||
|
|
@ -262,7 +262,7 @@ class ConstantPad1d(_ConstantPadNd):
|
|||
[ 3.5000, 3.5000, 3.5000, -3.6372, 0.1182, -1.8652, 3.5000]]])
|
||||
"""
|
||||
|
||||
padding: Tuple[int, int]
|
||||
padding: tuple[int, int]
|
||||
|
||||
def __init__(self, padding: _size_2_t, value: float):
|
||||
super().__init__(value)
|
||||
|
|
@ -313,7 +313,7 @@ class ConstantPad2d(_ConstantPadNd):
|
|||
"""
|
||||
|
||||
__constants__ = ["padding", "value"]
|
||||
padding: Tuple[int, int, int, int]
|
||||
padding: tuple[int, int, int, int]
|
||||
|
||||
def __init__(self, padding: _size_4_t, value: float) -> None:
|
||||
super().__init__(value)
|
||||
|
|
@ -353,7 +353,7 @@ class ConstantPad3d(_ConstantPadNd):
|
|||
>>> output = m(input)
|
||||
"""
|
||||
|
||||
padding: Tuple[int, int, int, int, int, int]
|
||||
padding: tuple[int, int, int, int, int, int]
|
||||
|
||||
def __init__(self, padding: _size_6_t, value: float) -> None:
|
||||
super().__init__(value)
|
||||
|
|
@ -405,7 +405,7 @@ class ReflectionPad1d(_ReflectionPadNd):
|
|||
[7., 6., 5., 4., 5., 6., 7., 6.]]])
|
||||
"""
|
||||
|
||||
padding: Tuple[int, int]
|
||||
padding: tuple[int, int]
|
||||
|
||||
def __init__(self, padding: _size_2_t) -> None:
|
||||
super().__init__()
|
||||
|
|
@ -458,7 +458,7 @@ class ReflectionPad2d(_ReflectionPadNd):
|
|||
[7., 6., 7., 8., 7.]]]])
|
||||
"""
|
||||
|
||||
padding: Tuple[int, int, int, int]
|
||||
padding: tuple[int, int, int, int]
|
||||
|
||||
def __init__(self, padding: _size_4_t) -> None:
|
||||
super().__init__()
|
||||
|
|
@ -512,7 +512,7 @@ class ReflectionPad3d(_ReflectionPadNd):
|
|||
[1., 0., 1., 0.]]]]])
|
||||
"""
|
||||
|
||||
padding: Tuple[int, int, int, int, int, int]
|
||||
padding: tuple[int, int, int, int, int, int]
|
||||
|
||||
def __init__(self, padding: _size_6_t) -> None:
|
||||
super().__init__()
|
||||
|
|
@ -564,7 +564,7 @@ class ReplicationPad1d(_ReplicationPadNd):
|
|||
[4., 4., 4., 4., 5., 6., 7., 7.]]])
|
||||
"""
|
||||
|
||||
padding: Tuple[int, int]
|
||||
padding: tuple[int, int]
|
||||
|
||||
def __init__(self, padding: _size_2_t) -> None:
|
||||
super().__init__()
|
||||
|
|
@ -616,7 +616,7 @@ class ReplicationPad2d(_ReplicationPadNd):
|
|||
[6., 6., 7., 8., 8.]]]])
|
||||
"""
|
||||
|
||||
padding: Tuple[int, int, int, int]
|
||||
padding: tuple[int, int, int, int]
|
||||
|
||||
def __init__(self, padding: _size_4_t) -> None:
|
||||
super().__init__()
|
||||
|
|
@ -657,7 +657,7 @@ class ReplicationPad3d(_ReplicationPadNd):
|
|||
>>> output = m(input)
|
||||
"""
|
||||
|
||||
padding: Tuple[int, int, int, int, int, int]
|
||||
padding: tuple[int, int, int, int, int, int]
|
||||
|
||||
def __init__(self, padding: _size_6_t) -> None:
|
||||
super().__init__()
|
||||
|
|
@ -708,7 +708,7 @@ class ZeroPad1d(ConstantPad1d):
|
|||
[ 0.0000, 0.0000, 0.0000, -3.6372, 0.1182, -1.8652, 0.0000]]])
|
||||
"""
|
||||
|
||||
padding: Tuple[int, int]
|
||||
padding: tuple[int, int]
|
||||
|
||||
def __init__(self, padding: _size_2_t) -> None:
|
||||
super().__init__(padding, 0.0)
|
||||
|
|
@ -762,7 +762,7 @@ class ZeroPad2d(ConstantPad2d):
|
|||
[ 0.0000, -0.9162, -0.5436, -0.6446, 0.0000]]]])
|
||||
"""
|
||||
|
||||
padding: Tuple[int, int, int, int]
|
||||
padding: tuple[int, int, int, int]
|
||||
|
||||
def __init__(self, padding: _size_4_t) -> None:
|
||||
super().__init__(padding, 0.0)
|
||||
|
|
@ -804,7 +804,7 @@ class ZeroPad3d(ConstantPad3d):
|
|||
>>> output = m(input)
|
||||
"""
|
||||
|
||||
padding: Tuple[int, int, int, int, int, int]
|
||||
padding: tuple[int, int, int, int, int, int]
|
||||
|
||||
def __init__(self, padding: _size_6_t) -> None:
|
||||
super().__init__(padding, 0.0)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
|
@ -384,7 +384,7 @@ class MaxUnpool1d(_MaxUnpoolNd):
|
|||
self.padding = _single(padding)
|
||||
|
||||
def forward(
|
||||
self, input: Tensor, indices: Tensor, output_size: Optional[List[int]] = None
|
||||
self, input: Tensor, indices: Tensor, output_size: Optional[list[int]] = None
|
||||
) -> Tensor:
|
||||
return F.max_unpool1d(
|
||||
input, indices, self.kernel_size, self.stride, self.padding, output_size
|
||||
|
|
@ -479,7 +479,7 @@ class MaxUnpool2d(_MaxUnpoolNd):
|
|||
self.padding = _pair(padding)
|
||||
|
||||
def forward(
|
||||
self, input: Tensor, indices: Tensor, output_size: Optional[List[int]] = None
|
||||
self, input: Tensor, indices: Tensor, output_size: Optional[list[int]] = None
|
||||
) -> Tensor:
|
||||
return F.max_unpool2d(
|
||||
input, indices, self.kernel_size, self.stride, self.padding, output_size
|
||||
|
|
@ -557,7 +557,7 @@ class MaxUnpool3d(_MaxUnpoolNd):
|
|||
self.padding = _triple(padding)
|
||||
|
||||
def forward(
|
||||
self, input: Tensor, indices: Tensor, output_size: Optional[List[int]] = None
|
||||
self, input: Tensor, indices: Tensor, output_size: Optional[list[int]] = None
|
||||
) -> Tensor:
|
||||
return F.max_unpool3d(
|
||||
input, indices, self.kernel_size, self.stride, self.padding, output_size
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import math
|
|||
import numbers
|
||||
import warnings
|
||||
import weakref
|
||||
from typing import List, Optional, overload, Tuple
|
||||
from typing import Optional, overload
|
||||
from typing_extensions import deprecated
|
||||
|
||||
import torch
|
||||
|
|
@ -106,7 +106,7 @@ class RNNBase(Module):
|
|||
self.dropout = float(dropout)
|
||||
self.bidirectional = bidirectional
|
||||
self.proj_size = proj_size
|
||||
self._flat_weight_refs: List[Optional[weakref.ReferenceType[Parameter]]] = []
|
||||
self._flat_weight_refs: list[Optional[weakref.ReferenceType[Parameter]]] = []
|
||||
num_directions = 2 if bidirectional else 1
|
||||
|
||||
if (
|
||||
|
|
@ -172,7 +172,7 @@ class RNNBase(Module):
|
|||
# Second bias vector included for CuDNN compatibility. Only one
|
||||
# bias vector is needed in standard definition.
|
||||
b_hh = Parameter(torch.empty(gate_size, **factory_kwargs))
|
||||
layer_params: Tuple[Tensor, ...] = ()
|
||||
layer_params: tuple[Tensor, ...] = ()
|
||||
if self.proj_size == 0:
|
||||
if bias:
|
||||
layer_params = (w_ih, w_hh, b_ih, b_hh)
|
||||
|
|
@ -317,7 +317,7 @@ class RNNBase(Module):
|
|||
|
||||
def get_expected_hidden_size(
|
||||
self, input: Tensor, batch_sizes: Optional[Tensor]
|
||||
) -> Tuple[int, int, int]:
|
||||
) -> tuple[int, int, int]:
|
||||
if batch_sizes is not None:
|
||||
mini_batch = int(batch_sizes[0])
|
||||
else:
|
||||
|
|
@ -340,7 +340,7 @@ class RNNBase(Module):
|
|||
def check_hidden_size(
|
||||
self,
|
||||
hx: Tensor,
|
||||
expected_hidden_size: Tuple[int, int, int],
|
||||
expected_hidden_size: tuple[int, int, int],
|
||||
msg: str = "Expected hidden size {}, got {}",
|
||||
) -> None:
|
||||
if hx.size() != expected_hidden_size:
|
||||
|
|
@ -451,7 +451,7 @@ class RNNBase(Module):
|
|||
]
|
||||
|
||||
@property
|
||||
def all_weights(self) -> List[List[Parameter]]:
|
||||
def all_weights(self) -> list[list[Parameter]]:
|
||||
return [
|
||||
[getattr(self, weight) for weight in weights]
|
||||
for weights in self._all_weights
|
||||
|
|
@ -639,14 +639,14 @@ class RNN(RNNBase):
|
|||
@torch._jit_internal._overload_method # noqa: F811
|
||||
def forward(
|
||||
self, input: Tensor, hx: Optional[Tensor] = None
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
pass
|
||||
|
||||
@overload
|
||||
@torch._jit_internal._overload_method # noqa: F811
|
||||
def forward(
|
||||
self, input: PackedSequence, hx: Optional[Tensor] = None
|
||||
) -> Tuple[PackedSequence, Tensor]:
|
||||
) -> tuple[PackedSequence, Tensor]:
|
||||
pass
|
||||
|
||||
def forward(self, input, hx=None): # noqa: F811
|
||||
|
|
@ -978,7 +978,7 @@ class LSTM(RNNBase):
|
|||
|
||||
def get_expected_cell_size(
|
||||
self, input: Tensor, batch_sizes: Optional[Tensor]
|
||||
) -> Tuple[int, int, int]:
|
||||
) -> tuple[int, int, int]:
|
||||
if batch_sizes is not None:
|
||||
mini_batch = int(batch_sizes[0])
|
||||
else:
|
||||
|
|
@ -996,7 +996,7 @@ class LSTM(RNNBase):
|
|||
def check_forward_args(
|
||||
self,
|
||||
input: Tensor,
|
||||
hidden: Tuple[Tensor, Tensor], # type: ignore[override]
|
||||
hidden: tuple[Tensor, Tensor], # type: ignore[override]
|
||||
batch_sizes: Optional[Tensor],
|
||||
):
|
||||
self.check_input(input, batch_sizes)
|
||||
|
|
@ -1014,9 +1014,9 @@ class LSTM(RNNBase):
|
|||
# Same as above, see torch/nn/modules/module.py::_forward_unimplemented
|
||||
def permute_hidden( # type: ignore[override]
|
||||
self,
|
||||
hx: Tuple[Tensor, Tensor],
|
||||
hx: tuple[Tensor, Tensor],
|
||||
permutation: Optional[Tensor],
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
if permutation is None:
|
||||
return hx
|
||||
return _apply_permutation(hx[0], permutation), _apply_permutation(
|
||||
|
|
@ -1027,16 +1027,16 @@ class LSTM(RNNBase):
|
|||
@overload # type: ignore[override]
|
||||
@torch._jit_internal._overload_method # noqa: F811
|
||||
def forward(
|
||||
self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None
|
||||
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: # noqa: F811
|
||||
self, input: Tensor, hx: Optional[tuple[Tensor, Tensor]] = None
|
||||
) -> tuple[Tensor, tuple[Tensor, Tensor]]: # noqa: F811
|
||||
pass
|
||||
|
||||
# Same as above, see torch/nn/modules/module.py::_forward_unimplemented
|
||||
@overload
|
||||
@torch._jit_internal._overload_method # noqa: F811
|
||||
def forward(
|
||||
self, input: PackedSequence, hx: Optional[Tuple[Tensor, Tensor]] = None
|
||||
) -> Tuple[PackedSequence, Tuple[Tensor, Tensor]]: # noqa: F811
|
||||
self, input: PackedSequence, hx: Optional[tuple[Tensor, Tensor]] = None
|
||||
) -> tuple[PackedSequence, tuple[Tensor, Tensor]]: # noqa: F811
|
||||
pass
|
||||
|
||||
def forward(self, input, hx=None): # noqa: F811
|
||||
|
|
@ -1319,14 +1319,14 @@ class GRU(RNNBase):
|
|||
@torch._jit_internal._overload_method # noqa: F811
|
||||
def forward(
|
||||
self, input: Tensor, hx: Optional[Tensor] = None
|
||||
) -> Tuple[Tensor, Tensor]: # noqa: F811
|
||||
) -> tuple[Tensor, Tensor]: # noqa: F811
|
||||
pass
|
||||
|
||||
@overload
|
||||
@torch._jit_internal._overload_method # noqa: F811
|
||||
def forward(
|
||||
self, input: PackedSequence, hx: Optional[Tensor] = None
|
||||
) -> Tuple[PackedSequence, Tensor]: # noqa: F811
|
||||
) -> tuple[PackedSequence, Tensor]: # noqa: F811
|
||||
pass
|
||||
|
||||
def forward(self, input, hx=None): # noqa: F811
|
||||
|
|
@ -1679,8 +1679,8 @@ class LSTMCell(RNNCellBase):
|
|||
super().__init__(input_size, hidden_size, bias, num_chunks=4, **factory_kwargs)
|
||||
|
||||
def forward(
|
||||
self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
self, input: Tensor, hx: Optional[tuple[Tensor, Tensor]] = None
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
if input.dim() not in (1, 2):
|
||||
raise ValueError(
|
||||
f"LSTMCell: Expected input to be 1D or 2D, got {input.dim()}D instead"
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import collections
|
||||
from itertools import repeat
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
|
||||
__all__ = ["consume_prefix_in_state_dict_if_present"]
|
||||
|
|
@ -32,7 +32,7 @@ def _reverse_repeat_tuple(t, n):
|
|||
return tuple(x for x in reversed(t) for _ in range(n))
|
||||
|
||||
|
||||
def _list_with_default(out_size: List[int], defaults: List[int]) -> List[int]:
|
||||
def _list_with_default(out_size: list[int], defaults: list[int]) -> list[int]:
|
||||
import torch
|
||||
|
||||
if isinstance(out_size, (int, torch.SymInt)):
|
||||
|
|
@ -45,7 +45,7 @@ def _list_with_default(out_size: List[int], defaults: List[int]) -> List[int]:
|
|||
|
||||
|
||||
def consume_prefix_in_state_dict_if_present(
|
||||
state_dict: Dict[str, Any],
|
||||
state_dict: dict[str, Any],
|
||||
prefix: str,
|
||||
) -> None:
|
||||
r"""Strip the prefix in state_dict in place, if any.
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import warnings
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch._utils import _get_device_index
|
||||
|
|
@ -116,7 +116,7 @@ class Scatter(Function):
|
|||
|
||||
|
||||
# background streams used for copying
|
||||
_streams: Optional[List[Optional[torch.Stream]]] = None
|
||||
_streams: Optional[list[Optional[torch.Stream]]] = None
|
||||
|
||||
|
||||
def _get_stream(device: torch.device):
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import warnings
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from torch._utils import (
|
||||
|
|
@ -137,7 +136,7 @@ def reduce_add_coalesced(inputs, destination=None, buffer_size=10485760):
|
|||
"""
|
||||
# TODO: When `len(inputs) == 1` and all inputs are on `destination`, just
|
||||
# return `inputs`.
|
||||
dense_tensors: List[List] = [[] for _ in inputs] # shape (num_gpus, num_tensors)
|
||||
dense_tensors: list[list] = [[] for _ in inputs] # shape (num_gpus, num_tensors)
|
||||
output = []
|
||||
ref_order = []
|
||||
# process sparse ones first since they may have different sizes on different gpus
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import operator
|
||||
import warnings
|
||||
from collections.abc import Sequence
|
||||
from itertools import chain
|
||||
from typing import Any, Dict, Generic, List, Optional, Sequence, Tuple, TypeVar, Union
|
||||
from typing import Any, Generic, Optional, TypeVar, Union
|
||||
|
||||
import torch
|
||||
from torch._utils import (
|
||||
|
|
@ -195,20 +196,20 @@ class DataParallel(Module, Generic[T]):
|
|||
|
||||
def replicate(
|
||||
self, module: T, device_ids: Sequence[Union[int, torch.device]]
|
||||
) -> List[T]:
|
||||
) -> list[T]:
|
||||
return replicate(module, device_ids, not torch.is_grad_enabled())
|
||||
|
||||
def scatter(
|
||||
self,
|
||||
inputs: Tuple[Any, ...],
|
||||
kwargs: Optional[Dict[str, Any]],
|
||||
inputs: tuple[Any, ...],
|
||||
kwargs: Optional[dict[str, Any]],
|
||||
device_ids: Sequence[Union[int, torch.device]],
|
||||
) -> Any:
|
||||
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
|
||||
|
||||
def parallel_apply(
|
||||
self, replicas: Sequence[T], inputs: Sequence[Any], kwargs: Any
|
||||
) -> List[Any]:
|
||||
) -> list[Any]:
|
||||
return parallel_apply(
|
||||
replicas, inputs, kwargs, self.device_ids[: len(replicas)]
|
||||
)
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from collections import defaultdict, deque
|
|||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, fields, is_dataclass
|
||||
from enum import auto, Enum
|
||||
from typing import Any, Callable, List, Optional, Tuple, Type, TYPE_CHECKING
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -815,7 +815,7 @@ class DistributedDataParallel(Module, Joinable):
|
|||
|
||||
# Initialize gradient buffers and register all reduce hook
|
||||
self._delay_grad_buffer: Optional[torch.Tensor] = None
|
||||
self._delay_grad_views: List[torch.Tensor] = []
|
||||
self._delay_grad_views: list[torch.Tensor] = []
|
||||
self._delay_all_reduce_all_params = False
|
||||
if len(self._delay_all_reduce_params) != 0:
|
||||
self._register_delay_all_reduce_hook(
|
||||
|
|
@ -853,7 +853,7 @@ class DistributedDataParallel(Module, Joinable):
|
|||
param_to_name_mapping,
|
||||
static_graph,
|
||||
)
|
||||
self._comm_hooks: List[Tuple[Callable, object]] = []
|
||||
self._comm_hooks: list[tuple[Callable, object]] = []
|
||||
|
||||
if self.mixed_precision is not None:
|
||||
_setup_mixed_precision_params(self.mixed_precision, self.module)
|
||||
|
|
@ -907,7 +907,7 @@ class DistributedDataParallel(Module, Joinable):
|
|||
# Register the AccumulateGrad post hooks if optimize_ddp is
|
||||
# True. The hooks will be deregistered if compiled_autograd is not
|
||||
# enabled.
|
||||
self._accum_grad_hooks: List[RemovableHandle] = []
|
||||
self._accum_grad_hooks: list[RemovableHandle] = []
|
||||
optimize_ddp = torch._dynamo.config._get_optimize_ddp_mode()
|
||||
self._use_python_reducer = optimize_ddp in (
|
||||
"python_reducer",
|
||||
|
|
@ -1614,7 +1614,7 @@ class DistributedDataParallel(Module, Joinable):
|
|||
treespec,
|
||||
output_is_rref,
|
||||
) = _tree_flatten_with_rref(output)
|
||||
output_placeholders: List[Optional[torch.Tensor]] = [
|
||||
output_placeholders: list[Optional[torch.Tensor]] = [
|
||||
None for _ in range(len(output_tensor_list))
|
||||
]
|
||||
# Do not touch tensors that have no grad_fn, which can cause issues
|
||||
|
|
@ -2046,7 +2046,7 @@ class DistributedDataParallel(Module, Joinable):
|
|||
self.logger._set_comm_hook_name(str(comm_hook_type))
|
||||
dist._register_builtin_comm_hook(self.reducer, comm_hook_type)
|
||||
|
||||
def _register_fused_optim(self, optim: Type, *args, optim_params=None, **kwargs):
|
||||
def _register_fused_optim(self, optim: type, *args, optim_params=None, **kwargs):
|
||||
r"""
|
||||
Register an optimizer in DDP to optimize parameter immediately after its gradient reduction.
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import threading
|
||||
from typing import Any, cast, Dict, List, Optional, Sequence, Tuple, Union
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, cast, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch._utils import ExceptionWrapper
|
||||
|
|
@ -11,7 +12,7 @@ __all__ = ["get_a_var", "parallel_apply"]
|
|||
|
||||
|
||||
def get_a_var(
|
||||
obj: Union[torch.Tensor, List[Any], Tuple[Any, ...], Dict[Any, Any]],
|
||||
obj: Union[torch.Tensor, list[Any], tuple[Any, ...], dict[Any, Any]],
|
||||
) -> Optional[torch.Tensor]:
|
||||
if isinstance(obj, torch.Tensor):
|
||||
return obj
|
||||
|
|
@ -30,9 +31,9 @@ def get_a_var(
|
|||
def parallel_apply(
|
||||
modules: Sequence[Module],
|
||||
inputs: Sequence[Any],
|
||||
kwargs_tup: Optional[Sequence[Dict[str, Any]]] = None,
|
||||
kwargs_tup: Optional[Sequence[dict[str, Any]]] = None,
|
||||
devices: Optional[Sequence[Optional[Union[int, torch.device]]]] = None,
|
||||
) -> List[Any]:
|
||||
) -> list[Any]:
|
||||
r"""Apply each `module` in :attr:`modules` in parallel on each of :attr:`devices`.
|
||||
|
||||
Args:
|
||||
|
|
@ -51,7 +52,7 @@ def parallel_apply(
|
|||
if kwargs_tup is not None:
|
||||
assert len(modules) == len(kwargs_tup)
|
||||
else:
|
||||
kwargs_tup = (cast(Dict[str, Any], {}),) * len(modules)
|
||||
kwargs_tup = (cast(dict[str, Any], {}),) * len(modules)
|
||||
if devices is not None:
|
||||
assert len(modules) == len(devices)
|
||||
else:
|
||||
|
|
@ -69,7 +70,7 @@ def parallel_apply(
|
|||
i: int,
|
||||
module: Module,
|
||||
input: Any,
|
||||
kwargs: Dict[str, Any],
|
||||
kwargs: dict[str, Any],
|
||||
device: Optional[Union[int, torch.device]] = None,
|
||||
stream: Optional[torch.cuda.Stream] = None,
|
||||
) -> None:
|
||||
|
|
|
|||
|
|
@ -1,16 +1,6 @@
|
|||
from collections import OrderedDict
|
||||
from typing import (
|
||||
cast,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
TYPE_CHECKING,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from collections.abc import Iterator, Sequence
|
||||
from typing import cast, Optional, TYPE_CHECKING, TypeVar, Union
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
import torch
|
||||
|
|
@ -59,7 +49,7 @@ def _is_jit_enabled() -> "EnabledProxy":
|
|||
#
|
||||
# currently a module cannot be replicated properly if the descendants of
|
||||
# any ScriptModule contains python module (type 1 above)
|
||||
def _replicatable_module(module: Module, memo: Optional[Set[Module]] = None) -> bool:
|
||||
def _replicatable_module(module: Module, memo: Optional[set[Module]] = None) -> bool:
|
||||
# module.modules() contains module itself as the first element
|
||||
def descendant_modules(module: Module) -> Iterator[Module]:
|
||||
gen = module.modules()
|
||||
|
|
@ -94,7 +84,7 @@ def _broadcast_coalesced_reshape(
|
|||
tensors: Sequence[torch.Tensor],
|
||||
devices: Sequence[Union[int, torch.device]],
|
||||
detach: bool = False,
|
||||
) -> List[List[torch.Tensor]]:
|
||||
) -> list[list[torch.Tensor]]:
|
||||
from torch.nn.parallel._functions import Broadcast
|
||||
|
||||
if detach:
|
||||
|
|
@ -118,7 +108,7 @@ def replicate(
|
|||
network: T,
|
||||
devices: Sequence[Union[int, torch.device]],
|
||||
detach: bool = False,
|
||||
) -> List[T]:
|
||||
) -> list[T]:
|
||||
if not _replicatable_module(network):
|
||||
raise RuntimeError(
|
||||
"Cannot replicate network where python modules are "
|
||||
|
|
@ -136,8 +126,8 @@ def replicate(
|
|||
param_copies = _broadcast_coalesced_reshape(params, devices, detach)
|
||||
|
||||
buffers = list(network.buffers())
|
||||
buffers_rg: List[torch.Tensor] = []
|
||||
buffers_not_rg: List[torch.Tensor] = []
|
||||
buffers_rg: list[torch.Tensor] = []
|
||||
buffers_not_rg: list[torch.Tensor] = []
|
||||
for buf in buffers:
|
||||
if buf.requires_grad and not detach:
|
||||
buffers_rg.append(buf)
|
||||
|
|
@ -153,8 +143,8 @@ def replicate(
|
|||
)
|
||||
|
||||
modules = list(network.modules())
|
||||
module_copies: List[List[Module]] = [[] for _ in devices]
|
||||
module_indices: Dict[Module, int] = {}
|
||||
module_copies: list[list[Module]] = [[] for _ in devices]
|
||||
module_indices: dict[Module, int] = {}
|
||||
|
||||
for i, module in enumerate(modules):
|
||||
module_indices[module] = i
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
# mypy: allow-untyped-defs
|
||||
from typing import Any, Dict, List, Optional, overload, Sequence, Tuple, TypeVar, Union
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional, overload, TypeVar, Union
|
||||
from typing_extensions import deprecated
|
||||
|
||||
import torch
|
||||
|
|
@ -34,7 +35,7 @@ def scatter(
|
|||
inputs: torch.Tensor,
|
||||
target_gpus: Sequence[Union[int, torch.device]],
|
||||
dim: int = ...,
|
||||
) -> Tuple[torch.Tensor, ...]:
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
...
|
||||
|
||||
|
||||
|
|
@ -43,7 +44,7 @@ def scatter(
|
|||
inputs: T,
|
||||
target_gpus: Sequence[Union[int, torch.device]],
|
||||
dim: int = ...,
|
||||
) -> List[T]:
|
||||
) -> list[T]:
|
||||
...
|
||||
|
||||
|
||||
|
|
@ -79,11 +80,11 @@ def scatter(inputs, target_gpus, dim=0):
|
|||
|
||||
|
||||
def scatter_kwargs(
|
||||
inputs: Tuple[Any, ...],
|
||||
kwargs: Optional[Dict[str, Any]],
|
||||
inputs: tuple[Any, ...],
|
||||
kwargs: Optional[dict[str, Any]],
|
||||
target_gpus: Sequence[Union[int, torch.device]],
|
||||
dim: int = 0,
|
||||
) -> Tuple[Tuple[Any, ...], Tuple[Dict[str, Any], ...]]:
|
||||
) -> tuple[tuple[Any, ...], tuple[dict[str, Any], ...]]:
|
||||
r"""Scatter with support for kwargs dictionary."""
|
||||
scattered_inputs = scatter(inputs, target_gpus, dim) if inputs else []
|
||||
scattered_kwargs = scatter(kwargs, target_gpus, dim) if kwargs else []
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import importlib
|
||||
import warnings
|
||||
from typing import Callable, List
|
||||
from typing import Callable
|
||||
|
||||
|
||||
_MESSAGE_TEMPLATE = (
|
||||
|
|
@ -9,7 +9,7 @@ _MESSAGE_TEMPLATE = (
|
|||
|
||||
|
||||
def lazy_deprecated_import(
|
||||
all: List[str],
|
||||
all: list[str],
|
||||
old_module: str,
|
||||
new_module: str,
|
||||
) -> Callable:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
# mypy: allow-untyped-defs
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
|
@ -68,7 +68,7 @@ def int_padding_for_string_padding(func, padding_style, dilation, kernel_size):
|
|||
return dilation[i] if isinstance(dilation, tuple) else dilation
|
||||
|
||||
if padding_style == "same":
|
||||
padding: List[int] = []
|
||||
padding: list[int] = []
|
||||
# F.pad needs the padding in reverse order from what conv expects
|
||||
for i in range(conv_picker(func, 0, 1, 2), -1, -1):
|
||||
padding += conv_padding_for_same(get_dilation(i), kernel_size[i])
|
||||
|
|
@ -145,7 +145,7 @@ def conv_backward(func, ctx, grad_output):
|
|||
kernel_size = [weight_shape[i] for i in range(2, conv_picker(func, 3, 4, 5))]
|
||||
|
||||
batch_size = ctx.batch_size
|
||||
results: List[Optional[torch.Tensor]] = []
|
||||
results: list[Optional[torch.Tensor]] = []
|
||||
results.append(None) # for kwarg names
|
||||
results.append(None) # for op reference
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
# mypy: allow-untyped-defs
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
|
@ -56,7 +56,7 @@ class EmbeddingPerSampleGrad(torch.autograd.Function):
|
|||
1, index, grad_output.reshape(batch_size, -1, embedding_dim)
|
||||
)
|
||||
|
||||
results: List[Optional[torch.Tensor]] = []
|
||||
results: list[Optional[torch.Tensor]] = []
|
||||
results.append(None) # for kwarg names
|
||||
results.append(None) # for op reference
|
||||
|
||||
|
|
|
|||
|
|
@ -1,14 +1,14 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import functools
|
||||
from contextlib import contextmanager
|
||||
from typing import Callable, Dict
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from torch._decomp import decomposition_table
|
||||
from torch.utils._pytree import tree_map_only
|
||||
|
||||
|
||||
HANDLED_FUNCTIONS: Dict[Callable, torch.autograd.Function] = {}
|
||||
HANDLED_FUNCTIONS: dict[Callable, torch.autograd.Function] = {}
|
||||
|
||||
aten = torch._ops.ops.aten
|
||||
# __torch_function__ runs before the pydispatcher so we need to manually use the same
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import operator
|
||||
from functools import reduce
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
|
@ -51,7 +51,7 @@ class GroupNormPerSampleGrad(torch.autograd.Function):
|
|||
weight, bias, eps = ctx.weight, ctx.bias, ctx.eps
|
||||
mean, rstd = ctx.mean, ctx.rstd
|
||||
|
||||
results: List[Optional[torch.Tensor]] = []
|
||||
results: list[Optional[torch.Tensor]] = []
|
||||
results.append(None) # for kwarg names
|
||||
results.append(None) # for op reference
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
# mypy: allow-untyped-defs
|
||||
from functools import partial
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
|
@ -40,7 +40,7 @@ class InstanceNormPerSampleGrad(torch.autograd.Function):
|
|||
input, running_mean, running_var = ctx.input, ctx.running_mean, ctx.running_var
|
||||
weight, bias, eps = ctx.weight, ctx.bias, ctx.eps
|
||||
|
||||
results: List[Optional[torch.Tensor]] = []
|
||||
results: list[Optional[torch.Tensor]] = []
|
||||
results.append(None) # for kwarg names
|
||||
results.append(None) # for op reference
|
||||
if input.requires_grad:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
# mypy: allow-untyped-defs
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
|
@ -52,7 +52,7 @@ class LayerNormPerSampleGrad(torch.autograd.Function):
|
|||
input, normalized_shape = ctx.args
|
||||
mean, rstd = ctx.mean, ctx.rstd
|
||||
|
||||
results: List[Optional[torch.Tensor]] = []
|
||||
results: list[Optional[torch.Tensor]] = []
|
||||
results.append(None) # for kwarg names
|
||||
results.append(None) # for op reference
|
||||
if input.requires_grad:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
# mypy: allow-untyped-defs
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
|
@ -38,7 +38,7 @@ class LinearPerSampleGrad(torch.autograd.Function):
|
|||
def backward(ctx, grad_output):
|
||||
input, weight = ctx.args
|
||||
bias = ctx.kwargs["bias"]
|
||||
results: List[Optional[torch.Tensor]] = []
|
||||
results: list[Optional[torch.Tensor]] = []
|
||||
results.append(None) # for kwarg_names
|
||||
results.append(None) # for op reference
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Dict, Iterable, List, Tuple
|
||||
from collections.abc import Iterable
|
||||
|
||||
import torch
|
||||
|
||||
|
|
@ -113,7 +113,7 @@ class NamedMemberAccessor:
|
|||
|
||||
def __init__(self, module: "torch.nn.Module") -> None:
|
||||
self.module = module
|
||||
self.memo: Dict[str, torch.nn.Module] = {}
|
||||
self.memo: dict[str, torch.nn.Module] = {}
|
||||
|
||||
# Nested attribute access
|
||||
|
||||
|
|
@ -225,7 +225,7 @@ class NamedMemberAccessor:
|
|||
|
||||
# Batched operations
|
||||
|
||||
def get_tensors(self, names: Iterable[str]) -> List[torch.Tensor]:
|
||||
def get_tensors(self, names: Iterable[str]) -> list[torch.Tensor]:
|
||||
"""
|
||||
Get the tensors specified by the given paths.
|
||||
|
||||
|
|
@ -252,7 +252,7 @@ class NamedMemberAccessor:
|
|||
for name, value in zip(names, values):
|
||||
self.set_tensor(name, value)
|
||||
|
||||
def set_tensors_dict(self, named_tensors: Dict[str, torch.Tensor]) -> None:
|
||||
def set_tensors_dict(self, named_tensors: dict[str, torch.Tensor]) -> None:
|
||||
"""
|
||||
Set the attributes specified by the given paths to values.
|
||||
|
||||
|
|
@ -281,7 +281,7 @@ class NamedMemberAccessor:
|
|||
names: Iterable[str],
|
||||
values: Iterable[torch.Tensor],
|
||||
allow_missing: bool = False,
|
||||
) -> List[torch.Tensor]:
|
||||
) -> list[torch.Tensor]:
|
||||
"""
|
||||
Swap the attributes specified by the given paths to values.
|
||||
|
||||
|
|
@ -301,8 +301,8 @@ class NamedMemberAccessor:
|
|||
]
|
||||
|
||||
def swap_tensors_dict(
|
||||
self, named_tensors: Dict[str, torch.Tensor], allow_missing: bool = False
|
||||
) -> Tuple[Dict[str, torch.Tensor], List[str]]:
|
||||
self, named_tensors: dict[str, torch.Tensor], allow_missing: bool = False
|
||||
) -> tuple[dict[str, torch.Tensor], list[str]]:
|
||||
"""
|
||||
Swap the attributes specified by the given paths to values.
|
||||
|
||||
|
|
@ -332,7 +332,7 @@ class NamedMemberAccessor:
|
|||
raise RuntimeError(f"Missing key(s): {', '.join(map(repr, missing_keys))}.")
|
||||
return orig_named_tensors, missing_keys
|
||||
|
||||
def check_keys(self, keys: Iterable[str]) -> Tuple[List[str], List[str]]:
|
||||
def check_keys(self, keys: Iterable[str]) -> tuple[list[str], list[str]]:
|
||||
"""Check that the given keys are valid."""
|
||||
keys = set(keys)
|
||||
valid_keys = {name for name, _ in self.named_tensors(remove_duplicate=False)}
|
||||
|
|
@ -345,21 +345,21 @@ class NamedMemberAccessor:
|
|||
def named_parameters(
|
||||
self,
|
||||
remove_duplicate: bool = True,
|
||||
) -> Iterable[Tuple[str, torch.Tensor]]:
|
||||
) -> Iterable[tuple[str, torch.Tensor]]:
|
||||
"""Iterate over all the parameters in the module."""
|
||||
yield from self.module.named_parameters(remove_duplicate=remove_duplicate)
|
||||
|
||||
def named_buffers(
|
||||
self,
|
||||
remove_duplicate: bool = True,
|
||||
) -> Iterable[Tuple[str, torch.Tensor]]:
|
||||
) -> Iterable[tuple[str, torch.Tensor]]:
|
||||
"""Iterate over all the buffers in the module."""
|
||||
yield from self.module.named_buffers(remove_duplicate=remove_duplicate)
|
||||
|
||||
def named_tensors(
|
||||
self,
|
||||
remove_duplicate: bool = True,
|
||||
) -> Iterable[Tuple[str, torch.Tensor]]:
|
||||
) -> Iterable[tuple[str, torch.Tensor]]:
|
||||
"""Iterate over all the tensors in the module."""
|
||||
yield from self.module.named_parameters(remove_duplicate=remove_duplicate)
|
||||
yield from self.module.named_buffers(remove_duplicate=remove_duplicate)
|
||||
|
|
@ -367,6 +367,6 @@ class NamedMemberAccessor:
|
|||
def named_modules(
|
||||
self,
|
||||
remove_duplicate: bool = True,
|
||||
) -> Iterable[Tuple[str, "torch.nn.Module"]]:
|
||||
) -> Iterable[tuple[str, "torch.nn.Module"]]:
|
||||
"""Iterate over all the modules in the module."""
|
||||
yield from self.module.named_modules(remove_duplicate=remove_duplicate)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from typing import Iterable, Optional
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from typing import Optional, Tuple, TypeVar
|
||||
from typing import TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
|
|
@ -55,14 +55,14 @@ def fuse_conv_bn_eval(
|
|||
|
||||
def fuse_conv_bn_weights(
|
||||
conv_w: torch.Tensor,
|
||||
conv_b: Optional[torch.Tensor],
|
||||
conv_b: torch.Tensor | None,
|
||||
bn_rm: torch.Tensor,
|
||||
bn_rv: torch.Tensor,
|
||||
bn_eps: float,
|
||||
bn_w: Optional[torch.Tensor],
|
||||
bn_b: Optional[torch.Tensor],
|
||||
bn_w: torch.Tensor | None,
|
||||
bn_b: torch.Tensor | None,
|
||||
transpose: bool = False,
|
||||
) -> Tuple[torch.nn.Parameter, torch.nn.Parameter]:
|
||||
) -> tuple[torch.nn.Parameter, torch.nn.Parameter]:
|
||||
r"""Fuse convolutional module parameters and BatchNorm module parameters into new convolutional module parameters.
|
||||
|
||||
Args:
|
||||
|
|
@ -155,13 +155,13 @@ def fuse_linear_bn_eval(
|
|||
|
||||
def fuse_linear_bn_weights(
|
||||
linear_w: torch.Tensor,
|
||||
linear_b: Optional[torch.Tensor],
|
||||
linear_b: torch.Tensor | None,
|
||||
bn_rm: torch.Tensor,
|
||||
bn_rv: torch.Tensor,
|
||||
bn_eps: float,
|
||||
bn_w: torch.Tensor,
|
||||
bn_b: torch.Tensor,
|
||||
) -> Tuple[torch.nn.Parameter, torch.nn.Parameter]:
|
||||
) -> tuple[torch.nn.Parameter, torch.nn.Parameter]:
|
||||
r"""Fuse linear module parameters and BatchNorm module parameters into new linear module parameters.
|
||||
|
||||
Args:
|
||||
|
|
|
|||
|
|
@ -2,9 +2,10 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import collections
|
||||
import copyreg
|
||||
from collections.abc import Sequence
|
||||
from contextlib import contextmanager
|
||||
from copy import deepcopy
|
||||
from typing import Dict, Optional, Sequence, Tuple, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
|
@ -25,7 +26,7 @@ __all__ = [
|
|||
]
|
||||
|
||||
_cache_enabled = 0
|
||||
_cache: Dict[Tuple[int, str], Optional[Tensor]] = {}
|
||||
_cache: dict[tuple[int, str], Optional[Tensor]] = {}
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
|
@ -165,9 +166,7 @@ class ParametrizationList(ModuleList):
|
|||
pass
|
||||
# else, or if it throws, we assume that right_inverse is the identity
|
||||
|
||||
if not isinstance(new, Tensor) and not isinstance(
|
||||
new, collections.abc.Sequence
|
||||
):
|
||||
if not isinstance(new, Tensor) and not isinstance(new, Sequence):
|
||||
raise ValueError(
|
||||
"'right_inverse' must return a Tensor or a Sequence of tensors (list, tuple...). "
|
||||
f"Got {type(new).__name__}"
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ r"""Pruning methods."""
|
|||
import numbers
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Iterable
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
|
@ -270,7 +269,7 @@ class PruningContainer(BasePruningMethod):
|
|||
"""
|
||||
|
||||
def __init__(self, *args):
|
||||
self._pruning_methods: Tuple[BasePruningMethod, ...] = ()
|
||||
self._pruning_methods: tuple[BasePruningMethod, ...] = ()
|
||||
if not isinstance(args, Iterable): # only 1 item
|
||||
self._tensor_name = args._tensor_name
|
||||
self.add_pruning_method(args)
|
||||
|
|
|
|||
|
|
@ -1,16 +1,6 @@
|
|||
import warnings
|
||||
from collections.abc import Iterable
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
List,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
overload,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from typing import Any, Callable, NamedTuple, Optional, overload, TypeVar, Union
|
||||
from typing_extensions import Self
|
||||
|
||||
import torch
|
||||
|
|
@ -228,7 +218,7 @@ def _packed_sequence_init_args(
|
|||
batch_sizes: Optional[Tensor] = None,
|
||||
sorted_indices: Optional[Tensor] = None,
|
||||
unsorted_indices: Optional[Tensor] = None,
|
||||
) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
|
||||
) -> tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
|
||||
# NB: if unsorted_indices is provided, it should be the inverse permutation
|
||||
# to sorted_indices. Don't assert it here because the PackedSequence ctor
|
||||
# should only be used internally.
|
||||
|
|
@ -279,7 +269,7 @@ def invert_permutation(permutation: Optional[Tensor]) -> Optional[Tensor]:
|
|||
|
||||
def pack_padded_sequence(
|
||||
input: Tensor,
|
||||
lengths: Union[Tensor, List[int]],
|
||||
lengths: Union[Tensor, list[int]],
|
||||
batch_first: bool = False,
|
||||
enforce_sorted: bool = True,
|
||||
) -> PackedSequence:
|
||||
|
|
@ -347,7 +337,7 @@ def pad_packed_sequence(
|
|||
batch_first: bool = False,
|
||||
padding_value: float = 0.0,
|
||||
total_length: Optional[int] = None,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
r"""Pad a packed batch of variable length sequences.
|
||||
|
||||
It is an inverse operation to :func:`pack_padded_sequence`.
|
||||
|
|
@ -419,7 +409,7 @@ def pad_packed_sequence(
|
|||
|
||||
# NOTE: for JIT-compatibility, we need to be more restrictive here and use specific types instead of Iterable.
|
||||
def pad_sequence(
|
||||
sequences: Union[Tensor, List[Tensor]],
|
||||
sequences: Union[Tensor, list[Tensor]],
|
||||
batch_first: bool = False,
|
||||
padding_value: float = 0.0,
|
||||
padding_side: str = "right",
|
||||
|
|
@ -487,7 +477,7 @@ def unpad_sequence(
|
|||
padded_sequences: Tensor,
|
||||
lengths: Tensor,
|
||||
batch_first: bool = False,
|
||||
) -> List[Tensor]:
|
||||
) -> list[Tensor]:
|
||||
r"""Unpad padded Tensor into a list of variable length Tensors.
|
||||
|
||||
``unpad_sequence`` unstacks padded Tensor into a list of variable length Tensors.
|
||||
|
|
@ -533,7 +523,7 @@ def unpad_sequence(
|
|||
|
||||
|
||||
def pack_sequence(
|
||||
sequences: List[Tensor],
|
||||
sequences: list[Tensor],
|
||||
enforce_sorted: bool = True,
|
||||
) -> PackedSequence:
|
||||
r"""Packs a list of variable length Tensors.
|
||||
|
|
@ -571,7 +561,7 @@ def pack_sequence(
|
|||
)
|
||||
|
||||
|
||||
def unpack_sequence(packed_sequences: PackedSequence) -> List[Tensor]:
|
||||
def unpack_sequence(packed_sequences: PackedSequence) -> list[Tensor]:
|
||||
r"""Unpack PackedSequence into a list of variable length Tensors.
|
||||
|
||||
``packed_sequences`` should be a PackedSequence object.
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import contextlib
|
||||
from typing import Any, Dict, Optional, Set, Tuple, Union
|
||||
from typing import Any, Optional, Union
|
||||
from typing_extensions import deprecated
|
||||
|
||||
import torch
|
||||
|
|
@ -13,8 +13,8 @@ __all__ = ["functional_call"]
|
|||
|
||||
def _untie_named_tensors_map(
|
||||
module: "torch.nn.Module",
|
||||
parameters_and_buffers: Dict[str, Tensor],
|
||||
) -> Dict[str, Tensor]:
|
||||
parameters_and_buffers: dict[str, Tensor],
|
||||
) -> dict[str, Tensor]:
|
||||
"""
|
||||
Unties all tied tensors in the module to parameters_and_buffers.
|
||||
|
||||
|
|
@ -41,12 +41,12 @@ def _untie_named_tensors_map(
|
|||
ValueError: if there are more than one user-given values for the same tied tensor.
|
||||
"""
|
||||
# A map of {name: tensor} for all tensors (including tied ones) in the module.
|
||||
all_named_tensors: Dict[str, Tensor] = {}
|
||||
all_named_tensors: dict[str, Tensor] = {}
|
||||
all_named_tensors.update(module.named_parameters(remove_duplicate=False))
|
||||
all_named_tensors.update(module.named_buffers(remove_duplicate=False))
|
||||
|
||||
# A map of {tensor: set(all_tied_names)} for all tensor names in the module.
|
||||
tensor_to_tied_names_map: Dict[Tensor, Set[str]] = {}
|
||||
tensor_to_tied_names_map: dict[Tensor, set[str]] = {}
|
||||
for name, tensor in all_named_tensors.items():
|
||||
if tensor not in tensor_to_tied_names_map:
|
||||
tensor_to_tied_names_map[tensor] = set()
|
||||
|
|
@ -54,7 +54,7 @@ def _untie_named_tensors_map(
|
|||
|
||||
# A map of {tied_name: set(all_tied_names)} for all tensor names in the module.
|
||||
# If a name is not tied, it will not be in this map.
|
||||
tied_names_map: Dict[str, Set[str]] = {}
|
||||
tied_names_map: dict[str, set[str]] = {}
|
||||
for tied_names in tensor_to_tied_names_map.values():
|
||||
if len(tied_names) > 1:
|
||||
for tied_name in tied_names:
|
||||
|
|
@ -98,7 +98,7 @@ def _untie_named_tensors_map(
|
|||
@contextlib.contextmanager
|
||||
def _reparametrize_module(
|
||||
module: "torch.nn.Module",
|
||||
parameters_and_buffers: Dict[str, Tensor],
|
||||
parameters_and_buffers: dict[str, Tensor],
|
||||
tie_weights: bool = False,
|
||||
strict: bool = False,
|
||||
stack_weights: bool = False,
|
||||
|
|
@ -132,7 +132,7 @@ def _reparametrize_module(
|
|||
)
|
||||
)
|
||||
|
||||
orig_parameters_and_buffers: Dict[str, Tensor] = {}
|
||||
orig_parameters_and_buffers: dict[str, Tensor] = {}
|
||||
try:
|
||||
orig_parameters_and_buffers, _ = accessor.swap_tensors_dict(
|
||||
untied_parameters_and_buffers, allow_missing=True
|
||||
|
|
@ -167,9 +167,9 @@ def _reparametrize_module(
|
|||
)
|
||||
def functional_call(
|
||||
module: "torch.nn.Module",
|
||||
parameters_and_buffers: Dict[str, Tensor],
|
||||
args: Optional[Union[Any, Tuple]] = None,
|
||||
kwargs: Optional[Dict[str, Any]] = None,
|
||||
parameters_and_buffers: dict[str, Tensor],
|
||||
args: Optional[Union[Any, tuple]] = None,
|
||||
kwargs: Optional[dict[str, Any]] = None,
|
||||
*,
|
||||
tie_weights: bool = True,
|
||||
strict: bool = False,
|
||||
|
|
@ -245,9 +245,9 @@ def functional_call(
|
|||
|
||||
def _functional_call(
|
||||
module: "torch.nn.Module",
|
||||
parameters_and_buffers: Dict[str, Tensor],
|
||||
args: Optional[Union[Any, Tuple]] = None,
|
||||
kwargs: Optional[Dict[str, Any]] = None,
|
||||
parameters_and_buffers: dict[str, Tensor],
|
||||
args: Optional[Union[Any, tuple]] = None,
|
||||
kwargs: Optional[dict[str, Any]] = None,
|
||||
*,
|
||||
tie_weights: bool = True,
|
||||
strict: bool = False,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user