mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[CI][BE] Factor out repeated test code (#166481)
Into `_run_single_arg_fwd` Pull Request resolved: https://github.com/pytorch/pytorch/pull/166481 Approved by: https://github.com/Skylion007
This commit is contained in:
parent
98d640bb11
commit
47f0024310
|
|
@ -204,32 +204,28 @@ class DistConvolutionOpsTest(DTensorTestBase):
|
|||
self.assertTrue(b_dt.grad is not None)
|
||||
self.assertTrue(x_dt.grad is None)
|
||||
|
||||
def _run_single_arg_fwd(self, model, arg) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Given model and arg, runs fwd model local and distbuted given device_mesh"""
|
||||
device_mesh = self.build_device_mesh()
|
||||
model_copy = copy.deepcopy(model).to(device=self.device_type)
|
||||
dist_model = distribute_module(model, device_mesh, _conv_fn)
|
||||
arg_dt = DTensor.from_local(arg, device_mesh, [Replicate()])
|
||||
out_dt = dist_model(arg_dt.to(device=self.device_type))
|
||||
out = model_copy(arg)
|
||||
return (out_dt.full_tensor(), out)
|
||||
|
||||
@with_comms
|
||||
def test_conv1d(self):
|
||||
device_mesh = self.build_device_mesh()
|
||||
model = nn.Conv1d(64, 64, 3, padding=1)
|
||||
model_gt = copy.deepcopy(model)
|
||||
x = torch.randn(1, 64, 8)
|
||||
x_dt = DTensor.from_local(x, device_mesh, [Replicate()])
|
||||
model_dt = distribute_module(
|
||||
model, device_mesh, _conv_fn, input_fn=None, output_fn=None
|
||||
)
|
||||
out_dt = model_dt(x_dt)
|
||||
out = model_gt(x)
|
||||
x = torch.randn(1, 64, 8, device=self.device_type)
|
||||
out_dt, out = self._run_single_arg_fwd(model, x)
|
||||
self.assertEqual(out_dt.shape, out.shape)
|
||||
|
||||
@with_comms
|
||||
def test_conv3d(self):
|
||||
device_mesh = self.build_device_mesh()
|
||||
model = nn.Conv3d(64, 64, 3, padding=1)
|
||||
model_gt = copy.deepcopy(model).to(device=self.device_type)
|
||||
x = torch.randn(1, 64, 8, 8, 8, device=self.device_type)
|
||||
x_dt = DTensor.from_local(x, device_mesh, [Replicate()])
|
||||
model_dt = distribute_module(
|
||||
model, device_mesh, _conv_fn, input_fn=None, output_fn=None
|
||||
)
|
||||
out_dt = model_dt(x_dt)
|
||||
out = model_gt(x)
|
||||
out_dt, out = self._run_single_arg_fwd(model, x)
|
||||
self.assertEqual(out_dt.shape, out.shape)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user