From 673b35c847ee6ba67367ba27ff8597c8ae382257 Mon Sep 17 00:00:00 2001 From: YifanShenSZ Date: Thu, 1 Sep 2022 20:01:39 +0000 Subject: [PATCH] Better reshape with autograd support (#82754) (#84154) The original author is @YifanShenSZ and the original PR is: #82754 # Summary: Previous reshape [https://github.com/pytorch/pytorch/issues/80981](https://github.com/pytorch/pytorch/pull/80981) is ok for forward, but needs improvement for backward: need to handle "sometimes view sometimes copy" behavior. This pull request fixes it by: 1. add a new alias dispatch key `CompositeImplicitAutogradNestedTensor`, which ideally would work as nested-tensor version of `CompositeImplicitAutograd` 2. register `reshape_nested` to `reshape` by `CompositeImplicitAutogradNestedTensor` Side changes: * add contiguous memory format support to `clone_nested` * add `view_nested` * add `reshape_as_nested` Fix issue [https://github.com/pytorch/pytorch/issues/83041](https://github.com/pytorch/pytorch/issues/83041) Pull Request resolved: https://github.com/pytorch/pytorch/pull/82754 Test Plan: Imported from GitHub, without a `Test Plan:` line. **Static Docs Preview: executorch** |[Full Site](https://our.intern.facebook.com/intern/staticdocs/eph/D39023822/V13/executorch/)| |**Modified Pages**| Reviewed By: albanD Differential Revision: D39023822 Pulled By: drisspg Pull Request resolved: https://github.com/pytorch/pytorch/pull/84154 Approved by: https://github.com/bdhirsh, https://github.com/albanD --- BUILD.bazel | 3 + aten/src/ATen/core/dispatch/OperatorEntry.cpp | 15 ++ aten/src/ATen/native/TensorShape.cpp | 11 - aten/src/ATen/native/native_functions.yaml | 15 +- .../native/nested/NestedTensorBackward.cpp | 17 -- .../ATen/native/nested/NestedTensorMath.cpp | 202 ++++++++++++------ buckbuild.bzl | 8 +- build.bzl | 3 + c10/core/DispatchKey.cpp | 4 + c10/core/DispatchKey.h | 2 + c10/core/DispatchKeySet.cpp | 10 + .../check_forward_backward_compatibility.py | 3 + test/test_nestedtensor.py | 79 ++++++- tools/autograd/derivatives.yaml | 14 +- torchgen/gen.py | 20 +- torchgen/model.py | 43 +++- torchgen/native_function_generation.py | 1 + 17 files changed, 327 insertions(+), 123 deletions(-) diff --git a/BUILD.bazel b/BUILD.bazel index 66c4af5c1fd..b0decc7b0c3 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -51,6 +51,7 @@ generated_cpu_cpp = [ "aten/src/ATen/RegisterSparseCsrCPU.cpp", "aten/src/ATen/RegisterZeroTensor.cpp", "aten/src/ATen/RegisterCompositeImplicitAutograd.cpp", + "aten/src/ATen/RegisterCompositeImplicitAutogradNestedTensor.cpp", "aten/src/ATen/RegisterCompositeExplicitAutograd.cpp", "aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp", "aten/src/ATen/RegisterMeta.cpp", @@ -66,6 +67,8 @@ generated_cpu_cpp = [ "aten/src/ATen/CompositeExplicitAutogradNonFunctionalFunctions_inl.h", "aten/src/ATen/CompositeImplicitAutogradFunctions.h", "aten/src/ATen/CompositeImplicitAutogradFunctions_inl.h", + "aten/src/ATen/CompositeImplicitAutogradNestedTensorFunctions.h", + "aten/src/ATen/CompositeImplicitAutogradNestedTensorFunctions_inl.h", "aten/src/ATen/CompositeViewCopyKernels.cpp", "aten/src/ATen/FunctionalInverses.h", "aten/src/ATen/Functions.h", diff --git a/aten/src/ATen/core/dispatch/OperatorEntry.cpp b/aten/src/ATen/core/dispatch/OperatorEntry.cpp index e10de4d9f85..139880c6d7f 100644 --- a/aten/src/ATen/core/dispatch/OperatorEntry.cpp +++ b/aten/src/ATen/core/dispatch/OperatorEntry.cpp @@ -307,6 +307,21 @@ std::pair OperatorEntry::computeDispatchTab // For AutogradOther, we return ambiguousAutogradOtherKernel() if there's registration // to any of its backends. // See Note [Undefined in dispatchTable_] for the special handling for Undefined. + + // If the dispatch key is included in CompositeImplicitAutogradNestedTensor, + // then we register it to nested-tensor kernel rather than + // regular-tensor CompositeImplicitAutograd kernel. + // We have no intention to change the behavior of Undefined, + // so this nested-tensor branch requires `dispatch_key != DispatchKey::Undefined` + // to let the original CompositeImplicitAutograd handle Undefined + if (dispatch_key != DispatchKey::Undefined && isIncludedInAlias(dispatch_key, DispatchKey::CompositeImplicitAutogradNestedTensor)) { + if (auto nested_registration = getKernelForDispatchKey(DispatchKey::CompositeImplicitAutogradNestedTensor)) { + if (!has_backend_kernel) { + return {*nested_registration, "nested kernel"}; + } + } + } + if (dispatch_key == DispatchKey::Undefined || isIncludedInAlias(dispatch_key, DispatchKey::CompositeImplicitAutograd)) { if (auto math_registration = getKernelForDispatchKey(DispatchKey::CompositeImplicitAutograd)) { if (dispatch_key == DispatchKey::AutogradOther diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index 6d332c3cc15..5541a6ed826 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -1256,17 +1256,6 @@ Tensor alias_with_sizes_and_strides( } Tensor reshape(const Tensor& self, IntArrayRef proposed_shape) { - // reshape has special autograd logic since it sometimes returns a view but sometimes does not - // we have to intercept here instead of using dispatcher - // otherwise we will see "autograd still running" kind of error in inference mode: - // * if we create a tensor in inference mode scope, - // then pass it to a inference mode decorated function, - // everything is fine - // * but if we create the input tensor not with inference mode, - // then errors like "Cannot set version_counter for inference tensor" arise - if (self.is_nested()) { - return at::_reshape_nested(self, proposed_shape); - } if (self.is_sparse()) { AT_ERROR("reshape is not implemented for sparse tensors"); } diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index cb16ec60ed9..5df05b5ec56 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -4200,16 +4200,9 @@ variants: function, method device_check: NoCheck device_guard: False - -- func: _reshape_nested(Tensor self, int[] shape) -> Tensor dispatch: - NestedTensorCPU, NestedTensorCUDA: _reshape_nested - autogen: _reshape_nested.out - -- func: _reshape_nested_backward(Tensor self, Tensor grad) -> Tensor - dispatch: - NestedTensorCPU, NestedTensorCUDA: _reshape_nested_backward - autogen: _reshape_nested_backward.out + CompositeImplicitAutograd: reshape + CompositeImplicitAutogradNestedTensor: reshape_nested # NOTE [ _reshape_alias ] is meant to be used in the implementation of reshape. # They are not user-facing, hence the leading underscore. Please don't use it @@ -4233,6 +4226,9 @@ variants: method device_check: NoCheck device_guard: False + dispatch: + CompositeImplicitAutograd: reshape_as + CompositeImplicitAutogradNestedTensor: reshape_as_nested - func: round(Tensor self) -> Tensor device_check: NoCheck # TensorIterator @@ -6889,6 +6885,7 @@ Meta: view_meta ZeroTensor, CPU, CUDA, QuantizedCPU, QuantizedCUDA, MPS: view MkldnnCPU: mkldnn_view + NestedTensorCPU, NestedTensorCUDA: view_nested # Warning: If you want to change the name or overload name of this # operator, you might also want to change the `isBlockListedSchema` diff --git a/aten/src/ATen/native/nested/NestedTensorBackward.cpp b/aten/src/ATen/native/nested/NestedTensorBackward.cpp index ec96fdfaf4c..949a2240513 100644 --- a/aten/src/ATen/native/nested/NestedTensorBackward.cpp +++ b/aten/src/ATen/native/nested/NestedTensorBackward.cpp @@ -66,23 +66,6 @@ std::tuple nested_linear_backward( return std::tuple{grad_input, grad_weight, grad_bias}; } -Tensor _reshape_nested_backward(const Tensor& self, const Tensor& grad) { - auto self_ptr = get_nested_tensor_impl(self); - // TODO: this is to reproduce self_ptr->opt_sizes_ - // if an accessor is provided in the future, can replace this - std::vector sizes; - for (int64_t i = 0; i < self_ptr->dim(); i++) { - c10::optional opt_size = self_ptr->opt_size(i); - if (opt_size.has_value()) { - sizes.push_back(*opt_size); - } - else { - sizes.push_back(-1); - } - } - return grad.reshape(sizes); -} - Tensor nested_softmax_backward( const Tensor& grad, const Tensor& output, diff --git a/aten/src/ATen/native/nested/NestedTensorMath.cpp b/aten/src/ATen/native/nested/NestedTensorMath.cpp index d819bceadbb..4e2861caace 100644 --- a/aten/src/ATen/native/nested/NestedTensorMath.cpp +++ b/aten/src/ATen/native/nested/NestedTensorMath.cpp @@ -688,17 +688,38 @@ Tensor select_nested(const Tensor& self, int64_t dim, int64_t index) { Tensor clone_nested( const Tensor& self, c10::optional optional_memory_format) { - auto memory_format = optional_memory_format.value_or(MemoryFormat::Preserve); - TORCH_CHECK( - memory_format == MemoryFormat::Preserve, - "clone_nested only supports memory format Preserve, but got ", - memory_format, - " instead."); - // TODO: The size doesn't necessarily need to be cloned, but it is more - // conservative. This is something we could revisit once we land a more - // efficient implementation of nested_size_tensor_. - return wrap_buffer( - get_buffer(self).clone(), get_nested_size_tensor(self).clone()); + auto memory_format = optional_memory_format.value_or(c10::MemoryFormat::Preserve); + auto self_ptr = get_nested_tensor_impl(self); + if (memory_format == c10::MemoryFormat::Preserve || + (memory_format == c10::MemoryFormat::Contiguous && self.is_contiguous())) { + const Tensor& buffer = self_ptr->get_buffer(), + sizemat = self_ptr->get_nested_size_tensor(), + stridemat = self_ptr->get_nested_stride_tensor(); + const std::vector& offsets = self_ptr->get_offsets(); + // TODO: The size and the stride do not necessarily need to be cloned, + // but it is more conservative. + // This is something we could revisit once we land a more + // efficient implementation of nested_size_tensor_ and nested_stride_tensor. + return wrap_buffer(buffer.clone(), sizemat.clone(), stridemat.clone(), std::vector(offsets)); + } + // actually, memory format is contiguous and self is noncontiguous + else if (memory_format == c10::MemoryFormat::Contiguous) { + const Tensor& self_buffer = self_ptr->get_buffer(), + sizemat = self_ptr->get_nested_size_tensor(); + Tensor output_buffer = at::empty_like(self_buffer); + Tensor output = wrap_buffer(output_buffer, sizemat); + std::vector self_unbind = self.unbind(), + output_unbind = output.unbind(); + for (int64_t i = 0; i < self_ptr->size(0); i++) { + output_unbind[i].copy_(self_unbind[i]); + } + return output; + } else { + TORCH_CHECK( + false, + "Nested tensor clone supports Preserve and Contiguous memory formats, called clone with memory format: ", + memory_format); + } } at::Tensor NestedTensor_get_nested_size_tensor(const at::Tensor& self){ @@ -1008,7 +1029,7 @@ Tensor transpose_nested(const Tensor& self, int64_t dim0, int64_t dim1) { self, sizemat_transposed, stridemat_transposed, std::vector(self_ptr->get_offsets())); } -// utilities supporting `_reshape_nested` +// utilities supporting `view_nested` and `reshape_nested` namespace { // Args: // sizes: the sizes of original nested tensor @@ -1016,10 +1037,10 @@ namespace { // proposed_shape: user proposed new shape // op: the options for new size and stride matrices // Returns: -// whether reshape as view is possible (i.e. old buffer can be reused) +// whether viewable // size matrix after reshape -// stride matrix after reshape (not fully populated if reshape as view is impossible) -inline std::tuple NestedTensor_reshape_size_stride( +// stride matrix after reshape (not fully populated if not viewable) +inline std::tuple NestedTensor_compute_size_stride( const std::vector& sizes, const std::vector& strides, const IntArrayRef& proposed_shape, @@ -1027,7 +1048,7 @@ inline std::tuple NestedTensor_reshape_size_stride( int64_t ntensors = sizes.size(), ndims_underlying = sizes[0].size(), ndims_underlying_reshaped = proposed_shape.size() - 1; - bool reshape_as_view = true; + bool viewable = true; Tensor sizemat_reshaped = at::empty({ntensors, ndims_underlying_reshaped}, op), stridemat_reshaped = at::empty({ntensors, ndims_underlying_reshaped}, op); int64_t* sizemat_reshaped_ptr = sizemat_reshaped.data_ptr(), @@ -1039,6 +1060,7 @@ inline std::tuple NestedTensor_reshape_size_stride( std::vector size_reshaped_vector(proposed_shape.begin() + 1, proposed_shape.end()); // some negative sizes remain to be infered if (ndims_underlying < ndims_underlying_reshaped) { + int64_t numel = 1, numel_reshaped = 1; // replace negative sizes for old dimensions with old sizes for (int64_t idim = 0; idim < ndims_underlying; idim++) { int64_t& size_reshaped = size_reshaped_vector[idim]; @@ -1046,12 +1068,17 @@ inline std::tuple NestedTensor_reshape_size_stride( if (size_reshaped == -1) { size_reshaped = size[idim]; } + numel *= size[idim]; + numel_reshaped *= size_reshaped; } // infer negative size for new dimension int64_t infer_index = -1; for (int64_t idim = ndims_underlying; idim < ndims_underlying_reshaped; idim++) { const int64_t& size_reshaped = size_reshaped_vector[idim]; - if (size_reshaped == -1) { + if (size_reshaped >= 0) { + numel_reshaped *= size_reshaped; + } + else if (size_reshaped == -1) { if (infer_index > -1) { throw std::runtime_error("only one dimension can be inferred"); } @@ -1059,22 +1086,36 @@ inline std::tuple NestedTensor_reshape_size_stride( infer_index = idim; } } - else if (size_reshaped < 0) { + else { AT_ERROR("invalid shape dimension ", size_reshaped); } } - // See Note [inference and inheritance semantics] + // See Note [Special size rule for nested tensor] TORCH_CHECK(infer_index == -1, "nested tensor does not infer shape"); + TORCH_CHECK( + numel == numel_reshaped, + "shape '", proposed_shape, "' ", + "is invalid for input of size ", numel); } // all negative sizes can be replaced else { + int64_t numel = 1, numel_reshaped = 1; for (int64_t idim = 0; idim < ndims_underlying_reshaped; idim++) { int64_t& size_reshaped = size_reshaped_vector[idim]; TORCH_CHECK(size_reshaped >= -1, "invalid shape dimension ", size_reshaped); if (size_reshaped == -1) { size_reshaped = size[idim]; } + numel *= size[idim]; + numel_reshaped *= size_reshaped; } + for (int64_t idim = ndims_underlying_reshaped; idim < ndims_underlying; idim++) { + numel *= size[idim]; + } + TORCH_CHECK( + numel == numel_reshaped, + "shape '", proposed_shape, "' ", + "is invalid for input of size ", numel); } IntArrayRef size_reshaped(size_reshaped_vector); // compute reshaped stride @@ -1092,7 +1133,7 @@ inline std::tuple NestedTensor_reshape_size_stride( } // reshape as view is impossible else { - reshape_as_view = false; + viewable = false; // fill reshaped size into sizemat for (int64_t idim = 0; idim < ndims_underlying_reshaped; idim++) { sizemat_reshaped_ptr[idim] = size_reshaped[idim]; @@ -1100,42 +1141,59 @@ inline std::tuple NestedTensor_reshape_size_stride( sizemat_reshaped_ptr += ndims_underlying_reshaped; } } - return std::make_tuple(reshape_as_view, sizemat_reshaped, stridemat_reshaped); -} - -// Args: -// nt_reshaped: the reshaped nested tensor to receive copies -// buffer: the original nested tensor buffer -// sizes: the original nested tensor sizes (may have gone through collapsing or splitting) -// strides: the original nested tensor strides (may have gone through collapsing or splitting) -// offsets: the original nested tensor offsets (may have gone through collapsing or splitting) -inline void NestedTensor_reshape_copy( - Tensor& nt_reshaped, - const Tensor& buffer, - const std::vector& sizes, - const std::vector& strides, - const std::vector& offsets) { - auto nt_reshaped_ptr = get_nested_tensor_impl(nt_reshaped); - const Tensor& buffer_reshaped = nt_reshaped_ptr->get_buffer(); - std::vector sizes_reshaped = NestedTensor_get_sizes(nt_reshaped_ptr), - strides_reshaped = NestedTensor_get_strides(nt_reshaped_ptr); - const std::vector& offsets_reshaped = nt_reshaped_ptr->get_offsets(); - for (int64_t i = 0; i < nt_reshaped_ptr->size(0); i++) { - buffer_reshaped.as_strided(sizes_reshaped[i], strides_reshaped[i], offsets_reshaped[i]).copy_( - // TODO: can we avoid allocating new memory for `buffer...reshape` - // I did not find anything like reshape_out - buffer.as_strided(sizes[i], strides[i], offsets[i]).reshape(sizes_reshaped[i])); - } + return std::make_tuple(viewable, sizemat_reshaped, stridemat_reshaped); } } // namespace -// Special rules for reshape(nested tensor): -// 1. Only 1 regular dimension can be collapsed with -// or splitted from the implicit batch dimension -// 2. Instead of infering size, -1 means "inherit the old size", so: -// * negative size is legal for a ragged dimension -// * multiple sizes can be -1 -Tensor _reshape_nested(const Tensor& self, IntArrayRef proposed_shape) { +// Note [Special size rule for nested tensor] +// Instead of infering size, -1 means "inherit the old size", so: +// * negative size is legal for a ragged dimension +// * multiple sizes can be -1 +// In principle we could still infer a dimension, +// we are designing a better semantics to include both inheritance and inference +Tensor view_nested(const Tensor& self, IntArrayRef proposed_shape) { + TORCH_CHECK( + proposed_shape.size() > 0, + "shape '[]' is invalid for a nested tensor"); + auto self_ptr = get_nested_tensor_impl(self); + // basic information before reshaping + int64_t ntensors = self_ptr->size(0); + TORCH_CHECK( + ntensors > 0, + "empty nested tensor cannot be reshaped"); + // basic information after reshaping + int64_t ntensors_reshaped; + if (proposed_shape[0] >= 0) { + ntensors_reshaped = proposed_shape[0]; + } + else if (proposed_shape[0] == -1) { + ntensors_reshaped = ntensors; + } + else { + AT_ERROR("invalid shape dimension ", proposed_shape[0]); + } + TORCH_CHECK( + ntensors == ntensors_reshaped, + "for now view cannot change the implicit batch dimension"); + std::vector sizes = NestedTensor_get_sizes(self_ptr), + strides = NestedTensor_get_strides(self_ptr); + // reshaping underlying tensor dimensions does not change offset + // determine reshaped size and stride + const Tensor& sizemat = self_ptr->get_nested_size_tensor(); + bool viewable; + Tensor sizemat_reshaped, stridemat_reshaped; + std::tie(viewable, sizemat_reshaped, stridemat_reshaped) = NestedTensor_compute_size_stride( + sizes, strides, proposed_shape, sizemat.options()); + TORCH_CHECK( + viewable, + "view size is not compatible with input tensor's size and stride " + "(at least one dimension spans across two contiguous subspaces). " + "Use .reshape(...) instead."); + return create_nested_view_tensor(self, sizemat_reshaped, stridemat_reshaped, std::vector(self_ptr->get_offsets())); +} + +// See Note [Special size rule for nested tensor] +Tensor reshape_nested(const Tensor& self, IntArrayRef proposed_shape) { TORCH_CHECK( proposed_shape.size() > 0, "shape '[]' is invalid for a nested tensor"); @@ -1161,23 +1219,37 @@ Tensor _reshape_nested(const Tensor& self, IntArrayRef proposed_shape) { "for now reshape cannot change the implicit batch dimension"); std::vector sizes = NestedTensor_get_sizes(self_ptr), strides = NestedTensor_get_strides(self_ptr); - const std::vector& offsets = self_ptr->get_offsets(); // reshaping underlying tensor dimensions does not change offset // determine reshaped size and stride - const Tensor& buffer = self_ptr->get_buffer(), - & sizemat = self_ptr->get_nested_size_tensor(); - bool reshape_as_view; + const Tensor& sizemat = self_ptr->get_nested_size_tensor(); + bool viewable; Tensor sizemat_reshaped, stridemat_reshaped; - std::tie(reshape_as_view, sizemat_reshaped, stridemat_reshaped) = NestedTensor_reshape_size_stride( + std::tie(viewable, sizemat_reshaped, stridemat_reshaped) = NestedTensor_compute_size_stride( sizes, strides, proposed_shape, sizemat.options()); - if (reshape_as_view) { - return wrap_buffer(buffer, sizemat_reshaped, stridemat_reshaped, std::vector(offsets)); + if (viewable) { + return self.view(proposed_shape); } - Tensor buffer_reshaped = buffer.new_empty(buffer.sizes()); - Tensor output = wrap_buffer(buffer_reshaped, sizemat_reshaped); - NestedTensor_reshape_copy(output, - buffer, sizes, strides, offsets); - return output; + else { + return self.clone(at::MemoryFormat::Contiguous).view(proposed_shape); + } +} + +Tensor reshape_as_nested(const Tensor& self, const Tensor& other) { + auto other_ptr = get_nested_tensor_impl(other); + // TODO: this is to reproduce other_ptr->opt_sizes_ + // if an accessor is provided in the future, can replace this + std::vector sizes; + for (int64_t i = 0; i < other_ptr->dim(); i++) { + c10::optional opt_size = other_ptr->opt_size(i); + if (opt_size.has_value()) { + sizes.push_back(*opt_size); + } + else { + sizes.push_back(-1); + } + } + // reshape with other.opt_sizes_ + return self.reshape(sizes); } } // namespace native diff --git a/buckbuild.bzl b/buckbuild.bzl index 76e0db976be..d4349d6a75a 100644 --- a/buckbuild.bzl +++ b/buckbuild.bzl @@ -249,6 +249,7 @@ PT_BACKEND_HEADERS = [ "CompositeExplicitAutograd", "CompositeExplicitAutogradNonFunctional", "CompositeImplicitAutograd", + "CompositeImplicitAutogradNestedTensor", "Meta", ] @@ -307,6 +308,7 @@ def get_aten_generated_files(enabled_backends): src_files = [ "RegisterBackendSelect.cpp", "RegisterCompositeImplicitAutograd.cpp", + "RegisterCompositeImplicitAutogradNestedTensor.cpp", "RegisterCompositeExplicitAutograd.cpp", "RegisterCompositeExplicitAutogradNonFunctional.cpp", "CompositeViewCopyKernels.cpp", @@ -327,6 +329,8 @@ def get_aten_generated_files(enabled_backends): "Operators_4.cpp", "CompositeImplicitAutogradFunctions.h", "CompositeImplicitAutogradFunctions_inl.h", + "CompositeImplicitAutogradNestedTensorFunctions.h", + "CompositeImplicitAutogradNestedTensorFunctions_inl.h", "CompositeExplicitAutogradFunctions.h", "CompositeExplicitAutogradFunctions_inl.h", "CompositeExplicitAutogradNonFunctionalFunctions.h", @@ -364,7 +368,7 @@ def get_aten_derived_type_src_rules(aten_rule_name, enabled_backends): def get_aten_selective_cpp_rules(aten_rule_name, enabled_backends): return [ ":{}[{}]".format(aten_rule_name, f) - for f in ["RegisterCompositeImplicitAutograd.cpp", "RegisterCompositeExplicitAutograd.cpp", "RegisterCompositeExplicitAutogradNonFunctional.cpp", "RegisterSchema.cpp", "RegisterBackendSelect.cpp", "CompositeViewCopyKernels.cpp"] + for f in ["RegisterCompositeImplicitAutograd.cpp", "RegisterCompositeImplicitAutogradNestedTensor.cpp", "RegisterCompositeExplicitAutograd.cpp", "RegisterCompositeExplicitAutogradNonFunctional.cpp", "RegisterSchema.cpp", "RegisterBackendSelect.cpp", "CompositeViewCopyKernels.cpp"] ] + get_aten_derived_type_src_rules(aten_rule_name, enabled_backends) def get_aten_derived_type_srcs(enabled_backends): @@ -1083,6 +1087,8 @@ def define_buck_targets( "CompositeExplicitAutogradNonFunctionalFunctions_inl.h": ":gen_aten[CompositeExplicitAutogradNonFunctionalFunctions_inl.h]", "CompositeImplicitAutogradFunctions.h": ":gen_aten[CompositeImplicitAutogradFunctions.h]", "CompositeImplicitAutogradFunctions_inl.h": ":gen_aten[CompositeImplicitAutogradFunctions_inl.h]", + "CompositeImplicitAutogradNestedTensorFunctions.h": ":gen_aten[CompositeImplicitAutogradNestedTensorFunctions.h]", + "CompositeImplicitAutogradNestedTensorFunctions_inl.h": ":gen_aten[CompositeImplicitAutogradNestedTensorFunctions_inl.h]", "FunctionalInverses.h": ":gen_aten[FunctionalInverses.h]", "Functions.h": ":gen_aten[Functions.h]", "MethodOperators.h": ":gen_aten[MethodOperators.h]", diff --git a/build.bzl b/build.bzl index 5715e34786d..f3b04a3f25a 100644 --- a/build.bzl +++ b/build.bzl @@ -162,6 +162,8 @@ GENERATED_H_CORE = [ "CompositeExplicitAutogradNonFunctionalFunctions_inl.h", "CompositeImplicitAutogradFunctions.h", "CompositeImplicitAutogradFunctions_inl.h", + "CompositeImplicitAutogradNestedTensorFunctions.h", + "CompositeImplicitAutogradNestedTensorFunctions_inl.h", "MetaFunctions.h", "MetaFunctions_inl.h", "core/TensorBody.h", @@ -193,6 +195,7 @@ GENERATED_CPP = [ "RegisterSparseCsrCPU.cpp", "RegisterMkldnnCPU.cpp", "RegisterCompositeImplicitAutograd.cpp", + "RegisterCompositeImplicitAutogradNestedTensor.cpp", "RegisterZeroTensor.cpp", "RegisterMeta.cpp", "RegisterQuantizedMeta.cpp", diff --git a/c10/core/DispatchKey.cpp b/c10/core/DispatchKey.cpp index 4423d578b52..343262f7b72 100644 --- a/c10/core/DispatchKey.cpp +++ b/c10/core/DispatchKey.cpp @@ -178,6 +178,8 @@ const char* toString(DispatchKey t) { return "Autograd"; case DispatchKey::CompositeImplicitAutograd: return "CompositeImplicitAutograd"; + case DispatchKey::CompositeImplicitAutogradNestedTensor: + return "CompositeImplicitAutogradNestedTensor"; case DispatchKey::CompositeExplicitAutograd: return "CompositeExplicitAutograd"; case DispatchKey::CompositeExplicitAutogradNonFunctional: @@ -324,6 +326,8 @@ c10::DispatchKey parseDispatchKey(const std::string& k) { {"Autograd", c10::DispatchKey::Autograd}, {"CompositeImplicitAutograd", c10::DispatchKey::CompositeImplicitAutograd}, + {"CompositeImplicitAutogradNestedTensor", + c10::DispatchKey::CompositeImplicitAutogradNestedTensor}, {"CompositeExplicitAutograd", c10::DispatchKey::CompositeExplicitAutograd}, {"CompositeExplicitAutogradNonFunctional", diff --git a/c10/core/DispatchKey.h b/c10/core/DispatchKey.h index 2f1f1fc5f77..8843cd5472d 100644 --- a/c10/core/DispatchKey.h +++ b/c10/core/DispatchKey.h @@ -439,6 +439,8 @@ enum class DispatchKey : uint16_t { Autograd, CompositeImplicitAutograd, // registered at // build/aten/src/ATen/RegisterCompositeImplicitAutograd.cpp + CompositeImplicitAutogradNestedTensor, // registered at + // build/aten/src/ATen/RegisterCompositeImplicitAutogradNestedTensor.cpp CompositeExplicitAutograd, // registered at // build/aten/src/ATen/RegisterCompositeExplicitAutograd.cpp // See Note [CompositeExplicitAutogradNonFunctional Key] diff --git a/c10/core/DispatchKeySet.cpp b/c10/core/DispatchKeySet.cpp index 3cc564bc04a..a8f60451be3 100644 --- a/c10/core/DispatchKeySet.cpp +++ b/c10/core/DispatchKeySet.cpp @@ -55,6 +55,11 @@ constexpr DispatchKeySet math_dispatch_keyset = backend_dispatch_keyset | // math_dispatch_keyset DispatchKeySet{DispatchKey::NestedTensor}; +constexpr DispatchKeySet nested_dispatch_keyset = + DispatchKeySet( + {DispatchKey::AutogradNestedTensor, DispatchKey::NestedTensor}) | + DispatchKeySet(DispatchKeySet::RAW, full_backend_mask); + DispatchKeySet getRuntimeDispatchKeySet(DispatchKey t) { TORCH_INTERNAL_ASSERT(t != DispatchKey::Undefined); switch (t) { @@ -67,6 +72,8 @@ DispatchKeySet getRuntimeDispatchKeySet(DispatchKey t) { DispatchKeySet(DispatchKeySet::RAW, full_backend_mask); case DispatchKey::CompositeImplicitAutograd: return math_dispatch_keyset; + case DispatchKey::CompositeImplicitAutogradNestedTensor: + return nested_dispatch_keyset; case DispatchKey::CompositeExplicitAutograd: return backend_dispatch_keyset; case DispatchKey::CompositeExplicitAutogradNonFunctional: @@ -84,6 +91,9 @@ bool runtimeDispatchKeySetHas(DispatchKey t, DispatchKey k) { case DispatchKey::CompositeImplicitAutograd: // See Note [NestedTensor Not Included in Backend Keys] return math_dispatch_keyset.has(k); + case DispatchKey::CompositeImplicitAutogradNestedTensor: + // See Note [NestedTensor Not Included in Backend Keys] + return nested_dispatch_keyset.has(k); case DispatchKey::CompositeExplicitAutograd: // See Note [NestedTensor Not Included in Backend Keys] return k != DispatchKey::NestedTensor && backend_dispatch_keyset.has(k); diff --git a/test/forward_backward_compatibility/check_forward_backward_compatibility.py b/test/forward_backward_compatibility/check_forward_backward_compatibility.py index bdbfc449931..226fe04c221 100644 --- a/test/forward_backward_compatibility/check_forward_backward_compatibility.py +++ b/test/forward_backward_compatibility/check_forward_backward_compatibility.py @@ -128,6 +128,9 @@ ALLOW_LIST = [ ("aten::nanmean.out", datetime.date(2022, 8, 30)), ("aten::nansum", datetime.date(2022, 8, 30)), ("aten::nansum.out", datetime.date(2022, 8, 30)), + # nested tensor temporary auxiliary ops + ("aten::_reshape_nested", datetime.date(9999, 1, 1)), + ("aten::_reshape_nested_backward", datetime.date(9999, 1, 1)), ("aten::sum.SymInt", datetime.date(2022, 11, 30)), ("aten::mps_linear", datetime.date(9999, 1, 1)), ("aten::_mps_linear", datetime.date(9999, 1, 1)), diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index 90ff9417b4d..1357e07dcfe 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -230,11 +230,7 @@ class TestNestedTensor(TestCase): # Test non_contiguous case assert not nt_noncontiguous.is_contiguous() - self.assertRaisesRegex( - RuntimeError, - r"clone_nested only supports memory format Preserve, but got Contiguous instead.", - lambda: nt_noncontiguous.contiguous() - ) + self.assertEqual(nt_contiguous, nt_noncontiguous.contiguous()) @torch.inference_mode() def test_repr_string(self): @@ -679,7 +675,6 @@ class TestNestedTensorDeviceType(TestCase): @dtypes(torch.float, torch.float16) @skipMeta - @torch.inference_mode() def test_clone(self, device, dtype): nt1 = self.random_nt(device, dtype, 4, (4, 4), (1, 1)) nt2 = nt1.clone() @@ -693,7 +688,7 @@ class TestNestedTensorDeviceType(TestCase): self.assertNotEqual(ub1[i], ub2[i]) nt1.clone(memory_format=torch.preserve_format) - msg = "clone_nested only supports memory format Preserve, but got ChannelsLast instead." + msg = "Nested tensor clone supports Preserve and Contiguous memory formats, called clone with memory format: ChannelsLast" with self.assertRaisesRegex(RuntimeError, msg): nt1.clone(memory_format=torch.channels_last) @@ -1105,7 +1100,6 @@ class TestNestedTensorDeviceType(TestCase): ) @dtypes(torch.float, torch.float16, torch.double) - @torch.inference_mode() def test_transpose(self, device, dtype): nt = self.random_nt(device, dtype, 4, (4, 4)) # error case: transpose nested dimension @@ -1150,7 +1144,74 @@ class TestNestedTensorDeviceType(TestCase): self.assertEqual(ptT, ptT_from_ntT) @dtypes(torch.float, torch.float16, torch.double) - @torch.inference_mode() + def test_view(self, device, dtype): + nt = self.random_nt(device, dtype, 4, (4, 4)) + # error case: empty shape + self.assertRaisesRegex( + RuntimeError, + r"shape '\[\]' is invalid for a nested tensor", + lambda: nt.view(()) + ) + # error case: empty nested tensor + nt_empty = torch.nested_tensor([]) + self.assertRaisesRegex( + RuntimeError, + "empty nested tensor cannot be reshaped", + lambda: nt_empty.view(-1) + ) + # error case: invalid proposed shape for underlying tensors + self.assertRaisesRegex( + RuntimeError, + r"invalid shape dimension -2", + lambda: nt.view(-2, 2, 3) + ) + self.assertRaisesRegex( + RuntimeError, + r"shape '\[.*\]' is invalid for input of size [0-9]+", + lambda: nt.view(4, 2, 3) + ) + # normal case + x0 = torch.randn((2, 20), device=device, dtype=dtype) + x1 = torch.randn((3, 20), device=device, dtype=dtype) + nt = torch.nested_tensor([x0, x1]) + pt = nt.to_padded_tensor(0.0) + self.assertRaisesRegex( + RuntimeError, + r"for now view cannot change the implicit batch dimension", + lambda: nt.transpose(-1, -2).view(40, -1) + ) + # inherit only the ragged dimension + # (2, 20) -> (2, 5, 4) + # (3, 20) -> (3, 5, 4) + nt1 = nt.view(2, -1, 5, 4) + # (2, 3, 20) -> (2, 3, 5, 4) -> (2, 4, 5, 4) + pt1 = pt.view(2, -1, 5, 4) + self.assertEqual(noncontiguous_to_padded_tensor(nt1), pt1) + # also inherit regular dimension + nt2 = nt1.view(2, -1, -1, 2, 2) + pt2 = pt1.view(2, -1, 5, 2, 2) + self.assertEqual(noncontiguous_to_padded_tensor(nt2), pt2) + + @dtypes(torch.float, torch.float16, torch.double) + def test_view_inference_mode_interaction(self, device, dtype): + # Construct in default mode and view while in inference mode + nt = torch.nested_tensor([torch.randn((2, 20)), torch.randn((3, 20))], device=device, dtype=dtype) + with torch.inference_mode(): + ntT = nt.view(2, -1, 4, 5) + ptT_from_ntT = noncontiguous_to_padded_tensor(ntT) + pt = nt.to_padded_tensor(0.0) + ptT = pt.view(2, -1, 4, 5) + self.assertEqual(ptT, ptT_from_ntT) + # Construct and view while in inference mode + with torch.inference_mode(): + nt = torch.nested_tensor([torch.randn((2, 20)), torch.randn((3, 20))], device=device, dtype=dtype) + ntT = nt.view(2, -1, 4, 5) + ptT_from_ntT = noncontiguous_to_padded_tensor(ntT) + pt = nt.to_padded_tensor(0.0) + ptT = pt.view(2, -1, 4, 5) + self.assertEqual(ptT, ptT_from_ntT) + + @dtypes(torch.float, torch.float16, torch.double) def test_reshape(self, device, dtype): nt = self.random_nt(device, dtype, 4, (4, 4)) # error case: empty shape diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index dd649bc9c08..da74b31a313 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -1356,10 +1356,6 @@ # making it impossible (hard) to detect when it is actually a view. # - name: reshape(Tensor self, IntArrayRef shape) -- name: _reshape_nested(Tensor self, int[] shape) -> Tensor - self: _reshape_nested_backward(self, grad) - result: auto_linear - - name: _reshape_alias(Tensor(a) self, int[] size, int[] stride) -> Tensor(a) self: grad.reshape(self.sizes()) result: auto_linear @@ -1732,10 +1728,14 @@ # linear result1: mean(self_t, dim.value_or(IntArrayRef({})), keepdim) -# TODO: this derivative is not SymInt safe, need reshape_symint - name: view(Tensor(a) self, SymInt[] size) -> Tensor(a) - self: grad.reshape(self.sizes()) - result: auto_linear + dispatch: + Default: + self: grad.reshape(self.sizes()) + result: auto_linear + AutogradNestedTensor: + self: grad.reshape_as(self) + result: auto_linear - name: view.dtype(Tensor(a) self, ScalarType dtype) -> Tensor(a) output_differentiability: [False] diff --git a/torchgen/gen.py b/torchgen/gen.py index 660e54dfdf3..4fa73a02150 100644 --- a/torchgen/gen.py +++ b/torchgen/gen.py @@ -310,6 +310,7 @@ def static_dispatch_keys(backends: List[BackendIndex]) -> List[DispatchKey]: else: return [backend.dispatch_key for backend in backends] + [ DispatchKey.CompositeImplicitAutograd, + DispatchKey.CompositeImplicitAutogradNestedTensor, DispatchKey.CompositeExplicitAutograd, DispatchKey.CompositeExplicitAutogradNonFunctional, ] @@ -330,6 +331,8 @@ def get_static_dispatch_backend( return DispatchKey.CompositeExplicitAutogradNonFunctional elif f.has_composite_implicit_autograd_kernel: return DispatchKey.CompositeImplicitAutograd + elif f.has_composite_implicit_autograd_nested_tensor_kernel: + return DispatchKey.CompositeImplicitAutogradNestedTensor return None @@ -426,6 +429,8 @@ def generate_static_dispatch_fallback_call( return f"return {ns}::{DispatchKey.CompositeExplicitAutogradNonFunctional.lower()}::{name}({exprs});" elif f.has_composite_implicit_autograd_kernel: return f"return {ns}::{DispatchKey.CompositeImplicitAutograd.lower()}::{name}({exprs});" + elif f.has_composite_implicit_autograd_nested_tensor_kernel: + return f"return {ns}::{DispatchKey.CompositeImplicitAutogradNestedTensor.lower()}::{name}({exprs});" else: return f"""TORCH_CHECK(false, "Static dispatch does not support {name} for\ {', '.join([str(index.dispatch_key)for index in backend_indices])} ");""" @@ -1241,6 +1246,11 @@ def compute_registration_declarations( "dispatch": str( {k for k, v in backend_indices.items() if v.has_kernel(f)} != {DispatchKey.CompositeImplicitAutograd} + and {k for k, v in backend_indices.items() if v.has_kernel(f)} + != { + DispatchKey.CompositeImplicitAutograd, + DispatchKey.CompositeImplicitAutogradNestedTensor, + } ), "default": str(f.has_composite_kernel or has_autogenerated_composite_kernel(f)), } @@ -2145,6 +2155,13 @@ def gen_source_files( ns_grouped_native_functions[namespace].append(grouped_native_function) dispatch_namespace = str(dispatch_key).lower() + + # CompositeImplicitAutogradNestdTensor does not currently user the helpers generated + # compilation will fail when `-Werror=unused-function` flag is set + gen_dispatch_helpers: bool = ( + dispatch_key != DispatchKey.CompositeImplicitAutogradNestedTensor + ) + dispatch_definitions = get_native_function_definitions( fm=fm, grouped_native_functions=grouped_native_functions, @@ -2153,7 +2170,7 @@ def gen_source_files( selector=selector, rocm=rocm, skip_dispatcher_op_registration=skip_dispatcher_op_registration, - gen_dispatch_helpers=True, + gen_dispatch_helpers=gen_dispatch_helpers, ) fm.write_with_template( f"Register{dispatch_key}.cpp", @@ -2636,6 +2653,7 @@ def main() -> None: DispatchKey.CPU, DispatchKey.CUDA, DispatchKey.CompositeImplicitAutograd, + DispatchKey.CompositeImplicitAutogradNestedTensor, DispatchKey.CompositeExplicitAutograd, DispatchKey.CompositeExplicitAutogradNonFunctional, DispatchKey.Meta, diff --git a/torchgen/model.py b/torchgen/model.py index ee8f48afdaa..81fc05760af 100644 --- a/torchgen/model.py +++ b/torchgen/model.py @@ -93,6 +93,7 @@ class DispatchKey(Enum): Autograd = auto() CompositeImplicitAutograd = auto() + CompositeImplicitAutogradNestedTensor = auto() CompositeExplicitAutograd = auto() CompositeExplicitAutogradNonFunctional = auto() @@ -217,6 +218,7 @@ dispatch_keys = [ DispatchKey.QuantizedCPU, DispatchKey.QuantizedCUDA, DispatchKey.CompositeImplicitAutograd, + DispatchKey.CompositeImplicitAutogradNestedTensor, DispatchKey.CompositeExplicitAutograd, DispatchKey.CompositeExplicitAutogradNonFunctional, DispatchKey.NestedTensorCPU, @@ -237,6 +239,7 @@ def is_generic_dispatch_key(dk: DispatchKey) -> bool: DispatchKey.CompositeExplicitAutograd, DispatchKey.CompositeExplicitAutogradNonFunctional, DispatchKey.CompositeImplicitAutograd, + DispatchKey.CompositeImplicitAutogradNestedTensor, } @@ -485,6 +488,7 @@ class NativeFunction: # Whether or not the NativeFunction contains a backend-agnostic kernel has_composite_implicit_autograd_kernel: bool + has_composite_implicit_autograd_nested_tensor_kernel: bool has_composite_explicit_autograd_kernel: bool has_composite_explicit_autograd_non_functional_kernel: bool @@ -699,9 +703,15 @@ class NativeFunction: if d == DispatchKey.CompositeExplicitAutograd or d == DispatchKey.CompositeExplicitAutogradNonFunctional or d == DispatchKey.CompositeImplicitAutograd + or d == DispatchKey.CompositeImplicitAutogradNestedTensor ] - assert len(composites_in_dispatch) <= 1, ( + assert len(composites_in_dispatch) <= 1 or ( + len(composites_in_dispatch) == 2 + and DispatchKey.CompositeImplicitAutograd in composites_in_dispatch + and DispatchKey.CompositeImplicitAutogradNestedTensor + in composites_in_dispatch + ), ( "cannot specify more than one of CompositeExplicitAutograd, CompositeExplicitAutogradNonFunctional, " "or CompositeImplicitAutograd on a single kernel; each " "strictly subsumes the other. If you wanted to provide an explicit autograd " @@ -756,11 +766,23 @@ class NativeFunction: # Structured functions MUST have a dispatch table is_abstract = True else: - is_abstract = dispatch.keys() != {DispatchKey.CompositeImplicitAutograd} + is_abstract = ( + dispatch.keys() != {DispatchKey.CompositeImplicitAutograd} + and dispatch.keys() + != {DispatchKey.CompositeImplicitAutogradNestedTensor} + and dispatch.keys() + != { + DispatchKey.CompositeImplicitAutograd, + DispatchKey.CompositeImplicitAutogradNestedTensor, + } + ) has_composite_implicit_autograd_kernel = ( DispatchKey.CompositeImplicitAutograd in dispatch.keys() ) + has_composite_implicit_autograd_nested_tensor_kernel = ( + DispatchKey.CompositeImplicitAutogradNestedTensor in dispatch.keys() + ) has_composite_explicit_autograd_kernel = ( DispatchKey.CompositeExplicitAutograd in dispatch.keys() ) @@ -808,6 +830,7 @@ class NativeFunction: cpp_no_default_args=cpp_no_default_args, is_abstract=is_abstract, has_composite_implicit_autograd_kernel=has_composite_implicit_autograd_kernel, + has_composite_implicit_autograd_nested_tensor_kernel=has_composite_implicit_autograd_nested_tensor_kernel, has_composite_explicit_autograd_kernel=has_composite_explicit_autograd_kernel, has_composite_explicit_autograd_non_functional_kernel=has_composite_explicit_autograd_non_functional_kernel, tags=tags, @@ -899,6 +922,9 @@ class NativeFunction: self.has_composite_implicit_autograd_kernel or self.has_composite_explicit_autograd_kernel or self.has_composite_explicit_autograd_non_functional_kernel + ) or ( + self.has_composite_implicit_autograd_kernel + and self.has_composite_implicit_autograd_nested_tensor_kernel ) @property @@ -976,7 +1002,10 @@ class NativeFunctionsGroup: if self.structured: # For now, structured composite kernels are not supported (need some # design work to figure out how to make the composite case work) - assert not self.out.has_composite_implicit_autograd_kernel + assert ( + not self.out.has_composite_implicit_autograd_kernel + and not self.out.has_composite_implicit_autograd_nested_tensor_kernel + ) assert self.functional.structured_delegate == self.out.func.name, ( f"{self.functional.func.name} delegates to {self.functional.structured_delegate} " @@ -2510,6 +2539,14 @@ class NativeFunctionsViewGroup: f"{str(self.view.func.name)} and {str(self.view_inplace.func.name)} must either" " both have CompositeImplicitAutograd kernels, or both not have composite kernels." ) + if self.view.has_composite_implicit_autograd_nested_tensor_kernel: + if self.view_inplace is not None: + assert ( + self.view_inplace.has_composite_implicit_autograd_nested_tensor_kernel + ), ( + f"{str(self.view.func.name)} and {str(self.view_inplace.func.name)} must either" + " both have CompositeImplicitAutogradNestedTensor kernels, or both not have composite kernels." + ) def functions(self, *, include_copy: bool = True) -> Iterator[NativeFunction]: yield self.view diff --git a/torchgen/native_function_generation.py b/torchgen/native_function_generation.py index 70755d4c0b4..bf8503ed640 100644 --- a/torchgen/native_function_generation.py +++ b/torchgen/native_function_generation.py @@ -336,6 +336,7 @@ def generate_function( cpp_no_default_args=set(), is_abstract=f.is_abstract, has_composite_implicit_autograd_kernel=False, + has_composite_implicit_autograd_nested_tensor_kernel=False, has_composite_explicit_autograd_kernel=True, has_composite_explicit_autograd_non_functional_kernel=False, # Every generated NativeFunction gets a "generated" tag, so it's easy to tell