mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ca][dtensor] run real PG dtensor tests under CA (#152689)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152689 Approved by: https://github.com/bdhirsh ghstack dependencies: #153300
This commit is contained in:
parent
5aea57d653
commit
1b4749f748
|
|
@ -1,6 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import contextlib
|
||||
import copy
|
||||
import functools
|
||||
import unittest
|
||||
|
|
@ -879,9 +880,17 @@ class TestDTensorCompileE2E(DTensorTestBase):
|
|||
def world_size(self):
|
||||
return 4
|
||||
|
||||
# multiprocess relies on pickling the source code
|
||||
# so compiled autograd tests can't dynamically wrap this class
|
||||
def _bwd_ctx(self, use_ca):
|
||||
if not use_ca:
|
||||
return contextlib.nullcontext()
|
||||
return torch._dynamo.compiled_autograd._enable(torch.compile)
|
||||
|
||||
@with_comms
|
||||
@parametrize("is_seq_parallel", [True, False])
|
||||
def test_tp_compile_fullgraph(self, is_seq_parallel):
|
||||
@parametrize("use_ca", [True, False])
|
||||
def test_tp_compile_fullgraph(self, is_seq_parallel, use_ca):
|
||||
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
||||
|
||||
model = SimpleModel(self.device_type)
|
||||
|
|
@ -935,13 +944,15 @@ class TestDTensorCompileE2E(DTensorTestBase):
|
|||
cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
|
||||
compiled_mod = torch.compile(model, backend=cnt, fullgraph=True)
|
||||
compiled_out = compiled_mod(inp)
|
||||
compiled_out.sum().backward()
|
||||
with self._bwd_ctx(use_ca):
|
||||
compiled_out.sum().backward()
|
||||
self.assertEqual(compiled_out, out)
|
||||
self.assertEqual(cnt.frame_count, 1)
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_2d_fsdp_tp_compile(self):
|
||||
@parametrize("use_ca", [True, False])
|
||||
def test_2d_fsdp_tp_compile(self, use_ca):
|
||||
data_parallel_size = 2
|
||||
model = SimpleModel(self.device_type)
|
||||
model_copy = copy.deepcopy(model)
|
||||
|
|
@ -984,13 +995,16 @@ class TestDTensorCompileE2E(DTensorTestBase):
|
|||
cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
|
||||
compiled_2d = torch.compile(fsdp_2d, backend=cnt)
|
||||
compiled_output = compiled_2d(inp)
|
||||
with self._bwd_ctx(use_ca):
|
||||
compiled_output.sum().backward()
|
||||
|
||||
self.assertEqual(out, compiled_output)
|
||||
self.assertEqual(cnt.frame_count, 1)
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_2d_fsdp_tp_ac_compile(self):
|
||||
@parametrize("use_ca", [True, False])
|
||||
def test_2d_fsdp_tp_ac_compile(self, use_ca):
|
||||
dp_degree = 2
|
||||
tp_degree = self.world_size // dp_degree
|
||||
model = SimpleModel(self.device_type)
|
||||
|
|
@ -1033,7 +1047,8 @@ class TestDTensorCompileE2E(DTensorTestBase):
|
|||
|
||||
# backward pass
|
||||
out.sum().backward()
|
||||
compiled_output.sum().backward()
|
||||
with self._bwd_ctx(use_ca):
|
||||
compiled_output.sum().backward()
|
||||
|
||||
# compare the gradients:
|
||||
for n, p in zip(fsdp_2d.parameters(), compiled_2d.parameters()):
|
||||
|
|
@ -1041,7 +1056,8 @@ class TestDTensorCompileE2E(DTensorTestBase):
|
|||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_compile_dtensor_redistribute_backward(self):
|
||||
@parametrize("use_ca", [True, False])
|
||||
def test_compile_dtensor_redistribute_backward(self, use_ca):
|
||||
mesh = DeviceMesh(device_type="cuda", mesh=torch.arange(self.world_size))
|
||||
|
||||
def fn(x, y):
|
||||
|
|
@ -1065,7 +1081,8 @@ class TestDTensorCompileE2E(DTensorTestBase):
|
|||
|
||||
# Now run and assert the backward + gradients
|
||||
ref.sum().backward()
|
||||
res.sum().backward()
|
||||
with self._bwd_ctx(use_ca):
|
||||
res.sum().backward()
|
||||
|
||||
self.assertEqual(x_ref.grad, x.grad)
|
||||
self.assertEqual(y_ref.grad, y.grad)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user