mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Allow OpaqueTensorImpl to be used for views (#151028)
Summary: When creating an `OpaqueTensorImpl`, currently there's only an option to create it for a non-view tensor, but it can be useful to create one for view tensors as well. View tensors should contain the same autograd parameters as the original tensor, whereas non-view tensors get created with whatever `inference_mode` option is currently enabled. For this reason, `TensorImpl` has a special view constructor that takes `TensorImpl::ImplType` as its first parameter, so adding a new constructor to `OpaqueTensorImpl` that does the same thing allows us to create views with it. Test Plan: CI Reviewed By: scottxu0730 Differential Revision: D71748460 Pull Request resolved: https://github.com/pytorch/pytorch/pull/151028 Approved by: https://github.com/scottxu0730, https://github.com/chaos5958
This commit is contained in:
parent
bb60e82672
commit
05236b5045
|
|
@ -29,12 +29,20 @@ struct TORCH_API OpaqueTensorImpl : public TensorImpl {
|
|||
bool is_non_overlapping_and_dense = true)
|
||||
: TensorImpl(key_set, data_type, device),
|
||||
opaque_handle_(std::move(opaque_handle)) {
|
||||
set_storage_access_should_throw();
|
||||
set_custom_sizes_strides(SizesStridesPolicy::CustomStrides);
|
||||
sizes_and_strides_.set_sizes(sizes);
|
||||
refresh_numel();
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-prefer-member-initializer)
|
||||
is_non_overlapping_and_dense_ = is_non_overlapping_and_dense;
|
||||
constructor_impl(sizes, is_non_overlapping_and_dense);
|
||||
}
|
||||
|
||||
OpaqueTensorImpl(
|
||||
TensorImpl::ImplType impl_type,
|
||||
c10::Storage&& storage,
|
||||
at::DispatchKeySet key_set,
|
||||
const caffe2::TypeMeta data_type,
|
||||
OpaqueHandle opaque_handle,
|
||||
c10::IntArrayRef sizes,
|
||||
bool is_non_overlapping_and_dense = true)
|
||||
: TensorImpl(impl_type, std::move(storage), key_set, data_type),
|
||||
opaque_handle_(std::move(opaque_handle)) {
|
||||
constructor_impl(sizes, is_non_overlapping_and_dense);
|
||||
}
|
||||
|
||||
// Destructor doesn't call release_resources because it's
|
||||
|
|
@ -181,6 +189,17 @@ struct TORCH_API OpaqueTensorImpl : public TensorImpl {
|
|||
return "OpaqueTensorImpl";
|
||||
}
|
||||
|
||||
void constructor_impl(
|
||||
c10::IntArrayRef sizes,
|
||||
bool is_non_overlapping_and_dense) {
|
||||
set_storage_access_should_throw();
|
||||
set_custom_sizes_strides(SizesStridesPolicy::CustomStrides);
|
||||
sizes_and_strides_.set_sizes(sizes);
|
||||
refresh_numel();
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-prefer-member-initializer)
|
||||
is_non_overlapping_and_dense_ = is_non_overlapping_and_dense;
|
||||
}
|
||||
|
||||
OpaqueHandle opaque_handle_;
|
||||
};
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user