mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add unbacked symints support; item works now (#90624)
The big idea is to add `create_unbacked_symfloat` and `create_unbacked_symint` to ShapeEnv, allowing you to allocate symbolic floats/ints corresponding to data you don't know about at compile time. Then, instead of immediately erroring out when you try to call local_scalar_dense on a FakeTensor, we instead create a fresh symint/symfloat and return that. There a bunch of odds and ends that need to be handled: * A number of `numel` calls converted to `sym_numel` * When we finally return from item(), we need to ensure we actually produce a SymInt/SymFloat when appropriate. The previous binding code assumed that you would have to get a normal Python item. I add a pybind11 binding for Scalar (to PyObject only) and refactor the code to use that. There is some trickiness where you are NOT allowed to go through c10::SymInt if there isn't actually any SymInt involved. See comment. * One of our unit tests tripped an implicit data dependent access which occurs when you pass a Tensor as an argument to a sizes parameter. This is also converted to support symbolic shapes * We now support tracking bare SymInt/SymFloat returns in proxy tensor mode (this was already in symbolic-shapes branch) * Whenever we allocate an unbacked symint, we record the stack trace it was allocated at. These get printed when you attempt data dependent access on the symint (e.g., you try to guard on it) * Subtlety: unbacked symints are not necessarily > 1. I added a test for this. These unbacked symints are not very useful right now as you will almost always immediately raise an error later when you try to guard on them. The next logical step is adding an assertion refinement system that lets ShapeEnv learn facts about unbacked symints so it can do a better job eliding guards that are unnecessary. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/90624 Approved by: https://github.com/Skylion007, https://github.com/voznesenskym
This commit is contained in:
parent
6702345416
commit
f7365eca90
|
|
@ -15,7 +15,7 @@ namespace at {
|
|||
namespace native {
|
||||
|
||||
Scalar item(const Tensor& self) {
|
||||
int64_t numel = self.numel();
|
||||
auto numel = self.sym_numel();
|
||||
TORCH_CHECK(numel == 1, "a Tensor with ", numel, " elements cannot be converted to Scalar");
|
||||
if (self.is_sparse()) {
|
||||
if (self._nnz() == 0) return Scalar(0);
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from collections.abc import Iterable
|
|||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||
from torch.testing._internal.common_methods_invocations import DecorateInfo
|
||||
from torch.testing._internal.common_methods_invocations import op_db, wrapper_set_seed
|
||||
from torch._subclasses.fake_tensor import DynamicOutputShapeException
|
||||
from torch._subclasses.fake_tensor import DynamicOutputShapeException, DataDependentOutputException
|
||||
|
||||
from torch._decomp import decomposition_table
|
||||
from torch.fx.experimental.symbolic_shapes import sym_float, eval_guards, fx_placeholder_vals
|
||||
|
|
@ -423,12 +423,15 @@ def forward(self, x_1):
|
|||
def f(a, b):
|
||||
return torch.allclose(a, b)
|
||||
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError, "data-dependent",
|
||||
lambda: make_fx(f, tracing_mode=self.tracing_mode)(
|
||||
def test_f():
|
||||
make_fx(f, tracing_mode=self.tracing_mode)(
|
||||
torch.zeros(3), torch.zeros(3)
|
||||
)
|
||||
)
|
||||
|
||||
if self.tracing_mode == "symbolic":
|
||||
self.assertRaises(DataDependentOutputException, test_f)
|
||||
else:
|
||||
self.assertRaisesRegex(RuntimeError, "data-dependent", test_f)
|
||||
|
||||
def test_constant_proxy_tensor_mut(self):
|
||||
def f():
|
||||
|
|
@ -454,7 +457,7 @@ def forward(self, x_1):
|
|||
def f():
|
||||
val = torch.tensor([2])
|
||||
blowup = val.repeat(1000)
|
||||
return blowup.sum().item()
|
||||
return bool(blowup.sum().item() == 2)
|
||||
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError, "data-dependent",
|
||||
|
|
@ -465,7 +468,7 @@ def forward(self, x_1):
|
|||
def f():
|
||||
val = torch.tensor([2.0])
|
||||
val.normal_()
|
||||
return val.item()
|
||||
return bool(val.item() == 2.1)
|
||||
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError, "data-dependent",
|
||||
|
|
@ -847,6 +850,18 @@ def forward(self, a_1):
|
|||
empty = torch.ops.aten.empty.memory_format([mul], device = device(type='cpu'), pin_memory = False); mul = None
|
||||
return empty""")
|
||||
|
||||
def test_item(self):
|
||||
def f(a):
|
||||
r = a.item()
|
||||
return r * a
|
||||
|
||||
r = str(make_fx(f, tracing_mode="symbolic")(torch.randn(1)).code).strip()
|
||||
self.assertExpectedInline(r, """\
|
||||
def forward(self, a_1):
|
||||
_local_scalar_dense = torch.ops.aten._local_scalar_dense.default(a_1)
|
||||
mul = torch.ops.aten.mul.Tensor(a_1, _local_scalar_dense); a_1 = _local_scalar_dense = None
|
||||
return mul""")
|
||||
|
||||
|
||||
def test_neg_shape(self):
|
||||
def f(a):
|
||||
|
|
|
|||
|
|
@ -47,6 +47,7 @@
|
|||
#include <ATen/Functions.h>
|
||||
#else
|
||||
$ops_headers
|
||||
#include <ATen/ops/_local_scalar_dense.h>
|
||||
#endif
|
||||
|
||||
using at::DeviceGuard;
|
||||
|
|
@ -317,7 +318,7 @@ static Tensor dispatch_copy_(const Tensor & self, const Tensor & other, bool non
|
|||
static double dispatch_to_CDouble(const Tensor & self) {
|
||||
pybind11::gil_scoped_release no_gil;
|
||||
OptionalDeviceGuard device_guard(device_of(self));
|
||||
if (self.numel() != 1) {
|
||||
if (self.sym_numel() != 1) {
|
||||
throw ValueError("only one element tensors can be converted to Python scalars");
|
||||
}
|
||||
return self.item<double>();
|
||||
|
|
@ -326,7 +327,7 @@ static double dispatch_to_CDouble(const Tensor & self) {
|
|||
static c10::complex<double> dispatch_to_CComplexDouble(const Tensor & self) {
|
||||
pybind11::gil_scoped_release no_gil;
|
||||
OptionalDeviceGuard device_guard(device_of(self));
|
||||
if (self.numel() != 1) {
|
||||
if (self.sym_numel() != 1) {
|
||||
throw ValueError("only one element tensors can be converted to Python scalars");
|
||||
}
|
||||
return self.item<c10::complex<double>>();
|
||||
|
|
@ -335,21 +336,12 @@ static c10::complex<double> dispatch_to_CComplexDouble(const Tensor & self) {
|
|||
static int64_t dispatch_to_CLong(const Tensor & self) {
|
||||
pybind11::gil_scoped_release no_gil;
|
||||
OptionalDeviceGuard device_guard(device_of(self));
|
||||
if (self.numel() != 1) {
|
||||
if (self.sym_numel() != 1) {
|
||||
throw ValueError("only one element tensors can be converted to Python scalars");
|
||||
}
|
||||
return self.item<int64_t>();
|
||||
}
|
||||
|
||||
static bool dispatch_to_Bool(const Tensor & self) {
|
||||
pybind11::gil_scoped_release no_gil;
|
||||
OptionalDeviceGuard device_guard(device_of(self));
|
||||
if (self.numel() != 1) {
|
||||
throw ValueError("only one element tensors can be converted to Python scalars");
|
||||
}
|
||||
return self.item<bool>();
|
||||
}
|
||||
|
||||
static PyObject * THPVariable_float_scalar(PyObject* self, PyObject* args) {
|
||||
HANDLE_TH_ERRORS
|
||||
if (check_has_torch_function(self)) {
|
||||
|
|
@ -399,7 +391,7 @@ static PyObject * THPVariable_index_scalar(PyObject* self, PyObject* args) {
|
|||
auto& self_ = THPVariable_Unpack(self);
|
||||
// TODO: change the condition to `self_.dim() != 0` once we expose scalars
|
||||
// in PyTorch.
|
||||
if (!isIntegralType(self_.scalar_type(), /*includeBool=*/true) || self_.numel() != 1) {
|
||||
if (!isIntegralType(self_.scalar_type(), /*includeBool=*/true) || self_.sym_numel() != 1) {
|
||||
throw TypeError("only integer tensors of a single element can be converted to an index");
|
||||
}
|
||||
return wrap(dispatch_to_CLong(self_));
|
||||
|
|
@ -883,15 +875,7 @@ static PyObject * THPVariable_item(PyObject* self, PyObject* args)
|
|||
}
|
||||
jit::tracer::warn("Converting a tensor to a Python number", jit::tracer::WARN_PYTHON_DATAFLOW);
|
||||
auto& self_ = THPVariable_Unpack(self);
|
||||
if (self_.is_floating_point()) {
|
||||
return wrap(dispatch_to_CDouble(self_));
|
||||
} else if (self_.is_complex()) {
|
||||
return wrap(dispatch_to_CComplexDouble(self_));
|
||||
} else if (self_.scalar_type() == ScalarType::Bool) {
|
||||
return wrap(dispatch_to_Bool(self_));
|
||||
} else {
|
||||
return wrap(dispatch_to_CLong(self_));
|
||||
}
|
||||
return py::cast(self_.item()).release().ptr();
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Un
|
|||
|
||||
import torch
|
||||
from torch._ops import OpOverload
|
||||
from torch._prims_common import is_float_dtype, is_integer_dtype
|
||||
from torch._subclasses.meta_utils import MetaConverter, WeakTensorRefKey
|
||||
from torch.fx.operator_schemas import normalize_function
|
||||
from torch.multiprocessing.reductions import StorageWeakRef
|
||||
|
|
@ -375,6 +376,20 @@ def dyn_shape(fake_mode, func, *args, **kwargs):
|
|||
raise DynamicOutputShapeException(func)
|
||||
|
||||
|
||||
@register_op_impl(lambda func: func is torch.ops.aten._local_scalar_dense.default)
|
||||
def local_scalar_dense(fake_mode, func, arg):
|
||||
if fake_mode.shape_env is None:
|
||||
# Without symints/symfloats, cannot handle this
|
||||
raise DataDependentOutputException(func)
|
||||
if is_float_dtype(arg.dtype):
|
||||
return fake_mode.shape_env.create_unbacked_symfloat()
|
||||
elif is_integer_dtype(arg.dtype):
|
||||
return fake_mode.shape_env.create_unbacked_symint()
|
||||
else:
|
||||
raise NotImplementedError(f"local_scalar_dense/item NYI for {arg.dtype}")
|
||||
|
||||
|
||||
# NB: this must be ordered after local_scalar_dense
|
||||
@register_op_impl(
|
||||
lambda func: torch.Tag.data_dependent_output in func.tags # type: ignore[attr-defined]
|
||||
)
|
||||
|
|
|
|||
|
|
@ -138,6 +138,10 @@ c10::SymIntArrayRef LTCTensorImpl::sym_sizes_custom() const {
|
|||
return c10::fromIntArrayRefKnownNonNegative(sizes_custom());
|
||||
}
|
||||
|
||||
c10::SymInt LTCTensorImpl::sym_numel_custom() const {
|
||||
return numel_custom();
|
||||
}
|
||||
|
||||
void LTCTensorImpl::setup_size_properties() {
|
||||
size_t generation = tensor_->generation();
|
||||
if (generation != generation_) {
|
||||
|
|
|
|||
|
|
@ -48,6 +48,7 @@ class TORCH_API LTCTensorImpl final : public c10::TensorImpl {
|
|||
|
||||
c10::SymIntArrayRef sym_sizes_custom() const override;
|
||||
c10::SymIntArrayRef sym_strides_custom() const override;
|
||||
c10::SymInt sym_numel_custom() const override;
|
||||
|
||||
private:
|
||||
void setup_size_properties();
|
||||
|
|
|
|||
|
|
@ -79,5 +79,39 @@ py::handle type_caster<c10::SymFloat>::cast(
|
|||
}
|
||||
}
|
||||
|
||||
bool type_caster<c10::Scalar>::load(py::handle src, bool) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
0, "pybind11 loading for c10::Scalar NYI (file a bug if you need it)");
|
||||
}
|
||||
|
||||
py::handle type_caster<c10::Scalar>::cast(
|
||||
const c10::Scalar& scalar,
|
||||
return_value_policy /* policy */,
|
||||
handle /* parent */) {
|
||||
if (scalar.isIntegral(/*includeBool*/ false)) {
|
||||
// We have to be careful here; we cannot unconditionally route through
|
||||
// SymInt because integer data from Tensors can easily be MIN_INT or
|
||||
// very negative, which conflicts with the allocated range.
|
||||
if (scalar.isSymbolic()) {
|
||||
return py::cast(scalar.toSymInt()).release();
|
||||
} else {
|
||||
return py::cast(scalar.toLong()).release();
|
||||
}
|
||||
} else if (scalar.isFloatingPoint()) {
|
||||
// This isn't strictly necessary but we add it for symmetry
|
||||
if (scalar.isSymbolic()) {
|
||||
return py::cast(scalar.toSymFloat()).release();
|
||||
} else {
|
||||
return py::cast(scalar.toDouble()).release();
|
||||
}
|
||||
} else if (scalar.isBoolean()) {
|
||||
return py::cast(scalar.toBool()).release();
|
||||
} else if (scalar.isComplex()) {
|
||||
return py::cast(scalar.toComplexDouble()).release();
|
||||
} else {
|
||||
TORCH_INTERNAL_ASSERT(0, "unrecognized scalar type ", scalar.type());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
} // namespace pybind11
|
||||
|
|
|
|||
|
|
@ -203,6 +203,20 @@ struct type_caster<c10::DispatchKey>
|
|||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TORCH_PYTHON_API type_caster<c10::Scalar> {
|
||||
public:
|
||||
PYBIND11_TYPE_CASTER(
|
||||
c10::Scalar,
|
||||
_("Union[Number, torch.SymInt, torch.SymFloat]"));
|
||||
bool load(py::handle src, bool);
|
||||
|
||||
static py::handle cast(
|
||||
const c10::Scalar& si,
|
||||
return_value_policy /* policy */,
|
||||
handle /* parent */);
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TORCH_PYTHON_API type_caster<c10::SymInt> {
|
||||
public:
|
||||
|
|
|
|||
|
|
@ -565,9 +565,9 @@ inline std::vector<c10::SymInt> PythonArgs::symintlist(int i) {
|
|||
var.dtype().toScalarType(), /*include_bool*/ true)) {
|
||||
throw_intlist_exception(this, i, obj, idx);
|
||||
}
|
||||
// TODO: ideally, if this was a fake tensor this would
|
||||
// result in a SymInt, but we don't have the API to do this
|
||||
res.push_back(var.item<int64_t>());
|
||||
auto scalar = var.item();
|
||||
TORCH_CHECK(scalar.isIntegral(/*include bool*/ false));
|
||||
res.push_back(scalar.toSymInt());
|
||||
} else {
|
||||
try {
|
||||
if (is_symint(py::handle(obj))) {
|
||||
|
|
|
|||
|
|
@ -183,6 +183,10 @@ def track_tensor_tree(inner_res, proxy_res, *, constant, tracer):
|
|||
if isinstance(e, torch.Tensor):
|
||||
track_tensor(e, proxy, tracer=tracer, constant=constant)
|
||||
set_meta(proxy, e)
|
||||
elif isinstance(e, py_sym_types):
|
||||
# NB: eagerly set meta here, so that the numbering is in order
|
||||
set_meta(proxy, e)
|
||||
set_proxy_slot(e.node, tracer, lambda: proxy)
|
||||
elif isinstance(e, list):
|
||||
# example use case: allreduce_ returns ([tensor], work)
|
||||
for idx, ee in enumerate(e):
|
||||
|
|
@ -202,7 +206,7 @@ def track_tensor_tree(inner_res, proxy_res, *, constant, tracer):
|
|||
set_meta(proxy_res, inner_res)
|
||||
for idx, e in enumerate(inner_res):
|
||||
wrap_with_proxy(e, proxy_res[idx], get_constant(idx))
|
||||
elif isinstance(inner_res, torch.Tensor):
|
||||
elif isinstance(inner_res, py_sym_types + (torch.Tensor,)):
|
||||
wrap_with_proxy(inner_res, proxy_res, constant)
|
||||
|
||||
return inner_res
|
||||
|
|
@ -281,10 +285,13 @@ def proxy_call(proxy_mode, func, args, kwargs):
|
|||
)
|
||||
with maybe_disable_fake_tensor_mode():
|
||||
return func(*const_args, **const_kwargs)
|
||||
raise RuntimeError(
|
||||
f"It appears that you're trying to get value out of a tracing tensor with {func} - erroring out! "
|
||||
"It's likely that this is caused by data-dependent control flow or similar."
|
||||
)
|
||||
# For symbolic tracing, we return a SymInt/SymFloat and try to
|
||||
# get further in the trace
|
||||
if proxy_mode.tracing_mode != "symbolic":
|
||||
raise RuntimeError(
|
||||
f"It appears that you're trying to get value out of a tracing tensor with {func} - erroring out! "
|
||||
"It's likely that this is caused by data-dependent control flow or similar."
|
||||
)
|
||||
proxy_args, proxy_kwargs = pytree.tree_map_only(
|
||||
(SymInt, SymFloat),
|
||||
fetch_sym_proxy(proxy_mode.tracer),
|
||||
|
|
@ -471,8 +478,9 @@ def wrap_key(f, tensors, tracer):
|
|||
|
||||
|
||||
class ProxyTorchDispatchMode(TorchDispatchMode):
|
||||
def __init__(self, tracer):
|
||||
def __init__(self, tracer, tracing_mode):
|
||||
self.tracer = tracer
|
||||
self.tracing_mode = tracing_mode
|
||||
self.enable_tracing = True
|
||||
self.sym_mode = ProxySymDispatchMode(tracer)
|
||||
self.trace_state = {}
|
||||
|
|
@ -575,7 +583,7 @@ class DecompositionInterpreter(torch.fx.Interpreter):
|
|||
self.decomposition_table = decomposition_table
|
||||
if self.decomposition_table is None:
|
||||
self.decomposition_table = {}
|
||||
self.mode = ProxyTorchDispatchMode(self.tracer)
|
||||
self.mode = ProxyTorchDispatchMode(self.tracer, tracing_mode="real")
|
||||
|
||||
def placeholder(self, target, args, kwargs):
|
||||
out = super().placeholder(target, args, kwargs)
|
||||
|
|
@ -652,7 +660,7 @@ def make_fx(f, decomposition_table=None, tracing_mode="real"):
|
|||
if tracing_mode == "symbolic":
|
||||
python_dispatcher_mode = enable_python_dispatcher()
|
||||
|
||||
proxy_mode = ProxyTorchDispatchMode(fx_tracer)
|
||||
proxy_mode = ProxyTorchDispatchMode(fx_tracer, tracing_mode)
|
||||
|
||||
arg_count = 0
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import torch
|
||||
from typing import Set, Dict, List, Type, Optional, cast
|
||||
import sys
|
||||
import itertools
|
||||
import operator
|
||||
import builtins
|
||||
import math
|
||||
|
|
@ -450,12 +451,14 @@ def _lru_cache(fn, maxsize=None):
|
|||
# name get interned. This is bad for us as we want the metadata (snames)
|
||||
# to vary across different invocations and not leak.
|
||||
class Symbol(sympy.Dummy):
|
||||
__slots__: List[str] = ['snames']
|
||||
__slots__: List[str] = ['snames', 'stack']
|
||||
snames: List[str]
|
||||
stack: Optional[str]
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
self = super().__new__(cls, *args, **kwargs)
|
||||
self.snames = []
|
||||
self.stack = None
|
||||
return self
|
||||
|
||||
|
||||
|
|
@ -489,6 +492,8 @@ class ShapeEnv(object):
|
|||
# they get assigned the same symbolic variable
|
||||
self.val_to_var: Dict[int, "sympy.Expr"] = {0: sympy.Integer(0), 1: sympy.Integer(1)}
|
||||
self.tls = threading.local()
|
||||
self.unbacked_symfloat_counter = itertools.count()
|
||||
self.unbacked_symint_counter = itertools.count()
|
||||
|
||||
def _suppress_guards_tls(self):
|
||||
return getattr(self.tls, "suppress_guards", False)
|
||||
|
|
@ -557,6 +562,16 @@ class ShapeEnv(object):
|
|||
def create_symintnode(self, sym: "sympy.Expr"):
|
||||
return SymInt(SymNode(sym, self, int))
|
||||
|
||||
def create_unbacked_symfloat(self):
|
||||
symbol = Symbol(f"f{next(self.unbacked_symfloat_counter)}")
|
||||
symbol.stack = ''.join(traceback.format_list(traceback.extract_stack()[:-1]))
|
||||
return SymFloat(SymNode(symbol, self, float))
|
||||
|
||||
def create_unbacked_symint(self):
|
||||
symbol = Symbol(f"i{next(self.unbacked_symint_counter)}", integer=True)
|
||||
symbol.stack = ''.join(traceback.format_list(traceback.extract_stack()[:-1]))
|
||||
return SymInt(SymNode(symbol, self, int))
|
||||
|
||||
# This is guaranteed to return a symbol or its negation is a sympy.Symbol,
|
||||
# but there may be a replacement that allows it to be immediately
|
||||
# simplified
|
||||
|
|
@ -776,6 +791,8 @@ class ShapeEnv(object):
|
|||
new_shape_env = {
|
||||
k: sympy.Symbol(f"shape_{idx}", positive=True, integer=True) + 1
|
||||
for idx, k in enumerate(symbols)
|
||||
# Do not assume unbacked symints are > 1
|
||||
if k in self.var_to_val
|
||||
}
|
||||
new_expr = expr.xreplace(new_shape_env)
|
||||
floor_div_replace = {}
|
||||
|
|
@ -823,9 +840,27 @@ class ShapeEnv(object):
|
|||
your code is still valid for arbitrary shapes (such as optimization decisions)
|
||||
"""
|
||||
result_expr = sympy.expand(expr).xreplace(self.var_to_val)
|
||||
assert len(result_expr.free_symbols) == 0, "Size hint has variables we don't have underlying values for"
|
||||
if len(result_expr.free_symbols) != 0:
|
||||
raise self._make_data_dependent_error(result_expr)
|
||||
return result_expr
|
||||
|
||||
def _make_data_dependent_error(self, expr):
|
||||
# TODO: in a Dynamo context, having user code, and having the
|
||||
# name of the local, will be much better
|
||||
accesses = '\n\n'.join(
|
||||
f"Data dependent variable '{s}' allocated at:\n{s.stack}"
|
||||
for s in expr.free_symbols
|
||||
)
|
||||
return RuntimeError(
|
||||
f"\n\n{accesses}\n"
|
||||
"RuntimeError: It appears that you're trying to get a value out of symbolic int/float "
|
||||
"whose value is data-dependent (and thus we do not know the true value.) "
|
||||
f"The expression we were trying to evaluate is {expr}. "
|
||||
"Scroll up to see where each of these data-dependent accesses originally occurred."
|
||||
# TODO: Help text about how to use our runtime tests to fix this
|
||||
# problem
|
||||
)
|
||||
|
||||
@_lru_cache
|
||||
def _find(self, a: "sympy.Symbol") -> "sympy.Expr":
|
||||
"""
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user