[CI][BE] Factor out repeated test code (#166481)

Into `_run_single_arg_fwd`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166481
Approved by: https://github.com/Skylion007
This commit is contained in:
Nikita Shulga 2025-10-30 14:28:04 -07:00 committed by PyTorch MergeBot
parent 98d640bb11
commit 47f0024310

View File

@ -204,32 +204,28 @@ class DistConvolutionOpsTest(DTensorTestBase):
self.assertTrue(b_dt.grad is not None) self.assertTrue(b_dt.grad is not None)
self.assertTrue(x_dt.grad is None) self.assertTrue(x_dt.grad is None)
def _run_single_arg_fwd(self, model, arg) -> tuple[torch.Tensor, torch.Tensor]:
"""Given model and arg, runs fwd model local and distbuted given device_mesh"""
device_mesh = self.build_device_mesh()
model_copy = copy.deepcopy(model).to(device=self.device_type)
dist_model = distribute_module(model, device_mesh, _conv_fn)
arg_dt = DTensor.from_local(arg, device_mesh, [Replicate()])
out_dt = dist_model(arg_dt.to(device=self.device_type))
out = model_copy(arg)
return (out_dt.full_tensor(), out)
@with_comms @with_comms
def test_conv1d(self): def test_conv1d(self):
device_mesh = self.build_device_mesh()
model = nn.Conv1d(64, 64, 3, padding=1) model = nn.Conv1d(64, 64, 3, padding=1)
model_gt = copy.deepcopy(model) x = torch.randn(1, 64, 8, device=self.device_type)
x = torch.randn(1, 64, 8) out_dt, out = self._run_single_arg_fwd(model, x)
x_dt = DTensor.from_local(x, device_mesh, [Replicate()])
model_dt = distribute_module(
model, device_mesh, _conv_fn, input_fn=None, output_fn=None
)
out_dt = model_dt(x_dt)
out = model_gt(x)
self.assertEqual(out_dt.shape, out.shape) self.assertEqual(out_dt.shape, out.shape)
@with_comms @with_comms
def test_conv3d(self): def test_conv3d(self):
device_mesh = self.build_device_mesh()
model = nn.Conv3d(64, 64, 3, padding=1) model = nn.Conv3d(64, 64, 3, padding=1)
model_gt = copy.deepcopy(model).to(device=self.device_type)
x = torch.randn(1, 64, 8, 8, 8, device=self.device_type) x = torch.randn(1, 64, 8, 8, 8, device=self.device_type)
x_dt = DTensor.from_local(x, device_mesh, [Replicate()]) out_dt, out = self._run_single_arg_fwd(model, x)
model_dt = distribute_module(
model, device_mesh, _conv_fn, input_fn=None, output_fn=None
)
out_dt = model_dt(x_dt)
out = model_gt(x)
self.assertEqual(out_dt.shape, out.shape) self.assertEqual(out_dt.shape, out.shape)