pytorch/torch/_dynamo/variables/tensor.py
Edward Z. Yang 49754f44ee Rewrite size/stride/numel TensorVariable handling (#103438)
The main concept behind this refactor is this: if we know that a size/stride/etc is constant, do NOT trace it into the graph, EXCEPT for any preexisting special cases that applied for static shapes. The refactor unfolds like this:

1. Delete the `dynamic_shapes` branches in torch/_dynamo/variables/builder.py which accept int/float/bool outputs. This is over-aggressive and we don't want to allow this (because if the operator returns a constant, we shouldn't have called wrap_fx_proxy in the first place.) This causes a bunch of failures because we are blindly feeding the result of size() call to wrap_fx_proxy when dynamic shapes is enabled.
2. Modify TensorVariable.call_method in torch/_dynamo/variables/tensor.py to avoid sending constant ints to wrap_fx_proxy. After normal specialization (which should be deleted, see https://github.com/pytorch/pytorch/pull/103434) we consult the fake tensor to see if the values in question have free variables or not. If they don't we short circuit tracing into graph. We only trace into graph if the operation in question is truly symbolic. Note that there is a near miss here: it's OK to trace x.size() call entirely into the graph, even if it doesn't have all dynamic shapes, because operator.getitem with int output is special cased in builder.py. This is a preexisting special case and I don't try to get rid of it.
3. It turns out that the change here also breaks torch_np compatibility layer. So I completely rewrite getattr handling in torch/_dynamo/variables/tensor.py to follow the same pattern (only trace into graph if truly dynamic).

There's some minor housekeeping in torch/fx/experimental/symbolic_shapes.py and some test files.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103438
Approved by: https://github.com/larryliu0820
2023-06-12 19:36:24 +00:00

898 lines
33 KiB
Python

import inspect
import operator
import types
from typing import Dict, List
import sympy
import torch.fx
import torch.random
from torch.fx.experimental.symbolic_shapes import free_symbols, guard_scalar, SymTypes
from .. import config, variables
from ..exc import unimplemented
from ..guards import GuardBuilder
from ..source import AttrSource
from ..utils import (
fqn,
get_fake_value,
get_real_value,
HAS_NUMPY_TORCH_INTEROP,
product,
proxy_args_kwargs,
tensortype_to_dtype,
)
from .base import VariableTracker
from .constant import ConstantVariable
from .lists import ShapeVariable, SizeVariable
supported_tensor_comparison_ops = {
">": operator.gt,
"<": operator.lt,
">=": operator.ge,
"<=": operator.le,
"==": operator.eq,
"!=": operator.ne,
}
supported_const_comparison_ops = {
"is": operator.is_,
"is not": operator.is_not,
"==": operator.eq,
"!=": operator.ne,
}
class TensorVariable(VariableTracker):
"""A torch.Tensor input or an intermediate value in the FX graph"""
_nonvar_fields = [
"proxy",
"dtype",
"device",
"layout",
"ndim",
"size",
"stride",
"requires_grad",
"is_quantized",
"is_contiguous",
]
def get_real_value(self):
"""
Get the actual value represented by this variable if computation is run
using the user-provided inputs.
NOTE: this runs actual tensor computation and may be
slow and memory-intensive.
"""
return get_real_value(self.proxy.node, self.proxy.tracer)
def __init__(
self,
proxy: torch.fx.Proxy,
dtype=None,
device=None,
layout=None,
ndim=None,
size=None,
stride=None,
requires_grad=None,
is_quantized=None,
is_contiguous=None,
is_sparse=None,
class_type=torch.Tensor,
specialized_value=None,
**kwargs,
):
super().__init__(**kwargs)
self.proxy = proxy
self.dtype = dtype
self.device = device
self.layout = layout
self.ndim = ndim
self.size = size
self.stride = stride
self.requires_grad = requires_grad
self.is_quantized = is_quantized
self.is_contiguous = is_contiguous
self.is_sparse = is_sparse
self.class_type = class_type
self.specialized_value = specialized_value
def as_proxy(self):
return self.proxy
def python_type(self):
return self.class_type
def call_isinstance(self, tensor_type):
def check_type(ty):
if ty not in tensortype_to_dtype:
return issubclass(self.python_type(), ty)
dtypes = tensortype_to_dtype[ty]
return self.dtype in dtypes
if type(tensor_type) is tuple:
return any(check_type(ty) for ty in tensor_type)
else:
return check_type(tensor_type)
@staticmethod
def specialize(value: torch.Tensor):
props = {
"dtype": value.dtype,
"device": value.device,
"layout": value.layout,
"ndim": int(value.ndim),
"requires_grad": value.requires_grad,
"is_quantized": value.is_quantized,
"is_sparse": value.is_sparse,
"class_type": type(value),
}
if not free_symbols(value):
# this is a fully static shape, and the keys on props here inform specialization.
# We have to cast to int here, because these might get accessed as ConstantVariable, which has
# a strict no-symint policy. If we got here due to not having free symbols, this is a known constant
# already. We could remove the discrepancy here, by having ConstantVariable be more permissive for
# constant backed SymInts, but that assert being strict has led to some good signal in hunting bugs, and
# I'd like to keep it around for now.
props["size"] = tuple([int(s) for s in value.size()])
props["stride"] = tuple(value.stride())
props["is_contiguous"] = tuple(
[
x
for x in torch._prims_common._memory_formats
if value.is_contiguous(memory_format=x)
]
)
return props
def var_getattr(self, tx, name):
from . import ConstantVariable, TorchVariable
if tx.strict_checks_enabled:
if name in self._strict_mode_banned_ops():
unimplemented(f"Illegal getattr invocation {name} in strict mode")
result = None
options = VariableTracker.propagate(self)
if name == "ndim" and self.ndim is not None:
result = ConstantVariable(self.ndim, **options)
elif name == "dtype" and self.dtype is not None:
result = TorchVariable(self.dtype, **options)
elif name == "device" and self.device is not None:
result = TorchVariable(self.device, **options)
elif name == "layout" and self.layout is not None:
result = TorchVariable(self.layout, **options)
elif name == "is_cuda" and self.device is not None:
result = ConstantVariable(self.device.type == "cuda", **options)
elif name == "shape" and self.size is not None:
sizes = [variables.ConstantVariable(x) for x in self.size]
result = ShapeVariable(sizes, **options)
elif name == "requires_grad" and self.requires_grad is not None:
result = ConstantVariable(self.requires_grad, **options)
elif name == "is_quantized" and self.is_quantized is not None:
result = ConstantVariable(self.is_quantized, **options)
elif name == "is_sparse" and self.is_sparse is not None:
result = ConstantVariable(self.is_sparse, **options)
elif name == "shape" and self.size is None:
result = self.call_method(tx, "size", [], {})
elif name == "ndim" and self.ndim is None:
result = self.call_method(tx, "dim", [], {})
elif name == "data":
result = self.call_method(tx, "detach", [], {})
if name == "__class__":
return TorchVariable(self.python_type(), **options)
# Add a guard for type matching, these guards are checked before tensor guards
# In some cases, a <tensor>.<attr> guard can be evaluated first, and break if
# <tensor> is later changed to another type
if result is not None and self.source is not None:
result = result.add_guard(self.make_guard(GuardBuilder.TYPE_MATCH))
# It's hard to get inplace view (metadata mutation) on graph input work properly across
# dynamo/aot/inductor, just fall back.
if self.source is not None and hasattr(torch.ops.aten, name):
fn = getattr(torch.ops.aten, name)
if (
hasattr(fn, "overloads")
and hasattr(fn, fn.overloads()[0])
and torch.Tag.inplace_view in getattr(fn, fn.overloads()[0]).tags
):
# Delay the graph break to the actual call of unsqueeze_/resize_/resize_as_ etc.
return variables.misc.DelayGraphBreakVariable()
# For attributes (not methods) that were not caught in the special handling above,
# (e.g. tensor.real), we handle these generically, assuming that the output type is
# a tensor.
if result is None:
def try_generic_attr_handling():
from .builder import wrap_fx_proxy
from .misc import GetAttrVariable
try:
static_attr = inspect.getattr_static(torch.Tensor, name)
except AttributeError:
return None
# Make sure this is an attribute, not a method.
# type(torch.Tensor.H) should be "getset_descriptor"
# This is a because of CPython implementation, see THPVariableType:
# these attributes are implemented under tp_getset, which appear
# as `getset_descriptor`s, (compared to, say, methods which appear
# as `method_descriptor`s)
if type(static_attr) != types.GetSetDescriptorType:
return None
return wrap_fx_proxy(
tx=tx,
proxy=GetAttrVariable.create_getattr_proxy(self.as_proxy(), name),
**options,
)
result = try_generic_attr_handling()
if result is None:
raise NotImplementedError()
return result
def has_unpack_var_sequence(self, tx):
return self.ndim > 0
def unpack_var_sequence(self, tx, idxes=None):
from .builder import wrap_fx_proxy
options = VariableTracker.propagate(self)
if idxes is None:
if self.size:
length = self.size[0]
else:
dyn_length = self.call_method(tx, "size", [ConstantVariable(0)], {})
# SymNodeVariable for symbolic sizes, ConstantVariable for constants OR values prouced through
# symbolic_shapes, but that end up as int/sympy.Integer
assert isinstance(dyn_length, (SymNodeVariable, ConstantVariable))
if isinstance(dyn_length, SymNodeVariable):
length = dyn_length.evaluate_expr(tx.output)
else:
length = dyn_length.value
idxes = range(length)
return [wrap_fx_proxy(tx, self.as_proxy()[i], **options) for i in idxes]
def _strict_mode_banned_ops(self):
return torch._dynamo.config._autograd_backward_strict_mode_banned_ops
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
if tx.strict_checks_enabled:
if name in self._strict_mode_banned_ops():
unimplemented(f"Illegal method invocation {name} in strict mode")
from . import ConstantVariable, TorchVariable, TupleVariable
from .builder import wrap_fx_proxy
kwargs = dict(kwargs)
options = VariableTracker.propagate(self, args, kwargs.values())
if name in ("stride", "size"):
dim_var = None
if len(args) == 1:
dim_var = args[0]
elif "dim" in kwargs:
dim_var = kwargs["dim"]
else:
assert not args and not kwargs, f"Tensor.{name}() unhandled args/kwargs"
dim = None
if isinstance(dim_var, SymNodeVariable):
# This is because SymNodeVariable intentionally doesn't define
# as_python_constant to avoid shunting down some codepaths
# that expect consts. In this case, we know we definitely
# want to specialize though.
dim = dim_var.evaluate_expr()
elif dim_var is not None:
dim = dim_var.as_python_constant()
def make_const_size_variable(x, **options):
return SizeVariable(
[ConstantVariable(y, **options) for y in x], **options
)
RetVariable = (
make_const_size_variable if name == "size" else ConstantVariable
)
# Technically, this should not be necessary, but I'm including it
# for enhanced BC, in case example_value is sometimes not set
# (it really should always be set though!)
if (r := getattr(self, name)) is not None:
if dim is None:
return RetVariable(r, **options)
else:
return ConstantVariable(r[dim], **options)
# It might still be constant! Consult the fake tensor and see
if (fake := self.proxy.node.meta.get("example_value")) is not None:
if dim is None:
fake_r = getattr(fake, name)()
if not free_symbols(fake_r):
# int conversion for safety, in case a SymInt refined
# to constant
return RetVariable(tuple(int(r) for r in fake_r), **options)
else:
fake_r = getattr(fake, name)(dim)
if not free_symbols(fake_r):
return ConstantVariable(int(fake_r), **options)
# Oops, it's not constant. Do the dynamic shapes path.
return wrap_fx_proxy(
tx,
tx.output.create_proxy(
"call_method",
name,
*proxy_args_kwargs([self] + list(args), kwargs),
),
**options,
)
elif name in ("numel", "nelement"):
if self.size is not None:
return ConstantVariable(product(self.size), **options)
# It might still be constant! Consult the fake tensor and see
if (fake := self.proxy.node.meta.get("example_value")) is not None:
fake_r = fake.numel()
if not free_symbols(fake_r):
return ConstantVariable(int(fake_r), **options)
assert not kwargs, f"Tensor.{name}() unhandled kwargs"
# Oops, it's not constant. Do the dynamic shapes path.
return wrap_fx_proxy(
tx,
tx.output.create_proxy(
"call_method",
"numel",
*proxy_args_kwargs([self] + list(args), kwargs),
),
**options,
)
elif name in ("ndimension", "dim") and self.ndim is not None:
constant_result = ConstantVariable(self.ndim, **options)
elif name == "is_floating_point" and self.dtype is not None:
constant_result = ConstantVariable(self.dtype.is_floating_point, **options)
elif name == "is_contiguous" and self.is_contiguous is not None:
if "memory_format" in kwargs:
memory_format = kwargs.pop("memory_format").as_python_constant()
else:
memory_format = torch.contiguous_format
constant_result = ConstantVariable(
memory_format in self.is_contiguous, **options
)
elif (
name == "type"
and self.dtype is not None
and len(args) == 0
and isinstance(self.device, torch.device)
):
tensortype = [k for k, v in tensortype_to_dtype.items() if self.dtype in v][
0
]
if self.device.type == "cuda":
constant_result = ConstantVariable(
f"torch.cuda.{tensortype.__name__}", **options
)
else:
constant_result = ConstantVariable(
f"torch.{tensortype.__name__}", **options
)
elif (
name == "type"
and len(args) == 1
and fqn(type(args[0].as_python_constant())) == "torch.tensortype"
):
# torch.FloatTensor, etc. are all of type "torch.tensortype".
# torch.fx's tracer fails on these types, because it doesn't support arguments of torch.tensortype type.
# So, we pass it in as a string (which is also supported, see above implementation for .type() with 0 args)
tensor_type = args[0].as_python_constant()
tensor_type_const = ConstantVariable(fqn(tensor_type), **options)
return wrap_fx_proxy(
tx,
tx.output.create_proxy(
"call_method",
name,
*proxy_args_kwargs([self, tensor_type_const], kwargs),
),
**options,
)
elif name == "get_device" and isinstance(self.device, torch.device):
index = self.device.index if self.device.type != "cpu" else -1
constant_result = ConstantVariable(index, **options)
else:
constant_result = None
if constant_result:
assert not kwargs, f"Tensor.{name}() unhandled kwargs"
# TODO: I think this branch is dead
if len(args) == 1:
return constant_result.getitem_const(args[0])
elif args:
return TupleVariable(
[constant_result.getitem_const(a) for a in args], **options
)
return constant_result
elif name == "numpy":
if not config.numpy_ndarray_as_tensor or not HAS_NUMPY_TORCH_INTEROP:
unimplemented(
f"Tensor.{name}. Turn on config.numpy_ndarray_as_tensor and install torch_np to support "
f"tensor.numpy(). "
)
from .builder import wrap_fx_proxy_cls
assert not args, "Tensor.numpy() doesn't take args."
# TODO: support force
if kwargs and "force" in kwargs:
unimplemented(f"Tensor.numpy(force={kwargs['force']})")
return wrap_fx_proxy_cls(
target_cls=NumpyNdarrayVariable,
tx=tx,
proxy=tx.output.create_proxy(
"call_function",
torch.detach,
*proxy_args_kwargs([self], {}),
),
example_value=None,
**options,
)
elif name in ("tolist", "backward", "data_ptr"):
unimplemented(f"Tensor.{name}")
elif name == "item" and not config.capture_scalar_outputs:
unimplemented(f"Tensor.{name}")
elif name == "__len__":
return self.call_method(tx, "size", [ConstantVariable(0, **options)], {})
elif name == "__setitem__":
key, value = args
def has_bool_key(v):
if isinstance(v, TensorVariable):
return v.dtype in (torch.bool, torch.int8)
elif isinstance(v, TupleVariable):
return any(has_bool_key(item) for item in v.items)
else:
return False
if (
not config.capture_dynamic_output_shape_ops
and has_bool_key(key)
and isinstance(value, TensorVariable)
and value.requires_grad
):
unimplemented(
"boolean masking setitem backwards requires dynamic shapes"
)
tx.output.guards.update(options["guards"])
tx.output.create_proxy(
"call_function",
operator.setitem,
*proxy_args_kwargs([self] + list(args), kwargs),
)
return ConstantVariable(None, **options)
elif name in ("resize_", "resize_as_"):
if "memory_format" in kwargs:
memory_format = kwargs["memory_format"].as_python_constant()
else:
memory_format = torch.contiguous_format
if name == "resize_":
self.size = args[0].as_python_constant()
self.is_contiguous = (memory_format,)
else:
assert isinstance(args[0], TensorVariable)
if self.size and args[0].size:
if (
self.size == args[0].size
or memory_format is torch.preserve_format
):
self.is_contiguous = args[0].is_contiguous
else:
self.size = args[0].size
self.stride = args[0].stride
self.ndim = args[0].ndim
self.is_contiguous = (memory_format,)
return wrap_fx_proxy(
tx,
tx.output.create_proxy(
"call_method",
name,
*proxy_args_kwargs([self] + list(args), kwargs),
),
**options,
)
elif (
name == "add_" and len(args) == 1 and len(kwargs) == 1 and "alpha" in kwargs
):
result = TorchVariable(torch.mul, **options).call_function(
tx, args + [kwargs["alpha"]], {}
)
return self.call_method(tx, "add_", [result], {})
elif (
name == "addcdiv_"
and len(args) == 2
and len(kwargs) == 1
and "value" in kwargs
):
result = TorchVariable(torch.div, **options).call_function(tx, args, {})
result = TorchVariable(torch.mul, **options).call_function(
tx, [result, kwargs["value"]], {}
)
return self.call_method(tx, "add_", [result], {})
elif name == "__contains__":
# Rewrite __contains__ here so that downstream passes can trace through
# without dealing with unbacked symbool. Roughly the code we translate is:
# def __contains__(self, x):
# return (x == self).any().item()
result = TorchVariable(torch.eq, **options).call_function(
tx, [self, args[0]], {}
)
result = TorchVariable(torch.any, **options).call_function(tx, [result], {})
return result.call_method(tx, "item", [], {})
else:
# Convert x.new(torch.Size) into x.new_empty(torch.Size),
# as Tensor.new acts differently with a Size input versus a tuple input.
if (
name == "new"
and len(args) == 1
and isinstance(args[0], (SizeVariable, ShapeVariable))
and not config.dynamic_shapes
):
name = "new_empty"
return wrap_fx_proxy(
tx,
tx.output.create_proxy(
"call_method",
name,
*proxy_args_kwargs([self] + list(args), kwargs),
),
**options,
)
class SymNodeVariable(VariableTracker):
"""
Represents a symbolic size, e.g., as returned by tensor.size(0)
"""
@classmethod
def create(cls, tx, proxy, sym_num, **options):
if "example_value" in proxy.node.meta:
assert proxy.node.meta["example_value"] == sym_num
if sym_num is None:
sym_num = get_fake_value(proxy.node, tx)
proxy.node.meta["example_value"] = sym_num
if isinstance(sym_num, (sympy.Integer, int)):
return ConstantVariable(int(sym_num))
return SymNodeVariable(proxy, sym_num, **options)
def __init__(self, proxy, sym_num, **kwargs):
super().__init__(**kwargs)
self.proxy = proxy
# TODO: Should we allow non SymTypes here? Today it is allowed
self.sym_num = sym_num
def python_type(self):
if isinstance(self.sym_num, SymTypes):
return self.sym_num.node.pytype
else:
return type(self.sym_num)
def unpack_var_sequence(self, tx):
super().unpack_var_sequence(tx)
def as_proxy(self):
return self.proxy
def evaluate_expr(self, output_graph=None):
return guard_scalar(self.sym_num)
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
from .builder import wrap_fx_proxy
options = VariableTracker.propagate(self, args, kwargs.values())
return wrap_fx_proxy(
tx,
tx.output.create_proxy(
"call_method",
name,
*proxy_args_kwargs([self] + list(args), kwargs),
),
**options,
)
class TensorWithTFOverrideVariable(VariableTracker):
"""
Represents a tensor subclass instance with a __torch_function__ override.
"""
def __init__(
self,
tensor_variable,
orig_tensor_variable_source,
subclass_torch_function__func,
subclass_type,
**kwargs,
):
super().__init__(**kwargs)
self.tensor_variable = tensor_variable
self.orig_tensor_variable_source = orig_tensor_variable_source
self.subclass_torch_function__func = subclass_torch_function__func
self.subclass_type = subclass_type
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
# This code block implements inlining the __torch_function__ override
# of `call_method`.
from . import GetAttrVariable
options = VariableTracker.propagate(self, args, kwargs.values())
# insert unwrapped version of self as the first argument
# TODO: This is wrong! When you call the internal __torch_function__,
# you still get the wrapped version of self, and if you call functions
# inside __torch_function__, they should come back here. If we unwrap
# the tensor immediately, that will not happen.
# See https://github.com/pytorch/torchdynamo/issues/1951
args = list(args)
args.insert(0, self.tensor_variable)
func_var = GetAttrVariable(self.tensor_variable, name)
unwrapped = TensorWithTFOverrideVariable.inline_torch_function_unwrapped(
tx,
func_var,
self.orig_tensor_variable_source,
self.subclass_torch_function__func,
self.subclass_type,
options,
args,
kwargs,
)
# TODO(future PR): implement rewrapping conditional on method presence
# in `torch.overrides.get_default_nowrap_function()`. It's unclear how
# to do this easily in the current codebase since the resolution of
# `GetAttrVariable` depends on the type of the underlying object.
return TensorWithTFOverrideVariable(
unwrapped,
self.orig_tensor_variable_source,
self.subclass_torch_function__func,
self.subclass_type,
)
@staticmethod
def inline_torch_function_unwrapped(
tx,
original_func_var,
tensor_with_tf_override_source,
tf_func,
subclass_type,
options,
args,
kwargs,
):
"""
This function inlines the `__torch_function__` override for `original_func_var`.
For example, if the user code is
x1 = torch.sigmoid(x0)
And `x0` has an override, then:
* `original_func_var` will be a `VariableTracker` object wrapping `torch.sigmoid`
* `tensor_with_tf_override_source` will be the `Source` object from
the original tensor override instance in the beginning of the program
* `tf_func` will be the custom `__torch_function__` function
* `subclass_type` will be `type(x0)`
The caller is expected to properly massage args and kwargs before
passing them into this function.
The caller is responsible for wrapping the return value, if needed.
"""
from . import UserDefinedClassVariable
from .builder import TupleVariable, VariableBuilder
source = AttrSource(
AttrSource(tensor_with_tf_override_source, "__torch_function__"),
"__func__",
)
tf_func_var = VariableBuilder(tx, source)(tf_func)
type_var = UserDefinedClassVariable(subclass_type, **options)
# signature:
# def __torch_function__(cls, func, types, args=(), kwargs=None):
tf_args = (
type_var, # cls
original_func_var, # func
(type_var,), # types
TupleVariable(args), # args
kwargs, # kwargs
)
# Disable __torch_function__ here to prevent the clone of the
# example tensor from going into the override.
with torch._C.DisableTorchFunctionSubclass():
return tx.inline_user_function_return(tf_func_var, tf_args, {})
class NumpyNdarrayVariable(TensorVariable):
"""
Represents a torch_np.ndarray, but backed by torch Tensor. Use this for Tensor.numpy() call.
"""
def __init__(
self,
proxy: torch.fx.Proxy,
**kwargs,
):
super().__init__(proxy, **kwargs)
def var_getattr(self, tx, name):
# NB: This INTENTIONALLY does not call super(), because there is
# no intrinsic reason ndarray properties are related to Tensor
# properties. The inheritance here is for implementation sharing.
from ..utils import numpy_attr_wrapper
from .builder import wrap_fx_proxy, wrap_fx_proxy_cls
result = None
options = VariableTracker.propagate(self)
import torch_np
example_value = self.proxy.node.meta["example_value"]
example_ndarray = torch_np.ndarray(example_value)
def insert_into_graph():
return wrap_fx_proxy(
tx,
tx.output.create_proxy(
"call_function", numpy_attr_wrapper, (self.as_proxy(), name), {}
),
**options,
)
if name in ["T", "real", "imag"]:
result = wrap_fx_proxy_cls(
target_cls=NumpyNdarrayVariable,
tx=tx,
proxy=tx.output.create_proxy(
"call_function",
numpy_attr_wrapper,
(self.as_proxy(), name),
{},
),
**options,
)
# These are awkward to implement. The standard playbook for torch_np
# interop is to trace a call into the torch_np wrapper which works for
# Tensor operations. However, we don't want to do this for calls
# that don't return Tensors, because in those cases we may not want
# to trace the attribute access into the graph at all (it is sort
# of harmless to do so, because AOTAutograd will eliminate them,
# but it's best not to trace them in to begin with.) But in any
# case, tracing these into the graph is like trying to fit a square
# peg into a round hole; best not to do it. So instead we
# painstakingly implement these by hand
#
# NB: only ALWAYS specialized attributes can go here; notably,
# size/shape not allowed!
elif name in ("ndim", "itemsize"):
return ConstantVariable(getattr(example_ndarray, name), **options)
elif name in ("shape", "stride"):
if not free_symbols(r := getattr(example_ndarray, name)):
return ConstantVariable(tuple(int(r) for r in r), **options)
return insert_into_graph()
elif name == "size":
if not free_symbols(r := example_ndarray.size):
return ConstantVariable(int(r), **options)
return insert_into_graph()
elif name in ["base", "flags", "dtype"]:
unimplemented(f"TODO: add support for ndarray.{name}")
if result is None:
raise NotImplementedError()
return result
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
options = VariableTracker.propagate([[self]], [args], [list(kwargs.values())])
from torch._dynamo.variables.builder import wrap_fx_proxy_cls
from ..utils import numpy_method_wrapper
result = wrap_fx_proxy_cls(
target_cls=NumpyNdarrayVariable,
tx=tx,
proxy=tx.output.create_proxy(
"call_function",
numpy_method_wrapper(name),
*proxy_args_kwargs([self] + list(args), kwargs),
),
example_value=None,
**options,
)
return result
class UnspecializedPythonVariable(TensorVariable):
"""
This is a 1-element tensor represents unspecialized python float/int.
"""
def __init__(self, proxy: torch.fx.Proxy, **kwargs):
raw_value = kwargs.pop("raw_value", None)
need_unwrap = kwargs.pop("need_unwrap", True)
super().__init__(proxy, **kwargs)
self.raw_value = raw_value
self.need_unwrap = need_unwrap
@classmethod
def from_tensor_variable(cls, tensor_variable, raw_value, need_unwrap=True):
# Convert a `TensorVariable` instance into an `UnspecializedPythonVariable` instance.
return UnspecializedPythonVariable(
**dict(tensor_variable.__dict__),
raw_value=raw_value,
need_unwrap=need_unwrap,
)
def as_specialized(self, tx):
for graph_arg in tx.output.graphargs:
if graph_arg.source is self.source:
graph_arg.erase()
for g in self.guards:
if g.is_volatile:
g.create_fn = GuardBuilder.CONSTANT_MATCH
return ConstantVariable(value=self.raw_value, guards=self.guards)
class FakeItemVariable(TensorVariable):
"""An unspecialized python variable which prevents access to the underlying raw value.
This is needed if item is called on a FakeTensor."""
def __init__(self, proxy: torch.fx.Proxy, **kwargs):
need_unwrap = kwargs.pop("need_unwrap", False)
super().__init__(proxy, **kwargs)
self.need_unwrap = need_unwrap
@classmethod
def from_tensor_variable(cls, tensor_variable):
return FakeItemVariable(**dict(tensor_variable.__dict__))