pytorch/torch/_dynamo/variables/distributed.py
Wanchao Liang f139aab2f4 [dynamo] add initial dynamo support for DTensor (#103146)
This PR adds initial dynamo support for DTensor, in particular, it:
- allows DTensor be passed into a compiled function, and allow fakify
DTensor during dynamo tracing by turning the inner local tensor to meta
tensor.
- We use `allow_in_graph` to include `DTensor` and `DTensor.from_local` to be represented as `TorchVariable`
- The dtensor created becomes a normal `TensorVariable` and it would insert any tensor operations to the output graph just like torch.Tensor
- note that dtensor have a new instance method `redistribute` compare to plain tensor, and we currently special handle it in `TensorVariable`

`from_local` and `redistribute` both accepts some non-trival metadata as arguments (i.e. DeviceMesh, Placement) which fx.Graph does not support. In order to let these two APIs appear in the dynamo captured graph, we encoded the metadata into a new_function (like `functools.partial`) and the new function only accepts prim args (i.e. tensor), then we put `call_function` with this new_function to the graph. This is suggested by @ezyang. The underlying rationale here is that the metadata will not change across the graph invocations so it's safe to encode them.

Captured graph:
```
    def forward(self, L_x_ : torch.Tensor):
        l_x_ = L_x_

        # File: /scratch/wanchaol/work/pytorch/test/distributed/_tensor/test_dtensor.py:685, code: dt = DTensor.from_local(x, mesh, [Shard(0)], run_check=False)
        prim_from_local = torch__dynamo_variables_torch_prim_from_local(l_x_, run_check = False);  l_x_ = None

        # File: /scratch/wanchaol/work/pytorch/test/distributed/_tensor/test_dtensor.py:686, code: return dt.redistribute(mesh, [Replicate()]).to_local() + 2
        prim_redistribute = torch__dynamo_variables_tensor_prim_redistribute(prim_from_local);  prim_from_local = None
        to_local = prim_redistribute.to_local();  prim_redistribute = None
        add = to_local + 2;  to_local = None
        return (add,)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103146
Approved by: https://github.com/voznesenskym
2023-07-19 16:01:12 +00:00

131 lines
4.3 KiB
Python

import inspect
from typing import Dict, List
import torch
from ..exc import unimplemented
from ..utils import istype
from .base import VariableTracker
class DistributedVariable(VariableTracker):
def __init__(self, **kwargs):
super().__init__(**kwargs)
if not DistributedVariable.is_available():
unimplemented("torch.distributed package is not available!")
@staticmethod
def is_available():
# check if the distributed package is available or not
return torch.distributed.is_available()
def is_from_local(value):
if not DistributedVariable.is_available():
return False
from torch.distributed._tensor import DTensor
return inspect.isfunction(value) and value is DTensor.from_local
class PlacementClassVariable(DistributedVariable):
def __init__(self, value, **kwargs):
super().__init__(**kwargs)
self.value = value
@staticmethod
def is_placement_type(value):
# we can't rely on importing/accessing torch distributed, it is not always built.
if not DistributedVariable.is_available():
return False
from torch.distributed._tensor.placement_types import Placement
return type(value) is type and issubclass(value, Placement)
def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
options = VariableTracker.propagate(self, args, kwargs.values())
if (
inspect.getattr_static(self.value, "__new__", None) in (object.__new__,)
and self.source
):
# NOTE: we don't need to track mutations to the placement class as they
# suppose to be immutable.
new_obj = object.__new__(self.value)
var = PlacementVariable(new_obj, **options)
if inspect.getattr_static(self.value, "__init__", None):
return var.add_options(var.call_method(tx, "__init__", args, kwargs))
return super().call_function(tx, args, kwargs)
class PlacementVariable(DistributedVariable):
def __init__(self, value, **kwargs):
super().__init__(**kwargs)
self.value = value
@staticmethod
def is_placement(value):
# we can't rely on importing/accessing torch distributed, it is not always built.
if not DistributedVariable.is_available():
return False
from torch.distributed._tensor.placement_types import Placement
return istype(value, Placement)
def as_python_constant(self):
return self.value
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
from . import ConstantVariable
options = VariableTracker.propagate(self, args, kwargs.values())
allowed_methods = ["__init__", "__setattr__"]
# placement types dynamo tracking allows only __init__
# and __setattr__ methods, the latter is for case like `Shard(dim)`
if name in allowed_methods:
try:
value_type = type(self.value)
assert (
inspect.getattr_static(value_type, "__getattr__", None) is None
), "no custom getattr allowed!"
method = inspect.getattr_static(value_type, name)
except AttributeError:
method = None
if method is object.__init__:
return ConstantVariable(None, **options)
args = [x.as_python_constant() for x in args]
kwargs = {k: v.as_python_constant() for k, v in kwargs.items()}
method(self.value, *args, **kwargs)
return self
return super().call_method(tx, name, args, kwargs)
class DeviceMeshVariable(DistributedVariable):
def __init__(self, value, **kwargs):
super().__init__(**kwargs)
self.value = value
@staticmethod
def is_device_mesh(value):
# we can't rely on importing/accessing torch distributed, it is not always built.
if not DistributedVariable.is_available():
return False
from torch.distributed._tensor.device_mesh import DeviceMesh
return istype(value, DeviceMesh)
def as_python_constant(self):
return self.value