PEP585 update - test (#145176)

See #145101 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145176
Approved by: https://github.com/bobrenjc93
This commit is contained in:
Aaron Orenstein 2025-01-21 11:22:20 -08:00 committed by PyTorch MergeBot
parent 40e27fbcf2
commit 99dbc5b0e2
146 changed files with 801 additions and 1099 deletions

View File

@ -7,7 +7,7 @@ import re
import textwrap
import timeit
import unittest
from typing import Any, List, Tuple
from typing import Any
import expecttest
import numpy as np
@ -67,7 +67,7 @@ def generate_callgrind_artifacts() -> None:
def load_callgrind_artifacts() -> (
Tuple[benchmark_utils.CallgrindStats, benchmark_utils.CallgrindStats]
tuple[benchmark_utils.CallgrindStats, benchmark_utils.CallgrindStats]
):
"""Hermetic artifact to unit test Callgrind wrapper.
@ -85,9 +85,9 @@ def load_callgrind_artifacts() -> (
pattern = re.compile(r"^\s*([0-9]+)\s(.+)$")
def to_function_counts(
count_strings: List[str], inclusive: bool
count_strings: list[str], inclusive: bool
) -> benchmark_utils.FunctionCounts:
data: List[benchmark_utils.FunctionCount] = []
data: list[benchmark_utils.FunctionCount] = []
for cs in count_strings:
# Storing entries as f"{c} {fn}" rather than [c, fn] adds some work
# reviving the artifact, but it makes the json much easier to read.

View File

@ -7,7 +7,7 @@ import sys
import xml.etree.ElementTree as ET
from collections import defaultdict
from types import MethodType
from typing import Any, List, Optional, TYPE_CHECKING, Union
from typing import Any, Optional, TYPE_CHECKING, Union
import pytest
from _pytest.config import Config, filename_arg
@ -241,7 +241,7 @@ def pytest_report_teststatus(report, config):
@pytest.hookimpl(trylast=True)
def pytest_collection_modifyitems(items: List[Any]) -> None:
def pytest_collection_modifyitems(items: list[Any]) -> None:
"""
This hook is used when rerunning disabled tests to get rid of all skipped tests
instead of running and skipping them N times. This avoids flooding the console
@ -304,7 +304,7 @@ class StepcurrentPlugin:
self.skip: bool = config.getoption("stepcurrent_skip")
self.run_single: bool = config.getoption("run_single")
def pytest_collection_modifyitems(self, config: Config, items: List[Any]) -> None:
def pytest_collection_modifyitems(self, config: Config, items: list[Any]) -> None:
if not self.lastrun:
self.report_status = "Cannot find last run test, not skipping"
return

View File

@ -1,5 +1,3 @@
from typing import List
import torch
from torch import Tensor
@ -8,7 +6,7 @@ lib = torch.library._scoped_library("python_agnostic", "FRAGMENT")
lib.define("ultra_norm(Tensor[] inputs) -> Tensor")
def ultra_norm(inputs: List[Tensor]) -> Tensor:
def ultra_norm(inputs: list[Tensor]) -> Tensor:
"""
Computes the ultra-L2-norm of a list of tensors via computing the norm of norms.

View File

@ -2,7 +2,7 @@
from __future__ import annotations
import typing
from typing import List, Optional, Union
from typing import Optional, Union
import torch
from torch import Tensor, types
@ -87,7 +87,7 @@ class TestInferSchemaWithAnnotation(TestCase):
result = torch.library.infer_schema(foo_op_2, mutates_args=mutates_args)
self.assertEqual(result, "(SymInt[] x) -> SymInt")
def foo_op_3(x: typing.List[int]) -> int:
def foo_op_3(x: list[int]) -> int:
return 1
result = torch.library.infer_schema(foo_op_3, mutates_args=mutates_args)
@ -99,7 +99,7 @@ class TestInferSchemaWithAnnotation(TestCase):
result = torch.library.infer_schema(foo_op_4, mutates_args=mutates_args)
self.assertEqual(result, "(SymInt[]? x) -> SymInt")
def foo_op_5(x: typing.Optional[typing.List[int]]) -> int:
def foo_op_5(x: typing.Optional[list[int]]) -> int:
return 1
result = torch.library.infer_schema(foo_op_5, mutates_args=mutates_args)
@ -136,7 +136,7 @@ class TestInferSchemaWithAnnotation(TestCase):
result = torch.library.infer_schema(foo_op_3, mutates_args=mutates_args)
self.assertEqual(result, "(Tensor x) -> Tensor")
def foo_op_4(x: List[int]) -> types.Number:
def foo_op_4(x: list[int]) -> types.Number:
return x[0]
result = torch.library.infer_schema(foo_op_4, mutates_args=mutates_args)
@ -154,7 +154,7 @@ class TestInferSchemaWithAnnotation(TestCase):
result = torch.library.infer_schema(foo_op_6, mutates_args=mutates_args)
self.assertEqual(result, "(SymInt[] x) -> SymInt")
def foo_op_7(x: List[int]) -> int:
def foo_op_7(x: list[int]) -> int:
return 1
result = torch.library.infer_schema(foo_op_7, mutates_args=mutates_args)
@ -166,7 +166,7 @@ class TestInferSchemaWithAnnotation(TestCase):
result = torch.library.infer_schema(foo_op_8, mutates_args=mutates_args)
self.assertEqual(result, "(SymInt[]? x) -> SymInt")
def foo_op_9(x: Optional[List[int]]) -> int:
def foo_op_9(x: Optional[list[int]]) -> int:
return 1
result = torch.library.infer_schema(foo_op_9, mutates_args=mutates_args)

View File

@ -5,7 +5,7 @@ import copy
import functools
import itertools
import unittest
from typing import Any, List, Optional, Type, Union
from typing import Any, Optional, Union
import torch
import torch.distributed as dist
@ -117,7 +117,7 @@ class TestFullyShardAutograd(FSDPTest):
local_inp = global_inp[
self.rank * local_batch_size : (self.rank + 1) * local_batch_size
].detach()
losses: List[torch.Tensor] = []
losses: list[torch.Tensor] = []
for _model, inp in ((ref_model, global_inp), (model, local_inp)):
losses.append(_model(inp).sum())
losses[-1].backward()
@ -141,7 +141,7 @@ class TestFullyShardAutograd(FSDPTest):
self._test_nontensor_activations,
)
def _test_nontensor_activations(self, container_type: Type):
def _test_nontensor_activations(self, container_type: type):
class Module(nn.Module):
def __init__(self, dim: int):
super().__init__()
@ -170,7 +170,7 @@ class TestFullyShardAutograd(FSDPTest):
return self.relu(self.lin2(self.relu(self.lin1(x))))
class ToContainerType(nn.Module):
def __init__(self, container_type: Type):
def __init__(self, container_type: type):
super().__init__()
self.container_type = container_type
@ -190,7 +190,7 @@ class TestFullyShardAutograd(FSDPTest):
)
class FromContainerType(nn.Module):
def __init__(self, container_type: Type):
def __init__(self, container_type: type):
super().__init__()
self.container_type = container_type
@ -227,7 +227,7 @@ class TestFullyShardAutograd(FSDPTest):
local_inp = global_inp[
self.rank * local_batch_size : (self.rank + 1) * local_batch_size
].detach()
losses: List[torch.Tensor] = []
losses: list[torch.Tensor] = []
for _model, inp in ((ref_model, global_inp), (model, local_inp)):
losses.append(_model(inp).sum())
losses[-1].backward()

View File

@ -4,7 +4,7 @@ import copy
import functools
import itertools
import unittest
from typing import Callable, List, Optional, Tuple, Union
from typing import Callable, Optional, Union
import torch
import torch.distributed as dist
@ -58,7 +58,7 @@ from torch.testing._internal.distributed._tensor.common_dtensor import (
c10d_ops = torch.ops.c10d
# For recording FSDP events like unshard or post-backward
EventType = Tuple[str, str, TrainingState]
EventType = tuple[str, str, TrainingState]
class TestFullyShardCollectiveOps(FSDPTestMultiThread):
@ -70,7 +70,7 @@ class TestFullyShardCollectiveOps(FSDPTestMultiThread):
def device(self) -> torch.device:
return torch.device("cuda:0")
def _get_param_sizes(self) -> List[torch.Size]:
def _get_param_sizes(self) -> list[torch.Size]:
# For world size 128, the fp32 all-gather and reduce-scatter testing
# requires ~0.22 GB
return [
@ -84,7 +84,7 @@ class TestFullyShardCollectiveOps(FSDPTestMultiThread):
torch.Size([64, 297]),
]
def _init_params(self, param_sizes: List[torch.Size]) -> List[nn.Parameter]:
def _init_params(self, param_sizes: list[torch.Size]) -> list[nn.Parameter]:
torch.manual_seed(42)
orig_params = [
nn.Parameter(torch.randn(size, device=self.device)) for size in param_sizes
@ -96,7 +96,7 @@ class TestFullyShardCollectiveOps(FSDPTestMultiThread):
return orig_params
def _init_fsdp_param_group(
self, params: List[nn.Parameter], reshard_after_forward: Union[bool, int]
self, params: list[nn.Parameter], reshard_after_forward: Union[bool, int]
):
module = nn.ParameterList([param.detach().clone() for param in params])
mesh_info = FSDPMeshInfo(_init_default_fully_shard_mesh(), shard_mesh_dim=0)
@ -143,7 +143,7 @@ class TestFullyShardCollectiveOps(FSDPTestMultiThread):
def _test_all_gather(
self,
param_sizes: List[torch.Size],
param_sizes: list[torch.Size],
reshard_after_forward: Union[bool, int],
async_op: bool,
all_gather_copy_in_stream: torch.cuda.Stream,
@ -165,7 +165,7 @@ class TestFullyShardCollectiveOps(FSDPTestMultiThread):
fsdp_param_group._to_unsharded()
def check_all_gathered_params(
orig_params: List[nn.Parameter], module: nn.Module
orig_params: list[nn.Parameter], module: nn.Module
):
for orig_param, param in zip(orig_params, module.parameters()):
self.assertIsInstance(param, torch.Tensor)
@ -228,7 +228,7 @@ class TestFullyShardCollectiveOps(FSDPTestMultiThread):
def _test_reduce_scatter(
self,
param_sizes: List[torch.Size],
param_sizes: list[torch.Size],
reduce_scatter_stream: torch.cuda.Stream,
reduce_scatter_dtype: torch.dtype,
):
@ -453,7 +453,7 @@ class TestFullyShardPrefetch(FSDPTest):
model, optim, inp = self._init_transformer(
n_layers, reshard_after_forward, checkpoint_impl
)
events: List[EventType] = []
events: list[EventType] = []
unshard_with_record = self._get_unshard_with_record(
FSDPParamGroup.unshard, events
)
@ -504,7 +504,7 @@ class TestFullyShardPrefetch(FSDPTest):
model, _, inp = self._init_transformer(
n_layers, reshard_after_forward, checkpoint_impl
)
events: List[EventType] = []
events: list[EventType] = []
unshard_with_record = self._get_unshard_with_record(
FSDPParamGroup.unshard, events
)
@ -582,7 +582,7 @@ class TestFullyShardPrefetch(FSDPTest):
fully_shard(model[1].lin2, reshard_after_forward=reshard_after_forward)
fully_shard(model, reshard_after_forward=reshard_after_forward)
inp = torch.randn((4, dim), device="cuda")
events: List[EventType] = []
events: list[EventType] = []
unshard_with_record = self._get_unshard_with_record(
FSDPParamGroup.unshard, events
)
@ -652,7 +652,7 @@ class TestFullyShardPrefetch(FSDPTest):
]
layer.set_modules_to_forward_prefetch(layers_to_prefetch)
events: List[EventType] = []
events: list[EventType] = []
unshard_with_record = self._get_unshard_with_record(
FSDPParamGroup.unshard, events
)
@ -742,7 +742,7 @@ class TestFullyShardPrefetch(FSDPTest):
]
layer.set_modules_to_backward_prefetch(layers_to_prefetch)
events: List[EventType] = []
events: list[EventType] = []
unshard_with_record = self._get_unshard_with_record(
FSDPParamGroup.unshard, events
)
@ -834,7 +834,7 @@ class TestFullyShardPrefetch(FSDPTest):
fully_shard(model)
optim = torch.optim.AdamW(model.parameters(), lr=1e-2)
events: List[EventType] = []
events: list[EventType] = []
unshard_with_record = self._get_unshard_with_record(
FSDPParamGroup.unshard, events
)
@ -915,7 +915,7 @@ class TestFullyShardPrefetch(FSDPTest):
fully_shard(model)
optim = torch.optim.AdamW(model.parameters(), lr=1e-2)
events: List[EventType] = []
events: list[EventType] = []
unshard_with_record = self._get_unshard_with_record(
FSDPParamGroup.unshard, events
)
@ -1011,7 +1011,7 @@ class TestFullyShardPrefetch(FSDPTest):
return model, optim, inp
def _get_unshard_with_record(
self, orig_unshard: Callable, events: List[EventType]
self, orig_unshard: Callable, events: list[EventType]
) -> Callable:
def unshard_with_record(self, *args, **kwargs):
nonlocal events
@ -1025,7 +1025,7 @@ class TestFullyShardPrefetch(FSDPTest):
return unshard_with_record
def _get_reshard_with_record(
self, orig_reshard: Callable, events: List[EventType]
self, orig_reshard: Callable, events: list[EventType]
) -> Callable:
def reshard_with_record(self, *args, **kwargs):
nonlocal events
@ -1040,7 +1040,7 @@ class TestFullyShardPrefetch(FSDPTest):
return reshard_with_record
def _get_post_backward_with_record(
self, orig_post_backward: Callable, events: List[EventType]
self, orig_post_backward: Callable, events: list[EventType]
) -> Callable:
def post_backward_with_record(self, *args, **kwargs):
nonlocal events
@ -1080,7 +1080,7 @@ class TestFullyShardUnshardMultiProcess(FSDPTest):
self.mlp2 = MLP(dim)
self.mlp3 = MLP(dim)
def forward(self, ys: List[torch.Tensor], works: List[dist.Work]):
def forward(self, ys: list[torch.Tensor], works: list[dist.Work]):
(y1, y2, y3), (work1, work2, work3) = ys, works
work1.wait()
z1 = self.mlp1(y1)
@ -1126,7 +1126,7 @@ class TestFullyShardUnshardMultiProcess(FSDPTest):
torch.manual_seed(42 + self.rank + 1)
inp = torch.randn((batch_size, dim), device="cuda")
for _ in range(10):
losses: List[torch.Tensor] = []
losses: list[torch.Tensor] = []
for _model, _optim in ((ref_model, ref_optim), (model, optim)):
losses.append(_model(inp).sum())
losses[-1].backward()

View File

@ -6,7 +6,7 @@ import functools
import math
import threading
import unittest
from typing import Any, List, Optional, Tuple, Union
from typing import Any, Optional, Union
import torch
import torch.distributed as dist
@ -29,7 +29,7 @@ from torch.testing._internal.two_tensor import TwoTensor
def two_tensor_fsdp_pre_all_gather_v1(
self, mesh: DeviceMesh
) -> Tuple[Tuple[torch.Tensor, ...], Any]:
) -> tuple[tuple[torch.Tensor, ...], Any]:
all_gather_inputs = (self.a, self.b)
metadata = None
return all_gather_inputs, metadata
@ -39,10 +39,10 @@ def two_tensor_fsdp_pre_all_gather_v2(
self,
mesh: DeviceMesh,
outer_size: torch.Size,
outer_stride: Tuple[int, ...],
outer_stride: tuple[int, ...],
module: nn.Module,
mp_policy: MixedPrecisionPolicy,
) -> Tuple[Tuple[torch.Tensor, ...], Any]:
) -> tuple[tuple[torch.Tensor, ...], Any]:
all_gather_inputs = (self.a, self.b)
metadata = None
return all_gather_inputs, metadata
@ -50,12 +50,12 @@ def two_tensor_fsdp_pre_all_gather_v2(
def two_tensor_fsdp_post_all_gather(
self,
all_gather_outputs: Tuple[torch.Tensor, ...],
all_gather_outputs: tuple[torch.Tensor, ...],
metadata: Any,
param_dtype: torch.dtype,
*,
out: Optional[torch.Tensor] = None,
) -> Union[Tuple[torch.Tensor, Tuple[torch.Tensor, ...]], None]:
) -> Union[tuple[torch.Tensor, tuple[torch.Tensor, ...]], None]:
assert metadata is None, f"{metadata}"
a, b = all_gather_outputs
if out is not None:
@ -96,10 +96,10 @@ class BFloat16AllGatherTensor(torch.Tensor):
self,
mesh: DeviceMesh,
outer_size: torch.Size,
outer_stride: Tuple[int, ...],
outer_stride: tuple[int, ...],
module: nn.Module,
mp_policy: MixedPrecisionPolicy,
) -> Tuple[Tuple[torch.Tensor, ...], Any]:
) -> tuple[tuple[torch.Tensor, ...], Any]:
assert mesh.ndim == 1, f"{mesh.ndim}"
mesh_size = mesh.size()
requires_padding = outer_size[0] % mesh_size != 0
@ -116,12 +116,12 @@ class BFloat16AllGatherTensor(torch.Tensor):
def fsdp_post_all_gather(
self,
all_gather_outputs: Tuple[torch.Tensor, ...],
all_gather_outputs: tuple[torch.Tensor, ...],
metadata: Any,
param_dtype: torch.dtype,
*,
out: Optional[torch.Tensor] = None,
) -> Union[Tuple[torch.Tensor, Tuple[torch.Tensor, ...]], None]:
) -> Union[tuple[torch.Tensor, tuple[torch.Tensor, ...]], None]:
assert metadata is None, f"{metadata}"
(tensor,) = all_gather_outputs
assert tensor.dtype == torch.bfloat16, f"{tensor.dtype}"
@ -157,7 +157,7 @@ class BFloat16AllGatherTensor(torch.Tensor):
@staticmethod
def __tensor_unflatten__(
inner_tensors, outer_size: torch.Size, outer_stride: Tuple[int, ...]
inner_tensors, outer_size: torch.Size, outer_stride: tuple[int, ...]
):
return inner_tensors["_data"]
@ -236,7 +236,7 @@ class TestFullyShardAllGatherExtensionsMultiProcess(
torch.manual_seed(42 + self.rank + 1)
inp = torch.randn((2, 8), device="cuda")
for iter_idx in range(10):
losses: List[torch.Tensor] = []
losses: list[torch.Tensor] = []
for _model in (ref_model, model):
losses.append(_model(inp).sum())
losses[-1].backward()
@ -314,10 +314,10 @@ class TestFullyShardAllGatherExtensionsMultiThread(
self,
mesh: DeviceMesh,
outer_size: torch.Size,
outer_stride: Tuple[int, ...],
outer_stride: tuple[int, ...],
module: nn.Module,
mp_policy: MixedPrecisionPolicy,
) -> Tuple[Tuple[torch.Tensor, ...], Any]:
) -> tuple[tuple[torch.Tensor, ...], Any]:
nonlocal tls
tls.ran_pre_all_gather = True
return (self.to(torch.bfloat16),), None
@ -325,12 +325,12 @@ class TestFullyShardAllGatherExtensionsMultiThread(
@torch.no_grad()
def fsdp_post_all_gather(
self,
all_gather_outputs: Tuple[torch.Tensor, ...],
all_gather_outputs: tuple[torch.Tensor, ...],
metadata: Any,
param_dtype: torch.dtype,
*,
out: Optional[torch.Tensor] = None,
) -> Union[Tuple[torch.Tensor, Tuple[torch.Tensor, ...]], None]:
) -> Union[tuple[torch.Tensor, tuple[torch.Tensor, ...]], None]:
(tensor,) = all_gather_outputs
assert metadata is None, f"{metadata}"
assert tensor.dtype == torch.bfloat16, f"{tensor.dtype}"
@ -416,10 +416,10 @@ class TestFullyShardAllGatherExtensionsMultiThread(
self,
mesh: DeviceMesh,
outer_size: torch.Size,
outer_stride: Tuple[int, ...],
outer_stride: tuple[int, ...],
module: nn.Module,
mp_policy: MixedPrecisionPolicy,
) -> Tuple[Tuple[torch.Tensor, ...], Any]:
) -> tuple[tuple[torch.Tensor, ...], Any]:
nonlocal tls
tls.mesh = mesh
return (self,), None
@ -427,12 +427,12 @@ class TestFullyShardAllGatherExtensionsMultiThread(
@torch.no_grad()
def fsdp_post_all_gather(
self,
all_gather_outputs: Tuple[torch.Tensor, ...],
all_gather_outputs: tuple[torch.Tensor, ...],
metadata: Any,
param_dtype: torch.dtype,
*,
out: Optional[torch.Tensor] = None,
) -> Union[Tuple[torch.Tensor, Tuple[torch.Tensor, ...]], None]:
) -> Union[tuple[torch.Tensor, tuple[torch.Tensor, ...]], None]:
(tensor,) = all_gather_outputs
if out is not None:
return

View File

@ -3,7 +3,7 @@
import copy
import functools
import itertools
from typing import List, Union
from typing import Union
import torch
import torch.distributed as dist
@ -116,7 +116,7 @@ class TestFullyShardFrozen(FSDPTest):
), patch_register_post_backward_hook_backward(backward_with_count):
for iter_idx in range(10):
inp = torch.randn((8, lin_dim), device=device)
losses: List[torch.Tensor] = []
losses: list[torch.Tensor] = []
for _model, _optim in ((ref_model, ref_optim), (model, optim)):
_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
losses.append(_model(inp).sum())
@ -151,7 +151,7 @@ class TestFullyShardFrozen(FSDPTest):
):
torch.manual_seed(42)
num_linears, lin_dim = (6, 32)
modules: List[nn.Module] = []
modules: list[nn.Module] = []
for _ in range(num_linears):
modules += [nn.Linear(lin_dim, lin_dim), nn.ReLU()]
model = nn.Sequential(*modules)
@ -187,7 +187,7 @@ class TestFullyShardFrozen(FSDPTest):
inp = torch.randn((8, lin_dim), device="cuda")
with patch_register_post_backward_hook_backward(backward_with_count):
for iter_idx in range(num_iters):
losses: List[torch.Tensor] = []
losses: list[torch.Tensor] = []
for _model, _optim in ((ref_model, ref_optim), (model, optim)):
# Unfreeze the parameters on the last step to emulate some
# kinds of fine-tuning
@ -251,7 +251,7 @@ class TestFullyShardFrozen(FSDPTest):
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
for iter_idx in range(10):
inp = torch.randn((8, 5), device="cuda")
losses: List[torch.Tensor] = []
losses: list[torch.Tensor] = []
for _model, _optim in ((ref_model, ref_optim), (model, optim)):
_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
losses.append(_model(inp).sum())

View File

@ -3,7 +3,7 @@
import copy
import itertools
import unittest
from typing import List, Optional
from typing import Optional
import torch
import torch.distributed as dist
@ -211,8 +211,8 @@ class TestFullyShardManagedModulesAndStates(FSDPTestMultiThread):
def _check_managed_modules(
self,
managed_modules: List[nn.Module],
expected_managed_modules: List[nn.Module],
managed_modules: list[nn.Module],
expected_managed_modules: list[nn.Module],
):
self.assertEqual(len(managed_modules), len(expected_managed_modules))
# Check set comparison since we do not require anything about the order
@ -262,10 +262,10 @@ class TestFullyShardManagedModulesAndStates(FSDPTestMultiThread):
def _check_managed_states(
self,
managed_params: List[nn.Parameter],
managed_buffers: List[torch.Tensor],
expected_managed_params: List[nn.Parameter],
expected_managed_buffers: List[torch.Tensor],
managed_params: list[nn.Parameter],
managed_buffers: list[torch.Tensor],
expected_managed_params: list[nn.Parameter],
expected_managed_buffers: list[torch.Tensor],
):
self.assertEqual(len(managed_params), len(expected_managed_params))
self.assertEqual(len(managed_buffers), len(expected_managed_buffers))
@ -370,7 +370,7 @@ class TestFullyShardShardedParameterTensor(FSDPTestMultiThread):
self._check_1d_sharded_parameters(orig_params, sharded_params)
def _check_1d_sharded_parameters(
self, orig_params: List[nn.Parameter], sharded_params: List[nn.Parameter]
self, orig_params: list[nn.Parameter], sharded_params: list[nn.Parameter]
):
self.assertEqual(len(orig_params), len(sharded_params))
global_mesh = init_device_mesh("cuda", (self.world_size,))

View File

@ -2,7 +2,7 @@
import copy
import functools
from typing import Dict, List, Optional, Union
from typing import Optional, Union
import torch
import torch.distributed as dist
@ -339,7 +339,7 @@ class TestFullyShardMixedPrecisionTraining(FSDPTest):
model.set_reshard_after_backward(
is_last_microbatch or reshard_after_forward
)
losses: List[torch.Tensor] = []
losses: list[torch.Tensor] = []
for _model in (ref_model_compute, model):
losses.append(
_model(microbatch_inps[microbatch_idx].detach()).sum()
@ -391,7 +391,7 @@ class TestFullyShardMixedPrecisionCasts(FSDPTestMultiThread):
# Subtest 1: use fp16 on the second child submodule -- does not require
# any additional casting logic
forward_inputs: Dict[str, nn.Module] = {}
forward_inputs: dict[str, nn.Module] = {}
model = SaveForwardInputsModel(
forward_inputs,
cast_forward_inputs=False,
@ -405,7 +405,7 @@ class TestFullyShardMixedPrecisionCasts(FSDPTestMultiThread):
# Subtest 2: use fp16 on the second child module, where the user module
# owns the cast
forward_inputs: Dict[nn.Module, torch.Tensor] = {}
forward_inputs: dict[nn.Module, torch.Tensor] = {}
model = SaveForwardInputsModel(
forward_inputs=forward_inputs, cast_forward_inputs=True
).cuda()
@ -423,7 +423,7 @@ class TestFullyShardMixedPrecisionCasts(FSDPTestMultiThread):
# Subtest 3: use fp16 on the first child module and specify its output
# dtype so that the second child module does not need to cast
forward_inputs: Dict[nn.Module, torch.Tensor] = {}
forward_inputs: dict[nn.Module, torch.Tensor] = {}
model = SaveForwardInputsModel(
forward_inputs=forward_inputs, cast_forward_inputs=False
).cuda()
@ -448,7 +448,7 @@ class TestFullyShardMixedPrecisionCasts(FSDPTestMultiThread):
def _test_submodules_with_external_inputs(self, enable_submodule_cast: bool):
class ToyModule(nn.Module):
def __init__(self, forward_inputs: Dict[str, torch.Tensor]) -> None:
def __init__(self, forward_inputs: dict[str, torch.Tensor]) -> None:
super().__init__()
self.l = nn.Linear(100, 100)
self.forward_inputs = forward_inputs
@ -459,7 +459,7 @@ class TestFullyShardMixedPrecisionCasts(FSDPTestMultiThread):
return self.l(x)
class ToyModel(nn.Module):
def __init__(self, forward_inputs: Dict[str, torch.Tensor]) -> None:
def __init__(self, forward_inputs: dict[str, torch.Tensor]) -> None:
super().__init__()
self.l1 = nn.Linear(100, 100)
self.l2 = ToyModule(forward_inputs)
@ -472,7 +472,7 @@ class TestFullyShardMixedPrecisionCasts(FSDPTestMultiThread):
) # external input
return self.l2(self.l1(x), y)
forward_inputs: Dict[str, torch.Tensor] = {}
forward_inputs: dict[str, torch.Tensor] = {}
model = ToyModel(forward_inputs).cuda()
x = torch.zeros(2, 100, device="cuda", dtype=torch.float32)
fully_shard(

View File

@ -4,7 +4,7 @@ import copy
import functools
import unittest
from contextlib import nullcontext
from typing import Dict, Optional
from typing import Optional
import torch
import torch.nn as nn
@ -308,7 +308,7 @@ class TestFullyShardStateDictMultiProcess(FSDPTest):
# Verify that we can load a new state dict that contains DTensors with
# storages different from the current model parameters
new_state_dict: Dict[str, DTensor] = {}
new_state_dict: dict[str, DTensor] = {}
for param_name, dtensor in state_dict.items():
# Construct new DTensors to exercise load state dict writeback
new_state_dict[param_name] = dtensor.detach().clone().fill_(new_fill_value)

View File

@ -7,7 +7,7 @@ import itertools
import unittest
from collections import defaultdict
from collections.abc import Iterable
from typing import Any, List, Optional, Tuple, Union
from typing import Any, Optional, Union
import torch
import torch.distributed as dist
@ -65,7 +65,7 @@ class TestFullyShardForwardInputs(FSDPTestMultiThread):
device = torch.device("cuda", 0)
class ParamlessModule(nn.Module):
def forward(self, x: torch.Tensor, ys: Tuple[torch.Tensor, ...]):
def forward(self, x: torch.Tensor, ys: tuple[torch.Tensor, ...]):
# Check that FSDP moved the inputs to GPU, including recursing
# into the tuple data structure
assert x.device == device, f"Expects {device} but got {x.device}"
@ -224,7 +224,7 @@ class TestFullyShardCastAfterInit(FSDPTestMultiThread):
torch.manual_seed(42 + self.rank + 1)
inp = torch.randn((2, mlp_dim), device="cuda", dtype=dtype)
for iter_idx in range(10):
losses: List[torch.Tensor] = []
losses: list[torch.Tensor] = []
for _model in (ref_model, model):
losses.append(_model(inp).sum())
losses[-1].backward()
@ -281,7 +281,7 @@ class TestFullyShard1DTrainingCore(FSDPTest):
)
def _test_train_parity_single_group(
self, lin_shapes: List[Tuple[int, int]], use_shard_placement_fn: bool
self, lin_shapes: list[tuple[int, int]], use_shard_placement_fn: bool
):
torch.manual_seed(42)
model = nn.Sequential(
@ -308,7 +308,7 @@ class TestFullyShard1DTrainingCore(FSDPTest):
torch.manual_seed(42 + self.rank + 1)
inp = (torch.randn((4, lin_shapes[0][0]), device="cuda"),)
for iter_idx in range(10):
losses: List[torch.Tensor] = []
losses: list[torch.Tensor] = []
for _model, _optim in ((ref_model, ref_optim), (model, optim)):
_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
losses.append(_model(*inp).sum())
@ -461,7 +461,7 @@ class TestFullyShard1DTrainingCore(FSDPTest):
with patch_all_gather_ctx, patch_reduce_scatter_ctx:
for iter_idx in range(10):
inp = torch.randint(0, vocab_size, (3, 64), device=device_type)
losses: List[torch.Tensor] = []
losses: list[torch.Tensor] = []
for _model, _optim in ((ref_model, ref_optim), (model, optim)):
_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
losses.append(_model(inp).sum())
@ -554,7 +554,7 @@ class TestFullyShard1DTrainingCore(FSDPTest):
torch.manual_seed(42 + self.rank)
inp = torch.randn((32, 4), device="cuda")
for iter_idx in range(10):
losses: List[torch.Tensor] = []
losses: list[torch.Tensor] = []
for _model, _optim in ((ref_model, ref_optim), (model, optim)):
_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
losses.append(_model(inp).sum())
@ -592,7 +592,7 @@ class TestFullyShard1DTrainingCore(FSDPTest):
torch.manual_seed(42 + self.rank)
inp = torch.randint(0, model_args.vocab_size, (2, 8), device="cuda")
for _ in range(10):
losses: List[torch.Tensor] = []
losses: list[torch.Tensor] = []
for _model, _optim in ((ref_model, ref_optim), (model, optim)):
_optim.zero_grad()
losses.append(_model(inp).sum())
@ -623,8 +623,8 @@ class TestFullyShard1DTrainingCore(FSDPTest):
inp = torch.randint(0, model_args.vocab_size, (2, 8), device="cuda")
# Track all losses and check for equality at the end to avoid a CPU
# sync point after each iteration
ref_losses: List[torch.Tensor] = []
losses: List[torch.Tensor] = []
ref_losses: list[torch.Tensor] = []
losses: list[torch.Tensor] = []
for _ in range(10):
ref_optim.zero_grad()
ref_losses.append(ref_model(inp).sum())
@ -736,7 +736,7 @@ class TestFullyShard1DTrainingCompose(FSDPTest):
self, ref_model, model, prefixes_to_ignore=prefixes_to_ignore
)
for iter_idx in range(10):
losses: List[torch.Tensor] = []
losses: list[torch.Tensor] = []
for _model in (ref_model, model):
torch.manual_seed(iter_idx + 1) # for dropout determinism
losses.append(_model(inp).sum())
@ -886,7 +886,7 @@ class TestFullyShardSharedParams(FSDPTest):
torch.manual_seed(42 + self.rank + 1)
for iter_idx in range(10):
inp = torch.randint(0, model_args.vocab_size, (2, 16), device="cuda")
losses: List[torch.Tensor] = []
losses: list[torch.Tensor] = []
for _model, _optim in ((ref_model, ref_optim), (model, optim)):
_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
losses.append(_model(inp).sum())
@ -1009,7 +1009,7 @@ class TestFullyShardGradientAccumulation(FSDPTest):
is_last_microbatch = microbatch_idx == num_microbatches - 1
set_backward_flags(model, is_last_microbatch)
inp = torch.randn(batch_size, lin_dim, device="cuda")
losses: List[torch.Tensor] = []
losses: list[torch.Tensor] = []
for _model in (ref_model, model):
with CommDebugMode() as comm_mode:
losses.append(_model(inp).sum())
@ -1125,8 +1125,8 @@ class TestFullyShardGradientAccumulation(FSDPTest):
# Emulate the 1f1b pipeline schedule and only reduce gradients on the
# last microbatch
losses: List[torch.Tensor] = []
ref_losses: List[torch.Tensor] = []
losses: list[torch.Tensor] = []
ref_losses: list[torch.Tensor] = []
for inp_idx, inp in enumerate(inps):
is_last_microbatch = inp_idx == num_microbatches - 1
model.set_requires_gradient_sync(is_last_microbatch)
@ -1210,7 +1210,7 @@ class TestFullyShardNDTraining(FSDPTest):
device = torch.device("cuda")
for iter_idx in range(10):
inp = torch.randn((8, mlp_dim), device=device)
losses: List[torch.Tensor] = []
losses: list[torch.Tensor] = []
for _model, _optim in ((ref_model, ref_optim), (model, optim)):
_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
losses.append(_model(inp).sum())
@ -1281,7 +1281,7 @@ class TestFullyShardHSDP3DTraining(FSDPTest):
device = torch.device("cuda")
for iter_idx in range(10):
inp = torch.randn((8, mlp_dim), device=device)
losses: List[torch.Tensor] = []
losses: list[torch.Tensor] = []
for _model, _optim in ((ref_model, ref_optim), (model, optim)):
_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
losses.append(_model(inp).sum())
@ -1360,7 +1360,7 @@ class TestFullyShardHSDPTraining(FSDPTest):
if sync_gradients_at_last_batch:
model.set_requires_gradient_sync(is_last_microbatch)
inp = torch.randn((8, mlp_dim), device=device)
losses: List[torch.Tensor] = []
losses: list[torch.Tensor] = []
for _model, _optim in ((ref_model, ref_optim), (model, optim)):
losses.append(_model(inp).sum())
losses[-1].backward()

View File

@ -5,7 +5,6 @@ from collections import deque, OrderedDict
from contextlib import ContextDecorator, contextmanager, nullcontext
from copy import deepcopy
from functools import partial
from typing import Tuple
import torch
import torch.nn as nn
@ -70,7 +69,7 @@ class MultiOutputModel(nn.Module):
self.w1 = nn.Parameter(torch.randn((100, 100), device=device))
self.w2 = nn.Parameter(torch.randn((100, 100), device=device))
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
z = x @ self.w1
z = nn.functional.relu(z)
z = z @ self.w2
@ -82,7 +81,7 @@ class MultiInputModel(nn.Module):
super().__init__()
self.w = nn.Parameter(torch.randn((100, 100), device=device))
def forward(self, xs: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
def forward(self, xs: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
assert len(xs) == 2, f"Expects 2 args but got {len(xs)}"
x, y = xs
z = x + y

View File

@ -4,7 +4,7 @@ import copy
import functools
import io
from copy import deepcopy
from typing import List, Optional, Type
from typing import Optional
import torch
import torch.distributed as dist
@ -166,7 +166,7 @@ class TestFullyShard2DTraining(FSDPTest):
device = torch.device("cuda")
for iter_idx in range(10):
inp = torch.randn((8, mlp_dim), device=device)
losses: List[torch.Tensor] = []
losses: list[torch.Tensor] = []
for _model, _optim in ((ref_model, ref_optim), (model, optim)):
_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
losses.append(_model(inp).sum())
@ -335,7 +335,7 @@ class TestFullyShard2DTraining(FSDPTest):
self,
use_seq_parallel: bool,
reuse_model_optim: bool,
optimizer_class: Type[torch.optim.Optimizer],
optimizer_class: type[torch.optim.Optimizer],
foreach: bool,
):
def train_step(

View File

@ -1,7 +1,6 @@
# Owner(s): ["oncall: distributed"]
from copy import deepcopy
from typing import List, Tuple
import torch
import torch.nn as nn
@ -28,12 +27,12 @@ class TestContract(TestCase):
@skipIfTorchDynamo("Dynamo does not support the state key")
def test_add_hooks(self):
def forward_pre_hook(
module: nn.Module, inp: Tuple[torch.Tensor]
) -> Tuple[torch.Tensor]:
module: nn.Module, inp: tuple[torch.Tensor]
) -> tuple[torch.Tensor]:
return inp
def forward_hook(
module: nn.Module, inp: Tuple[torch.Tensor], out: torch.Tensor
module: nn.Module, inp: tuple[torch.Tensor], out: torch.Tensor
) -> torch.Tensor:
return out
@ -44,9 +43,9 @@ class TestContract(TestCase):
def backward_hook(
module: nn.Module,
grad_input: Tuple[torch.Tensor],
grad_input: tuple[torch.Tensor],
grad_output: torch.Tensor,
) -> Tuple[torch.Tensor]:
) -> tuple[torch.Tensor]:
return grad_input
@contract()
@ -92,8 +91,8 @@ class TestContract(TestCase):
@skipIfTorchDynamo("Dynamo does not support the state key")
def test_state(self):
def check_and_update_state_hook(
module: nn.Module, inp: Tuple[torch.Tensor]
) -> Tuple[torch.Tensor]:
module: nn.Module, inp: tuple[torch.Tensor]
) -> tuple[torch.Tensor]:
self.assertEqual(api.state(module).dummy_state, 7)
api.state(module).dummy_state = 8
return inp
@ -139,7 +138,7 @@ class TestContract(TestCase):
@skipIfTorchDynamo("Dynamo does not support the state key")
def test_multi_module_api(self):
@contract()
def multi_module_api(modules: List[nn.Module]) -> nn.Module:
def multi_module_api(modules: list[nn.Module]) -> nn.Module:
return modules
model = nn.Sequential(*[nn.Linear(3, 3) for _ in range(5)])

View File

@ -6,7 +6,6 @@ import itertools
import math
import pickle
import sys
from typing import List
import torch
import torch.distributed as dist
@ -3186,7 +3185,7 @@ class TestCreateTensorNoProcessGroupMode(TestCase):
],
size=torch.Size([4, 2]),
)
st_local_shards: List[Shard] = []
st_local_shards: list[Shard] = []
for shard_metadata in st_metadata.shards_metadata:
st_local_shards.append(
Shard(
@ -3215,7 +3214,7 @@ class TestCreateTensorNoProcessGroupMode(TestCase):
],
size=torch.Size([4, 2]),
)
st_local_shards: List[Shard] = []
st_local_shards: list[Shard] = []
src = torch.randn(4, 2)
for shard_metadata in st_metadata.shards_metadata:
offsets = shard_metadata.shard_offsets

View File

@ -1,7 +1,7 @@
# Owner(s): ["oncall: distributed"]
import copy
from dataclasses import dataclass
from typing import List, Union
from typing import Union
import torch
from torch.distributed._shard import _shard_tensor, sharded_tensor
@ -495,7 +495,7 @@ class TestShardingSpec(TestCase):
@dataclass
class GridShardingSpec(ShardingSpec):
grid_size: int
placements: List[Union[torch.distributed._remote_device, str]]
placements: list[Union[torch.distributed._remote_device, str]]
def __post_init__(self):
for i, remote_device in enumerate(self.placements):

View File

@ -1,7 +1,6 @@
# Owner(s): ["module: unknown"]
import gc
import unittest
from typing import Tuple
import torch
import torch.nn as nn
@ -161,7 +160,7 @@ class TestMemTracker(TestCase):
def get_param_grad_optstate_actual_bytes(
model: nn.Module, opt: torch.optim.Optimizer
) -> Tuple[int, int, int]:
) -> tuple[int, int, int]:
param_bytes = 0
grad_bytes = 0
opt_state_bytes = 0
@ -179,7 +178,7 @@ class TestMemTracker(TestCase):
def get_param_grad_optstate_bytes_from_tracker(
tracker: MemTracker,
) -> Tuple[int, int, int]:
) -> tuple[int, int, int]:
snapshot = tracker.get_tracker_snapshot()
param_bytes = snapshot[dev]["Parameter"]
grad_bytes = snapshot[dev]["Gradient"]

View File

@ -1,7 +1,7 @@
# Owner(s): ["module: unknown"]
import unittest
from dataclasses import dataclass
from typing import Any, Callable, cast, Tuple, Union
from typing import Any, Callable, cast, Union
import torch
from torch import nn, optim
@ -73,7 +73,7 @@ class TestRuntimeEstimator(TestCase):
def _measure_actual_cuda_time(
self,
func: Callable,
args: Tuple[Any, ...],
args: tuple[Any, ...],
) -> float:
warmup_iters, actual_iters = 2, 5
start_event = torch.cuda.Event(enable_timing=True)
@ -92,7 +92,7 @@ class TestRuntimeEstimator(TestCase):
self,
estimate_mode: str,
func: Callable,
args: Tuple[Any, ...],
args: tuple[Any, ...],
) -> float:
# Optimizer init step
func(*args)
@ -106,7 +106,7 @@ class TestRuntimeEstimator(TestCase):
model_type: str,
model_args: Union[ConvArgs, ModelArgs],
bsz: int,
) -> Tuple[nn.Module, optim.Optimizer, torch.Tensor]:
) -> tuple[nn.Module, optim.Optimizer, torch.Tensor]:
dev = torch.cuda.current_device()
if model_type == "Transformer":
model_args = cast(ModelArgs, model_args)

View File

@ -1,7 +1,6 @@
# Owner(s): ["module: unknown"]
import copy
import unittest
from typing import Tuple
import torch
from torch._subclasses.fake_tensor import FakeTensorMode
@ -40,7 +39,7 @@ class TestSACILP(TestCase):
def _init_model_input_optimizer(
self,
) -> Tuple[torch.nn.Module, torch.optim.Optimizer, torch.Tensor]:
) -> tuple[torch.nn.Module, torch.optim.Optimizer, torch.Tensor]:
bsz = 8
model_args = ModelArgs(
n_layers=4,

View File

@ -5,7 +5,7 @@ from dataclasses import dataclass, field
from enum import auto, Enum
from functools import partial
from io import BytesIO
from typing import Any, Dict, List
from typing import Any
import torch
import torch.distributed as dist
@ -95,9 +95,9 @@ class ModelType(Enum):
class TestTrainState:
step: int = 0
current_loss: float = -1
losses: List[float] = field(default_factory=list)
losses: list[float] = field(default_factory=list)
def state_dict(self) -> Dict[str, Any]:
def state_dict(self) -> dict[str, Any]:
loss_bytes = BytesIO()
torch.save(self.losses, loss_bytes)
return {
@ -284,7 +284,7 @@ class TestE2ESaveAndLoad(DTensorTestBase, VerifyStateDictMixin):
@with_temp_dir
def test_stateful_and_non_stateful_loads(self) -> None:
class StateDict(Dict):
class StateDict(dict):
def __init__(self):
self.set_sd_item_called = False

View File

@ -2,7 +2,7 @@
import os
import sys
from typing import cast, List, Optional, Union
from typing import cast, Optional, Union
import torch
import torch.distributed as dist
@ -177,17 +177,17 @@ class FaultyStorageWriter(TestStorageBase, StorageWriter):
self._fail_rank("fail_prepare_local_plan")
return plan
def prepare_global_plan(self, plans: List[SavePlan]) -> List[SavePlan]:
def prepare_global_plan(self, plans: list[SavePlan]) -> list[SavePlan]:
self._fail_rank("fail_prepare_global_plan")
return plans
def write_data(
self, plan: SavePlan, planner: SavePlanner
) -> Future[List[WriteResult]]:
) -> Future[list[WriteResult]]:
self._fail_rank("fail_write_data")
return self._fail_rank_async("fail_write_data_async", [])
def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None:
def finish(self, metadata: Metadata, results: list[list[WriteResult]]) -> None:
self._fail_rank("fail_finish")
@classmethod
@ -210,7 +210,7 @@ class FaultyStorageReader(TestStorageBase, StorageReader):
self._fail_rank("fail_prepare_local_plan")
return plan
def prepare_global_plan(self, plans: List[LoadPlan]) -> List[LoadPlan]:
def prepare_global_plan(self, plans: list[LoadPlan]) -> list[LoadPlan]:
self._fail_rank("fail_prepare_global_plan")
return plans

View File

@ -1,5 +1,5 @@
# Owner(s): ["oncall: distributed"]
from typing import Dict, Union
from typing import Union
import torch
import torch.distributed as dist
@ -58,14 +58,14 @@ class MyTestModule(torch.nn.Module):
def extra_state_tensor(self, new_extra_state_tensor: torch.Tensor) -> None:
self._extra_state_tensor = new_extra_state_tensor
def get_extra_state(self) -> Dict[str, Union[int, torch._tensor.Tensor]]:
def get_extra_state(self) -> dict[str, Union[int, torch._tensor.Tensor]]:
return {
"extra_state": self._extra_state,
"extra_state_tensor": self._extra_state_tensor,
}
def set_extra_state(
self, state: Dict[str, Union[int, torch._tensor.Tensor]]
self, state: dict[str, Union[int, torch._tensor.Tensor]]
) -> None:
self._extra_state = state["extra_state"] # pyre-ignore[8]
self._extra_state_tensor = state["extra_state_tensor"] # pyre-ignore[8]

View File

@ -4,7 +4,6 @@ import os
import shutil
import sys
import tempfile
from typing import Dict
import torch
import torch.distributed as dist
@ -53,8 +52,8 @@ if TEST_WITH_DEV_DBG_ASAN:
def assert_state_dict_equal(
self: TestCase,
state_dict_1: Dict[str, torch.Tensor],
state_dict_2: Dict[str, torch.Tensor],
state_dict_1: dict[str, torch.Tensor],
state_dict_2: dict[str, torch.Tensor],
) -> bool:
self.assertEqual(
len(state_dict_1), len(state_dict_2), "state_dict must be the same size"

View File

@ -2,7 +2,7 @@
import sys
import tempfile
from typing import Any, Dict, IO
from typing import Any, IO
import torch
import torch.distributed as dist
@ -56,8 +56,8 @@ _THREAD_COUNTS = {1, 2}
def assert_state_dict_equal(
self: TestCase,
state_dict_1: Dict[str, torch.Tensor],
state_dict_2: Dict[str, torch.Tensor],
state_dict_1: dict[str, torch.Tensor],
state_dict_2: dict[str, torch.Tensor],
) -> bool:
self.assertEqual(
len(state_dict_1), len(state_dict_2), "state_dict must be the same size"
@ -113,10 +113,10 @@ class BlobState:
def __init__(self, value: IO[bytes]) -> Any:
self.state = {"blob": value}
def state_dict(self) -> Dict[str, Any]:
def state_dict(self) -> dict[str, Any]:
return self.state
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
self.state = state_dict
def __eq__(self, other: object) -> bool:

View File

@ -3,7 +3,7 @@
import shutil
import tempfile
from functools import wraps
from typing import Any, Callable, Dict, Optional, Tuple
from typing import Any, Callable, Optional
import torch
import torch.distributed as dist
@ -35,7 +35,7 @@ def with_temp_dir(
assert func is not None
@wraps(func)
def wrapper(self, *args: Tuple[object], **kwargs: Dict[str, Any]) -> None:
def wrapper(self, *args: tuple[object], **kwargs: dict[str, Any]) -> None:
# Only create temp_dir when rank is 0 (or no pg)
if not dist.is_initialized() or dist.get_rank() == 0:
temp_dir = tempfile.mkdtemp()

View File

@ -4,7 +4,7 @@ import copy
import functools
import sys
from itertools import chain
from typing import Callable, Tuple, Type, Union
from typing import Callable, Union
import torch
import torch.distributed as dist
@ -154,9 +154,9 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
*,
use_orig_params: bool,
use_dtensor: bool,
wrapping: Tuple[nn.Module] = (),
wrapping: tuple[nn.Module] = (),
compile_model: bool = False,
optimizer_class: Type[Optimizer],
optimizer_class: type[Optimizer],
) -> None:
if not use_orig_params:
return
@ -232,7 +232,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
self,
*,
reshard_after_forward: Union[bool, int],
optimizer_class: Type[Optimizer],
optimizer_class: type[Optimizer],
compile_model: bool,
foreach: bool = True,
):
@ -272,7 +272,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
self._test_fsdp2,
)
def _test_ddp(self, use_composable: bool, optimizer_class: Type[Optimizer]) -> None:
def _test_ddp(self, use_composable: bool, optimizer_class: type[Optimizer]) -> None:
def init_model_optim():
orig_model = CompositeParamModel(device=torch.device("cuda"))
orig_optim = optimizer_class(orig_model.parameters(), lr=1e-4)
@ -303,7 +303,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
def _test_fsdp_ddp(
self,
optimizer_class: Type[Optimizer],
optimizer_class: type[Optimizer],
optim_in_backward: bool = False,
test_frozen: bool = False,
) -> None:
@ -347,7 +347,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
self._test_fsdp_ddp,
)
def _test_single_gpu(self, optimizer_class: Type[Optimizer]) -> None:
def _test_single_gpu(self, optimizer_class: type[Optimizer]) -> None:
def init_model_optim():
orig_model = CompositeParamModel(device=torch.device("cuda"))
orig_optim = optimizer_class(orig_model.parameters(), lr=1e-4)
@ -399,7 +399,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
)
def _test_cpu_offload_full_state_dict(
self, optimizer_class: Type[Optimizer]
self, optimizer_class: type[Optimizer]
) -> None:
orig_model = CompositeParamModel(device=torch.device("cuda"))
device_mesh = init_device_mesh("cuda", (self.world_size,))

View File

@ -14,7 +14,7 @@ import signal
import unittest
import uuid
from multiprocessing.pool import ThreadPool
from typing import Any, Dict, List
from typing import Any
from unittest.mock import call, patch
import torch.distributed as dist
@ -135,7 +135,7 @@ class TestAgent(SimpleElasticAgent):
worker_group.group_world_size = None
self.stop_workers_call_count += 1
def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]:
def _start_workers(self, worker_group: WorkerGroup) -> dict[int, Any]:
# crate fake workers; make worker id equal to global rank
ids = {}
for worker in worker_group.workers:
@ -477,7 +477,7 @@ class SimpleElasticAgentTest(unittest.TestCase):
self.assertEqual(1, mock_monitor_workers.call_count)
self.assertEqual(spec.max_restarts, agent._remaining_restarts)
def get_worker_assigned(self, store, role_infos_len, info) -> List[Worker]:
def get_worker_assigned(self, store, role_infos_len, info) -> list[Worker]:
i, role_info = info
spec = self._get_worker_spec(
max_restarts=3,

View File

@ -17,7 +17,7 @@ import time
import unittest
import uuid
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, Tuple
from typing import Callable, Optional
from unittest import mock
from unittest.mock import Mock, patch
@ -256,7 +256,7 @@ class Conf:
entrypoint: Callable
local_world_size: int
args: Tuple = ()
args: tuple = ()
role: str = "default"
redirects: Std = Std.NONE
tee: Std = Std.NONE
@ -394,10 +394,10 @@ class LocalElasticAgentTest(unittest.TestCase):
def run_job(
self,
node_configs: List[Conf],
node_configs: list[Conf],
exit_barrier_timeout: int = 5,
log_line_prefix_template: Optional[str] = None,
) -> Dict[str, List[RunResult]]:
) -> dict[str, list[RunResult]]:
"""
Simulates running a distributed job by running multiple agents
(one on each process). Agent 0 is run on the main process for
@ -431,7 +431,7 @@ class LocalElasticAgentTest(unittest.TestCase):
for p in procs:
p.join()
results: Dict[str, List[RunResult]] = {}
results: dict[str, list[RunResult]] = {}
while not agent_results.empty():
role, run_result = agent_results.get()
results.setdefault(role, []).append(run_result)
@ -1032,8 +1032,8 @@ class LocalElasticAgentTest(unittest.TestCase):
def assert_rank_consistency(
self,
run_results: Dict[str, List[RunResult]],
expected_role_world_sizes: Dict[str, int],
run_results: dict[str, list[RunResult]],
expected_role_world_sizes: dict[str, int],
):
"""
Asserts that ranks are consecutive w.r.t role_rank. If local world sizes are 4:
@ -1042,11 +1042,11 @@ class LocalElasticAgentTest(unittest.TestCase):
... etc ...
"""
global_ranks: List[int] = []
global_ranks: list[int] = []
# role -> [role_rank,...]
role_ranks: Dict[str, List[int]] = {}
role_ranks: dict[str, list[int]] = {}
# group rank -> [(rank, role_rank),...]
grouped_ranks: Dict[int, List[Tuple[int, int]]] = {}
grouped_ranks: dict[int, list[tuple[int, int]]] = {}
# global world size == sum of all the role world sizes
expected_world_size = sum(expected_role_world_sizes.values())

View File

@ -16,7 +16,7 @@ import sys
import tempfile
import time
from itertools import product
from typing import Callable, Dict, List, Union
from typing import Callable, Union
from unittest import mock
import torch
@ -141,7 +141,7 @@ def echo2(msg: str, fail: bool = False) -> str:
return msg
def echo_large(size: int) -> Dict[int, str]:
def echo_large(size: int) -> dict[int, str]:
"""
returns a large output ({0: test0", 1: "test1", ..., (size-1):f"test{size-1}"})
"""
@ -167,13 +167,13 @@ def dummy_compute() -> torch.Tensor:
return torch.rand(100, 100)
def redirects_oss_test() -> List[Std]:
def redirects_oss_test() -> list[Std]:
return [
Std.NONE,
]
def redirects_all() -> List[Std]:
def redirects_all() -> list[Std]:
return [
Std.NONE,
Std.OUT,
@ -240,14 +240,14 @@ class _StartProcessesTest(TestCase):
def log_dir(self):
return tempfile.mkdtemp(dir=self.test_dir)
def assert_in_file(self, expected: List[str], filename: str) -> None:
def assert_in_file(self, expected: list[str], filename: str) -> None:
expected = [f"{line.rstrip()}\n" for line in expected]
with open(filename) as fp:
actual = fp.readlines()
for line in expected:
self.assertIn(line, actual)
def assert_pids_noexist(self, pids: Dict[int, int]):
def assert_pids_noexist(self, pids: dict[int, int]):
for local_rank, pid in pids.items():
with self.assertRaises(
OSError, msg=f"local_rank: {local_rank} pid: {pid} should not exist"

View File

@ -16,7 +16,6 @@ import unittest
from concurrent.futures import wait
from concurrent.futures._base import ALL_COMPLETED
from concurrent.futures.thread import ThreadPoolExecutor
from typing import Dict, Set
from unittest import mock
from torch.distributed.elastic.multiprocessing.tail_log import TailLog
@ -72,7 +71,7 @@ class TailLogTest(unittest.TestCase):
tail.stop()
dst.seek(0)
actual: Dict[int, Set[int]] = {}
actual: dict[int, set[int]] = {}
for line in dst.readlines():
header, num = line.split(":")
@ -123,7 +122,7 @@ class TailLogTest(unittest.TestCase):
tail.stop()
dst.seek(0)
headers: Set[str] = set()
headers: set[str] = set()
for line in dst.readlines():
header, _ = line.split(":")
headers.add(header)

View File

@ -6,7 +6,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, cast, Dict, SupportsInt
from typing import Any, cast, SupportsInt
from unittest import TestCase
from torch.distributed.elastic.rendezvous import (
@ -24,7 +24,7 @@ class RendezvousParametersTest(TestCase):
self._run_id = "dummy_run_id"
self._min_nodes = 3
self._max_nodes = 6
self._kwargs: Dict[str, Any] = {}
self._kwargs: dict[str, Any] = {}
def _create_params(self) -> RendezvousParameters:
return RendezvousParameters(

View File

@ -15,7 +15,7 @@ import time
from abc import ABC, abstractmethod
from base64 import b64encode
from datetime import datetime, timedelta, timezone
from typing import Callable, cast, Optional, Tuple
from typing import Callable, cast, Optional
from unittest import TestCase
from unittest.mock import call, MagicMock, Mock, patch, PropertyMock
@ -186,7 +186,7 @@ class FakeRendezvousBackend(RendezvousBackend):
def name(self) -> str:
return "fake_backend"
def get_state(self) -> Optional[Tuple[bytes, Token]]:
def get_state(self) -> Optional[tuple[bytes, Token]]:
if self._token == 0:
return None
@ -194,7 +194,7 @@ class FakeRendezvousBackend(RendezvousBackend):
def set_state(
self, state: bytes, token: Optional[Token] = None
) -> Optional[Tuple[bytes, Token, bool]]:
) -> Optional[tuple[bytes, Token, bool]]:
if token is None:
token = 0

View File

@ -7,7 +7,7 @@
# LICENSE file in the root directory of this source tree.
from abc import ABC, abstractmethod
from typing import Any, Callable, cast, Optional, Tuple
from typing import Any, Callable, cast, Optional
from torch.distributed.elastic.rendezvous import RendezvousStateError
from torch.distributed.elastic.rendezvous.dynamic_rendezvous import (
@ -32,12 +32,12 @@ class RendezvousBackendTestMixin(ABC):
def _set_state(
self, state: bytes, token: Optional[Any] = None
) -> Tuple[bytes, Token, bool]:
) -> tuple[bytes, Token, bool]:
result = self._backend.set_state(state, token)
self.assertIsNotNone(result)
return cast(Tuple[bytes, Token, bool], result)
return cast(tuple[bytes, Token, bool], result)
def test_get_state_returns_backend_state(self) -> None:
self._backend.set_state(b"x")
@ -46,7 +46,7 @@ class RendezvousBackendTestMixin(ABC):
self.assertIsNotNone(result)
state, token = cast(Tuple[bytes, Token], result)
state, token = cast(tuple[bytes, Token], result)
self.assertEqual(b"x", state)
self.assertIsNotNone(token)

View File

@ -10,7 +10,6 @@ import socket
import threading
import time
from datetime import timedelta
from typing import List
from unittest import TestCase
from unittest.mock import patch
@ -350,7 +349,7 @@ class PeriodicTimerTest(TestCase):
call_interval = 0.2
# Keep the log of intervals between each consecutive call.
actual_call_intervals: List[float] = []
actual_call_intervals: list[float] = []
# Keep the number of times the function was called.
call_count = 0

View File

@ -7,7 +7,6 @@ import pickle
import socket
import tempfile
from contextlib import contextmanager
from typing import Dict
from urllib3.connection import HTTPConnection
from urllib3.connectionpool import HTTPConnectionPool
@ -181,7 +180,7 @@ class WorkerServerTest(TestCase):
def body(self) -> bytes:
return b"dummy"
def params(self) -> Dict[str, str]:
def params(self) -> dict[str, str]:
return {}
class Response(_Response):

View File

@ -9,7 +9,6 @@
import datetime
from multiprocessing.pool import ThreadPool
from typing import List
from unittest import mock
import torch.distributed as dist
@ -40,7 +39,7 @@ class MockStore:
self.ops.append(("get", key))
return "value"
def multi_get(self, keys: List[str]) -> List[str]:
def multi_get(self, keys: list[str]) -> list[str]:
self.ops.append(("multi_get", keys))
return ["value"] * len(keys)
@ -48,7 +47,7 @@ class MockStore:
self.ops.append(("add", key, val))
return 3
def wait(self, keys: List[str]) -> None:
def wait(self, keys: list[str]) -> None:
self.ops.append(("wait", keys))
@ -157,7 +156,7 @@ class StoreUtilTest(TestCase):
return ""
with ThreadPool(N - 1) as pool:
outputs: List[str] = pool.map(run_barrier_for_rank, range(N - 1))
outputs: list[str] = pool.map(run_barrier_for_rank, range(N - 1))
self.assertTrue(any("missing_ranks=[Rank 2 host]" in msg for msg in outputs))

View File

@ -1,7 +1,6 @@
# Owner(s): ["oncall: distributed"]
import sys
from typing import List
from unittest.mock import patch
import torch
@ -102,7 +101,7 @@ class TestBackwardPrefetch(FSDPTest):
tgt = torch.randn((20, 1, 1024), device=device_type)
# monkey patch
all_handle_fqns: List[List[str]] = []
all_handle_fqns: list[list[str]] = []
def patched_get_handle_to_prefetch(*args, **kwargs):
handle = orig_get_handle_to_prefetch(*args, **kwargs)

View File

@ -2,7 +2,7 @@
import sys
from contextlib import nullcontext
from enum import auto, Enum
from typing import List, Optional
from typing import Optional
from unittest.mock import patch
import torch
@ -319,7 +319,7 @@ class TestExplicitUnshard(FSDPTest):
self.mlp2 = MLP(dim)
self.mlp3 = MLP(dim)
def forward(self, ys: List[torch.Tensor], works: List[dist.Work]):
def forward(self, ys: list[torch.Tensor], works: list[dist.Work]):
(y1, y2, y3), (work1, work2, work3) = ys, works
work1.wait()
z1 = self.mlp1(y1)
@ -372,7 +372,7 @@ class TestExplicitUnshard(FSDPTest):
torch.manual_seed(42 + self.rank + 1)
inp = torch.randn((batch_size, dim), device=device_type)
for _ in range(10):
losses: List[torch.Tensor] = []
losses: list[torch.Tensor] = []
for _model, _optim in ((ref_model, ref_optim), (model, optim)):
losses.append(_model(inp).sum())
losses[-1].backward()

View File

@ -4,7 +4,7 @@ import functools
import itertools
import sys
import unittest
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Optional
from unittest import mock
import torch
@ -76,7 +76,7 @@ class TestParityWithDDP(FSDPTest):
PyTorch DDP vs. FullyShardedDataParallel.
"""
def _get_device_init_modes(self, cpu_offload: CPUOffload) -> List[DEVICEInitMode]:
def _get_device_init_modes(self, cpu_offload: CPUOffload) -> list[DEVICEInitMode]:
modes = [
DEVICEInitMode.DEVICE_AFTER,
DEVICEInitMode.DEVICE_BEFORE,
@ -89,7 +89,7 @@ class TestParityWithDDP(FSDPTest):
modes.append(DEVICEInitMode.DEVICE_NEVER)
return modes
def _get_subtest_config(self, cpu_offload: CPUOffload) -> Dict[str, List[Any]]:
def _get_subtest_config(self, cpu_offload: CPUOffload) -> dict[str, list[Any]]:
"""Returns a subtest configuration that subtests CUDA initialization
modes and prefetching settings together."""
return {

View File

@ -4,7 +4,7 @@ import contextlib
import itertools
import sys
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Optional
import torch
from torch import distributed as dist
@ -71,7 +71,7 @@ class _GradAccConfigs:
sole purpose of overriding :meth:`__repr__` to remove spaces.
"""
configs: List[_GradAccConfig]
configs: list[_GradAccConfig]
def __repr__(self) -> str:
# Override to remove any spaces in the string to appease the internal
@ -90,7 +90,7 @@ class TestGradAcc(FSDPTest):
def _test_grad_acc(
self,
batch_dim: int,
configs: List[_GradAccConfig],
configs: list[_GradAccConfig],
cpu_offload: CPUOffload,
backward_prefetch: Optional[BackwardPrefetch],
sharding_strategy: ShardingStrategy,
@ -146,8 +146,8 @@ class TestGradAcc(FSDPTest):
def permute_tensor(x: torch.Tensor):
return x.view(-1)[torch.randperm(x.numel())].view_as(x)
batch: Tuple[torch.Tensor, ...] = fsdp_model.module.get_input(device)
batches: List[Tuple[torch.Tensor, ...]] = [batch]
batch: tuple[torch.Tensor, ...] = fsdp_model.module.get_input(device)
batches: list[tuple[torch.Tensor, ...]] = [batch]
num_iters_to_acc = sum(config.num_iters for config in configs)
for _ in range(num_iters_to_acc - 1):
batches.append(tuple(permute_tensor(t) for t in batch))
@ -158,7 +158,7 @@ class TestGradAcc(FSDPTest):
), "Check the test to make sure that batches are distinct"
# Concatenate the batches along the given batch dimension
concat_batch: Tuple[torch.Tensor, ...] = tuple(
concat_batch: tuple[torch.Tensor, ...] = tuple(
torch.cat(ts, dim=batch_dim) for ts in zip(*batches)
)
@ -214,7 +214,7 @@ class TestGradAcc(FSDPTest):
# Check that the optimizer step does not error
optim.step()
def _get_subtest_config(self) -> Dict[str, List[Any]]:
def _get_subtest_config(self) -> dict[str, list[Any]]:
"""Returns a subtest configuration that subtests prefetching."""
return {
"backward_prefetch": [

View File

@ -5,7 +5,7 @@ import sys
from collections import Counter
from enum import auto, Enum
from functools import partial
from typing import List, Optional, Tuple
from typing import Optional
import torch
import torch.distributed as dist
@ -363,7 +363,7 @@ class TestFSDPHybridShard(FSDPTest):
torch.manual_seed(global_pg.rank() + 1)
for _ in range(5):
inp = fsdp_model.module.get_input(torch.device("cuda"))
losses: List[torch.Tensor] = []
losses: list[torch.Tensor] = []
for model, optim in ((fsdp_model, fsdp_optim), (hsdp_model, hsdp_optim)):
optim.zero_grad()
loss = model(*inp).sum()
@ -396,7 +396,7 @@ class TestFSDPHybridShard(FSDPTest):
sharding_strategy_mode: str,
use_orig_params: bool,
hsdp_process_groups: Optional[
Tuple[dist.ProcessGroup, dist.ProcessGroup]
tuple[dist.ProcessGroup, dist.ProcessGroup]
] = None,
hsdp_device_mesh: Optional = None,
):

View File

@ -8,7 +8,7 @@ from collections import namedtuple
from contextlib import nullcontext
from copy import deepcopy
from itertools import chain
from typing import Any, Tuple
from typing import Any
import torch
import torch.distributed as dist
@ -945,7 +945,7 @@ class TestFSDPMiscMultiThread(FSDPTestMultiThread):
self._test_homogeneous_attributes,
)
def _test_homogeneous_attributes(self, attr_name_and_values: Tuple[str, Any, Any]):
def _test_homogeneous_attributes(self, attr_name_and_values: tuple[str, Any, Any]):
model = NestedWrappedModule.init(
self.process_group,
FSDPInitMode.NO_FSDP,

View File

@ -6,7 +6,7 @@ import os
import sys
from functools import partial
from itertools import product
from typing import Any, Dict, List
from typing import Any
import torch
import torch.cuda.nccl as nccl
@ -521,7 +521,7 @@ class TestFSDPMixedPrecisionSharded(TestFSDPMixedPrecision):
def world_size(self):
return 2
def _get_subtest_config(self) -> Dict[str, List[Any]]:
def _get_subtest_config(self) -> dict[str, list[Any]]:
"""Returns a subtest configuration that subtests prefetching settings
together."""
return {
@ -1136,7 +1136,7 @@ class TestFSDPDifferentSubmodulePrecision(FSDPTest):
@skip_if_lt_x_gpu(2)
def test_float16_on_one_submodule(self):
forward_inputs: Dict[str, nn.Module] = {}
forward_inputs: dict[str, nn.Module] = {}
float16 = MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=True)
model = SaveForwardInputsModel(
@ -1158,7 +1158,7 @@ class TestFSDPDifferentSubmodulePrecision(FSDPTest):
@skip_if_lt_x_gpu(2)
def test_float16_on_one_submodule_skip_inputs(self):
forward_inputs: Dict[nn.Module, torch.Tensor] = {}
forward_inputs: dict[nn.Module, torch.Tensor] = {}
float16 = MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=False)
model = SaveForwardInputsModel(
@ -1179,7 +1179,7 @@ class TestFSDPDifferentSubmodulePrecision(FSDPTest):
@skip_if_lt_x_gpu(2)
def test_float16_on_one_submodule_skip_inputs_error(self):
forward_inputs: Dict[nn.Module, torch.Tensor] = {}
forward_inputs: dict[nn.Module, torch.Tensor] = {}
float16 = MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=False)
model = SaveForwardInputsModel(
@ -1198,7 +1198,7 @@ class TestFSDPDifferentSubmodulePrecision(FSDPTest):
@skip_if_lt_x_gpu(2)
def test_submodules_with_different_precisions_error(self):
forward_inputs: Dict[nn.Module, torch.Tensor] = {}
forward_inputs: dict[nn.Module, torch.Tensor] = {}
float16 = MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=True)
float32 = MixedPrecision(param_dtype=torch.float32, cast_forward_inputs=True)
@ -1222,7 +1222,7 @@ class TestFSDPDifferentSubmodulePrecision(FSDPTest):
@skip_if_lt_x_gpu(2)
def test_submodules_with_different_precisions(self):
forward_inputs: Dict[nn.Module, torch.Tensor] = {}
forward_inputs: dict[nn.Module, torch.Tensor] = {}
float16 = MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=True)
float32 = MixedPrecision(param_dtype=torch.float32, cast_forward_inputs=True)
@ -1244,7 +1244,7 @@ class TestFSDPDifferentSubmodulePrecision(FSDPTest):
@skip_if_lt_x_gpu(2)
def test_submodules_with_external_inputs(self):
class ToyModule(nn.Module):
def __init__(self, forward_inputs: Dict[str, torch.Tensor]) -> None:
def __init__(self, forward_inputs: dict[str, torch.Tensor]) -> None:
super().__init__()
self.l = nn.Linear(100, 100)
self.forward_inputs = forward_inputs
@ -1255,7 +1255,7 @@ class TestFSDPDifferentSubmodulePrecision(FSDPTest):
return self.l(x)
class ToyModel(nn.Module):
def __init__(self, forward_inputs: Dict[str, torch.Tensor]) -> None:
def __init__(self, forward_inputs: dict[str, torch.Tensor]) -> None:
super().__init__()
self.l1 = nn.Linear(100, 100)
self.l2 = ToyModule(forward_inputs)
@ -1266,7 +1266,7 @@ class TestFSDPDifferentSubmodulePrecision(FSDPTest):
y = torch.ones(2, 100, device="cuda", dtype=torch.float32)
return self.l2(self.l1(x), y)
forward_inputs: Dict[str, torch.Tensor] = {}
forward_inputs: dict[str, torch.Tensor] = {}
float16 = MixedPrecision(param_dtype=torch.float16)
model = ToyModel(forward_inputs).cuda()
@ -1343,7 +1343,7 @@ class TestFSDPTrainEval(FSDPTest):
torch.manual_seed(1 + self.rank)
eval_src = torch.randn((8, 1, 512), device=device)
eval_tgt = torch.randn((16, 1, 512), device=device)
eval_out_sums: List[torch.Tensor] = []
eval_out_sums: list[torch.Tensor] = []
# An iteration consists of training forward/backward/optimizer,
# updating the EMA copy with the main copy, and eval forward
for _ in range(3):

View File

@ -4,7 +4,7 @@ import bisect
import sys
from copy import deepcopy
from enum import auto, Enum
from typing import Any, Callable, Dict, List, Optional, Tuple, Type
from typing import Any, Callable, Optional
import torch
import torch.nn as nn
@ -177,7 +177,7 @@ class NestedModel(torch.nn.Module):
model: torch.nn.Module,
group: Optional[dist.ProcessGroup] = None,
ignore_modules: bool = False,
fsdp_kwargs: Optional[Dict[str, Any]] = None,
fsdp_kwargs: Optional[dict[str, Any]] = None,
) -> torch.nn.Module:
if fsdp_kwargs is None:
fsdp_kwargs = {}
@ -214,7 +214,7 @@ class NestedModel(torch.nn.Module):
def wrap_alt(
model: torch.nn.Module,
group: Optional[dist.ProcessGroup] = None,
fsdp_kwargs: Optional[Dict[str, Any]] = None,
fsdp_kwargs: Optional[dict[str, Any]] = None,
) -> torch.nn.Module:
if fsdp_kwargs is None:
fsdp_kwargs = {}
@ -231,7 +231,7 @@ class NestedModel(torch.nn.Module):
model,
add_to_fsdp_module: bool,
group=None,
) -> Tuple[torch.nn.Module, List[torch.nn.Parameter]]:
) -> tuple[torch.nn.Module, list[torch.nn.Parameter]]:
"""Registers unmanaged parameters before wrapping with :meth:`wrap`."""
device = next(model.parameters()).device
unmanaged_param = torch.nn.Parameter(torch.randn(5, 5, device=device))
@ -277,12 +277,12 @@ class NestedModel(torch.nn.Module):
# NOTE: We exclude `self.bias` from either parameter group to test the
# case where the optimizer input does not include all model parameters
def param_group0(self) -> List[torch.nn.Parameter]:
def param_group0(self) -> list[torch.nn.Parameter]:
# Use `block1`'s parameters for the first parameter group to deviate
# from the `model.parameters()` order
return list(self.block1.parameters())
def param_group1(self) -> List[torch.nn.Parameter]:
def param_group1(self) -> list[torch.nn.Parameter]:
# Deviate from the `model.parameters()` order further by rearranging
# `block2`'s parameters to be before `block0`'s parameters
return list(self.block2.parameters()) + list(self.block0.parameters())
@ -322,10 +322,10 @@ class TestFSDPOptimState(FSDPTest):
wrap_alt: bool = False, # ignored if `wrap=False`
device: torch.device = torch.device("cuda"),
group=None,
optim_class: Type[torch.optim.Optimizer] = torch.optim.Adam,
optim_class: type[torch.optim.Optimizer] = torch.optim.Adam,
use_multiple_param_groups: bool = False,
use_diff_optim_inputs: bool = False,
fsdp_kwargs: Optional[Dict[str, Any]] = None,
fsdp_kwargs: Optional[dict[str, Any]] = None,
):
model = NestedModel().to(device)
if wrap:
@ -356,7 +356,7 @@ class TestFSDPOptimState(FSDPTest):
wrap: bool,
device: torch.device = torch.device("cuda"),
group=None,
optim_class: Type[torch.optim.Optimizer] = torch.optim.Adam,
optim_class: type[torch.optim.Optimizer] = torch.optim.Adam,
use_multiple_param_groups: bool = False,
use_diff_optim_inputs: bool = False,
):
@ -383,7 +383,7 @@ class TestFSDPOptimState(FSDPTest):
optim: torch.optim.Optimizer,
device: torch.device = torch.device("cuda"),
num_iters: int = 1,
) -> List[float]:
) -> list[float]:
"""Performs a forward pass, backward pass, and optimizer step
``num_iters``-many times, and returns the per-iteration losses."""
torch.manual_seed(0) # set seed for determinism
@ -399,7 +399,7 @@ class TestFSDPOptimState(FSDPTest):
optim.step()
return losses
def _broadcast_full_osd(self, full_osd: Dict[str, Any], group=None):
def _broadcast_full_osd(self, full_osd: dict[str, Any], group=None):
"""Broadcasts the full optimizer state dict in place of using
``torch.save()`` and ``torch.load()`` so that all ranks can have it."""
obj_list = [full_osd]
@ -413,8 +413,8 @@ class TestFSDPOptimState(FSDPTest):
def _are_equal_states(
self,
state1: Dict[str, Any],
state2: Dict[str, Any],
state1: dict[str, Any],
state2: dict[str, Any],
) -> bool:
"""Checks if ``state1`` and ``state2`` contain the same mappings."""
if set(state1.keys()) != set(state2.keys()):
@ -1450,7 +1450,7 @@ class TestFSDPOptimState(FSDPTest):
self,
should_check_method_fn: Callable[[str], bool],
context_fn: Callable,
fsdp_kwargs: Optional[Dict[str, Any]],
fsdp_kwargs: Optional[dict[str, Any]],
):
"""
Runs through all optimizer state checkpointing APIs with a context

View File

@ -5,7 +5,7 @@ import functools
import itertools
import sys
import unittest
from typing import List, Optional
from typing import Optional
import torch
from torch import distributed as dist
@ -259,7 +259,7 @@ class TestShardedGradScalerParityWithDDP(FSDPTest):
)
grad_scaler = ShardedGradScaler(init_scale=2.0)
ref_grad_scaler = torch.amp.GradScaler(device="cuda", init_scale=2.0)
scaled_losses: List[torch.Tensor] = []
scaled_losses: list[torch.Tensor] = []
device = torch.device("cuda")
torch.manual_seed(42 + self.rank + 1)

View File

@ -6,7 +6,7 @@ import sys
from contextlib import nullcontext
from copy import deepcopy
from functools import partial
from typing import Any, Dict
from typing import Any
import torch
import torch.nn as nn
@ -787,7 +787,7 @@ class TestFSDPStateDict(FSDPTest):
@staticmethod
def _load_state_dict(
model: Module, state_dict_type: str, state_dict: Dict[str, Any]
model: Module, state_dict_type: str, state_dict: dict[str, Any]
):
try:
enum_val = STATE_DICT_MAPPING[state_dict_type]

View File

@ -2,7 +2,7 @@
import copy
import sys
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple
from typing import Optional
import torch
from torch import distributed as dist
@ -62,11 +62,11 @@ class SimpleModel(torch.nn.Module):
return self.net3(self.net2(self.relu(self.net1(x))))
@staticmethod
def get_sharded_param_names() -> List[str]:
def get_sharded_param_names() -> list[str]:
return ["net1.weight", "net1.bias", "net2.weight"]
@staticmethod
def get_non_sharded_param_names() -> List[str]:
def get_non_sharded_param_names() -> list[str]:
return ["net3.weight", "net3.bias"]
@ -87,9 +87,9 @@ class TestTPFSDPIntegration(FSDPTest):
def _get_params_and_sharding_info(
self,
model: SimpleModel,
sharded_param_names: List[str],
sharded_param_names: list[str],
tensor_parallel_size: int,
) -> Tuple[Dict[str, int], Dict[str, Tuple[torch.Size, int]]]:
) -> tuple[dict[str, int], dict[str, tuple[torch.Size, int]]]:
""" """
assert (
type(model) is SimpleModel
@ -131,8 +131,8 @@ class TestTPFSDPIntegration(FSDPTest):
self,
tp_fsdp_model: FSDP,
tp_pg: dist.ProcessGroup,
param_name_to_numel: Dict[str, int],
non_sharded_param_names: List[str],
param_name_to_numel: dict[str, int],
non_sharded_param_names: list[str],
) -> None:
"""
Syncs the tensor parallel parameters' gradients following the data
@ -177,11 +177,11 @@ class TestTPFSDPIntegration(FSDPTest):
self,
model: FSDP,
uses_tp: bool,
param_name_to_numel: Dict[str, int],
param_name_to_sharding_info: Dict[str, Tuple[torch.Size, int]],
param_name_to_numel: dict[str, int],
param_name_to_sharding_info: dict[str, tuple[torch.Size, int]],
tp_pg: Optional[dist.ProcessGroup],
fsdp_pg: Optional[dist.ProcessGroup],
sharded_param_names: Optional[List[str]],
sharded_param_names: Optional[list[str]],
) -> torch.Tensor:
"""
Returns all unsharded gradients as a single flattened tensor. This

View File

@ -3,7 +3,7 @@ import contextlib
import itertools
import math
import sys
from typing import Any, Dict, List, Optional, Union
from typing import Any, Optional, Union
import torch
import torch.distributed.fsdp._traversal_utils as traversal_utils
@ -55,7 +55,7 @@ class TestUnshardParamsBase(FSDPTest):
self,
writeback: bool,
check_outer: bool,
**fsdp_kwargs: Dict[str, Any],
**fsdp_kwargs: dict[str, Any],
):
model = nn.Sequential(
nn.Linear(5, 5, bias=False, device=device_type.type),
@ -101,7 +101,7 @@ class TestUnshardParamsBase(FSDPTest):
for param in model.parameters():
self.assertEqual(param.device, cpu_device)
def _get_test_unshard_params_writeback_config(self) -> Dict[str, List[Any]]:
def _get_test_unshard_params_writeback_config(self) -> dict[str, list[Any]]:
return {
"writeback": [True, False],
"check_outer": [True, False],
@ -193,7 +193,7 @@ class TestUnshardParamsBase(FSDPTest):
num_fsdp_roots += fsdp_state._is_root
self.assertGreater(num_fsdp_roots, 1)
def _get_test_unshard_params_param_data_config(self) -> Dict[str, List[Any]]:
def _get_test_unshard_params_param_data_config(self) -> dict[str, list[Any]]:
return {
"rank0_only": [False, True],
"offload_to_cpu": [False, True],
@ -493,7 +493,7 @@ class TestUnshardParams(TestUnshardParamsBase):
def _check_grads(
ddp_model: DDP,
fsdp_model: FSDP,
old_fsdp_grads: Optional[List[torch.Tensor]],
old_fsdp_grads: Optional[list[torch.Tensor]],
):
"""
Checks that writes to the FSDP parameters' gradients persist or do

View File

@ -6,7 +6,7 @@ import itertools
import os
import sys
import unittest
from typing import Any, Dict, List, Optional, Tuple, Type
from typing import Any, Optional
import torch
import torch.nn as nn
@ -65,7 +65,7 @@ class TestFSDPUseOrigParamsMultipleParamGroups(FSDPTest):
def world_size(self) -> int:
return 2
def _get_param_groups(self, model: nn.Module) -> List[Dict[str, Any]]:
def _get_param_groups(self, model: nn.Module) -> list[dict[str, Any]]:
"""
Constructs separate parameter groups for weights, biases, and other
parameters.
@ -87,7 +87,7 @@ class TestFSDPUseOrigParamsMultipleParamGroups(FSDPTest):
def _get_optim(
self,
model: nn.Module,
optim_class: Type[torch.optim.Optimizer],
optim_class: type[torch.optim.Optimizer],
multi_tensor: bool,
) -> torch.optim.Optimizer:
"""
@ -117,12 +117,12 @@ class TestFSDPUseOrigParamsMultipleParamGroups(FSDPTest):
self,
device_init_mode: DEVICEInitMode,
init_optim_before_wrap: bool,
optim_class: Type[torch.optim.Optimizer],
optim_class: type[torch.optim.Optimizer],
multi_tensor: bool,
sharding_strategy: ShardingStrategy,
backward_prefetch: Optional[BackwardPrefetch],
cpu_offload: CPUOffload,
) -> Tuple[FSDP, torch.optim.Optimizer]:
) -> tuple[FSDP, torch.optim.Optimizer]:
"""
Returns a transformer with shared parameters wrapped with FSDP and a
corresponding optimizer.
@ -335,7 +335,7 @@ class TestFSDPUseOrigParamsMultipleParamGroups(FSDPTest):
self,
device_init_mode: DEVICEInitMode,
init_optim_before_wrap: bool,
optim_class: Type[torch.optim.Optimizer],
optim_class: type[torch.optim.Optimizer],
multi_tensor: bool,
set_to_none: bool,
backward_prefetch: Optional[BackwardPrefetch],
@ -566,7 +566,7 @@ class TestFSDPUseOrigParamsUnshardReshard(FSDPTest):
self,
sharding_strategy: ShardingStrategy,
cpu_offload: CPUOffload,
) -> Tuple[FSDP, torch.optim.Optimizer, FSDP, torch.optim.Optimizer]:
) -> tuple[FSDP, torch.optim.Optimizer, FSDP, torch.optim.Optimizer]:
"""
Returns a pair of (FSDP model, optimizer) for ``use_orig_params=False``
and ``True``, respectively.
@ -778,7 +778,7 @@ class TestFSDPUseOrigParamsParamAccess(FSDPTest):
z = self.lin2(z)
return z
def get_input(self, device: torch.device) -> Tuple[torch.Tensor, ...]:
def get_input(self, device: torch.device) -> tuple[torch.Tensor, ...]:
return (torch.randn((2, 5)).to(device),)
def get_loss(self, inp, out):
@ -872,7 +872,7 @@ class TestFSDPUseOrigParamsWriteback(FSDPTest):
z = self.lin2(z)
return z
def get_input(self, device: torch.device) -> Tuple[torch.Tensor, ...]:
def get_input(self, device: torch.device) -> tuple[torch.Tensor, ...]:
return (torch.randn((2, 5)).to(device),)
def get_loss(self, inp, out):

View File

@ -4,7 +4,6 @@ import random
import sys
from collections import OrderedDict
from dataclasses import dataclass
from typing import List
import torch
import torch.nn as nn
@ -62,13 +61,13 @@ class TestUtils(TestCase):
class NonFrozenDataClass:
some_key: str
some_float: float
some_tensor: List[torch.Tensor]
some_tensor: list[torch.Tensor]
@dataclass(frozen=True)
class FrozenDataClass:
some_key: str
some_float: float
some_tensor: List[torch.Tensor]
some_tensor: list[torch.Tensor]
# create a mixed bag of data.
data = [1, "str"]

View File

@ -16,7 +16,7 @@ import time
import unittest
import uuid
from contextlib import closing
from typing import Any, Dict, Optional
from typing import Any, Optional
from unittest import mock
from unittest.mock import MagicMock, Mock, patch
@ -59,7 +59,7 @@ def get_test_launch_config(
nproc_per_node: int,
run_id: str = "",
rdzv_backend: str = "etcd",
config: Optional[Dict[str, Any]] = None,
config: Optional[dict[str, Any]] = None,
) -> LaunchConfig:
rdzv_configs = {}
if config:

View File

@ -3,7 +3,6 @@
import sys
from pathlib import Path
from typing import Tuple
import torch
import torch.distributed as dist
@ -22,7 +21,7 @@ from torch.testing._internal.common_utils import run_tests, TestCase
class MyModuleInterface:
def forward(
self, tensor: Tensor, number: int, word: str = "default"
) -> Tuple[Tensor, int, str]:
) -> tuple[Tensor, int, str]:
pass

View File

@ -10,7 +10,7 @@ import os
import sys
import unittest
from contextlib import nullcontext
from typing import Any, cast, List
from typing import Any, cast
import numpy as np
@ -207,7 +207,7 @@ class TestZeroRedundancyOptimizerSingleRank(TestZeroRedundancyOptimizer):
super().step()
kwarg.append(5)
kwarg: List[Any] = []
kwarg: list[Any] = []
x = torch.tensor([1.0], device=self.device, requires_grad=True)
o = ZeroRedundancyOptimizer(
[x],

View File

@ -2,7 +2,7 @@
# Owner(s): ["oncall: distributed"]
# This file is a Schedule zoo for testing torch.distributed.pipelining.
# It includes schedules designed purely for testing purposes
from typing import Callable, Dict, List, Optional
from typing import Callable, Optional
from torch.distributed.pipelining.schedules import (
_Action,
@ -32,9 +32,9 @@ class ScheduleVShaped(PipelineScheduleMulti):
def __init__(
self,
stages: List[_PipelineStageBase],
stages: list[_PipelineStageBase],
n_microbatches: int,
stage_index_to_group_rank: Dict[int, int],
stage_index_to_group_rank: dict[int, int],
loss_fn: Optional[Callable] = None,
scale_grads: bool = True,
):
@ -82,9 +82,9 @@ class ScheduleUnbalanced(PipelineScheduleMulti):
def __init__(
self,
stages: List[_PipelineStageBase],
stages: list[_PipelineStageBase],
n_microbatches: int,
stage_index_to_group_rank: Dict[int, int],
stage_index_to_group_rank: dict[int, int],
loss_fn: Optional[Callable] = None,
scale_grads: bool = True,
):
@ -134,7 +134,7 @@ class ScheduleWithW(PipelineScheduleMulti):
def __init__(
self,
stages: List[_PipelineStageBase],
stages: list[_PipelineStageBase],
n_microbatches: int,
loss_fn: Optional[Callable] = None,
enable_zero_bubble: bool = True,
@ -195,7 +195,7 @@ class ScheduleWithReorderedB(_PipelineScheduleRuntime):
def __init__(
self,
stages: List[_PipelineStageBase],
stages: list[_PipelineStageBase],
n_microbatches: int,
loss_fn: Optional[Callable] = None,
scale_grads: bool = True,

View File

@ -4,7 +4,6 @@ import copy
import csv
import logging
import os
from typing import List
from model_registry import MultiMLP
@ -356,7 +355,7 @@ instantiate_parametrized_tests(TestSchedulePlan)
class TestScheduleLowering(TestCase):
"""Tests lowering passes that convert simple compute-only (FBW) schedules into compute+comms schedules"""
def _parse_actions(self, actions: List[str]) -> List[_Action]:
def _parse_actions(self, actions: list[str]) -> list[_Action]:
return [_Action.from_str(s) for s in actions]
@parametrize(

View File

@ -1,7 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
# Owner(s): ["oncall: distributed"]
from typing import Any, Dict
from typing import Any
import torch
from torch.distributed._tensor import DeviceMesh
@ -57,8 +57,8 @@ class TestCommModeFeatures(DTensorTestBase):
Used to generate the ground-truth parameter and sharding info for a given distributed model to
verify comm_mode correctness
"""
module_parameters_dict: Dict[str, Any] = {}
module_sharding_dict: Dict[str, Any] = {}
module_parameters_dict: dict[str, Any] = {}
module_sharding_dict: dict[str, Any] = {}
for name, parameters in model.named_parameters():
# splits name into module name to create FQN and parameter name

View File

@ -1,6 +1,5 @@
# Owner(s): ["oncall: distributed"]
from collections import defaultdict
from typing import Dict
import torch
from torch.distributed._tensor.experimental._tp_transform import (
@ -57,9 +56,9 @@ class TensorParallelTest(DTensorTestBase):
super().setUp()
def assert_has_c10d_ops(
self, gm: torch.fx.GraphModule, expected_ops_count: Dict[str, int]
self, gm: torch.fx.GraphModule, expected_ops_count: dict[str, int]
) -> None:
actual_ops_count: Dict[str, int] = defaultdict(int)
actual_ops_count: dict[str, int] = defaultdict(int)
for node in gm.graph.nodes:
if node.op == "call_function":
if "c10d_functional" in str(node.target):
@ -100,7 +99,7 @@ class TensorParallelTest(DTensorTestBase):
torch.manual_seed(0)
model = MLPListModule(2).to(device=self.device_type)
inputs = (torch.randn((10, 12)).to(device=self.device_type),)
parallel_strategies: Dict[str, ParallelStyle] = {
parallel_strategies: dict[str, ParallelStyle] = {
"mlps.0.0": ColwiseParallel,
"mlps.0.2": RowwiseParallel,
"mlps.1.0": ColwiseParallel,
@ -137,7 +136,7 @@ class TensorParallelTest(DTensorTestBase):
torch.manual_seed(0)
model = MLPListModule(1, bias=False).to(device=self.device_type)
inputs = (torch.randn((10, 12)).to(device=self.device_type),)
parallel_strategies: Dict[str, ParallelStyle] = {
parallel_strategies: dict[str, ParallelStyle] = {
"mlps.0.0": ColwiseParallel,
"mlps.0.2": RowwiseParallel,
}

View File

@ -3,7 +3,7 @@
import itertools
from copy import deepcopy
from typing import Dict, NamedTuple, Optional
from typing import NamedTuple, Optional
import torch
import torch.distributed as dist
@ -52,9 +52,9 @@ reduce_scatter, all_gather, all_reduce = (
class ExpCommCounts(NamedTuple):
fwd: Optional[Dict] = None
bwd: Optional[Dict] = None
optim: Optional[Dict] = None
fwd: Optional[dict] = None
bwd: Optional[dict] = None
optim: Optional[dict] = None
class DistTensorParallelExampleTest(DTensorTestBase):

View File

@ -3,7 +3,7 @@
import itertools
import unittest
from typing import cast, List, Optional, Tuple
from typing import cast, Optional
import torch
import torch.nn.functional as F
@ -27,8 +27,8 @@ from torch.testing._internal.distributed._tensor.common_dtensor import (
def scale_for_fp8(
t: torch.Tensor, scale_shape: Tuple[int]
) -> Tuple[torch.Tensor, torch.Tensor]:
t: torch.Tensor, scale_shape: tuple[int]
) -> tuple[torch.Tensor, torch.Tensor]:
if all(d == 1 for d in scale_shape):
t = t.unsqueeze(0).unsqueeze(-2)
else:
@ -116,7 +116,7 @@ class DistMatrixOpsTest(DTensorTestBase):
local_res = torch.mm(t1, t2)
def test_placement_comb(
placements1: List[Placement], placements2: List[Placement]
placements1: list[Placement], placements2: list[Placement]
) -> None:
dt1 = distribute_tensor(t1, device_mesh, placements1)
dt2 = distribute_tensor(t2, device_mesh, placements2)
@ -272,9 +272,9 @@ class DistMatrixOpsTest(DTensorTestBase):
batch_2 = torch.rand(4, 8, 8, device=self.device_type, requires_grad=True)
def test_placement_comb(
tensor_placements: List[Placement],
batch_1_placements: List[Placement],
batch_2_placements: List[Placement],
tensor_placements: list[Placement],
batch_1_placements: list[Placement],
batch_2_placements: list[Placement],
beta: int,
alpha: int,
batch_1_grad: Optional[torch.Tensor],
@ -338,8 +338,8 @@ class DistMatrixOpsTest(DTensorTestBase):
local_result.backward(grad_local_res)
def test_placement_comb(
placements1: List[Placement],
placements2: List[Placement],
placements1: list[Placement],
placements2: list[Placement],
) -> None:
mat1_dt = distribute_tensor(mat1, device_mesh, placements1)
mat2_dt = distribute_tensor(mat2, device_mesh, placements2)

View File

@ -2,7 +2,7 @@
# Owner(s): ["oncall: distributed"]
from collections.abc import Sequence
from typing import Any, Callable, Dict, Optional
from typing import Any, Callable, Optional
from unittest import skip
import torch
@ -76,7 +76,7 @@ class DistElementwiseOpsTest(DTensorOpTestBase):
op: Callable,
pre_op_fn: Optional[Callable] = None,
args: Sequence[Any] = (),
kwargs: Optional[Dict[str, Any]] = None,
kwargs: Optional[dict[str, Any]] = None,
):
if pre_op_fn is None:
pre_op_fn = no_op

View File

@ -2,7 +2,7 @@
# Owner(s): ["oncall: distributed"]
import itertools
from typing import cast, List
from typing import cast
import torch
import torch.distributed as dist
@ -160,7 +160,7 @@ class TestViewOps(DTensorTestBase):
if op == torch.unbind:
no_shard_dims.add(kwargs.get("dim", 0))
sharding_choices = cast(List[Placement], [Replicate()]) + [
sharding_choices = cast(list[Placement], [Replicate()]) + [
Shard(i) for i, s in enumerate(in_shape) if s > 1 and i not in no_shard_dims
]
@ -513,7 +513,7 @@ class TestViewOps(DTensorTestBase):
# test sharded computation correctness
# NOTE: For the input to torch.view_as_complex, sharding
# on the last two dimensions is not supported.
sharding_choices: List[Placement] = [Replicate(), Shard(0)]
sharding_choices: list[Placement] = [Replicate(), Shard(0)]
all_sharding_choices = itertools.product(
*(self.device_mesh.ndim * [sharding_choices])
)

View File

@ -4,7 +4,7 @@
import os
import unittest
from functools import wraps
from typing import Any, Callable, Dict, Tuple
from typing import Any, Callable
import numpy as np
@ -26,7 +26,7 @@ def with_xla(func: Callable) -> Callable:
@wraps(func) # pyre-ignore[6]
def wrapper(
self, *args: Tuple[object], **kwargs: Dict[str, Any] # type: ignore[misc]
self, *args: tuple[object], **kwargs: dict[str, Any] # type: ignore[misc]
) -> None:
# TODO(yeounoh) replace this with xr.use_spmd() when we deprecate the flag.
os.environ["XLA_USE_SPMD"] = "1"

View File

@ -12,7 +12,7 @@ from dataclasses import dataclass
from datetime import timedelta
from itertools import product
from sys import platform
from typing import Dict, Optional
from typing import Optional
import torch
import torch.distributed as dist
@ -972,7 +972,7 @@ class CommonDistributedDataParallelTest:
@dataclass
class CustomOutput:
o1: Optional[torch.Tensor]
o2: Dict[str, torch.Tensor]
o2: dict[str, torch.Tensor]
class DataclassOutputModule(nn.Module):
def __init__(self, skip_o1):

View File

@ -1,7 +1,6 @@
# Owner(s): ["module: c10d"]
import threading
import unittest
from typing import List
import torch
import torch.distributed as dist
@ -65,7 +64,7 @@ class TestWithNCCL(MultiProcessTestCase):
return 2
@property
def ranks(self) -> List[int]:
def ranks(self) -> list[int]:
return list(range(self.world_size))
@property
@ -556,7 +555,7 @@ class CompileTest(TestCase):
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@fresh_inductor_cache()
def test_inductor_all_reduce_coalesced(self):
def func(args: List[torch.Tensor]) -> torch.Tensor:
def func(args: list[torch.Tensor]) -> torch.Tensor:
bufs = [arg + 42 for arg in args]
# Expect in-place with inductor allocated buf
ar0 = funcol.all_reduce_coalesced(bufs, "avg", "0")
@ -714,7 +713,7 @@ class CompileTest(TestCase):
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@fresh_inductor_cache()
def test_inductor_all_gather_into_tensor_coalesced(self):
def func(args: List[torch.Tensor]) -> torch.Tensor:
def func(args: list[torch.Tensor]) -> torch.Tensor:
ag0 = funcol.all_gather_into_tensor_coalesced(args, "0")
ag0 = [funcol.wait_tensor(out) for out in ag0]
return ag0
@ -796,7 +795,7 @@ class CompileTest(TestCase):
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@fresh_inductor_cache()
def test_inductor_reduce_scatter_tensor_coalesced(self):
def func(args: List[torch.Tensor]) -> torch.Tensor:
def func(args: list[torch.Tensor]) -> torch.Tensor:
rs0 = funcol.reduce_scatter_tensor_coalesced(
args, "avg", [0] * len(args), "0"
)

View File

@ -27,7 +27,6 @@ if not c10d.is_available() or not c10d.is_nccl_available():
print("c10d NCCL not available, skipping tests", file=sys.stderr)
sys.exit(0)
from typing import Dict, List
import test_c10d_common
from test_c10d_common import ConvNet, DoubleGpuNet, gpus_for_rank, ModuleForDdpCommHook
@ -2546,7 +2545,7 @@ class WorkHookTest(MultiProcessTestCase):
def test_on_completion_hook_broadcast(self):
pg = self._get_process_group()
num_hook_fired = 0
durations: List[float] = []
durations: list[float] = []
def hook(work_info: torch._C._distributed_c10d.WorkInfo):
nonlocal num_hook_fired, durations
@ -2574,7 +2573,7 @@ class WorkHookTest(MultiProcessTestCase):
def test_on_completion_hook_mixed_ops(self):
pg = self._get_process_group()
num_hook_fired = 0
durations: List[float] = []
durations: list[float] = []
def hook(work_info: torch._C._distributed_c10d.WorkInfo):
nonlocal num_hook_fired, durations
@ -2615,8 +2614,8 @@ class WorkHookTest(MultiProcessTestCase):
@skip_if_lt_x_gpu(2)
def test_on_completion_hook_with_ddp(self):
pg = self._get_process_group()
num_hook_fired: Dict[int, int] = {}
durations: Dict[OpType, List[float]] = {}
num_hook_fired: dict[int, int] = {}
durations: dict[OpType, list[float]] = {}
def hook(work_info: torch._C._distributed_c10d.WorkInfo):
nonlocal num_hook_fired, durations
@ -2673,8 +2672,8 @@ class WorkHookTest(MultiProcessTestCase):
torch.cuda.set_device(self.rank)
pg = self._get_process_group()
num_hook_fired: Dict[int, int] = {}
durations: Dict[OpType, List[float]] = {}
num_hook_fired: dict[int, int] = {}
durations: dict[OpType, list[float]] = {}
def hook(work_info: torch._C._distributed_c10d.WorkInfo):
nonlocal num_hook_fired, durations

View File

@ -2,21 +2,15 @@
import copy
import os
import sys
import tempfile
import test_c10d_spawn
from test_c10d_spawn import _torch_dist_nn_available, TestDistributedNNFunctions
import torch
import torch.distributed as c10d
import torch.nn as nn
from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU
from torch.testing._internal.common_distributed import (
create_device,
requires_gloo,
skip_if_lt_x_gpu,
)
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_distributed import requires_gloo, skip_if_lt_x_gpu
from torch.testing._internal.common_utils import (
run_tests,
skip_but_pass_in_sandcastle_if,
@ -26,95 +20,6 @@ from torch.testing._internal.common_utils import (
# Fails on Python-3.9, see https://github.com/pytorch/pytorch/issues/51619
if sys.version_info < (3, 9):
class ProcessGroupShareTensorTest(
test_c10d_spawn.AbstractProcessGroupShareTensorTest, TestCase
):
@classmethod
def opts(cls, threads=2):
opts = c10d.ProcessGroupGloo._Options()
opts._timeout = 5.0
opts._devices = [create_device(interface="lo")]
opts._threads = threads
return opts
@classmethod
def _init_pg_gloo(cls, rank, filename, world_size):
store = c10d.FileStore(filename, world_size)
backend = c10d.ProcessGroupGloo(
store, rank, world_size, ProcessGroupShareTensorTest.opts()
)
# set process group backends manually
c10d.init_process_group(
backend="gloo", store=store, rank=rank, world_size=world_size
)
pg = c10d.distributed_c10d._get_default_group()
pg._register_backend(
torch.device("cpu"), c10d.ProcessGroup.BackendType.GLOO, backend
)
pg._register_backend(
torch.device("cuda"), c10d.ProcessGroup.BackendType.GLOO, backend
)
return pg
@skip_but_pass_in_sandcastle_if(
not TEST_MULTIGPU, "At least 2 CUDA GPUS needed"
)
def test_shared_broadcast_gloo(self):
self._test_multiprocess(
ProcessGroupShareTensorTest._test_broadcast_process,
[torch.ones(2, 2).to(i) * i for i in range(self.world_size)],
ProcessGroupShareTensorTest._init_pg_gloo,
1,
)
@skip_but_pass_in_sandcastle_if(
not TEST_MULTIGPU, "At least 2 CUDA GPUS needed"
)
def test_shared_allreduce_gloo(self):
self._test_multiprocess(
ProcessGroupShareTensorTest._test_allreduce_process,
[torch.ones(2, 2).to(i) for i in range(self.world_size)],
ProcessGroupShareTensorTest._init_pg_gloo,
1,
)
@skip_but_pass_in_sandcastle_if(
not TEST_MULTIGPU, "At least 2 CUDA GPUS needed"
)
def test_shared_allgather_gloo(self):
self._test_multiprocess(
ProcessGroupShareTensorTest._test_allgather_process,
[torch.ones(2, 2).to(i) * i for i in range(self.world_size)],
ProcessGroupShareTensorTest._init_pg_gloo,
self.world_size,
)
@classmethod
def _test_allgather_chunk_process(
cls, rank, filename, shared_tensor, world_size, init_pg, c2p, p2c
):
pg = init_pg(rank, filename, world_size)
chunks = torch.chunk(shared_tensor, world_size, dim=0)
x = chunks[rank]
ys = [torch.zeros_like(x) for _ in range(world_size)]
pg.allgather(ys, x).wait()
c2p.put((rank, chunks[0].to("cpu"), ys[0].to("cpu")))
c2p.put((rank, chunks[1].to("cpu"), ys[1].to("cpu")))
p2c.get()
@skip_but_pass_in_sandcastle_if(
not TEST_MULTIGPU, "At least 2 CUDA GPUS needed"
)
def test_shared_allgather_chunk_gloo(self):
self._test_multiprocess(
ProcessGroupShareTensorTest._test_allgather_chunk_process,
torch.tensor(range(4)).reshape(2, 2),
ProcessGroupShareTensorTest._init_pg_gloo,
self.world_size,
)
class DistributedDataParallelSingleProcessTest(TestCase):

View File

@ -1,95 +1,21 @@
# Owner(s): ["oncall: distributed"]
import sys
import test_c10d_spawn
from test_c10d_spawn import _torch_dist_nn_available, TestDistributedNNFunctions
import torch
import torch.distributed as c10d
from torch.testing._internal.common_cuda import TEST_MULTIGPU
from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu
from torch.testing._internal.common_utils import (
run_tests,
skip_but_pass_in_sandcastle_if,
TEST_WITH_DEV_DBG_ASAN,
TestCase,
)
NO_NCCL = not hasattr(c10d, "ProcessGroupNCCL")
# Fails on Python-3.9, see https://github.com/pytorch/pytorch/issues/51619
if sys.version_info < (3, 9):
class ProcessGroupShareTensorTest(
test_c10d_spawn.AbstractProcessGroupShareTensorTest, TestCase
):
@classmethod
def _init_pg_nccl(cls, rank, filename, world_size):
store = c10d.FileStore(filename, world_size)
return c10d.ProcessGroupNCCL(store, rank, world_size)
@skip_but_pass_in_sandcastle_if(
not TEST_MULTIGPU, "At least 2 CUDA GPUS needed"
)
@skip_but_pass_in_sandcastle_if(NO_NCCL, "NCCL needed")
def test_shared_broadcast_nccl(self):
self._test_multiprocess(
ProcessGroupShareTensorTest._test_broadcast_process,
[torch.ones(2, 2).to(i) * i for i in range(self.world_size)],
ProcessGroupShareTensorTest._init_pg_nccl,
1,
)
@skip_but_pass_in_sandcastle_if(
not TEST_MULTIGPU, "At least 2 CUDA GPUS needed"
)
@skip_but_pass_in_sandcastle_if(NO_NCCL, "NCCL needed")
def test_shared_allreduce_nccl(self):
self._test_multiprocess(
ProcessGroupShareTensorTest._test_allreduce_process,
[torch.ones(2, 2).to(i) for i in range(self.world_size)],
ProcessGroupShareTensorTest._init_pg_nccl,
1,
)
@classmethod
def _test_reduce_process(
cls, rank, filename, shared_tensors, world_size, init_pg, c2p, p2c
):
pg = init_pg(rank, filename, world_size)
x = shared_tensors[rank]
pg.reduce(x, root=0, op=c10d.ReduceOp.SUM).wait()
if rank == 0:
c2p.put((rank, torch.ones(2, 2) * 2, x.to("cpu")))
else:
c2p.put((rank, torch.ones(2, 2), x.to("cpu")))
p2c.get()
@skip_but_pass_in_sandcastle_if(
not TEST_MULTIGPU, "At least 2 CUDA GPUS needed"
)
@skip_but_pass_in_sandcastle_if(NO_NCCL, "NCCL needed")
def test_shared_reduce_nccl(self):
self._test_multiprocess(
ProcessGroupShareTensorTest._test_reduce_process,
[torch.ones(2, 2).to(i) for i in range(self.world_size)],
ProcessGroupShareTensorTest._init_pg_nccl,
1,
)
@skip_but_pass_in_sandcastle_if(
not TEST_MULTIGPU, "At least 2 CUDA GPUS needed"
)
@skip_but_pass_in_sandcastle_if(NO_NCCL, "NCCL needed")
def test_shared_allgather_nccl(self):
self._test_multiprocess(
ProcessGroupShareTensorTest._test_allgather_process,
[torch.ones(2, 2).to(i) * i for i in range(self.world_size)],
ProcessGroupShareTensorTest._init_pg_nccl,
self.world_size,
)
# Skip dev-asan as torch + multiprocessing spawn have known issues

View File

@ -1,74 +1,21 @@
# Owner(s): ["oncall: distributed"]
import sys
import test_c10d_spawn
from test_c10d_spawn import _torch_dist_nn_available, TestDistributedNNFunctions
import torch
import torch.distributed as c10d
from torch.testing._internal.common_cuda import TEST_MULTIGPU
from torch.testing._internal.common_distributed import requires_ucc, skip_if_lt_x_gpu
from torch.testing._internal.common_utils import (
run_tests,
skip_but_pass_in_sandcastle,
skip_but_pass_in_sandcastle_if,
TEST_WITH_DEV_DBG_ASAN,
TestCase,
)
NO_UCC = not hasattr(c10d, "ProcessGroupUCC")
# Fails on Python-3.9, see https://github.com/pytorch/pytorch/issues/51619
if sys.version_info < (3, 9):
class ProcessGroupShareTensorTest(
test_c10d_spawn.AbstractProcessGroupShareTensorTest, TestCase
):
@classmethod
def _init_pg_ucc(cls, rank, filename, world_size):
store = c10d.FileStore(filename, world_size)
c10d.init_process_group(
backend="ucc", store=store, rank=rank, world_size=world_size
)
return c10d.distributed_c10d._get_default_group()
@skip_but_pass_in_sandcastle_if(
not TEST_MULTIGPU, "At least 2 CUDA GPUS needed"
)
@skip_but_pass_in_sandcastle_if(NO_UCC, "UCC needed")
def test_shared_broadcast_ucc(self):
self._test_multiprocess(
ProcessGroupShareTensorTest._test_broadcast_process,
[torch.ones(2, 2).to(i) * i for i in range(self.world_size)],
ProcessGroupShareTensorTest._init_pg_ucc,
1,
)
@skip_but_pass_in_sandcastle_if(
not TEST_MULTIGPU, "At least 2 CUDA GPUS needed"
)
@skip_but_pass_in_sandcastle_if(NO_UCC, "UCC needed")
def test_shared_allreduce_ucc(self):
self._test_multiprocess(
ProcessGroupShareTensorTest._test_allreduce_process,
[torch.ones(2, 2).to(i) for i in range(self.world_size)],
ProcessGroupShareTensorTest._init_pg_ucc,
1,
)
@skip_but_pass_in_sandcastle_if(
not TEST_MULTIGPU, "At least 2 CUDA GPUS needed"
)
@skip_but_pass_in_sandcastle_if(NO_UCC, "UCC needed")
def test_shared_allgather_ucc(self):
self._test_multiprocess(
ProcessGroupShareTensorTest._test_allgather_process,
[torch.ones(2, 2).to(i) * i for i in range(self.world_size)],
ProcessGroupShareTensorTest._init_pg_ucc,
self.world_size,
)
# Skip dev-asan as torch + multiprocessing spawn have known issues

View File

@ -7,7 +7,6 @@ import unittest
from contextlib import contextmanager
from datetime import timedelta
from io import StringIO
from typing import List
from unittest.mock import patch
import numpy as np
@ -1959,7 +1958,7 @@ class TestSingleProc(DynamoDistributedSingleProcTestCase):
model = ModuleWithStaticMethod(False)
x = torch.randn((2, 3), device="cuda")
ref_out = model(x)
test_outs: List[torch.Tensor] = []
test_outs: list[torch.Tensor] = []
for use_self in (False, True):
model = ModuleWithStaticMethod(use_self)

View File

@ -5,7 +5,6 @@ import functools
import math
import unittest # noqa: F811
from importlib import import_module
from typing import Set
import torch
import torch._dynamo.config
@ -86,7 +85,7 @@ def count_ops(
return gm
def collect_fwd_graph_outputs(graph: torch.fx.Graph, *, fwd_outputs: Set[str]):
def collect_fwd_graph_outputs(graph: torch.fx.Graph, *, fwd_outputs: set[str]):
if not torch._dynamo.compiled_autograd.in_compiled_autograd_region: # fwd graph
return_node = list(graph.nodes)[-1]
assert return_node.target == "output"

View File

@ -12,7 +12,6 @@ import operator
import unittest
from collections.abc import Sequence
from enum import Enum
from typing import Dict, List
from unittest.mock import patch
import torch
@ -1999,7 +1998,7 @@ def forward(self, l_x_):
self.assertIn("val", node.meta)
def test_input_container_type(self):
def f(x: torch.Tensor, y: List[torch.Tensor]) -> Dict[str, torch.Tensor]:
def f(x: torch.Tensor, y: list[torch.Tensor]) -> dict[str, torch.Tensor]:
return {"a": x.sum() + sum(y).sum()}
inp = (torch.randn(6, 5), [torch.randn(6, 5), torch.randn(6, 5)])

View File

@ -11,7 +11,7 @@ import random
import sys
import unittest
from dataclasses import dataclass, field
from typing import Any, Dict, Generic, List, TypeVar
from typing import Any, Generic, TypeVar
from typing_extensions import NamedTuple
from unittest.mock import patch
@ -3896,8 +3896,8 @@ class DefaultsTests(torch._dynamo.test_case.TestCase):
@dataclass
class Output:
scalar: int = 2
named_tensors: Dict[str, torch.Tensor] = field(default_factory=dict)
lists: List[torch.Tensor] = field(default_factory=list)
named_tensors: dict[str, torch.Tensor] = field(default_factory=dict)
lists: list[torch.Tensor] = field(default_factory=list)
def scale(self):
return self.scalar * 2

View File

@ -12,7 +12,7 @@ import types
import unittest
from copy import deepcopy
from functools import partial
from typing import Dict, NamedTuple, Tuple
from typing import NamedTuple
from unittest.mock import patch
import torch
@ -602,7 +602,7 @@ class LazyMLP(torch.nn.Module):
class MyInput(NamedTuple):
x: Dict[str, Dict[str, torch.Tensor]]
x: dict[str, dict[str, torch.Tensor]]
y: torch.Tensor
@ -2311,7 +2311,7 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
m = TestModule()
def forward_hook(
module: torch.nn.Module, inputs: Tuple[torch.Tensor], output: torch.Tensor
module: torch.nn.Module, inputs: tuple[torch.Tensor], output: torch.Tensor
) -> torch.Tensor:
return 2 * output + 1
@ -2358,7 +2358,7 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
m = TestModule()
def forward_hook(
module: torch.nn.Module, inputs: Tuple[torch.Tensor], output: torch.Tensor
module: torch.nn.Module, inputs: tuple[torch.Tensor], output: torch.Tensor
) -> torch.Tensor:
return 2 * output + 1
@ -2407,7 +2407,7 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
self.assertEqual(compiled_func(inp).item(), 15)
def new_forward_hook(
module: torch.nn.Module, inputs: Tuple[torch.Tensor], output: torch.Tensor
module: torch.nn.Module, inputs: tuple[torch.Tensor], output: torch.Tensor
) -> torch.Tensor:
return 2 * output + 2
@ -2426,7 +2426,7 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
m = TestModule()
def forward_hook(
module: torch.nn.Module, inputs: Tuple[torch.Tensor], output: torch.Tensor
module: torch.nn.Module, inputs: tuple[torch.Tensor], output: torch.Tensor
) -> torch.Tensor:
return 2 * output + 1

View File

@ -1,5 +1,5 @@
# Owner(s): ["module: dynamo"]
from typing import Callable, Dict, List, NamedTuple, Optional
from typing import Callable, List, NamedTuple, Optional
import torch
import torch._dynamo
@ -50,20 +50,20 @@ class Variable:
def sum(self, name: Optional[str] = None) -> "Variable":
return operator_sum(self, name)
def expand(self, sizes: List[int]) -> "Variable":
def expand(self, sizes: list[int]) -> "Variable":
return operator_expand(self, sizes)
class TapeEntry(NamedTuple):
# names of the inputs to the original computation
inputs: List[str]
inputs: list[str]
# names of the outputs of the original computation
outputs: List[str]
outputs: list[str]
# apply chain rule
propagate: "Callable[List[Variable], List[Variable]]"
propagate: "Callable[list[Variable], list[Variable]]"
gradient_tape: List[TapeEntry] = []
gradient_tape: list[TapeEntry] = []
def reset_tape():
@ -72,9 +72,9 @@ def reset_tape():
_name = 0
def grad(L, desired_results: List[Variable]) -> List[Variable]:
def grad(L, desired_results: list[Variable]) -> list[Variable]:
# this map holds dL/dX for all values X
dL_d: Dict[str, Variable] = {}
dL_d: dict[str, Variable] = {}
# It starts by initializing the 'seed' dL/dL, which is 1
dL_d[L.name] = Variable(torch.ones(()))
# print(f'd{L.name} ------------------------')

View File

@ -3,7 +3,6 @@
import contextlib
import dis
import unittest
from typing import List
import torch
import torch._dynamo.test_case
@ -33,7 +32,7 @@ class ReconstructTest(torch._dynamo.test_case.TestCase):
Emit code to reconstruct only the key that changed
"""
def hook(instructions: List[dis.Instruction]):
def hook(instructions: list[dis.Instruction]):
build_map = _filter_instructions(instructions, "BUILD_MAP")
self.assertEqual(len(build_map), 1)
# reconstruct only d[40]
@ -57,7 +56,7 @@ class ReconstructTest(torch._dynamo.test_case.TestCase):
If something is pop'ed from the dict, we reconstruct everything
"""
def hook(instructions: List[dis.Instruction]):
def hook(instructions: list[dis.Instruction]):
build_map = _filter_instructions(instructions, "BUILD_MAP")
self.assertEqual(len(build_map), 1)
# reconstruct everything
@ -84,7 +83,7 @@ class ReconstructTest(torch._dynamo.test_case.TestCase):
If something is pop'ed from the dict, we reconstruct everything
"""
def hook(instructions: List[dis.Instruction]):
def hook(instructions: list[dis.Instruction]):
build_map = _filter_instructions(instructions, "BUILD_MAP")
self.assertEqual(len(build_map), 1)
# reconstruct everything
@ -128,7 +127,7 @@ class ReconstructTest(torch._dynamo.test_case.TestCase):
If something is deleted from the dict, we reconstruct everything
"""
def hook(instructions: List[dis.Instruction]):
def hook(instructions: list[dis.Instruction]):
build_map = _filter_instructions(instructions, "BUILD_MAP")
self.assertEqual(len(build_map), 1)
# reconstruct everything
@ -154,7 +153,7 @@ class ReconstructTest(torch._dynamo.test_case.TestCase):
dict.get shouldn't affect anything
"""
def hook(instructions: List[dis.Instruction]):
def hook(instructions: list[dis.Instruction]):
build_map = _filter_instructions(instructions, "BUILD_MAP")
self.assertEqual(len(build_map), 1)
self.assertEqual(build_map[0].argval, 1)
@ -180,7 +179,7 @@ class ReconstructTest(torch._dynamo.test_case.TestCase):
If dict.clear() is used, we reconstruct everything
"""
def hook(instructions: List[dis.Instruction]):
def hook(instructions: list[dis.Instruction]):
build_map = _filter_instructions(instructions, "BUILD_MAP")
self.assertEqual(len(build_map), 1)
# reconstruct everything
@ -206,7 +205,7 @@ class ReconstructTest(torch._dynamo.test_case.TestCase):
If dict is created inside a function, everything needs to be reconstructed
"""
def hook(instructions: List[dis.Instruction]):
def hook(instructions: list[dis.Instruction]):
build_map = _filter_instructions(instructions, "BUILD_MAP")
self.assertEqual(len(build_map), 1)
# reconstruct everything
@ -231,7 +230,7 @@ class ReconstructTest(torch._dynamo.test_case.TestCase):
PyTorch shouldn't codegen any key/value when functional_call is used
"""
def hook(instructions: List[dis.Instruction]):
def hook(instructions: list[dis.Instruction]):
build_map = _filter_instructions(instructions, "BUILD_MAP")
# don't reconstruct anything
self.assertEqual(len(build_map), 0)
@ -260,7 +259,7 @@ class ReconstructTest(torch._dynamo.test_case.TestCase):
PyTorch shouldn't codegen any key/value when functional_call is used
"""
def hook(instructions: List[dis.Instruction]):
def hook(instructions: list[dis.Instruction]):
build_map = _filter_instructions(instructions, "BUILD_MAP")
# don't reconstruct anything
self.assertEqual(len(build_map), 0)

View File

@ -25,7 +25,7 @@ from collections.abc import Iterator
from copy import deepcopy
from enum import Enum, IntEnum
from functools import wraps
from typing import Any, Dict, List, Literal, Tuple, TypedDict
from typing import Any, Literal, TypedDict
from unittest import mock
import numpy as np
@ -712,7 +712,7 @@ def create_rand_mask_from_inputs(
class SequentialAppendList(torch.nn.Sequential):
"""from timm/models/vovnet.py"""
def forward(self, x: torch.Tensor, concat_list: List[torch.Tensor]) -> torch.Tensor:
def forward(self, x: torch.Tensor, concat_list: list[torch.Tensor]) -> torch.Tensor:
for i, module in enumerate(self):
if i == 0:
concat_list.append(module(x))
@ -4108,7 +4108,7 @@ class ReproTests(torch._dynamo.test_case.TestCase):
def test_graph_break_on_jit_isinstance(self):
@torch.compile(backend="eager")
def fn(x):
if torch.jit.isinstance(x, List[str]):
if torch.jit.isinstance(x, list[str]):
return x * 2
return x
@ -4819,14 +4819,14 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
def test_detectron2_instances_cat(self):
class Instances:
def __init__(self, image_size: Tuple[int, int], **kwargs: Any):
def __init__(self, image_size: tuple[int, int], **kwargs: Any):
self._image_size = image_size
self._fields: Dict[str, Any] = {}
self._fields: dict[str, Any] = {}
for k, v in kwargs.items():
self.set(k, v)
@property
def image_size(self) -> Tuple[int, int]:
def image_size(self) -> tuple[int, int]:
return self._image_size
def __setattr__(self, name: str, val: Any) -> None:
@ -4861,7 +4861,7 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
return self._fields[name]
@staticmethod
def cat(instance_lists: List["Instances"]) -> "Instances":
def cat(instance_lists: list["Instances"]) -> "Instances":
assert all(isinstance(i, Instances) for i in instance_lists)
assert len(instance_lists) > 0
if len(instance_lists) == 1:
@ -5997,7 +5997,7 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
# https://github.com/pytorch/pytorch/issues/88813
def test_return_value_duplication_tensor(self) -> None:
def fn(val: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
def fn(val: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
return val * 2, val * 2
x = torch.randn(2, requires_grad=True)
@ -6016,7 +6016,7 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
# https://github.com/pytorch/pytorch/issues/114344
def test_return_value_duplication_mixed_grad(self) -> None:
def fn(val: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
def fn(val: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
with torch.no_grad():
out0 = val + 1
out1 = val + 1
@ -6033,7 +6033,7 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
# https://github.com/pytorch/pytorch/pull/134726#discussion_r1738774371
def test_return_value_duplication_scalar(self) -> None:
def fn(val: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
def fn(val: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
x, y = val * 2, val * 2
return x[0], y[0]

View File

@ -1822,7 +1822,7 @@ class GraphModule(torch.nn.Module):
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
@parametrize("dynamic", [True, False])
def test_mark_static_with_subclass_desugaring(self, dynamic):
from typing import Any, Callable, List, Optional
from typing import Any, Callable, Optional
from torch._dynamo.decorators import mark_static_address
from torch._inductor.compile_fx import compile_fx
@ -1835,9 +1835,9 @@ class GraphModule(torch.nn.Module):
def inner_compile(
gm: torch.fx.GraphModule,
example_inputs: List[torch.Tensor],
example_inputs: list[torch.Tensor],
cudagraphs: Optional[BoxedBool] = None,
static_input_idxs: Optional[List[int]] = None,
static_input_idxs: Optional[list[int]] = None,
is_backward: bool = False,
graph_id: Optional[int] = None,
cpp_wrapper: bool = False,
@ -1845,7 +1845,7 @@ class GraphModule(torch.nn.Module):
is_inference: bool = False,
boxed_forward_device_index: Optional[BoxedDeviceIndex] = None,
layout_opt: Optional[bool] = None,
extern_node_serializer: Optional[Callable[[List[Any]], Any]] = None,
extern_node_serializer: Optional[Callable[[list[Any]], Any]] = None,
):
if dynamic:
self.assertEqual(static_input_idxs, [2, 3, 4])

View File

@ -1,7 +1,6 @@
# Owner(s): ["module: dynamo"]
import sys
import unittest
from typing import Dict, List
import torch
import torch._dynamo.config
@ -23,7 +22,7 @@ except ImportError:
@torch._dynamo.config.patch(force_unspec_int_unbacked_size_like_on_torchrec_kjt=True)
class BucketizeMod(torch.nn.Module):
def __init__(self, feature_boundaries: Dict[str, List[float]]):
def __init__(self, feature_boundaries: dict[str, list[float]]):
super().__init__()
self.bucket_w = torch.nn.ParameterDict()
self.boundaries_dict = {}
@ -84,7 +83,7 @@ class TorchRecTests(TestCase):
@torch.compile(backend=counter, fullgraph=True, dynamic=True)
def f(id_list_features: KeyedJaggedTensor):
id_list_jt_dict: Dict[str, JaggedTensor] = id_list_features.to_dict()
id_list_jt_dict: dict[str, JaggedTensor] = id_list_features.to_dict()
pooled_embeddings = {}
# TODO: run feature processor
for emb_module, feature_names in tables:

View File

@ -6,7 +6,7 @@ import math
import types
import unittest
import warnings
from typing import Any, Dict, Set
from typing import Any
import torch
import torch._dynamo.config as config
@ -103,10 +103,10 @@ class AllowedObjects:
from the heuristic defined in `gen_allowed_objs_and_ids`.
"""
object_ids: Dict[int, str]
c_binding_in_graph_functions: Set[Any]
non_c_binding_in_graph_functions: Set[Any]
name_rule_map: Dict[str, Any]
object_ids: dict[int, str]
c_binding_in_graph_functions: set[Any]
non_c_binding_in_graph_functions: set[Any]
name_rule_map: dict[str, Any]
def gen_allowed_objs_and_ids(record=False, c_binding_only=True) -> AllowedObjects:

View File

@ -2,7 +2,7 @@
import unittest
from collections import OrderedDict
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Optional
import torch
import torch.utils._pytree as pytree
@ -65,11 +65,11 @@ class TestConverter(TestCase):
self,
M,
tracing_inputs,
option: Optional[List[str]] = None,
option: Optional[list[str]] = None,
check_persistent=False,
lifted_tensor_constants=None,
runtime_inputs: Optional[List[Any]] = None,
) -> List[ExportedProgram]:
runtime_inputs: Optional[list[Any]] = None,
) -> list[ExportedProgram]:
# By default, it tests both jit.trace and jit.script.
if option is None:
option = ["trace", "script"]
@ -130,7 +130,7 @@ class TestConverter(TestCase):
self._check_tensor_list_equal(ep_out, orig_out)
return ep_list
def _check_tensor_list_equal(self, xs: List[torch.Tensor], ys: List[torch.Tensor]):
def _check_tensor_list_equal(self, xs: list[torch.Tensor], ys: list[torch.Tensor]):
self.assertEqual(len(xs), len(ys))
for x, y in zip(xs, ys):
if isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor):
@ -219,7 +219,7 @@ class TestConverter(TestCase):
self._check_equal_ts_ep_converter(Module(), inp)
class Module(torch.nn.Module):
def forward(self, x: List[int]):
def forward(self, x: list[int]):
length = len(x)
return torch.ones(length)
@ -228,7 +228,7 @@ class TestConverter(TestCase):
self._check_equal_ts_ep_converter(Module(), inp, ["script"])
class Module(torch.nn.Module):
def forward(self, x: Dict[int, str]):
def forward(self, x: dict[int, str]):
length = len(x)
return torch.ones(length)
@ -237,7 +237,7 @@ class TestConverter(TestCase):
self._check_equal_ts_ep_converter(Module(), inp, ["script"])
class Module(torch.nn.Module):
def forward(self, x: Dict[bool, str]):
def forward(self, x: dict[bool, str]):
length = len(x)
return torch.ones(length)
@ -246,7 +246,7 @@ class TestConverter(TestCase):
self._check_equal_ts_ep_converter(Module(), inp, ["script"])
class Module(torch.nn.Module):
def forward(self, x: Dict[float, str]):
def forward(self, x: dict[float, str]):
length = len(x)
return torch.ones(length)
@ -255,7 +255,7 @@ class TestConverter(TestCase):
self._check_equal_ts_ep_converter(Module(), inp, ["script"])
class Module(torch.nn.Module):
def forward(self, x: Dict[torch.Tensor, str]):
def forward(self, x: dict[torch.Tensor, str]):
length = len(x)
return torch.ones(length)
@ -273,7 +273,7 @@ class TestConverter(TestCase):
def test_aten_add_t(self):
# python list append
class Module(torch.nn.Module):
def forward(self, x: List[torch.Tensor]):
def forward(self, x: list[torch.Tensor]):
out = []
out = out + x
a = torch.cat(out)
@ -531,7 +531,7 @@ class TestConverter(TestCase):
class Module(torch.nn.Module):
def forward(
self, x: torch.Tensor, y: torch.Tensor
) -> Tuple[bool, torch.Tensor]:
) -> tuple[bool, torch.Tensor]:
z = x + 1
return x is y, z
@ -546,7 +546,7 @@ class TestConverter(TestCase):
class Module(torch.nn.Module):
def forward(
self, x: torch.Tensor, y: torch.Tensor
) -> Tuple[bool, torch.Tensor]:
) -> tuple[bool, torch.Tensor]:
z = x + 1
return x is not y, z
@ -558,7 +558,7 @@ class TestConverter(TestCase):
class Module(torch.nn.Module):
def forward(
self, x: torch.Tensor, y: torch.Tensor
) -> Tuple[bool, torch.Tensor]:
) -> tuple[bool, torch.Tensor]:
z = x + 1
return not (x is not y), z
@ -573,7 +573,7 @@ class TestConverter(TestCase):
return x + y
class MUnpackTuple(torch.nn.Module):
def forward(self, x_tuple: Tuple[torch.Tensor, torch.Tensor]):
def forward(self, x_tuple: tuple[torch.Tensor, torch.Tensor]):
x, y = x_tuple
x = x.cos()
return x + y
@ -904,7 +904,7 @@ class TestConverter(TestCase):
return x.dtype in [torch.int8]
class MTensorIn(torch.nn.Module):
def forward(self, x: torch.Tensor, x_dict: Dict[torch.Tensor, str]):
def forward(self, x: torch.Tensor, x_dict: dict[torch.Tensor, str]):
return x in x_dict
# Traced function must return output that has tensors.
@ -1118,14 +1118,14 @@ class TestConverter(TestCase):
def test_prim_tolist(self):
class Module(torch.nn.Module):
def forward(self, x: torch.Tensor) -> List[int]:
def forward(self, x: torch.Tensor) -> list[int]:
return x.tolist()
inp = (torch.tensor([1, 2, 3]),)
self._check_equal_ts_ep_converter(Module(), inp, ["script"])
class Module(torch.nn.Module):
def forward(self, x: torch.Tensor) -> List[List[int]]:
def forward(self, x: torch.Tensor) -> list[list[int]]:
return x.tolist()
inp = (torch.tensor([[1, 2, 3], [4, 5, 6]]),)
@ -1353,7 +1353,7 @@ class TestConverter(TestCase):
def test_aten_append_t(self):
class M(torch.nn.Module):
def forward(self, x: List[torch.Tensor]):
def forward(self, x: list[torch.Tensor]):
out = []
out.append(x[0] + x[1])
out.append(x[0] - x[1])
@ -1381,7 +1381,7 @@ class TestConverter(TestCase):
self._check_equal_ts_ep_converter(M1(), inp, ["script"])
def test_ts2ep_with_loop(self):
def func1(x, x_list: List[torch.Tensor]):
def func1(x, x_list: list[torch.Tensor]):
a, b, c = x, x, x
for _ in range(1, 5, 2):
for k in range(5):

View File

@ -2,7 +2,6 @@
import copy
import tempfile
import unittest
from typing import List, Tuple
import torch
from torch.export import Dim, export
@ -325,7 +324,7 @@ class TestDraftExport(TestCase):
return torch.ops.mylib.foo(a)
@torch.library.custom_op("mylib::foo", mutates_args={})
def foo(a: torch.Tensor) -> List[torch.Tensor]:
def foo(a: torch.Tensor) -> list[torch.Tensor]:
x = a * 2
y = a.repeat(2, 2)
z = a.to(torch.bfloat16)
@ -370,7 +369,7 @@ class TestDraftExport(TestCase):
return torch.ops.mylib.foo(a)
@torch.library.custom_op("mylib::foo", mutates_args={})
def foo(a: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
def foo(a: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
return a * 2, a + 2
@foo.register_fake

View File

@ -1,7 +1,7 @@
# Owner(s): ["oncall: export"]
import unittest
from collections import OrderedDict
from typing import Any, Dict, Optional, Tuple
from typing import Any, Optional
import torch
from torch._export.passes.lift_constants_pass import (
@ -34,9 +34,9 @@ class GraphBuilder:
self.graph = torch.fx.Graph()
self.nodes = {}
self.values = {}
self.nn_module_stack_key: Dict[str, int] = {}
self.nn_module_stack_key: dict[str, int] = {}
self.latest_id = 0
self.input_to_kind: Dict[torch.fx.Node, InputKind] = {}
self.input_to_kind: dict[torch.fx.Node, InputKind] = {}
def input(self, name: str, value: torch.Tensor, kind: InputKind):
node = self.graph.placeholder(name)
@ -87,7 +87,7 @@ class GraphBuilder:
def create_nn_module_stack(
self, module_fqn: str
) -> OrderedDict[int, Tuple[str, type]]:
) -> OrderedDict[int, tuple[str, type]]:
cur_name = ""
nn_module_stack = OrderedDict()
for atom in module_fqn.split("."):

View File

@ -8,7 +8,6 @@ import math
import operator
import unittest
from re import escape
from typing import List, Set
import torch
from functorch.experimental.control_flow import cond
@ -75,11 +74,11 @@ class _AtenAddOperatorSupport(OperatorSupport):
return node.op == "call_function" and node.target in {torch.ops.aten.add.Tensor}
def _to_partition_names(partitions: List[Partition]) -> List[Set[str]]:
def _to_partition_names(partitions: list[Partition]) -> list[set[str]]:
return [{n.name for n in p.nodes} for p in partitions]
def _get_output_names(gm: torch.fx.GraphModule) -> List[str]:
def _get_output_names(gm: torch.fx.GraphModule) -> list[str]:
output_node = next(n for n in gm.graph.nodes if n.op == "output")
args = pytree.tree_leaves(output_node.args)
# if isinstance(args, tuple) and len(args) == 1:

View File

@ -12,7 +12,7 @@ import unittest
import warnings
from contextlib import ContextDecorator, nullcontext
from functools import partial, wraps
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Optional, Union
from unittest.mock import patch
from common_utils import decorate, decorateForModules, skip, skipOps, xfail
@ -319,8 +319,8 @@ class TestAOTAutograd(AOTTestCase):
def run_autograd(
self,
f: Callable,
fw_graph_cell: List[Optional[Callable]],
decompositions: Optional[Dict],
fw_graph_cell: list[Optional[Callable]],
decompositions: Optional[dict],
keep_input_mutations: bool,
dynamic: bool,
):
@ -358,11 +358,11 @@ class TestAOTAutograd(AOTTestCase):
def verify_aot_autograd(
self,
f,
inp_: Union[Callable, List[Any]],
inp_: Union[Callable, list[Any]],
*,
test_mutation: bool = False,
keep_inp_mutations: bool = False,
decompositions: Optional[Dict] = None,
decompositions: Optional[dict] = None,
dynamic: bool = False,
# Only active when inp_ is Callable.
# TODO: probably consolidate all tests to make inp a Callable.
@ -6748,8 +6748,8 @@ class TestAOTAutogradWithDynamo(TestAOTAutograd):
def run_autograd(
self,
f: Callable,
fw_graph_cell: List[Optional[Callable]],
decompositions: Optional[Dict],
fw_graph_cell: list[Optional[Callable]],
decompositions: Optional[dict],
keep_input_mutations: bool,
dynamic: bool,
):
@ -6880,8 +6880,8 @@ class TestAOTAutogradWithCache(TestAOTAutogradWithDynamo):
def run_autograd(
self,
f: Callable,
fw_graph_cell: List[Optional[Callable]],
decompositions: Optional[Dict],
fw_graph_cell: list[Optional[Callable]],
decompositions: Optional[dict],
keep_input_mutations: bool,
dynamic: bool,
):
@ -6904,11 +6904,11 @@ class TestAOTAutogradWithCache(TestAOTAutogradWithDynamo):
def verify_aot_autograd(
self,
f,
inp_: Union[Callable, List[Any]],
inp_: Union[Callable, list[Any]],
*,
test_mutation: bool = False,
keep_inp_mutations: bool = False,
decompositions: Optional[Dict] = None,
decompositions: Optional[dict] = None,
dynamic: bool = False,
# Only active when inp_ is Callable.
# TODO: probably consolidate all tests to make inp a Callable.

View File

@ -24,7 +24,7 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
from typing import Any, Callable, Dict
from typing import Any, Callable
from unittest import mock
from functorch.einops._parsing import (
@ -206,7 +206,7 @@ class TestParsedExpression(TestCase):
class TestParsingUtils(TestCase):
def test_parse_pattern_number_of_arrows(self) -> None:
axes_lengths: Dict[str, int] = {}
axes_lengths: dict[str, int] = {}
too_many_arrows_pattern = "a -> b -> c -> d"
with self.assertRaises(ValueError):
@ -220,13 +220,13 @@ class TestParsingUtils(TestCase):
parse_pattern(just_right_arrows, axes_lengths)
def test_ellipsis_invalid_identifier(self) -> None:
axes_lengths: Dict[str, int] = {"a": 1, _ellipsis: 2}
axes_lengths: dict[str, int] = {"a": 1, _ellipsis: 2}
pattern = f"a {_ellipsis} -> {_ellipsis} a"
with self.assertRaises(ValueError):
parse_pattern(pattern, axes_lengths)
def test_ellipsis_matching(self) -> None:
axes_lengths: Dict[str, int] = {}
axes_lengths: dict[str, int] = {}
pattern = "a -> a ..."
with self.assertRaises(ValueError):
@ -240,7 +240,7 @@ class TestParsingUtils(TestCase):
parse_pattern(pattern, axes_lengths)
def test_left_parenthesized_ellipsis(self) -> None:
axes_lengths: Dict[str, int] = {}
axes_lengths: dict[str, int] = {}
pattern = "(...) -> ..."
with self.assertRaises(ValueError):
@ -254,7 +254,7 @@ class MaliciousRepr:
class TestValidateRearrangeExpressions(TestCase):
def test_validate_axes_lengths_are_integers(self) -> None:
axes_lengths: Dict[str, Any] = {"a": 1, "b": 2, "c": 3}
axes_lengths: dict[str, Any] = {"a": 1, "b": 2, "c": 3}
pattern = "a b c -> c b a"
left, right = parse_pattern(pattern, axes_lengths)
validate_rearrange_expressions(left, right, axes_lengths)
@ -265,7 +265,7 @@ class TestValidateRearrangeExpressions(TestCase):
validate_rearrange_expressions(left, right, axes_lengths)
def test_non_unitary_anonymous_axes_raises_error(self) -> None:
axes_lengths: Dict[str, int] = {}
axes_lengths: dict[str, int] = {}
left_non_unitary_axis = "a 2 -> 1 1 a"
left, right = parse_pattern(left_non_unitary_axis, axes_lengths)
@ -278,7 +278,7 @@ class TestValidateRearrangeExpressions(TestCase):
validate_rearrange_expressions(left, right, axes_lengths)
def test_identifier_mismatch(self) -> None:
axes_lengths: Dict[str, int] = {}
axes_lengths: dict[str, int] = {}
mismatched_identifiers = "a -> a b"
left, right = parse_pattern(mismatched_identifiers, axes_lengths)
@ -291,7 +291,7 @@ class TestValidateRearrangeExpressions(TestCase):
validate_rearrange_expressions(left, right, axes_lengths)
def test_unexpected_axes_lengths(self) -> None:
axes_lengths: Dict[str, int] = {"c": 2}
axes_lengths: dict[str, int] = {"c": 2}
pattern = "a b -> b a"
left, right = parse_pattern(pattern, axes_lengths)

View File

@ -25,7 +25,6 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
from typing import List, Tuple
import numpy as np
@ -34,7 +33,7 @@ from functorch.einops import rearrange
from torch.testing._internal.common_utils import run_tests, TestCase
identity_patterns: List[str] = [
identity_patterns: list[str] = [
"...->...",
"a b c d e-> a b c d e",
"a b c d e ...-> ... a b c d e",
@ -45,7 +44,7 @@ identity_patterns: List[str] = [
"a ... c d e -> a (...) c d e",
]
equivalent_rearrange_patterns: List[Tuple[str, str]] = [
equivalent_rearrange_patterns: list[tuple[str, str]] = [
("a b c d e -> (a b) c d e", "a b ... -> (a b) ... "),
("a b c d e -> a b (c d) e", "... c d e -> ... (c d) e"),
("a b c d e -> a b c d e", "... -> ... "),
@ -149,7 +148,7 @@ class TestRearrange(TestCase):
def test_concatenations_and_stacking(self) -> None:
for n_arrays in [1, 2, 5]:
shapes: List[List[int]] = [[], [1], [1, 1], [2, 3, 5, 7], [1] * 6]
shapes: list[list[int]] = [[], [1], [1, 1], [2, 3, 5, 7], [1] * 6]
for shape in shapes:
arrays1 = [
torch.arange(i, i + np.prod(shape, dtype=int)).reshape(shape)

View File

@ -1,7 +1,7 @@
# Owner(s): ["module: fx"]
import copy
import unittest
from typing import Optional, Set, Type
from typing import Optional
import torch
import torch.fx
@ -38,7 +38,7 @@ class TestDCE(TestCase):
self,
m: torch.nn.Module,
expect_dce_changes: bool,
modules_to_be_leafs: Optional[Set[Type]] = None,
modules_to_be_leafs: Optional[set[type]] = None,
custom: bool = False,
):
class TestTracer(torch.fx.Tracer):

View File

@ -2,8 +2,6 @@
from __future__ import annotations # type: ignore[attr-defined]
import typing
import torch
from torch.fx import symbolic_trace
@ -27,13 +25,13 @@ class M2(torch.nn.Module):
# Non-torch annotation with no internal forward references
class M3(torch.nn.Module):
def forward(self, x: typing.List[torch.Tensor], a: A) -> torch.Tensor:
def forward(self, x: list[torch.Tensor], a: A) -> torch.Tensor:
return a(x[0])
# Non-torch annotation with internal forward references
class M4(torch.nn.Module):
def forward(self, x: typing.List[torch.Tensor], a: A) -> torch.Tensor:
def forward(self, x: list[torch.Tensor], a: A) -> torch.Tensor:
return a(x[0])

View File

@ -1,7 +1,6 @@
# Owner(s): ["module: fx"]
from collections import defaultdict
from typing import Dict, List, Tuple
import torch
from torch.fx.passes.split_utils import split_by_tags
@ -63,8 +62,8 @@ class TestSplitByTags(TestCase):
@staticmethod
def trace_and_tag(
module: torch.nn.Module, tags: List[str]
) -> Tuple[torch.fx.GraphModule, Dict[str, List[str]]]:
module: torch.nn.Module, tags: list[str]
) -> tuple[torch.fx.GraphModule, dict[str, list[str]]]:
"""
Test simple gm consists of nodes with tag (only show call_module nodes here):
linear1 - tag: "red"
@ -167,8 +166,8 @@ class TestSplitOutputType(TestCase):
@staticmethod
def trace_and_tag(
module: torch.nn.Module, inputs: torch.Tensor, tags: List[str]
) -> Tuple[torch.fx.GraphModule, Dict[str, List[str]]]:
module: torch.nn.Module, inputs: torch.Tensor, tags: list[str]
) -> tuple[torch.fx.GraphModule, dict[str, list[str]]]:
"""
Test simple gm consists of nodes with tag (only show call_module nodes here):
conv - tag: "red"

View File

@ -4,7 +4,7 @@
import unittest
from collections import deque
from functools import partial
from typing import List, TYPE_CHECKING
from typing import TYPE_CHECKING
import torch
import torch._dynamo
@ -390,7 +390,7 @@ def forward(self, arg0_1, arg1_1, arg2_1):
return output
def add_hooks(module, config):
handles: List[RemovableHandle] = []
handles: list[RemovableHandle] = []
q = deque([(module.__class__.__name__, module)])
while q:
name, m = q.pop()

View File

@ -5,7 +5,6 @@ import os
import sys
import tempfile
import unittest
from typing import Dict, Tuple
from unittest import skip
import torch
@ -1744,7 +1743,7 @@ class AOTInductorTestsTemplate:
def __init__(self) -> None:
super().__init__()
def forward(self, x: Dict[str, torch.Tensor]):
def forward(self, x: dict[str, torch.Tensor]):
device = next(iter(x.values())).device
add_ = torch.zeros(5, device=device)
mul_ = torch.ones(5, device=device)
@ -2660,7 +2659,7 @@ class AOTInductorTestsTemplate:
def forward(
self,
self_tensor: torch.Tensor,
indices: Tuple[torch.Tensor],
indices: tuple[torch.Tensor],
values: torch.Tensor,
):
return torch.index_put(
@ -4285,7 +4284,7 @@ def fail_cpu(is_skip=False):
)
def fail_gpu(suffixes: Tuple[str, ...], is_skip=False):
def fail_gpu(suffixes: tuple[str, ...], is_skip=False):
return TestFailure(
suffixes,
is_skip=is_skip,

View File

@ -4,7 +4,7 @@ import pickle
import shutil
import tempfile
import unittest
from typing import List, Optional, Union
from typing import Optional, Union
from unittest import mock
import torch
@ -1434,7 +1434,7 @@ class TestCudaCompileCommand(TestCase):
with mock.patch("subprocess.check_output") as check_output_mock:
CUDACodeCache.compile("test123.cu", "so", ["-Wsomething"])
check_output_mock.assert_called()
cmd_parts: List[str] = check_output_mock.call_args[0][0]
cmd_parts: list[str] = check_output_mock.call_args[0][0]
assert cmd_parts[0] == "nvcc", cmd_parts
assert "-Wsomething" in cmd_parts, cmd_parts
assert "-DNDEBUG" in cmd_parts, cmd_parts

View File

@ -1,6 +1,6 @@
# Owner(s): ["module: inductor"]
import unittest
from typing import Any, Dict, List, Type
from typing import Any
import sympy
@ -160,11 +160,11 @@ class TestFixedConfigs(TestCase):
class MyHeuristics(InductorChoices):
def triton_kernel_kwargs(
self,
kernel_cls: Type[TritonKernel],
kernel_cls: type[TritonKernel],
features: SIMDKernelFeatures,
groups: List[sympy.Expr],
kernel_kwargs: Dict[str, Any],
) -> Dict[str, Any]:
groups: list[sympy.Expr],
kernel_kwargs: dict[str, Any],
) -> dict[str, Any]:
return {
**kernel_kwargs,
"override_cooperative_reduction": cooperative,

View File

@ -3,7 +3,7 @@ import logging
import math
import os
import unittest
from typing import Callable, List, Optional
from typing import Callable, Optional
from unittest import mock
from torch.export import Dim
@ -114,7 +114,7 @@ class TestCutlassBackend(TestCase):
) as mocked_select_algorithm:
Y_compiled = torch.compile(mm, dynamic=False)(a, b)
Y = mm(a, b)
passed_choice_callers: List[ChoiceCaller] = mocked_select_algorithm[0][
passed_choice_callers: list[ChoiceCaller] = mocked_select_algorithm[0][
1
]
assert all(
@ -573,7 +573,7 @@ class TestCutlassBackend(TestCase):
return torch.addmm(x, a, b, alpha=alpha, beta=beta)
def compare_results(
m: int, k: int, n: int, alpha: float, beta: float, x_shape: List[int]
m: int, k: int, n: int, alpha: float, beta: float, x_shape: list[int]
) -> None:
x = torch.randn(x_shape).cuda().half()
a = torch.randn(m, k).cuda().half()

View File

@ -9,7 +9,7 @@ from collections import namedtuple
from contextlib import contextmanager
from dataclasses import dataclass
from itertools import product
from typing import Callable, List, Optional, Tuple, Union
from typing import Callable, Optional, Union
from unittest import expectedFailure, skip, skipUnless
from unittest.mock import patch
@ -511,7 +511,7 @@ class TestFlexAttention(InductorTestCase):
block_mask,
dtype: torch.dtype = torch.float16,
page_size: int = 128,
) -> Tuple[Tensor, Tensor, BlockMask, _score_mod_signature]:
) -> tuple[Tensor, Tensor, BlockMask, _score_mod_signature]:
assert block_mask is not None, "Must provide block_mask"
Q_B, Q_H, Q_S, _ = q.shape
KV_B, KV_H, KV_S, QK_D = k.shape
@ -596,7 +596,7 @@ class TestFlexAttention(InductorTestCase):
v: Tensor,
dtype: torch.dtype = torch.float16,
block_mask: Optional[BlockMask] = None,
) -> Tuple[Tensor, Tensor]:
) -> tuple[Tensor, Tensor]:
B, Q_H, Q_S, KV_H, KV_S = (
q.shape[0],
q.shape[1],
@ -797,7 +797,7 @@ class TestFlexAttention(InductorTestCase):
def run_dynamic_test(
self,
score_mask_mod: Tuple[Callable, Callable],
score_mask_mod: tuple[Callable, Callable],
dtype: torch.dtype = torch.float16,
B: int = B,
H: int = H,
@ -1089,7 +1089,7 @@ class TestFlexAttention(InductorTestCase):
@common_utils.parametrize("dtype", test_dtypes_fast)
@common_utils.parametrize("score_mask_mod", test_score_mask_mod_map.items())
def test_builtin_score_mods_dynamic(
self, dtype: torch.dtype, score_mask_mod: Tuple[Callable, Callable]
self, dtype: torch.dtype, score_mask_mod: tuple[Callable, Callable]
):
self.run_dynamic_test(score_mask_mod, dtype)
@ -1127,7 +1127,7 @@ class TestFlexAttention(InductorTestCase):
self,
dtype: torch.dtype,
score_mod: Callable,
BLOCK_SIZE: Union[int, Tuple[int, int]],
BLOCK_SIZE: Union[int, tuple[int, int]],
):
block_mask = create_block_mask(
noop_mask, B, H, S, S, BLOCK_SIZE=BLOCK_SIZE, device=self.device
@ -1142,8 +1142,8 @@ class TestFlexAttention(InductorTestCase):
def test_kv_batch_broadcast(
self,
dtype: torch.dtype,
batch_dims: Tuple[int, int],
head_dims: Tuple[int, int],
batch_dims: tuple[int, int],
head_dims: tuple[int, int],
score_mod: Callable,
):
Hq, Hkv = head_dims
@ -1175,8 +1175,8 @@ class TestFlexAttention(InductorTestCase):
def test_kv_batch_broadcast_causal_mask(
self,
dtype: torch.dtype,
batch_dims: Tuple[int, int],
head_dims: Tuple[int, int],
batch_dims: tuple[int, int],
head_dims: tuple[int, int],
score_mod: Callable,
):
Hq, Hkv = head_dims
@ -3616,7 +3616,7 @@ class TestBlockMask(InductorTestCase):
@supported_platform
@common_utils.parametrize("BLOCK_SIZE", [32, 64, 128, 256, (32, 64), (64, 32)])
def test_block_size_changes(self, BLOCK_SIZE: Union[int, Tuple[int, int]]):
def test_block_size_changes(self, BLOCK_SIZE: Union[int, tuple[int, int]]):
B, H, Q_LEN, KV_LEN = 4, 2, 2048, 2048
if isinstance(BLOCK_SIZE, int):
@ -3990,7 +3990,7 @@ BlockMask(shape=(1,s1,s2048,s2048),ssparsity=46.88%,s
)
def length_to_offsets(
lengths: List[int], device: Union[str, torch.device]
lengths: list[int], device: Union[str, torch.device]
) -> Tensor:
offsets = [0]
offsets.extend(lengths)
@ -4561,7 +4561,7 @@ class Params:
return f"batch:{self.batch_size}_head:{self.num_heads}_seq_len:{self.seq_length}_headdim:{self.head_dim}_dtype:{str(self.dtype).split('.')[-1]}"
def get_params(dtypes: List[torch.dtype]) -> List[Params]:
def get_params(dtypes: list[torch.dtype]) -> list[Params]:
params = []
seq_lengths = [37, 256, 277]
for seq_len, dtype in product(seq_lengths, dtypes):

View File

@ -3,7 +3,7 @@
import functools
from collections import namedtuple
from typing import Callable, Optional, Tuple, Union
from typing import Callable, Optional, Union
from unittest import expectedFailure, skipUnless
from unittest.mock import patch
@ -645,7 +645,7 @@ class TestFlexDecoding(InductorTestCase):
self,
dtype: torch.dtype,
score_mod: Callable,
head_dims: Tuple[int, int],
head_dims: tuple[int, int],
page_size: int,
):
Hq, Hkv = head_dims
@ -681,7 +681,7 @@ class TestFlexDecoding(InductorTestCase):
self,
dtype: torch.dtype,
score_mod: Callable,
BLOCK_SIZE: Union[int, Tuple[int, int]],
BLOCK_SIZE: Union[int, tuple[int, int]],
):
block_mask = create_block_mask(noop_mask, B, 1, 1, S, BLOCK_SIZE=BLOCK_SIZE)
self.run_test(score_mod, dtype, block_mask=block_mask)
@ -763,8 +763,8 @@ class TestFlexDecoding(InductorTestCase):
def test_kv_batch_broadcast(
self,
dtype: torch.dtype,
head_dims: Tuple[int, int],
batch_dims: Tuple[int, int],
head_dims: tuple[int, int],
batch_dims: tuple[int, int],
score_mod: Callable,
):
Hq, Hkv = head_dims

View File

@ -2,7 +2,7 @@
import functools
import unittest
from typing import List, Tuple, Union
from typing import Union
import torch
from torch import Tensor
@ -89,8 +89,8 @@ def _quantize_rowwise(x: Tensor, float8_dtype: torch.dtype):
def _fix_fp8_dtype_for_rocm(
dtype: Union[torch.dtype, List[torch.dtype], Tuple[torch.dtype]], device
) -> Union[torch.dtype, List[torch.dtype], Tuple[torch.dtype]]:
dtype: Union[torch.dtype, list[torch.dtype], tuple[torch.dtype]], device
) -> Union[torch.dtype, list[torch.dtype], tuple[torch.dtype]]:
# This function is used to change FP8 data types
# with MI300 supported FP8 types if device is GPU:
# e4m3fn -> e4m3fnuz

View File

@ -2,7 +2,7 @@
import sys
import unittest
from typing import List, Literal
from typing import Literal
from unittest.mock import MagicMock, patch
import torch
@ -50,8 +50,8 @@ class TestConfigFuzzer(TestCase):
self.assertEqual(toggle("", bool, True), False)
self.assertEqual(toggle("", Literal["foo", "bar"], "foo"), "bar")
self.assertEqual(toggle("", Literal["foo", "bar"], "bar"), "foo")
self.assertTrue("bar" in toggle("", List[Literal["foo", "bar"]], ["foo"]))
self.assertTrue("foo" in toggle("", List[Literal["foo", "bar"]], ["bar"]))
self.assertTrue("bar" in toggle("", list[Literal["foo", "bar"]], ["foo"]))
self.assertTrue("foo" in toggle("", list[Literal["foo", "bar"]], ["bar"]))
@unittest.skipIf(sys.version_info < (3, 10), "python < 3.10 not supported")
def test_sampling_method_random(self):

View File

@ -2,7 +2,6 @@
import collections
import unittest
from typing import List
import torch
import torch._inductor
@ -39,7 +38,7 @@ class TestHighwaySelfGating(torch.nn.Module):
def forward(
self,
inputs: List[torch.Tensor],
inputs: list[torch.Tensor],
) -> torch.Tensor:
results = []
for i in range(self.size):

Some files were not shown because too many files have changed in this diff Show More