[dynamo] make ProcessGroupVariable a DistributedVariable (#105593)

This PR move the ProcessGroupVariable from UDO to DistributedVT
so that Distributed VTs are consolidated together

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105593
Approved by: https://github.com/voznesenskym
This commit is contained in:
Wanchao Liang 2023-07-25 08:33:29 +00:00 committed by PyTorch MergeBot
parent 15442915cf
commit c76c84bde4
4 changed files with 87 additions and 84 deletions

View File

@ -73,7 +73,11 @@ from .dicts import (
DefaultDictVariable, DefaultDictVariable,
HFPretrainedConfigVariable, HFPretrainedConfigVariable,
) )
from .distributed import DeviceMeshVariable, PlacementClassVariable from .distributed import (
DeviceMeshVariable,
PlacementClassVariable,
ProcessGroupVariable,
)
from .functions import ( from .functions import (
CollectiveFunctionRewriteVariable, CollectiveFunctionRewriteVariable,
UserFunctionVariable, UserFunctionVariable,
@ -115,11 +119,7 @@ from .tensor import (
UnspecializedPythonVariable, UnspecializedPythonVariable,
) )
from .torch import tensor_dunder_fns, torch_special_class_types, TorchVariable from .torch import tensor_dunder_fns, torch_special_class_types, TorchVariable
from .user_defined import ( from .user_defined import UserDefinedClassVariable, UserDefinedObjectVariable
ProcessGroupVariable,
UserDefinedClassVariable,
UserDefinedObjectVariable,
)
log = logging.getLogger(__name__) log = logging.getLogger(__name__)

View File

@ -2,6 +2,7 @@ import inspect
from typing import Dict, List from typing import Dict, List
import torch import torch
from .. import variables
from ..exc import unimplemented from ..exc import unimplemented
from ..utils import istype from ..utils import istype
from .base import VariableTracker from .base import VariableTracker
@ -27,6 +28,23 @@ def is_from_local(value):
return inspect.isfunction(value) and value is DTensor.from_local return inspect.isfunction(value) and value is DTensor.from_local
def is_constant_pg_functions(value):
if not DistributedVariable.is_available():
return False
from torch.distributed.distributed_c10d import (
_get_group_tag,
get_process_group_ranks,
)
constant_processgroup_functions = [
get_process_group_ranks,
_get_group_tag,
]
return inspect.isfunction(value) and value in constant_processgroup_functions
class PlacementClassVariable(DistributedVariable): class PlacementClassVariable(DistributedVariable):
def __init__(self, value, **kwargs): def __init__(self, value, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
@ -128,3 +146,64 @@ class DeviceMeshVariable(DistributedVariable):
def as_python_constant(self): def as_python_constant(self):
return self.value return self.value
class ProcessGroupVariable(DistributedVariable):
"""
We don't want a ProcessGroup object to end up in our output graph.
But it's common for dynamo to intercept a PG that is then used to get info like
rank() or world_size(), as well as passed to utility functions in distributed_c10d
which desugar it into plain types like a ranklist and tag.
For convenience and proper guarding, we construct a variable type.
TODO: make it possible to use ProcessGroupVariable as input to simple functions
like _expand_group without dynamo complaining about making a proxy for it.
It is not a tensor-like type, and we don't want a proxy- but dynamo assumes
torch library functions are dealing with tensor-like types and would have proxies
for their args.
TODO: should we make this inherit VT instead of UDOV? Do we want any of the default behaviors
or just graph-break whenever one of our special cases is not hit?
"""
def __init__(self, value, **kwargs):
super().__init__(**kwargs)
self.value = value
def as_python_constant(self):
return self.value
def python_type(self):
return type(self.value)
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
if name == "rank":
return variables.ConstantVariable(self.value.rank())
if name == "size":
return variables.ConstantVariable(self.value.size())
return super().call_method(tx, name, args, kwargs)
def var_getattr(self, tx, name):
if name in ["rank", "size"]:
return variables.LambdaVariable(
lambda *args, **kwargs: self.call_method(tx, name, args, kwargs)
).add_options(self)
# TODO should this just raise unimplemented?
return super().var_getattr(tx, name)
@staticmethod
def is_process_group(value):
# we can't rely on importing/accessing torch distributed, it is not always built.
if not DistributedVariable.is_available():
return False
from torch._C._distributed_c10d import ProcessGroup
return istype(value, ProcessGroup)

View File

@ -1,5 +1,4 @@
import collections import collections
import inspect
import logging import logging
import math import math
@ -12,7 +11,6 @@ import torch.fx
import torch.nn import torch.nn
import torch.onnx.operators import torch.onnx.operators
from torch._dynamo.variables import UserFunctionVariable from torch._dynamo.variables import UserFunctionVariable
from torch._dynamo.variables.user_defined import ProcessGroupVariable
from .. import config, variables from .. import config, variables
from ..allowed_functions import torch_get_name from ..allowed_functions import torch_get_name
@ -35,7 +33,7 @@ from .ctx_manager import (
NullContextVariable, NullContextVariable,
TorchFunctionDisableVariable, TorchFunctionDisableVariable,
) )
from .distributed import is_from_local from .distributed import is_constant_pg_functions, is_from_local, ProcessGroupVariable
from .higher_order_ops import TorchHigherOrderOperatorVariable from .higher_order_ops import TorchHigherOrderOperatorVariable
from .lists import ListVariable, TupleVariable from .lists import ListVariable, TupleVariable
from .tensor import TensorWithTFOverrideVariable from .tensor import TensorWithTFOverrideVariable
@ -81,23 +79,10 @@ constant_fold_functions = [
torch._C._get_privateuse1_backend_name, torch._C._get_privateuse1_backend_name,
] ]
constant_processgroup_functions = []
if torch.distributed.is_available(): if torch.distributed.is_available():
constant_fold_functions.append(torch.distributed.is_initialized) constant_fold_functions.append(torch.distributed.is_initialized)
from torch.distributed.distributed_c10d import (
_get_group_tag,
get_process_group_ranks,
)
constant_processgroup_functions.extend(
[
get_process_group_ranks,
_get_group_tag,
]
)
# TODO(voz): perhaps a decorator? This is rather readable for now tho, and not a public API. # TODO(voz): perhaps a decorator? This is rather readable for now tho, and not a public API.
def remap_as_fn___radd__(*args): def remap_as_fn___radd__(*args):
@ -539,10 +524,7 @@ class TorchVariable(VariableTracker):
return TorchVariable(torch.add, **options).call_function( return TorchVariable(torch.add, **options).call_function(
tx, [args[0], result], {} tx, [args[0], result], {}
) )
elif ( elif is_constant_pg_functions(self.value):
inspect.isfunction(self.value)
and self.value in constant_processgroup_functions
):
# becuase the input is a "ProcessGroupVariable", we'll be guarding on its # becuase the input is a "ProcessGroupVariable", we'll be guarding on its
# ID_MATCH based on how it was constructed. # ID_MATCH based on how it was constructed.

View File

@ -558,61 +558,3 @@ class UserDefinedObjectVariable(UserDefinedVariable):
)( )(
collections.OrderedDict.__getitem__(self.value, key.as_python_constant()) collections.OrderedDict.__getitem__(self.value, key.as_python_constant())
).add_options(key, self) ).add_options(key, self)
class ProcessGroupVariable(UserDefinedObjectVariable):
"""
We don't want a ProcessGroup object to end up in our output graph.
But it's common for dynamo to intercept a PG that is then used to get info like
rank() or world_size(), as well as passed to utility functions in distributed_c10d
which desugar it into plain types like a ranklist and tag.
For convenience and proper guarding, we construct a variable type.
TODO: make it possible to use ProcessGroupVariable as input to simple functions
like _expand_group without dynamo complaining about making a proxy for it.
It is not a tensor-like type, and we don't want a proxy- but dynamo assumes
torch library functions are dealing with tensor-like types and would have proxies
for their args.
TODO: should we make this inherit VT instead of UDOV? Do we want any of the default behaviors
or just graph-break whenever one of our special cases is not hit?
"""
def __init__(self, value, **kwargs):
super().__init__(value, **kwargs)
def as_python_constant(self):
return self.value
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
if name == "rank":
return variables.ConstantVariable(self.value.rank())
if name == "size":
return variables.ConstantVariable(self.value.size())
# TODO should this just raise unimplemented?
return super().call_method(tx, name, args, kwargs)
def var_getattr(self, tx, name):
if name in ["rank", "size"]:
return variables.LambdaVariable(
lambda *args, **kwargs: self.call_method(tx, name, args, kwargs)
).add_options(self)
# TODO should this just raise unimplemented?
return super().var_getattr(tx, name)
@staticmethod
def is_process_group(value):
# we can't rely on importing/accessing torch distributed, it is not always built.
if torch.distributed.is_available():
from torch._C._distributed_c10d import ProcessGroup
return istype(value, ProcessGroup)
return False