Allow OpaqueTensorImpl to be used for views (#151028)

Summary:
When creating an `OpaqueTensorImpl`, currently there's only an option to create it for a non-view tensor, but it can be useful to create one for view tensors as well.

View tensors should contain the same autograd parameters as the original tensor, whereas non-view tensors get created with whatever `inference_mode` option is currently enabled. For this reason, `TensorImpl` has a special view constructor that takes `TensorImpl::ImplType` as its first parameter, so adding a new constructor to `OpaqueTensorImpl` that does the same thing allows us to create views with it.

Test Plan: CI

Reviewed By: scottxu0730

Differential Revision: D71748460

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151028
Approved by: https://github.com/scottxu0730, https://github.com/chaos5958
This commit is contained in:
Pat Vignola 2025-04-11 20:07:47 +00:00 committed by PyTorch MergeBot
parent bb60e82672
commit 05236b5045

View File

@ -29,12 +29,20 @@ struct TORCH_API OpaqueTensorImpl : public TensorImpl {
bool is_non_overlapping_and_dense = true)
: TensorImpl(key_set, data_type, device),
opaque_handle_(std::move(opaque_handle)) {
set_storage_access_should_throw();
set_custom_sizes_strides(SizesStridesPolicy::CustomStrides);
sizes_and_strides_.set_sizes(sizes);
refresh_numel();
// NOLINTNEXTLINE(cppcoreguidelines-prefer-member-initializer)
is_non_overlapping_and_dense_ = is_non_overlapping_and_dense;
constructor_impl(sizes, is_non_overlapping_and_dense);
}
OpaqueTensorImpl(
TensorImpl::ImplType impl_type,
c10::Storage&& storage,
at::DispatchKeySet key_set,
const caffe2::TypeMeta data_type,
OpaqueHandle opaque_handle,
c10::IntArrayRef sizes,
bool is_non_overlapping_and_dense = true)
: TensorImpl(impl_type, std::move(storage), key_set, data_type),
opaque_handle_(std::move(opaque_handle)) {
constructor_impl(sizes, is_non_overlapping_and_dense);
}
// Destructor doesn't call release_resources because it's
@ -181,6 +189,17 @@ struct TORCH_API OpaqueTensorImpl : public TensorImpl {
return "OpaqueTensorImpl";
}
void constructor_impl(
c10::IntArrayRef sizes,
bool is_non_overlapping_and_dense) {
set_storage_access_should_throw();
set_custom_sizes_strides(SizesStridesPolicy::CustomStrides);
sizes_and_strides_.set_sizes(sizes);
refresh_numel();
// NOLINTNEXTLINE(cppcoreguidelines-prefer-member-initializer)
is_non_overlapping_and_dense_ = is_non_overlapping_and_dense;
}
OpaqueHandle opaque_handle_;
};