mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[dynamo] make ProcessGroupVariable a DistributedVariable (#105593)
This PR move the ProcessGroupVariable from UDO to DistributedVT so that Distributed VTs are consolidated together Pull Request resolved: https://github.com/pytorch/pytorch/pull/105593 Approved by: https://github.com/voznesenskym
This commit is contained in:
parent
15442915cf
commit
c76c84bde4
|
|
@ -73,7 +73,11 @@ from .dicts import (
|
||||||
DefaultDictVariable,
|
DefaultDictVariable,
|
||||||
HFPretrainedConfigVariable,
|
HFPretrainedConfigVariable,
|
||||||
)
|
)
|
||||||
from .distributed import DeviceMeshVariable, PlacementClassVariable
|
from .distributed import (
|
||||||
|
DeviceMeshVariable,
|
||||||
|
PlacementClassVariable,
|
||||||
|
ProcessGroupVariable,
|
||||||
|
)
|
||||||
from .functions import (
|
from .functions import (
|
||||||
CollectiveFunctionRewriteVariable,
|
CollectiveFunctionRewriteVariable,
|
||||||
UserFunctionVariable,
|
UserFunctionVariable,
|
||||||
|
|
@ -115,11 +119,7 @@ from .tensor import (
|
||||||
UnspecializedPythonVariable,
|
UnspecializedPythonVariable,
|
||||||
)
|
)
|
||||||
from .torch import tensor_dunder_fns, torch_special_class_types, TorchVariable
|
from .torch import tensor_dunder_fns, torch_special_class_types, TorchVariable
|
||||||
from .user_defined import (
|
from .user_defined import UserDefinedClassVariable, UserDefinedObjectVariable
|
||||||
ProcessGroupVariable,
|
|
||||||
UserDefinedClassVariable,
|
|
||||||
UserDefinedObjectVariable,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@ import inspect
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from .. import variables
|
||||||
from ..exc import unimplemented
|
from ..exc import unimplemented
|
||||||
from ..utils import istype
|
from ..utils import istype
|
||||||
from .base import VariableTracker
|
from .base import VariableTracker
|
||||||
|
|
@ -27,6 +28,23 @@ def is_from_local(value):
|
||||||
return inspect.isfunction(value) and value is DTensor.from_local
|
return inspect.isfunction(value) and value is DTensor.from_local
|
||||||
|
|
||||||
|
|
||||||
|
def is_constant_pg_functions(value):
|
||||||
|
if not DistributedVariable.is_available():
|
||||||
|
return False
|
||||||
|
|
||||||
|
from torch.distributed.distributed_c10d import (
|
||||||
|
_get_group_tag,
|
||||||
|
get_process_group_ranks,
|
||||||
|
)
|
||||||
|
|
||||||
|
constant_processgroup_functions = [
|
||||||
|
get_process_group_ranks,
|
||||||
|
_get_group_tag,
|
||||||
|
]
|
||||||
|
|
||||||
|
return inspect.isfunction(value) and value in constant_processgroup_functions
|
||||||
|
|
||||||
|
|
||||||
class PlacementClassVariable(DistributedVariable):
|
class PlacementClassVariable(DistributedVariable):
|
||||||
def __init__(self, value, **kwargs):
|
def __init__(self, value, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
@ -128,3 +146,64 @@ class DeviceMeshVariable(DistributedVariable):
|
||||||
|
|
||||||
def as_python_constant(self):
|
def as_python_constant(self):
|
||||||
return self.value
|
return self.value
|
||||||
|
|
||||||
|
|
||||||
|
class ProcessGroupVariable(DistributedVariable):
|
||||||
|
"""
|
||||||
|
We don't want a ProcessGroup object to end up in our output graph.
|
||||||
|
|
||||||
|
But it's common for dynamo to intercept a PG that is then used to get info like
|
||||||
|
rank() or world_size(), as well as passed to utility functions in distributed_c10d
|
||||||
|
which desugar it into plain types like a ranklist and tag.
|
||||||
|
|
||||||
|
For convenience and proper guarding, we construct a variable type.
|
||||||
|
|
||||||
|
TODO: make it possible to use ProcessGroupVariable as input to simple functions
|
||||||
|
like _expand_group without dynamo complaining about making a proxy for it.
|
||||||
|
It is not a tensor-like type, and we don't want a proxy- but dynamo assumes
|
||||||
|
torch library functions are dealing with tensor-like types and would have proxies
|
||||||
|
for their args.
|
||||||
|
TODO: should we make this inherit VT instead of UDOV? Do we want any of the default behaviors
|
||||||
|
or just graph-break whenever one of our special cases is not hit?
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, value, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.value = value
|
||||||
|
|
||||||
|
def as_python_constant(self):
|
||||||
|
return self.value
|
||||||
|
|
||||||
|
def python_type(self):
|
||||||
|
return type(self.value)
|
||||||
|
|
||||||
|
def call_method(
|
||||||
|
self,
|
||||||
|
tx,
|
||||||
|
name,
|
||||||
|
args: "List[VariableTracker]",
|
||||||
|
kwargs: "Dict[str, VariableTracker]",
|
||||||
|
) -> "VariableTracker":
|
||||||
|
if name == "rank":
|
||||||
|
return variables.ConstantVariable(self.value.rank())
|
||||||
|
if name == "size":
|
||||||
|
return variables.ConstantVariable(self.value.size())
|
||||||
|
|
||||||
|
return super().call_method(tx, name, args, kwargs)
|
||||||
|
|
||||||
|
def var_getattr(self, tx, name):
|
||||||
|
if name in ["rank", "size"]:
|
||||||
|
return variables.LambdaVariable(
|
||||||
|
lambda *args, **kwargs: self.call_method(tx, name, args, kwargs)
|
||||||
|
).add_options(self)
|
||||||
|
# TODO should this just raise unimplemented?
|
||||||
|
return super().var_getattr(tx, name)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_process_group(value):
|
||||||
|
# we can't rely on importing/accessing torch distributed, it is not always built.
|
||||||
|
if not DistributedVariable.is_available():
|
||||||
|
return False
|
||||||
|
from torch._C._distributed_c10d import ProcessGroup
|
||||||
|
|
||||||
|
return istype(value, ProcessGroup)
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,4 @@
|
||||||
import collections
|
import collections
|
||||||
import inspect
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
|
@ -12,7 +11,6 @@ import torch.fx
|
||||||
import torch.nn
|
import torch.nn
|
||||||
import torch.onnx.operators
|
import torch.onnx.operators
|
||||||
from torch._dynamo.variables import UserFunctionVariable
|
from torch._dynamo.variables import UserFunctionVariable
|
||||||
from torch._dynamo.variables.user_defined import ProcessGroupVariable
|
|
||||||
|
|
||||||
from .. import config, variables
|
from .. import config, variables
|
||||||
from ..allowed_functions import torch_get_name
|
from ..allowed_functions import torch_get_name
|
||||||
|
|
@ -35,7 +33,7 @@ from .ctx_manager import (
|
||||||
NullContextVariable,
|
NullContextVariable,
|
||||||
TorchFunctionDisableVariable,
|
TorchFunctionDisableVariable,
|
||||||
)
|
)
|
||||||
from .distributed import is_from_local
|
from .distributed import is_constant_pg_functions, is_from_local, ProcessGroupVariable
|
||||||
from .higher_order_ops import TorchHigherOrderOperatorVariable
|
from .higher_order_ops import TorchHigherOrderOperatorVariable
|
||||||
from .lists import ListVariable, TupleVariable
|
from .lists import ListVariable, TupleVariable
|
||||||
from .tensor import TensorWithTFOverrideVariable
|
from .tensor import TensorWithTFOverrideVariable
|
||||||
|
|
@ -81,23 +79,10 @@ constant_fold_functions = [
|
||||||
torch._C._get_privateuse1_backend_name,
|
torch._C._get_privateuse1_backend_name,
|
||||||
]
|
]
|
||||||
|
|
||||||
constant_processgroup_functions = []
|
|
||||||
|
|
||||||
if torch.distributed.is_available():
|
if torch.distributed.is_available():
|
||||||
constant_fold_functions.append(torch.distributed.is_initialized)
|
constant_fold_functions.append(torch.distributed.is_initialized)
|
||||||
|
|
||||||
from torch.distributed.distributed_c10d import (
|
|
||||||
_get_group_tag,
|
|
||||||
get_process_group_ranks,
|
|
||||||
)
|
|
||||||
|
|
||||||
constant_processgroup_functions.extend(
|
|
||||||
[
|
|
||||||
get_process_group_ranks,
|
|
||||||
_get_group_tag,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# TODO(voz): perhaps a decorator? This is rather readable for now tho, and not a public API.
|
# TODO(voz): perhaps a decorator? This is rather readable for now tho, and not a public API.
|
||||||
def remap_as_fn___radd__(*args):
|
def remap_as_fn___radd__(*args):
|
||||||
|
|
@ -539,10 +524,7 @@ class TorchVariable(VariableTracker):
|
||||||
return TorchVariable(torch.add, **options).call_function(
|
return TorchVariable(torch.add, **options).call_function(
|
||||||
tx, [args[0], result], {}
|
tx, [args[0], result], {}
|
||||||
)
|
)
|
||||||
elif (
|
elif is_constant_pg_functions(self.value):
|
||||||
inspect.isfunction(self.value)
|
|
||||||
and self.value in constant_processgroup_functions
|
|
||||||
):
|
|
||||||
# becuase the input is a "ProcessGroupVariable", we'll be guarding on its
|
# becuase the input is a "ProcessGroupVariable", we'll be guarding on its
|
||||||
# ID_MATCH based on how it was constructed.
|
# ID_MATCH based on how it was constructed.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -558,61 +558,3 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
||||||
)(
|
)(
|
||||||
collections.OrderedDict.__getitem__(self.value, key.as_python_constant())
|
collections.OrderedDict.__getitem__(self.value, key.as_python_constant())
|
||||||
).add_options(key, self)
|
).add_options(key, self)
|
||||||
|
|
||||||
|
|
||||||
class ProcessGroupVariable(UserDefinedObjectVariable):
|
|
||||||
"""
|
|
||||||
We don't want a ProcessGroup object to end up in our output graph.
|
|
||||||
|
|
||||||
But it's common for dynamo to intercept a PG that is then used to get info like
|
|
||||||
rank() or world_size(), as well as passed to utility functions in distributed_c10d
|
|
||||||
which desugar it into plain types like a ranklist and tag.
|
|
||||||
|
|
||||||
For convenience and proper guarding, we construct a variable type.
|
|
||||||
|
|
||||||
TODO: make it possible to use ProcessGroupVariable as input to simple functions
|
|
||||||
like _expand_group without dynamo complaining about making a proxy for it.
|
|
||||||
It is not a tensor-like type, and we don't want a proxy- but dynamo assumes
|
|
||||||
torch library functions are dealing with tensor-like types and would have proxies
|
|
||||||
for their args.
|
|
||||||
TODO: should we make this inherit VT instead of UDOV? Do we want any of the default behaviors
|
|
||||||
or just graph-break whenever one of our special cases is not hit?
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, value, **kwargs):
|
|
||||||
super().__init__(value, **kwargs)
|
|
||||||
|
|
||||||
def as_python_constant(self):
|
|
||||||
return self.value
|
|
||||||
|
|
||||||
def call_method(
|
|
||||||
self,
|
|
||||||
tx,
|
|
||||||
name,
|
|
||||||
args: "List[VariableTracker]",
|
|
||||||
kwargs: "Dict[str, VariableTracker]",
|
|
||||||
) -> "VariableTracker":
|
|
||||||
if name == "rank":
|
|
||||||
return variables.ConstantVariable(self.value.rank())
|
|
||||||
if name == "size":
|
|
||||||
return variables.ConstantVariable(self.value.size())
|
|
||||||
|
|
||||||
# TODO should this just raise unimplemented?
|
|
||||||
return super().call_method(tx, name, args, kwargs)
|
|
||||||
|
|
||||||
def var_getattr(self, tx, name):
|
|
||||||
if name in ["rank", "size"]:
|
|
||||||
return variables.LambdaVariable(
|
|
||||||
lambda *args, **kwargs: self.call_method(tx, name, args, kwargs)
|
|
||||||
).add_options(self)
|
|
||||||
# TODO should this just raise unimplemented?
|
|
||||||
return super().var_getattr(tx, name)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def is_process_group(value):
|
|
||||||
# we can't rely on importing/accessing torch distributed, it is not always built.
|
|
||||||
if torch.distributed.is_available():
|
|
||||||
from torch._C._distributed_c10d import ProcessGroup
|
|
||||||
|
|
||||||
return istype(value, ProcessGroup)
|
|
||||||
return False
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user