Better reshape with autograd support (#82754) (#84154)

The original author is @YifanShenSZ  and the original PR is: #82754
# Summary:
Previous reshape [https://github.com/pytorch/pytorch/issues/80981](https://github.com/pytorch/pytorch/pull/80981) is ok for forward, but needs improvement for backward: need to handle "sometimes view sometimes copy" behavior.

This pull request fixes it by:
1. add a new alias dispatch key `CompositeImplicitAutogradNestedTensor`, which ideally would work as nested-tensor version of `CompositeImplicitAutograd`
2. register `reshape_nested` to `reshape` by `CompositeImplicitAutogradNestedTensor`

Side changes:
* add contiguous memory format support to `clone_nested`
* add `view_nested`
* add `reshape_as_nested`

Fix issue [https://github.com/pytorch/pytorch/issues/83041](https://github.com/pytorch/pytorch/issues/83041)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/82754

Test Plan:
Imported from GitHub, without a `Test Plan:` line.

**Static Docs Preview: executorch**
|[Full Site](https://our.intern.facebook.com/intern/staticdocs/eph/D39023822/V13/executorch/)|

|**Modified Pages**|

Reviewed By: albanD

Differential Revision: D39023822

Pulled By: drisspg

Pull Request resolved: https://github.com/pytorch/pytorch/pull/84154
Approved by: https://github.com/bdhirsh, https://github.com/albanD
This commit is contained in:
YifanShenSZ 2022-09-01 20:01:39 +00:00 committed by PyTorch MergeBot
parent 9bcad063d8
commit 673b35c847
17 changed files with 327 additions and 123 deletions

View File

@ -51,6 +51,7 @@ generated_cpu_cpp = [
"aten/src/ATen/RegisterSparseCsrCPU.cpp",
"aten/src/ATen/RegisterZeroTensor.cpp",
"aten/src/ATen/RegisterCompositeImplicitAutograd.cpp",
"aten/src/ATen/RegisterCompositeImplicitAutogradNestedTensor.cpp",
"aten/src/ATen/RegisterCompositeExplicitAutograd.cpp",
"aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp",
"aten/src/ATen/RegisterMeta.cpp",
@ -66,6 +67,8 @@ generated_cpu_cpp = [
"aten/src/ATen/CompositeExplicitAutogradNonFunctionalFunctions_inl.h",
"aten/src/ATen/CompositeImplicitAutogradFunctions.h",
"aten/src/ATen/CompositeImplicitAutogradFunctions_inl.h",
"aten/src/ATen/CompositeImplicitAutogradNestedTensorFunctions.h",
"aten/src/ATen/CompositeImplicitAutogradNestedTensorFunctions_inl.h",
"aten/src/ATen/CompositeViewCopyKernels.cpp",
"aten/src/ATen/FunctionalInverses.h",
"aten/src/ATen/Functions.h",

View File

@ -307,6 +307,21 @@ std::pair<const AnnotatedKernel&, const char*> OperatorEntry::computeDispatchTab
// For AutogradOther, we return ambiguousAutogradOtherKernel() if there's registration
// to any of its backends.
// See Note [Undefined in dispatchTable_] for the special handling for Undefined.
// If the dispatch key is included in CompositeImplicitAutogradNestedTensor,
// then we register it to nested-tensor kernel rather than
// regular-tensor CompositeImplicitAutograd kernel.
// We have no intention to change the behavior of Undefined,
// so this nested-tensor branch requires `dispatch_key != DispatchKey::Undefined`
// to let the original CompositeImplicitAutograd handle Undefined
if (dispatch_key != DispatchKey::Undefined && isIncludedInAlias(dispatch_key, DispatchKey::CompositeImplicitAutogradNestedTensor)) {
if (auto nested_registration = getKernelForDispatchKey(DispatchKey::CompositeImplicitAutogradNestedTensor)) {
if (!has_backend_kernel) {
return {*nested_registration, "nested kernel"};
}
}
}
if (dispatch_key == DispatchKey::Undefined || isIncludedInAlias(dispatch_key, DispatchKey::CompositeImplicitAutograd)) {
if (auto math_registration = getKernelForDispatchKey(DispatchKey::CompositeImplicitAutograd)) {
if (dispatch_key == DispatchKey::AutogradOther

View File

@ -1256,17 +1256,6 @@ Tensor alias_with_sizes_and_strides(
}
Tensor reshape(const Tensor& self, IntArrayRef proposed_shape) {
// reshape has special autograd logic since it sometimes returns a view but sometimes does not
// we have to intercept here instead of using dispatcher
// otherwise we will see "autograd still running" kind of error in inference mode:
// * if we create a tensor in inference mode scope,
// then pass it to a inference mode decorated function,
// everything is fine
// * but if we create the input tensor not with inference mode,
// then errors like "Cannot set version_counter for inference tensor" arise
if (self.is_nested()) {
return at::_reshape_nested(self, proposed_shape);
}
if (self.is_sparse()) {
AT_ERROR("reshape is not implemented for sparse tensors");
}

View File

@ -4200,16 +4200,9 @@
variants: function, method
device_check: NoCheck
device_guard: False
- func: _reshape_nested(Tensor self, int[] shape) -> Tensor
dispatch:
NestedTensorCPU, NestedTensorCUDA: _reshape_nested
autogen: _reshape_nested.out
- func: _reshape_nested_backward(Tensor self, Tensor grad) -> Tensor
dispatch:
NestedTensorCPU, NestedTensorCUDA: _reshape_nested_backward
autogen: _reshape_nested_backward.out
CompositeImplicitAutograd: reshape
CompositeImplicitAutogradNestedTensor: reshape_nested
# NOTE [ _reshape_alias ] is meant to be used in the implementation of reshape.
# They are not user-facing, hence the leading underscore. Please don't use it
@ -4233,6 +4226,9 @@
variants: method
device_check: NoCheck
device_guard: False
dispatch:
CompositeImplicitAutograd: reshape_as
CompositeImplicitAutogradNestedTensor: reshape_as_nested
- func: round(Tensor self) -> Tensor
device_check: NoCheck # TensorIterator
@ -6889,6 +6885,7 @@
Meta: view_meta
ZeroTensor, CPU, CUDA, QuantizedCPU, QuantizedCUDA, MPS: view
MkldnnCPU: mkldnn_view
NestedTensorCPU, NestedTensorCUDA: view_nested
# Warning: If you want to change the name or overload name of this
# operator, you might also want to change the `isBlockListedSchema`

View File

@ -66,23 +66,6 @@ std::tuple<Tensor, Tensor, Tensor> nested_linear_backward(
return std::tuple<Tensor, Tensor, Tensor>{grad_input, grad_weight, grad_bias};
}
Tensor _reshape_nested_backward(const Tensor& self, const Tensor& grad) {
auto self_ptr = get_nested_tensor_impl(self);
// TODO: this is to reproduce self_ptr->opt_sizes_
// if an accessor is provided in the future, can replace this
std::vector<int64_t> sizes;
for (int64_t i = 0; i < self_ptr->dim(); i++) {
c10::optional<int64_t> opt_size = self_ptr->opt_size(i);
if (opt_size.has_value()) {
sizes.push_back(*opt_size);
}
else {
sizes.push_back(-1);
}
}
return grad.reshape(sizes);
}
Tensor nested_softmax_backward(
const Tensor& grad,
const Tensor& output,

View File

@ -688,17 +688,38 @@ Tensor select_nested(const Tensor& self, int64_t dim, int64_t index) {
Tensor clone_nested(
const Tensor& self,
c10::optional<c10::MemoryFormat> optional_memory_format) {
auto memory_format = optional_memory_format.value_or(MemoryFormat::Preserve);
auto memory_format = optional_memory_format.value_or(c10::MemoryFormat::Preserve);
auto self_ptr = get_nested_tensor_impl(self);
if (memory_format == c10::MemoryFormat::Preserve ||
(memory_format == c10::MemoryFormat::Contiguous && self.is_contiguous())) {
const Tensor& buffer = self_ptr->get_buffer(),
sizemat = self_ptr->get_nested_size_tensor(),
stridemat = self_ptr->get_nested_stride_tensor();
const std::vector<int64_t>& offsets = self_ptr->get_offsets();
// TODO: The size and the stride do not necessarily need to be cloned,
// but it is more conservative.
// This is something we could revisit once we land a more
// efficient implementation of nested_size_tensor_ and nested_stride_tensor.
return wrap_buffer(buffer.clone(), sizemat.clone(), stridemat.clone(), std::vector<int64_t>(offsets));
}
// actually, memory format is contiguous and self is noncontiguous
else if (memory_format == c10::MemoryFormat::Contiguous) {
const Tensor& self_buffer = self_ptr->get_buffer(),
sizemat = self_ptr->get_nested_size_tensor();
Tensor output_buffer = at::empty_like(self_buffer);
Tensor output = wrap_buffer(output_buffer, sizemat);
std::vector<Tensor> self_unbind = self.unbind(),
output_unbind = output.unbind();
for (int64_t i = 0; i < self_ptr->size(0); i++) {
output_unbind[i].copy_(self_unbind[i]);
}
return output;
} else {
TORCH_CHECK(
memory_format == MemoryFormat::Preserve,
"clone_nested only supports memory format Preserve, but got ",
memory_format,
" instead.");
// TODO: The size doesn't necessarily need to be cloned, but it is more
// conservative. This is something we could revisit once we land a more
// efficient implementation of nested_size_tensor_.
return wrap_buffer(
get_buffer(self).clone(), get_nested_size_tensor(self).clone());
false,
"Nested tensor clone supports Preserve and Contiguous memory formats, called clone with memory format: ",
memory_format);
}
}
at::Tensor NestedTensor_get_nested_size_tensor(const at::Tensor& self){
@ -1008,7 +1029,7 @@ Tensor transpose_nested(const Tensor& self, int64_t dim0, int64_t dim1) {
self, sizemat_transposed, stridemat_transposed, std::vector<int64_t>(self_ptr->get_offsets()));
}
// utilities supporting `_reshape_nested`
// utilities supporting `view_nested` and `reshape_nested`
namespace {
// Args:
// sizes: the sizes of original nested tensor
@ -1016,10 +1037,10 @@ namespace {
// proposed_shape: user proposed new shape
// op: the options for new size and stride matrices
// Returns:
// whether reshape as view is possible (i.e. old buffer can be reused)
// whether viewable
// size matrix after reshape
// stride matrix after reshape (not fully populated if reshape as view is impossible)
inline std::tuple<bool, Tensor, Tensor> NestedTensor_reshape_size_stride(
// stride matrix after reshape (not fully populated if not viewable)
inline std::tuple<bool, Tensor, Tensor> NestedTensor_compute_size_stride(
const std::vector<IntArrayRef>& sizes,
const std::vector<IntArrayRef>& strides,
const IntArrayRef& proposed_shape,
@ -1027,7 +1048,7 @@ inline std::tuple<bool, Tensor, Tensor> NestedTensor_reshape_size_stride(
int64_t ntensors = sizes.size(),
ndims_underlying = sizes[0].size(),
ndims_underlying_reshaped = proposed_shape.size() - 1;
bool reshape_as_view = true;
bool viewable = true;
Tensor sizemat_reshaped = at::empty({ntensors, ndims_underlying_reshaped}, op),
stridemat_reshaped = at::empty({ntensors, ndims_underlying_reshaped}, op);
int64_t* sizemat_reshaped_ptr = sizemat_reshaped.data_ptr<int64_t>(),
@ -1039,6 +1060,7 @@ inline std::tuple<bool, Tensor, Tensor> NestedTensor_reshape_size_stride(
std::vector<int64_t> size_reshaped_vector(proposed_shape.begin() + 1, proposed_shape.end());
// some negative sizes remain to be infered
if (ndims_underlying < ndims_underlying_reshaped) {
int64_t numel = 1, numel_reshaped = 1;
// replace negative sizes for old dimensions with old sizes
for (int64_t idim = 0; idim < ndims_underlying; idim++) {
int64_t& size_reshaped = size_reshaped_vector[idim];
@ -1046,12 +1068,17 @@ inline std::tuple<bool, Tensor, Tensor> NestedTensor_reshape_size_stride(
if (size_reshaped == -1) {
size_reshaped = size[idim];
}
numel *= size[idim];
numel_reshaped *= size_reshaped;
}
// infer negative size for new dimension
int64_t infer_index = -1;
for (int64_t idim = ndims_underlying; idim < ndims_underlying_reshaped; idim++) {
const int64_t& size_reshaped = size_reshaped_vector[idim];
if (size_reshaped == -1) {
if (size_reshaped >= 0) {
numel_reshaped *= size_reshaped;
}
else if (size_reshaped == -1) {
if (infer_index > -1) {
throw std::runtime_error("only one dimension can be inferred");
}
@ -1059,22 +1086,36 @@ inline std::tuple<bool, Tensor, Tensor> NestedTensor_reshape_size_stride(
infer_index = idim;
}
}
else if (size_reshaped < 0) {
else {
AT_ERROR("invalid shape dimension ", size_reshaped);
}
}
// See Note [inference and inheritance semantics]
// See Note [Special size rule for nested tensor]
TORCH_CHECK(infer_index == -1, "nested tensor does not infer shape");
TORCH_CHECK(
numel == numel_reshaped,
"shape '", proposed_shape, "' ",
"is invalid for input of size ", numel);
}
// all negative sizes can be replaced
else {
int64_t numel = 1, numel_reshaped = 1;
for (int64_t idim = 0; idim < ndims_underlying_reshaped; idim++) {
int64_t& size_reshaped = size_reshaped_vector[idim];
TORCH_CHECK(size_reshaped >= -1, "invalid shape dimension ", size_reshaped);
if (size_reshaped == -1) {
size_reshaped = size[idim];
}
numel *= size[idim];
numel_reshaped *= size_reshaped;
}
for (int64_t idim = ndims_underlying_reshaped; idim < ndims_underlying; idim++) {
numel *= size[idim];
}
TORCH_CHECK(
numel == numel_reshaped,
"shape '", proposed_shape, "' ",
"is invalid for input of size ", numel);
}
IntArrayRef size_reshaped(size_reshaped_vector);
// compute reshaped stride
@ -1092,7 +1133,7 @@ inline std::tuple<bool, Tensor, Tensor> NestedTensor_reshape_size_stride(
}
// reshape as view is impossible
else {
reshape_as_view = false;
viewable = false;
// fill reshaped size into sizemat
for (int64_t idim = 0; idim < ndims_underlying_reshaped; idim++) {
sizemat_reshaped_ptr[idim] = size_reshaped[idim];
@ -1100,42 +1141,59 @@ inline std::tuple<bool, Tensor, Tensor> NestedTensor_reshape_size_stride(
sizemat_reshaped_ptr += ndims_underlying_reshaped;
}
}
return std::make_tuple(reshape_as_view, sizemat_reshaped, stridemat_reshaped);
}
// Args:
// nt_reshaped: the reshaped nested tensor to receive copies
// buffer: the original nested tensor buffer
// sizes: the original nested tensor sizes (may have gone through collapsing or splitting)
// strides: the original nested tensor strides (may have gone through collapsing or splitting)
// offsets: the original nested tensor offsets (may have gone through collapsing or splitting)
inline void NestedTensor_reshape_copy(
Tensor& nt_reshaped,
const Tensor& buffer,
const std::vector<IntArrayRef>& sizes,
const std::vector<IntArrayRef>& strides,
const std::vector<int64_t>& offsets) {
auto nt_reshaped_ptr = get_nested_tensor_impl(nt_reshaped);
const Tensor& buffer_reshaped = nt_reshaped_ptr->get_buffer();
std::vector<IntArrayRef> sizes_reshaped = NestedTensor_get_sizes(nt_reshaped_ptr),
strides_reshaped = NestedTensor_get_strides(nt_reshaped_ptr);
const std::vector<int64_t>& offsets_reshaped = nt_reshaped_ptr->get_offsets();
for (int64_t i = 0; i < nt_reshaped_ptr->size(0); i++) {
buffer_reshaped.as_strided(sizes_reshaped[i], strides_reshaped[i], offsets_reshaped[i]).copy_(
// TODO: can we avoid allocating new memory for `buffer...reshape`
// I did not find anything like reshape_out
buffer.as_strided(sizes[i], strides[i], offsets[i]).reshape(sizes_reshaped[i]));
}
return std::make_tuple(viewable, sizemat_reshaped, stridemat_reshaped);
}
} // namespace
// Special rules for reshape(nested tensor):
// 1. Only 1 regular dimension can be collapsed with
// or splitted from the implicit batch dimension
// 2. Instead of infering size, -1 means "inherit the old size", so:
// Note [Special size rule for nested tensor]
// Instead of infering size, -1 means "inherit the old size", so:
// * negative size is legal for a ragged dimension
// * multiple sizes can be -1
Tensor _reshape_nested(const Tensor& self, IntArrayRef proposed_shape) {
// In principle we could still infer a dimension,
// we are designing a better semantics to include both inheritance and inference
Tensor view_nested(const Tensor& self, IntArrayRef proposed_shape) {
TORCH_CHECK(
proposed_shape.size() > 0,
"shape '[]' is invalid for a nested tensor");
auto self_ptr = get_nested_tensor_impl(self);
// basic information before reshaping
int64_t ntensors = self_ptr->size(0);
TORCH_CHECK(
ntensors > 0,
"empty nested tensor cannot be reshaped");
// basic information after reshaping
int64_t ntensors_reshaped;
if (proposed_shape[0] >= 0) {
ntensors_reshaped = proposed_shape[0];
}
else if (proposed_shape[0] == -1) {
ntensors_reshaped = ntensors;
}
else {
AT_ERROR("invalid shape dimension ", proposed_shape[0]);
}
TORCH_CHECK(
ntensors == ntensors_reshaped,
"for now view cannot change the implicit batch dimension");
std::vector<IntArrayRef> sizes = NestedTensor_get_sizes(self_ptr),
strides = NestedTensor_get_strides(self_ptr);
// reshaping underlying tensor dimensions does not change offset
// determine reshaped size and stride
const Tensor& sizemat = self_ptr->get_nested_size_tensor();
bool viewable;
Tensor sizemat_reshaped, stridemat_reshaped;
std::tie(viewable, sizemat_reshaped, stridemat_reshaped) = NestedTensor_compute_size_stride(
sizes, strides, proposed_shape, sizemat.options());
TORCH_CHECK(
viewable,
"view size is not compatible with input tensor's size and stride "
"(at least one dimension spans across two contiguous subspaces). "
"Use .reshape(...) instead.");
return create_nested_view_tensor(self, sizemat_reshaped, stridemat_reshaped, std::vector<int64_t>(self_ptr->get_offsets()));
}
// See Note [Special size rule for nested tensor]
Tensor reshape_nested(const Tensor& self, IntArrayRef proposed_shape) {
TORCH_CHECK(
proposed_shape.size() > 0,
"shape '[]' is invalid for a nested tensor");
@ -1161,23 +1219,37 @@ Tensor _reshape_nested(const Tensor& self, IntArrayRef proposed_shape) {
"for now reshape cannot change the implicit batch dimension");
std::vector<IntArrayRef> sizes = NestedTensor_get_sizes(self_ptr),
strides = NestedTensor_get_strides(self_ptr);
const std::vector<int64_t>& offsets = self_ptr->get_offsets();
// reshaping underlying tensor dimensions does not change offset
// determine reshaped size and stride
const Tensor& buffer = self_ptr->get_buffer(),
& sizemat = self_ptr->get_nested_size_tensor();
bool reshape_as_view;
const Tensor& sizemat = self_ptr->get_nested_size_tensor();
bool viewable;
Tensor sizemat_reshaped, stridemat_reshaped;
std::tie(reshape_as_view, sizemat_reshaped, stridemat_reshaped) = NestedTensor_reshape_size_stride(
std::tie(viewable, sizemat_reshaped, stridemat_reshaped) = NestedTensor_compute_size_stride(
sizes, strides, proposed_shape, sizemat.options());
if (reshape_as_view) {
return wrap_buffer(buffer, sizemat_reshaped, stridemat_reshaped, std::vector<int64_t>(offsets));
if (viewable) {
return self.view(proposed_shape);
}
Tensor buffer_reshaped = buffer.new_empty(buffer.sizes());
Tensor output = wrap_buffer(buffer_reshaped, sizemat_reshaped);
NestedTensor_reshape_copy(output,
buffer, sizes, strides, offsets);
return output;
else {
return self.clone(at::MemoryFormat::Contiguous).view(proposed_shape);
}
}
Tensor reshape_as_nested(const Tensor& self, const Tensor& other) {
auto other_ptr = get_nested_tensor_impl(other);
// TODO: this is to reproduce other_ptr->opt_sizes_
// if an accessor is provided in the future, can replace this
std::vector<int64_t> sizes;
for (int64_t i = 0; i < other_ptr->dim(); i++) {
c10::optional<int64_t> opt_size = other_ptr->opt_size(i);
if (opt_size.has_value()) {
sizes.push_back(*opt_size);
}
else {
sizes.push_back(-1);
}
}
// reshape with other.opt_sizes_
return self.reshape(sizes);
}
} // namespace native

View File

@ -249,6 +249,7 @@ PT_BACKEND_HEADERS = [
"CompositeExplicitAutograd",
"CompositeExplicitAutogradNonFunctional",
"CompositeImplicitAutograd",
"CompositeImplicitAutogradNestedTensor",
"Meta",
]
@ -307,6 +308,7 @@ def get_aten_generated_files(enabled_backends):
src_files = [
"RegisterBackendSelect.cpp",
"RegisterCompositeImplicitAutograd.cpp",
"RegisterCompositeImplicitAutogradNestedTensor.cpp",
"RegisterCompositeExplicitAutograd.cpp",
"RegisterCompositeExplicitAutogradNonFunctional.cpp",
"CompositeViewCopyKernels.cpp",
@ -327,6 +329,8 @@ def get_aten_generated_files(enabled_backends):
"Operators_4.cpp",
"CompositeImplicitAutogradFunctions.h",
"CompositeImplicitAutogradFunctions_inl.h",
"CompositeImplicitAutogradNestedTensorFunctions.h",
"CompositeImplicitAutogradNestedTensorFunctions_inl.h",
"CompositeExplicitAutogradFunctions.h",
"CompositeExplicitAutogradFunctions_inl.h",
"CompositeExplicitAutogradNonFunctionalFunctions.h",
@ -364,7 +368,7 @@ def get_aten_derived_type_src_rules(aten_rule_name, enabled_backends):
def get_aten_selective_cpp_rules(aten_rule_name, enabled_backends):
return [
":{}[{}]".format(aten_rule_name, f)
for f in ["RegisterCompositeImplicitAutograd.cpp", "RegisterCompositeExplicitAutograd.cpp", "RegisterCompositeExplicitAutogradNonFunctional.cpp", "RegisterSchema.cpp", "RegisterBackendSelect.cpp", "CompositeViewCopyKernels.cpp"]
for f in ["RegisterCompositeImplicitAutograd.cpp", "RegisterCompositeImplicitAutogradNestedTensor.cpp", "RegisterCompositeExplicitAutograd.cpp", "RegisterCompositeExplicitAutogradNonFunctional.cpp", "RegisterSchema.cpp", "RegisterBackendSelect.cpp", "CompositeViewCopyKernels.cpp"]
] + get_aten_derived_type_src_rules(aten_rule_name, enabled_backends)
def get_aten_derived_type_srcs(enabled_backends):
@ -1083,6 +1087,8 @@ def define_buck_targets(
"CompositeExplicitAutogradNonFunctionalFunctions_inl.h": ":gen_aten[CompositeExplicitAutogradNonFunctionalFunctions_inl.h]",
"CompositeImplicitAutogradFunctions.h": ":gen_aten[CompositeImplicitAutogradFunctions.h]",
"CompositeImplicitAutogradFunctions_inl.h": ":gen_aten[CompositeImplicitAutogradFunctions_inl.h]",
"CompositeImplicitAutogradNestedTensorFunctions.h": ":gen_aten[CompositeImplicitAutogradNestedTensorFunctions.h]",
"CompositeImplicitAutogradNestedTensorFunctions_inl.h": ":gen_aten[CompositeImplicitAutogradNestedTensorFunctions_inl.h]",
"FunctionalInverses.h": ":gen_aten[FunctionalInverses.h]",
"Functions.h": ":gen_aten[Functions.h]",
"MethodOperators.h": ":gen_aten[MethodOperators.h]",

View File

@ -162,6 +162,8 @@ GENERATED_H_CORE = [
"CompositeExplicitAutogradNonFunctionalFunctions_inl.h",
"CompositeImplicitAutogradFunctions.h",
"CompositeImplicitAutogradFunctions_inl.h",
"CompositeImplicitAutogradNestedTensorFunctions.h",
"CompositeImplicitAutogradNestedTensorFunctions_inl.h",
"MetaFunctions.h",
"MetaFunctions_inl.h",
"core/TensorBody.h",
@ -193,6 +195,7 @@ GENERATED_CPP = [
"RegisterSparseCsrCPU.cpp",
"RegisterMkldnnCPU.cpp",
"RegisterCompositeImplicitAutograd.cpp",
"RegisterCompositeImplicitAutogradNestedTensor.cpp",
"RegisterZeroTensor.cpp",
"RegisterMeta.cpp",
"RegisterQuantizedMeta.cpp",

View File

@ -178,6 +178,8 @@ const char* toString(DispatchKey t) {
return "Autograd";
case DispatchKey::CompositeImplicitAutograd:
return "CompositeImplicitAutograd";
case DispatchKey::CompositeImplicitAutogradNestedTensor:
return "CompositeImplicitAutogradNestedTensor";
case DispatchKey::CompositeExplicitAutograd:
return "CompositeExplicitAutograd";
case DispatchKey::CompositeExplicitAutogradNonFunctional:
@ -324,6 +326,8 @@ c10::DispatchKey parseDispatchKey(const std::string& k) {
{"Autograd", c10::DispatchKey::Autograd},
{"CompositeImplicitAutograd",
c10::DispatchKey::CompositeImplicitAutograd},
{"CompositeImplicitAutogradNestedTensor",
c10::DispatchKey::CompositeImplicitAutogradNestedTensor},
{"CompositeExplicitAutograd",
c10::DispatchKey::CompositeExplicitAutograd},
{"CompositeExplicitAutogradNonFunctional",

View File

@ -439,6 +439,8 @@ enum class DispatchKey : uint16_t {
Autograd,
CompositeImplicitAutograd, // registered at
// build/aten/src/ATen/RegisterCompositeImplicitAutograd.cpp
CompositeImplicitAutogradNestedTensor, // registered at
// build/aten/src/ATen/RegisterCompositeImplicitAutogradNestedTensor.cpp
CompositeExplicitAutograd, // registered at
// build/aten/src/ATen/RegisterCompositeExplicitAutograd.cpp
// See Note [CompositeExplicitAutogradNonFunctional Key]

View File

@ -55,6 +55,11 @@ constexpr DispatchKeySet math_dispatch_keyset = backend_dispatch_keyset |
// math_dispatch_keyset
DispatchKeySet{DispatchKey::NestedTensor};
constexpr DispatchKeySet nested_dispatch_keyset =
DispatchKeySet(
{DispatchKey::AutogradNestedTensor, DispatchKey::NestedTensor}) |
DispatchKeySet(DispatchKeySet::RAW, full_backend_mask);
DispatchKeySet getRuntimeDispatchKeySet(DispatchKey t) {
TORCH_INTERNAL_ASSERT(t != DispatchKey::Undefined);
switch (t) {
@ -67,6 +72,8 @@ DispatchKeySet getRuntimeDispatchKeySet(DispatchKey t) {
DispatchKeySet(DispatchKeySet::RAW, full_backend_mask);
case DispatchKey::CompositeImplicitAutograd:
return math_dispatch_keyset;
case DispatchKey::CompositeImplicitAutogradNestedTensor:
return nested_dispatch_keyset;
case DispatchKey::CompositeExplicitAutograd:
return backend_dispatch_keyset;
case DispatchKey::CompositeExplicitAutogradNonFunctional:
@ -84,6 +91,9 @@ bool runtimeDispatchKeySetHas(DispatchKey t, DispatchKey k) {
case DispatchKey::CompositeImplicitAutograd:
// See Note [NestedTensor Not Included in Backend Keys]
return math_dispatch_keyset.has(k);
case DispatchKey::CompositeImplicitAutogradNestedTensor:
// See Note [NestedTensor Not Included in Backend Keys]
return nested_dispatch_keyset.has(k);
case DispatchKey::CompositeExplicitAutograd:
// See Note [NestedTensor Not Included in Backend Keys]
return k != DispatchKey::NestedTensor && backend_dispatch_keyset.has(k);

View File

@ -128,6 +128,9 @@ ALLOW_LIST = [
("aten::nanmean.out", datetime.date(2022, 8, 30)),
("aten::nansum", datetime.date(2022, 8, 30)),
("aten::nansum.out", datetime.date(2022, 8, 30)),
# nested tensor temporary auxiliary ops
("aten::_reshape_nested", datetime.date(9999, 1, 1)),
("aten::_reshape_nested_backward", datetime.date(9999, 1, 1)),
("aten::sum.SymInt", datetime.date(2022, 11, 30)),
("aten::mps_linear", datetime.date(9999, 1, 1)),
("aten::_mps_linear", datetime.date(9999, 1, 1)),

View File

@ -230,11 +230,7 @@ class TestNestedTensor(TestCase):
# Test non_contiguous case
assert not nt_noncontiguous.is_contiguous()
self.assertRaisesRegex(
RuntimeError,
r"clone_nested only supports memory format Preserve, but got Contiguous instead.",
lambda: nt_noncontiguous.contiguous()
)
self.assertEqual(nt_contiguous, nt_noncontiguous.contiguous())
@torch.inference_mode()
def test_repr_string(self):
@ -679,7 +675,6 @@ class TestNestedTensorDeviceType(TestCase):
@dtypes(torch.float, torch.float16)
@skipMeta
@torch.inference_mode()
def test_clone(self, device, dtype):
nt1 = self.random_nt(device, dtype, 4, (4, 4), (1, 1))
nt2 = nt1.clone()
@ -693,7 +688,7 @@ class TestNestedTensorDeviceType(TestCase):
self.assertNotEqual(ub1[i], ub2[i])
nt1.clone(memory_format=torch.preserve_format)
msg = "clone_nested only supports memory format Preserve, but got ChannelsLast instead."
msg = "Nested tensor clone supports Preserve and Contiguous memory formats, called clone with memory format: ChannelsLast"
with self.assertRaisesRegex(RuntimeError, msg):
nt1.clone(memory_format=torch.channels_last)
@ -1105,7 +1100,6 @@ class TestNestedTensorDeviceType(TestCase):
)
@dtypes(torch.float, torch.float16, torch.double)
@torch.inference_mode()
def test_transpose(self, device, dtype):
nt = self.random_nt(device, dtype, 4, (4, 4))
# error case: transpose nested dimension
@ -1150,7 +1144,74 @@ class TestNestedTensorDeviceType(TestCase):
self.assertEqual(ptT, ptT_from_ntT)
@dtypes(torch.float, torch.float16, torch.double)
@torch.inference_mode()
def test_view(self, device, dtype):
nt = self.random_nt(device, dtype, 4, (4, 4))
# error case: empty shape
self.assertRaisesRegex(
RuntimeError,
r"shape '\[\]' is invalid for a nested tensor",
lambda: nt.view(())
)
# error case: empty nested tensor
nt_empty = torch.nested_tensor([])
self.assertRaisesRegex(
RuntimeError,
"empty nested tensor cannot be reshaped",
lambda: nt_empty.view(-1)
)
# error case: invalid proposed shape for underlying tensors
self.assertRaisesRegex(
RuntimeError,
r"invalid shape dimension -2",
lambda: nt.view(-2, 2, 3)
)
self.assertRaisesRegex(
RuntimeError,
r"shape '\[.*\]' is invalid for input of size [0-9]+",
lambda: nt.view(4, 2, 3)
)
# normal case
x0 = torch.randn((2, 20), device=device, dtype=dtype)
x1 = torch.randn((3, 20), device=device, dtype=dtype)
nt = torch.nested_tensor([x0, x1])
pt = nt.to_padded_tensor(0.0)
self.assertRaisesRegex(
RuntimeError,
r"for now view cannot change the implicit batch dimension",
lambda: nt.transpose(-1, -2).view(40, -1)
)
# inherit only the ragged dimension
# (2, 20) -> (2, 5, 4)
# (3, 20) -> (3, 5, 4)
nt1 = nt.view(2, -1, 5, 4)
# (2, 3, 20) -> (2, 3, 5, 4) -> (2, 4, 5, 4)
pt1 = pt.view(2, -1, 5, 4)
self.assertEqual(noncontiguous_to_padded_tensor(nt1), pt1)
# also inherit regular dimension
nt2 = nt1.view(2, -1, -1, 2, 2)
pt2 = pt1.view(2, -1, 5, 2, 2)
self.assertEqual(noncontiguous_to_padded_tensor(nt2), pt2)
@dtypes(torch.float, torch.float16, torch.double)
def test_view_inference_mode_interaction(self, device, dtype):
# Construct in default mode and view while in inference mode
nt = torch.nested_tensor([torch.randn((2, 20)), torch.randn((3, 20))], device=device, dtype=dtype)
with torch.inference_mode():
ntT = nt.view(2, -1, 4, 5)
ptT_from_ntT = noncontiguous_to_padded_tensor(ntT)
pt = nt.to_padded_tensor(0.0)
ptT = pt.view(2, -1, 4, 5)
self.assertEqual(ptT, ptT_from_ntT)
# Construct and view while in inference mode
with torch.inference_mode():
nt = torch.nested_tensor([torch.randn((2, 20)), torch.randn((3, 20))], device=device, dtype=dtype)
ntT = nt.view(2, -1, 4, 5)
ptT_from_ntT = noncontiguous_to_padded_tensor(ntT)
pt = nt.to_padded_tensor(0.0)
ptT = pt.view(2, -1, 4, 5)
self.assertEqual(ptT, ptT_from_ntT)
@dtypes(torch.float, torch.float16, torch.double)
def test_reshape(self, device, dtype):
nt = self.random_nt(device, dtype, 4, (4, 4))
# error case: empty shape

View File

@ -1356,10 +1356,6 @@
# making it impossible (hard) to detect when it is actually a view.
# - name: reshape(Tensor self, IntArrayRef shape)
- name: _reshape_nested(Tensor self, int[] shape) -> Tensor
self: _reshape_nested_backward(self, grad)
result: auto_linear
- name: _reshape_alias(Tensor(a) self, int[] size, int[] stride) -> Tensor(a)
self: grad.reshape(self.sizes())
result: auto_linear
@ -1732,10 +1728,14 @@
# linear
result1: mean(self_t, dim.value_or(IntArrayRef({})), keepdim)
# TODO: this derivative is not SymInt safe, need reshape_symint
- name: view(Tensor(a) self, SymInt[] size) -> Tensor(a)
dispatch:
Default:
self: grad.reshape(self.sizes())
result: auto_linear
AutogradNestedTensor:
self: grad.reshape_as(self)
result: auto_linear
- name: view.dtype(Tensor(a) self, ScalarType dtype) -> Tensor(a)
output_differentiability: [False]

View File

@ -310,6 +310,7 @@ def static_dispatch_keys(backends: List[BackendIndex]) -> List[DispatchKey]:
else:
return [backend.dispatch_key for backend in backends] + [
DispatchKey.CompositeImplicitAutograd,
DispatchKey.CompositeImplicitAutogradNestedTensor,
DispatchKey.CompositeExplicitAutograd,
DispatchKey.CompositeExplicitAutogradNonFunctional,
]
@ -330,6 +331,8 @@ def get_static_dispatch_backend(
return DispatchKey.CompositeExplicitAutogradNonFunctional
elif f.has_composite_implicit_autograd_kernel:
return DispatchKey.CompositeImplicitAutograd
elif f.has_composite_implicit_autograd_nested_tensor_kernel:
return DispatchKey.CompositeImplicitAutogradNestedTensor
return None
@ -426,6 +429,8 @@ def generate_static_dispatch_fallback_call(
return f"return {ns}::{DispatchKey.CompositeExplicitAutogradNonFunctional.lower()}::{name}({exprs});"
elif f.has_composite_implicit_autograd_kernel:
return f"return {ns}::{DispatchKey.CompositeImplicitAutograd.lower()}::{name}({exprs});"
elif f.has_composite_implicit_autograd_nested_tensor_kernel:
return f"return {ns}::{DispatchKey.CompositeImplicitAutogradNestedTensor.lower()}::{name}({exprs});"
else:
return f"""TORCH_CHECK(false, "Static dispatch does not support {name} for\
{', '.join([str(index.dispatch_key)for index in backend_indices])} ");"""
@ -1241,6 +1246,11 @@ def compute_registration_declarations(
"dispatch": str(
{k for k, v in backend_indices.items() if v.has_kernel(f)}
!= {DispatchKey.CompositeImplicitAutograd}
and {k for k, v in backend_indices.items() if v.has_kernel(f)}
!= {
DispatchKey.CompositeImplicitAutograd,
DispatchKey.CompositeImplicitAutogradNestedTensor,
}
),
"default": str(f.has_composite_kernel or has_autogenerated_composite_kernel(f)),
}
@ -2145,6 +2155,13 @@ def gen_source_files(
ns_grouped_native_functions[namespace].append(grouped_native_function)
dispatch_namespace = str(dispatch_key).lower()
# CompositeImplicitAutogradNestdTensor does not currently user the helpers generated
# compilation will fail when `-Werror=unused-function` flag is set
gen_dispatch_helpers: bool = (
dispatch_key != DispatchKey.CompositeImplicitAutogradNestedTensor
)
dispatch_definitions = get_native_function_definitions(
fm=fm,
grouped_native_functions=grouped_native_functions,
@ -2153,7 +2170,7 @@ def gen_source_files(
selector=selector,
rocm=rocm,
skip_dispatcher_op_registration=skip_dispatcher_op_registration,
gen_dispatch_helpers=True,
gen_dispatch_helpers=gen_dispatch_helpers,
)
fm.write_with_template(
f"Register{dispatch_key}.cpp",
@ -2636,6 +2653,7 @@ def main() -> None:
DispatchKey.CPU,
DispatchKey.CUDA,
DispatchKey.CompositeImplicitAutograd,
DispatchKey.CompositeImplicitAutogradNestedTensor,
DispatchKey.CompositeExplicitAutograd,
DispatchKey.CompositeExplicitAutogradNonFunctional,
DispatchKey.Meta,

View File

@ -93,6 +93,7 @@ class DispatchKey(Enum):
Autograd = auto()
CompositeImplicitAutograd = auto()
CompositeImplicitAutogradNestedTensor = auto()
CompositeExplicitAutograd = auto()
CompositeExplicitAutogradNonFunctional = auto()
@ -217,6 +218,7 @@ dispatch_keys = [
DispatchKey.QuantizedCPU,
DispatchKey.QuantizedCUDA,
DispatchKey.CompositeImplicitAutograd,
DispatchKey.CompositeImplicitAutogradNestedTensor,
DispatchKey.CompositeExplicitAutograd,
DispatchKey.CompositeExplicitAutogradNonFunctional,
DispatchKey.NestedTensorCPU,
@ -237,6 +239,7 @@ def is_generic_dispatch_key(dk: DispatchKey) -> bool:
DispatchKey.CompositeExplicitAutograd,
DispatchKey.CompositeExplicitAutogradNonFunctional,
DispatchKey.CompositeImplicitAutograd,
DispatchKey.CompositeImplicitAutogradNestedTensor,
}
@ -485,6 +488,7 @@ class NativeFunction:
# Whether or not the NativeFunction contains a backend-agnostic kernel
has_composite_implicit_autograd_kernel: bool
has_composite_implicit_autograd_nested_tensor_kernel: bool
has_composite_explicit_autograd_kernel: bool
has_composite_explicit_autograd_non_functional_kernel: bool
@ -699,9 +703,15 @@ class NativeFunction:
if d == DispatchKey.CompositeExplicitAutograd
or d == DispatchKey.CompositeExplicitAutogradNonFunctional
or d == DispatchKey.CompositeImplicitAutograd
or d == DispatchKey.CompositeImplicitAutogradNestedTensor
]
assert len(composites_in_dispatch) <= 1, (
assert len(composites_in_dispatch) <= 1 or (
len(composites_in_dispatch) == 2
and DispatchKey.CompositeImplicitAutograd in composites_in_dispatch
and DispatchKey.CompositeImplicitAutogradNestedTensor
in composites_in_dispatch
), (
"cannot specify more than one of CompositeExplicitAutograd, CompositeExplicitAutogradNonFunctional, "
"or CompositeImplicitAutograd on a single kernel; each "
"strictly subsumes the other. If you wanted to provide an explicit autograd "
@ -756,11 +766,23 @@ class NativeFunction:
# Structured functions MUST have a dispatch table
is_abstract = True
else:
is_abstract = dispatch.keys() != {DispatchKey.CompositeImplicitAutograd}
is_abstract = (
dispatch.keys() != {DispatchKey.CompositeImplicitAutograd}
and dispatch.keys()
!= {DispatchKey.CompositeImplicitAutogradNestedTensor}
and dispatch.keys()
!= {
DispatchKey.CompositeImplicitAutograd,
DispatchKey.CompositeImplicitAutogradNestedTensor,
}
)
has_composite_implicit_autograd_kernel = (
DispatchKey.CompositeImplicitAutograd in dispatch.keys()
)
has_composite_implicit_autograd_nested_tensor_kernel = (
DispatchKey.CompositeImplicitAutogradNestedTensor in dispatch.keys()
)
has_composite_explicit_autograd_kernel = (
DispatchKey.CompositeExplicitAutograd in dispatch.keys()
)
@ -808,6 +830,7 @@ class NativeFunction:
cpp_no_default_args=cpp_no_default_args,
is_abstract=is_abstract,
has_composite_implicit_autograd_kernel=has_composite_implicit_autograd_kernel,
has_composite_implicit_autograd_nested_tensor_kernel=has_composite_implicit_autograd_nested_tensor_kernel,
has_composite_explicit_autograd_kernel=has_composite_explicit_autograd_kernel,
has_composite_explicit_autograd_non_functional_kernel=has_composite_explicit_autograd_non_functional_kernel,
tags=tags,
@ -899,6 +922,9 @@ class NativeFunction:
self.has_composite_implicit_autograd_kernel
or self.has_composite_explicit_autograd_kernel
or self.has_composite_explicit_autograd_non_functional_kernel
) or (
self.has_composite_implicit_autograd_kernel
and self.has_composite_implicit_autograd_nested_tensor_kernel
)
@property
@ -976,7 +1002,10 @@ class NativeFunctionsGroup:
if self.structured:
# For now, structured composite kernels are not supported (need some
# design work to figure out how to make the composite case work)
assert not self.out.has_composite_implicit_autograd_kernel
assert (
not self.out.has_composite_implicit_autograd_kernel
and not self.out.has_composite_implicit_autograd_nested_tensor_kernel
)
assert self.functional.structured_delegate == self.out.func.name, (
f"{self.functional.func.name} delegates to {self.functional.structured_delegate} "
@ -2510,6 +2539,14 @@ class NativeFunctionsViewGroup:
f"{str(self.view.func.name)} and {str(self.view_inplace.func.name)} must either"
" both have CompositeImplicitAutograd kernels, or both not have composite kernels."
)
if self.view.has_composite_implicit_autograd_nested_tensor_kernel:
if self.view_inplace is not None:
assert (
self.view_inplace.has_composite_implicit_autograd_nested_tensor_kernel
), (
f"{str(self.view.func.name)} and {str(self.view_inplace.func.name)} must either"
" both have CompositeImplicitAutogradNestedTensor kernels, or both not have composite kernels."
)
def functions(self, *, include_copy: bool = True) -> Iterator[NativeFunction]:
yield self.view

View File

@ -336,6 +336,7 @@ def generate_function(
cpp_no_default_args=set(),
is_abstract=f.is_abstract,
has_composite_implicit_autograd_kernel=False,
has_composite_implicit_autograd_nested_tensor_kernel=False,
has_composite_explicit_autograd_kernel=True,
has_composite_explicit_autograd_non_functional_kernel=False,
# Every generated NativeFunction gets a "generated" tag, so it's easy to tell