mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
symintify all of derivatives.yaml (#86610)
Big-bang PR to symintify **all** .sizes() calls in derivatives.yaml, which will be needed for symbolic tracing. * with the exception of `split()`, which is tougher to land because it requires internal changes. Pull Request resolved: https://github.com/pytorch/pytorch/pull/86610 Approved by: https://github.com/albanD
This commit is contained in:
parent
d7bbb61f6b
commit
34c86adec4
|
|
@ -126,7 +126,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
|
|||
OP_DECOMPOSE2(hsplit, int);
|
||||
OP_DECOMPOSE2(hsplit, array);
|
||||
OP_DECOMPOSE(hstack);
|
||||
OP_DECOMPOSE(index_select_backward);
|
||||
m.impl("index_select_backward", native::index_select_backward_symint);
|
||||
OP_DECOMPOSE(inner);
|
||||
OP_DECOMPOSE(inverse);
|
||||
OP_DECOMPOSE(instance_norm);
|
||||
|
|
@ -230,7 +230,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
|
|||
OP_DECOMPOSE2(trapezoid, dx);
|
||||
OP_DECOMPOSE2(trapz, x);
|
||||
OP_DECOMPOSE2(trapz, dx);
|
||||
OP_DECOMPOSE(value_selecting_reduction_backward);
|
||||
m.impl("value_selecting_reduction_backward", native::value_selecting_reduction_backward_symint);
|
||||
OP_DECOMPOSE(var);
|
||||
OP_DECOMPOSE2(var, dim);
|
||||
OP_DECOMPOSE(var_mean);
|
||||
|
|
|
|||
|
|
@ -12,9 +12,9 @@ template<bool inplace>
|
|||
using Ctype = typename std::conditional<inplace, Tensor&, Tensor>::type;
|
||||
|
||||
Tensor make_feature_noise(const Tensor& input) {
|
||||
auto input_sizes = input.sizes();
|
||||
auto input_sizes = input.sym_sizes();
|
||||
TORCH_CHECK(input.dim() >= 2, "Feature dropout requires at least 2 dimensions in the input");
|
||||
std::vector<int64_t> sizes;
|
||||
c10::SymDimVector sizes;
|
||||
sizes.reserve(input.dim());
|
||||
sizes.push_back(input_sizes[0]);
|
||||
sizes.push_back(input_sizes[1]);
|
||||
|
|
@ -22,7 +22,7 @@ Tensor make_feature_noise(const Tensor& input) {
|
|||
(void)i; //Suppress unused variable warning
|
||||
sizes.push_back(1);
|
||||
}
|
||||
return input.new_empty(sizes);
|
||||
return input.new_empty_symint(sizes);
|
||||
}
|
||||
|
||||
bool is_fused_kernel_acceptable(const Tensor& input, double p) {
|
||||
|
|
@ -46,7 +46,7 @@ Tensor multiply(const Tensor& input, const Tensor& noise) {
|
|||
template<bool feature_dropout, bool alpha_dropout, bool inplace, typename T>
|
||||
Ctype<inplace> _dropout_impl(T& input, double p, bool train) {
|
||||
TORCH_CHECK(p >= 0 && p <= 1, "dropout probability has to be between 0 and 1, but got ", p);
|
||||
if (p == 0 || !train || input.numel() == 0) {
|
||||
if (p == 0 || !train || input.sym_numel() == 0) {
|
||||
return input;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
#include <ATen/TensorSubclassLikeUtils.h>
|
||||
|
||||
#include <ATen/native/CPUBlas.h>
|
||||
#include <ATen/native/NonSymbolicBC.h>
|
||||
|
||||
#include <c10/util/irange.h>
|
||||
|
||||
|
|
@ -1245,8 +1246,6 @@ void _embedding_bag_cpu_out(
|
|||
fbgemm_kernel_cache);
|
||||
}
|
||||
|
||||
// Assumes all input tensors are contiguous.
|
||||
// See NOTE [ embedding_bag Native Functions ] in native_functions.yaml for details
|
||||
Tensor _embedding_bag_backward(const Tensor &grad, const Tensor &indices_,
|
||||
const Tensor &offsets_,
|
||||
const Tensor &offset2bag,
|
||||
|
|
@ -1256,6 +1255,21 @@ Tensor _embedding_bag_backward(const Tensor &grad, const Tensor &indices_,
|
|||
bool scale_grad_by_freq, int64_t mode,
|
||||
bool sparse, const c10::optional<Tensor>& per_sample_weights_opt,
|
||||
int64_t padding_idx) {
|
||||
return at::native::_embedding_bag_backward_symint(
|
||||
grad, indices_, offsets_, offset2bag, bag_size_, max_indices_, num_weights, scale_grad_by_freq, mode, sparse, per_sample_weights_opt, padding_idx);
|
||||
}
|
||||
|
||||
// Assumes all input tensors are contiguous.
|
||||
// See NOTE [ embedding_bag Native Functions ] in native_functions.yaml for details
|
||||
Tensor _embedding_bag_backward_symint(const Tensor &grad, const Tensor &indices_,
|
||||
const Tensor &offsets_,
|
||||
const Tensor &offset2bag,
|
||||
const Tensor &bag_size_,
|
||||
const Tensor &max_indices_,
|
||||
c10::SymInt num_weights,
|
||||
bool scale_grad_by_freq, int64_t mode,
|
||||
bool sparse, const c10::optional<Tensor>& per_sample_weights_opt,
|
||||
int64_t padding_idx) {
|
||||
// See [Note: hacky wrapper removal for optional tensor]
|
||||
c10::MaybeOwned<Tensor> per_sample_weights_maybe_owned = at::borrow_from_optional_tensor(per_sample_weights_opt);
|
||||
const Tensor& per_sample_weights = *per_sample_weights_maybe_owned;
|
||||
|
|
@ -1292,11 +1306,11 @@ Tensor _embedding_bag_backward(const Tensor &grad, const Tensor &indices_,
|
|||
}
|
||||
|
||||
if (sparse) {
|
||||
return at::_embedding_bag_sparse_backward(
|
||||
return at::_embedding_bag_sparse_backward_symint(
|
||||
grad, indices, offsets, offset2bag_, bag_size_, num_weights,
|
||||
scale_grad_by_freq, mode, per_sample_weights, padding_idx);
|
||||
} else {
|
||||
return at::_embedding_bag_dense_backward(
|
||||
return at::_embedding_bag_dense_backward_symint(
|
||||
grad, indices, offset2bag_, bag_size_, max_indices_, num_weights,
|
||||
scale_grad_by_freq, mode, per_sample_weights, padding_idx);
|
||||
}
|
||||
|
|
@ -1606,7 +1620,16 @@ Tensor _embedding_bag_per_sample_weights_backward_cpu(
|
|||
|
||||
Tensor _embedding_bag_sparse_backward(
|
||||
const Tensor &grad_, const Tensor &indices, const Tensor &offsets,
|
||||
const Tensor &offset2bag, const Tensor &bag_size_, int64_t num_weights,
|
||||
const Tensor &offset2bag, const Tensor &bag_size_, SymInt num_weights,
|
||||
bool scale_grad_by_freq, int64_t mode, const c10::optional<Tensor>& per_sample_weights_opt,
|
||||
int64_t padding_idx) {
|
||||
return at::native::_embedding_bag_sparse_backward_symint(grad_, indices, offsets, offset2bag, bag_size_, num_weights,
|
||||
scale_grad_by_freq, mode, per_sample_weights_opt, padding_idx);
|
||||
}
|
||||
|
||||
Tensor _embedding_bag_sparse_backward_symint(
|
||||
const Tensor &grad_, const Tensor &indices, const Tensor &offsets,
|
||||
const Tensor &offset2bag, const Tensor &bag_size_, SymInt num_weights,
|
||||
bool scale_grad_by_freq, int64_t mode, const c10::optional<Tensor>& per_sample_weights_opt,
|
||||
int64_t padding_idx) {
|
||||
// See [Note: hacky wrapper removal for optional tensor]
|
||||
|
|
@ -1628,7 +1651,7 @@ Tensor _embedding_bag_sparse_backward(
|
|||
AT_ASSERT(mode == MODE_SUM);
|
||||
index_grad.mul_(per_sample_weights.unsqueeze(1));
|
||||
}
|
||||
return native::embedding_backward_symint(index_grad, indices, c10::SymInt(num_weights), padding_idx,
|
||||
return native::embedding_backward_symint(index_grad, indices, num_weights, padding_idx,
|
||||
scale_grad_by_freq, true);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,4 +10,14 @@ namespace native {
|
|||
// In those cases, we will duplicate the signature here with non-symbolic ints, and also duplicate the C++ implementation.
|
||||
TORCH_API at::Tensor reshape(const at::Tensor& self, at::IntArrayRef proposed_shape);
|
||||
TORCH_API at::Tensor narrow(const at::Tensor& self, int64_t dim, int64_t start, int64_t length);
|
||||
TORCH_API at::Tensor _sparse_coo_tensor_unsafe(const at::Tensor & indices, const at::Tensor & values, at::IntArrayRef size, c10::optional<at::ScalarType> dtype=c10::nullopt, c10::optional<at::Layout> layout=c10::nullopt, c10::optional<at::Device> device=c10::nullopt, c10::optional<bool> pin_memory=c10::nullopt);
|
||||
// The below ops don't get a duplicated C++ implementation.
|
||||
// They are backward ops, which make them very unlikely to be called directly
|
||||
// by external code (at::native::trace_backward).
|
||||
// They get their own declaration for BC purposes however.
|
||||
TORCH_API at::Tensor _embedding_bag_backward(const at::Tensor & grad, const at::Tensor & indices, const at::Tensor & offsets, const at::Tensor & offset2bag, const at::Tensor & bag_size, const at::Tensor & maximum_indices, int64_t num_weights, bool scale_grad_by_freq, int64_t mode, bool sparse, const c10::optional<at::Tensor> & per_sample_weights, int64_t padding_idx=-1);
|
||||
TORCH_API at::Tensor _embedding_bag_sparse_backward(const at::Tensor & grad, const at::Tensor & indices, const at::Tensor & offsets, const at::Tensor & offset2bag, const at::Tensor & bag_size, int64_t num_weights, bool scale_grad_by_freq, int64_t mode, const c10::optional<at::Tensor> & per_sample_weights, int64_t padding_idx=-1);
|
||||
TORCH_API at::Tensor value_selecting_reduction_backward(const at::Tensor & grad, int64_t dim, const at::Tensor & indices, at::IntArrayRef sizes, bool keepdim);
|
||||
TORCH_API at::Tensor trace_backward(const at::Tensor & grad, at::IntArrayRef sizes);
|
||||
TORCH_API at::Tensor index_select_backward(const at::Tensor & grad, at::IntArrayRef self_sizes, int64_t dim, const at::Tensor & index);
|
||||
}}
|
||||
|
|
|
|||
|
|
@ -96,18 +96,20 @@ std::tuple<Tensor, Tensor> _pack_padded_sequence(const Tensor& _input, const Ten
|
|||
// `grad` could be on arbitrary device and of arbitrary dtype, but `_batch_sizes`
|
||||
// is guaranteed to be a CPU int64 tensor.
|
||||
// See NOTE [ device and dtype of a PackedSequence ]
|
||||
Tensor _pack_padded_sequence_backward(const Tensor& grad, at::IntArrayRef input_size, const Tensor& _batch_sizes, bool batch_first) {
|
||||
std::vector<int64_t> input_size_after_t = input_size.vec();
|
||||
Tensor _pack_padded_sequence_backward_symint(const Tensor& grad, c10::SymIntArrayRef input_size, const Tensor& _batch_sizes, bool batch_first) {
|
||||
std::vector<c10::SymInt> input_size_after_t = input_size.vec();
|
||||
if (batch_first) {
|
||||
TORCH_CHECK(input_size.size() >= 2);
|
||||
std::swap(input_size_after_t[0], input_size_after_t[1]);
|
||||
}
|
||||
auto grad_input = at::zeros(input_size_after_t, grad.options());
|
||||
auto grad_input = at::zeros_symint(input_size_after_t, grad.options());
|
||||
auto batch_sizes_t = _batch_sizes.contiguous();
|
||||
checkLongTensor(batch_sizes_t);
|
||||
|
||||
int64_t offset = 0;
|
||||
int64_t max_seq_len = batch_sizes_t.size(0);
|
||||
// NOTE: this op advertises as CompositeImplicitAutograd, but uses data_ptr().
|
||||
// we should fix this.
|
||||
auto max_seq_len = batch_sizes_t.size(0);
|
||||
int64_t * batch_sizes = batch_sizes_t.data_ptr<int64_t>();
|
||||
for (const auto i : c10::irange(max_seq_len)) {
|
||||
grad_input[i].slice(0, 0, batch_sizes[i]).copy_(grad.slice(0, offset, offset + batch_sizes[i]));
|
||||
|
|
|
|||
|
|
@ -2024,14 +2024,19 @@ bool cpu_equal(const Tensor& self, const Tensor& other) {
|
|||
return result.load();
|
||||
}
|
||||
|
||||
Tensor value_selecting_reduction_backward(const Tensor& grad, int64_t dim, const Tensor& indices, at::IntArrayRef sizes, bool keepdim) {
|
||||
return at::native::value_selecting_reduction_backward_symint(grad, dim, indices, c10::fromIntArrayRefSlow(sizes), keepdim);
|
||||
}
|
||||
|
||||
|
||||
// max(dim), min(dim), topk(dim), mode(dim), are examples of reduction
|
||||
// functions that select values. value_selecting_reduction_backward is the
|
||||
// backward function for those operators; it propagates the grad to the
|
||||
// specific value locations referred to at `indices`.
|
||||
Tensor value_selecting_reduction_backward(const Tensor& grad, int64_t dim, const Tensor& indices, IntArrayRef sizes, bool keepdim) {
|
||||
Tensor value_selecting_reduction_backward_symint(const Tensor& grad, int64_t dim, const Tensor& indices, c10::SymIntArrayRef sizes, bool keepdim) {
|
||||
auto inplace_scatter_if_not_tensor_subclass =
|
||||
[&](const Tensor& grad_out, const Tensor& indices_) {
|
||||
auto grad_in = at::zeros(sizes, grad_out.options());
|
||||
auto grad_in = at::zeros_symint(sizes, grad_out.options());
|
||||
if (areAnyTensorSubclassLike({grad, indices})) {
|
||||
return grad_in.scatter(dim, indices_, grad_out);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1257,13 +1257,17 @@ Tensor index_select_quantized_cpu_(const Tensor & self, int64_t dim, const Tenso
|
|||
return at::native::index_select_out_cpu_(self, dim, index, result);
|
||||
}
|
||||
|
||||
Tensor index_select_backward(const Tensor& grad, IntArrayRef self_sizes, int64_t dim, const Tensor& index) {
|
||||
Tensor index_select_backward(const Tensor& grad, at::IntArrayRef self_sizes, int64_t dim, const Tensor& index) {
|
||||
return at::native::index_select_backward_symint(grad, c10::fromIntArrayRefSlow(self_sizes), dim, index);
|
||||
}
|
||||
|
||||
Tensor index_select_backward_symint(const Tensor& grad, c10::SymIntArrayRef self_sizes, int64_t dim, const Tensor& index) {
|
||||
// for composite compliance, use out-of-place variant of
|
||||
// `index_add` if index tensor is a Tensor Subclass.
|
||||
if (isTensorSubclassLike(index)) {
|
||||
return grad.new_zeros(self_sizes, grad.options()).index_add(dim, index, grad);
|
||||
return grad.new_zeros_symint(self_sizes, grad.options()).index_add(dim, index, grad);
|
||||
}
|
||||
return grad.new_zeros(self_sizes, grad.options()).index_add_(dim, index, grad);
|
||||
return grad.new_zeros_symint(self_sizes, grad.options()).index_add_(dim, index, grad);
|
||||
}
|
||||
|
||||
Tensor & index_fill_(Tensor & self, int64_t dim, const Tensor & index, const Scalar& source) {
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@
|
|||
#include <ATen/SparseTensorUtils.h>
|
||||
#include <ATen/core/ATen_fwd.h>
|
||||
#include <ATen/native/IndexingUtils.h>
|
||||
#include <ATen/native/NonSymbolicBC.h>
|
||||
#include <c10/core/impl/DeviceGuardImplInterface.h>
|
||||
#include <numeric>
|
||||
|
||||
|
|
|
|||
|
|
@ -168,12 +168,16 @@ TORCH_IMPL_FUNC(triu_cpu)(const Tensor& self, int64_t k, const Tensor &result) {
|
|||
compute_triu_tril<UpperTriangle>(self, k, result);
|
||||
}
|
||||
|
||||
Tensor trace_backward(const Tensor& grad, IntArrayRef sizes) {
|
||||
Tensor trace_backward(const Tensor& grad, at::IntArrayRef sizes) {
|
||||
return at::native::trace_backward_symint(grad, c10::fromIntArrayRefSlow(sizes));
|
||||
}
|
||||
|
||||
Tensor trace_backward_symint(const Tensor& grad, c10::SymIntArrayRef sizes) {
|
||||
if (sizes.size() != 2) {
|
||||
throw std::runtime_error("expected matrix input");
|
||||
}
|
||||
|
||||
auto grad_input = at::zeros(sizes[0] * sizes[1], grad.options());
|
||||
auto grad_input = at::zeros_symint(sizes[0] * sizes[1], grad.options());
|
||||
auto indices = at::arange(0, grad_input.numel(), sizes[1] + 1, grad.options().dtype(at::kLong));
|
||||
// for composite compliance, use out-of-place variant of
|
||||
// `index_fill` if grad tensor is a Tensor Subclass.
|
||||
|
|
@ -182,7 +186,7 @@ Tensor trace_backward(const Tensor& grad, IntArrayRef sizes) {
|
|||
} else {
|
||||
grad_input.index_fill_(0, indices, grad);
|
||||
}
|
||||
return grad_input.view(sizes);
|
||||
return grad_input.view_symint(sizes);
|
||||
}
|
||||
|
||||
} // namespace native
|
||||
|
|
|
|||
|
|
@ -2079,11 +2079,15 @@
|
|||
CUDA: _embedding_bag_cuda
|
||||
autogen: _embedding_bag.out
|
||||
|
||||
- func: _embedding_bag_backward(Tensor grad, Tensor indices, Tensor offsets, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, int num_weights, bool scale_grad_by_freq, int mode, bool sparse, Tensor? per_sample_weights, int padding_idx=-1) -> Tensor
|
||||
- func: _embedding_bag_backward(Tensor grad, Tensor indices, Tensor offsets, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, SymInt num_weights, bool scale_grad_by_freq, int mode, bool sparse, Tensor? per_sample_weights, int padding_idx=-1) -> Tensor
|
||||
dispatch:
|
||||
CompositeImplicitAutograd: _embedding_bag_backward_symint
|
||||
|
||||
- func: _embedding_bag_sparse_backward(Tensor grad, Tensor indices, Tensor offsets, Tensor offset2bag, Tensor bag_size, int num_weights, bool scale_grad_by_freq, int mode, Tensor? per_sample_weights, int padding_idx=-1) -> Tensor
|
||||
- func: _embedding_bag_sparse_backward(Tensor grad, Tensor indices, Tensor offsets, Tensor offset2bag, Tensor bag_size, SymInt num_weights, bool scale_grad_by_freq, int mode, Tensor? per_sample_weights, int padding_idx=-1) -> Tensor
|
||||
dispatch:
|
||||
CompositeImplicitAutograd: _embedding_bag_sparse_backward_symint
|
||||
|
||||
- func: _embedding_bag_dense_backward(Tensor grad, Tensor indices, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, int num_weights, bool scale_grad_by_freq, int mode, Tensor? per_sample_weights, int padding_idx=-1) -> Tensor
|
||||
- func: _embedding_bag_dense_backward(Tensor grad, Tensor indices, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, SymInt num_weights, bool scale_grad_by_freq, int mode, Tensor? per_sample_weights, int padding_idx=-1) -> Tensor
|
||||
dispatch:
|
||||
CPU: _embedding_bag_dense_backward_cpu
|
||||
CUDA: _embedding_bag_dense_backward_cuda
|
||||
|
|
@ -2682,13 +2686,13 @@
|
|||
CUDA: _fft_c2r_cufft_out
|
||||
|
||||
# Standard complex to complex FFT (forward or backward)
|
||||
- func: _fft_c2c(Tensor self, int[] dim, int normalization, bool forward) -> Tensor
|
||||
- func: _fft_c2c(Tensor self, SymInt[] dim, int normalization, bool forward) -> Tensor
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU: _fft_c2c_mkl
|
||||
CUDA: _fft_c2c_cufft
|
||||
|
||||
- func: _fft_c2c.out(Tensor self, int[] dim, int normalization, bool forward, *, Tensor(a!) out) -> Tensor(a!)
|
||||
- func: _fft_c2c.out(Tensor self, SymInt[] dim, int normalization, bool forward, *, Tensor(a!) out) -> Tensor(a!)
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU: _fft_c2c_mkl_out
|
||||
|
|
@ -3342,10 +3346,12 @@
|
|||
- func: max.names_dim_max(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) max, Tensor(b!) max_values) -> (Tensor(a!) values, Tensor(b!) indices)
|
||||
device_check: NoCheck # TensorIterator
|
||||
|
||||
- func: value_selecting_reduction_backward(Tensor grad, int dim, Tensor indices, int[] sizes, bool keepdim) -> Tensor
|
||||
- func: value_selecting_reduction_backward(Tensor grad, int dim, Tensor indices, SymInt[] sizes, bool keepdim) -> Tensor
|
||||
variants: function
|
||||
device_check: NoCheck
|
||||
device_guard: False
|
||||
dispatch:
|
||||
CompositeImplicitAutograd: value_selecting_reduction_backward_symint
|
||||
|
||||
- func: amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor
|
||||
variants: function, method
|
||||
|
|
@ -6258,7 +6264,9 @@
|
|||
|
||||
- func: sparse_coo_tensor.indices_size(Tensor indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
|
||||
|
||||
- func: _sparse_coo_tensor_unsafe(Tensor indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
|
||||
- func: _sparse_coo_tensor_unsafe(Tensor indices, Tensor values, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
|
||||
dispatch:
|
||||
CompositeImplicitAutograd: _sparse_coo_tensor_unsafe_symint
|
||||
|
||||
- func: _validate_sparse_coo_tensor_args(Tensor indices, Tensor values, int[] size) -> ()
|
||||
|
||||
|
|
@ -6273,7 +6281,7 @@
|
|||
SparseCPU, SparseCUDA, SparseMeta, Meta: new_with_dims_sparse
|
||||
autogen: _sparse_coo_tensor_with_dims.out
|
||||
|
||||
- func: _sparse_coo_tensor_with_dims_and_tensors(int sparse_dim, int dense_dim, int[] size, Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
|
||||
- func: _sparse_coo_tensor_with_dims_and_tensors(SymInt sparse_dim, SymInt dense_dim, SymInt[] size, Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
|
||||
dispatch:
|
||||
SparseCPU, SparseCUDA, SparseMeta, Meta: new_with_dims_and_tensor_sparse
|
||||
autogen: _sparse_coo_tensor_with_dims_and_tensors.out
|
||||
|
|
@ -6897,7 +6905,9 @@
|
|||
CompositeExplicitAutograd: _pack_padded_sequence
|
||||
autogen: _pack_padded_sequence.out
|
||||
|
||||
- func: _pack_padded_sequence_backward(Tensor grad, int[] input_size, Tensor batch_sizes, bool batch_first) -> Tensor
|
||||
- func: _pack_padded_sequence_backward(Tensor grad, SymInt[] input_size, Tensor batch_sizes, bool batch_first) -> Tensor
|
||||
dispatch:
|
||||
CompositeImplicitAutograd: _pack_padded_sequence_backward_symint
|
||||
|
||||
- func: _pad_packed_sequence(Tensor data, Tensor batch_sizes, bool batch_first, Scalar padding_value, int total_length) -> (Tensor, Tensor)
|
||||
|
||||
|
|
@ -7745,10 +7755,12 @@
|
|||
CUDA: trace_cuda
|
||||
autogen: trace.out
|
||||
|
||||
- func: trace_backward(Tensor grad, int[] sizes) -> Tensor
|
||||
- func: trace_backward(Tensor grad, SymInt[] sizes) -> Tensor
|
||||
variants: function
|
||||
device_check: NoCheck
|
||||
device_guard: False
|
||||
dispatch:
|
||||
CompositeImplicitAutograd: trace_backward_symint
|
||||
|
||||
- func: ne.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
|
||||
structured: True
|
||||
|
|
@ -8119,10 +8131,12 @@
|
|||
- func: index_select.dimname(Tensor self, Dimname dim, Tensor index) -> Tensor
|
||||
variants: method, function
|
||||
|
||||
- func: index_select_backward(Tensor grad, int[] self_sizes, int dim, Tensor index) -> Tensor
|
||||
- func: index_select_backward(Tensor grad, SymInt[] self_sizes, int dim, Tensor index) -> Tensor
|
||||
variants: function
|
||||
device_check: NoCheck
|
||||
device_guard: False
|
||||
dispatch:
|
||||
CompositeImplicitAutograd: index_select_backward_symint
|
||||
|
||||
- func: masked_select.out(Tensor self, Tensor mask, *, Tensor(a!) out) -> Tensor(a!)
|
||||
dispatch:
|
||||
|
|
@ -8995,7 +9009,7 @@
|
|||
CPU, CUDA, Meta: unfold
|
||||
QuantizedCPU, QuantizedCUDA: unfold
|
||||
|
||||
- func: unfold_backward(Tensor grad_in, int[] input_sizes, int dim, int size, int step) -> Tensor
|
||||
- func: unfold_backward(Tensor grad_in, SymInt[] input_sizes, int dim, int size, int step) -> Tensor
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU, CUDA: unfold_backward
|
||||
|
|
@ -10533,7 +10547,7 @@
|
|||
- func: adaptive_avg_pool3d(Tensor self, int[3] output_size) -> Tensor
|
||||
python_module: nn
|
||||
|
||||
- func: _adaptive_avg_pool3d(Tensor self, int[3] output_size) -> Tensor
|
||||
- func: _adaptive_avg_pool3d(Tensor self, SymInt[3] output_size) -> Tensor
|
||||
dispatch:
|
||||
CPU: adaptive_avg_pool3d_cpu
|
||||
CUDA: adaptive_avg_pool3d_cuda
|
||||
|
|
@ -13232,7 +13246,7 @@
|
|||
dispatch:
|
||||
CompositeExplicitAutograd: alias_copy_out
|
||||
|
||||
- func: to_padded_tensor(Tensor self, float padding, int[]? output_size=None) -> Tensor
|
||||
- func: to_padded_tensor(Tensor self, float padding, SymInt[]? output_size=None) -> Tensor
|
||||
variants: method
|
||||
dispatch:
|
||||
NestedTensorCPU: NestedTensor_to_padded_tensor_generic
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/NestedTensorImpl.h>
|
||||
#include <ATen/native/NonSymbolicBC.h>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/NativeFunctions.h>
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@
|
|||
#include <ATen/SparseTensorImpl.h>
|
||||
#include <ATen/SparseTensorUtils.h>
|
||||
#include <ATen/native/IndexingUtils.h>
|
||||
#include <ATen/native/NonSymbolicBC.h>
|
||||
#include <ATen/NamedTensorUtils.h>
|
||||
|
||||
#include <ATen/native/Copy.h>
|
||||
|
|
@ -407,13 +408,34 @@ Tensor sparse_coo_tensor(const Tensor& indices, const Tensor& values, IntArrayRe
|
|||
options.pinned_memory_opt());
|
||||
}
|
||||
|
||||
Tensor _sparse_coo_tensor_unsafe(const Tensor& indices, const Tensor& values_, at::IntArrayRef size,
|
||||
c10::optional<ScalarType> dtype,
|
||||
c10::optional<Layout> layout,
|
||||
c10::optional<Device> device,
|
||||
c10::optional<bool> pin_memory) {
|
||||
return at::native::_sparse_coo_tensor_unsafe_symint(indices, values_, c10::fromIntArrayRefSlow(size), dtype, layout, device, pin_memory);
|
||||
|
||||
Tensor values = expand_values_if_needed(values_);
|
||||
|
||||
auto sparse_dim = indices.size(0);
|
||||
auto dense_dim = values.dim() - 1;
|
||||
|
||||
return at::_sparse_coo_tensor_with_dims_and_tensors(
|
||||
sparse_dim,
|
||||
dense_dim,
|
||||
size,
|
||||
indices,
|
||||
values,
|
||||
values.options().layout(kSparse));
|
||||
}
|
||||
|
||||
// NOTE: _sparse_coo_tensor_unsafe() differs from sparse_coo_tensor()
|
||||
// in that we don't check whether any indices are out of boundaries of `size`, thus avoiding a
|
||||
// copy from CUDA to CPU. However, this function should ONLY be used where we know that the indices
|
||||
// are guaranteed to be within bounds or if the caller is going to call
|
||||
// _validate_sparse_coo_tensor_args before using the tensor.
|
||||
// NB: Got rid of the size == NULL case
|
||||
Tensor _sparse_coo_tensor_unsafe(const Tensor& indices, const Tensor& values_, IntArrayRef size,
|
||||
Tensor _sparse_coo_tensor_unsafe_symint(const Tensor& indices, const Tensor& values_, c10::SymIntArrayRef size,
|
||||
c10::optional<ScalarType> dtype,
|
||||
c10::optional<Layout> layout,
|
||||
c10::optional<Device> device,
|
||||
|
|
@ -422,10 +444,10 @@ Tensor _sparse_coo_tensor_unsafe(const Tensor& indices, const Tensor& values_, I
|
|||
|
||||
Tensor values = expand_values_if_needed(values_);
|
||||
|
||||
int64_t sparse_dim = indices.size(0);
|
||||
int64_t dense_dim = values.dim() - 1;
|
||||
auto sparse_dim = indices.sym_size(0);
|
||||
auto dense_dim = values.dim() - 1;
|
||||
|
||||
return at::_sparse_coo_tensor_with_dims_and_tensors(
|
||||
return at::_sparse_coo_tensor_with_dims_and_tensors_symint(
|
||||
sparse_dim,
|
||||
dense_dim,
|
||||
size,
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@
|
|||
#include <ATen/cuda/ThrustAllocator.h>
|
||||
#include <ATen/native/sparse/cuda/SparseCUDAApplyUtils.cuh>
|
||||
#include <ATen/native/cuda/SortingCommon.cuh>
|
||||
#include <ATen/native/NonSymbolicBC.h>
|
||||
#include <ATen/SparseTensorUtils.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/accumulate.h>
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@
|
|||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/detail/KernelUtils.h>
|
||||
#include <ATen/cuda/detail/IndexUtils.cuh>
|
||||
#include <ATen/native/NonSymbolicBC.h>
|
||||
#include <ATen/native/cuda/Loops.cuh>
|
||||
#include <ATen/native/cuda/MemoryAccess.cuh>
|
||||
#include <ATen/native/cuda/PersistentSoftmax.cuh>
|
||||
|
|
|
|||
|
|
@ -961,7 +961,6 @@ aot_autograd_failures = {
|
|||
|
||||
symbolic_aot_autograd_failures = {
|
||||
xfail('__rmatmul__', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('addbmm', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('addcdiv', ''), # aten.fill_.Scalar - couldn't find symbolic meta function/decomposition
|
||||
xfail('addmv', ''), # aten.addmv.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('addr', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
|
|
@ -1022,7 +1021,6 @@ symbolic_aot_autograd_failures = {
|
|||
xfail('hsplit', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('i0', ''), # aten.i0.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('index_put', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('index_select', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('inner', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('kron', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('kthvalue', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
|
|
@ -1093,7 +1091,6 @@ symbolic_aot_autograd_failures = {
|
|||
xfail('matrix_exp', ''), # aten.linalg_matrix_exp.default - couldn't find symbolic meta function/decompo...
|
||||
xfail('max', 'reduction_no_dim'), # aten.logical_or_.default - couldn't find symbolic meta function/dec...
|
||||
xfail('max', 'reduction_with_dim'), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('mean', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('median', ''), # could not find kernel
|
||||
xfail('meshgrid', 'list_of_tensors'), # Cannot call numel() on tensor with symbolic sizes/strides
|
||||
xfail('meshgrid', 'variadic_tensors'), # Cannot call numel() on tensor with symbolic sizes/strides
|
||||
|
|
@ -1144,8 +1141,6 @@ symbolic_aot_autograd_failures = {
|
|||
xfail('nn.functional.interpolate', 'linear'), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('nn.functional.interpolate', 'nearest'), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('nn.functional.interpolate', 'trilinear'), # Cannot call sizes() on tensor with symbolic sizes/st...
|
||||
xfail('nn.functional.kl_div', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('nn.functional.l1_loss', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('nn.functional.local_response_norm', ''), # aten.fill.Scalar - couldn't find symbolic meta functio...
|
||||
xfail('nn.functional.max_pool1d', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('nn.functional.max_pool2d', ''), # aten.max_pool2d_with_indices_backward.default - couldn't find s...
|
||||
|
|
@ -1159,7 +1154,6 @@ symbolic_aot_autograd_failures = {
|
|||
xfail('nn.functional.mse_loss', ''), # Unable to cast Python instance to C++ type (#define PYBIND11_DETA...
|
||||
xfail('nn.functional.multi_margin_loss', ''), # could not find kernel
|
||||
xfail('nn.functional.multilabel_margin_loss', ''), # could not find kernel
|
||||
xfail('nn.functional.multilabel_soft_margin_loss', ''), # Cannot call sizes() on tensor with symbolic si...
|
||||
xfail('nn.functional.nll_loss', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('nn.functional.normalize', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('nn.functional.pad', 'circular'), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
|
|
@ -1169,12 +1163,9 @@ symbolic_aot_autograd_failures = {
|
|||
xfail('nn.functional.pdist', ''), # could not find kernel
|
||||
xfail('nn.functional.pixel_shuffle', ''), # aten.pixel_shuffle.default - couldn't find symbolic meta fun...
|
||||
xfail('nn.functional.pixel_unshuffle', ''), # aten.pixel_unshuffle.default - couldn't find symbolic meta...
|
||||
xfail('nn.functional.poisson_nll_loss', ''), # aten.add_.Tensor - couldn't find symbolic meta function/d...
|
||||
xfail('nn.functional.prelu', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('nn.functional.rrelu', ''), # aten.rrelu_with_noise.default - couldn't find symbolic meta function...
|
||||
xfail('nn.functional.smooth_l1_loss', ''), # could not find kernel
|
||||
xfail('nn.functional.triplet_margin_loss', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('nn.functional.triplet_margin_with_distance_loss', ''), # Cannot call sizes() on tensor with symbo...
|
||||
xfail('nn.functional.unfold', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('nn.functional.upsample_bilinear', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('nn.functional.upsample_nearest', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
|
|
@ -1236,8 +1227,6 @@ symbolic_aot_autograd_failures = {
|
|||
xfail('triangular_solve', ''), # aten.triangular_solve.default - couldn't find symbolic meta function/de...
|
||||
xfail('unbind', ''), # tensor_split() received an invalid combination of arguments - got (FakeTensor, torch...
|
||||
xfail('unflatten', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('unfold', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('unfold_copy', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('var', ''), # Cannot call numel() on tensor with symbolic sizes/strides
|
||||
xfail('var_mean', ''), # Cannot call numel() on tensor with symbolic sizes/strides
|
||||
xfail('view_as_complex', ''), # aten.view_as_complex.default - couldn't find symbolic meta function/deco...
|
||||
|
|
|
|||
|
|
@ -5299,7 +5299,7 @@ for shape in [(1,), ()]:
|
|||
out.grad_fn._raw_saved_indices[0].register_hooks(lambda x: x, lambda x: x)
|
||||
|
||||
out = a.mean()
|
||||
self.assertEqual(out.grad_fn._saved_self_sizes, a.shape) # IntArrayRef -> Tuple[int]
|
||||
self.assertEqual(out.grad_fn._saved_self_sym_sizes, a.shape) # IntArrayRef -> Tuple[int]
|
||||
|
||||
a = torch.ones(2, 2, requires_grad=True)
|
||||
out = a * a
|
||||
|
|
@ -5328,8 +5328,8 @@ for shape in [(1,), ()]:
|
|||
|
||||
a = torch.ones(1, 3, 3, requires_grad=True)
|
||||
out = torch.addbmm(a.squeeze(0), a, a)
|
||||
self.assertEqual(out.grad_fn._saved_batch1_argsize_0, 1) # int64_t
|
||||
self.assertEqual(out.grad_fn._saved_batch1_argsize_1, 3) # int64_t
|
||||
self.assertEqual(out.grad_fn._saved_batch1_sym_argsize_0, 1) # int64_t
|
||||
self.assertEqual(out.grad_fn._saved_batch1_sym_argsize_1, 3) # int64_t
|
||||
|
||||
a = torch.ones(1, 1, 3, 3, requires_grad=True)
|
||||
out = torch.nn.functional.unfold(a, 3)
|
||||
|
|
|
|||
|
|
@ -1190,7 +1190,6 @@ symbolic_tensor_failures = {
|
|||
xfail('nn.functional.dropout', ''), # Tensors of type TensorImpl do not have numel
|
||||
xfail('nn.functional.embedding_bag', ''), # aten._embedding_bag_forward_only.default - couldn't find symbolic meta fun...
|
||||
xfail('nn.functional.embedding', ''), # argument 'size' must be tuple of ints, but found element of type tor...
|
||||
xfail('nn.functional.feature_alpha_dropout', 'with_train'), # Tensors of type TensorImpl do not have numel
|
||||
xfail('nn.functional.fractional_max_pool2d', ''), # argument 'size' must be tuple of ints, but found element of t...
|
||||
xfail('nn.functional.fractional_max_pool3d', ''), # argument 'size' must be tuple of ints, but found element of t...
|
||||
xfail('nn.functional.grid_sample', ''), # aten.grid_sampler_2d.default - couldn't find symbolic meta function/decompos...
|
||||
|
|
|
|||
|
|
@ -220,8 +220,8 @@
|
|||
|
||||
- name: addbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
|
||||
self: maybe_multiply(grad, beta.conj())
|
||||
batch1: maybe_multiply(grad.unsqueeze(0).expand({ batch1.size(0), batch1.size(1), batch2.size(2) }).bmm(batch2.transpose(1, 2).conj()), alpha.conj())
|
||||
batch2: maybe_multiply(batch1.transpose(1, 2).conj().bmm(grad.unsqueeze(0).expand({ batch1.size(0), batch1.size(1), batch2.size(2) })), alpha.conj())
|
||||
batch1: maybe_multiply(grad.unsqueeze(0).expand_symint({ batch1.sym_size(0), batch1.sym_size(1), batch2.sym_size(2) }).bmm(batch2.transpose(1, 2).conj()), alpha.conj())
|
||||
batch2: maybe_multiply(batch1.transpose(1, 2).conj().bmm(grad.unsqueeze(0).expand_symint({ batch1.sym_size(0), batch1.sym_size(1), batch2.sym_size(2) })), alpha.conj())
|
||||
result: maybe_multiply(self_t, beta) + maybe_multiply(batch1_t.bmm(batch2_p).sum(0), alpha) + maybe_multiply(batch1_p.bmm(batch2_t).sum(0), alpha)
|
||||
|
||||
- name: addcdiv(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor
|
||||
|
|
@ -558,11 +558,11 @@
|
|||
- name: div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor
|
||||
self: div_tensor_self_backward(grad, other, self.scalar_type(), rounding_mode)
|
||||
other: div_tensor_other_backward(grad, self, other, rounding_mode)
|
||||
result: "rounding_mode.has_value() ? result.new_zeros(result.sizes()) : self_t / other_p - other_t * (self_p / other_p) / other_p"
|
||||
result: "rounding_mode.has_value() ? result.new_zeros_symint(result.sym_sizes()) : self_t / other_p - other_t * (self_p / other_p) / other_p"
|
||||
|
||||
- name: div.Scalar_mode(Tensor self, Scalar other, *, str? rounding_mode) -> Tensor
|
||||
self: div_tensor_self_backward(grad, at::lift_fresh(at::scalar_to_tensor(other)), self.scalar_type(), rounding_mode)
|
||||
result: "rounding_mode.has_value() ? result.new_zeros(result.sizes()) : self_t / other"
|
||||
result: "rounding_mode.has_value() ? result.new_zeros_symint(result.sym_sizes()) : self_t / other"
|
||||
|
||||
- name: dot(Tensor self, Tensor tensor) -> Tensor
|
||||
self: grad * tensor.conj()
|
||||
|
|
@ -828,7 +828,7 @@
|
|||
result: at::_index_put_impl_(self_t, indices, values_t, accumulate, unsafe)
|
||||
|
||||
- name: index_select(Tensor self, int dim, Tensor index) -> Tensor
|
||||
self: index_select_backward(grad, self.sizes(), dim, index)
|
||||
self: index_select_backward_symint(grad, self.sym_sizes(), dim, index)
|
||||
index: non_differentiable
|
||||
result: auto_linear
|
||||
|
||||
|
|
@ -845,7 +845,7 @@
|
|||
self: non_differentiable
|
||||
|
||||
- name: kthvalue(Tensor self, int k, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices)
|
||||
self: value_selecting_reduction_backward(grad, dim, indices, self.sizes(), keepdim)
|
||||
self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), keepdim)
|
||||
values: gather_with_keepdimed_indices(self_t, dim, indices, keepdim)
|
||||
|
||||
- name: le_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
|
||||
|
|
@ -999,10 +999,10 @@
|
|||
result: linalg_lu_solve_jvp(result, LU_p, pivots, LU_t, B_t, left, adjoint)
|
||||
|
||||
- name: lu_unpack(Tensor LU_data, Tensor LU_pivots, bool unpack_data=True, bool unpack_pivots=True) -> (Tensor P, Tensor L, Tensor U)
|
||||
LU_data: lu_unpack_backward(grad_L, grad_U, LU_data.size(-2), LU_data.size(-1))
|
||||
LU_data: lu_unpack_backward(grad_L, grad_U, LU_data.sym_size(-2), LU_data.sym_size(-1))
|
||||
LU_pivots: non_differentiable
|
||||
L: "LU_data_t.size(-2) >= LU_data_t.size(-1) ? LU_data_t.tril(-1) : LU_data_t.narrow(-1, 0, LU_data_t.size(-2)).tril(-1)"
|
||||
U: "LU_data_t.size(-1) >= LU_data_t.size(-2) ? LU_data_t.triu() : LU_data_t.narrow(-2, 0, LU_data_t.size(-1)).triu()"
|
||||
L: "LU_data_t.sym_size(-2) >= LU_data_t.sym_size(-1) ? LU_data_t.tril(-1) : LU_data_t.narrow_symint(-1, 0, LU_data_t.sym_size(-2)).tril(-1)"
|
||||
U: "LU_data_t.sym_size(-1) >= LU_data_t.sym_size(-2) ? LU_data_t.triu() : LU_data_t.narrow_symint(-2, 0, LU_data_t.sym_size(-1)).triu()"
|
||||
output_differentiability: [False, True, True]
|
||||
|
||||
- name: masked_fill.Scalar(Tensor self, Tensor mask, Scalar value) -> Tensor
|
||||
|
|
@ -1018,7 +1018,7 @@
|
|||
|
||||
- name: masked_scatter(Tensor self, Tensor mask, Tensor source) -> Tensor
|
||||
self: grad.masked_fill(mask, 0)
|
||||
source: masked_scatter_backward(grad, mask, source.sizes())
|
||||
source: masked_scatter_backward(grad, mask, source.sym_sizes())
|
||||
mask: non_differentiable
|
||||
result: self_t.masked_scatter(mask, source_t)
|
||||
|
||||
|
|
@ -1032,7 +1032,7 @@
|
|||
result: linalg_matrix_exp_differential(self_p, self_t, /*adjoint*/ false)
|
||||
|
||||
- name: max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)
|
||||
self: value_selecting_reduction_backward(grad, dim, indices, self.sizes(), keepdim)
|
||||
self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), keepdim)
|
||||
values: gather_with_keepdimed_indices(self_t, dim, indices, keepdim)
|
||||
|
||||
- name: max(Tensor self) -> Tensor
|
||||
|
|
@ -1050,7 +1050,7 @@
|
|||
result: other_t + (self_p > other_p).logical_or_(other_p.isnan()) * (self_t - other_t)
|
||||
|
||||
- name: mean(Tensor self, *, ScalarType? dtype=None) -> Tensor
|
||||
self: grad.expand(self.sizes()) / self.numel()
|
||||
self: grad.expand_symint(self.sym_sizes()) / self.sym_numel()
|
||||
result: auto_linear
|
||||
|
||||
- name: mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
|
||||
|
|
@ -1080,15 +1080,15 @@
|
|||
# The backward implementation is correct in the sense that it returns the
|
||||
# subgradient on one side.
|
||||
- name: median.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)
|
||||
self: value_selecting_reduction_backward(grad, dim, indices, self.sizes(), keepdim)
|
||||
self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), keepdim)
|
||||
values: gather_with_keepdimed_indices(self_t, dim, indices, keepdim)
|
||||
|
||||
- name: nanmedian.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)
|
||||
self: value_selecting_reduction_backward(grad, dim, indices, self.sizes(), keepdim)
|
||||
self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), keepdim)
|
||||
values: gather_with_keepdimed_indices(self_t, dim, indices, keepdim)
|
||||
|
||||
- name: min.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)
|
||||
self: value_selecting_reduction_backward(grad, dim, indices, self.sizes(), keepdim)
|
||||
self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), keepdim)
|
||||
values: gather_with_keepdimed_indices(self_t, dim, indices, keepdim)
|
||||
|
||||
- name: min(Tensor self) -> Tensor
|
||||
|
|
@ -1119,7 +1119,7 @@
|
|||
result: at::mm(self_t, mat2_p) + at::mm(self_p, mat2_t)
|
||||
|
||||
- name: mode(Tensor self, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices)
|
||||
self: value_selecting_reduction_backward(grad, dim, indices, self.sizes(), keepdim)
|
||||
self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), keepdim)
|
||||
values: gather_with_keepdimed_indices(self_t, dim, indices, keepdim)
|
||||
|
||||
- name: mul.Tensor(Tensor self, Tensor other) -> Tensor
|
||||
|
|
@ -1232,16 +1232,16 @@
|
|||
result: self_t.zero_()
|
||||
|
||||
- name: normal.Tensor_float(Tensor mean, float std=1, *, Generator? generator=None) -> Tensor
|
||||
mean: at::zeros(mean.sizes(), grad.options())
|
||||
mean: at::zeros_symint(mean.sym_sizes(), grad.options())
|
||||
result: auto_element_wise
|
||||
|
||||
- name: normal.float_Tensor(float mean, Tensor std, *, Generator? generator=None) -> Tensor
|
||||
std: at::zeros(std.sizes(), grad.options())
|
||||
std: at::zeros_symint(std.sym_sizes(), grad.options())
|
||||
result: auto_element_wise
|
||||
|
||||
- name: normal.Tensor_Tensor(Tensor mean, Tensor std, *, Generator? generator=None) -> Tensor
|
||||
mean: at::zeros(mean.sizes(), grad.options())
|
||||
std: at::zeros(std.sizes(), grad.options())
|
||||
mean: at::zeros_symint(mean.sym_sizes(), grad.options())
|
||||
std: at::zeros_symint(std.sym_sizes(), grad.options())
|
||||
result: zeros_like(mean_t)
|
||||
|
||||
- name: linalg_householder_product(Tensor input, Tensor tau) -> Tensor
|
||||
|
|
@ -1463,12 +1463,12 @@
|
|||
output_differentiability: [True, False, False, False] # LU is an auxiliary tensor not exposed to the user
|
||||
|
||||
- name: sort(Tensor self, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices)
|
||||
self: value_selecting_reduction_backward(grad, dim, indices, self.sizes(), true)
|
||||
self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), true)
|
||||
output_differentiability: [True, False]
|
||||
values: gather_with_keepdimed_indices(self_t, dim, indices, true)
|
||||
|
||||
- name: sort.stable(Tensor self, *, bool? stable, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices)
|
||||
self: value_selecting_reduction_backward(grad, dim, indices, self.sizes(), true)
|
||||
self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), true)
|
||||
output_differentiability: [True, False]
|
||||
values: gather_with_keepdimed_indices(self_t, dim, indices, true)
|
||||
|
||||
|
|
@ -1560,12 +1560,12 @@
|
|||
|
||||
# We never call _linalg_svd with compute_uv=False in an autograd context, so we don't even consider it here
|
||||
- name: _linalg_svd(Tensor A, bool full_matrices=False, bool compute_uv=True, *, str? driver=None) -> (Tensor U, Tensor S, Tensor Vh)
|
||||
A: "svd_backward(full_matrices && grad_U.defined() ? grad_U.narrow(-1, 0, S.size(-1)) : grad_U,
|
||||
A: "svd_backward(full_matrices && grad_U.defined() ? grad_U.narrow_symint(-1, 0, S.sym_size(-1)) : grad_U,
|
||||
grad_S,
|
||||
full_matrices && grad_Vh.defined() ? grad_Vh.narrow(-2, 0, S.size(-1)) : grad_Vh,
|
||||
full_matrices ? U.narrow(-1, 0, S.size(-1)) : U,
|
||||
full_matrices && grad_Vh.defined() ? grad_Vh.narrow_symint(-2, 0, S.sym_size(-1)) : grad_Vh,
|
||||
full_matrices ? U.narrow_symint(-1, 0, S.sym_size(-1)) : U,
|
||||
S,
|
||||
full_matrices ? Vh.narrow(-2, 0, S.size(-1)) : Vh)"
|
||||
full_matrices ? Vh.narrow_symint(-2, 0, S.sym_size(-1)) : Vh)"
|
||||
U, S, Vh: linalg_svd_jvp(A_t, U, S, Vh, full_matrices)
|
||||
|
||||
- name: symeig(Tensor self, bool eigenvectors=False, bool upper=True) -> (Tensor eigenvalues, Tensor eigenvectors)
|
||||
|
|
@ -1616,12 +1616,12 @@
|
|||
result: auto_element_wise
|
||||
|
||||
- name: topk(Tensor self, int k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices)
|
||||
self: value_selecting_reduction_backward(grad, dim, indices, self.sizes(), true)
|
||||
self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), true)
|
||||
output_differentiability: [True, False]
|
||||
values: gather(self_t, dim, indices)
|
||||
|
||||
- name: trace(Tensor self) -> Tensor
|
||||
self: trace_backward(grad, self.sizes())
|
||||
self: trace_backward_symint(grad, self.sym_sizes())
|
||||
result: auto_linear
|
||||
|
||||
- name: transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a)
|
||||
|
|
@ -1670,10 +1670,10 @@
|
|||
self: to_mkldnn_backward(grad, self)
|
||||
|
||||
- name: unfold(Tensor(a) self, int dimension, int size, int step) -> Tensor(a)
|
||||
self: unfold_backward(grad, self.sizes(), dimension, size, step)
|
||||
self: unfold_backward_symint(grad, self.sym_sizes(), dimension, size, step)
|
||||
result: auto_linear
|
||||
|
||||
- name: unfold_backward(Tensor grad_in, int[] input_sizes, int dim, int size, int step) -> Tensor
|
||||
- name: unfold_backward(Tensor grad_in, SymInt[] input_sizes, int dim, int size, int step) -> Tensor
|
||||
grad_in: grad.unfold(dim, size, step)
|
||||
result: auto_linear
|
||||
|
||||
|
|
@ -1772,7 +1772,7 @@
|
|||
self: grad.to_dense().sparse_mask(mask).to_dense()
|
||||
mask: non_differentiable
|
||||
|
||||
- name: _sparse_coo_tensor_with_dims_and_tensors(int sparse_dim, int dense_dim, int[] size, Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
|
||||
- name: _sparse_coo_tensor_with_dims_and_tensors(SymInt sparse_dim, SymInt dense_dim, SymInt[] size, Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
|
||||
values: sparse_constructor_values_backward(grad, indices)
|
||||
|
||||
- name: _sparse_sum.dim(Tensor self, int[1] dim) -> Tensor
|
||||
|
|
@ -1787,7 +1787,7 @@
|
|||
- name: values(Tensor(a) self) -> Tensor(a)
|
||||
dispatch:
|
||||
Default:
|
||||
self: at::_sparse_coo_tensor_unsafe(self.indices(), grad, self.sizes())._coalesced_(true)
|
||||
self: at::_sparse_coo_tensor_unsafe_symint(self.indices(), grad, self.sym_sizes())._coalesced_(true)
|
||||
AutogradNestedTensor:
|
||||
self: at::_nested_view_from_buffer(grad.contiguous(), self._nested_tensor_size(), self._nested_tensor_strides(), self._nested_tensor_offsets())
|
||||
|
||||
|
|
@ -1844,10 +1844,10 @@
|
|||
- name: _embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1) -> (Tensor, Tensor, Tensor, Tensor)
|
||||
indices: non_differentiable
|
||||
offsets: non_differentiable
|
||||
weight: _embedding_bag_backward(grad, indices, offsets, result1, result2, result3, weight.size(0), scale_grad_by_freq, mode, sparse, per_sample_weights, padding_idx)
|
||||
weight: _embedding_bag_backward_symint(grad, indices, offsets, result1, result2, result3, weight.sym_size(0), scale_grad_by_freq, mode, sparse, per_sample_weights, padding_idx)
|
||||
per_sample_weights: _embedding_bag_per_sample_weights_backward(grad, weight, indices, offsets, result1, mode, padding_idx)
|
||||
|
||||
- name: _embedding_bag_dense_backward(Tensor grad, Tensor indices, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, int num_weights, bool scale_grad_by_freq, int mode, Tensor? per_sample_weights, int padding_idx=-1) -> Tensor
|
||||
- name: _embedding_bag_dense_backward(Tensor grad, Tensor indices, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, SymInt num_weights, bool scale_grad_by_freq, int mode, Tensor? per_sample_weights, int padding_idx=-1) -> Tensor
|
||||
indices: non_differentiable
|
||||
offset2bag: non_differentiable
|
||||
bag_size: non_differentiable
|
||||
|
|
@ -2147,7 +2147,7 @@
|
|||
self: _adaptive_avg_pool2d_backward(grad, self)
|
||||
result: auto_linear
|
||||
|
||||
- name: _adaptive_avg_pool3d(Tensor self, int[3] output_size) -> Tensor
|
||||
- name: _adaptive_avg_pool3d(Tensor self, SymInt[3] output_size) -> Tensor
|
||||
self: _adaptive_avg_pool3d_backward(grad, self)
|
||||
result: auto_linear
|
||||
|
||||
|
|
@ -2220,7 +2220,7 @@
|
|||
# Note that the benchmark, deterministic, cudnn_enabled, and allow_tf32 flags are queried from the global context
|
||||
# by convolution_backward instead of being passed along from the forward pass.
|
||||
- name: _convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> Tensor
|
||||
input, weight, bias: "grad.defined() ? convolution_backward(grad, input, weight, bias->sizes(), stride, padding, dilation, transposed, output_padding, groups, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
|
||||
input, weight, bias: "grad.defined() ? convolution_backward_symint(grad, input, weight, bias->sym_sizes(), stride, padding, dilation, transposed, output_padding, groups, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
|
||||
result: _convolution_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, stride, padding, dilation, transposed, output_padding, groups, benchmark, deterministic, cudnn_enabled, allow_tf32)
|
||||
|
||||
- name: convolution_backward(Tensor grad_output, Tensor input, Tensor weight, SymInt[]? bias_sizes, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
|
||||
|
|
@ -2236,10 +2236,10 @@
|
|||
grad_output, input, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, input, stride, padding, dilation, transposed, output_padding, groups, grad_input_mask)
|
||||
|
||||
- name: slow_conv_transpose2d(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] output_padding=0, int[2] dilation=1) -> Tensor
|
||||
self, weight, bias: "grad.defined() ? convolution_backward(grad, self, weight, bias->sizes(), stride, padding, dilation, true, output_padding, 1, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
|
||||
self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, true, output_padding, 1, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
|
||||
|
||||
- name: slow_conv_transpose3d(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias=None, int[3] stride=1, int[3] padding=0, int[3] output_padding=0, int[3] dilation=1) -> Tensor
|
||||
self, weight, bias: "grad.defined() ? convolution_backward(grad, self, weight, bias->sizes(), stride, padding, dilation, true, output_padding, 1, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
|
||||
self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, true, output_padding, 1, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
|
||||
|
||||
- name: _slow_conv2d_forward(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias, int[2] stride, int[2] padding) -> Tensor
|
||||
self, weight, bias: "grad.defined() ? _slow_conv2d_backward(grad, self, weight, kernel_size, stride, padding, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
|
||||
|
|
@ -2248,19 +2248,19 @@
|
|||
grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, {{1, 1}}, false, {{0, 0}}, 1, grad_input_mask)
|
||||
|
||||
- name: _conv_depthwise2d(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias, int[2] stride, int[2] padding, int[2] dilation) -> Tensor
|
||||
self, weight, bias: "grad.defined() ? convolution_backward(grad.contiguous(), self, weight, bias->sizes(), stride, padding, dilation, /*transposed=*/ false, /*output_padding=*/ {{0, 0}}, /*groups=*/ 1, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
|
||||
self, weight, bias: "grad.defined() ? convolution_backward_symint(grad.contiguous(), self, weight, bias->sym_sizes(), stride, padding, dilation, /*transposed=*/ false, /*output_padding=*/ {{0, 0}}, /*groups=*/ 1, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
|
||||
|
||||
- name: conv_depthwise3d(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias, int[3] stride, int[3] padding, int[3] dilation) -> Tensor
|
||||
self, weight, bias: "grad.defined() ? convolution_backward(grad.contiguous(), self, weight, bias->sizes(), stride, padding, dilation, /*transposed=*/ false, /*output_padding=*/ {{0, 0, 0}}, /*groups=*/ 1, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
|
||||
self, weight, bias: "grad.defined() ? convolution_backward_symint(grad.contiguous(), self, weight, bias->sym_sizes(), stride, padding, dilation, /*transposed=*/ false, /*output_padding=*/ {{0, 0, 0}}, /*groups=*/ 1, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
|
||||
|
||||
- name: slow_conv3d_forward(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias, int[3] stride, int[3] padding) -> Tensor
|
||||
self, weight, bias: "grad.defined() ? convolution_backward(grad, self, weight, bias->sizes(), stride, padding, /*dilation=*/ {{1, 1, 1}}, false, /*output_padding=*/ {{0, 0, 0}}, 1, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
|
||||
self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, /*dilation=*/ {{1, 1, 1}}, false, /*output_padding=*/ {{0, 0, 0}}, 1, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
|
||||
|
||||
- name: slow_conv_dilated2d(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1) -> Tensor
|
||||
self, weight, bias: "grad.defined() ? convolution_backward(grad, self, weight, bias->sizes(), stride, padding, dilation, false, std::vector<int64_t>(padding.size(), 0), 1, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
|
||||
self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, false, std::vector<int64_t>(padding.size(), 0), 1, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
|
||||
|
||||
- name: slow_conv_dilated3d(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias=None, int[3] stride=1, int[3] padding=0, int[3] dilation=1) -> Tensor
|
||||
self, weight, bias: "grad.defined() ? convolution_backward(grad, self, weight, bias->sizes(), stride, padding, dilation, false, std::vector<int64_t>(padding.size(), 0), 1, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
|
||||
self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, false, std::vector<int64_t>(padding.size(), 0), 1, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
|
||||
|
||||
- name: col2im(Tensor self, SymInt[2] output_size, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor
|
||||
self: im2col(grad, kernel_size, dilation, padding, stride)
|
||||
|
|
@ -2276,7 +2276,7 @@
|
|||
result: _adaptive_avg_pool2d_backward(grad_output_t, self_p)
|
||||
|
||||
- name: _adaptive_avg_pool3d_backward(Tensor grad_output, Tensor self) -> Tensor
|
||||
grad_output: _adaptive_avg_pool3d(grad, { grad_output.size(-3), grad_output.size(-2), grad_output.size(-1) })
|
||||
grad_output: _adaptive_avg_pool3d_symint(grad, { grad_output.sym_size(-3), grad_output.sym_size(-2), grad_output.sym_size(-1) })
|
||||
self: zeros_like(self)
|
||||
result: _adaptive_avg_pool3d_backward(grad_output_t, self_p)
|
||||
|
||||
|
|
@ -2616,7 +2616,7 @@
|
|||
|
||||
- name: _nnpack_spatial_convolution(Tensor input, Tensor weight, Tensor? bias, int[2] padding, int[2] stride=1) -> Tensor
|
||||
# NNPACK does not support strided convolutions in the backwards path, which is the reason why we are using the closest available function that does here.
|
||||
input, weight, bias: "grad.defined() ? convolution_backward(grad, input, weight, bias->sizes(), stride, padding, std::vector<int64_t>(padding.size(), 1), false, std::vector<int64_t>(padding.size(), 0), 1, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
|
||||
input, weight, bias: "grad.defined() ? convolution_backward_symint(grad, input, weight, bias->sym_sizes(), stride, padding, std::vector<int64_t>(padding.size(), 1), false, std::vector<int64_t>(padding.size(), 0), 1, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
|
||||
|
||||
#LSTM MPS
|
||||
- name: _lstm_mps(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor, Tensor, Tensor, Tensor)
|
||||
|
|
@ -2648,13 +2648,13 @@
|
|||
# miopen
|
||||
|
||||
- name: miopen_convolution_transpose(Tensor self, Tensor weight, Tensor? bias, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor
|
||||
self, weight, bias: "grad.defined() ? convolution_backward(grad, self, weight, bias->sizes(), stride, padding, dilation, true, output_padding, groups, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
|
||||
self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, true, output_padding, groups, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
|
||||
|
||||
- name: miopen_convolution(Tensor self, Tensor weight, Tensor? bias, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor
|
||||
self, weight, bias: "grad.defined() ? convolution_backward(grad, self, weight, bias->sizes(), stride, padding, dilation, false, std::vector<int64_t>(padding.size(), 0), groups, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
|
||||
self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, false, std::vector<int64_t>(padding.size(), 0), groups, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
|
||||
|
||||
- name: miopen_depthwise_convolution(Tensor self, Tensor weight, Tensor? bias, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor
|
||||
self, weight, bias: "grad.defined() ? convolution_backward(grad, self, weight, bias->sizes(), stride, padding, dilation, false, std::vector<int64_t>(padding.size(), 0), groups, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
|
||||
self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, false, std::vector<int64_t>(padding.size(), 0), groups, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
|
||||
|
||||
- name: miopen_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor)
|
||||
input, weight, bias: "grad.defined() ? (training ? miopen_batch_norm_backward(input, grad.contiguous(), weight, running_mean, running_var, result1, result2, epsilon) : native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, epsilon, grad_input_mask)) : std::tuple<Tensor, Tensor, Tensor>()"
|
||||
|
|
@ -2674,7 +2674,7 @@
|
|||
|
||||
# mkldnn
|
||||
- name: mkldnn_convolution(Tensor self, Tensor weight, Tensor? bias, int[] padding, int[] stride, int[] dilation, int groups) -> Tensor
|
||||
self, weight, bias: "grad.defined() ? convolution_backward(grad, self, weight, bias->sizes(), stride, padding, dilation, /*transposed=*/ false, /*output_padding=*/ std::vector<int64_t>(padding.size(), 0), groups, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
|
||||
self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, /*transposed=*/ false, /*output_padding=*/ std::vector<int64_t>(padding.size(), 0), groups, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
|
||||
|
||||
- name: mkldnn_linear(Tensor self, Tensor weight, Tensor? bias=None) -> Tensor
|
||||
self, weight, bias: mkldnn_linear_backward(self, grad, weight, grad_input_mask)
|
||||
|
|
@ -2689,21 +2689,21 @@
|
|||
self: mkldnn_adaptive_avg_pool2d_backward(grad, self)
|
||||
|
||||
- name: _mkldnn_reshape(Tensor self, int[] shape) -> Tensor
|
||||
self: grad.reshape(self.sizes())
|
||||
self: grad.reshape_symint(self.sym_sizes())
|
||||
|
||||
# Nested Tensor
|
||||
- name: _nested_tensor_from_tensor_list(Tensor[] list, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
|
||||
list: "grad.defined()? at::unbind(grad) : std::vector<Tensor>(list.size())"
|
||||
|
||||
- name: _nested_tensor_from_mask(Tensor t, Tensor mask, bool mask_check=True) -> Tensor
|
||||
t: grad.to_padded_tensor(0, t.sizes())
|
||||
t: grad.to_padded_tensor_symint(0, t.sym_sizes())
|
||||
mask: non_differentiable
|
||||
|
||||
- name: _nested_from_padded(Tensor padded, Tensor cpu_nested_shape_example, bool fuse_transform_0213=False) -> Tensor
|
||||
padded: _nested_from_padded_backward(grad, padded, fuse_transform_0213)
|
||||
cpu_nested_shape_example: non_differentiable
|
||||
|
||||
- name: to_padded_tensor(Tensor self, float padding, int[]? output_size=None) -> Tensor
|
||||
- name: to_padded_tensor(Tensor self, float padding, SymInt[]? output_size=None) -> Tensor
|
||||
self: at::_nested_from_padded(grad, self._nested_tensor_size())
|
||||
padding: non_differentiable
|
||||
|
||||
|
|
@ -2714,15 +2714,15 @@
|
|||
|
||||
# fft
|
||||
- name: _fft_r2c(Tensor self, int[] dim, int normalization, bool onesided) -> Tensor
|
||||
self: fft_r2c_backward(grad, dim, normalization, onesided, self.size(dim.back()))
|
||||
self: fft_r2c_backward(grad, dim, normalization, onesided, self.sym_size(dim.back()))
|
||||
result: auto_linear
|
||||
|
||||
- name: _fft_c2r(Tensor self, int[] dim, int normalization, int last_dim_size) -> Tensor
|
||||
self: fft_c2r_backward(grad, dim, normalization)
|
||||
result: auto_linear
|
||||
|
||||
- name: _fft_c2c(Tensor self, int[] dim, int normalization, bool forward) -> Tensor
|
||||
self: _fft_c2c(grad, dim, normalization, !forward)
|
||||
- name: _fft_c2c(Tensor self, SymInt[] dim, int normalization, bool forward) -> Tensor
|
||||
self: _fft_c2c_symint(grad, dim, normalization, !forward)
|
||||
result: auto_linear
|
||||
|
||||
- name: unbind.int(Tensor(a -> *) self, int dim=0) -> Tensor(a)[]
|
||||
|
|
@ -2746,7 +2746,7 @@
|
|||
|
||||
# PackedSequence helpers
|
||||
- name: _pack_padded_sequence(Tensor input, Tensor lengths, bool batch_first) -> (Tensor, Tensor)
|
||||
input: _pack_padded_sequence_backward(grad, input.sizes(), result1, batch_first)
|
||||
input: _pack_padded_sequence_backward_symint(grad, input.sym_sizes(), result1, batch_first)
|
||||
|
||||
# TH wrappers
|
||||
- name: eq.Scalar(Tensor self, Scalar other) -> Tensor
|
||||
|
|
@ -2808,12 +2808,12 @@
|
|||
- name: _test_autograd_multiple_dispatch.fullcoverage(Tensor self) -> Tensor
|
||||
dispatch:
|
||||
Default:
|
||||
self: grad.expand(self.sizes()) + 1
|
||||
self: grad.expand_symint(self.sym_sizes()) + 1
|
||||
result: auto_linear
|
||||
AutogradNestedTensor:
|
||||
self: grad.mul(grad)
|
||||
AutogradCUDA:
|
||||
self: grad.expand(self.sizes()) * 2
|
||||
self: grad.expand_symint(self.sym_sizes()) * 2
|
||||
|
||||
- name: _test_autograd_multiple_dispatch.ntonly(Tensor self, bool b) -> Tensor
|
||||
dispatch:
|
||||
|
|
|
|||
|
|
@ -196,6 +196,11 @@ def meta_pad2d(self, padding):
|
|||
return self.new_empty((nbatch, nplane, output_h, output_w))
|
||||
|
||||
|
||||
@register_meta(aten.bernoulli_.float, register_dispatcher=False)
|
||||
def meta_bernoulli_(self, p=0.5, generator=None):
|
||||
return self
|
||||
|
||||
|
||||
def dot_check(self, other):
|
||||
check(
|
||||
self.dim() == 1 and other.dim() == 1,
|
||||
|
|
|
|||
|
|
@ -1622,21 +1622,21 @@ Tensor std_mean_backward(
|
|||
Tensor masked_scatter_backward(
|
||||
const Tensor& grad,
|
||||
const Tensor& mask,
|
||||
IntArrayRef sizes) {
|
||||
int64_t numel = 1;
|
||||
c10::SymIntArrayRef sizes) {
|
||||
c10::SymInt numel = 1;
|
||||
for (auto size : sizes) {
|
||||
numel *= size;
|
||||
}
|
||||
auto mask_selected = grad.masked_select(mask);
|
||||
auto diff_nelem = numel - mask_selected.numel();
|
||||
auto diff_nelem = numel - mask_selected.sym_numel();
|
||||
if (diff_nelem > 0) {
|
||||
// because mask_selected returns a 1-d tensor with size of masked elements
|
||||
// that are 1, we need to fill out the rest with zeros then reshape back to
|
||||
// tensor2's size.
|
||||
auto zeros_fillin = at::zeros({diff_nelem}, grad.options());
|
||||
auto zeros_fillin = at::zeros_symint({diff_nelem}, grad.options());
|
||||
mask_selected = at::cat({mask_selected, zeros_fillin}, 0);
|
||||
}
|
||||
return mask_selected.view(sizes);
|
||||
return mask_selected.view_symint(sizes);
|
||||
}
|
||||
|
||||
Tensor cholesky_jvp(const Tensor& dA, const Tensor& L, bool upper) {
|
||||
|
|
@ -4371,10 +4371,10 @@ Tensor fft_c2r_backward(
|
|||
|
||||
Tensor fft_r2c_backward(
|
||||
const Tensor& grad,
|
||||
IntArrayRef dim,
|
||||
at::IntArrayRef dim,
|
||||
int64_t normalization,
|
||||
bool onesided,
|
||||
int64_t last_dim_size) {
|
||||
c10::SymInt last_dim_size) {
|
||||
if (!onesided) {
|
||||
return at::real(at::_fft_c2c(grad, dim, normalization, /*forward=*/false));
|
||||
}
|
||||
|
|
@ -4389,16 +4389,17 @@ Tensor fft_r2c_backward(
|
|||
// (C2C ifft only take twosided inputs so we need to fill here)
|
||||
// 2. inverse C2C ifft
|
||||
// 3. discard the complex dim
|
||||
auto half_sizes = grad.sizes();
|
||||
at::DimVector new_grad_shape(half_sizes.begin(), half_sizes.end());
|
||||
auto half_sizes = grad.sym_sizes();
|
||||
std::vector<c10::SymInt> new_grad_shape(half_sizes.begin(), half_sizes.end());
|
||||
const auto last_dim = at::maybe_wrap_dim(dim.back(), half_sizes.size());
|
||||
new_grad_shape[last_dim] = last_dim_size;
|
||||
|
||||
const auto zero_length = last_dim_size - grad.size(dim.back());
|
||||
const auto zero_length = last_dim_size - grad.sym_size(dim.back());
|
||||
auto complex_full_grad =
|
||||
zero_length > 0 ? grad.new_zeros(new_grad_shape) : grad;
|
||||
zero_length > 0 ? grad.new_zeros_symint(new_grad_shape) : grad;
|
||||
if (zero_length > 0) {
|
||||
complex_full_grad.slice(last_dim, 0, half_sizes[last_dim]).copy_(grad);
|
||||
complex_full_grad.slice_symint(last_dim, 0, half_sizes[last_dim])
|
||||
.copy_(grad);
|
||||
}
|
||||
return at::real(
|
||||
at::_fft_c2c(complex_full_grad, dim, normalization, /*forward=*/false));
|
||||
|
|
@ -5622,8 +5623,8 @@ Tensor solve_jvp(
|
|||
Tensor lu_unpack_backward(
|
||||
const Tensor& L_grad,
|
||||
const Tensor& U_grad,
|
||||
const int64_t m,
|
||||
const int64_t n) {
|
||||
const c10::SymInt m,
|
||||
const c10::SymInt n) {
|
||||
if (!L_grad.defined() && !U_grad.defined()) {
|
||||
return {};
|
||||
}
|
||||
|
|
@ -5631,16 +5632,16 @@ Tensor lu_unpack_backward(
|
|||
|
||||
// Getters for the principal and complementary part of the matrices
|
||||
const auto get_L1 = [m, k](const Tensor& L) {
|
||||
return m == k ? L.tril(-1) : L.narrow(-2, 0, k).tril(-1);
|
||||
return m == k ? L.tril(-1) : L.narrow_symint(-2, 0, k).tril(-1);
|
||||
};
|
||||
const auto get_L2 = [m, k](const Tensor& L) {
|
||||
return L.narrow(-2, k, m - k);
|
||||
return L.narrow_symint(-2, k, m - k);
|
||||
};
|
||||
const auto get_U1 = [n, k](const Tensor& U) {
|
||||
return n == k ? U.triu() : U.narrow(-1, 0, k).triu();
|
||||
return n == k ? U.triu() : U.narrow_symint(-1, 0, k).triu();
|
||||
};
|
||||
const auto get_U2 = [n, k](const Tensor& U) {
|
||||
return U.narrow(-1, k, n - k);
|
||||
return U.narrow_symint(-1, k, n - k);
|
||||
};
|
||||
|
||||
if (L_grad.defined()) {
|
||||
|
|
@ -5657,20 +5658,22 @@ Tensor lu_unpack_backward(
|
|||
if (m >= n) {
|
||||
return L_grad.tril(-1);
|
||||
} else {
|
||||
auto size = L_grad.sizes().vec();
|
||||
auto size = L_grad.sym_sizes().vec();
|
||||
size.end()[-1] = n - m;
|
||||
return at::cat(
|
||||
{L_grad.tril(-1), at::zeros(size, L_grad.options())}, /*dim=*/-1);
|
||||
{L_grad.tril(-1), at::zeros_symint(size, L_grad.options())},
|
||||
/*dim=*/-1);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (n >= m) {
|
||||
return U_grad.triu();
|
||||
} else {
|
||||
auto size = U_grad.sizes().vec();
|
||||
auto size = U_grad.sym_sizes().vec();
|
||||
size.end()[-2] = m - n;
|
||||
return at::cat(
|
||||
{U_grad.triu(), at::zeros(size, U_grad.options())}, /*dim=*/-2);
|
||||
{U_grad.triu(), at::zeros_symint(size, U_grad.options())},
|
||||
/*dim=*/-2);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -356,7 +356,7 @@ Tensor std_mean_backward(
|
|||
at::Tensor masked_scatter_backward(
|
||||
const at::Tensor& grad,
|
||||
const at::Tensor& mask,
|
||||
at::IntArrayRef sizes);
|
||||
c10::SymIntArrayRef sizes);
|
||||
at::Tensor cholesky_backward(
|
||||
const at::Tensor& grad,
|
||||
bool upper,
|
||||
|
|
@ -668,10 +668,10 @@ Tensor fft_backward(
|
|||
IntArrayRef output_sizes);
|
||||
Tensor fft_r2c_backward(
|
||||
const Tensor& grad,
|
||||
IntArrayRef dim,
|
||||
at::IntArrayRef dim,
|
||||
int64_t normalization,
|
||||
bool onesided,
|
||||
int64_t last_dim_size);
|
||||
c10::SymInt last_dim_size);
|
||||
Tensor fft_c2r_backward(
|
||||
const Tensor& grad,
|
||||
IntArrayRef dim,
|
||||
|
|
@ -824,8 +824,8 @@ Tensor linalg_solve_jvp(
|
|||
Tensor lu_unpack_backward(
|
||||
const Tensor& L_grad,
|
||||
const Tensor& U_grad,
|
||||
const int64_t m,
|
||||
const int64_t n);
|
||||
const c10::SymInt m,
|
||||
const c10::SymInt n);
|
||||
|
||||
Tensor linalg_det_backward(
|
||||
const Tensor& grad,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user