[ca][dtensor] run real PG dtensor tests under CA (#152689)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152689
Approved by: https://github.com/bdhirsh
ghstack dependencies: #153300
This commit is contained in:
Simon Fan 2025-05-15 11:20:33 -07:00 committed by PyTorch MergeBot
parent 5aea57d653
commit 1b4749f748

View File

@ -1,6 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
# Owner(s): ["oncall: distributed"]
import contextlib
import copy
import functools
import unittest
@ -879,9 +880,17 @@ class TestDTensorCompileE2E(DTensorTestBase):
def world_size(self):
return 4
# multiprocess relies on pickling the source code
# so compiled autograd tests can't dynamically wrap this class
def _bwd_ctx(self, use_ca):
if not use_ca:
return contextlib.nullcontext()
return torch._dynamo.compiled_autograd._enable(torch.compile)
@with_comms
@parametrize("is_seq_parallel", [True, False])
def test_tp_compile_fullgraph(self, is_seq_parallel):
@parametrize("use_ca", [True, False])
def test_tp_compile_fullgraph(self, is_seq_parallel, use_ca):
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
model = SimpleModel(self.device_type)
@ -935,13 +944,15 @@ class TestDTensorCompileE2E(DTensorTestBase):
cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
compiled_mod = torch.compile(model, backend=cnt, fullgraph=True)
compiled_out = compiled_mod(inp)
compiled_out.sum().backward()
with self._bwd_ctx(use_ca):
compiled_out.sum().backward()
self.assertEqual(compiled_out, out)
self.assertEqual(cnt.frame_count, 1)
@with_comms
@skip_if_lt_x_gpu(4)
def test_2d_fsdp_tp_compile(self):
@parametrize("use_ca", [True, False])
def test_2d_fsdp_tp_compile(self, use_ca):
data_parallel_size = 2
model = SimpleModel(self.device_type)
model_copy = copy.deepcopy(model)
@ -984,13 +995,16 @@ class TestDTensorCompileE2E(DTensorTestBase):
cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
compiled_2d = torch.compile(fsdp_2d, backend=cnt)
compiled_output = compiled_2d(inp)
with self._bwd_ctx(use_ca):
compiled_output.sum().backward()
self.assertEqual(out, compiled_output)
self.assertEqual(cnt.frame_count, 1)
@with_comms
@skip_if_lt_x_gpu(4)
def test_2d_fsdp_tp_ac_compile(self):
@parametrize("use_ca", [True, False])
def test_2d_fsdp_tp_ac_compile(self, use_ca):
dp_degree = 2
tp_degree = self.world_size // dp_degree
model = SimpleModel(self.device_type)
@ -1033,7 +1047,8 @@ class TestDTensorCompileE2E(DTensorTestBase):
# backward pass
out.sum().backward()
compiled_output.sum().backward()
with self._bwd_ctx(use_ca):
compiled_output.sum().backward()
# compare the gradients:
for n, p in zip(fsdp_2d.parameters(), compiled_2d.parameters()):
@ -1041,7 +1056,8 @@ class TestDTensorCompileE2E(DTensorTestBase):
@with_comms
@skip_if_lt_x_gpu(4)
def test_compile_dtensor_redistribute_backward(self):
@parametrize("use_ca", [True, False])
def test_compile_dtensor_redistribute_backward(self, use_ca):
mesh = DeviceMesh(device_type="cuda", mesh=torch.arange(self.world_size))
def fn(x, y):
@ -1065,7 +1081,8 @@ class TestDTensorCompileE2E(DTensorTestBase):
# Now run and assert the backward + gradients
ref.sum().backward()
res.sum().backward()
with self._bwd_ctx(use_ca):
res.sum().backward()
self.assertEqual(x_ref.grad, x.grad)
self.assertEqual(y_ref.grad, y.grad)