import functools import inspect import operator import types from typing import Dict, List try: import numpy as np except ModuleNotFoundError: np = None import sympy import torch._numpy as tnp import torch.fx import torch.random from torch._dynamo import compiled_autograd from torch.fx.experimental.symbolic_shapes import free_symbols, guard_scalar, SymTypes from .. import config, variables from .._trace_wrapped_higher_order_op import trace_wrapped from ..exc import unimplemented from ..guards import GuardBuilder from ..source import AttrSource from ..utils import ( fqn, get_custom_getattr, get_fake_value, get_real_value, guard_if_dyn, object_has_getattribute, product, proxy_args_kwargs, tensortype_to_dtype, ) from .base import VariableTracker from .constant import ConstantVariable from .lists import SizeVariable supported_tensor_comparison_ops = { ">": operator.gt, "<": operator.lt, ">=": operator.ge, "<=": operator.le, "==": operator.eq, "!=": operator.ne, } supported_const_comparison_ops = { "is": operator.is_, "is not": operator.is_not, "==": operator.eq, "!=": operator.ne, } class TensorVariable(VariableTracker): """A torch.Tensor input or an intermediate value in the FX graph""" _nonvar_fields = [ "proxy", "dtype", "device", "layout", "ndim", "size", "stride", "requires_grad", "is_quantized", "is_contiguous", ] def get_real_value(self): """ Get the actual value represented by this variable if computation is run using the user-provided inputs. NOTE: this runs actual tensor computation and may be slow and memory-intensive. """ return get_real_value(self.proxy.node, self.proxy.tracer) def __init__( self, proxy: torch.fx.Proxy, *, dtype, device, layout, ndim, requires_grad, is_quantized, is_sparse, class_type, size=None, stride=None, is_contiguous=None, specialized_value=None, **kwargs, ): super().__init__(**kwargs) self.proxy = proxy self.dtype = dtype self.device = device self.layout = layout self.ndim = ndim self.size = size self.stride = stride self.requires_grad = requires_grad self.is_quantized = is_quantized self.is_contiguous = is_contiguous self.is_sparse = is_sparse self.class_type = class_type self.specialized_value = specialized_value def as_proxy(self): return self.proxy def python_type(self): return self.class_type def call_isinstance(self, tensor_type): def check_type(ty): if ty not in tensortype_to_dtype: return issubclass(self.python_type(), ty) dtypes = tensortype_to_dtype[ty] return self.dtype in dtypes if type(tensor_type) is tuple: return any(check_type(ty) for ty in tensor_type) else: return check_type(tensor_type) @staticmethod def specialize(value: torch.Tensor): props = { "dtype": value.dtype, "device": value.device, "layout": value.layout, "ndim": int(value.ndim), "requires_grad": value.requires_grad, "is_quantized": value.is_quantized, "is_sparse": value.is_sparse, "class_type": type(value), } if not free_symbols(value): # this is a fully static shape, and the keys on props here inform specialization. # We have to cast to int here, because these might get accessed as ConstantVariable, which has # a strict no-symint policy. If we got here due to not having free symbols, this is a known constant # already. We could remove the discrepancy here, by having ConstantVariable be more permissive for # constant backed SymInts, but that assert being strict has led to some good signal in hunting bugs, and # I'd like to keep it around for now. props["size"] = tuple([int(s) for s in value.size()]) props["stride"] = tuple(value.stride()) props["is_contiguous"] = tuple( [ x for x in torch._prims_common._memory_formats if value.is_contiguous(memory_format=x) ] ) return props def dynamic_getattr(self, tx, name): if not self.source: raise NotImplementedError() # For local source, we associate the real value. We use this real value # for implementing getattr fallthrough on the variable tracker base class. # Note - this scope construction is mirrored in guards # A subsequent PR will introduce a util. scope = {"L": tx.output.local_scope, "G": tx.output.global_scope} try: # We raise in case we get a typerror bug w/ SuperSource. # SuperSource has bugs in it atm, and can produce code like # eval("super(L['mod'].model.model.encoder.embed_positions.forward__class__, # L['mod'].model.model.encoder.embed_positions)", scope) # Which is incorrect, and violates the invariant that all sources should be eval()-able against the scope. _input_associated_real_value = eval(self.source.name(), scope) except Exception as exc: raise NotImplementedError() from exc if _input_associated_real_value is None: raise NotImplementedError() if object_has_getattribute(_input_associated_real_value): raise NotImplementedError() if get_custom_getattr(_input_associated_real_value): raise NotImplementedError() real_value = getattr(_input_associated_real_value, name) if callable(real_value): # Callables have more nuanced handling, and we should let the existing system delegate here. # Raising was past behavior and so should always be sound to fall back. # Note - at a certain point we may want to handle raise NotImplementedError() from ..guards import GuardBuilder from .builder import VariableBuilder attr_source = AttrSource(self.source, name) has_attr_guard = attr_source.make_guard(GuardBuilder.HASATTR) return ( VariableBuilder(tx, attr_source)(real_value) .add_options(self) .add_guard(has_attr_guard) ) def var_getattr(self, tx, name): from . import ConstantVariable, TorchVariable if tx.strict_checks_enabled: if name in self._strict_mode_banned_ops(): unimplemented(f"Illegal getattr invocation {name} in strict mode") result = None options = VariableTracker.propagate(self) if name == "ndim" and self.ndim is not None: result = ConstantVariable.create(self.ndim, **options) elif name == "dtype" and self.dtype is not None: result = TorchVariable(self.dtype, **options) elif name == "device" and self.device is not None: result = TorchVariable(self.device, **options) elif name == "layout" and self.layout is not None: result = TorchVariable(self.layout, **options) elif name == "is_cuda" and self.device is not None: result = ConstantVariable.create(self.device.type == "cuda", **options) elif name == "shape" and self.size is not None: sizes = [variables.ConstantVariable.create(x) for x in self.size] result = SizeVariable(sizes, **options) elif name == "requires_grad" and self.requires_grad is not None: result = ConstantVariable.create(self.requires_grad, **options) elif name == "is_quantized" and self.is_quantized is not None: result = ConstantVariable.create(self.is_quantized, **options) elif name == "is_sparse" and self.is_sparse is not None: result = ConstantVariable.create(self.is_sparse, **options) elif name == "shape" and self.size is None: result = self.call_method(tx, "size", [], {}) elif name == "ndim" and self.ndim is None: result = self.call_method(tx, "dim", [], {}) elif name == "data": result = self.call_method(tx, "detach", [], {}) if name == "__class__": return TorchVariable(self.python_type(), **options) # Add a guard for type matching, these guards are checked before tensor guards # In some cases, a . guard can be evaluated first, and break if # is later changed to another type if result is not None and self.source is not None: result = result.add_guard(self.make_guard(GuardBuilder.TYPE_MATCH)) # It's hard to get inplace view (metadata mutation) on graph input work properly across # dynamo/aot/inductor, just fall back. if self.source is not None and hasattr(torch.ops.aten, name): fn = getattr(torch.ops.aten, name) if ( hasattr(fn, "overloads") and hasattr(fn, fn.overloads()[0]) and torch.Tag.inplace_view in getattr(fn, fn.overloads()[0]).tags ): # Delay the graph break to the actual call of unsqueeze_/resize_/resize_as_ etc. return variables.misc.DelayGraphBreakVariable() # For attributes (not methods) that were not caught in the special handling above, # (e.g. tensor.real), we handle these generically, assuming that the output type is # a tensor. if result is None: def try_generic_attr_handling(): from .builder import wrap_fx_proxy from .misc import GetAttrVariable try: static_attr = inspect.getattr_static(torch.Tensor, name) except AttributeError: return None # Make sure this is an attribute, not a method. # type(torch.Tensor.H) should be "getset_descriptor" # This is a because of CPython implementation, see THPVariableType: # these attributes are implemented under tp_getset, which appear # as `getset_descriptor`s, (compared to, say, methods which appear # as `method_descriptor`s) if type(static_attr) != types.GetSetDescriptorType: return None return wrap_fx_proxy( tx=tx, proxy=GetAttrVariable.create_getattr_proxy(self.as_proxy(), name), **options, ) result = try_generic_attr_handling() if result is None: result = self.dynamic_getattr(tx, name) if result is None: raise NotImplementedError() return result def has_unpack_var_sequence(self, tx): return self.ndim > 0 def unpack_var_sequence(self, tx, idxes=None): from .builder import wrap_fx_proxy_cls options = VariableTracker.propagate(self) if idxes is None: if self.size: length = self.size[0] else: dyn_length = self.call_method( tx, "size", [ConstantVariable.create(0)], {} ) # SymNodeVariable for symbolic sizes, ConstantVariable for constants OR values produced through # symbolic_shapes, but that end up as int/sympy.Integer assert isinstance(dyn_length, (SymNodeVariable, ConstantVariable)) if isinstance(dyn_length, SymNodeVariable): length = dyn_length.evaluate_expr(tx.output) else: length = dyn_length.value idxes = range(length) return [ wrap_fx_proxy_cls( target_cls=type(self), tx=tx, proxy=self.as_proxy()[i], **options ) for i in idxes ] def _strict_mode_banned_ops(self): return torch._dynamo.config._autograd_backward_strict_mode_banned_ops def call_method( self, tx, name, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]", ) -> "VariableTracker": if tx.strict_checks_enabled: if name in self._strict_mode_banned_ops(): unimplemented(f"Illegal method invocation {name} in strict mode") from . import ConstantVariable, TorchVariable, TupleVariable from .builder import wrap_fx_proxy kwargs = dict(kwargs) options = VariableTracker.propagate(self, args, kwargs.values()) if name in ("stride", "size"): dim_var = None if len(args) == 1: dim_var = args[0] elif "dim" in kwargs: dim_var = kwargs["dim"] else: assert not args and not kwargs, f"Tensor.{name}() unhandled args/kwargs" dim = guard_if_dyn(dim_var) def make_const_size_variable(x, **options): return SizeVariable( [ConstantVariable.create(y, **options) for y in x], **options ) RetVariable = ( make_const_size_variable if name == "size" else ConstantVariable.create ) # Technically, this should not be necessary, but I'm including it # for enhanced BC, in case example_value is sometimes not set # (it really should always be set though!) if (r := getattr(self, name)) is not None: if dim is None: return RetVariable(r, **options) else: return ConstantVariable.create(r[dim], **options) # It might still be constant! Consult the fake tensor and see if (fake := self.proxy.node.meta.get("example_value")) is not None: if dim is None: fake_r = getattr(fake, name)() if not free_symbols(fake_r): # int conversion for safety, in case a SymInt refined # to constant return RetVariable(tuple(int(r) for r in fake_r), **options) else: fake_r = getattr(fake, name)(dim) if not free_symbols(fake_r): return ConstantVariable.create(int(fake_r), **options) # Oops, it's not constant. Do the dynamic shapes path. return wrap_fx_proxy( tx, tx.output.create_proxy( "call_method", name, *proxy_args_kwargs([self] + list(args), kwargs), ), **options, ) elif name in ("numel", "nelement"): if self.size is not None: return ConstantVariable.create(product(self.size), **options) # It might still be constant! Consult the fake tensor and see if (fake := self.proxy.node.meta.get("example_value")) is not None: fake_r = fake.numel() if not free_symbols(fake_r): return ConstantVariable.create(int(fake_r), **options) assert not kwargs, f"Tensor.{name}() unhandled kwargs" # Oops, it's not constant. Do the dynamic shapes path. return wrap_fx_proxy( tx, tx.output.create_proxy( "call_method", "numel", *proxy_args_kwargs([self] + list(args), kwargs), ), **options, ) elif name in ("ndimension", "dim") and self.ndim is not None: constant_result = ConstantVariable.create(self.ndim, **options) elif name == "is_floating_point" and self.dtype is not None: constant_result = ConstantVariable.create( self.dtype.is_floating_point, **options ) elif name == "is_contiguous" and self.is_contiguous is not None: if "memory_format" in kwargs: memory_format = kwargs.pop("memory_format").as_python_constant() else: memory_format = torch.contiguous_format constant_result = ConstantVariable.create( memory_format in self.is_contiguous, **options ) elif ( name == "type" and self.dtype is not None and len(args) == 0 and isinstance(self.device, torch.device) ): tensortype = [k for k, v in tensortype_to_dtype.items() if self.dtype in v][ 0 ] if self.device.type == "cuda": constant_result = ConstantVariable.create( f"torch.cuda.{tensortype.__name__}", **options ) else: constant_result = ConstantVariable.create( f"torch.{tensortype.__name__}", **options ) elif ( name == "type" and len(args) == 1 and fqn(type(args[0].as_python_constant())) == "torch.tensortype" ): # torch.FloatTensor, etc. are all of type "torch.tensortype". # torch.fx's tracer fails on these types, because it doesn't support arguments of torch.tensortype type. # So, we pass it in as a string (which is also supported, see above implementation for .type() with 0 args) tensor_type = args[0].as_python_constant() tensor_type_const = ConstantVariable.create(fqn(tensor_type), **options) return wrap_fx_proxy( tx, tx.output.create_proxy( "call_method", name, *proxy_args_kwargs([self, tensor_type_const], kwargs), ), **options, ) elif name == "get_device" and isinstance(self.device, torch.device): index = self.device.index if self.device.type != "cpu" else -1 constant_result = ConstantVariable.create(index, **options) else: constant_result = None if constant_result: assert not kwargs, f"Tensor.{name}() unhandled kwargs" # TODO: I think this branch is dead if len(args) == 1: return constant_result.getitem_const(args[0]) elif args: return TupleVariable( [constant_result.getitem_const(a) for a in args], **options ) return constant_result elif name == "numpy": if not config.trace_numpy: unimplemented("Tensor.numpy(). config.trace_numpy is False") if not np: unimplemented("Tensor.numpy(). NumPy is not available") assert not args, "Tensor.numpy() doesn't take args." if self.layout != torch.strided: raise TypeError( f"can't convert {self.layout} layout tensor to numpy. Use Tensor.dense() first" ) # We don't check that the tensor is on CPU when force is False, as this # allows us to execute NumPy code on CUDA. # We don't check that requires_grad=False as we are currently doing an # unconditional detach. # TODO: We may want to avoid detaching if `requires_grad=True` # and `force=False` to allow computing gradients. force = "force" in kwargs and kwargs["force"].as_python_constant() proxy = tx.output.create_proxy( "call_method", "detach", *proxy_args_kwargs([self], {}) ) if force: # TODO Add resolve_conj and resolve_neg once we support complex tensors proxy = tx.output.create_proxy( "call_method", "cpu", *proxy_args_kwargs([self], {}) ) return NumpyNdarrayVariable.create(tx, proxy, **options) elif name == "tolist": from .builder import SourcelessBuilder def tolist(tensor, sub_proxy): def wrap(i, sub_proxy): return SymNodeVariable.create( tx, sub_proxy.item(), sym_num=tx.output.shape_env.create_unbacked_symint(), ) if tensor.dtype not in [ torch.int8, torch.int16, torch.int32, torch.int64, ]: unimplemented("Input tensor for tolist must be an integer tensor") if tensor.dim() == 0: return wrap(tensor, sub_proxy) if tensor.dim() == 1: return [wrap(val, sub_proxy[i]) for i, val in enumerate(tensor)] return [ tolist(sub_tensor, sub_proxy=sub_proxy[i]) for i, sub_tensor in enumerate(tensor) ] tensor = self.as_proxy().node.meta["example_value"] out = tolist(tensor, self.as_proxy()) return SourcelessBuilder()(tx, out).add_options(options) elif name in ("backward", "data_ptr"): unimplemented(f"Tensor.{name}") elif name == "item" and not config.capture_scalar_outputs: unimplemented(f"Tensor.{name}") elif name == "__len__": return self.call_method( tx, "size", [ConstantVariable.create(0, **options)], {} ) elif name == "__setitem__": key, value = args def has_bool_key(v): if isinstance(v, TensorVariable): return v.dtype in (torch.bool, torch.int8) elif isinstance(v, TupleVariable): return any(has_bool_key(item) for item in v.items) else: return False if ( not config.capture_dynamic_output_shape_ops and has_bool_key(key) and isinstance(value, TensorVariable) and value.requires_grad ): unimplemented( "boolean masking setitem backwards requires dynamic shapes" ) tx.output.guards.update(options["guards"]) tx.output.create_proxy( "call_function", operator.setitem, *proxy_args_kwargs([self] + list(args), kwargs), ) return ConstantVariable.create(None, **options) elif name in ("resize_", "resize_as_"): # Handling resizing in its full generality is difficult. unimplemented(f"Tensor.{name}") elif ( name == "add_" and len(args) == 1 and len(kwargs) == 1 and "alpha" in kwargs ): result = TorchVariable(torch.mul, **options).call_function( tx, args + [kwargs["alpha"]], {} ) return self.call_method(tx, "add_", [result], {}) elif ( name == "addcdiv_" and len(args) == 2 and len(kwargs) == 1 and "value" in kwargs ): result = TorchVariable(torch.div, **options).call_function(tx, args, {}) result = TorchVariable(torch.mul, **options).call_function( tx, [result, kwargs["value"]], {} ) return self.call_method(tx, "add_", [result], {}) elif name == "__contains__": # Rewrite __contains__ here so that downstream passes can trace through # without dealing with unbacked symbool. Roughly the code we translate is: # def __contains__(self, x): # return (x == self).any().item() result = TorchVariable(torch.eq, **options).call_function( tx, [self, args[0]], {} ) result = TorchVariable(torch.any, **options).call_function(tx, [result], {}) return result.call_method(tx, "item", [], {}) elif name == "redistribute": # rewrite non-primitive args/kwargs to be included in the on-the-fly prim function # and rewrite args to have only proxyable args, then insert call_function args_as_value = [x.as_python_constant() for x in args] kwargs_as_value = {k: v.as_python_constant() for k, v in kwargs.items()} def redistribute_fn_with_prim_types(x): return x.redistribute(*args_as_value, **kwargs_as_value) # attach the same function name for better debugging redistribute_fn_with_prim_types.__name__ = f"prim_{name}" return wrap_fx_proxy( tx=tx, proxy=tx.output.create_proxy( "call_function", redistribute_fn_with_prim_types, *proxy_args_kwargs([self], {}), ), **options, ) elif name == "register_hook": # see [On tensor.register_hook] assert len(args) == 1 fn_var = args[0] if not isinstance( fn_var, ( variables.functions.FunctoolsPartialVariable, variables.UserFunctionVariable, variables.TorchVariable, variables.NNModuleVariable, ), ): unimplemented("Unexpected callable type passed to register_hook") # Guards from the fn_var options.update(VariableTracker.propagate(fn_var)) if isinstance(fn_var, variables.NestedUserFunctionVariable): # NestedUserFunctionVariable don't carry their fn, but reconstruction builds it # This should not be onerous to support when needed. unimplemented("NYI - lambda variables as hooks") elif isinstance(fn_var, variables.functions.FunctoolsPartialVariable): fn = fn_var.as_python_constant() name = fn_var.func.fn.__name__ else: fn = fn_var.fn name = fn_var.fn.__name__ handle_variable = variables.user_defined.RemovableHandleVariable( mutable_local=variables.base.MutableLocal(), **options, ) if not self.source: # Intermediary src = fn_var.source if ( not src and isinstance(fn_var, variables.functions.FunctoolsPartialVariable) and fn_var.func.source ): src = fn_var.func.source if not src: unimplemented("No source for register_hook target fn") tx.output.guards.add(src.make_guard(GuardBuilder.ID_MATCH)) if not compiled_autograd.compiled_autograd_enabled: # TODO(voz): # We can relax this by speculating the callable and ensuring that it doesn't modify arbitrary # python state. # We *Must* be in compiled_autograd here because backward hooks can contain anything, and it is unsafe to run # them in a compiled bwd without re-entering dynamo as compiled_autograd does. # # Discussion point 1 - Should we bypass this if nopython/fullgraph = True? # No. Because this was going to be a graph break anyway - this check does not # introduce new graph breaks where there were none. # # Discussion point 2 - Should we defer this check to backwards? # No. Because compiled autograd is not yet ready for prime time. As such, if we defer, a user # would have no recourse - their forward traces just fine, but will fail at backwards unless # compiled_autograd is enabled. If compiled_autograd fails (there are a lot of failures today) # then they have nothing they can do except disable compile. unimplemented( "Compilation of intermediate hooks requires compiled autograd" ) # This wraps our user provided fn with a function that intercedes and # uses our `invoke` higher order op to record a hook invocation in bwd graph. fn = functools.partial(trace_wrapped, fn=fn) def _register_hook_trampoline(tensor): tensor.register_hook(fn) return tensor return wrap_fx_proxy( tx, tx.output.create_proxy( "call_function", _register_hook_trampoline, (self.as_proxy(),), {}, ), **options, ) tx.output.side_effects.register_hook(self, fn_var, handle_variable) return handle_variable elif name == "requires_grad_" and self.as_proxy().node.meta[ "example_value" ].requires_grad != (args[0].value if len(args) > 0 else True): unimplemented("Tensor.requires_grad_") else: # Convert x.new(torch.Size) into x.new_empty(torch.Size), # as Tensor.new acts differently with a Size input versus a tuple input. if name == "new" and len(args) == 1 and isinstance(args[0], SizeVariable): name = "new_empty" return wrap_fx_proxy( tx, tx.output.create_proxy( "call_method", name, *proxy_args_kwargs([self] + list(args), kwargs), ), **options, ) def rename(self, tx, name): self.proxy.node._rename(name) return super().rename(tx, name) class SymNodeVariable(VariableTracker): """ Represents a symbolic size, e.g., as returned by tensor.size(0) """ @classmethod def create(cls, tx, proxy, sym_num, **options): if "example_value" in proxy.node.meta: assert proxy.node.meta["example_value"] == sym_num if sym_num is None: sym_num = get_fake_value(proxy.node, tx) proxy.node.meta["example_value"] = sym_num if isinstance(sym_num, (sympy.Integer, int)): return ConstantVariable.create(int(sym_num)) return SymNodeVariable(proxy, sym_num, **options) def __init__(self, proxy, sym_num, **kwargs): super().__init__(**kwargs) self.proxy = proxy # TODO: Should we allow non SymTypes here? Today it is allowed self.sym_num = sym_num def python_type(self): if isinstance(self.sym_num, SymTypes): return self.sym_num.node.pytype else: return type(self.sym_num) def unpack_var_sequence(self, tx): super().unpack_var_sequence(tx) def as_proxy(self): return self.proxy def evaluate_expr(self, output_graph=None): return guard_scalar(self.sym_num) def call_method( self, tx, name, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]", ) -> "VariableTracker": from .builder import wrap_fx_proxy options = VariableTracker.propagate(self, args, kwargs.values()) return wrap_fx_proxy( tx, tx.output.create_proxy( "call_method", name, *proxy_args_kwargs([self] + list(args), kwargs), ), **options, ) class NumpyNdarrayVariable(TensorVariable): """ Represents an np.ndarray, but backed by torch Tensor via torch._numpy.ndarray. Use this for Tensor.numpy() call. """ @staticmethod def create(tx, proxy, **options): from .builder import wrap_fx_proxy_cls return wrap_fx_proxy_cls( target_cls=NumpyNdarrayVariable, tx=tx, proxy=proxy, **options, ) def var_getattr(self, tx, name): # NB: This INTENTIONALLY does not call super(), because there is # no intrinsic reason ndarray properties are related to Tensor # properties. The inheritance here is for implementation sharing. from ..utils import numpy_attr_wrapper from .builder import wrap_fx_proxy result = None options = VariableTracker.propagate(self) example_value = self.as_proxy().node.meta["example_value"] example_ndarray = tnp.ndarray(example_value) def insert_into_graph(): return wrap_fx_proxy( tx, tx.output.create_proxy( "call_function", numpy_attr_wrapper, (self.as_proxy(), name), {} ), **options, ) if name in ["T", "real", "imag"]: proxy = tx.output.create_proxy( "call_function", numpy_attr_wrapper, (self.as_proxy(), name), {}, ) result = NumpyNdarrayVariable.create(tx, proxy, **options) # These are awkward to implement. The standard playbook for torch._numpy # interop is to trace a call into the torch._numpy wrapper which works for # Tensor operations. However, we don't want to do this for calls # that don't return Tensors, because in those cases we may not want # to trace the attribute access into the graph at all (it is sort # of harmless to do so, because AOTAutograd will eliminate them, # but it's best not to trace them in to begin with.) But in any # case, tracing these into the graph is like trying to fit a square # peg into a round hole; best not to do it. So instead we # painstakingly implement these by hand # # NB: only ALWAYS specialized attributes can go here; notably, # size/shape not allowed! elif name in ("ndim", "itemsize"): return ConstantVariable.create(getattr(example_ndarray, name), **options) elif name in ("shape", "stride"): if not free_symbols(r := getattr(example_ndarray, name)): return ConstantVariable.create(tuple(int(r) for r in r), **options) return insert_into_graph() elif name == "size": if not free_symbols(r := example_ndarray.size): return ConstantVariable.create(int(r), **options) return insert_into_graph() elif name in ["base", "flags", "dtype"]: unimplemented(f"TODO: add support for ndarray.{name}") if result is None: raise NotImplementedError() return result def call_method( self, tx, name, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]", ) -> "VariableTracker": options = VariableTracker.propagate([[self]], [args], [list(kwargs.values())]) from ..utils import numpy_method_wrapper if name in ["__len__", "size", "tolist"]: # delegate back to TensorVariable return super().call_method(tx, name, args, kwargs) proxy = tx.output.create_proxy( "call_function", numpy_method_wrapper(name), *proxy_args_kwargs([self] + list(args), kwargs), ) return NumpyNdarrayVariable.create(tx, proxy, **options) def python_type(self): return np.ndarray class UnspecializedPythonVariable(TensorVariable): """ This is a 1-element tensor represents unspecialized python float/int. """ def __init__(self, proxy: torch.fx.Proxy, **kwargs): raw_value = kwargs.pop("raw_value", None) need_unwrap = kwargs.pop("need_unwrap", True) super().__init__(proxy, **kwargs) self.raw_value = raw_value self.need_unwrap = need_unwrap @classmethod def from_tensor_variable(cls, tensor_variable, raw_value, need_unwrap=True): # Convert a `TensorVariable` instance into an `UnspecializedPythonVariable` instance. return UnspecializedPythonVariable( **dict(tensor_variable.__dict__), raw_value=raw_value, need_unwrap=need_unwrap, ) def as_specialized(self, tx): for graph_arg in tx.output.graphargs: if graph_arg.source is self.source: graph_arg.erase() for g in self.guards: if g.is_volatile: g.create_fn = GuardBuilder.CONSTANT_MATCH return ConstantVariable.create(value=self.raw_value, guards=self.guards) class FakeItemVariable(TensorVariable): """An unspecialized python variable which prevents access to the underlying raw value. This is needed if item is called on a FakeTensor.""" def __init__(self, proxy: torch.fx.Proxy, **kwargs): need_unwrap = kwargs.pop("need_unwrap", False) super().__init__(proxy, **kwargs) self.need_unwrap = need_unwrap @classmethod def from_tensor_variable(cls, tensor_variable): return FakeItemVariable(**dict(tensor_variable.__dict__)) class TensorSubclassVariable(VariableTracker): def __init__(self, value, *args, **kwargs): self.value = value super().__init__(*args, **kwargs) def call_function( self, tx, args: List[VariableTracker], kwargs: Dict[str, VariableTracker] ) -> VariableTracker: if len(args) == 1 and isinstance(args[0], TensorVariable): from .builder import VariableBuilder from .torch_function import TensorWithTFOverrideVariable torch_fn = VariableBuilder( tx, AttrSource(self.source, "__torch_function__") )(self.value.__torch_function__) return TensorWithTFOverrideVariable.create( tx, args[0], torch_fn, self.value, ) return super().call_function(tx, args, kwargs)