diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index e50ef336572..852b241978b 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -73,7 +73,11 @@ from .dicts import ( DefaultDictVariable, HFPretrainedConfigVariable, ) -from .distributed import DeviceMeshVariable, PlacementClassVariable +from .distributed import ( + DeviceMeshVariable, + PlacementClassVariable, + ProcessGroupVariable, +) from .functions import ( CollectiveFunctionRewriteVariable, UserFunctionVariable, @@ -115,11 +119,7 @@ from .tensor import ( UnspecializedPythonVariable, ) from .torch import tensor_dunder_fns, torch_special_class_types, TorchVariable -from .user_defined import ( - ProcessGroupVariable, - UserDefinedClassVariable, - UserDefinedObjectVariable, -) +from .user_defined import UserDefinedClassVariable, UserDefinedObjectVariable log = logging.getLogger(__name__) diff --git a/torch/_dynamo/variables/distributed.py b/torch/_dynamo/variables/distributed.py index c0fe074d7e5..a7e25548cf1 100644 --- a/torch/_dynamo/variables/distributed.py +++ b/torch/_dynamo/variables/distributed.py @@ -2,6 +2,7 @@ import inspect from typing import Dict, List import torch +from .. import variables from ..exc import unimplemented from ..utils import istype from .base import VariableTracker @@ -27,6 +28,23 @@ def is_from_local(value): return inspect.isfunction(value) and value is DTensor.from_local +def is_constant_pg_functions(value): + if not DistributedVariable.is_available(): + return False + + from torch.distributed.distributed_c10d import ( + _get_group_tag, + get_process_group_ranks, + ) + + constant_processgroup_functions = [ + get_process_group_ranks, + _get_group_tag, + ] + + return inspect.isfunction(value) and value in constant_processgroup_functions + + class PlacementClassVariable(DistributedVariable): def __init__(self, value, **kwargs): super().__init__(**kwargs) @@ -128,3 +146,64 @@ class DeviceMeshVariable(DistributedVariable): def as_python_constant(self): return self.value + + +class ProcessGroupVariable(DistributedVariable): + """ + We don't want a ProcessGroup object to end up in our output graph. + + But it's common for dynamo to intercept a PG that is then used to get info like + rank() or world_size(), as well as passed to utility functions in distributed_c10d + which desugar it into plain types like a ranklist and tag. + + For convenience and proper guarding, we construct a variable type. + + TODO: make it possible to use ProcessGroupVariable as input to simple functions + like _expand_group without dynamo complaining about making a proxy for it. + It is not a tensor-like type, and we don't want a proxy- but dynamo assumes + torch library functions are dealing with tensor-like types and would have proxies + for their args. + TODO: should we make this inherit VT instead of UDOV? Do we want any of the default behaviors + or just graph-break whenever one of our special cases is not hit? + """ + + def __init__(self, value, **kwargs): + super().__init__(**kwargs) + self.value = value + + def as_python_constant(self): + return self.value + + def python_type(self): + return type(self.value) + + def call_method( + self, + tx, + name, + args: "List[VariableTracker]", + kwargs: "Dict[str, VariableTracker]", + ) -> "VariableTracker": + if name == "rank": + return variables.ConstantVariable(self.value.rank()) + if name == "size": + return variables.ConstantVariable(self.value.size()) + + return super().call_method(tx, name, args, kwargs) + + def var_getattr(self, tx, name): + if name in ["rank", "size"]: + return variables.LambdaVariable( + lambda *args, **kwargs: self.call_method(tx, name, args, kwargs) + ).add_options(self) + # TODO should this just raise unimplemented? + return super().var_getattr(tx, name) + + @staticmethod + def is_process_group(value): + # we can't rely on importing/accessing torch distributed, it is not always built. + if not DistributedVariable.is_available(): + return False + from torch._C._distributed_c10d import ProcessGroup + + return istype(value, ProcessGroup) diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index bf308402217..216b73a4a13 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -1,5 +1,4 @@ import collections -import inspect import logging import math @@ -12,7 +11,6 @@ import torch.fx import torch.nn import torch.onnx.operators from torch._dynamo.variables import UserFunctionVariable -from torch._dynamo.variables.user_defined import ProcessGroupVariable from .. import config, variables from ..allowed_functions import torch_get_name @@ -35,7 +33,7 @@ from .ctx_manager import ( NullContextVariable, TorchFunctionDisableVariable, ) -from .distributed import is_from_local +from .distributed import is_constant_pg_functions, is_from_local, ProcessGroupVariable from .higher_order_ops import TorchHigherOrderOperatorVariable from .lists import ListVariable, TupleVariable from .tensor import TensorWithTFOverrideVariable @@ -81,23 +79,10 @@ constant_fold_functions = [ torch._C._get_privateuse1_backend_name, ] -constant_processgroup_functions = [] if torch.distributed.is_available(): constant_fold_functions.append(torch.distributed.is_initialized) - from torch.distributed.distributed_c10d import ( - _get_group_tag, - get_process_group_ranks, - ) - - constant_processgroup_functions.extend( - [ - get_process_group_ranks, - _get_group_tag, - ] - ) - # TODO(voz): perhaps a decorator? This is rather readable for now tho, and not a public API. def remap_as_fn___radd__(*args): @@ -539,10 +524,7 @@ class TorchVariable(VariableTracker): return TorchVariable(torch.add, **options).call_function( tx, [args[0], result], {} ) - elif ( - inspect.isfunction(self.value) - and self.value in constant_processgroup_functions - ): + elif is_constant_pg_functions(self.value): # becuase the input is a "ProcessGroupVariable", we'll be guarding on its # ID_MATCH based on how it was constructed. diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 59da1cdcf4a..0e1851e662f 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -558,61 +558,3 @@ class UserDefinedObjectVariable(UserDefinedVariable): )( collections.OrderedDict.__getitem__(self.value, key.as_python_constant()) ).add_options(key, self) - - -class ProcessGroupVariable(UserDefinedObjectVariable): - """ - We don't want a ProcessGroup object to end up in our output graph. - - But it's common for dynamo to intercept a PG that is then used to get info like - rank() or world_size(), as well as passed to utility functions in distributed_c10d - which desugar it into plain types like a ranklist and tag. - - For convenience and proper guarding, we construct a variable type. - - TODO: make it possible to use ProcessGroupVariable as input to simple functions - like _expand_group without dynamo complaining about making a proxy for it. - It is not a tensor-like type, and we don't want a proxy- but dynamo assumes - torch library functions are dealing with tensor-like types and would have proxies - for their args. - TODO: should we make this inherit VT instead of UDOV? Do we want any of the default behaviors - or just graph-break whenever one of our special cases is not hit? - """ - - def __init__(self, value, **kwargs): - super().__init__(value, **kwargs) - - def as_python_constant(self): - return self.value - - def call_method( - self, - tx, - name, - args: "List[VariableTracker]", - kwargs: "Dict[str, VariableTracker]", - ) -> "VariableTracker": - if name == "rank": - return variables.ConstantVariable(self.value.rank()) - if name == "size": - return variables.ConstantVariable(self.value.size()) - - # TODO should this just raise unimplemented? - return super().call_method(tx, name, args, kwargs) - - def var_getattr(self, tx, name): - if name in ["rank", "size"]: - return variables.LambdaVariable( - lambda *args, **kwargs: self.call_method(tx, name, args, kwargs) - ).add_options(self) - # TODO should this just raise unimplemented? - return super().var_getattr(tx, name) - - @staticmethod - def is_process_group(value): - # we can't rely on importing/accessing torch distributed, it is not always built. - if torch.distributed.is_available(): - from torch._C._distributed_c10d import ProcessGroup - - return istype(value, ProcessGroup) - return False