mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
AOTAutograd's handling for resize_() isn't fully robust (and on top of that, functionalization can potentially give up and raise an error if the tensor you're resizing has outstanding views). So given that, and given that resize_() is rare, I updated dynamo to graph break on resize_() instead. Pull Request resolved: https://github.com/pytorch/pytorch/pull/111553 Approved by: https://github.com/ezyang
991 lines
38 KiB
Python
991 lines
38 KiB
Python
import functools
|
|
import inspect
|
|
import operator
|
|
import types
|
|
from typing import Dict, List
|
|
|
|
try:
|
|
import numpy as np
|
|
except ModuleNotFoundError:
|
|
np = None
|
|
|
|
|
|
import sympy
|
|
|
|
import torch._numpy as tnp
|
|
|
|
import torch.fx
|
|
import torch.random
|
|
from torch._dynamo import compiled_autograd
|
|
|
|
from torch.fx.experimental.symbolic_shapes import free_symbols, guard_scalar, SymTypes
|
|
|
|
from .. import config, variables
|
|
from .._trace_wrapped_higher_order_op import trace_wrapped
|
|
|
|
from ..exc import unimplemented
|
|
from ..guards import GuardBuilder
|
|
from ..source import AttrSource
|
|
from ..utils import (
|
|
fqn,
|
|
get_custom_getattr,
|
|
get_fake_value,
|
|
get_real_value,
|
|
guard_if_dyn,
|
|
object_has_getattribute,
|
|
product,
|
|
proxy_args_kwargs,
|
|
tensortype_to_dtype,
|
|
)
|
|
from .base import VariableTracker
|
|
from .constant import ConstantVariable
|
|
from .lists import SizeVariable
|
|
|
|
supported_tensor_comparison_ops = {
|
|
">": operator.gt,
|
|
"<": operator.lt,
|
|
">=": operator.ge,
|
|
"<=": operator.le,
|
|
"==": operator.eq,
|
|
"!=": operator.ne,
|
|
}
|
|
supported_const_comparison_ops = {
|
|
"is": operator.is_,
|
|
"is not": operator.is_not,
|
|
"==": operator.eq,
|
|
"!=": operator.ne,
|
|
}
|
|
|
|
|
|
class TensorVariable(VariableTracker):
|
|
"""A torch.Tensor input or an intermediate value in the FX graph"""
|
|
|
|
_nonvar_fields = [
|
|
"proxy",
|
|
"dtype",
|
|
"device",
|
|
"layout",
|
|
"ndim",
|
|
"size",
|
|
"stride",
|
|
"requires_grad",
|
|
"is_quantized",
|
|
"is_contiguous",
|
|
]
|
|
|
|
def get_real_value(self):
|
|
"""
|
|
Get the actual value represented by this variable if computation is run
|
|
using the user-provided inputs.
|
|
NOTE: this runs actual tensor computation and may be
|
|
slow and memory-intensive.
|
|
"""
|
|
return get_real_value(self.proxy.node, self.proxy.tracer)
|
|
|
|
def __init__(
|
|
self,
|
|
proxy: torch.fx.Proxy,
|
|
*,
|
|
dtype,
|
|
device,
|
|
layout,
|
|
ndim,
|
|
requires_grad,
|
|
is_quantized,
|
|
is_sparse,
|
|
class_type,
|
|
size=None,
|
|
stride=None,
|
|
is_contiguous=None,
|
|
specialized_value=None,
|
|
**kwargs,
|
|
):
|
|
super().__init__(**kwargs)
|
|
self.proxy = proxy
|
|
self.dtype = dtype
|
|
self.device = device
|
|
self.layout = layout
|
|
self.ndim = ndim
|
|
self.size = size
|
|
self.stride = stride
|
|
self.requires_grad = requires_grad
|
|
self.is_quantized = is_quantized
|
|
self.is_contiguous = is_contiguous
|
|
self.is_sparse = is_sparse
|
|
self.class_type = class_type
|
|
self.specialized_value = specialized_value
|
|
|
|
def as_proxy(self):
|
|
return self.proxy
|
|
|
|
def python_type(self):
|
|
return self.class_type
|
|
|
|
def call_isinstance(self, tensor_type):
|
|
def check_type(ty):
|
|
if ty not in tensortype_to_dtype:
|
|
return issubclass(self.python_type(), ty)
|
|
|
|
dtypes = tensortype_to_dtype[ty]
|
|
return self.dtype in dtypes
|
|
|
|
if type(tensor_type) is tuple:
|
|
return any(check_type(ty) for ty in tensor_type)
|
|
else:
|
|
return check_type(tensor_type)
|
|
|
|
@staticmethod
|
|
def specialize(value: torch.Tensor):
|
|
props = {
|
|
"dtype": value.dtype,
|
|
"device": value.device,
|
|
"layout": value.layout,
|
|
"ndim": int(value.ndim),
|
|
"requires_grad": value.requires_grad,
|
|
"is_quantized": value.is_quantized,
|
|
"is_sparse": value.is_sparse,
|
|
"class_type": type(value),
|
|
}
|
|
if not free_symbols(value):
|
|
# this is a fully static shape, and the keys on props here inform specialization.
|
|
# We have to cast to int here, because these might get accessed as ConstantVariable, which has
|
|
# a strict no-symint policy. If we got here due to not having free symbols, this is a known constant
|
|
# already. We could remove the discrepancy here, by having ConstantVariable be more permissive for
|
|
# constant backed SymInts, but that assert being strict has led to some good signal in hunting bugs, and
|
|
# I'd like to keep it around for now.
|
|
props["size"] = tuple([int(s) for s in value.size()])
|
|
props["stride"] = tuple(value.stride())
|
|
props["is_contiguous"] = tuple(
|
|
[
|
|
x
|
|
for x in torch._prims_common._memory_formats
|
|
if value.is_contiguous(memory_format=x)
|
|
]
|
|
)
|
|
return props
|
|
|
|
def dynamic_getattr(self, tx, name):
|
|
if not self.source:
|
|
raise NotImplementedError()
|
|
|
|
# For local source, we associate the real value. We use this real value
|
|
# for implementing getattr fallthrough on the variable tracker base class.
|
|
|
|
# Note - this scope construction is mirrored in guards
|
|
# A subsequent PR will introduce a util.
|
|
scope = {"L": tx.output.local_scope, "G": tx.output.global_scope}
|
|
try:
|
|
# We raise in case we get a typerror bug w/ SuperSource.
|
|
# SuperSource has bugs in it atm, and can produce code like
|
|
# eval("super(L['mod'].model.model.encoder.embed_positions.forward__class__,
|
|
# L['mod'].model.model.encoder.embed_positions)", scope)
|
|
# Which is incorrect, and violates the invariant that all sources should be eval()-able against the scope.
|
|
_input_associated_real_value = eval(self.source.name(), scope)
|
|
except Exception as exc:
|
|
raise NotImplementedError() from exc
|
|
|
|
if _input_associated_real_value is None:
|
|
raise NotImplementedError()
|
|
|
|
if object_has_getattribute(_input_associated_real_value):
|
|
raise NotImplementedError()
|
|
|
|
if get_custom_getattr(_input_associated_real_value):
|
|
raise NotImplementedError()
|
|
|
|
real_value = getattr(_input_associated_real_value, name)
|
|
if callable(real_value):
|
|
# Callables have more nuanced handling, and we should let the existing system delegate here.
|
|
# Raising was past behavior and so should always be sound to fall back.
|
|
# Note - at a certain point we may want to handle
|
|
raise NotImplementedError()
|
|
|
|
from ..guards import GuardBuilder
|
|
from .builder import VariableBuilder
|
|
|
|
attr_source = AttrSource(self.source, name)
|
|
has_attr_guard = attr_source.make_guard(GuardBuilder.HASATTR)
|
|
return (
|
|
VariableBuilder(tx, attr_source)(real_value)
|
|
.add_options(self)
|
|
.add_guard(has_attr_guard)
|
|
)
|
|
|
|
def var_getattr(self, tx, name):
|
|
from . import ConstantVariable, TorchVariable
|
|
|
|
if tx.strict_checks_enabled:
|
|
if name in self._strict_mode_banned_ops():
|
|
unimplemented(f"Illegal getattr invocation {name} in strict mode")
|
|
|
|
result = None
|
|
options = VariableTracker.propagate(self)
|
|
if name == "ndim" and self.ndim is not None:
|
|
result = ConstantVariable.create(self.ndim, **options)
|
|
elif name == "dtype" and self.dtype is not None:
|
|
result = TorchVariable(self.dtype, **options)
|
|
elif name == "device" and self.device is not None:
|
|
result = TorchVariable(self.device, **options)
|
|
elif name == "layout" and self.layout is not None:
|
|
result = TorchVariable(self.layout, **options)
|
|
elif name == "is_cuda" and self.device is not None:
|
|
result = ConstantVariable.create(self.device.type == "cuda", **options)
|
|
elif name == "shape" and self.size is not None:
|
|
sizes = [variables.ConstantVariable.create(x) for x in self.size]
|
|
result = SizeVariable(sizes, **options)
|
|
elif name == "requires_grad" and self.requires_grad is not None:
|
|
result = ConstantVariable.create(self.requires_grad, **options)
|
|
elif name == "is_quantized" and self.is_quantized is not None:
|
|
result = ConstantVariable.create(self.is_quantized, **options)
|
|
elif name == "is_sparse" and self.is_sparse is not None:
|
|
result = ConstantVariable.create(self.is_sparse, **options)
|
|
elif name == "shape" and self.size is None:
|
|
result = self.call_method(tx, "size", [], {})
|
|
elif name == "ndim" and self.ndim is None:
|
|
result = self.call_method(tx, "dim", [], {})
|
|
elif name == "data":
|
|
result = self.call_method(tx, "detach", [], {})
|
|
if name == "__class__":
|
|
return TorchVariable(self.python_type(), **options)
|
|
|
|
# Add a guard for type matching, these guards are checked before tensor guards
|
|
# In some cases, a <tensor>.<attr> guard can be evaluated first, and break if
|
|
# <tensor> is later changed to another type
|
|
if result is not None and self.source is not None:
|
|
result = result.add_guard(self.make_guard(GuardBuilder.TYPE_MATCH))
|
|
|
|
# It's hard to get inplace view (metadata mutation) on graph input work properly across
|
|
# dynamo/aot/inductor, just fall back.
|
|
if self.source is not None and hasattr(torch.ops.aten, name):
|
|
fn = getattr(torch.ops.aten, name)
|
|
if (
|
|
hasattr(fn, "overloads")
|
|
and hasattr(fn, fn.overloads()[0])
|
|
and torch.Tag.inplace_view in getattr(fn, fn.overloads()[0]).tags
|
|
):
|
|
# Delay the graph break to the actual call of unsqueeze_/resize_/resize_as_ etc.
|
|
return variables.misc.DelayGraphBreakVariable()
|
|
|
|
# For attributes (not methods) that were not caught in the special handling above,
|
|
# (e.g. tensor.real), we handle these generically, assuming that the output type is
|
|
# a tensor.
|
|
if result is None:
|
|
|
|
def try_generic_attr_handling():
|
|
from .builder import wrap_fx_proxy
|
|
from .misc import GetAttrVariable
|
|
|
|
try:
|
|
static_attr = inspect.getattr_static(torch.Tensor, name)
|
|
except AttributeError:
|
|
return None
|
|
|
|
# Make sure this is an attribute, not a method.
|
|
# type(torch.Tensor.H) should be "getset_descriptor"
|
|
# This is a because of CPython implementation, see THPVariableType:
|
|
# these attributes are implemented under tp_getset, which appear
|
|
# as `getset_descriptor`s, (compared to, say, methods which appear
|
|
# as `method_descriptor`s)
|
|
if type(static_attr) != types.GetSetDescriptorType:
|
|
return None
|
|
|
|
return wrap_fx_proxy(
|
|
tx=tx,
|
|
proxy=GetAttrVariable.create_getattr_proxy(self.as_proxy(), name),
|
|
**options,
|
|
)
|
|
|
|
result = try_generic_attr_handling()
|
|
|
|
if result is None:
|
|
result = self.dynamic_getattr(tx, name)
|
|
|
|
if result is None:
|
|
raise NotImplementedError()
|
|
return result
|
|
|
|
def has_unpack_var_sequence(self, tx):
|
|
return self.ndim > 0
|
|
|
|
def unpack_var_sequence(self, tx, idxes=None):
|
|
from .builder import wrap_fx_proxy_cls
|
|
|
|
options = VariableTracker.propagate(self)
|
|
if idxes is None:
|
|
if self.size:
|
|
length = self.size[0]
|
|
else:
|
|
dyn_length = self.call_method(
|
|
tx, "size", [ConstantVariable.create(0)], {}
|
|
)
|
|
# SymNodeVariable for symbolic sizes, ConstantVariable for constants OR values produced through
|
|
# symbolic_shapes, but that end up as int/sympy.Integer
|
|
assert isinstance(dyn_length, (SymNodeVariable, ConstantVariable))
|
|
if isinstance(dyn_length, SymNodeVariable):
|
|
length = dyn_length.evaluate_expr(tx.output)
|
|
else:
|
|
length = dyn_length.value
|
|
idxes = range(length)
|
|
return [
|
|
wrap_fx_proxy_cls(
|
|
target_cls=type(self), tx=tx, proxy=self.as_proxy()[i], **options
|
|
)
|
|
for i in idxes
|
|
]
|
|
|
|
def _strict_mode_banned_ops(self):
|
|
return torch._dynamo.config._autograd_backward_strict_mode_banned_ops
|
|
|
|
def call_method(
|
|
self,
|
|
tx,
|
|
name,
|
|
args: "List[VariableTracker]",
|
|
kwargs: "Dict[str, VariableTracker]",
|
|
) -> "VariableTracker":
|
|
if tx.strict_checks_enabled:
|
|
if name in self._strict_mode_banned_ops():
|
|
unimplemented(f"Illegal method invocation {name} in strict mode")
|
|
from . import ConstantVariable, TorchVariable, TupleVariable
|
|
from .builder import wrap_fx_proxy
|
|
|
|
kwargs = dict(kwargs)
|
|
options = VariableTracker.propagate(self, args, kwargs.values())
|
|
|
|
if name in ("stride", "size"):
|
|
dim_var = None
|
|
if len(args) == 1:
|
|
dim_var = args[0]
|
|
elif "dim" in kwargs:
|
|
dim_var = kwargs["dim"]
|
|
else:
|
|
assert not args and not kwargs, f"Tensor.{name}() unhandled args/kwargs"
|
|
|
|
dim = guard_if_dyn(dim_var)
|
|
|
|
def make_const_size_variable(x, **options):
|
|
return SizeVariable(
|
|
[ConstantVariable.create(y, **options) for y in x], **options
|
|
)
|
|
|
|
RetVariable = (
|
|
make_const_size_variable if name == "size" else ConstantVariable.create
|
|
)
|
|
|
|
# Technically, this should not be necessary, but I'm including it
|
|
# for enhanced BC, in case example_value is sometimes not set
|
|
# (it really should always be set though!)
|
|
if (r := getattr(self, name)) is not None:
|
|
if dim is None:
|
|
return RetVariable(r, **options)
|
|
else:
|
|
return ConstantVariable.create(r[dim], **options)
|
|
|
|
# It might still be constant! Consult the fake tensor and see
|
|
if (fake := self.proxy.node.meta.get("example_value")) is not None:
|
|
if dim is None:
|
|
fake_r = getattr(fake, name)()
|
|
if not free_symbols(fake_r):
|
|
# int conversion for safety, in case a SymInt refined
|
|
# to constant
|
|
return RetVariable(tuple(int(r) for r in fake_r), **options)
|
|
else:
|
|
fake_r = getattr(fake, name)(dim)
|
|
if not free_symbols(fake_r):
|
|
return ConstantVariable.create(int(fake_r), **options)
|
|
|
|
# Oops, it's not constant. Do the dynamic shapes path.
|
|
return wrap_fx_proxy(
|
|
tx,
|
|
tx.output.create_proxy(
|
|
"call_method",
|
|
name,
|
|
*proxy_args_kwargs([self] + list(args), kwargs),
|
|
),
|
|
**options,
|
|
)
|
|
|
|
elif name in ("numel", "nelement"):
|
|
if self.size is not None:
|
|
return ConstantVariable.create(product(self.size), **options)
|
|
|
|
# It might still be constant! Consult the fake tensor and see
|
|
if (fake := self.proxy.node.meta.get("example_value")) is not None:
|
|
fake_r = fake.numel()
|
|
if not free_symbols(fake_r):
|
|
return ConstantVariable.create(int(fake_r), **options)
|
|
|
|
assert not kwargs, f"Tensor.{name}() unhandled kwargs"
|
|
|
|
# Oops, it's not constant. Do the dynamic shapes path.
|
|
return wrap_fx_proxy(
|
|
tx,
|
|
tx.output.create_proxy(
|
|
"call_method",
|
|
"numel",
|
|
*proxy_args_kwargs([self] + list(args), kwargs),
|
|
),
|
|
**options,
|
|
)
|
|
|
|
elif name in ("ndimension", "dim") and self.ndim is not None:
|
|
constant_result = ConstantVariable.create(self.ndim, **options)
|
|
elif name == "is_floating_point" and self.dtype is not None:
|
|
constant_result = ConstantVariable.create(
|
|
self.dtype.is_floating_point, **options
|
|
)
|
|
elif name == "is_contiguous" and self.is_contiguous is not None:
|
|
if "memory_format" in kwargs:
|
|
memory_format = kwargs.pop("memory_format").as_python_constant()
|
|
else:
|
|
memory_format = torch.contiguous_format
|
|
constant_result = ConstantVariable.create(
|
|
memory_format in self.is_contiguous, **options
|
|
)
|
|
elif (
|
|
name == "type"
|
|
and self.dtype is not None
|
|
and len(args) == 0
|
|
and isinstance(self.device, torch.device)
|
|
):
|
|
tensortype = [k for k, v in tensortype_to_dtype.items() if self.dtype in v][
|
|
0
|
|
]
|
|
if self.device.type == "cuda":
|
|
constant_result = ConstantVariable.create(
|
|
f"torch.cuda.{tensortype.__name__}", **options
|
|
)
|
|
else:
|
|
constant_result = ConstantVariable.create(
|
|
f"torch.{tensortype.__name__}", **options
|
|
)
|
|
elif (
|
|
name == "type"
|
|
and len(args) == 1
|
|
and fqn(type(args[0].as_python_constant())) == "torch.tensortype"
|
|
):
|
|
# torch.FloatTensor, etc. are all of type "torch.tensortype".
|
|
# torch.fx's tracer fails on these types, because it doesn't support arguments of torch.tensortype type.
|
|
# So, we pass it in as a string (which is also supported, see above implementation for .type() with 0 args)
|
|
tensor_type = args[0].as_python_constant()
|
|
tensor_type_const = ConstantVariable.create(fqn(tensor_type), **options)
|
|
return wrap_fx_proxy(
|
|
tx,
|
|
tx.output.create_proxy(
|
|
"call_method",
|
|
name,
|
|
*proxy_args_kwargs([self, tensor_type_const], kwargs),
|
|
),
|
|
**options,
|
|
)
|
|
elif name == "get_device" and isinstance(self.device, torch.device):
|
|
index = self.device.index if self.device.type != "cpu" else -1
|
|
constant_result = ConstantVariable.create(index, **options)
|
|
else:
|
|
constant_result = None
|
|
|
|
if constant_result:
|
|
assert not kwargs, f"Tensor.{name}() unhandled kwargs"
|
|
# TODO: I think this branch is dead
|
|
if len(args) == 1:
|
|
return constant_result.getitem_const(args[0])
|
|
elif args:
|
|
return TupleVariable(
|
|
[constant_result.getitem_const(a) for a in args], **options
|
|
)
|
|
return constant_result
|
|
elif name == "numpy":
|
|
if not config.trace_numpy:
|
|
unimplemented("Tensor.numpy(). config.trace_numpy is False")
|
|
if not np:
|
|
unimplemented("Tensor.numpy(). NumPy is not available")
|
|
assert not args, "Tensor.numpy() doesn't take args."
|
|
if self.layout != torch.strided:
|
|
raise TypeError(
|
|
f"can't convert {self.layout} layout tensor to numpy. Use Tensor.dense() first"
|
|
)
|
|
# We don't check that the tensor is on CPU when force is False, as this
|
|
# allows us to execute NumPy code on CUDA.
|
|
# We don't check that requires_grad=False as we are currently doing an
|
|
# unconditional detach.
|
|
# TODO: We may want to avoid detaching if `requires_grad=True`
|
|
# and `force=False` to allow computing gradients.
|
|
force = "force" in kwargs and kwargs["force"].as_python_constant()
|
|
proxy = tx.output.create_proxy(
|
|
"call_method", "detach", *proxy_args_kwargs([self], {})
|
|
)
|
|
if force:
|
|
# TODO Add resolve_conj and resolve_neg once we support complex tensors
|
|
proxy = tx.output.create_proxy(
|
|
"call_method", "cpu", *proxy_args_kwargs([self], {})
|
|
)
|
|
return NumpyNdarrayVariable.create(tx, proxy, **options)
|
|
elif name == "tolist":
|
|
from .builder import SourcelessBuilder
|
|
|
|
def tolist(tensor, sub_proxy):
|
|
def wrap(i, sub_proxy):
|
|
return SymNodeVariable.create(
|
|
tx,
|
|
sub_proxy.item(),
|
|
sym_num=tx.output.shape_env.create_unbacked_symint(),
|
|
)
|
|
|
|
if tensor.dtype not in [
|
|
torch.int8,
|
|
torch.int16,
|
|
torch.int32,
|
|
torch.int64,
|
|
]:
|
|
unimplemented("Input tensor for tolist must be an integer tensor")
|
|
|
|
if tensor.dim() == 0:
|
|
return wrap(tensor, sub_proxy)
|
|
|
|
if tensor.dim() == 1:
|
|
return [wrap(val, sub_proxy[i]) for i, val in enumerate(tensor)]
|
|
|
|
return [
|
|
tolist(sub_tensor, sub_proxy=sub_proxy[i])
|
|
for i, sub_tensor in enumerate(tensor)
|
|
]
|
|
|
|
tensor = self.as_proxy().node.meta["example_value"]
|
|
out = tolist(tensor, self.as_proxy())
|
|
return SourcelessBuilder()(tx, out).add_options(options)
|
|
elif name in ("backward", "data_ptr"):
|
|
unimplemented(f"Tensor.{name}")
|
|
elif name == "item" and not config.capture_scalar_outputs:
|
|
unimplemented(f"Tensor.{name}")
|
|
elif name == "__len__":
|
|
return self.call_method(
|
|
tx, "size", [ConstantVariable.create(0, **options)], {}
|
|
)
|
|
elif name == "__setitem__":
|
|
key, value = args
|
|
|
|
def has_bool_key(v):
|
|
if isinstance(v, TensorVariable):
|
|
return v.dtype in (torch.bool, torch.int8)
|
|
elif isinstance(v, TupleVariable):
|
|
return any(has_bool_key(item) for item in v.items)
|
|
else:
|
|
return False
|
|
|
|
if (
|
|
not config.capture_dynamic_output_shape_ops
|
|
and has_bool_key(key)
|
|
and isinstance(value, TensorVariable)
|
|
and value.requires_grad
|
|
):
|
|
unimplemented(
|
|
"boolean masking setitem backwards requires dynamic shapes"
|
|
)
|
|
tx.output.guards.update(options["guards"])
|
|
tx.output.create_proxy(
|
|
"call_function",
|
|
operator.setitem,
|
|
*proxy_args_kwargs([self] + list(args), kwargs),
|
|
)
|
|
return ConstantVariable.create(None, **options)
|
|
elif name in ("resize_", "resize_as_"):
|
|
# Handling resizing in its full generality is difficult.
|
|
unimplemented(f"Tensor.{name}")
|
|
elif (
|
|
name == "add_" and len(args) == 1 and len(kwargs) == 1 and "alpha" in kwargs
|
|
):
|
|
result = TorchVariable(torch.mul, **options).call_function(
|
|
tx, args + [kwargs["alpha"]], {}
|
|
)
|
|
return self.call_method(tx, "add_", [result], {})
|
|
elif (
|
|
name == "addcdiv_"
|
|
and len(args) == 2
|
|
and len(kwargs) == 1
|
|
and "value" in kwargs
|
|
):
|
|
result = TorchVariable(torch.div, **options).call_function(tx, args, {})
|
|
result = TorchVariable(torch.mul, **options).call_function(
|
|
tx, [result, kwargs["value"]], {}
|
|
)
|
|
return self.call_method(tx, "add_", [result], {})
|
|
elif name == "__contains__":
|
|
# Rewrite __contains__ here so that downstream passes can trace through
|
|
# without dealing with unbacked symbool. Roughly the code we translate is:
|
|
# def __contains__(self, x):
|
|
# return (x == self).any().item()
|
|
result = TorchVariable(torch.eq, **options).call_function(
|
|
tx, [self, args[0]], {}
|
|
)
|
|
result = TorchVariable(torch.any, **options).call_function(tx, [result], {})
|
|
return result.call_method(tx, "item", [], {})
|
|
elif name == "redistribute":
|
|
# rewrite non-primitive args/kwargs to be included in the on-the-fly prim function
|
|
# and rewrite args to have only proxyable args, then insert call_function
|
|
args_as_value = [x.as_python_constant() for x in args]
|
|
kwargs_as_value = {k: v.as_python_constant() for k, v in kwargs.items()}
|
|
|
|
def redistribute_fn_with_prim_types(x):
|
|
return x.redistribute(*args_as_value, **kwargs_as_value)
|
|
|
|
# attach the same function name for better debugging
|
|
redistribute_fn_with_prim_types.__name__ = f"prim_{name}"
|
|
|
|
return wrap_fx_proxy(
|
|
tx=tx,
|
|
proxy=tx.output.create_proxy(
|
|
"call_function",
|
|
redistribute_fn_with_prim_types,
|
|
*proxy_args_kwargs([self], {}),
|
|
),
|
|
**options,
|
|
)
|
|
elif name == "register_hook":
|
|
# see [On tensor.register_hook]
|
|
assert len(args) == 1
|
|
fn_var = args[0]
|
|
if not isinstance(
|
|
fn_var,
|
|
(
|
|
variables.functions.FunctoolsPartialVariable,
|
|
variables.UserFunctionVariable,
|
|
variables.TorchVariable,
|
|
variables.NNModuleVariable,
|
|
),
|
|
):
|
|
unimplemented("Unexpected callable type passed to register_hook")
|
|
|
|
# Guards from the fn_var
|
|
options.update(VariableTracker.propagate(fn_var))
|
|
|
|
if isinstance(fn_var, variables.NestedUserFunctionVariable):
|
|
# NestedUserFunctionVariable don't carry their fn, but reconstruction builds it
|
|
# This should not be onerous to support when needed.
|
|
unimplemented("NYI - lambda variables as hooks")
|
|
elif isinstance(fn_var, variables.functions.FunctoolsPartialVariable):
|
|
fn = fn_var.as_python_constant()
|
|
name = fn_var.func.fn.__name__
|
|
else:
|
|
fn = fn_var.fn
|
|
name = fn_var.fn.__name__
|
|
|
|
handle_variable = variables.user_defined.RemovableHandleVariable(
|
|
mutable_local=variables.base.MutableLocal(),
|
|
**options,
|
|
)
|
|
|
|
if not self.source:
|
|
# Intermediary
|
|
src = fn_var.source
|
|
if (
|
|
not src
|
|
and isinstance(fn_var, variables.functions.FunctoolsPartialVariable)
|
|
and fn_var.func.source
|
|
):
|
|
src = fn_var.func.source
|
|
|
|
if not src:
|
|
unimplemented("No source for register_hook target fn")
|
|
|
|
tx.output.guards.add(src.make_guard(GuardBuilder.ID_MATCH))
|
|
|
|
if not compiled_autograd.compiled_autograd_enabled:
|
|
# TODO(voz):
|
|
# We can relax this by speculating the callable and ensuring that it doesn't modify arbitrary
|
|
# python state.
|
|
# We *Must* be in compiled_autograd here because backward hooks can contain anything, and it is unsafe to run
|
|
# them in a compiled bwd without re-entering dynamo as compiled_autograd does.
|
|
#
|
|
# Discussion point 1 - Should we bypass this if nopython/fullgraph = True?
|
|
# No. Because this was going to be a graph break anyway - this check does not
|
|
# introduce new graph breaks where there were none.
|
|
#
|
|
# Discussion point 2 - Should we defer this check to backwards?
|
|
# No. Because compiled autograd is not yet ready for prime time. As such, if we defer, a user
|
|
# would have no recourse - their forward traces just fine, but will fail at backwards unless
|
|
# compiled_autograd is enabled. If compiled_autograd fails (there are a lot of failures today)
|
|
# then they have nothing they can do except disable compile.
|
|
unimplemented(
|
|
"Compilation of intermediate hooks requires compiled autograd"
|
|
)
|
|
|
|
# This wraps our user provided fn with a function that intercedes and
|
|
# uses our `invoke` higher order op to record a hook invocation in bwd graph.
|
|
fn = functools.partial(trace_wrapped, fn=fn)
|
|
|
|
def _register_hook_trampoline(tensor):
|
|
tensor.register_hook(fn)
|
|
return tensor
|
|
|
|
return wrap_fx_proxy(
|
|
tx,
|
|
tx.output.create_proxy(
|
|
"call_function",
|
|
_register_hook_trampoline,
|
|
(self.as_proxy(),),
|
|
{},
|
|
),
|
|
**options,
|
|
)
|
|
|
|
tx.output.side_effects.register_hook(self, fn_var, handle_variable)
|
|
return handle_variable
|
|
elif name == "requires_grad_" and self.as_proxy().node.meta[
|
|
"example_value"
|
|
].requires_grad != (args[0].value if len(args) > 0 else True):
|
|
unimplemented("Tensor.requires_grad_")
|
|
|
|
else:
|
|
# Convert x.new(torch.Size) into x.new_empty(torch.Size),
|
|
# as Tensor.new acts differently with a Size input versus a tuple input.
|
|
if name == "new" and len(args) == 1 and isinstance(args[0], SizeVariable):
|
|
name = "new_empty"
|
|
return wrap_fx_proxy(
|
|
tx,
|
|
tx.output.create_proxy(
|
|
"call_method",
|
|
name,
|
|
*proxy_args_kwargs([self] + list(args), kwargs),
|
|
),
|
|
**options,
|
|
)
|
|
|
|
def rename(self, tx, name):
|
|
self.proxy.node._rename(name)
|
|
return super().rename(tx, name)
|
|
|
|
|
|
class SymNodeVariable(VariableTracker):
|
|
"""
|
|
Represents a symbolic size, e.g., as returned by tensor.size(0)
|
|
"""
|
|
|
|
@classmethod
|
|
def create(cls, tx, proxy, sym_num, **options):
|
|
if "example_value" in proxy.node.meta:
|
|
assert proxy.node.meta["example_value"] == sym_num
|
|
if sym_num is None:
|
|
sym_num = get_fake_value(proxy.node, tx)
|
|
proxy.node.meta["example_value"] = sym_num
|
|
|
|
if isinstance(sym_num, (sympy.Integer, int)):
|
|
return ConstantVariable.create(int(sym_num))
|
|
|
|
return SymNodeVariable(proxy, sym_num, **options)
|
|
|
|
def __init__(self, proxy, sym_num, **kwargs):
|
|
super().__init__(**kwargs)
|
|
self.proxy = proxy
|
|
# TODO: Should we allow non SymTypes here? Today it is allowed
|
|
self.sym_num = sym_num
|
|
|
|
def python_type(self):
|
|
if isinstance(self.sym_num, SymTypes):
|
|
return self.sym_num.node.pytype
|
|
else:
|
|
return type(self.sym_num)
|
|
|
|
def unpack_var_sequence(self, tx):
|
|
super().unpack_var_sequence(tx)
|
|
|
|
def as_proxy(self):
|
|
return self.proxy
|
|
|
|
def evaluate_expr(self, output_graph=None):
|
|
return guard_scalar(self.sym_num)
|
|
|
|
def call_method(
|
|
self,
|
|
tx,
|
|
name,
|
|
args: "List[VariableTracker]",
|
|
kwargs: "Dict[str, VariableTracker]",
|
|
) -> "VariableTracker":
|
|
from .builder import wrap_fx_proxy
|
|
|
|
options = VariableTracker.propagate(self, args, kwargs.values())
|
|
|
|
return wrap_fx_proxy(
|
|
tx,
|
|
tx.output.create_proxy(
|
|
"call_method",
|
|
name,
|
|
*proxy_args_kwargs([self] + list(args), kwargs),
|
|
),
|
|
**options,
|
|
)
|
|
|
|
|
|
class NumpyNdarrayVariable(TensorVariable):
|
|
"""
|
|
Represents an np.ndarray, but backed by torch Tensor via torch._numpy.ndarray.
|
|
Use this for Tensor.numpy() call.
|
|
"""
|
|
|
|
@staticmethod
|
|
def create(tx, proxy, **options):
|
|
from .builder import wrap_fx_proxy_cls
|
|
|
|
return wrap_fx_proxy_cls(
|
|
target_cls=NumpyNdarrayVariable,
|
|
tx=tx,
|
|
proxy=proxy,
|
|
**options,
|
|
)
|
|
|
|
def var_getattr(self, tx, name):
|
|
# NB: This INTENTIONALLY does not call super(), because there is
|
|
# no intrinsic reason ndarray properties are related to Tensor
|
|
# properties. The inheritance here is for implementation sharing.
|
|
|
|
from ..utils import numpy_attr_wrapper
|
|
from .builder import wrap_fx_proxy
|
|
|
|
result = None
|
|
options = VariableTracker.propagate(self)
|
|
|
|
example_value = self.as_proxy().node.meta["example_value"]
|
|
example_ndarray = tnp.ndarray(example_value)
|
|
|
|
def insert_into_graph():
|
|
return wrap_fx_proxy(
|
|
tx,
|
|
tx.output.create_proxy(
|
|
"call_function", numpy_attr_wrapper, (self.as_proxy(), name), {}
|
|
),
|
|
**options,
|
|
)
|
|
|
|
if name in ["T", "real", "imag"]:
|
|
proxy = tx.output.create_proxy(
|
|
"call_function",
|
|
numpy_attr_wrapper,
|
|
(self.as_proxy(), name),
|
|
{},
|
|
)
|
|
result = NumpyNdarrayVariable.create(tx, proxy, **options)
|
|
|
|
# These are awkward to implement. The standard playbook for torch._numpy
|
|
# interop is to trace a call into the torch._numpy wrapper which works for
|
|
# Tensor operations. However, we don't want to do this for calls
|
|
# that don't return Tensors, because in those cases we may not want
|
|
# to trace the attribute access into the graph at all (it is sort
|
|
# of harmless to do so, because AOTAutograd will eliminate them,
|
|
# but it's best not to trace them in to begin with.) But in any
|
|
# case, tracing these into the graph is like trying to fit a square
|
|
# peg into a round hole; best not to do it. So instead we
|
|
# painstakingly implement these by hand
|
|
#
|
|
# NB: only ALWAYS specialized attributes can go here; notably,
|
|
# size/shape not allowed!
|
|
elif name in ("ndim", "itemsize"):
|
|
return ConstantVariable.create(getattr(example_ndarray, name), **options)
|
|
elif name in ("shape", "stride"):
|
|
if not free_symbols(r := getattr(example_ndarray, name)):
|
|
return ConstantVariable.create(tuple(int(r) for r in r), **options)
|
|
return insert_into_graph()
|
|
elif name == "size":
|
|
if not free_symbols(r := example_ndarray.size):
|
|
return ConstantVariable.create(int(r), **options)
|
|
return insert_into_graph()
|
|
elif name in ["base", "flags", "dtype"]:
|
|
unimplemented(f"TODO: add support for ndarray.{name}")
|
|
if result is None:
|
|
raise NotImplementedError()
|
|
return result
|
|
|
|
def call_method(
|
|
self,
|
|
tx,
|
|
name,
|
|
args: "List[VariableTracker]",
|
|
kwargs: "Dict[str, VariableTracker]",
|
|
) -> "VariableTracker":
|
|
options = VariableTracker.propagate([[self]], [args], [list(kwargs.values())])
|
|
from ..utils import numpy_method_wrapper
|
|
|
|
if name in ["__len__", "size", "tolist"]:
|
|
# delegate back to TensorVariable
|
|
return super().call_method(tx, name, args, kwargs)
|
|
proxy = tx.output.create_proxy(
|
|
"call_function",
|
|
numpy_method_wrapper(name),
|
|
*proxy_args_kwargs([self] + list(args), kwargs),
|
|
)
|
|
return NumpyNdarrayVariable.create(tx, proxy, **options)
|
|
|
|
def python_type(self):
|
|
return np.ndarray
|
|
|
|
|
|
class UnspecializedPythonVariable(TensorVariable):
|
|
"""
|
|
This is a 1-element tensor represents unspecialized python float/int.
|
|
"""
|
|
|
|
def __init__(self, proxy: torch.fx.Proxy, **kwargs):
|
|
raw_value = kwargs.pop("raw_value", None)
|
|
need_unwrap = kwargs.pop("need_unwrap", True)
|
|
super().__init__(proxy, **kwargs)
|
|
self.raw_value = raw_value
|
|
self.need_unwrap = need_unwrap
|
|
|
|
@classmethod
|
|
def from_tensor_variable(cls, tensor_variable, raw_value, need_unwrap=True):
|
|
# Convert a `TensorVariable` instance into an `UnspecializedPythonVariable` instance.
|
|
return UnspecializedPythonVariable(
|
|
**dict(tensor_variable.__dict__),
|
|
raw_value=raw_value,
|
|
need_unwrap=need_unwrap,
|
|
)
|
|
|
|
def as_specialized(self, tx):
|
|
for graph_arg in tx.output.graphargs:
|
|
if graph_arg.source is self.source:
|
|
graph_arg.erase()
|
|
|
|
for g in self.guards:
|
|
if g.is_volatile:
|
|
g.create_fn = GuardBuilder.CONSTANT_MATCH
|
|
|
|
return ConstantVariable.create(value=self.raw_value, guards=self.guards)
|
|
|
|
|
|
class FakeItemVariable(TensorVariable):
|
|
"""An unspecialized python variable which prevents access to the underlying raw value.
|
|
This is needed if item is called on a FakeTensor."""
|
|
|
|
def __init__(self, proxy: torch.fx.Proxy, **kwargs):
|
|
need_unwrap = kwargs.pop("need_unwrap", False)
|
|
super().__init__(proxy, **kwargs)
|
|
self.need_unwrap = need_unwrap
|
|
|
|
@classmethod
|
|
def from_tensor_variable(cls, tensor_variable):
|
|
return FakeItemVariable(**dict(tensor_variable.__dict__))
|
|
|
|
|
|
class TensorSubclassVariable(VariableTracker):
|
|
def __init__(self, value, *args, **kwargs):
|
|
self.value = value
|
|
super().__init__(*args, **kwargs)
|
|
|
|
def call_function(
|
|
self, tx, args: List[VariableTracker], kwargs: Dict[str, VariableTracker]
|
|
) -> VariableTracker:
|
|
if len(args) == 1 and isinstance(args[0], TensorVariable):
|
|
from .builder import VariableBuilder
|
|
from .torch_function import TensorWithTFOverrideVariable
|
|
|
|
torch_fn = VariableBuilder(
|
|
tx, AttrSource(self.source, "__torch_function__")
|
|
)(self.value.__torch_function__)
|
|
return TensorWithTFOverrideVariable.create(
|
|
tx,
|
|
args[0],
|
|
torch_fn,
|
|
self.value,
|
|
)
|
|
|
|
return super().call_function(tx, args, kwargs)
|