mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo] Add more dynamo call_methods and getattr support or Placement (#117733)
Summary: Explained by title. This fix is part of: https://github.com/pytorch/pytorch/issues/117670 Test Plan: Unit tetst and CI - Unit test: `buck2 test mode/dev-nosan //caffe2/test/distributed/_tensor:dtensor_compile -- test_placement_compile` Differential Revision: D52863073 Pull Request resolved: https://github.com/pytorch/pytorch/pull/117733 Approved by: https://github.com/yanboliang
This commit is contained in:
parent
f612e96180
commit
56ef5afdee
|
|
@ -19,6 +19,7 @@ from torch.distributed._tensor import (
|
|||
Replicate,
|
||||
Shard,
|
||||
)
|
||||
from torch.distributed._tensor.placement_types import _Partial
|
||||
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
||||
checkpoint_wrapper,
|
||||
CheckpointImpl,
|
||||
|
|
@ -98,6 +99,32 @@ class TestDTensorCompile(torch._dynamo.test_case.TestCase):
|
|||
def world_size(self) -> int:
|
||||
return 2
|
||||
|
||||
def test_placement_compile(self):
|
||||
def fn(x):
|
||||
a = 0
|
||||
if x.is_replicate():
|
||||
a += 1
|
||||
if x.is_shard():
|
||||
a += 2
|
||||
if x.dim < 0:
|
||||
raise RuntimeError("dim < 0")
|
||||
if x.is_shard(0):
|
||||
a += 2
|
||||
if x.is_shard(dim=0):
|
||||
a += 2
|
||||
if x.is_shard(dim=None):
|
||||
a += 2
|
||||
if x.is_partial():
|
||||
a += 3
|
||||
return a
|
||||
|
||||
compiled_fn = torch.compile(backend="aot_eager", fullgraph=True)(fn)
|
||||
|
||||
for x in [Shard(0), Replicate(), _Partial()]:
|
||||
opt_fn = fn(x)
|
||||
compiled_out = compiled_fn(x)
|
||||
self.assertEqual(opt_fn, compiled_out)
|
||||
|
||||
def test_fakify_dtensor(self):
|
||||
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
||||
|
||||
|
|
|
|||
|
|
@ -103,6 +103,11 @@ class PlacementVariable(DistributedVariable):
|
|||
def as_python_constant(self):
|
||||
return self.value
|
||||
|
||||
def var_getattr(self, tx, name: str) -> VariableTracker:
|
||||
if name == "dim":
|
||||
return ConstantVariable.create(self.value.dim)
|
||||
return super().var_getattr(tx, name)
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx,
|
||||
|
|
@ -112,10 +117,20 @@ class PlacementVariable(DistributedVariable):
|
|||
) -> "VariableTracker":
|
||||
from . import ConstantVariable
|
||||
|
||||
allowed_methods = ["__init__", "__setattr__"]
|
||||
# placement types dynamo tracking allows only __init__
|
||||
# and __setattr__ methods, the latter is for case like `Shard(dim)`
|
||||
if name in allowed_methods:
|
||||
# Placement types dynamo tracking only allows following methods
|
||||
# and __setattr__ is for case like `Shard(dim)` and methods.
|
||||
# Methods in the list must satisfy:
|
||||
# 1. Input arguments are constants and do not need to be guarded on;
|
||||
# 2. Output is constant with respect to their inputs
|
||||
constant_fold_functions = [
|
||||
"__init__",
|
||||
"__setattr__",
|
||||
"is_shard",
|
||||
"is_partial",
|
||||
"is_replicate",
|
||||
]
|
||||
|
||||
if name in constant_fold_functions:
|
||||
try:
|
||||
value_type = type(self.value)
|
||||
assert (
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user