[dynamo] Support passing arguments to DeviceMesh.get_group (#147741)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147741
Approved by: https://github.com/StrongerXi
This commit is contained in:
dan_the_3rd 2025-03-04 21:19:43 +00:00 committed by PyTorch MergeBot
parent f30776c37a
commit d1abde11ec
2 changed files with 10 additions and 5 deletions

View File

@ -154,7 +154,7 @@ class TestDTensorCompile(torch._dynamo.test_case.TestCase):
self.assertEqual(opt_fn, compiled_out)
def test_device_mesh_compile(self):
def fn(x):
def fn(x: DeviceMesh):
# test size()
a = x.size()
b = x.size(0)
@ -163,13 +163,14 @@ class TestDTensorCompile(torch._dynamo.test_case.TestCase):
# test get_coordinate()
coord = x.get_coordinate()
# test get_group()
group = x.get_group()
return size, coord, group
group0 = x.get_group(0)
group1 = x.get_group(mesh_dim=1)
return size, coord, group0, group1
# Cant be fullgraph=True because ProcessGroup is not reconstructible in dynamo
compiled_fn = torch.compile(backend="aot_eager")(fn)
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size).unsqueeze(1))
opt_fn = fn(mesh)
compiled_out = compiled_fn(mesh)
self.assertEqual(opt_fn, compiled_out)

View File

@ -260,7 +260,11 @@ class DeviceMeshVariable(DistributedVariable):
if name == "get_coordinate":
return ConstantVariable.create(self.value.get_coordinate())
if name == "get_group":
return ProcessGroupVariable(self.value.get_group())
const_args = [x.as_python_constant() for x in args]
const_kwargs = {k: v.as_python_constant() for k, v in kwargs.items()}
return ProcessGroupVariable(
self.value.get_group(*const_args, **const_kwargs)
)
if name == "_get_or_create_default_group":
return ProcessGroupVariable(self.value._get_or_create_default_group())
return super().call_method(tx, name, args, kwargs)