mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fix mesh.get_local_rank when it is > 1d (#164473)
Previously, we would not take the arguments passed by get_local_rank into account. This means that we wouldn't be able to trace this call if we had a device_mesh > 1d Pull Request resolved: https://github.com/pytorch/pytorch/pull/164473 Approved by: https://github.com/xmfan, https://github.com/Skylion007
This commit is contained in:
parent
5103ecc5d8
commit
83d71dfb2f
|
|
@ -378,7 +378,7 @@ vgg16,pass,0
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
vision_maskrcnn,pass,20
|
vision_maskrcnn,pass,18
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
|
@ -286,7 +286,7 @@ vgg16,pass,6
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
vision_maskrcnn,pass,39
|
vision_maskrcnn,pass,37
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
|
@ -274,7 +274,11 @@ class DeviceMeshVariable(DistributedVariable):
|
||||||
if name == "get_rank":
|
if name == "get_rank":
|
||||||
return ConstantVariable.create(self.value.get_rank())
|
return ConstantVariable.create(self.value.get_rank())
|
||||||
if name == "get_local_rank":
|
if name == "get_local_rank":
|
||||||
return ConstantVariable.create(self.value.get_local_rank())
|
const_args = [x.as_python_constant() for x in args]
|
||||||
|
const_kwargs = {k: v.as_python_constant() for k, v in kwargs.items()}
|
||||||
|
return ConstantVariable.create(
|
||||||
|
self.value.get_local_rank(*const_args, **const_kwargs)
|
||||||
|
)
|
||||||
if name == "get_group":
|
if name == "get_group":
|
||||||
const_args = [x.as_python_constant() for x in args]
|
const_args = [x.as_python_constant() for x in args]
|
||||||
const_kwargs = {k: v.as_python_constant() for k, v in kwargs.items()}
|
const_kwargs = {k: v.as_python_constant() for k, v in kwargs.items()}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user