mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
PEP585 update - test (#145176)
See #145101 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145176 Approved by: https://github.com/bobrenjc93
This commit is contained in:
parent
40e27fbcf2
commit
99dbc5b0e2
|
|
@ -7,7 +7,7 @@ import re
|
|||
import textwrap
|
||||
import timeit
|
||||
import unittest
|
||||
from typing import Any, List, Tuple
|
||||
from typing import Any
|
||||
|
||||
import expecttest
|
||||
import numpy as np
|
||||
|
|
@ -67,7 +67,7 @@ def generate_callgrind_artifacts() -> None:
|
|||
|
||||
|
||||
def load_callgrind_artifacts() -> (
|
||||
Tuple[benchmark_utils.CallgrindStats, benchmark_utils.CallgrindStats]
|
||||
tuple[benchmark_utils.CallgrindStats, benchmark_utils.CallgrindStats]
|
||||
):
|
||||
"""Hermetic artifact to unit test Callgrind wrapper.
|
||||
|
||||
|
|
@ -85,9 +85,9 @@ def load_callgrind_artifacts() -> (
|
|||
pattern = re.compile(r"^\s*([0-9]+)\s(.+)$")
|
||||
|
||||
def to_function_counts(
|
||||
count_strings: List[str], inclusive: bool
|
||||
count_strings: list[str], inclusive: bool
|
||||
) -> benchmark_utils.FunctionCounts:
|
||||
data: List[benchmark_utils.FunctionCount] = []
|
||||
data: list[benchmark_utils.FunctionCount] = []
|
||||
for cs in count_strings:
|
||||
# Storing entries as f"{c} {fn}" rather than [c, fn] adds some work
|
||||
# reviving the artifact, but it makes the json much easier to read.
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import sys
|
|||
import xml.etree.ElementTree as ET
|
||||
from collections import defaultdict
|
||||
from types import MethodType
|
||||
from typing import Any, List, Optional, TYPE_CHECKING, Union
|
||||
from typing import Any, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import pytest
|
||||
from _pytest.config import Config, filename_arg
|
||||
|
|
@ -241,7 +241,7 @@ def pytest_report_teststatus(report, config):
|
|||
|
||||
|
||||
@pytest.hookimpl(trylast=True)
|
||||
def pytest_collection_modifyitems(items: List[Any]) -> None:
|
||||
def pytest_collection_modifyitems(items: list[Any]) -> None:
|
||||
"""
|
||||
This hook is used when rerunning disabled tests to get rid of all skipped tests
|
||||
instead of running and skipping them N times. This avoids flooding the console
|
||||
|
|
@ -304,7 +304,7 @@ class StepcurrentPlugin:
|
|||
self.skip: bool = config.getoption("stepcurrent_skip")
|
||||
self.run_single: bool = config.getoption("run_single")
|
||||
|
||||
def pytest_collection_modifyitems(self, config: Config, items: List[Any]) -> None:
|
||||
def pytest_collection_modifyitems(self, config: Config, items: list[Any]) -> None:
|
||||
if not self.lastrun:
|
||||
self.report_status = "Cannot find last run test, not skipping"
|
||||
return
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
from typing import List
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
|
|
@ -8,7 +6,7 @@ lib = torch.library._scoped_library("python_agnostic", "FRAGMENT")
|
|||
lib.define("ultra_norm(Tensor[] inputs) -> Tensor")
|
||||
|
||||
|
||||
def ultra_norm(inputs: List[Tensor]) -> Tensor:
|
||||
def ultra_norm(inputs: list[Tensor]) -> Tensor:
|
||||
"""
|
||||
Computes the ultra-L2-norm of a list of tensors via computing the norm of norms.
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
from typing import List, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor, types
|
||||
|
|
@ -87,7 +87,7 @@ class TestInferSchemaWithAnnotation(TestCase):
|
|||
result = torch.library.infer_schema(foo_op_2, mutates_args=mutates_args)
|
||||
self.assertEqual(result, "(SymInt[] x) -> SymInt")
|
||||
|
||||
def foo_op_3(x: typing.List[int]) -> int:
|
||||
def foo_op_3(x: list[int]) -> int:
|
||||
return 1
|
||||
|
||||
result = torch.library.infer_schema(foo_op_3, mutates_args=mutates_args)
|
||||
|
|
@ -99,7 +99,7 @@ class TestInferSchemaWithAnnotation(TestCase):
|
|||
result = torch.library.infer_schema(foo_op_4, mutates_args=mutates_args)
|
||||
self.assertEqual(result, "(SymInt[]? x) -> SymInt")
|
||||
|
||||
def foo_op_5(x: typing.Optional[typing.List[int]]) -> int:
|
||||
def foo_op_5(x: typing.Optional[list[int]]) -> int:
|
||||
return 1
|
||||
|
||||
result = torch.library.infer_schema(foo_op_5, mutates_args=mutates_args)
|
||||
|
|
@ -136,7 +136,7 @@ class TestInferSchemaWithAnnotation(TestCase):
|
|||
result = torch.library.infer_schema(foo_op_3, mutates_args=mutates_args)
|
||||
self.assertEqual(result, "(Tensor x) -> Tensor")
|
||||
|
||||
def foo_op_4(x: List[int]) -> types.Number:
|
||||
def foo_op_4(x: list[int]) -> types.Number:
|
||||
return x[0]
|
||||
|
||||
result = torch.library.infer_schema(foo_op_4, mutates_args=mutates_args)
|
||||
|
|
@ -154,7 +154,7 @@ class TestInferSchemaWithAnnotation(TestCase):
|
|||
result = torch.library.infer_schema(foo_op_6, mutates_args=mutates_args)
|
||||
self.assertEqual(result, "(SymInt[] x) -> SymInt")
|
||||
|
||||
def foo_op_7(x: List[int]) -> int:
|
||||
def foo_op_7(x: list[int]) -> int:
|
||||
return 1
|
||||
|
||||
result = torch.library.infer_schema(foo_op_7, mutates_args=mutates_args)
|
||||
|
|
@ -166,7 +166,7 @@ class TestInferSchemaWithAnnotation(TestCase):
|
|||
result = torch.library.infer_schema(foo_op_8, mutates_args=mutates_args)
|
||||
self.assertEqual(result, "(SymInt[]? x) -> SymInt")
|
||||
|
||||
def foo_op_9(x: Optional[List[int]]) -> int:
|
||||
def foo_op_9(x: Optional[list[int]]) -> int:
|
||||
return 1
|
||||
|
||||
result = torch.library.infer_schema(foo_op_9, mutates_args=mutates_args)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import copy
|
|||
import functools
|
||||
import itertools
|
||||
import unittest
|
||||
from typing import Any, List, Optional, Type, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -117,7 +117,7 @@ class TestFullyShardAutograd(FSDPTest):
|
|||
local_inp = global_inp[
|
||||
self.rank * local_batch_size : (self.rank + 1) * local_batch_size
|
||||
].detach()
|
||||
losses: List[torch.Tensor] = []
|
||||
losses: list[torch.Tensor] = []
|
||||
for _model, inp in ((ref_model, global_inp), (model, local_inp)):
|
||||
losses.append(_model(inp).sum())
|
||||
losses[-1].backward()
|
||||
|
|
@ -141,7 +141,7 @@ class TestFullyShardAutograd(FSDPTest):
|
|||
self._test_nontensor_activations,
|
||||
)
|
||||
|
||||
def _test_nontensor_activations(self, container_type: Type):
|
||||
def _test_nontensor_activations(self, container_type: type):
|
||||
class Module(nn.Module):
|
||||
def __init__(self, dim: int):
|
||||
super().__init__()
|
||||
|
|
@ -170,7 +170,7 @@ class TestFullyShardAutograd(FSDPTest):
|
|||
return self.relu(self.lin2(self.relu(self.lin1(x))))
|
||||
|
||||
class ToContainerType(nn.Module):
|
||||
def __init__(self, container_type: Type):
|
||||
def __init__(self, container_type: type):
|
||||
super().__init__()
|
||||
self.container_type = container_type
|
||||
|
||||
|
|
@ -190,7 +190,7 @@ class TestFullyShardAutograd(FSDPTest):
|
|||
)
|
||||
|
||||
class FromContainerType(nn.Module):
|
||||
def __init__(self, container_type: Type):
|
||||
def __init__(self, container_type: type):
|
||||
super().__init__()
|
||||
self.container_type = container_type
|
||||
|
||||
|
|
@ -227,7 +227,7 @@ class TestFullyShardAutograd(FSDPTest):
|
|||
local_inp = global_inp[
|
||||
self.rank * local_batch_size : (self.rank + 1) * local_batch_size
|
||||
].detach()
|
||||
losses: List[torch.Tensor] = []
|
||||
losses: list[torch.Tensor] = []
|
||||
for _model, inp in ((ref_model, global_inp), (model, local_inp)):
|
||||
losses.append(_model(inp).sum())
|
||||
losses[-1].backward()
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import copy
|
|||
import functools
|
||||
import itertools
|
||||
import unittest
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -58,7 +58,7 @@ from torch.testing._internal.distributed._tensor.common_dtensor import (
|
|||
c10d_ops = torch.ops.c10d
|
||||
|
||||
# For recording FSDP events like unshard or post-backward
|
||||
EventType = Tuple[str, str, TrainingState]
|
||||
EventType = tuple[str, str, TrainingState]
|
||||
|
||||
|
||||
class TestFullyShardCollectiveOps(FSDPTestMultiThread):
|
||||
|
|
@ -70,7 +70,7 @@ class TestFullyShardCollectiveOps(FSDPTestMultiThread):
|
|||
def device(self) -> torch.device:
|
||||
return torch.device("cuda:0")
|
||||
|
||||
def _get_param_sizes(self) -> List[torch.Size]:
|
||||
def _get_param_sizes(self) -> list[torch.Size]:
|
||||
# For world size 128, the fp32 all-gather and reduce-scatter testing
|
||||
# requires ~0.22 GB
|
||||
return [
|
||||
|
|
@ -84,7 +84,7 @@ class TestFullyShardCollectiveOps(FSDPTestMultiThread):
|
|||
torch.Size([64, 297]),
|
||||
]
|
||||
|
||||
def _init_params(self, param_sizes: List[torch.Size]) -> List[nn.Parameter]:
|
||||
def _init_params(self, param_sizes: list[torch.Size]) -> list[nn.Parameter]:
|
||||
torch.manual_seed(42)
|
||||
orig_params = [
|
||||
nn.Parameter(torch.randn(size, device=self.device)) for size in param_sizes
|
||||
|
|
@ -96,7 +96,7 @@ class TestFullyShardCollectiveOps(FSDPTestMultiThread):
|
|||
return orig_params
|
||||
|
||||
def _init_fsdp_param_group(
|
||||
self, params: List[nn.Parameter], reshard_after_forward: Union[bool, int]
|
||||
self, params: list[nn.Parameter], reshard_after_forward: Union[bool, int]
|
||||
):
|
||||
module = nn.ParameterList([param.detach().clone() for param in params])
|
||||
mesh_info = FSDPMeshInfo(_init_default_fully_shard_mesh(), shard_mesh_dim=0)
|
||||
|
|
@ -143,7 +143,7 @@ class TestFullyShardCollectiveOps(FSDPTestMultiThread):
|
|||
|
||||
def _test_all_gather(
|
||||
self,
|
||||
param_sizes: List[torch.Size],
|
||||
param_sizes: list[torch.Size],
|
||||
reshard_after_forward: Union[bool, int],
|
||||
async_op: bool,
|
||||
all_gather_copy_in_stream: torch.cuda.Stream,
|
||||
|
|
@ -165,7 +165,7 @@ class TestFullyShardCollectiveOps(FSDPTestMultiThread):
|
|||
fsdp_param_group._to_unsharded()
|
||||
|
||||
def check_all_gathered_params(
|
||||
orig_params: List[nn.Parameter], module: nn.Module
|
||||
orig_params: list[nn.Parameter], module: nn.Module
|
||||
):
|
||||
for orig_param, param in zip(orig_params, module.parameters()):
|
||||
self.assertIsInstance(param, torch.Tensor)
|
||||
|
|
@ -228,7 +228,7 @@ class TestFullyShardCollectiveOps(FSDPTestMultiThread):
|
|||
|
||||
def _test_reduce_scatter(
|
||||
self,
|
||||
param_sizes: List[torch.Size],
|
||||
param_sizes: list[torch.Size],
|
||||
reduce_scatter_stream: torch.cuda.Stream,
|
||||
reduce_scatter_dtype: torch.dtype,
|
||||
):
|
||||
|
|
@ -453,7 +453,7 @@ class TestFullyShardPrefetch(FSDPTest):
|
|||
model, optim, inp = self._init_transformer(
|
||||
n_layers, reshard_after_forward, checkpoint_impl
|
||||
)
|
||||
events: List[EventType] = []
|
||||
events: list[EventType] = []
|
||||
unshard_with_record = self._get_unshard_with_record(
|
||||
FSDPParamGroup.unshard, events
|
||||
)
|
||||
|
|
@ -504,7 +504,7 @@ class TestFullyShardPrefetch(FSDPTest):
|
|||
model, _, inp = self._init_transformer(
|
||||
n_layers, reshard_after_forward, checkpoint_impl
|
||||
)
|
||||
events: List[EventType] = []
|
||||
events: list[EventType] = []
|
||||
unshard_with_record = self._get_unshard_with_record(
|
||||
FSDPParamGroup.unshard, events
|
||||
)
|
||||
|
|
@ -582,7 +582,7 @@ class TestFullyShardPrefetch(FSDPTest):
|
|||
fully_shard(model[1].lin2, reshard_after_forward=reshard_after_forward)
|
||||
fully_shard(model, reshard_after_forward=reshard_after_forward)
|
||||
inp = torch.randn((4, dim), device="cuda")
|
||||
events: List[EventType] = []
|
||||
events: list[EventType] = []
|
||||
unshard_with_record = self._get_unshard_with_record(
|
||||
FSDPParamGroup.unshard, events
|
||||
)
|
||||
|
|
@ -652,7 +652,7 @@ class TestFullyShardPrefetch(FSDPTest):
|
|||
]
|
||||
layer.set_modules_to_forward_prefetch(layers_to_prefetch)
|
||||
|
||||
events: List[EventType] = []
|
||||
events: list[EventType] = []
|
||||
unshard_with_record = self._get_unshard_with_record(
|
||||
FSDPParamGroup.unshard, events
|
||||
)
|
||||
|
|
@ -742,7 +742,7 @@ class TestFullyShardPrefetch(FSDPTest):
|
|||
]
|
||||
layer.set_modules_to_backward_prefetch(layers_to_prefetch)
|
||||
|
||||
events: List[EventType] = []
|
||||
events: list[EventType] = []
|
||||
unshard_with_record = self._get_unshard_with_record(
|
||||
FSDPParamGroup.unshard, events
|
||||
)
|
||||
|
|
@ -834,7 +834,7 @@ class TestFullyShardPrefetch(FSDPTest):
|
|||
fully_shard(model)
|
||||
optim = torch.optim.AdamW(model.parameters(), lr=1e-2)
|
||||
|
||||
events: List[EventType] = []
|
||||
events: list[EventType] = []
|
||||
unshard_with_record = self._get_unshard_with_record(
|
||||
FSDPParamGroup.unshard, events
|
||||
)
|
||||
|
|
@ -915,7 +915,7 @@ class TestFullyShardPrefetch(FSDPTest):
|
|||
fully_shard(model)
|
||||
optim = torch.optim.AdamW(model.parameters(), lr=1e-2)
|
||||
|
||||
events: List[EventType] = []
|
||||
events: list[EventType] = []
|
||||
unshard_with_record = self._get_unshard_with_record(
|
||||
FSDPParamGroup.unshard, events
|
||||
)
|
||||
|
|
@ -1011,7 +1011,7 @@ class TestFullyShardPrefetch(FSDPTest):
|
|||
return model, optim, inp
|
||||
|
||||
def _get_unshard_with_record(
|
||||
self, orig_unshard: Callable, events: List[EventType]
|
||||
self, orig_unshard: Callable, events: list[EventType]
|
||||
) -> Callable:
|
||||
def unshard_with_record(self, *args, **kwargs):
|
||||
nonlocal events
|
||||
|
|
@ -1025,7 +1025,7 @@ class TestFullyShardPrefetch(FSDPTest):
|
|||
return unshard_with_record
|
||||
|
||||
def _get_reshard_with_record(
|
||||
self, orig_reshard: Callable, events: List[EventType]
|
||||
self, orig_reshard: Callable, events: list[EventType]
|
||||
) -> Callable:
|
||||
def reshard_with_record(self, *args, **kwargs):
|
||||
nonlocal events
|
||||
|
|
@ -1040,7 +1040,7 @@ class TestFullyShardPrefetch(FSDPTest):
|
|||
return reshard_with_record
|
||||
|
||||
def _get_post_backward_with_record(
|
||||
self, orig_post_backward: Callable, events: List[EventType]
|
||||
self, orig_post_backward: Callable, events: list[EventType]
|
||||
) -> Callable:
|
||||
def post_backward_with_record(self, *args, **kwargs):
|
||||
nonlocal events
|
||||
|
|
@ -1080,7 +1080,7 @@ class TestFullyShardUnshardMultiProcess(FSDPTest):
|
|||
self.mlp2 = MLP(dim)
|
||||
self.mlp3 = MLP(dim)
|
||||
|
||||
def forward(self, ys: List[torch.Tensor], works: List[dist.Work]):
|
||||
def forward(self, ys: list[torch.Tensor], works: list[dist.Work]):
|
||||
(y1, y2, y3), (work1, work2, work3) = ys, works
|
||||
work1.wait()
|
||||
z1 = self.mlp1(y1)
|
||||
|
|
@ -1126,7 +1126,7 @@ class TestFullyShardUnshardMultiProcess(FSDPTest):
|
|||
torch.manual_seed(42 + self.rank + 1)
|
||||
inp = torch.randn((batch_size, dim), device="cuda")
|
||||
for _ in range(10):
|
||||
losses: List[torch.Tensor] = []
|
||||
losses: list[torch.Tensor] = []
|
||||
for _model, _optim in ((ref_model, ref_optim), (model, optim)):
|
||||
losses.append(_model(inp).sum())
|
||||
losses[-1].backward()
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import functools
|
|||
import math
|
||||
import threading
|
||||
import unittest
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -29,7 +29,7 @@ from torch.testing._internal.two_tensor import TwoTensor
|
|||
|
||||
def two_tensor_fsdp_pre_all_gather_v1(
|
||||
self, mesh: DeviceMesh
|
||||
) -> Tuple[Tuple[torch.Tensor, ...], Any]:
|
||||
) -> tuple[tuple[torch.Tensor, ...], Any]:
|
||||
all_gather_inputs = (self.a, self.b)
|
||||
metadata = None
|
||||
return all_gather_inputs, metadata
|
||||
|
|
@ -39,10 +39,10 @@ def two_tensor_fsdp_pre_all_gather_v2(
|
|||
self,
|
||||
mesh: DeviceMesh,
|
||||
outer_size: torch.Size,
|
||||
outer_stride: Tuple[int, ...],
|
||||
outer_stride: tuple[int, ...],
|
||||
module: nn.Module,
|
||||
mp_policy: MixedPrecisionPolicy,
|
||||
) -> Tuple[Tuple[torch.Tensor, ...], Any]:
|
||||
) -> tuple[tuple[torch.Tensor, ...], Any]:
|
||||
all_gather_inputs = (self.a, self.b)
|
||||
metadata = None
|
||||
return all_gather_inputs, metadata
|
||||
|
|
@ -50,12 +50,12 @@ def two_tensor_fsdp_pre_all_gather_v2(
|
|||
|
||||
def two_tensor_fsdp_post_all_gather(
|
||||
self,
|
||||
all_gather_outputs: Tuple[torch.Tensor, ...],
|
||||
all_gather_outputs: tuple[torch.Tensor, ...],
|
||||
metadata: Any,
|
||||
param_dtype: torch.dtype,
|
||||
*,
|
||||
out: Optional[torch.Tensor] = None,
|
||||
) -> Union[Tuple[torch.Tensor, Tuple[torch.Tensor, ...]], None]:
|
||||
) -> Union[tuple[torch.Tensor, tuple[torch.Tensor, ...]], None]:
|
||||
assert metadata is None, f"{metadata}"
|
||||
a, b = all_gather_outputs
|
||||
if out is not None:
|
||||
|
|
@ -96,10 +96,10 @@ class BFloat16AllGatherTensor(torch.Tensor):
|
|||
self,
|
||||
mesh: DeviceMesh,
|
||||
outer_size: torch.Size,
|
||||
outer_stride: Tuple[int, ...],
|
||||
outer_stride: tuple[int, ...],
|
||||
module: nn.Module,
|
||||
mp_policy: MixedPrecisionPolicy,
|
||||
) -> Tuple[Tuple[torch.Tensor, ...], Any]:
|
||||
) -> tuple[tuple[torch.Tensor, ...], Any]:
|
||||
assert mesh.ndim == 1, f"{mesh.ndim}"
|
||||
mesh_size = mesh.size()
|
||||
requires_padding = outer_size[0] % mesh_size != 0
|
||||
|
|
@ -116,12 +116,12 @@ class BFloat16AllGatherTensor(torch.Tensor):
|
|||
|
||||
def fsdp_post_all_gather(
|
||||
self,
|
||||
all_gather_outputs: Tuple[torch.Tensor, ...],
|
||||
all_gather_outputs: tuple[torch.Tensor, ...],
|
||||
metadata: Any,
|
||||
param_dtype: torch.dtype,
|
||||
*,
|
||||
out: Optional[torch.Tensor] = None,
|
||||
) -> Union[Tuple[torch.Tensor, Tuple[torch.Tensor, ...]], None]:
|
||||
) -> Union[tuple[torch.Tensor, tuple[torch.Tensor, ...]], None]:
|
||||
assert metadata is None, f"{metadata}"
|
||||
(tensor,) = all_gather_outputs
|
||||
assert tensor.dtype == torch.bfloat16, f"{tensor.dtype}"
|
||||
|
|
@ -157,7 +157,7 @@ class BFloat16AllGatherTensor(torch.Tensor):
|
|||
|
||||
@staticmethod
|
||||
def __tensor_unflatten__(
|
||||
inner_tensors, outer_size: torch.Size, outer_stride: Tuple[int, ...]
|
||||
inner_tensors, outer_size: torch.Size, outer_stride: tuple[int, ...]
|
||||
):
|
||||
return inner_tensors["_data"]
|
||||
|
||||
|
|
@ -236,7 +236,7 @@ class TestFullyShardAllGatherExtensionsMultiProcess(
|
|||
torch.manual_seed(42 + self.rank + 1)
|
||||
inp = torch.randn((2, 8), device="cuda")
|
||||
for iter_idx in range(10):
|
||||
losses: List[torch.Tensor] = []
|
||||
losses: list[torch.Tensor] = []
|
||||
for _model in (ref_model, model):
|
||||
losses.append(_model(inp).sum())
|
||||
losses[-1].backward()
|
||||
|
|
@ -314,10 +314,10 @@ class TestFullyShardAllGatherExtensionsMultiThread(
|
|||
self,
|
||||
mesh: DeviceMesh,
|
||||
outer_size: torch.Size,
|
||||
outer_stride: Tuple[int, ...],
|
||||
outer_stride: tuple[int, ...],
|
||||
module: nn.Module,
|
||||
mp_policy: MixedPrecisionPolicy,
|
||||
) -> Tuple[Tuple[torch.Tensor, ...], Any]:
|
||||
) -> tuple[tuple[torch.Tensor, ...], Any]:
|
||||
nonlocal tls
|
||||
tls.ran_pre_all_gather = True
|
||||
return (self.to(torch.bfloat16),), None
|
||||
|
|
@ -325,12 +325,12 @@ class TestFullyShardAllGatherExtensionsMultiThread(
|
|||
@torch.no_grad()
|
||||
def fsdp_post_all_gather(
|
||||
self,
|
||||
all_gather_outputs: Tuple[torch.Tensor, ...],
|
||||
all_gather_outputs: tuple[torch.Tensor, ...],
|
||||
metadata: Any,
|
||||
param_dtype: torch.dtype,
|
||||
*,
|
||||
out: Optional[torch.Tensor] = None,
|
||||
) -> Union[Tuple[torch.Tensor, Tuple[torch.Tensor, ...]], None]:
|
||||
) -> Union[tuple[torch.Tensor, tuple[torch.Tensor, ...]], None]:
|
||||
(tensor,) = all_gather_outputs
|
||||
assert metadata is None, f"{metadata}"
|
||||
assert tensor.dtype == torch.bfloat16, f"{tensor.dtype}"
|
||||
|
|
@ -416,10 +416,10 @@ class TestFullyShardAllGatherExtensionsMultiThread(
|
|||
self,
|
||||
mesh: DeviceMesh,
|
||||
outer_size: torch.Size,
|
||||
outer_stride: Tuple[int, ...],
|
||||
outer_stride: tuple[int, ...],
|
||||
module: nn.Module,
|
||||
mp_policy: MixedPrecisionPolicy,
|
||||
) -> Tuple[Tuple[torch.Tensor, ...], Any]:
|
||||
) -> tuple[tuple[torch.Tensor, ...], Any]:
|
||||
nonlocal tls
|
||||
tls.mesh = mesh
|
||||
return (self,), None
|
||||
|
|
@ -427,12 +427,12 @@ class TestFullyShardAllGatherExtensionsMultiThread(
|
|||
@torch.no_grad()
|
||||
def fsdp_post_all_gather(
|
||||
self,
|
||||
all_gather_outputs: Tuple[torch.Tensor, ...],
|
||||
all_gather_outputs: tuple[torch.Tensor, ...],
|
||||
metadata: Any,
|
||||
param_dtype: torch.dtype,
|
||||
*,
|
||||
out: Optional[torch.Tensor] = None,
|
||||
) -> Union[Tuple[torch.Tensor, Tuple[torch.Tensor, ...]], None]:
|
||||
) -> Union[tuple[torch.Tensor, tuple[torch.Tensor, ...]], None]:
|
||||
(tensor,) = all_gather_outputs
|
||||
if out is not None:
|
||||
return
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
import copy
|
||||
import functools
|
||||
import itertools
|
||||
from typing import List, Union
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -116,7 +116,7 @@ class TestFullyShardFrozen(FSDPTest):
|
|||
), patch_register_post_backward_hook_backward(backward_with_count):
|
||||
for iter_idx in range(10):
|
||||
inp = torch.randn((8, lin_dim), device=device)
|
||||
losses: List[torch.Tensor] = []
|
||||
losses: list[torch.Tensor] = []
|
||||
for _model, _optim in ((ref_model, ref_optim), (model, optim)):
|
||||
_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
|
||||
losses.append(_model(inp).sum())
|
||||
|
|
@ -151,7 +151,7 @@ class TestFullyShardFrozen(FSDPTest):
|
|||
):
|
||||
torch.manual_seed(42)
|
||||
num_linears, lin_dim = (6, 32)
|
||||
modules: List[nn.Module] = []
|
||||
modules: list[nn.Module] = []
|
||||
for _ in range(num_linears):
|
||||
modules += [nn.Linear(lin_dim, lin_dim), nn.ReLU()]
|
||||
model = nn.Sequential(*modules)
|
||||
|
|
@ -187,7 +187,7 @@ class TestFullyShardFrozen(FSDPTest):
|
|||
inp = torch.randn((8, lin_dim), device="cuda")
|
||||
with patch_register_post_backward_hook_backward(backward_with_count):
|
||||
for iter_idx in range(num_iters):
|
||||
losses: List[torch.Tensor] = []
|
||||
losses: list[torch.Tensor] = []
|
||||
for _model, _optim in ((ref_model, ref_optim), (model, optim)):
|
||||
# Unfreeze the parameters on the last step to emulate some
|
||||
# kinds of fine-tuning
|
||||
|
|
@ -251,7 +251,7 @@ class TestFullyShardFrozen(FSDPTest):
|
|||
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
|
||||
for iter_idx in range(10):
|
||||
inp = torch.randn((8, 5), device="cuda")
|
||||
losses: List[torch.Tensor] = []
|
||||
losses: list[torch.Tensor] = []
|
||||
for _model, _optim in ((ref_model, ref_optim), (model, optim)):
|
||||
_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
|
||||
losses.append(_model(inp).sum())
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
import copy
|
||||
import itertools
|
||||
import unittest
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -211,8 +211,8 @@ class TestFullyShardManagedModulesAndStates(FSDPTestMultiThread):
|
|||
|
||||
def _check_managed_modules(
|
||||
self,
|
||||
managed_modules: List[nn.Module],
|
||||
expected_managed_modules: List[nn.Module],
|
||||
managed_modules: list[nn.Module],
|
||||
expected_managed_modules: list[nn.Module],
|
||||
):
|
||||
self.assertEqual(len(managed_modules), len(expected_managed_modules))
|
||||
# Check set comparison since we do not require anything about the order
|
||||
|
|
@ -262,10 +262,10 @@ class TestFullyShardManagedModulesAndStates(FSDPTestMultiThread):
|
|||
|
||||
def _check_managed_states(
|
||||
self,
|
||||
managed_params: List[nn.Parameter],
|
||||
managed_buffers: List[torch.Tensor],
|
||||
expected_managed_params: List[nn.Parameter],
|
||||
expected_managed_buffers: List[torch.Tensor],
|
||||
managed_params: list[nn.Parameter],
|
||||
managed_buffers: list[torch.Tensor],
|
||||
expected_managed_params: list[nn.Parameter],
|
||||
expected_managed_buffers: list[torch.Tensor],
|
||||
):
|
||||
self.assertEqual(len(managed_params), len(expected_managed_params))
|
||||
self.assertEqual(len(managed_buffers), len(expected_managed_buffers))
|
||||
|
|
@ -370,7 +370,7 @@ class TestFullyShardShardedParameterTensor(FSDPTestMultiThread):
|
|||
self._check_1d_sharded_parameters(orig_params, sharded_params)
|
||||
|
||||
def _check_1d_sharded_parameters(
|
||||
self, orig_params: List[nn.Parameter], sharded_params: List[nn.Parameter]
|
||||
self, orig_params: list[nn.Parameter], sharded_params: list[nn.Parameter]
|
||||
):
|
||||
self.assertEqual(len(orig_params), len(sharded_params))
|
||||
global_mesh = init_device_mesh("cuda", (self.world_size,))
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
import copy
|
||||
import functools
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -339,7 +339,7 @@ class TestFullyShardMixedPrecisionTraining(FSDPTest):
|
|||
model.set_reshard_after_backward(
|
||||
is_last_microbatch or reshard_after_forward
|
||||
)
|
||||
losses: List[torch.Tensor] = []
|
||||
losses: list[torch.Tensor] = []
|
||||
for _model in (ref_model_compute, model):
|
||||
losses.append(
|
||||
_model(microbatch_inps[microbatch_idx].detach()).sum()
|
||||
|
|
@ -391,7 +391,7 @@ class TestFullyShardMixedPrecisionCasts(FSDPTestMultiThread):
|
|||
|
||||
# Subtest 1: use fp16 on the second child submodule -- does not require
|
||||
# any additional casting logic
|
||||
forward_inputs: Dict[str, nn.Module] = {}
|
||||
forward_inputs: dict[str, nn.Module] = {}
|
||||
model = SaveForwardInputsModel(
|
||||
forward_inputs,
|
||||
cast_forward_inputs=False,
|
||||
|
|
@ -405,7 +405,7 @@ class TestFullyShardMixedPrecisionCasts(FSDPTestMultiThread):
|
|||
|
||||
# Subtest 2: use fp16 on the second child module, where the user module
|
||||
# owns the cast
|
||||
forward_inputs: Dict[nn.Module, torch.Tensor] = {}
|
||||
forward_inputs: dict[nn.Module, torch.Tensor] = {}
|
||||
model = SaveForwardInputsModel(
|
||||
forward_inputs=forward_inputs, cast_forward_inputs=True
|
||||
).cuda()
|
||||
|
|
@ -423,7 +423,7 @@ class TestFullyShardMixedPrecisionCasts(FSDPTestMultiThread):
|
|||
|
||||
# Subtest 3: use fp16 on the first child module and specify its output
|
||||
# dtype so that the second child module does not need to cast
|
||||
forward_inputs: Dict[nn.Module, torch.Tensor] = {}
|
||||
forward_inputs: dict[nn.Module, torch.Tensor] = {}
|
||||
model = SaveForwardInputsModel(
|
||||
forward_inputs=forward_inputs, cast_forward_inputs=False
|
||||
).cuda()
|
||||
|
|
@ -448,7 +448,7 @@ class TestFullyShardMixedPrecisionCasts(FSDPTestMultiThread):
|
|||
|
||||
def _test_submodules_with_external_inputs(self, enable_submodule_cast: bool):
|
||||
class ToyModule(nn.Module):
|
||||
def __init__(self, forward_inputs: Dict[str, torch.Tensor]) -> None:
|
||||
def __init__(self, forward_inputs: dict[str, torch.Tensor]) -> None:
|
||||
super().__init__()
|
||||
self.l = nn.Linear(100, 100)
|
||||
self.forward_inputs = forward_inputs
|
||||
|
|
@ -459,7 +459,7 @@ class TestFullyShardMixedPrecisionCasts(FSDPTestMultiThread):
|
|||
return self.l(x)
|
||||
|
||||
class ToyModel(nn.Module):
|
||||
def __init__(self, forward_inputs: Dict[str, torch.Tensor]) -> None:
|
||||
def __init__(self, forward_inputs: dict[str, torch.Tensor]) -> None:
|
||||
super().__init__()
|
||||
self.l1 = nn.Linear(100, 100)
|
||||
self.l2 = ToyModule(forward_inputs)
|
||||
|
|
@ -472,7 +472,7 @@ class TestFullyShardMixedPrecisionCasts(FSDPTestMultiThread):
|
|||
) # external input
|
||||
return self.l2(self.l1(x), y)
|
||||
|
||||
forward_inputs: Dict[str, torch.Tensor] = {}
|
||||
forward_inputs: dict[str, torch.Tensor] = {}
|
||||
model = ToyModel(forward_inputs).cuda()
|
||||
x = torch.zeros(2, 100, device="cuda", dtype=torch.float32)
|
||||
fully_shard(
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import copy
|
|||
import functools
|
||||
import unittest
|
||||
from contextlib import nullcontext
|
||||
from typing import Dict, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
|
@ -308,7 +308,7 @@ class TestFullyShardStateDictMultiProcess(FSDPTest):
|
|||
|
||||
# Verify that we can load a new state dict that contains DTensors with
|
||||
# storages different from the current model parameters
|
||||
new_state_dict: Dict[str, DTensor] = {}
|
||||
new_state_dict: dict[str, DTensor] = {}
|
||||
for param_name, dtensor in state_dict.items():
|
||||
# Construct new DTensors to exercise load state dict writeback
|
||||
new_state_dict[param_name] = dtensor.detach().clone().fill_(new_fill_value)
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import itertools
|
|||
import unittest
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -65,7 +65,7 @@ class TestFullyShardForwardInputs(FSDPTestMultiThread):
|
|||
device = torch.device("cuda", 0)
|
||||
|
||||
class ParamlessModule(nn.Module):
|
||||
def forward(self, x: torch.Tensor, ys: Tuple[torch.Tensor, ...]):
|
||||
def forward(self, x: torch.Tensor, ys: tuple[torch.Tensor, ...]):
|
||||
# Check that FSDP moved the inputs to GPU, including recursing
|
||||
# into the tuple data structure
|
||||
assert x.device == device, f"Expects {device} but got {x.device}"
|
||||
|
|
@ -224,7 +224,7 @@ class TestFullyShardCastAfterInit(FSDPTestMultiThread):
|
|||
torch.manual_seed(42 + self.rank + 1)
|
||||
inp = torch.randn((2, mlp_dim), device="cuda", dtype=dtype)
|
||||
for iter_idx in range(10):
|
||||
losses: List[torch.Tensor] = []
|
||||
losses: list[torch.Tensor] = []
|
||||
for _model in (ref_model, model):
|
||||
losses.append(_model(inp).sum())
|
||||
losses[-1].backward()
|
||||
|
|
@ -281,7 +281,7 @@ class TestFullyShard1DTrainingCore(FSDPTest):
|
|||
)
|
||||
|
||||
def _test_train_parity_single_group(
|
||||
self, lin_shapes: List[Tuple[int, int]], use_shard_placement_fn: bool
|
||||
self, lin_shapes: list[tuple[int, int]], use_shard_placement_fn: bool
|
||||
):
|
||||
torch.manual_seed(42)
|
||||
model = nn.Sequential(
|
||||
|
|
@ -308,7 +308,7 @@ class TestFullyShard1DTrainingCore(FSDPTest):
|
|||
torch.manual_seed(42 + self.rank + 1)
|
||||
inp = (torch.randn((4, lin_shapes[0][0]), device="cuda"),)
|
||||
for iter_idx in range(10):
|
||||
losses: List[torch.Tensor] = []
|
||||
losses: list[torch.Tensor] = []
|
||||
for _model, _optim in ((ref_model, ref_optim), (model, optim)):
|
||||
_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
|
||||
losses.append(_model(*inp).sum())
|
||||
|
|
@ -461,7 +461,7 @@ class TestFullyShard1DTrainingCore(FSDPTest):
|
|||
with patch_all_gather_ctx, patch_reduce_scatter_ctx:
|
||||
for iter_idx in range(10):
|
||||
inp = torch.randint(0, vocab_size, (3, 64), device=device_type)
|
||||
losses: List[torch.Tensor] = []
|
||||
losses: list[torch.Tensor] = []
|
||||
for _model, _optim in ((ref_model, ref_optim), (model, optim)):
|
||||
_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
|
||||
losses.append(_model(inp).sum())
|
||||
|
|
@ -554,7 +554,7 @@ class TestFullyShard1DTrainingCore(FSDPTest):
|
|||
torch.manual_seed(42 + self.rank)
|
||||
inp = torch.randn((32, 4), device="cuda")
|
||||
for iter_idx in range(10):
|
||||
losses: List[torch.Tensor] = []
|
||||
losses: list[torch.Tensor] = []
|
||||
for _model, _optim in ((ref_model, ref_optim), (model, optim)):
|
||||
_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
|
||||
losses.append(_model(inp).sum())
|
||||
|
|
@ -592,7 +592,7 @@ class TestFullyShard1DTrainingCore(FSDPTest):
|
|||
torch.manual_seed(42 + self.rank)
|
||||
inp = torch.randint(0, model_args.vocab_size, (2, 8), device="cuda")
|
||||
for _ in range(10):
|
||||
losses: List[torch.Tensor] = []
|
||||
losses: list[torch.Tensor] = []
|
||||
for _model, _optim in ((ref_model, ref_optim), (model, optim)):
|
||||
_optim.zero_grad()
|
||||
losses.append(_model(inp).sum())
|
||||
|
|
@ -623,8 +623,8 @@ class TestFullyShard1DTrainingCore(FSDPTest):
|
|||
inp = torch.randint(0, model_args.vocab_size, (2, 8), device="cuda")
|
||||
# Track all losses and check for equality at the end to avoid a CPU
|
||||
# sync point after each iteration
|
||||
ref_losses: List[torch.Tensor] = []
|
||||
losses: List[torch.Tensor] = []
|
||||
ref_losses: list[torch.Tensor] = []
|
||||
losses: list[torch.Tensor] = []
|
||||
for _ in range(10):
|
||||
ref_optim.zero_grad()
|
||||
ref_losses.append(ref_model(inp).sum())
|
||||
|
|
@ -736,7 +736,7 @@ class TestFullyShard1DTrainingCompose(FSDPTest):
|
|||
self, ref_model, model, prefixes_to_ignore=prefixes_to_ignore
|
||||
)
|
||||
for iter_idx in range(10):
|
||||
losses: List[torch.Tensor] = []
|
||||
losses: list[torch.Tensor] = []
|
||||
for _model in (ref_model, model):
|
||||
torch.manual_seed(iter_idx + 1) # for dropout determinism
|
||||
losses.append(_model(inp).sum())
|
||||
|
|
@ -886,7 +886,7 @@ class TestFullyShardSharedParams(FSDPTest):
|
|||
torch.manual_seed(42 + self.rank + 1)
|
||||
for iter_idx in range(10):
|
||||
inp = torch.randint(0, model_args.vocab_size, (2, 16), device="cuda")
|
||||
losses: List[torch.Tensor] = []
|
||||
losses: list[torch.Tensor] = []
|
||||
for _model, _optim in ((ref_model, ref_optim), (model, optim)):
|
||||
_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
|
||||
losses.append(_model(inp).sum())
|
||||
|
|
@ -1009,7 +1009,7 @@ class TestFullyShardGradientAccumulation(FSDPTest):
|
|||
is_last_microbatch = microbatch_idx == num_microbatches - 1
|
||||
set_backward_flags(model, is_last_microbatch)
|
||||
inp = torch.randn(batch_size, lin_dim, device="cuda")
|
||||
losses: List[torch.Tensor] = []
|
||||
losses: list[torch.Tensor] = []
|
||||
for _model in (ref_model, model):
|
||||
with CommDebugMode() as comm_mode:
|
||||
losses.append(_model(inp).sum())
|
||||
|
|
@ -1125,8 +1125,8 @@ class TestFullyShardGradientAccumulation(FSDPTest):
|
|||
|
||||
# Emulate the 1f1b pipeline schedule and only reduce gradients on the
|
||||
# last microbatch
|
||||
losses: List[torch.Tensor] = []
|
||||
ref_losses: List[torch.Tensor] = []
|
||||
losses: list[torch.Tensor] = []
|
||||
ref_losses: list[torch.Tensor] = []
|
||||
for inp_idx, inp in enumerate(inps):
|
||||
is_last_microbatch = inp_idx == num_microbatches - 1
|
||||
model.set_requires_gradient_sync(is_last_microbatch)
|
||||
|
|
@ -1210,7 +1210,7 @@ class TestFullyShardNDTraining(FSDPTest):
|
|||
device = torch.device("cuda")
|
||||
for iter_idx in range(10):
|
||||
inp = torch.randn((8, mlp_dim), device=device)
|
||||
losses: List[torch.Tensor] = []
|
||||
losses: list[torch.Tensor] = []
|
||||
for _model, _optim in ((ref_model, ref_optim), (model, optim)):
|
||||
_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
|
||||
losses.append(_model(inp).sum())
|
||||
|
|
@ -1281,7 +1281,7 @@ class TestFullyShardHSDP3DTraining(FSDPTest):
|
|||
device = torch.device("cuda")
|
||||
for iter_idx in range(10):
|
||||
inp = torch.randn((8, mlp_dim), device=device)
|
||||
losses: List[torch.Tensor] = []
|
||||
losses: list[torch.Tensor] = []
|
||||
for _model, _optim in ((ref_model, ref_optim), (model, optim)):
|
||||
_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
|
||||
losses.append(_model(inp).sum())
|
||||
|
|
@ -1360,7 +1360,7 @@ class TestFullyShardHSDPTraining(FSDPTest):
|
|||
if sync_gradients_at_last_batch:
|
||||
model.set_requires_gradient_sync(is_last_microbatch)
|
||||
inp = torch.randn((8, mlp_dim), device=device)
|
||||
losses: List[torch.Tensor] = []
|
||||
losses: list[torch.Tensor] = []
|
||||
for _model, _optim in ((ref_model, ref_optim), (model, optim)):
|
||||
losses.append(_model(inp).sum())
|
||||
losses[-1].backward()
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ from collections import deque, OrderedDict
|
|||
from contextlib import ContextDecorator, contextmanager, nullcontext
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
|
@ -70,7 +69,7 @@ class MultiOutputModel(nn.Module):
|
|||
self.w1 = nn.Parameter(torch.randn((100, 100), device=device))
|
||||
self.w2 = nn.Parameter(torch.randn((100, 100), device=device))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
z = x @ self.w1
|
||||
z = nn.functional.relu(z)
|
||||
z = z @ self.w2
|
||||
|
|
@ -82,7 +81,7 @@ class MultiInputModel(nn.Module):
|
|||
super().__init__()
|
||||
self.w = nn.Parameter(torch.randn((100, 100), device=device))
|
||||
|
||||
def forward(self, xs: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
|
||||
def forward(self, xs: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
|
||||
assert len(xs) == 2, f"Expects 2 args but got {len(xs)}"
|
||||
x, y = xs
|
||||
z = x + y
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import copy
|
|||
import functools
|
||||
import io
|
||||
from copy import deepcopy
|
||||
from typing import List, Optional, Type
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -166,7 +166,7 @@ class TestFullyShard2DTraining(FSDPTest):
|
|||
device = torch.device("cuda")
|
||||
for iter_idx in range(10):
|
||||
inp = torch.randn((8, mlp_dim), device=device)
|
||||
losses: List[torch.Tensor] = []
|
||||
losses: list[torch.Tensor] = []
|
||||
for _model, _optim in ((ref_model, ref_optim), (model, optim)):
|
||||
_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
|
||||
losses.append(_model(inp).sum())
|
||||
|
|
@ -335,7 +335,7 @@ class TestFullyShard2DTraining(FSDPTest):
|
|||
self,
|
||||
use_seq_parallel: bool,
|
||||
reuse_model_optim: bool,
|
||||
optimizer_class: Type[torch.optim.Optimizer],
|
||||
optimizer_class: type[torch.optim.Optimizer],
|
||||
foreach: bool,
|
||||
):
|
||||
def train_step(
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
|
@ -28,12 +27,12 @@ class TestContract(TestCase):
|
|||
@skipIfTorchDynamo("Dynamo does not support the state key")
|
||||
def test_add_hooks(self):
|
||||
def forward_pre_hook(
|
||||
module: nn.Module, inp: Tuple[torch.Tensor]
|
||||
) -> Tuple[torch.Tensor]:
|
||||
module: nn.Module, inp: tuple[torch.Tensor]
|
||||
) -> tuple[torch.Tensor]:
|
||||
return inp
|
||||
|
||||
def forward_hook(
|
||||
module: nn.Module, inp: Tuple[torch.Tensor], out: torch.Tensor
|
||||
module: nn.Module, inp: tuple[torch.Tensor], out: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
return out
|
||||
|
||||
|
|
@ -44,9 +43,9 @@ class TestContract(TestCase):
|
|||
|
||||
def backward_hook(
|
||||
module: nn.Module,
|
||||
grad_input: Tuple[torch.Tensor],
|
||||
grad_input: tuple[torch.Tensor],
|
||||
grad_output: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor]:
|
||||
) -> tuple[torch.Tensor]:
|
||||
return grad_input
|
||||
|
||||
@contract()
|
||||
|
|
@ -92,8 +91,8 @@ class TestContract(TestCase):
|
|||
@skipIfTorchDynamo("Dynamo does not support the state key")
|
||||
def test_state(self):
|
||||
def check_and_update_state_hook(
|
||||
module: nn.Module, inp: Tuple[torch.Tensor]
|
||||
) -> Tuple[torch.Tensor]:
|
||||
module: nn.Module, inp: tuple[torch.Tensor]
|
||||
) -> tuple[torch.Tensor]:
|
||||
self.assertEqual(api.state(module).dummy_state, 7)
|
||||
api.state(module).dummy_state = 8
|
||||
return inp
|
||||
|
|
@ -139,7 +138,7 @@ class TestContract(TestCase):
|
|||
@skipIfTorchDynamo("Dynamo does not support the state key")
|
||||
def test_multi_module_api(self):
|
||||
@contract()
|
||||
def multi_module_api(modules: List[nn.Module]) -> nn.Module:
|
||||
def multi_module_api(modules: list[nn.Module]) -> nn.Module:
|
||||
return modules
|
||||
|
||||
model = nn.Sequential(*[nn.Linear(3, 3) for _ in range(5)])
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ import itertools
|
|||
import math
|
||||
import pickle
|
||||
import sys
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -3186,7 +3185,7 @@ class TestCreateTensorNoProcessGroupMode(TestCase):
|
|||
],
|
||||
size=torch.Size([4, 2]),
|
||||
)
|
||||
st_local_shards: List[Shard] = []
|
||||
st_local_shards: list[Shard] = []
|
||||
for shard_metadata in st_metadata.shards_metadata:
|
||||
st_local_shards.append(
|
||||
Shard(
|
||||
|
|
@ -3215,7 +3214,7 @@ class TestCreateTensorNoProcessGroupMode(TestCase):
|
|||
],
|
||||
size=torch.Size([4, 2]),
|
||||
)
|
||||
st_local_shards: List[Shard] = []
|
||||
st_local_shards: list[Shard] = []
|
||||
src = torch.randn(4, 2)
|
||||
for shard_metadata in st_metadata.shards_metadata:
|
||||
offsets = shard_metadata.shard_offsets
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
# Owner(s): ["oncall: distributed"]
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Union
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from torch.distributed._shard import _shard_tensor, sharded_tensor
|
||||
|
|
@ -495,7 +495,7 @@ class TestShardingSpec(TestCase):
|
|||
@dataclass
|
||||
class GridShardingSpec(ShardingSpec):
|
||||
grid_size: int
|
||||
placements: List[Union[torch.distributed._remote_device, str]]
|
||||
placements: list[Union[torch.distributed._remote_device, str]]
|
||||
|
||||
def __post_init__(self):
|
||||
for i, remote_device in enumerate(self.placements):
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
# Owner(s): ["module: unknown"]
|
||||
import gc
|
||||
import unittest
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
|
@ -161,7 +160,7 @@ class TestMemTracker(TestCase):
|
|||
|
||||
def get_param_grad_optstate_actual_bytes(
|
||||
model: nn.Module, opt: torch.optim.Optimizer
|
||||
) -> Tuple[int, int, int]:
|
||||
) -> tuple[int, int, int]:
|
||||
param_bytes = 0
|
||||
grad_bytes = 0
|
||||
opt_state_bytes = 0
|
||||
|
|
@ -179,7 +178,7 @@ class TestMemTracker(TestCase):
|
|||
|
||||
def get_param_grad_optstate_bytes_from_tracker(
|
||||
tracker: MemTracker,
|
||||
) -> Tuple[int, int, int]:
|
||||
) -> tuple[int, int, int]:
|
||||
snapshot = tracker.get_tracker_snapshot()
|
||||
param_bytes = snapshot[dev]["Parameter"]
|
||||
grad_bytes = snapshot[dev]["Gradient"]
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
# Owner(s): ["module: unknown"]
|
||||
import unittest
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, cast, Tuple, Union
|
||||
from typing import Any, Callable, cast, Union
|
||||
|
||||
import torch
|
||||
from torch import nn, optim
|
||||
|
|
@ -73,7 +73,7 @@ class TestRuntimeEstimator(TestCase):
|
|||
def _measure_actual_cuda_time(
|
||||
self,
|
||||
func: Callable,
|
||||
args: Tuple[Any, ...],
|
||||
args: tuple[Any, ...],
|
||||
) -> float:
|
||||
warmup_iters, actual_iters = 2, 5
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
|
|
@ -92,7 +92,7 @@ class TestRuntimeEstimator(TestCase):
|
|||
self,
|
||||
estimate_mode: str,
|
||||
func: Callable,
|
||||
args: Tuple[Any, ...],
|
||||
args: tuple[Any, ...],
|
||||
) -> float:
|
||||
# Optimizer init step
|
||||
func(*args)
|
||||
|
|
@ -106,7 +106,7 @@ class TestRuntimeEstimator(TestCase):
|
|||
model_type: str,
|
||||
model_args: Union[ConvArgs, ModelArgs],
|
||||
bsz: int,
|
||||
) -> Tuple[nn.Module, optim.Optimizer, torch.Tensor]:
|
||||
) -> tuple[nn.Module, optim.Optimizer, torch.Tensor]:
|
||||
dev = torch.cuda.current_device()
|
||||
if model_type == "Transformer":
|
||||
model_args = cast(ModelArgs, model_args)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
# Owner(s): ["module: unknown"]
|
||||
import copy
|
||||
import unittest
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
|
|
@ -40,7 +39,7 @@ class TestSACILP(TestCase):
|
|||
|
||||
def _init_model_input_optimizer(
|
||||
self,
|
||||
) -> Tuple[torch.nn.Module, torch.optim.Optimizer, torch.Tensor]:
|
||||
) -> tuple[torch.nn.Module, torch.optim.Optimizer, torch.Tensor]:
|
||||
bsz = 8
|
||||
model_args = ModelArgs(
|
||||
n_layers=4,
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from dataclasses import dataclass, field
|
|||
from enum import auto, Enum
|
||||
from functools import partial
|
||||
from io import BytesIO
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -95,9 +95,9 @@ class ModelType(Enum):
|
|||
class TestTrainState:
|
||||
step: int = 0
|
||||
current_loss: float = -1
|
||||
losses: List[float] = field(default_factory=list)
|
||||
losses: list[float] = field(default_factory=list)
|
||||
|
||||
def state_dict(self) -> Dict[str, Any]:
|
||||
def state_dict(self) -> dict[str, Any]:
|
||||
loss_bytes = BytesIO()
|
||||
torch.save(self.losses, loss_bytes)
|
||||
return {
|
||||
|
|
@ -284,7 +284,7 @@ class TestE2ESaveAndLoad(DTensorTestBase, VerifyStateDictMixin):
|
|||
|
||||
@with_temp_dir
|
||||
def test_stateful_and_non_stateful_loads(self) -> None:
|
||||
class StateDict(Dict):
|
||||
class StateDict(dict):
|
||||
def __init__(self):
|
||||
self.set_sd_item_called = False
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
import os
|
||||
import sys
|
||||
from typing import cast, List, Optional, Union
|
||||
from typing import cast, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -177,17 +177,17 @@ class FaultyStorageWriter(TestStorageBase, StorageWriter):
|
|||
self._fail_rank("fail_prepare_local_plan")
|
||||
return plan
|
||||
|
||||
def prepare_global_plan(self, plans: List[SavePlan]) -> List[SavePlan]:
|
||||
def prepare_global_plan(self, plans: list[SavePlan]) -> list[SavePlan]:
|
||||
self._fail_rank("fail_prepare_global_plan")
|
||||
return plans
|
||||
|
||||
def write_data(
|
||||
self, plan: SavePlan, planner: SavePlanner
|
||||
) -> Future[List[WriteResult]]:
|
||||
) -> Future[list[WriteResult]]:
|
||||
self._fail_rank("fail_write_data")
|
||||
return self._fail_rank_async("fail_write_data_async", [])
|
||||
|
||||
def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None:
|
||||
def finish(self, metadata: Metadata, results: list[list[WriteResult]]) -> None:
|
||||
self._fail_rank("fail_finish")
|
||||
|
||||
@classmethod
|
||||
|
|
@ -210,7 +210,7 @@ class FaultyStorageReader(TestStorageBase, StorageReader):
|
|||
self._fail_rank("fail_prepare_local_plan")
|
||||
return plan
|
||||
|
||||
def prepare_global_plan(self, plans: List[LoadPlan]) -> List[LoadPlan]:
|
||||
def prepare_global_plan(self, plans: list[LoadPlan]) -> list[LoadPlan]:
|
||||
self._fail_rank("fail_prepare_global_plan")
|
||||
return plans
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
# Owner(s): ["oncall: distributed"]
|
||||
from typing import Dict, Union
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -58,14 +58,14 @@ class MyTestModule(torch.nn.Module):
|
|||
def extra_state_tensor(self, new_extra_state_tensor: torch.Tensor) -> None:
|
||||
self._extra_state_tensor = new_extra_state_tensor
|
||||
|
||||
def get_extra_state(self) -> Dict[str, Union[int, torch._tensor.Tensor]]:
|
||||
def get_extra_state(self) -> dict[str, Union[int, torch._tensor.Tensor]]:
|
||||
return {
|
||||
"extra_state": self._extra_state,
|
||||
"extra_state_tensor": self._extra_state_tensor,
|
||||
}
|
||||
|
||||
def set_extra_state(
|
||||
self, state: Dict[str, Union[int, torch._tensor.Tensor]]
|
||||
self, state: dict[str, Union[int, torch._tensor.Tensor]]
|
||||
) -> None:
|
||||
self._extra_state = state["extra_state"] # pyre-ignore[8]
|
||||
self._extra_state_tensor = state["extra_state_tensor"] # pyre-ignore[8]
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ import os
|
|||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -53,8 +52,8 @@ if TEST_WITH_DEV_DBG_ASAN:
|
|||
|
||||
def assert_state_dict_equal(
|
||||
self: TestCase,
|
||||
state_dict_1: Dict[str, torch.Tensor],
|
||||
state_dict_2: Dict[str, torch.Tensor],
|
||||
state_dict_1: dict[str, torch.Tensor],
|
||||
state_dict_2: dict[str, torch.Tensor],
|
||||
) -> bool:
|
||||
self.assertEqual(
|
||||
len(state_dict_1), len(state_dict_2), "state_dict must be the same size"
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
import sys
|
||||
import tempfile
|
||||
from typing import Any, Dict, IO
|
||||
from typing import Any, IO
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -56,8 +56,8 @@ _THREAD_COUNTS = {1, 2}
|
|||
|
||||
def assert_state_dict_equal(
|
||||
self: TestCase,
|
||||
state_dict_1: Dict[str, torch.Tensor],
|
||||
state_dict_2: Dict[str, torch.Tensor],
|
||||
state_dict_1: dict[str, torch.Tensor],
|
||||
state_dict_2: dict[str, torch.Tensor],
|
||||
) -> bool:
|
||||
self.assertEqual(
|
||||
len(state_dict_1), len(state_dict_2), "state_dict must be the same size"
|
||||
|
|
@ -113,10 +113,10 @@ class BlobState:
|
|||
def __init__(self, value: IO[bytes]) -> Any:
|
||||
self.state = {"blob": value}
|
||||
|
||||
def state_dict(self) -> Dict[str, Any]:
|
||||
def state_dict(self) -> dict[str, Any]:
|
||||
return self.state
|
||||
|
||||
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
|
||||
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
||||
self.state = state_dict
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
import shutil
|
||||
import tempfile
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Dict, Optional, Tuple
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -35,7 +35,7 @@ def with_temp_dir(
|
|||
assert func is not None
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(self, *args: Tuple[object], **kwargs: Dict[str, Any]) -> None:
|
||||
def wrapper(self, *args: tuple[object], **kwargs: dict[str, Any]) -> None:
|
||||
# Only create temp_dir when rank is 0 (or no pg)
|
||||
if not dist.is_initialized() or dist.get_rank() == 0:
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import copy
|
|||
import functools
|
||||
import sys
|
||||
from itertools import chain
|
||||
from typing import Callable, Tuple, Type, Union
|
||||
from typing import Callable, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -154,9 +154,9 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
|
|||
*,
|
||||
use_orig_params: bool,
|
||||
use_dtensor: bool,
|
||||
wrapping: Tuple[nn.Module] = (),
|
||||
wrapping: tuple[nn.Module] = (),
|
||||
compile_model: bool = False,
|
||||
optimizer_class: Type[Optimizer],
|
||||
optimizer_class: type[Optimizer],
|
||||
) -> None:
|
||||
if not use_orig_params:
|
||||
return
|
||||
|
|
@ -232,7 +232,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
|
|||
self,
|
||||
*,
|
||||
reshard_after_forward: Union[bool, int],
|
||||
optimizer_class: Type[Optimizer],
|
||||
optimizer_class: type[Optimizer],
|
||||
compile_model: bool,
|
||||
foreach: bool = True,
|
||||
):
|
||||
|
|
@ -272,7 +272,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
|
|||
self._test_fsdp2,
|
||||
)
|
||||
|
||||
def _test_ddp(self, use_composable: bool, optimizer_class: Type[Optimizer]) -> None:
|
||||
def _test_ddp(self, use_composable: bool, optimizer_class: type[Optimizer]) -> None:
|
||||
def init_model_optim():
|
||||
orig_model = CompositeParamModel(device=torch.device("cuda"))
|
||||
orig_optim = optimizer_class(orig_model.parameters(), lr=1e-4)
|
||||
|
|
@ -303,7 +303,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
|
|||
|
||||
def _test_fsdp_ddp(
|
||||
self,
|
||||
optimizer_class: Type[Optimizer],
|
||||
optimizer_class: type[Optimizer],
|
||||
optim_in_backward: bool = False,
|
||||
test_frozen: bool = False,
|
||||
) -> None:
|
||||
|
|
@ -347,7 +347,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
|
|||
self._test_fsdp_ddp,
|
||||
)
|
||||
|
||||
def _test_single_gpu(self, optimizer_class: Type[Optimizer]) -> None:
|
||||
def _test_single_gpu(self, optimizer_class: type[Optimizer]) -> None:
|
||||
def init_model_optim():
|
||||
orig_model = CompositeParamModel(device=torch.device("cuda"))
|
||||
orig_optim = optimizer_class(orig_model.parameters(), lr=1e-4)
|
||||
|
|
@ -399,7 +399,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
|
|||
)
|
||||
|
||||
def _test_cpu_offload_full_state_dict(
|
||||
self, optimizer_class: Type[Optimizer]
|
||||
self, optimizer_class: type[Optimizer]
|
||||
) -> None:
|
||||
orig_model = CompositeParamModel(device=torch.device("cuda"))
|
||||
device_mesh = init_device_mesh("cuda", (self.world_size,))
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ import signal
|
|||
import unittest
|
||||
import uuid
|
||||
from multiprocessing.pool import ThreadPool
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
from unittest.mock import call, patch
|
||||
|
||||
import torch.distributed as dist
|
||||
|
|
@ -135,7 +135,7 @@ class TestAgent(SimpleElasticAgent):
|
|||
worker_group.group_world_size = None
|
||||
self.stop_workers_call_count += 1
|
||||
|
||||
def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]:
|
||||
def _start_workers(self, worker_group: WorkerGroup) -> dict[int, Any]:
|
||||
# crate fake workers; make worker id equal to global rank
|
||||
ids = {}
|
||||
for worker in worker_group.workers:
|
||||
|
|
@ -477,7 +477,7 @@ class SimpleElasticAgentTest(unittest.TestCase):
|
|||
self.assertEqual(1, mock_monitor_workers.call_count)
|
||||
self.assertEqual(spec.max_restarts, agent._remaining_restarts)
|
||||
|
||||
def get_worker_assigned(self, store, role_infos_len, info) -> List[Worker]:
|
||||
def get_worker_assigned(self, store, role_infos_len, info) -> list[Worker]:
|
||||
i, role_info = info
|
||||
spec = self._get_worker_spec(
|
||||
max_restarts=3,
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ import time
|
|||
import unittest
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
from typing import Callable, Optional
|
||||
from unittest import mock
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
|
|
@ -256,7 +256,7 @@ class Conf:
|
|||
|
||||
entrypoint: Callable
|
||||
local_world_size: int
|
||||
args: Tuple = ()
|
||||
args: tuple = ()
|
||||
role: str = "default"
|
||||
redirects: Std = Std.NONE
|
||||
tee: Std = Std.NONE
|
||||
|
|
@ -394,10 +394,10 @@ class LocalElasticAgentTest(unittest.TestCase):
|
|||
|
||||
def run_job(
|
||||
self,
|
||||
node_configs: List[Conf],
|
||||
node_configs: list[Conf],
|
||||
exit_barrier_timeout: int = 5,
|
||||
log_line_prefix_template: Optional[str] = None,
|
||||
) -> Dict[str, List[RunResult]]:
|
||||
) -> dict[str, list[RunResult]]:
|
||||
"""
|
||||
Simulates running a distributed job by running multiple agents
|
||||
(one on each process). Agent 0 is run on the main process for
|
||||
|
|
@ -431,7 +431,7 @@ class LocalElasticAgentTest(unittest.TestCase):
|
|||
for p in procs:
|
||||
p.join()
|
||||
|
||||
results: Dict[str, List[RunResult]] = {}
|
||||
results: dict[str, list[RunResult]] = {}
|
||||
while not agent_results.empty():
|
||||
role, run_result = agent_results.get()
|
||||
results.setdefault(role, []).append(run_result)
|
||||
|
|
@ -1032,8 +1032,8 @@ class LocalElasticAgentTest(unittest.TestCase):
|
|||
|
||||
def assert_rank_consistency(
|
||||
self,
|
||||
run_results: Dict[str, List[RunResult]],
|
||||
expected_role_world_sizes: Dict[str, int],
|
||||
run_results: dict[str, list[RunResult]],
|
||||
expected_role_world_sizes: dict[str, int],
|
||||
):
|
||||
"""
|
||||
Asserts that ranks are consecutive w.r.t role_rank. If local world sizes are 4:
|
||||
|
|
@ -1042,11 +1042,11 @@ class LocalElasticAgentTest(unittest.TestCase):
|
|||
... etc ...
|
||||
"""
|
||||
|
||||
global_ranks: List[int] = []
|
||||
global_ranks: list[int] = []
|
||||
# role -> [role_rank,...]
|
||||
role_ranks: Dict[str, List[int]] = {}
|
||||
role_ranks: dict[str, list[int]] = {}
|
||||
# group rank -> [(rank, role_rank),...]
|
||||
grouped_ranks: Dict[int, List[Tuple[int, int]]] = {}
|
||||
grouped_ranks: dict[int, list[tuple[int, int]]] = {}
|
||||
|
||||
# global world size == sum of all the role world sizes
|
||||
expected_world_size = sum(expected_role_world_sizes.values())
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ import sys
|
|||
import tempfile
|
||||
import time
|
||||
from itertools import product
|
||||
from typing import Callable, Dict, List, Union
|
||||
from typing import Callable, Union
|
||||
from unittest import mock
|
||||
|
||||
import torch
|
||||
|
|
@ -141,7 +141,7 @@ def echo2(msg: str, fail: bool = False) -> str:
|
|||
return msg
|
||||
|
||||
|
||||
def echo_large(size: int) -> Dict[int, str]:
|
||||
def echo_large(size: int) -> dict[int, str]:
|
||||
"""
|
||||
returns a large output ({0: test0", 1: "test1", ..., (size-1):f"test{size-1}"})
|
||||
"""
|
||||
|
|
@ -167,13 +167,13 @@ def dummy_compute() -> torch.Tensor:
|
|||
return torch.rand(100, 100)
|
||||
|
||||
|
||||
def redirects_oss_test() -> List[Std]:
|
||||
def redirects_oss_test() -> list[Std]:
|
||||
return [
|
||||
Std.NONE,
|
||||
]
|
||||
|
||||
|
||||
def redirects_all() -> List[Std]:
|
||||
def redirects_all() -> list[Std]:
|
||||
return [
|
||||
Std.NONE,
|
||||
Std.OUT,
|
||||
|
|
@ -240,14 +240,14 @@ class _StartProcessesTest(TestCase):
|
|||
def log_dir(self):
|
||||
return tempfile.mkdtemp(dir=self.test_dir)
|
||||
|
||||
def assert_in_file(self, expected: List[str], filename: str) -> None:
|
||||
def assert_in_file(self, expected: list[str], filename: str) -> None:
|
||||
expected = [f"{line.rstrip()}\n" for line in expected]
|
||||
with open(filename) as fp:
|
||||
actual = fp.readlines()
|
||||
for line in expected:
|
||||
self.assertIn(line, actual)
|
||||
|
||||
def assert_pids_noexist(self, pids: Dict[int, int]):
|
||||
def assert_pids_noexist(self, pids: dict[int, int]):
|
||||
for local_rank, pid in pids.items():
|
||||
with self.assertRaises(
|
||||
OSError, msg=f"local_rank: {local_rank} pid: {pid} should not exist"
|
||||
|
|
|
|||
|
|
@ -16,7 +16,6 @@ import unittest
|
|||
from concurrent.futures import wait
|
||||
from concurrent.futures._base import ALL_COMPLETED
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
from typing import Dict, Set
|
||||
from unittest import mock
|
||||
|
||||
from torch.distributed.elastic.multiprocessing.tail_log import TailLog
|
||||
|
|
@ -72,7 +71,7 @@ class TailLogTest(unittest.TestCase):
|
|||
tail.stop()
|
||||
|
||||
dst.seek(0)
|
||||
actual: Dict[int, Set[int]] = {}
|
||||
actual: dict[int, set[int]] = {}
|
||||
|
||||
for line in dst.readlines():
|
||||
header, num = line.split(":")
|
||||
|
|
@ -123,7 +122,7 @@ class TailLogTest(unittest.TestCase):
|
|||
tail.stop()
|
||||
dst.seek(0)
|
||||
|
||||
headers: Set[str] = set()
|
||||
headers: set[str] = set()
|
||||
for line in dst.readlines():
|
||||
header, _ = line.split(":")
|
||||
headers.add(header)
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Any, cast, Dict, SupportsInt
|
||||
from typing import Any, cast, SupportsInt
|
||||
from unittest import TestCase
|
||||
|
||||
from torch.distributed.elastic.rendezvous import (
|
||||
|
|
@ -24,7 +24,7 @@ class RendezvousParametersTest(TestCase):
|
|||
self._run_id = "dummy_run_id"
|
||||
self._min_nodes = 3
|
||||
self._max_nodes = 6
|
||||
self._kwargs: Dict[str, Any] = {}
|
||||
self._kwargs: dict[str, Any] = {}
|
||||
|
||||
def _create_params(self) -> RendezvousParameters:
|
||||
return RendezvousParameters(
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ import time
|
|||
from abc import ABC, abstractmethod
|
||||
from base64 import b64encode
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Callable, cast, Optional, Tuple
|
||||
from typing import Callable, cast, Optional
|
||||
from unittest import TestCase
|
||||
from unittest.mock import call, MagicMock, Mock, patch, PropertyMock
|
||||
|
||||
|
|
@ -186,7 +186,7 @@ class FakeRendezvousBackend(RendezvousBackend):
|
|||
def name(self) -> str:
|
||||
return "fake_backend"
|
||||
|
||||
def get_state(self) -> Optional[Tuple[bytes, Token]]:
|
||||
def get_state(self) -> Optional[tuple[bytes, Token]]:
|
||||
if self._token == 0:
|
||||
return None
|
||||
|
||||
|
|
@ -194,7 +194,7 @@ class FakeRendezvousBackend(RendezvousBackend):
|
|||
|
||||
def set_state(
|
||||
self, state: bytes, token: Optional[Token] = None
|
||||
) -> Optional[Tuple[bytes, Token, bool]]:
|
||||
) -> Optional[tuple[bytes, Token, bool]]:
|
||||
if token is None:
|
||||
token = 0
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@
|
|||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, cast, Optional, Tuple
|
||||
from typing import Any, Callable, cast, Optional
|
||||
|
||||
from torch.distributed.elastic.rendezvous import RendezvousStateError
|
||||
from torch.distributed.elastic.rendezvous.dynamic_rendezvous import (
|
||||
|
|
@ -32,12 +32,12 @@ class RendezvousBackendTestMixin(ABC):
|
|||
|
||||
def _set_state(
|
||||
self, state: bytes, token: Optional[Any] = None
|
||||
) -> Tuple[bytes, Token, bool]:
|
||||
) -> tuple[bytes, Token, bool]:
|
||||
result = self._backend.set_state(state, token)
|
||||
|
||||
self.assertIsNotNone(result)
|
||||
|
||||
return cast(Tuple[bytes, Token, bool], result)
|
||||
return cast(tuple[bytes, Token, bool], result)
|
||||
|
||||
def test_get_state_returns_backend_state(self) -> None:
|
||||
self._backend.set_state(b"x")
|
||||
|
|
@ -46,7 +46,7 @@ class RendezvousBackendTestMixin(ABC):
|
|||
|
||||
self.assertIsNotNone(result)
|
||||
|
||||
state, token = cast(Tuple[bytes, Token], result)
|
||||
state, token = cast(tuple[bytes, Token], result)
|
||||
|
||||
self.assertEqual(b"x", state)
|
||||
self.assertIsNotNone(token)
|
||||
|
|
|
|||
|
|
@ -10,7 +10,6 @@ import socket
|
|||
import threading
|
||||
import time
|
||||
from datetime import timedelta
|
||||
from typing import List
|
||||
from unittest import TestCase
|
||||
from unittest.mock import patch
|
||||
|
||||
|
|
@ -350,7 +349,7 @@ class PeriodicTimerTest(TestCase):
|
|||
call_interval = 0.2
|
||||
|
||||
# Keep the log of intervals between each consecutive call.
|
||||
actual_call_intervals: List[float] = []
|
||||
actual_call_intervals: list[float] = []
|
||||
|
||||
# Keep the number of times the function was called.
|
||||
call_count = 0
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@ import pickle
|
|||
import socket
|
||||
import tempfile
|
||||
from contextlib import contextmanager
|
||||
from typing import Dict
|
||||
|
||||
from urllib3.connection import HTTPConnection
|
||||
from urllib3.connectionpool import HTTPConnectionPool
|
||||
|
|
@ -181,7 +180,7 @@ class WorkerServerTest(TestCase):
|
|||
def body(self) -> bytes:
|
||||
return b"dummy"
|
||||
|
||||
def params(self) -> Dict[str, str]:
|
||||
def params(self) -> dict[str, str]:
|
||||
return {}
|
||||
|
||||
class Response(_Response):
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@
|
|||
|
||||
import datetime
|
||||
from multiprocessing.pool import ThreadPool
|
||||
from typing import List
|
||||
from unittest import mock
|
||||
|
||||
import torch.distributed as dist
|
||||
|
|
@ -40,7 +39,7 @@ class MockStore:
|
|||
self.ops.append(("get", key))
|
||||
return "value"
|
||||
|
||||
def multi_get(self, keys: List[str]) -> List[str]:
|
||||
def multi_get(self, keys: list[str]) -> list[str]:
|
||||
self.ops.append(("multi_get", keys))
|
||||
return ["value"] * len(keys)
|
||||
|
||||
|
|
@ -48,7 +47,7 @@ class MockStore:
|
|||
self.ops.append(("add", key, val))
|
||||
return 3
|
||||
|
||||
def wait(self, keys: List[str]) -> None:
|
||||
def wait(self, keys: list[str]) -> None:
|
||||
self.ops.append(("wait", keys))
|
||||
|
||||
|
||||
|
|
@ -157,7 +156,7 @@ class StoreUtilTest(TestCase):
|
|||
return ""
|
||||
|
||||
with ThreadPool(N - 1) as pool:
|
||||
outputs: List[str] = pool.map(run_barrier_for_rank, range(N - 1))
|
||||
outputs: list[str] = pool.map(run_barrier_for_rank, range(N - 1))
|
||||
|
||||
self.assertTrue(any("missing_ranks=[Rank 2 host]" in msg for msg in outputs))
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import sys
|
||||
from typing import List
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
|
|
@ -102,7 +101,7 @@ class TestBackwardPrefetch(FSDPTest):
|
|||
tgt = torch.randn((20, 1, 1024), device=device_type)
|
||||
|
||||
# monkey patch
|
||||
all_handle_fqns: List[List[str]] = []
|
||||
all_handle_fqns: list[list[str]] = []
|
||||
|
||||
def patched_get_handle_to_prefetch(*args, **kwargs):
|
||||
handle = orig_get_handle_to_prefetch(*args, **kwargs)
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
import sys
|
||||
from contextlib import nullcontext
|
||||
from enum import auto, Enum
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
|
|
@ -319,7 +319,7 @@ class TestExplicitUnshard(FSDPTest):
|
|||
self.mlp2 = MLP(dim)
|
||||
self.mlp3 = MLP(dim)
|
||||
|
||||
def forward(self, ys: List[torch.Tensor], works: List[dist.Work]):
|
||||
def forward(self, ys: list[torch.Tensor], works: list[dist.Work]):
|
||||
(y1, y2, y3), (work1, work2, work3) = ys, works
|
||||
work1.wait()
|
||||
z1 = self.mlp1(y1)
|
||||
|
|
@ -372,7 +372,7 @@ class TestExplicitUnshard(FSDPTest):
|
|||
torch.manual_seed(42 + self.rank + 1)
|
||||
inp = torch.randn((batch_size, dim), device=device_type)
|
||||
for _ in range(10):
|
||||
losses: List[torch.Tensor] = []
|
||||
losses: list[torch.Tensor] = []
|
||||
for _model, _optim in ((ref_model, ref_optim), (model, optim)):
|
||||
losses.append(_model(inp).sum())
|
||||
losses[-1].backward()
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import functools
|
|||
import itertools
|
||||
import sys
|
||||
import unittest
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from typing import Any, Callable, Optional
|
||||
from unittest import mock
|
||||
|
||||
import torch
|
||||
|
|
@ -76,7 +76,7 @@ class TestParityWithDDP(FSDPTest):
|
|||
PyTorch DDP vs. FullyShardedDataParallel.
|
||||
"""
|
||||
|
||||
def _get_device_init_modes(self, cpu_offload: CPUOffload) -> List[DEVICEInitMode]:
|
||||
def _get_device_init_modes(self, cpu_offload: CPUOffload) -> list[DEVICEInitMode]:
|
||||
modes = [
|
||||
DEVICEInitMode.DEVICE_AFTER,
|
||||
DEVICEInitMode.DEVICE_BEFORE,
|
||||
|
|
@ -89,7 +89,7 @@ class TestParityWithDDP(FSDPTest):
|
|||
modes.append(DEVICEInitMode.DEVICE_NEVER)
|
||||
return modes
|
||||
|
||||
def _get_subtest_config(self, cpu_offload: CPUOffload) -> Dict[str, List[Any]]:
|
||||
def _get_subtest_config(self, cpu_offload: CPUOffload) -> dict[str, list[Any]]:
|
||||
"""Returns a subtest configuration that subtests CUDA initialization
|
||||
modes and prefetching settings together."""
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import contextlib
|
|||
import itertools
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch import distributed as dist
|
||||
|
|
@ -71,7 +71,7 @@ class _GradAccConfigs:
|
|||
sole purpose of overriding :meth:`__repr__` to remove spaces.
|
||||
"""
|
||||
|
||||
configs: List[_GradAccConfig]
|
||||
configs: list[_GradAccConfig]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
# Override to remove any spaces in the string to appease the internal
|
||||
|
|
@ -90,7 +90,7 @@ class TestGradAcc(FSDPTest):
|
|||
def _test_grad_acc(
|
||||
self,
|
||||
batch_dim: int,
|
||||
configs: List[_GradAccConfig],
|
||||
configs: list[_GradAccConfig],
|
||||
cpu_offload: CPUOffload,
|
||||
backward_prefetch: Optional[BackwardPrefetch],
|
||||
sharding_strategy: ShardingStrategy,
|
||||
|
|
@ -146,8 +146,8 @@ class TestGradAcc(FSDPTest):
|
|||
def permute_tensor(x: torch.Tensor):
|
||||
return x.view(-1)[torch.randperm(x.numel())].view_as(x)
|
||||
|
||||
batch: Tuple[torch.Tensor, ...] = fsdp_model.module.get_input(device)
|
||||
batches: List[Tuple[torch.Tensor, ...]] = [batch]
|
||||
batch: tuple[torch.Tensor, ...] = fsdp_model.module.get_input(device)
|
||||
batches: list[tuple[torch.Tensor, ...]] = [batch]
|
||||
num_iters_to_acc = sum(config.num_iters for config in configs)
|
||||
for _ in range(num_iters_to_acc - 1):
|
||||
batches.append(tuple(permute_tensor(t) for t in batch))
|
||||
|
|
@ -158,7 +158,7 @@ class TestGradAcc(FSDPTest):
|
|||
), "Check the test to make sure that batches are distinct"
|
||||
|
||||
# Concatenate the batches along the given batch dimension
|
||||
concat_batch: Tuple[torch.Tensor, ...] = tuple(
|
||||
concat_batch: tuple[torch.Tensor, ...] = tuple(
|
||||
torch.cat(ts, dim=batch_dim) for ts in zip(*batches)
|
||||
)
|
||||
|
||||
|
|
@ -214,7 +214,7 @@ class TestGradAcc(FSDPTest):
|
|||
# Check that the optimizer step does not error
|
||||
optim.step()
|
||||
|
||||
def _get_subtest_config(self) -> Dict[str, List[Any]]:
|
||||
def _get_subtest_config(self) -> dict[str, list[Any]]:
|
||||
"""Returns a subtest configuration that subtests prefetching."""
|
||||
return {
|
||||
"backward_prefetch": [
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import sys
|
|||
from collections import Counter
|
||||
from enum import auto, Enum
|
||||
from functools import partial
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -363,7 +363,7 @@ class TestFSDPHybridShard(FSDPTest):
|
|||
torch.manual_seed(global_pg.rank() + 1)
|
||||
for _ in range(5):
|
||||
inp = fsdp_model.module.get_input(torch.device("cuda"))
|
||||
losses: List[torch.Tensor] = []
|
||||
losses: list[torch.Tensor] = []
|
||||
for model, optim in ((fsdp_model, fsdp_optim), (hsdp_model, hsdp_optim)):
|
||||
optim.zero_grad()
|
||||
loss = model(*inp).sum()
|
||||
|
|
@ -396,7 +396,7 @@ class TestFSDPHybridShard(FSDPTest):
|
|||
sharding_strategy_mode: str,
|
||||
use_orig_params: bool,
|
||||
hsdp_process_groups: Optional[
|
||||
Tuple[dist.ProcessGroup, dist.ProcessGroup]
|
||||
tuple[dist.ProcessGroup, dist.ProcessGroup]
|
||||
] = None,
|
||||
hsdp_device_mesh: Optional = None,
|
||||
):
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from collections import namedtuple
|
|||
from contextlib import nullcontext
|
||||
from copy import deepcopy
|
||||
from itertools import chain
|
||||
from typing import Any, Tuple
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -945,7 +945,7 @@ class TestFSDPMiscMultiThread(FSDPTestMultiThread):
|
|||
self._test_homogeneous_attributes,
|
||||
)
|
||||
|
||||
def _test_homogeneous_attributes(self, attr_name_and_values: Tuple[str, Any, Any]):
|
||||
def _test_homogeneous_attributes(self, attr_name_and_values: tuple[str, Any, Any]):
|
||||
model = NestedWrappedModule.init(
|
||||
self.process_group,
|
||||
FSDPInitMode.NO_FSDP,
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import os
|
|||
import sys
|
||||
from functools import partial
|
||||
from itertools import product
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.cuda.nccl as nccl
|
||||
|
|
@ -521,7 +521,7 @@ class TestFSDPMixedPrecisionSharded(TestFSDPMixedPrecision):
|
|||
def world_size(self):
|
||||
return 2
|
||||
|
||||
def _get_subtest_config(self) -> Dict[str, List[Any]]:
|
||||
def _get_subtest_config(self) -> dict[str, list[Any]]:
|
||||
"""Returns a subtest configuration that subtests prefetching settings
|
||||
together."""
|
||||
return {
|
||||
|
|
@ -1136,7 +1136,7 @@ class TestFSDPDifferentSubmodulePrecision(FSDPTest):
|
|||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_float16_on_one_submodule(self):
|
||||
forward_inputs: Dict[str, nn.Module] = {}
|
||||
forward_inputs: dict[str, nn.Module] = {}
|
||||
float16 = MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=True)
|
||||
|
||||
model = SaveForwardInputsModel(
|
||||
|
|
@ -1158,7 +1158,7 @@ class TestFSDPDifferentSubmodulePrecision(FSDPTest):
|
|||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_float16_on_one_submodule_skip_inputs(self):
|
||||
forward_inputs: Dict[nn.Module, torch.Tensor] = {}
|
||||
forward_inputs: dict[nn.Module, torch.Tensor] = {}
|
||||
float16 = MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=False)
|
||||
|
||||
model = SaveForwardInputsModel(
|
||||
|
|
@ -1179,7 +1179,7 @@ class TestFSDPDifferentSubmodulePrecision(FSDPTest):
|
|||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_float16_on_one_submodule_skip_inputs_error(self):
|
||||
forward_inputs: Dict[nn.Module, torch.Tensor] = {}
|
||||
forward_inputs: dict[nn.Module, torch.Tensor] = {}
|
||||
float16 = MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=False)
|
||||
|
||||
model = SaveForwardInputsModel(
|
||||
|
|
@ -1198,7 +1198,7 @@ class TestFSDPDifferentSubmodulePrecision(FSDPTest):
|
|||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_submodules_with_different_precisions_error(self):
|
||||
forward_inputs: Dict[nn.Module, torch.Tensor] = {}
|
||||
forward_inputs: dict[nn.Module, torch.Tensor] = {}
|
||||
float16 = MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=True)
|
||||
float32 = MixedPrecision(param_dtype=torch.float32, cast_forward_inputs=True)
|
||||
|
||||
|
|
@ -1222,7 +1222,7 @@ class TestFSDPDifferentSubmodulePrecision(FSDPTest):
|
|||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_submodules_with_different_precisions(self):
|
||||
forward_inputs: Dict[nn.Module, torch.Tensor] = {}
|
||||
forward_inputs: dict[nn.Module, torch.Tensor] = {}
|
||||
float16 = MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=True)
|
||||
float32 = MixedPrecision(param_dtype=torch.float32, cast_forward_inputs=True)
|
||||
|
||||
|
|
@ -1244,7 +1244,7 @@ class TestFSDPDifferentSubmodulePrecision(FSDPTest):
|
|||
@skip_if_lt_x_gpu(2)
|
||||
def test_submodules_with_external_inputs(self):
|
||||
class ToyModule(nn.Module):
|
||||
def __init__(self, forward_inputs: Dict[str, torch.Tensor]) -> None:
|
||||
def __init__(self, forward_inputs: dict[str, torch.Tensor]) -> None:
|
||||
super().__init__()
|
||||
self.l = nn.Linear(100, 100)
|
||||
self.forward_inputs = forward_inputs
|
||||
|
|
@ -1255,7 +1255,7 @@ class TestFSDPDifferentSubmodulePrecision(FSDPTest):
|
|||
return self.l(x)
|
||||
|
||||
class ToyModel(nn.Module):
|
||||
def __init__(self, forward_inputs: Dict[str, torch.Tensor]) -> None:
|
||||
def __init__(self, forward_inputs: dict[str, torch.Tensor]) -> None:
|
||||
super().__init__()
|
||||
self.l1 = nn.Linear(100, 100)
|
||||
self.l2 = ToyModule(forward_inputs)
|
||||
|
|
@ -1266,7 +1266,7 @@ class TestFSDPDifferentSubmodulePrecision(FSDPTest):
|
|||
y = torch.ones(2, 100, device="cuda", dtype=torch.float32)
|
||||
return self.l2(self.l1(x), y)
|
||||
|
||||
forward_inputs: Dict[str, torch.Tensor] = {}
|
||||
forward_inputs: dict[str, torch.Tensor] = {}
|
||||
|
||||
float16 = MixedPrecision(param_dtype=torch.float16)
|
||||
model = ToyModel(forward_inputs).cuda()
|
||||
|
|
@ -1343,7 +1343,7 @@ class TestFSDPTrainEval(FSDPTest):
|
|||
torch.manual_seed(1 + self.rank)
|
||||
eval_src = torch.randn((8, 1, 512), device=device)
|
||||
eval_tgt = torch.randn((16, 1, 512), device=device)
|
||||
eval_out_sums: List[torch.Tensor] = []
|
||||
eval_out_sums: list[torch.Tensor] = []
|
||||
# An iteration consists of training forward/backward/optimizer,
|
||||
# updating the EMA copy with the main copy, and eval forward
|
||||
for _ in range(3):
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import bisect
|
|||
import sys
|
||||
from copy import deepcopy
|
||||
from enum import auto, Enum
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
|
@ -177,7 +177,7 @@ class NestedModel(torch.nn.Module):
|
|||
model: torch.nn.Module,
|
||||
group: Optional[dist.ProcessGroup] = None,
|
||||
ignore_modules: bool = False,
|
||||
fsdp_kwargs: Optional[Dict[str, Any]] = None,
|
||||
fsdp_kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> torch.nn.Module:
|
||||
if fsdp_kwargs is None:
|
||||
fsdp_kwargs = {}
|
||||
|
|
@ -214,7 +214,7 @@ class NestedModel(torch.nn.Module):
|
|||
def wrap_alt(
|
||||
model: torch.nn.Module,
|
||||
group: Optional[dist.ProcessGroup] = None,
|
||||
fsdp_kwargs: Optional[Dict[str, Any]] = None,
|
||||
fsdp_kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> torch.nn.Module:
|
||||
if fsdp_kwargs is None:
|
||||
fsdp_kwargs = {}
|
||||
|
|
@ -231,7 +231,7 @@ class NestedModel(torch.nn.Module):
|
|||
model,
|
||||
add_to_fsdp_module: bool,
|
||||
group=None,
|
||||
) -> Tuple[torch.nn.Module, List[torch.nn.Parameter]]:
|
||||
) -> tuple[torch.nn.Module, list[torch.nn.Parameter]]:
|
||||
"""Registers unmanaged parameters before wrapping with :meth:`wrap`."""
|
||||
device = next(model.parameters()).device
|
||||
unmanaged_param = torch.nn.Parameter(torch.randn(5, 5, device=device))
|
||||
|
|
@ -277,12 +277,12 @@ class NestedModel(torch.nn.Module):
|
|||
|
||||
# NOTE: We exclude `self.bias` from either parameter group to test the
|
||||
# case where the optimizer input does not include all model parameters
|
||||
def param_group0(self) -> List[torch.nn.Parameter]:
|
||||
def param_group0(self) -> list[torch.nn.Parameter]:
|
||||
# Use `block1`'s parameters for the first parameter group to deviate
|
||||
# from the `model.parameters()` order
|
||||
return list(self.block1.parameters())
|
||||
|
||||
def param_group1(self) -> List[torch.nn.Parameter]:
|
||||
def param_group1(self) -> list[torch.nn.Parameter]:
|
||||
# Deviate from the `model.parameters()` order further by rearranging
|
||||
# `block2`'s parameters to be before `block0`'s parameters
|
||||
return list(self.block2.parameters()) + list(self.block0.parameters())
|
||||
|
|
@ -322,10 +322,10 @@ class TestFSDPOptimState(FSDPTest):
|
|||
wrap_alt: bool = False, # ignored if `wrap=False`
|
||||
device: torch.device = torch.device("cuda"),
|
||||
group=None,
|
||||
optim_class: Type[torch.optim.Optimizer] = torch.optim.Adam,
|
||||
optim_class: type[torch.optim.Optimizer] = torch.optim.Adam,
|
||||
use_multiple_param_groups: bool = False,
|
||||
use_diff_optim_inputs: bool = False,
|
||||
fsdp_kwargs: Optional[Dict[str, Any]] = None,
|
||||
fsdp_kwargs: Optional[dict[str, Any]] = None,
|
||||
):
|
||||
model = NestedModel().to(device)
|
||||
if wrap:
|
||||
|
|
@ -356,7 +356,7 @@ class TestFSDPOptimState(FSDPTest):
|
|||
wrap: bool,
|
||||
device: torch.device = torch.device("cuda"),
|
||||
group=None,
|
||||
optim_class: Type[torch.optim.Optimizer] = torch.optim.Adam,
|
||||
optim_class: type[torch.optim.Optimizer] = torch.optim.Adam,
|
||||
use_multiple_param_groups: bool = False,
|
||||
use_diff_optim_inputs: bool = False,
|
||||
):
|
||||
|
|
@ -383,7 +383,7 @@ class TestFSDPOptimState(FSDPTest):
|
|||
optim: torch.optim.Optimizer,
|
||||
device: torch.device = torch.device("cuda"),
|
||||
num_iters: int = 1,
|
||||
) -> List[float]:
|
||||
) -> list[float]:
|
||||
"""Performs a forward pass, backward pass, and optimizer step
|
||||
``num_iters``-many times, and returns the per-iteration losses."""
|
||||
torch.manual_seed(0) # set seed for determinism
|
||||
|
|
@ -399,7 +399,7 @@ class TestFSDPOptimState(FSDPTest):
|
|||
optim.step()
|
||||
return losses
|
||||
|
||||
def _broadcast_full_osd(self, full_osd: Dict[str, Any], group=None):
|
||||
def _broadcast_full_osd(self, full_osd: dict[str, Any], group=None):
|
||||
"""Broadcasts the full optimizer state dict in place of using
|
||||
``torch.save()`` and ``torch.load()`` so that all ranks can have it."""
|
||||
obj_list = [full_osd]
|
||||
|
|
@ -413,8 +413,8 @@ class TestFSDPOptimState(FSDPTest):
|
|||
|
||||
def _are_equal_states(
|
||||
self,
|
||||
state1: Dict[str, Any],
|
||||
state2: Dict[str, Any],
|
||||
state1: dict[str, Any],
|
||||
state2: dict[str, Any],
|
||||
) -> bool:
|
||||
"""Checks if ``state1`` and ``state2`` contain the same mappings."""
|
||||
if set(state1.keys()) != set(state2.keys()):
|
||||
|
|
@ -1450,7 +1450,7 @@ class TestFSDPOptimState(FSDPTest):
|
|||
self,
|
||||
should_check_method_fn: Callable[[str], bool],
|
||||
context_fn: Callable,
|
||||
fsdp_kwargs: Optional[Dict[str, Any]],
|
||||
fsdp_kwargs: Optional[dict[str, Any]],
|
||||
):
|
||||
"""
|
||||
Runs through all optimizer state checkpointing APIs with a context
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import functools
|
|||
import itertools
|
||||
import sys
|
||||
import unittest
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import distributed as dist
|
||||
|
|
@ -259,7 +259,7 @@ class TestShardedGradScalerParityWithDDP(FSDPTest):
|
|||
)
|
||||
grad_scaler = ShardedGradScaler(init_scale=2.0)
|
||||
ref_grad_scaler = torch.amp.GradScaler(device="cuda", init_scale=2.0)
|
||||
scaled_losses: List[torch.Tensor] = []
|
||||
scaled_losses: list[torch.Tensor] = []
|
||||
device = torch.device("cuda")
|
||||
torch.manual_seed(42 + self.rank + 1)
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import sys
|
|||
from contextlib import nullcontext
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
|
@ -787,7 +787,7 @@ class TestFSDPStateDict(FSDPTest):
|
|||
|
||||
@staticmethod
|
||||
def _load_state_dict(
|
||||
model: Module, state_dict_type: str, state_dict: Dict[str, Any]
|
||||
model: Module, state_dict_type: str, state_dict: dict[str, Any]
|
||||
):
|
||||
try:
|
||||
enum_val = STATE_DICT_MAPPING[state_dict_type]
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
import copy
|
||||
import sys
|
||||
from collections import OrderedDict
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import distributed as dist
|
||||
|
|
@ -62,11 +62,11 @@ class SimpleModel(torch.nn.Module):
|
|||
return self.net3(self.net2(self.relu(self.net1(x))))
|
||||
|
||||
@staticmethod
|
||||
def get_sharded_param_names() -> List[str]:
|
||||
def get_sharded_param_names() -> list[str]:
|
||||
return ["net1.weight", "net1.bias", "net2.weight"]
|
||||
|
||||
@staticmethod
|
||||
def get_non_sharded_param_names() -> List[str]:
|
||||
def get_non_sharded_param_names() -> list[str]:
|
||||
return ["net3.weight", "net3.bias"]
|
||||
|
||||
|
||||
|
|
@ -87,9 +87,9 @@ class TestTPFSDPIntegration(FSDPTest):
|
|||
def _get_params_and_sharding_info(
|
||||
self,
|
||||
model: SimpleModel,
|
||||
sharded_param_names: List[str],
|
||||
sharded_param_names: list[str],
|
||||
tensor_parallel_size: int,
|
||||
) -> Tuple[Dict[str, int], Dict[str, Tuple[torch.Size, int]]]:
|
||||
) -> tuple[dict[str, int], dict[str, tuple[torch.Size, int]]]:
|
||||
""" """
|
||||
assert (
|
||||
type(model) is SimpleModel
|
||||
|
|
@ -131,8 +131,8 @@ class TestTPFSDPIntegration(FSDPTest):
|
|||
self,
|
||||
tp_fsdp_model: FSDP,
|
||||
tp_pg: dist.ProcessGroup,
|
||||
param_name_to_numel: Dict[str, int],
|
||||
non_sharded_param_names: List[str],
|
||||
param_name_to_numel: dict[str, int],
|
||||
non_sharded_param_names: list[str],
|
||||
) -> None:
|
||||
"""
|
||||
Syncs the tensor parallel parameters' gradients following the data
|
||||
|
|
@ -177,11 +177,11 @@ class TestTPFSDPIntegration(FSDPTest):
|
|||
self,
|
||||
model: FSDP,
|
||||
uses_tp: bool,
|
||||
param_name_to_numel: Dict[str, int],
|
||||
param_name_to_sharding_info: Dict[str, Tuple[torch.Size, int]],
|
||||
param_name_to_numel: dict[str, int],
|
||||
param_name_to_sharding_info: dict[str, tuple[torch.Size, int]],
|
||||
tp_pg: Optional[dist.ProcessGroup],
|
||||
fsdp_pg: Optional[dist.ProcessGroup],
|
||||
sharded_param_names: Optional[List[str]],
|
||||
sharded_param_names: Optional[list[str]],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Returns all unsharded gradients as a single flattened tensor. This
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import contextlib
|
|||
import itertools
|
||||
import math
|
||||
import sys
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed.fsdp._traversal_utils as traversal_utils
|
||||
|
|
@ -55,7 +55,7 @@ class TestUnshardParamsBase(FSDPTest):
|
|||
self,
|
||||
writeback: bool,
|
||||
check_outer: bool,
|
||||
**fsdp_kwargs: Dict[str, Any],
|
||||
**fsdp_kwargs: dict[str, Any],
|
||||
):
|
||||
model = nn.Sequential(
|
||||
nn.Linear(5, 5, bias=False, device=device_type.type),
|
||||
|
|
@ -101,7 +101,7 @@ class TestUnshardParamsBase(FSDPTest):
|
|||
for param in model.parameters():
|
||||
self.assertEqual(param.device, cpu_device)
|
||||
|
||||
def _get_test_unshard_params_writeback_config(self) -> Dict[str, List[Any]]:
|
||||
def _get_test_unshard_params_writeback_config(self) -> dict[str, list[Any]]:
|
||||
return {
|
||||
"writeback": [True, False],
|
||||
"check_outer": [True, False],
|
||||
|
|
@ -193,7 +193,7 @@ class TestUnshardParamsBase(FSDPTest):
|
|||
num_fsdp_roots += fsdp_state._is_root
|
||||
self.assertGreater(num_fsdp_roots, 1)
|
||||
|
||||
def _get_test_unshard_params_param_data_config(self) -> Dict[str, List[Any]]:
|
||||
def _get_test_unshard_params_param_data_config(self) -> dict[str, list[Any]]:
|
||||
return {
|
||||
"rank0_only": [False, True],
|
||||
"offload_to_cpu": [False, True],
|
||||
|
|
@ -493,7 +493,7 @@ class TestUnshardParams(TestUnshardParamsBase):
|
|||
def _check_grads(
|
||||
ddp_model: DDP,
|
||||
fsdp_model: FSDP,
|
||||
old_fsdp_grads: Optional[List[torch.Tensor]],
|
||||
old_fsdp_grads: Optional[list[torch.Tensor]],
|
||||
):
|
||||
"""
|
||||
Checks that writes to the FSDP parameters' gradients persist or do
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import itertools
|
|||
import os
|
||||
import sys
|
||||
import unittest
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
|
@ -65,7 +65,7 @@ class TestFSDPUseOrigParamsMultipleParamGroups(FSDPTest):
|
|||
def world_size(self) -> int:
|
||||
return 2
|
||||
|
||||
def _get_param_groups(self, model: nn.Module) -> List[Dict[str, Any]]:
|
||||
def _get_param_groups(self, model: nn.Module) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Constructs separate parameter groups for weights, biases, and other
|
||||
parameters.
|
||||
|
|
@ -87,7 +87,7 @@ class TestFSDPUseOrigParamsMultipleParamGroups(FSDPTest):
|
|||
def _get_optim(
|
||||
self,
|
||||
model: nn.Module,
|
||||
optim_class: Type[torch.optim.Optimizer],
|
||||
optim_class: type[torch.optim.Optimizer],
|
||||
multi_tensor: bool,
|
||||
) -> torch.optim.Optimizer:
|
||||
"""
|
||||
|
|
@ -117,12 +117,12 @@ class TestFSDPUseOrigParamsMultipleParamGroups(FSDPTest):
|
|||
self,
|
||||
device_init_mode: DEVICEInitMode,
|
||||
init_optim_before_wrap: bool,
|
||||
optim_class: Type[torch.optim.Optimizer],
|
||||
optim_class: type[torch.optim.Optimizer],
|
||||
multi_tensor: bool,
|
||||
sharding_strategy: ShardingStrategy,
|
||||
backward_prefetch: Optional[BackwardPrefetch],
|
||||
cpu_offload: CPUOffload,
|
||||
) -> Tuple[FSDP, torch.optim.Optimizer]:
|
||||
) -> tuple[FSDP, torch.optim.Optimizer]:
|
||||
"""
|
||||
Returns a transformer with shared parameters wrapped with FSDP and a
|
||||
corresponding optimizer.
|
||||
|
|
@ -335,7 +335,7 @@ class TestFSDPUseOrigParamsMultipleParamGroups(FSDPTest):
|
|||
self,
|
||||
device_init_mode: DEVICEInitMode,
|
||||
init_optim_before_wrap: bool,
|
||||
optim_class: Type[torch.optim.Optimizer],
|
||||
optim_class: type[torch.optim.Optimizer],
|
||||
multi_tensor: bool,
|
||||
set_to_none: bool,
|
||||
backward_prefetch: Optional[BackwardPrefetch],
|
||||
|
|
@ -566,7 +566,7 @@ class TestFSDPUseOrigParamsUnshardReshard(FSDPTest):
|
|||
self,
|
||||
sharding_strategy: ShardingStrategy,
|
||||
cpu_offload: CPUOffload,
|
||||
) -> Tuple[FSDP, torch.optim.Optimizer, FSDP, torch.optim.Optimizer]:
|
||||
) -> tuple[FSDP, torch.optim.Optimizer, FSDP, torch.optim.Optimizer]:
|
||||
"""
|
||||
Returns a pair of (FSDP model, optimizer) for ``use_orig_params=False``
|
||||
and ``True``, respectively.
|
||||
|
|
@ -778,7 +778,7 @@ class TestFSDPUseOrigParamsParamAccess(FSDPTest):
|
|||
z = self.lin2(z)
|
||||
return z
|
||||
|
||||
def get_input(self, device: torch.device) -> Tuple[torch.Tensor, ...]:
|
||||
def get_input(self, device: torch.device) -> tuple[torch.Tensor, ...]:
|
||||
return (torch.randn((2, 5)).to(device),)
|
||||
|
||||
def get_loss(self, inp, out):
|
||||
|
|
@ -872,7 +872,7 @@ class TestFSDPUseOrigParamsWriteback(FSDPTest):
|
|||
z = self.lin2(z)
|
||||
return z
|
||||
|
||||
def get_input(self, device: torch.device) -> Tuple[torch.Tensor, ...]:
|
||||
def get_input(self, device: torch.device) -> tuple[torch.Tensor, ...]:
|
||||
return (torch.randn((2, 5)).to(device),)
|
||||
|
||||
def get_loss(self, inp, out):
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ import random
|
|||
import sys
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
|
@ -62,13 +61,13 @@ class TestUtils(TestCase):
|
|||
class NonFrozenDataClass:
|
||||
some_key: str
|
||||
some_float: float
|
||||
some_tensor: List[torch.Tensor]
|
||||
some_tensor: list[torch.Tensor]
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FrozenDataClass:
|
||||
some_key: str
|
||||
some_float: float
|
||||
some_tensor: List[torch.Tensor]
|
||||
some_tensor: list[torch.Tensor]
|
||||
|
||||
# create a mixed bag of data.
|
||||
data = [1, "str"]
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ import time
|
|||
import unittest
|
||||
import uuid
|
||||
from contextlib import closing
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Optional
|
||||
from unittest import mock
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
|
|
@ -59,7 +59,7 @@ def get_test_launch_config(
|
|||
nproc_per_node: int,
|
||||
run_id: str = "",
|
||||
rdzv_backend: str = "etcd",
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
config: Optional[dict[str, Any]] = None,
|
||||
) -> LaunchConfig:
|
||||
rdzv_configs = {}
|
||||
if config:
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@
|
|||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -22,7 +21,7 @@ from torch.testing._internal.common_utils import run_tests, TestCase
|
|||
class MyModuleInterface:
|
||||
def forward(
|
||||
self, tensor: Tensor, number: int, word: str = "default"
|
||||
) -> Tuple[Tensor, int, str]:
|
||||
) -> tuple[Tensor, int, str]:
|
||||
pass
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ import os
|
|||
import sys
|
||||
import unittest
|
||||
from contextlib import nullcontext
|
||||
from typing import Any, cast, List
|
||||
from typing import Any, cast
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
|
@ -207,7 +207,7 @@ class TestZeroRedundancyOptimizerSingleRank(TestZeroRedundancyOptimizer):
|
|||
super().step()
|
||||
kwarg.append(5)
|
||||
|
||||
kwarg: List[Any] = []
|
||||
kwarg: list[Any] = []
|
||||
x = torch.tensor([1.0], device=self.device, requires_grad=True)
|
||||
o = ZeroRedundancyOptimizer(
|
||||
[x],
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
# Owner(s): ["oncall: distributed"]
|
||||
# This file is a Schedule zoo for testing torch.distributed.pipelining.
|
||||
# It includes schedules designed purely for testing purposes
|
||||
from typing import Callable, Dict, List, Optional
|
||||
from typing import Callable, Optional
|
||||
|
||||
from torch.distributed.pipelining.schedules import (
|
||||
_Action,
|
||||
|
|
@ -32,9 +32,9 @@ class ScheduleVShaped(PipelineScheduleMulti):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
stages: List[_PipelineStageBase],
|
||||
stages: list[_PipelineStageBase],
|
||||
n_microbatches: int,
|
||||
stage_index_to_group_rank: Dict[int, int],
|
||||
stage_index_to_group_rank: dict[int, int],
|
||||
loss_fn: Optional[Callable] = None,
|
||||
scale_grads: bool = True,
|
||||
):
|
||||
|
|
@ -82,9 +82,9 @@ class ScheduleUnbalanced(PipelineScheduleMulti):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
stages: List[_PipelineStageBase],
|
||||
stages: list[_PipelineStageBase],
|
||||
n_microbatches: int,
|
||||
stage_index_to_group_rank: Dict[int, int],
|
||||
stage_index_to_group_rank: dict[int, int],
|
||||
loss_fn: Optional[Callable] = None,
|
||||
scale_grads: bool = True,
|
||||
):
|
||||
|
|
@ -134,7 +134,7 @@ class ScheduleWithW(PipelineScheduleMulti):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
stages: List[_PipelineStageBase],
|
||||
stages: list[_PipelineStageBase],
|
||||
n_microbatches: int,
|
||||
loss_fn: Optional[Callable] = None,
|
||||
enable_zero_bubble: bool = True,
|
||||
|
|
@ -195,7 +195,7 @@ class ScheduleWithReorderedB(_PipelineScheduleRuntime):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
stages: List[_PipelineStageBase],
|
||||
stages: list[_PipelineStageBase],
|
||||
n_microbatches: int,
|
||||
loss_fn: Optional[Callable] = None,
|
||||
scale_grads: bool = True,
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ import copy
|
|||
import csv
|
||||
import logging
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
from model_registry import MultiMLP
|
||||
|
||||
|
|
@ -356,7 +355,7 @@ instantiate_parametrized_tests(TestSchedulePlan)
|
|||
class TestScheduleLowering(TestCase):
|
||||
"""Tests lowering passes that convert simple compute-only (FBW) schedules into compute+comms schedules"""
|
||||
|
||||
def _parse_actions(self, actions: List[str]) -> List[_Action]:
|
||||
def _parse_actions(self, actions: list[str]) -> list[_Action]:
|
||||
return [_Action.from_str(s) for s in actions]
|
||||
|
||||
@parametrize(
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch.distributed._tensor import DeviceMesh
|
||||
|
|
@ -57,8 +57,8 @@ class TestCommModeFeatures(DTensorTestBase):
|
|||
Used to generate the ground-truth parameter and sharding info for a given distributed model to
|
||||
verify comm_mode correctness
|
||||
"""
|
||||
module_parameters_dict: Dict[str, Any] = {}
|
||||
module_sharding_dict: Dict[str, Any] = {}
|
||||
module_parameters_dict: dict[str, Any] = {}
|
||||
module_sharding_dict: dict[str, Any] = {}
|
||||
|
||||
for name, parameters in model.named_parameters():
|
||||
# splits name into module name to create FQN and parameter name
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
# Owner(s): ["oncall: distributed"]
|
||||
from collections import defaultdict
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
from torch.distributed._tensor.experimental._tp_transform import (
|
||||
|
|
@ -57,9 +56,9 @@ class TensorParallelTest(DTensorTestBase):
|
|||
super().setUp()
|
||||
|
||||
def assert_has_c10d_ops(
|
||||
self, gm: torch.fx.GraphModule, expected_ops_count: Dict[str, int]
|
||||
self, gm: torch.fx.GraphModule, expected_ops_count: dict[str, int]
|
||||
) -> None:
|
||||
actual_ops_count: Dict[str, int] = defaultdict(int)
|
||||
actual_ops_count: dict[str, int] = defaultdict(int)
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == "call_function":
|
||||
if "c10d_functional" in str(node.target):
|
||||
|
|
@ -100,7 +99,7 @@ class TensorParallelTest(DTensorTestBase):
|
|||
torch.manual_seed(0)
|
||||
model = MLPListModule(2).to(device=self.device_type)
|
||||
inputs = (torch.randn((10, 12)).to(device=self.device_type),)
|
||||
parallel_strategies: Dict[str, ParallelStyle] = {
|
||||
parallel_strategies: dict[str, ParallelStyle] = {
|
||||
"mlps.0.0": ColwiseParallel,
|
||||
"mlps.0.2": RowwiseParallel,
|
||||
"mlps.1.0": ColwiseParallel,
|
||||
|
|
@ -137,7 +136,7 @@ class TensorParallelTest(DTensorTestBase):
|
|||
torch.manual_seed(0)
|
||||
model = MLPListModule(1, bias=False).to(device=self.device_type)
|
||||
inputs = (torch.randn((10, 12)).to(device=self.device_type),)
|
||||
parallel_strategies: Dict[str, ParallelStyle] = {
|
||||
parallel_strategies: dict[str, ParallelStyle] = {
|
||||
"mlps.0.0": ColwiseParallel,
|
||||
"mlps.0.2": RowwiseParallel,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
import itertools
|
||||
from copy import deepcopy
|
||||
from typing import Dict, NamedTuple, Optional
|
||||
from typing import NamedTuple, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -52,9 +52,9 @@ reduce_scatter, all_gather, all_reduce = (
|
|||
|
||||
|
||||
class ExpCommCounts(NamedTuple):
|
||||
fwd: Optional[Dict] = None
|
||||
bwd: Optional[Dict] = None
|
||||
optim: Optional[Dict] = None
|
||||
fwd: Optional[dict] = None
|
||||
bwd: Optional[dict] = None
|
||||
optim: Optional[dict] = None
|
||||
|
||||
|
||||
class DistTensorParallelExampleTest(DTensorTestBase):
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
import itertools
|
||||
import unittest
|
||||
from typing import cast, List, Optional, Tuple
|
||||
from typing import cast, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
|
@ -27,8 +27,8 @@ from torch.testing._internal.distributed._tensor.common_dtensor import (
|
|||
|
||||
|
||||
def scale_for_fp8(
|
||||
t: torch.Tensor, scale_shape: Tuple[int]
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
t: torch.Tensor, scale_shape: tuple[int]
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if all(d == 1 for d in scale_shape):
|
||||
t = t.unsqueeze(0).unsqueeze(-2)
|
||||
else:
|
||||
|
|
@ -116,7 +116,7 @@ class DistMatrixOpsTest(DTensorTestBase):
|
|||
local_res = torch.mm(t1, t2)
|
||||
|
||||
def test_placement_comb(
|
||||
placements1: List[Placement], placements2: List[Placement]
|
||||
placements1: list[Placement], placements2: list[Placement]
|
||||
) -> None:
|
||||
dt1 = distribute_tensor(t1, device_mesh, placements1)
|
||||
dt2 = distribute_tensor(t2, device_mesh, placements2)
|
||||
|
|
@ -272,9 +272,9 @@ class DistMatrixOpsTest(DTensorTestBase):
|
|||
batch_2 = torch.rand(4, 8, 8, device=self.device_type, requires_grad=True)
|
||||
|
||||
def test_placement_comb(
|
||||
tensor_placements: List[Placement],
|
||||
batch_1_placements: List[Placement],
|
||||
batch_2_placements: List[Placement],
|
||||
tensor_placements: list[Placement],
|
||||
batch_1_placements: list[Placement],
|
||||
batch_2_placements: list[Placement],
|
||||
beta: int,
|
||||
alpha: int,
|
||||
batch_1_grad: Optional[torch.Tensor],
|
||||
|
|
@ -338,8 +338,8 @@ class DistMatrixOpsTest(DTensorTestBase):
|
|||
local_result.backward(grad_local_res)
|
||||
|
||||
def test_placement_comb(
|
||||
placements1: List[Placement],
|
||||
placements2: List[Placement],
|
||||
placements1: list[Placement],
|
||||
placements2: list[Placement],
|
||||
) -> None:
|
||||
mat1_dt = distribute_tensor(mat1, device_mesh, placements1)
|
||||
mat2_dt = distribute_tensor(mat2, device_mesh, placements2)
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
from typing import Any, Callable, Optional
|
||||
from unittest import skip
|
||||
|
||||
import torch
|
||||
|
|
@ -76,7 +76,7 @@ class DistElementwiseOpsTest(DTensorOpTestBase):
|
|||
op: Callable,
|
||||
pre_op_fn: Optional[Callable] = None,
|
||||
args: Sequence[Any] = (),
|
||||
kwargs: Optional[Dict[str, Any]] = None,
|
||||
kwargs: Optional[dict[str, Any]] = None,
|
||||
):
|
||||
if pre_op_fn is None:
|
||||
pre_op_fn = no_op
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import itertools
|
||||
from typing import cast, List
|
||||
from typing import cast
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -160,7 +160,7 @@ class TestViewOps(DTensorTestBase):
|
|||
if op == torch.unbind:
|
||||
no_shard_dims.add(kwargs.get("dim", 0))
|
||||
|
||||
sharding_choices = cast(List[Placement], [Replicate()]) + [
|
||||
sharding_choices = cast(list[Placement], [Replicate()]) + [
|
||||
Shard(i) for i, s in enumerate(in_shape) if s > 1 and i not in no_shard_dims
|
||||
]
|
||||
|
||||
|
|
@ -513,7 +513,7 @@ class TestViewOps(DTensorTestBase):
|
|||
# test sharded computation correctness
|
||||
# NOTE: For the input to torch.view_as_complex, sharding
|
||||
# on the last two dimensions is not supported.
|
||||
sharding_choices: List[Placement] = [Replicate(), Shard(0)]
|
||||
sharding_choices: list[Placement] = [Replicate(), Shard(0)]
|
||||
all_sharding_choices = itertools.product(
|
||||
*(self.device_mesh.ndim * [sharding_choices])
|
||||
)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
import os
|
||||
import unittest
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Dict, Tuple
|
||||
from typing import Any, Callable
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
|
@ -26,7 +26,7 @@ def with_xla(func: Callable) -> Callable:
|
|||
|
||||
@wraps(func) # pyre-ignore[6]
|
||||
def wrapper(
|
||||
self, *args: Tuple[object], **kwargs: Dict[str, Any] # type: ignore[misc]
|
||||
self, *args: tuple[object], **kwargs: dict[str, Any] # type: ignore[misc]
|
||||
) -> None:
|
||||
# TODO(yeounoh) replace this with xr.use_spmd() when we deprecate the flag.
|
||||
os.environ["XLA_USE_SPMD"] = "1"
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from dataclasses import dataclass
|
|||
from datetime import timedelta
|
||||
from itertools import product
|
||||
from sys import platform
|
||||
from typing import Dict, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -972,7 +972,7 @@ class CommonDistributedDataParallelTest:
|
|||
@dataclass
|
||||
class CustomOutput:
|
||||
o1: Optional[torch.Tensor]
|
||||
o2: Dict[str, torch.Tensor]
|
||||
o2: dict[str, torch.Tensor]
|
||||
|
||||
class DataclassOutputModule(nn.Module):
|
||||
def __init__(self, skip_o1):
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
# Owner(s): ["module: c10d"]
|
||||
import threading
|
||||
import unittest
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -65,7 +64,7 @@ class TestWithNCCL(MultiProcessTestCase):
|
|||
return 2
|
||||
|
||||
@property
|
||||
def ranks(self) -> List[int]:
|
||||
def ranks(self) -> list[int]:
|
||||
return list(range(self.world_size))
|
||||
|
||||
@property
|
||||
|
|
@ -556,7 +555,7 @@ class CompileTest(TestCase):
|
|||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@fresh_inductor_cache()
|
||||
def test_inductor_all_reduce_coalesced(self):
|
||||
def func(args: List[torch.Tensor]) -> torch.Tensor:
|
||||
def func(args: list[torch.Tensor]) -> torch.Tensor:
|
||||
bufs = [arg + 42 for arg in args]
|
||||
# Expect in-place with inductor allocated buf
|
||||
ar0 = funcol.all_reduce_coalesced(bufs, "avg", "0")
|
||||
|
|
@ -714,7 +713,7 @@ class CompileTest(TestCase):
|
|||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@fresh_inductor_cache()
|
||||
def test_inductor_all_gather_into_tensor_coalesced(self):
|
||||
def func(args: List[torch.Tensor]) -> torch.Tensor:
|
||||
def func(args: list[torch.Tensor]) -> torch.Tensor:
|
||||
ag0 = funcol.all_gather_into_tensor_coalesced(args, "0")
|
||||
ag0 = [funcol.wait_tensor(out) for out in ag0]
|
||||
return ag0
|
||||
|
|
@ -796,7 +795,7 @@ class CompileTest(TestCase):
|
|||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@fresh_inductor_cache()
|
||||
def test_inductor_reduce_scatter_tensor_coalesced(self):
|
||||
def func(args: List[torch.Tensor]) -> torch.Tensor:
|
||||
def func(args: list[torch.Tensor]) -> torch.Tensor:
|
||||
rs0 = funcol.reduce_scatter_tensor_coalesced(
|
||||
args, "avg", [0] * len(args), "0"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -27,7 +27,6 @@ if not c10d.is_available() or not c10d.is_nccl_available():
|
|||
print("c10d NCCL not available, skipping tests", file=sys.stderr)
|
||||
sys.exit(0)
|
||||
|
||||
from typing import Dict, List
|
||||
|
||||
import test_c10d_common
|
||||
from test_c10d_common import ConvNet, DoubleGpuNet, gpus_for_rank, ModuleForDdpCommHook
|
||||
|
|
@ -2546,7 +2545,7 @@ class WorkHookTest(MultiProcessTestCase):
|
|||
def test_on_completion_hook_broadcast(self):
|
||||
pg = self._get_process_group()
|
||||
num_hook_fired = 0
|
||||
durations: List[float] = []
|
||||
durations: list[float] = []
|
||||
|
||||
def hook(work_info: torch._C._distributed_c10d.WorkInfo):
|
||||
nonlocal num_hook_fired, durations
|
||||
|
|
@ -2574,7 +2573,7 @@ class WorkHookTest(MultiProcessTestCase):
|
|||
def test_on_completion_hook_mixed_ops(self):
|
||||
pg = self._get_process_group()
|
||||
num_hook_fired = 0
|
||||
durations: List[float] = []
|
||||
durations: list[float] = []
|
||||
|
||||
def hook(work_info: torch._C._distributed_c10d.WorkInfo):
|
||||
nonlocal num_hook_fired, durations
|
||||
|
|
@ -2615,8 +2614,8 @@ class WorkHookTest(MultiProcessTestCase):
|
|||
@skip_if_lt_x_gpu(2)
|
||||
def test_on_completion_hook_with_ddp(self):
|
||||
pg = self._get_process_group()
|
||||
num_hook_fired: Dict[int, int] = {}
|
||||
durations: Dict[OpType, List[float]] = {}
|
||||
num_hook_fired: dict[int, int] = {}
|
||||
durations: dict[OpType, list[float]] = {}
|
||||
|
||||
def hook(work_info: torch._C._distributed_c10d.WorkInfo):
|
||||
nonlocal num_hook_fired, durations
|
||||
|
|
@ -2673,8 +2672,8 @@ class WorkHookTest(MultiProcessTestCase):
|
|||
torch.cuda.set_device(self.rank)
|
||||
|
||||
pg = self._get_process_group()
|
||||
num_hook_fired: Dict[int, int] = {}
|
||||
durations: Dict[OpType, List[float]] = {}
|
||||
num_hook_fired: dict[int, int] = {}
|
||||
durations: dict[OpType, list[float]] = {}
|
||||
|
||||
def hook(work_info: torch._C._distributed_c10d.WorkInfo):
|
||||
nonlocal num_hook_fired, durations
|
||||
|
|
|
|||
|
|
@ -2,21 +2,15 @@
|
|||
|
||||
import copy
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
import test_c10d_spawn
|
||||
from test_c10d_spawn import _torch_dist_nn_available, TestDistributedNNFunctions
|
||||
|
||||
import torch
|
||||
import torch.distributed as c10d
|
||||
import torch.nn as nn
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU
|
||||
from torch.testing._internal.common_distributed import (
|
||||
create_device,
|
||||
requires_gloo,
|
||||
skip_if_lt_x_gpu,
|
||||
)
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA
|
||||
from torch.testing._internal.common_distributed import requires_gloo, skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_utils import (
|
||||
run_tests,
|
||||
skip_but_pass_in_sandcastle_if,
|
||||
|
|
@ -26,95 +20,6 @@ from torch.testing._internal.common_utils import (
|
|||
|
||||
|
||||
# Fails on Python-3.9, see https://github.com/pytorch/pytorch/issues/51619
|
||||
if sys.version_info < (3, 9):
|
||||
|
||||
class ProcessGroupShareTensorTest(
|
||||
test_c10d_spawn.AbstractProcessGroupShareTensorTest, TestCase
|
||||
):
|
||||
@classmethod
|
||||
def opts(cls, threads=2):
|
||||
opts = c10d.ProcessGroupGloo._Options()
|
||||
opts._timeout = 5.0
|
||||
opts._devices = [create_device(interface="lo")]
|
||||
opts._threads = threads
|
||||
return opts
|
||||
|
||||
@classmethod
|
||||
def _init_pg_gloo(cls, rank, filename, world_size):
|
||||
store = c10d.FileStore(filename, world_size)
|
||||
backend = c10d.ProcessGroupGloo(
|
||||
store, rank, world_size, ProcessGroupShareTensorTest.opts()
|
||||
)
|
||||
# set process group backends manually
|
||||
c10d.init_process_group(
|
||||
backend="gloo", store=store, rank=rank, world_size=world_size
|
||||
)
|
||||
pg = c10d.distributed_c10d._get_default_group()
|
||||
pg._register_backend(
|
||||
torch.device("cpu"), c10d.ProcessGroup.BackendType.GLOO, backend
|
||||
)
|
||||
pg._register_backend(
|
||||
torch.device("cuda"), c10d.ProcessGroup.BackendType.GLOO, backend
|
||||
)
|
||||
|
||||
return pg
|
||||
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
not TEST_MULTIGPU, "At least 2 CUDA GPUS needed"
|
||||
)
|
||||
def test_shared_broadcast_gloo(self):
|
||||
self._test_multiprocess(
|
||||
ProcessGroupShareTensorTest._test_broadcast_process,
|
||||
[torch.ones(2, 2).to(i) * i for i in range(self.world_size)],
|
||||
ProcessGroupShareTensorTest._init_pg_gloo,
|
||||
1,
|
||||
)
|
||||
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
not TEST_MULTIGPU, "At least 2 CUDA GPUS needed"
|
||||
)
|
||||
def test_shared_allreduce_gloo(self):
|
||||
self._test_multiprocess(
|
||||
ProcessGroupShareTensorTest._test_allreduce_process,
|
||||
[torch.ones(2, 2).to(i) for i in range(self.world_size)],
|
||||
ProcessGroupShareTensorTest._init_pg_gloo,
|
||||
1,
|
||||
)
|
||||
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
not TEST_MULTIGPU, "At least 2 CUDA GPUS needed"
|
||||
)
|
||||
def test_shared_allgather_gloo(self):
|
||||
self._test_multiprocess(
|
||||
ProcessGroupShareTensorTest._test_allgather_process,
|
||||
[torch.ones(2, 2).to(i) * i for i in range(self.world_size)],
|
||||
ProcessGroupShareTensorTest._init_pg_gloo,
|
||||
self.world_size,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _test_allgather_chunk_process(
|
||||
cls, rank, filename, shared_tensor, world_size, init_pg, c2p, p2c
|
||||
):
|
||||
pg = init_pg(rank, filename, world_size)
|
||||
chunks = torch.chunk(shared_tensor, world_size, dim=0)
|
||||
x = chunks[rank]
|
||||
ys = [torch.zeros_like(x) for _ in range(world_size)]
|
||||
pg.allgather(ys, x).wait()
|
||||
c2p.put((rank, chunks[0].to("cpu"), ys[0].to("cpu")))
|
||||
c2p.put((rank, chunks[1].to("cpu"), ys[1].to("cpu")))
|
||||
p2c.get()
|
||||
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
not TEST_MULTIGPU, "At least 2 CUDA GPUS needed"
|
||||
)
|
||||
def test_shared_allgather_chunk_gloo(self):
|
||||
self._test_multiprocess(
|
||||
ProcessGroupShareTensorTest._test_allgather_chunk_process,
|
||||
torch.tensor(range(4)).reshape(2, 2),
|
||||
ProcessGroupShareTensorTest._init_pg_gloo,
|
||||
self.world_size,
|
||||
)
|
||||
|
||||
|
||||
class DistributedDataParallelSingleProcessTest(TestCase):
|
||||
|
|
|
|||
|
|
@ -1,95 +1,21 @@
|
|||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import sys
|
||||
|
||||
import test_c10d_spawn
|
||||
from test_c10d_spawn import _torch_dist_nn_available, TestDistributedNNFunctions
|
||||
|
||||
import torch
|
||||
import torch.distributed as c10d
|
||||
from torch.testing._internal.common_cuda import TEST_MULTIGPU
|
||||
from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_utils import (
|
||||
run_tests,
|
||||
skip_but_pass_in_sandcastle_if,
|
||||
TEST_WITH_DEV_DBG_ASAN,
|
||||
TestCase,
|
||||
)
|
||||
|
||||
|
||||
NO_NCCL = not hasattr(c10d, "ProcessGroupNCCL")
|
||||
|
||||
# Fails on Python-3.9, see https://github.com/pytorch/pytorch/issues/51619
|
||||
if sys.version_info < (3, 9):
|
||||
|
||||
class ProcessGroupShareTensorTest(
|
||||
test_c10d_spawn.AbstractProcessGroupShareTensorTest, TestCase
|
||||
):
|
||||
@classmethod
|
||||
def _init_pg_nccl(cls, rank, filename, world_size):
|
||||
store = c10d.FileStore(filename, world_size)
|
||||
return c10d.ProcessGroupNCCL(store, rank, world_size)
|
||||
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
not TEST_MULTIGPU, "At least 2 CUDA GPUS needed"
|
||||
)
|
||||
@skip_but_pass_in_sandcastle_if(NO_NCCL, "NCCL needed")
|
||||
def test_shared_broadcast_nccl(self):
|
||||
self._test_multiprocess(
|
||||
ProcessGroupShareTensorTest._test_broadcast_process,
|
||||
[torch.ones(2, 2).to(i) * i for i in range(self.world_size)],
|
||||
ProcessGroupShareTensorTest._init_pg_nccl,
|
||||
1,
|
||||
)
|
||||
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
not TEST_MULTIGPU, "At least 2 CUDA GPUS needed"
|
||||
)
|
||||
@skip_but_pass_in_sandcastle_if(NO_NCCL, "NCCL needed")
|
||||
def test_shared_allreduce_nccl(self):
|
||||
self._test_multiprocess(
|
||||
ProcessGroupShareTensorTest._test_allreduce_process,
|
||||
[torch.ones(2, 2).to(i) for i in range(self.world_size)],
|
||||
ProcessGroupShareTensorTest._init_pg_nccl,
|
||||
1,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _test_reduce_process(
|
||||
cls, rank, filename, shared_tensors, world_size, init_pg, c2p, p2c
|
||||
):
|
||||
pg = init_pg(rank, filename, world_size)
|
||||
x = shared_tensors[rank]
|
||||
pg.reduce(x, root=0, op=c10d.ReduceOp.SUM).wait()
|
||||
if rank == 0:
|
||||
c2p.put((rank, torch.ones(2, 2) * 2, x.to("cpu")))
|
||||
else:
|
||||
c2p.put((rank, torch.ones(2, 2), x.to("cpu")))
|
||||
p2c.get()
|
||||
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
not TEST_MULTIGPU, "At least 2 CUDA GPUS needed"
|
||||
)
|
||||
@skip_but_pass_in_sandcastle_if(NO_NCCL, "NCCL needed")
|
||||
def test_shared_reduce_nccl(self):
|
||||
self._test_multiprocess(
|
||||
ProcessGroupShareTensorTest._test_reduce_process,
|
||||
[torch.ones(2, 2).to(i) for i in range(self.world_size)],
|
||||
ProcessGroupShareTensorTest._init_pg_nccl,
|
||||
1,
|
||||
)
|
||||
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
not TEST_MULTIGPU, "At least 2 CUDA GPUS needed"
|
||||
)
|
||||
@skip_but_pass_in_sandcastle_if(NO_NCCL, "NCCL needed")
|
||||
def test_shared_allgather_nccl(self):
|
||||
self._test_multiprocess(
|
||||
ProcessGroupShareTensorTest._test_allgather_process,
|
||||
[torch.ones(2, 2).to(i) * i for i in range(self.world_size)],
|
||||
ProcessGroupShareTensorTest._init_pg_nccl,
|
||||
self.world_size,
|
||||
)
|
||||
|
||||
|
||||
# Skip dev-asan as torch + multiprocessing spawn have known issues
|
||||
|
|
|
|||
|
|
@ -1,74 +1,21 @@
|
|||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import sys
|
||||
|
||||
import test_c10d_spawn
|
||||
from test_c10d_spawn import _torch_dist_nn_available, TestDistributedNNFunctions
|
||||
|
||||
import torch
|
||||
import torch.distributed as c10d
|
||||
from torch.testing._internal.common_cuda import TEST_MULTIGPU
|
||||
from torch.testing._internal.common_distributed import requires_ucc, skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_utils import (
|
||||
run_tests,
|
||||
skip_but_pass_in_sandcastle,
|
||||
skip_but_pass_in_sandcastle_if,
|
||||
TEST_WITH_DEV_DBG_ASAN,
|
||||
TestCase,
|
||||
)
|
||||
|
||||
|
||||
NO_UCC = not hasattr(c10d, "ProcessGroupUCC")
|
||||
|
||||
# Fails on Python-3.9, see https://github.com/pytorch/pytorch/issues/51619
|
||||
if sys.version_info < (3, 9):
|
||||
|
||||
class ProcessGroupShareTensorTest(
|
||||
test_c10d_spawn.AbstractProcessGroupShareTensorTest, TestCase
|
||||
):
|
||||
@classmethod
|
||||
def _init_pg_ucc(cls, rank, filename, world_size):
|
||||
store = c10d.FileStore(filename, world_size)
|
||||
c10d.init_process_group(
|
||||
backend="ucc", store=store, rank=rank, world_size=world_size
|
||||
)
|
||||
return c10d.distributed_c10d._get_default_group()
|
||||
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
not TEST_MULTIGPU, "At least 2 CUDA GPUS needed"
|
||||
)
|
||||
@skip_but_pass_in_sandcastle_if(NO_UCC, "UCC needed")
|
||||
def test_shared_broadcast_ucc(self):
|
||||
self._test_multiprocess(
|
||||
ProcessGroupShareTensorTest._test_broadcast_process,
|
||||
[torch.ones(2, 2).to(i) * i for i in range(self.world_size)],
|
||||
ProcessGroupShareTensorTest._init_pg_ucc,
|
||||
1,
|
||||
)
|
||||
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
not TEST_MULTIGPU, "At least 2 CUDA GPUS needed"
|
||||
)
|
||||
@skip_but_pass_in_sandcastle_if(NO_UCC, "UCC needed")
|
||||
def test_shared_allreduce_ucc(self):
|
||||
self._test_multiprocess(
|
||||
ProcessGroupShareTensorTest._test_allreduce_process,
|
||||
[torch.ones(2, 2).to(i) for i in range(self.world_size)],
|
||||
ProcessGroupShareTensorTest._init_pg_ucc,
|
||||
1,
|
||||
)
|
||||
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
not TEST_MULTIGPU, "At least 2 CUDA GPUS needed"
|
||||
)
|
||||
@skip_but_pass_in_sandcastle_if(NO_UCC, "UCC needed")
|
||||
def test_shared_allgather_ucc(self):
|
||||
self._test_multiprocess(
|
||||
ProcessGroupShareTensorTest._test_allgather_process,
|
||||
[torch.ones(2, 2).to(i) * i for i in range(self.world_size)],
|
||||
ProcessGroupShareTensorTest._init_pg_ucc,
|
||||
self.world_size,
|
||||
)
|
||||
|
||||
|
||||
# Skip dev-asan as torch + multiprocessing spawn have known issues
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@ import unittest
|
|||
from contextlib import contextmanager
|
||||
from datetime import timedelta
|
||||
from io import StringIO
|
||||
from typing import List
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
|
|
@ -1959,7 +1958,7 @@ class TestSingleProc(DynamoDistributedSingleProcTestCase):
|
|||
model = ModuleWithStaticMethod(False)
|
||||
x = torch.randn((2, 3), device="cuda")
|
||||
ref_out = model(x)
|
||||
test_outs: List[torch.Tensor] = []
|
||||
test_outs: list[torch.Tensor] = []
|
||||
|
||||
for use_self in (False, True):
|
||||
model = ModuleWithStaticMethod(use_self)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ import functools
|
|||
import math
|
||||
import unittest # noqa: F811
|
||||
from importlib import import_module
|
||||
from typing import Set
|
||||
|
||||
import torch
|
||||
import torch._dynamo.config
|
||||
|
|
@ -86,7 +85,7 @@ def count_ops(
|
|||
return gm
|
||||
|
||||
|
||||
def collect_fwd_graph_outputs(graph: torch.fx.Graph, *, fwd_outputs: Set[str]):
|
||||
def collect_fwd_graph_outputs(graph: torch.fx.Graph, *, fwd_outputs: set[str]):
|
||||
if not torch._dynamo.compiled_autograd.in_compiled_autograd_region: # fwd graph
|
||||
return_node = list(graph.nodes)[-1]
|
||||
assert return_node.target == "output"
|
||||
|
|
|
|||
|
|
@ -12,7 +12,6 @@ import operator
|
|||
import unittest
|
||||
from collections.abc import Sequence
|
||||
from enum import Enum
|
||||
from typing import Dict, List
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
|
|
@ -1999,7 +1998,7 @@ def forward(self, l_x_):
|
|||
self.assertIn("val", node.meta)
|
||||
|
||||
def test_input_container_type(self):
|
||||
def f(x: torch.Tensor, y: List[torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
def f(x: torch.Tensor, y: list[torch.Tensor]) -> dict[str, torch.Tensor]:
|
||||
return {"a": x.sum() + sum(y).sum()}
|
||||
|
||||
inp = (torch.randn(6, 5), [torch.randn(6, 5), torch.randn(6, 5)])
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ import random
|
|||
import sys
|
||||
import unittest
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Generic, List, TypeVar
|
||||
from typing import Any, Generic, TypeVar
|
||||
from typing_extensions import NamedTuple
|
||||
from unittest.mock import patch
|
||||
|
||||
|
|
@ -3896,8 +3896,8 @@ class DefaultsTests(torch._dynamo.test_case.TestCase):
|
|||
@dataclass
|
||||
class Output:
|
||||
scalar: int = 2
|
||||
named_tensors: Dict[str, torch.Tensor] = field(default_factory=dict)
|
||||
lists: List[torch.Tensor] = field(default_factory=list)
|
||||
named_tensors: dict[str, torch.Tensor] = field(default_factory=dict)
|
||||
lists: list[torch.Tensor] = field(default_factory=list)
|
||||
|
||||
def scale(self):
|
||||
return self.scalar * 2
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ import types
|
|||
import unittest
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Dict, NamedTuple, Tuple
|
||||
from typing import NamedTuple
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
|
|
@ -602,7 +602,7 @@ class LazyMLP(torch.nn.Module):
|
|||
|
||||
|
||||
class MyInput(NamedTuple):
|
||||
x: Dict[str, Dict[str, torch.Tensor]]
|
||||
x: dict[str, dict[str, torch.Tensor]]
|
||||
y: torch.Tensor
|
||||
|
||||
|
||||
|
|
@ -2311,7 +2311,7 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
|
|||
m = TestModule()
|
||||
|
||||
def forward_hook(
|
||||
module: torch.nn.Module, inputs: Tuple[torch.Tensor], output: torch.Tensor
|
||||
module: torch.nn.Module, inputs: tuple[torch.Tensor], output: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
return 2 * output + 1
|
||||
|
||||
|
|
@ -2358,7 +2358,7 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
|
|||
m = TestModule()
|
||||
|
||||
def forward_hook(
|
||||
module: torch.nn.Module, inputs: Tuple[torch.Tensor], output: torch.Tensor
|
||||
module: torch.nn.Module, inputs: tuple[torch.Tensor], output: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
return 2 * output + 1
|
||||
|
||||
|
|
@ -2407,7 +2407,7 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
|
|||
self.assertEqual(compiled_func(inp).item(), 15)
|
||||
|
||||
def new_forward_hook(
|
||||
module: torch.nn.Module, inputs: Tuple[torch.Tensor], output: torch.Tensor
|
||||
module: torch.nn.Module, inputs: tuple[torch.Tensor], output: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
return 2 * output + 2
|
||||
|
||||
|
|
@ -2426,7 +2426,7 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
|
|||
m = TestModule()
|
||||
|
||||
def forward_hook(
|
||||
module: torch.nn.Module, inputs: Tuple[torch.Tensor], output: torch.Tensor
|
||||
module: torch.nn.Module, inputs: tuple[torch.Tensor], output: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
return 2 * output + 1
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
# Owner(s): ["module: dynamo"]
|
||||
from typing import Callable, Dict, List, NamedTuple, Optional
|
||||
from typing import Callable, List, NamedTuple, Optional
|
||||
|
||||
import torch
|
||||
import torch._dynamo
|
||||
|
|
@ -50,20 +50,20 @@ class Variable:
|
|||
def sum(self, name: Optional[str] = None) -> "Variable":
|
||||
return operator_sum(self, name)
|
||||
|
||||
def expand(self, sizes: List[int]) -> "Variable":
|
||||
def expand(self, sizes: list[int]) -> "Variable":
|
||||
return operator_expand(self, sizes)
|
||||
|
||||
|
||||
class TapeEntry(NamedTuple):
|
||||
# names of the inputs to the original computation
|
||||
inputs: List[str]
|
||||
inputs: list[str]
|
||||
# names of the outputs of the original computation
|
||||
outputs: List[str]
|
||||
outputs: list[str]
|
||||
# apply chain rule
|
||||
propagate: "Callable[List[Variable], List[Variable]]"
|
||||
propagate: "Callable[list[Variable], list[Variable]]"
|
||||
|
||||
|
||||
gradient_tape: List[TapeEntry] = []
|
||||
gradient_tape: list[TapeEntry] = []
|
||||
|
||||
|
||||
def reset_tape():
|
||||
|
|
@ -72,9 +72,9 @@ def reset_tape():
|
|||
_name = 0
|
||||
|
||||
|
||||
def grad(L, desired_results: List[Variable]) -> List[Variable]:
|
||||
def grad(L, desired_results: list[Variable]) -> list[Variable]:
|
||||
# this map holds dL/dX for all values X
|
||||
dL_d: Dict[str, Variable] = {}
|
||||
dL_d: dict[str, Variable] = {}
|
||||
# It starts by initializing the 'seed' dL/dL, which is 1
|
||||
dL_d[L.name] = Variable(torch.ones(()))
|
||||
# print(f'd{L.name} ------------------------')
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@
|
|||
import contextlib
|
||||
import dis
|
||||
import unittest
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch._dynamo.test_case
|
||||
|
|
@ -33,7 +32,7 @@ class ReconstructTest(torch._dynamo.test_case.TestCase):
|
|||
Emit code to reconstruct only the key that changed
|
||||
"""
|
||||
|
||||
def hook(instructions: List[dis.Instruction]):
|
||||
def hook(instructions: list[dis.Instruction]):
|
||||
build_map = _filter_instructions(instructions, "BUILD_MAP")
|
||||
self.assertEqual(len(build_map), 1)
|
||||
# reconstruct only d[40]
|
||||
|
|
@ -57,7 +56,7 @@ class ReconstructTest(torch._dynamo.test_case.TestCase):
|
|||
If something is pop'ed from the dict, we reconstruct everything
|
||||
"""
|
||||
|
||||
def hook(instructions: List[dis.Instruction]):
|
||||
def hook(instructions: list[dis.Instruction]):
|
||||
build_map = _filter_instructions(instructions, "BUILD_MAP")
|
||||
self.assertEqual(len(build_map), 1)
|
||||
# reconstruct everything
|
||||
|
|
@ -84,7 +83,7 @@ class ReconstructTest(torch._dynamo.test_case.TestCase):
|
|||
If something is pop'ed from the dict, we reconstruct everything
|
||||
"""
|
||||
|
||||
def hook(instructions: List[dis.Instruction]):
|
||||
def hook(instructions: list[dis.Instruction]):
|
||||
build_map = _filter_instructions(instructions, "BUILD_MAP")
|
||||
self.assertEqual(len(build_map), 1)
|
||||
# reconstruct everything
|
||||
|
|
@ -128,7 +127,7 @@ class ReconstructTest(torch._dynamo.test_case.TestCase):
|
|||
If something is deleted from the dict, we reconstruct everything
|
||||
"""
|
||||
|
||||
def hook(instructions: List[dis.Instruction]):
|
||||
def hook(instructions: list[dis.Instruction]):
|
||||
build_map = _filter_instructions(instructions, "BUILD_MAP")
|
||||
self.assertEqual(len(build_map), 1)
|
||||
# reconstruct everything
|
||||
|
|
@ -154,7 +153,7 @@ class ReconstructTest(torch._dynamo.test_case.TestCase):
|
|||
dict.get shouldn't affect anything
|
||||
"""
|
||||
|
||||
def hook(instructions: List[dis.Instruction]):
|
||||
def hook(instructions: list[dis.Instruction]):
|
||||
build_map = _filter_instructions(instructions, "BUILD_MAP")
|
||||
self.assertEqual(len(build_map), 1)
|
||||
self.assertEqual(build_map[0].argval, 1)
|
||||
|
|
@ -180,7 +179,7 @@ class ReconstructTest(torch._dynamo.test_case.TestCase):
|
|||
If dict.clear() is used, we reconstruct everything
|
||||
"""
|
||||
|
||||
def hook(instructions: List[dis.Instruction]):
|
||||
def hook(instructions: list[dis.Instruction]):
|
||||
build_map = _filter_instructions(instructions, "BUILD_MAP")
|
||||
self.assertEqual(len(build_map), 1)
|
||||
# reconstruct everything
|
||||
|
|
@ -206,7 +205,7 @@ class ReconstructTest(torch._dynamo.test_case.TestCase):
|
|||
If dict is created inside a function, everything needs to be reconstructed
|
||||
"""
|
||||
|
||||
def hook(instructions: List[dis.Instruction]):
|
||||
def hook(instructions: list[dis.Instruction]):
|
||||
build_map = _filter_instructions(instructions, "BUILD_MAP")
|
||||
self.assertEqual(len(build_map), 1)
|
||||
# reconstruct everything
|
||||
|
|
@ -231,7 +230,7 @@ class ReconstructTest(torch._dynamo.test_case.TestCase):
|
|||
PyTorch shouldn't codegen any key/value when functional_call is used
|
||||
"""
|
||||
|
||||
def hook(instructions: List[dis.Instruction]):
|
||||
def hook(instructions: list[dis.Instruction]):
|
||||
build_map = _filter_instructions(instructions, "BUILD_MAP")
|
||||
# don't reconstruct anything
|
||||
self.assertEqual(len(build_map), 0)
|
||||
|
|
@ -260,7 +259,7 @@ class ReconstructTest(torch._dynamo.test_case.TestCase):
|
|||
PyTorch shouldn't codegen any key/value when functional_call is used
|
||||
"""
|
||||
|
||||
def hook(instructions: List[dis.Instruction]):
|
||||
def hook(instructions: list[dis.Instruction]):
|
||||
build_map = _filter_instructions(instructions, "BUILD_MAP")
|
||||
# don't reconstruct anything
|
||||
self.assertEqual(len(build_map), 0)
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ from collections.abc import Iterator
|
|||
from copy import deepcopy
|
||||
from enum import Enum, IntEnum
|
||||
from functools import wraps
|
||||
from typing import Any, Dict, List, Literal, Tuple, TypedDict
|
||||
from typing import Any, Literal, TypedDict
|
||||
from unittest import mock
|
||||
|
||||
import numpy as np
|
||||
|
|
@ -712,7 +712,7 @@ def create_rand_mask_from_inputs(
|
|||
class SequentialAppendList(torch.nn.Sequential):
|
||||
"""from timm/models/vovnet.py"""
|
||||
|
||||
def forward(self, x: torch.Tensor, concat_list: List[torch.Tensor]) -> torch.Tensor:
|
||||
def forward(self, x: torch.Tensor, concat_list: list[torch.Tensor]) -> torch.Tensor:
|
||||
for i, module in enumerate(self):
|
||||
if i == 0:
|
||||
concat_list.append(module(x))
|
||||
|
|
@ -4108,7 +4108,7 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
|||
def test_graph_break_on_jit_isinstance(self):
|
||||
@torch.compile(backend="eager")
|
||||
def fn(x):
|
||||
if torch.jit.isinstance(x, List[str]):
|
||||
if torch.jit.isinstance(x, list[str]):
|
||||
return x * 2
|
||||
return x
|
||||
|
||||
|
|
@ -4819,14 +4819,14 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
|
|||
|
||||
def test_detectron2_instances_cat(self):
|
||||
class Instances:
|
||||
def __init__(self, image_size: Tuple[int, int], **kwargs: Any):
|
||||
def __init__(self, image_size: tuple[int, int], **kwargs: Any):
|
||||
self._image_size = image_size
|
||||
self._fields: Dict[str, Any] = {}
|
||||
self._fields: dict[str, Any] = {}
|
||||
for k, v in kwargs.items():
|
||||
self.set(k, v)
|
||||
|
||||
@property
|
||||
def image_size(self) -> Tuple[int, int]:
|
||||
def image_size(self) -> tuple[int, int]:
|
||||
return self._image_size
|
||||
|
||||
def __setattr__(self, name: str, val: Any) -> None:
|
||||
|
|
@ -4861,7 +4861,7 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
|
|||
return self._fields[name]
|
||||
|
||||
@staticmethod
|
||||
def cat(instance_lists: List["Instances"]) -> "Instances":
|
||||
def cat(instance_lists: list["Instances"]) -> "Instances":
|
||||
assert all(isinstance(i, Instances) for i in instance_lists)
|
||||
assert len(instance_lists) > 0
|
||||
if len(instance_lists) == 1:
|
||||
|
|
@ -5997,7 +5997,7 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
|
|||
|
||||
# https://github.com/pytorch/pytorch/issues/88813
|
||||
def test_return_value_duplication_tensor(self) -> None:
|
||||
def fn(val: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def fn(val: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return val * 2, val * 2
|
||||
|
||||
x = torch.randn(2, requires_grad=True)
|
||||
|
|
@ -6016,7 +6016,7 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
|
|||
|
||||
# https://github.com/pytorch/pytorch/issues/114344
|
||||
def test_return_value_duplication_mixed_grad(self) -> None:
|
||||
def fn(val: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def fn(val: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
with torch.no_grad():
|
||||
out0 = val + 1
|
||||
out1 = val + 1
|
||||
|
|
@ -6033,7 +6033,7 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
|
|||
|
||||
# https://github.com/pytorch/pytorch/pull/134726#discussion_r1738774371
|
||||
def test_return_value_duplication_scalar(self) -> None:
|
||||
def fn(val: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def fn(val: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
x, y = val * 2, val * 2
|
||||
return x[0], y[0]
|
||||
|
||||
|
|
|
|||
|
|
@ -1822,7 +1822,7 @@ class GraphModule(torch.nn.Module):
|
|||
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
|
||||
@parametrize("dynamic", [True, False])
|
||||
def test_mark_static_with_subclass_desugaring(self, dynamic):
|
||||
from typing import Any, Callable, List, Optional
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from torch._dynamo.decorators import mark_static_address
|
||||
from torch._inductor.compile_fx import compile_fx
|
||||
|
|
@ -1835,9 +1835,9 @@ class GraphModule(torch.nn.Module):
|
|||
|
||||
def inner_compile(
|
||||
gm: torch.fx.GraphModule,
|
||||
example_inputs: List[torch.Tensor],
|
||||
example_inputs: list[torch.Tensor],
|
||||
cudagraphs: Optional[BoxedBool] = None,
|
||||
static_input_idxs: Optional[List[int]] = None,
|
||||
static_input_idxs: Optional[list[int]] = None,
|
||||
is_backward: bool = False,
|
||||
graph_id: Optional[int] = None,
|
||||
cpp_wrapper: bool = False,
|
||||
|
|
@ -1845,7 +1845,7 @@ class GraphModule(torch.nn.Module):
|
|||
is_inference: bool = False,
|
||||
boxed_forward_device_index: Optional[BoxedDeviceIndex] = None,
|
||||
layout_opt: Optional[bool] = None,
|
||||
extern_node_serializer: Optional[Callable[[List[Any]], Any]] = None,
|
||||
extern_node_serializer: Optional[Callable[[list[Any]], Any]] = None,
|
||||
):
|
||||
if dynamic:
|
||||
self.assertEqual(static_input_idxs, [2, 3, 4])
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
# Owner(s): ["module: dynamo"]
|
||||
import sys
|
||||
import unittest
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
import torch._dynamo.config
|
||||
|
|
@ -23,7 +22,7 @@ except ImportError:
|
|||
|
||||
@torch._dynamo.config.patch(force_unspec_int_unbacked_size_like_on_torchrec_kjt=True)
|
||||
class BucketizeMod(torch.nn.Module):
|
||||
def __init__(self, feature_boundaries: Dict[str, List[float]]):
|
||||
def __init__(self, feature_boundaries: dict[str, list[float]]):
|
||||
super().__init__()
|
||||
self.bucket_w = torch.nn.ParameterDict()
|
||||
self.boundaries_dict = {}
|
||||
|
|
@ -84,7 +83,7 @@ class TorchRecTests(TestCase):
|
|||
|
||||
@torch.compile(backend=counter, fullgraph=True, dynamic=True)
|
||||
def f(id_list_features: KeyedJaggedTensor):
|
||||
id_list_jt_dict: Dict[str, JaggedTensor] = id_list_features.to_dict()
|
||||
id_list_jt_dict: dict[str, JaggedTensor] = id_list_features.to_dict()
|
||||
pooled_embeddings = {}
|
||||
# TODO: run feature processor
|
||||
for emb_module, feature_names in tables:
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import math
|
|||
import types
|
||||
import unittest
|
||||
import warnings
|
||||
from typing import Any, Dict, Set
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch._dynamo.config as config
|
||||
|
|
@ -103,10 +103,10 @@ class AllowedObjects:
|
|||
from the heuristic defined in `gen_allowed_objs_and_ids`.
|
||||
"""
|
||||
|
||||
object_ids: Dict[int, str]
|
||||
c_binding_in_graph_functions: Set[Any]
|
||||
non_c_binding_in_graph_functions: Set[Any]
|
||||
name_rule_map: Dict[str, Any]
|
||||
object_ids: dict[int, str]
|
||||
c_binding_in_graph_functions: set[Any]
|
||||
non_c_binding_in_graph_functions: set[Any]
|
||||
name_rule_map: dict[str, Any]
|
||||
|
||||
|
||||
def gen_allowed_objs_and_ids(record=False, c_binding_only=True) -> AllowedObjects:
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
import unittest
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
|
|
@ -65,11 +65,11 @@ class TestConverter(TestCase):
|
|||
self,
|
||||
M,
|
||||
tracing_inputs,
|
||||
option: Optional[List[str]] = None,
|
||||
option: Optional[list[str]] = None,
|
||||
check_persistent=False,
|
||||
lifted_tensor_constants=None,
|
||||
runtime_inputs: Optional[List[Any]] = None,
|
||||
) -> List[ExportedProgram]:
|
||||
runtime_inputs: Optional[list[Any]] = None,
|
||||
) -> list[ExportedProgram]:
|
||||
# By default, it tests both jit.trace and jit.script.
|
||||
if option is None:
|
||||
option = ["trace", "script"]
|
||||
|
|
@ -130,7 +130,7 @@ class TestConverter(TestCase):
|
|||
self._check_tensor_list_equal(ep_out, orig_out)
|
||||
return ep_list
|
||||
|
||||
def _check_tensor_list_equal(self, xs: List[torch.Tensor], ys: List[torch.Tensor]):
|
||||
def _check_tensor_list_equal(self, xs: list[torch.Tensor], ys: list[torch.Tensor]):
|
||||
self.assertEqual(len(xs), len(ys))
|
||||
for x, y in zip(xs, ys):
|
||||
if isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor):
|
||||
|
|
@ -219,7 +219,7 @@ class TestConverter(TestCase):
|
|||
self._check_equal_ts_ep_converter(Module(), inp)
|
||||
|
||||
class Module(torch.nn.Module):
|
||||
def forward(self, x: List[int]):
|
||||
def forward(self, x: list[int]):
|
||||
length = len(x)
|
||||
return torch.ones(length)
|
||||
|
||||
|
|
@ -228,7 +228,7 @@ class TestConverter(TestCase):
|
|||
self._check_equal_ts_ep_converter(Module(), inp, ["script"])
|
||||
|
||||
class Module(torch.nn.Module):
|
||||
def forward(self, x: Dict[int, str]):
|
||||
def forward(self, x: dict[int, str]):
|
||||
length = len(x)
|
||||
return torch.ones(length)
|
||||
|
||||
|
|
@ -237,7 +237,7 @@ class TestConverter(TestCase):
|
|||
self._check_equal_ts_ep_converter(Module(), inp, ["script"])
|
||||
|
||||
class Module(torch.nn.Module):
|
||||
def forward(self, x: Dict[bool, str]):
|
||||
def forward(self, x: dict[bool, str]):
|
||||
length = len(x)
|
||||
return torch.ones(length)
|
||||
|
||||
|
|
@ -246,7 +246,7 @@ class TestConverter(TestCase):
|
|||
self._check_equal_ts_ep_converter(Module(), inp, ["script"])
|
||||
|
||||
class Module(torch.nn.Module):
|
||||
def forward(self, x: Dict[float, str]):
|
||||
def forward(self, x: dict[float, str]):
|
||||
length = len(x)
|
||||
return torch.ones(length)
|
||||
|
||||
|
|
@ -255,7 +255,7 @@ class TestConverter(TestCase):
|
|||
self._check_equal_ts_ep_converter(Module(), inp, ["script"])
|
||||
|
||||
class Module(torch.nn.Module):
|
||||
def forward(self, x: Dict[torch.Tensor, str]):
|
||||
def forward(self, x: dict[torch.Tensor, str]):
|
||||
length = len(x)
|
||||
return torch.ones(length)
|
||||
|
||||
|
|
@ -273,7 +273,7 @@ class TestConverter(TestCase):
|
|||
def test_aten_add_t(self):
|
||||
# python list append
|
||||
class Module(torch.nn.Module):
|
||||
def forward(self, x: List[torch.Tensor]):
|
||||
def forward(self, x: list[torch.Tensor]):
|
||||
out = []
|
||||
out = out + x
|
||||
a = torch.cat(out)
|
||||
|
|
@ -531,7 +531,7 @@ class TestConverter(TestCase):
|
|||
class Module(torch.nn.Module):
|
||||
def forward(
|
||||
self, x: torch.Tensor, y: torch.Tensor
|
||||
) -> Tuple[bool, torch.Tensor]:
|
||||
) -> tuple[bool, torch.Tensor]:
|
||||
z = x + 1
|
||||
return x is y, z
|
||||
|
||||
|
|
@ -546,7 +546,7 @@ class TestConverter(TestCase):
|
|||
class Module(torch.nn.Module):
|
||||
def forward(
|
||||
self, x: torch.Tensor, y: torch.Tensor
|
||||
) -> Tuple[bool, torch.Tensor]:
|
||||
) -> tuple[bool, torch.Tensor]:
|
||||
z = x + 1
|
||||
return x is not y, z
|
||||
|
||||
|
|
@ -558,7 +558,7 @@ class TestConverter(TestCase):
|
|||
class Module(torch.nn.Module):
|
||||
def forward(
|
||||
self, x: torch.Tensor, y: torch.Tensor
|
||||
) -> Tuple[bool, torch.Tensor]:
|
||||
) -> tuple[bool, torch.Tensor]:
|
||||
z = x + 1
|
||||
return not (x is not y), z
|
||||
|
||||
|
|
@ -573,7 +573,7 @@ class TestConverter(TestCase):
|
|||
return x + y
|
||||
|
||||
class MUnpackTuple(torch.nn.Module):
|
||||
def forward(self, x_tuple: Tuple[torch.Tensor, torch.Tensor]):
|
||||
def forward(self, x_tuple: tuple[torch.Tensor, torch.Tensor]):
|
||||
x, y = x_tuple
|
||||
x = x.cos()
|
||||
return x + y
|
||||
|
|
@ -904,7 +904,7 @@ class TestConverter(TestCase):
|
|||
return x.dtype in [torch.int8]
|
||||
|
||||
class MTensorIn(torch.nn.Module):
|
||||
def forward(self, x: torch.Tensor, x_dict: Dict[torch.Tensor, str]):
|
||||
def forward(self, x: torch.Tensor, x_dict: dict[torch.Tensor, str]):
|
||||
return x in x_dict
|
||||
|
||||
# Traced function must return output that has tensors.
|
||||
|
|
@ -1118,14 +1118,14 @@ class TestConverter(TestCase):
|
|||
|
||||
def test_prim_tolist(self):
|
||||
class Module(torch.nn.Module):
|
||||
def forward(self, x: torch.Tensor) -> List[int]:
|
||||
def forward(self, x: torch.Tensor) -> list[int]:
|
||||
return x.tolist()
|
||||
|
||||
inp = (torch.tensor([1, 2, 3]),)
|
||||
self._check_equal_ts_ep_converter(Module(), inp, ["script"])
|
||||
|
||||
class Module(torch.nn.Module):
|
||||
def forward(self, x: torch.Tensor) -> List[List[int]]:
|
||||
def forward(self, x: torch.Tensor) -> list[list[int]]:
|
||||
return x.tolist()
|
||||
|
||||
inp = (torch.tensor([[1, 2, 3], [4, 5, 6]]),)
|
||||
|
|
@ -1353,7 +1353,7 @@ class TestConverter(TestCase):
|
|||
|
||||
def test_aten_append_t(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x: List[torch.Tensor]):
|
||||
def forward(self, x: list[torch.Tensor]):
|
||||
out = []
|
||||
out.append(x[0] + x[1])
|
||||
out.append(x[0] - x[1])
|
||||
|
|
@ -1381,7 +1381,7 @@ class TestConverter(TestCase):
|
|||
self._check_equal_ts_ep_converter(M1(), inp, ["script"])
|
||||
|
||||
def test_ts2ep_with_loop(self):
|
||||
def func1(x, x_list: List[torch.Tensor]):
|
||||
def func1(x, x_list: list[torch.Tensor]):
|
||||
a, b, c = x, x, x
|
||||
for _ in range(1, 5, 2):
|
||||
for k in range(5):
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
import copy
|
||||
import tempfile
|
||||
import unittest
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
from torch.export import Dim, export
|
||||
|
|
@ -325,7 +324,7 @@ class TestDraftExport(TestCase):
|
|||
return torch.ops.mylib.foo(a)
|
||||
|
||||
@torch.library.custom_op("mylib::foo", mutates_args={})
|
||||
def foo(a: torch.Tensor) -> List[torch.Tensor]:
|
||||
def foo(a: torch.Tensor) -> list[torch.Tensor]:
|
||||
x = a * 2
|
||||
y = a.repeat(2, 2)
|
||||
z = a.to(torch.bfloat16)
|
||||
|
|
@ -370,7 +369,7 @@ class TestDraftExport(TestCase):
|
|||
return torch.ops.mylib.foo(a)
|
||||
|
||||
@torch.library.custom_op("mylib::foo", mutates_args={})
|
||||
def foo(a: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def foo(a: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return a * 2, a + 2
|
||||
|
||||
@foo.register_fake
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
# Owner(s): ["oncall: export"]
|
||||
import unittest
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch._export.passes.lift_constants_pass import (
|
||||
|
|
@ -34,9 +34,9 @@ class GraphBuilder:
|
|||
self.graph = torch.fx.Graph()
|
||||
self.nodes = {}
|
||||
self.values = {}
|
||||
self.nn_module_stack_key: Dict[str, int] = {}
|
||||
self.nn_module_stack_key: dict[str, int] = {}
|
||||
self.latest_id = 0
|
||||
self.input_to_kind: Dict[torch.fx.Node, InputKind] = {}
|
||||
self.input_to_kind: dict[torch.fx.Node, InputKind] = {}
|
||||
|
||||
def input(self, name: str, value: torch.Tensor, kind: InputKind):
|
||||
node = self.graph.placeholder(name)
|
||||
|
|
@ -87,7 +87,7 @@ class GraphBuilder:
|
|||
|
||||
def create_nn_module_stack(
|
||||
self, module_fqn: str
|
||||
) -> OrderedDict[int, Tuple[str, type]]:
|
||||
) -> OrderedDict[int, tuple[str, type]]:
|
||||
cur_name = ""
|
||||
nn_module_stack = OrderedDict()
|
||||
for atom in module_fqn.split("."):
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ import math
|
|||
import operator
|
||||
import unittest
|
||||
from re import escape
|
||||
from typing import List, Set
|
||||
|
||||
import torch
|
||||
from functorch.experimental.control_flow import cond
|
||||
|
|
@ -75,11 +74,11 @@ class _AtenAddOperatorSupport(OperatorSupport):
|
|||
return node.op == "call_function" and node.target in {torch.ops.aten.add.Tensor}
|
||||
|
||||
|
||||
def _to_partition_names(partitions: List[Partition]) -> List[Set[str]]:
|
||||
def _to_partition_names(partitions: list[Partition]) -> list[set[str]]:
|
||||
return [{n.name for n in p.nodes} for p in partitions]
|
||||
|
||||
|
||||
def _get_output_names(gm: torch.fx.GraphModule) -> List[str]:
|
||||
def _get_output_names(gm: torch.fx.GraphModule) -> list[str]:
|
||||
output_node = next(n for n in gm.graph.nodes if n.op == "output")
|
||||
args = pytree.tree_leaves(output_node.args)
|
||||
# if isinstance(args, tuple) and len(args) == 1:
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ import unittest
|
|||
import warnings
|
||||
from contextlib import ContextDecorator, nullcontext
|
||||
from functools import partial, wraps
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
from common_utils import decorate, decorateForModules, skip, skipOps, xfail
|
||||
|
|
@ -319,8 +319,8 @@ class TestAOTAutograd(AOTTestCase):
|
|||
def run_autograd(
|
||||
self,
|
||||
f: Callable,
|
||||
fw_graph_cell: List[Optional[Callable]],
|
||||
decompositions: Optional[Dict],
|
||||
fw_graph_cell: list[Optional[Callable]],
|
||||
decompositions: Optional[dict],
|
||||
keep_input_mutations: bool,
|
||||
dynamic: bool,
|
||||
):
|
||||
|
|
@ -358,11 +358,11 @@ class TestAOTAutograd(AOTTestCase):
|
|||
def verify_aot_autograd(
|
||||
self,
|
||||
f,
|
||||
inp_: Union[Callable, List[Any]],
|
||||
inp_: Union[Callable, list[Any]],
|
||||
*,
|
||||
test_mutation: bool = False,
|
||||
keep_inp_mutations: bool = False,
|
||||
decompositions: Optional[Dict] = None,
|
||||
decompositions: Optional[dict] = None,
|
||||
dynamic: bool = False,
|
||||
# Only active when inp_ is Callable.
|
||||
# TODO: probably consolidate all tests to make inp a Callable.
|
||||
|
|
@ -6748,8 +6748,8 @@ class TestAOTAutogradWithDynamo(TestAOTAutograd):
|
|||
def run_autograd(
|
||||
self,
|
||||
f: Callable,
|
||||
fw_graph_cell: List[Optional[Callable]],
|
||||
decompositions: Optional[Dict],
|
||||
fw_graph_cell: list[Optional[Callable]],
|
||||
decompositions: Optional[dict],
|
||||
keep_input_mutations: bool,
|
||||
dynamic: bool,
|
||||
):
|
||||
|
|
@ -6880,8 +6880,8 @@ class TestAOTAutogradWithCache(TestAOTAutogradWithDynamo):
|
|||
def run_autograd(
|
||||
self,
|
||||
f: Callable,
|
||||
fw_graph_cell: List[Optional[Callable]],
|
||||
decompositions: Optional[Dict],
|
||||
fw_graph_cell: list[Optional[Callable]],
|
||||
decompositions: Optional[dict],
|
||||
keep_input_mutations: bool,
|
||||
dynamic: bool,
|
||||
):
|
||||
|
|
@ -6904,11 +6904,11 @@ class TestAOTAutogradWithCache(TestAOTAutogradWithDynamo):
|
|||
def verify_aot_autograd(
|
||||
self,
|
||||
f,
|
||||
inp_: Union[Callable, List[Any]],
|
||||
inp_: Union[Callable, list[Any]],
|
||||
*,
|
||||
test_mutation: bool = False,
|
||||
keep_inp_mutations: bool = False,
|
||||
decompositions: Optional[Dict] = None,
|
||||
decompositions: Optional[dict] = None,
|
||||
dynamic: bool = False,
|
||||
# Only active when inp_ is Callable.
|
||||
# TODO: probably consolidate all tests to make inp a Callable.
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
"""
|
||||
from typing import Any, Callable, Dict
|
||||
from typing import Any, Callable
|
||||
from unittest import mock
|
||||
|
||||
from functorch.einops._parsing import (
|
||||
|
|
@ -206,7 +206,7 @@ class TestParsedExpression(TestCase):
|
|||
|
||||
class TestParsingUtils(TestCase):
|
||||
def test_parse_pattern_number_of_arrows(self) -> None:
|
||||
axes_lengths: Dict[str, int] = {}
|
||||
axes_lengths: dict[str, int] = {}
|
||||
|
||||
too_many_arrows_pattern = "a -> b -> c -> d"
|
||||
with self.assertRaises(ValueError):
|
||||
|
|
@ -220,13 +220,13 @@ class TestParsingUtils(TestCase):
|
|||
parse_pattern(just_right_arrows, axes_lengths)
|
||||
|
||||
def test_ellipsis_invalid_identifier(self) -> None:
|
||||
axes_lengths: Dict[str, int] = {"a": 1, _ellipsis: 2}
|
||||
axes_lengths: dict[str, int] = {"a": 1, _ellipsis: 2}
|
||||
pattern = f"a {_ellipsis} -> {_ellipsis} a"
|
||||
with self.assertRaises(ValueError):
|
||||
parse_pattern(pattern, axes_lengths)
|
||||
|
||||
def test_ellipsis_matching(self) -> None:
|
||||
axes_lengths: Dict[str, int] = {}
|
||||
axes_lengths: dict[str, int] = {}
|
||||
|
||||
pattern = "a -> a ..."
|
||||
with self.assertRaises(ValueError):
|
||||
|
|
@ -240,7 +240,7 @@ class TestParsingUtils(TestCase):
|
|||
parse_pattern(pattern, axes_lengths)
|
||||
|
||||
def test_left_parenthesized_ellipsis(self) -> None:
|
||||
axes_lengths: Dict[str, int] = {}
|
||||
axes_lengths: dict[str, int] = {}
|
||||
|
||||
pattern = "(...) -> ..."
|
||||
with self.assertRaises(ValueError):
|
||||
|
|
@ -254,7 +254,7 @@ class MaliciousRepr:
|
|||
|
||||
class TestValidateRearrangeExpressions(TestCase):
|
||||
def test_validate_axes_lengths_are_integers(self) -> None:
|
||||
axes_lengths: Dict[str, Any] = {"a": 1, "b": 2, "c": 3}
|
||||
axes_lengths: dict[str, Any] = {"a": 1, "b": 2, "c": 3}
|
||||
pattern = "a b c -> c b a"
|
||||
left, right = parse_pattern(pattern, axes_lengths)
|
||||
validate_rearrange_expressions(left, right, axes_lengths)
|
||||
|
|
@ -265,7 +265,7 @@ class TestValidateRearrangeExpressions(TestCase):
|
|||
validate_rearrange_expressions(left, right, axes_lengths)
|
||||
|
||||
def test_non_unitary_anonymous_axes_raises_error(self) -> None:
|
||||
axes_lengths: Dict[str, int] = {}
|
||||
axes_lengths: dict[str, int] = {}
|
||||
|
||||
left_non_unitary_axis = "a 2 -> 1 1 a"
|
||||
left, right = parse_pattern(left_non_unitary_axis, axes_lengths)
|
||||
|
|
@ -278,7 +278,7 @@ class TestValidateRearrangeExpressions(TestCase):
|
|||
validate_rearrange_expressions(left, right, axes_lengths)
|
||||
|
||||
def test_identifier_mismatch(self) -> None:
|
||||
axes_lengths: Dict[str, int] = {}
|
||||
axes_lengths: dict[str, int] = {}
|
||||
|
||||
mismatched_identifiers = "a -> a b"
|
||||
left, right = parse_pattern(mismatched_identifiers, axes_lengths)
|
||||
|
|
@ -291,7 +291,7 @@ class TestValidateRearrangeExpressions(TestCase):
|
|||
validate_rearrange_expressions(left, right, axes_lengths)
|
||||
|
||||
def test_unexpected_axes_lengths(self) -> None:
|
||||
axes_lengths: Dict[str, int] = {"c": 2}
|
||||
axes_lengths: dict[str, int] = {"c": 2}
|
||||
|
||||
pattern = "a b -> b a"
|
||||
left, right = parse_pattern(pattern, axes_lengths)
|
||||
|
|
|
|||
|
|
@ -25,7 +25,6 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|||
SOFTWARE.
|
||||
"""
|
||||
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
|
@ -34,7 +33,7 @@ from functorch.einops import rearrange
|
|||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
|
||||
identity_patterns: List[str] = [
|
||||
identity_patterns: list[str] = [
|
||||
"...->...",
|
||||
"a b c d e-> a b c d e",
|
||||
"a b c d e ...-> ... a b c d e",
|
||||
|
|
@ -45,7 +44,7 @@ identity_patterns: List[str] = [
|
|||
"a ... c d e -> a (...) c d e",
|
||||
]
|
||||
|
||||
equivalent_rearrange_patterns: List[Tuple[str, str]] = [
|
||||
equivalent_rearrange_patterns: list[tuple[str, str]] = [
|
||||
("a b c d e -> (a b) c d e", "a b ... -> (a b) ... "),
|
||||
("a b c d e -> a b (c d) e", "... c d e -> ... (c d) e"),
|
||||
("a b c d e -> a b c d e", "... -> ... "),
|
||||
|
|
@ -149,7 +148,7 @@ class TestRearrange(TestCase):
|
|||
|
||||
def test_concatenations_and_stacking(self) -> None:
|
||||
for n_arrays in [1, 2, 5]:
|
||||
shapes: List[List[int]] = [[], [1], [1, 1], [2, 3, 5, 7], [1] * 6]
|
||||
shapes: list[list[int]] = [[], [1], [1, 1], [2, 3, 5, 7], [1] * 6]
|
||||
for shape in shapes:
|
||||
arrays1 = [
|
||||
torch.arange(i, i + np.prod(shape, dtype=int)).reshape(shape)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
# Owner(s): ["module: fx"]
|
||||
import copy
|
||||
import unittest
|
||||
from typing import Optional, Set, Type
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
|
|
@ -38,7 +38,7 @@ class TestDCE(TestCase):
|
|||
self,
|
||||
m: torch.nn.Module,
|
||||
expect_dce_changes: bool,
|
||||
modules_to_be_leafs: Optional[Set[Type]] = None,
|
||||
modules_to_be_leafs: Optional[set[type]] = None,
|
||||
custom: bool = False,
|
||||
):
|
||||
class TestTracer(torch.fx.Tracer):
|
||||
|
|
|
|||
|
|
@ -2,8 +2,6 @@
|
|||
|
||||
from __future__ import annotations # type: ignore[attr-defined]
|
||||
|
||||
import typing
|
||||
|
||||
import torch
|
||||
from torch.fx import symbolic_trace
|
||||
|
||||
|
|
@ -27,13 +25,13 @@ class M2(torch.nn.Module):
|
|||
|
||||
# Non-torch annotation with no internal forward references
|
||||
class M3(torch.nn.Module):
|
||||
def forward(self, x: typing.List[torch.Tensor], a: A) -> torch.Tensor:
|
||||
def forward(self, x: list[torch.Tensor], a: A) -> torch.Tensor:
|
||||
return a(x[0])
|
||||
|
||||
|
||||
# Non-torch annotation with internal forward references
|
||||
class M4(torch.nn.Module):
|
||||
def forward(self, x: typing.List[torch.Tensor], a: A) -> torch.Tensor:
|
||||
def forward(self, x: list[torch.Tensor], a: A) -> torch.Tensor:
|
||||
return a(x[0])
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
# Owner(s): ["module: fx"]
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
from torch.fx.passes.split_utils import split_by_tags
|
||||
|
|
@ -63,8 +62,8 @@ class TestSplitByTags(TestCase):
|
|||
|
||||
@staticmethod
|
||||
def trace_and_tag(
|
||||
module: torch.nn.Module, tags: List[str]
|
||||
) -> Tuple[torch.fx.GraphModule, Dict[str, List[str]]]:
|
||||
module: torch.nn.Module, tags: list[str]
|
||||
) -> tuple[torch.fx.GraphModule, dict[str, list[str]]]:
|
||||
"""
|
||||
Test simple gm consists of nodes with tag (only show call_module nodes here):
|
||||
linear1 - tag: "red"
|
||||
|
|
@ -167,8 +166,8 @@ class TestSplitOutputType(TestCase):
|
|||
|
||||
@staticmethod
|
||||
def trace_and_tag(
|
||||
module: torch.nn.Module, inputs: torch.Tensor, tags: List[str]
|
||||
) -> Tuple[torch.fx.GraphModule, Dict[str, List[str]]]:
|
||||
module: torch.nn.Module, inputs: torch.Tensor, tags: list[str]
|
||||
) -> tuple[torch.fx.GraphModule, dict[str, list[str]]]:
|
||||
"""
|
||||
Test simple gm consists of nodes with tag (only show call_module nodes here):
|
||||
conv - tag: "red"
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
import unittest
|
||||
from collections import deque
|
||||
from functools import partial
|
||||
from typing import List, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch._dynamo
|
||||
|
|
@ -390,7 +390,7 @@ def forward(self, arg0_1, arg1_1, arg2_1):
|
|||
return output
|
||||
|
||||
def add_hooks(module, config):
|
||||
handles: List[RemovableHandle] = []
|
||||
handles: list[RemovableHandle] = []
|
||||
q = deque([(module.__class__.__name__, module)])
|
||||
while q:
|
||||
name, m = q.pop()
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ import os
|
|||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from typing import Dict, Tuple
|
||||
from unittest import skip
|
||||
|
||||
import torch
|
||||
|
|
@ -1744,7 +1743,7 @@ class AOTInductorTestsTemplate:
|
|||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: Dict[str, torch.Tensor]):
|
||||
def forward(self, x: dict[str, torch.Tensor]):
|
||||
device = next(iter(x.values())).device
|
||||
add_ = torch.zeros(5, device=device)
|
||||
mul_ = torch.ones(5, device=device)
|
||||
|
|
@ -2660,7 +2659,7 @@ class AOTInductorTestsTemplate:
|
|||
def forward(
|
||||
self,
|
||||
self_tensor: torch.Tensor,
|
||||
indices: Tuple[torch.Tensor],
|
||||
indices: tuple[torch.Tensor],
|
||||
values: torch.Tensor,
|
||||
):
|
||||
return torch.index_put(
|
||||
|
|
@ -4285,7 +4284,7 @@ def fail_cpu(is_skip=False):
|
|||
)
|
||||
|
||||
|
||||
def fail_gpu(suffixes: Tuple[str, ...], is_skip=False):
|
||||
def fail_gpu(suffixes: tuple[str, ...], is_skip=False):
|
||||
return TestFailure(
|
||||
suffixes,
|
||||
is_skip=is_skip,
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import pickle
|
|||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
from typing import List, Optional, Union
|
||||
from typing import Optional, Union
|
||||
from unittest import mock
|
||||
|
||||
import torch
|
||||
|
|
@ -1434,7 +1434,7 @@ class TestCudaCompileCommand(TestCase):
|
|||
with mock.patch("subprocess.check_output") as check_output_mock:
|
||||
CUDACodeCache.compile("test123.cu", "so", ["-Wsomething"])
|
||||
check_output_mock.assert_called()
|
||||
cmd_parts: List[str] = check_output_mock.call_args[0][0]
|
||||
cmd_parts: list[str] = check_output_mock.call_args[0][0]
|
||||
assert cmd_parts[0] == "nvcc", cmd_parts
|
||||
assert "-Wsomething" in cmd_parts, cmd_parts
|
||||
assert "-DNDEBUG" in cmd_parts, cmd_parts
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
# Owner(s): ["module: inductor"]
|
||||
import unittest
|
||||
from typing import Any, Dict, List, Type
|
||||
from typing import Any
|
||||
|
||||
import sympy
|
||||
|
||||
|
|
@ -160,11 +160,11 @@ class TestFixedConfigs(TestCase):
|
|||
class MyHeuristics(InductorChoices):
|
||||
def triton_kernel_kwargs(
|
||||
self,
|
||||
kernel_cls: Type[TritonKernel],
|
||||
kernel_cls: type[TritonKernel],
|
||||
features: SIMDKernelFeatures,
|
||||
groups: List[sympy.Expr],
|
||||
kernel_kwargs: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
groups: list[sympy.Expr],
|
||||
kernel_kwargs: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
**kernel_kwargs,
|
||||
"override_cooperative_reduction": cooperative,
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import logging
|
|||
import math
|
||||
import os
|
||||
import unittest
|
||||
from typing import Callable, List, Optional
|
||||
from typing import Callable, Optional
|
||||
from unittest import mock
|
||||
|
||||
from torch.export import Dim
|
||||
|
|
@ -114,7 +114,7 @@ class TestCutlassBackend(TestCase):
|
|||
) as mocked_select_algorithm:
|
||||
Y_compiled = torch.compile(mm, dynamic=False)(a, b)
|
||||
Y = mm(a, b)
|
||||
passed_choice_callers: List[ChoiceCaller] = mocked_select_algorithm[0][
|
||||
passed_choice_callers: list[ChoiceCaller] = mocked_select_algorithm[0][
|
||||
1
|
||||
]
|
||||
assert all(
|
||||
|
|
@ -573,7 +573,7 @@ class TestCutlassBackend(TestCase):
|
|||
return torch.addmm(x, a, b, alpha=alpha, beta=beta)
|
||||
|
||||
def compare_results(
|
||||
m: int, k: int, n: int, alpha: float, beta: float, x_shape: List[int]
|
||||
m: int, k: int, n: int, alpha: float, beta: float, x_shape: list[int]
|
||||
) -> None:
|
||||
x = torch.randn(x_shape).cuda().half()
|
||||
a = torch.randn(m, k).cuda().half()
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from collections import namedtuple
|
|||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from itertools import product
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
from typing import Callable, Optional, Union
|
||||
from unittest import expectedFailure, skip, skipUnless
|
||||
from unittest.mock import patch
|
||||
|
||||
|
|
@ -511,7 +511,7 @@ class TestFlexAttention(InductorTestCase):
|
|||
block_mask,
|
||||
dtype: torch.dtype = torch.float16,
|
||||
page_size: int = 128,
|
||||
) -> Tuple[Tensor, Tensor, BlockMask, _score_mod_signature]:
|
||||
) -> tuple[Tensor, Tensor, BlockMask, _score_mod_signature]:
|
||||
assert block_mask is not None, "Must provide block_mask"
|
||||
Q_B, Q_H, Q_S, _ = q.shape
|
||||
KV_B, KV_H, KV_S, QK_D = k.shape
|
||||
|
|
@ -596,7 +596,7 @@ class TestFlexAttention(InductorTestCase):
|
|||
v: Tensor,
|
||||
dtype: torch.dtype = torch.float16,
|
||||
block_mask: Optional[BlockMask] = None,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
B, Q_H, Q_S, KV_H, KV_S = (
|
||||
q.shape[0],
|
||||
q.shape[1],
|
||||
|
|
@ -797,7 +797,7 @@ class TestFlexAttention(InductorTestCase):
|
|||
|
||||
def run_dynamic_test(
|
||||
self,
|
||||
score_mask_mod: Tuple[Callable, Callable],
|
||||
score_mask_mod: tuple[Callable, Callable],
|
||||
dtype: torch.dtype = torch.float16,
|
||||
B: int = B,
|
||||
H: int = H,
|
||||
|
|
@ -1089,7 +1089,7 @@ class TestFlexAttention(InductorTestCase):
|
|||
@common_utils.parametrize("dtype", test_dtypes_fast)
|
||||
@common_utils.parametrize("score_mask_mod", test_score_mask_mod_map.items())
|
||||
def test_builtin_score_mods_dynamic(
|
||||
self, dtype: torch.dtype, score_mask_mod: Tuple[Callable, Callable]
|
||||
self, dtype: torch.dtype, score_mask_mod: tuple[Callable, Callable]
|
||||
):
|
||||
self.run_dynamic_test(score_mask_mod, dtype)
|
||||
|
||||
|
|
@ -1127,7 +1127,7 @@ class TestFlexAttention(InductorTestCase):
|
|||
self,
|
||||
dtype: torch.dtype,
|
||||
score_mod: Callable,
|
||||
BLOCK_SIZE: Union[int, Tuple[int, int]],
|
||||
BLOCK_SIZE: Union[int, tuple[int, int]],
|
||||
):
|
||||
block_mask = create_block_mask(
|
||||
noop_mask, B, H, S, S, BLOCK_SIZE=BLOCK_SIZE, device=self.device
|
||||
|
|
@ -1142,8 +1142,8 @@ class TestFlexAttention(InductorTestCase):
|
|||
def test_kv_batch_broadcast(
|
||||
self,
|
||||
dtype: torch.dtype,
|
||||
batch_dims: Tuple[int, int],
|
||||
head_dims: Tuple[int, int],
|
||||
batch_dims: tuple[int, int],
|
||||
head_dims: tuple[int, int],
|
||||
score_mod: Callable,
|
||||
):
|
||||
Hq, Hkv = head_dims
|
||||
|
|
@ -1175,8 +1175,8 @@ class TestFlexAttention(InductorTestCase):
|
|||
def test_kv_batch_broadcast_causal_mask(
|
||||
self,
|
||||
dtype: torch.dtype,
|
||||
batch_dims: Tuple[int, int],
|
||||
head_dims: Tuple[int, int],
|
||||
batch_dims: tuple[int, int],
|
||||
head_dims: tuple[int, int],
|
||||
score_mod: Callable,
|
||||
):
|
||||
Hq, Hkv = head_dims
|
||||
|
|
@ -3616,7 +3616,7 @@ class TestBlockMask(InductorTestCase):
|
|||
|
||||
@supported_platform
|
||||
@common_utils.parametrize("BLOCK_SIZE", [32, 64, 128, 256, (32, 64), (64, 32)])
|
||||
def test_block_size_changes(self, BLOCK_SIZE: Union[int, Tuple[int, int]]):
|
||||
def test_block_size_changes(self, BLOCK_SIZE: Union[int, tuple[int, int]]):
|
||||
B, H, Q_LEN, KV_LEN = 4, 2, 2048, 2048
|
||||
|
||||
if isinstance(BLOCK_SIZE, int):
|
||||
|
|
@ -3990,7 +3990,7 @@ BlockMask(shape=(1,s1,s2048,s2048),ssparsity=46.88%,s
|
|||
)
|
||||
|
||||
def length_to_offsets(
|
||||
lengths: List[int], device: Union[str, torch.device]
|
||||
lengths: list[int], device: Union[str, torch.device]
|
||||
) -> Tensor:
|
||||
offsets = [0]
|
||||
offsets.extend(lengths)
|
||||
|
|
@ -4561,7 +4561,7 @@ class Params:
|
|||
return f"batch:{self.batch_size}_head:{self.num_heads}_seq_len:{self.seq_length}_headdim:{self.head_dim}_dtype:{str(self.dtype).split('.')[-1]}"
|
||||
|
||||
|
||||
def get_params(dtypes: List[torch.dtype]) -> List[Params]:
|
||||
def get_params(dtypes: list[torch.dtype]) -> list[Params]:
|
||||
params = []
|
||||
seq_lengths = [37, 256, 277]
|
||||
for seq_len, dtype in product(seq_lengths, dtypes):
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
import functools
|
||||
from collections import namedtuple
|
||||
from typing import Callable, Optional, Tuple, Union
|
||||
from typing import Callable, Optional, Union
|
||||
from unittest import expectedFailure, skipUnless
|
||||
from unittest.mock import patch
|
||||
|
||||
|
|
@ -645,7 +645,7 @@ class TestFlexDecoding(InductorTestCase):
|
|||
self,
|
||||
dtype: torch.dtype,
|
||||
score_mod: Callable,
|
||||
head_dims: Tuple[int, int],
|
||||
head_dims: tuple[int, int],
|
||||
page_size: int,
|
||||
):
|
||||
Hq, Hkv = head_dims
|
||||
|
|
@ -681,7 +681,7 @@ class TestFlexDecoding(InductorTestCase):
|
|||
self,
|
||||
dtype: torch.dtype,
|
||||
score_mod: Callable,
|
||||
BLOCK_SIZE: Union[int, Tuple[int, int]],
|
||||
BLOCK_SIZE: Union[int, tuple[int, int]],
|
||||
):
|
||||
block_mask = create_block_mask(noop_mask, B, 1, 1, S, BLOCK_SIZE=BLOCK_SIZE)
|
||||
self.run_test(score_mod, dtype, block_mask=block_mask)
|
||||
|
|
@ -763,8 +763,8 @@ class TestFlexDecoding(InductorTestCase):
|
|||
def test_kv_batch_broadcast(
|
||||
self,
|
||||
dtype: torch.dtype,
|
||||
head_dims: Tuple[int, int],
|
||||
batch_dims: Tuple[int, int],
|
||||
head_dims: tuple[int, int],
|
||||
batch_dims: tuple[int, int],
|
||||
score_mod: Callable,
|
||||
):
|
||||
Hq, Hkv = head_dims
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
import functools
|
||||
import unittest
|
||||
from typing import List, Tuple, Union
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
|
@ -89,8 +89,8 @@ def _quantize_rowwise(x: Tensor, float8_dtype: torch.dtype):
|
|||
|
||||
|
||||
def _fix_fp8_dtype_for_rocm(
|
||||
dtype: Union[torch.dtype, List[torch.dtype], Tuple[torch.dtype]], device
|
||||
) -> Union[torch.dtype, List[torch.dtype], Tuple[torch.dtype]]:
|
||||
dtype: Union[torch.dtype, list[torch.dtype], tuple[torch.dtype]], device
|
||||
) -> Union[torch.dtype, list[torch.dtype], tuple[torch.dtype]]:
|
||||
# This function is used to change FP8 data types
|
||||
# with MI300 supported FP8 types if device is GPU:
|
||||
# e4m3fn -> e4m3fnuz
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
import sys
|
||||
import unittest
|
||||
from typing import List, Literal
|
||||
from typing import Literal
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
|
|
@ -50,8 +50,8 @@ class TestConfigFuzzer(TestCase):
|
|||
self.assertEqual(toggle("", bool, True), False)
|
||||
self.assertEqual(toggle("", Literal["foo", "bar"], "foo"), "bar")
|
||||
self.assertEqual(toggle("", Literal["foo", "bar"], "bar"), "foo")
|
||||
self.assertTrue("bar" in toggle("", List[Literal["foo", "bar"]], ["foo"]))
|
||||
self.assertTrue("foo" in toggle("", List[Literal["foo", "bar"]], ["bar"]))
|
||||
self.assertTrue("bar" in toggle("", list[Literal["foo", "bar"]], ["foo"]))
|
||||
self.assertTrue("foo" in toggle("", list[Literal["foo", "bar"]], ["bar"]))
|
||||
|
||||
@unittest.skipIf(sys.version_info < (3, 10), "python < 3.10 not supported")
|
||||
def test_sampling_method_random(self):
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
|
||||
import collections
|
||||
import unittest
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch._inductor
|
||||
|
|
@ -39,7 +38,7 @@ class TestHighwaySelfGating(torch.nn.Module):
|
|||
|
||||
def forward(
|
||||
self,
|
||||
inputs: List[torch.Tensor],
|
||||
inputs: list[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
results = []
|
||||
for i in range(self.size):
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user