mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo] Support passing arguments to DeviceMesh.get_group (#147741)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147741 Approved by: https://github.com/StrongerXi
This commit is contained in:
parent
f30776c37a
commit
d1abde11ec
|
|
@ -154,7 +154,7 @@ class TestDTensorCompile(torch._dynamo.test_case.TestCase):
|
|||
self.assertEqual(opt_fn, compiled_out)
|
||||
|
||||
def test_device_mesh_compile(self):
|
||||
def fn(x):
|
||||
def fn(x: DeviceMesh):
|
||||
# test size()
|
||||
a = x.size()
|
||||
b = x.size(0)
|
||||
|
|
@ -163,13 +163,14 @@ class TestDTensorCompile(torch._dynamo.test_case.TestCase):
|
|||
# test get_coordinate()
|
||||
coord = x.get_coordinate()
|
||||
# test get_group()
|
||||
group = x.get_group()
|
||||
return size, coord, group
|
||||
group0 = x.get_group(0)
|
||||
group1 = x.get_group(mesh_dim=1)
|
||||
return size, coord, group0, group1
|
||||
|
||||
# Cant be fullgraph=True because ProcessGroup is not reconstructible in dynamo
|
||||
compiled_fn = torch.compile(backend="aot_eager")(fn)
|
||||
|
||||
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
||||
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size).unsqueeze(1))
|
||||
opt_fn = fn(mesh)
|
||||
compiled_out = compiled_fn(mesh)
|
||||
self.assertEqual(opt_fn, compiled_out)
|
||||
|
|
|
|||
|
|
@ -260,7 +260,11 @@ class DeviceMeshVariable(DistributedVariable):
|
|||
if name == "get_coordinate":
|
||||
return ConstantVariable.create(self.value.get_coordinate())
|
||||
if name == "get_group":
|
||||
return ProcessGroupVariable(self.value.get_group())
|
||||
const_args = [x.as_python_constant() for x in args]
|
||||
const_kwargs = {k: v.as_python_constant() for k, v in kwargs.items()}
|
||||
return ProcessGroupVariable(
|
||||
self.value.get_group(*const_args, **const_kwargs)
|
||||
)
|
||||
if name == "_get_or_create_default_group":
|
||||
return ProcessGroupVariable(self.value._get_or_create_default_group())
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user