import abc import collections import contextlib import dataclasses import enum import functools import inspect import logging import operator import re import types from typing import List, NamedTuple, Optional, Union try: import numpy as np except ModuleNotFoundError: np = None import torch from torch import SymInt from torch._guards import GuardSource, TracingContext from torch._ops import HigherOrderOperator from torch._subclasses.fake_tensor import FakeTensor, is_fake, maybe_get_fake_mode from torch.fx.experimental.symbolic_shapes import ( _constrain_range_for_size, DimConstraint, DimDynamic, RelaxedUnspecConstraint, ) from torch.fx.immutable_collections import immutable_list from torch.utils._python_dispatch import is_traceable_wrapper_subclass from torch.utils.weak import TensorWeakRef, WeakIdRef from .. import config, mutation_guard, replay_record, skipfiles from ..allowed_functions import ( is_allowed, is_builtin_callable, is_numpy, is_user_defined_allowed, ) from ..exc import unimplemented from ..guards import GuardBuilder, make_dupe_guard from ..side_effects import SideEffects from ..source import ( AttrSource, ConstantSource, ConvertIntSource, GetItemSource, GlobalWeakRefSource, is_constant_source, LocalSource, NumpyTensorSource, RandomValueSource, Source, TupleIteratorGetItemSource, ) from ..utils import ( build_checkpoint_variable, clone_input, get_fake_value, get_static_address_type, global_key_name, is_namedtuple, is_typing, is_utils_checkpoint, istype, odict_values, preserve_rng_state, tensor_always_has_static_shape, tuple_iterator, tuple_iterator_getitem, tuple_iterator_len, wrap_fake_exception, ) from .base import MutableLocal, typestr, VariableTracker from .builtin import BuiltinVariable from .constant import ConstantVariable, EnumVariable from .ctx_manager import CUDAStreamVariable, NullContextVariable from .dicts import ( ConstDictVariable, DataClassVariable, DefaultDictVariable, HFPretrainedConfigVariable, ) from .distributed import ( DeviceMeshVariable, PlacementClassVariable, PlacementVariable, ProcessGroupVariable, ) from .functions import ( CollectiveFunctionRewriteVariable, FunctoolsPartialVariable, TritonKernelVariable, UserFunctionVariable, UserMethodVariable, ) from .higher_order_ops import TorchHigherOrderOperatorVariable from .lists import ( BaseListVariable, ListVariable, NamedTupleVariable, RangeVariable, SetVariable, SizeVariable, SliceVariable, TupleIteratorVariable, TupleVariable, ) from .misc import ( AutogradFunctionContextVariable, AutogradFunctionVariable, ComptimeVariable, GetAttrVariable, GetSetDescriptorVariable, InspectSignatureVariable, LambdaVariable, MethodWrapperVariable, NumpyVariable, PythonModuleVariable, SkipFilesVariable, TypingVariable, ) from .nn_module import FSDPManagedNNModuleVariable, UnspecializedNNModuleVariable from .optimizer import OptimizerVariable from .tensor import ( NumpyNdarrayVariable, SymNodeVariable, TensorSubclassVariable, TensorVariable, TensorWithTFOverrideVariable, UnspecializedPythonVariable, ) from .torch import tensor_dunder_fns, torch_special_class_types, TorchVariable from .user_defined import ( KeyedJaggedTensorVariable, UserDefinedClassVariable, UserDefinedObjectVariable, ) log = logging.getLogger(__name__) DimList = List class _missing: pass @dataclasses.dataclass class GraphArg: source: Source # TODO: storing a SymInt here but not a FakeTensor is a pretty strange # thing to do. Probably should have example (which stores an int) and # fake_example _example: Union[TensorWeakRef, torch.SymInt] is_unspecialized: bool fake_tensor: Optional[torch._subclasses.fake_tensor.FakeTensor] # UnspecializedPythonVariable often masquerades as a tensor. # We MUST NOT generate shape guard code # that actually tries to access tensor properties on these values. # is_tensor lets us tell if this graph arg actually is a tensor # or not. is_tensor: bool = True # Sometimes, the Tensor we pass to example is freshly allocated (smh). # Then we cannot only keep a weak reference to it. This lets you # stash a strong reference too. example_strong_ref: Optional[torch.Tensor] = None @property def example(self): if isinstance(self._example, TensorWeakRef): r = self._example() assert r is not None return r else: return self._example def __post_init__(self): if isinstance(self._example, torch.Tensor): self._example = TensorWeakRef(self._example) assert is_fake(self.fake_tensor) def load(self, tx): return self.source.reconstruct(tx) def erase(self): self._example = None def __eq__(self, other): return self.source.name() == other.source.name() @dataclasses.dataclass class FrameStateSizeEntry: scalar: Optional[int] size: Optional[List[int]] class VariableBuilder: """Wrap a python value in a VariableTracker() instance""" def __init__( self, tx, source: Source, ): assert ( source is not None ), "Consider SourcelessBuilder for ephemeral objects, usually objects created locally." assert TracingContext.get() is not None, "Expected active TracingContext" super().__init__() self.tx = tx self.source = source self.name = source.name() def __call__(self, value): if value in self.tx.output.side_effects: side_effect_result = self.tx.output.side_effects[value] dup_guard = make_dupe_guard(self.source, side_effect_result.source) if dup_guard: side_effect_result = side_effect_result.add_guards( self.make_guards(dup_guard) ) return side_effect_result vt = self._wrap(value).clone(**self.options()) if self._can_lift_attrs_to_inputs(vt): vt = self.tx.output.side_effects.track_object_existing( self.source, value, vt ) return vt def _can_lift_attrs_to_inputs(self, vt): if type(vt) in [ TensorVariable, TensorWithTFOverrideVariable, UserDefinedObjectVariable, NumpyNdarrayVariable, ]: return True return False @staticmethod @functools.lru_cache(None) def _common_constants(): return { # We zero-one specialize shapes, so specialize these constants # too 0, 1, # NB: There used to be more constants here, but honestly it was # pretty confusing. Note we specialize floats by default, and # DON'T specialize ints by default. This all only matters with # dynamic_shapes } def get_source(self): return self.source def options(self): return {"source": self.get_source()} def make_guards(self, *guards): source = self.get_source() if ( isinstance(source, ConstantSource) or source.guard_source() == GuardSource.CONSTANT ): return None return {source.make_guard(guard) for guard in guards} @classmethod @functools.lru_cache(None) def _type_dispatch(cls): # NB: Careful not to close over self to avoid ref cycle from lru_cache entries = [ ( (torch.Tensor, torch.nn.Parameter, torch._subclasses.FakeTensor), cls.wrap_tensor, ), ((tuple, list, odict_values, collections.deque), cls.wrap_listlike), (tuple_iterator, cls.wrap_tuple_iterator), ((slice, range), cls.wrap_slice_range), ( ( int, float, bool, type(None), str, torch.Size, torch.device, torch.dtype, ), cls.wrap_literal, ), ] if config.trace_numpy and np: entries.append((np.ndarray, cls.wrap_numpy_ndarray)) result = {} for ts, fn in entries: for t in ts if isinstance(ts, tuple) else (ts,): assert t not in result result[t] = fn return result @classmethod @functools.lru_cache(None) def _id_dispatch(cls): from ..comptime import comptime entries = [ ( inspect.signature, lambda self, value: LambdaVariable( InspectSignatureVariable.create, source=self.source, guards=self.make_guards(GuardBuilder.FUNCTION_MATCH), ), ), (comptime, lambda self, value: ComptimeVariable()), ( dataclasses.fields, lambda self, value: LambdaVariable( _dataclasses_fields_lambda, source=self.source, guards=self.make_guards(GuardBuilder.FUNCTION_MATCH), ), ), ( tensor_dunder_fns, lambda self, value: TorchVariable( value, source=self.source, guards=self.make_guards(GuardBuilder.FUNCTION_MATCH), ), ), ] result = {} for ts, fn in entries: for t in ts if isinstance(ts, (tuple, list)) else (ts,): assert t not in result result[id(t)] = fn return result def _wrap(self, value): # import here to avoid circular dependencies from torch.utils._triton import has_triton if has_triton(): from triton.runtime.jit import JITFunction else: class JITFunction: pass make_guards = self.make_guards # Handle exact type() match type_dispatch = self._type_dispatch().get(type(value)) if type_dispatch is not None: return type_dispatch(self, value) # Handle exact id() match id_dispatch = self._id_dispatch().get(id(value)) if id_dispatch is not None: return id_dispatch(self, value) # Note - There are some nested values where types mismatch! # We want to get those out and wrap those. value = inspect.getattr_static(value, "_torchdynamo_inline", value) # Everything else (NB: order matters!) if is_traceable_wrapper_subclass(value) or istype( value, config.traceable_tensor_subclasses ): return self.wrap_tensor(value) elif is_namedtuple(value): return self.wrap_listlike(value) elif value is torch.utils._pytree.SUPPORTED_NODES: result = { k: UserDefinedObjectVariable( value[k], source=GetItemSource(self.get_source(), k), # For SUPPORTED_NODES, we guard on the dictionary version (PEP509) # under the assumption that the values themselves don't change. guards=self.make_guards(GuardBuilder.DICT_VERSION), ) for k in value.keys() } return ConstDictVariable(result, type(value)) elif istype( value, (dict, collections.defaultdict, collections.OrderedDict) ) and all( ConstantVariable.is_literal(k) or self.tensor_can_be_dict_key(k) or isinstance(k, enum.Enum) for k in value.keys() ): if not value and self.get_source().is_nn_module(): # It is faster to guard on 'false' property than to guard # on actual dict keys, but we can't do this fast guard in general because # it omits a crucial type check that ensures the value is actually still a dict at runtime. # Why is this OK for (specialized) nnmodules? We set up a setattr hook # to check for module property mutations, which does a reasonable, # but not completely secure job ensuring a property wasn't changed. guards = self.make_guards(GuardBuilder.BOOL_FALSE) else: guards = self.make_guards(GuardBuilder.DICT_KEYS) # store key variables in global location for reconstruction for key in value.keys(): if self.tensor_can_be_dict_key(key): self.tx.store_global_weakref(global_key_name(key), key) def index_source(key): if self.tensor_can_be_dict_key(key): return GlobalWeakRefSource(global_key_name(key)) else: return key result = { k: VariableBuilder( self.tx, GetItemSource(self.get_source(), index_source(k)) )(value[k]).add_guards(guards) for k in value.keys() } if istype(value, collections.defaultdict): result = DefaultDictVariable( result, type(value), self._wrap(value.default_factory), guards=guards, ) else: result = ConstDictVariable(result, type(value), guards=guards) return self.tx.output.side_effects.track_dict(self.source, value, result) elif isinstance(value, torch.nn.Module): return self.wrap_module(value) elif ConstantVariable.is_literal(value): # non-atomic literals return self.wrap_literal(value) elif istype(value, frozenset) and ( all(is_allowed(x) or ConstantVariable.is_literal(x) for x in value) ): # For frozenset, we can guard by object ID instead of value # equality, this allows us to handle non-literal values return ConstantVariable.create( value=value, source=self.source, guards=make_guards(GuardBuilder.ID_MATCH), ) elif isinstance(value, enum.Enum): return EnumVariable( value=value, source=self.source, guards=make_guards(GuardBuilder.ID_MATCH), ) elif is_builtin_callable(value): return BuiltinVariable( value, source=self.source, guards=make_guards(GuardBuilder.BUILTIN_MATCH), ) elif is_utils_checkpoint(value): return build_checkpoint_variable(source=self.source) elif is_allowed(value): if is_user_defined_allowed(value): self.tx.output.has_user_defined_allowed_in_graph = True return TorchVariable( value, source=self.source, guards=make_guards(GuardBuilder.FUNCTION_MATCH), ) elif isinstance(value, functools.partial): func_src = AttrSource(self.get_source(), "func") func_obj = VariableBuilder(self.tx, func_src)(value.func) args = [] args_source = AttrSource(self.get_source(), "args") for i, arg in enumerate(value.args): args.append( VariableBuilder(self.tx, GetItemSource(args_source, i))(arg) ) keywords = {} keywords_source = AttrSource(self.get_source(), "keywords") for k, v in value.keywords.items(): keywords[k] = VariableBuilder( self.tx, GetItemSource(keywords_source, k) )(v) guards = { self.get_source().make_guard(GuardBuilder.TYPE_MATCH), keywords_source.make_guard(GuardBuilder.DICT_KEYS), args_source.make_guard(GuardBuilder.LIST_LENGTH), } return FunctoolsPartialVariable( func_obj, args, keywords, original=value, guards=guards ) elif is_typing(value): # typing.List, typing.Mapping, etc. return TypingVariable( value, source=self.source, guards=make_guards(GuardBuilder.ID_MATCH), ) elif is_numpy(value): assert np return NumpyVariable( value, source=self.source, guards=make_guards( GuardBuilder.FUNCTION_MATCH if callable(value) else GuardBuilder.TYPE_MATCH ), ) elif ( istype(value, (type, types.FunctionType)) and skipfiles.check(value, allow_torch=True) and not inspect.getattr_static(value, "_torchdynamo_inline", False) ): return SkipFilesVariable( value, skipfiles.check_verbose(value, allow_torch=True).reason, source=self.source, guards=make_guards(GuardBuilder.FUNCTION_MATCH), ) # NB: These can't be put in type_dispatch, they have to run later elif CollectiveFunctionRewriteVariable.can_rewrite(value): new_fn, new_source = CollectiveFunctionRewriteVariable.rewrite(value) old_source = self.source self.source = new_source return CollectiveFunctionRewriteVariable( new_fn, orig_fn=value, orig_source=old_source, source=new_source, guards=make_guards(GuardBuilder.FUNCTION_MATCH), ) elif istype(value, (types.FunctionType, torch.jit.ScriptFunction)): return UserFunctionVariable( value, source=self.source, guards=make_guards(GuardBuilder.FUNCTION_MATCH), ) elif istype(value, (types.ModuleType, replay_record.DummyModule)): return PythonModuleVariable( value, source=self.source, guards=make_guards(GuardBuilder.PYMODULE_MATCH), ) elif istype(value, torch.autograd.function.FunctionMeta): return AutogradFunctionVariable( value, source=self.source, guards=make_guards(GuardBuilder.FUNCTION_MATCH), ) elif isinstance(value, torch.autograd.function.FunctionCtx): # The autograd.function context return self.tx.output.side_effects.track_object_existing( self.source, value, AutogradFunctionContextVariable( value, source=self.source, guards=make_guards(GuardBuilder.TYPE_MATCH), ), ) elif ( isinstance(value, types.MethodType) and istype( getattr(value, "__self__", None), torch.autograd.function.FunctionMeta ) and getattr(value, "__name__", "") == "apply" and value == getattr(value.__self__, "apply", None) ): # handle aliased autograd function `apply` calls return GetAttrVariable( AutogradFunctionVariable( value.__self__, source=self.source, guards=make_guards(GuardBuilder.FUNCTION_MATCH), ), "apply", ) elif np and isinstance(value, np.number): return self.wrap_unspecialized_primitive(value) elif DataClassVariable.is_matching_object(value): return DataClassVariable.wrap(self, value).add_guards( make_guards(GuardBuilder.TYPE_MATCH) ) elif HFPretrainedConfigVariable.is_matching_object(value): return HFPretrainedConfigVariable( value, guards=make_guards(GuardBuilder.TYPE_MATCH) ) elif isinstance(value, HigherOrderOperator): return TorchHigherOrderOperatorVariable.make( value, source=self.source, guards=self.make_guards( GuardBuilder.TYPE_MATCH, GuardBuilder.NAME_MATCH ), ) elif type(value).__name__ == "builtin_function_or_method" and isinstance( value.__self__, torch_special_class_types ): return TorchVariable( value, guards=make_guards(GuardBuilder.FUNCTION_MATCH), ) elif isinstance(value, torch.cuda.streams.Stream): unimplemented("CUDAStreamVariable does not currently work soundly.") # return CUDAStreamVariable( # None, # value, # source=self.source, # guards=self.make_guards(GuardBuilder.ID_MATCH), # ) elif ( isinstance(value, torch._C._TensorMeta) and value in config.traceable_tensor_subclasses ): return TensorSubclassVariable(value, source=self.source) elif isinstance(value, types.MethodType) and isinstance( value.__self__, torch.nn.Module ): # don't let MethodTypes fall through to UserDefinedObject, # which doesn't support 'CALL_FUNCTION' # TODO(whc): Why do we limit this to methods on NNModules? # I don't have a good reason for this, but it preserves the existing behavior # for MBartForConditionalGeneration, which generates many graph breaks and OOMs otherwise. # I suspect we probably want to relax this check and dig deeper there. # In order to construct a MethodVariable in Dynamo, we start with an actual method obj from python, # but need to separately wrap its underlying `__func__` and its `self` argument. We wrap `self` here # and then `__func__` gets wrapped inside UserMethodVariable. self_obj = VariableBuilder( self.tx, source=AttrSource(self.source, "__self__") )(value.__self__) assert self_obj and isinstance( self_obj, VariableTracker ), "Failed to produce a valid self obj" return UserMethodVariable( value.__func__, self_obj, source=self.source, guards=make_guards(GuardBuilder.FUNCTION_MATCH), ) elif ( istype(value, contextlib.nullcontext) and inspect.getattr_static(value, "enter_result", None) is None ): return NullContextVariable( source=self.source, guards=make_guards(GuardBuilder.FUNCTION_MATCH), ) elif KeyedJaggedTensorVariable.is_matching_object(value): result = KeyedJaggedTensorVariable( value, source=self.source, guards=self.make_guards(GuardBuilder.TYPE_MATCH), ) # TODO: this doing it manually is bad return self.tx.output.side_effects.track_object_existing( self.source, value, result ) elif isinstance(value, types.GetSetDescriptorType): return GetSetDescriptorVariable( value, guards=self.make_guards(GuardBuilder.FUNCTION_MATCH) ) elif isinstance(value, types.MethodWrapperType): return MethodWrapperVariable( value, guards=self.make_guards(GuardBuilder.FUNCTION_MATCH) ) elif isinstance(value, torch.optim.Optimizer): return OptimizerVariable( value, source=self.source, guards=self.make_guards(GuardBuilder.TYPE_MATCH), ) elif ProcessGroupVariable.is_process_group(value): return ProcessGroupVariable( value, source=self.source, guards=self.make_guards(GuardBuilder.ID_MATCH), ) elif DeviceMeshVariable.is_device_mesh(value): # TODO: see if we need to add custom guard instead # of a simple ID_MATCH return DeviceMeshVariable( value, source=self.source, guards=self.make_guards(GuardBuilder.ID_MATCH), ) elif PlacementClassVariable.is_placement_type(value): # TODO: see if we need to add custom guard instead # of a simple ID_MATCH return PlacementClassVariable( value, source=self.source, guards=make_guards(GuardBuilder.ID_MATCH), ) elif PlacementVariable.is_placement(value): # TODO: see if we need to add custom guard instead # of a simple ID_MATCH return PlacementVariable( value, source=self.source, guards=make_guards(GuardBuilder.ID_MATCH), ) elif issubclass(type(value), type): # TODO(whc) the following seems preferable but breaks some tests, debug # elif inspect.isclass(value): return UserDefinedClassVariable( value, source=self.source, guards=make_guards(GuardBuilder.FUNCTION_MATCH), ) elif isinstance(value, torch.SymBool): # Note: the idea here is to re-use the infra we've built for SymInt by simulating the # user provided SymBool with a SymInt in dynamo. # Concretely, # 1. We create a SymInt in dynamo's shape_env, whose source is constructed as ConvertIntSource(self.source). # so that guards on the SymInts can be effectively applied on the original SymBool in user program. # 2. We create a SymBool based on the SymInt in dynamo's ShapeEnv. Because the original user program # depends on the value being a SymBool. This allows dynamo to interpret the user's program correctly. value_hint = value.node.require_hint() new_source = ConvertIntSource(self.source) new_symint = self.tx.output.shape_env.create_unspecified_symint_and_symbol( int(value_hint), new_source, dynamic_dim=DimDynamic.DYNAMIC, ) sym_node_proxy = self.tx.output.root_tracer.create_graph_input( re.sub(r"[^a-zA-Z0-9]+", "_", self.name), type(new_symint), source=new_source, ) sym_node_proxy.node.meta["grapharg"] = GraphArg( new_source, new_symint, False, None, is_tensor=False, example_strong_ref=new_symint, ) self.tx.output.tracked_fakes.append( TrackedFake(new_symint, new_source, None) ) return SymNodeVariable( sym_node_proxy, new_symint == 1, ) elif isinstance(value, JITFunction): return TritonKernelVariable( value, None, # No kernel idx provided None, # No grid provided source=self.source, guards=make_guards(GuardBuilder.ID_MATCH), ) else: result = UserDefinedObjectVariable( value, source=self.source, guards=self.make_guards(GuardBuilder.TYPE_MATCH), ) if not SideEffects.cls_supports_mutation_side_effects(type(value)): # don't allow STORE_ATTR mutation with custom __setattr__ return result return self.tx.output.side_effects.track_object_existing( self.source, value, result ) def tensor_can_be_dict_key(self, value): # only allow Parameter and another specific Tensor can be used as dict key return ( isinstance(value, torch.nn.Parameter) or isinstance(self.source, AttrSource) and self.source.member == "state" and isinstance(self.source.base, LocalSource) ) def tensor_should_specialize(self): return ( self.source and isinstance(self.source, GetItemSource) and isinstance(self.source.base, GetItemSource) and self.source.base.index == "params" and isinstance(self.source.base.base, GetItemSource) and isinstance(self.source.base.base.base, AttrSource) and self.source.base.base.base.member == "param_groups" and isinstance(self.source.base.base.base.base, LocalSource) and ( isinstance( self.tx.f_locals[self.source.base.base.base.base.local_name], torch.optim.Optimizer, ) if self.source.base.base.base.base.local_name in self.tx.f_locals.keys() else True ) ) def wrap_listlike(self, value: Union[tuple, list, odict_values, NamedTuple]): # One can index a tensor with a list/tuple. Therefore, we need to # have a stricter match. guards = self.make_guards(GuardBuilder.LIST_LENGTH) for item in value: if item is value: unimplemented("list elements are pointing to the list itself") output = [ VariableBuilder(self.tx, GetItemSource(self.get_source(), i))( item ).add_guards(guards) for i, item in enumerate(value) ] result = BaseListVariable.cls_for_instance(value)( output, mutable_local=MutableLocal(), guards=guards ) if istype(value, list): return self.tx.output.side_effects.track_list(self.source, value, result) return result def wrap_tuple_iterator(self, value: tuple_iterator): guards = self.make_guards(GuardBuilder.TUPLE_ITERATOR_LEN) output = [ VariableBuilder(self.tx, TupleIteratorGetItemSource(self.get_source(), i))( tuple_iterator_getitem(value, i) ).add_guards(guards) for i in range(tuple_iterator_len(value)) ] return TupleIteratorVariable( output, mutable_local=MutableLocal(), guards=guards ) def wrap_slice_range(self, value: Union[slice, range]): items = [ VariableBuilder(self.tx, AttrSource(self.get_source(), k))( getattr(value, k) ) for k in ("start", "stop", "step") ] if isinstance(value, slice): return SliceVariable( items, guards=self.make_guards(GuardBuilder.TYPE_MATCH) ) else: return RangeVariable( items, guards=self.make_guards(GuardBuilder.EQUALS_MATCH) ) def wrap_module(self, value: torch.nn.Module): from ..eval_frame import OptimizedModule if istype(value, OptimizedModule): guards = self.make_guards(GuardBuilder.TYPE_MATCH) self.source = AttrSource(self.source, "_orig_mod") return self.wrap_module(value._orig_mod).add_guards(guards) if ( isinstance(value, (torch.nn.RNN, torch.nn.GRU, torch.nn.LSTM)) and not config.allow_rnn ): unimplemented("TorchDynamo purposely graph breaks on RNN, GRU, LSTMs") if mutation_guard.is_dynamic_nn_module(value): # created dynamically, don't specialize on it result = UnspecializedNNModuleVariable( value, guards=self.make_guards(GuardBuilder.TYPE_MATCH) ) if not SideEffects.cls_supports_mutation_side_effects(type(value)): # don't allow STORE_ATTR mutation with custom __setattr__ return result return self.tx.output.side_effects.track_object_existing( self.source, value, result ) elif issubclass( value.__class__, torch.nn.parallel.distributed.DistributedDataParallel ): return UnspecializedNNModuleVariable( value, guards=self.make_guards(GuardBuilder.TYPE_MATCH) ) elif getattr(value, "_is_fsdp_managed_module", False): # See note [Dynamo treats FSDP wrapped modules as UnspecializedNNModule] # in fully_sharded_data_parallel.py for more information # we can't do this assert inside FSDP constructor, # since we don't know yet whether dynamo will be used assert getattr( value, "_fsdp_use_orig_params", False ), "Dynamo only supports FSDP with use_orig_params=True" # Note on FSDP guarding # 1. We expect FSDP wrapping mutates an nn module irreversably (no way to de-wrap). # 2. Eager FSDP already assumes (requires, but without enforcement) that users don't mutate their # model parameters/structure after FSDP wrapping, because FSDP wouldn't notice or update its FlatParams. # # Due to (1), once we enter this path we expect not to go back nor have to guard on type # or _is_fsdp_managed_module. # # TODO(whc) We could add a guard on the opposite case, where a user compiled/ran # pre-FSDP-wrapped model, then wrapped, to ensure that we recompile with the FSDP handling. # # Due to (2), we skip guards on inner contents of fsdp_managed modules, by using FSDPNNModuleSource as the # guard source. This behavior is gated on config.skip_fsdp_guards. # # ID_MATCH is required to disambiguate cases as simple as a unit test that constructs 2 models and wraps # them differently with different FSDP configs. (test_dynamo_distributed.py -k test_fsdp_aot_eager) return FSDPManagedNNModuleVariable( value, guards=self.make_guards(GuardBuilder.TYPE_MATCH, GuardBuilder.ID_MATCH), source=self.get_source(), ) else: return self.tx.output.register_attr_or_module( value, self.name, source=self.get_source(), # Guards are added inside register_attr_or_module ) def wrap_literal(self, value): unspec = not config.specialize_int if unspec and type(value) is torch.Size: return SizeVariable( [ VariableBuilder(self.tx, GetItemSource(self.get_source(), i))(v) for i, v in enumerate(value) ], guards=self.make_guards(GuardBuilder.LIST_LENGTH), ) elif unspec and type(value) is int: # unspecializing int by default, but still # specialize for the following conditions if not TracingContext.get().force_unspec_int_unbacked_size_like and ( value in self._common_constants() # Assume integers from global variables want to be specialized or not self.source.guard_source().is_local() # Assume that integers that came from NN modules want to be # specialized (as we don't expect users to be changing the # NN modules on the fly) or self.source.guard_source().is_nn_module() ): return ConstantVariable.create( value=value, guards=self.make_guards(GuardBuilder.CONSTANT_MATCH), ) else: return self.wrap_unspecialized_primitive(value) else: return ConstantVariable.create( value=value, guards=self.make_guards(GuardBuilder.CONSTANT_MATCH), ) def wrap_tensor(self, value: torch.Tensor): source = self.get_source() if ( source.guard_source().is_nn_module() or get_static_address_type(value) is not None ) and not source.guard_source().is_fsdp_module(): return self.tx.output.register_attr_or_module( value, self.name, source=source, # Guards are done inside register_attr_or_module # guards=self.make_guards(GuardBuilder.TENSOR_MATCH), ) if is_constant_source(source): return self.tx.output.register_attr_or_module( value, re.sub(r"[^a-zA-Z0-9]+", "_", self.name), source=source, # Guards are added inside register_attr_or_module ) if type(value) in config.traceable_tensor_subclasses: # Ordinarily, we would fakeify a tensor so that it can get dynamic # shapes and be computed on without triggering actual operations. # However, how can we fakeify a tensor subclass? Ordinary # inheritance (nor multiple inheritance) won't work work. # # Instead, our plan is to *manually simulate* the tensor subclass # inheriting from a fake tensor with dynamo. This means our # data representation for a tensor subclass will be a fake tensor # + tensor subclass type + any extra data the subclass may have # been storing on the tensor. Because all Python accesses are # mediated through TensorWithTFOverrideVariable, we can ensure # that we dispatch differently, e.g., according to # __torch_function__ # # To simplify things for now, the __dict__ tracking bits haven't # been implemented yet, but they can be added into this design at # a later point in time. ignore_subclass = True else: assert type(value) in ( torch.Tensor, torch.nn.Parameter, torch._subclasses.fake_tensor.FakeTensor, ) or is_traceable_wrapper_subclass(value), type(value) ignore_subclass = False # NB: this just says we accessed a tensor from the same source again # (e.g., a tensor lives in a global foo, and we LOAD_GLOBAL it twice). # This is distinct from two distinct sources mapping to the same # Tensor (per id())! No guard is necessary here. See below for the # other case. is_duplicate_tensor = source in self.tx.output.input_source_to_var if is_duplicate_tensor: return self.tx.output.input_source_to_var[source] # We have accessed the SAME tensor from a different source. In some # situations, it doesn't matter if you have the same tensor identity # or not, but we are unable to do this fine-grained tracking. So # instead we just say, if x is y, then to successfully reuse this # compiled tensor again, you must have x is y again. Negative # aliases, that is, that x is not y, are IMPLICITLY checked as part of # the code cache matching process, you don't need to explicitly # generate a guard for it (nor would you want to, you need O(n^2) # pairwise 'is not' tests to do it.) if value in self.tx.output.real_value_tensor_positive_aliases: stored_value = self.tx.output.real_value_tensor_positive_aliases[value] # TODO(voz): Decently common pattern, refactor at some point. dup_guard = self._make_dupe_guard(stored_value) if dup_guard: stored_value = stored_value.add_guards(self.make_guards(dup_guard)) return stored_value # tx.output has multiple tracers if we're introspecting HigherOrderOperator. # When we've discovered an untracked tensor, then we actually need # to get Dynamo to track the tensor (which is what this function does) # and put it as a graph input on the root tracer. Later on, # if the input is actually used in the body of the HigherOrderOperator, # then the relevant SubgraphTracer will lift it to being an input of # the subgraph. # See NOTE [HigherOrderOperator tracing design] for more details. tensor_proxy = self.tx.output.root_tracer.create_graph_input( re.sub(r"[^a-zA-Z0-9]+", "_", self.name), type(value), source=source ) tensor_variable = wrap_fx_proxy( tx=self.tx, proxy=tensor_proxy, example_value=value, guards=self.make_guards( functools.partial( GuardBuilder.TENSOR_MATCH, value=value if isinstance(source, NumpyTensorSource) else TensorWeakRef(value), ) ), should_specialize=self.tensor_should_specialize(), ignore_subclass=ignore_subclass, source=source, ) self.tx.output.input_source_to_var[source] = tensor_variable assert "tensor_dict" not in tensor_proxy.node.meta tensor_proxy.node.meta["tensor_dict"] = value.__dict__.copy() # TODO: I think the result is guaranteed to be fake with # ignore_subclass changes fake_tensor_value = None example_value = tensor_variable.proxy.node.meta["example_value"] if is_fake(example_value): fake_tensor_value = example_value grapharg = GraphArg(source, value, False, fake_tensor_value) tensor_proxy.node.meta["grapharg"] = grapharg self.tx.output.add_symbol_bindings(grapharg) if type(value) in config.traceable_tensor_subclasses: # NB: This is slightly misnamed, a tensor subclass might not have # any explicit __torch_function__ implementation and is relying # on the default inherited from torch.Tensor return TensorWithTFOverrideVariable.create( self.tx, tensor_variable, source, value.__torch_function__.__func__, type(value), ) return tensor_variable def wrap_numpy_ndarray(self, value): assert np is not None assert isinstance(value, np.ndarray) source = NumpyTensorSource(self.get_source()) tensor_value = torch.as_tensor(value) # We do this because we want the full behavior of guarding the numpy ndarray as if it were # a tensor. It's a little annoying to make a VT to throw out, but there's so many side effects here # that there's not another great way to do this atm. # This creates the right graphargs, as well as registration for guards in tensor names and shape env. tensor_vt = VariableBuilder(self.tx, source)(tensor_value) proxy = self.tx.output.root_tracer.create_graph_input( re.sub(r"[^a-zA-Z0-9]+", "_", self.name), type(tensor_value), source=source ) options = {"source": source, "guards": tensor_vt.guards} numpy_ndarray_variable = wrap_fx_proxy_cls( target_cls=NumpyNdarrayVariable, tx=self.tx, proxy=proxy, example_value=tensor_value, **options, ) self.tx.output.input_source_to_var[source] = numpy_ndarray_variable example_value = numpy_ndarray_variable.proxy.node.meta["example_value"] # is_unspecialized should be true because we are wrapping a np.ndarray as argument input, and it needs to be # converted to a tensor. grapharg = GraphArg( source, tensor_value, is_unspecialized=True, fake_tensor=example_value, is_tensor=True, example_strong_ref=tensor_value, ) proxy.node.meta["grapharg"] = grapharg return numpy_ndarray_variable def wrap_unspecialized_primitive(self, value): if self.name in self.tx.output.unspec_variable_map: return self.tx.output.unspec_variable_map[self.name] else: shape_env = self.tx.output.shape_env if TracingContext.get().force_unspec_int_unbacked_size_like and isinstance( value, int ): wrapped_value = shape_env.create_unbacked_symint() _constrain_range_for_size(wrapped_value) self.tx.output.tracked_fakes.append( TrackedFake(wrapped_value, self.source, None) ) # NB: We do not do float. For motivation, see # https://docs.google.com/document/d/1INSCdYu1PxXcr43HrD82OudeEuS-qxQe1yZmLg2wy6A/edit # but the general idea is that we generate kernels that can # take unspecialized floats and use them in sizevar computation elif ( isinstance(value, int) and not is_constant_source(self.get_source()) and not isinstance(self.get_source(), RandomValueSource) ): if torch._dynamo.config.specialize_int: # If specialize_int is False, also return # a constant (but this should have been handled # in the caller, TBH) return ConstantVariable.create( value=value, guards=self.make_guards(GuardBuilder.CONSTANT_MATCH), ) name = self.source.name() if name not in self.tx.output.frame_state: # Note - this essentially means that if this name gets reused as a tensor, # it will start fully dynamic. That should always be a safe option, and not awfully inefficient. # Alternatively, if we want to improve pef here, we can add a third state of unset, but I am not # sure that is necessary for now. frame_state_entry = FrameStateSizeEntry(scalar=value, size=None) else: frame_state_entry = self.tx.output.frame_state[name] if frame_state_entry.scalar != value: log.debug( "automatic dynamic int %s val %s != %s", name, value, frame_state_entry.scalar, ) frame_state_entry.scalar = None self.tx.output.frame_state[name] = frame_state_entry # TODO: This should be dynamic, as we in general do not # know if bare integers are actually going to be sizevars # and it is inappropriate to eagerly duck size them with # real sizevars if ( config.automatic_dynamic_shapes and frame_state_entry.scalar is None ) or not config.assume_static_by_default: dynamic_dim = DimDynamic.DYNAMIC else: # assume_static_by_default # TODO: dynamic_dim = DimDynamic.STATIC should work but # for some reason it doesn't return ConstantVariable.create( value=value, guards=self.make_guards(GuardBuilder.CONSTANT_MATCH), ) wrapped_value = shape_env.create_unspecified_symint_and_symbol( value, source=self.source, dynamic_dim=dynamic_dim, ) self.tx.output.tracked_fakes.append( TrackedFake(wrapped_value, self.source, None) ) else: wrapped_value = torch.tensor(value) if not isinstance(self.get_source(), RandomValueSource): guards = {self.get_source().make_guard(GuardBuilder.TYPE_MATCH, True)} options = {"guards": guards} else: options = {} options.update({"source": self.get_source()}) if isinstance(wrapped_value, torch.Tensor): options.update({"raw_value": value}) proxy = self.tx.output.root_tracer.create_graph_input( re.sub(r"[^a-zA-Z0-9]+", "_", self.name), type(wrapped_value), source=self.get_source(), ) unspec_var = wrap_fx_proxy_cls( UnspecializedPythonVariable, tx=self.tx, proxy=proxy, example_value=wrapped_value, **options, ) self.tx.output.unspec_variable_map[self.name] = unspec_var if not is_constant_source(self.get_source()): if self.tx.export and not isinstance(self.get_source(), LocalSource): raise AssertionError( "Dynamo attempts to add additional input during export: value={}, source={}".format( wrapped_value, self.get_source() ) ) fake_tensor_value = None if isinstance(unspec_var, ConstantVariable): example_value = unspec_var.value else: example_value = unspec_var.proxy.node.meta["example_value"] if is_fake(example_value): fake_tensor_value = example_value assert fake_tensor_value.fake_mode is self.tx.fake_mode, ( f"fake mode ({fake_tensor_value.fake_mode}) from fake tensor metadata doesn't match mode" "({self.tx.fake_mode}) from InstructionTranslator" ) proxy.node.meta["grapharg"] = GraphArg( self.get_source(), wrapped_value, isinstance(wrapped_value, torch.Tensor), fake_tensor_value, is_tensor=False, example_strong_ref=wrapped_value, ) return unspec_var def _dataclasses_fields_lambda(obj): if isinstance(obj, UserDefinedObjectVariable): value = obj.value elif isinstance(obj, DataClassVariable): value = obj.user_cls else: unimplemented(f"Dataclass fields handling fails for type {obj}") items = [] for field in dataclasses.fields(value): source = None if obj.source: source = GetItemSource( AttrSource(obj.source, "__dataclass_fields__"), field.name ) items.append(UserDefinedObjectVariable(field, source=source).add_options(obj)) return TupleVariable(items).add_options(obj) def wrap_fx_proxy(tx, proxy, example_value=None, **options): return wrap_fx_proxy_cls( target_cls=TensorVariable, tx=tx, proxy=proxy, example_value=example_value, **options, ) # Note: Unfortunate split due to some gross classes existing that subclass TensorVariable # Should be compositional instead # # This is a horribly complicated function that does too many things, to # explain what it does, let's first talk about the classic usage wrap_fx_proxy # for a TensorVariable. There are two primary modes of use: # # 1. Wrapping a pre-existing Tensor. In this case, example_value is set # to the pre-existing Tensor. (Note that this example_value will NOT # be the final example_value we put into node.meta['example_value'], # instead it is converted into a fake tensor using # wrap_to_fake_tensor_and_record and registered as a graph input.) # # 2. "Wrapping" the result of some Tensor operation Dynamo traced over. In # this case, example_value is None (and we are going to figure it out # ourselves using FakeTensors, via get_fake_value, which will run # the operation represented by the (singular!) FX node referenced by # the passed in proxy.) # # The expectation is you end up with a Tensor output, and everything is # straightforwardly traced into the graph. # # Upon closer inspection, you may notice that there are a slurry of non-Tensor # output cases. What gives? Well, we sometimes trace operations into the # graph that don't involve tensors. # # * Some operators return tuples; we need to recursively handle their # contents # # * Some operators have side effects that will affect subsequent AOTAutograd # tracing but don't otherwise return anything. # # * Some operators return symbolic ints/floats/bools which can go in the # graph and be traced (but only if they're actually symbolic! If they're # static you don't want to put them in the graph, which means you # shouldn't call this function.) # # The common theme is that you only use this function WHEN YOU ARE TRACING # SOMETHING INTO THE GRAPH. This is sort of obvious, because you can't call # this function without a proxy. def wrap_fx_proxy_cls( target_cls, tx, proxy, example_value=None, ignore_subclass=False, **options ): from ..symbolic_convert import InstructionTranslatorBase assert isinstance(tx, InstructionTranslatorBase) if "guards" in options and options["guards"] is not None: tx.output.guards.update(options["guards"]) assert "example_value" not in proxy.node.meta, f"{proxy.node.meta['example_value']}" initial_example_value = example_value def _is_functional_tensor_fakified_by_dynamo(x): if isinstance(x, torch.Tensor) and torch._is_functional_tensor(x): reapply_views = torch._C._functionalization_reapply_views_tls() unwrapped = torch._C._functorch._unwrap_functional_tensor(x, reapply_views) return ( isinstance(unwrapped, FakeTensor) and unwrapped.fake_mode == tx.fake_mode ) return False def _clone_input(value): if isinstance(value, torch.Tensor): # tensor subclasses will not be converted to FakeTensors and need to be cloned if not ( isinstance(value, FakeTensor) or _is_functional_tensor_fakified_by_dynamo(value) or value.is_nested ): # NB: ensure strides are preserved value = clone_input(value) return value with preserve_rng_state(): if example_value is None: example_value = get_fake_value(proxy.node, tx) # Handle recursive calls here elif ( is_fake(example_value) and maybe_get_fake_mode(example_value) is tx.fake_mode ) or _is_functional_tensor_fakified_by_dynamo(example_value): pass elif isinstance(example_value, torch.Tensor): if tx.export: # The legacy behavior for real value cache with subclasses was # to perform a clone WITHOUT preserving the subclass. It's # not entirely clear this is what you actually want though. with torch._C.DisableTorchFunctionSubclass(): proxy.tracer.real_value_cache[proxy.node] = _clone_input( example_value ) # NB: If we're ignoring subclass, then the expectation is you will # take the returned TensorVariable and wrap it into a more # accurate TensorVariable that is able to track subclass-ness; # otherwise this is wrong! kwargs = { "ignore_subclass": ignore_subclass, "is_tensor": target_cls is TensorVariable, } assert "source" in options and options["source"] is not None kwargs["source"] = options["source"] example_value = wrap_to_fake_tensor_and_record( example_value, tx=tx, **kwargs ) if isinstance(example_value, torch.Tensor): is_parameter = isinstance(example_value, torch.nn.Parameter) should_specialize = options.pop("should_specialize", False) if is_parameter or should_specialize: specialized_value = initial_example_value else: specialized_value = None # NB: In most (all?) cases, this does not actually do a clone. # (WARNING: this means that if we mutate metadata on the fake # tensor, the stored example value will update too!) example_value = _clone_input(example_value) proxy.node.meta["example_value"] = example_value specialized_props = target_cls.specialize(example_value) # TODO: not sure about this fake mode test if ( isinstance(example_value, torch._subclasses.fake_tensor.FakeTensor) and example_value.fake_mode is tx.fake_mode ): # NB: This will be wrong for ignore_subclass; fix it up later! specialized_props["class_type"] = ( torch.nn.Parameter if is_parameter else torch.Tensor ) specialized_props["specialized_value"] = specialized_value options.update(specialized_props) return target_cls(proxy, **options) elif ( hasattr(proxy.node.target, "__name__") and proxy.node.target.__name__ == "set_state" and isinstance(proxy.node.target.__self__, torch._C.Generator) or proxy.node.target == torch.random.set_rng_state ): from . import TorchVariable return TorchVariable(proxy.node.target) elif ( proxy.node.target == torch._C._DisableFuncTorch or proxy.node.target == torch.cuda._is_in_bad_fork ): from . import UserDefinedObjectVariable return UserDefinedObjectVariable(example_value) elif istype(example_value, torch.Size) and all( isinstance(x, int) for x in example_value ): sizes = [ConstantVariable.create(x) for x in example_value] return SizeVariable(sizes, **options) elif isinstance(example_value, (tuple, list, set)): proxy.node.meta["example_value"] = example_value unpacked = [] for i, val in enumerate(example_value): if val is None: # nn.MultiheadAttention() can return None, see issue #175 unpacked.append( ConstantVariable.create(None, **options), ) else: unpacked.append( wrap_fx_proxy_cls( target_cls, tx, proxy.tracer.create_proxy( "call_function", operator.getitem, (proxy, i), {} ), example_value=val, **options, ) ) if isinstance(example_value, torch.Size): # NB: Keep the old proxy around. See SizeVariable for an # explanation why return SizeVariable(unpacked, proxy, **options) elif istype(example_value, tuple): return TupleVariable(unpacked, **options) elif istype(example_value, (list, immutable_list)): return ListVariable(unpacked, mutable_local=MutableLocal(), **options) elif istype(example_value, set): return SetVariable(unpacked, mutable_local=MutableLocal(), **options) else: assert example_value.__class__.__module__ == "torch.return_types" or hasattr( example_value, "_fields" ), f"expected {example_value.__class__.__module__} == torch.return_types or named tuple but got {type(example_value)}" return NamedTupleVariable(unpacked, example_value.__class__, **options) elif example_value is None or proxy.node.target is torch.manual_seed: return ConstantVariable.create(None, **options) elif isinstance(example_value, (torch.SymInt, torch.SymFloat, torch.SymBool)): proxy.node.meta["example_value"] = example_value return SymNodeVariable(proxy, example_value, **options) elif proxy.node.target in [torch.cuda.streams.Stream, torch.cuda.current_stream]: proxy.node.meta["example_value"] = example_value return CUDAStreamVariable(proxy, example_value, **options) elif isinstance(example_value, int) and proxy.node.target in [ torch.sym_int, getattr, operator.getitem, torch._utils._element_size, torch.seed, operator.mod, # some mac builds are missing torch.distributed.get_rank() getattr(torch.distributed, "get_rank", _missing), getattr(torch.distributed, "get_world_size", _missing), # This always wants to be in the graph, even if the constraint # results in a constant int torch._constrain_as_value, torch._constrain_as_size, ]: proxy.node.meta["example_value"] = example_value return ConstantVariable.create(example_value, **options) else: unimplemented( "torch.* op returned non-Tensor " + f"{typestr(example_value)} {proxy.node.op} {proxy.node.target}" ) # Tracks the sources of all fake tensors we wrap in Dynamo. # Used by shape guard computation. @dataclasses.dataclass class TrackedFake: fake: Union[FakeTensor, SymInt] source: Source # Is None when fake is SymInt constraint_dims: Optional[DimList[DimConstraint]] def __hash__(self) -> int: return hash((self.fake, self.source.name())) def __eq__(self, other: object) -> bool: if isinstance(other, TrackedFake): return self.fake is other.fake and self.source.name() == other.source.name() return False # Performs automatic dynamic dim determination. # Returns tuple of (dynamic_dims, constraint_dims) where each is either a list of dims or None. def _automatic_dynamic(e, tx, name, static_shapes): if static_shapes: return [DimDynamic.STATIC] * e.dim(), [None] * e.dim() # We preserve the dynamism of inputs. For example, when users call # make_fx(torch.cond, tracing_mode="symbolic")(*args), inputs have SymInt sizes. if any(isinstance(s, SymInt) for s in e.size()): return [ DimDynamic.DYNAMIC if isinstance(s, SymInt) else DimDynamic.STATIC for s in e.size() ], [None] * e.dim() # Prep for automatic dynamic frame_state_entry = None if name not in tx.output.frame_state: # If there is no entry for this source, add the tensor to frame state with its current static size. # E.g., {} -> {"x": [2, 4]} frame_state_entry = FrameStateSizeEntry(None, None) frame_state_entry.size = list(e.size()) else: frame_state_entry = tx.output.frame_state[name] if frame_state_entry.size is not None: if e.ndim != len(frame_state_entry.size): # If there is already an entry, and the dim mismatches, replace the frame state entry with None. # E.g. {"x": [2, 3, 4]} -> {"x": None} log.debug( "automatic dynamic %s dim %s != %s", name, e.ndim, frame_state_entry.size, ) frame_state_entry.size = None else: # If there is already an entry, and the dim matches, for every size in the frame state which # disagrees with the current static size, replace it with None. E.g., {"x": [2, 3]} -> {"x": [2, None]} for i, dim in enumerate(frame_state_entry.size): if dim is not None and e.size()[i] != dim: log.debug( "automatic dynamic %s size(%s) %s != %s", name, i, e.size(i), dim, ) frame_state_entry.size[i] = None # TODO: index export_constraints ahead of time so we don't have to # do a linear scan every time here t_id = id(e) dim2constraint = {} def update_dim2constraint(dim, constraint_range, debug_name): if dim in dim2constraint: from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint old_constraint_range, old_debug_name = dim2constraint[dim] new_constraint_range = StrictMinMaxConstraint( vr=constraint_range.vr & old_constraint_range.vr, warn_only=False, ) if old_debug_name is not None: assert debug_name is None or debug_name == old_debug_name new_debug_name = old_debug_name else: new_debug_name = debug_name dim2constraint[dim] = new_constraint_range, new_debug_name else: dim2constraint[dim] = constraint_range, debug_name if tx.output.export_constraints: for constraint in tx.output.export_constraints: if constraint.t_id == t_id: update_dim2constraint( constraint.dim, constraint.constraint_range, constraint.debug_name ) if constraint.shared is not None and constraint.shared.t_id == t_id: # We process constraint ranges for each shared dimension separately # so that we can directly check range constraint violations on them # without looking up which other shared dimensions have this info. # In other words, for this t_id, we will have processed all of its # constraint ranges, no matter where / how they were specified, by # by the end of this loop. update_dim2constraint( constraint.shared.dim, constraint.constraint_range, constraint.debug_name, ) dynamic_dims = [] constraint_dims = [] for i in range(e.dim()): # NB: mark dynamic has precedence over static marked_dynamic = i in getattr(e, "_dynamo_dynamic_indices", set()) marked_weak_dynamic = i in getattr(e, "_dynamo_weak_dynamic_indices", set()) marked_static = i in getattr(e, "_dynamo_static_indices", set()) # NB: both static and dynamic have precedence over automatic_dynamic = config.automatic_dynamic_shapes and ( frame_state_entry.size is None or frame_state_entry.size[i] is None ) # Reflect the user directive in the frame_state # For dynamic, apply None always if frame_state_entry.size and marked_dynamic: log.debug("automatic dynamic %s marked dynamic", name) frame_state_entry.size[i] = None # We will process constraints first, as they will imply that we # have a dynamic dimension # Precedence: export constraints > eager constraints constraint = dim2constraint.get(i) if constraint is None: if marked_dynamic and not config.allow_ignore_mark_dynamic: constraint_dim = RelaxedUnspecConstraint(warn_only=False) elif not marked_static and automatic_dynamic: constraint_dim = RelaxedUnspecConstraint(warn_only=True) else: constraint_dim = None else: constraint_dim, debug_name = constraint if debug_name is not None: dim_name = f"{name}.size()[{i}]" tx.output.shape_env.source_name_to_debug_name[dim_name] = debug_name constraint_dims.append(constraint_dim) # Now, figure out if the dim is dynamic/duck/static if constraint_dim is not None or marked_dynamic or marked_weak_dynamic: # NB: We could assert static_shapes is False here, but it # seems better to allow the user to override policy in this # case dynamic = DimDynamic.DYNAMIC elif static_shapes or config.assume_static_by_default or marked_static: dynamic = DimDynamic.STATIC else: dynamic = DimDynamic.DUCK dynamic_dims.append(dynamic) tx.output.frame_state[name] = frame_state_entry return dynamic_dims, constraint_dims def wrap_to_fake_tensor_and_record( e, tx, ignore_subclass=False, *, source: Optional[Source], is_tensor: bool ): if ( type(e) in (torch.Tensor, torch.nn.Parameter, FakeTensor) or (ignore_subclass and isinstance(e, torch.Tensor)) or is_traceable_wrapper_subclass(e) ): assert source is not None static_shapes, reason = tensor_always_has_static_shape( e, is_tensor, guard_source=source.guard_source() ) dynamic_dims, constraint_dims = None, None if not e.is_nested: # TODO: We should probably support this for nested tensors too dynamic_dims, constraint_dims = _automatic_dynamic( e, tx, source.name(), static_shapes ) log.debug( "wrap_to_fake %s %s %s %s", source.name(), tuple(e.shape), dynamic_dims, constraint_dims, ) fake_e = wrap_fake_exception( lambda: tx.fake_mode.from_tensor( e, ignore_subclass=ignore_subclass, source=source, dynamic_dims=dynamic_dims, constraint_dims=constraint_dims, ) ) if is_tensor and not (static_shapes and source.is_nn_module()): tx.output.tracked_fakes.append(TrackedFake(fake_e, source, constraint_dims)) tx.output.tracked_fakes_id_to_source[id(e)].append(source) tx.output.tensor_weakref_to_sizes_strides[WeakIdRef(e)] = { "size": fake_e.size(), "stride": fake_e.stride(), } return fake_e else: return e class SourcelessBuilder: """ Like builder, but stateless and does not require a source. Useful for simple type->VT objects, or objects that are being created/evaporated during inlining (ex: consider a locally made list of tensors we then iterate over .), such a list should not show up as an artifact from inputs, nor in reconstruction, nor in the graph. However, there may be reasons to represent it as a ListVariable internally. NOTE - Objects produced here are born UNGUARDED due to the nature of sources! NOTE - This class is very new! It will have some rough edges, but it was created to stem the bleeding of giant if/else type->VariableTracker trees that were cropping up all over dynamo. """ def __call__(self, tx, value) -> VariableTracker: if isinstance(value, VariableTracker): # This is always valid to call, and useful for recursive calls. return value if isinstance(value, dataclasses._HAS_DEFAULT_FACTORY_CLASS): return UserDefinedObjectVariable(value) if ConstantVariable.is_literal(value): return SourcelessBuilder.wrap_constant_literal(value) elif is_builtin_callable(value): return BuiltinVariable(value) elif is_allowed(value): if is_user_defined_allowed(value): self.tx.output.has_user_defined_allowed_in_graph = True return TorchVariable(value) elif isinstance(value, types.FunctionType): return UserFunctionVariable(value) elif isinstance(value, enum.Enum): return EnumVariable(value) elif isinstance(value, (type, abc.ABCMeta)): return UserDefinedClassVariable(value) elif isinstance(value, dict): return ConstDictVariable( {k: self(tx, v) for k, v in value.items()}, dict, mutable_local=MutableLocal(), ) elif isinstance(value, (tuple, list)): cls = BaseListVariable.cls_for(type(value)) return cls([self(tx, x) for x in value], mutable_local=MutableLocal()) elif isinstance(value, types.MethodWrapperType): return MethodWrapperVariable(value) unimplemented(f"Unexpected type in sourceless builder {type(value)}") @staticmethod def wrap_constant_literal(value): assert ConstantVariable.is_literal(value) return ConstantVariable.create(value=value)