pytorch/torch/_dynamo/variables/tensor.py
Michael Voznesensky 02f6a8126e Support a simple subset of functions as backward hooks on intermediate tensors (#109537)
The main thrust of the initial effort here was to capture `register_hook` calls on tensors in compile regions. The first part of this was done in https://github.com/pytorch/pytorch/pull/108903 wherein we added support for register_hook input tensors.

The distinction between input and intermediary is due to implementation differences.

There are 2 kinds of hooks:

1) Hooks on objects with sources (inputs, params)
2) Hooks on objects w/o sources (intermediaries, and outputs).

Note: As outputs can be made simple by how dynamo handles residuals, they could actually be handled as if they were inputs, but, for the sake of this PR, we will refer to hooks as either hooks on inputs (sourced), or hooks on intermediaries (not sourced).

**The plan:**

For tensors w/ a source: (The PR above)
We record registered hooks, store them as a global, and associate them with the tensor in residuals. This means that when dynamo goes to create the frame, where we produce bytecode to stitch together our PT2 modified bytecode with the original eager code, we call register_hook. This registration of hooks in residuals is sound because (a) it happens right after a Pt2 frame region ends and (b) we know that the tensor is alive in f_locals, f_globals, or a module in the users invoking frame. This means we can soundly know it will be around to invoke register_hook on. As long as we guard on the identity of the lifted function, this is sound to do.

For tensors w/o a source: (This PR)

Ostensibly, the most correct and complete solution would be to smuggle hooks into a runtime wrapper in aot_autograd, where all the items the hooks close over are lifted to inputs as necessary and passed alongside the user provided function. This is necessary so that we can properly trace out and capture all the mutations within the user defined hook at backwards time.

This is too complicated - so, we limited the scope of this initial PR to a simple subset of hooks:

- Hooks must have a source (be known to us already, not a lambda or intermediary defined function)
- We must be tracing under compiled autograd

**The flow**:

We use the HOP added in https://github.com/pytorch/pytorch/pull/109690/files, referred to as the HOP below.

1) We intercept register_hook calls and wrap the user defined fn in the HOP
2) We write a `_register_hook_trampoline` to the graph that is a local no-arg function that is invoked as a call_function in the dynamo graph
3) aot_autograd inlines through it during its trace, and sees the HOP
4) the HOP preserves itself in the graph - it does not get traced into
5) During backwards, compiled_autograd installs the HOP under a hook call
6) When compiled_autograd enters compilation over its generated graph, dynamo traces the contents of the hook

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109537
Approved by: https://github.com/ezyang
2023-10-11 01:35:37 +00:00

1152 lines
44 KiB
Python

import functools
import inspect
import operator
import types
from typing import Dict, List
try:
import numpy as np
except ModuleNotFoundError:
np = None
import sympy
import torch._numpy as tnp
import torch.fx
import torch.random
from torch._dynamo import compiled_autograd
from torch.fx.experimental.symbolic_shapes import free_symbols, guard_scalar, SymTypes
from .. import config, variables
from .._trace_wrapped_higher_order_op import trace_wrapped
from ..exc import unimplemented
from ..guards import GuardBuilder
from ..source import AttrSource
from ..utils import (
fqn,
get_custom_getattr,
get_fake_value,
get_real_value,
guard_if_dyn,
object_has_getattribute,
product,
proxy_args_kwargs,
tensortype_to_dtype,
)
from .base import VariableTracker
from .constant import ConstantVariable
from .lists import SizeVariable
supported_tensor_comparison_ops = {
">": operator.gt,
"<": operator.lt,
">=": operator.ge,
"<=": operator.le,
"==": operator.eq,
"!=": operator.ne,
}
supported_const_comparison_ops = {
"is": operator.is_,
"is not": operator.is_not,
"==": operator.eq,
"!=": operator.ne,
}
class TensorVariable(VariableTracker):
"""A torch.Tensor input or an intermediate value in the FX graph"""
_nonvar_fields = [
"proxy",
"dtype",
"device",
"layout",
"ndim",
"size",
"stride",
"requires_grad",
"is_quantized",
"is_contiguous",
]
def get_real_value(self):
"""
Get the actual value represented by this variable if computation is run
using the user-provided inputs.
NOTE: this runs actual tensor computation and may be
slow and memory-intensive.
"""
return get_real_value(self.proxy.node, self.proxy.tracer)
def __init__(
self,
proxy: torch.fx.Proxy,
*,
dtype,
device,
layout,
ndim,
requires_grad,
is_quantized,
is_sparse,
class_type,
size=None,
stride=None,
is_contiguous=None,
specialized_value=None,
**kwargs,
):
super().__init__(**kwargs)
self.proxy = proxy
self.dtype = dtype
self.device = device
self.layout = layout
self.ndim = ndim
self.size = size
self.stride = stride
self.requires_grad = requires_grad
self.is_quantized = is_quantized
self.is_contiguous = is_contiguous
self.is_sparse = is_sparse
self.class_type = class_type
self.specialized_value = specialized_value
def as_proxy(self):
return self.proxy
def python_type(self):
return self.class_type
def call_isinstance(self, tensor_type):
def check_type(ty):
if ty not in tensortype_to_dtype:
return issubclass(self.python_type(), ty)
dtypes = tensortype_to_dtype[ty]
return self.dtype in dtypes
if type(tensor_type) is tuple:
return any(check_type(ty) for ty in tensor_type)
else:
return check_type(tensor_type)
@staticmethod
def specialize(value: torch.Tensor):
props = {
"dtype": value.dtype,
"device": value.device,
"layout": value.layout,
"ndim": int(value.ndim),
"requires_grad": value.requires_grad,
"is_quantized": value.is_quantized,
"is_sparse": value.is_sparse,
"class_type": type(value),
}
if not free_symbols(value):
# this is a fully static shape, and the keys on props here inform specialization.
# We have to cast to int here, because these might get accessed as ConstantVariable, which has
# a strict no-symint policy. If we got here due to not having free symbols, this is a known constant
# already. We could remove the discrepancy here, by having ConstantVariable be more permissive for
# constant backed SymInts, but that assert being strict has led to some good signal in hunting bugs, and
# I'd like to keep it around for now.
props["size"] = tuple([int(s) for s in value.size()])
props["stride"] = tuple(value.stride())
props["is_contiguous"] = tuple(
[
x
for x in torch._prims_common._memory_formats
if value.is_contiguous(memory_format=x)
]
)
return props
def dynamic_getattr(self, tx, name):
if not self.source:
raise NotImplementedError()
# For local source, we associate the real value. We use this real value
# for implementing getattr fallthrough on the variable tracker base class.
# Note - this scope construction is mirrored in guards
# A subsequent PR will introduce a util.
scope = {"L": tx.output.local_scope, "G": tx.output.global_scope}
try:
# We raise in case we get a typerror bug w/ SuperSource.
# SuperSource has bugs in it atm, and can produce code like
# eval("super(L['mod'].model.model.encoder.embed_positions.forward__class__,
# L['mod'].model.model.encoder.embed_positions)", scope)
# Which is incorrect, and violates the invariant that all sources should be eval()-able against the scope.
_input_associated_real_value = eval(self.source.name(), scope)
except Exception:
raise NotImplementedError()
if _input_associated_real_value is None:
raise NotImplementedError()
if object_has_getattribute(_input_associated_real_value):
raise NotImplementedError()
if get_custom_getattr(_input_associated_real_value):
raise NotImplementedError()
real_value = getattr(_input_associated_real_value, name)
if callable(real_value):
# Callables have more nuanced handling, and we should let the existing system delegate here.
# Raising was past behavior and so should always be sound to fall back.
# Note - at a certain point we may want to handle
raise NotImplementedError()
from ..guards import GuardBuilder
from .builder import VariableBuilder
attr_source = AttrSource(self.source, name)
has_attr_guard = attr_source.make_guard(GuardBuilder.HASATTR)
return (
VariableBuilder(tx, attr_source)(real_value)
.add_options(self)
.add_guard(has_attr_guard)
)
def var_getattr(self, tx, name):
from . import ConstantVariable, TorchVariable
if tx.strict_checks_enabled:
if name in self._strict_mode_banned_ops():
unimplemented(f"Illegal getattr invocation {name} in strict mode")
result = None
options = VariableTracker.propagate(self)
if name == "ndim" and self.ndim is not None:
result = ConstantVariable.create(self.ndim, **options)
elif name == "dtype" and self.dtype is not None:
result = TorchVariable(self.dtype, **options)
elif name == "device" and self.device is not None:
result = TorchVariable(self.device, **options)
elif name == "layout" and self.layout is not None:
result = TorchVariable(self.layout, **options)
elif name == "is_cuda" and self.device is not None:
result = ConstantVariable.create(self.device.type == "cuda", **options)
elif name == "shape" and self.size is not None:
sizes = [variables.ConstantVariable.create(x) for x in self.size]
result = SizeVariable(sizes, **options)
elif name == "requires_grad" and self.requires_grad is not None:
result = ConstantVariable.create(self.requires_grad, **options)
elif name == "is_quantized" and self.is_quantized is not None:
result = ConstantVariable.create(self.is_quantized, **options)
elif name == "is_sparse" and self.is_sparse is not None:
result = ConstantVariable.create(self.is_sparse, **options)
elif name == "shape" and self.size is None:
result = self.call_method(tx, "size", [], {})
elif name == "ndim" and self.ndim is None:
result = self.call_method(tx, "dim", [], {})
elif name == "data":
result = self.call_method(tx, "detach", [], {})
if name == "__class__":
return TorchVariable(self.python_type(), **options)
# Add a guard for type matching, these guards are checked before tensor guards
# In some cases, a <tensor>.<attr> guard can be evaluated first, and break if
# <tensor> is later changed to another type
if result is not None and self.source is not None:
result = result.add_guard(self.make_guard(GuardBuilder.TYPE_MATCH))
# It's hard to get inplace view (metadata mutation) on graph input work properly across
# dynamo/aot/inductor, just fall back.
if self.source is not None and hasattr(torch.ops.aten, name):
fn = getattr(torch.ops.aten, name)
if (
hasattr(fn, "overloads")
and hasattr(fn, fn.overloads()[0])
and torch.Tag.inplace_view in getattr(fn, fn.overloads()[0]).tags
):
# Delay the graph break to the actual call of unsqueeze_/resize_/resize_as_ etc.
return variables.misc.DelayGraphBreakVariable()
# For attributes (not methods) that were not caught in the special handling above,
# (e.g. tensor.real), we handle these generically, assuming that the output type is
# a tensor.
if result is None:
def try_generic_attr_handling():
from .builder import wrap_fx_proxy
from .misc import GetAttrVariable
try:
static_attr = inspect.getattr_static(torch.Tensor, name)
except AttributeError:
return None
# Make sure this is an attribute, not a method.
# type(torch.Tensor.H) should be "getset_descriptor"
# This is a because of CPython implementation, see THPVariableType:
# these attributes are implemented under tp_getset, which appear
# as `getset_descriptor`s, (compared to, say, methods which appear
# as `method_descriptor`s)
if type(static_attr) != types.GetSetDescriptorType:
return None
return wrap_fx_proxy(
tx=tx,
proxy=GetAttrVariable.create_getattr_proxy(self.as_proxy(), name),
**options,
)
result = try_generic_attr_handling()
if result is None:
result = self.dynamic_getattr(tx, name)
if result is None:
raise NotImplementedError()
return result
def has_unpack_var_sequence(self, tx):
return self.ndim > 0
def unpack_var_sequence(self, tx, idxes=None):
from .builder import wrap_fx_proxy
options = VariableTracker.propagate(self)
if idxes is None:
if self.size:
length = self.size[0]
else:
dyn_length = self.call_method(
tx, "size", [ConstantVariable.create(0)], {}
)
# SymNodeVariable for symbolic sizes, ConstantVariable for constants OR values produced through
# symbolic_shapes, but that end up as int/sympy.Integer
assert isinstance(dyn_length, (SymNodeVariable, ConstantVariable))
if isinstance(dyn_length, SymNodeVariable):
length = dyn_length.evaluate_expr(tx.output)
else:
length = dyn_length.value
idxes = range(length)
return [wrap_fx_proxy(tx, self.as_proxy()[i], **options) for i in idxes]
def _strict_mode_banned_ops(self):
return torch._dynamo.config._autograd_backward_strict_mode_banned_ops
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
if tx.strict_checks_enabled:
if name in self._strict_mode_banned_ops():
unimplemented(f"Illegal method invocation {name} in strict mode")
from . import ConstantVariable, TorchVariable, TupleVariable
from .builder import wrap_fx_proxy
kwargs = dict(kwargs)
options = VariableTracker.propagate(self, args, kwargs.values())
if name in ("stride", "size"):
dim_var = None
if len(args) == 1:
dim_var = args[0]
elif "dim" in kwargs:
dim_var = kwargs["dim"]
else:
assert not args and not kwargs, f"Tensor.{name}() unhandled args/kwargs"
dim = guard_if_dyn(dim_var)
def make_const_size_variable(x, **options):
return SizeVariable(
[ConstantVariable.create(y, **options) for y in x], **options
)
RetVariable = (
make_const_size_variable if name == "size" else ConstantVariable.create
)
# Technically, this should not be necessary, but I'm including it
# for enhanced BC, in case example_value is sometimes not set
# (it really should always be set though!)
if (r := getattr(self, name)) is not None:
if dim is None:
return RetVariable(r, **options)
else:
return ConstantVariable.create(r[dim], **options)
# It might still be constant! Consult the fake tensor and see
if (fake := self.proxy.node.meta.get("example_value")) is not None:
if dim is None:
fake_r = getattr(fake, name)()
if not free_symbols(fake_r):
# int conversion for safety, in case a SymInt refined
# to constant
return RetVariable(tuple(int(r) for r in fake_r), **options)
else:
fake_r = getattr(fake, name)(dim)
if not free_symbols(fake_r):
return ConstantVariable.create(int(fake_r), **options)
# Oops, it's not constant. Do the dynamic shapes path.
return wrap_fx_proxy(
tx,
tx.output.create_proxy(
"call_method",
name,
*proxy_args_kwargs([self] + list(args), kwargs),
),
**options,
)
elif name in ("numel", "nelement"):
if self.size is not None:
return ConstantVariable.create(product(self.size), **options)
# It might still be constant! Consult the fake tensor and see
if (fake := self.proxy.node.meta.get("example_value")) is not None:
fake_r = fake.numel()
if not free_symbols(fake_r):
return ConstantVariable.create(int(fake_r), **options)
assert not kwargs, f"Tensor.{name}() unhandled kwargs"
# Oops, it's not constant. Do the dynamic shapes path.
return wrap_fx_proxy(
tx,
tx.output.create_proxy(
"call_method",
"numel",
*proxy_args_kwargs([self] + list(args), kwargs),
),
**options,
)
elif name in ("ndimension", "dim") and self.ndim is not None:
constant_result = ConstantVariable.create(self.ndim, **options)
elif name == "is_floating_point" and self.dtype is not None:
constant_result = ConstantVariable.create(
self.dtype.is_floating_point, **options
)
elif name == "is_contiguous" and self.is_contiguous is not None:
if "memory_format" in kwargs:
memory_format = kwargs.pop("memory_format").as_python_constant()
else:
memory_format = torch.contiguous_format
constant_result = ConstantVariable.create(
memory_format in self.is_contiguous, **options
)
elif (
name == "type"
and self.dtype is not None
and len(args) == 0
and isinstance(self.device, torch.device)
):
tensortype = [k for k, v in tensortype_to_dtype.items() if self.dtype in v][
0
]
if self.device.type == "cuda":
constant_result = ConstantVariable.create(
f"torch.cuda.{tensortype.__name__}", **options
)
else:
constant_result = ConstantVariable.create(
f"torch.{tensortype.__name__}", **options
)
elif (
name == "type"
and len(args) == 1
and fqn(type(args[0].as_python_constant())) == "torch.tensortype"
):
# torch.FloatTensor, etc. are all of type "torch.tensortype".
# torch.fx's tracer fails on these types, because it doesn't support arguments of torch.tensortype type.
# So, we pass it in as a string (which is also supported, see above implementation for .type() with 0 args)
tensor_type = args[0].as_python_constant()
tensor_type_const = ConstantVariable.create(fqn(tensor_type), **options)
return wrap_fx_proxy(
tx,
tx.output.create_proxy(
"call_method",
name,
*proxy_args_kwargs([self, tensor_type_const], kwargs),
),
**options,
)
elif name == "get_device" and isinstance(self.device, torch.device):
index = self.device.index if self.device.type != "cpu" else -1
constant_result = ConstantVariable.create(index, **options)
else:
constant_result = None
if constant_result:
assert not kwargs, f"Tensor.{name}() unhandled kwargs"
# TODO: I think this branch is dead
if len(args) == 1:
return constant_result.getitem_const(args[0])
elif args:
return TupleVariable(
[constant_result.getitem_const(a) for a in args], **options
)
return constant_result
elif name == "numpy":
if not config.trace_numpy:
unimplemented("Tensor.numpy(). config.trace_numpy is False")
if not np:
unimplemented("Tensor.numpy(). NumPy is not available")
assert not args, "Tensor.numpy() doesn't take args."
if self.layout != torch.strided:
raise TypeError(
f"can't convert {self.layout} layout tensor to numpy. Use Tensor.dense() first"
)
# We don't check that the tensor is on CPU when force is False, as this
# allows us to execute NumPy code on CUDA.
# We don't check that requires_grad=False as we are currently doing an
# unconditional detach.
# TODO: We may want to avoid detaching if `requires_grad=True`
# and `force=False` to allow computing gradients.
force = "force" in kwargs and kwargs["force"].as_python_constant()
proxy = tx.output.create_proxy(
"call_method", "detach", *proxy_args_kwargs([self], {})
)
if force:
# TODO Add resolve_conj and resolve_neg once we support complex tensors
proxy = tx.output.create_proxy(
"call_method", "cpu", *proxy_args_kwargs([self], {})
)
return NumpyNdarrayVariable.create(tx, proxy, **options)
elif name == "tolist":
from .builder import SourcelessBuilder
def tolist(tensor, sub_proxy):
def wrap(i, sub_proxy):
return SymNodeVariable.create(
tx,
sub_proxy.item(),
sym_num=tx.output.shape_env.create_unbacked_symint(),
)
if tensor.dtype not in [
torch.int8,
torch.int16,
torch.int32,
torch.int64,
]:
unimplemented("Input tensor for tolist must be an integer tensor")
if tensor.dim() == 0:
return wrap(tensor, sub_proxy)
if tensor.dim() == 1:
return [wrap(val, sub_proxy[i]) for i, val in enumerate(tensor)]
return [
tolist(sub_tensor, sub_proxy=sub_proxy[i])
for i, sub_tensor in enumerate(tensor)
]
tensor = self.as_proxy().node.meta["example_value"]
out = tolist(tensor, self.as_proxy())
return SourcelessBuilder()(tx, out).add_options(options)
elif name in ("backward", "data_ptr"):
unimplemented(f"Tensor.{name}")
elif name == "item" and not config.capture_scalar_outputs:
unimplemented(f"Tensor.{name}")
elif name == "__len__":
return self.call_method(
tx, "size", [ConstantVariable.create(0, **options)], {}
)
elif name == "__setitem__":
key, value = args
def has_bool_key(v):
if isinstance(v, TensorVariable):
return v.dtype in (torch.bool, torch.int8)
elif isinstance(v, TupleVariable):
return any(has_bool_key(item) for item in v.items)
else:
return False
if (
not config.capture_dynamic_output_shape_ops
and has_bool_key(key)
and isinstance(value, TensorVariable)
and value.requires_grad
):
unimplemented(
"boolean masking setitem backwards requires dynamic shapes"
)
tx.output.guards.update(options["guards"])
tx.output.create_proxy(
"call_function",
operator.setitem,
*proxy_args_kwargs([self] + list(args), kwargs),
)
return ConstantVariable.create(None, **options)
elif name in ("resize_", "resize_as_"):
if "memory_format" in kwargs:
memory_format = kwargs["memory_format"].as_python_constant()
else:
memory_format = torch.contiguous_format
if name == "resize_":
self.size = args[0].as_python_constant()
self.is_contiguous = (memory_format,)
else:
assert isinstance(args[0], TensorVariable)
if self.size and args[0].size:
if (
self.size == args[0].size
or memory_format is torch.preserve_format
):
self.is_contiguous = args[0].is_contiguous
else:
self.size = args[0].size
self.stride = args[0].stride
self.ndim = args[0].ndim
self.is_contiguous = (memory_format,)
return wrap_fx_proxy(
tx,
tx.output.create_proxy(
"call_method",
name,
*proxy_args_kwargs([self] + list(args), kwargs),
),
**options,
)
elif (
name == "add_" and len(args) == 1 and len(kwargs) == 1 and "alpha" in kwargs
):
result = TorchVariable(torch.mul, **options).call_function(
tx, args + [kwargs["alpha"]], {}
)
return self.call_method(tx, "add_", [result], {})
elif (
name == "addcdiv_"
and len(args) == 2
and len(kwargs) == 1
and "value" in kwargs
):
result = TorchVariable(torch.div, **options).call_function(tx, args, {})
result = TorchVariable(torch.mul, **options).call_function(
tx, [result, kwargs["value"]], {}
)
return self.call_method(tx, "add_", [result], {})
elif name == "__contains__":
# Rewrite __contains__ here so that downstream passes can trace through
# without dealing with unbacked symbool. Roughly the code we translate is:
# def __contains__(self, x):
# return (x == self).any().item()
result = TorchVariable(torch.eq, **options).call_function(
tx, [self, args[0]], {}
)
result = TorchVariable(torch.any, **options).call_function(tx, [result], {})
return result.call_method(tx, "item", [], {})
elif name == "redistribute":
# rewrite non-primitive args/kwargs to be included in the on-the-fly prim function
# and rewrite args to have only proxyable args, then insert call_function
args_as_value = [x.as_python_constant() for x in args]
kwargs_as_value = {k: v.as_python_constant() for k, v in kwargs.items()}
def redistribute_fn_with_prim_types(x):
return x.redistribute(*args_as_value, **kwargs_as_value)
# attach the same function name for better debugging
redistribute_fn_with_prim_types.__name__ = f"prim_{name}"
return wrap_fx_proxy(
tx=tx,
proxy=tx.output.create_proxy(
"call_function",
redistribute_fn_with_prim_types,
*proxy_args_kwargs([self], {}),
),
**options,
)
elif name == "register_hook":
# see [On tensor.register_hook]
assert len(args) == 1
fn_var = args[0]
if not isinstance(
fn_var,
(
variables.functions.FunctoolsPartialVariable,
variables.UserFunctionVariable,
variables.TorchVariable,
variables.NNModuleVariable,
),
):
unimplemented("Unexpected callable type passed to register_hook")
# Guards from the fn_var
options.update(VariableTracker.propagate(fn_var))
if isinstance(fn_var, variables.NestedUserFunctionVariable):
# NestedUserFunctionVariable don't carry their fn, but reconstruction builds it
# This should not be onerous to support when needed.
unimplemented("NYI - lambda variables as hooks")
elif isinstance(fn_var, variables.functions.FunctoolsPartialVariable):
fn = fn_var.as_python_constant()
name = fn_var.func.fn.__name__
else:
fn = fn_var.fn
name = fn_var.fn.__name__
handle_variable = variables.user_defined.RemovableHandleVariable(
mutable_local=variables.base.MutableLocal(),
**options,
)
if not self.source:
# Intermediary
src = fn_var.source
if (
not src
and isinstance(fn_var, variables.functions.FunctoolsPartialVariable)
and fn_var.func.source
):
src = fn_var.func.source
if not src:
unimplemented("No source for register_hook target fn")
tx.output.guards.add(src.make_guard(GuardBuilder.ID_MATCH))
if not compiled_autograd.compiled_autograd_enabled:
# TODO(voz):
# We can relax this by speculating the callable and ensuring that it doesn't modify arbitrary
# python state.
# We *Must* be in compiled_autograd here because backward hooks can contain anything, and it is unsafe to run
# them in a compiled bwd without re-entering dynamo as compiled_autograd does.
#
# Discussion point 1 - Should we bypass this if nopython/fullgraph = True?
# No. Because this was going to be a graph break anyway - this check does not
# introduce new graph breaks where there were none.
#
# Discussion point 2 - Should we defer this check to backwards?
# No. Because compiled autograd is not yet ready for prime time. As such, if we defer, a user
# would have no recourse - their forward traces just fine, but will fail at backwards unless
# compiled_autograd is enabled. If compiled_autograd fails (there are a lot of failures today)
# then they have nothing they can do except disable compile.
unimplemented(
"Compilation of intermediate hooks requires compiled autograd"
)
# This wraps our user provided fn with a function that intercedes and
# uses our `invoke` higher order op to record a hook invocation in bwd graph.
fn = functools.partial(trace_wrapped, fn=fn)
def _register_hook_trampoline(tensor):
tensor.register_hook(fn)
return tensor
return wrap_fx_proxy(
tx,
tx.output.create_proxy(
"call_function",
_register_hook_trampoline,
(self.as_proxy(),),
{},
),
**options,
)
tx.output.side_effects.register_hook(self, fn_var, handle_variable)
return handle_variable
elif name == "requires_grad_" and self.as_proxy().node.meta[
"example_value"
].requires_grad != (args[0].value if len(args) > 0 else True):
unimplemented("Tensor.requires_grad_")
else:
# Convert x.new(torch.Size) into x.new_empty(torch.Size),
# as Tensor.new acts differently with a Size input versus a tuple input.
if name == "new" and len(args) == 1 and isinstance(args[0], SizeVariable):
name = "new_empty"
return wrap_fx_proxy(
tx,
tx.output.create_proxy(
"call_method",
name,
*proxy_args_kwargs([self] + list(args), kwargs),
),
**options,
)
def rename(self, tx, name):
self.proxy.node._rename(name)
return super().rename(tx, name)
class SymNodeVariable(VariableTracker):
"""
Represents a symbolic size, e.g., as returned by tensor.size(0)
"""
@classmethod
def create(cls, tx, proxy, sym_num, **options):
if "example_value" in proxy.node.meta:
assert proxy.node.meta["example_value"] == sym_num
if sym_num is None:
sym_num = get_fake_value(proxy.node, tx)
proxy.node.meta["example_value"] = sym_num
if isinstance(sym_num, (sympy.Integer, int)):
return ConstantVariable.create(int(sym_num))
return SymNodeVariable(proxy, sym_num, **options)
def __init__(self, proxy, sym_num, **kwargs):
super().__init__(**kwargs)
self.proxy = proxy
# TODO: Should we allow non SymTypes here? Today it is allowed
self.sym_num = sym_num
def python_type(self):
if isinstance(self.sym_num, SymTypes):
return self.sym_num.node.pytype
else:
return type(self.sym_num)
def unpack_var_sequence(self, tx):
super().unpack_var_sequence(tx)
def as_proxy(self):
return self.proxy
def evaluate_expr(self, output_graph=None):
return guard_scalar(self.sym_num)
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
from .builder import wrap_fx_proxy
options = VariableTracker.propagate(self, args, kwargs.values())
return wrap_fx_proxy(
tx,
tx.output.create_proxy(
"call_method",
name,
*proxy_args_kwargs([self] + list(args), kwargs),
),
**options,
)
class TensorWithTFOverrideVariable(VariableTracker):
"""
Represents a tensor subclass instance with a __torch_function__ override.
"""
@staticmethod
def create(
tx,
tensor_variable,
orig_tensor_variable_source,
torch_function_fn,
subclass_type,
**kwargs,
):
var = TensorWithTFOverrideVariable(
tensor_variable,
orig_tensor_variable_source,
torch_function_fn,
subclass_type,
**kwargs,
)
# stash the subclass type to rewrap an output tensor if needed
tx.output.install_global(var.global_class_name(), subclass_type)
return var
def __init__(
self,
tensor_variable,
orig_tensor_variable_source,
subclass_torch_function__func,
subclass_type,
**kwargs,
):
super().__init__(**kwargs)
self.tensor_variable = tensor_variable
self.orig_tensor_variable_source = orig_tensor_variable_source
self.subclass_torch_function__func = subclass_torch_function__func
self.subclass_type = subclass_type
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
# This code block implements inlining the __torch_function__ override
# of `call_method`.
from . import GetAttrVariable
options = VariableTracker.propagate(self, args, kwargs.values())
# insert unwrapped version of self as the first argument
# TODO: This is wrong! When you call the internal __torch_function__,
# you still get the wrapped version of self, and if you call functions
# inside __torch_function__, they should come back here. If we unwrap
# the tensor immediately, that will not happen.
# See https://github.com/pytorch/torchdynamo/issues/1951
args = list(args)
args.insert(0, self.tensor_variable)
func_var = GetAttrVariable(self.tensor_variable, name)
unwrapped = TensorWithTFOverrideVariable.inline_torch_function_unwrapped(
tx,
func_var,
self.orig_tensor_variable_source,
self.subclass_torch_function__func,
self.subclass_type,
options,
args,
kwargs,
)
# TODO(future PR): implement rewrapping conditional on method presence
# in `torch.overrides.get_default_nowrap_function()`. It's unclear how
# to do this easily in the current codebase since the resolution of
# `GetAttrVariable` depends on the type of the underlying object.
return TensorWithTFOverrideVariable(
unwrapped,
self.orig_tensor_variable_source,
self.subclass_torch_function__func,
self.subclass_type,
)
def global_class_name(self):
return f"__subclass_{self.subclass_type.__name__}"
@staticmethod
def inline_torch_function_unwrapped(
tx,
original_func_var,
tensor_with_tf_override_source,
tf_func,
subclass_type,
options,
args,
kwargs,
):
"""
This function inlines the `__torch_function__` override for `original_func_var`.
For example, if the user code is
x1 = torch.sigmoid(x0)
And `x0` has an override, then:
* `original_func_var` will be a `VariableTracker` object wrapping `torch.sigmoid`
* `tensor_with_tf_override_source` will be the `Source` object from
the original tensor override instance in the beginning of the program
* `tf_func` will be the custom `__torch_function__` function
* `subclass_type` will be `type(x0)`
The caller is expected to properly massage args and kwargs before
passing them into this function.
The caller is responsible for wrapping the return value, if needed.
"""
from . import UserDefinedClassVariable
from .builder import TupleVariable, VariableBuilder
source = AttrSource(
AttrSource(tensor_with_tf_override_source, "__torch_function__"),
"__func__",
)
tf_func_var = VariableBuilder(tx, source)(tf_func)
type_var = UserDefinedClassVariable(subclass_type, **options)
# signature:
# def __torch_function__(cls, func, types, args=(), kwargs=None):
tf_args = (
type_var, # cls
original_func_var, # func
(type_var,), # types
TupleVariable(args), # args
kwargs, # kwargs
)
# Disable __torch_function__ here to prevent the clone of the
# example tensor from going into the override.
with torch._C.DisableTorchFunctionSubclass():
return tx.inline_user_function_return(tf_func_var, tf_args, {})
class NumpyNdarrayVariable(TensorVariable):
"""
Represents an np.ndarray, but backed by torch Tensor via torch._numpy.ndarray.
Use this for Tensor.numpy() call.
"""
@staticmethod
def create(tx, proxy, **options):
from .builder import wrap_fx_proxy_cls
return wrap_fx_proxy_cls(
target_cls=NumpyNdarrayVariable,
tx=tx,
proxy=proxy,
**options,
)
def var_getattr(self, tx, name):
# NB: This INTENTIONALLY does not call super(), because there is
# no intrinsic reason ndarray properties are related to Tensor
# properties. The inheritance here is for implementation sharing.
from ..utils import numpy_attr_wrapper
from .builder import wrap_fx_proxy
result = None
options = VariableTracker.propagate(self)
example_value = self.as_proxy().node.meta["example_value"]
example_ndarray = tnp.ndarray(example_value)
def insert_into_graph():
return wrap_fx_proxy(
tx,
tx.output.create_proxy(
"call_function", numpy_attr_wrapper, (self.as_proxy(), name), {}
),
**options,
)
if name in ["T", "real", "imag"]:
proxy = tx.output.create_proxy(
"call_function",
numpy_attr_wrapper,
(self.as_proxy(), name),
{},
)
result = NumpyNdarrayVariable.create(tx, proxy, **options)
# These are awkward to implement. The standard playbook for torch._numpy
# interop is to trace a call into the torch._numpy wrapper which works for
# Tensor operations. However, we don't want to do this for calls
# that don't return Tensors, because in those cases we may not want
# to trace the attribute access into the graph at all (it is sort
# of harmless to do so, because AOTAutograd will eliminate them,
# but it's best not to trace them in to begin with.) But in any
# case, tracing these into the graph is like trying to fit a square
# peg into a round hole; best not to do it. So instead we
# painstakingly implement these by hand
#
# NB: only ALWAYS specialized attributes can go here; notably,
# size/shape not allowed!
elif name in ("ndim", "itemsize"):
return ConstantVariable.create(getattr(example_ndarray, name), **options)
elif name in ("shape", "stride"):
if not free_symbols(r := getattr(example_ndarray, name)):
return ConstantVariable.create(tuple(int(r) for r in r), **options)
return insert_into_graph()
elif name == "size":
if not free_symbols(r := example_ndarray.size):
return ConstantVariable.create(int(r), **options)
return insert_into_graph()
elif name in ["base", "flags", "dtype"]:
unimplemented(f"TODO: add support for ndarray.{name}")
if result is None:
raise NotImplementedError()
return result
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
options = VariableTracker.propagate([[self]], [args], [list(kwargs.values())])
from ..utils import numpy_method_wrapper
if name in ["__len__", "size"]:
# delegate back to TensorVariable
return super().call_method(tx, name, args, kwargs)
proxy = tx.output.create_proxy(
"call_function",
numpy_method_wrapper(name),
*proxy_args_kwargs([self] + list(args), kwargs),
)
return NumpyNdarrayVariable.create(tx, proxy, **options)
def python_type(self):
return np.ndarray
class UnspecializedPythonVariable(TensorVariable):
"""
This is a 1-element tensor represents unspecialized python float/int.
"""
def __init__(self, proxy: torch.fx.Proxy, **kwargs):
raw_value = kwargs.pop("raw_value", None)
need_unwrap = kwargs.pop("need_unwrap", True)
super().__init__(proxy, **kwargs)
self.raw_value = raw_value
self.need_unwrap = need_unwrap
@classmethod
def from_tensor_variable(cls, tensor_variable, raw_value, need_unwrap=True):
# Convert a `TensorVariable` instance into an `UnspecializedPythonVariable` instance.
return UnspecializedPythonVariable(
**dict(tensor_variable.__dict__),
raw_value=raw_value,
need_unwrap=need_unwrap,
)
def as_specialized(self, tx):
for graph_arg in tx.output.graphargs:
if graph_arg.source is self.source:
graph_arg.erase()
for g in self.guards:
if g.is_volatile:
g.create_fn = GuardBuilder.CONSTANT_MATCH
return ConstantVariable.create(value=self.raw_value, guards=self.guards)
class FakeItemVariable(TensorVariable):
"""An unspecialized python variable which prevents access to the underlying raw value.
This is needed if item is called on a FakeTensor."""
def __init__(self, proxy: torch.fx.Proxy, **kwargs):
need_unwrap = kwargs.pop("need_unwrap", False)
super().__init__(proxy, **kwargs)
self.need_unwrap = need_unwrap
@classmethod
def from_tensor_variable(cls, tensor_variable):
return FakeItemVariable(**dict(tensor_variable.__dict__))
class TensorSubclassVariable(VariableTracker):
def __init__(self, value, *args, **kwargs):
self.value = value
super().__init__(*args, **kwargs)
def call_function(
self, tx, args: List[VariableTracker], kwargs: Dict[str, VariableTracker]
) -> VariableTracker:
if len(args) == 1 and isinstance(args[0], TensorVariable):
return TensorWithTFOverrideVariable.create(
tx,
args[0],
args[0].source,
self.value.__torch_function__.__func__,
self.value,
)
return super().call_function(tx, args, kwargs)