[dynamo] make ProcessGroupVariable a DistributedVariable (#105593)

This PR move the ProcessGroupVariable from UDO to DistributedVT
so that Distributed VTs are consolidated together

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105593
Approved by: https://github.com/voznesenskym
This commit is contained in:
Wanchao Liang 2023-07-25 08:33:29 +00:00 committed by PyTorch MergeBot
parent 15442915cf
commit c76c84bde4
4 changed files with 87 additions and 84 deletions

View File

@ -73,7 +73,11 @@ from .dicts import (
DefaultDictVariable,
HFPretrainedConfigVariable,
)
from .distributed import DeviceMeshVariable, PlacementClassVariable
from .distributed import (
DeviceMeshVariable,
PlacementClassVariable,
ProcessGroupVariable,
)
from .functions import (
CollectiveFunctionRewriteVariable,
UserFunctionVariable,
@ -115,11 +119,7 @@ from .tensor import (
UnspecializedPythonVariable,
)
from .torch import tensor_dunder_fns, torch_special_class_types, TorchVariable
from .user_defined import (
ProcessGroupVariable,
UserDefinedClassVariable,
UserDefinedObjectVariable,
)
from .user_defined import UserDefinedClassVariable, UserDefinedObjectVariable
log = logging.getLogger(__name__)

View File

@ -2,6 +2,7 @@ import inspect
from typing import Dict, List
import torch
from .. import variables
from ..exc import unimplemented
from ..utils import istype
from .base import VariableTracker
@ -27,6 +28,23 @@ def is_from_local(value):
return inspect.isfunction(value) and value is DTensor.from_local
def is_constant_pg_functions(value):
if not DistributedVariable.is_available():
return False
from torch.distributed.distributed_c10d import (
_get_group_tag,
get_process_group_ranks,
)
constant_processgroup_functions = [
get_process_group_ranks,
_get_group_tag,
]
return inspect.isfunction(value) and value in constant_processgroup_functions
class PlacementClassVariable(DistributedVariable):
def __init__(self, value, **kwargs):
super().__init__(**kwargs)
@ -128,3 +146,64 @@ class DeviceMeshVariable(DistributedVariable):
def as_python_constant(self):
return self.value
class ProcessGroupVariable(DistributedVariable):
"""
We don't want a ProcessGroup object to end up in our output graph.
But it's common for dynamo to intercept a PG that is then used to get info like
rank() or world_size(), as well as passed to utility functions in distributed_c10d
which desugar it into plain types like a ranklist and tag.
For convenience and proper guarding, we construct a variable type.
TODO: make it possible to use ProcessGroupVariable as input to simple functions
like _expand_group without dynamo complaining about making a proxy for it.
It is not a tensor-like type, and we don't want a proxy- but dynamo assumes
torch library functions are dealing with tensor-like types and would have proxies
for their args.
TODO: should we make this inherit VT instead of UDOV? Do we want any of the default behaviors
or just graph-break whenever one of our special cases is not hit?
"""
def __init__(self, value, **kwargs):
super().__init__(**kwargs)
self.value = value
def as_python_constant(self):
return self.value
def python_type(self):
return type(self.value)
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
if name == "rank":
return variables.ConstantVariable(self.value.rank())
if name == "size":
return variables.ConstantVariable(self.value.size())
return super().call_method(tx, name, args, kwargs)
def var_getattr(self, tx, name):
if name in ["rank", "size"]:
return variables.LambdaVariable(
lambda *args, **kwargs: self.call_method(tx, name, args, kwargs)
).add_options(self)
# TODO should this just raise unimplemented?
return super().var_getattr(tx, name)
@staticmethod
def is_process_group(value):
# we can't rely on importing/accessing torch distributed, it is not always built.
if not DistributedVariable.is_available():
return False
from torch._C._distributed_c10d import ProcessGroup
return istype(value, ProcessGroup)

View File

@ -1,5 +1,4 @@
import collections
import inspect
import logging
import math
@ -12,7 +11,6 @@ import torch.fx
import torch.nn
import torch.onnx.operators
from torch._dynamo.variables import UserFunctionVariable
from torch._dynamo.variables.user_defined import ProcessGroupVariable
from .. import config, variables
from ..allowed_functions import torch_get_name
@ -35,7 +33,7 @@ from .ctx_manager import (
NullContextVariable,
TorchFunctionDisableVariable,
)
from .distributed import is_from_local
from .distributed import is_constant_pg_functions, is_from_local, ProcessGroupVariable
from .higher_order_ops import TorchHigherOrderOperatorVariable
from .lists import ListVariable, TupleVariable
from .tensor import TensorWithTFOverrideVariable
@ -81,23 +79,10 @@ constant_fold_functions = [
torch._C._get_privateuse1_backend_name,
]
constant_processgroup_functions = []
if torch.distributed.is_available():
constant_fold_functions.append(torch.distributed.is_initialized)
from torch.distributed.distributed_c10d import (
_get_group_tag,
get_process_group_ranks,
)
constant_processgroup_functions.extend(
[
get_process_group_ranks,
_get_group_tag,
]
)
# TODO(voz): perhaps a decorator? This is rather readable for now tho, and not a public API.
def remap_as_fn___radd__(*args):
@ -539,10 +524,7 @@ class TorchVariable(VariableTracker):
return TorchVariable(torch.add, **options).call_function(
tx, [args[0], result], {}
)
elif (
inspect.isfunction(self.value)
and self.value in constant_processgroup_functions
):
elif is_constant_pg_functions(self.value):
# becuase the input is a "ProcessGroupVariable", we'll be guarding on its
# ID_MATCH based on how it was constructed.

View File

@ -558,61 +558,3 @@ class UserDefinedObjectVariable(UserDefinedVariable):
)(
collections.OrderedDict.__getitem__(self.value, key.as_python_constant())
).add_options(key, self)
class ProcessGroupVariable(UserDefinedObjectVariable):
"""
We don't want a ProcessGroup object to end up in our output graph.
But it's common for dynamo to intercept a PG that is then used to get info like
rank() or world_size(), as well as passed to utility functions in distributed_c10d
which desugar it into plain types like a ranklist and tag.
For convenience and proper guarding, we construct a variable type.
TODO: make it possible to use ProcessGroupVariable as input to simple functions
like _expand_group without dynamo complaining about making a proxy for it.
It is not a tensor-like type, and we don't want a proxy- but dynamo assumes
torch library functions are dealing with tensor-like types and would have proxies
for their args.
TODO: should we make this inherit VT instead of UDOV? Do we want any of the default behaviors
or just graph-break whenever one of our special cases is not hit?
"""
def __init__(self, value, **kwargs):
super().__init__(value, **kwargs)
def as_python_constant(self):
return self.value
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
if name == "rank":
return variables.ConstantVariable(self.value.rank())
if name == "size":
return variables.ConstantVariable(self.value.size())
# TODO should this just raise unimplemented?
return super().call_method(tx, name, args, kwargs)
def var_getattr(self, tx, name):
if name in ["rank", "size"]:
return variables.LambdaVariable(
lambda *args, **kwargs: self.call_method(tx, name, args, kwargs)
).add_options(self)
# TODO should this just raise unimplemented?
return super().var_getattr(tx, name)
@staticmethod
def is_process_group(value):
# we can't rely on importing/accessing torch distributed, it is not always built.
if torch.distributed.is_available():
from torch._C._distributed_c10d import ProcessGroup
return istype(value, ProcessGroup)
return False