mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
[dynamo] make ProcessGroupVariable a DistributedVariable (#105593)
This PR move the ProcessGroupVariable from UDO to DistributedVT so that Distributed VTs are consolidated together Pull Request resolved: https://github.com/pytorch/pytorch/pull/105593 Approved by: https://github.com/voznesenskym
This commit is contained in:
parent
15442915cf
commit
c76c84bde4
|
|
@ -73,7 +73,11 @@ from .dicts import (
|
|||
DefaultDictVariable,
|
||||
HFPretrainedConfigVariable,
|
||||
)
|
||||
from .distributed import DeviceMeshVariable, PlacementClassVariable
|
||||
from .distributed import (
|
||||
DeviceMeshVariable,
|
||||
PlacementClassVariable,
|
||||
ProcessGroupVariable,
|
||||
)
|
||||
from .functions import (
|
||||
CollectiveFunctionRewriteVariable,
|
||||
UserFunctionVariable,
|
||||
|
|
@ -115,11 +119,7 @@ from .tensor import (
|
|||
UnspecializedPythonVariable,
|
||||
)
|
||||
from .torch import tensor_dunder_fns, torch_special_class_types, TorchVariable
|
||||
from .user_defined import (
|
||||
ProcessGroupVariable,
|
||||
UserDefinedClassVariable,
|
||||
UserDefinedObjectVariable,
|
||||
)
|
||||
from .user_defined import UserDefinedClassVariable, UserDefinedObjectVariable
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import inspect
|
|||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
from .. import variables
|
||||
from ..exc import unimplemented
|
||||
from ..utils import istype
|
||||
from .base import VariableTracker
|
||||
|
|
@ -27,6 +28,23 @@ def is_from_local(value):
|
|||
return inspect.isfunction(value) and value is DTensor.from_local
|
||||
|
||||
|
||||
def is_constant_pg_functions(value):
|
||||
if not DistributedVariable.is_available():
|
||||
return False
|
||||
|
||||
from torch.distributed.distributed_c10d import (
|
||||
_get_group_tag,
|
||||
get_process_group_ranks,
|
||||
)
|
||||
|
||||
constant_processgroup_functions = [
|
||||
get_process_group_ranks,
|
||||
_get_group_tag,
|
||||
]
|
||||
|
||||
return inspect.isfunction(value) and value in constant_processgroup_functions
|
||||
|
||||
|
||||
class PlacementClassVariable(DistributedVariable):
|
||||
def __init__(self, value, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
|
@ -128,3 +146,64 @@ class DeviceMeshVariable(DistributedVariable):
|
|||
|
||||
def as_python_constant(self):
|
||||
return self.value
|
||||
|
||||
|
||||
class ProcessGroupVariable(DistributedVariable):
|
||||
"""
|
||||
We don't want a ProcessGroup object to end up in our output graph.
|
||||
|
||||
But it's common for dynamo to intercept a PG that is then used to get info like
|
||||
rank() or world_size(), as well as passed to utility functions in distributed_c10d
|
||||
which desugar it into plain types like a ranklist and tag.
|
||||
|
||||
For convenience and proper guarding, we construct a variable type.
|
||||
|
||||
TODO: make it possible to use ProcessGroupVariable as input to simple functions
|
||||
like _expand_group without dynamo complaining about making a proxy for it.
|
||||
It is not a tensor-like type, and we don't want a proxy- but dynamo assumes
|
||||
torch library functions are dealing with tensor-like types and would have proxies
|
||||
for their args.
|
||||
TODO: should we make this inherit VT instead of UDOV? Do we want any of the default behaviors
|
||||
or just graph-break whenever one of our special cases is not hit?
|
||||
"""
|
||||
|
||||
def __init__(self, value, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.value = value
|
||||
|
||||
def as_python_constant(self):
|
||||
return self.value
|
||||
|
||||
def python_type(self):
|
||||
return type(self.value)
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx,
|
||||
name,
|
||||
args: "List[VariableTracker]",
|
||||
kwargs: "Dict[str, VariableTracker]",
|
||||
) -> "VariableTracker":
|
||||
if name == "rank":
|
||||
return variables.ConstantVariable(self.value.rank())
|
||||
if name == "size":
|
||||
return variables.ConstantVariable(self.value.size())
|
||||
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
def var_getattr(self, tx, name):
|
||||
if name in ["rank", "size"]:
|
||||
return variables.LambdaVariable(
|
||||
lambda *args, **kwargs: self.call_method(tx, name, args, kwargs)
|
||||
).add_options(self)
|
||||
# TODO should this just raise unimplemented?
|
||||
return super().var_getattr(tx, name)
|
||||
|
||||
@staticmethod
|
||||
def is_process_group(value):
|
||||
# we can't rely on importing/accessing torch distributed, it is not always built.
|
||||
if not DistributedVariable.is_available():
|
||||
return False
|
||||
from torch._C._distributed_c10d import ProcessGroup
|
||||
|
||||
return istype(value, ProcessGroup)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import collections
|
||||
import inspect
|
||||
import logging
|
||||
|
||||
import math
|
||||
|
|
@ -12,7 +11,6 @@ import torch.fx
|
|||
import torch.nn
|
||||
import torch.onnx.operators
|
||||
from torch._dynamo.variables import UserFunctionVariable
|
||||
from torch._dynamo.variables.user_defined import ProcessGroupVariable
|
||||
|
||||
from .. import config, variables
|
||||
from ..allowed_functions import torch_get_name
|
||||
|
|
@ -35,7 +33,7 @@ from .ctx_manager import (
|
|||
NullContextVariable,
|
||||
TorchFunctionDisableVariable,
|
||||
)
|
||||
from .distributed import is_from_local
|
||||
from .distributed import is_constant_pg_functions, is_from_local, ProcessGroupVariable
|
||||
from .higher_order_ops import TorchHigherOrderOperatorVariable
|
||||
from .lists import ListVariable, TupleVariable
|
||||
from .tensor import TensorWithTFOverrideVariable
|
||||
|
|
@ -81,23 +79,10 @@ constant_fold_functions = [
|
|||
torch._C._get_privateuse1_backend_name,
|
||||
]
|
||||
|
||||
constant_processgroup_functions = []
|
||||
|
||||
if torch.distributed.is_available():
|
||||
constant_fold_functions.append(torch.distributed.is_initialized)
|
||||
|
||||
from torch.distributed.distributed_c10d import (
|
||||
_get_group_tag,
|
||||
get_process_group_ranks,
|
||||
)
|
||||
|
||||
constant_processgroup_functions.extend(
|
||||
[
|
||||
get_process_group_ranks,
|
||||
_get_group_tag,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# TODO(voz): perhaps a decorator? This is rather readable for now tho, and not a public API.
|
||||
def remap_as_fn___radd__(*args):
|
||||
|
|
@ -539,10 +524,7 @@ class TorchVariable(VariableTracker):
|
|||
return TorchVariable(torch.add, **options).call_function(
|
||||
tx, [args[0], result], {}
|
||||
)
|
||||
elif (
|
||||
inspect.isfunction(self.value)
|
||||
and self.value in constant_processgroup_functions
|
||||
):
|
||||
elif is_constant_pg_functions(self.value):
|
||||
# becuase the input is a "ProcessGroupVariable", we'll be guarding on its
|
||||
# ID_MATCH based on how it was constructed.
|
||||
|
||||
|
|
|
|||
|
|
@ -558,61 +558,3 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
|||
)(
|
||||
collections.OrderedDict.__getitem__(self.value, key.as_python_constant())
|
||||
).add_options(key, self)
|
||||
|
||||
|
||||
class ProcessGroupVariable(UserDefinedObjectVariable):
|
||||
"""
|
||||
We don't want a ProcessGroup object to end up in our output graph.
|
||||
|
||||
But it's common for dynamo to intercept a PG that is then used to get info like
|
||||
rank() or world_size(), as well as passed to utility functions in distributed_c10d
|
||||
which desugar it into plain types like a ranklist and tag.
|
||||
|
||||
For convenience and proper guarding, we construct a variable type.
|
||||
|
||||
TODO: make it possible to use ProcessGroupVariable as input to simple functions
|
||||
like _expand_group without dynamo complaining about making a proxy for it.
|
||||
It is not a tensor-like type, and we don't want a proxy- but dynamo assumes
|
||||
torch library functions are dealing with tensor-like types and would have proxies
|
||||
for their args.
|
||||
TODO: should we make this inherit VT instead of UDOV? Do we want any of the default behaviors
|
||||
or just graph-break whenever one of our special cases is not hit?
|
||||
"""
|
||||
|
||||
def __init__(self, value, **kwargs):
|
||||
super().__init__(value, **kwargs)
|
||||
|
||||
def as_python_constant(self):
|
||||
return self.value
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx,
|
||||
name,
|
||||
args: "List[VariableTracker]",
|
||||
kwargs: "Dict[str, VariableTracker]",
|
||||
) -> "VariableTracker":
|
||||
if name == "rank":
|
||||
return variables.ConstantVariable(self.value.rank())
|
||||
if name == "size":
|
||||
return variables.ConstantVariable(self.value.size())
|
||||
|
||||
# TODO should this just raise unimplemented?
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
def var_getattr(self, tx, name):
|
||||
if name in ["rank", "size"]:
|
||||
return variables.LambdaVariable(
|
||||
lambda *args, **kwargs: self.call_method(tx, name, args, kwargs)
|
||||
).add_options(self)
|
||||
# TODO should this just raise unimplemented?
|
||||
return super().var_getattr(tx, name)
|
||||
|
||||
@staticmethod
|
||||
def is_process_group(value):
|
||||
# we can't rely on importing/accessing torch distributed, it is not always built.
|
||||
if torch.distributed.is_available():
|
||||
from torch._C._distributed_c10d import ProcessGroup
|
||||
|
||||
return istype(value, ProcessGroup)
|
||||
return False
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user