[dynamo] Add more dynamo call_methods and getattr support or Placement (#117733)

Summary:
Explained by title.
This fix is part of: https://github.com/pytorch/pytorch/issues/117670

Test Plan:
Unit tetst and CI
- Unit test: `buck2 test mode/dev-nosan //caffe2/test/distributed/_tensor:dtensor_compile -- test_placement_compile`

Differential Revision: D52863073

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117733
Approved by: https://github.com/yanboliang
This commit is contained in:
Yue Dong 2024-01-22 18:22:54 +00:00 committed by PyTorch MergeBot
parent f612e96180
commit 56ef5afdee
2 changed files with 46 additions and 4 deletions

View File

@ -19,6 +19,7 @@ from torch.distributed._tensor import (
Replicate,
Shard,
)
from torch.distributed._tensor.placement_types import _Partial
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper,
CheckpointImpl,
@ -98,6 +99,32 @@ class TestDTensorCompile(torch._dynamo.test_case.TestCase):
def world_size(self) -> int:
return 2
def test_placement_compile(self):
def fn(x):
a = 0
if x.is_replicate():
a += 1
if x.is_shard():
a += 2
if x.dim < 0:
raise RuntimeError("dim < 0")
if x.is_shard(0):
a += 2
if x.is_shard(dim=0):
a += 2
if x.is_shard(dim=None):
a += 2
if x.is_partial():
a += 3
return a
compiled_fn = torch.compile(backend="aot_eager", fullgraph=True)(fn)
for x in [Shard(0), Replicate(), _Partial()]:
opt_fn = fn(x)
compiled_out = compiled_fn(x)
self.assertEqual(opt_fn, compiled_out)
def test_fakify_dtensor(self):
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))

View File

@ -103,6 +103,11 @@ class PlacementVariable(DistributedVariable):
def as_python_constant(self):
return self.value
def var_getattr(self, tx, name: str) -> VariableTracker:
if name == "dim":
return ConstantVariable.create(self.value.dim)
return super().var_getattr(tx, name)
def call_method(
self,
tx,
@ -112,10 +117,20 @@ class PlacementVariable(DistributedVariable):
) -> "VariableTracker":
from . import ConstantVariable
allowed_methods = ["__init__", "__setattr__"]
# placement types dynamo tracking allows only __init__
# and __setattr__ methods, the latter is for case like `Shard(dim)`
if name in allowed_methods:
# Placement types dynamo tracking only allows following methods
# and __setattr__ is for case like `Shard(dim)` and methods.
# Methods in the list must satisfy:
# 1. Input arguments are constants and do not need to be guarded on;
# 2. Output is constant with respect to their inputs
constant_fold_functions = [
"__init__",
"__setattr__",
"is_shard",
"is_partial",
"is_replicate",
]
if name in constant_fold_functions:
try:
value_type = type(self.value)
assert (