PEP585 update - test (#145176)

See #145101 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145176
Approved by: https://github.com/bobrenjc93
This commit is contained in:
Aaron Orenstein 2025-01-21 11:22:20 -08:00 committed by PyTorch MergeBot
parent 40e27fbcf2
commit 99dbc5b0e2
146 changed files with 801 additions and 1099 deletions

View File

@ -7,7 +7,7 @@ import re
import textwrap import textwrap
import timeit import timeit
import unittest import unittest
from typing import Any, List, Tuple from typing import Any
import expecttest import expecttest
import numpy as np import numpy as np
@ -67,7 +67,7 @@ def generate_callgrind_artifacts() -> None:
def load_callgrind_artifacts() -> ( def load_callgrind_artifacts() -> (
Tuple[benchmark_utils.CallgrindStats, benchmark_utils.CallgrindStats] tuple[benchmark_utils.CallgrindStats, benchmark_utils.CallgrindStats]
): ):
"""Hermetic artifact to unit test Callgrind wrapper. """Hermetic artifact to unit test Callgrind wrapper.
@ -85,9 +85,9 @@ def load_callgrind_artifacts() -> (
pattern = re.compile(r"^\s*([0-9]+)\s(.+)$") pattern = re.compile(r"^\s*([0-9]+)\s(.+)$")
def to_function_counts( def to_function_counts(
count_strings: List[str], inclusive: bool count_strings: list[str], inclusive: bool
) -> benchmark_utils.FunctionCounts: ) -> benchmark_utils.FunctionCounts:
data: List[benchmark_utils.FunctionCount] = [] data: list[benchmark_utils.FunctionCount] = []
for cs in count_strings: for cs in count_strings:
# Storing entries as f"{c} {fn}" rather than [c, fn] adds some work # Storing entries as f"{c} {fn}" rather than [c, fn] adds some work
# reviving the artifact, but it makes the json much easier to read. # reviving the artifact, but it makes the json much easier to read.

View File

@ -7,7 +7,7 @@ import sys
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from collections import defaultdict from collections import defaultdict
from types import MethodType from types import MethodType
from typing import Any, List, Optional, TYPE_CHECKING, Union from typing import Any, Optional, TYPE_CHECKING, Union
import pytest import pytest
from _pytest.config import Config, filename_arg from _pytest.config import Config, filename_arg
@ -241,7 +241,7 @@ def pytest_report_teststatus(report, config):
@pytest.hookimpl(trylast=True) @pytest.hookimpl(trylast=True)
def pytest_collection_modifyitems(items: List[Any]) -> None: def pytest_collection_modifyitems(items: list[Any]) -> None:
""" """
This hook is used when rerunning disabled tests to get rid of all skipped tests This hook is used when rerunning disabled tests to get rid of all skipped tests
instead of running and skipping them N times. This avoids flooding the console instead of running and skipping them N times. This avoids flooding the console
@ -304,7 +304,7 @@ class StepcurrentPlugin:
self.skip: bool = config.getoption("stepcurrent_skip") self.skip: bool = config.getoption("stepcurrent_skip")
self.run_single: bool = config.getoption("run_single") self.run_single: bool = config.getoption("run_single")
def pytest_collection_modifyitems(self, config: Config, items: List[Any]) -> None: def pytest_collection_modifyitems(self, config: Config, items: list[Any]) -> None:
if not self.lastrun: if not self.lastrun:
self.report_status = "Cannot find last run test, not skipping" self.report_status = "Cannot find last run test, not skipping"
return return

View File

@ -1,5 +1,3 @@
from typing import List
import torch import torch
from torch import Tensor from torch import Tensor
@ -8,7 +6,7 @@ lib = torch.library._scoped_library("python_agnostic", "FRAGMENT")
lib.define("ultra_norm(Tensor[] inputs) -> Tensor") lib.define("ultra_norm(Tensor[] inputs) -> Tensor")
def ultra_norm(inputs: List[Tensor]) -> Tensor: def ultra_norm(inputs: list[Tensor]) -> Tensor:
""" """
Computes the ultra-L2-norm of a list of tensors via computing the norm of norms. Computes the ultra-L2-norm of a list of tensors via computing the norm of norms.

View File

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
import typing import typing
from typing import List, Optional, Union from typing import Optional, Union
import torch import torch
from torch import Tensor, types from torch import Tensor, types
@ -87,7 +87,7 @@ class TestInferSchemaWithAnnotation(TestCase):
result = torch.library.infer_schema(foo_op_2, mutates_args=mutates_args) result = torch.library.infer_schema(foo_op_2, mutates_args=mutates_args)
self.assertEqual(result, "(SymInt[] x) -> SymInt") self.assertEqual(result, "(SymInt[] x) -> SymInt")
def foo_op_3(x: typing.List[int]) -> int: def foo_op_3(x: list[int]) -> int:
return 1 return 1
result = torch.library.infer_schema(foo_op_3, mutates_args=mutates_args) result = torch.library.infer_schema(foo_op_3, mutates_args=mutates_args)
@ -99,7 +99,7 @@ class TestInferSchemaWithAnnotation(TestCase):
result = torch.library.infer_schema(foo_op_4, mutates_args=mutates_args) result = torch.library.infer_schema(foo_op_4, mutates_args=mutates_args)
self.assertEqual(result, "(SymInt[]? x) -> SymInt") self.assertEqual(result, "(SymInt[]? x) -> SymInt")
def foo_op_5(x: typing.Optional[typing.List[int]]) -> int: def foo_op_5(x: typing.Optional[list[int]]) -> int:
return 1 return 1
result = torch.library.infer_schema(foo_op_5, mutates_args=mutates_args) result = torch.library.infer_schema(foo_op_5, mutates_args=mutates_args)
@ -136,7 +136,7 @@ class TestInferSchemaWithAnnotation(TestCase):
result = torch.library.infer_schema(foo_op_3, mutates_args=mutates_args) result = torch.library.infer_schema(foo_op_3, mutates_args=mutates_args)
self.assertEqual(result, "(Tensor x) -> Tensor") self.assertEqual(result, "(Tensor x) -> Tensor")
def foo_op_4(x: List[int]) -> types.Number: def foo_op_4(x: list[int]) -> types.Number:
return x[0] return x[0]
result = torch.library.infer_schema(foo_op_4, mutates_args=mutates_args) result = torch.library.infer_schema(foo_op_4, mutates_args=mutates_args)
@ -154,7 +154,7 @@ class TestInferSchemaWithAnnotation(TestCase):
result = torch.library.infer_schema(foo_op_6, mutates_args=mutates_args) result = torch.library.infer_schema(foo_op_6, mutates_args=mutates_args)
self.assertEqual(result, "(SymInt[] x) -> SymInt") self.assertEqual(result, "(SymInt[] x) -> SymInt")
def foo_op_7(x: List[int]) -> int: def foo_op_7(x: list[int]) -> int:
return 1 return 1
result = torch.library.infer_schema(foo_op_7, mutates_args=mutates_args) result = torch.library.infer_schema(foo_op_7, mutates_args=mutates_args)
@ -166,7 +166,7 @@ class TestInferSchemaWithAnnotation(TestCase):
result = torch.library.infer_schema(foo_op_8, mutates_args=mutates_args) result = torch.library.infer_schema(foo_op_8, mutates_args=mutates_args)
self.assertEqual(result, "(SymInt[]? x) -> SymInt") self.assertEqual(result, "(SymInt[]? x) -> SymInt")
def foo_op_9(x: Optional[List[int]]) -> int: def foo_op_9(x: Optional[list[int]]) -> int:
return 1 return 1
result = torch.library.infer_schema(foo_op_9, mutates_args=mutates_args) result = torch.library.infer_schema(foo_op_9, mutates_args=mutates_args)

View File

@ -5,7 +5,7 @@ import copy
import functools import functools
import itertools import itertools
import unittest import unittest
from typing import Any, List, Optional, Type, Union from typing import Any, Optional, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -117,7 +117,7 @@ class TestFullyShardAutograd(FSDPTest):
local_inp = global_inp[ local_inp = global_inp[
self.rank * local_batch_size : (self.rank + 1) * local_batch_size self.rank * local_batch_size : (self.rank + 1) * local_batch_size
].detach() ].detach()
losses: List[torch.Tensor] = [] losses: list[torch.Tensor] = []
for _model, inp in ((ref_model, global_inp), (model, local_inp)): for _model, inp in ((ref_model, global_inp), (model, local_inp)):
losses.append(_model(inp).sum()) losses.append(_model(inp).sum())
losses[-1].backward() losses[-1].backward()
@ -141,7 +141,7 @@ class TestFullyShardAutograd(FSDPTest):
self._test_nontensor_activations, self._test_nontensor_activations,
) )
def _test_nontensor_activations(self, container_type: Type): def _test_nontensor_activations(self, container_type: type):
class Module(nn.Module): class Module(nn.Module):
def __init__(self, dim: int): def __init__(self, dim: int):
super().__init__() super().__init__()
@ -170,7 +170,7 @@ class TestFullyShardAutograd(FSDPTest):
return self.relu(self.lin2(self.relu(self.lin1(x)))) return self.relu(self.lin2(self.relu(self.lin1(x))))
class ToContainerType(nn.Module): class ToContainerType(nn.Module):
def __init__(self, container_type: Type): def __init__(self, container_type: type):
super().__init__() super().__init__()
self.container_type = container_type self.container_type = container_type
@ -190,7 +190,7 @@ class TestFullyShardAutograd(FSDPTest):
) )
class FromContainerType(nn.Module): class FromContainerType(nn.Module):
def __init__(self, container_type: Type): def __init__(self, container_type: type):
super().__init__() super().__init__()
self.container_type = container_type self.container_type = container_type
@ -227,7 +227,7 @@ class TestFullyShardAutograd(FSDPTest):
local_inp = global_inp[ local_inp = global_inp[
self.rank * local_batch_size : (self.rank + 1) * local_batch_size self.rank * local_batch_size : (self.rank + 1) * local_batch_size
].detach() ].detach()
losses: List[torch.Tensor] = [] losses: list[torch.Tensor] = []
for _model, inp in ((ref_model, global_inp), (model, local_inp)): for _model, inp in ((ref_model, global_inp), (model, local_inp)):
losses.append(_model(inp).sum()) losses.append(_model(inp).sum())
losses[-1].backward() losses[-1].backward()

View File

@ -4,7 +4,7 @@ import copy
import functools import functools
import itertools import itertools
import unittest import unittest
from typing import Callable, List, Optional, Tuple, Union from typing import Callable, Optional, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -58,7 +58,7 @@ from torch.testing._internal.distributed._tensor.common_dtensor import (
c10d_ops = torch.ops.c10d c10d_ops = torch.ops.c10d
# For recording FSDP events like unshard or post-backward # For recording FSDP events like unshard or post-backward
EventType = Tuple[str, str, TrainingState] EventType = tuple[str, str, TrainingState]
class TestFullyShardCollectiveOps(FSDPTestMultiThread): class TestFullyShardCollectiveOps(FSDPTestMultiThread):
@ -70,7 +70,7 @@ class TestFullyShardCollectiveOps(FSDPTestMultiThread):
def device(self) -> torch.device: def device(self) -> torch.device:
return torch.device("cuda:0") return torch.device("cuda:0")
def _get_param_sizes(self) -> List[torch.Size]: def _get_param_sizes(self) -> list[torch.Size]:
# For world size 128, the fp32 all-gather and reduce-scatter testing # For world size 128, the fp32 all-gather and reduce-scatter testing
# requires ~0.22 GB # requires ~0.22 GB
return [ return [
@ -84,7 +84,7 @@ class TestFullyShardCollectiveOps(FSDPTestMultiThread):
torch.Size([64, 297]), torch.Size([64, 297]),
] ]
def _init_params(self, param_sizes: List[torch.Size]) -> List[nn.Parameter]: def _init_params(self, param_sizes: list[torch.Size]) -> list[nn.Parameter]:
torch.manual_seed(42) torch.manual_seed(42)
orig_params = [ orig_params = [
nn.Parameter(torch.randn(size, device=self.device)) for size in param_sizes nn.Parameter(torch.randn(size, device=self.device)) for size in param_sizes
@ -96,7 +96,7 @@ class TestFullyShardCollectiveOps(FSDPTestMultiThread):
return orig_params return orig_params
def _init_fsdp_param_group( def _init_fsdp_param_group(
self, params: List[nn.Parameter], reshard_after_forward: Union[bool, int] self, params: list[nn.Parameter], reshard_after_forward: Union[bool, int]
): ):
module = nn.ParameterList([param.detach().clone() for param in params]) module = nn.ParameterList([param.detach().clone() for param in params])
mesh_info = FSDPMeshInfo(_init_default_fully_shard_mesh(), shard_mesh_dim=0) mesh_info = FSDPMeshInfo(_init_default_fully_shard_mesh(), shard_mesh_dim=0)
@ -143,7 +143,7 @@ class TestFullyShardCollectiveOps(FSDPTestMultiThread):
def _test_all_gather( def _test_all_gather(
self, self,
param_sizes: List[torch.Size], param_sizes: list[torch.Size],
reshard_after_forward: Union[bool, int], reshard_after_forward: Union[bool, int],
async_op: bool, async_op: bool,
all_gather_copy_in_stream: torch.cuda.Stream, all_gather_copy_in_stream: torch.cuda.Stream,
@ -165,7 +165,7 @@ class TestFullyShardCollectiveOps(FSDPTestMultiThread):
fsdp_param_group._to_unsharded() fsdp_param_group._to_unsharded()
def check_all_gathered_params( def check_all_gathered_params(
orig_params: List[nn.Parameter], module: nn.Module orig_params: list[nn.Parameter], module: nn.Module
): ):
for orig_param, param in zip(orig_params, module.parameters()): for orig_param, param in zip(orig_params, module.parameters()):
self.assertIsInstance(param, torch.Tensor) self.assertIsInstance(param, torch.Tensor)
@ -228,7 +228,7 @@ class TestFullyShardCollectiveOps(FSDPTestMultiThread):
def _test_reduce_scatter( def _test_reduce_scatter(
self, self,
param_sizes: List[torch.Size], param_sizes: list[torch.Size],
reduce_scatter_stream: torch.cuda.Stream, reduce_scatter_stream: torch.cuda.Stream,
reduce_scatter_dtype: torch.dtype, reduce_scatter_dtype: torch.dtype,
): ):
@ -453,7 +453,7 @@ class TestFullyShardPrefetch(FSDPTest):
model, optim, inp = self._init_transformer( model, optim, inp = self._init_transformer(
n_layers, reshard_after_forward, checkpoint_impl n_layers, reshard_after_forward, checkpoint_impl
) )
events: List[EventType] = [] events: list[EventType] = []
unshard_with_record = self._get_unshard_with_record( unshard_with_record = self._get_unshard_with_record(
FSDPParamGroup.unshard, events FSDPParamGroup.unshard, events
) )
@ -504,7 +504,7 @@ class TestFullyShardPrefetch(FSDPTest):
model, _, inp = self._init_transformer( model, _, inp = self._init_transformer(
n_layers, reshard_after_forward, checkpoint_impl n_layers, reshard_after_forward, checkpoint_impl
) )
events: List[EventType] = [] events: list[EventType] = []
unshard_with_record = self._get_unshard_with_record( unshard_with_record = self._get_unshard_with_record(
FSDPParamGroup.unshard, events FSDPParamGroup.unshard, events
) )
@ -582,7 +582,7 @@ class TestFullyShardPrefetch(FSDPTest):
fully_shard(model[1].lin2, reshard_after_forward=reshard_after_forward) fully_shard(model[1].lin2, reshard_after_forward=reshard_after_forward)
fully_shard(model, reshard_after_forward=reshard_after_forward) fully_shard(model, reshard_after_forward=reshard_after_forward)
inp = torch.randn((4, dim), device="cuda") inp = torch.randn((4, dim), device="cuda")
events: List[EventType] = [] events: list[EventType] = []
unshard_with_record = self._get_unshard_with_record( unshard_with_record = self._get_unshard_with_record(
FSDPParamGroup.unshard, events FSDPParamGroup.unshard, events
) )
@ -652,7 +652,7 @@ class TestFullyShardPrefetch(FSDPTest):
] ]
layer.set_modules_to_forward_prefetch(layers_to_prefetch) layer.set_modules_to_forward_prefetch(layers_to_prefetch)
events: List[EventType] = [] events: list[EventType] = []
unshard_with_record = self._get_unshard_with_record( unshard_with_record = self._get_unshard_with_record(
FSDPParamGroup.unshard, events FSDPParamGroup.unshard, events
) )
@ -742,7 +742,7 @@ class TestFullyShardPrefetch(FSDPTest):
] ]
layer.set_modules_to_backward_prefetch(layers_to_prefetch) layer.set_modules_to_backward_prefetch(layers_to_prefetch)
events: List[EventType] = [] events: list[EventType] = []
unshard_with_record = self._get_unshard_with_record( unshard_with_record = self._get_unshard_with_record(
FSDPParamGroup.unshard, events FSDPParamGroup.unshard, events
) )
@ -834,7 +834,7 @@ class TestFullyShardPrefetch(FSDPTest):
fully_shard(model) fully_shard(model)
optim = torch.optim.AdamW(model.parameters(), lr=1e-2) optim = torch.optim.AdamW(model.parameters(), lr=1e-2)
events: List[EventType] = [] events: list[EventType] = []
unshard_with_record = self._get_unshard_with_record( unshard_with_record = self._get_unshard_with_record(
FSDPParamGroup.unshard, events FSDPParamGroup.unshard, events
) )
@ -915,7 +915,7 @@ class TestFullyShardPrefetch(FSDPTest):
fully_shard(model) fully_shard(model)
optim = torch.optim.AdamW(model.parameters(), lr=1e-2) optim = torch.optim.AdamW(model.parameters(), lr=1e-2)
events: List[EventType] = [] events: list[EventType] = []
unshard_with_record = self._get_unshard_with_record( unshard_with_record = self._get_unshard_with_record(
FSDPParamGroup.unshard, events FSDPParamGroup.unshard, events
) )
@ -1011,7 +1011,7 @@ class TestFullyShardPrefetch(FSDPTest):
return model, optim, inp return model, optim, inp
def _get_unshard_with_record( def _get_unshard_with_record(
self, orig_unshard: Callable, events: List[EventType] self, orig_unshard: Callable, events: list[EventType]
) -> Callable: ) -> Callable:
def unshard_with_record(self, *args, **kwargs): def unshard_with_record(self, *args, **kwargs):
nonlocal events nonlocal events
@ -1025,7 +1025,7 @@ class TestFullyShardPrefetch(FSDPTest):
return unshard_with_record return unshard_with_record
def _get_reshard_with_record( def _get_reshard_with_record(
self, orig_reshard: Callable, events: List[EventType] self, orig_reshard: Callable, events: list[EventType]
) -> Callable: ) -> Callable:
def reshard_with_record(self, *args, **kwargs): def reshard_with_record(self, *args, **kwargs):
nonlocal events nonlocal events
@ -1040,7 +1040,7 @@ class TestFullyShardPrefetch(FSDPTest):
return reshard_with_record return reshard_with_record
def _get_post_backward_with_record( def _get_post_backward_with_record(
self, orig_post_backward: Callable, events: List[EventType] self, orig_post_backward: Callable, events: list[EventType]
) -> Callable: ) -> Callable:
def post_backward_with_record(self, *args, **kwargs): def post_backward_with_record(self, *args, **kwargs):
nonlocal events nonlocal events
@ -1080,7 +1080,7 @@ class TestFullyShardUnshardMultiProcess(FSDPTest):
self.mlp2 = MLP(dim) self.mlp2 = MLP(dim)
self.mlp3 = MLP(dim) self.mlp3 = MLP(dim)
def forward(self, ys: List[torch.Tensor], works: List[dist.Work]): def forward(self, ys: list[torch.Tensor], works: list[dist.Work]):
(y1, y2, y3), (work1, work2, work3) = ys, works (y1, y2, y3), (work1, work2, work3) = ys, works
work1.wait() work1.wait()
z1 = self.mlp1(y1) z1 = self.mlp1(y1)
@ -1126,7 +1126,7 @@ class TestFullyShardUnshardMultiProcess(FSDPTest):
torch.manual_seed(42 + self.rank + 1) torch.manual_seed(42 + self.rank + 1)
inp = torch.randn((batch_size, dim), device="cuda") inp = torch.randn((batch_size, dim), device="cuda")
for _ in range(10): for _ in range(10):
losses: List[torch.Tensor] = [] losses: list[torch.Tensor] = []
for _model, _optim in ((ref_model, ref_optim), (model, optim)): for _model, _optim in ((ref_model, ref_optim), (model, optim)):
losses.append(_model(inp).sum()) losses.append(_model(inp).sum())
losses[-1].backward() losses[-1].backward()

View File

@ -6,7 +6,7 @@ import functools
import math import math
import threading import threading
import unittest import unittest
from typing import Any, List, Optional, Tuple, Union from typing import Any, Optional, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -29,7 +29,7 @@ from torch.testing._internal.two_tensor import TwoTensor
def two_tensor_fsdp_pre_all_gather_v1( def two_tensor_fsdp_pre_all_gather_v1(
self, mesh: DeviceMesh self, mesh: DeviceMesh
) -> Tuple[Tuple[torch.Tensor, ...], Any]: ) -> tuple[tuple[torch.Tensor, ...], Any]:
all_gather_inputs = (self.a, self.b) all_gather_inputs = (self.a, self.b)
metadata = None metadata = None
return all_gather_inputs, metadata return all_gather_inputs, metadata
@ -39,10 +39,10 @@ def two_tensor_fsdp_pre_all_gather_v2(
self, self,
mesh: DeviceMesh, mesh: DeviceMesh,
outer_size: torch.Size, outer_size: torch.Size,
outer_stride: Tuple[int, ...], outer_stride: tuple[int, ...],
module: nn.Module, module: nn.Module,
mp_policy: MixedPrecisionPolicy, mp_policy: MixedPrecisionPolicy,
) -> Tuple[Tuple[torch.Tensor, ...], Any]: ) -> tuple[tuple[torch.Tensor, ...], Any]:
all_gather_inputs = (self.a, self.b) all_gather_inputs = (self.a, self.b)
metadata = None metadata = None
return all_gather_inputs, metadata return all_gather_inputs, metadata
@ -50,12 +50,12 @@ def two_tensor_fsdp_pre_all_gather_v2(
def two_tensor_fsdp_post_all_gather( def two_tensor_fsdp_post_all_gather(
self, self,
all_gather_outputs: Tuple[torch.Tensor, ...], all_gather_outputs: tuple[torch.Tensor, ...],
metadata: Any, metadata: Any,
param_dtype: torch.dtype, param_dtype: torch.dtype,
*, *,
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
) -> Union[Tuple[torch.Tensor, Tuple[torch.Tensor, ...]], None]: ) -> Union[tuple[torch.Tensor, tuple[torch.Tensor, ...]], None]:
assert metadata is None, f"{metadata}" assert metadata is None, f"{metadata}"
a, b = all_gather_outputs a, b = all_gather_outputs
if out is not None: if out is not None:
@ -96,10 +96,10 @@ class BFloat16AllGatherTensor(torch.Tensor):
self, self,
mesh: DeviceMesh, mesh: DeviceMesh,
outer_size: torch.Size, outer_size: torch.Size,
outer_stride: Tuple[int, ...], outer_stride: tuple[int, ...],
module: nn.Module, module: nn.Module,
mp_policy: MixedPrecisionPolicy, mp_policy: MixedPrecisionPolicy,
) -> Tuple[Tuple[torch.Tensor, ...], Any]: ) -> tuple[tuple[torch.Tensor, ...], Any]:
assert mesh.ndim == 1, f"{mesh.ndim}" assert mesh.ndim == 1, f"{mesh.ndim}"
mesh_size = mesh.size() mesh_size = mesh.size()
requires_padding = outer_size[0] % mesh_size != 0 requires_padding = outer_size[0] % mesh_size != 0
@ -116,12 +116,12 @@ class BFloat16AllGatherTensor(torch.Tensor):
def fsdp_post_all_gather( def fsdp_post_all_gather(
self, self,
all_gather_outputs: Tuple[torch.Tensor, ...], all_gather_outputs: tuple[torch.Tensor, ...],
metadata: Any, metadata: Any,
param_dtype: torch.dtype, param_dtype: torch.dtype,
*, *,
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
) -> Union[Tuple[torch.Tensor, Tuple[torch.Tensor, ...]], None]: ) -> Union[tuple[torch.Tensor, tuple[torch.Tensor, ...]], None]:
assert metadata is None, f"{metadata}" assert metadata is None, f"{metadata}"
(tensor,) = all_gather_outputs (tensor,) = all_gather_outputs
assert tensor.dtype == torch.bfloat16, f"{tensor.dtype}" assert tensor.dtype == torch.bfloat16, f"{tensor.dtype}"
@ -157,7 +157,7 @@ class BFloat16AllGatherTensor(torch.Tensor):
@staticmethod @staticmethod
def __tensor_unflatten__( def __tensor_unflatten__(
inner_tensors, outer_size: torch.Size, outer_stride: Tuple[int, ...] inner_tensors, outer_size: torch.Size, outer_stride: tuple[int, ...]
): ):
return inner_tensors["_data"] return inner_tensors["_data"]
@ -236,7 +236,7 @@ class TestFullyShardAllGatherExtensionsMultiProcess(
torch.manual_seed(42 + self.rank + 1) torch.manual_seed(42 + self.rank + 1)
inp = torch.randn((2, 8), device="cuda") inp = torch.randn((2, 8), device="cuda")
for iter_idx in range(10): for iter_idx in range(10):
losses: List[torch.Tensor] = [] losses: list[torch.Tensor] = []
for _model in (ref_model, model): for _model in (ref_model, model):
losses.append(_model(inp).sum()) losses.append(_model(inp).sum())
losses[-1].backward() losses[-1].backward()
@ -314,10 +314,10 @@ class TestFullyShardAllGatherExtensionsMultiThread(
self, self,
mesh: DeviceMesh, mesh: DeviceMesh,
outer_size: torch.Size, outer_size: torch.Size,
outer_stride: Tuple[int, ...], outer_stride: tuple[int, ...],
module: nn.Module, module: nn.Module,
mp_policy: MixedPrecisionPolicy, mp_policy: MixedPrecisionPolicy,
) -> Tuple[Tuple[torch.Tensor, ...], Any]: ) -> tuple[tuple[torch.Tensor, ...], Any]:
nonlocal tls nonlocal tls
tls.ran_pre_all_gather = True tls.ran_pre_all_gather = True
return (self.to(torch.bfloat16),), None return (self.to(torch.bfloat16),), None
@ -325,12 +325,12 @@ class TestFullyShardAllGatherExtensionsMultiThread(
@torch.no_grad() @torch.no_grad()
def fsdp_post_all_gather( def fsdp_post_all_gather(
self, self,
all_gather_outputs: Tuple[torch.Tensor, ...], all_gather_outputs: tuple[torch.Tensor, ...],
metadata: Any, metadata: Any,
param_dtype: torch.dtype, param_dtype: torch.dtype,
*, *,
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
) -> Union[Tuple[torch.Tensor, Tuple[torch.Tensor, ...]], None]: ) -> Union[tuple[torch.Tensor, tuple[torch.Tensor, ...]], None]:
(tensor,) = all_gather_outputs (tensor,) = all_gather_outputs
assert metadata is None, f"{metadata}" assert metadata is None, f"{metadata}"
assert tensor.dtype == torch.bfloat16, f"{tensor.dtype}" assert tensor.dtype == torch.bfloat16, f"{tensor.dtype}"
@ -416,10 +416,10 @@ class TestFullyShardAllGatherExtensionsMultiThread(
self, self,
mesh: DeviceMesh, mesh: DeviceMesh,
outer_size: torch.Size, outer_size: torch.Size,
outer_stride: Tuple[int, ...], outer_stride: tuple[int, ...],
module: nn.Module, module: nn.Module,
mp_policy: MixedPrecisionPolicy, mp_policy: MixedPrecisionPolicy,
) -> Tuple[Tuple[torch.Tensor, ...], Any]: ) -> tuple[tuple[torch.Tensor, ...], Any]:
nonlocal tls nonlocal tls
tls.mesh = mesh tls.mesh = mesh
return (self,), None return (self,), None
@ -427,12 +427,12 @@ class TestFullyShardAllGatherExtensionsMultiThread(
@torch.no_grad() @torch.no_grad()
def fsdp_post_all_gather( def fsdp_post_all_gather(
self, self,
all_gather_outputs: Tuple[torch.Tensor, ...], all_gather_outputs: tuple[torch.Tensor, ...],
metadata: Any, metadata: Any,
param_dtype: torch.dtype, param_dtype: torch.dtype,
*, *,
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
) -> Union[Tuple[torch.Tensor, Tuple[torch.Tensor, ...]], None]: ) -> Union[tuple[torch.Tensor, tuple[torch.Tensor, ...]], None]:
(tensor,) = all_gather_outputs (tensor,) = all_gather_outputs
if out is not None: if out is not None:
return return

View File

@ -3,7 +3,7 @@
import copy import copy
import functools import functools
import itertools import itertools
from typing import List, Union from typing import Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -116,7 +116,7 @@ class TestFullyShardFrozen(FSDPTest):
), patch_register_post_backward_hook_backward(backward_with_count): ), patch_register_post_backward_hook_backward(backward_with_count):
for iter_idx in range(10): for iter_idx in range(10):
inp = torch.randn((8, lin_dim), device=device) inp = torch.randn((8, lin_dim), device=device)
losses: List[torch.Tensor] = [] losses: list[torch.Tensor] = []
for _model, _optim in ((ref_model, ref_optim), (model, optim)): for _model, _optim in ((ref_model, ref_optim), (model, optim)):
_optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) _optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
losses.append(_model(inp).sum()) losses.append(_model(inp).sum())
@ -151,7 +151,7 @@ class TestFullyShardFrozen(FSDPTest):
): ):
torch.manual_seed(42) torch.manual_seed(42)
num_linears, lin_dim = (6, 32) num_linears, lin_dim = (6, 32)
modules: List[nn.Module] = [] modules: list[nn.Module] = []
for _ in range(num_linears): for _ in range(num_linears):
modules += [nn.Linear(lin_dim, lin_dim), nn.ReLU()] modules += [nn.Linear(lin_dim, lin_dim), nn.ReLU()]
model = nn.Sequential(*modules) model = nn.Sequential(*modules)
@ -187,7 +187,7 @@ class TestFullyShardFrozen(FSDPTest):
inp = torch.randn((8, lin_dim), device="cuda") inp = torch.randn((8, lin_dim), device="cuda")
with patch_register_post_backward_hook_backward(backward_with_count): with patch_register_post_backward_hook_backward(backward_with_count):
for iter_idx in range(num_iters): for iter_idx in range(num_iters):
losses: List[torch.Tensor] = [] losses: list[torch.Tensor] = []
for _model, _optim in ((ref_model, ref_optim), (model, optim)): for _model, _optim in ((ref_model, ref_optim), (model, optim)):
# Unfreeze the parameters on the last step to emulate some # Unfreeze the parameters on the last step to emulate some
# kinds of fine-tuning # kinds of fine-tuning
@ -251,7 +251,7 @@ class TestFullyShardFrozen(FSDPTest):
optim = torch.optim.Adam(model.parameters(), lr=1e-2) optim = torch.optim.Adam(model.parameters(), lr=1e-2)
for iter_idx in range(10): for iter_idx in range(10):
inp = torch.randn((8, 5), device="cuda") inp = torch.randn((8, 5), device="cuda")
losses: List[torch.Tensor] = [] losses: list[torch.Tensor] = []
for _model, _optim in ((ref_model, ref_optim), (model, optim)): for _model, _optim in ((ref_model, ref_optim), (model, optim)):
_optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) _optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
losses.append(_model(inp).sum()) losses.append(_model(inp).sum())

View File

@ -3,7 +3,7 @@
import copy import copy
import itertools import itertools
import unittest import unittest
from typing import List, Optional from typing import Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -211,8 +211,8 @@ class TestFullyShardManagedModulesAndStates(FSDPTestMultiThread):
def _check_managed_modules( def _check_managed_modules(
self, self,
managed_modules: List[nn.Module], managed_modules: list[nn.Module],
expected_managed_modules: List[nn.Module], expected_managed_modules: list[nn.Module],
): ):
self.assertEqual(len(managed_modules), len(expected_managed_modules)) self.assertEqual(len(managed_modules), len(expected_managed_modules))
# Check set comparison since we do not require anything about the order # Check set comparison since we do not require anything about the order
@ -262,10 +262,10 @@ class TestFullyShardManagedModulesAndStates(FSDPTestMultiThread):
def _check_managed_states( def _check_managed_states(
self, self,
managed_params: List[nn.Parameter], managed_params: list[nn.Parameter],
managed_buffers: List[torch.Tensor], managed_buffers: list[torch.Tensor],
expected_managed_params: List[nn.Parameter], expected_managed_params: list[nn.Parameter],
expected_managed_buffers: List[torch.Tensor], expected_managed_buffers: list[torch.Tensor],
): ):
self.assertEqual(len(managed_params), len(expected_managed_params)) self.assertEqual(len(managed_params), len(expected_managed_params))
self.assertEqual(len(managed_buffers), len(expected_managed_buffers)) self.assertEqual(len(managed_buffers), len(expected_managed_buffers))
@ -370,7 +370,7 @@ class TestFullyShardShardedParameterTensor(FSDPTestMultiThread):
self._check_1d_sharded_parameters(orig_params, sharded_params) self._check_1d_sharded_parameters(orig_params, sharded_params)
def _check_1d_sharded_parameters( def _check_1d_sharded_parameters(
self, orig_params: List[nn.Parameter], sharded_params: List[nn.Parameter] self, orig_params: list[nn.Parameter], sharded_params: list[nn.Parameter]
): ):
self.assertEqual(len(orig_params), len(sharded_params)) self.assertEqual(len(orig_params), len(sharded_params))
global_mesh = init_device_mesh("cuda", (self.world_size,)) global_mesh = init_device_mesh("cuda", (self.world_size,))

View File

@ -2,7 +2,7 @@
import copy import copy
import functools import functools
from typing import Dict, List, Optional, Union from typing import Optional, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -339,7 +339,7 @@ class TestFullyShardMixedPrecisionTraining(FSDPTest):
model.set_reshard_after_backward( model.set_reshard_after_backward(
is_last_microbatch or reshard_after_forward is_last_microbatch or reshard_after_forward
) )
losses: List[torch.Tensor] = [] losses: list[torch.Tensor] = []
for _model in (ref_model_compute, model): for _model in (ref_model_compute, model):
losses.append( losses.append(
_model(microbatch_inps[microbatch_idx].detach()).sum() _model(microbatch_inps[microbatch_idx].detach()).sum()
@ -391,7 +391,7 @@ class TestFullyShardMixedPrecisionCasts(FSDPTestMultiThread):
# Subtest 1: use fp16 on the second child submodule -- does not require # Subtest 1: use fp16 on the second child submodule -- does not require
# any additional casting logic # any additional casting logic
forward_inputs: Dict[str, nn.Module] = {} forward_inputs: dict[str, nn.Module] = {}
model = SaveForwardInputsModel( model = SaveForwardInputsModel(
forward_inputs, forward_inputs,
cast_forward_inputs=False, cast_forward_inputs=False,
@ -405,7 +405,7 @@ class TestFullyShardMixedPrecisionCasts(FSDPTestMultiThread):
# Subtest 2: use fp16 on the second child module, where the user module # Subtest 2: use fp16 on the second child module, where the user module
# owns the cast # owns the cast
forward_inputs: Dict[nn.Module, torch.Tensor] = {} forward_inputs: dict[nn.Module, torch.Tensor] = {}
model = SaveForwardInputsModel( model = SaveForwardInputsModel(
forward_inputs=forward_inputs, cast_forward_inputs=True forward_inputs=forward_inputs, cast_forward_inputs=True
).cuda() ).cuda()
@ -423,7 +423,7 @@ class TestFullyShardMixedPrecisionCasts(FSDPTestMultiThread):
# Subtest 3: use fp16 on the first child module and specify its output # Subtest 3: use fp16 on the first child module and specify its output
# dtype so that the second child module does not need to cast # dtype so that the second child module does not need to cast
forward_inputs: Dict[nn.Module, torch.Tensor] = {} forward_inputs: dict[nn.Module, torch.Tensor] = {}
model = SaveForwardInputsModel( model = SaveForwardInputsModel(
forward_inputs=forward_inputs, cast_forward_inputs=False forward_inputs=forward_inputs, cast_forward_inputs=False
).cuda() ).cuda()
@ -448,7 +448,7 @@ class TestFullyShardMixedPrecisionCasts(FSDPTestMultiThread):
def _test_submodules_with_external_inputs(self, enable_submodule_cast: bool): def _test_submodules_with_external_inputs(self, enable_submodule_cast: bool):
class ToyModule(nn.Module): class ToyModule(nn.Module):
def __init__(self, forward_inputs: Dict[str, torch.Tensor]) -> None: def __init__(self, forward_inputs: dict[str, torch.Tensor]) -> None:
super().__init__() super().__init__()
self.l = nn.Linear(100, 100) self.l = nn.Linear(100, 100)
self.forward_inputs = forward_inputs self.forward_inputs = forward_inputs
@ -459,7 +459,7 @@ class TestFullyShardMixedPrecisionCasts(FSDPTestMultiThread):
return self.l(x) return self.l(x)
class ToyModel(nn.Module): class ToyModel(nn.Module):
def __init__(self, forward_inputs: Dict[str, torch.Tensor]) -> None: def __init__(self, forward_inputs: dict[str, torch.Tensor]) -> None:
super().__init__() super().__init__()
self.l1 = nn.Linear(100, 100) self.l1 = nn.Linear(100, 100)
self.l2 = ToyModule(forward_inputs) self.l2 = ToyModule(forward_inputs)
@ -472,7 +472,7 @@ class TestFullyShardMixedPrecisionCasts(FSDPTestMultiThread):
) # external input ) # external input
return self.l2(self.l1(x), y) return self.l2(self.l1(x), y)
forward_inputs: Dict[str, torch.Tensor] = {} forward_inputs: dict[str, torch.Tensor] = {}
model = ToyModel(forward_inputs).cuda() model = ToyModel(forward_inputs).cuda()
x = torch.zeros(2, 100, device="cuda", dtype=torch.float32) x = torch.zeros(2, 100, device="cuda", dtype=torch.float32)
fully_shard( fully_shard(

View File

@ -4,7 +4,7 @@ import copy
import functools import functools
import unittest import unittest
from contextlib import nullcontext from contextlib import nullcontext
from typing import Dict, Optional from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -308,7 +308,7 @@ class TestFullyShardStateDictMultiProcess(FSDPTest):
# Verify that we can load a new state dict that contains DTensors with # Verify that we can load a new state dict that contains DTensors with
# storages different from the current model parameters # storages different from the current model parameters
new_state_dict: Dict[str, DTensor] = {} new_state_dict: dict[str, DTensor] = {}
for param_name, dtensor in state_dict.items(): for param_name, dtensor in state_dict.items():
# Construct new DTensors to exercise load state dict writeback # Construct new DTensors to exercise load state dict writeback
new_state_dict[param_name] = dtensor.detach().clone().fill_(new_fill_value) new_state_dict[param_name] = dtensor.detach().clone().fill_(new_fill_value)

View File

@ -7,7 +7,7 @@ import itertools
import unittest import unittest
from collections import defaultdict from collections import defaultdict
from collections.abc import Iterable from collections.abc import Iterable
from typing import Any, List, Optional, Tuple, Union from typing import Any, Optional, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -65,7 +65,7 @@ class TestFullyShardForwardInputs(FSDPTestMultiThread):
device = torch.device("cuda", 0) device = torch.device("cuda", 0)
class ParamlessModule(nn.Module): class ParamlessModule(nn.Module):
def forward(self, x: torch.Tensor, ys: Tuple[torch.Tensor, ...]): def forward(self, x: torch.Tensor, ys: tuple[torch.Tensor, ...]):
# Check that FSDP moved the inputs to GPU, including recursing # Check that FSDP moved the inputs to GPU, including recursing
# into the tuple data structure # into the tuple data structure
assert x.device == device, f"Expects {device} but got {x.device}" assert x.device == device, f"Expects {device} but got {x.device}"
@ -224,7 +224,7 @@ class TestFullyShardCastAfterInit(FSDPTestMultiThread):
torch.manual_seed(42 + self.rank + 1) torch.manual_seed(42 + self.rank + 1)
inp = torch.randn((2, mlp_dim), device="cuda", dtype=dtype) inp = torch.randn((2, mlp_dim), device="cuda", dtype=dtype)
for iter_idx in range(10): for iter_idx in range(10):
losses: List[torch.Tensor] = [] losses: list[torch.Tensor] = []
for _model in (ref_model, model): for _model in (ref_model, model):
losses.append(_model(inp).sum()) losses.append(_model(inp).sum())
losses[-1].backward() losses[-1].backward()
@ -281,7 +281,7 @@ class TestFullyShard1DTrainingCore(FSDPTest):
) )
def _test_train_parity_single_group( def _test_train_parity_single_group(
self, lin_shapes: List[Tuple[int, int]], use_shard_placement_fn: bool self, lin_shapes: list[tuple[int, int]], use_shard_placement_fn: bool
): ):
torch.manual_seed(42) torch.manual_seed(42)
model = nn.Sequential( model = nn.Sequential(
@ -308,7 +308,7 @@ class TestFullyShard1DTrainingCore(FSDPTest):
torch.manual_seed(42 + self.rank + 1) torch.manual_seed(42 + self.rank + 1)
inp = (torch.randn((4, lin_shapes[0][0]), device="cuda"),) inp = (torch.randn((4, lin_shapes[0][0]), device="cuda"),)
for iter_idx in range(10): for iter_idx in range(10):
losses: List[torch.Tensor] = [] losses: list[torch.Tensor] = []
for _model, _optim in ((ref_model, ref_optim), (model, optim)): for _model, _optim in ((ref_model, ref_optim), (model, optim)):
_optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) _optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
losses.append(_model(*inp).sum()) losses.append(_model(*inp).sum())
@ -461,7 +461,7 @@ class TestFullyShard1DTrainingCore(FSDPTest):
with patch_all_gather_ctx, patch_reduce_scatter_ctx: with patch_all_gather_ctx, patch_reduce_scatter_ctx:
for iter_idx in range(10): for iter_idx in range(10):
inp = torch.randint(0, vocab_size, (3, 64), device=device_type) inp = torch.randint(0, vocab_size, (3, 64), device=device_type)
losses: List[torch.Tensor] = [] losses: list[torch.Tensor] = []
for _model, _optim in ((ref_model, ref_optim), (model, optim)): for _model, _optim in ((ref_model, ref_optim), (model, optim)):
_optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) _optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
losses.append(_model(inp).sum()) losses.append(_model(inp).sum())
@ -554,7 +554,7 @@ class TestFullyShard1DTrainingCore(FSDPTest):
torch.manual_seed(42 + self.rank) torch.manual_seed(42 + self.rank)
inp = torch.randn((32, 4), device="cuda") inp = torch.randn((32, 4), device="cuda")
for iter_idx in range(10): for iter_idx in range(10):
losses: List[torch.Tensor] = [] losses: list[torch.Tensor] = []
for _model, _optim in ((ref_model, ref_optim), (model, optim)): for _model, _optim in ((ref_model, ref_optim), (model, optim)):
_optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) _optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
losses.append(_model(inp).sum()) losses.append(_model(inp).sum())
@ -592,7 +592,7 @@ class TestFullyShard1DTrainingCore(FSDPTest):
torch.manual_seed(42 + self.rank) torch.manual_seed(42 + self.rank)
inp = torch.randint(0, model_args.vocab_size, (2, 8), device="cuda") inp = torch.randint(0, model_args.vocab_size, (2, 8), device="cuda")
for _ in range(10): for _ in range(10):
losses: List[torch.Tensor] = [] losses: list[torch.Tensor] = []
for _model, _optim in ((ref_model, ref_optim), (model, optim)): for _model, _optim in ((ref_model, ref_optim), (model, optim)):
_optim.zero_grad() _optim.zero_grad()
losses.append(_model(inp).sum()) losses.append(_model(inp).sum())
@ -623,8 +623,8 @@ class TestFullyShard1DTrainingCore(FSDPTest):
inp = torch.randint(0, model_args.vocab_size, (2, 8), device="cuda") inp = torch.randint(0, model_args.vocab_size, (2, 8), device="cuda")
# Track all losses and check for equality at the end to avoid a CPU # Track all losses and check for equality at the end to avoid a CPU
# sync point after each iteration # sync point after each iteration
ref_losses: List[torch.Tensor] = [] ref_losses: list[torch.Tensor] = []
losses: List[torch.Tensor] = [] losses: list[torch.Tensor] = []
for _ in range(10): for _ in range(10):
ref_optim.zero_grad() ref_optim.zero_grad()
ref_losses.append(ref_model(inp).sum()) ref_losses.append(ref_model(inp).sum())
@ -736,7 +736,7 @@ class TestFullyShard1DTrainingCompose(FSDPTest):
self, ref_model, model, prefixes_to_ignore=prefixes_to_ignore self, ref_model, model, prefixes_to_ignore=prefixes_to_ignore
) )
for iter_idx in range(10): for iter_idx in range(10):
losses: List[torch.Tensor] = [] losses: list[torch.Tensor] = []
for _model in (ref_model, model): for _model in (ref_model, model):
torch.manual_seed(iter_idx + 1) # for dropout determinism torch.manual_seed(iter_idx + 1) # for dropout determinism
losses.append(_model(inp).sum()) losses.append(_model(inp).sum())
@ -886,7 +886,7 @@ class TestFullyShardSharedParams(FSDPTest):
torch.manual_seed(42 + self.rank + 1) torch.manual_seed(42 + self.rank + 1)
for iter_idx in range(10): for iter_idx in range(10):
inp = torch.randint(0, model_args.vocab_size, (2, 16), device="cuda") inp = torch.randint(0, model_args.vocab_size, (2, 16), device="cuda")
losses: List[torch.Tensor] = [] losses: list[torch.Tensor] = []
for _model, _optim in ((ref_model, ref_optim), (model, optim)): for _model, _optim in ((ref_model, ref_optim), (model, optim)):
_optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) _optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
losses.append(_model(inp).sum()) losses.append(_model(inp).sum())
@ -1009,7 +1009,7 @@ class TestFullyShardGradientAccumulation(FSDPTest):
is_last_microbatch = microbatch_idx == num_microbatches - 1 is_last_microbatch = microbatch_idx == num_microbatches - 1
set_backward_flags(model, is_last_microbatch) set_backward_flags(model, is_last_microbatch)
inp = torch.randn(batch_size, lin_dim, device="cuda") inp = torch.randn(batch_size, lin_dim, device="cuda")
losses: List[torch.Tensor] = [] losses: list[torch.Tensor] = []
for _model in (ref_model, model): for _model in (ref_model, model):
with CommDebugMode() as comm_mode: with CommDebugMode() as comm_mode:
losses.append(_model(inp).sum()) losses.append(_model(inp).sum())
@ -1125,8 +1125,8 @@ class TestFullyShardGradientAccumulation(FSDPTest):
# Emulate the 1f1b pipeline schedule and only reduce gradients on the # Emulate the 1f1b pipeline schedule and only reduce gradients on the
# last microbatch # last microbatch
losses: List[torch.Tensor] = [] losses: list[torch.Tensor] = []
ref_losses: List[torch.Tensor] = [] ref_losses: list[torch.Tensor] = []
for inp_idx, inp in enumerate(inps): for inp_idx, inp in enumerate(inps):
is_last_microbatch = inp_idx == num_microbatches - 1 is_last_microbatch = inp_idx == num_microbatches - 1
model.set_requires_gradient_sync(is_last_microbatch) model.set_requires_gradient_sync(is_last_microbatch)
@ -1210,7 +1210,7 @@ class TestFullyShardNDTraining(FSDPTest):
device = torch.device("cuda") device = torch.device("cuda")
for iter_idx in range(10): for iter_idx in range(10):
inp = torch.randn((8, mlp_dim), device=device) inp = torch.randn((8, mlp_dim), device=device)
losses: List[torch.Tensor] = [] losses: list[torch.Tensor] = []
for _model, _optim in ((ref_model, ref_optim), (model, optim)): for _model, _optim in ((ref_model, ref_optim), (model, optim)):
_optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) _optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
losses.append(_model(inp).sum()) losses.append(_model(inp).sum())
@ -1281,7 +1281,7 @@ class TestFullyShardHSDP3DTraining(FSDPTest):
device = torch.device("cuda") device = torch.device("cuda")
for iter_idx in range(10): for iter_idx in range(10):
inp = torch.randn((8, mlp_dim), device=device) inp = torch.randn((8, mlp_dim), device=device)
losses: List[torch.Tensor] = [] losses: list[torch.Tensor] = []
for _model, _optim in ((ref_model, ref_optim), (model, optim)): for _model, _optim in ((ref_model, ref_optim), (model, optim)):
_optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) _optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
losses.append(_model(inp).sum()) losses.append(_model(inp).sum())
@ -1360,7 +1360,7 @@ class TestFullyShardHSDPTraining(FSDPTest):
if sync_gradients_at_last_batch: if sync_gradients_at_last_batch:
model.set_requires_gradient_sync(is_last_microbatch) model.set_requires_gradient_sync(is_last_microbatch)
inp = torch.randn((8, mlp_dim), device=device) inp = torch.randn((8, mlp_dim), device=device)
losses: List[torch.Tensor] = [] losses: list[torch.Tensor] = []
for _model, _optim in ((ref_model, ref_optim), (model, optim)): for _model, _optim in ((ref_model, ref_optim), (model, optim)):
losses.append(_model(inp).sum()) losses.append(_model(inp).sum())
losses[-1].backward() losses[-1].backward()

View File

@ -5,7 +5,6 @@ from collections import deque, OrderedDict
from contextlib import ContextDecorator, contextmanager, nullcontext from contextlib import ContextDecorator, contextmanager, nullcontext
from copy import deepcopy from copy import deepcopy
from functools import partial from functools import partial
from typing import Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -70,7 +69,7 @@ class MultiOutputModel(nn.Module):
self.w1 = nn.Parameter(torch.randn((100, 100), device=device)) self.w1 = nn.Parameter(torch.randn((100, 100), device=device))
self.w2 = nn.Parameter(torch.randn((100, 100), device=device)) self.w2 = nn.Parameter(torch.randn((100, 100), device=device))
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
z = x @ self.w1 z = x @ self.w1
z = nn.functional.relu(z) z = nn.functional.relu(z)
z = z @ self.w2 z = z @ self.w2
@ -82,7 +81,7 @@ class MultiInputModel(nn.Module):
super().__init__() super().__init__()
self.w = nn.Parameter(torch.randn((100, 100), device=device)) self.w = nn.Parameter(torch.randn((100, 100), device=device))
def forward(self, xs: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: def forward(self, xs: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
assert len(xs) == 2, f"Expects 2 args but got {len(xs)}" assert len(xs) == 2, f"Expects 2 args but got {len(xs)}"
x, y = xs x, y = xs
z = x + y z = x + y

View File

@ -4,7 +4,7 @@ import copy
import functools import functools
import io import io
from copy import deepcopy from copy import deepcopy
from typing import List, Optional, Type from typing import Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -166,7 +166,7 @@ class TestFullyShard2DTraining(FSDPTest):
device = torch.device("cuda") device = torch.device("cuda")
for iter_idx in range(10): for iter_idx in range(10):
inp = torch.randn((8, mlp_dim), device=device) inp = torch.randn((8, mlp_dim), device=device)
losses: List[torch.Tensor] = [] losses: list[torch.Tensor] = []
for _model, _optim in ((ref_model, ref_optim), (model, optim)): for _model, _optim in ((ref_model, ref_optim), (model, optim)):
_optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) _optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
losses.append(_model(inp).sum()) losses.append(_model(inp).sum())
@ -335,7 +335,7 @@ class TestFullyShard2DTraining(FSDPTest):
self, self,
use_seq_parallel: bool, use_seq_parallel: bool,
reuse_model_optim: bool, reuse_model_optim: bool,
optimizer_class: Type[torch.optim.Optimizer], optimizer_class: type[torch.optim.Optimizer],
foreach: bool, foreach: bool,
): ):
def train_step( def train_step(

View File

@ -1,7 +1,6 @@
# Owner(s): ["oncall: distributed"] # Owner(s): ["oncall: distributed"]
from copy import deepcopy from copy import deepcopy
from typing import List, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -28,12 +27,12 @@ class TestContract(TestCase):
@skipIfTorchDynamo("Dynamo does not support the state key") @skipIfTorchDynamo("Dynamo does not support the state key")
def test_add_hooks(self): def test_add_hooks(self):
def forward_pre_hook( def forward_pre_hook(
module: nn.Module, inp: Tuple[torch.Tensor] module: nn.Module, inp: tuple[torch.Tensor]
) -> Tuple[torch.Tensor]: ) -> tuple[torch.Tensor]:
return inp return inp
def forward_hook( def forward_hook(
module: nn.Module, inp: Tuple[torch.Tensor], out: torch.Tensor module: nn.Module, inp: tuple[torch.Tensor], out: torch.Tensor
) -> torch.Tensor: ) -> torch.Tensor:
return out return out
@ -44,9 +43,9 @@ class TestContract(TestCase):
def backward_hook( def backward_hook(
module: nn.Module, module: nn.Module,
grad_input: Tuple[torch.Tensor], grad_input: tuple[torch.Tensor],
grad_output: torch.Tensor, grad_output: torch.Tensor,
) -> Tuple[torch.Tensor]: ) -> tuple[torch.Tensor]:
return grad_input return grad_input
@contract() @contract()
@ -92,8 +91,8 @@ class TestContract(TestCase):
@skipIfTorchDynamo("Dynamo does not support the state key") @skipIfTorchDynamo("Dynamo does not support the state key")
def test_state(self): def test_state(self):
def check_and_update_state_hook( def check_and_update_state_hook(
module: nn.Module, inp: Tuple[torch.Tensor] module: nn.Module, inp: tuple[torch.Tensor]
) -> Tuple[torch.Tensor]: ) -> tuple[torch.Tensor]:
self.assertEqual(api.state(module).dummy_state, 7) self.assertEqual(api.state(module).dummy_state, 7)
api.state(module).dummy_state = 8 api.state(module).dummy_state = 8
return inp return inp
@ -139,7 +138,7 @@ class TestContract(TestCase):
@skipIfTorchDynamo("Dynamo does not support the state key") @skipIfTorchDynamo("Dynamo does not support the state key")
def test_multi_module_api(self): def test_multi_module_api(self):
@contract() @contract()
def multi_module_api(modules: List[nn.Module]) -> nn.Module: def multi_module_api(modules: list[nn.Module]) -> nn.Module:
return modules return modules
model = nn.Sequential(*[nn.Linear(3, 3) for _ in range(5)]) model = nn.Sequential(*[nn.Linear(3, 3) for _ in range(5)])

View File

@ -6,7 +6,6 @@ import itertools
import math import math
import pickle import pickle
import sys import sys
from typing import List
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -3186,7 +3185,7 @@ class TestCreateTensorNoProcessGroupMode(TestCase):
], ],
size=torch.Size([4, 2]), size=torch.Size([4, 2]),
) )
st_local_shards: List[Shard] = [] st_local_shards: list[Shard] = []
for shard_metadata in st_metadata.shards_metadata: for shard_metadata in st_metadata.shards_metadata:
st_local_shards.append( st_local_shards.append(
Shard( Shard(
@ -3215,7 +3214,7 @@ class TestCreateTensorNoProcessGroupMode(TestCase):
], ],
size=torch.Size([4, 2]), size=torch.Size([4, 2]),
) )
st_local_shards: List[Shard] = [] st_local_shards: list[Shard] = []
src = torch.randn(4, 2) src = torch.randn(4, 2)
for shard_metadata in st_metadata.shards_metadata: for shard_metadata in st_metadata.shards_metadata:
offsets = shard_metadata.shard_offsets offsets = shard_metadata.shard_offsets

View File

@ -1,7 +1,7 @@
# Owner(s): ["oncall: distributed"] # Owner(s): ["oncall: distributed"]
import copy import copy
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Union from typing import Union
import torch import torch
from torch.distributed._shard import _shard_tensor, sharded_tensor from torch.distributed._shard import _shard_tensor, sharded_tensor
@ -495,7 +495,7 @@ class TestShardingSpec(TestCase):
@dataclass @dataclass
class GridShardingSpec(ShardingSpec): class GridShardingSpec(ShardingSpec):
grid_size: int grid_size: int
placements: List[Union[torch.distributed._remote_device, str]] placements: list[Union[torch.distributed._remote_device, str]]
def __post_init__(self): def __post_init__(self):
for i, remote_device in enumerate(self.placements): for i, remote_device in enumerate(self.placements):

View File

@ -1,7 +1,6 @@
# Owner(s): ["module: unknown"] # Owner(s): ["module: unknown"]
import gc import gc
import unittest import unittest
from typing import Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -161,7 +160,7 @@ class TestMemTracker(TestCase):
def get_param_grad_optstate_actual_bytes( def get_param_grad_optstate_actual_bytes(
model: nn.Module, opt: torch.optim.Optimizer model: nn.Module, opt: torch.optim.Optimizer
) -> Tuple[int, int, int]: ) -> tuple[int, int, int]:
param_bytes = 0 param_bytes = 0
grad_bytes = 0 grad_bytes = 0
opt_state_bytes = 0 opt_state_bytes = 0
@ -179,7 +178,7 @@ class TestMemTracker(TestCase):
def get_param_grad_optstate_bytes_from_tracker( def get_param_grad_optstate_bytes_from_tracker(
tracker: MemTracker, tracker: MemTracker,
) -> Tuple[int, int, int]: ) -> tuple[int, int, int]:
snapshot = tracker.get_tracker_snapshot() snapshot = tracker.get_tracker_snapshot()
param_bytes = snapshot[dev]["Parameter"] param_bytes = snapshot[dev]["Parameter"]
grad_bytes = snapshot[dev]["Gradient"] grad_bytes = snapshot[dev]["Gradient"]

View File

@ -1,7 +1,7 @@
# Owner(s): ["module: unknown"] # Owner(s): ["module: unknown"]
import unittest import unittest
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Callable, cast, Tuple, Union from typing import Any, Callable, cast, Union
import torch import torch
from torch import nn, optim from torch import nn, optim
@ -73,7 +73,7 @@ class TestRuntimeEstimator(TestCase):
def _measure_actual_cuda_time( def _measure_actual_cuda_time(
self, self,
func: Callable, func: Callable,
args: Tuple[Any, ...], args: tuple[Any, ...],
) -> float: ) -> float:
warmup_iters, actual_iters = 2, 5 warmup_iters, actual_iters = 2, 5
start_event = torch.cuda.Event(enable_timing=True) start_event = torch.cuda.Event(enable_timing=True)
@ -92,7 +92,7 @@ class TestRuntimeEstimator(TestCase):
self, self,
estimate_mode: str, estimate_mode: str,
func: Callable, func: Callable,
args: Tuple[Any, ...], args: tuple[Any, ...],
) -> float: ) -> float:
# Optimizer init step # Optimizer init step
func(*args) func(*args)
@ -106,7 +106,7 @@ class TestRuntimeEstimator(TestCase):
model_type: str, model_type: str,
model_args: Union[ConvArgs, ModelArgs], model_args: Union[ConvArgs, ModelArgs],
bsz: int, bsz: int,
) -> Tuple[nn.Module, optim.Optimizer, torch.Tensor]: ) -> tuple[nn.Module, optim.Optimizer, torch.Tensor]:
dev = torch.cuda.current_device() dev = torch.cuda.current_device()
if model_type == "Transformer": if model_type == "Transformer":
model_args = cast(ModelArgs, model_args) model_args = cast(ModelArgs, model_args)

View File

@ -1,7 +1,6 @@
# Owner(s): ["module: unknown"] # Owner(s): ["module: unknown"]
import copy import copy
import unittest import unittest
from typing import Tuple
import torch import torch
from torch._subclasses.fake_tensor import FakeTensorMode from torch._subclasses.fake_tensor import FakeTensorMode
@ -40,7 +39,7 @@ class TestSACILP(TestCase):
def _init_model_input_optimizer( def _init_model_input_optimizer(
self, self,
) -> Tuple[torch.nn.Module, torch.optim.Optimizer, torch.Tensor]: ) -> tuple[torch.nn.Module, torch.optim.Optimizer, torch.Tensor]:
bsz = 8 bsz = 8
model_args = ModelArgs( model_args = ModelArgs(
n_layers=4, n_layers=4,

View File

@ -5,7 +5,7 @@ from dataclasses import dataclass, field
from enum import auto, Enum from enum import auto, Enum
from functools import partial from functools import partial
from io import BytesIO from io import BytesIO
from typing import Any, Dict, List from typing import Any
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -95,9 +95,9 @@ class ModelType(Enum):
class TestTrainState: class TestTrainState:
step: int = 0 step: int = 0
current_loss: float = -1 current_loss: float = -1
losses: List[float] = field(default_factory=list) losses: list[float] = field(default_factory=list)
def state_dict(self) -> Dict[str, Any]: def state_dict(self) -> dict[str, Any]:
loss_bytes = BytesIO() loss_bytes = BytesIO()
torch.save(self.losses, loss_bytes) torch.save(self.losses, loss_bytes)
return { return {
@ -284,7 +284,7 @@ class TestE2ESaveAndLoad(DTensorTestBase, VerifyStateDictMixin):
@with_temp_dir @with_temp_dir
def test_stateful_and_non_stateful_loads(self) -> None: def test_stateful_and_non_stateful_loads(self) -> None:
class StateDict(Dict): class StateDict(dict):
def __init__(self): def __init__(self):
self.set_sd_item_called = False self.set_sd_item_called = False

View File

@ -2,7 +2,7 @@
import os import os
import sys import sys
from typing import cast, List, Optional, Union from typing import cast, Optional, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -177,17 +177,17 @@ class FaultyStorageWriter(TestStorageBase, StorageWriter):
self._fail_rank("fail_prepare_local_plan") self._fail_rank("fail_prepare_local_plan")
return plan return plan
def prepare_global_plan(self, plans: List[SavePlan]) -> List[SavePlan]: def prepare_global_plan(self, plans: list[SavePlan]) -> list[SavePlan]:
self._fail_rank("fail_prepare_global_plan") self._fail_rank("fail_prepare_global_plan")
return plans return plans
def write_data( def write_data(
self, plan: SavePlan, planner: SavePlanner self, plan: SavePlan, planner: SavePlanner
) -> Future[List[WriteResult]]: ) -> Future[list[WriteResult]]:
self._fail_rank("fail_write_data") self._fail_rank("fail_write_data")
return self._fail_rank_async("fail_write_data_async", []) return self._fail_rank_async("fail_write_data_async", [])
def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None: def finish(self, metadata: Metadata, results: list[list[WriteResult]]) -> None:
self._fail_rank("fail_finish") self._fail_rank("fail_finish")
@classmethod @classmethod
@ -210,7 +210,7 @@ class FaultyStorageReader(TestStorageBase, StorageReader):
self._fail_rank("fail_prepare_local_plan") self._fail_rank("fail_prepare_local_plan")
return plan return plan
def prepare_global_plan(self, plans: List[LoadPlan]) -> List[LoadPlan]: def prepare_global_plan(self, plans: list[LoadPlan]) -> list[LoadPlan]:
self._fail_rank("fail_prepare_global_plan") self._fail_rank("fail_prepare_global_plan")
return plans return plans

View File

@ -1,5 +1,5 @@
# Owner(s): ["oncall: distributed"] # Owner(s): ["oncall: distributed"]
from typing import Dict, Union from typing import Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -58,14 +58,14 @@ class MyTestModule(torch.nn.Module):
def extra_state_tensor(self, new_extra_state_tensor: torch.Tensor) -> None: def extra_state_tensor(self, new_extra_state_tensor: torch.Tensor) -> None:
self._extra_state_tensor = new_extra_state_tensor self._extra_state_tensor = new_extra_state_tensor
def get_extra_state(self) -> Dict[str, Union[int, torch._tensor.Tensor]]: def get_extra_state(self) -> dict[str, Union[int, torch._tensor.Tensor]]:
return { return {
"extra_state": self._extra_state, "extra_state": self._extra_state,
"extra_state_tensor": self._extra_state_tensor, "extra_state_tensor": self._extra_state_tensor,
} }
def set_extra_state( def set_extra_state(
self, state: Dict[str, Union[int, torch._tensor.Tensor]] self, state: dict[str, Union[int, torch._tensor.Tensor]]
) -> None: ) -> None:
self._extra_state = state["extra_state"] # pyre-ignore[8] self._extra_state = state["extra_state"] # pyre-ignore[8]
self._extra_state_tensor = state["extra_state_tensor"] # pyre-ignore[8] self._extra_state_tensor = state["extra_state_tensor"] # pyre-ignore[8]

View File

@ -4,7 +4,6 @@ import os
import shutil import shutil
import sys import sys
import tempfile import tempfile
from typing import Dict
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -53,8 +52,8 @@ if TEST_WITH_DEV_DBG_ASAN:
def assert_state_dict_equal( def assert_state_dict_equal(
self: TestCase, self: TestCase,
state_dict_1: Dict[str, torch.Tensor], state_dict_1: dict[str, torch.Tensor],
state_dict_2: Dict[str, torch.Tensor], state_dict_2: dict[str, torch.Tensor],
) -> bool: ) -> bool:
self.assertEqual( self.assertEqual(
len(state_dict_1), len(state_dict_2), "state_dict must be the same size" len(state_dict_1), len(state_dict_2), "state_dict must be the same size"

View File

@ -2,7 +2,7 @@
import sys import sys
import tempfile import tempfile
from typing import Any, Dict, IO from typing import Any, IO
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -56,8 +56,8 @@ _THREAD_COUNTS = {1, 2}
def assert_state_dict_equal( def assert_state_dict_equal(
self: TestCase, self: TestCase,
state_dict_1: Dict[str, torch.Tensor], state_dict_1: dict[str, torch.Tensor],
state_dict_2: Dict[str, torch.Tensor], state_dict_2: dict[str, torch.Tensor],
) -> bool: ) -> bool:
self.assertEqual( self.assertEqual(
len(state_dict_1), len(state_dict_2), "state_dict must be the same size" len(state_dict_1), len(state_dict_2), "state_dict must be the same size"
@ -113,10 +113,10 @@ class BlobState:
def __init__(self, value: IO[bytes]) -> Any: def __init__(self, value: IO[bytes]) -> Any:
self.state = {"blob": value} self.state = {"blob": value}
def state_dict(self) -> Dict[str, Any]: def state_dict(self) -> dict[str, Any]:
return self.state return self.state
def load_state_dict(self, state_dict: Dict[str, Any]) -> None: def load_state_dict(self, state_dict: dict[str, Any]) -> None:
self.state = state_dict self.state = state_dict
def __eq__(self, other: object) -> bool: def __eq__(self, other: object) -> bool:

View File

@ -3,7 +3,7 @@
import shutil import shutil
import tempfile import tempfile
from functools import wraps from functools import wraps
from typing import Any, Callable, Dict, Optional, Tuple from typing import Any, Callable, Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -35,7 +35,7 @@ def with_temp_dir(
assert func is not None assert func is not None
@wraps(func) @wraps(func)
def wrapper(self, *args: Tuple[object], **kwargs: Dict[str, Any]) -> None: def wrapper(self, *args: tuple[object], **kwargs: dict[str, Any]) -> None:
# Only create temp_dir when rank is 0 (or no pg) # Only create temp_dir when rank is 0 (or no pg)
if not dist.is_initialized() or dist.get_rank() == 0: if not dist.is_initialized() or dist.get_rank() == 0:
temp_dir = tempfile.mkdtemp() temp_dir = tempfile.mkdtemp()

View File

@ -4,7 +4,7 @@ import copy
import functools import functools
import sys import sys
from itertools import chain from itertools import chain
from typing import Callable, Tuple, Type, Union from typing import Callable, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -154,9 +154,9 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
*, *,
use_orig_params: bool, use_orig_params: bool,
use_dtensor: bool, use_dtensor: bool,
wrapping: Tuple[nn.Module] = (), wrapping: tuple[nn.Module] = (),
compile_model: bool = False, compile_model: bool = False,
optimizer_class: Type[Optimizer], optimizer_class: type[Optimizer],
) -> None: ) -> None:
if not use_orig_params: if not use_orig_params:
return return
@ -232,7 +232,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
self, self,
*, *,
reshard_after_forward: Union[bool, int], reshard_after_forward: Union[bool, int],
optimizer_class: Type[Optimizer], optimizer_class: type[Optimizer],
compile_model: bool, compile_model: bool,
foreach: bool = True, foreach: bool = True,
): ):
@ -272,7 +272,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
self._test_fsdp2, self._test_fsdp2,
) )
def _test_ddp(self, use_composable: bool, optimizer_class: Type[Optimizer]) -> None: def _test_ddp(self, use_composable: bool, optimizer_class: type[Optimizer]) -> None:
def init_model_optim(): def init_model_optim():
orig_model = CompositeParamModel(device=torch.device("cuda")) orig_model = CompositeParamModel(device=torch.device("cuda"))
orig_optim = optimizer_class(orig_model.parameters(), lr=1e-4) orig_optim = optimizer_class(orig_model.parameters(), lr=1e-4)
@ -303,7 +303,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
def _test_fsdp_ddp( def _test_fsdp_ddp(
self, self,
optimizer_class: Type[Optimizer], optimizer_class: type[Optimizer],
optim_in_backward: bool = False, optim_in_backward: bool = False,
test_frozen: bool = False, test_frozen: bool = False,
) -> None: ) -> None:
@ -347,7 +347,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
self._test_fsdp_ddp, self._test_fsdp_ddp,
) )
def _test_single_gpu(self, optimizer_class: Type[Optimizer]) -> None: def _test_single_gpu(self, optimizer_class: type[Optimizer]) -> None:
def init_model_optim(): def init_model_optim():
orig_model = CompositeParamModel(device=torch.device("cuda")) orig_model = CompositeParamModel(device=torch.device("cuda"))
orig_optim = optimizer_class(orig_model.parameters(), lr=1e-4) orig_optim = optimizer_class(orig_model.parameters(), lr=1e-4)
@ -399,7 +399,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
) )
def _test_cpu_offload_full_state_dict( def _test_cpu_offload_full_state_dict(
self, optimizer_class: Type[Optimizer] self, optimizer_class: type[Optimizer]
) -> None: ) -> None:
orig_model = CompositeParamModel(device=torch.device("cuda")) orig_model = CompositeParamModel(device=torch.device("cuda"))
device_mesh = init_device_mesh("cuda", (self.world_size,)) device_mesh = init_device_mesh("cuda", (self.world_size,))

View File

@ -14,7 +14,7 @@ import signal
import unittest import unittest
import uuid import uuid
from multiprocessing.pool import ThreadPool from multiprocessing.pool import ThreadPool
from typing import Any, Dict, List from typing import Any
from unittest.mock import call, patch from unittest.mock import call, patch
import torch.distributed as dist import torch.distributed as dist
@ -135,7 +135,7 @@ class TestAgent(SimpleElasticAgent):
worker_group.group_world_size = None worker_group.group_world_size = None
self.stop_workers_call_count += 1 self.stop_workers_call_count += 1
def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]: def _start_workers(self, worker_group: WorkerGroup) -> dict[int, Any]:
# crate fake workers; make worker id equal to global rank # crate fake workers; make worker id equal to global rank
ids = {} ids = {}
for worker in worker_group.workers: for worker in worker_group.workers:
@ -477,7 +477,7 @@ class SimpleElasticAgentTest(unittest.TestCase):
self.assertEqual(1, mock_monitor_workers.call_count) self.assertEqual(1, mock_monitor_workers.call_count)
self.assertEqual(spec.max_restarts, agent._remaining_restarts) self.assertEqual(spec.max_restarts, agent._remaining_restarts)
def get_worker_assigned(self, store, role_infos_len, info) -> List[Worker]: def get_worker_assigned(self, store, role_infos_len, info) -> list[Worker]:
i, role_info = info i, role_info = info
spec = self._get_worker_spec( spec = self._get_worker_spec(
max_restarts=3, max_restarts=3,

View File

@ -17,7 +17,7 @@ import time
import unittest import unittest
import uuid import uuid
from dataclasses import dataclass from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, Tuple from typing import Callable, Optional
from unittest import mock from unittest import mock
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
@ -256,7 +256,7 @@ class Conf:
entrypoint: Callable entrypoint: Callable
local_world_size: int local_world_size: int
args: Tuple = () args: tuple = ()
role: str = "default" role: str = "default"
redirects: Std = Std.NONE redirects: Std = Std.NONE
tee: Std = Std.NONE tee: Std = Std.NONE
@ -394,10 +394,10 @@ class LocalElasticAgentTest(unittest.TestCase):
def run_job( def run_job(
self, self,
node_configs: List[Conf], node_configs: list[Conf],
exit_barrier_timeout: int = 5, exit_barrier_timeout: int = 5,
log_line_prefix_template: Optional[str] = None, log_line_prefix_template: Optional[str] = None,
) -> Dict[str, List[RunResult]]: ) -> dict[str, list[RunResult]]:
""" """
Simulates running a distributed job by running multiple agents Simulates running a distributed job by running multiple agents
(one on each process). Agent 0 is run on the main process for (one on each process). Agent 0 is run on the main process for
@ -431,7 +431,7 @@ class LocalElasticAgentTest(unittest.TestCase):
for p in procs: for p in procs:
p.join() p.join()
results: Dict[str, List[RunResult]] = {} results: dict[str, list[RunResult]] = {}
while not agent_results.empty(): while not agent_results.empty():
role, run_result = agent_results.get() role, run_result = agent_results.get()
results.setdefault(role, []).append(run_result) results.setdefault(role, []).append(run_result)
@ -1032,8 +1032,8 @@ class LocalElasticAgentTest(unittest.TestCase):
def assert_rank_consistency( def assert_rank_consistency(
self, self,
run_results: Dict[str, List[RunResult]], run_results: dict[str, list[RunResult]],
expected_role_world_sizes: Dict[str, int], expected_role_world_sizes: dict[str, int],
): ):
""" """
Asserts that ranks are consecutive w.r.t role_rank. If local world sizes are 4: Asserts that ranks are consecutive w.r.t role_rank. If local world sizes are 4:
@ -1042,11 +1042,11 @@ class LocalElasticAgentTest(unittest.TestCase):
... etc ... ... etc ...
""" """
global_ranks: List[int] = [] global_ranks: list[int] = []
# role -> [role_rank,...] # role -> [role_rank,...]
role_ranks: Dict[str, List[int]] = {} role_ranks: dict[str, list[int]] = {}
# group rank -> [(rank, role_rank),...] # group rank -> [(rank, role_rank),...]
grouped_ranks: Dict[int, List[Tuple[int, int]]] = {} grouped_ranks: dict[int, list[tuple[int, int]]] = {}
# global world size == sum of all the role world sizes # global world size == sum of all the role world sizes
expected_world_size = sum(expected_role_world_sizes.values()) expected_world_size = sum(expected_role_world_sizes.values())

View File

@ -16,7 +16,7 @@ import sys
import tempfile import tempfile
import time import time
from itertools import product from itertools import product
from typing import Callable, Dict, List, Union from typing import Callable, Union
from unittest import mock from unittest import mock
import torch import torch
@ -141,7 +141,7 @@ def echo2(msg: str, fail: bool = False) -> str:
return msg return msg
def echo_large(size: int) -> Dict[int, str]: def echo_large(size: int) -> dict[int, str]:
""" """
returns a large output ({0: test0", 1: "test1", ..., (size-1):f"test{size-1}"}) returns a large output ({0: test0", 1: "test1", ..., (size-1):f"test{size-1}"})
""" """
@ -167,13 +167,13 @@ def dummy_compute() -> torch.Tensor:
return torch.rand(100, 100) return torch.rand(100, 100)
def redirects_oss_test() -> List[Std]: def redirects_oss_test() -> list[Std]:
return [ return [
Std.NONE, Std.NONE,
] ]
def redirects_all() -> List[Std]: def redirects_all() -> list[Std]:
return [ return [
Std.NONE, Std.NONE,
Std.OUT, Std.OUT,
@ -240,14 +240,14 @@ class _StartProcessesTest(TestCase):
def log_dir(self): def log_dir(self):
return tempfile.mkdtemp(dir=self.test_dir) return tempfile.mkdtemp(dir=self.test_dir)
def assert_in_file(self, expected: List[str], filename: str) -> None: def assert_in_file(self, expected: list[str], filename: str) -> None:
expected = [f"{line.rstrip()}\n" for line in expected] expected = [f"{line.rstrip()}\n" for line in expected]
with open(filename) as fp: with open(filename) as fp:
actual = fp.readlines() actual = fp.readlines()
for line in expected: for line in expected:
self.assertIn(line, actual) self.assertIn(line, actual)
def assert_pids_noexist(self, pids: Dict[int, int]): def assert_pids_noexist(self, pids: dict[int, int]):
for local_rank, pid in pids.items(): for local_rank, pid in pids.items():
with self.assertRaises( with self.assertRaises(
OSError, msg=f"local_rank: {local_rank} pid: {pid} should not exist" OSError, msg=f"local_rank: {local_rank} pid: {pid} should not exist"

View File

@ -16,7 +16,6 @@ import unittest
from concurrent.futures import wait from concurrent.futures import wait
from concurrent.futures._base import ALL_COMPLETED from concurrent.futures._base import ALL_COMPLETED
from concurrent.futures.thread import ThreadPoolExecutor from concurrent.futures.thread import ThreadPoolExecutor
from typing import Dict, Set
from unittest import mock from unittest import mock
from torch.distributed.elastic.multiprocessing.tail_log import TailLog from torch.distributed.elastic.multiprocessing.tail_log import TailLog
@ -72,7 +71,7 @@ class TailLogTest(unittest.TestCase):
tail.stop() tail.stop()
dst.seek(0) dst.seek(0)
actual: Dict[int, Set[int]] = {} actual: dict[int, set[int]] = {}
for line in dst.readlines(): for line in dst.readlines():
header, num = line.split(":") header, num = line.split(":")
@ -123,7 +122,7 @@ class TailLogTest(unittest.TestCase):
tail.stop() tail.stop()
dst.seek(0) dst.seek(0)
headers: Set[str] = set() headers: set[str] = set()
for line in dst.readlines(): for line in dst.readlines():
header, _ = line.split(":") header, _ = line.split(":")
headers.add(header) headers.add(header)

View File

@ -6,7 +6,7 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from typing import Any, cast, Dict, SupportsInt from typing import Any, cast, SupportsInt
from unittest import TestCase from unittest import TestCase
from torch.distributed.elastic.rendezvous import ( from torch.distributed.elastic.rendezvous import (
@ -24,7 +24,7 @@ class RendezvousParametersTest(TestCase):
self._run_id = "dummy_run_id" self._run_id = "dummy_run_id"
self._min_nodes = 3 self._min_nodes = 3
self._max_nodes = 6 self._max_nodes = 6
self._kwargs: Dict[str, Any] = {} self._kwargs: dict[str, Any] = {}
def _create_params(self) -> RendezvousParameters: def _create_params(self) -> RendezvousParameters:
return RendezvousParameters( return RendezvousParameters(

View File

@ -15,7 +15,7 @@ import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from base64 import b64encode from base64 import b64encode
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from typing import Callable, cast, Optional, Tuple from typing import Callable, cast, Optional
from unittest import TestCase from unittest import TestCase
from unittest.mock import call, MagicMock, Mock, patch, PropertyMock from unittest.mock import call, MagicMock, Mock, patch, PropertyMock
@ -186,7 +186,7 @@ class FakeRendezvousBackend(RendezvousBackend):
def name(self) -> str: def name(self) -> str:
return "fake_backend" return "fake_backend"
def get_state(self) -> Optional[Tuple[bytes, Token]]: def get_state(self) -> Optional[tuple[bytes, Token]]:
if self._token == 0: if self._token == 0:
return None return None
@ -194,7 +194,7 @@ class FakeRendezvousBackend(RendezvousBackend):
def set_state( def set_state(
self, state: bytes, token: Optional[Token] = None self, state: bytes, token: Optional[Token] = None
) -> Optional[Tuple[bytes, Token, bool]]: ) -> Optional[tuple[bytes, Token, bool]]:
if token is None: if token is None:
token = 0 token = 0

View File

@ -7,7 +7,7 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Callable, cast, Optional, Tuple from typing import Any, Callable, cast, Optional
from torch.distributed.elastic.rendezvous import RendezvousStateError from torch.distributed.elastic.rendezvous import RendezvousStateError
from torch.distributed.elastic.rendezvous.dynamic_rendezvous import ( from torch.distributed.elastic.rendezvous.dynamic_rendezvous import (
@ -32,12 +32,12 @@ class RendezvousBackendTestMixin(ABC):
def _set_state( def _set_state(
self, state: bytes, token: Optional[Any] = None self, state: bytes, token: Optional[Any] = None
) -> Tuple[bytes, Token, bool]: ) -> tuple[bytes, Token, bool]:
result = self._backend.set_state(state, token) result = self._backend.set_state(state, token)
self.assertIsNotNone(result) self.assertIsNotNone(result)
return cast(Tuple[bytes, Token, bool], result) return cast(tuple[bytes, Token, bool], result)
def test_get_state_returns_backend_state(self) -> None: def test_get_state_returns_backend_state(self) -> None:
self._backend.set_state(b"x") self._backend.set_state(b"x")
@ -46,7 +46,7 @@ class RendezvousBackendTestMixin(ABC):
self.assertIsNotNone(result) self.assertIsNotNone(result)
state, token = cast(Tuple[bytes, Token], result) state, token = cast(tuple[bytes, Token], result)
self.assertEqual(b"x", state) self.assertEqual(b"x", state)
self.assertIsNotNone(token) self.assertIsNotNone(token)

View File

@ -10,7 +10,6 @@ import socket
import threading import threading
import time import time
from datetime import timedelta from datetime import timedelta
from typing import List
from unittest import TestCase from unittest import TestCase
from unittest.mock import patch from unittest.mock import patch
@ -350,7 +349,7 @@ class PeriodicTimerTest(TestCase):
call_interval = 0.2 call_interval = 0.2
# Keep the log of intervals between each consecutive call. # Keep the log of intervals between each consecutive call.
actual_call_intervals: List[float] = [] actual_call_intervals: list[float] = []
# Keep the number of times the function was called. # Keep the number of times the function was called.
call_count = 0 call_count = 0

View File

@ -7,7 +7,6 @@ import pickle
import socket import socket
import tempfile import tempfile
from contextlib import contextmanager from contextlib import contextmanager
from typing import Dict
from urllib3.connection import HTTPConnection from urllib3.connection import HTTPConnection
from urllib3.connectionpool import HTTPConnectionPool from urllib3.connectionpool import HTTPConnectionPool
@ -181,7 +180,7 @@ class WorkerServerTest(TestCase):
def body(self) -> bytes: def body(self) -> bytes:
return b"dummy" return b"dummy"
def params(self) -> Dict[str, str]: def params(self) -> dict[str, str]:
return {} return {}
class Response(_Response): class Response(_Response):

View File

@ -9,7 +9,6 @@
import datetime import datetime
from multiprocessing.pool import ThreadPool from multiprocessing.pool import ThreadPool
from typing import List
from unittest import mock from unittest import mock
import torch.distributed as dist import torch.distributed as dist
@ -40,7 +39,7 @@ class MockStore:
self.ops.append(("get", key)) self.ops.append(("get", key))
return "value" return "value"
def multi_get(self, keys: List[str]) -> List[str]: def multi_get(self, keys: list[str]) -> list[str]:
self.ops.append(("multi_get", keys)) self.ops.append(("multi_get", keys))
return ["value"] * len(keys) return ["value"] * len(keys)
@ -48,7 +47,7 @@ class MockStore:
self.ops.append(("add", key, val)) self.ops.append(("add", key, val))
return 3 return 3
def wait(self, keys: List[str]) -> None: def wait(self, keys: list[str]) -> None:
self.ops.append(("wait", keys)) self.ops.append(("wait", keys))
@ -157,7 +156,7 @@ class StoreUtilTest(TestCase):
return "" return ""
with ThreadPool(N - 1) as pool: with ThreadPool(N - 1) as pool:
outputs: List[str] = pool.map(run_barrier_for_rank, range(N - 1)) outputs: list[str] = pool.map(run_barrier_for_rank, range(N - 1))
self.assertTrue(any("missing_ranks=[Rank 2 host]" in msg for msg in outputs)) self.assertTrue(any("missing_ranks=[Rank 2 host]" in msg for msg in outputs))

View File

@ -1,7 +1,6 @@
# Owner(s): ["oncall: distributed"] # Owner(s): ["oncall: distributed"]
import sys import sys
from typing import List
from unittest.mock import patch from unittest.mock import patch
import torch import torch
@ -102,7 +101,7 @@ class TestBackwardPrefetch(FSDPTest):
tgt = torch.randn((20, 1, 1024), device=device_type) tgt = torch.randn((20, 1, 1024), device=device_type)
# monkey patch # monkey patch
all_handle_fqns: List[List[str]] = [] all_handle_fqns: list[list[str]] = []
def patched_get_handle_to_prefetch(*args, **kwargs): def patched_get_handle_to_prefetch(*args, **kwargs):
handle = orig_get_handle_to_prefetch(*args, **kwargs) handle = orig_get_handle_to_prefetch(*args, **kwargs)

View File

@ -2,7 +2,7 @@
import sys import sys
from contextlib import nullcontext from contextlib import nullcontext
from enum import auto, Enum from enum import auto, Enum
from typing import List, Optional from typing import Optional
from unittest.mock import patch from unittest.mock import patch
import torch import torch
@ -319,7 +319,7 @@ class TestExplicitUnshard(FSDPTest):
self.mlp2 = MLP(dim) self.mlp2 = MLP(dim)
self.mlp3 = MLP(dim) self.mlp3 = MLP(dim)
def forward(self, ys: List[torch.Tensor], works: List[dist.Work]): def forward(self, ys: list[torch.Tensor], works: list[dist.Work]):
(y1, y2, y3), (work1, work2, work3) = ys, works (y1, y2, y3), (work1, work2, work3) = ys, works
work1.wait() work1.wait()
z1 = self.mlp1(y1) z1 = self.mlp1(y1)
@ -372,7 +372,7 @@ class TestExplicitUnshard(FSDPTest):
torch.manual_seed(42 + self.rank + 1) torch.manual_seed(42 + self.rank + 1)
inp = torch.randn((batch_size, dim), device=device_type) inp = torch.randn((batch_size, dim), device=device_type)
for _ in range(10): for _ in range(10):
losses: List[torch.Tensor] = [] losses: list[torch.Tensor] = []
for _model, _optim in ((ref_model, ref_optim), (model, optim)): for _model, _optim in ((ref_model, ref_optim), (model, optim)):
losses.append(_model(inp).sum()) losses.append(_model(inp).sum())
losses[-1].backward() losses[-1].backward()

View File

@ -4,7 +4,7 @@ import functools
import itertools import itertools
import sys import sys
import unittest import unittest
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Optional
from unittest import mock from unittest import mock
import torch import torch
@ -76,7 +76,7 @@ class TestParityWithDDP(FSDPTest):
PyTorch DDP vs. FullyShardedDataParallel. PyTorch DDP vs. FullyShardedDataParallel.
""" """
def _get_device_init_modes(self, cpu_offload: CPUOffload) -> List[DEVICEInitMode]: def _get_device_init_modes(self, cpu_offload: CPUOffload) -> list[DEVICEInitMode]:
modes = [ modes = [
DEVICEInitMode.DEVICE_AFTER, DEVICEInitMode.DEVICE_AFTER,
DEVICEInitMode.DEVICE_BEFORE, DEVICEInitMode.DEVICE_BEFORE,
@ -89,7 +89,7 @@ class TestParityWithDDP(FSDPTest):
modes.append(DEVICEInitMode.DEVICE_NEVER) modes.append(DEVICEInitMode.DEVICE_NEVER)
return modes return modes
def _get_subtest_config(self, cpu_offload: CPUOffload) -> Dict[str, List[Any]]: def _get_subtest_config(self, cpu_offload: CPUOffload) -> dict[str, list[Any]]:
"""Returns a subtest configuration that subtests CUDA initialization """Returns a subtest configuration that subtests CUDA initialization
modes and prefetching settings together.""" modes and prefetching settings together."""
return { return {

View File

@ -4,7 +4,7 @@ import contextlib
import itertools import itertools
import sys import sys
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Optional
import torch import torch
from torch import distributed as dist from torch import distributed as dist
@ -71,7 +71,7 @@ class _GradAccConfigs:
sole purpose of overriding :meth:`__repr__` to remove spaces. sole purpose of overriding :meth:`__repr__` to remove spaces.
""" """
configs: List[_GradAccConfig] configs: list[_GradAccConfig]
def __repr__(self) -> str: def __repr__(self) -> str:
# Override to remove any spaces in the string to appease the internal # Override to remove any spaces in the string to appease the internal
@ -90,7 +90,7 @@ class TestGradAcc(FSDPTest):
def _test_grad_acc( def _test_grad_acc(
self, self,
batch_dim: int, batch_dim: int,
configs: List[_GradAccConfig], configs: list[_GradAccConfig],
cpu_offload: CPUOffload, cpu_offload: CPUOffload,
backward_prefetch: Optional[BackwardPrefetch], backward_prefetch: Optional[BackwardPrefetch],
sharding_strategy: ShardingStrategy, sharding_strategy: ShardingStrategy,
@ -146,8 +146,8 @@ class TestGradAcc(FSDPTest):
def permute_tensor(x: torch.Tensor): def permute_tensor(x: torch.Tensor):
return x.view(-1)[torch.randperm(x.numel())].view_as(x) return x.view(-1)[torch.randperm(x.numel())].view_as(x)
batch: Tuple[torch.Tensor, ...] = fsdp_model.module.get_input(device) batch: tuple[torch.Tensor, ...] = fsdp_model.module.get_input(device)
batches: List[Tuple[torch.Tensor, ...]] = [batch] batches: list[tuple[torch.Tensor, ...]] = [batch]
num_iters_to_acc = sum(config.num_iters for config in configs) num_iters_to_acc = sum(config.num_iters for config in configs)
for _ in range(num_iters_to_acc - 1): for _ in range(num_iters_to_acc - 1):
batches.append(tuple(permute_tensor(t) for t in batch)) batches.append(tuple(permute_tensor(t) for t in batch))
@ -158,7 +158,7 @@ class TestGradAcc(FSDPTest):
), "Check the test to make sure that batches are distinct" ), "Check the test to make sure that batches are distinct"
# Concatenate the batches along the given batch dimension # Concatenate the batches along the given batch dimension
concat_batch: Tuple[torch.Tensor, ...] = tuple( concat_batch: tuple[torch.Tensor, ...] = tuple(
torch.cat(ts, dim=batch_dim) for ts in zip(*batches) torch.cat(ts, dim=batch_dim) for ts in zip(*batches)
) )
@ -214,7 +214,7 @@ class TestGradAcc(FSDPTest):
# Check that the optimizer step does not error # Check that the optimizer step does not error
optim.step() optim.step()
def _get_subtest_config(self) -> Dict[str, List[Any]]: def _get_subtest_config(self) -> dict[str, list[Any]]:
"""Returns a subtest configuration that subtests prefetching.""" """Returns a subtest configuration that subtests prefetching."""
return { return {
"backward_prefetch": [ "backward_prefetch": [

View File

@ -5,7 +5,7 @@ import sys
from collections import Counter from collections import Counter
from enum import auto, Enum from enum import auto, Enum
from functools import partial from functools import partial
from typing import List, Optional, Tuple from typing import Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -363,7 +363,7 @@ class TestFSDPHybridShard(FSDPTest):
torch.manual_seed(global_pg.rank() + 1) torch.manual_seed(global_pg.rank() + 1)
for _ in range(5): for _ in range(5):
inp = fsdp_model.module.get_input(torch.device("cuda")) inp = fsdp_model.module.get_input(torch.device("cuda"))
losses: List[torch.Tensor] = [] losses: list[torch.Tensor] = []
for model, optim in ((fsdp_model, fsdp_optim), (hsdp_model, hsdp_optim)): for model, optim in ((fsdp_model, fsdp_optim), (hsdp_model, hsdp_optim)):
optim.zero_grad() optim.zero_grad()
loss = model(*inp).sum() loss = model(*inp).sum()
@ -396,7 +396,7 @@ class TestFSDPHybridShard(FSDPTest):
sharding_strategy_mode: str, sharding_strategy_mode: str,
use_orig_params: bool, use_orig_params: bool,
hsdp_process_groups: Optional[ hsdp_process_groups: Optional[
Tuple[dist.ProcessGroup, dist.ProcessGroup] tuple[dist.ProcessGroup, dist.ProcessGroup]
] = None, ] = None,
hsdp_device_mesh: Optional = None, hsdp_device_mesh: Optional = None,
): ):

View File

@ -8,7 +8,7 @@ from collections import namedtuple
from contextlib import nullcontext from contextlib import nullcontext
from copy import deepcopy from copy import deepcopy
from itertools import chain from itertools import chain
from typing import Any, Tuple from typing import Any
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -945,7 +945,7 @@ class TestFSDPMiscMultiThread(FSDPTestMultiThread):
self._test_homogeneous_attributes, self._test_homogeneous_attributes,
) )
def _test_homogeneous_attributes(self, attr_name_and_values: Tuple[str, Any, Any]): def _test_homogeneous_attributes(self, attr_name_and_values: tuple[str, Any, Any]):
model = NestedWrappedModule.init( model = NestedWrappedModule.init(
self.process_group, self.process_group,
FSDPInitMode.NO_FSDP, FSDPInitMode.NO_FSDP,

View File

@ -6,7 +6,7 @@ import os
import sys import sys
from functools import partial from functools import partial
from itertools import product from itertools import product
from typing import Any, Dict, List from typing import Any
import torch import torch
import torch.cuda.nccl as nccl import torch.cuda.nccl as nccl
@ -521,7 +521,7 @@ class TestFSDPMixedPrecisionSharded(TestFSDPMixedPrecision):
def world_size(self): def world_size(self):
return 2 return 2
def _get_subtest_config(self) -> Dict[str, List[Any]]: def _get_subtest_config(self) -> dict[str, list[Any]]:
"""Returns a subtest configuration that subtests prefetching settings """Returns a subtest configuration that subtests prefetching settings
together.""" together."""
return { return {
@ -1136,7 +1136,7 @@ class TestFSDPDifferentSubmodulePrecision(FSDPTest):
@skip_if_lt_x_gpu(2) @skip_if_lt_x_gpu(2)
def test_float16_on_one_submodule(self): def test_float16_on_one_submodule(self):
forward_inputs: Dict[str, nn.Module] = {} forward_inputs: dict[str, nn.Module] = {}
float16 = MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=True) float16 = MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=True)
model = SaveForwardInputsModel( model = SaveForwardInputsModel(
@ -1158,7 +1158,7 @@ class TestFSDPDifferentSubmodulePrecision(FSDPTest):
@skip_if_lt_x_gpu(2) @skip_if_lt_x_gpu(2)
def test_float16_on_one_submodule_skip_inputs(self): def test_float16_on_one_submodule_skip_inputs(self):
forward_inputs: Dict[nn.Module, torch.Tensor] = {} forward_inputs: dict[nn.Module, torch.Tensor] = {}
float16 = MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=False) float16 = MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=False)
model = SaveForwardInputsModel( model = SaveForwardInputsModel(
@ -1179,7 +1179,7 @@ class TestFSDPDifferentSubmodulePrecision(FSDPTest):
@skip_if_lt_x_gpu(2) @skip_if_lt_x_gpu(2)
def test_float16_on_one_submodule_skip_inputs_error(self): def test_float16_on_one_submodule_skip_inputs_error(self):
forward_inputs: Dict[nn.Module, torch.Tensor] = {} forward_inputs: dict[nn.Module, torch.Tensor] = {}
float16 = MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=False) float16 = MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=False)
model = SaveForwardInputsModel( model = SaveForwardInputsModel(
@ -1198,7 +1198,7 @@ class TestFSDPDifferentSubmodulePrecision(FSDPTest):
@skip_if_lt_x_gpu(2) @skip_if_lt_x_gpu(2)
def test_submodules_with_different_precisions_error(self): def test_submodules_with_different_precisions_error(self):
forward_inputs: Dict[nn.Module, torch.Tensor] = {} forward_inputs: dict[nn.Module, torch.Tensor] = {}
float16 = MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=True) float16 = MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=True)
float32 = MixedPrecision(param_dtype=torch.float32, cast_forward_inputs=True) float32 = MixedPrecision(param_dtype=torch.float32, cast_forward_inputs=True)
@ -1222,7 +1222,7 @@ class TestFSDPDifferentSubmodulePrecision(FSDPTest):
@skip_if_lt_x_gpu(2) @skip_if_lt_x_gpu(2)
def test_submodules_with_different_precisions(self): def test_submodules_with_different_precisions(self):
forward_inputs: Dict[nn.Module, torch.Tensor] = {} forward_inputs: dict[nn.Module, torch.Tensor] = {}
float16 = MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=True) float16 = MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=True)
float32 = MixedPrecision(param_dtype=torch.float32, cast_forward_inputs=True) float32 = MixedPrecision(param_dtype=torch.float32, cast_forward_inputs=True)
@ -1244,7 +1244,7 @@ class TestFSDPDifferentSubmodulePrecision(FSDPTest):
@skip_if_lt_x_gpu(2) @skip_if_lt_x_gpu(2)
def test_submodules_with_external_inputs(self): def test_submodules_with_external_inputs(self):
class ToyModule(nn.Module): class ToyModule(nn.Module):
def __init__(self, forward_inputs: Dict[str, torch.Tensor]) -> None: def __init__(self, forward_inputs: dict[str, torch.Tensor]) -> None:
super().__init__() super().__init__()
self.l = nn.Linear(100, 100) self.l = nn.Linear(100, 100)
self.forward_inputs = forward_inputs self.forward_inputs = forward_inputs
@ -1255,7 +1255,7 @@ class TestFSDPDifferentSubmodulePrecision(FSDPTest):
return self.l(x) return self.l(x)
class ToyModel(nn.Module): class ToyModel(nn.Module):
def __init__(self, forward_inputs: Dict[str, torch.Tensor]) -> None: def __init__(self, forward_inputs: dict[str, torch.Tensor]) -> None:
super().__init__() super().__init__()
self.l1 = nn.Linear(100, 100) self.l1 = nn.Linear(100, 100)
self.l2 = ToyModule(forward_inputs) self.l2 = ToyModule(forward_inputs)
@ -1266,7 +1266,7 @@ class TestFSDPDifferentSubmodulePrecision(FSDPTest):
y = torch.ones(2, 100, device="cuda", dtype=torch.float32) y = torch.ones(2, 100, device="cuda", dtype=torch.float32)
return self.l2(self.l1(x), y) return self.l2(self.l1(x), y)
forward_inputs: Dict[str, torch.Tensor] = {} forward_inputs: dict[str, torch.Tensor] = {}
float16 = MixedPrecision(param_dtype=torch.float16) float16 = MixedPrecision(param_dtype=torch.float16)
model = ToyModel(forward_inputs).cuda() model = ToyModel(forward_inputs).cuda()
@ -1343,7 +1343,7 @@ class TestFSDPTrainEval(FSDPTest):
torch.manual_seed(1 + self.rank) torch.manual_seed(1 + self.rank)
eval_src = torch.randn((8, 1, 512), device=device) eval_src = torch.randn((8, 1, 512), device=device)
eval_tgt = torch.randn((16, 1, 512), device=device) eval_tgt = torch.randn((16, 1, 512), device=device)
eval_out_sums: List[torch.Tensor] = [] eval_out_sums: list[torch.Tensor] = []
# An iteration consists of training forward/backward/optimizer, # An iteration consists of training forward/backward/optimizer,
# updating the EMA copy with the main copy, and eval forward # updating the EMA copy with the main copy, and eval forward
for _ in range(3): for _ in range(3):

View File

@ -4,7 +4,7 @@ import bisect
import sys import sys
from copy import deepcopy from copy import deepcopy
from enum import auto, Enum from enum import auto, Enum
from typing import Any, Callable, Dict, List, Optional, Tuple, Type from typing import Any, Callable, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -177,7 +177,7 @@ class NestedModel(torch.nn.Module):
model: torch.nn.Module, model: torch.nn.Module,
group: Optional[dist.ProcessGroup] = None, group: Optional[dist.ProcessGroup] = None,
ignore_modules: bool = False, ignore_modules: bool = False,
fsdp_kwargs: Optional[Dict[str, Any]] = None, fsdp_kwargs: Optional[dict[str, Any]] = None,
) -> torch.nn.Module: ) -> torch.nn.Module:
if fsdp_kwargs is None: if fsdp_kwargs is None:
fsdp_kwargs = {} fsdp_kwargs = {}
@ -214,7 +214,7 @@ class NestedModel(torch.nn.Module):
def wrap_alt( def wrap_alt(
model: torch.nn.Module, model: torch.nn.Module,
group: Optional[dist.ProcessGroup] = None, group: Optional[dist.ProcessGroup] = None,
fsdp_kwargs: Optional[Dict[str, Any]] = None, fsdp_kwargs: Optional[dict[str, Any]] = None,
) -> torch.nn.Module: ) -> torch.nn.Module:
if fsdp_kwargs is None: if fsdp_kwargs is None:
fsdp_kwargs = {} fsdp_kwargs = {}
@ -231,7 +231,7 @@ class NestedModel(torch.nn.Module):
model, model,
add_to_fsdp_module: bool, add_to_fsdp_module: bool,
group=None, group=None,
) -> Tuple[torch.nn.Module, List[torch.nn.Parameter]]: ) -> tuple[torch.nn.Module, list[torch.nn.Parameter]]:
"""Registers unmanaged parameters before wrapping with :meth:`wrap`.""" """Registers unmanaged parameters before wrapping with :meth:`wrap`."""
device = next(model.parameters()).device device = next(model.parameters()).device
unmanaged_param = torch.nn.Parameter(torch.randn(5, 5, device=device)) unmanaged_param = torch.nn.Parameter(torch.randn(5, 5, device=device))
@ -277,12 +277,12 @@ class NestedModel(torch.nn.Module):
# NOTE: We exclude `self.bias` from either parameter group to test the # NOTE: We exclude `self.bias` from either parameter group to test the
# case where the optimizer input does not include all model parameters # case where the optimizer input does not include all model parameters
def param_group0(self) -> List[torch.nn.Parameter]: def param_group0(self) -> list[torch.nn.Parameter]:
# Use `block1`'s parameters for the first parameter group to deviate # Use `block1`'s parameters for the first parameter group to deviate
# from the `model.parameters()` order # from the `model.parameters()` order
return list(self.block1.parameters()) return list(self.block1.parameters())
def param_group1(self) -> List[torch.nn.Parameter]: def param_group1(self) -> list[torch.nn.Parameter]:
# Deviate from the `model.parameters()` order further by rearranging # Deviate from the `model.parameters()` order further by rearranging
# `block2`'s parameters to be before `block0`'s parameters # `block2`'s parameters to be before `block0`'s parameters
return list(self.block2.parameters()) + list(self.block0.parameters()) return list(self.block2.parameters()) + list(self.block0.parameters())
@ -322,10 +322,10 @@ class TestFSDPOptimState(FSDPTest):
wrap_alt: bool = False, # ignored if `wrap=False` wrap_alt: bool = False, # ignored if `wrap=False`
device: torch.device = torch.device("cuda"), device: torch.device = torch.device("cuda"),
group=None, group=None,
optim_class: Type[torch.optim.Optimizer] = torch.optim.Adam, optim_class: type[torch.optim.Optimizer] = torch.optim.Adam,
use_multiple_param_groups: bool = False, use_multiple_param_groups: bool = False,
use_diff_optim_inputs: bool = False, use_diff_optim_inputs: bool = False,
fsdp_kwargs: Optional[Dict[str, Any]] = None, fsdp_kwargs: Optional[dict[str, Any]] = None,
): ):
model = NestedModel().to(device) model = NestedModel().to(device)
if wrap: if wrap:
@ -356,7 +356,7 @@ class TestFSDPOptimState(FSDPTest):
wrap: bool, wrap: bool,
device: torch.device = torch.device("cuda"), device: torch.device = torch.device("cuda"),
group=None, group=None,
optim_class: Type[torch.optim.Optimizer] = torch.optim.Adam, optim_class: type[torch.optim.Optimizer] = torch.optim.Adam,
use_multiple_param_groups: bool = False, use_multiple_param_groups: bool = False,
use_diff_optim_inputs: bool = False, use_diff_optim_inputs: bool = False,
): ):
@ -383,7 +383,7 @@ class TestFSDPOptimState(FSDPTest):
optim: torch.optim.Optimizer, optim: torch.optim.Optimizer,
device: torch.device = torch.device("cuda"), device: torch.device = torch.device("cuda"),
num_iters: int = 1, num_iters: int = 1,
) -> List[float]: ) -> list[float]:
"""Performs a forward pass, backward pass, and optimizer step """Performs a forward pass, backward pass, and optimizer step
``num_iters``-many times, and returns the per-iteration losses.""" ``num_iters``-many times, and returns the per-iteration losses."""
torch.manual_seed(0) # set seed for determinism torch.manual_seed(0) # set seed for determinism
@ -399,7 +399,7 @@ class TestFSDPOptimState(FSDPTest):
optim.step() optim.step()
return losses return losses
def _broadcast_full_osd(self, full_osd: Dict[str, Any], group=None): def _broadcast_full_osd(self, full_osd: dict[str, Any], group=None):
"""Broadcasts the full optimizer state dict in place of using """Broadcasts the full optimizer state dict in place of using
``torch.save()`` and ``torch.load()`` so that all ranks can have it.""" ``torch.save()`` and ``torch.load()`` so that all ranks can have it."""
obj_list = [full_osd] obj_list = [full_osd]
@ -413,8 +413,8 @@ class TestFSDPOptimState(FSDPTest):
def _are_equal_states( def _are_equal_states(
self, self,
state1: Dict[str, Any], state1: dict[str, Any],
state2: Dict[str, Any], state2: dict[str, Any],
) -> bool: ) -> bool:
"""Checks if ``state1`` and ``state2`` contain the same mappings.""" """Checks if ``state1`` and ``state2`` contain the same mappings."""
if set(state1.keys()) != set(state2.keys()): if set(state1.keys()) != set(state2.keys()):
@ -1450,7 +1450,7 @@ class TestFSDPOptimState(FSDPTest):
self, self,
should_check_method_fn: Callable[[str], bool], should_check_method_fn: Callable[[str], bool],
context_fn: Callable, context_fn: Callable,
fsdp_kwargs: Optional[Dict[str, Any]], fsdp_kwargs: Optional[dict[str, Any]],
): ):
""" """
Runs through all optimizer state checkpointing APIs with a context Runs through all optimizer state checkpointing APIs with a context

View File

@ -5,7 +5,7 @@ import functools
import itertools import itertools
import sys import sys
import unittest import unittest
from typing import List, Optional from typing import Optional
import torch import torch
from torch import distributed as dist from torch import distributed as dist
@ -259,7 +259,7 @@ class TestShardedGradScalerParityWithDDP(FSDPTest):
) )
grad_scaler = ShardedGradScaler(init_scale=2.0) grad_scaler = ShardedGradScaler(init_scale=2.0)
ref_grad_scaler = torch.amp.GradScaler(device="cuda", init_scale=2.0) ref_grad_scaler = torch.amp.GradScaler(device="cuda", init_scale=2.0)
scaled_losses: List[torch.Tensor] = [] scaled_losses: list[torch.Tensor] = []
device = torch.device("cuda") device = torch.device("cuda")
torch.manual_seed(42 + self.rank + 1) torch.manual_seed(42 + self.rank + 1)

View File

@ -6,7 +6,7 @@ import sys
from contextlib import nullcontext from contextlib import nullcontext
from copy import deepcopy from copy import deepcopy
from functools import partial from functools import partial
from typing import Any, Dict from typing import Any
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -787,7 +787,7 @@ class TestFSDPStateDict(FSDPTest):
@staticmethod @staticmethod
def _load_state_dict( def _load_state_dict(
model: Module, state_dict_type: str, state_dict: Dict[str, Any] model: Module, state_dict_type: str, state_dict: dict[str, Any]
): ):
try: try:
enum_val = STATE_DICT_MAPPING[state_dict_type] enum_val = STATE_DICT_MAPPING[state_dict_type]

View File

@ -2,7 +2,7 @@
import copy import copy
import sys import sys
from collections import OrderedDict from collections import OrderedDict
from typing import Dict, List, Optional, Tuple from typing import Optional
import torch import torch
from torch import distributed as dist from torch import distributed as dist
@ -62,11 +62,11 @@ class SimpleModel(torch.nn.Module):
return self.net3(self.net2(self.relu(self.net1(x)))) return self.net3(self.net2(self.relu(self.net1(x))))
@staticmethod @staticmethod
def get_sharded_param_names() -> List[str]: def get_sharded_param_names() -> list[str]:
return ["net1.weight", "net1.bias", "net2.weight"] return ["net1.weight", "net1.bias", "net2.weight"]
@staticmethod @staticmethod
def get_non_sharded_param_names() -> List[str]: def get_non_sharded_param_names() -> list[str]:
return ["net3.weight", "net3.bias"] return ["net3.weight", "net3.bias"]
@ -87,9 +87,9 @@ class TestTPFSDPIntegration(FSDPTest):
def _get_params_and_sharding_info( def _get_params_and_sharding_info(
self, self,
model: SimpleModel, model: SimpleModel,
sharded_param_names: List[str], sharded_param_names: list[str],
tensor_parallel_size: int, tensor_parallel_size: int,
) -> Tuple[Dict[str, int], Dict[str, Tuple[torch.Size, int]]]: ) -> tuple[dict[str, int], dict[str, tuple[torch.Size, int]]]:
""" """ """ """
assert ( assert (
type(model) is SimpleModel type(model) is SimpleModel
@ -131,8 +131,8 @@ class TestTPFSDPIntegration(FSDPTest):
self, self,
tp_fsdp_model: FSDP, tp_fsdp_model: FSDP,
tp_pg: dist.ProcessGroup, tp_pg: dist.ProcessGroup,
param_name_to_numel: Dict[str, int], param_name_to_numel: dict[str, int],
non_sharded_param_names: List[str], non_sharded_param_names: list[str],
) -> None: ) -> None:
""" """
Syncs the tensor parallel parameters' gradients following the data Syncs the tensor parallel parameters' gradients following the data
@ -177,11 +177,11 @@ class TestTPFSDPIntegration(FSDPTest):
self, self,
model: FSDP, model: FSDP,
uses_tp: bool, uses_tp: bool,
param_name_to_numel: Dict[str, int], param_name_to_numel: dict[str, int],
param_name_to_sharding_info: Dict[str, Tuple[torch.Size, int]], param_name_to_sharding_info: dict[str, tuple[torch.Size, int]],
tp_pg: Optional[dist.ProcessGroup], tp_pg: Optional[dist.ProcessGroup],
fsdp_pg: Optional[dist.ProcessGroup], fsdp_pg: Optional[dist.ProcessGroup],
sharded_param_names: Optional[List[str]], sharded_param_names: Optional[list[str]],
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Returns all unsharded gradients as a single flattened tensor. This Returns all unsharded gradients as a single flattened tensor. This

View File

@ -3,7 +3,7 @@ import contextlib
import itertools import itertools
import math import math
import sys import sys
from typing import Any, Dict, List, Optional, Union from typing import Any, Optional, Union
import torch import torch
import torch.distributed.fsdp._traversal_utils as traversal_utils import torch.distributed.fsdp._traversal_utils as traversal_utils
@ -55,7 +55,7 @@ class TestUnshardParamsBase(FSDPTest):
self, self,
writeback: bool, writeback: bool,
check_outer: bool, check_outer: bool,
**fsdp_kwargs: Dict[str, Any], **fsdp_kwargs: dict[str, Any],
): ):
model = nn.Sequential( model = nn.Sequential(
nn.Linear(5, 5, bias=False, device=device_type.type), nn.Linear(5, 5, bias=False, device=device_type.type),
@ -101,7 +101,7 @@ class TestUnshardParamsBase(FSDPTest):
for param in model.parameters(): for param in model.parameters():
self.assertEqual(param.device, cpu_device) self.assertEqual(param.device, cpu_device)
def _get_test_unshard_params_writeback_config(self) -> Dict[str, List[Any]]: def _get_test_unshard_params_writeback_config(self) -> dict[str, list[Any]]:
return { return {
"writeback": [True, False], "writeback": [True, False],
"check_outer": [True, False], "check_outer": [True, False],
@ -193,7 +193,7 @@ class TestUnshardParamsBase(FSDPTest):
num_fsdp_roots += fsdp_state._is_root num_fsdp_roots += fsdp_state._is_root
self.assertGreater(num_fsdp_roots, 1) self.assertGreater(num_fsdp_roots, 1)
def _get_test_unshard_params_param_data_config(self) -> Dict[str, List[Any]]: def _get_test_unshard_params_param_data_config(self) -> dict[str, list[Any]]:
return { return {
"rank0_only": [False, True], "rank0_only": [False, True],
"offload_to_cpu": [False, True], "offload_to_cpu": [False, True],
@ -493,7 +493,7 @@ class TestUnshardParams(TestUnshardParamsBase):
def _check_grads( def _check_grads(
ddp_model: DDP, ddp_model: DDP,
fsdp_model: FSDP, fsdp_model: FSDP,
old_fsdp_grads: Optional[List[torch.Tensor]], old_fsdp_grads: Optional[list[torch.Tensor]],
): ):
""" """
Checks that writes to the FSDP parameters' gradients persist or do Checks that writes to the FSDP parameters' gradients persist or do

View File

@ -6,7 +6,7 @@ import itertools
import os import os
import sys import sys
import unittest import unittest
from typing import Any, Dict, List, Optional, Tuple, Type from typing import Any, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -65,7 +65,7 @@ class TestFSDPUseOrigParamsMultipleParamGroups(FSDPTest):
def world_size(self) -> int: def world_size(self) -> int:
return 2 return 2
def _get_param_groups(self, model: nn.Module) -> List[Dict[str, Any]]: def _get_param_groups(self, model: nn.Module) -> list[dict[str, Any]]:
""" """
Constructs separate parameter groups for weights, biases, and other Constructs separate parameter groups for weights, biases, and other
parameters. parameters.
@ -87,7 +87,7 @@ class TestFSDPUseOrigParamsMultipleParamGroups(FSDPTest):
def _get_optim( def _get_optim(
self, self,
model: nn.Module, model: nn.Module,
optim_class: Type[torch.optim.Optimizer], optim_class: type[torch.optim.Optimizer],
multi_tensor: bool, multi_tensor: bool,
) -> torch.optim.Optimizer: ) -> torch.optim.Optimizer:
""" """
@ -117,12 +117,12 @@ class TestFSDPUseOrigParamsMultipleParamGroups(FSDPTest):
self, self,
device_init_mode: DEVICEInitMode, device_init_mode: DEVICEInitMode,
init_optim_before_wrap: bool, init_optim_before_wrap: bool,
optim_class: Type[torch.optim.Optimizer], optim_class: type[torch.optim.Optimizer],
multi_tensor: bool, multi_tensor: bool,
sharding_strategy: ShardingStrategy, sharding_strategy: ShardingStrategy,
backward_prefetch: Optional[BackwardPrefetch], backward_prefetch: Optional[BackwardPrefetch],
cpu_offload: CPUOffload, cpu_offload: CPUOffload,
) -> Tuple[FSDP, torch.optim.Optimizer]: ) -> tuple[FSDP, torch.optim.Optimizer]:
""" """
Returns a transformer with shared parameters wrapped with FSDP and a Returns a transformer with shared parameters wrapped with FSDP and a
corresponding optimizer. corresponding optimizer.
@ -335,7 +335,7 @@ class TestFSDPUseOrigParamsMultipleParamGroups(FSDPTest):
self, self,
device_init_mode: DEVICEInitMode, device_init_mode: DEVICEInitMode,
init_optim_before_wrap: bool, init_optim_before_wrap: bool,
optim_class: Type[torch.optim.Optimizer], optim_class: type[torch.optim.Optimizer],
multi_tensor: bool, multi_tensor: bool,
set_to_none: bool, set_to_none: bool,
backward_prefetch: Optional[BackwardPrefetch], backward_prefetch: Optional[BackwardPrefetch],
@ -566,7 +566,7 @@ class TestFSDPUseOrigParamsUnshardReshard(FSDPTest):
self, self,
sharding_strategy: ShardingStrategy, sharding_strategy: ShardingStrategy,
cpu_offload: CPUOffload, cpu_offload: CPUOffload,
) -> Tuple[FSDP, torch.optim.Optimizer, FSDP, torch.optim.Optimizer]: ) -> tuple[FSDP, torch.optim.Optimizer, FSDP, torch.optim.Optimizer]:
""" """
Returns a pair of (FSDP model, optimizer) for ``use_orig_params=False`` Returns a pair of (FSDP model, optimizer) for ``use_orig_params=False``
and ``True``, respectively. and ``True``, respectively.
@ -778,7 +778,7 @@ class TestFSDPUseOrigParamsParamAccess(FSDPTest):
z = self.lin2(z) z = self.lin2(z)
return z return z
def get_input(self, device: torch.device) -> Tuple[torch.Tensor, ...]: def get_input(self, device: torch.device) -> tuple[torch.Tensor, ...]:
return (torch.randn((2, 5)).to(device),) return (torch.randn((2, 5)).to(device),)
def get_loss(self, inp, out): def get_loss(self, inp, out):
@ -872,7 +872,7 @@ class TestFSDPUseOrigParamsWriteback(FSDPTest):
z = self.lin2(z) z = self.lin2(z)
return z return z
def get_input(self, device: torch.device) -> Tuple[torch.Tensor, ...]: def get_input(self, device: torch.device) -> tuple[torch.Tensor, ...]:
return (torch.randn((2, 5)).to(device),) return (torch.randn((2, 5)).to(device),)
def get_loss(self, inp, out): def get_loss(self, inp, out):

View File

@ -4,7 +4,6 @@ import random
import sys import sys
from collections import OrderedDict from collections import OrderedDict
from dataclasses import dataclass from dataclasses import dataclass
from typing import List
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -62,13 +61,13 @@ class TestUtils(TestCase):
class NonFrozenDataClass: class NonFrozenDataClass:
some_key: str some_key: str
some_float: float some_float: float
some_tensor: List[torch.Tensor] some_tensor: list[torch.Tensor]
@dataclass(frozen=True) @dataclass(frozen=True)
class FrozenDataClass: class FrozenDataClass:
some_key: str some_key: str
some_float: float some_float: float
some_tensor: List[torch.Tensor] some_tensor: list[torch.Tensor]
# create a mixed bag of data. # create a mixed bag of data.
data = [1, "str"] data = [1, "str"]

View File

@ -16,7 +16,7 @@ import time
import unittest import unittest
import uuid import uuid
from contextlib import closing from contextlib import closing
from typing import Any, Dict, Optional from typing import Any, Optional
from unittest import mock from unittest import mock
from unittest.mock import MagicMock, Mock, patch from unittest.mock import MagicMock, Mock, patch
@ -59,7 +59,7 @@ def get_test_launch_config(
nproc_per_node: int, nproc_per_node: int,
run_id: str = "", run_id: str = "",
rdzv_backend: str = "etcd", rdzv_backend: str = "etcd",
config: Optional[Dict[str, Any]] = None, config: Optional[dict[str, Any]] = None,
) -> LaunchConfig: ) -> LaunchConfig:
rdzv_configs = {} rdzv_configs = {}
if config: if config:

View File

@ -3,7 +3,6 @@
import sys import sys
from pathlib import Path from pathlib import Path
from typing import Tuple
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -22,7 +21,7 @@ from torch.testing._internal.common_utils import run_tests, TestCase
class MyModuleInterface: class MyModuleInterface:
def forward( def forward(
self, tensor: Tensor, number: int, word: str = "default" self, tensor: Tensor, number: int, word: str = "default"
) -> Tuple[Tensor, int, str]: ) -> tuple[Tensor, int, str]:
pass pass

View File

@ -10,7 +10,7 @@ import os
import sys import sys
import unittest import unittest
from contextlib import nullcontext from contextlib import nullcontext
from typing import Any, cast, List from typing import Any, cast
import numpy as np import numpy as np
@ -207,7 +207,7 @@ class TestZeroRedundancyOptimizerSingleRank(TestZeroRedundancyOptimizer):
super().step() super().step()
kwarg.append(5) kwarg.append(5)
kwarg: List[Any] = [] kwarg: list[Any] = []
x = torch.tensor([1.0], device=self.device, requires_grad=True) x = torch.tensor([1.0], device=self.device, requires_grad=True)
o = ZeroRedundancyOptimizer( o = ZeroRedundancyOptimizer(
[x], [x],

View File

@ -2,7 +2,7 @@
# Owner(s): ["oncall: distributed"] # Owner(s): ["oncall: distributed"]
# This file is a Schedule zoo for testing torch.distributed.pipelining. # This file is a Schedule zoo for testing torch.distributed.pipelining.
# It includes schedules designed purely for testing purposes # It includes schedules designed purely for testing purposes
from typing import Callable, Dict, List, Optional from typing import Callable, Optional
from torch.distributed.pipelining.schedules import ( from torch.distributed.pipelining.schedules import (
_Action, _Action,
@ -32,9 +32,9 @@ class ScheduleVShaped(PipelineScheduleMulti):
def __init__( def __init__(
self, self,
stages: List[_PipelineStageBase], stages: list[_PipelineStageBase],
n_microbatches: int, n_microbatches: int,
stage_index_to_group_rank: Dict[int, int], stage_index_to_group_rank: dict[int, int],
loss_fn: Optional[Callable] = None, loss_fn: Optional[Callable] = None,
scale_grads: bool = True, scale_grads: bool = True,
): ):
@ -82,9 +82,9 @@ class ScheduleUnbalanced(PipelineScheduleMulti):
def __init__( def __init__(
self, self,
stages: List[_PipelineStageBase], stages: list[_PipelineStageBase],
n_microbatches: int, n_microbatches: int,
stage_index_to_group_rank: Dict[int, int], stage_index_to_group_rank: dict[int, int],
loss_fn: Optional[Callable] = None, loss_fn: Optional[Callable] = None,
scale_grads: bool = True, scale_grads: bool = True,
): ):
@ -134,7 +134,7 @@ class ScheduleWithW(PipelineScheduleMulti):
def __init__( def __init__(
self, self,
stages: List[_PipelineStageBase], stages: list[_PipelineStageBase],
n_microbatches: int, n_microbatches: int,
loss_fn: Optional[Callable] = None, loss_fn: Optional[Callable] = None,
enable_zero_bubble: bool = True, enable_zero_bubble: bool = True,
@ -195,7 +195,7 @@ class ScheduleWithReorderedB(_PipelineScheduleRuntime):
def __init__( def __init__(
self, self,
stages: List[_PipelineStageBase], stages: list[_PipelineStageBase],
n_microbatches: int, n_microbatches: int,
loss_fn: Optional[Callable] = None, loss_fn: Optional[Callable] = None,
scale_grads: bool = True, scale_grads: bool = True,

View File

@ -4,7 +4,6 @@ import copy
import csv import csv
import logging import logging
import os import os
from typing import List
from model_registry import MultiMLP from model_registry import MultiMLP
@ -356,7 +355,7 @@ instantiate_parametrized_tests(TestSchedulePlan)
class TestScheduleLowering(TestCase): class TestScheduleLowering(TestCase):
"""Tests lowering passes that convert simple compute-only (FBW) schedules into compute+comms schedules""" """Tests lowering passes that convert simple compute-only (FBW) schedules into compute+comms schedules"""
def _parse_actions(self, actions: List[str]) -> List[_Action]: def _parse_actions(self, actions: list[str]) -> list[_Action]:
return [_Action.from_str(s) for s in actions] return [_Action.from_str(s) for s in actions]
@parametrize( @parametrize(

View File

@ -1,7 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates # Copyright (c) Meta Platforms, Inc. and affiliates
# Owner(s): ["oncall: distributed"] # Owner(s): ["oncall: distributed"]
from typing import Any, Dict from typing import Any
import torch import torch
from torch.distributed._tensor import DeviceMesh from torch.distributed._tensor import DeviceMesh
@ -57,8 +57,8 @@ class TestCommModeFeatures(DTensorTestBase):
Used to generate the ground-truth parameter and sharding info for a given distributed model to Used to generate the ground-truth parameter and sharding info for a given distributed model to
verify comm_mode correctness verify comm_mode correctness
""" """
module_parameters_dict: Dict[str, Any] = {} module_parameters_dict: dict[str, Any] = {}
module_sharding_dict: Dict[str, Any] = {} module_sharding_dict: dict[str, Any] = {}
for name, parameters in model.named_parameters(): for name, parameters in model.named_parameters():
# splits name into module name to create FQN and parameter name # splits name into module name to create FQN and parameter name

View File

@ -1,6 +1,5 @@
# Owner(s): ["oncall: distributed"] # Owner(s): ["oncall: distributed"]
from collections import defaultdict from collections import defaultdict
from typing import Dict
import torch import torch
from torch.distributed._tensor.experimental._tp_transform import ( from torch.distributed._tensor.experimental._tp_transform import (
@ -57,9 +56,9 @@ class TensorParallelTest(DTensorTestBase):
super().setUp() super().setUp()
def assert_has_c10d_ops( def assert_has_c10d_ops(
self, gm: torch.fx.GraphModule, expected_ops_count: Dict[str, int] self, gm: torch.fx.GraphModule, expected_ops_count: dict[str, int]
) -> None: ) -> None:
actual_ops_count: Dict[str, int] = defaultdict(int) actual_ops_count: dict[str, int] = defaultdict(int)
for node in gm.graph.nodes: for node in gm.graph.nodes:
if node.op == "call_function": if node.op == "call_function":
if "c10d_functional" in str(node.target): if "c10d_functional" in str(node.target):
@ -100,7 +99,7 @@ class TensorParallelTest(DTensorTestBase):
torch.manual_seed(0) torch.manual_seed(0)
model = MLPListModule(2).to(device=self.device_type) model = MLPListModule(2).to(device=self.device_type)
inputs = (torch.randn((10, 12)).to(device=self.device_type),) inputs = (torch.randn((10, 12)).to(device=self.device_type),)
parallel_strategies: Dict[str, ParallelStyle] = { parallel_strategies: dict[str, ParallelStyle] = {
"mlps.0.0": ColwiseParallel, "mlps.0.0": ColwiseParallel,
"mlps.0.2": RowwiseParallel, "mlps.0.2": RowwiseParallel,
"mlps.1.0": ColwiseParallel, "mlps.1.0": ColwiseParallel,
@ -137,7 +136,7 @@ class TensorParallelTest(DTensorTestBase):
torch.manual_seed(0) torch.manual_seed(0)
model = MLPListModule(1, bias=False).to(device=self.device_type) model = MLPListModule(1, bias=False).to(device=self.device_type)
inputs = (torch.randn((10, 12)).to(device=self.device_type),) inputs = (torch.randn((10, 12)).to(device=self.device_type),)
parallel_strategies: Dict[str, ParallelStyle] = { parallel_strategies: dict[str, ParallelStyle] = {
"mlps.0.0": ColwiseParallel, "mlps.0.0": ColwiseParallel,
"mlps.0.2": RowwiseParallel, "mlps.0.2": RowwiseParallel,
} }

View File

@ -3,7 +3,7 @@
import itertools import itertools
from copy import deepcopy from copy import deepcopy
from typing import Dict, NamedTuple, Optional from typing import NamedTuple, Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -52,9 +52,9 @@ reduce_scatter, all_gather, all_reduce = (
class ExpCommCounts(NamedTuple): class ExpCommCounts(NamedTuple):
fwd: Optional[Dict] = None fwd: Optional[dict] = None
bwd: Optional[Dict] = None bwd: Optional[dict] = None
optim: Optional[Dict] = None optim: Optional[dict] = None
class DistTensorParallelExampleTest(DTensorTestBase): class DistTensorParallelExampleTest(DTensorTestBase):

View File

@ -3,7 +3,7 @@
import itertools import itertools
import unittest import unittest
from typing import cast, List, Optional, Tuple from typing import cast, Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@ -27,8 +27,8 @@ from torch.testing._internal.distributed._tensor.common_dtensor import (
def scale_for_fp8( def scale_for_fp8(
t: torch.Tensor, scale_shape: Tuple[int] t: torch.Tensor, scale_shape: tuple[int]
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
if all(d == 1 for d in scale_shape): if all(d == 1 for d in scale_shape):
t = t.unsqueeze(0).unsqueeze(-2) t = t.unsqueeze(0).unsqueeze(-2)
else: else:
@ -116,7 +116,7 @@ class DistMatrixOpsTest(DTensorTestBase):
local_res = torch.mm(t1, t2) local_res = torch.mm(t1, t2)
def test_placement_comb( def test_placement_comb(
placements1: List[Placement], placements2: List[Placement] placements1: list[Placement], placements2: list[Placement]
) -> None: ) -> None:
dt1 = distribute_tensor(t1, device_mesh, placements1) dt1 = distribute_tensor(t1, device_mesh, placements1)
dt2 = distribute_tensor(t2, device_mesh, placements2) dt2 = distribute_tensor(t2, device_mesh, placements2)
@ -272,9 +272,9 @@ class DistMatrixOpsTest(DTensorTestBase):
batch_2 = torch.rand(4, 8, 8, device=self.device_type, requires_grad=True) batch_2 = torch.rand(4, 8, 8, device=self.device_type, requires_grad=True)
def test_placement_comb( def test_placement_comb(
tensor_placements: List[Placement], tensor_placements: list[Placement],
batch_1_placements: List[Placement], batch_1_placements: list[Placement],
batch_2_placements: List[Placement], batch_2_placements: list[Placement],
beta: int, beta: int,
alpha: int, alpha: int,
batch_1_grad: Optional[torch.Tensor], batch_1_grad: Optional[torch.Tensor],
@ -338,8 +338,8 @@ class DistMatrixOpsTest(DTensorTestBase):
local_result.backward(grad_local_res) local_result.backward(grad_local_res)
def test_placement_comb( def test_placement_comb(
placements1: List[Placement], placements1: list[Placement],
placements2: List[Placement], placements2: list[Placement],
) -> None: ) -> None:
mat1_dt = distribute_tensor(mat1, device_mesh, placements1) mat1_dt = distribute_tensor(mat1, device_mesh, placements1)
mat2_dt = distribute_tensor(mat2, device_mesh, placements2) mat2_dt = distribute_tensor(mat2, device_mesh, placements2)

View File

@ -2,7 +2,7 @@
# Owner(s): ["oncall: distributed"] # Owner(s): ["oncall: distributed"]
from collections.abc import Sequence from collections.abc import Sequence
from typing import Any, Callable, Dict, Optional from typing import Any, Callable, Optional
from unittest import skip from unittest import skip
import torch import torch
@ -76,7 +76,7 @@ class DistElementwiseOpsTest(DTensorOpTestBase):
op: Callable, op: Callable,
pre_op_fn: Optional[Callable] = None, pre_op_fn: Optional[Callable] = None,
args: Sequence[Any] = (), args: Sequence[Any] = (),
kwargs: Optional[Dict[str, Any]] = None, kwargs: Optional[dict[str, Any]] = None,
): ):
if pre_op_fn is None: if pre_op_fn is None:
pre_op_fn = no_op pre_op_fn = no_op

View File

@ -2,7 +2,7 @@
# Owner(s): ["oncall: distributed"] # Owner(s): ["oncall: distributed"]
import itertools import itertools
from typing import cast, List from typing import cast
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -160,7 +160,7 @@ class TestViewOps(DTensorTestBase):
if op == torch.unbind: if op == torch.unbind:
no_shard_dims.add(kwargs.get("dim", 0)) no_shard_dims.add(kwargs.get("dim", 0))
sharding_choices = cast(List[Placement], [Replicate()]) + [ sharding_choices = cast(list[Placement], [Replicate()]) + [
Shard(i) for i, s in enumerate(in_shape) if s > 1 and i not in no_shard_dims Shard(i) for i, s in enumerate(in_shape) if s > 1 and i not in no_shard_dims
] ]
@ -513,7 +513,7 @@ class TestViewOps(DTensorTestBase):
# test sharded computation correctness # test sharded computation correctness
# NOTE: For the input to torch.view_as_complex, sharding # NOTE: For the input to torch.view_as_complex, sharding
# on the last two dimensions is not supported. # on the last two dimensions is not supported.
sharding_choices: List[Placement] = [Replicate(), Shard(0)] sharding_choices: list[Placement] = [Replicate(), Shard(0)]
all_sharding_choices = itertools.product( all_sharding_choices = itertools.product(
*(self.device_mesh.ndim * [sharding_choices]) *(self.device_mesh.ndim * [sharding_choices])
) )

View File

@ -4,7 +4,7 @@
import os import os
import unittest import unittest
from functools import wraps from functools import wraps
from typing import Any, Callable, Dict, Tuple from typing import Any, Callable
import numpy as np import numpy as np
@ -26,7 +26,7 @@ def with_xla(func: Callable) -> Callable:
@wraps(func) # pyre-ignore[6] @wraps(func) # pyre-ignore[6]
def wrapper( def wrapper(
self, *args: Tuple[object], **kwargs: Dict[str, Any] # type: ignore[misc] self, *args: tuple[object], **kwargs: dict[str, Any] # type: ignore[misc]
) -> None: ) -> None:
# TODO(yeounoh) replace this with xr.use_spmd() when we deprecate the flag. # TODO(yeounoh) replace this with xr.use_spmd() when we deprecate the flag.
os.environ["XLA_USE_SPMD"] = "1" os.environ["XLA_USE_SPMD"] = "1"

View File

@ -12,7 +12,7 @@ from dataclasses import dataclass
from datetime import timedelta from datetime import timedelta
from itertools import product from itertools import product
from sys import platform from sys import platform
from typing import Dict, Optional from typing import Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -972,7 +972,7 @@ class CommonDistributedDataParallelTest:
@dataclass @dataclass
class CustomOutput: class CustomOutput:
o1: Optional[torch.Tensor] o1: Optional[torch.Tensor]
o2: Dict[str, torch.Tensor] o2: dict[str, torch.Tensor]
class DataclassOutputModule(nn.Module): class DataclassOutputModule(nn.Module):
def __init__(self, skip_o1): def __init__(self, skip_o1):

View File

@ -1,7 +1,6 @@
# Owner(s): ["module: c10d"] # Owner(s): ["module: c10d"]
import threading import threading
import unittest import unittest
from typing import List
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -65,7 +64,7 @@ class TestWithNCCL(MultiProcessTestCase):
return 2 return 2
@property @property
def ranks(self) -> List[int]: def ranks(self) -> list[int]:
return list(range(self.world_size)) return list(range(self.world_size))
@property @property
@ -556,7 +555,7 @@ class CompileTest(TestCase):
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@fresh_inductor_cache() @fresh_inductor_cache()
def test_inductor_all_reduce_coalesced(self): def test_inductor_all_reduce_coalesced(self):
def func(args: List[torch.Tensor]) -> torch.Tensor: def func(args: list[torch.Tensor]) -> torch.Tensor:
bufs = [arg + 42 for arg in args] bufs = [arg + 42 for arg in args]
# Expect in-place with inductor allocated buf # Expect in-place with inductor allocated buf
ar0 = funcol.all_reduce_coalesced(bufs, "avg", "0") ar0 = funcol.all_reduce_coalesced(bufs, "avg", "0")
@ -714,7 +713,7 @@ class CompileTest(TestCase):
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@fresh_inductor_cache() @fresh_inductor_cache()
def test_inductor_all_gather_into_tensor_coalesced(self): def test_inductor_all_gather_into_tensor_coalesced(self):
def func(args: List[torch.Tensor]) -> torch.Tensor: def func(args: list[torch.Tensor]) -> torch.Tensor:
ag0 = funcol.all_gather_into_tensor_coalesced(args, "0") ag0 = funcol.all_gather_into_tensor_coalesced(args, "0")
ag0 = [funcol.wait_tensor(out) for out in ag0] ag0 = [funcol.wait_tensor(out) for out in ag0]
return ag0 return ag0
@ -796,7 +795,7 @@ class CompileTest(TestCase):
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@fresh_inductor_cache() @fresh_inductor_cache()
def test_inductor_reduce_scatter_tensor_coalesced(self): def test_inductor_reduce_scatter_tensor_coalesced(self):
def func(args: List[torch.Tensor]) -> torch.Tensor: def func(args: list[torch.Tensor]) -> torch.Tensor:
rs0 = funcol.reduce_scatter_tensor_coalesced( rs0 = funcol.reduce_scatter_tensor_coalesced(
args, "avg", [0] * len(args), "0" args, "avg", [0] * len(args), "0"
) )

View File

@ -27,7 +27,6 @@ if not c10d.is_available() or not c10d.is_nccl_available():
print("c10d NCCL not available, skipping tests", file=sys.stderr) print("c10d NCCL not available, skipping tests", file=sys.stderr)
sys.exit(0) sys.exit(0)
from typing import Dict, List
import test_c10d_common import test_c10d_common
from test_c10d_common import ConvNet, DoubleGpuNet, gpus_for_rank, ModuleForDdpCommHook from test_c10d_common import ConvNet, DoubleGpuNet, gpus_for_rank, ModuleForDdpCommHook
@ -2546,7 +2545,7 @@ class WorkHookTest(MultiProcessTestCase):
def test_on_completion_hook_broadcast(self): def test_on_completion_hook_broadcast(self):
pg = self._get_process_group() pg = self._get_process_group()
num_hook_fired = 0 num_hook_fired = 0
durations: List[float] = [] durations: list[float] = []
def hook(work_info: torch._C._distributed_c10d.WorkInfo): def hook(work_info: torch._C._distributed_c10d.WorkInfo):
nonlocal num_hook_fired, durations nonlocal num_hook_fired, durations
@ -2574,7 +2573,7 @@ class WorkHookTest(MultiProcessTestCase):
def test_on_completion_hook_mixed_ops(self): def test_on_completion_hook_mixed_ops(self):
pg = self._get_process_group() pg = self._get_process_group()
num_hook_fired = 0 num_hook_fired = 0
durations: List[float] = [] durations: list[float] = []
def hook(work_info: torch._C._distributed_c10d.WorkInfo): def hook(work_info: torch._C._distributed_c10d.WorkInfo):
nonlocal num_hook_fired, durations nonlocal num_hook_fired, durations
@ -2615,8 +2614,8 @@ class WorkHookTest(MultiProcessTestCase):
@skip_if_lt_x_gpu(2) @skip_if_lt_x_gpu(2)
def test_on_completion_hook_with_ddp(self): def test_on_completion_hook_with_ddp(self):
pg = self._get_process_group() pg = self._get_process_group()
num_hook_fired: Dict[int, int] = {} num_hook_fired: dict[int, int] = {}
durations: Dict[OpType, List[float]] = {} durations: dict[OpType, list[float]] = {}
def hook(work_info: torch._C._distributed_c10d.WorkInfo): def hook(work_info: torch._C._distributed_c10d.WorkInfo):
nonlocal num_hook_fired, durations nonlocal num_hook_fired, durations
@ -2673,8 +2672,8 @@ class WorkHookTest(MultiProcessTestCase):
torch.cuda.set_device(self.rank) torch.cuda.set_device(self.rank)
pg = self._get_process_group() pg = self._get_process_group()
num_hook_fired: Dict[int, int] = {} num_hook_fired: dict[int, int] = {}
durations: Dict[OpType, List[float]] = {} durations: dict[OpType, list[float]] = {}
def hook(work_info: torch._C._distributed_c10d.WorkInfo): def hook(work_info: torch._C._distributed_c10d.WorkInfo):
nonlocal num_hook_fired, durations nonlocal num_hook_fired, durations

View File

@ -2,21 +2,15 @@
import copy import copy
import os import os
import sys
import tempfile import tempfile
import test_c10d_spawn
from test_c10d_spawn import _torch_dist_nn_available, TestDistributedNNFunctions from test_c10d_spawn import _torch_dist_nn_available, TestDistributedNNFunctions
import torch import torch
import torch.distributed as c10d import torch.distributed as c10d
import torch.nn as nn import torch.nn as nn
from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_distributed import ( from torch.testing._internal.common_distributed import requires_gloo, skip_if_lt_x_gpu
create_device,
requires_gloo,
skip_if_lt_x_gpu,
)
from torch.testing._internal.common_utils import ( from torch.testing._internal.common_utils import (
run_tests, run_tests,
skip_but_pass_in_sandcastle_if, skip_but_pass_in_sandcastle_if,
@ -26,95 +20,6 @@ from torch.testing._internal.common_utils import (
# Fails on Python-3.9, see https://github.com/pytorch/pytorch/issues/51619 # Fails on Python-3.9, see https://github.com/pytorch/pytorch/issues/51619
if sys.version_info < (3, 9):
class ProcessGroupShareTensorTest(
test_c10d_spawn.AbstractProcessGroupShareTensorTest, TestCase
):
@classmethod
def opts(cls, threads=2):
opts = c10d.ProcessGroupGloo._Options()
opts._timeout = 5.0
opts._devices = [create_device(interface="lo")]
opts._threads = threads
return opts
@classmethod
def _init_pg_gloo(cls, rank, filename, world_size):
store = c10d.FileStore(filename, world_size)
backend = c10d.ProcessGroupGloo(
store, rank, world_size, ProcessGroupShareTensorTest.opts()
)
# set process group backends manually
c10d.init_process_group(
backend="gloo", store=store, rank=rank, world_size=world_size
)
pg = c10d.distributed_c10d._get_default_group()
pg._register_backend(
torch.device("cpu"), c10d.ProcessGroup.BackendType.GLOO, backend
)
pg._register_backend(
torch.device("cuda"), c10d.ProcessGroup.BackendType.GLOO, backend
)
return pg
@skip_but_pass_in_sandcastle_if(
not TEST_MULTIGPU, "At least 2 CUDA GPUS needed"
)
def test_shared_broadcast_gloo(self):
self._test_multiprocess(
ProcessGroupShareTensorTest._test_broadcast_process,
[torch.ones(2, 2).to(i) * i for i in range(self.world_size)],
ProcessGroupShareTensorTest._init_pg_gloo,
1,
)
@skip_but_pass_in_sandcastle_if(
not TEST_MULTIGPU, "At least 2 CUDA GPUS needed"
)
def test_shared_allreduce_gloo(self):
self._test_multiprocess(
ProcessGroupShareTensorTest._test_allreduce_process,
[torch.ones(2, 2).to(i) for i in range(self.world_size)],
ProcessGroupShareTensorTest._init_pg_gloo,
1,
)
@skip_but_pass_in_sandcastle_if(
not TEST_MULTIGPU, "At least 2 CUDA GPUS needed"
)
def test_shared_allgather_gloo(self):
self._test_multiprocess(
ProcessGroupShareTensorTest._test_allgather_process,
[torch.ones(2, 2).to(i) * i for i in range(self.world_size)],
ProcessGroupShareTensorTest._init_pg_gloo,
self.world_size,
)
@classmethod
def _test_allgather_chunk_process(
cls, rank, filename, shared_tensor, world_size, init_pg, c2p, p2c
):
pg = init_pg(rank, filename, world_size)
chunks = torch.chunk(shared_tensor, world_size, dim=0)
x = chunks[rank]
ys = [torch.zeros_like(x) for _ in range(world_size)]
pg.allgather(ys, x).wait()
c2p.put((rank, chunks[0].to("cpu"), ys[0].to("cpu")))
c2p.put((rank, chunks[1].to("cpu"), ys[1].to("cpu")))
p2c.get()
@skip_but_pass_in_sandcastle_if(
not TEST_MULTIGPU, "At least 2 CUDA GPUS needed"
)
def test_shared_allgather_chunk_gloo(self):
self._test_multiprocess(
ProcessGroupShareTensorTest._test_allgather_chunk_process,
torch.tensor(range(4)).reshape(2, 2),
ProcessGroupShareTensorTest._init_pg_gloo,
self.world_size,
)
class DistributedDataParallelSingleProcessTest(TestCase): class DistributedDataParallelSingleProcessTest(TestCase):

View File

@ -1,95 +1,21 @@
# Owner(s): ["oncall: distributed"] # Owner(s): ["oncall: distributed"]
import sys
import test_c10d_spawn
from test_c10d_spawn import _torch_dist_nn_available, TestDistributedNNFunctions from test_c10d_spawn import _torch_dist_nn_available, TestDistributedNNFunctions
import torch import torch
import torch.distributed as c10d import torch.distributed as c10d
from torch.testing._internal.common_cuda import TEST_MULTIGPU
from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu
from torch.testing._internal.common_utils import ( from torch.testing._internal.common_utils import (
run_tests, run_tests,
skip_but_pass_in_sandcastle_if, skip_but_pass_in_sandcastle_if,
TEST_WITH_DEV_DBG_ASAN, TEST_WITH_DEV_DBG_ASAN,
TestCase,
) )
NO_NCCL = not hasattr(c10d, "ProcessGroupNCCL") NO_NCCL = not hasattr(c10d, "ProcessGroupNCCL")
# Fails on Python-3.9, see https://github.com/pytorch/pytorch/issues/51619 # Fails on Python-3.9, see https://github.com/pytorch/pytorch/issues/51619
if sys.version_info < (3, 9):
class ProcessGroupShareTensorTest(
test_c10d_spawn.AbstractProcessGroupShareTensorTest, TestCase
):
@classmethod
def _init_pg_nccl(cls, rank, filename, world_size):
store = c10d.FileStore(filename, world_size)
return c10d.ProcessGroupNCCL(store, rank, world_size)
@skip_but_pass_in_sandcastle_if(
not TEST_MULTIGPU, "At least 2 CUDA GPUS needed"
)
@skip_but_pass_in_sandcastle_if(NO_NCCL, "NCCL needed")
def test_shared_broadcast_nccl(self):
self._test_multiprocess(
ProcessGroupShareTensorTest._test_broadcast_process,
[torch.ones(2, 2).to(i) * i for i in range(self.world_size)],
ProcessGroupShareTensorTest._init_pg_nccl,
1,
)
@skip_but_pass_in_sandcastle_if(
not TEST_MULTIGPU, "At least 2 CUDA GPUS needed"
)
@skip_but_pass_in_sandcastle_if(NO_NCCL, "NCCL needed")
def test_shared_allreduce_nccl(self):
self._test_multiprocess(
ProcessGroupShareTensorTest._test_allreduce_process,
[torch.ones(2, 2).to(i) for i in range(self.world_size)],
ProcessGroupShareTensorTest._init_pg_nccl,
1,
)
@classmethod
def _test_reduce_process(
cls, rank, filename, shared_tensors, world_size, init_pg, c2p, p2c
):
pg = init_pg(rank, filename, world_size)
x = shared_tensors[rank]
pg.reduce(x, root=0, op=c10d.ReduceOp.SUM).wait()
if rank == 0:
c2p.put((rank, torch.ones(2, 2) * 2, x.to("cpu")))
else:
c2p.put((rank, torch.ones(2, 2), x.to("cpu")))
p2c.get()
@skip_but_pass_in_sandcastle_if(
not TEST_MULTIGPU, "At least 2 CUDA GPUS needed"
)
@skip_but_pass_in_sandcastle_if(NO_NCCL, "NCCL needed")
def test_shared_reduce_nccl(self):
self._test_multiprocess(
ProcessGroupShareTensorTest._test_reduce_process,
[torch.ones(2, 2).to(i) for i in range(self.world_size)],
ProcessGroupShareTensorTest._init_pg_nccl,
1,
)
@skip_but_pass_in_sandcastle_if(
not TEST_MULTIGPU, "At least 2 CUDA GPUS needed"
)
@skip_but_pass_in_sandcastle_if(NO_NCCL, "NCCL needed")
def test_shared_allgather_nccl(self):
self._test_multiprocess(
ProcessGroupShareTensorTest._test_allgather_process,
[torch.ones(2, 2).to(i) * i for i in range(self.world_size)],
ProcessGroupShareTensorTest._init_pg_nccl,
self.world_size,
)
# Skip dev-asan as torch + multiprocessing spawn have known issues # Skip dev-asan as torch + multiprocessing spawn have known issues

View File

@ -1,74 +1,21 @@
# Owner(s): ["oncall: distributed"] # Owner(s): ["oncall: distributed"]
import sys
import test_c10d_spawn
from test_c10d_spawn import _torch_dist_nn_available, TestDistributedNNFunctions from test_c10d_spawn import _torch_dist_nn_available, TestDistributedNNFunctions
import torch
import torch.distributed as c10d import torch.distributed as c10d
from torch.testing._internal.common_cuda import TEST_MULTIGPU
from torch.testing._internal.common_distributed import requires_ucc, skip_if_lt_x_gpu from torch.testing._internal.common_distributed import requires_ucc, skip_if_lt_x_gpu
from torch.testing._internal.common_utils import ( from torch.testing._internal.common_utils import (
run_tests, run_tests,
skip_but_pass_in_sandcastle, skip_but_pass_in_sandcastle,
skip_but_pass_in_sandcastle_if, skip_but_pass_in_sandcastle_if,
TEST_WITH_DEV_DBG_ASAN, TEST_WITH_DEV_DBG_ASAN,
TestCase,
) )
NO_UCC = not hasattr(c10d, "ProcessGroupUCC") NO_UCC = not hasattr(c10d, "ProcessGroupUCC")
# Fails on Python-3.9, see https://github.com/pytorch/pytorch/issues/51619 # Fails on Python-3.9, see https://github.com/pytorch/pytorch/issues/51619
if sys.version_info < (3, 9):
class ProcessGroupShareTensorTest(
test_c10d_spawn.AbstractProcessGroupShareTensorTest, TestCase
):
@classmethod
def _init_pg_ucc(cls, rank, filename, world_size):
store = c10d.FileStore(filename, world_size)
c10d.init_process_group(
backend="ucc", store=store, rank=rank, world_size=world_size
)
return c10d.distributed_c10d._get_default_group()
@skip_but_pass_in_sandcastle_if(
not TEST_MULTIGPU, "At least 2 CUDA GPUS needed"
)
@skip_but_pass_in_sandcastle_if(NO_UCC, "UCC needed")
def test_shared_broadcast_ucc(self):
self._test_multiprocess(
ProcessGroupShareTensorTest._test_broadcast_process,
[torch.ones(2, 2).to(i) * i for i in range(self.world_size)],
ProcessGroupShareTensorTest._init_pg_ucc,
1,
)
@skip_but_pass_in_sandcastle_if(
not TEST_MULTIGPU, "At least 2 CUDA GPUS needed"
)
@skip_but_pass_in_sandcastle_if(NO_UCC, "UCC needed")
def test_shared_allreduce_ucc(self):
self._test_multiprocess(
ProcessGroupShareTensorTest._test_allreduce_process,
[torch.ones(2, 2).to(i) for i in range(self.world_size)],
ProcessGroupShareTensorTest._init_pg_ucc,
1,
)
@skip_but_pass_in_sandcastle_if(
not TEST_MULTIGPU, "At least 2 CUDA GPUS needed"
)
@skip_but_pass_in_sandcastle_if(NO_UCC, "UCC needed")
def test_shared_allgather_ucc(self):
self._test_multiprocess(
ProcessGroupShareTensorTest._test_allgather_process,
[torch.ones(2, 2).to(i) * i for i in range(self.world_size)],
ProcessGroupShareTensorTest._init_pg_ucc,
self.world_size,
)
# Skip dev-asan as torch + multiprocessing spawn have known issues # Skip dev-asan as torch + multiprocessing spawn have known issues

View File

@ -7,7 +7,6 @@ import unittest
from contextlib import contextmanager from contextlib import contextmanager
from datetime import timedelta from datetime import timedelta
from io import StringIO from io import StringIO
from typing import List
from unittest.mock import patch from unittest.mock import patch
import numpy as np import numpy as np
@ -1959,7 +1958,7 @@ class TestSingleProc(DynamoDistributedSingleProcTestCase):
model = ModuleWithStaticMethod(False) model = ModuleWithStaticMethod(False)
x = torch.randn((2, 3), device="cuda") x = torch.randn((2, 3), device="cuda")
ref_out = model(x) ref_out = model(x)
test_outs: List[torch.Tensor] = [] test_outs: list[torch.Tensor] = []
for use_self in (False, True): for use_self in (False, True):
model = ModuleWithStaticMethod(use_self) model = ModuleWithStaticMethod(use_self)

View File

@ -5,7 +5,6 @@ import functools
import math import math
import unittest # noqa: F811 import unittest # noqa: F811
from importlib import import_module from importlib import import_module
from typing import Set
import torch import torch
import torch._dynamo.config import torch._dynamo.config
@ -86,7 +85,7 @@ def count_ops(
return gm return gm
def collect_fwd_graph_outputs(graph: torch.fx.Graph, *, fwd_outputs: Set[str]): def collect_fwd_graph_outputs(graph: torch.fx.Graph, *, fwd_outputs: set[str]):
if not torch._dynamo.compiled_autograd.in_compiled_autograd_region: # fwd graph if not torch._dynamo.compiled_autograd.in_compiled_autograd_region: # fwd graph
return_node = list(graph.nodes)[-1] return_node = list(graph.nodes)[-1]
assert return_node.target == "output" assert return_node.target == "output"

View File

@ -12,7 +12,6 @@ import operator
import unittest import unittest
from collections.abc import Sequence from collections.abc import Sequence
from enum import Enum from enum import Enum
from typing import Dict, List
from unittest.mock import patch from unittest.mock import patch
import torch import torch
@ -1999,7 +1998,7 @@ def forward(self, l_x_):
self.assertIn("val", node.meta) self.assertIn("val", node.meta)
def test_input_container_type(self): def test_input_container_type(self):
def f(x: torch.Tensor, y: List[torch.Tensor]) -> Dict[str, torch.Tensor]: def f(x: torch.Tensor, y: list[torch.Tensor]) -> dict[str, torch.Tensor]:
return {"a": x.sum() + sum(y).sum()} return {"a": x.sum() + sum(y).sum()}
inp = (torch.randn(6, 5), [torch.randn(6, 5), torch.randn(6, 5)]) inp = (torch.randn(6, 5), [torch.randn(6, 5), torch.randn(6, 5)])

View File

@ -11,7 +11,7 @@ import random
import sys import sys
import unittest import unittest
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Dict, Generic, List, TypeVar from typing import Any, Generic, TypeVar
from typing_extensions import NamedTuple from typing_extensions import NamedTuple
from unittest.mock import patch from unittest.mock import patch
@ -3896,8 +3896,8 @@ class DefaultsTests(torch._dynamo.test_case.TestCase):
@dataclass @dataclass
class Output: class Output:
scalar: int = 2 scalar: int = 2
named_tensors: Dict[str, torch.Tensor] = field(default_factory=dict) named_tensors: dict[str, torch.Tensor] = field(default_factory=dict)
lists: List[torch.Tensor] = field(default_factory=list) lists: list[torch.Tensor] = field(default_factory=list)
def scale(self): def scale(self):
return self.scalar * 2 return self.scalar * 2

View File

@ -12,7 +12,7 @@ import types
import unittest import unittest
from copy import deepcopy from copy import deepcopy
from functools import partial from functools import partial
from typing import Dict, NamedTuple, Tuple from typing import NamedTuple
from unittest.mock import patch from unittest.mock import patch
import torch import torch
@ -602,7 +602,7 @@ class LazyMLP(torch.nn.Module):
class MyInput(NamedTuple): class MyInput(NamedTuple):
x: Dict[str, Dict[str, torch.Tensor]] x: dict[str, dict[str, torch.Tensor]]
y: torch.Tensor y: torch.Tensor
@ -2311,7 +2311,7 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
m = TestModule() m = TestModule()
def forward_hook( def forward_hook(
module: torch.nn.Module, inputs: Tuple[torch.Tensor], output: torch.Tensor module: torch.nn.Module, inputs: tuple[torch.Tensor], output: torch.Tensor
) -> torch.Tensor: ) -> torch.Tensor:
return 2 * output + 1 return 2 * output + 1
@ -2358,7 +2358,7 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
m = TestModule() m = TestModule()
def forward_hook( def forward_hook(
module: torch.nn.Module, inputs: Tuple[torch.Tensor], output: torch.Tensor module: torch.nn.Module, inputs: tuple[torch.Tensor], output: torch.Tensor
) -> torch.Tensor: ) -> torch.Tensor:
return 2 * output + 1 return 2 * output + 1
@ -2407,7 +2407,7 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
self.assertEqual(compiled_func(inp).item(), 15) self.assertEqual(compiled_func(inp).item(), 15)
def new_forward_hook( def new_forward_hook(
module: torch.nn.Module, inputs: Tuple[torch.Tensor], output: torch.Tensor module: torch.nn.Module, inputs: tuple[torch.Tensor], output: torch.Tensor
) -> torch.Tensor: ) -> torch.Tensor:
return 2 * output + 2 return 2 * output + 2
@ -2426,7 +2426,7 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
m = TestModule() m = TestModule()
def forward_hook( def forward_hook(
module: torch.nn.Module, inputs: Tuple[torch.Tensor], output: torch.Tensor module: torch.nn.Module, inputs: tuple[torch.Tensor], output: torch.Tensor
) -> torch.Tensor: ) -> torch.Tensor:
return 2 * output + 1 return 2 * output + 1

View File

@ -1,5 +1,5 @@
# Owner(s): ["module: dynamo"] # Owner(s): ["module: dynamo"]
from typing import Callable, Dict, List, NamedTuple, Optional from typing import Callable, List, NamedTuple, Optional
import torch import torch
import torch._dynamo import torch._dynamo
@ -50,20 +50,20 @@ class Variable:
def sum(self, name: Optional[str] = None) -> "Variable": def sum(self, name: Optional[str] = None) -> "Variable":
return operator_sum(self, name) return operator_sum(self, name)
def expand(self, sizes: List[int]) -> "Variable": def expand(self, sizes: list[int]) -> "Variable":
return operator_expand(self, sizes) return operator_expand(self, sizes)
class TapeEntry(NamedTuple): class TapeEntry(NamedTuple):
# names of the inputs to the original computation # names of the inputs to the original computation
inputs: List[str] inputs: list[str]
# names of the outputs of the original computation # names of the outputs of the original computation
outputs: List[str] outputs: list[str]
# apply chain rule # apply chain rule
propagate: "Callable[List[Variable], List[Variable]]" propagate: "Callable[list[Variable], list[Variable]]"
gradient_tape: List[TapeEntry] = [] gradient_tape: list[TapeEntry] = []
def reset_tape(): def reset_tape():
@ -72,9 +72,9 @@ def reset_tape():
_name = 0 _name = 0
def grad(L, desired_results: List[Variable]) -> List[Variable]: def grad(L, desired_results: list[Variable]) -> list[Variable]:
# this map holds dL/dX for all values X # this map holds dL/dX for all values X
dL_d: Dict[str, Variable] = {} dL_d: dict[str, Variable] = {}
# It starts by initializing the 'seed' dL/dL, which is 1 # It starts by initializing the 'seed' dL/dL, which is 1
dL_d[L.name] = Variable(torch.ones(())) dL_d[L.name] = Variable(torch.ones(()))
# print(f'd{L.name} ------------------------') # print(f'd{L.name} ------------------------')

View File

@ -3,7 +3,6 @@
import contextlib import contextlib
import dis import dis
import unittest import unittest
from typing import List
import torch import torch
import torch._dynamo.test_case import torch._dynamo.test_case
@ -33,7 +32,7 @@ class ReconstructTest(torch._dynamo.test_case.TestCase):
Emit code to reconstruct only the key that changed Emit code to reconstruct only the key that changed
""" """
def hook(instructions: List[dis.Instruction]): def hook(instructions: list[dis.Instruction]):
build_map = _filter_instructions(instructions, "BUILD_MAP") build_map = _filter_instructions(instructions, "BUILD_MAP")
self.assertEqual(len(build_map), 1) self.assertEqual(len(build_map), 1)
# reconstruct only d[40] # reconstruct only d[40]
@ -57,7 +56,7 @@ class ReconstructTest(torch._dynamo.test_case.TestCase):
If something is pop'ed from the dict, we reconstruct everything If something is pop'ed from the dict, we reconstruct everything
""" """
def hook(instructions: List[dis.Instruction]): def hook(instructions: list[dis.Instruction]):
build_map = _filter_instructions(instructions, "BUILD_MAP") build_map = _filter_instructions(instructions, "BUILD_MAP")
self.assertEqual(len(build_map), 1) self.assertEqual(len(build_map), 1)
# reconstruct everything # reconstruct everything
@ -84,7 +83,7 @@ class ReconstructTest(torch._dynamo.test_case.TestCase):
If something is pop'ed from the dict, we reconstruct everything If something is pop'ed from the dict, we reconstruct everything
""" """
def hook(instructions: List[dis.Instruction]): def hook(instructions: list[dis.Instruction]):
build_map = _filter_instructions(instructions, "BUILD_MAP") build_map = _filter_instructions(instructions, "BUILD_MAP")
self.assertEqual(len(build_map), 1) self.assertEqual(len(build_map), 1)
# reconstruct everything # reconstruct everything
@ -128,7 +127,7 @@ class ReconstructTest(torch._dynamo.test_case.TestCase):
If something is deleted from the dict, we reconstruct everything If something is deleted from the dict, we reconstruct everything
""" """
def hook(instructions: List[dis.Instruction]): def hook(instructions: list[dis.Instruction]):
build_map = _filter_instructions(instructions, "BUILD_MAP") build_map = _filter_instructions(instructions, "BUILD_MAP")
self.assertEqual(len(build_map), 1) self.assertEqual(len(build_map), 1)
# reconstruct everything # reconstruct everything
@ -154,7 +153,7 @@ class ReconstructTest(torch._dynamo.test_case.TestCase):
dict.get shouldn't affect anything dict.get shouldn't affect anything
""" """
def hook(instructions: List[dis.Instruction]): def hook(instructions: list[dis.Instruction]):
build_map = _filter_instructions(instructions, "BUILD_MAP") build_map = _filter_instructions(instructions, "BUILD_MAP")
self.assertEqual(len(build_map), 1) self.assertEqual(len(build_map), 1)
self.assertEqual(build_map[0].argval, 1) self.assertEqual(build_map[0].argval, 1)
@ -180,7 +179,7 @@ class ReconstructTest(torch._dynamo.test_case.TestCase):
If dict.clear() is used, we reconstruct everything If dict.clear() is used, we reconstruct everything
""" """
def hook(instructions: List[dis.Instruction]): def hook(instructions: list[dis.Instruction]):
build_map = _filter_instructions(instructions, "BUILD_MAP") build_map = _filter_instructions(instructions, "BUILD_MAP")
self.assertEqual(len(build_map), 1) self.assertEqual(len(build_map), 1)
# reconstruct everything # reconstruct everything
@ -206,7 +205,7 @@ class ReconstructTest(torch._dynamo.test_case.TestCase):
If dict is created inside a function, everything needs to be reconstructed If dict is created inside a function, everything needs to be reconstructed
""" """
def hook(instructions: List[dis.Instruction]): def hook(instructions: list[dis.Instruction]):
build_map = _filter_instructions(instructions, "BUILD_MAP") build_map = _filter_instructions(instructions, "BUILD_MAP")
self.assertEqual(len(build_map), 1) self.assertEqual(len(build_map), 1)
# reconstruct everything # reconstruct everything
@ -231,7 +230,7 @@ class ReconstructTest(torch._dynamo.test_case.TestCase):
PyTorch shouldn't codegen any key/value when functional_call is used PyTorch shouldn't codegen any key/value when functional_call is used
""" """
def hook(instructions: List[dis.Instruction]): def hook(instructions: list[dis.Instruction]):
build_map = _filter_instructions(instructions, "BUILD_MAP") build_map = _filter_instructions(instructions, "BUILD_MAP")
# don't reconstruct anything # don't reconstruct anything
self.assertEqual(len(build_map), 0) self.assertEqual(len(build_map), 0)
@ -260,7 +259,7 @@ class ReconstructTest(torch._dynamo.test_case.TestCase):
PyTorch shouldn't codegen any key/value when functional_call is used PyTorch shouldn't codegen any key/value when functional_call is used
""" """
def hook(instructions: List[dis.Instruction]): def hook(instructions: list[dis.Instruction]):
build_map = _filter_instructions(instructions, "BUILD_MAP") build_map = _filter_instructions(instructions, "BUILD_MAP")
# don't reconstruct anything # don't reconstruct anything
self.assertEqual(len(build_map), 0) self.assertEqual(len(build_map), 0)

View File

@ -25,7 +25,7 @@ from collections.abc import Iterator
from copy import deepcopy from copy import deepcopy
from enum import Enum, IntEnum from enum import Enum, IntEnum
from functools import wraps from functools import wraps
from typing import Any, Dict, List, Literal, Tuple, TypedDict from typing import Any, Literal, TypedDict
from unittest import mock from unittest import mock
import numpy as np import numpy as np
@ -712,7 +712,7 @@ def create_rand_mask_from_inputs(
class SequentialAppendList(torch.nn.Sequential): class SequentialAppendList(torch.nn.Sequential):
"""from timm/models/vovnet.py""" """from timm/models/vovnet.py"""
def forward(self, x: torch.Tensor, concat_list: List[torch.Tensor]) -> torch.Tensor: def forward(self, x: torch.Tensor, concat_list: list[torch.Tensor]) -> torch.Tensor:
for i, module in enumerate(self): for i, module in enumerate(self):
if i == 0: if i == 0:
concat_list.append(module(x)) concat_list.append(module(x))
@ -4108,7 +4108,7 @@ class ReproTests(torch._dynamo.test_case.TestCase):
def test_graph_break_on_jit_isinstance(self): def test_graph_break_on_jit_isinstance(self):
@torch.compile(backend="eager") @torch.compile(backend="eager")
def fn(x): def fn(x):
if torch.jit.isinstance(x, List[str]): if torch.jit.isinstance(x, list[str]):
return x * 2 return x * 2
return x return x
@ -4819,14 +4819,14 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
def test_detectron2_instances_cat(self): def test_detectron2_instances_cat(self):
class Instances: class Instances:
def __init__(self, image_size: Tuple[int, int], **kwargs: Any): def __init__(self, image_size: tuple[int, int], **kwargs: Any):
self._image_size = image_size self._image_size = image_size
self._fields: Dict[str, Any] = {} self._fields: dict[str, Any] = {}
for k, v in kwargs.items(): for k, v in kwargs.items():
self.set(k, v) self.set(k, v)
@property @property
def image_size(self) -> Tuple[int, int]: def image_size(self) -> tuple[int, int]:
return self._image_size return self._image_size
def __setattr__(self, name: str, val: Any) -> None: def __setattr__(self, name: str, val: Any) -> None:
@ -4861,7 +4861,7 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
return self._fields[name] return self._fields[name]
@staticmethod @staticmethod
def cat(instance_lists: List["Instances"]) -> "Instances": def cat(instance_lists: list["Instances"]) -> "Instances":
assert all(isinstance(i, Instances) for i in instance_lists) assert all(isinstance(i, Instances) for i in instance_lists)
assert len(instance_lists) > 0 assert len(instance_lists) > 0
if len(instance_lists) == 1: if len(instance_lists) == 1:
@ -5997,7 +5997,7 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
# https://github.com/pytorch/pytorch/issues/88813 # https://github.com/pytorch/pytorch/issues/88813
def test_return_value_duplication_tensor(self) -> None: def test_return_value_duplication_tensor(self) -> None:
def fn(val: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: def fn(val: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
return val * 2, val * 2 return val * 2, val * 2
x = torch.randn(2, requires_grad=True) x = torch.randn(2, requires_grad=True)
@ -6016,7 +6016,7 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
# https://github.com/pytorch/pytorch/issues/114344 # https://github.com/pytorch/pytorch/issues/114344
def test_return_value_duplication_mixed_grad(self) -> None: def test_return_value_duplication_mixed_grad(self) -> None:
def fn(val: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: def fn(val: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
with torch.no_grad(): with torch.no_grad():
out0 = val + 1 out0 = val + 1
out1 = val + 1 out1 = val + 1
@ -6033,7 +6033,7 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
# https://github.com/pytorch/pytorch/pull/134726#discussion_r1738774371 # https://github.com/pytorch/pytorch/pull/134726#discussion_r1738774371
def test_return_value_duplication_scalar(self) -> None: def test_return_value_duplication_scalar(self) -> None:
def fn(val: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: def fn(val: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
x, y = val * 2, val * 2 x, y = val * 2, val * 2
return x[0], y[0] return x[0], y[0]

View File

@ -1822,7 +1822,7 @@ class GraphModule(torch.nn.Module):
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", True) @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
@parametrize("dynamic", [True, False]) @parametrize("dynamic", [True, False])
def test_mark_static_with_subclass_desugaring(self, dynamic): def test_mark_static_with_subclass_desugaring(self, dynamic):
from typing import Any, Callable, List, Optional from typing import Any, Callable, Optional
from torch._dynamo.decorators import mark_static_address from torch._dynamo.decorators import mark_static_address
from torch._inductor.compile_fx import compile_fx from torch._inductor.compile_fx import compile_fx
@ -1835,9 +1835,9 @@ class GraphModule(torch.nn.Module):
def inner_compile( def inner_compile(
gm: torch.fx.GraphModule, gm: torch.fx.GraphModule,
example_inputs: List[torch.Tensor], example_inputs: list[torch.Tensor],
cudagraphs: Optional[BoxedBool] = None, cudagraphs: Optional[BoxedBool] = None,
static_input_idxs: Optional[List[int]] = None, static_input_idxs: Optional[list[int]] = None,
is_backward: bool = False, is_backward: bool = False,
graph_id: Optional[int] = None, graph_id: Optional[int] = None,
cpp_wrapper: bool = False, cpp_wrapper: bool = False,
@ -1845,7 +1845,7 @@ class GraphModule(torch.nn.Module):
is_inference: bool = False, is_inference: bool = False,
boxed_forward_device_index: Optional[BoxedDeviceIndex] = None, boxed_forward_device_index: Optional[BoxedDeviceIndex] = None,
layout_opt: Optional[bool] = None, layout_opt: Optional[bool] = None,
extern_node_serializer: Optional[Callable[[List[Any]], Any]] = None, extern_node_serializer: Optional[Callable[[list[Any]], Any]] = None,
): ):
if dynamic: if dynamic:
self.assertEqual(static_input_idxs, [2, 3, 4]) self.assertEqual(static_input_idxs, [2, 3, 4])

View File

@ -1,7 +1,6 @@
# Owner(s): ["module: dynamo"] # Owner(s): ["module: dynamo"]
import sys import sys
import unittest import unittest
from typing import Dict, List
import torch import torch
import torch._dynamo.config import torch._dynamo.config
@ -23,7 +22,7 @@ except ImportError:
@torch._dynamo.config.patch(force_unspec_int_unbacked_size_like_on_torchrec_kjt=True) @torch._dynamo.config.patch(force_unspec_int_unbacked_size_like_on_torchrec_kjt=True)
class BucketizeMod(torch.nn.Module): class BucketizeMod(torch.nn.Module):
def __init__(self, feature_boundaries: Dict[str, List[float]]): def __init__(self, feature_boundaries: dict[str, list[float]]):
super().__init__() super().__init__()
self.bucket_w = torch.nn.ParameterDict() self.bucket_w = torch.nn.ParameterDict()
self.boundaries_dict = {} self.boundaries_dict = {}
@ -84,7 +83,7 @@ class TorchRecTests(TestCase):
@torch.compile(backend=counter, fullgraph=True, dynamic=True) @torch.compile(backend=counter, fullgraph=True, dynamic=True)
def f(id_list_features: KeyedJaggedTensor): def f(id_list_features: KeyedJaggedTensor):
id_list_jt_dict: Dict[str, JaggedTensor] = id_list_features.to_dict() id_list_jt_dict: dict[str, JaggedTensor] = id_list_features.to_dict()
pooled_embeddings = {} pooled_embeddings = {}
# TODO: run feature processor # TODO: run feature processor
for emb_module, feature_names in tables: for emb_module, feature_names in tables:

View File

@ -6,7 +6,7 @@ import math
import types import types
import unittest import unittest
import warnings import warnings
from typing import Any, Dict, Set from typing import Any
import torch import torch
import torch._dynamo.config as config import torch._dynamo.config as config
@ -103,10 +103,10 @@ class AllowedObjects:
from the heuristic defined in `gen_allowed_objs_and_ids`. from the heuristic defined in `gen_allowed_objs_and_ids`.
""" """
object_ids: Dict[int, str] object_ids: dict[int, str]
c_binding_in_graph_functions: Set[Any] c_binding_in_graph_functions: set[Any]
non_c_binding_in_graph_functions: Set[Any] non_c_binding_in_graph_functions: set[Any]
name_rule_map: Dict[str, Any] name_rule_map: dict[str, Any]
def gen_allowed_objs_and_ids(record=False, c_binding_only=True) -> AllowedObjects: def gen_allowed_objs_and_ids(record=False, c_binding_only=True) -> AllowedObjects:

View File

@ -2,7 +2,7 @@
import unittest import unittest
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Optional
import torch import torch
import torch.utils._pytree as pytree import torch.utils._pytree as pytree
@ -65,11 +65,11 @@ class TestConverter(TestCase):
self, self,
M, M,
tracing_inputs, tracing_inputs,
option: Optional[List[str]] = None, option: Optional[list[str]] = None,
check_persistent=False, check_persistent=False,
lifted_tensor_constants=None, lifted_tensor_constants=None,
runtime_inputs: Optional[List[Any]] = None, runtime_inputs: Optional[list[Any]] = None,
) -> List[ExportedProgram]: ) -> list[ExportedProgram]:
# By default, it tests both jit.trace and jit.script. # By default, it tests both jit.trace and jit.script.
if option is None: if option is None:
option = ["trace", "script"] option = ["trace", "script"]
@ -130,7 +130,7 @@ class TestConverter(TestCase):
self._check_tensor_list_equal(ep_out, orig_out) self._check_tensor_list_equal(ep_out, orig_out)
return ep_list return ep_list
def _check_tensor_list_equal(self, xs: List[torch.Tensor], ys: List[torch.Tensor]): def _check_tensor_list_equal(self, xs: list[torch.Tensor], ys: list[torch.Tensor]):
self.assertEqual(len(xs), len(ys)) self.assertEqual(len(xs), len(ys))
for x, y in zip(xs, ys): for x, y in zip(xs, ys):
if isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor): if isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor):
@ -219,7 +219,7 @@ class TestConverter(TestCase):
self._check_equal_ts_ep_converter(Module(), inp) self._check_equal_ts_ep_converter(Module(), inp)
class Module(torch.nn.Module): class Module(torch.nn.Module):
def forward(self, x: List[int]): def forward(self, x: list[int]):
length = len(x) length = len(x)
return torch.ones(length) return torch.ones(length)
@ -228,7 +228,7 @@ class TestConverter(TestCase):
self._check_equal_ts_ep_converter(Module(), inp, ["script"]) self._check_equal_ts_ep_converter(Module(), inp, ["script"])
class Module(torch.nn.Module): class Module(torch.nn.Module):
def forward(self, x: Dict[int, str]): def forward(self, x: dict[int, str]):
length = len(x) length = len(x)
return torch.ones(length) return torch.ones(length)
@ -237,7 +237,7 @@ class TestConverter(TestCase):
self._check_equal_ts_ep_converter(Module(), inp, ["script"]) self._check_equal_ts_ep_converter(Module(), inp, ["script"])
class Module(torch.nn.Module): class Module(torch.nn.Module):
def forward(self, x: Dict[bool, str]): def forward(self, x: dict[bool, str]):
length = len(x) length = len(x)
return torch.ones(length) return torch.ones(length)
@ -246,7 +246,7 @@ class TestConverter(TestCase):
self._check_equal_ts_ep_converter(Module(), inp, ["script"]) self._check_equal_ts_ep_converter(Module(), inp, ["script"])
class Module(torch.nn.Module): class Module(torch.nn.Module):
def forward(self, x: Dict[float, str]): def forward(self, x: dict[float, str]):
length = len(x) length = len(x)
return torch.ones(length) return torch.ones(length)
@ -255,7 +255,7 @@ class TestConverter(TestCase):
self._check_equal_ts_ep_converter(Module(), inp, ["script"]) self._check_equal_ts_ep_converter(Module(), inp, ["script"])
class Module(torch.nn.Module): class Module(torch.nn.Module):
def forward(self, x: Dict[torch.Tensor, str]): def forward(self, x: dict[torch.Tensor, str]):
length = len(x) length = len(x)
return torch.ones(length) return torch.ones(length)
@ -273,7 +273,7 @@ class TestConverter(TestCase):
def test_aten_add_t(self): def test_aten_add_t(self):
# python list append # python list append
class Module(torch.nn.Module): class Module(torch.nn.Module):
def forward(self, x: List[torch.Tensor]): def forward(self, x: list[torch.Tensor]):
out = [] out = []
out = out + x out = out + x
a = torch.cat(out) a = torch.cat(out)
@ -531,7 +531,7 @@ class TestConverter(TestCase):
class Module(torch.nn.Module): class Module(torch.nn.Module):
def forward( def forward(
self, x: torch.Tensor, y: torch.Tensor self, x: torch.Tensor, y: torch.Tensor
) -> Tuple[bool, torch.Tensor]: ) -> tuple[bool, torch.Tensor]:
z = x + 1 z = x + 1
return x is y, z return x is y, z
@ -546,7 +546,7 @@ class TestConverter(TestCase):
class Module(torch.nn.Module): class Module(torch.nn.Module):
def forward( def forward(
self, x: torch.Tensor, y: torch.Tensor self, x: torch.Tensor, y: torch.Tensor
) -> Tuple[bool, torch.Tensor]: ) -> tuple[bool, torch.Tensor]:
z = x + 1 z = x + 1
return x is not y, z return x is not y, z
@ -558,7 +558,7 @@ class TestConverter(TestCase):
class Module(torch.nn.Module): class Module(torch.nn.Module):
def forward( def forward(
self, x: torch.Tensor, y: torch.Tensor self, x: torch.Tensor, y: torch.Tensor
) -> Tuple[bool, torch.Tensor]: ) -> tuple[bool, torch.Tensor]:
z = x + 1 z = x + 1
return not (x is not y), z return not (x is not y), z
@ -573,7 +573,7 @@ class TestConverter(TestCase):
return x + y return x + y
class MUnpackTuple(torch.nn.Module): class MUnpackTuple(torch.nn.Module):
def forward(self, x_tuple: Tuple[torch.Tensor, torch.Tensor]): def forward(self, x_tuple: tuple[torch.Tensor, torch.Tensor]):
x, y = x_tuple x, y = x_tuple
x = x.cos() x = x.cos()
return x + y return x + y
@ -904,7 +904,7 @@ class TestConverter(TestCase):
return x.dtype in [torch.int8] return x.dtype in [torch.int8]
class MTensorIn(torch.nn.Module): class MTensorIn(torch.nn.Module):
def forward(self, x: torch.Tensor, x_dict: Dict[torch.Tensor, str]): def forward(self, x: torch.Tensor, x_dict: dict[torch.Tensor, str]):
return x in x_dict return x in x_dict
# Traced function must return output that has tensors. # Traced function must return output that has tensors.
@ -1118,14 +1118,14 @@ class TestConverter(TestCase):
def test_prim_tolist(self): def test_prim_tolist(self):
class Module(torch.nn.Module): class Module(torch.nn.Module):
def forward(self, x: torch.Tensor) -> List[int]: def forward(self, x: torch.Tensor) -> list[int]:
return x.tolist() return x.tolist()
inp = (torch.tensor([1, 2, 3]),) inp = (torch.tensor([1, 2, 3]),)
self._check_equal_ts_ep_converter(Module(), inp, ["script"]) self._check_equal_ts_ep_converter(Module(), inp, ["script"])
class Module(torch.nn.Module): class Module(torch.nn.Module):
def forward(self, x: torch.Tensor) -> List[List[int]]: def forward(self, x: torch.Tensor) -> list[list[int]]:
return x.tolist() return x.tolist()
inp = (torch.tensor([[1, 2, 3], [4, 5, 6]]),) inp = (torch.tensor([[1, 2, 3], [4, 5, 6]]),)
@ -1353,7 +1353,7 @@ class TestConverter(TestCase):
def test_aten_append_t(self): def test_aten_append_t(self):
class M(torch.nn.Module): class M(torch.nn.Module):
def forward(self, x: List[torch.Tensor]): def forward(self, x: list[torch.Tensor]):
out = [] out = []
out.append(x[0] + x[1]) out.append(x[0] + x[1])
out.append(x[0] - x[1]) out.append(x[0] - x[1])
@ -1381,7 +1381,7 @@ class TestConverter(TestCase):
self._check_equal_ts_ep_converter(M1(), inp, ["script"]) self._check_equal_ts_ep_converter(M1(), inp, ["script"])
def test_ts2ep_with_loop(self): def test_ts2ep_with_loop(self):
def func1(x, x_list: List[torch.Tensor]): def func1(x, x_list: list[torch.Tensor]):
a, b, c = x, x, x a, b, c = x, x, x
for _ in range(1, 5, 2): for _ in range(1, 5, 2):
for k in range(5): for k in range(5):

View File

@ -2,7 +2,6 @@
import copy import copy
import tempfile import tempfile
import unittest import unittest
from typing import List, Tuple
import torch import torch
from torch.export import Dim, export from torch.export import Dim, export
@ -325,7 +324,7 @@ class TestDraftExport(TestCase):
return torch.ops.mylib.foo(a) return torch.ops.mylib.foo(a)
@torch.library.custom_op("mylib::foo", mutates_args={}) @torch.library.custom_op("mylib::foo", mutates_args={})
def foo(a: torch.Tensor) -> List[torch.Tensor]: def foo(a: torch.Tensor) -> list[torch.Tensor]:
x = a * 2 x = a * 2
y = a.repeat(2, 2) y = a.repeat(2, 2)
z = a.to(torch.bfloat16) z = a.to(torch.bfloat16)
@ -370,7 +369,7 @@ class TestDraftExport(TestCase):
return torch.ops.mylib.foo(a) return torch.ops.mylib.foo(a)
@torch.library.custom_op("mylib::foo", mutates_args={}) @torch.library.custom_op("mylib::foo", mutates_args={})
def foo(a: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: def foo(a: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
return a * 2, a + 2 return a * 2, a + 2
@foo.register_fake @foo.register_fake

View File

@ -1,7 +1,7 @@
# Owner(s): ["oncall: export"] # Owner(s): ["oncall: export"]
import unittest import unittest
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Dict, Optional, Tuple from typing import Any, Optional
import torch import torch
from torch._export.passes.lift_constants_pass import ( from torch._export.passes.lift_constants_pass import (
@ -34,9 +34,9 @@ class GraphBuilder:
self.graph = torch.fx.Graph() self.graph = torch.fx.Graph()
self.nodes = {} self.nodes = {}
self.values = {} self.values = {}
self.nn_module_stack_key: Dict[str, int] = {} self.nn_module_stack_key: dict[str, int] = {}
self.latest_id = 0 self.latest_id = 0
self.input_to_kind: Dict[torch.fx.Node, InputKind] = {} self.input_to_kind: dict[torch.fx.Node, InputKind] = {}
def input(self, name: str, value: torch.Tensor, kind: InputKind): def input(self, name: str, value: torch.Tensor, kind: InputKind):
node = self.graph.placeholder(name) node = self.graph.placeholder(name)
@ -87,7 +87,7 @@ class GraphBuilder:
def create_nn_module_stack( def create_nn_module_stack(
self, module_fqn: str self, module_fqn: str
) -> OrderedDict[int, Tuple[str, type]]: ) -> OrderedDict[int, tuple[str, type]]:
cur_name = "" cur_name = ""
nn_module_stack = OrderedDict() nn_module_stack = OrderedDict()
for atom in module_fqn.split("."): for atom in module_fqn.split("."):

View File

@ -8,7 +8,6 @@ import math
import operator import operator
import unittest import unittest
from re import escape from re import escape
from typing import List, Set
import torch import torch
from functorch.experimental.control_flow import cond from functorch.experimental.control_flow import cond
@ -75,11 +74,11 @@ class _AtenAddOperatorSupport(OperatorSupport):
return node.op == "call_function" and node.target in {torch.ops.aten.add.Tensor} return node.op == "call_function" and node.target in {torch.ops.aten.add.Tensor}
def _to_partition_names(partitions: List[Partition]) -> List[Set[str]]: def _to_partition_names(partitions: list[Partition]) -> list[set[str]]:
return [{n.name for n in p.nodes} for p in partitions] return [{n.name for n in p.nodes} for p in partitions]
def _get_output_names(gm: torch.fx.GraphModule) -> List[str]: def _get_output_names(gm: torch.fx.GraphModule) -> list[str]:
output_node = next(n for n in gm.graph.nodes if n.op == "output") output_node = next(n for n in gm.graph.nodes if n.op == "output")
args = pytree.tree_leaves(output_node.args) args = pytree.tree_leaves(output_node.args)
# if isinstance(args, tuple) and len(args) == 1: # if isinstance(args, tuple) and len(args) == 1:

View File

@ -12,7 +12,7 @@ import unittest
import warnings import warnings
from contextlib import ContextDecorator, nullcontext from contextlib import ContextDecorator, nullcontext
from functools import partial, wraps from functools import partial, wraps
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable, Optional, Union
from unittest.mock import patch from unittest.mock import patch
from common_utils import decorate, decorateForModules, skip, skipOps, xfail from common_utils import decorate, decorateForModules, skip, skipOps, xfail
@ -319,8 +319,8 @@ class TestAOTAutograd(AOTTestCase):
def run_autograd( def run_autograd(
self, self,
f: Callable, f: Callable,
fw_graph_cell: List[Optional[Callable]], fw_graph_cell: list[Optional[Callable]],
decompositions: Optional[Dict], decompositions: Optional[dict],
keep_input_mutations: bool, keep_input_mutations: bool,
dynamic: bool, dynamic: bool,
): ):
@ -358,11 +358,11 @@ class TestAOTAutograd(AOTTestCase):
def verify_aot_autograd( def verify_aot_autograd(
self, self,
f, f,
inp_: Union[Callable, List[Any]], inp_: Union[Callable, list[Any]],
*, *,
test_mutation: bool = False, test_mutation: bool = False,
keep_inp_mutations: bool = False, keep_inp_mutations: bool = False,
decompositions: Optional[Dict] = None, decompositions: Optional[dict] = None,
dynamic: bool = False, dynamic: bool = False,
# Only active when inp_ is Callable. # Only active when inp_ is Callable.
# TODO: probably consolidate all tests to make inp a Callable. # TODO: probably consolidate all tests to make inp a Callable.
@ -6748,8 +6748,8 @@ class TestAOTAutogradWithDynamo(TestAOTAutograd):
def run_autograd( def run_autograd(
self, self,
f: Callable, f: Callable,
fw_graph_cell: List[Optional[Callable]], fw_graph_cell: list[Optional[Callable]],
decompositions: Optional[Dict], decompositions: Optional[dict],
keep_input_mutations: bool, keep_input_mutations: bool,
dynamic: bool, dynamic: bool,
): ):
@ -6880,8 +6880,8 @@ class TestAOTAutogradWithCache(TestAOTAutogradWithDynamo):
def run_autograd( def run_autograd(
self, self,
f: Callable, f: Callable,
fw_graph_cell: List[Optional[Callable]], fw_graph_cell: list[Optional[Callable]],
decompositions: Optional[Dict], decompositions: Optional[dict],
keep_input_mutations: bool, keep_input_mutations: bool,
dynamic: bool, dynamic: bool,
): ):
@ -6904,11 +6904,11 @@ class TestAOTAutogradWithCache(TestAOTAutogradWithDynamo):
def verify_aot_autograd( def verify_aot_autograd(
self, self,
f, f,
inp_: Union[Callable, List[Any]], inp_: Union[Callable, list[Any]],
*, *,
test_mutation: bool = False, test_mutation: bool = False,
keep_inp_mutations: bool = False, keep_inp_mutations: bool = False,
decompositions: Optional[Dict] = None, decompositions: Optional[dict] = None,
dynamic: bool = False, dynamic: bool = False,
# Only active when inp_ is Callable. # Only active when inp_ is Callable.
# TODO: probably consolidate all tests to make inp a Callable. # TODO: probably consolidate all tests to make inp a Callable.

View File

@ -24,7 +24,7 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE. SOFTWARE.
""" """
from typing import Any, Callable, Dict from typing import Any, Callable
from unittest import mock from unittest import mock
from functorch.einops._parsing import ( from functorch.einops._parsing import (
@ -206,7 +206,7 @@ class TestParsedExpression(TestCase):
class TestParsingUtils(TestCase): class TestParsingUtils(TestCase):
def test_parse_pattern_number_of_arrows(self) -> None: def test_parse_pattern_number_of_arrows(self) -> None:
axes_lengths: Dict[str, int] = {} axes_lengths: dict[str, int] = {}
too_many_arrows_pattern = "a -> b -> c -> d" too_many_arrows_pattern = "a -> b -> c -> d"
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
@ -220,13 +220,13 @@ class TestParsingUtils(TestCase):
parse_pattern(just_right_arrows, axes_lengths) parse_pattern(just_right_arrows, axes_lengths)
def test_ellipsis_invalid_identifier(self) -> None: def test_ellipsis_invalid_identifier(self) -> None:
axes_lengths: Dict[str, int] = {"a": 1, _ellipsis: 2} axes_lengths: dict[str, int] = {"a": 1, _ellipsis: 2}
pattern = f"a {_ellipsis} -> {_ellipsis} a" pattern = f"a {_ellipsis} -> {_ellipsis} a"
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
parse_pattern(pattern, axes_lengths) parse_pattern(pattern, axes_lengths)
def test_ellipsis_matching(self) -> None: def test_ellipsis_matching(self) -> None:
axes_lengths: Dict[str, int] = {} axes_lengths: dict[str, int] = {}
pattern = "a -> a ..." pattern = "a -> a ..."
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
@ -240,7 +240,7 @@ class TestParsingUtils(TestCase):
parse_pattern(pattern, axes_lengths) parse_pattern(pattern, axes_lengths)
def test_left_parenthesized_ellipsis(self) -> None: def test_left_parenthesized_ellipsis(self) -> None:
axes_lengths: Dict[str, int] = {} axes_lengths: dict[str, int] = {}
pattern = "(...) -> ..." pattern = "(...) -> ..."
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
@ -254,7 +254,7 @@ class MaliciousRepr:
class TestValidateRearrangeExpressions(TestCase): class TestValidateRearrangeExpressions(TestCase):
def test_validate_axes_lengths_are_integers(self) -> None: def test_validate_axes_lengths_are_integers(self) -> None:
axes_lengths: Dict[str, Any] = {"a": 1, "b": 2, "c": 3} axes_lengths: dict[str, Any] = {"a": 1, "b": 2, "c": 3}
pattern = "a b c -> c b a" pattern = "a b c -> c b a"
left, right = parse_pattern(pattern, axes_lengths) left, right = parse_pattern(pattern, axes_lengths)
validate_rearrange_expressions(left, right, axes_lengths) validate_rearrange_expressions(left, right, axes_lengths)
@ -265,7 +265,7 @@ class TestValidateRearrangeExpressions(TestCase):
validate_rearrange_expressions(left, right, axes_lengths) validate_rearrange_expressions(left, right, axes_lengths)
def test_non_unitary_anonymous_axes_raises_error(self) -> None: def test_non_unitary_anonymous_axes_raises_error(self) -> None:
axes_lengths: Dict[str, int] = {} axes_lengths: dict[str, int] = {}
left_non_unitary_axis = "a 2 -> 1 1 a" left_non_unitary_axis = "a 2 -> 1 1 a"
left, right = parse_pattern(left_non_unitary_axis, axes_lengths) left, right = parse_pattern(left_non_unitary_axis, axes_lengths)
@ -278,7 +278,7 @@ class TestValidateRearrangeExpressions(TestCase):
validate_rearrange_expressions(left, right, axes_lengths) validate_rearrange_expressions(left, right, axes_lengths)
def test_identifier_mismatch(self) -> None: def test_identifier_mismatch(self) -> None:
axes_lengths: Dict[str, int] = {} axes_lengths: dict[str, int] = {}
mismatched_identifiers = "a -> a b" mismatched_identifiers = "a -> a b"
left, right = parse_pattern(mismatched_identifiers, axes_lengths) left, right = parse_pattern(mismatched_identifiers, axes_lengths)
@ -291,7 +291,7 @@ class TestValidateRearrangeExpressions(TestCase):
validate_rearrange_expressions(left, right, axes_lengths) validate_rearrange_expressions(left, right, axes_lengths)
def test_unexpected_axes_lengths(self) -> None: def test_unexpected_axes_lengths(self) -> None:
axes_lengths: Dict[str, int] = {"c": 2} axes_lengths: dict[str, int] = {"c": 2}
pattern = "a b -> b a" pattern = "a b -> b a"
left, right = parse_pattern(pattern, axes_lengths) left, right = parse_pattern(pattern, axes_lengths)

View File

@ -25,7 +25,6 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE. SOFTWARE.
""" """
from typing import List, Tuple
import numpy as np import numpy as np
@ -34,7 +33,7 @@ from functorch.einops import rearrange
from torch.testing._internal.common_utils import run_tests, TestCase from torch.testing._internal.common_utils import run_tests, TestCase
identity_patterns: List[str] = [ identity_patterns: list[str] = [
"...->...", "...->...",
"a b c d e-> a b c d e", "a b c d e-> a b c d e",
"a b c d e ...-> ... a b c d e", "a b c d e ...-> ... a b c d e",
@ -45,7 +44,7 @@ identity_patterns: List[str] = [
"a ... c d e -> a (...) c d e", "a ... c d e -> a (...) c d e",
] ]
equivalent_rearrange_patterns: List[Tuple[str, str]] = [ equivalent_rearrange_patterns: list[tuple[str, str]] = [
("a b c d e -> (a b) c d e", "a b ... -> (a b) ... "), ("a b c d e -> (a b) c d e", "a b ... -> (a b) ... "),
("a b c d e -> a b (c d) e", "... c d e -> ... (c d) e"), ("a b c d e -> a b (c d) e", "... c d e -> ... (c d) e"),
("a b c d e -> a b c d e", "... -> ... "), ("a b c d e -> a b c d e", "... -> ... "),
@ -149,7 +148,7 @@ class TestRearrange(TestCase):
def test_concatenations_and_stacking(self) -> None: def test_concatenations_and_stacking(self) -> None:
for n_arrays in [1, 2, 5]: for n_arrays in [1, 2, 5]:
shapes: List[List[int]] = [[], [1], [1, 1], [2, 3, 5, 7], [1] * 6] shapes: list[list[int]] = [[], [1], [1, 1], [2, 3, 5, 7], [1] * 6]
for shape in shapes: for shape in shapes:
arrays1 = [ arrays1 = [
torch.arange(i, i + np.prod(shape, dtype=int)).reshape(shape) torch.arange(i, i + np.prod(shape, dtype=int)).reshape(shape)

View File

@ -1,7 +1,7 @@
# Owner(s): ["module: fx"] # Owner(s): ["module: fx"]
import copy import copy
import unittest import unittest
from typing import Optional, Set, Type from typing import Optional
import torch import torch
import torch.fx import torch.fx
@ -38,7 +38,7 @@ class TestDCE(TestCase):
self, self,
m: torch.nn.Module, m: torch.nn.Module,
expect_dce_changes: bool, expect_dce_changes: bool,
modules_to_be_leafs: Optional[Set[Type]] = None, modules_to_be_leafs: Optional[set[type]] = None,
custom: bool = False, custom: bool = False,
): ):
class TestTracer(torch.fx.Tracer): class TestTracer(torch.fx.Tracer):

View File

@ -2,8 +2,6 @@
from __future__ import annotations # type: ignore[attr-defined] from __future__ import annotations # type: ignore[attr-defined]
import typing
import torch import torch
from torch.fx import symbolic_trace from torch.fx import symbolic_trace
@ -27,13 +25,13 @@ class M2(torch.nn.Module):
# Non-torch annotation with no internal forward references # Non-torch annotation with no internal forward references
class M3(torch.nn.Module): class M3(torch.nn.Module):
def forward(self, x: typing.List[torch.Tensor], a: A) -> torch.Tensor: def forward(self, x: list[torch.Tensor], a: A) -> torch.Tensor:
return a(x[0]) return a(x[0])
# Non-torch annotation with internal forward references # Non-torch annotation with internal forward references
class M4(torch.nn.Module): class M4(torch.nn.Module):
def forward(self, x: typing.List[torch.Tensor], a: A) -> torch.Tensor: def forward(self, x: list[torch.Tensor], a: A) -> torch.Tensor:
return a(x[0]) return a(x[0])

View File

@ -1,7 +1,6 @@
# Owner(s): ["module: fx"] # Owner(s): ["module: fx"]
from collections import defaultdict from collections import defaultdict
from typing import Dict, List, Tuple
import torch import torch
from torch.fx.passes.split_utils import split_by_tags from torch.fx.passes.split_utils import split_by_tags
@ -63,8 +62,8 @@ class TestSplitByTags(TestCase):
@staticmethod @staticmethod
def trace_and_tag( def trace_and_tag(
module: torch.nn.Module, tags: List[str] module: torch.nn.Module, tags: list[str]
) -> Tuple[torch.fx.GraphModule, Dict[str, List[str]]]: ) -> tuple[torch.fx.GraphModule, dict[str, list[str]]]:
""" """
Test simple gm consists of nodes with tag (only show call_module nodes here): Test simple gm consists of nodes with tag (only show call_module nodes here):
linear1 - tag: "red" linear1 - tag: "red"
@ -167,8 +166,8 @@ class TestSplitOutputType(TestCase):
@staticmethod @staticmethod
def trace_and_tag( def trace_and_tag(
module: torch.nn.Module, inputs: torch.Tensor, tags: List[str] module: torch.nn.Module, inputs: torch.Tensor, tags: list[str]
) -> Tuple[torch.fx.GraphModule, Dict[str, List[str]]]: ) -> tuple[torch.fx.GraphModule, dict[str, list[str]]]:
""" """
Test simple gm consists of nodes with tag (only show call_module nodes here): Test simple gm consists of nodes with tag (only show call_module nodes here):
conv - tag: "red" conv - tag: "red"

View File

@ -4,7 +4,7 @@
import unittest import unittest
from collections import deque from collections import deque
from functools import partial from functools import partial
from typing import List, TYPE_CHECKING from typing import TYPE_CHECKING
import torch import torch
import torch._dynamo import torch._dynamo
@ -390,7 +390,7 @@ def forward(self, arg0_1, arg1_1, arg2_1):
return output return output
def add_hooks(module, config): def add_hooks(module, config):
handles: List[RemovableHandle] = [] handles: list[RemovableHandle] = []
q = deque([(module.__class__.__name__, module)]) q = deque([(module.__class__.__name__, module)])
while q: while q:
name, m = q.pop() name, m = q.pop()

View File

@ -5,7 +5,6 @@ import os
import sys import sys
import tempfile import tempfile
import unittest import unittest
from typing import Dict, Tuple
from unittest import skip from unittest import skip
import torch import torch
@ -1744,7 +1743,7 @@ class AOTInductorTestsTemplate:
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
def forward(self, x: Dict[str, torch.Tensor]): def forward(self, x: dict[str, torch.Tensor]):
device = next(iter(x.values())).device device = next(iter(x.values())).device
add_ = torch.zeros(5, device=device) add_ = torch.zeros(5, device=device)
mul_ = torch.ones(5, device=device) mul_ = torch.ones(5, device=device)
@ -2660,7 +2659,7 @@ class AOTInductorTestsTemplate:
def forward( def forward(
self, self,
self_tensor: torch.Tensor, self_tensor: torch.Tensor,
indices: Tuple[torch.Tensor], indices: tuple[torch.Tensor],
values: torch.Tensor, values: torch.Tensor,
): ):
return torch.index_put( return torch.index_put(
@ -4285,7 +4284,7 @@ def fail_cpu(is_skip=False):
) )
def fail_gpu(suffixes: Tuple[str, ...], is_skip=False): def fail_gpu(suffixes: tuple[str, ...], is_skip=False):
return TestFailure( return TestFailure(
suffixes, suffixes,
is_skip=is_skip, is_skip=is_skip,

View File

@ -4,7 +4,7 @@ import pickle
import shutil import shutil
import tempfile import tempfile
import unittest import unittest
from typing import List, Optional, Union from typing import Optional, Union
from unittest import mock from unittest import mock
import torch import torch
@ -1434,7 +1434,7 @@ class TestCudaCompileCommand(TestCase):
with mock.patch("subprocess.check_output") as check_output_mock: with mock.patch("subprocess.check_output") as check_output_mock:
CUDACodeCache.compile("test123.cu", "so", ["-Wsomething"]) CUDACodeCache.compile("test123.cu", "so", ["-Wsomething"])
check_output_mock.assert_called() check_output_mock.assert_called()
cmd_parts: List[str] = check_output_mock.call_args[0][0] cmd_parts: list[str] = check_output_mock.call_args[0][0]
assert cmd_parts[0] == "nvcc", cmd_parts assert cmd_parts[0] == "nvcc", cmd_parts
assert "-Wsomething" in cmd_parts, cmd_parts assert "-Wsomething" in cmd_parts, cmd_parts
assert "-DNDEBUG" in cmd_parts, cmd_parts assert "-DNDEBUG" in cmd_parts, cmd_parts

View File

@ -1,6 +1,6 @@
# Owner(s): ["module: inductor"] # Owner(s): ["module: inductor"]
import unittest import unittest
from typing import Any, Dict, List, Type from typing import Any
import sympy import sympy
@ -160,11 +160,11 @@ class TestFixedConfigs(TestCase):
class MyHeuristics(InductorChoices): class MyHeuristics(InductorChoices):
def triton_kernel_kwargs( def triton_kernel_kwargs(
self, self,
kernel_cls: Type[TritonKernel], kernel_cls: type[TritonKernel],
features: SIMDKernelFeatures, features: SIMDKernelFeatures,
groups: List[sympy.Expr], groups: list[sympy.Expr],
kernel_kwargs: Dict[str, Any], kernel_kwargs: dict[str, Any],
) -> Dict[str, Any]: ) -> dict[str, Any]:
return { return {
**kernel_kwargs, **kernel_kwargs,
"override_cooperative_reduction": cooperative, "override_cooperative_reduction": cooperative,

View File

@ -3,7 +3,7 @@ import logging
import math import math
import os import os
import unittest import unittest
from typing import Callable, List, Optional from typing import Callable, Optional
from unittest import mock from unittest import mock
from torch.export import Dim from torch.export import Dim
@ -114,7 +114,7 @@ class TestCutlassBackend(TestCase):
) as mocked_select_algorithm: ) as mocked_select_algorithm:
Y_compiled = torch.compile(mm, dynamic=False)(a, b) Y_compiled = torch.compile(mm, dynamic=False)(a, b)
Y = mm(a, b) Y = mm(a, b)
passed_choice_callers: List[ChoiceCaller] = mocked_select_algorithm[0][ passed_choice_callers: list[ChoiceCaller] = mocked_select_algorithm[0][
1 1
] ]
assert all( assert all(
@ -573,7 +573,7 @@ class TestCutlassBackend(TestCase):
return torch.addmm(x, a, b, alpha=alpha, beta=beta) return torch.addmm(x, a, b, alpha=alpha, beta=beta)
def compare_results( def compare_results(
m: int, k: int, n: int, alpha: float, beta: float, x_shape: List[int] m: int, k: int, n: int, alpha: float, beta: float, x_shape: list[int]
) -> None: ) -> None:
x = torch.randn(x_shape).cuda().half() x = torch.randn(x_shape).cuda().half()
a = torch.randn(m, k).cuda().half() a = torch.randn(m, k).cuda().half()

View File

@ -9,7 +9,7 @@ from collections import namedtuple
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from itertools import product from itertools import product
from typing import Callable, List, Optional, Tuple, Union from typing import Callable, Optional, Union
from unittest import expectedFailure, skip, skipUnless from unittest import expectedFailure, skip, skipUnless
from unittest.mock import patch from unittest.mock import patch
@ -511,7 +511,7 @@ class TestFlexAttention(InductorTestCase):
block_mask, block_mask,
dtype: torch.dtype = torch.float16, dtype: torch.dtype = torch.float16,
page_size: int = 128, page_size: int = 128,
) -> Tuple[Tensor, Tensor, BlockMask, _score_mod_signature]: ) -> tuple[Tensor, Tensor, BlockMask, _score_mod_signature]:
assert block_mask is not None, "Must provide block_mask" assert block_mask is not None, "Must provide block_mask"
Q_B, Q_H, Q_S, _ = q.shape Q_B, Q_H, Q_S, _ = q.shape
KV_B, KV_H, KV_S, QK_D = k.shape KV_B, KV_H, KV_S, QK_D = k.shape
@ -596,7 +596,7 @@ class TestFlexAttention(InductorTestCase):
v: Tensor, v: Tensor,
dtype: torch.dtype = torch.float16, dtype: torch.dtype = torch.float16,
block_mask: Optional[BlockMask] = None, block_mask: Optional[BlockMask] = None,
) -> Tuple[Tensor, Tensor]: ) -> tuple[Tensor, Tensor]:
B, Q_H, Q_S, KV_H, KV_S = ( B, Q_H, Q_S, KV_H, KV_S = (
q.shape[0], q.shape[0],
q.shape[1], q.shape[1],
@ -797,7 +797,7 @@ class TestFlexAttention(InductorTestCase):
def run_dynamic_test( def run_dynamic_test(
self, self,
score_mask_mod: Tuple[Callable, Callable], score_mask_mod: tuple[Callable, Callable],
dtype: torch.dtype = torch.float16, dtype: torch.dtype = torch.float16,
B: int = B, B: int = B,
H: int = H, H: int = H,
@ -1089,7 +1089,7 @@ class TestFlexAttention(InductorTestCase):
@common_utils.parametrize("dtype", test_dtypes_fast) @common_utils.parametrize("dtype", test_dtypes_fast)
@common_utils.parametrize("score_mask_mod", test_score_mask_mod_map.items()) @common_utils.parametrize("score_mask_mod", test_score_mask_mod_map.items())
def test_builtin_score_mods_dynamic( def test_builtin_score_mods_dynamic(
self, dtype: torch.dtype, score_mask_mod: Tuple[Callable, Callable] self, dtype: torch.dtype, score_mask_mod: tuple[Callable, Callable]
): ):
self.run_dynamic_test(score_mask_mod, dtype) self.run_dynamic_test(score_mask_mod, dtype)
@ -1127,7 +1127,7 @@ class TestFlexAttention(InductorTestCase):
self, self,
dtype: torch.dtype, dtype: torch.dtype,
score_mod: Callable, score_mod: Callable,
BLOCK_SIZE: Union[int, Tuple[int, int]], BLOCK_SIZE: Union[int, tuple[int, int]],
): ):
block_mask = create_block_mask( block_mask = create_block_mask(
noop_mask, B, H, S, S, BLOCK_SIZE=BLOCK_SIZE, device=self.device noop_mask, B, H, S, S, BLOCK_SIZE=BLOCK_SIZE, device=self.device
@ -1142,8 +1142,8 @@ class TestFlexAttention(InductorTestCase):
def test_kv_batch_broadcast( def test_kv_batch_broadcast(
self, self,
dtype: torch.dtype, dtype: torch.dtype,
batch_dims: Tuple[int, int], batch_dims: tuple[int, int],
head_dims: Tuple[int, int], head_dims: tuple[int, int],
score_mod: Callable, score_mod: Callable,
): ):
Hq, Hkv = head_dims Hq, Hkv = head_dims
@ -1175,8 +1175,8 @@ class TestFlexAttention(InductorTestCase):
def test_kv_batch_broadcast_causal_mask( def test_kv_batch_broadcast_causal_mask(
self, self,
dtype: torch.dtype, dtype: torch.dtype,
batch_dims: Tuple[int, int], batch_dims: tuple[int, int],
head_dims: Tuple[int, int], head_dims: tuple[int, int],
score_mod: Callable, score_mod: Callable,
): ):
Hq, Hkv = head_dims Hq, Hkv = head_dims
@ -3616,7 +3616,7 @@ class TestBlockMask(InductorTestCase):
@supported_platform @supported_platform
@common_utils.parametrize("BLOCK_SIZE", [32, 64, 128, 256, (32, 64), (64, 32)]) @common_utils.parametrize("BLOCK_SIZE", [32, 64, 128, 256, (32, 64), (64, 32)])
def test_block_size_changes(self, BLOCK_SIZE: Union[int, Tuple[int, int]]): def test_block_size_changes(self, BLOCK_SIZE: Union[int, tuple[int, int]]):
B, H, Q_LEN, KV_LEN = 4, 2, 2048, 2048 B, H, Q_LEN, KV_LEN = 4, 2, 2048, 2048
if isinstance(BLOCK_SIZE, int): if isinstance(BLOCK_SIZE, int):
@ -3990,7 +3990,7 @@ BlockMask(shape=(1,s1,s2048,s2048),ssparsity=46.88%,s
) )
def length_to_offsets( def length_to_offsets(
lengths: List[int], device: Union[str, torch.device] lengths: list[int], device: Union[str, torch.device]
) -> Tensor: ) -> Tensor:
offsets = [0] offsets = [0]
offsets.extend(lengths) offsets.extend(lengths)
@ -4561,7 +4561,7 @@ class Params:
return f"batch:{self.batch_size}_head:{self.num_heads}_seq_len:{self.seq_length}_headdim:{self.head_dim}_dtype:{str(self.dtype).split('.')[-1]}" return f"batch:{self.batch_size}_head:{self.num_heads}_seq_len:{self.seq_length}_headdim:{self.head_dim}_dtype:{str(self.dtype).split('.')[-1]}"
def get_params(dtypes: List[torch.dtype]) -> List[Params]: def get_params(dtypes: list[torch.dtype]) -> list[Params]:
params = [] params = []
seq_lengths = [37, 256, 277] seq_lengths = [37, 256, 277]
for seq_len, dtype in product(seq_lengths, dtypes): for seq_len, dtype in product(seq_lengths, dtypes):

View File

@ -3,7 +3,7 @@
import functools import functools
from collections import namedtuple from collections import namedtuple
from typing import Callable, Optional, Tuple, Union from typing import Callable, Optional, Union
from unittest import expectedFailure, skipUnless from unittest import expectedFailure, skipUnless
from unittest.mock import patch from unittest.mock import patch
@ -645,7 +645,7 @@ class TestFlexDecoding(InductorTestCase):
self, self,
dtype: torch.dtype, dtype: torch.dtype,
score_mod: Callable, score_mod: Callable,
head_dims: Tuple[int, int], head_dims: tuple[int, int],
page_size: int, page_size: int,
): ):
Hq, Hkv = head_dims Hq, Hkv = head_dims
@ -681,7 +681,7 @@ class TestFlexDecoding(InductorTestCase):
self, self,
dtype: torch.dtype, dtype: torch.dtype,
score_mod: Callable, score_mod: Callable,
BLOCK_SIZE: Union[int, Tuple[int, int]], BLOCK_SIZE: Union[int, tuple[int, int]],
): ):
block_mask = create_block_mask(noop_mask, B, 1, 1, S, BLOCK_SIZE=BLOCK_SIZE) block_mask = create_block_mask(noop_mask, B, 1, 1, S, BLOCK_SIZE=BLOCK_SIZE)
self.run_test(score_mod, dtype, block_mask=block_mask) self.run_test(score_mod, dtype, block_mask=block_mask)
@ -763,8 +763,8 @@ class TestFlexDecoding(InductorTestCase):
def test_kv_batch_broadcast( def test_kv_batch_broadcast(
self, self,
dtype: torch.dtype, dtype: torch.dtype,
head_dims: Tuple[int, int], head_dims: tuple[int, int],
batch_dims: Tuple[int, int], batch_dims: tuple[int, int],
score_mod: Callable, score_mod: Callable,
): ):
Hq, Hkv = head_dims Hq, Hkv = head_dims

View File

@ -2,7 +2,7 @@
import functools import functools
import unittest import unittest
from typing import List, Tuple, Union from typing import Union
import torch import torch
from torch import Tensor from torch import Tensor
@ -89,8 +89,8 @@ def _quantize_rowwise(x: Tensor, float8_dtype: torch.dtype):
def _fix_fp8_dtype_for_rocm( def _fix_fp8_dtype_for_rocm(
dtype: Union[torch.dtype, List[torch.dtype], Tuple[torch.dtype]], device dtype: Union[torch.dtype, list[torch.dtype], tuple[torch.dtype]], device
) -> Union[torch.dtype, List[torch.dtype], Tuple[torch.dtype]]: ) -> Union[torch.dtype, list[torch.dtype], tuple[torch.dtype]]:
# This function is used to change FP8 data types # This function is used to change FP8 data types
# with MI300 supported FP8 types if device is GPU: # with MI300 supported FP8 types if device is GPU:
# e4m3fn -> e4m3fnuz # e4m3fn -> e4m3fnuz

View File

@ -2,7 +2,7 @@
import sys import sys
import unittest import unittest
from typing import List, Literal from typing import Literal
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import torch import torch
@ -50,8 +50,8 @@ class TestConfigFuzzer(TestCase):
self.assertEqual(toggle("", bool, True), False) self.assertEqual(toggle("", bool, True), False)
self.assertEqual(toggle("", Literal["foo", "bar"], "foo"), "bar") self.assertEqual(toggle("", Literal["foo", "bar"], "foo"), "bar")
self.assertEqual(toggle("", Literal["foo", "bar"], "bar"), "foo") self.assertEqual(toggle("", Literal["foo", "bar"], "bar"), "foo")
self.assertTrue("bar" in toggle("", List[Literal["foo", "bar"]], ["foo"])) self.assertTrue("bar" in toggle("", list[Literal["foo", "bar"]], ["foo"]))
self.assertTrue("foo" in toggle("", List[Literal["foo", "bar"]], ["bar"])) self.assertTrue("foo" in toggle("", list[Literal["foo", "bar"]], ["bar"]))
@unittest.skipIf(sys.version_info < (3, 10), "python < 3.10 not supported") @unittest.skipIf(sys.version_info < (3, 10), "python < 3.10 not supported")
def test_sampling_method_random(self): def test_sampling_method_random(self):

View File

@ -2,7 +2,6 @@
import collections import collections
import unittest import unittest
from typing import List
import torch import torch
import torch._inductor import torch._inductor
@ -39,7 +38,7 @@ class TestHighwaySelfGating(torch.nn.Module):
def forward( def forward(
self, self,
inputs: List[torch.Tensor], inputs: list[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
results = [] results = []
for i in range(self.size): for i in range(self.size):

Some files were not shown because too many files have changed in this diff Show More