mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
The original author is @YifanShenSZ and the original PR is: #82754 # Summary: Previous reshape [https://github.com/pytorch/pytorch/issues/80981](https://github.com/pytorch/pytorch/pull/80981) is ok for forward, but needs improvement for backward: need to handle "sometimes view sometimes copy" behavior. This pull request fixes it by: 1. add a new alias dispatch key `CompositeImplicitAutogradNestedTensor`, which ideally would work as nested-tensor version of `CompositeImplicitAutograd` 2. register `reshape_nested` to `reshape` by `CompositeImplicitAutogradNestedTensor` Side changes: * add contiguous memory format support to `clone_nested` * add `view_nested` * add `reshape_as_nested` Fix issue [https://github.com/pytorch/pytorch/issues/83041](https://github.com/pytorch/pytorch/issues/83041) Pull Request resolved: https://github.com/pytorch/pytorch/pull/82754 Test Plan: Imported from GitHub, without a `Test Plan:` line. **Static Docs Preview: executorch** |[Full Site](https://our.intern.facebook.com/intern/staticdocs/eph/D39023822/V13/executorch/)| |**Modified Pages**| Reviewed By: albanD Differential Revision: D39023822 Pulled By: drisspg Pull Request resolved: https://github.com/pytorch/pytorch/pull/84154 Approved by: https://github.com/bdhirsh, https://github.com/albanD
This commit is contained in:
parent
9bcad063d8
commit
673b35c847
|
|
@ -51,6 +51,7 @@ generated_cpu_cpp = [
|
|||
"aten/src/ATen/RegisterSparseCsrCPU.cpp",
|
||||
"aten/src/ATen/RegisterZeroTensor.cpp",
|
||||
"aten/src/ATen/RegisterCompositeImplicitAutograd.cpp",
|
||||
"aten/src/ATen/RegisterCompositeImplicitAutogradNestedTensor.cpp",
|
||||
"aten/src/ATen/RegisterCompositeExplicitAutograd.cpp",
|
||||
"aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp",
|
||||
"aten/src/ATen/RegisterMeta.cpp",
|
||||
|
|
@ -66,6 +67,8 @@ generated_cpu_cpp = [
|
|||
"aten/src/ATen/CompositeExplicitAutogradNonFunctionalFunctions_inl.h",
|
||||
"aten/src/ATen/CompositeImplicitAutogradFunctions.h",
|
||||
"aten/src/ATen/CompositeImplicitAutogradFunctions_inl.h",
|
||||
"aten/src/ATen/CompositeImplicitAutogradNestedTensorFunctions.h",
|
||||
"aten/src/ATen/CompositeImplicitAutogradNestedTensorFunctions_inl.h",
|
||||
"aten/src/ATen/CompositeViewCopyKernels.cpp",
|
||||
"aten/src/ATen/FunctionalInverses.h",
|
||||
"aten/src/ATen/Functions.h",
|
||||
|
|
|
|||
|
|
@ -307,6 +307,21 @@ std::pair<const AnnotatedKernel&, const char*> OperatorEntry::computeDispatchTab
|
|||
// For AutogradOther, we return ambiguousAutogradOtherKernel() if there's registration
|
||||
// to any of its backends.
|
||||
// See Note [Undefined in dispatchTable_] for the special handling for Undefined.
|
||||
|
||||
// If the dispatch key is included in CompositeImplicitAutogradNestedTensor,
|
||||
// then we register it to nested-tensor kernel rather than
|
||||
// regular-tensor CompositeImplicitAutograd kernel.
|
||||
// We have no intention to change the behavior of Undefined,
|
||||
// so this nested-tensor branch requires `dispatch_key != DispatchKey::Undefined`
|
||||
// to let the original CompositeImplicitAutograd handle Undefined
|
||||
if (dispatch_key != DispatchKey::Undefined && isIncludedInAlias(dispatch_key, DispatchKey::CompositeImplicitAutogradNestedTensor)) {
|
||||
if (auto nested_registration = getKernelForDispatchKey(DispatchKey::CompositeImplicitAutogradNestedTensor)) {
|
||||
if (!has_backend_kernel) {
|
||||
return {*nested_registration, "nested kernel"};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (dispatch_key == DispatchKey::Undefined || isIncludedInAlias(dispatch_key, DispatchKey::CompositeImplicitAutograd)) {
|
||||
if (auto math_registration = getKernelForDispatchKey(DispatchKey::CompositeImplicitAutograd)) {
|
||||
if (dispatch_key == DispatchKey::AutogradOther
|
||||
|
|
|
|||
|
|
@ -1256,17 +1256,6 @@ Tensor alias_with_sizes_and_strides(
|
|||
}
|
||||
|
||||
Tensor reshape(const Tensor& self, IntArrayRef proposed_shape) {
|
||||
// reshape has special autograd logic since it sometimes returns a view but sometimes does not
|
||||
// we have to intercept here instead of using dispatcher
|
||||
// otherwise we will see "autograd still running" kind of error in inference mode:
|
||||
// * if we create a tensor in inference mode scope,
|
||||
// then pass it to a inference mode decorated function,
|
||||
// everything is fine
|
||||
// * but if we create the input tensor not with inference mode,
|
||||
// then errors like "Cannot set version_counter for inference tensor" arise
|
||||
if (self.is_nested()) {
|
||||
return at::_reshape_nested(self, proposed_shape);
|
||||
}
|
||||
if (self.is_sparse()) {
|
||||
AT_ERROR("reshape is not implemented for sparse tensors");
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4200,16 +4200,9 @@
|
|||
variants: function, method
|
||||
device_check: NoCheck
|
||||
device_guard: False
|
||||
|
||||
- func: _reshape_nested(Tensor self, int[] shape) -> Tensor
|
||||
dispatch:
|
||||
NestedTensorCPU, NestedTensorCUDA: _reshape_nested
|
||||
autogen: _reshape_nested.out
|
||||
|
||||
- func: _reshape_nested_backward(Tensor self, Tensor grad) -> Tensor
|
||||
dispatch:
|
||||
NestedTensorCPU, NestedTensorCUDA: _reshape_nested_backward
|
||||
autogen: _reshape_nested_backward.out
|
||||
CompositeImplicitAutograd: reshape
|
||||
CompositeImplicitAutogradNestedTensor: reshape_nested
|
||||
|
||||
# NOTE [ _reshape_alias ] is meant to be used in the implementation of reshape.
|
||||
# They are not user-facing, hence the leading underscore. Please don't use it
|
||||
|
|
@ -4233,6 +4226,9 @@
|
|||
variants: method
|
||||
device_check: NoCheck
|
||||
device_guard: False
|
||||
dispatch:
|
||||
CompositeImplicitAutograd: reshape_as
|
||||
CompositeImplicitAutogradNestedTensor: reshape_as_nested
|
||||
|
||||
- func: round(Tensor self) -> Tensor
|
||||
device_check: NoCheck # TensorIterator
|
||||
|
|
@ -6889,6 +6885,7 @@
|
|||
Meta: view_meta
|
||||
ZeroTensor, CPU, CUDA, QuantizedCPU, QuantizedCUDA, MPS: view
|
||||
MkldnnCPU: mkldnn_view
|
||||
NestedTensorCPU, NestedTensorCUDA: view_nested
|
||||
|
||||
# Warning: If you want to change the name or overload name of this
|
||||
# operator, you might also want to change the `isBlockListedSchema`
|
||||
|
|
|
|||
|
|
@ -66,23 +66,6 @@ std::tuple<Tensor, Tensor, Tensor> nested_linear_backward(
|
|||
return std::tuple<Tensor, Tensor, Tensor>{grad_input, grad_weight, grad_bias};
|
||||
}
|
||||
|
||||
Tensor _reshape_nested_backward(const Tensor& self, const Tensor& grad) {
|
||||
auto self_ptr = get_nested_tensor_impl(self);
|
||||
// TODO: this is to reproduce self_ptr->opt_sizes_
|
||||
// if an accessor is provided in the future, can replace this
|
||||
std::vector<int64_t> sizes;
|
||||
for (int64_t i = 0; i < self_ptr->dim(); i++) {
|
||||
c10::optional<int64_t> opt_size = self_ptr->opt_size(i);
|
||||
if (opt_size.has_value()) {
|
||||
sizes.push_back(*opt_size);
|
||||
}
|
||||
else {
|
||||
sizes.push_back(-1);
|
||||
}
|
||||
}
|
||||
return grad.reshape(sizes);
|
||||
}
|
||||
|
||||
Tensor nested_softmax_backward(
|
||||
const Tensor& grad,
|
||||
const Tensor& output,
|
||||
|
|
|
|||
|
|
@ -688,17 +688,38 @@ Tensor select_nested(const Tensor& self, int64_t dim, int64_t index) {
|
|||
Tensor clone_nested(
|
||||
const Tensor& self,
|
||||
c10::optional<c10::MemoryFormat> optional_memory_format) {
|
||||
auto memory_format = optional_memory_format.value_or(MemoryFormat::Preserve);
|
||||
auto memory_format = optional_memory_format.value_or(c10::MemoryFormat::Preserve);
|
||||
auto self_ptr = get_nested_tensor_impl(self);
|
||||
if (memory_format == c10::MemoryFormat::Preserve ||
|
||||
(memory_format == c10::MemoryFormat::Contiguous && self.is_contiguous())) {
|
||||
const Tensor& buffer = self_ptr->get_buffer(),
|
||||
sizemat = self_ptr->get_nested_size_tensor(),
|
||||
stridemat = self_ptr->get_nested_stride_tensor();
|
||||
const std::vector<int64_t>& offsets = self_ptr->get_offsets();
|
||||
// TODO: The size and the stride do not necessarily need to be cloned,
|
||||
// but it is more conservative.
|
||||
// This is something we could revisit once we land a more
|
||||
// efficient implementation of nested_size_tensor_ and nested_stride_tensor.
|
||||
return wrap_buffer(buffer.clone(), sizemat.clone(), stridemat.clone(), std::vector<int64_t>(offsets));
|
||||
}
|
||||
// actually, memory format is contiguous and self is noncontiguous
|
||||
else if (memory_format == c10::MemoryFormat::Contiguous) {
|
||||
const Tensor& self_buffer = self_ptr->get_buffer(),
|
||||
sizemat = self_ptr->get_nested_size_tensor();
|
||||
Tensor output_buffer = at::empty_like(self_buffer);
|
||||
Tensor output = wrap_buffer(output_buffer, sizemat);
|
||||
std::vector<Tensor> self_unbind = self.unbind(),
|
||||
output_unbind = output.unbind();
|
||||
for (int64_t i = 0; i < self_ptr->size(0); i++) {
|
||||
output_unbind[i].copy_(self_unbind[i]);
|
||||
}
|
||||
return output;
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
memory_format == MemoryFormat::Preserve,
|
||||
"clone_nested only supports memory format Preserve, but got ",
|
||||
memory_format,
|
||||
" instead.");
|
||||
// TODO: The size doesn't necessarily need to be cloned, but it is more
|
||||
// conservative. This is something we could revisit once we land a more
|
||||
// efficient implementation of nested_size_tensor_.
|
||||
return wrap_buffer(
|
||||
get_buffer(self).clone(), get_nested_size_tensor(self).clone());
|
||||
false,
|
||||
"Nested tensor clone supports Preserve and Contiguous memory formats, called clone with memory format: ",
|
||||
memory_format);
|
||||
}
|
||||
}
|
||||
|
||||
at::Tensor NestedTensor_get_nested_size_tensor(const at::Tensor& self){
|
||||
|
|
@ -1008,7 +1029,7 @@ Tensor transpose_nested(const Tensor& self, int64_t dim0, int64_t dim1) {
|
|||
self, sizemat_transposed, stridemat_transposed, std::vector<int64_t>(self_ptr->get_offsets()));
|
||||
}
|
||||
|
||||
// utilities supporting `_reshape_nested`
|
||||
// utilities supporting `view_nested` and `reshape_nested`
|
||||
namespace {
|
||||
// Args:
|
||||
// sizes: the sizes of original nested tensor
|
||||
|
|
@ -1016,10 +1037,10 @@ namespace {
|
|||
// proposed_shape: user proposed new shape
|
||||
// op: the options for new size and stride matrices
|
||||
// Returns:
|
||||
// whether reshape as view is possible (i.e. old buffer can be reused)
|
||||
// whether viewable
|
||||
// size matrix after reshape
|
||||
// stride matrix after reshape (not fully populated if reshape as view is impossible)
|
||||
inline std::tuple<bool, Tensor, Tensor> NestedTensor_reshape_size_stride(
|
||||
// stride matrix after reshape (not fully populated if not viewable)
|
||||
inline std::tuple<bool, Tensor, Tensor> NestedTensor_compute_size_stride(
|
||||
const std::vector<IntArrayRef>& sizes,
|
||||
const std::vector<IntArrayRef>& strides,
|
||||
const IntArrayRef& proposed_shape,
|
||||
|
|
@ -1027,7 +1048,7 @@ inline std::tuple<bool, Tensor, Tensor> NestedTensor_reshape_size_stride(
|
|||
int64_t ntensors = sizes.size(),
|
||||
ndims_underlying = sizes[0].size(),
|
||||
ndims_underlying_reshaped = proposed_shape.size() - 1;
|
||||
bool reshape_as_view = true;
|
||||
bool viewable = true;
|
||||
Tensor sizemat_reshaped = at::empty({ntensors, ndims_underlying_reshaped}, op),
|
||||
stridemat_reshaped = at::empty({ntensors, ndims_underlying_reshaped}, op);
|
||||
int64_t* sizemat_reshaped_ptr = sizemat_reshaped.data_ptr<int64_t>(),
|
||||
|
|
@ -1039,6 +1060,7 @@ inline std::tuple<bool, Tensor, Tensor> NestedTensor_reshape_size_stride(
|
|||
std::vector<int64_t> size_reshaped_vector(proposed_shape.begin() + 1, proposed_shape.end());
|
||||
// some negative sizes remain to be infered
|
||||
if (ndims_underlying < ndims_underlying_reshaped) {
|
||||
int64_t numel = 1, numel_reshaped = 1;
|
||||
// replace negative sizes for old dimensions with old sizes
|
||||
for (int64_t idim = 0; idim < ndims_underlying; idim++) {
|
||||
int64_t& size_reshaped = size_reshaped_vector[idim];
|
||||
|
|
@ -1046,12 +1068,17 @@ inline std::tuple<bool, Tensor, Tensor> NestedTensor_reshape_size_stride(
|
|||
if (size_reshaped == -1) {
|
||||
size_reshaped = size[idim];
|
||||
}
|
||||
numel *= size[idim];
|
||||
numel_reshaped *= size_reshaped;
|
||||
}
|
||||
// infer negative size for new dimension
|
||||
int64_t infer_index = -1;
|
||||
for (int64_t idim = ndims_underlying; idim < ndims_underlying_reshaped; idim++) {
|
||||
const int64_t& size_reshaped = size_reshaped_vector[idim];
|
||||
if (size_reshaped == -1) {
|
||||
if (size_reshaped >= 0) {
|
||||
numel_reshaped *= size_reshaped;
|
||||
}
|
||||
else if (size_reshaped == -1) {
|
||||
if (infer_index > -1) {
|
||||
throw std::runtime_error("only one dimension can be inferred");
|
||||
}
|
||||
|
|
@ -1059,22 +1086,36 @@ inline std::tuple<bool, Tensor, Tensor> NestedTensor_reshape_size_stride(
|
|||
infer_index = idim;
|
||||
}
|
||||
}
|
||||
else if (size_reshaped < 0) {
|
||||
else {
|
||||
AT_ERROR("invalid shape dimension ", size_reshaped);
|
||||
}
|
||||
}
|
||||
// See Note [inference and inheritance semantics]
|
||||
// See Note [Special size rule for nested tensor]
|
||||
TORCH_CHECK(infer_index == -1, "nested tensor does not infer shape");
|
||||
TORCH_CHECK(
|
||||
numel == numel_reshaped,
|
||||
"shape '", proposed_shape, "' ",
|
||||
"is invalid for input of size ", numel);
|
||||
}
|
||||
// all negative sizes can be replaced
|
||||
else {
|
||||
int64_t numel = 1, numel_reshaped = 1;
|
||||
for (int64_t idim = 0; idim < ndims_underlying_reshaped; idim++) {
|
||||
int64_t& size_reshaped = size_reshaped_vector[idim];
|
||||
TORCH_CHECK(size_reshaped >= -1, "invalid shape dimension ", size_reshaped);
|
||||
if (size_reshaped == -1) {
|
||||
size_reshaped = size[idim];
|
||||
}
|
||||
numel *= size[idim];
|
||||
numel_reshaped *= size_reshaped;
|
||||
}
|
||||
for (int64_t idim = ndims_underlying_reshaped; idim < ndims_underlying; idim++) {
|
||||
numel *= size[idim];
|
||||
}
|
||||
TORCH_CHECK(
|
||||
numel == numel_reshaped,
|
||||
"shape '", proposed_shape, "' ",
|
||||
"is invalid for input of size ", numel);
|
||||
}
|
||||
IntArrayRef size_reshaped(size_reshaped_vector);
|
||||
// compute reshaped stride
|
||||
|
|
@ -1092,7 +1133,7 @@ inline std::tuple<bool, Tensor, Tensor> NestedTensor_reshape_size_stride(
|
|||
}
|
||||
// reshape as view is impossible
|
||||
else {
|
||||
reshape_as_view = false;
|
||||
viewable = false;
|
||||
// fill reshaped size into sizemat
|
||||
for (int64_t idim = 0; idim < ndims_underlying_reshaped; idim++) {
|
||||
sizemat_reshaped_ptr[idim] = size_reshaped[idim];
|
||||
|
|
@ -1100,42 +1141,59 @@ inline std::tuple<bool, Tensor, Tensor> NestedTensor_reshape_size_stride(
|
|||
sizemat_reshaped_ptr += ndims_underlying_reshaped;
|
||||
}
|
||||
}
|
||||
return std::make_tuple(reshape_as_view, sizemat_reshaped, stridemat_reshaped);
|
||||
}
|
||||
|
||||
// Args:
|
||||
// nt_reshaped: the reshaped nested tensor to receive copies
|
||||
// buffer: the original nested tensor buffer
|
||||
// sizes: the original nested tensor sizes (may have gone through collapsing or splitting)
|
||||
// strides: the original nested tensor strides (may have gone through collapsing or splitting)
|
||||
// offsets: the original nested tensor offsets (may have gone through collapsing or splitting)
|
||||
inline void NestedTensor_reshape_copy(
|
||||
Tensor& nt_reshaped,
|
||||
const Tensor& buffer,
|
||||
const std::vector<IntArrayRef>& sizes,
|
||||
const std::vector<IntArrayRef>& strides,
|
||||
const std::vector<int64_t>& offsets) {
|
||||
auto nt_reshaped_ptr = get_nested_tensor_impl(nt_reshaped);
|
||||
const Tensor& buffer_reshaped = nt_reshaped_ptr->get_buffer();
|
||||
std::vector<IntArrayRef> sizes_reshaped = NestedTensor_get_sizes(nt_reshaped_ptr),
|
||||
strides_reshaped = NestedTensor_get_strides(nt_reshaped_ptr);
|
||||
const std::vector<int64_t>& offsets_reshaped = nt_reshaped_ptr->get_offsets();
|
||||
for (int64_t i = 0; i < nt_reshaped_ptr->size(0); i++) {
|
||||
buffer_reshaped.as_strided(sizes_reshaped[i], strides_reshaped[i], offsets_reshaped[i]).copy_(
|
||||
// TODO: can we avoid allocating new memory for `buffer...reshape`
|
||||
// I did not find anything like reshape_out
|
||||
buffer.as_strided(sizes[i], strides[i], offsets[i]).reshape(sizes_reshaped[i]));
|
||||
}
|
||||
return std::make_tuple(viewable, sizemat_reshaped, stridemat_reshaped);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// Special rules for reshape(nested tensor):
|
||||
// 1. Only 1 regular dimension can be collapsed with
|
||||
// or splitted from the implicit batch dimension
|
||||
// 2. Instead of infering size, -1 means "inherit the old size", so:
|
||||
// Note [Special size rule for nested tensor]
|
||||
// Instead of infering size, -1 means "inherit the old size", so:
|
||||
// * negative size is legal for a ragged dimension
|
||||
// * multiple sizes can be -1
|
||||
Tensor _reshape_nested(const Tensor& self, IntArrayRef proposed_shape) {
|
||||
// In principle we could still infer a dimension,
|
||||
// we are designing a better semantics to include both inheritance and inference
|
||||
Tensor view_nested(const Tensor& self, IntArrayRef proposed_shape) {
|
||||
TORCH_CHECK(
|
||||
proposed_shape.size() > 0,
|
||||
"shape '[]' is invalid for a nested tensor");
|
||||
auto self_ptr = get_nested_tensor_impl(self);
|
||||
// basic information before reshaping
|
||||
int64_t ntensors = self_ptr->size(0);
|
||||
TORCH_CHECK(
|
||||
ntensors > 0,
|
||||
"empty nested tensor cannot be reshaped");
|
||||
// basic information after reshaping
|
||||
int64_t ntensors_reshaped;
|
||||
if (proposed_shape[0] >= 0) {
|
||||
ntensors_reshaped = proposed_shape[0];
|
||||
}
|
||||
else if (proposed_shape[0] == -1) {
|
||||
ntensors_reshaped = ntensors;
|
||||
}
|
||||
else {
|
||||
AT_ERROR("invalid shape dimension ", proposed_shape[0]);
|
||||
}
|
||||
TORCH_CHECK(
|
||||
ntensors == ntensors_reshaped,
|
||||
"for now view cannot change the implicit batch dimension");
|
||||
std::vector<IntArrayRef> sizes = NestedTensor_get_sizes(self_ptr),
|
||||
strides = NestedTensor_get_strides(self_ptr);
|
||||
// reshaping underlying tensor dimensions does not change offset
|
||||
// determine reshaped size and stride
|
||||
const Tensor& sizemat = self_ptr->get_nested_size_tensor();
|
||||
bool viewable;
|
||||
Tensor sizemat_reshaped, stridemat_reshaped;
|
||||
std::tie(viewable, sizemat_reshaped, stridemat_reshaped) = NestedTensor_compute_size_stride(
|
||||
sizes, strides, proposed_shape, sizemat.options());
|
||||
TORCH_CHECK(
|
||||
viewable,
|
||||
"view size is not compatible with input tensor's size and stride "
|
||||
"(at least one dimension spans across two contiguous subspaces). "
|
||||
"Use .reshape(...) instead.");
|
||||
return create_nested_view_tensor(self, sizemat_reshaped, stridemat_reshaped, std::vector<int64_t>(self_ptr->get_offsets()));
|
||||
}
|
||||
|
||||
// See Note [Special size rule for nested tensor]
|
||||
Tensor reshape_nested(const Tensor& self, IntArrayRef proposed_shape) {
|
||||
TORCH_CHECK(
|
||||
proposed_shape.size() > 0,
|
||||
"shape '[]' is invalid for a nested tensor");
|
||||
|
|
@ -1161,23 +1219,37 @@ Tensor _reshape_nested(const Tensor& self, IntArrayRef proposed_shape) {
|
|||
"for now reshape cannot change the implicit batch dimension");
|
||||
std::vector<IntArrayRef> sizes = NestedTensor_get_sizes(self_ptr),
|
||||
strides = NestedTensor_get_strides(self_ptr);
|
||||
const std::vector<int64_t>& offsets = self_ptr->get_offsets();
|
||||
// reshaping underlying tensor dimensions does not change offset
|
||||
// determine reshaped size and stride
|
||||
const Tensor& buffer = self_ptr->get_buffer(),
|
||||
& sizemat = self_ptr->get_nested_size_tensor();
|
||||
bool reshape_as_view;
|
||||
const Tensor& sizemat = self_ptr->get_nested_size_tensor();
|
||||
bool viewable;
|
||||
Tensor sizemat_reshaped, stridemat_reshaped;
|
||||
std::tie(reshape_as_view, sizemat_reshaped, stridemat_reshaped) = NestedTensor_reshape_size_stride(
|
||||
std::tie(viewable, sizemat_reshaped, stridemat_reshaped) = NestedTensor_compute_size_stride(
|
||||
sizes, strides, proposed_shape, sizemat.options());
|
||||
if (reshape_as_view) {
|
||||
return wrap_buffer(buffer, sizemat_reshaped, stridemat_reshaped, std::vector<int64_t>(offsets));
|
||||
if (viewable) {
|
||||
return self.view(proposed_shape);
|
||||
}
|
||||
Tensor buffer_reshaped = buffer.new_empty(buffer.sizes());
|
||||
Tensor output = wrap_buffer(buffer_reshaped, sizemat_reshaped);
|
||||
NestedTensor_reshape_copy(output,
|
||||
buffer, sizes, strides, offsets);
|
||||
return output;
|
||||
else {
|
||||
return self.clone(at::MemoryFormat::Contiguous).view(proposed_shape);
|
||||
}
|
||||
}
|
||||
|
||||
Tensor reshape_as_nested(const Tensor& self, const Tensor& other) {
|
||||
auto other_ptr = get_nested_tensor_impl(other);
|
||||
// TODO: this is to reproduce other_ptr->opt_sizes_
|
||||
// if an accessor is provided in the future, can replace this
|
||||
std::vector<int64_t> sizes;
|
||||
for (int64_t i = 0; i < other_ptr->dim(); i++) {
|
||||
c10::optional<int64_t> opt_size = other_ptr->opt_size(i);
|
||||
if (opt_size.has_value()) {
|
||||
sizes.push_back(*opt_size);
|
||||
}
|
||||
else {
|
||||
sizes.push_back(-1);
|
||||
}
|
||||
}
|
||||
// reshape with other.opt_sizes_
|
||||
return self.reshape(sizes);
|
||||
}
|
||||
|
||||
} // namespace native
|
||||
|
|
|
|||
|
|
@ -249,6 +249,7 @@ PT_BACKEND_HEADERS = [
|
|||
"CompositeExplicitAutograd",
|
||||
"CompositeExplicitAutogradNonFunctional",
|
||||
"CompositeImplicitAutograd",
|
||||
"CompositeImplicitAutogradNestedTensor",
|
||||
"Meta",
|
||||
]
|
||||
|
||||
|
|
@ -307,6 +308,7 @@ def get_aten_generated_files(enabled_backends):
|
|||
src_files = [
|
||||
"RegisterBackendSelect.cpp",
|
||||
"RegisterCompositeImplicitAutograd.cpp",
|
||||
"RegisterCompositeImplicitAutogradNestedTensor.cpp",
|
||||
"RegisterCompositeExplicitAutograd.cpp",
|
||||
"RegisterCompositeExplicitAutogradNonFunctional.cpp",
|
||||
"CompositeViewCopyKernels.cpp",
|
||||
|
|
@ -327,6 +329,8 @@ def get_aten_generated_files(enabled_backends):
|
|||
"Operators_4.cpp",
|
||||
"CompositeImplicitAutogradFunctions.h",
|
||||
"CompositeImplicitAutogradFunctions_inl.h",
|
||||
"CompositeImplicitAutogradNestedTensorFunctions.h",
|
||||
"CompositeImplicitAutogradNestedTensorFunctions_inl.h",
|
||||
"CompositeExplicitAutogradFunctions.h",
|
||||
"CompositeExplicitAutogradFunctions_inl.h",
|
||||
"CompositeExplicitAutogradNonFunctionalFunctions.h",
|
||||
|
|
@ -364,7 +368,7 @@ def get_aten_derived_type_src_rules(aten_rule_name, enabled_backends):
|
|||
def get_aten_selective_cpp_rules(aten_rule_name, enabled_backends):
|
||||
return [
|
||||
":{}[{}]".format(aten_rule_name, f)
|
||||
for f in ["RegisterCompositeImplicitAutograd.cpp", "RegisterCompositeExplicitAutograd.cpp", "RegisterCompositeExplicitAutogradNonFunctional.cpp", "RegisterSchema.cpp", "RegisterBackendSelect.cpp", "CompositeViewCopyKernels.cpp"]
|
||||
for f in ["RegisterCompositeImplicitAutograd.cpp", "RegisterCompositeImplicitAutogradNestedTensor.cpp", "RegisterCompositeExplicitAutograd.cpp", "RegisterCompositeExplicitAutogradNonFunctional.cpp", "RegisterSchema.cpp", "RegisterBackendSelect.cpp", "CompositeViewCopyKernels.cpp"]
|
||||
] + get_aten_derived_type_src_rules(aten_rule_name, enabled_backends)
|
||||
|
||||
def get_aten_derived_type_srcs(enabled_backends):
|
||||
|
|
@ -1083,6 +1087,8 @@ def define_buck_targets(
|
|||
"CompositeExplicitAutogradNonFunctionalFunctions_inl.h": ":gen_aten[CompositeExplicitAutogradNonFunctionalFunctions_inl.h]",
|
||||
"CompositeImplicitAutogradFunctions.h": ":gen_aten[CompositeImplicitAutogradFunctions.h]",
|
||||
"CompositeImplicitAutogradFunctions_inl.h": ":gen_aten[CompositeImplicitAutogradFunctions_inl.h]",
|
||||
"CompositeImplicitAutogradNestedTensorFunctions.h": ":gen_aten[CompositeImplicitAutogradNestedTensorFunctions.h]",
|
||||
"CompositeImplicitAutogradNestedTensorFunctions_inl.h": ":gen_aten[CompositeImplicitAutogradNestedTensorFunctions_inl.h]",
|
||||
"FunctionalInverses.h": ":gen_aten[FunctionalInverses.h]",
|
||||
"Functions.h": ":gen_aten[Functions.h]",
|
||||
"MethodOperators.h": ":gen_aten[MethodOperators.h]",
|
||||
|
|
|
|||
|
|
@ -162,6 +162,8 @@ GENERATED_H_CORE = [
|
|||
"CompositeExplicitAutogradNonFunctionalFunctions_inl.h",
|
||||
"CompositeImplicitAutogradFunctions.h",
|
||||
"CompositeImplicitAutogradFunctions_inl.h",
|
||||
"CompositeImplicitAutogradNestedTensorFunctions.h",
|
||||
"CompositeImplicitAutogradNestedTensorFunctions_inl.h",
|
||||
"MetaFunctions.h",
|
||||
"MetaFunctions_inl.h",
|
||||
"core/TensorBody.h",
|
||||
|
|
@ -193,6 +195,7 @@ GENERATED_CPP = [
|
|||
"RegisterSparseCsrCPU.cpp",
|
||||
"RegisterMkldnnCPU.cpp",
|
||||
"RegisterCompositeImplicitAutograd.cpp",
|
||||
"RegisterCompositeImplicitAutogradNestedTensor.cpp",
|
||||
"RegisterZeroTensor.cpp",
|
||||
"RegisterMeta.cpp",
|
||||
"RegisterQuantizedMeta.cpp",
|
||||
|
|
|
|||
|
|
@ -178,6 +178,8 @@ const char* toString(DispatchKey t) {
|
|||
return "Autograd";
|
||||
case DispatchKey::CompositeImplicitAutograd:
|
||||
return "CompositeImplicitAutograd";
|
||||
case DispatchKey::CompositeImplicitAutogradNestedTensor:
|
||||
return "CompositeImplicitAutogradNestedTensor";
|
||||
case DispatchKey::CompositeExplicitAutograd:
|
||||
return "CompositeExplicitAutograd";
|
||||
case DispatchKey::CompositeExplicitAutogradNonFunctional:
|
||||
|
|
@ -324,6 +326,8 @@ c10::DispatchKey parseDispatchKey(const std::string& k) {
|
|||
{"Autograd", c10::DispatchKey::Autograd},
|
||||
{"CompositeImplicitAutograd",
|
||||
c10::DispatchKey::CompositeImplicitAutograd},
|
||||
{"CompositeImplicitAutogradNestedTensor",
|
||||
c10::DispatchKey::CompositeImplicitAutogradNestedTensor},
|
||||
{"CompositeExplicitAutograd",
|
||||
c10::DispatchKey::CompositeExplicitAutograd},
|
||||
{"CompositeExplicitAutogradNonFunctional",
|
||||
|
|
|
|||
|
|
@ -439,6 +439,8 @@ enum class DispatchKey : uint16_t {
|
|||
Autograd,
|
||||
CompositeImplicitAutograd, // registered at
|
||||
// build/aten/src/ATen/RegisterCompositeImplicitAutograd.cpp
|
||||
CompositeImplicitAutogradNestedTensor, // registered at
|
||||
// build/aten/src/ATen/RegisterCompositeImplicitAutogradNestedTensor.cpp
|
||||
CompositeExplicitAutograd, // registered at
|
||||
// build/aten/src/ATen/RegisterCompositeExplicitAutograd.cpp
|
||||
// See Note [CompositeExplicitAutogradNonFunctional Key]
|
||||
|
|
|
|||
|
|
@ -55,6 +55,11 @@ constexpr DispatchKeySet math_dispatch_keyset = backend_dispatch_keyset |
|
|||
// math_dispatch_keyset
|
||||
DispatchKeySet{DispatchKey::NestedTensor};
|
||||
|
||||
constexpr DispatchKeySet nested_dispatch_keyset =
|
||||
DispatchKeySet(
|
||||
{DispatchKey::AutogradNestedTensor, DispatchKey::NestedTensor}) |
|
||||
DispatchKeySet(DispatchKeySet::RAW, full_backend_mask);
|
||||
|
||||
DispatchKeySet getRuntimeDispatchKeySet(DispatchKey t) {
|
||||
TORCH_INTERNAL_ASSERT(t != DispatchKey::Undefined);
|
||||
switch (t) {
|
||||
|
|
@ -67,6 +72,8 @@ DispatchKeySet getRuntimeDispatchKeySet(DispatchKey t) {
|
|||
DispatchKeySet(DispatchKeySet::RAW, full_backend_mask);
|
||||
case DispatchKey::CompositeImplicitAutograd:
|
||||
return math_dispatch_keyset;
|
||||
case DispatchKey::CompositeImplicitAutogradNestedTensor:
|
||||
return nested_dispatch_keyset;
|
||||
case DispatchKey::CompositeExplicitAutograd:
|
||||
return backend_dispatch_keyset;
|
||||
case DispatchKey::CompositeExplicitAutogradNonFunctional:
|
||||
|
|
@ -84,6 +91,9 @@ bool runtimeDispatchKeySetHas(DispatchKey t, DispatchKey k) {
|
|||
case DispatchKey::CompositeImplicitAutograd:
|
||||
// See Note [NestedTensor Not Included in Backend Keys]
|
||||
return math_dispatch_keyset.has(k);
|
||||
case DispatchKey::CompositeImplicitAutogradNestedTensor:
|
||||
// See Note [NestedTensor Not Included in Backend Keys]
|
||||
return nested_dispatch_keyset.has(k);
|
||||
case DispatchKey::CompositeExplicitAutograd:
|
||||
// See Note [NestedTensor Not Included in Backend Keys]
|
||||
return k != DispatchKey::NestedTensor && backend_dispatch_keyset.has(k);
|
||||
|
|
|
|||
|
|
@ -128,6 +128,9 @@ ALLOW_LIST = [
|
|||
("aten::nanmean.out", datetime.date(2022, 8, 30)),
|
||||
("aten::nansum", datetime.date(2022, 8, 30)),
|
||||
("aten::nansum.out", datetime.date(2022, 8, 30)),
|
||||
# nested tensor temporary auxiliary ops
|
||||
("aten::_reshape_nested", datetime.date(9999, 1, 1)),
|
||||
("aten::_reshape_nested_backward", datetime.date(9999, 1, 1)),
|
||||
("aten::sum.SymInt", datetime.date(2022, 11, 30)),
|
||||
("aten::mps_linear", datetime.date(9999, 1, 1)),
|
||||
("aten::_mps_linear", datetime.date(9999, 1, 1)),
|
||||
|
|
|
|||
|
|
@ -230,11 +230,7 @@ class TestNestedTensor(TestCase):
|
|||
|
||||
# Test non_contiguous case
|
||||
assert not nt_noncontiguous.is_contiguous()
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
r"clone_nested only supports memory format Preserve, but got Contiguous instead.",
|
||||
lambda: nt_noncontiguous.contiguous()
|
||||
)
|
||||
self.assertEqual(nt_contiguous, nt_noncontiguous.contiguous())
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_repr_string(self):
|
||||
|
|
@ -679,7 +675,6 @@ class TestNestedTensorDeviceType(TestCase):
|
|||
|
||||
@dtypes(torch.float, torch.float16)
|
||||
@skipMeta
|
||||
@torch.inference_mode()
|
||||
def test_clone(self, device, dtype):
|
||||
nt1 = self.random_nt(device, dtype, 4, (4, 4), (1, 1))
|
||||
nt2 = nt1.clone()
|
||||
|
|
@ -693,7 +688,7 @@ class TestNestedTensorDeviceType(TestCase):
|
|||
self.assertNotEqual(ub1[i], ub2[i])
|
||||
|
||||
nt1.clone(memory_format=torch.preserve_format)
|
||||
msg = "clone_nested only supports memory format Preserve, but got ChannelsLast instead."
|
||||
msg = "Nested tensor clone supports Preserve and Contiguous memory formats, called clone with memory format: ChannelsLast"
|
||||
with self.assertRaisesRegex(RuntimeError, msg):
|
||||
nt1.clone(memory_format=torch.channels_last)
|
||||
|
||||
|
|
@ -1105,7 +1100,6 @@ class TestNestedTensorDeviceType(TestCase):
|
|||
)
|
||||
|
||||
@dtypes(torch.float, torch.float16, torch.double)
|
||||
@torch.inference_mode()
|
||||
def test_transpose(self, device, dtype):
|
||||
nt = self.random_nt(device, dtype, 4, (4, 4))
|
||||
# error case: transpose nested dimension
|
||||
|
|
@ -1150,7 +1144,74 @@ class TestNestedTensorDeviceType(TestCase):
|
|||
self.assertEqual(ptT, ptT_from_ntT)
|
||||
|
||||
@dtypes(torch.float, torch.float16, torch.double)
|
||||
@torch.inference_mode()
|
||||
def test_view(self, device, dtype):
|
||||
nt = self.random_nt(device, dtype, 4, (4, 4))
|
||||
# error case: empty shape
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
r"shape '\[\]' is invalid for a nested tensor",
|
||||
lambda: nt.view(())
|
||||
)
|
||||
# error case: empty nested tensor
|
||||
nt_empty = torch.nested_tensor([])
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"empty nested tensor cannot be reshaped",
|
||||
lambda: nt_empty.view(-1)
|
||||
)
|
||||
# error case: invalid proposed shape for underlying tensors
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
r"invalid shape dimension -2",
|
||||
lambda: nt.view(-2, 2, 3)
|
||||
)
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
r"shape '\[.*\]' is invalid for input of size [0-9]+",
|
||||
lambda: nt.view(4, 2, 3)
|
||||
)
|
||||
# normal case
|
||||
x0 = torch.randn((2, 20), device=device, dtype=dtype)
|
||||
x1 = torch.randn((3, 20), device=device, dtype=dtype)
|
||||
nt = torch.nested_tensor([x0, x1])
|
||||
pt = nt.to_padded_tensor(0.0)
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
r"for now view cannot change the implicit batch dimension",
|
||||
lambda: nt.transpose(-1, -2).view(40, -1)
|
||||
)
|
||||
# inherit only the ragged dimension
|
||||
# (2, 20) -> (2, 5, 4)
|
||||
# (3, 20) -> (3, 5, 4)
|
||||
nt1 = nt.view(2, -1, 5, 4)
|
||||
# (2, 3, 20) -> (2, 3, 5, 4) -> (2, 4, 5, 4)
|
||||
pt1 = pt.view(2, -1, 5, 4)
|
||||
self.assertEqual(noncontiguous_to_padded_tensor(nt1), pt1)
|
||||
# also inherit regular dimension
|
||||
nt2 = nt1.view(2, -1, -1, 2, 2)
|
||||
pt2 = pt1.view(2, -1, 5, 2, 2)
|
||||
self.assertEqual(noncontiguous_to_padded_tensor(nt2), pt2)
|
||||
|
||||
@dtypes(torch.float, torch.float16, torch.double)
|
||||
def test_view_inference_mode_interaction(self, device, dtype):
|
||||
# Construct in default mode and view while in inference mode
|
||||
nt = torch.nested_tensor([torch.randn((2, 20)), torch.randn((3, 20))], device=device, dtype=dtype)
|
||||
with torch.inference_mode():
|
||||
ntT = nt.view(2, -1, 4, 5)
|
||||
ptT_from_ntT = noncontiguous_to_padded_tensor(ntT)
|
||||
pt = nt.to_padded_tensor(0.0)
|
||||
ptT = pt.view(2, -1, 4, 5)
|
||||
self.assertEqual(ptT, ptT_from_ntT)
|
||||
# Construct and view while in inference mode
|
||||
with torch.inference_mode():
|
||||
nt = torch.nested_tensor([torch.randn((2, 20)), torch.randn((3, 20))], device=device, dtype=dtype)
|
||||
ntT = nt.view(2, -1, 4, 5)
|
||||
ptT_from_ntT = noncontiguous_to_padded_tensor(ntT)
|
||||
pt = nt.to_padded_tensor(0.0)
|
||||
ptT = pt.view(2, -1, 4, 5)
|
||||
self.assertEqual(ptT, ptT_from_ntT)
|
||||
|
||||
@dtypes(torch.float, torch.float16, torch.double)
|
||||
def test_reshape(self, device, dtype):
|
||||
nt = self.random_nt(device, dtype, 4, (4, 4))
|
||||
# error case: empty shape
|
||||
|
|
|
|||
|
|
@ -1356,10 +1356,6 @@
|
|||
# making it impossible (hard) to detect when it is actually a view.
|
||||
# - name: reshape(Tensor self, IntArrayRef shape)
|
||||
|
||||
- name: _reshape_nested(Tensor self, int[] shape) -> Tensor
|
||||
self: _reshape_nested_backward(self, grad)
|
||||
result: auto_linear
|
||||
|
||||
- name: _reshape_alias(Tensor(a) self, int[] size, int[] stride) -> Tensor(a)
|
||||
self: grad.reshape(self.sizes())
|
||||
result: auto_linear
|
||||
|
|
@ -1732,10 +1728,14 @@
|
|||
# linear
|
||||
result1: mean(self_t, dim.value_or(IntArrayRef({})), keepdim)
|
||||
|
||||
# TODO: this derivative is not SymInt safe, need reshape_symint
|
||||
- name: view(Tensor(a) self, SymInt[] size) -> Tensor(a)
|
||||
dispatch:
|
||||
Default:
|
||||
self: grad.reshape(self.sizes())
|
||||
result: auto_linear
|
||||
AutogradNestedTensor:
|
||||
self: grad.reshape_as(self)
|
||||
result: auto_linear
|
||||
|
||||
- name: view.dtype(Tensor(a) self, ScalarType dtype) -> Tensor(a)
|
||||
output_differentiability: [False]
|
||||
|
|
|
|||
|
|
@ -310,6 +310,7 @@ def static_dispatch_keys(backends: List[BackendIndex]) -> List[DispatchKey]:
|
|||
else:
|
||||
return [backend.dispatch_key for backend in backends] + [
|
||||
DispatchKey.CompositeImplicitAutograd,
|
||||
DispatchKey.CompositeImplicitAutogradNestedTensor,
|
||||
DispatchKey.CompositeExplicitAutograd,
|
||||
DispatchKey.CompositeExplicitAutogradNonFunctional,
|
||||
]
|
||||
|
|
@ -330,6 +331,8 @@ def get_static_dispatch_backend(
|
|||
return DispatchKey.CompositeExplicitAutogradNonFunctional
|
||||
elif f.has_composite_implicit_autograd_kernel:
|
||||
return DispatchKey.CompositeImplicitAutograd
|
||||
elif f.has_composite_implicit_autograd_nested_tensor_kernel:
|
||||
return DispatchKey.CompositeImplicitAutogradNestedTensor
|
||||
return None
|
||||
|
||||
|
||||
|
|
@ -426,6 +429,8 @@ def generate_static_dispatch_fallback_call(
|
|||
return f"return {ns}::{DispatchKey.CompositeExplicitAutogradNonFunctional.lower()}::{name}({exprs});"
|
||||
elif f.has_composite_implicit_autograd_kernel:
|
||||
return f"return {ns}::{DispatchKey.CompositeImplicitAutograd.lower()}::{name}({exprs});"
|
||||
elif f.has_composite_implicit_autograd_nested_tensor_kernel:
|
||||
return f"return {ns}::{DispatchKey.CompositeImplicitAutogradNestedTensor.lower()}::{name}({exprs});"
|
||||
else:
|
||||
return f"""TORCH_CHECK(false, "Static dispatch does not support {name} for\
|
||||
{', '.join([str(index.dispatch_key)for index in backend_indices])} ");"""
|
||||
|
|
@ -1241,6 +1246,11 @@ def compute_registration_declarations(
|
|||
"dispatch": str(
|
||||
{k for k, v in backend_indices.items() if v.has_kernel(f)}
|
||||
!= {DispatchKey.CompositeImplicitAutograd}
|
||||
and {k for k, v in backend_indices.items() if v.has_kernel(f)}
|
||||
!= {
|
||||
DispatchKey.CompositeImplicitAutograd,
|
||||
DispatchKey.CompositeImplicitAutogradNestedTensor,
|
||||
}
|
||||
),
|
||||
"default": str(f.has_composite_kernel or has_autogenerated_composite_kernel(f)),
|
||||
}
|
||||
|
|
@ -2145,6 +2155,13 @@ def gen_source_files(
|
|||
ns_grouped_native_functions[namespace].append(grouped_native_function)
|
||||
|
||||
dispatch_namespace = str(dispatch_key).lower()
|
||||
|
||||
# CompositeImplicitAutogradNestdTensor does not currently user the helpers generated
|
||||
# compilation will fail when `-Werror=unused-function` flag is set
|
||||
gen_dispatch_helpers: bool = (
|
||||
dispatch_key != DispatchKey.CompositeImplicitAutogradNestedTensor
|
||||
)
|
||||
|
||||
dispatch_definitions = get_native_function_definitions(
|
||||
fm=fm,
|
||||
grouped_native_functions=grouped_native_functions,
|
||||
|
|
@ -2153,7 +2170,7 @@ def gen_source_files(
|
|||
selector=selector,
|
||||
rocm=rocm,
|
||||
skip_dispatcher_op_registration=skip_dispatcher_op_registration,
|
||||
gen_dispatch_helpers=True,
|
||||
gen_dispatch_helpers=gen_dispatch_helpers,
|
||||
)
|
||||
fm.write_with_template(
|
||||
f"Register{dispatch_key}.cpp",
|
||||
|
|
@ -2636,6 +2653,7 @@ def main() -> None:
|
|||
DispatchKey.CPU,
|
||||
DispatchKey.CUDA,
|
||||
DispatchKey.CompositeImplicitAutograd,
|
||||
DispatchKey.CompositeImplicitAutogradNestedTensor,
|
||||
DispatchKey.CompositeExplicitAutograd,
|
||||
DispatchKey.CompositeExplicitAutogradNonFunctional,
|
||||
DispatchKey.Meta,
|
||||
|
|
|
|||
|
|
@ -93,6 +93,7 @@ class DispatchKey(Enum):
|
|||
|
||||
Autograd = auto()
|
||||
CompositeImplicitAutograd = auto()
|
||||
CompositeImplicitAutogradNestedTensor = auto()
|
||||
CompositeExplicitAutograd = auto()
|
||||
CompositeExplicitAutogradNonFunctional = auto()
|
||||
|
||||
|
|
@ -217,6 +218,7 @@ dispatch_keys = [
|
|||
DispatchKey.QuantizedCPU,
|
||||
DispatchKey.QuantizedCUDA,
|
||||
DispatchKey.CompositeImplicitAutograd,
|
||||
DispatchKey.CompositeImplicitAutogradNestedTensor,
|
||||
DispatchKey.CompositeExplicitAutograd,
|
||||
DispatchKey.CompositeExplicitAutogradNonFunctional,
|
||||
DispatchKey.NestedTensorCPU,
|
||||
|
|
@ -237,6 +239,7 @@ def is_generic_dispatch_key(dk: DispatchKey) -> bool:
|
|||
DispatchKey.CompositeExplicitAutograd,
|
||||
DispatchKey.CompositeExplicitAutogradNonFunctional,
|
||||
DispatchKey.CompositeImplicitAutograd,
|
||||
DispatchKey.CompositeImplicitAutogradNestedTensor,
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -485,6 +488,7 @@ class NativeFunction:
|
|||
|
||||
# Whether or not the NativeFunction contains a backend-agnostic kernel
|
||||
has_composite_implicit_autograd_kernel: bool
|
||||
has_composite_implicit_autograd_nested_tensor_kernel: bool
|
||||
has_composite_explicit_autograd_kernel: bool
|
||||
has_composite_explicit_autograd_non_functional_kernel: bool
|
||||
|
||||
|
|
@ -699,9 +703,15 @@ class NativeFunction:
|
|||
if d == DispatchKey.CompositeExplicitAutograd
|
||||
or d == DispatchKey.CompositeExplicitAutogradNonFunctional
|
||||
or d == DispatchKey.CompositeImplicitAutograd
|
||||
or d == DispatchKey.CompositeImplicitAutogradNestedTensor
|
||||
]
|
||||
|
||||
assert len(composites_in_dispatch) <= 1, (
|
||||
assert len(composites_in_dispatch) <= 1 or (
|
||||
len(composites_in_dispatch) == 2
|
||||
and DispatchKey.CompositeImplicitAutograd in composites_in_dispatch
|
||||
and DispatchKey.CompositeImplicitAutogradNestedTensor
|
||||
in composites_in_dispatch
|
||||
), (
|
||||
"cannot specify more than one of CompositeExplicitAutograd, CompositeExplicitAutogradNonFunctional, "
|
||||
"or CompositeImplicitAutograd on a single kernel; each "
|
||||
"strictly subsumes the other. If you wanted to provide an explicit autograd "
|
||||
|
|
@ -756,11 +766,23 @@ class NativeFunction:
|
|||
# Structured functions MUST have a dispatch table
|
||||
is_abstract = True
|
||||
else:
|
||||
is_abstract = dispatch.keys() != {DispatchKey.CompositeImplicitAutograd}
|
||||
is_abstract = (
|
||||
dispatch.keys() != {DispatchKey.CompositeImplicitAutograd}
|
||||
and dispatch.keys()
|
||||
!= {DispatchKey.CompositeImplicitAutogradNestedTensor}
|
||||
and dispatch.keys()
|
||||
!= {
|
||||
DispatchKey.CompositeImplicitAutograd,
|
||||
DispatchKey.CompositeImplicitAutogradNestedTensor,
|
||||
}
|
||||
)
|
||||
|
||||
has_composite_implicit_autograd_kernel = (
|
||||
DispatchKey.CompositeImplicitAutograd in dispatch.keys()
|
||||
)
|
||||
has_composite_implicit_autograd_nested_tensor_kernel = (
|
||||
DispatchKey.CompositeImplicitAutogradNestedTensor in dispatch.keys()
|
||||
)
|
||||
has_composite_explicit_autograd_kernel = (
|
||||
DispatchKey.CompositeExplicitAutograd in dispatch.keys()
|
||||
)
|
||||
|
|
@ -808,6 +830,7 @@ class NativeFunction:
|
|||
cpp_no_default_args=cpp_no_default_args,
|
||||
is_abstract=is_abstract,
|
||||
has_composite_implicit_autograd_kernel=has_composite_implicit_autograd_kernel,
|
||||
has_composite_implicit_autograd_nested_tensor_kernel=has_composite_implicit_autograd_nested_tensor_kernel,
|
||||
has_composite_explicit_autograd_kernel=has_composite_explicit_autograd_kernel,
|
||||
has_composite_explicit_autograd_non_functional_kernel=has_composite_explicit_autograd_non_functional_kernel,
|
||||
tags=tags,
|
||||
|
|
@ -899,6 +922,9 @@ class NativeFunction:
|
|||
self.has_composite_implicit_autograd_kernel
|
||||
or self.has_composite_explicit_autograd_kernel
|
||||
or self.has_composite_explicit_autograd_non_functional_kernel
|
||||
) or (
|
||||
self.has_composite_implicit_autograd_kernel
|
||||
and self.has_composite_implicit_autograd_nested_tensor_kernel
|
||||
)
|
||||
|
||||
@property
|
||||
|
|
@ -976,7 +1002,10 @@ class NativeFunctionsGroup:
|
|||
if self.structured:
|
||||
# For now, structured composite kernels are not supported (need some
|
||||
# design work to figure out how to make the composite case work)
|
||||
assert not self.out.has_composite_implicit_autograd_kernel
|
||||
assert (
|
||||
not self.out.has_composite_implicit_autograd_kernel
|
||||
and not self.out.has_composite_implicit_autograd_nested_tensor_kernel
|
||||
)
|
||||
|
||||
assert self.functional.structured_delegate == self.out.func.name, (
|
||||
f"{self.functional.func.name} delegates to {self.functional.structured_delegate} "
|
||||
|
|
@ -2510,6 +2539,14 @@ class NativeFunctionsViewGroup:
|
|||
f"{str(self.view.func.name)} and {str(self.view_inplace.func.name)} must either"
|
||||
" both have CompositeImplicitAutograd kernels, or both not have composite kernels."
|
||||
)
|
||||
if self.view.has_composite_implicit_autograd_nested_tensor_kernel:
|
||||
if self.view_inplace is not None:
|
||||
assert (
|
||||
self.view_inplace.has_composite_implicit_autograd_nested_tensor_kernel
|
||||
), (
|
||||
f"{str(self.view.func.name)} and {str(self.view_inplace.func.name)} must either"
|
||||
" both have CompositeImplicitAutogradNestedTensor kernels, or both not have composite kernels."
|
||||
)
|
||||
|
||||
def functions(self, *, include_copy: bool = True) -> Iterator[NativeFunction]:
|
||||
yield self.view
|
||||
|
|
|
|||
|
|
@ -336,6 +336,7 @@ def generate_function(
|
|||
cpp_no_default_args=set(),
|
||||
is_abstract=f.is_abstract,
|
||||
has_composite_implicit_autograd_kernel=False,
|
||||
has_composite_implicit_autograd_nested_tensor_kernel=False,
|
||||
has_composite_explicit_autograd_kernel=True,
|
||||
has_composite_explicit_autograd_non_functional_kernel=False,
|
||||
# Every generated NativeFunction gets a "generated" tag, so it's easy to tell
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user