mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[CI][BE] Factor out repeated test code (#166481)
Into `_run_single_arg_fwd` Pull Request resolved: https://github.com/pytorch/pytorch/pull/166481 Approved by: https://github.com/Skylion007
This commit is contained in:
parent
98d640bb11
commit
47f0024310
|
|
@ -204,32 +204,28 @@ class DistConvolutionOpsTest(DTensorTestBase):
|
||||||
self.assertTrue(b_dt.grad is not None)
|
self.assertTrue(b_dt.grad is not None)
|
||||||
self.assertTrue(x_dt.grad is None)
|
self.assertTrue(x_dt.grad is None)
|
||||||
|
|
||||||
|
def _run_single_arg_fwd(self, model, arg) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Given model and arg, runs fwd model local and distbuted given device_mesh"""
|
||||||
|
device_mesh = self.build_device_mesh()
|
||||||
|
model_copy = copy.deepcopy(model).to(device=self.device_type)
|
||||||
|
dist_model = distribute_module(model, device_mesh, _conv_fn)
|
||||||
|
arg_dt = DTensor.from_local(arg, device_mesh, [Replicate()])
|
||||||
|
out_dt = dist_model(arg_dt.to(device=self.device_type))
|
||||||
|
out = model_copy(arg)
|
||||||
|
return (out_dt.full_tensor(), out)
|
||||||
|
|
||||||
@with_comms
|
@with_comms
|
||||||
def test_conv1d(self):
|
def test_conv1d(self):
|
||||||
device_mesh = self.build_device_mesh()
|
|
||||||
model = nn.Conv1d(64, 64, 3, padding=1)
|
model = nn.Conv1d(64, 64, 3, padding=1)
|
||||||
model_gt = copy.deepcopy(model)
|
x = torch.randn(1, 64, 8, device=self.device_type)
|
||||||
x = torch.randn(1, 64, 8)
|
out_dt, out = self._run_single_arg_fwd(model, x)
|
||||||
x_dt = DTensor.from_local(x, device_mesh, [Replicate()])
|
|
||||||
model_dt = distribute_module(
|
|
||||||
model, device_mesh, _conv_fn, input_fn=None, output_fn=None
|
|
||||||
)
|
|
||||||
out_dt = model_dt(x_dt)
|
|
||||||
out = model_gt(x)
|
|
||||||
self.assertEqual(out_dt.shape, out.shape)
|
self.assertEqual(out_dt.shape, out.shape)
|
||||||
|
|
||||||
@with_comms
|
@with_comms
|
||||||
def test_conv3d(self):
|
def test_conv3d(self):
|
||||||
device_mesh = self.build_device_mesh()
|
|
||||||
model = nn.Conv3d(64, 64, 3, padding=1)
|
model = nn.Conv3d(64, 64, 3, padding=1)
|
||||||
model_gt = copy.deepcopy(model).to(device=self.device_type)
|
|
||||||
x = torch.randn(1, 64, 8, 8, 8, device=self.device_type)
|
x = torch.randn(1, 64, 8, 8, 8, device=self.device_type)
|
||||||
x_dt = DTensor.from_local(x, device_mesh, [Replicate()])
|
out_dt, out = self._run_single_arg_fwd(model, x)
|
||||||
model_dt = distribute_module(
|
|
||||||
model, device_mesh, _conv_fn, input_fn=None, output_fn=None
|
|
||||||
)
|
|
||||||
out_dt = model_dt(x_dt)
|
|
||||||
out = model_gt(x)
|
|
||||||
self.assertEqual(out_dt.shape, out.shape)
|
self.assertEqual(out_dt.shape, out.shape)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user