Disallow {FakeTensor,FunctionalTensor}.data_ptr (#122514)

This PR:
- disallows FakeTensor.data_ptr when it is called inside PT2 or fx tracing.
- disallows FunctionalTensor.data_ptr (python FunctionalTensor is only used in
  PT2)

The motivation behind this is that the leading cause of segfaults when
using custom ops with PT2 is calling .data_ptr on FunctionalTensor or
FakeTensor.

This change is BC-breaking. If your code broke as a result of this, it's
because there was a bug in it (these .data_ptr should never be
accessed!). You can either fix the bug (recommended) or get the previous
behavior back with:
```
from torch._subclasses.fake_tensor import FakeTensor
from torch._subclasses.functional_tensor import FunctionalTensor

data_ptr = 0 if isinstance(tensor, (FakeTensor, FunctionalTensor)) else tensor.data_ptr()
```

Test Plan:
- existing tests

Differential Revision: [D55366199](https://our.internmc.facebook.com/intern/diff/D55366199)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122514
Approved by: https://github.com/ezyang, https://github.com/albanD, https://github.com/yifuwang, https://github.com/kurtamohler
This commit is contained in:
rzou 2024-03-26 10:41:25 -07:00 committed by PyTorch MergeBot
parent 04399a3091
commit c81c9ba472
15 changed files with 208 additions and 14 deletions

View File

@ -11,6 +11,16 @@ C10_API std::array<StorageImplCreateHelper, at::COMPILE_TIME_MAX_DEVICE_TYPES>
static ska::flat_hash_set<c10::DeviceType> DeviceTypeAllowList{
DeviceType::PrivateUse1};
void throwNullDataPtrError() {
TORCH_CHECK(
false,
"Cannot access data pointer of Tensor (e.g. FakeTensor, FunctionalTensor). "
"If you're using torch.compile/export/fx, it is likely that we are erroneously "
"tracing into a custom kernel. To fix this, please wrap the custom kernel into "
"an opaque custom op. Please see the following for details: "
"https://docs.google.com/document/d/1W--T6wz8IY8fOI0Vm8BF44PdBgs283QvpelJZWieQWQ");
}
void SetStorageImplCreate(DeviceType t, StorageImplCreateHelper fptr) {
// Allowlist verification.
// Only if the devicetype is in the allowlist,

View File

@ -16,6 +16,8 @@
namespace c10 {
C10_API void throwNullDataPtrError();
// A storage represents the underlying backing data buffer for a
// tensor. This concept was inherited from the original Torch7
// codebase; we'd kind of like to get rid of the concept
@ -59,6 +61,7 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
TORCH_INTERNAL_ASSERT(
allocator_, "For resizable storage, allocator must be provided");
}
refresh_has_data_ptr_check();
}
StorageImpl(
@ -118,12 +121,22 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
return resizable_;
}
at::DataPtr& mutable_data_ptr() {
maybe_materialize_cow();
const at::DataPtr& data_ptr() const {
return data_ptr_;
}
const at::DataPtr& data_ptr() const {
at::DataPtr& mutable_data_ptr() {
if (C10_UNLIKELY(has_data_ptr_check_)) {
if (throw_on_mutable_data_ptr_) {
throwNullDataPtrError();
}
maybe_materialize_cow();
}
return data_ptr_;
}
// Returns the data_ptr. Bypasses all checks.
at::DataPtr& _mutable_data_ptr_no_checks() {
return data_ptr_;
}
@ -137,6 +150,7 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
void set_data_ptr_noswap(at::DataPtr&& data_ptr) {
data_ptr_ = std::move(data_ptr);
refresh_has_data_ptr_check();
}
const void* data() const {
@ -144,7 +158,12 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
}
void* mutable_data() {
maybe_materialize_cow();
if (C10_UNLIKELY(has_data_ptr_check_)) {
if (throw_on_mutable_data_ptr_) {
throwNullDataPtrError();
}
maybe_materialize_cow();
}
return data_ptr_.mutable_get();
}
@ -222,6 +241,11 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
return &pyobj_slot_;
}
void set_throw_on_mutable_data_ptr() {
throw_on_mutable_data_ptr_ = true;
refresh_has_data_ptr_check();
}
protected:
// materialize_cow_storage needs to call set_data_ptr_no_materlize_cow
friend void c10::impl::cow::materialize_cow_storage(StorageImpl& storage);
@ -231,13 +255,22 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
at::DataPtr set_data_ptr_no_materialize_cow(at::DataPtr&& data_ptr) {
at::DataPtr old_data_ptr(std::move(data_ptr_));
data_ptr_ = std::move(data_ptr);
refresh_has_data_ptr_check();
return old_data_ptr;
}
private:
void refresh_has_data_ptr_check() {
has_data_ptr_check_ = is_cow() || throw_on_mutable_data_ptr_;
}
inline bool is_cow() const {
return c10::impl::cow::is_cow_data_ptr(data_ptr_);
}
// Triggers a copy if this is a copy-on-write tensor.
void maybe_materialize_cow() {
if (data_ptr_.get_deleter() == impl::cow::cow_deleter) {
if (is_cow()) {
impl::cow::materialize_cow_storage(*this);
}
}
@ -249,6 +282,12 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
// Identifies that Storage was received from another process and doesn't have
// local to process cuda memory allocation
bool received_cuda_;
// All special checks in data/data_ptr calls are guarded behind this single
// boolean. This is for performance: .data/.data_ptr calls are commonly in the
// hot-path.
bool has_data_ptr_check_ = false;
// If we should throw when mutable_data_ptr() or mutable_data() is called.
bool throw_on_mutable_data_ptr_ = false;
Allocator* allocator_;
impl::PyObjectSlot pyobj_slot_;
};

View File

@ -81,7 +81,7 @@ c10::intrusive_ptr<StorageImpl> lazy_clone_storage(StorageImpl& storage) {
if (has_simple_data_ptr(storage)) {
// Case 1) We have a simple data pointer: wrap it.
std::unique_ptr<void, DeleterFnPtr> original_ctx =
storage.mutable_data_ptr().move_context();
storage._mutable_data_ptr_no_checks().move_context();
// Save this for the result.
new_data_ptr = make_data_ptr(

View File

@ -1,4 +1,5 @@
# Owner(s): ["module: dynamo"]
import copy
import re
import unittest
from textwrap import dedent
@ -12,6 +13,8 @@ import torch.fx.traceback as fx_traceback
import torch.utils._pytree as pytree
from torch._dynamo.testing import CompileCounter, expectedFailureDynamic, rand_strided
from torch._functorch.aot_autograd import _aot_export_function, create_functional_call
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.proxy_tensor import make_fx
from torch.profiler import profile
from torch.testing._internal.common_utils import compare_equal_outs_and_grads
@ -1104,6 +1107,77 @@ SeqNr|OrigAten|SrcFn
self.assertEqual(x, x_opt)
self.assertEqual(z.grad, z_opt.grad)
def test_data_ptr_access_copy(self):
with FakeTensorMode(_allow_unsafe_data_ptr_access=False):
x = torch.randn(3)
y = copy.copy(x)
self.assertEqual(y.shape, x.shape)
def test_data_ptr_access_fails_in_forward(self):
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
torch.library.define("mylib::foo", "(Tensor x) -> Tensor", lib=lib)
@torch.library.impl("mylib::foo", "CompositeImplicitAutograd", lib=lib)
def _(x):
x.data_ptr()
return x.clone()
x = torch.randn(3)
def data_ptr_graph_input(x):
r0 = torch.ops.mylib.foo(x)
return r0
def data_ptr_graph_intermediate(x):
y = x.clone()
r0 = torch.ops.mylib.foo(y)
return r0
tests = [data_ptr_graph_input, data_ptr_graph_intermediate]
def ctx():
return self.assertRaisesRegex(
RuntimeError, "Cannot access data pointer"
)
for f in tests:
with ctx():
make_fx(f, tracing_mode="fake")(x)
with ctx():
make_fx(f, tracing_mode="symbolic")(x)
with ctx():
torch.compile(f, backend="eager", fullgraph=True)(x)
def test_data_ptr_access_fails_in_backward(self):
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
torch.library.define("mylib::foo", "(Tensor x) -> Tensor", lib=lib)
backward_called = False
class Foo(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return x.clone()
@staticmethod
def backward(ctx, grad):
nonlocal backward_called
backward_called = True
grad.data_ptr()
return grad.clone()
@torch.library.impl("mylib::foo", "CompositeImplicitAutograd", lib=lib)
def _(x):
return Foo.apply(x)
def f(x):
return torch.ops.mylib.foo(x)
x = torch.randn(3, requires_grad=True)
with self.assertRaisesRegex(RuntimeError, "Cannot access data pointer"):
y = torch.compile(f, backend="aot_eager", fullgraph=True)(x)
self.assertTrue(backward_called)
# We don't know how to catch multiple mutations to the same memory location
@unittest.expectedFailure
def test_aot_autograd_expand_mutation_error(self):

View File

@ -2,6 +2,7 @@
import functools
import sys
import unittest
from unittest import skipIf as skipif
@ -15,6 +16,7 @@ from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
skipIfTorchDynamo,
TEST_WITH_TORCHDYNAMO,
TestCase,
xpassIfTorchDynamo,
@ -46,7 +48,8 @@ class TestDLPack(TestCase):
del y
assert sys.getrefcount(x) == 2
@xpassIfTorchDynamo # (reason="pytorch does not raise")
@unittest.expectedFailure
@skipIfTorchDynamo("I can't figure out how to get __dlpack__ into trace_rules.py")
def test_dunder_dlpack_stream(self):
x = np.arange(5)
x.__dlpack__(stream=None)

View File

@ -1472,6 +1472,7 @@ def _dispatch_pystub(name: str, overload: str) -> Optional[Tuple[str, str]]: ...
def _dispatch_is_alias_key(dispatch: _dispatchkey) -> _bool: ...
def _functionality_to_backend_keys(dispatch: _dispatchkey) -> List[DispatchKey]: ...
def _functionalization_reapply_views_tls() -> _bool: ...
def _set_throw_on_mutable_data_ptr(tensor: Tensor) -> None: ...
class DispatchKey(Enum):
${dispatch_key_hints}

View File

@ -296,6 +296,7 @@ class OutputGraph:
shape_env=shape_env,
# TODO (tmanlaibaatar) Remove this once we always lift params and buffers
allow_non_fake_inputs=True if self.export else False,
_allow_unsafe_data_ptr_access=False,
)
self.tracing_context: TracingContext = TracingContext(fake_mode)
self.init_ambient_guards()
@ -1138,6 +1139,7 @@ class OutputGraph:
# TODO(voz): The way export uses gm, and fake tensors, is not supported with us resetting
backend_fake_mode = torch._subclasses.FakeTensorMode(
shape_env=old_fake_mode.shape_env,
_allow_unsafe_data_ptr_access=False,
)
# TODO(voz): Ostensibily, this should be scoped and
# restore back to old_fake_mode, but doing so currently violates

View File

@ -104,6 +104,8 @@ manual_torch_name_rule_map = {
"torch.compiler.is_compiling": TorchInGraphFunctionVariable,
"torch.compiler.is_dynamo_compiling": TorchInGraphFunctionVariable,
"torch.autograd._profiler_enabled": SkipFunctionVariable,
"torch._C._to_dlpack": SkipFunctionVariable,
"torch.to_dlpack": SkipFunctionVariable,
# We graph break on RNG state setters or getters like
# `torch.get_rng_state` or `torch.set_rng_state`. These functions
# are not aten operations and therefore they are completely ignored
@ -1187,7 +1189,6 @@ torch_c_binding_in_graph_functions = dict.fromkeys(
"torch._C._test_only_populate_upgraders",
"torch._C._test_only_remove_entry_to_op_version_map",
"torch._C._test_only_remove_upgraders",
"torch._C._to_dlpack",
"torch._C._to_functionality_key",
"torch._C._tracer_set_force_outplace",
"torch._C._tracer_set_get_unique_name_fn",

View File

@ -433,6 +433,8 @@ class FakeTensor(torch.Tensor):
dispatch_device=True,
device_for_backend_keys=device,
)
if not fake_mode._allow_unsafe_data_ptr_access:
torch._C._set_throw_on_mutable_data_ptr(self)
assert elem.device.type == "meta", elem.device.type
device = device if isinstance(device, torch.device) else torch.device(device)
@ -759,8 +761,10 @@ class FakeTensorMode(TorchDispatchMode):
allow_non_fake_inputs=False,
shape_env=None,
static_shapes=None,
_allow_unsafe_data_ptr_access=True,
):
log.debug("create_mode 0x%x", id(self))
self._allow_unsafe_data_ptr_access = _allow_unsafe_data_ptr_access
self.allow_fallback_kernels = allow_fallback_kernels
self.fake_tensor_converter = FakeTensorConverter()
if static_shapes is not None:

View File

@ -120,6 +120,7 @@ class FunctionalTensor(torch.Tensor):
False, # dispatch_layout
extra_dispatch_keys, # _extra_dispatch_keys
)
torch._C._set_throw_on_mutable_data_ptr(out)
out.elem = elem
return out

View File

@ -378,9 +378,18 @@ class Tensor(torch._C.TensorBase):
)
return (torch._utils._rebuild_nested_tensor, args_nested)
elif (
self.data_ptr() == 0
and type(self) is not torch.Tensor
type(self) is not torch.Tensor
and type(self).__torch_dispatch__ is not torch.Tensor.__torch_dispatch__
and (
isinstance(
self,
(
torch._subclasses.fake_tensor.FakeTensor,
torch._subclasses.functional_tensor.FunctionalTensor,
),
)
or self.data_ptr() == 0
)
):
arg_wrapper_subclass = (
type(self),

View File

@ -536,11 +536,33 @@ _allow_mutation_on_saved_tensors_enabled = False
def _get_tid(t) -> Tuple[int, int, int]:
return (id(t), t.data_ptr(), t._version)
# FIXME: This is almost definitely a bug.
if isinstance(
t,
(
torch._subclasses.fake_tensor.FakeTensor,
torch._subclasses.functional_tensor.FunctionalTensor,
),
):
data_ptr = 0
else:
data_ptr = t.data_ptr()
return (id(t), data_ptr, t._version)
def _get_sid(t) -> Tuple[int, int]:
return (t.data_ptr(), t._version)
# FIXME: This is almost definitely a bug.
if isinstance(
t,
(
torch._subclasses.fake_tensor.FakeTensor,
torch._subclasses.functional_tensor.FunctionalTensor,
),
):
data_ptr = 0
else:
data_ptr = t.data_ptr()
return (data_ptr, t._version)
class _Handle:

View File

@ -838,6 +838,20 @@ void initDispatchBindings(PyObject* module) {
return a.sizes(); // NB: NOT sym_size
});
m.def("_set_throw_on_mutable_data_ptr", [](const at::Tensor& t) {
if (!t.unsafeGetTensorImpl()->has_storage()) {
// If the Tensor doesn't have a storage, then accessing .data_ptr()
// will already raise an error.
return;
}
// Otherwise, set (on the StorageImpl) that accessing (mutable) data_ptr
// will throw.
t.unsafeGetTensorImpl()
->storage()
.unsafeGetStorageImpl()
->set_throw_on_mutable_data_ptr();
});
using c10::impl::TorchDispatchModeKey;
py::enum_<TorchDispatchModeKey>(m, "_TorchDispatchModeKey")
.value("FUNCTIONAL", TorchDispatchModeKey::FUNCTIONAL)

View File

@ -127,7 +127,19 @@ def _register_tensor_wrapper(tensor) -> None:
# Tensor storage -> work mapping is maintained in C++
return
global data_ptr_to_work
data_ptr = tensor.elem.data_ptr()
# FIXME: This is almost definitely a bug.
if isinstance(
tensor.elem,
(
torch._subclasses.fake_tensor.FakeTensor,
torch._subclasses.functional_tensor.FunctionalTensor,
),
):
data_ptr = 0
else:
data_ptr = tensor.elem.data_ptr()
# Note: we should NEVER try to trace this, bc it registers runtime stuff during trace.
# Instead, backends must call this themselves when implementing traced collectives.
wait_reg = data_ptr_to_work.get(data_ptr, None)

View File

@ -1101,6 +1101,7 @@ def make_fx(f,
allow_non_fake_inputs=_allow_non_fake_inputs,
shape_env=ShapeEnv(),
static_shapes=True,
_allow_unsafe_data_ptr_access=False,
)
elif tracing_mode == "symbolic":
import torch._dynamo
@ -1110,7 +1111,8 @@ def make_fx(f,
fake_tensor_mode = FakeTensorMode(
allow_fallback_kernels=False,
allow_non_fake_inputs=_allow_non_fake_inputs,
shape_env=shape_env)
shape_env=shape_env,
_allow_unsafe_data_ptr_access=False)
else:
shape_env = fake_tensor_mode.shape_env
assert shape_env is not None, "shape_env should be set if tracing with 'symbolic'"