mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
The old code didn't actually fakeify traceable tensor subclasses at the time they are added as a GraphArg to the module; now we do, by ignoring the subclass during fakeification and relying on Dynamo to simulate the subclass on top. See comments for more details. BTW, this codepath is super broken, see filed issues linked on the inside. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/90009 Approved by: https://github.com/wconstab, https://github.com/voznesenskym
609 lines
22 KiB
Python
609 lines
22 KiB
Python
import itertools
|
|
import operator
|
|
from typing import Dict, List
|
|
|
|
import torch.fx
|
|
import torch.random
|
|
|
|
from .. import config, variables
|
|
from ..exc import unimplemented
|
|
from ..guards import GuardBuilder
|
|
from ..source import AttrSource
|
|
|
|
from ..utils import (
|
|
get_fake_value,
|
|
get_real_value,
|
|
product,
|
|
proxy_args_kwargs,
|
|
tensortype_to_dtype,
|
|
)
|
|
from .base import VariableTracker
|
|
from .constant import ConstantVariable
|
|
from .lists import ShapeVariable, SizeVariable
|
|
|
|
|
|
class TensorVariable(VariableTracker):
|
|
"""A torch.Tensor input or an intermediate value in the FX graph"""
|
|
|
|
_nonvar_fields = [
|
|
"proxy",
|
|
"dtype",
|
|
"device",
|
|
"layout",
|
|
"ndim",
|
|
"size",
|
|
"stride",
|
|
"requires_grad",
|
|
"is_quantized",
|
|
"is_contiguous",
|
|
]
|
|
|
|
def get_real_value(self):
|
|
"""
|
|
Get the actual value represented by this variable if computation is run
|
|
using the user-provided inputs.
|
|
NOTE: this runs actual tensor computation and may be
|
|
slow and memory-intensive.
|
|
"""
|
|
return get_real_value(self.proxy.node, self.proxy.tracer)
|
|
|
|
def __init__(
|
|
self,
|
|
proxy: torch.fx.Proxy,
|
|
dtype=None,
|
|
device=None,
|
|
layout=None,
|
|
ndim=None,
|
|
size=None,
|
|
stride=None,
|
|
requires_grad=None,
|
|
is_quantized=None,
|
|
is_contiguous=None,
|
|
is_sparse=None,
|
|
class_type=torch.Tensor,
|
|
specialized_value=None,
|
|
**kwargs,
|
|
):
|
|
super(TensorVariable, self).__init__(**kwargs)
|
|
self.proxy = proxy
|
|
self.dtype = dtype
|
|
self.device = device
|
|
self.layout = layout
|
|
self.ndim = ndim
|
|
self.size = size
|
|
self.stride = stride
|
|
self.requires_grad = requires_grad
|
|
self.is_quantized = is_quantized
|
|
self.is_contiguous = is_contiguous
|
|
self.is_sparse = is_sparse
|
|
self.class_type = class_type
|
|
self.specialized_value = specialized_value
|
|
|
|
def as_proxy(self):
|
|
return self.proxy
|
|
|
|
def python_type(self):
|
|
return self.class_type
|
|
|
|
def call_isinstance(self, tensor_type):
|
|
def check_type(ty):
|
|
if ty not in tensortype_to_dtype:
|
|
return issubclass(self.python_type(), ty)
|
|
|
|
dtypes = tensortype_to_dtype[ty]
|
|
return self.dtype in dtypes
|
|
|
|
if type(tensor_type) is tuple:
|
|
return any([check_type(ty) for ty in tensor_type])
|
|
else:
|
|
return check_type(tensor_type)
|
|
|
|
@staticmethod
|
|
def specialize(value: torch.Tensor):
|
|
props = {
|
|
"dtype": value.dtype,
|
|
"device": value.device,
|
|
"layout": value.layout,
|
|
"ndim": int(value.ndim),
|
|
"requires_grad": value.requires_grad,
|
|
"is_quantized": value.is_quantized,
|
|
"is_sparse": value.is_sparse,
|
|
"class_type": type(value),
|
|
}
|
|
if not config.dynamic_shapes:
|
|
props["size"] = tuple(value.size())
|
|
props["stride"] = tuple(value.stride())
|
|
props["is_contiguous"] = tuple(
|
|
[
|
|
x
|
|
for x in torch._prims_common._memory_formats
|
|
if value.is_contiguous(memory_format=x)
|
|
]
|
|
)
|
|
return props
|
|
|
|
def var_getattr(self, tx, name):
|
|
from . import ConstantVariable, TorchVariable
|
|
|
|
result = None
|
|
options = VariableTracker.propagate(self)
|
|
if name == "ndim" and self.ndim is not None:
|
|
result = ConstantVariable(self.ndim, **options)
|
|
elif name == "dtype" and self.dtype is not None:
|
|
result = TorchVariable(self.dtype, **options)
|
|
elif name == "device" and self.device is not None:
|
|
result = TorchVariable(self.device, **options)
|
|
elif name == "layout" and self.layout is not None:
|
|
result = TorchVariable(self.layout, **options)
|
|
elif name == "is_cuda" and self.device is not None:
|
|
result = ConstantVariable(self.device.type == "cuda", **options)
|
|
elif name == "shape" and self.size is not None:
|
|
sizes = [variables.ConstantVariable(x) for x in self.size]
|
|
result = ShapeVariable(sizes, **options)
|
|
elif name == "requires_grad" and self.requires_grad is not None:
|
|
result = ConstantVariable(self.requires_grad, **options)
|
|
elif name == "is_quantized" and self.is_quantized is not None:
|
|
result = ConstantVariable(self.is_quantized, **options)
|
|
elif name == "is_sparse" and self.is_sparse is not None:
|
|
result = ConstantVariable(self.is_sparse, **options)
|
|
elif name == "shape" and self.size is None:
|
|
result = self.call_method(tx, "size", [], {})
|
|
elif name == "ndim" and self.ndim is None:
|
|
result = self.call_method(tx, "dim", [], {})
|
|
elif name == "data":
|
|
result = self.call_method(tx, "detach", [], {})
|
|
elif name == "T":
|
|
args = [variables.ConstantVariable(i) for i in range(self.ndim - 1, -1, -1)]
|
|
result = self.call_method(tx, "permute", args, {})
|
|
|
|
if name == "__class__":
|
|
return TorchVariable(self.python_type(), **options)
|
|
|
|
# Add a guard for type matching, these guards are checked before tensor guards
|
|
# In some cases, a <tensor>.<attr> guard can be evaluated first, and break if
|
|
# <tensor> is later changed to another type
|
|
if result is not None and self.source is not None:
|
|
result = result.add_guard(self.make_guard(GuardBuilder.TYPE_MATCH))
|
|
|
|
if result is None:
|
|
raise NotImplementedError()
|
|
|
|
return result
|
|
|
|
def unpack_var_sequence(self, tx, idxes=None):
|
|
from .builder import wrap_fx_proxy
|
|
|
|
if idxes is None:
|
|
if self.size:
|
|
idxes = range(self.size[0])
|
|
else:
|
|
return super(TensorVariable, self).unpack_var_sequence(tx)
|
|
options = VariableTracker.propagate(self)
|
|
return [wrap_fx_proxy(tx, self.as_proxy()[i], **options) for i in idxes]
|
|
|
|
def call_method(
|
|
self,
|
|
tx,
|
|
name,
|
|
args: "List[VariableTracker]",
|
|
kwargs: "Dict[str, VariableTracker]",
|
|
) -> "VariableTracker":
|
|
from . import ConstantVariable, TupleVariable
|
|
from .builder import wrap_fx_proxy
|
|
|
|
kwargs = dict(kwargs)
|
|
options = VariableTracker.propagate(self, args, kwargs.values())
|
|
if name == "stride" and self.stride is not None:
|
|
constant_result = ConstantVariable(self.stride, **options)
|
|
elif name == "size" and self.size is not None:
|
|
sizes = [variables.ConstantVariable(x) for x in self.size]
|
|
constant_result = SizeVariable(sizes, **options)
|
|
elif name == "size" and self.size is None and config.dynamic_shapes:
|
|
return wrap_fx_proxy(
|
|
tx,
|
|
tx.output.create_proxy(
|
|
"call_method",
|
|
name,
|
|
*proxy_args_kwargs([self] + list(args), kwargs),
|
|
),
|
|
**options,
|
|
)
|
|
elif name in ("numel", "nelement") and self.size is not None:
|
|
constant_result = ConstantVariable(product(self.size), **options)
|
|
elif name in ("ndimension", "dim") and self.ndim is not None:
|
|
constant_result = ConstantVariable(self.ndim, **options)
|
|
elif name == "is_floating_point" and self.dtype is not None:
|
|
constant_result = ConstantVariable(self.dtype.is_floating_point, **options)
|
|
elif name == "is_contiguous" and self.is_contiguous is not None:
|
|
if "memory_format" in kwargs:
|
|
memory_format = kwargs.pop("memory_format").as_python_constant()
|
|
else:
|
|
memory_format = torch.contiguous_format
|
|
constant_result = ConstantVariable(
|
|
memory_format in self.is_contiguous, **options
|
|
)
|
|
elif (
|
|
name == "type"
|
|
and self.dtype is not None
|
|
and len(args) == 0
|
|
and isinstance(self.device, torch.device)
|
|
):
|
|
tensortype = [k for k, v in tensortype_to_dtype.items() if self.dtype in v][
|
|
0
|
|
]
|
|
if self.device.type == "cuda":
|
|
constant_result = ConstantVariable(
|
|
f"torch.cuda.{tensortype.__name__}", **options
|
|
)
|
|
else:
|
|
constant_result = ConstantVariable(
|
|
f"torch.{tensortype.__name__}", **options
|
|
)
|
|
elif name == "get_device" and isinstance(self.device, torch.device):
|
|
index = self.device.index if self.device.type != "cpu" else -1
|
|
constant_result = ConstantVariable(index, **options)
|
|
else:
|
|
constant_result = None
|
|
|
|
if constant_result:
|
|
assert not kwargs, f"Tensor.{name}() unhandled kwargs"
|
|
if len(args) == 1:
|
|
return constant_result.getitem_const(args[0])
|
|
elif args:
|
|
return TupleVariable(
|
|
[constant_result.getitem_const(a) for a in args], **options
|
|
)
|
|
return constant_result
|
|
elif (
|
|
name == "repeat"
|
|
and not all(
|
|
x.is_python_constant() for x in itertools.chain(args, kwargs.values())
|
|
)
|
|
and not config.dynamic_shapes
|
|
):
|
|
unimplemented("dynamic Tensor.repeat")
|
|
elif name in ("tolist", "numpy", "backward", "data_ptr"):
|
|
unimplemented(f"Tensor.{name}")
|
|
elif name == "nonzero" and not config.dynamic_shapes:
|
|
unimplemented(f"Tensor.{name}")
|
|
elif name == "item":
|
|
if config.capture_scalar_outputs:
|
|
example_value = get_fake_value(self.proxy.node, tx)
|
|
return wrap_fx_proxy(
|
|
tx,
|
|
tx.output.create_proxy(
|
|
"call_method",
|
|
"item",
|
|
(self.as_proxy(),),
|
|
{},
|
|
),
|
|
example_value=example_value,
|
|
**options,
|
|
)
|
|
else:
|
|
unimplemented(f"Tensor.{name}")
|
|
elif name == "__len__":
|
|
if self.size:
|
|
assert not config.dynamic_shapes
|
|
return ConstantVariable(self.size[0], **options)
|
|
else:
|
|
return wrap_fx_proxy(
|
|
tx,
|
|
tx.output.create_proxy(
|
|
"call_function",
|
|
len,
|
|
(self.as_proxy(),),
|
|
{},
|
|
),
|
|
**options,
|
|
)
|
|
elif name == "__setitem__":
|
|
tx.output.guards.update(options["guards"])
|
|
tx.output.create_proxy(
|
|
"call_function",
|
|
operator.setitem,
|
|
*proxy_args_kwargs([self] + list(args), kwargs),
|
|
)
|
|
return ConstantVariable(None, **options)
|
|
elif name in ("resize_", "resize_as_"):
|
|
if "memory_format" in kwargs:
|
|
memory_format = kwargs["memory_format"].as_python_constant()
|
|
else:
|
|
memory_format = torch.contiguous_format
|
|
|
|
if name == "resize_":
|
|
self.size = args[0].as_python_constant()
|
|
self.is_contiguous = (memory_format,)
|
|
else:
|
|
assert isinstance(args[0], TensorVariable)
|
|
if self.size and args[0].size:
|
|
if (
|
|
self.size == args[0].size
|
|
or memory_format is torch.preserve_format
|
|
):
|
|
self.is_contiguous = args[0].is_contiguous
|
|
else:
|
|
self.size = args[0].size
|
|
self.stride = args[0].stride
|
|
self.ndim = args[0].ndim
|
|
self.is_contiguous = (memory_format,)
|
|
|
|
return wrap_fx_proxy(
|
|
tx,
|
|
tx.output.create_proxy(
|
|
"call_method",
|
|
name,
|
|
*proxy_args_kwargs([self] + list(args), kwargs),
|
|
),
|
|
**options,
|
|
)
|
|
else:
|
|
# Convert x.new(torch.Size) into x.new_empty(torch.Size),
|
|
# as Tensor.new acts differently with a Size input versus a tuple input.
|
|
if (
|
|
name == "new"
|
|
and len(args) == 1
|
|
and isinstance(args[0], (SizeVariable, ShapeVariable))
|
|
and not config.dynamic_shapes
|
|
):
|
|
name = "new_empty"
|
|
return wrap_fx_proxy(
|
|
tx,
|
|
tx.output.create_proxy(
|
|
"call_method",
|
|
name,
|
|
*proxy_args_kwargs([self] + list(args), kwargs),
|
|
),
|
|
**options,
|
|
)
|
|
|
|
|
|
class DynamicShapeVariable(VariableTracker):
|
|
"""
|
|
Represents a symbolic size, e.g., as returned by tensor.size(0)
|
|
"""
|
|
|
|
@classmethod
|
|
def create(cls, tx, proxy, dyn_shape, **options):
|
|
if "example_value" in proxy.node.meta:
|
|
assert proxy.node.meta["example_value"] == dyn_shape
|
|
if dyn_shape is None:
|
|
dyn_shape = get_fake_value(proxy.node, tx)
|
|
proxy.node.meta["example_value"] = dyn_shape
|
|
return DynamicShapeVariable(proxy, dyn_shape, **options)
|
|
|
|
def __init__(self, proxy, dyn_shape, **kwargs):
|
|
super(DynamicShapeVariable, self).__init__(**kwargs)
|
|
self.proxy = proxy
|
|
self.dyn_shape = dyn_shape
|
|
|
|
def python_type(self):
|
|
return type(self.dyn_shape)
|
|
|
|
def unpack_var_sequence(self, tx):
|
|
super(DynamicShapeVariable, self).unpack_var_sequence(tx)
|
|
|
|
def as_proxy(self):
|
|
return self.proxy
|
|
|
|
def evaluate_expr(self, output_graph):
|
|
if not isinstance(self.dyn_shape, torch.SymInt):
|
|
return self.dyn_shape
|
|
return output_graph.shape_env.evaluate_expr(self.dyn_shape.get_pyobj().expr)
|
|
|
|
def call_method(
|
|
self,
|
|
tx,
|
|
name,
|
|
args: "List[VariableTracker]",
|
|
kwargs: "Dict[str, VariableTracker]",
|
|
) -> "VariableTracker":
|
|
from .builder import wrap_fx_proxy
|
|
|
|
options = VariableTracker.propagate(self, args, kwargs.values())
|
|
|
|
return wrap_fx_proxy(
|
|
tx,
|
|
tx.output.create_proxy(
|
|
"call_method",
|
|
name,
|
|
*proxy_args_kwargs([self] + list(args), kwargs),
|
|
),
|
|
**options,
|
|
)
|
|
|
|
|
|
class TensorWithTFOverrideVariable(VariableTracker):
|
|
"""
|
|
Represents a tensor subclass instance with a __torch_function__ override.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
tensor_variable,
|
|
orig_tensor_variable_source,
|
|
subclass_torch_function__func,
|
|
subclass_type,
|
|
**kwargs,
|
|
):
|
|
super(TensorWithTFOverrideVariable, self).__init__(**kwargs)
|
|
self.tensor_variable = tensor_variable
|
|
self.orig_tensor_variable_source = orig_tensor_variable_source
|
|
self.subclass_torch_function__func = subclass_torch_function__func
|
|
self.subclass_type = subclass_type
|
|
|
|
def call_method(
|
|
self,
|
|
tx,
|
|
name,
|
|
args: "List[VariableTracker]",
|
|
kwargs: "Dict[str, VariableTracker]",
|
|
) -> "VariableTracker":
|
|
# This code block implements inlining the __torch_function__ override
|
|
# of `call_method`.
|
|
from . import GetAttrVariable
|
|
|
|
options = VariableTracker.propagate(self, args, kwargs.values())
|
|
# insert unwrapped version of self as the first argument
|
|
# TODO: This is wrong! When you call the internal __torch_function__,
|
|
# you still get the wrapped version of self, and if you call functions
|
|
# inside __torch_function__, they should come back here. If we unwrap
|
|
# the tensor immediately, that will not happen.
|
|
# See https://github.com/pytorch/torchdynamo/issues/1951
|
|
args = list(args)
|
|
args.insert(0, self.tensor_variable)
|
|
func_var = GetAttrVariable(self.tensor_variable, name)
|
|
|
|
unwrapped = TensorWithTFOverrideVariable.inline_torch_function_unwrapped(
|
|
tx,
|
|
func_var,
|
|
self.orig_tensor_variable_source,
|
|
self.subclass_torch_function__func,
|
|
self.subclass_type,
|
|
options,
|
|
args,
|
|
kwargs,
|
|
)
|
|
|
|
# TODO(future PR): implement rewrapping conditional on method presence
|
|
# in `torch.overrides.get_default_nowrap_function()`. It's unclear how
|
|
# to do this easily in the current codebase since the resolution of
|
|
# `GetAttrVariable` depends on the type of the underlying object.
|
|
|
|
return TensorWithTFOverrideVariable(
|
|
unwrapped,
|
|
self.orig_tensor_variable_source,
|
|
self.subclass_torch_function__func,
|
|
self.subclass_type,
|
|
)
|
|
|
|
@staticmethod
|
|
def inline_torch_function_unwrapped(
|
|
tx,
|
|
original_func_var,
|
|
tensor_with_tf_override_source,
|
|
tf_func,
|
|
subclass_type,
|
|
options,
|
|
args,
|
|
kwargs,
|
|
):
|
|
"""
|
|
This function inlines the `__torch_function__` override for `original_func_var`.
|
|
For example, if the user code is
|
|
|
|
x1 = torch.sigmoid(x0)
|
|
|
|
And `x0` has an override, then:
|
|
* `original_func_var` will be a `VariableTracker` object wrapping `torch.sigmoid`
|
|
* `tensor_with_tf_override_source` will be the `Source` object from
|
|
the original tensor override instance in the beginning of the program
|
|
* `tf_func` will be the custom `__torch_function__` function
|
|
* `subclass_type` will be `type(x0)`
|
|
|
|
The caller is expected to properly massage args and kwargs before
|
|
passing them into this function.
|
|
|
|
The caller is responsible for wrapping the return value, if needed.
|
|
"""
|
|
from . import UserDefinedClassVariable
|
|
from .builder import TupleVariable, VariableBuilder
|
|
|
|
source = AttrSource(
|
|
AttrSource(tensor_with_tf_override_source, "__torch_function__"),
|
|
"__func__",
|
|
)
|
|
tf_func_var = VariableBuilder(tx, source)(tf_func)
|
|
type_var = UserDefinedClassVariable(subclass_type, **options)
|
|
|
|
# signature:
|
|
# def __torch_function__(cls, func, types, args=(), kwargs=None):
|
|
tf_args = (
|
|
type_var, # cls
|
|
original_func_var, # func
|
|
(type_var,), # types
|
|
TupleVariable(args), # args
|
|
kwargs, # kwargs
|
|
)
|
|
|
|
# Disable __torch_function__ here to prevent the clone of the
|
|
# example tensor from going into the override.
|
|
with torch._C.DisableTorchFunction():
|
|
return tx.inline_user_function_return(tf_func_var, tf_args, {})
|
|
|
|
|
|
class UnspecializedNumpyVariable(TensorVariable):
|
|
"""
|
|
This is a 1-element tensor represents unspecialized numpy float/int.
|
|
"""
|
|
|
|
def __init__(self, proxy: torch.fx.Proxy, **kwargs):
|
|
raw_value = kwargs.pop("raw_value", None)
|
|
super(UnspecializedNumpyVariable, self).__init__(proxy, **kwargs)
|
|
self.raw_value = raw_value
|
|
|
|
@classmethod
|
|
def from_tensor_variable(cls, tensor_variable, raw_value):
|
|
# Convert a `TensorVariable` instance into an `UnspecializedNumpyVariable` instance.
|
|
return UnspecializedNumpyVariable(
|
|
**dict(tensor_variable.__dict__), raw_value=raw_value
|
|
)
|
|
|
|
def as_specialized(self, tx):
|
|
for graph_arg in tx.output.graphargs:
|
|
if graph_arg.source is self.source:
|
|
graph_arg.erase()
|
|
|
|
for g in self.guards:
|
|
if g.is_volatile:
|
|
g.create_fn = GuardBuilder.CONSTANT_MATCH
|
|
|
|
return ConstantVariable(value=self.raw_value, guards=self.guards)
|
|
|
|
|
|
class UnspecializedPythonVariable(TensorVariable):
|
|
"""
|
|
This is a 1-element tensor represents unspecialized python float/int.
|
|
"""
|
|
|
|
def __init__(self, proxy: torch.fx.Proxy, **kwargs):
|
|
raw_value = kwargs.pop("raw_value", None)
|
|
need_unwrap = kwargs.pop("need_unwrap", True)
|
|
super(UnspecializedPythonVariable, self).__init__(proxy, **kwargs)
|
|
self.raw_value = raw_value
|
|
self.need_unwrap = need_unwrap
|
|
|
|
@classmethod
|
|
def from_tensor_variable(cls, tensor_variable, raw_value, need_unwrap=True):
|
|
# Convert a `TensorVariable` instance into an `UnspecializedPythonVariable` instance.
|
|
return UnspecializedPythonVariable(
|
|
**dict(tensor_variable.__dict__),
|
|
raw_value=raw_value,
|
|
need_unwrap=need_unwrap,
|
|
)
|
|
|
|
def as_specialized(self, tx):
|
|
for graph_arg in tx.output.graphargs:
|
|
if graph_arg.source is self.source:
|
|
graph_arg.erase()
|
|
|
|
for g in self.guards:
|
|
if g.is_volatile:
|
|
g.create_fn = GuardBuilder.CONSTANT_MATCH
|
|
|
|
return ConstantVariable(value=self.raw_value, guards=self.guards)
|
|
|
|
|
|
class FakeItemVariable(TensorVariable):
|
|
"""An unspecialized python variable which prevents access to the underlying raw value.
|
|
This is needed if item is called on a FakeTensor."""
|
|
|
|
def __init__(self, proxy: torch.fx.Proxy, **kwargs):
|
|
need_unwrap = kwargs.pop("need_unwrap", False)
|
|
super(FakeItemVariable, self).__init__(proxy, **kwargs)
|
|
self.need_unwrap = need_unwrap
|
|
|
|
@classmethod
|
|
def from_tensor_variable(cls, tensor_variable):
|
|
return FakeItemVariable(**dict(tensor_variable.__dict__))
|