mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Disallow {FakeTensor,FunctionalTensor}.data_ptr (#122514)
This PR: - disallows FakeTensor.data_ptr when it is called inside PT2 or fx tracing. - disallows FunctionalTensor.data_ptr (python FunctionalTensor is only used in PT2) The motivation behind this is that the leading cause of segfaults when using custom ops with PT2 is calling .data_ptr on FunctionalTensor or FakeTensor. This change is BC-breaking. If your code broke as a result of this, it's because there was a bug in it (these .data_ptr should never be accessed!). You can either fix the bug (recommended) or get the previous behavior back with: ``` from torch._subclasses.fake_tensor import FakeTensor from torch._subclasses.functional_tensor import FunctionalTensor data_ptr = 0 if isinstance(tensor, (FakeTensor, FunctionalTensor)) else tensor.data_ptr() ``` Test Plan: - existing tests Differential Revision: [D55366199](https://our.internmc.facebook.com/intern/diff/D55366199) Pull Request resolved: https://github.com/pytorch/pytorch/pull/122514 Approved by: https://github.com/ezyang, https://github.com/albanD, https://github.com/yifuwang, https://github.com/kurtamohler
This commit is contained in:
parent
04399a3091
commit
c81c9ba472
|
|
@ -11,6 +11,16 @@ C10_API std::array<StorageImplCreateHelper, at::COMPILE_TIME_MAX_DEVICE_TYPES>
|
|||
static ska::flat_hash_set<c10::DeviceType> DeviceTypeAllowList{
|
||||
DeviceType::PrivateUse1};
|
||||
|
||||
void throwNullDataPtrError() {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Cannot access data pointer of Tensor (e.g. FakeTensor, FunctionalTensor). "
|
||||
"If you're using torch.compile/export/fx, it is likely that we are erroneously "
|
||||
"tracing into a custom kernel. To fix this, please wrap the custom kernel into "
|
||||
"an opaque custom op. Please see the following for details: "
|
||||
"https://docs.google.com/document/d/1W--T6wz8IY8fOI0Vm8BF44PdBgs283QvpelJZWieQWQ");
|
||||
}
|
||||
|
||||
void SetStorageImplCreate(DeviceType t, StorageImplCreateHelper fptr) {
|
||||
// Allowlist verification.
|
||||
// Only if the devicetype is in the allowlist,
|
||||
|
|
|
|||
|
|
@ -16,6 +16,8 @@
|
|||
|
||||
namespace c10 {
|
||||
|
||||
C10_API void throwNullDataPtrError();
|
||||
|
||||
// A storage represents the underlying backing data buffer for a
|
||||
// tensor. This concept was inherited from the original Torch7
|
||||
// codebase; we'd kind of like to get rid of the concept
|
||||
|
|
@ -59,6 +61,7 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
|
|||
TORCH_INTERNAL_ASSERT(
|
||||
allocator_, "For resizable storage, allocator must be provided");
|
||||
}
|
||||
refresh_has_data_ptr_check();
|
||||
}
|
||||
|
||||
StorageImpl(
|
||||
|
|
@ -118,12 +121,22 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
|
|||
return resizable_;
|
||||
}
|
||||
|
||||
at::DataPtr& mutable_data_ptr() {
|
||||
maybe_materialize_cow();
|
||||
const at::DataPtr& data_ptr() const {
|
||||
return data_ptr_;
|
||||
}
|
||||
|
||||
const at::DataPtr& data_ptr() const {
|
||||
at::DataPtr& mutable_data_ptr() {
|
||||
if (C10_UNLIKELY(has_data_ptr_check_)) {
|
||||
if (throw_on_mutable_data_ptr_) {
|
||||
throwNullDataPtrError();
|
||||
}
|
||||
maybe_materialize_cow();
|
||||
}
|
||||
return data_ptr_;
|
||||
}
|
||||
|
||||
// Returns the data_ptr. Bypasses all checks.
|
||||
at::DataPtr& _mutable_data_ptr_no_checks() {
|
||||
return data_ptr_;
|
||||
}
|
||||
|
||||
|
|
@ -137,6 +150,7 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
|
|||
|
||||
void set_data_ptr_noswap(at::DataPtr&& data_ptr) {
|
||||
data_ptr_ = std::move(data_ptr);
|
||||
refresh_has_data_ptr_check();
|
||||
}
|
||||
|
||||
const void* data() const {
|
||||
|
|
@ -144,7 +158,12 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
|
|||
}
|
||||
|
||||
void* mutable_data() {
|
||||
maybe_materialize_cow();
|
||||
if (C10_UNLIKELY(has_data_ptr_check_)) {
|
||||
if (throw_on_mutable_data_ptr_) {
|
||||
throwNullDataPtrError();
|
||||
}
|
||||
maybe_materialize_cow();
|
||||
}
|
||||
return data_ptr_.mutable_get();
|
||||
}
|
||||
|
||||
|
|
@ -222,6 +241,11 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
|
|||
return &pyobj_slot_;
|
||||
}
|
||||
|
||||
void set_throw_on_mutable_data_ptr() {
|
||||
throw_on_mutable_data_ptr_ = true;
|
||||
refresh_has_data_ptr_check();
|
||||
}
|
||||
|
||||
protected:
|
||||
// materialize_cow_storage needs to call set_data_ptr_no_materlize_cow
|
||||
friend void c10::impl::cow::materialize_cow_storage(StorageImpl& storage);
|
||||
|
|
@ -231,13 +255,22 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
|
|||
at::DataPtr set_data_ptr_no_materialize_cow(at::DataPtr&& data_ptr) {
|
||||
at::DataPtr old_data_ptr(std::move(data_ptr_));
|
||||
data_ptr_ = std::move(data_ptr);
|
||||
refresh_has_data_ptr_check();
|
||||
return old_data_ptr;
|
||||
}
|
||||
|
||||
private:
|
||||
void refresh_has_data_ptr_check() {
|
||||
has_data_ptr_check_ = is_cow() || throw_on_mutable_data_ptr_;
|
||||
}
|
||||
|
||||
inline bool is_cow() const {
|
||||
return c10::impl::cow::is_cow_data_ptr(data_ptr_);
|
||||
}
|
||||
|
||||
// Triggers a copy if this is a copy-on-write tensor.
|
||||
void maybe_materialize_cow() {
|
||||
if (data_ptr_.get_deleter() == impl::cow::cow_deleter) {
|
||||
if (is_cow()) {
|
||||
impl::cow::materialize_cow_storage(*this);
|
||||
}
|
||||
}
|
||||
|
|
@ -249,6 +282,12 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
|
|||
// Identifies that Storage was received from another process and doesn't have
|
||||
// local to process cuda memory allocation
|
||||
bool received_cuda_;
|
||||
// All special checks in data/data_ptr calls are guarded behind this single
|
||||
// boolean. This is for performance: .data/.data_ptr calls are commonly in the
|
||||
// hot-path.
|
||||
bool has_data_ptr_check_ = false;
|
||||
// If we should throw when mutable_data_ptr() or mutable_data() is called.
|
||||
bool throw_on_mutable_data_ptr_ = false;
|
||||
Allocator* allocator_;
|
||||
impl::PyObjectSlot pyobj_slot_;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -81,7 +81,7 @@ c10::intrusive_ptr<StorageImpl> lazy_clone_storage(StorageImpl& storage) {
|
|||
if (has_simple_data_ptr(storage)) {
|
||||
// Case 1) We have a simple data pointer: wrap it.
|
||||
std::unique_ptr<void, DeleterFnPtr> original_ctx =
|
||||
storage.mutable_data_ptr().move_context();
|
||||
storage._mutable_data_ptr_no_checks().move_context();
|
||||
|
||||
// Save this for the result.
|
||||
new_data_ptr = make_data_ptr(
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
# Owner(s): ["module: dynamo"]
|
||||
import copy
|
||||
import re
|
||||
import unittest
|
||||
from textwrap import dedent
|
||||
|
|
@ -12,6 +13,8 @@ import torch.fx.traceback as fx_traceback
|
|||
import torch.utils._pytree as pytree
|
||||
from torch._dynamo.testing import CompileCounter, expectedFailureDynamic, rand_strided
|
||||
from torch._functorch.aot_autograd import _aot_export_function, create_functional_call
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.profiler import profile
|
||||
from torch.testing._internal.common_utils import compare_equal_outs_and_grads
|
||||
|
||||
|
|
@ -1104,6 +1107,77 @@ SeqNr|OrigAten|SrcFn
|
|||
self.assertEqual(x, x_opt)
|
||||
self.assertEqual(z.grad, z_opt.grad)
|
||||
|
||||
def test_data_ptr_access_copy(self):
|
||||
with FakeTensorMode(_allow_unsafe_data_ptr_access=False):
|
||||
x = torch.randn(3)
|
||||
y = copy.copy(x)
|
||||
self.assertEqual(y.shape, x.shape)
|
||||
|
||||
def test_data_ptr_access_fails_in_forward(self):
|
||||
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
|
||||
torch.library.define("mylib::foo", "(Tensor x) -> Tensor", lib=lib)
|
||||
|
||||
@torch.library.impl("mylib::foo", "CompositeImplicitAutograd", lib=lib)
|
||||
def _(x):
|
||||
x.data_ptr()
|
||||
return x.clone()
|
||||
|
||||
x = torch.randn(3)
|
||||
|
||||
def data_ptr_graph_input(x):
|
||||
r0 = torch.ops.mylib.foo(x)
|
||||
return r0
|
||||
|
||||
def data_ptr_graph_intermediate(x):
|
||||
y = x.clone()
|
||||
r0 = torch.ops.mylib.foo(y)
|
||||
return r0
|
||||
|
||||
tests = [data_ptr_graph_input, data_ptr_graph_intermediate]
|
||||
|
||||
def ctx():
|
||||
return self.assertRaisesRegex(
|
||||
RuntimeError, "Cannot access data pointer"
|
||||
)
|
||||
|
||||
for f in tests:
|
||||
with ctx():
|
||||
make_fx(f, tracing_mode="fake")(x)
|
||||
with ctx():
|
||||
make_fx(f, tracing_mode="symbolic")(x)
|
||||
with ctx():
|
||||
torch.compile(f, backend="eager", fullgraph=True)(x)
|
||||
|
||||
def test_data_ptr_access_fails_in_backward(self):
|
||||
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
|
||||
torch.library.define("mylib::foo", "(Tensor x) -> Tensor", lib=lib)
|
||||
|
||||
backward_called = False
|
||||
|
||||
class Foo(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
return x.clone()
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad):
|
||||
nonlocal backward_called
|
||||
backward_called = True
|
||||
grad.data_ptr()
|
||||
return grad.clone()
|
||||
|
||||
@torch.library.impl("mylib::foo", "CompositeImplicitAutograd", lib=lib)
|
||||
def _(x):
|
||||
return Foo.apply(x)
|
||||
|
||||
def f(x):
|
||||
return torch.ops.mylib.foo(x)
|
||||
|
||||
x = torch.randn(3, requires_grad=True)
|
||||
with self.assertRaisesRegex(RuntimeError, "Cannot access data pointer"):
|
||||
y = torch.compile(f, backend="aot_eager", fullgraph=True)(x)
|
||||
self.assertTrue(backward_called)
|
||||
|
||||
# We don't know how to catch multiple mutations to the same memory location
|
||||
@unittest.expectedFailure
|
||||
def test_aot_autograd_expand_mutation_error(self):
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
import functools
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
from unittest import skipIf as skipif
|
||||
|
||||
|
|
@ -15,6 +16,7 @@ from torch.testing._internal.common_utils import (
|
|||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
run_tests,
|
||||
skipIfTorchDynamo,
|
||||
TEST_WITH_TORCHDYNAMO,
|
||||
TestCase,
|
||||
xpassIfTorchDynamo,
|
||||
|
|
@ -46,7 +48,8 @@ class TestDLPack(TestCase):
|
|||
del y
|
||||
assert sys.getrefcount(x) == 2
|
||||
|
||||
@xpassIfTorchDynamo # (reason="pytorch does not raise")
|
||||
@unittest.expectedFailure
|
||||
@skipIfTorchDynamo("I can't figure out how to get __dlpack__ into trace_rules.py")
|
||||
def test_dunder_dlpack_stream(self):
|
||||
x = np.arange(5)
|
||||
x.__dlpack__(stream=None)
|
||||
|
|
|
|||
|
|
@ -1472,6 +1472,7 @@ def _dispatch_pystub(name: str, overload: str) -> Optional[Tuple[str, str]]: ...
|
|||
def _dispatch_is_alias_key(dispatch: _dispatchkey) -> _bool: ...
|
||||
def _functionality_to_backend_keys(dispatch: _dispatchkey) -> List[DispatchKey]: ...
|
||||
def _functionalization_reapply_views_tls() -> _bool: ...
|
||||
def _set_throw_on_mutable_data_ptr(tensor: Tensor) -> None: ...
|
||||
|
||||
class DispatchKey(Enum):
|
||||
${dispatch_key_hints}
|
||||
|
|
|
|||
|
|
@ -296,6 +296,7 @@ class OutputGraph:
|
|||
shape_env=shape_env,
|
||||
# TODO (tmanlaibaatar) Remove this once we always lift params and buffers
|
||||
allow_non_fake_inputs=True if self.export else False,
|
||||
_allow_unsafe_data_ptr_access=False,
|
||||
)
|
||||
self.tracing_context: TracingContext = TracingContext(fake_mode)
|
||||
self.init_ambient_guards()
|
||||
|
|
@ -1138,6 +1139,7 @@ class OutputGraph:
|
|||
# TODO(voz): The way export uses gm, and fake tensors, is not supported with us resetting
|
||||
backend_fake_mode = torch._subclasses.FakeTensorMode(
|
||||
shape_env=old_fake_mode.shape_env,
|
||||
_allow_unsafe_data_ptr_access=False,
|
||||
)
|
||||
# TODO(voz): Ostensibily, this should be scoped and
|
||||
# restore back to old_fake_mode, but doing so currently violates
|
||||
|
|
|
|||
|
|
@ -104,6 +104,8 @@ manual_torch_name_rule_map = {
|
|||
"torch.compiler.is_compiling": TorchInGraphFunctionVariable,
|
||||
"torch.compiler.is_dynamo_compiling": TorchInGraphFunctionVariable,
|
||||
"torch.autograd._profiler_enabled": SkipFunctionVariable,
|
||||
"torch._C._to_dlpack": SkipFunctionVariable,
|
||||
"torch.to_dlpack": SkipFunctionVariable,
|
||||
# We graph break on RNG state setters or getters like
|
||||
# `torch.get_rng_state` or `torch.set_rng_state`. These functions
|
||||
# are not aten operations and therefore they are completely ignored
|
||||
|
|
@ -1187,7 +1189,6 @@ torch_c_binding_in_graph_functions = dict.fromkeys(
|
|||
"torch._C._test_only_populate_upgraders",
|
||||
"torch._C._test_only_remove_entry_to_op_version_map",
|
||||
"torch._C._test_only_remove_upgraders",
|
||||
"torch._C._to_dlpack",
|
||||
"torch._C._to_functionality_key",
|
||||
"torch._C._tracer_set_force_outplace",
|
||||
"torch._C._tracer_set_get_unique_name_fn",
|
||||
|
|
|
|||
|
|
@ -433,6 +433,8 @@ class FakeTensor(torch.Tensor):
|
|||
dispatch_device=True,
|
||||
device_for_backend_keys=device,
|
||||
)
|
||||
if not fake_mode._allow_unsafe_data_ptr_access:
|
||||
torch._C._set_throw_on_mutable_data_ptr(self)
|
||||
|
||||
assert elem.device.type == "meta", elem.device.type
|
||||
device = device if isinstance(device, torch.device) else torch.device(device)
|
||||
|
|
@ -759,8 +761,10 @@ class FakeTensorMode(TorchDispatchMode):
|
|||
allow_non_fake_inputs=False,
|
||||
shape_env=None,
|
||||
static_shapes=None,
|
||||
_allow_unsafe_data_ptr_access=True,
|
||||
):
|
||||
log.debug("create_mode 0x%x", id(self))
|
||||
self._allow_unsafe_data_ptr_access = _allow_unsafe_data_ptr_access
|
||||
self.allow_fallback_kernels = allow_fallback_kernels
|
||||
self.fake_tensor_converter = FakeTensorConverter()
|
||||
if static_shapes is not None:
|
||||
|
|
|
|||
|
|
@ -120,6 +120,7 @@ class FunctionalTensor(torch.Tensor):
|
|||
False, # dispatch_layout
|
||||
extra_dispatch_keys, # _extra_dispatch_keys
|
||||
)
|
||||
torch._C._set_throw_on_mutable_data_ptr(out)
|
||||
out.elem = elem
|
||||
return out
|
||||
|
||||
|
|
|
|||
|
|
@ -378,9 +378,18 @@ class Tensor(torch._C.TensorBase):
|
|||
)
|
||||
return (torch._utils._rebuild_nested_tensor, args_nested)
|
||||
elif (
|
||||
self.data_ptr() == 0
|
||||
and type(self) is not torch.Tensor
|
||||
type(self) is not torch.Tensor
|
||||
and type(self).__torch_dispatch__ is not torch.Tensor.__torch_dispatch__
|
||||
and (
|
||||
isinstance(
|
||||
self,
|
||||
(
|
||||
torch._subclasses.fake_tensor.FakeTensor,
|
||||
torch._subclasses.functional_tensor.FunctionalTensor,
|
||||
),
|
||||
)
|
||||
or self.data_ptr() == 0
|
||||
)
|
||||
):
|
||||
arg_wrapper_subclass = (
|
||||
type(self),
|
||||
|
|
|
|||
|
|
@ -536,11 +536,33 @@ _allow_mutation_on_saved_tensors_enabled = False
|
|||
|
||||
|
||||
def _get_tid(t) -> Tuple[int, int, int]:
|
||||
return (id(t), t.data_ptr(), t._version)
|
||||
# FIXME: This is almost definitely a bug.
|
||||
if isinstance(
|
||||
t,
|
||||
(
|
||||
torch._subclasses.fake_tensor.FakeTensor,
|
||||
torch._subclasses.functional_tensor.FunctionalTensor,
|
||||
),
|
||||
):
|
||||
data_ptr = 0
|
||||
else:
|
||||
data_ptr = t.data_ptr()
|
||||
return (id(t), data_ptr, t._version)
|
||||
|
||||
|
||||
def _get_sid(t) -> Tuple[int, int]:
|
||||
return (t.data_ptr(), t._version)
|
||||
# FIXME: This is almost definitely a bug.
|
||||
if isinstance(
|
||||
t,
|
||||
(
|
||||
torch._subclasses.fake_tensor.FakeTensor,
|
||||
torch._subclasses.functional_tensor.FunctionalTensor,
|
||||
),
|
||||
):
|
||||
data_ptr = 0
|
||||
else:
|
||||
data_ptr = t.data_ptr()
|
||||
return (data_ptr, t._version)
|
||||
|
||||
|
||||
class _Handle:
|
||||
|
|
|
|||
|
|
@ -838,6 +838,20 @@ void initDispatchBindings(PyObject* module) {
|
|||
return a.sizes(); // NB: NOT sym_size
|
||||
});
|
||||
|
||||
m.def("_set_throw_on_mutable_data_ptr", [](const at::Tensor& t) {
|
||||
if (!t.unsafeGetTensorImpl()->has_storage()) {
|
||||
// If the Tensor doesn't have a storage, then accessing .data_ptr()
|
||||
// will already raise an error.
|
||||
return;
|
||||
}
|
||||
// Otherwise, set (on the StorageImpl) that accessing (mutable) data_ptr
|
||||
// will throw.
|
||||
t.unsafeGetTensorImpl()
|
||||
->storage()
|
||||
.unsafeGetStorageImpl()
|
||||
->set_throw_on_mutable_data_ptr();
|
||||
});
|
||||
|
||||
using c10::impl::TorchDispatchModeKey;
|
||||
py::enum_<TorchDispatchModeKey>(m, "_TorchDispatchModeKey")
|
||||
.value("FUNCTIONAL", TorchDispatchModeKey::FUNCTIONAL)
|
||||
|
|
|
|||
|
|
@ -127,7 +127,19 @@ def _register_tensor_wrapper(tensor) -> None:
|
|||
# Tensor storage -> work mapping is maintained in C++
|
||||
return
|
||||
global data_ptr_to_work
|
||||
data_ptr = tensor.elem.data_ptr()
|
||||
|
||||
# FIXME: This is almost definitely a bug.
|
||||
if isinstance(
|
||||
tensor.elem,
|
||||
(
|
||||
torch._subclasses.fake_tensor.FakeTensor,
|
||||
torch._subclasses.functional_tensor.FunctionalTensor,
|
||||
),
|
||||
):
|
||||
data_ptr = 0
|
||||
else:
|
||||
data_ptr = tensor.elem.data_ptr()
|
||||
|
||||
# Note: we should NEVER try to trace this, bc it registers runtime stuff during trace.
|
||||
# Instead, backends must call this themselves when implementing traced collectives.
|
||||
wait_reg = data_ptr_to_work.get(data_ptr, None)
|
||||
|
|
|
|||
|
|
@ -1101,6 +1101,7 @@ def make_fx(f,
|
|||
allow_non_fake_inputs=_allow_non_fake_inputs,
|
||||
shape_env=ShapeEnv(),
|
||||
static_shapes=True,
|
||||
_allow_unsafe_data_ptr_access=False,
|
||||
)
|
||||
elif tracing_mode == "symbolic":
|
||||
import torch._dynamo
|
||||
|
|
@ -1110,7 +1111,8 @@ def make_fx(f,
|
|||
fake_tensor_mode = FakeTensorMode(
|
||||
allow_fallback_kernels=False,
|
||||
allow_non_fake_inputs=_allow_non_fake_inputs,
|
||||
shape_env=shape_env)
|
||||
shape_env=shape_env,
|
||||
_allow_unsafe_data_ptr_access=False)
|
||||
else:
|
||||
shape_env = fake_tensor_mode.shape_env
|
||||
assert shape_env is not None, "shape_env should be set if tracing with 'symbolic'"
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user