mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
This refactor was prompted by challenges handling mixed int/float operations in C++. A previous version of this patch added overloads for each permutation of int/float and was unwieldy https://github.com/pytorch/pytorch/pull/87722/ This PR takes a different approach. The general outline of the patch is to combine the C++ types SymIntNode and SymFloatNode into a single type, SymNode. This is type erased; we no longer know statically at C++ if we have an int/float and have to test it with the is_int()/is_float() virtual methods. This has a number of knock on effects. - We no longer have C++ classes to bind to Python. Instead, we take an entirely new approach to our Python API, where we have a SymInt/SymFloat class defined entirely in Python, which hold a SymNode (which corresponds to the C++ SymNode). However, SymNode is not pybind11-bound; instead, it lives as-is in Python, and is wrapped into C++ SymNode using PythonSymNode when it goes into C++. This implies a userland rename. In principle, it is also possible for the canonical implementation of SymNode to be written in C++, and then bound to Python with pybind11 (we have this code, although it is commented out.) However, I did not implement this as we currently have no C++ implementations of SymNode. Because we do return SymInt/SymFloat from C++ bindings, the C++ binding code needs to know how to find these classes. Currently, this is done just by manually importing torch and getting the attributes. - Because SymInt/SymFloat are easy Python wrappers, __sym_dispatch__ now takes SymInt/SymFloat, rather than SymNode, bringing it in line with how __torch_dispatch__ works. Some miscellaneous improvements: - SymInt now has a constructor that takes SymNode. Note that this constructor is ambiguous if you pass in a subclass of SymNode, so an explicit downcast is necessary. This means toSymFloat/toSymInt are no more. This is a mild optimization as it means rvalue reference works automatically. - We uniformly use the caster for c10::SymInt/SymFloat, rather than going the long way via the SymIntNode/SymFloatNode. - Removed some unnecessary toSymInt/toSymFloat calls in normalize_* functions, pretty sure this doesn't do anything. - guard_int is now a free function, since to guard on an int you cannot assume the method exists. A function can handle both int and SymInt inputs. - We clean up the magic method definition code for SymInt/SymFloat/SymNode. ONLY the user classes (SymInt/SymFloat) get magic methods; SymNode gets plain methods; this is to help avoid confusion between the two types. Signed-off-by: Edward Z. Yang <ezyang@fb.com> cc @jansel @mlazos @soumith @voznesenskym @yanboliang @penguinwu @anijain2305 Pull Request resolved: https://github.com/pytorch/pytorch/pull/87817 Approved by: https://github.com/albanD, https://github.com/anjali411
782 lines
28 KiB
Python
782 lines
28 KiB
Python
import copy
|
|
import functools
|
|
import itertools
|
|
import math
|
|
import numbers
|
|
import operator
|
|
from typing import Dict, List
|
|
|
|
import torch.fx
|
|
import torch.random
|
|
|
|
from ..utils import fake_tensors_available
|
|
|
|
if fake_tensors_available:
|
|
from torch._subclasses import FakeTensor
|
|
from torch._subclasses.fake_tensor import (
|
|
DataDependentOutputException,
|
|
DynamicOutputShapeException,
|
|
)
|
|
from ..utils import deepcopy_to_fake_tensor, wrap_to_fake_tensor_and_record
|
|
|
|
import torch.utils._python_dispatch as py_dispatch
|
|
from torch.fx.immutable_collections import immutable_list
|
|
from torch.utils._pytree import tree_map
|
|
|
|
from .. import config, variables
|
|
from ..exc import TorchRuntimeError, unimplemented, Unsupported
|
|
from ..guards import GuardBuilder
|
|
from ..source import AttrSource
|
|
from ..utils import (
|
|
clone_input,
|
|
is_lazy_module,
|
|
istype,
|
|
preserve_rng_state,
|
|
product,
|
|
proxy_args_kwargs,
|
|
tensortype_to_dtype,
|
|
)
|
|
from .base import MutableLocal, typestr, VariableTracker
|
|
from .constant import ConstantVariable
|
|
from .lists import ShapeVariable, SizeVariable
|
|
|
|
|
|
class _missing:
|
|
pass
|
|
|
|
|
|
def _run_node(output_graph, node, args, kwargs, nnmodule):
|
|
op = node.op
|
|
if op == "call_function":
|
|
return node.target(*args, **kwargs)
|
|
elif op == "call_method":
|
|
return getattr(args[0], node.target)(*args[1:], **kwargs)
|
|
elif op == "call_module":
|
|
assert nnmodule is not None
|
|
return nnmodule(*args, **kwargs)
|
|
elif op == "get_attr":
|
|
return output_graph.get_submodule(node.target)
|
|
raise AssertionError(op)
|
|
|
|
|
|
def _get_real_value(node, output_graph):
|
|
"""
|
|
Run the actual computation represented by `node` and return the result.
|
|
This will execute any dependent nodes in the graph as well.
|
|
"""
|
|
cache = output_graph.real_value_cache
|
|
if node in cache:
|
|
return cache[node]
|
|
|
|
op = node.op
|
|
args, kwargs = torch.fx.node.map_arg(
|
|
(node.args, node.kwargs),
|
|
lambda n: _get_real_value(n, output_graph),
|
|
)
|
|
|
|
if op == "call_module":
|
|
nn_module = output_graph.nn_modules[node.target]
|
|
if not is_lazy_module(nn_module):
|
|
nn_module = copy.deepcopy(nn_module)
|
|
else:
|
|
# In the case of a lazy module, we want to run
|
|
# the pre-hooks which initialize it
|
|
nn_module(*args, **kwargs)
|
|
else:
|
|
nn_module = None
|
|
|
|
try:
|
|
real_value = _run_node(output_graph, node, args, kwargs, nn_module)
|
|
cache[node] = real_value
|
|
except RuntimeError as e:
|
|
raise TorchRuntimeError() from e
|
|
return real_value
|
|
|
|
|
|
def _get_fake_value(node, tx):
|
|
"""
|
|
Run the computation represented by `node` using fake tensors and return the result.
|
|
"""
|
|
op = node.op
|
|
fake_wrapper = functools.partial(wrap_to_fake_tensor_and_record, tx=tx)
|
|
from ..utils import wrap_fake_exception
|
|
|
|
def visit(n: torch.fx.Node):
|
|
return n.meta["example_value"]
|
|
|
|
args, kwargs = torch.fx.node.map_arg((node.args, node.kwargs), visit)
|
|
args = tree_map(fake_wrapper, args)
|
|
kwargs = tree_map(fake_wrapper, kwargs)
|
|
|
|
nnmodule = None
|
|
if op == "call_module":
|
|
nnmodule = tx.output.nn_modules[node.target]
|
|
|
|
if not is_lazy_module(nnmodule):
|
|
nnmodule = deepcopy_to_fake_tensor(nnmodule, tx.fake_mode)
|
|
|
|
def context():
|
|
if hasattr(py_dispatch, "enable_torch_dispatch_mode"):
|
|
return py_dispatch.enable_torch_dispatch_mode(tx.fake_mode)
|
|
else:
|
|
return tx.fake_mode
|
|
|
|
if op == "call_module" and is_lazy_module(nnmodule):
|
|
assert nnmodule is not None
|
|
# In the case of a lazy module, we want to run
|
|
# the pre-hooks which initialize it
|
|
nnmodule(*args, **kwargs)
|
|
try:
|
|
with context():
|
|
return wrap_fake_exception(
|
|
lambda: _run_node(tx.output, node, args, kwargs, nnmodule)
|
|
)
|
|
except Unsupported:
|
|
raise
|
|
except RuntimeError as e:
|
|
if isinstance(e, DataDependentOutputException):
|
|
if config.capture_scalar_outputs and node.target == "item":
|
|
return torch.zeros(size=(), dtype=args[0].dtype).item()
|
|
else:
|
|
unimplemented(f"data dependent operator: {e.func}")
|
|
elif isinstance(e, DynamicOutputShapeException):
|
|
unimplemented(f"dynamic shape operator: {e.func}")
|
|
else:
|
|
raise TorchRuntimeError() from e
|
|
|
|
|
|
def _clone_input(value):
|
|
if isinstance(value, torch.Tensor):
|
|
use_fake_tensors = fake_tensors_available and config.fake_tensor_propagation
|
|
# tensor subclasses will not be converted to FakeTensors and need to be cloned
|
|
if not use_fake_tensors or not isinstance(value, FakeTensor):
|
|
# NB: ensure strides are preserved
|
|
value = clone_input(value)
|
|
|
|
return value
|
|
|
|
|
|
class TensorVariable(VariableTracker):
|
|
"""A torch.Tensor input or an intermediate value in the FX graph"""
|
|
|
|
_nonvar_fields = [
|
|
"proxy",
|
|
"dtype",
|
|
"device",
|
|
"ndim",
|
|
"size",
|
|
"stride",
|
|
"requires_grad",
|
|
"is_quantized",
|
|
"is_contiguous",
|
|
]
|
|
|
|
def get_real_value(self):
|
|
"""
|
|
Get the actual value represented by this variable if computation is run
|
|
using the user-provided inputs.
|
|
NOTE: this runs actual tensor computation and may be
|
|
slow and memory-intensive.
|
|
"""
|
|
return _get_real_value(self.proxy.node, self.proxy.tracer)
|
|
|
|
@classmethod
|
|
def create(cls, tx, proxy, example_value=None, **options):
|
|
if "guards" in options and options["guards"] is not None:
|
|
tx.output.guards.update(options["guards"])
|
|
|
|
assert "example_value" not in proxy.node.meta
|
|
if not config.dynamic_propagation:
|
|
if isinstance(example_value, torch.Tensor):
|
|
options.update(cls.specialize(example_value))
|
|
return cls(proxy, **options)
|
|
|
|
use_fake_tensors = fake_tensors_available and config.fake_tensor_propagation
|
|
|
|
initial_example_value = example_value
|
|
|
|
with preserve_rng_state():
|
|
if example_value is None:
|
|
if use_fake_tensors:
|
|
example_value = _get_fake_value(proxy.node, tx)
|
|
else:
|
|
example_value = _get_real_value(proxy.node, tx.output)
|
|
|
|
else:
|
|
proxy.tracer.real_value_cache[proxy.node] = _clone_input(example_value)
|
|
if use_fake_tensors:
|
|
fake_wrapper = functools.partial(
|
|
wrap_to_fake_tensor_and_record, tx=tx
|
|
)
|
|
example_value = fake_wrapper(example_value)
|
|
|
|
if isinstance(example_value, torch.Tensor):
|
|
is_parameter = isinstance(example_value, torch.nn.Parameter)
|
|
should_specialize = options.pop("should_specialize", False)
|
|
if is_parameter or should_specialize:
|
|
specialized_value = initial_example_value
|
|
else:
|
|
specialized_value = None
|
|
|
|
example_value = _clone_input(example_value)
|
|
proxy.node.meta["example_value"] = example_value
|
|
specialized_props = cls.specialize(example_value)
|
|
if use_fake_tensors and isinstance(example_value, FakeTensor):
|
|
specialized_props["class_type"] = (
|
|
torch.nn.Parameter if is_parameter else torch.Tensor
|
|
)
|
|
|
|
specialized_props["specialized_value"] = specialized_value
|
|
|
|
options.update(specialized_props)
|
|
return cls(proxy, **options)
|
|
elif (
|
|
hasattr(proxy.node.target, "__name__")
|
|
and proxy.node.target.__name__ == "set_state"
|
|
and isinstance(proxy.node.target.__self__, torch._C.Generator)
|
|
or proxy.node.target == torch.random.set_rng_state
|
|
):
|
|
from . import TorchVariable
|
|
|
|
return TorchVariable(proxy.node.target)
|
|
elif istype(example_value, (int, bool, float)) and config.dynamic_shapes:
|
|
proxy.node.meta["example_value"] = example_value
|
|
return DynamicShapeVariable(proxy, example_value, **options)
|
|
elif istype(example_value, torch.Size) and config.dynamic_shapes:
|
|
proxy.node.meta["example_value"] = example_value
|
|
sizes = []
|
|
for i, v in enumerate(example_value):
|
|
proxy_i = proxy[i]
|
|
proxy_i.node.meta["example_value"] = v
|
|
sizes.append(DynamicShapeVariable(proxy_i, v))
|
|
return SizeVariable(sizes, proxy, **options)
|
|
elif istype(example_value, int) and proxy.node.target in (
|
|
torch.seed,
|
|
operator.mod,
|
|
# some mac builds are missing torch.distributed.get_rank()
|
|
getattr(torch.distributed, "get_rank", _missing),
|
|
getattr(torch.distributed, "get_world_size", _missing),
|
|
):
|
|
proxy.node.meta["example_value"] = example_value
|
|
return DynamicShapeVariable(proxy, example_value, **options)
|
|
elif istype(example_value, torch.Size) and all(
|
|
[isinstance(x, int) for x in example_value]
|
|
):
|
|
sizes = [variables.ConstantVariable(x) for x in example_value]
|
|
return SizeVariable(sizes, **options)
|
|
elif isinstance(example_value, (tuple, list)):
|
|
unpacked = []
|
|
for i, val in enumerate(example_value):
|
|
if val is None:
|
|
# nn.MultiheadAttention() can return None, see issue #175
|
|
unpacked.append(
|
|
variables.ConstantVariable(None, **options),
|
|
)
|
|
else:
|
|
unpacked.append(
|
|
cls.create(
|
|
tx,
|
|
proxy.tracer.create_proxy(
|
|
"call_function", operator.getitem, (proxy, i), {}
|
|
),
|
|
example_value=val,
|
|
**options,
|
|
)
|
|
)
|
|
if istype(example_value, tuple):
|
|
return variables.TupleVariable(unpacked, **options)
|
|
elif istype(example_value, (list, immutable_list)):
|
|
return variables.ListVariable(
|
|
unpacked, mutable_local=MutableLocal(), **options
|
|
)
|
|
else:
|
|
assert (
|
|
example_value.__class__.__module__ == "torch.return_types"
|
|
or hasattr(example_value, "_fields")
|
|
), "namedtuple?"
|
|
return variables.NamedTupleVariable(
|
|
unpacked, example_value.__class__, **options
|
|
)
|
|
elif example_value is None or proxy.node.target is torch.manual_seed:
|
|
return variables.ConstantVariable(None, **options)
|
|
elif (
|
|
isinstance(example_value, int)
|
|
and proxy.node.target is torch._utils._element_size
|
|
):
|
|
proxy.node.meta["example_value"] = example_value
|
|
return variables.ConstantVariable(example_value, **options)
|
|
elif (
|
|
isinstance(example_value, numbers.Number)
|
|
and (
|
|
proxy.node.target == "item"
|
|
or proxy.node.target in {math.sqrt, math.pow}
|
|
)
|
|
and config.capture_scalar_outputs
|
|
):
|
|
if use_fake_tensors:
|
|
# item raw value should not be accessed
|
|
return FakeItemVariable.create(
|
|
tx=tx,
|
|
proxy=proxy,
|
|
example_value=torch.tensor(example_value),
|
|
**options,
|
|
)
|
|
else:
|
|
return UnspecializedPythonVariable.create(
|
|
tx=tx,
|
|
proxy=proxy,
|
|
example_value=torch.tensor(example_value),
|
|
raw_value=None if use_fake_tensors else example_value,
|
|
need_unwrap=False,
|
|
**options,
|
|
)
|
|
elif (
|
|
proxy.node.target == torch._C._DisableFuncTorch
|
|
or proxy.node.target == torch.cuda._is_in_bad_fork
|
|
):
|
|
from . import UserDefinedObjectVariable
|
|
|
|
return UserDefinedObjectVariable(example_value)
|
|
elif isinstance(example_value, torch.SymInt):
|
|
proxy.node.meta["example_value"] = example_value
|
|
return cls(proxy, **options)
|
|
else:
|
|
raise AssertionError(
|
|
"torch.* op returned non-Tensor "
|
|
+ f"{typestr(example_value)} {proxy.node.op} {proxy.node.target}"
|
|
)
|
|
|
|
def __init__(
|
|
self,
|
|
proxy: torch.fx.Proxy,
|
|
dtype=None,
|
|
device=None,
|
|
ndim=None,
|
|
size=None,
|
|
stride=None,
|
|
requires_grad=None,
|
|
is_quantized=None,
|
|
is_contiguous=None,
|
|
is_sparse=None,
|
|
class_type=torch.Tensor,
|
|
specialized_value=None,
|
|
**kwargs,
|
|
):
|
|
super(TensorVariable, self).__init__(**kwargs)
|
|
self.proxy = proxy
|
|
self.dtype = dtype
|
|
self.device = device
|
|
self.ndim = ndim
|
|
self.size = size
|
|
self.stride = stride
|
|
self.requires_grad = requires_grad
|
|
self.is_quantized = is_quantized
|
|
self.is_contiguous = is_contiguous
|
|
self.is_sparse = is_sparse
|
|
self.class_type = class_type
|
|
self.specialized_value = specialized_value
|
|
|
|
def as_proxy(self):
|
|
return self.proxy
|
|
|
|
def python_type(self):
|
|
return self.class_type
|
|
|
|
def call_isinstance(self, tensor_type):
|
|
def check_type(ty):
|
|
if ty not in tensortype_to_dtype:
|
|
return issubclass(self.python_type(), ty)
|
|
|
|
dtypes = tensortype_to_dtype[ty]
|
|
return self.dtype in dtypes
|
|
|
|
if type(tensor_type) is tuple:
|
|
return any([check_type(ty) for ty in tensor_type])
|
|
else:
|
|
return check_type(tensor_type)
|
|
|
|
@staticmethod
|
|
def specialize(value: torch.Tensor):
|
|
props = {
|
|
"dtype": value.dtype,
|
|
"device": value.device,
|
|
"ndim": int(value.ndim),
|
|
"requires_grad": value.requires_grad,
|
|
"is_quantized": value.is_quantized,
|
|
"is_sparse": value.is_sparse,
|
|
"class_type": type(value),
|
|
}
|
|
if not config.dynamic_shapes:
|
|
props["size"] = tuple(value.size())
|
|
props["stride"] = tuple(value.stride())
|
|
props["is_contiguous"] = value.is_contiguous()
|
|
return props
|
|
|
|
def var_getattr(self, tx, name):
|
|
from . import ConstantVariable, TorchVariable
|
|
|
|
result = None
|
|
options = VariableTracker.propagate(self)
|
|
if name == "ndim" and self.ndim is not None:
|
|
result = ConstantVariable(self.ndim, **options)
|
|
elif name == "dtype" and self.dtype is not None:
|
|
result = TorchVariable(self.dtype, **options)
|
|
elif name == "device" and self.device is not None:
|
|
result = TorchVariable(self.device, **options)
|
|
elif name == "is_cuda" and self.device is not None:
|
|
result = ConstantVariable(self.device.type == "cuda", **options)
|
|
elif name == "shape" and self.size is not None:
|
|
sizes = [variables.ConstantVariable(x) for x in self.size]
|
|
result = ShapeVariable(sizes, **options)
|
|
elif name == "requires_grad" and self.requires_grad is not None:
|
|
result = ConstantVariable(self.requires_grad, **options)
|
|
elif name == "is_quantized" and self.is_quantized is not None:
|
|
result = ConstantVariable(self.is_quantized, **options)
|
|
elif name == "is_sparse" and self.is_sparse is not None:
|
|
result = ConstantVariable(self.is_sparse, **options)
|
|
elif name == "shape" and self.size is None:
|
|
result = self.call_method(tx, "size", [], {})
|
|
elif name == "ndim" and self.ndim is None:
|
|
result = self.call_method(tx, "dim", [], {})
|
|
|
|
if name == "__class__":
|
|
return TorchVariable(self.python_type(), **options)
|
|
|
|
# Add a guard for type matching, these guards are checked before tensor guards
|
|
# In some cases, a <tensor>.<attr> guard can be evaluated first, and break if
|
|
# <tensor> is later changed to another type
|
|
if result is not None and self.source is not None:
|
|
result = result.add_guard(self.make_guard(GuardBuilder.TYPE_MATCH))
|
|
|
|
if result is None:
|
|
raise NotImplementedError()
|
|
|
|
return result
|
|
|
|
def unpack_var_sequence(self, tx):
|
|
options = VariableTracker.propagate(self)
|
|
if self.size:
|
|
return [
|
|
variables.BuiltinVariable(operator.getitem, **options).call_function(
|
|
tx, [self, variables.ConstantVariable(i)], {}
|
|
)
|
|
for i in range(self.size[0])
|
|
]
|
|
|
|
return super(TensorVariable, self).unpack_var_sequence(tx)
|
|
|
|
def call_method(
|
|
self,
|
|
tx,
|
|
name,
|
|
args: "List[VariableTracker]",
|
|
kwargs: "Dict[str, VariableTracker]",
|
|
) -> "VariableTracker":
|
|
from . import ConstantVariable, TupleVariable
|
|
|
|
kwargs = dict(kwargs)
|
|
|
|
options = VariableTracker.propagate(self, args, kwargs.values())
|
|
if name == "stride" and self.stride is not None:
|
|
constant_result = ConstantVariable(self.stride, **options)
|
|
elif name == "size" and self.size is not None:
|
|
sizes = [variables.ConstantVariable(x) for x in self.size]
|
|
constant_result = SizeVariable(sizes, **options)
|
|
elif name == "numel" and self.size is not None:
|
|
constant_result = ConstantVariable(product(self.size), **options)
|
|
elif name in ("ndimension", "dim") and self.ndim is not None:
|
|
constant_result = ConstantVariable(self.ndim, **options)
|
|
elif name == "is_floating_point" and self.dtype is not None:
|
|
constant_result = ConstantVariable(self.dtype.is_floating_point, **options)
|
|
elif name == "is_contiguous" and self.is_contiguous is not None:
|
|
if (
|
|
"memory_format" in kwargs
|
|
and kwargs["memory_format"].as_python_constant()
|
|
== torch.contiguous_format
|
|
):
|
|
kwargs.pop("memory_format")
|
|
constant_result = ConstantVariable(self.is_contiguous, **options)
|
|
else:
|
|
constant_result = None
|
|
|
|
if constant_result:
|
|
assert not kwargs, f"Tensor.{name}() unhandled kwargs"
|
|
if len(args) == 1:
|
|
return constant_result.getitem_const(args[0])
|
|
elif args:
|
|
return TupleVariable(
|
|
[constant_result.getitem_const(a) for a in args], **options
|
|
)
|
|
return constant_result
|
|
elif (
|
|
name == "repeat"
|
|
and not all(
|
|
x.is_python_constant() for x in itertools.chain(args, kwargs.values())
|
|
)
|
|
and not config.dynamic_shapes
|
|
):
|
|
unimplemented("dynamic Tensor.repeat")
|
|
elif name in ("tolist", "numpy", "backward"):
|
|
unimplemented(f"Tensor.{name}")
|
|
elif name == "nonzero" and not config.dynamic_shapes:
|
|
unimplemented(f"Tensor.{name}")
|
|
elif name == "item":
|
|
if config.capture_scalar_outputs:
|
|
return self.__class__.create(
|
|
tx,
|
|
tx.output.create_proxy(
|
|
"call_method", "item", (self.as_proxy(),), {}, current_tx=tx
|
|
),
|
|
**options,
|
|
)
|
|
else:
|
|
unimplemented(f"Tensor.{name}")
|
|
elif name == "__len__":
|
|
if self.size:
|
|
assert not config.dynamic_shapes
|
|
return ConstantVariable(self.size[0], **options)
|
|
else:
|
|
return self.__class__.create(
|
|
tx,
|
|
tx.output.create_proxy(
|
|
"call_function", len, (self.as_proxy(),), {}, current_tx=tx
|
|
),
|
|
**options,
|
|
)
|
|
elif name == "__setitem__":
|
|
tx.output.guards.update(options["guards"])
|
|
tx.output.create_proxy(
|
|
"call_function",
|
|
operator.setitem,
|
|
*proxy_args_kwargs([self] + args, kwargs),
|
|
current_tx=tx,
|
|
)
|
|
return ConstantVariable(None, **options)
|
|
else:
|
|
# Convert x.new(torch.Size) into x.new_empty(torch.Size),
|
|
# as Tensor.new acts differently with a Size input versus a tuple input.
|
|
if (
|
|
name == "new"
|
|
and len(args) == 1
|
|
and isinstance(args[0], (SizeVariable, ShapeVariable))
|
|
and not config.dynamic_shapes
|
|
):
|
|
name = "new_empty"
|
|
|
|
return self.__class__.create(
|
|
tx,
|
|
tx.output.create_proxy(
|
|
"call_method",
|
|
name,
|
|
*proxy_args_kwargs([self] + args, kwargs),
|
|
current_tx=tx,
|
|
),
|
|
**options,
|
|
)
|
|
|
|
|
|
class DynamicShapeVariable(TensorVariable):
|
|
"""
|
|
Represents a symbolic size, e.g., as returned by tensor.size(0)
|
|
"""
|
|
|
|
def __init__(self, proxy, dyn_shape, **kwargs):
|
|
super(DynamicShapeVariable, self).__init__(proxy, **kwargs)
|
|
self.dyn_shape = dyn_shape
|
|
|
|
def python_type(self):
|
|
return type(self.dyn_shape)
|
|
|
|
def unpack_var_sequence(self, tx):
|
|
super(DynamicShapeVariable, self).unpack_var_sequence(tx)
|
|
|
|
|
|
class TensorWithTFOverrideVariable(VariableTracker):
|
|
"""
|
|
Represents a tensor subclass instance with a __torch_function__ override.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
tensor_variable,
|
|
orig_tensor_variable_source,
|
|
subclass_torch_function__func,
|
|
subclass_type,
|
|
**kwargs,
|
|
):
|
|
super(TensorWithTFOverrideVariable, self).__init__(**kwargs)
|
|
self.tensor_variable = tensor_variable
|
|
self.orig_tensor_variable_source = orig_tensor_variable_source
|
|
self.subclass_torch_function__func = subclass_torch_function__func
|
|
self.subclass_type = subclass_type
|
|
|
|
def call_method(
|
|
self,
|
|
tx,
|
|
name,
|
|
args: "List[VariableTracker]",
|
|
kwargs: "Dict[str, VariableTracker]",
|
|
) -> "VariableTracker":
|
|
# This code block implements inlining the __torch_function__ override
|
|
# of `call_method`.
|
|
from . import GetAttrVariable
|
|
|
|
options = VariableTracker.propagate(self, args, kwargs.values())
|
|
# insert unwrapped version of self as the first argument
|
|
args = list(args)
|
|
args.insert(0, self.tensor_variable)
|
|
func_var = GetAttrVariable(self.tensor_variable, name)
|
|
|
|
unwrapped = TensorWithTFOverrideVariable.inline_torch_function_unwrapped(
|
|
tx,
|
|
func_var,
|
|
self.orig_tensor_variable_source,
|
|
self.subclass_torch_function__func,
|
|
self.subclass_type,
|
|
options,
|
|
args,
|
|
kwargs,
|
|
)
|
|
|
|
# TODO(future PR): implement rewrapping conditional on method presence
|
|
# in `torch.overrides.get_default_nowrap_function()`. It's unclear how
|
|
# to do this easily in the current codebase since the resolution of
|
|
# `GetAttrVariable` depends on the type of the underlying object.
|
|
|
|
return TensorWithTFOverrideVariable(
|
|
unwrapped,
|
|
self.orig_tensor_variable_source,
|
|
self.subclass_torch_function__func,
|
|
self.subclass_type,
|
|
)
|
|
|
|
@staticmethod
|
|
def inline_torch_function_unwrapped(
|
|
tx,
|
|
original_func_var,
|
|
tensor_with_tf_override_source,
|
|
tf_func,
|
|
subclass_type,
|
|
options,
|
|
args,
|
|
kwargs,
|
|
):
|
|
"""
|
|
This function inlines the `__torch_function__` override for `original_func_var`.
|
|
For example, if the user code is
|
|
|
|
x1 = torch.sigmoid(x0)
|
|
|
|
And `x0` has an override, then:
|
|
* `original_func_var` will be a `VariableTracker` object wrapping `torch.sigmoid`
|
|
* `tensor_with_tf_override_source` will be the `Source` object from
|
|
the original tensor override instance in the beginning of the program
|
|
* `tf_func` will be the custom `__torch_function__` function
|
|
* `subclass_type` will be `type(x0)`
|
|
|
|
The caller is expected to properly massage args and kwargs before
|
|
passing them into this function.
|
|
|
|
The caller is responsible for wrapping the return value, if needed.
|
|
"""
|
|
from . import UserDefinedClassVariable
|
|
from .builder import TupleVariable, VariableBuilder
|
|
|
|
source = AttrSource(
|
|
AttrSource(tensor_with_tf_override_source, "__torch_function__"),
|
|
"__func__",
|
|
)
|
|
tf_func_var = VariableBuilder(tx, source)(tf_func)
|
|
type_var = UserDefinedClassVariable(subclass_type, **options)
|
|
|
|
# signature:
|
|
# def __torch_function__(cls, func, types, args=(), kwargs=None):
|
|
tf_args = (
|
|
type_var, # cls
|
|
original_func_var, # func
|
|
(type_var,), # types
|
|
TupleVariable(args), # args
|
|
kwargs, # kwargs
|
|
)
|
|
|
|
# Disable __torch_function__ here to prevent the clone of the
|
|
# example tensor from going into the override.
|
|
with torch._C.DisableTorchFunction():
|
|
return tx.inline_user_function_return(tf_func_var, tf_args, {})
|
|
|
|
|
|
class UnspecializedNumpyVariable(TensorVariable):
|
|
"""
|
|
This is a 1-element tensor represents unspecialized numpy float/int.
|
|
"""
|
|
|
|
def __init__(self, proxy: torch.fx.Proxy, **kwargs):
|
|
raw_value = kwargs.pop("raw_value", None)
|
|
super(UnspecializedNumpyVariable, self).__init__(proxy, **kwargs)
|
|
self.raw_value = raw_value
|
|
|
|
@classmethod
|
|
def from_tensor_variable(cls, tensor_variable, raw_value):
|
|
# Convert a `TensorVariable` instance into an `UnspecializedNumpyVariable` instance.
|
|
return UnspecializedNumpyVariable(
|
|
**dict(tensor_variable.__dict__), raw_value=raw_value
|
|
)
|
|
|
|
def as_specialized(self, tx):
|
|
for graph_arg in tx.output.graphargs:
|
|
if graph_arg.source is self.source:
|
|
graph_arg.erase()
|
|
|
|
for g in self.guards:
|
|
if g.is_volatile:
|
|
g.create_fn = GuardBuilder.CONSTANT_MATCH
|
|
|
|
return ConstantVariable(value=self.raw_value, guards=self.guards)
|
|
|
|
|
|
class UnspecializedPythonVariable(TensorVariable):
|
|
"""
|
|
This is a 1-element tensor represents unspecialized python float/int.
|
|
"""
|
|
|
|
def __init__(self, proxy: torch.fx.Proxy, **kwargs):
|
|
raw_value = kwargs.pop("raw_value", None)
|
|
need_unwrap = kwargs.pop("need_unwrap", True)
|
|
super(UnspecializedPythonVariable, self).__init__(proxy, **kwargs)
|
|
self.raw_value = raw_value
|
|
self.need_unwrap = need_unwrap
|
|
|
|
@classmethod
|
|
def from_tensor_variable(cls, tensor_variable, raw_value, need_unwrap=True):
|
|
# Convert a `TensorVariable` instance into an `UnspecializedPythonVariable` instance.
|
|
return UnspecializedPythonVariable(
|
|
**dict(tensor_variable.__dict__),
|
|
raw_value=raw_value,
|
|
need_unwrap=need_unwrap,
|
|
)
|
|
|
|
def as_specialized(self, tx):
|
|
for graph_arg in tx.output.graphargs:
|
|
if graph_arg.source is self.source:
|
|
graph_arg.erase()
|
|
|
|
for g in self.guards:
|
|
if g.is_volatile:
|
|
g.create_fn = GuardBuilder.CONSTANT_MATCH
|
|
|
|
return ConstantVariable(value=self.raw_value, guards=self.guards)
|
|
|
|
|
|
class FakeItemVariable(TensorVariable):
|
|
"""An unspecialized python variable which prevents access to the underlying raw value.
|
|
This is needed if item is called on a FakeTensor."""
|
|
|
|
def __init__(self, proxy: torch.fx.Proxy, **kwargs):
|
|
need_unwrap = kwargs.pop("need_unwrap", False)
|
|
super(FakeItemVariable, self).__init__(proxy, **kwargs)
|
|
self.need_unwrap = need_unwrap
|
|
|
|
@classmethod
|
|
def from_tensor_variable(cls, tensor_variable):
|
|
return FakeItemVariable(**dict(tensor_variable.__dict__))
|