mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
The original author is @YifanShenSZ and the original PR is: #82754 # Summary: Previous reshape [https://github.com/pytorch/pytorch/issues/80981](https://github.com/pytorch/pytorch/pull/80981) is ok for forward, but needs improvement for backward: need to handle "sometimes view sometimes copy" behavior. This pull request fixes it by: 1. add a new alias dispatch key `CompositeImplicitAutogradNestedTensor`, which ideally would work as nested-tensor version of `CompositeImplicitAutograd` 2. register `reshape_nested` to `reshape` by `CompositeImplicitAutogradNestedTensor` Side changes: * add contiguous memory format support to `clone_nested` * add `view_nested` * add `reshape_as_nested` Fix issue [https://github.com/pytorch/pytorch/issues/83041](https://github.com/pytorch/pytorch/issues/83041) Pull Request resolved: https://github.com/pytorch/pytorch/pull/82754 Test Plan: Imported from GitHub, without a `Test Plan:` line. **Static Docs Preview: executorch** |[Full Site](https://our.intern.facebook.com/intern/staticdocs/eph/D39023822/V13/executorch/)| |**Modified Pages**| Reviewed By: albanD Differential Revision: D39023822 Pulled By: drisspg Pull Request resolved: https://github.com/pytorch/pytorch/pull/84154 Approved by: https://github.com/bdhirsh, https://github.com/albanD
This commit is contained in:
parent
9bcad063d8
commit
673b35c847
|
|
@ -51,6 +51,7 @@ generated_cpu_cpp = [
|
||||||
"aten/src/ATen/RegisterSparseCsrCPU.cpp",
|
"aten/src/ATen/RegisterSparseCsrCPU.cpp",
|
||||||
"aten/src/ATen/RegisterZeroTensor.cpp",
|
"aten/src/ATen/RegisterZeroTensor.cpp",
|
||||||
"aten/src/ATen/RegisterCompositeImplicitAutograd.cpp",
|
"aten/src/ATen/RegisterCompositeImplicitAutograd.cpp",
|
||||||
|
"aten/src/ATen/RegisterCompositeImplicitAutogradNestedTensor.cpp",
|
||||||
"aten/src/ATen/RegisterCompositeExplicitAutograd.cpp",
|
"aten/src/ATen/RegisterCompositeExplicitAutograd.cpp",
|
||||||
"aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp",
|
"aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp",
|
||||||
"aten/src/ATen/RegisterMeta.cpp",
|
"aten/src/ATen/RegisterMeta.cpp",
|
||||||
|
|
@ -66,6 +67,8 @@ generated_cpu_cpp = [
|
||||||
"aten/src/ATen/CompositeExplicitAutogradNonFunctionalFunctions_inl.h",
|
"aten/src/ATen/CompositeExplicitAutogradNonFunctionalFunctions_inl.h",
|
||||||
"aten/src/ATen/CompositeImplicitAutogradFunctions.h",
|
"aten/src/ATen/CompositeImplicitAutogradFunctions.h",
|
||||||
"aten/src/ATen/CompositeImplicitAutogradFunctions_inl.h",
|
"aten/src/ATen/CompositeImplicitAutogradFunctions_inl.h",
|
||||||
|
"aten/src/ATen/CompositeImplicitAutogradNestedTensorFunctions.h",
|
||||||
|
"aten/src/ATen/CompositeImplicitAutogradNestedTensorFunctions_inl.h",
|
||||||
"aten/src/ATen/CompositeViewCopyKernels.cpp",
|
"aten/src/ATen/CompositeViewCopyKernels.cpp",
|
||||||
"aten/src/ATen/FunctionalInverses.h",
|
"aten/src/ATen/FunctionalInverses.h",
|
||||||
"aten/src/ATen/Functions.h",
|
"aten/src/ATen/Functions.h",
|
||||||
|
|
|
||||||
|
|
@ -307,6 +307,21 @@ std::pair<const AnnotatedKernel&, const char*> OperatorEntry::computeDispatchTab
|
||||||
// For AutogradOther, we return ambiguousAutogradOtherKernel() if there's registration
|
// For AutogradOther, we return ambiguousAutogradOtherKernel() if there's registration
|
||||||
// to any of its backends.
|
// to any of its backends.
|
||||||
// See Note [Undefined in dispatchTable_] for the special handling for Undefined.
|
// See Note [Undefined in dispatchTable_] for the special handling for Undefined.
|
||||||
|
|
||||||
|
// If the dispatch key is included in CompositeImplicitAutogradNestedTensor,
|
||||||
|
// then we register it to nested-tensor kernel rather than
|
||||||
|
// regular-tensor CompositeImplicitAutograd kernel.
|
||||||
|
// We have no intention to change the behavior of Undefined,
|
||||||
|
// so this nested-tensor branch requires `dispatch_key != DispatchKey::Undefined`
|
||||||
|
// to let the original CompositeImplicitAutograd handle Undefined
|
||||||
|
if (dispatch_key != DispatchKey::Undefined && isIncludedInAlias(dispatch_key, DispatchKey::CompositeImplicitAutogradNestedTensor)) {
|
||||||
|
if (auto nested_registration = getKernelForDispatchKey(DispatchKey::CompositeImplicitAutogradNestedTensor)) {
|
||||||
|
if (!has_backend_kernel) {
|
||||||
|
return {*nested_registration, "nested kernel"};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (dispatch_key == DispatchKey::Undefined || isIncludedInAlias(dispatch_key, DispatchKey::CompositeImplicitAutograd)) {
|
if (dispatch_key == DispatchKey::Undefined || isIncludedInAlias(dispatch_key, DispatchKey::CompositeImplicitAutograd)) {
|
||||||
if (auto math_registration = getKernelForDispatchKey(DispatchKey::CompositeImplicitAutograd)) {
|
if (auto math_registration = getKernelForDispatchKey(DispatchKey::CompositeImplicitAutograd)) {
|
||||||
if (dispatch_key == DispatchKey::AutogradOther
|
if (dispatch_key == DispatchKey::AutogradOther
|
||||||
|
|
|
||||||
|
|
@ -1256,17 +1256,6 @@ Tensor alias_with_sizes_and_strides(
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor reshape(const Tensor& self, IntArrayRef proposed_shape) {
|
Tensor reshape(const Tensor& self, IntArrayRef proposed_shape) {
|
||||||
// reshape has special autograd logic since it sometimes returns a view but sometimes does not
|
|
||||||
// we have to intercept here instead of using dispatcher
|
|
||||||
// otherwise we will see "autograd still running" kind of error in inference mode:
|
|
||||||
// * if we create a tensor in inference mode scope,
|
|
||||||
// then pass it to a inference mode decorated function,
|
|
||||||
// everything is fine
|
|
||||||
// * but if we create the input tensor not with inference mode,
|
|
||||||
// then errors like "Cannot set version_counter for inference tensor" arise
|
|
||||||
if (self.is_nested()) {
|
|
||||||
return at::_reshape_nested(self, proposed_shape);
|
|
||||||
}
|
|
||||||
if (self.is_sparse()) {
|
if (self.is_sparse()) {
|
||||||
AT_ERROR("reshape is not implemented for sparse tensors");
|
AT_ERROR("reshape is not implemented for sparse tensors");
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -4200,16 +4200,9 @@
|
||||||
variants: function, method
|
variants: function, method
|
||||||
device_check: NoCheck
|
device_check: NoCheck
|
||||||
device_guard: False
|
device_guard: False
|
||||||
|
|
||||||
- func: _reshape_nested(Tensor self, int[] shape) -> Tensor
|
|
||||||
dispatch:
|
dispatch:
|
||||||
NestedTensorCPU, NestedTensorCUDA: _reshape_nested
|
CompositeImplicitAutograd: reshape
|
||||||
autogen: _reshape_nested.out
|
CompositeImplicitAutogradNestedTensor: reshape_nested
|
||||||
|
|
||||||
- func: _reshape_nested_backward(Tensor self, Tensor grad) -> Tensor
|
|
||||||
dispatch:
|
|
||||||
NestedTensorCPU, NestedTensorCUDA: _reshape_nested_backward
|
|
||||||
autogen: _reshape_nested_backward.out
|
|
||||||
|
|
||||||
# NOTE [ _reshape_alias ] is meant to be used in the implementation of reshape.
|
# NOTE [ _reshape_alias ] is meant to be used in the implementation of reshape.
|
||||||
# They are not user-facing, hence the leading underscore. Please don't use it
|
# They are not user-facing, hence the leading underscore. Please don't use it
|
||||||
|
|
@ -4233,6 +4226,9 @@
|
||||||
variants: method
|
variants: method
|
||||||
device_check: NoCheck
|
device_check: NoCheck
|
||||||
device_guard: False
|
device_guard: False
|
||||||
|
dispatch:
|
||||||
|
CompositeImplicitAutograd: reshape_as
|
||||||
|
CompositeImplicitAutogradNestedTensor: reshape_as_nested
|
||||||
|
|
||||||
- func: round(Tensor self) -> Tensor
|
- func: round(Tensor self) -> Tensor
|
||||||
device_check: NoCheck # TensorIterator
|
device_check: NoCheck # TensorIterator
|
||||||
|
|
@ -6889,6 +6885,7 @@
|
||||||
Meta: view_meta
|
Meta: view_meta
|
||||||
ZeroTensor, CPU, CUDA, QuantizedCPU, QuantizedCUDA, MPS: view
|
ZeroTensor, CPU, CUDA, QuantizedCPU, QuantizedCUDA, MPS: view
|
||||||
MkldnnCPU: mkldnn_view
|
MkldnnCPU: mkldnn_view
|
||||||
|
NestedTensorCPU, NestedTensorCUDA: view_nested
|
||||||
|
|
||||||
# Warning: If you want to change the name or overload name of this
|
# Warning: If you want to change the name or overload name of this
|
||||||
# operator, you might also want to change the `isBlockListedSchema`
|
# operator, you might also want to change the `isBlockListedSchema`
|
||||||
|
|
|
||||||
|
|
@ -66,23 +66,6 @@ std::tuple<Tensor, Tensor, Tensor> nested_linear_backward(
|
||||||
return std::tuple<Tensor, Tensor, Tensor>{grad_input, grad_weight, grad_bias};
|
return std::tuple<Tensor, Tensor, Tensor>{grad_input, grad_weight, grad_bias};
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor _reshape_nested_backward(const Tensor& self, const Tensor& grad) {
|
|
||||||
auto self_ptr = get_nested_tensor_impl(self);
|
|
||||||
// TODO: this is to reproduce self_ptr->opt_sizes_
|
|
||||||
// if an accessor is provided in the future, can replace this
|
|
||||||
std::vector<int64_t> sizes;
|
|
||||||
for (int64_t i = 0; i < self_ptr->dim(); i++) {
|
|
||||||
c10::optional<int64_t> opt_size = self_ptr->opt_size(i);
|
|
||||||
if (opt_size.has_value()) {
|
|
||||||
sizes.push_back(*opt_size);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
sizes.push_back(-1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return grad.reshape(sizes);
|
|
||||||
}
|
|
||||||
|
|
||||||
Tensor nested_softmax_backward(
|
Tensor nested_softmax_backward(
|
||||||
const Tensor& grad,
|
const Tensor& grad,
|
||||||
const Tensor& output,
|
const Tensor& output,
|
||||||
|
|
|
||||||
|
|
@ -688,17 +688,38 @@ Tensor select_nested(const Tensor& self, int64_t dim, int64_t index) {
|
||||||
Tensor clone_nested(
|
Tensor clone_nested(
|
||||||
const Tensor& self,
|
const Tensor& self,
|
||||||
c10::optional<c10::MemoryFormat> optional_memory_format) {
|
c10::optional<c10::MemoryFormat> optional_memory_format) {
|
||||||
auto memory_format = optional_memory_format.value_or(MemoryFormat::Preserve);
|
auto memory_format = optional_memory_format.value_or(c10::MemoryFormat::Preserve);
|
||||||
TORCH_CHECK(
|
auto self_ptr = get_nested_tensor_impl(self);
|
||||||
memory_format == MemoryFormat::Preserve,
|
if (memory_format == c10::MemoryFormat::Preserve ||
|
||||||
"clone_nested only supports memory format Preserve, but got ",
|
(memory_format == c10::MemoryFormat::Contiguous && self.is_contiguous())) {
|
||||||
memory_format,
|
const Tensor& buffer = self_ptr->get_buffer(),
|
||||||
" instead.");
|
sizemat = self_ptr->get_nested_size_tensor(),
|
||||||
// TODO: The size doesn't necessarily need to be cloned, but it is more
|
stridemat = self_ptr->get_nested_stride_tensor();
|
||||||
// conservative. This is something we could revisit once we land a more
|
const std::vector<int64_t>& offsets = self_ptr->get_offsets();
|
||||||
// efficient implementation of nested_size_tensor_.
|
// TODO: The size and the stride do not necessarily need to be cloned,
|
||||||
return wrap_buffer(
|
// but it is more conservative.
|
||||||
get_buffer(self).clone(), get_nested_size_tensor(self).clone());
|
// This is something we could revisit once we land a more
|
||||||
|
// efficient implementation of nested_size_tensor_ and nested_stride_tensor.
|
||||||
|
return wrap_buffer(buffer.clone(), sizemat.clone(), stridemat.clone(), std::vector<int64_t>(offsets));
|
||||||
|
}
|
||||||
|
// actually, memory format is contiguous and self is noncontiguous
|
||||||
|
else if (memory_format == c10::MemoryFormat::Contiguous) {
|
||||||
|
const Tensor& self_buffer = self_ptr->get_buffer(),
|
||||||
|
sizemat = self_ptr->get_nested_size_tensor();
|
||||||
|
Tensor output_buffer = at::empty_like(self_buffer);
|
||||||
|
Tensor output = wrap_buffer(output_buffer, sizemat);
|
||||||
|
std::vector<Tensor> self_unbind = self.unbind(),
|
||||||
|
output_unbind = output.unbind();
|
||||||
|
for (int64_t i = 0; i < self_ptr->size(0); i++) {
|
||||||
|
output_unbind[i].copy_(self_unbind[i]);
|
||||||
|
}
|
||||||
|
return output;
|
||||||
|
} else {
|
||||||
|
TORCH_CHECK(
|
||||||
|
false,
|
||||||
|
"Nested tensor clone supports Preserve and Contiguous memory formats, called clone with memory format: ",
|
||||||
|
memory_format);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
at::Tensor NestedTensor_get_nested_size_tensor(const at::Tensor& self){
|
at::Tensor NestedTensor_get_nested_size_tensor(const at::Tensor& self){
|
||||||
|
|
@ -1008,7 +1029,7 @@ Tensor transpose_nested(const Tensor& self, int64_t dim0, int64_t dim1) {
|
||||||
self, sizemat_transposed, stridemat_transposed, std::vector<int64_t>(self_ptr->get_offsets()));
|
self, sizemat_transposed, stridemat_transposed, std::vector<int64_t>(self_ptr->get_offsets()));
|
||||||
}
|
}
|
||||||
|
|
||||||
// utilities supporting `_reshape_nested`
|
// utilities supporting `view_nested` and `reshape_nested`
|
||||||
namespace {
|
namespace {
|
||||||
// Args:
|
// Args:
|
||||||
// sizes: the sizes of original nested tensor
|
// sizes: the sizes of original nested tensor
|
||||||
|
|
@ -1016,10 +1037,10 @@ namespace {
|
||||||
// proposed_shape: user proposed new shape
|
// proposed_shape: user proposed new shape
|
||||||
// op: the options for new size and stride matrices
|
// op: the options for new size and stride matrices
|
||||||
// Returns:
|
// Returns:
|
||||||
// whether reshape as view is possible (i.e. old buffer can be reused)
|
// whether viewable
|
||||||
// size matrix after reshape
|
// size matrix after reshape
|
||||||
// stride matrix after reshape (not fully populated if reshape as view is impossible)
|
// stride matrix after reshape (not fully populated if not viewable)
|
||||||
inline std::tuple<bool, Tensor, Tensor> NestedTensor_reshape_size_stride(
|
inline std::tuple<bool, Tensor, Tensor> NestedTensor_compute_size_stride(
|
||||||
const std::vector<IntArrayRef>& sizes,
|
const std::vector<IntArrayRef>& sizes,
|
||||||
const std::vector<IntArrayRef>& strides,
|
const std::vector<IntArrayRef>& strides,
|
||||||
const IntArrayRef& proposed_shape,
|
const IntArrayRef& proposed_shape,
|
||||||
|
|
@ -1027,7 +1048,7 @@ inline std::tuple<bool, Tensor, Tensor> NestedTensor_reshape_size_stride(
|
||||||
int64_t ntensors = sizes.size(),
|
int64_t ntensors = sizes.size(),
|
||||||
ndims_underlying = sizes[0].size(),
|
ndims_underlying = sizes[0].size(),
|
||||||
ndims_underlying_reshaped = proposed_shape.size() - 1;
|
ndims_underlying_reshaped = proposed_shape.size() - 1;
|
||||||
bool reshape_as_view = true;
|
bool viewable = true;
|
||||||
Tensor sizemat_reshaped = at::empty({ntensors, ndims_underlying_reshaped}, op),
|
Tensor sizemat_reshaped = at::empty({ntensors, ndims_underlying_reshaped}, op),
|
||||||
stridemat_reshaped = at::empty({ntensors, ndims_underlying_reshaped}, op);
|
stridemat_reshaped = at::empty({ntensors, ndims_underlying_reshaped}, op);
|
||||||
int64_t* sizemat_reshaped_ptr = sizemat_reshaped.data_ptr<int64_t>(),
|
int64_t* sizemat_reshaped_ptr = sizemat_reshaped.data_ptr<int64_t>(),
|
||||||
|
|
@ -1039,6 +1060,7 @@ inline std::tuple<bool, Tensor, Tensor> NestedTensor_reshape_size_stride(
|
||||||
std::vector<int64_t> size_reshaped_vector(proposed_shape.begin() + 1, proposed_shape.end());
|
std::vector<int64_t> size_reshaped_vector(proposed_shape.begin() + 1, proposed_shape.end());
|
||||||
// some negative sizes remain to be infered
|
// some negative sizes remain to be infered
|
||||||
if (ndims_underlying < ndims_underlying_reshaped) {
|
if (ndims_underlying < ndims_underlying_reshaped) {
|
||||||
|
int64_t numel = 1, numel_reshaped = 1;
|
||||||
// replace negative sizes for old dimensions with old sizes
|
// replace negative sizes for old dimensions with old sizes
|
||||||
for (int64_t idim = 0; idim < ndims_underlying; idim++) {
|
for (int64_t idim = 0; idim < ndims_underlying; idim++) {
|
||||||
int64_t& size_reshaped = size_reshaped_vector[idim];
|
int64_t& size_reshaped = size_reshaped_vector[idim];
|
||||||
|
|
@ -1046,12 +1068,17 @@ inline std::tuple<bool, Tensor, Tensor> NestedTensor_reshape_size_stride(
|
||||||
if (size_reshaped == -1) {
|
if (size_reshaped == -1) {
|
||||||
size_reshaped = size[idim];
|
size_reshaped = size[idim];
|
||||||
}
|
}
|
||||||
|
numel *= size[idim];
|
||||||
|
numel_reshaped *= size_reshaped;
|
||||||
}
|
}
|
||||||
// infer negative size for new dimension
|
// infer negative size for new dimension
|
||||||
int64_t infer_index = -1;
|
int64_t infer_index = -1;
|
||||||
for (int64_t idim = ndims_underlying; idim < ndims_underlying_reshaped; idim++) {
|
for (int64_t idim = ndims_underlying; idim < ndims_underlying_reshaped; idim++) {
|
||||||
const int64_t& size_reshaped = size_reshaped_vector[idim];
|
const int64_t& size_reshaped = size_reshaped_vector[idim];
|
||||||
if (size_reshaped == -1) {
|
if (size_reshaped >= 0) {
|
||||||
|
numel_reshaped *= size_reshaped;
|
||||||
|
}
|
||||||
|
else if (size_reshaped == -1) {
|
||||||
if (infer_index > -1) {
|
if (infer_index > -1) {
|
||||||
throw std::runtime_error("only one dimension can be inferred");
|
throw std::runtime_error("only one dimension can be inferred");
|
||||||
}
|
}
|
||||||
|
|
@ -1059,22 +1086,36 @@ inline std::tuple<bool, Tensor, Tensor> NestedTensor_reshape_size_stride(
|
||||||
infer_index = idim;
|
infer_index = idim;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else if (size_reshaped < 0) {
|
else {
|
||||||
AT_ERROR("invalid shape dimension ", size_reshaped);
|
AT_ERROR("invalid shape dimension ", size_reshaped);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// See Note [inference and inheritance semantics]
|
// See Note [Special size rule for nested tensor]
|
||||||
TORCH_CHECK(infer_index == -1, "nested tensor does not infer shape");
|
TORCH_CHECK(infer_index == -1, "nested tensor does not infer shape");
|
||||||
|
TORCH_CHECK(
|
||||||
|
numel == numel_reshaped,
|
||||||
|
"shape '", proposed_shape, "' ",
|
||||||
|
"is invalid for input of size ", numel);
|
||||||
}
|
}
|
||||||
// all negative sizes can be replaced
|
// all negative sizes can be replaced
|
||||||
else {
|
else {
|
||||||
|
int64_t numel = 1, numel_reshaped = 1;
|
||||||
for (int64_t idim = 0; idim < ndims_underlying_reshaped; idim++) {
|
for (int64_t idim = 0; idim < ndims_underlying_reshaped; idim++) {
|
||||||
int64_t& size_reshaped = size_reshaped_vector[idim];
|
int64_t& size_reshaped = size_reshaped_vector[idim];
|
||||||
TORCH_CHECK(size_reshaped >= -1, "invalid shape dimension ", size_reshaped);
|
TORCH_CHECK(size_reshaped >= -1, "invalid shape dimension ", size_reshaped);
|
||||||
if (size_reshaped == -1) {
|
if (size_reshaped == -1) {
|
||||||
size_reshaped = size[idim];
|
size_reshaped = size[idim];
|
||||||
}
|
}
|
||||||
|
numel *= size[idim];
|
||||||
|
numel_reshaped *= size_reshaped;
|
||||||
}
|
}
|
||||||
|
for (int64_t idim = ndims_underlying_reshaped; idim < ndims_underlying; idim++) {
|
||||||
|
numel *= size[idim];
|
||||||
|
}
|
||||||
|
TORCH_CHECK(
|
||||||
|
numel == numel_reshaped,
|
||||||
|
"shape '", proposed_shape, "' ",
|
||||||
|
"is invalid for input of size ", numel);
|
||||||
}
|
}
|
||||||
IntArrayRef size_reshaped(size_reshaped_vector);
|
IntArrayRef size_reshaped(size_reshaped_vector);
|
||||||
// compute reshaped stride
|
// compute reshaped stride
|
||||||
|
|
@ -1092,7 +1133,7 @@ inline std::tuple<bool, Tensor, Tensor> NestedTensor_reshape_size_stride(
|
||||||
}
|
}
|
||||||
// reshape as view is impossible
|
// reshape as view is impossible
|
||||||
else {
|
else {
|
||||||
reshape_as_view = false;
|
viewable = false;
|
||||||
// fill reshaped size into sizemat
|
// fill reshaped size into sizemat
|
||||||
for (int64_t idim = 0; idim < ndims_underlying_reshaped; idim++) {
|
for (int64_t idim = 0; idim < ndims_underlying_reshaped; idim++) {
|
||||||
sizemat_reshaped_ptr[idim] = size_reshaped[idim];
|
sizemat_reshaped_ptr[idim] = size_reshaped[idim];
|
||||||
|
|
@ -1100,42 +1141,59 @@ inline std::tuple<bool, Tensor, Tensor> NestedTensor_reshape_size_stride(
|
||||||
sizemat_reshaped_ptr += ndims_underlying_reshaped;
|
sizemat_reshaped_ptr += ndims_underlying_reshaped;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return std::make_tuple(reshape_as_view, sizemat_reshaped, stridemat_reshaped);
|
return std::make_tuple(viewable, sizemat_reshaped, stridemat_reshaped);
|
||||||
}
|
|
||||||
|
|
||||||
// Args:
|
|
||||||
// nt_reshaped: the reshaped nested tensor to receive copies
|
|
||||||
// buffer: the original nested tensor buffer
|
|
||||||
// sizes: the original nested tensor sizes (may have gone through collapsing or splitting)
|
|
||||||
// strides: the original nested tensor strides (may have gone through collapsing or splitting)
|
|
||||||
// offsets: the original nested tensor offsets (may have gone through collapsing or splitting)
|
|
||||||
inline void NestedTensor_reshape_copy(
|
|
||||||
Tensor& nt_reshaped,
|
|
||||||
const Tensor& buffer,
|
|
||||||
const std::vector<IntArrayRef>& sizes,
|
|
||||||
const std::vector<IntArrayRef>& strides,
|
|
||||||
const std::vector<int64_t>& offsets) {
|
|
||||||
auto nt_reshaped_ptr = get_nested_tensor_impl(nt_reshaped);
|
|
||||||
const Tensor& buffer_reshaped = nt_reshaped_ptr->get_buffer();
|
|
||||||
std::vector<IntArrayRef> sizes_reshaped = NestedTensor_get_sizes(nt_reshaped_ptr),
|
|
||||||
strides_reshaped = NestedTensor_get_strides(nt_reshaped_ptr);
|
|
||||||
const std::vector<int64_t>& offsets_reshaped = nt_reshaped_ptr->get_offsets();
|
|
||||||
for (int64_t i = 0; i < nt_reshaped_ptr->size(0); i++) {
|
|
||||||
buffer_reshaped.as_strided(sizes_reshaped[i], strides_reshaped[i], offsets_reshaped[i]).copy_(
|
|
||||||
// TODO: can we avoid allocating new memory for `buffer...reshape`
|
|
||||||
// I did not find anything like reshape_out
|
|
||||||
buffer.as_strided(sizes[i], strides[i], offsets[i]).reshape(sizes_reshaped[i]));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// Special rules for reshape(nested tensor):
|
// Note [Special size rule for nested tensor]
|
||||||
// 1. Only 1 regular dimension can be collapsed with
|
// Instead of infering size, -1 means "inherit the old size", so:
|
||||||
// or splitted from the implicit batch dimension
|
// * negative size is legal for a ragged dimension
|
||||||
// 2. Instead of infering size, -1 means "inherit the old size", so:
|
// * multiple sizes can be -1
|
||||||
// * negative size is legal for a ragged dimension
|
// In principle we could still infer a dimension,
|
||||||
// * multiple sizes can be -1
|
// we are designing a better semantics to include both inheritance and inference
|
||||||
Tensor _reshape_nested(const Tensor& self, IntArrayRef proposed_shape) {
|
Tensor view_nested(const Tensor& self, IntArrayRef proposed_shape) {
|
||||||
|
TORCH_CHECK(
|
||||||
|
proposed_shape.size() > 0,
|
||||||
|
"shape '[]' is invalid for a nested tensor");
|
||||||
|
auto self_ptr = get_nested_tensor_impl(self);
|
||||||
|
// basic information before reshaping
|
||||||
|
int64_t ntensors = self_ptr->size(0);
|
||||||
|
TORCH_CHECK(
|
||||||
|
ntensors > 0,
|
||||||
|
"empty nested tensor cannot be reshaped");
|
||||||
|
// basic information after reshaping
|
||||||
|
int64_t ntensors_reshaped;
|
||||||
|
if (proposed_shape[0] >= 0) {
|
||||||
|
ntensors_reshaped = proposed_shape[0];
|
||||||
|
}
|
||||||
|
else if (proposed_shape[0] == -1) {
|
||||||
|
ntensors_reshaped = ntensors;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
AT_ERROR("invalid shape dimension ", proposed_shape[0]);
|
||||||
|
}
|
||||||
|
TORCH_CHECK(
|
||||||
|
ntensors == ntensors_reshaped,
|
||||||
|
"for now view cannot change the implicit batch dimension");
|
||||||
|
std::vector<IntArrayRef> sizes = NestedTensor_get_sizes(self_ptr),
|
||||||
|
strides = NestedTensor_get_strides(self_ptr);
|
||||||
|
// reshaping underlying tensor dimensions does not change offset
|
||||||
|
// determine reshaped size and stride
|
||||||
|
const Tensor& sizemat = self_ptr->get_nested_size_tensor();
|
||||||
|
bool viewable;
|
||||||
|
Tensor sizemat_reshaped, stridemat_reshaped;
|
||||||
|
std::tie(viewable, sizemat_reshaped, stridemat_reshaped) = NestedTensor_compute_size_stride(
|
||||||
|
sizes, strides, proposed_shape, sizemat.options());
|
||||||
|
TORCH_CHECK(
|
||||||
|
viewable,
|
||||||
|
"view size is not compatible with input tensor's size and stride "
|
||||||
|
"(at least one dimension spans across two contiguous subspaces). "
|
||||||
|
"Use .reshape(...) instead.");
|
||||||
|
return create_nested_view_tensor(self, sizemat_reshaped, stridemat_reshaped, std::vector<int64_t>(self_ptr->get_offsets()));
|
||||||
|
}
|
||||||
|
|
||||||
|
// See Note [Special size rule for nested tensor]
|
||||||
|
Tensor reshape_nested(const Tensor& self, IntArrayRef proposed_shape) {
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
proposed_shape.size() > 0,
|
proposed_shape.size() > 0,
|
||||||
"shape '[]' is invalid for a nested tensor");
|
"shape '[]' is invalid for a nested tensor");
|
||||||
|
|
@ -1161,23 +1219,37 @@ Tensor _reshape_nested(const Tensor& self, IntArrayRef proposed_shape) {
|
||||||
"for now reshape cannot change the implicit batch dimension");
|
"for now reshape cannot change the implicit batch dimension");
|
||||||
std::vector<IntArrayRef> sizes = NestedTensor_get_sizes(self_ptr),
|
std::vector<IntArrayRef> sizes = NestedTensor_get_sizes(self_ptr),
|
||||||
strides = NestedTensor_get_strides(self_ptr);
|
strides = NestedTensor_get_strides(self_ptr);
|
||||||
const std::vector<int64_t>& offsets = self_ptr->get_offsets();
|
|
||||||
// reshaping underlying tensor dimensions does not change offset
|
// reshaping underlying tensor dimensions does not change offset
|
||||||
// determine reshaped size and stride
|
// determine reshaped size and stride
|
||||||
const Tensor& buffer = self_ptr->get_buffer(),
|
const Tensor& sizemat = self_ptr->get_nested_size_tensor();
|
||||||
& sizemat = self_ptr->get_nested_size_tensor();
|
bool viewable;
|
||||||
bool reshape_as_view;
|
|
||||||
Tensor sizemat_reshaped, stridemat_reshaped;
|
Tensor sizemat_reshaped, stridemat_reshaped;
|
||||||
std::tie(reshape_as_view, sizemat_reshaped, stridemat_reshaped) = NestedTensor_reshape_size_stride(
|
std::tie(viewable, sizemat_reshaped, stridemat_reshaped) = NestedTensor_compute_size_stride(
|
||||||
sizes, strides, proposed_shape, sizemat.options());
|
sizes, strides, proposed_shape, sizemat.options());
|
||||||
if (reshape_as_view) {
|
if (viewable) {
|
||||||
return wrap_buffer(buffer, sizemat_reshaped, stridemat_reshaped, std::vector<int64_t>(offsets));
|
return self.view(proposed_shape);
|
||||||
}
|
}
|
||||||
Tensor buffer_reshaped = buffer.new_empty(buffer.sizes());
|
else {
|
||||||
Tensor output = wrap_buffer(buffer_reshaped, sizemat_reshaped);
|
return self.clone(at::MemoryFormat::Contiguous).view(proposed_shape);
|
||||||
NestedTensor_reshape_copy(output,
|
}
|
||||||
buffer, sizes, strides, offsets);
|
}
|
||||||
return output;
|
|
||||||
|
Tensor reshape_as_nested(const Tensor& self, const Tensor& other) {
|
||||||
|
auto other_ptr = get_nested_tensor_impl(other);
|
||||||
|
// TODO: this is to reproduce other_ptr->opt_sizes_
|
||||||
|
// if an accessor is provided in the future, can replace this
|
||||||
|
std::vector<int64_t> sizes;
|
||||||
|
for (int64_t i = 0; i < other_ptr->dim(); i++) {
|
||||||
|
c10::optional<int64_t> opt_size = other_ptr->opt_size(i);
|
||||||
|
if (opt_size.has_value()) {
|
||||||
|
sizes.push_back(*opt_size);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
sizes.push_back(-1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// reshape with other.opt_sizes_
|
||||||
|
return self.reshape(sizes);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace native
|
} // namespace native
|
||||||
|
|
|
||||||
|
|
@ -249,6 +249,7 @@ PT_BACKEND_HEADERS = [
|
||||||
"CompositeExplicitAutograd",
|
"CompositeExplicitAutograd",
|
||||||
"CompositeExplicitAutogradNonFunctional",
|
"CompositeExplicitAutogradNonFunctional",
|
||||||
"CompositeImplicitAutograd",
|
"CompositeImplicitAutograd",
|
||||||
|
"CompositeImplicitAutogradNestedTensor",
|
||||||
"Meta",
|
"Meta",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
@ -307,6 +308,7 @@ def get_aten_generated_files(enabled_backends):
|
||||||
src_files = [
|
src_files = [
|
||||||
"RegisterBackendSelect.cpp",
|
"RegisterBackendSelect.cpp",
|
||||||
"RegisterCompositeImplicitAutograd.cpp",
|
"RegisterCompositeImplicitAutograd.cpp",
|
||||||
|
"RegisterCompositeImplicitAutogradNestedTensor.cpp",
|
||||||
"RegisterCompositeExplicitAutograd.cpp",
|
"RegisterCompositeExplicitAutograd.cpp",
|
||||||
"RegisterCompositeExplicitAutogradNonFunctional.cpp",
|
"RegisterCompositeExplicitAutogradNonFunctional.cpp",
|
||||||
"CompositeViewCopyKernels.cpp",
|
"CompositeViewCopyKernels.cpp",
|
||||||
|
|
@ -327,6 +329,8 @@ def get_aten_generated_files(enabled_backends):
|
||||||
"Operators_4.cpp",
|
"Operators_4.cpp",
|
||||||
"CompositeImplicitAutogradFunctions.h",
|
"CompositeImplicitAutogradFunctions.h",
|
||||||
"CompositeImplicitAutogradFunctions_inl.h",
|
"CompositeImplicitAutogradFunctions_inl.h",
|
||||||
|
"CompositeImplicitAutogradNestedTensorFunctions.h",
|
||||||
|
"CompositeImplicitAutogradNestedTensorFunctions_inl.h",
|
||||||
"CompositeExplicitAutogradFunctions.h",
|
"CompositeExplicitAutogradFunctions.h",
|
||||||
"CompositeExplicitAutogradFunctions_inl.h",
|
"CompositeExplicitAutogradFunctions_inl.h",
|
||||||
"CompositeExplicitAutogradNonFunctionalFunctions.h",
|
"CompositeExplicitAutogradNonFunctionalFunctions.h",
|
||||||
|
|
@ -364,7 +368,7 @@ def get_aten_derived_type_src_rules(aten_rule_name, enabled_backends):
|
||||||
def get_aten_selective_cpp_rules(aten_rule_name, enabled_backends):
|
def get_aten_selective_cpp_rules(aten_rule_name, enabled_backends):
|
||||||
return [
|
return [
|
||||||
":{}[{}]".format(aten_rule_name, f)
|
":{}[{}]".format(aten_rule_name, f)
|
||||||
for f in ["RegisterCompositeImplicitAutograd.cpp", "RegisterCompositeExplicitAutograd.cpp", "RegisterCompositeExplicitAutogradNonFunctional.cpp", "RegisterSchema.cpp", "RegisterBackendSelect.cpp", "CompositeViewCopyKernels.cpp"]
|
for f in ["RegisterCompositeImplicitAutograd.cpp", "RegisterCompositeImplicitAutogradNestedTensor.cpp", "RegisterCompositeExplicitAutograd.cpp", "RegisterCompositeExplicitAutogradNonFunctional.cpp", "RegisterSchema.cpp", "RegisterBackendSelect.cpp", "CompositeViewCopyKernels.cpp"]
|
||||||
] + get_aten_derived_type_src_rules(aten_rule_name, enabled_backends)
|
] + get_aten_derived_type_src_rules(aten_rule_name, enabled_backends)
|
||||||
|
|
||||||
def get_aten_derived_type_srcs(enabled_backends):
|
def get_aten_derived_type_srcs(enabled_backends):
|
||||||
|
|
@ -1083,6 +1087,8 @@ def define_buck_targets(
|
||||||
"CompositeExplicitAutogradNonFunctionalFunctions_inl.h": ":gen_aten[CompositeExplicitAutogradNonFunctionalFunctions_inl.h]",
|
"CompositeExplicitAutogradNonFunctionalFunctions_inl.h": ":gen_aten[CompositeExplicitAutogradNonFunctionalFunctions_inl.h]",
|
||||||
"CompositeImplicitAutogradFunctions.h": ":gen_aten[CompositeImplicitAutogradFunctions.h]",
|
"CompositeImplicitAutogradFunctions.h": ":gen_aten[CompositeImplicitAutogradFunctions.h]",
|
||||||
"CompositeImplicitAutogradFunctions_inl.h": ":gen_aten[CompositeImplicitAutogradFunctions_inl.h]",
|
"CompositeImplicitAutogradFunctions_inl.h": ":gen_aten[CompositeImplicitAutogradFunctions_inl.h]",
|
||||||
|
"CompositeImplicitAutogradNestedTensorFunctions.h": ":gen_aten[CompositeImplicitAutogradNestedTensorFunctions.h]",
|
||||||
|
"CompositeImplicitAutogradNestedTensorFunctions_inl.h": ":gen_aten[CompositeImplicitAutogradNestedTensorFunctions_inl.h]",
|
||||||
"FunctionalInverses.h": ":gen_aten[FunctionalInverses.h]",
|
"FunctionalInverses.h": ":gen_aten[FunctionalInverses.h]",
|
||||||
"Functions.h": ":gen_aten[Functions.h]",
|
"Functions.h": ":gen_aten[Functions.h]",
|
||||||
"MethodOperators.h": ":gen_aten[MethodOperators.h]",
|
"MethodOperators.h": ":gen_aten[MethodOperators.h]",
|
||||||
|
|
|
||||||
|
|
@ -162,6 +162,8 @@ GENERATED_H_CORE = [
|
||||||
"CompositeExplicitAutogradNonFunctionalFunctions_inl.h",
|
"CompositeExplicitAutogradNonFunctionalFunctions_inl.h",
|
||||||
"CompositeImplicitAutogradFunctions.h",
|
"CompositeImplicitAutogradFunctions.h",
|
||||||
"CompositeImplicitAutogradFunctions_inl.h",
|
"CompositeImplicitAutogradFunctions_inl.h",
|
||||||
|
"CompositeImplicitAutogradNestedTensorFunctions.h",
|
||||||
|
"CompositeImplicitAutogradNestedTensorFunctions_inl.h",
|
||||||
"MetaFunctions.h",
|
"MetaFunctions.h",
|
||||||
"MetaFunctions_inl.h",
|
"MetaFunctions_inl.h",
|
||||||
"core/TensorBody.h",
|
"core/TensorBody.h",
|
||||||
|
|
@ -193,6 +195,7 @@ GENERATED_CPP = [
|
||||||
"RegisterSparseCsrCPU.cpp",
|
"RegisterSparseCsrCPU.cpp",
|
||||||
"RegisterMkldnnCPU.cpp",
|
"RegisterMkldnnCPU.cpp",
|
||||||
"RegisterCompositeImplicitAutograd.cpp",
|
"RegisterCompositeImplicitAutograd.cpp",
|
||||||
|
"RegisterCompositeImplicitAutogradNestedTensor.cpp",
|
||||||
"RegisterZeroTensor.cpp",
|
"RegisterZeroTensor.cpp",
|
||||||
"RegisterMeta.cpp",
|
"RegisterMeta.cpp",
|
||||||
"RegisterQuantizedMeta.cpp",
|
"RegisterQuantizedMeta.cpp",
|
||||||
|
|
|
||||||
|
|
@ -178,6 +178,8 @@ const char* toString(DispatchKey t) {
|
||||||
return "Autograd";
|
return "Autograd";
|
||||||
case DispatchKey::CompositeImplicitAutograd:
|
case DispatchKey::CompositeImplicitAutograd:
|
||||||
return "CompositeImplicitAutograd";
|
return "CompositeImplicitAutograd";
|
||||||
|
case DispatchKey::CompositeImplicitAutogradNestedTensor:
|
||||||
|
return "CompositeImplicitAutogradNestedTensor";
|
||||||
case DispatchKey::CompositeExplicitAutograd:
|
case DispatchKey::CompositeExplicitAutograd:
|
||||||
return "CompositeExplicitAutograd";
|
return "CompositeExplicitAutograd";
|
||||||
case DispatchKey::CompositeExplicitAutogradNonFunctional:
|
case DispatchKey::CompositeExplicitAutogradNonFunctional:
|
||||||
|
|
@ -324,6 +326,8 @@ c10::DispatchKey parseDispatchKey(const std::string& k) {
|
||||||
{"Autograd", c10::DispatchKey::Autograd},
|
{"Autograd", c10::DispatchKey::Autograd},
|
||||||
{"CompositeImplicitAutograd",
|
{"CompositeImplicitAutograd",
|
||||||
c10::DispatchKey::CompositeImplicitAutograd},
|
c10::DispatchKey::CompositeImplicitAutograd},
|
||||||
|
{"CompositeImplicitAutogradNestedTensor",
|
||||||
|
c10::DispatchKey::CompositeImplicitAutogradNestedTensor},
|
||||||
{"CompositeExplicitAutograd",
|
{"CompositeExplicitAutograd",
|
||||||
c10::DispatchKey::CompositeExplicitAutograd},
|
c10::DispatchKey::CompositeExplicitAutograd},
|
||||||
{"CompositeExplicitAutogradNonFunctional",
|
{"CompositeExplicitAutogradNonFunctional",
|
||||||
|
|
|
||||||
|
|
@ -439,6 +439,8 @@ enum class DispatchKey : uint16_t {
|
||||||
Autograd,
|
Autograd,
|
||||||
CompositeImplicitAutograd, // registered at
|
CompositeImplicitAutograd, // registered at
|
||||||
// build/aten/src/ATen/RegisterCompositeImplicitAutograd.cpp
|
// build/aten/src/ATen/RegisterCompositeImplicitAutograd.cpp
|
||||||
|
CompositeImplicitAutogradNestedTensor, // registered at
|
||||||
|
// build/aten/src/ATen/RegisterCompositeImplicitAutogradNestedTensor.cpp
|
||||||
CompositeExplicitAutograd, // registered at
|
CompositeExplicitAutograd, // registered at
|
||||||
// build/aten/src/ATen/RegisterCompositeExplicitAutograd.cpp
|
// build/aten/src/ATen/RegisterCompositeExplicitAutograd.cpp
|
||||||
// See Note [CompositeExplicitAutogradNonFunctional Key]
|
// See Note [CompositeExplicitAutogradNonFunctional Key]
|
||||||
|
|
|
||||||
|
|
@ -55,6 +55,11 @@ constexpr DispatchKeySet math_dispatch_keyset = backend_dispatch_keyset |
|
||||||
// math_dispatch_keyset
|
// math_dispatch_keyset
|
||||||
DispatchKeySet{DispatchKey::NestedTensor};
|
DispatchKeySet{DispatchKey::NestedTensor};
|
||||||
|
|
||||||
|
constexpr DispatchKeySet nested_dispatch_keyset =
|
||||||
|
DispatchKeySet(
|
||||||
|
{DispatchKey::AutogradNestedTensor, DispatchKey::NestedTensor}) |
|
||||||
|
DispatchKeySet(DispatchKeySet::RAW, full_backend_mask);
|
||||||
|
|
||||||
DispatchKeySet getRuntimeDispatchKeySet(DispatchKey t) {
|
DispatchKeySet getRuntimeDispatchKeySet(DispatchKey t) {
|
||||||
TORCH_INTERNAL_ASSERT(t != DispatchKey::Undefined);
|
TORCH_INTERNAL_ASSERT(t != DispatchKey::Undefined);
|
||||||
switch (t) {
|
switch (t) {
|
||||||
|
|
@ -67,6 +72,8 @@ DispatchKeySet getRuntimeDispatchKeySet(DispatchKey t) {
|
||||||
DispatchKeySet(DispatchKeySet::RAW, full_backend_mask);
|
DispatchKeySet(DispatchKeySet::RAW, full_backend_mask);
|
||||||
case DispatchKey::CompositeImplicitAutograd:
|
case DispatchKey::CompositeImplicitAutograd:
|
||||||
return math_dispatch_keyset;
|
return math_dispatch_keyset;
|
||||||
|
case DispatchKey::CompositeImplicitAutogradNestedTensor:
|
||||||
|
return nested_dispatch_keyset;
|
||||||
case DispatchKey::CompositeExplicitAutograd:
|
case DispatchKey::CompositeExplicitAutograd:
|
||||||
return backend_dispatch_keyset;
|
return backend_dispatch_keyset;
|
||||||
case DispatchKey::CompositeExplicitAutogradNonFunctional:
|
case DispatchKey::CompositeExplicitAutogradNonFunctional:
|
||||||
|
|
@ -84,6 +91,9 @@ bool runtimeDispatchKeySetHas(DispatchKey t, DispatchKey k) {
|
||||||
case DispatchKey::CompositeImplicitAutograd:
|
case DispatchKey::CompositeImplicitAutograd:
|
||||||
// See Note [NestedTensor Not Included in Backend Keys]
|
// See Note [NestedTensor Not Included in Backend Keys]
|
||||||
return math_dispatch_keyset.has(k);
|
return math_dispatch_keyset.has(k);
|
||||||
|
case DispatchKey::CompositeImplicitAutogradNestedTensor:
|
||||||
|
// See Note [NestedTensor Not Included in Backend Keys]
|
||||||
|
return nested_dispatch_keyset.has(k);
|
||||||
case DispatchKey::CompositeExplicitAutograd:
|
case DispatchKey::CompositeExplicitAutograd:
|
||||||
// See Note [NestedTensor Not Included in Backend Keys]
|
// See Note [NestedTensor Not Included in Backend Keys]
|
||||||
return k != DispatchKey::NestedTensor && backend_dispatch_keyset.has(k);
|
return k != DispatchKey::NestedTensor && backend_dispatch_keyset.has(k);
|
||||||
|
|
|
||||||
|
|
@ -128,6 +128,9 @@ ALLOW_LIST = [
|
||||||
("aten::nanmean.out", datetime.date(2022, 8, 30)),
|
("aten::nanmean.out", datetime.date(2022, 8, 30)),
|
||||||
("aten::nansum", datetime.date(2022, 8, 30)),
|
("aten::nansum", datetime.date(2022, 8, 30)),
|
||||||
("aten::nansum.out", datetime.date(2022, 8, 30)),
|
("aten::nansum.out", datetime.date(2022, 8, 30)),
|
||||||
|
# nested tensor temporary auxiliary ops
|
||||||
|
("aten::_reshape_nested", datetime.date(9999, 1, 1)),
|
||||||
|
("aten::_reshape_nested_backward", datetime.date(9999, 1, 1)),
|
||||||
("aten::sum.SymInt", datetime.date(2022, 11, 30)),
|
("aten::sum.SymInt", datetime.date(2022, 11, 30)),
|
||||||
("aten::mps_linear", datetime.date(9999, 1, 1)),
|
("aten::mps_linear", datetime.date(9999, 1, 1)),
|
||||||
("aten::_mps_linear", datetime.date(9999, 1, 1)),
|
("aten::_mps_linear", datetime.date(9999, 1, 1)),
|
||||||
|
|
|
||||||
|
|
@ -230,11 +230,7 @@ class TestNestedTensor(TestCase):
|
||||||
|
|
||||||
# Test non_contiguous case
|
# Test non_contiguous case
|
||||||
assert not nt_noncontiguous.is_contiguous()
|
assert not nt_noncontiguous.is_contiguous()
|
||||||
self.assertRaisesRegex(
|
self.assertEqual(nt_contiguous, nt_noncontiguous.contiguous())
|
||||||
RuntimeError,
|
|
||||||
r"clone_nested only supports memory format Preserve, but got Contiguous instead.",
|
|
||||||
lambda: nt_noncontiguous.contiguous()
|
|
||||||
)
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_repr_string(self):
|
def test_repr_string(self):
|
||||||
|
|
@ -679,7 +675,6 @@ class TestNestedTensorDeviceType(TestCase):
|
||||||
|
|
||||||
@dtypes(torch.float, torch.float16)
|
@dtypes(torch.float, torch.float16)
|
||||||
@skipMeta
|
@skipMeta
|
||||||
@torch.inference_mode()
|
|
||||||
def test_clone(self, device, dtype):
|
def test_clone(self, device, dtype):
|
||||||
nt1 = self.random_nt(device, dtype, 4, (4, 4), (1, 1))
|
nt1 = self.random_nt(device, dtype, 4, (4, 4), (1, 1))
|
||||||
nt2 = nt1.clone()
|
nt2 = nt1.clone()
|
||||||
|
|
@ -693,7 +688,7 @@ class TestNestedTensorDeviceType(TestCase):
|
||||||
self.assertNotEqual(ub1[i], ub2[i])
|
self.assertNotEqual(ub1[i], ub2[i])
|
||||||
|
|
||||||
nt1.clone(memory_format=torch.preserve_format)
|
nt1.clone(memory_format=torch.preserve_format)
|
||||||
msg = "clone_nested only supports memory format Preserve, but got ChannelsLast instead."
|
msg = "Nested tensor clone supports Preserve and Contiguous memory formats, called clone with memory format: ChannelsLast"
|
||||||
with self.assertRaisesRegex(RuntimeError, msg):
|
with self.assertRaisesRegex(RuntimeError, msg):
|
||||||
nt1.clone(memory_format=torch.channels_last)
|
nt1.clone(memory_format=torch.channels_last)
|
||||||
|
|
||||||
|
|
@ -1105,7 +1100,6 @@ class TestNestedTensorDeviceType(TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
@dtypes(torch.float, torch.float16, torch.double)
|
@dtypes(torch.float, torch.float16, torch.double)
|
||||||
@torch.inference_mode()
|
|
||||||
def test_transpose(self, device, dtype):
|
def test_transpose(self, device, dtype):
|
||||||
nt = self.random_nt(device, dtype, 4, (4, 4))
|
nt = self.random_nt(device, dtype, 4, (4, 4))
|
||||||
# error case: transpose nested dimension
|
# error case: transpose nested dimension
|
||||||
|
|
@ -1150,7 +1144,74 @@ class TestNestedTensorDeviceType(TestCase):
|
||||||
self.assertEqual(ptT, ptT_from_ntT)
|
self.assertEqual(ptT, ptT_from_ntT)
|
||||||
|
|
||||||
@dtypes(torch.float, torch.float16, torch.double)
|
@dtypes(torch.float, torch.float16, torch.double)
|
||||||
@torch.inference_mode()
|
def test_view(self, device, dtype):
|
||||||
|
nt = self.random_nt(device, dtype, 4, (4, 4))
|
||||||
|
# error case: empty shape
|
||||||
|
self.assertRaisesRegex(
|
||||||
|
RuntimeError,
|
||||||
|
r"shape '\[\]' is invalid for a nested tensor",
|
||||||
|
lambda: nt.view(())
|
||||||
|
)
|
||||||
|
# error case: empty nested tensor
|
||||||
|
nt_empty = torch.nested_tensor([])
|
||||||
|
self.assertRaisesRegex(
|
||||||
|
RuntimeError,
|
||||||
|
"empty nested tensor cannot be reshaped",
|
||||||
|
lambda: nt_empty.view(-1)
|
||||||
|
)
|
||||||
|
# error case: invalid proposed shape for underlying tensors
|
||||||
|
self.assertRaisesRegex(
|
||||||
|
RuntimeError,
|
||||||
|
r"invalid shape dimension -2",
|
||||||
|
lambda: nt.view(-2, 2, 3)
|
||||||
|
)
|
||||||
|
self.assertRaisesRegex(
|
||||||
|
RuntimeError,
|
||||||
|
r"shape '\[.*\]' is invalid for input of size [0-9]+",
|
||||||
|
lambda: nt.view(4, 2, 3)
|
||||||
|
)
|
||||||
|
# normal case
|
||||||
|
x0 = torch.randn((2, 20), device=device, dtype=dtype)
|
||||||
|
x1 = torch.randn((3, 20), device=device, dtype=dtype)
|
||||||
|
nt = torch.nested_tensor([x0, x1])
|
||||||
|
pt = nt.to_padded_tensor(0.0)
|
||||||
|
self.assertRaisesRegex(
|
||||||
|
RuntimeError,
|
||||||
|
r"for now view cannot change the implicit batch dimension",
|
||||||
|
lambda: nt.transpose(-1, -2).view(40, -1)
|
||||||
|
)
|
||||||
|
# inherit only the ragged dimension
|
||||||
|
# (2, 20) -> (2, 5, 4)
|
||||||
|
# (3, 20) -> (3, 5, 4)
|
||||||
|
nt1 = nt.view(2, -1, 5, 4)
|
||||||
|
# (2, 3, 20) -> (2, 3, 5, 4) -> (2, 4, 5, 4)
|
||||||
|
pt1 = pt.view(2, -1, 5, 4)
|
||||||
|
self.assertEqual(noncontiguous_to_padded_tensor(nt1), pt1)
|
||||||
|
# also inherit regular dimension
|
||||||
|
nt2 = nt1.view(2, -1, -1, 2, 2)
|
||||||
|
pt2 = pt1.view(2, -1, 5, 2, 2)
|
||||||
|
self.assertEqual(noncontiguous_to_padded_tensor(nt2), pt2)
|
||||||
|
|
||||||
|
@dtypes(torch.float, torch.float16, torch.double)
|
||||||
|
def test_view_inference_mode_interaction(self, device, dtype):
|
||||||
|
# Construct in default mode and view while in inference mode
|
||||||
|
nt = torch.nested_tensor([torch.randn((2, 20)), torch.randn((3, 20))], device=device, dtype=dtype)
|
||||||
|
with torch.inference_mode():
|
||||||
|
ntT = nt.view(2, -1, 4, 5)
|
||||||
|
ptT_from_ntT = noncontiguous_to_padded_tensor(ntT)
|
||||||
|
pt = nt.to_padded_tensor(0.0)
|
||||||
|
ptT = pt.view(2, -1, 4, 5)
|
||||||
|
self.assertEqual(ptT, ptT_from_ntT)
|
||||||
|
# Construct and view while in inference mode
|
||||||
|
with torch.inference_mode():
|
||||||
|
nt = torch.nested_tensor([torch.randn((2, 20)), torch.randn((3, 20))], device=device, dtype=dtype)
|
||||||
|
ntT = nt.view(2, -1, 4, 5)
|
||||||
|
ptT_from_ntT = noncontiguous_to_padded_tensor(ntT)
|
||||||
|
pt = nt.to_padded_tensor(0.0)
|
||||||
|
ptT = pt.view(2, -1, 4, 5)
|
||||||
|
self.assertEqual(ptT, ptT_from_ntT)
|
||||||
|
|
||||||
|
@dtypes(torch.float, torch.float16, torch.double)
|
||||||
def test_reshape(self, device, dtype):
|
def test_reshape(self, device, dtype):
|
||||||
nt = self.random_nt(device, dtype, 4, (4, 4))
|
nt = self.random_nt(device, dtype, 4, (4, 4))
|
||||||
# error case: empty shape
|
# error case: empty shape
|
||||||
|
|
|
||||||
|
|
@ -1356,10 +1356,6 @@
|
||||||
# making it impossible (hard) to detect when it is actually a view.
|
# making it impossible (hard) to detect when it is actually a view.
|
||||||
# - name: reshape(Tensor self, IntArrayRef shape)
|
# - name: reshape(Tensor self, IntArrayRef shape)
|
||||||
|
|
||||||
- name: _reshape_nested(Tensor self, int[] shape) -> Tensor
|
|
||||||
self: _reshape_nested_backward(self, grad)
|
|
||||||
result: auto_linear
|
|
||||||
|
|
||||||
- name: _reshape_alias(Tensor(a) self, int[] size, int[] stride) -> Tensor(a)
|
- name: _reshape_alias(Tensor(a) self, int[] size, int[] stride) -> Tensor(a)
|
||||||
self: grad.reshape(self.sizes())
|
self: grad.reshape(self.sizes())
|
||||||
result: auto_linear
|
result: auto_linear
|
||||||
|
|
@ -1732,10 +1728,14 @@
|
||||||
# linear
|
# linear
|
||||||
result1: mean(self_t, dim.value_or(IntArrayRef({})), keepdim)
|
result1: mean(self_t, dim.value_or(IntArrayRef({})), keepdim)
|
||||||
|
|
||||||
# TODO: this derivative is not SymInt safe, need reshape_symint
|
|
||||||
- name: view(Tensor(a) self, SymInt[] size) -> Tensor(a)
|
- name: view(Tensor(a) self, SymInt[] size) -> Tensor(a)
|
||||||
self: grad.reshape(self.sizes())
|
dispatch:
|
||||||
result: auto_linear
|
Default:
|
||||||
|
self: grad.reshape(self.sizes())
|
||||||
|
result: auto_linear
|
||||||
|
AutogradNestedTensor:
|
||||||
|
self: grad.reshape_as(self)
|
||||||
|
result: auto_linear
|
||||||
|
|
||||||
- name: view.dtype(Tensor(a) self, ScalarType dtype) -> Tensor(a)
|
- name: view.dtype(Tensor(a) self, ScalarType dtype) -> Tensor(a)
|
||||||
output_differentiability: [False]
|
output_differentiability: [False]
|
||||||
|
|
|
||||||
|
|
@ -310,6 +310,7 @@ def static_dispatch_keys(backends: List[BackendIndex]) -> List[DispatchKey]:
|
||||||
else:
|
else:
|
||||||
return [backend.dispatch_key for backend in backends] + [
|
return [backend.dispatch_key for backend in backends] + [
|
||||||
DispatchKey.CompositeImplicitAutograd,
|
DispatchKey.CompositeImplicitAutograd,
|
||||||
|
DispatchKey.CompositeImplicitAutogradNestedTensor,
|
||||||
DispatchKey.CompositeExplicitAutograd,
|
DispatchKey.CompositeExplicitAutograd,
|
||||||
DispatchKey.CompositeExplicitAutogradNonFunctional,
|
DispatchKey.CompositeExplicitAutogradNonFunctional,
|
||||||
]
|
]
|
||||||
|
|
@ -330,6 +331,8 @@ def get_static_dispatch_backend(
|
||||||
return DispatchKey.CompositeExplicitAutogradNonFunctional
|
return DispatchKey.CompositeExplicitAutogradNonFunctional
|
||||||
elif f.has_composite_implicit_autograd_kernel:
|
elif f.has_composite_implicit_autograd_kernel:
|
||||||
return DispatchKey.CompositeImplicitAutograd
|
return DispatchKey.CompositeImplicitAutograd
|
||||||
|
elif f.has_composite_implicit_autograd_nested_tensor_kernel:
|
||||||
|
return DispatchKey.CompositeImplicitAutogradNestedTensor
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -426,6 +429,8 @@ def generate_static_dispatch_fallback_call(
|
||||||
return f"return {ns}::{DispatchKey.CompositeExplicitAutogradNonFunctional.lower()}::{name}({exprs});"
|
return f"return {ns}::{DispatchKey.CompositeExplicitAutogradNonFunctional.lower()}::{name}({exprs});"
|
||||||
elif f.has_composite_implicit_autograd_kernel:
|
elif f.has_composite_implicit_autograd_kernel:
|
||||||
return f"return {ns}::{DispatchKey.CompositeImplicitAutograd.lower()}::{name}({exprs});"
|
return f"return {ns}::{DispatchKey.CompositeImplicitAutograd.lower()}::{name}({exprs});"
|
||||||
|
elif f.has_composite_implicit_autograd_nested_tensor_kernel:
|
||||||
|
return f"return {ns}::{DispatchKey.CompositeImplicitAutogradNestedTensor.lower()}::{name}({exprs});"
|
||||||
else:
|
else:
|
||||||
return f"""TORCH_CHECK(false, "Static dispatch does not support {name} for\
|
return f"""TORCH_CHECK(false, "Static dispatch does not support {name} for\
|
||||||
{', '.join([str(index.dispatch_key)for index in backend_indices])} ");"""
|
{', '.join([str(index.dispatch_key)for index in backend_indices])} ");"""
|
||||||
|
|
@ -1241,6 +1246,11 @@ def compute_registration_declarations(
|
||||||
"dispatch": str(
|
"dispatch": str(
|
||||||
{k for k, v in backend_indices.items() if v.has_kernel(f)}
|
{k for k, v in backend_indices.items() if v.has_kernel(f)}
|
||||||
!= {DispatchKey.CompositeImplicitAutograd}
|
!= {DispatchKey.CompositeImplicitAutograd}
|
||||||
|
and {k for k, v in backend_indices.items() if v.has_kernel(f)}
|
||||||
|
!= {
|
||||||
|
DispatchKey.CompositeImplicitAutograd,
|
||||||
|
DispatchKey.CompositeImplicitAutogradNestedTensor,
|
||||||
|
}
|
||||||
),
|
),
|
||||||
"default": str(f.has_composite_kernel or has_autogenerated_composite_kernel(f)),
|
"default": str(f.has_composite_kernel or has_autogenerated_composite_kernel(f)),
|
||||||
}
|
}
|
||||||
|
|
@ -2145,6 +2155,13 @@ def gen_source_files(
|
||||||
ns_grouped_native_functions[namespace].append(grouped_native_function)
|
ns_grouped_native_functions[namespace].append(grouped_native_function)
|
||||||
|
|
||||||
dispatch_namespace = str(dispatch_key).lower()
|
dispatch_namespace = str(dispatch_key).lower()
|
||||||
|
|
||||||
|
# CompositeImplicitAutogradNestdTensor does not currently user the helpers generated
|
||||||
|
# compilation will fail when `-Werror=unused-function` flag is set
|
||||||
|
gen_dispatch_helpers: bool = (
|
||||||
|
dispatch_key != DispatchKey.CompositeImplicitAutogradNestedTensor
|
||||||
|
)
|
||||||
|
|
||||||
dispatch_definitions = get_native_function_definitions(
|
dispatch_definitions = get_native_function_definitions(
|
||||||
fm=fm,
|
fm=fm,
|
||||||
grouped_native_functions=grouped_native_functions,
|
grouped_native_functions=grouped_native_functions,
|
||||||
|
|
@ -2153,7 +2170,7 @@ def gen_source_files(
|
||||||
selector=selector,
|
selector=selector,
|
||||||
rocm=rocm,
|
rocm=rocm,
|
||||||
skip_dispatcher_op_registration=skip_dispatcher_op_registration,
|
skip_dispatcher_op_registration=skip_dispatcher_op_registration,
|
||||||
gen_dispatch_helpers=True,
|
gen_dispatch_helpers=gen_dispatch_helpers,
|
||||||
)
|
)
|
||||||
fm.write_with_template(
|
fm.write_with_template(
|
||||||
f"Register{dispatch_key}.cpp",
|
f"Register{dispatch_key}.cpp",
|
||||||
|
|
@ -2636,6 +2653,7 @@ def main() -> None:
|
||||||
DispatchKey.CPU,
|
DispatchKey.CPU,
|
||||||
DispatchKey.CUDA,
|
DispatchKey.CUDA,
|
||||||
DispatchKey.CompositeImplicitAutograd,
|
DispatchKey.CompositeImplicitAutograd,
|
||||||
|
DispatchKey.CompositeImplicitAutogradNestedTensor,
|
||||||
DispatchKey.CompositeExplicitAutograd,
|
DispatchKey.CompositeExplicitAutograd,
|
||||||
DispatchKey.CompositeExplicitAutogradNonFunctional,
|
DispatchKey.CompositeExplicitAutogradNonFunctional,
|
||||||
DispatchKey.Meta,
|
DispatchKey.Meta,
|
||||||
|
|
|
||||||
|
|
@ -93,6 +93,7 @@ class DispatchKey(Enum):
|
||||||
|
|
||||||
Autograd = auto()
|
Autograd = auto()
|
||||||
CompositeImplicitAutograd = auto()
|
CompositeImplicitAutograd = auto()
|
||||||
|
CompositeImplicitAutogradNestedTensor = auto()
|
||||||
CompositeExplicitAutograd = auto()
|
CompositeExplicitAutograd = auto()
|
||||||
CompositeExplicitAutogradNonFunctional = auto()
|
CompositeExplicitAutogradNonFunctional = auto()
|
||||||
|
|
||||||
|
|
@ -217,6 +218,7 @@ dispatch_keys = [
|
||||||
DispatchKey.QuantizedCPU,
|
DispatchKey.QuantizedCPU,
|
||||||
DispatchKey.QuantizedCUDA,
|
DispatchKey.QuantizedCUDA,
|
||||||
DispatchKey.CompositeImplicitAutograd,
|
DispatchKey.CompositeImplicitAutograd,
|
||||||
|
DispatchKey.CompositeImplicitAutogradNestedTensor,
|
||||||
DispatchKey.CompositeExplicitAutograd,
|
DispatchKey.CompositeExplicitAutograd,
|
||||||
DispatchKey.CompositeExplicitAutogradNonFunctional,
|
DispatchKey.CompositeExplicitAutogradNonFunctional,
|
||||||
DispatchKey.NestedTensorCPU,
|
DispatchKey.NestedTensorCPU,
|
||||||
|
|
@ -237,6 +239,7 @@ def is_generic_dispatch_key(dk: DispatchKey) -> bool:
|
||||||
DispatchKey.CompositeExplicitAutograd,
|
DispatchKey.CompositeExplicitAutograd,
|
||||||
DispatchKey.CompositeExplicitAutogradNonFunctional,
|
DispatchKey.CompositeExplicitAutogradNonFunctional,
|
||||||
DispatchKey.CompositeImplicitAutograd,
|
DispatchKey.CompositeImplicitAutograd,
|
||||||
|
DispatchKey.CompositeImplicitAutogradNestedTensor,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -485,6 +488,7 @@ class NativeFunction:
|
||||||
|
|
||||||
# Whether or not the NativeFunction contains a backend-agnostic kernel
|
# Whether or not the NativeFunction contains a backend-agnostic kernel
|
||||||
has_composite_implicit_autograd_kernel: bool
|
has_composite_implicit_autograd_kernel: bool
|
||||||
|
has_composite_implicit_autograd_nested_tensor_kernel: bool
|
||||||
has_composite_explicit_autograd_kernel: bool
|
has_composite_explicit_autograd_kernel: bool
|
||||||
has_composite_explicit_autograd_non_functional_kernel: bool
|
has_composite_explicit_autograd_non_functional_kernel: bool
|
||||||
|
|
||||||
|
|
@ -699,9 +703,15 @@ class NativeFunction:
|
||||||
if d == DispatchKey.CompositeExplicitAutograd
|
if d == DispatchKey.CompositeExplicitAutograd
|
||||||
or d == DispatchKey.CompositeExplicitAutogradNonFunctional
|
or d == DispatchKey.CompositeExplicitAutogradNonFunctional
|
||||||
or d == DispatchKey.CompositeImplicitAutograd
|
or d == DispatchKey.CompositeImplicitAutograd
|
||||||
|
or d == DispatchKey.CompositeImplicitAutogradNestedTensor
|
||||||
]
|
]
|
||||||
|
|
||||||
assert len(composites_in_dispatch) <= 1, (
|
assert len(composites_in_dispatch) <= 1 or (
|
||||||
|
len(composites_in_dispatch) == 2
|
||||||
|
and DispatchKey.CompositeImplicitAutograd in composites_in_dispatch
|
||||||
|
and DispatchKey.CompositeImplicitAutogradNestedTensor
|
||||||
|
in composites_in_dispatch
|
||||||
|
), (
|
||||||
"cannot specify more than one of CompositeExplicitAutograd, CompositeExplicitAutogradNonFunctional, "
|
"cannot specify more than one of CompositeExplicitAutograd, CompositeExplicitAutogradNonFunctional, "
|
||||||
"or CompositeImplicitAutograd on a single kernel; each "
|
"or CompositeImplicitAutograd on a single kernel; each "
|
||||||
"strictly subsumes the other. If you wanted to provide an explicit autograd "
|
"strictly subsumes the other. If you wanted to provide an explicit autograd "
|
||||||
|
|
@ -756,11 +766,23 @@ class NativeFunction:
|
||||||
# Structured functions MUST have a dispatch table
|
# Structured functions MUST have a dispatch table
|
||||||
is_abstract = True
|
is_abstract = True
|
||||||
else:
|
else:
|
||||||
is_abstract = dispatch.keys() != {DispatchKey.CompositeImplicitAutograd}
|
is_abstract = (
|
||||||
|
dispatch.keys() != {DispatchKey.CompositeImplicitAutograd}
|
||||||
|
and dispatch.keys()
|
||||||
|
!= {DispatchKey.CompositeImplicitAutogradNestedTensor}
|
||||||
|
and dispatch.keys()
|
||||||
|
!= {
|
||||||
|
DispatchKey.CompositeImplicitAutograd,
|
||||||
|
DispatchKey.CompositeImplicitAutogradNestedTensor,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
has_composite_implicit_autograd_kernel = (
|
has_composite_implicit_autograd_kernel = (
|
||||||
DispatchKey.CompositeImplicitAutograd in dispatch.keys()
|
DispatchKey.CompositeImplicitAutograd in dispatch.keys()
|
||||||
)
|
)
|
||||||
|
has_composite_implicit_autograd_nested_tensor_kernel = (
|
||||||
|
DispatchKey.CompositeImplicitAutogradNestedTensor in dispatch.keys()
|
||||||
|
)
|
||||||
has_composite_explicit_autograd_kernel = (
|
has_composite_explicit_autograd_kernel = (
|
||||||
DispatchKey.CompositeExplicitAutograd in dispatch.keys()
|
DispatchKey.CompositeExplicitAutograd in dispatch.keys()
|
||||||
)
|
)
|
||||||
|
|
@ -808,6 +830,7 @@ class NativeFunction:
|
||||||
cpp_no_default_args=cpp_no_default_args,
|
cpp_no_default_args=cpp_no_default_args,
|
||||||
is_abstract=is_abstract,
|
is_abstract=is_abstract,
|
||||||
has_composite_implicit_autograd_kernel=has_composite_implicit_autograd_kernel,
|
has_composite_implicit_autograd_kernel=has_composite_implicit_autograd_kernel,
|
||||||
|
has_composite_implicit_autograd_nested_tensor_kernel=has_composite_implicit_autograd_nested_tensor_kernel,
|
||||||
has_composite_explicit_autograd_kernel=has_composite_explicit_autograd_kernel,
|
has_composite_explicit_autograd_kernel=has_composite_explicit_autograd_kernel,
|
||||||
has_composite_explicit_autograd_non_functional_kernel=has_composite_explicit_autograd_non_functional_kernel,
|
has_composite_explicit_autograd_non_functional_kernel=has_composite_explicit_autograd_non_functional_kernel,
|
||||||
tags=tags,
|
tags=tags,
|
||||||
|
|
@ -899,6 +922,9 @@ class NativeFunction:
|
||||||
self.has_composite_implicit_autograd_kernel
|
self.has_composite_implicit_autograd_kernel
|
||||||
or self.has_composite_explicit_autograd_kernel
|
or self.has_composite_explicit_autograd_kernel
|
||||||
or self.has_composite_explicit_autograd_non_functional_kernel
|
or self.has_composite_explicit_autograd_non_functional_kernel
|
||||||
|
) or (
|
||||||
|
self.has_composite_implicit_autograd_kernel
|
||||||
|
and self.has_composite_implicit_autograd_nested_tensor_kernel
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
@ -976,7 +1002,10 @@ class NativeFunctionsGroup:
|
||||||
if self.structured:
|
if self.structured:
|
||||||
# For now, structured composite kernels are not supported (need some
|
# For now, structured composite kernels are not supported (need some
|
||||||
# design work to figure out how to make the composite case work)
|
# design work to figure out how to make the composite case work)
|
||||||
assert not self.out.has_composite_implicit_autograd_kernel
|
assert (
|
||||||
|
not self.out.has_composite_implicit_autograd_kernel
|
||||||
|
and not self.out.has_composite_implicit_autograd_nested_tensor_kernel
|
||||||
|
)
|
||||||
|
|
||||||
assert self.functional.structured_delegate == self.out.func.name, (
|
assert self.functional.structured_delegate == self.out.func.name, (
|
||||||
f"{self.functional.func.name} delegates to {self.functional.structured_delegate} "
|
f"{self.functional.func.name} delegates to {self.functional.structured_delegate} "
|
||||||
|
|
@ -2510,6 +2539,14 @@ class NativeFunctionsViewGroup:
|
||||||
f"{str(self.view.func.name)} and {str(self.view_inplace.func.name)} must either"
|
f"{str(self.view.func.name)} and {str(self.view_inplace.func.name)} must either"
|
||||||
" both have CompositeImplicitAutograd kernels, or both not have composite kernels."
|
" both have CompositeImplicitAutograd kernels, or both not have composite kernels."
|
||||||
)
|
)
|
||||||
|
if self.view.has_composite_implicit_autograd_nested_tensor_kernel:
|
||||||
|
if self.view_inplace is not None:
|
||||||
|
assert (
|
||||||
|
self.view_inplace.has_composite_implicit_autograd_nested_tensor_kernel
|
||||||
|
), (
|
||||||
|
f"{str(self.view.func.name)} and {str(self.view_inplace.func.name)} must either"
|
||||||
|
" both have CompositeImplicitAutogradNestedTensor kernels, or both not have composite kernels."
|
||||||
|
)
|
||||||
|
|
||||||
def functions(self, *, include_copy: bool = True) -> Iterator[NativeFunction]:
|
def functions(self, *, include_copy: bool = True) -> Iterator[NativeFunction]:
|
||||||
yield self.view
|
yield self.view
|
||||||
|
|
|
||||||
|
|
@ -336,6 +336,7 @@ def generate_function(
|
||||||
cpp_no_default_args=set(),
|
cpp_no_default_args=set(),
|
||||||
is_abstract=f.is_abstract,
|
is_abstract=f.is_abstract,
|
||||||
has_composite_implicit_autograd_kernel=False,
|
has_composite_implicit_autograd_kernel=False,
|
||||||
|
has_composite_implicit_autograd_nested_tensor_kernel=False,
|
||||||
has_composite_explicit_autograd_kernel=True,
|
has_composite_explicit_autograd_kernel=True,
|
||||||
has_composite_explicit_autograd_non_functional_kernel=False,
|
has_composite_explicit_autograd_non_functional_kernel=False,
|
||||||
# Every generated NativeFunction gets a "generated" tag, so it's easy to tell
|
# Every generated NativeFunction gets a "generated" tag, so it's easy to tell
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user