mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
**Background:** Before this PR, support in dynamo for tensor attributes (e.g. `x.H`, `x.T`, ...) need to be individually implemented one-by-one. This could potentially lead to errors, e.g. if the implementation in [variables/tensor.py](21c7c7c72f/torch/_dynamo/variables/tensor.py (L160)) differs from the implementation from a direct call to the attribute. For attributes that were not special-cased in tensor.py, dynamo tracing would fail. This PR adds generic support for tensor attributes that return tensors without needing to specially handle them. (Notably, for x.real and x.imag, which previously weren't supported).
**In this PR:** This directly creates a proxy node for a `"call_function"` node with `target=getattr`, and feeds it into wrap_fx_proxy. This will produce a TensorVariable for the attribute returned.
This also removes the implementations for H, T, mH, mT which were broken (previously `torch.relu(x.T)` would fail). They now fall back to this default implementation (for which `torch.relu(x.T)` passes).
**Further context**:
* Ed's original suggestion in [90463](https://github.com/pytorch/pytorch/pull/90463#discussion_r1043398340) is to use `torch.Tensor.H.__get__(x)`. I wasn't able to get this to work; fx compilation fails with `getset_descriptor does not have attribute __module__`. Basically, the `__module__` attribute which is available on most python attributes, is not available on `getset_descriptor` objects. (i.e., these are implemented in C++ as attributes on torch.Tensor, so they don't obey some assumptions made by fx)
* Although both tensor attributes and methods (like `x.relu()`) both go through this, this PR should only handle attributes (e.g. see the `"getset_descriptor"` in variables/tensor.py). Methods are handled already by by GetAttrVariable.
* Prior to this PR, we already returned GetAttrVariables for unsupported attrs: the parent caller would catch the NotImplementedError and fallback to returning a GetAttrVariable. But if this GetAttrVariable was ever passed into a torch.\* function (as it could quite possibly be, since most of these attrs are tensors), it would fail because its proxy node would be missing an [example_value](https://github.com/pytorch/pytorch/blob/master/torch/_dynamo/utils.py#L1017). So: before, for some tensor x, `x.real` would work fine; but `torch.relu(x.real)` would fail.
**Testing**: added tests in test_misc.py for x.real, x.imag, x.T, x.real.T.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91840
Approved by: https://github.com/ezyang
632 lines
23 KiB
Python
632 lines
23 KiB
Python
import inspect
|
|
import itertools
|
|
import operator
|
|
import types
|
|
from typing import Dict, List
|
|
|
|
import torch.fx
|
|
import torch.random
|
|
|
|
from .. import config, variables
|
|
from ..exc import unimplemented
|
|
from ..guards import GuardBuilder
|
|
from ..source import AttrSource
|
|
|
|
from ..utils import (
|
|
fqn,
|
|
get_fake_value,
|
|
get_real_value,
|
|
HAS_NUMPY,
|
|
np,
|
|
product,
|
|
proxy_args_kwargs,
|
|
tensortype_to_dtype,
|
|
)
|
|
from .base import VariableTracker
|
|
from .constant import ConstantVariable
|
|
from .lists import ShapeVariable, SizeVariable
|
|
|
|
|
|
class TensorVariable(VariableTracker):
|
|
"""A torch.Tensor input or an intermediate value in the FX graph"""
|
|
|
|
_nonvar_fields = [
|
|
"proxy",
|
|
"dtype",
|
|
"device",
|
|
"layout",
|
|
"ndim",
|
|
"size",
|
|
"stride",
|
|
"requires_grad",
|
|
"is_quantized",
|
|
"is_contiguous",
|
|
]
|
|
|
|
def get_real_value(self):
|
|
"""
|
|
Get the actual value represented by this variable if computation is run
|
|
using the user-provided inputs.
|
|
NOTE: this runs actual tensor computation and may be
|
|
slow and memory-intensive.
|
|
"""
|
|
return get_real_value(self.proxy.node, self.proxy.tracer)
|
|
|
|
def __init__(
|
|
self,
|
|
proxy: torch.fx.Proxy,
|
|
dtype=None,
|
|
device=None,
|
|
layout=None,
|
|
ndim=None,
|
|
size=None,
|
|
stride=None,
|
|
requires_grad=None,
|
|
is_quantized=None,
|
|
is_contiguous=None,
|
|
is_sparse=None,
|
|
class_type=torch.Tensor,
|
|
specialized_value=None,
|
|
**kwargs,
|
|
):
|
|
super(TensorVariable, self).__init__(**kwargs)
|
|
self.proxy = proxy
|
|
self.dtype = dtype
|
|
self.device = device
|
|
self.layout = layout
|
|
self.ndim = ndim
|
|
self.size = size
|
|
self.stride = stride
|
|
self.requires_grad = requires_grad
|
|
self.is_quantized = is_quantized
|
|
self.is_contiguous = is_contiguous
|
|
self.is_sparse = is_sparse
|
|
self.class_type = class_type
|
|
self.specialized_value = specialized_value
|
|
|
|
def as_proxy(self):
|
|
return self.proxy
|
|
|
|
def python_type(self):
|
|
return self.class_type
|
|
|
|
def call_isinstance(self, tensor_type):
|
|
def check_type(ty):
|
|
if ty not in tensortype_to_dtype:
|
|
return issubclass(self.python_type(), ty)
|
|
|
|
dtypes = tensortype_to_dtype[ty]
|
|
return self.dtype in dtypes
|
|
|
|
if type(tensor_type) is tuple:
|
|
return any([check_type(ty) for ty in tensor_type])
|
|
else:
|
|
return check_type(tensor_type)
|
|
|
|
@staticmethod
|
|
def specialize(value: torch.Tensor):
|
|
props = {
|
|
"dtype": value.dtype,
|
|
"device": value.device,
|
|
"layout": value.layout,
|
|
"ndim": int(value.ndim),
|
|
"requires_grad": value.requires_grad,
|
|
"is_quantized": value.is_quantized,
|
|
"is_sparse": value.is_sparse,
|
|
"class_type": type(value),
|
|
}
|
|
if not config.dynamic_shapes:
|
|
props["size"] = tuple(value.size())
|
|
props["stride"] = tuple(value.stride())
|
|
props["is_contiguous"] = tuple(
|
|
[
|
|
x
|
|
for x in torch._prims_common._memory_formats
|
|
if value.is_contiguous(memory_format=x)
|
|
]
|
|
)
|
|
return props
|
|
|
|
def var_getattr(self, tx, name):
|
|
from . import ConstantVariable, TorchVariable
|
|
|
|
result = None
|
|
options = VariableTracker.propagate(self)
|
|
if name == "ndim" and self.ndim is not None:
|
|
result = ConstantVariable(self.ndim, **options)
|
|
elif name == "dtype" and self.dtype is not None:
|
|
result = TorchVariable(self.dtype, **options)
|
|
elif name == "device" and self.device is not None:
|
|
result = TorchVariable(self.device, **options)
|
|
elif name == "layout" and self.layout is not None:
|
|
result = TorchVariable(self.layout, **options)
|
|
elif name == "is_cuda" and self.device is not None:
|
|
result = ConstantVariable(self.device.type == "cuda", **options)
|
|
elif name == "shape" and self.size is not None:
|
|
sizes = [variables.ConstantVariable(x) for x in self.size]
|
|
result = ShapeVariable(sizes, **options)
|
|
elif name == "requires_grad" and self.requires_grad is not None:
|
|
result = ConstantVariable(self.requires_grad, **options)
|
|
elif name == "is_quantized" and self.is_quantized is not None:
|
|
result = ConstantVariable(self.is_quantized, **options)
|
|
elif name == "is_sparse" and self.is_sparse is not None:
|
|
result = ConstantVariable(self.is_sparse, **options)
|
|
elif name == "shape" and self.size is None:
|
|
result = self.call_method(tx, "size", [], {})
|
|
elif name == "ndim" and self.ndim is None:
|
|
result = self.call_method(tx, "dim", [], {})
|
|
elif name == "data":
|
|
result = self.call_method(tx, "detach", [], {})
|
|
if name == "__class__":
|
|
return TorchVariable(self.python_type(), **options)
|
|
|
|
# Add a guard for type matching, these guards are checked before tensor guards
|
|
# In some cases, a <tensor>.<attr> guard can be evaluated first, and break if
|
|
# <tensor> is later changed to another type
|
|
if result is not None and self.source is not None:
|
|
result = result.add_guard(self.make_guard(GuardBuilder.TYPE_MATCH))
|
|
|
|
# For attributes (not methods) that were not caught in the special handling above,
|
|
# (e.g. tensor.real), we handle these generically, assuming that the output type is
|
|
# a tensor.
|
|
if result is None:
|
|
|
|
def try_generic_attr_handling():
|
|
from .builder import wrap_fx_proxy
|
|
from .misc import GetAttrVariable
|
|
|
|
try:
|
|
static_attr = inspect.getattr_static(torch.Tensor, name)
|
|
except NameError:
|
|
return None
|
|
|
|
# Make sure this is an attribute, not a method.
|
|
# type(torch.Tensor.H) should be "getset_descriptor"
|
|
# This is a because of CPython implementation, see THPVariableType:
|
|
# these attributes are implemented under tp_getset, which appear
|
|
# as `getset_descriptor`s, (compared to, say, methods which appear
|
|
# as `method_descriptor`s)
|
|
if type(static_attr) != types.GetSetDescriptorType:
|
|
return None
|
|
|
|
return wrap_fx_proxy(
|
|
tx=tx,
|
|
proxy=GetAttrVariable.create_getattr_proxy(self.as_proxy(), name),
|
|
**options,
|
|
)
|
|
|
|
result = try_generic_attr_handling()
|
|
|
|
if result is None:
|
|
raise NotImplementedError()
|
|
|
|
return result
|
|
|
|
def unpack_var_sequence(self, tx, idxes=None):
|
|
from .builder import wrap_fx_proxy
|
|
|
|
if idxes is None:
|
|
if self.size:
|
|
idxes = range(self.size[0])
|
|
else:
|
|
return super(TensorVariable, self).unpack_var_sequence(tx)
|
|
options = VariableTracker.propagate(self)
|
|
return [wrap_fx_proxy(tx, self.as_proxy()[i], **options) for i in idxes]
|
|
|
|
def call_method(
|
|
self,
|
|
tx,
|
|
name,
|
|
args: "List[VariableTracker]",
|
|
kwargs: "Dict[str, VariableTracker]",
|
|
) -> "VariableTracker":
|
|
from . import ConstantVariable, TorchVariable, TupleVariable
|
|
from .builder import wrap_fx_proxy
|
|
|
|
kwargs = dict(kwargs)
|
|
options = VariableTracker.propagate(self, args, kwargs.values())
|
|
if name == "stride" and self.stride is not None:
|
|
constant_result = ConstantVariable(self.stride, **options)
|
|
elif name == "size" and self.size is not None:
|
|
sizes = [variables.ConstantVariable(x) for x in self.size]
|
|
constant_result = SizeVariable(sizes, **options)
|
|
elif name == "size" and self.size is None and config.dynamic_shapes:
|
|
return wrap_fx_proxy(
|
|
tx,
|
|
tx.output.create_proxy(
|
|
"call_method",
|
|
name,
|
|
*proxy_args_kwargs([self] + list(args), kwargs),
|
|
),
|
|
**options,
|
|
)
|
|
elif name in ("numel", "nelement") and self.size is not None:
|
|
constant_result = ConstantVariable(product(self.size), **options)
|
|
elif name in ("ndimension", "dim") and self.ndim is not None:
|
|
constant_result = ConstantVariable(self.ndim, **options)
|
|
elif name == "is_floating_point" and self.dtype is not None:
|
|
constant_result = ConstantVariable(self.dtype.is_floating_point, **options)
|
|
elif name == "is_contiguous" and self.is_contiguous is not None:
|
|
if "memory_format" in kwargs:
|
|
memory_format = kwargs.pop("memory_format").as_python_constant()
|
|
else:
|
|
memory_format = torch.contiguous_format
|
|
constant_result = ConstantVariable(
|
|
memory_format in self.is_contiguous, **options
|
|
)
|
|
elif (
|
|
name == "type"
|
|
and self.dtype is not None
|
|
and len(args) == 0
|
|
and isinstance(self.device, torch.device)
|
|
):
|
|
tensortype = [k for k, v in tensortype_to_dtype.items() if self.dtype in v][
|
|
0
|
|
]
|
|
if self.device.type == "cuda":
|
|
constant_result = ConstantVariable(
|
|
f"torch.cuda.{tensortype.__name__}", **options
|
|
)
|
|
else:
|
|
constant_result = ConstantVariable(
|
|
f"torch.{tensortype.__name__}", **options
|
|
)
|
|
elif (
|
|
name == "type"
|
|
and len(args) == 1
|
|
and fqn(type(args[0].as_python_constant())) == "torch.tensortype"
|
|
):
|
|
# torch.FloatTensor, etc. are all of type "torch.tensortype".
|
|
# torch.fx's tracer fails on these types, because it doesn't support arguments of torch.tensortype type.
|
|
# So, we pass it in as a string (which is also supported, see above implementation for .type() with 0 args)
|
|
tensor_type = args[0].as_python_constant()
|
|
tensor_type_const = ConstantVariable(fqn(tensor_type), **options)
|
|
return wrap_fx_proxy(
|
|
tx,
|
|
tx.output.create_proxy(
|
|
"call_method",
|
|
name,
|
|
*proxy_args_kwargs([self, tensor_type_const], kwargs),
|
|
),
|
|
**options,
|
|
)
|
|
elif name == "get_device" and isinstance(self.device, torch.device):
|
|
index = self.device.index if self.device.type != "cpu" else -1
|
|
constant_result = ConstantVariable(index, **options)
|
|
else:
|
|
constant_result = None
|
|
|
|
if constant_result:
|
|
assert not kwargs, f"Tensor.{name}() unhandled kwargs"
|
|
if len(args) == 1:
|
|
return constant_result.getitem_const(args[0])
|
|
elif args:
|
|
return TupleVariable(
|
|
[constant_result.getitem_const(a) for a in args], **options
|
|
)
|
|
return constant_result
|
|
elif (
|
|
name == "repeat"
|
|
and not all(
|
|
x.is_python_constant() for x in itertools.chain(args, kwargs.values())
|
|
)
|
|
and not config.dynamic_shapes
|
|
):
|
|
unimplemented("dynamic Tensor.repeat")
|
|
elif name in ("tolist", "numpy", "backward", "data_ptr"):
|
|
unimplemented(f"Tensor.{name}")
|
|
elif name == "nonzero" and not config.dynamic_shapes:
|
|
unimplemented(f"Tensor.{name}")
|
|
elif name == "item" and not config.capture_scalar_outputs:
|
|
unimplemented(f"Tensor.{name}")
|
|
elif (
|
|
name == "item"
|
|
and config.capture_scalar_outputs
|
|
and not config.dynamic_shapes
|
|
):
|
|
raise AssertionError(
|
|
"To capture_scalar_outputs, you must also set dynamic_shapes = True"
|
|
)
|
|
elif name == "__len__":
|
|
return self.call_method(tx, "size", [ConstantVariable(0, **options)], {})
|
|
elif name == "__setitem__":
|
|
tx.output.guards.update(options["guards"])
|
|
tx.output.create_proxy(
|
|
"call_function",
|
|
operator.setitem,
|
|
*proxy_args_kwargs([self] + list(args), kwargs),
|
|
)
|
|
return ConstantVariable(None, **options)
|
|
elif name in ("resize_", "resize_as_"):
|
|
if "memory_format" in kwargs:
|
|
memory_format = kwargs["memory_format"].as_python_constant()
|
|
else:
|
|
memory_format = torch.contiguous_format
|
|
|
|
if name == "resize_":
|
|
self.size = args[0].as_python_constant()
|
|
self.is_contiguous = (memory_format,)
|
|
else:
|
|
assert isinstance(args[0], TensorVariable)
|
|
if self.size and args[0].size:
|
|
if (
|
|
self.size == args[0].size
|
|
or memory_format is torch.preserve_format
|
|
):
|
|
self.is_contiguous = args[0].is_contiguous
|
|
else:
|
|
self.size = args[0].size
|
|
self.stride = args[0].stride
|
|
self.ndim = args[0].ndim
|
|
self.is_contiguous = (memory_format,)
|
|
|
|
return wrap_fx_proxy(
|
|
tx,
|
|
tx.output.create_proxy(
|
|
"call_method",
|
|
name,
|
|
*proxy_args_kwargs([self] + list(args), kwargs),
|
|
),
|
|
**options,
|
|
)
|
|
elif (
|
|
name == "add_" and len(args) == 1 and len(kwargs) == 1 and "alpha" in kwargs
|
|
):
|
|
result = TorchVariable(torch.mul, **options).call_function(
|
|
tx, args + [kwargs["alpha"]], {}
|
|
)
|
|
return self.call_method(tx, "add_", [result], {})
|
|
elif (
|
|
name == "addcdiv_"
|
|
and len(args) == 2
|
|
and len(kwargs) == 1
|
|
and "value" in kwargs
|
|
):
|
|
result = TorchVariable(torch.div, **options).call_function(tx, args, {})
|
|
result = TorchVariable(torch.mul, **options).call_function(
|
|
tx, [result, kwargs["value"]], {}
|
|
)
|
|
return self.call_method(tx, "add_", [result], {})
|
|
else:
|
|
# Convert x.new(torch.Size) into x.new_empty(torch.Size),
|
|
# as Tensor.new acts differently with a Size input versus a tuple input.
|
|
if (
|
|
name == "new"
|
|
and len(args) == 1
|
|
and isinstance(args[0], (SizeVariable, ShapeVariable))
|
|
and not config.dynamic_shapes
|
|
):
|
|
name = "new_empty"
|
|
return wrap_fx_proxy(
|
|
tx,
|
|
tx.output.create_proxy(
|
|
"call_method",
|
|
name,
|
|
*proxy_args_kwargs([self] + list(args), kwargs),
|
|
),
|
|
**options,
|
|
)
|
|
|
|
|
|
class DynamicShapeVariable(VariableTracker):
|
|
"""
|
|
Represents a symbolic size, e.g., as returned by tensor.size(0)
|
|
"""
|
|
|
|
@classmethod
|
|
def create(cls, tx, proxy, dyn_shape, **options):
|
|
if "example_value" in proxy.node.meta:
|
|
assert proxy.node.meta["example_value"] == dyn_shape
|
|
if dyn_shape is None:
|
|
dyn_shape = get_fake_value(proxy.node, tx)
|
|
proxy.node.meta["example_value"] = dyn_shape
|
|
return DynamicShapeVariable(proxy, dyn_shape, **options)
|
|
|
|
def __init__(self, proxy, dyn_shape, **kwargs):
|
|
super(DynamicShapeVariable, self).__init__(**kwargs)
|
|
self.proxy = proxy
|
|
self.dyn_shape = dyn_shape
|
|
|
|
def python_type(self):
|
|
return type(self.dyn_shape)
|
|
|
|
def unpack_var_sequence(self, tx):
|
|
super(DynamicShapeVariable, self).unpack_var_sequence(tx)
|
|
|
|
def as_proxy(self):
|
|
return self.proxy
|
|
|
|
def evaluate_expr(self, output_graph):
|
|
if not isinstance(self.dyn_shape, torch.SymInt):
|
|
return self.dyn_shape
|
|
return output_graph.shape_env.evaluate_expr(self.dyn_shape.node.expr)
|
|
|
|
def call_method(
|
|
self,
|
|
tx,
|
|
name,
|
|
args: "List[VariableTracker]",
|
|
kwargs: "Dict[str, VariableTracker]",
|
|
) -> "VariableTracker":
|
|
from .builder import wrap_fx_proxy
|
|
|
|
options = VariableTracker.propagate(self, args, kwargs.values())
|
|
|
|
return wrap_fx_proxy(
|
|
tx,
|
|
tx.output.create_proxy(
|
|
"call_method",
|
|
name,
|
|
*proxy_args_kwargs([self] + list(args), kwargs),
|
|
),
|
|
**options,
|
|
)
|
|
|
|
|
|
class TensorWithTFOverrideVariable(VariableTracker):
|
|
"""
|
|
Represents a tensor subclass instance with a __torch_function__ override.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
tensor_variable,
|
|
orig_tensor_variable_source,
|
|
subclass_torch_function__func,
|
|
subclass_type,
|
|
**kwargs,
|
|
):
|
|
super(TensorWithTFOverrideVariable, self).__init__(**kwargs)
|
|
self.tensor_variable = tensor_variable
|
|
self.orig_tensor_variable_source = orig_tensor_variable_source
|
|
self.subclass_torch_function__func = subclass_torch_function__func
|
|
self.subclass_type = subclass_type
|
|
|
|
def call_method(
|
|
self,
|
|
tx,
|
|
name,
|
|
args: "List[VariableTracker]",
|
|
kwargs: "Dict[str, VariableTracker]",
|
|
) -> "VariableTracker":
|
|
# This code block implements inlining the __torch_function__ override
|
|
# of `call_method`.
|
|
from . import GetAttrVariable
|
|
|
|
options = VariableTracker.propagate(self, args, kwargs.values())
|
|
# insert unwrapped version of self as the first argument
|
|
# TODO: This is wrong! When you call the internal __torch_function__,
|
|
# you still get the wrapped version of self, and if you call functions
|
|
# inside __torch_function__, they should come back here. If we unwrap
|
|
# the tensor immediately, that will not happen.
|
|
# See https://github.com/pytorch/torchdynamo/issues/1951
|
|
args = list(args)
|
|
args.insert(0, self.tensor_variable)
|
|
func_var = GetAttrVariable(self.tensor_variable, name)
|
|
|
|
unwrapped = TensorWithTFOverrideVariable.inline_torch_function_unwrapped(
|
|
tx,
|
|
func_var,
|
|
self.orig_tensor_variable_source,
|
|
self.subclass_torch_function__func,
|
|
self.subclass_type,
|
|
options,
|
|
args,
|
|
kwargs,
|
|
)
|
|
|
|
# TODO(future PR): implement rewrapping conditional on method presence
|
|
# in `torch.overrides.get_default_nowrap_function()`. It's unclear how
|
|
# to do this easily in the current codebase since the resolution of
|
|
# `GetAttrVariable` depends on the type of the underlying object.
|
|
|
|
return TensorWithTFOverrideVariable(
|
|
unwrapped,
|
|
self.orig_tensor_variable_source,
|
|
self.subclass_torch_function__func,
|
|
self.subclass_type,
|
|
)
|
|
|
|
@staticmethod
|
|
def inline_torch_function_unwrapped(
|
|
tx,
|
|
original_func_var,
|
|
tensor_with_tf_override_source,
|
|
tf_func,
|
|
subclass_type,
|
|
options,
|
|
args,
|
|
kwargs,
|
|
):
|
|
"""
|
|
This function inlines the `__torch_function__` override for `original_func_var`.
|
|
For example, if the user code is
|
|
|
|
x1 = torch.sigmoid(x0)
|
|
|
|
And `x0` has an override, then:
|
|
* `original_func_var` will be a `VariableTracker` object wrapping `torch.sigmoid`
|
|
* `tensor_with_tf_override_source` will be the `Source` object from
|
|
the original tensor override instance in the beginning of the program
|
|
* `tf_func` will be the custom `__torch_function__` function
|
|
* `subclass_type` will be `type(x0)`
|
|
|
|
The caller is expected to properly massage args and kwargs before
|
|
passing them into this function.
|
|
|
|
The caller is responsible for wrapping the return value, if needed.
|
|
"""
|
|
from . import UserDefinedClassVariable
|
|
from .builder import TupleVariable, VariableBuilder
|
|
|
|
source = AttrSource(
|
|
AttrSource(tensor_with_tf_override_source, "__torch_function__"),
|
|
"__func__",
|
|
)
|
|
tf_func_var = VariableBuilder(tx, source)(tf_func)
|
|
type_var = UserDefinedClassVariable(subclass_type, **options)
|
|
|
|
# signature:
|
|
# def __torch_function__(cls, func, types, args=(), kwargs=None):
|
|
tf_args = (
|
|
type_var, # cls
|
|
original_func_var, # func
|
|
(type_var,), # types
|
|
TupleVariable(args), # args
|
|
kwargs, # kwargs
|
|
)
|
|
|
|
# Disable __torch_function__ here to prevent the clone of the
|
|
# example tensor from going into the override.
|
|
with torch._C.DisableTorchFunctionSubclass():
|
|
return tx.inline_user_function_return(tf_func_var, tf_args, {})
|
|
|
|
|
|
class UnspecializedPythonVariable(TensorVariable):
|
|
"""
|
|
This is a 1-element tensor represents unspecialized python float/int.
|
|
"""
|
|
|
|
def __init__(self, proxy: torch.fx.Proxy, **kwargs):
|
|
raw_value = kwargs.pop("raw_value", None)
|
|
if HAS_NUMPY and isinstance(raw_value, np.number):
|
|
raw_values = raw_value.item()
|
|
need_unwrap = kwargs.pop("need_unwrap", True)
|
|
super(UnspecializedPythonVariable, self).__init__(proxy, **kwargs)
|
|
self.raw_value = raw_value
|
|
self.need_unwrap = need_unwrap
|
|
|
|
@classmethod
|
|
def from_tensor_variable(cls, tensor_variable, raw_value, need_unwrap=True):
|
|
# Convert a `TensorVariable` instance into an `UnspecializedPythonVariable` instance.
|
|
return UnspecializedPythonVariable(
|
|
**dict(tensor_variable.__dict__),
|
|
raw_value=raw_value,
|
|
need_unwrap=need_unwrap,
|
|
)
|
|
|
|
def as_specialized(self, tx):
|
|
for graph_arg in tx.output.graphargs:
|
|
if graph_arg.source is self.source:
|
|
graph_arg.erase()
|
|
|
|
for g in self.guards:
|
|
if g.is_volatile:
|
|
g.create_fn = GuardBuilder.CONSTANT_MATCH
|
|
|
|
return ConstantVariable(value=self.raw_value, guards=self.guards)
|
|
|
|
|
|
class FakeItemVariable(TensorVariable):
|
|
"""An unspecialized python variable which prevents access to the underlying raw value.
|
|
This is needed if item is called on a FakeTensor."""
|
|
|
|
def __init__(self, proxy: torch.fx.Proxy, **kwargs):
|
|
need_unwrap = kwargs.pop("need_unwrap", False)
|
|
super(FakeItemVariable, self).__init__(proxy, **kwargs)
|
|
self.need_unwrap = need_unwrap
|
|
|
|
@classmethod
|
|
def from_tensor_variable(cls, tensor_variable):
|
|
return FakeItemVariable(**dict(tensor_variable.__dict__))
|