mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ROCm] Enable several fsdp related UTs (#149369)
Enabling 26 UTs for ROCm in the following files: - distributed._shard.sharded_optim.test_sharded_optim - 2 UTs - distributed._shard.sharded_tensor.ops.test_binary_cmp - 4 UTs - distributed._shard.sharded_tensor.ops.test_init - 3 UTs - distributed._shard.sharded_tensor.ops.test_embedding - 2 UTs - distributed._shard.sharded_tensor.ops.test_embedding_bag - 2 UTs - distributed._composable.test_replicate_with_compiler - 4 UTs - distributed._composable.fsdp.test_fully_shard_grad_scaler - 1 UTs - distributed.tensor.test_attention - 4 UTs - distributed.tensor.test_matrix_ops - 1 UTs - distributed.tensor.test_tensor_ops - 1 UTs - distributed.fsdp.test_fsdp_grad_acc - 2 UTs Pull Request resolved: https://github.com/pytorch/pytorch/pull/149369 Approved by: https://github.com/jeffdaily
This commit is contained in:
parent
7c858066ae
commit
47cdad2995
|
|
@ -131,7 +131,6 @@ class TestFullyShardCompile(FSDPTest):
|
|||
if not sm_is_or_higher_than(device, 8, 0):
|
||||
self.skipTest("bf16 requires sm >= 8.0")
|
||||
|
||||
@skipIfRocm
|
||||
def test_dynamo_trace_use_training_state(self):
|
||||
torch._dynamo.reset()
|
||||
# Construct a dummy FSDPParamGroup, since we just want to test the `use_training_state` ctx manager.
|
||||
|
|
@ -169,7 +168,6 @@ class TestFullyShardCompile(FSDPTest):
|
|||
self.assertEqual(cnt.op_count, 1)
|
||||
self.assertEqual(len(cnt.graphs), 1)
|
||||
|
||||
@skipIfRocm
|
||||
def test_trace_fsdp_copy_(self):
|
||||
@torch.library.custom_op("mylib::add_one_out", mutates_args={"out"})
|
||||
def add_one_out(x: torch.Tensor, out: torch.Tensor) -> None:
|
||||
|
|
|
|||
|
|
@ -13,12 +13,11 @@ from torch.distributed.tensor.parallel import (
|
|||
)
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_fsdp import FSDPTest, MLP
|
||||
from torch.testing._internal.common_utils import run_tests, skipIfRocm
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
|
||||
|
||||
class TestFullyShardGradientScaler(FSDPTest):
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@skipIfRocm
|
||||
def test_gradient_scaler(self):
|
||||
self.run_subtests(
|
||||
{"has_inf": [True, False], "test_2d": [True, False]},
|
||||
|
|
|
|||
|
|
@ -28,11 +28,10 @@ from torch.nn.parallel.distributed import DistributedDataParallel as DDP
|
|||
from torch.testing._internal.common_distributed import (
|
||||
DistributedTestBase,
|
||||
skip_if_lt_x_gpu,
|
||||
skip_if_rocm_multiprocess,
|
||||
sm_is_or_higher_than,
|
||||
)
|
||||
from torch.testing._internal.common_fsdp import get_devtype
|
||||
from torch.testing._internal.common_utils import run_tests, skipIfRocm
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
from torch.testing._internal.distributed.fake_pg import FakeStore
|
||||
from torch.testing._internal.inductor_utils import HAS_GPU
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
|
@ -194,7 +193,6 @@ class ReplicateTest(MultiProcessInductorTestCase):
|
|||
self._test_compile(no_sync=True, device="cpu")
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@skip_if_rocm_multiprocess
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@torch._inductor.config.patch(
|
||||
reorder_for_locality=False, reorder_for_peak_memory=False
|
||||
|
|
@ -203,7 +201,6 @@ class ReplicateTest(MultiProcessInductorTestCase):
|
|||
self._test_compile(no_sync=False, checkpoint=False, device=device_type)
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@skip_if_rocm_multiprocess
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@torch._inductor.config.patch(
|
||||
reorder_for_locality=False, reorder_for_peak_memory=False
|
||||
|
|
@ -212,11 +209,13 @@ class ReplicateTest(MultiProcessInductorTestCase):
|
|||
self._test_compile(no_sync=False, checkpoint=True, device=device_type)
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@skip_if_rocm_multiprocess
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_compile_bf16(self):
|
||||
# Check device capability wrt bf16
|
||||
if not sm_is_or_higher_than(torch.device(device_type), 8, 0):
|
||||
if (
|
||||
not sm_is_or_higher_than(torch.device(device_type), 8, 0)
|
||||
and torch.version.hip is None
|
||||
):
|
||||
self.skipTest("bf16 requires sm >= 8.0")
|
||||
|
||||
def setup(model, compiled_replicate_model, compiled_ddp_model) -> None:
|
||||
|
|
@ -230,7 +229,6 @@ class ReplicateTest(MultiProcessInductorTestCase):
|
|||
self._test_compile(no_sync=False, setup_func=setup, device=device_type)
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@skip_if_rocm_multiprocess
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_compile_fp16(self):
|
||||
def setup(model, compiled_replicate_model, compiled_ddp_model) -> None:
|
||||
|
|
@ -247,7 +245,6 @@ class ReplicateTest(MultiProcessInductorTestCase):
|
|||
)
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@skip_if_rocm_multiprocess
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_compile_backward_only(self):
|
||||
self._test_compile(no_sync=False, no_compile_forward=True, device=device_type)
|
||||
|
|
@ -387,7 +384,6 @@ class DDP_TP_Test(InductorTestCase):
|
|||
"Temporarily disabled due to SymInt error: `unhashable type: non-nested SymInt`"
|
||||
)
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@skipIfRocm
|
||||
def test_ddp_tp(self):
|
||||
ref_model = Net()
|
||||
compiled_replicate_model = deepcopy(ref_model)
|
||||
|
|
|
|||
|
|
@ -24,7 +24,6 @@ from torch.testing._internal.common_utils import (
|
|||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
run_tests,
|
||||
skipIfRocm,
|
||||
TEST_WITH_DEV_DBG_ASAN,
|
||||
)
|
||||
|
||||
|
|
@ -275,7 +274,6 @@ class TestGradAcc(FSDPTest):
|
|||
)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@skipIfRocm
|
||||
@parametrize("use_orig_params", [False, True])
|
||||
def test_grad_acc_cpu_offload(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -18,7 +18,8 @@ from torch.distributed.tensor import (
|
|||
)
|
||||
from torch.distributed.tensor.debug import CommDebugMode
|
||||
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8
|
||||
from torch.testing._internal.common_utils import run_tests, skipIfRocm
|
||||
from torch.testing._internal.common_device_type import E4M3_MAX_POS, e4m3_type
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
DTensorTestBase,
|
||||
skip_unless_torch_gpu,
|
||||
|
|
@ -33,8 +34,10 @@ def scale_for_fp8(
|
|||
t = t.unsqueeze(0).unsqueeze(-2)
|
||||
else:
|
||||
t = t.unflatten(0, (scale_shape[0], -1)).unflatten(-1, (scale_shape[1], -1))
|
||||
scale = t.abs().amax(dim=[1, -1]).float() / torch.finfo(torch.float8_e4m3fn).max
|
||||
t_fp8 = (t / scale[:, None, :, None]).to(torch.float8_e4m3fn)
|
||||
|
||||
scale = t.abs().amax(dim=[1, -1]).float() / E4M3_MAX_POS
|
||||
t_fp8 = (t / scale[:, None, :, None]).to(e4m3_type)
|
||||
|
||||
return t_fp8.flatten(end_dim=1).flatten(start_dim=-2), scale.view(scale_shape)
|
||||
|
||||
|
||||
|
|
@ -205,7 +208,7 @@ class DistMatrixOpsTest(DTensorTestBase):
|
|||
|
||||
full_dist_res = dist_res.full_tensor()
|
||||
# Fp8 matmuls are quite inaccurate, we need high tolerances
|
||||
self.assertEqual(full_dist_res, full_ref_res, atol=1, rtol=7e-2)
|
||||
self.assertEqual(full_dist_res, full_ref_res, atol=1.5, rtol=7e-2)
|
||||
|
||||
self.assertEqual(comm_mode.get_total_counts(), 0)
|
||||
|
||||
|
|
@ -448,7 +451,6 @@ class DistMatrixOpsTest(DTensorTestBase):
|
|||
self.assertTrue(dist_value.grad.placements[0].is_shard(dim=1))
|
||||
self.assertEqual(dist_value.grad.full_tensor(), value.grad)
|
||||
|
||||
@skipIfRocm
|
||||
@skip_unless_torch_gpu
|
||||
@with_comms()
|
||||
def test_dtensor_mm(self):
|
||||
|
|
@ -472,7 +474,9 @@ class DistMatrixOpsTest(DTensorTestBase):
|
|||
lhs_dtensor = distribute_tensor(lhs, mesh, [Shard(dim=0), Replicate()])
|
||||
rhs_dtensor = distribute_tensor(rhs, mesh, [Replicate(), Shard(dim=1)])
|
||||
dtensor_result = lhs_dtensor @ rhs_dtensor
|
||||
self.assertEqual(dtensor_result.full_tensor(), mm_result)
|
||||
self.assertEqual(
|
||||
dtensor_result.full_tensor(), mm_result, atol=1.5e-5, rtol=1e-6
|
||||
)
|
||||
|
||||
@with_comms
|
||||
@skip_unless_torch_gpu
|
||||
|
|
|
|||
|
|
@ -171,19 +171,10 @@ ROCM_BLOCKLIST = [
|
|||
"distributed/rpc/test_tensorpipe_agent",
|
||||
"distributed/rpc/test_share_memory",
|
||||
"distributed/rpc/cuda/test_tensorpipe_agent",
|
||||
"distributed/_shard/checkpoint/test_checkpoint"
|
||||
"distributed/_shard/checkpoint/test_file_system_checkpoint"
|
||||
"distributed/_shard/sharding_spec/test_sharding_spec",
|
||||
"distributed/_shard/sharded_tensor/ops/test_embedding",
|
||||
"distributed/_shard/sharded_tensor/ops/test_embedding_bag",
|
||||
"distributed/_shard/sharded_tensor/ops/test_binary_cmp",
|
||||
"distributed/_shard/sharded_tensor/ops/test_init",
|
||||
"distributed/_shard/sharded_optim/test_sharded_optim",
|
||||
"test_determination",
|
||||
"test_jit_legacy",
|
||||
"test_cuda_nvml_based_avail",
|
||||
"test_jit_cuda_fuser",
|
||||
"distributed/tensor/test_attention",
|
||||
]
|
||||
|
||||
S390X_BLOCKLIST = [
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ from torch.testing._internal.common_cuda import (
|
|||
SM90OrLater,
|
||||
_get_torch_cuda_version,
|
||||
PLATFORM_SUPPORTS_FP8,
|
||||
PLATFORM_SUPPORTS_MX_GEMM
|
||||
PLATFORM_SUPPORTS_MX_GEMM,
|
||||
)
|
||||
from torch.testing._internal.common_device_type import (
|
||||
dtypes,
|
||||
|
|
@ -32,6 +32,10 @@ from torch.testing._internal.common_device_type import (
|
|||
onlyCUDA,
|
||||
tol as xtol,
|
||||
toleranceOverride,
|
||||
e4m3_type,
|
||||
e5m2_type,
|
||||
E4M3_MAX_POS,
|
||||
E5M2_MAX_POS,
|
||||
)
|
||||
|
||||
from torch.testing._internal.common_utils import (
|
||||
|
|
@ -258,17 +262,6 @@ class TestMatmulCuda(TestCase):
|
|||
f8_msg = "FP8 is only supported on H100+, SM 8.9 and MI300+ devices"
|
||||
mx_skip_msg = "MX gemm is only supported on CUDA capability 10.0+"
|
||||
|
||||
if torch.version.hip and 'gfx94' in torch.cuda.get_device_properties(0).gcnArchName:
|
||||
e4m3_type = torch.float8_e4m3fnuz
|
||||
e5m2_type = torch.float8_e5m2fnuz
|
||||
E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fnuz).max
|
||||
E5M2_MAX_POS = torch.finfo(torch.float8_e5m2fnuz).max
|
||||
else:
|
||||
e4m3_type = torch.float8_e4m3fn
|
||||
e5m2_type = torch.float8_e5m2
|
||||
E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max
|
||||
E5M2_MAX_POS = torch.finfo(torch.float8_e5m2).max
|
||||
|
||||
# avoid division by zero when calculating scale
|
||||
EPS = 1e-12
|
||||
|
||||
|
|
|
|||
|
|
@ -1982,3 +1982,13 @@ flex_attention_supported_platform = unittest.skipUnless(
|
|||
and torch.cuda.get_device_capability() >= (8, 0),
|
||||
"Requires CUDA and Triton",
|
||||
)
|
||||
if torch.version.hip and "gfx94" in torch.cuda.get_device_properties(0).gcnArchName:
|
||||
e4m3_type = torch.float8_e4m3fnuz
|
||||
e5m2_type = torch.float8_e5m2fnuz
|
||||
E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fnuz).max
|
||||
E5M2_MAX_POS = torch.finfo(torch.float8_e5m2fnuz).max
|
||||
else:
|
||||
e4m3_type = torch.float8_e4m3fn
|
||||
e5m2_type = torch.float8_e5m2
|
||||
E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max
|
||||
E5M2_MAX_POS = torch.finfo(torch.float8_e5m2).max
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user