diff --git a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp index 663d7f40bd9..dd1c4f276b0 100644 --- a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp +++ b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp @@ -126,7 +126,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) { OP_DECOMPOSE2(hsplit, int); OP_DECOMPOSE2(hsplit, array); OP_DECOMPOSE(hstack); - OP_DECOMPOSE(index_select_backward); + m.impl("index_select_backward", native::index_select_backward_symint); OP_DECOMPOSE(inner); OP_DECOMPOSE(inverse); OP_DECOMPOSE(instance_norm); @@ -230,7 +230,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) { OP_DECOMPOSE2(trapezoid, dx); OP_DECOMPOSE2(trapz, x); OP_DECOMPOSE2(trapz, dx); - OP_DECOMPOSE(value_selecting_reduction_backward); + m.impl("value_selecting_reduction_backward", native::value_selecting_reduction_backward_symint); OP_DECOMPOSE(var); OP_DECOMPOSE2(var, dim); OP_DECOMPOSE(var_mean); diff --git a/aten/src/ATen/native/Dropout.cpp b/aten/src/ATen/native/Dropout.cpp index 9c5cd25f458..95a731eebd5 100644 --- a/aten/src/ATen/native/Dropout.cpp +++ b/aten/src/ATen/native/Dropout.cpp @@ -12,9 +12,9 @@ template using Ctype = typename std::conditional::type; Tensor make_feature_noise(const Tensor& input) { - auto input_sizes = input.sizes(); + auto input_sizes = input.sym_sizes(); TORCH_CHECK(input.dim() >= 2, "Feature dropout requires at least 2 dimensions in the input"); - std::vector sizes; + c10::SymDimVector sizes; sizes.reserve(input.dim()); sizes.push_back(input_sizes[0]); sizes.push_back(input_sizes[1]); @@ -22,7 +22,7 @@ Tensor make_feature_noise(const Tensor& input) { (void)i; //Suppress unused variable warning sizes.push_back(1); } - return input.new_empty(sizes); + return input.new_empty_symint(sizes); } bool is_fused_kernel_acceptable(const Tensor& input, double p) { @@ -46,7 +46,7 @@ Tensor multiply(const Tensor& input, const Tensor& noise) { template Ctype _dropout_impl(T& input, double p, bool train) { TORCH_CHECK(p >= 0 && p <= 1, "dropout probability has to be between 0 and 1, but got ", p); - if (p == 0 || !train || input.numel() == 0) { + if (p == 0 || !train || input.sym_numel() == 0) { return input; } diff --git a/aten/src/ATen/native/EmbeddingBag.cpp b/aten/src/ATen/native/EmbeddingBag.cpp index b960377863c..ffe3a212734 100644 --- a/aten/src/ATen/native/EmbeddingBag.cpp +++ b/aten/src/ATen/native/EmbeddingBag.cpp @@ -6,6 +6,7 @@ #include #include +#include #include @@ -1245,8 +1246,6 @@ void _embedding_bag_cpu_out( fbgemm_kernel_cache); } -// Assumes all input tensors are contiguous. -// See NOTE [ embedding_bag Native Functions ] in native_functions.yaml for details Tensor _embedding_bag_backward(const Tensor &grad, const Tensor &indices_, const Tensor &offsets_, const Tensor &offset2bag, @@ -1256,6 +1255,21 @@ Tensor _embedding_bag_backward(const Tensor &grad, const Tensor &indices_, bool scale_grad_by_freq, int64_t mode, bool sparse, const c10::optional& per_sample_weights_opt, int64_t padding_idx) { + return at::native::_embedding_bag_backward_symint( + grad, indices_, offsets_, offset2bag, bag_size_, max_indices_, num_weights, scale_grad_by_freq, mode, sparse, per_sample_weights_opt, padding_idx); +} + +// Assumes all input tensors are contiguous. +// See NOTE [ embedding_bag Native Functions ] in native_functions.yaml for details +Tensor _embedding_bag_backward_symint(const Tensor &grad, const Tensor &indices_, + const Tensor &offsets_, + const Tensor &offset2bag, + const Tensor &bag_size_, + const Tensor &max_indices_, + c10::SymInt num_weights, + bool scale_grad_by_freq, int64_t mode, + bool sparse, const c10::optional& per_sample_weights_opt, + int64_t padding_idx) { // See [Note: hacky wrapper removal for optional tensor] c10::MaybeOwned per_sample_weights_maybe_owned = at::borrow_from_optional_tensor(per_sample_weights_opt); const Tensor& per_sample_weights = *per_sample_weights_maybe_owned; @@ -1292,11 +1306,11 @@ Tensor _embedding_bag_backward(const Tensor &grad, const Tensor &indices_, } if (sparse) { - return at::_embedding_bag_sparse_backward( + return at::_embedding_bag_sparse_backward_symint( grad, indices, offsets, offset2bag_, bag_size_, num_weights, scale_grad_by_freq, mode, per_sample_weights, padding_idx); } else { - return at::_embedding_bag_dense_backward( + return at::_embedding_bag_dense_backward_symint( grad, indices, offset2bag_, bag_size_, max_indices_, num_weights, scale_grad_by_freq, mode, per_sample_weights, padding_idx); } @@ -1606,7 +1620,16 @@ Tensor _embedding_bag_per_sample_weights_backward_cpu( Tensor _embedding_bag_sparse_backward( const Tensor &grad_, const Tensor &indices, const Tensor &offsets, - const Tensor &offset2bag, const Tensor &bag_size_, int64_t num_weights, + const Tensor &offset2bag, const Tensor &bag_size_, SymInt num_weights, + bool scale_grad_by_freq, int64_t mode, const c10::optional& per_sample_weights_opt, + int64_t padding_idx) { + return at::native::_embedding_bag_sparse_backward_symint(grad_, indices, offsets, offset2bag, bag_size_, num_weights, + scale_grad_by_freq, mode, per_sample_weights_opt, padding_idx); +} + +Tensor _embedding_bag_sparse_backward_symint( + const Tensor &grad_, const Tensor &indices, const Tensor &offsets, + const Tensor &offset2bag, const Tensor &bag_size_, SymInt num_weights, bool scale_grad_by_freq, int64_t mode, const c10::optional& per_sample_weights_opt, int64_t padding_idx) { // See [Note: hacky wrapper removal for optional tensor] @@ -1628,7 +1651,7 @@ Tensor _embedding_bag_sparse_backward( AT_ASSERT(mode == MODE_SUM); index_grad.mul_(per_sample_weights.unsqueeze(1)); } - return native::embedding_backward_symint(index_grad, indices, c10::SymInt(num_weights), padding_idx, + return native::embedding_backward_symint(index_grad, indices, num_weights, padding_idx, scale_grad_by_freq, true); } } diff --git a/aten/src/ATen/native/NonSymbolicBC.h b/aten/src/ATen/native/NonSymbolicBC.h index d8a11cbc320..fd4325fca6d 100644 --- a/aten/src/ATen/native/NonSymbolicBC.h +++ b/aten/src/ATen/native/NonSymbolicBC.h @@ -10,4 +10,14 @@ namespace native { // In those cases, we will duplicate the signature here with non-symbolic ints, and also duplicate the C++ implementation. TORCH_API at::Tensor reshape(const at::Tensor& self, at::IntArrayRef proposed_shape); TORCH_API at::Tensor narrow(const at::Tensor& self, int64_t dim, int64_t start, int64_t length); +TORCH_API at::Tensor _sparse_coo_tensor_unsafe(const at::Tensor & indices, const at::Tensor & values, at::IntArrayRef size, c10::optional dtype=c10::nullopt, c10::optional layout=c10::nullopt, c10::optional device=c10::nullopt, c10::optional pin_memory=c10::nullopt); +// The below ops don't get a duplicated C++ implementation. +// They are backward ops, which make them very unlikely to be called directly +// by external code (at::native::trace_backward). +// They get their own declaration for BC purposes however. +TORCH_API at::Tensor _embedding_bag_backward(const at::Tensor & grad, const at::Tensor & indices, const at::Tensor & offsets, const at::Tensor & offset2bag, const at::Tensor & bag_size, const at::Tensor & maximum_indices, int64_t num_weights, bool scale_grad_by_freq, int64_t mode, bool sparse, const c10::optional & per_sample_weights, int64_t padding_idx=-1); +TORCH_API at::Tensor _embedding_bag_sparse_backward(const at::Tensor & grad, const at::Tensor & indices, const at::Tensor & offsets, const at::Tensor & offset2bag, const at::Tensor & bag_size, int64_t num_weights, bool scale_grad_by_freq, int64_t mode, const c10::optional & per_sample_weights, int64_t padding_idx=-1); +TORCH_API at::Tensor value_selecting_reduction_backward(const at::Tensor & grad, int64_t dim, const at::Tensor & indices, at::IntArrayRef sizes, bool keepdim); +TORCH_API at::Tensor trace_backward(const at::Tensor & grad, at::IntArrayRef sizes); +TORCH_API at::Tensor index_select_backward(const at::Tensor & grad, at::IntArrayRef self_sizes, int64_t dim, const at::Tensor & index); }} diff --git a/aten/src/ATen/native/PackedSequence.cpp b/aten/src/ATen/native/PackedSequence.cpp index ec997d86aa1..736829eb6d1 100644 --- a/aten/src/ATen/native/PackedSequence.cpp +++ b/aten/src/ATen/native/PackedSequence.cpp @@ -96,18 +96,20 @@ std::tuple _pack_padded_sequence(const Tensor& _input, const Ten // `grad` could be on arbitrary device and of arbitrary dtype, but `_batch_sizes` // is guaranteed to be a CPU int64 tensor. // See NOTE [ device and dtype of a PackedSequence ] -Tensor _pack_padded_sequence_backward(const Tensor& grad, at::IntArrayRef input_size, const Tensor& _batch_sizes, bool batch_first) { - std::vector input_size_after_t = input_size.vec(); +Tensor _pack_padded_sequence_backward_symint(const Tensor& grad, c10::SymIntArrayRef input_size, const Tensor& _batch_sizes, bool batch_first) { + std::vector input_size_after_t = input_size.vec(); if (batch_first) { TORCH_CHECK(input_size.size() >= 2); std::swap(input_size_after_t[0], input_size_after_t[1]); } - auto grad_input = at::zeros(input_size_after_t, grad.options()); + auto grad_input = at::zeros_symint(input_size_after_t, grad.options()); auto batch_sizes_t = _batch_sizes.contiguous(); checkLongTensor(batch_sizes_t); int64_t offset = 0; - int64_t max_seq_len = batch_sizes_t.size(0); + // NOTE: this op advertises as CompositeImplicitAutograd, but uses data_ptr(). + // we should fix this. + auto max_seq_len = batch_sizes_t.size(0); int64_t * batch_sizes = batch_sizes_t.data_ptr(); for (const auto i : c10::irange(max_seq_len)) { grad_input[i].slice(0, 0, batch_sizes[i]).copy_(grad.slice(0, offset, offset + batch_sizes[i])); diff --git a/aten/src/ATen/native/ReduceOps.cpp b/aten/src/ATen/native/ReduceOps.cpp index 91364904193..2bb01abd51b 100644 --- a/aten/src/ATen/native/ReduceOps.cpp +++ b/aten/src/ATen/native/ReduceOps.cpp @@ -2024,14 +2024,19 @@ bool cpu_equal(const Tensor& self, const Tensor& other) { return result.load(); } +Tensor value_selecting_reduction_backward(const Tensor& grad, int64_t dim, const Tensor& indices, at::IntArrayRef sizes, bool keepdim) { + return at::native::value_selecting_reduction_backward_symint(grad, dim, indices, c10::fromIntArrayRefSlow(sizes), keepdim); +} + + // max(dim), min(dim), topk(dim), mode(dim), are examples of reduction // functions that select values. value_selecting_reduction_backward is the // backward function for those operators; it propagates the grad to the // specific value locations referred to at `indices`. -Tensor value_selecting_reduction_backward(const Tensor& grad, int64_t dim, const Tensor& indices, IntArrayRef sizes, bool keepdim) { +Tensor value_selecting_reduction_backward_symint(const Tensor& grad, int64_t dim, const Tensor& indices, c10::SymIntArrayRef sizes, bool keepdim) { auto inplace_scatter_if_not_tensor_subclass = [&](const Tensor& grad_out, const Tensor& indices_) { - auto grad_in = at::zeros(sizes, grad_out.options()); + auto grad_in = at::zeros_symint(sizes, grad_out.options()); if (areAnyTensorSubclassLike({grad, indices})) { return grad_in.scatter(dim, indices_, grad_out); } diff --git a/aten/src/ATen/native/TensorAdvancedIndexing.cpp b/aten/src/ATen/native/TensorAdvancedIndexing.cpp index d14a03c384f..7e882e38aa9 100644 --- a/aten/src/ATen/native/TensorAdvancedIndexing.cpp +++ b/aten/src/ATen/native/TensorAdvancedIndexing.cpp @@ -1257,13 +1257,17 @@ Tensor index_select_quantized_cpu_(const Tensor & self, int64_t dim, const Tenso return at::native::index_select_out_cpu_(self, dim, index, result); } -Tensor index_select_backward(const Tensor& grad, IntArrayRef self_sizes, int64_t dim, const Tensor& index) { +Tensor index_select_backward(const Tensor& grad, at::IntArrayRef self_sizes, int64_t dim, const Tensor& index) { + return at::native::index_select_backward_symint(grad, c10::fromIntArrayRefSlow(self_sizes), dim, index); +} + +Tensor index_select_backward_symint(const Tensor& grad, c10::SymIntArrayRef self_sizes, int64_t dim, const Tensor& index) { // for composite compliance, use out-of-place variant of // `index_add` if index tensor is a Tensor Subclass. if (isTensorSubclassLike(index)) { - return grad.new_zeros(self_sizes, grad.options()).index_add(dim, index, grad); + return grad.new_zeros_symint(self_sizes, grad.options()).index_add(dim, index, grad); } - return grad.new_zeros(self_sizes, grad.options()).index_add_(dim, index, grad); + return grad.new_zeros_symint(self_sizes, grad.options()).index_add_(dim, index, grad); } Tensor & index_fill_(Tensor & self, int64_t dim, const Tensor & index, const Scalar& source) { diff --git a/aten/src/ATen/native/TensorConversions.cpp b/aten/src/ATen/native/TensorConversions.cpp index 3e67fa1efed..2af35c66a0b 100644 --- a/aten/src/ATen/native/TensorConversions.cpp +++ b/aten/src/ATen/native/TensorConversions.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include diff --git a/aten/src/ATen/native/TriangularOps.cpp b/aten/src/ATen/native/TriangularOps.cpp index d5f408a74f1..f98018d7fe5 100644 --- a/aten/src/ATen/native/TriangularOps.cpp +++ b/aten/src/ATen/native/TriangularOps.cpp @@ -168,12 +168,16 @@ TORCH_IMPL_FUNC(triu_cpu)(const Tensor& self, int64_t k, const Tensor &result) { compute_triu_tril(self, k, result); } -Tensor trace_backward(const Tensor& grad, IntArrayRef sizes) { +Tensor trace_backward(const Tensor& grad, at::IntArrayRef sizes) { + return at::native::trace_backward_symint(grad, c10::fromIntArrayRefSlow(sizes)); +} + +Tensor trace_backward_symint(const Tensor& grad, c10::SymIntArrayRef sizes) { if (sizes.size() != 2) { throw std::runtime_error("expected matrix input"); } - auto grad_input = at::zeros(sizes[0] * sizes[1], grad.options()); + auto grad_input = at::zeros_symint(sizes[0] * sizes[1], grad.options()); auto indices = at::arange(0, grad_input.numel(), sizes[1] + 1, grad.options().dtype(at::kLong)); // for composite compliance, use out-of-place variant of // `index_fill` if grad tensor is a Tensor Subclass. @@ -182,7 +186,7 @@ Tensor trace_backward(const Tensor& grad, IntArrayRef sizes) { } else { grad_input.index_fill_(0, indices, grad); } - return grad_input.view(sizes); + return grad_input.view_symint(sizes); } } // namespace native diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index ba47a5ea04e..6043832e89b 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -2079,11 +2079,15 @@ CUDA: _embedding_bag_cuda autogen: _embedding_bag.out -- func: _embedding_bag_backward(Tensor grad, Tensor indices, Tensor offsets, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, int num_weights, bool scale_grad_by_freq, int mode, bool sparse, Tensor? per_sample_weights, int padding_idx=-1) -> Tensor +- func: _embedding_bag_backward(Tensor grad, Tensor indices, Tensor offsets, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, SymInt num_weights, bool scale_grad_by_freq, int mode, bool sparse, Tensor? per_sample_weights, int padding_idx=-1) -> Tensor + dispatch: + CompositeImplicitAutograd: _embedding_bag_backward_symint -- func: _embedding_bag_sparse_backward(Tensor grad, Tensor indices, Tensor offsets, Tensor offset2bag, Tensor bag_size, int num_weights, bool scale_grad_by_freq, int mode, Tensor? per_sample_weights, int padding_idx=-1) -> Tensor +- func: _embedding_bag_sparse_backward(Tensor grad, Tensor indices, Tensor offsets, Tensor offset2bag, Tensor bag_size, SymInt num_weights, bool scale_grad_by_freq, int mode, Tensor? per_sample_weights, int padding_idx=-1) -> Tensor + dispatch: + CompositeImplicitAutograd: _embedding_bag_sparse_backward_symint -- func: _embedding_bag_dense_backward(Tensor grad, Tensor indices, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, int num_weights, bool scale_grad_by_freq, int mode, Tensor? per_sample_weights, int padding_idx=-1) -> Tensor +- func: _embedding_bag_dense_backward(Tensor grad, Tensor indices, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, SymInt num_weights, bool scale_grad_by_freq, int mode, Tensor? per_sample_weights, int padding_idx=-1) -> Tensor dispatch: CPU: _embedding_bag_dense_backward_cpu CUDA: _embedding_bag_dense_backward_cuda @@ -2682,13 +2686,13 @@ CUDA: _fft_c2r_cufft_out # Standard complex to complex FFT (forward or backward) -- func: _fft_c2c(Tensor self, int[] dim, int normalization, bool forward) -> Tensor +- func: _fft_c2c(Tensor self, SymInt[] dim, int normalization, bool forward) -> Tensor variants: function dispatch: CPU: _fft_c2c_mkl CUDA: _fft_c2c_cufft -- func: _fft_c2c.out(Tensor self, int[] dim, int normalization, bool forward, *, Tensor(a!) out) -> Tensor(a!) +- func: _fft_c2c.out(Tensor self, SymInt[] dim, int normalization, bool forward, *, Tensor(a!) out) -> Tensor(a!) variants: function dispatch: CPU: _fft_c2c_mkl_out @@ -3342,10 +3346,12 @@ - func: max.names_dim_max(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) max, Tensor(b!) max_values) -> (Tensor(a!) values, Tensor(b!) indices) device_check: NoCheck # TensorIterator -- func: value_selecting_reduction_backward(Tensor grad, int dim, Tensor indices, int[] sizes, bool keepdim) -> Tensor +- func: value_selecting_reduction_backward(Tensor grad, int dim, Tensor indices, SymInt[] sizes, bool keepdim) -> Tensor variants: function device_check: NoCheck device_guard: False + dispatch: + CompositeImplicitAutograd: value_selecting_reduction_backward_symint - func: amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor variants: function, method @@ -6258,7 +6264,9 @@ - func: sparse_coo_tensor.indices_size(Tensor indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor -- func: _sparse_coo_tensor_unsafe(Tensor indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor +- func: _sparse_coo_tensor_unsafe(Tensor indices, Tensor values, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + dispatch: + CompositeImplicitAutograd: _sparse_coo_tensor_unsafe_symint - func: _validate_sparse_coo_tensor_args(Tensor indices, Tensor values, int[] size) -> () @@ -6273,7 +6281,7 @@ SparseCPU, SparseCUDA, SparseMeta, Meta: new_with_dims_sparse autogen: _sparse_coo_tensor_with_dims.out -- func: _sparse_coo_tensor_with_dims_and_tensors(int sparse_dim, int dense_dim, int[] size, Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor +- func: _sparse_coo_tensor_with_dims_and_tensors(SymInt sparse_dim, SymInt dense_dim, SymInt[] size, Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor dispatch: SparseCPU, SparseCUDA, SparseMeta, Meta: new_with_dims_and_tensor_sparse autogen: _sparse_coo_tensor_with_dims_and_tensors.out @@ -6897,7 +6905,9 @@ CompositeExplicitAutograd: _pack_padded_sequence autogen: _pack_padded_sequence.out -- func: _pack_padded_sequence_backward(Tensor grad, int[] input_size, Tensor batch_sizes, bool batch_first) -> Tensor +- func: _pack_padded_sequence_backward(Tensor grad, SymInt[] input_size, Tensor batch_sizes, bool batch_first) -> Tensor + dispatch: + CompositeImplicitAutograd: _pack_padded_sequence_backward_symint - func: _pad_packed_sequence(Tensor data, Tensor batch_sizes, bool batch_first, Scalar padding_value, int total_length) -> (Tensor, Tensor) @@ -7745,10 +7755,12 @@ CUDA: trace_cuda autogen: trace.out -- func: trace_backward(Tensor grad, int[] sizes) -> Tensor +- func: trace_backward(Tensor grad, SymInt[] sizes) -> Tensor variants: function device_check: NoCheck device_guard: False + dispatch: + CompositeImplicitAutograd: trace_backward_symint - func: ne.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) structured: True @@ -8119,10 +8131,12 @@ - func: index_select.dimname(Tensor self, Dimname dim, Tensor index) -> Tensor variants: method, function -- func: index_select_backward(Tensor grad, int[] self_sizes, int dim, Tensor index) -> Tensor +- func: index_select_backward(Tensor grad, SymInt[] self_sizes, int dim, Tensor index) -> Tensor variants: function device_check: NoCheck device_guard: False + dispatch: + CompositeImplicitAutograd: index_select_backward_symint - func: masked_select.out(Tensor self, Tensor mask, *, Tensor(a!) out) -> Tensor(a!) dispatch: @@ -8995,7 +9009,7 @@ CPU, CUDA, Meta: unfold QuantizedCPU, QuantizedCUDA: unfold -- func: unfold_backward(Tensor grad_in, int[] input_sizes, int dim, int size, int step) -> Tensor +- func: unfold_backward(Tensor grad_in, SymInt[] input_sizes, int dim, int size, int step) -> Tensor variants: function dispatch: CPU, CUDA: unfold_backward @@ -10533,7 +10547,7 @@ - func: adaptive_avg_pool3d(Tensor self, int[3] output_size) -> Tensor python_module: nn -- func: _adaptive_avg_pool3d(Tensor self, int[3] output_size) -> Tensor +- func: _adaptive_avg_pool3d(Tensor self, SymInt[3] output_size) -> Tensor dispatch: CPU: adaptive_avg_pool3d_cpu CUDA: adaptive_avg_pool3d_cuda @@ -13232,7 +13246,7 @@ dispatch: CompositeExplicitAutograd: alias_copy_out -- func: to_padded_tensor(Tensor self, float padding, int[]? output_size=None) -> Tensor +- func: to_padded_tensor(Tensor self, float padding, SymInt[]? output_size=None) -> Tensor variants: method dispatch: NestedTensorCPU: NestedTensor_to_padded_tensor_generic diff --git a/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp b/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp index a66d02770f9..33dbfa8016d 100644 --- a/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp +++ b/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp @@ -3,6 +3,7 @@ #include #include +#include #ifndef AT_PER_OPERATOR_HEADERS #include diff --git a/aten/src/ATen/native/sparse/SparseTensor.cpp b/aten/src/ATen/native/sparse/SparseTensor.cpp index e8eea30cc12..625f5b1c0b0 100644 --- a/aten/src/ATen/native/sparse/SparseTensor.cpp +++ b/aten/src/ATen/native/sparse/SparseTensor.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include @@ -407,13 +408,34 @@ Tensor sparse_coo_tensor(const Tensor& indices, const Tensor& values, IntArrayRe options.pinned_memory_opt()); } +Tensor _sparse_coo_tensor_unsafe(const Tensor& indices, const Tensor& values_, at::IntArrayRef size, + c10::optional dtype, + c10::optional layout, + c10::optional device, + c10::optional pin_memory) { + return at::native::_sparse_coo_tensor_unsafe_symint(indices, values_, c10::fromIntArrayRefSlow(size), dtype, layout, device, pin_memory); + + Tensor values = expand_values_if_needed(values_); + + auto sparse_dim = indices.size(0); + auto dense_dim = values.dim() - 1; + + return at::_sparse_coo_tensor_with_dims_and_tensors( + sparse_dim, + dense_dim, + size, + indices, + values, + values.options().layout(kSparse)); +} + // NOTE: _sparse_coo_tensor_unsafe() differs from sparse_coo_tensor() // in that we don't check whether any indices are out of boundaries of `size`, thus avoiding a // copy from CUDA to CPU. However, this function should ONLY be used where we know that the indices // are guaranteed to be within bounds or if the caller is going to call // _validate_sparse_coo_tensor_args before using the tensor. // NB: Got rid of the size == NULL case -Tensor _sparse_coo_tensor_unsafe(const Tensor& indices, const Tensor& values_, IntArrayRef size, +Tensor _sparse_coo_tensor_unsafe_symint(const Tensor& indices, const Tensor& values_, c10::SymIntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, @@ -422,10 +444,10 @@ Tensor _sparse_coo_tensor_unsafe(const Tensor& indices, const Tensor& values_, I Tensor values = expand_values_if_needed(values_); - int64_t sparse_dim = indices.size(0); - int64_t dense_dim = values.dim() - 1; + auto sparse_dim = indices.sym_size(0); + auto dense_dim = values.dim() - 1; - return at::_sparse_coo_tensor_with_dims_and_tensors( + return at::_sparse_coo_tensor_with_dims_and_tensors_symint( sparse_dim, dense_dim, size, diff --git a/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cu b/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cu index c88a9c6abfd..34a864a8fae 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cu +++ b/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cu @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include diff --git a/aten/src/ATen/native/transformers/cuda/attention.cu b/aten/src/ATen/native/transformers/cuda/attention.cu index 39dcb5baf42..68e5cc01019 100644 --- a/aten/src/ATen/native/transformers/cuda/attention.cu +++ b/aten/src/ATen/native/transformers/cuda/attention.cu @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index 822e7b351d9..514a2e1968d 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -961,7 +961,6 @@ aot_autograd_failures = { symbolic_aot_autograd_failures = { xfail('__rmatmul__', ''), # Cannot call sizes() on tensor with symbolic sizes/strides - xfail('addbmm', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('addcdiv', ''), # aten.fill_.Scalar - couldn't find symbolic meta function/decomposition xfail('addmv', ''), # aten.addmv.default - couldn't find symbolic meta function/decomposition xfail('addr', ''), # Cannot call sizes() on tensor with symbolic sizes/strides @@ -1022,7 +1021,6 @@ symbolic_aot_autograd_failures = { xfail('hsplit', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('i0', ''), # aten.i0.default - couldn't find symbolic meta function/decomposition xfail('index_put', ''), # Cannot call sizes() on tensor with symbolic sizes/strides - xfail('index_select', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('inner', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('kron', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('kthvalue', ''), # Cannot call sizes() on tensor with symbolic sizes/strides @@ -1093,7 +1091,6 @@ symbolic_aot_autograd_failures = { xfail('matrix_exp', ''), # aten.linalg_matrix_exp.default - couldn't find symbolic meta function/decompo... xfail('max', 'reduction_no_dim'), # aten.logical_or_.default - couldn't find symbolic meta function/dec... xfail('max', 'reduction_with_dim'), # Cannot call sizes() on tensor with symbolic sizes/strides - xfail('mean', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('median', ''), # could not find kernel xfail('meshgrid', 'list_of_tensors'), # Cannot call numel() on tensor with symbolic sizes/strides xfail('meshgrid', 'variadic_tensors'), # Cannot call numel() on tensor with symbolic sizes/strides @@ -1144,8 +1141,6 @@ symbolic_aot_autograd_failures = { xfail('nn.functional.interpolate', 'linear'), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('nn.functional.interpolate', 'nearest'), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('nn.functional.interpolate', 'trilinear'), # Cannot call sizes() on tensor with symbolic sizes/st... - xfail('nn.functional.kl_div', ''), # Cannot call sizes() on tensor with symbolic sizes/strides - xfail('nn.functional.l1_loss', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('nn.functional.local_response_norm', ''), # aten.fill.Scalar - couldn't find symbolic meta functio... xfail('nn.functional.max_pool1d', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('nn.functional.max_pool2d', ''), # aten.max_pool2d_with_indices_backward.default - couldn't find s... @@ -1159,7 +1154,6 @@ symbolic_aot_autograd_failures = { xfail('nn.functional.mse_loss', ''), # Unable to cast Python instance to C++ type (#define PYBIND11_DETA... xfail('nn.functional.multi_margin_loss', ''), # could not find kernel xfail('nn.functional.multilabel_margin_loss', ''), # could not find kernel - xfail('nn.functional.multilabel_soft_margin_loss', ''), # Cannot call sizes() on tensor with symbolic si... xfail('nn.functional.nll_loss', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('nn.functional.normalize', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('nn.functional.pad', 'circular'), # Cannot call sizes() on tensor with symbolic sizes/strides @@ -1169,12 +1163,9 @@ symbolic_aot_autograd_failures = { xfail('nn.functional.pdist', ''), # could not find kernel xfail('nn.functional.pixel_shuffle', ''), # aten.pixel_shuffle.default - couldn't find symbolic meta fun... xfail('nn.functional.pixel_unshuffle', ''), # aten.pixel_unshuffle.default - couldn't find symbolic meta... - xfail('nn.functional.poisson_nll_loss', ''), # aten.add_.Tensor - couldn't find symbolic meta function/d... xfail('nn.functional.prelu', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('nn.functional.rrelu', ''), # aten.rrelu_with_noise.default - couldn't find symbolic meta function... xfail('nn.functional.smooth_l1_loss', ''), # could not find kernel - xfail('nn.functional.triplet_margin_loss', ''), # Cannot call sizes() on tensor with symbolic sizes/strides - xfail('nn.functional.triplet_margin_with_distance_loss', ''), # Cannot call sizes() on tensor with symbo... xfail('nn.functional.unfold', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('nn.functional.upsample_bilinear', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('nn.functional.upsample_nearest', ''), # Cannot call sizes() on tensor with symbolic sizes/strides @@ -1236,8 +1227,6 @@ symbolic_aot_autograd_failures = { xfail('triangular_solve', ''), # aten.triangular_solve.default - couldn't find symbolic meta function/de... xfail('unbind', ''), # tensor_split() received an invalid combination of arguments - got (FakeTensor, torch... xfail('unflatten', ''), # Cannot call sizes() on tensor with symbolic sizes/strides - xfail('unfold', ''), # Cannot call sizes() on tensor with symbolic sizes/strides - xfail('unfold_copy', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('var', ''), # Cannot call numel() on tensor with symbolic sizes/strides xfail('var_mean', ''), # Cannot call numel() on tensor with symbolic sizes/strides xfail('view_as_complex', ''), # aten.view_as_complex.default - couldn't find symbolic meta function/deco... diff --git a/test/test_autograd.py b/test/test_autograd.py index 039f24ab49b..12ec7028cdf 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -5299,7 +5299,7 @@ for shape in [(1,), ()]: out.grad_fn._raw_saved_indices[0].register_hooks(lambda x: x, lambda x: x) out = a.mean() - self.assertEqual(out.grad_fn._saved_self_sizes, a.shape) # IntArrayRef -> Tuple[int] + self.assertEqual(out.grad_fn._saved_self_sym_sizes, a.shape) # IntArrayRef -> Tuple[int] a = torch.ones(2, 2, requires_grad=True) out = a * a @@ -5328,8 +5328,8 @@ for shape in [(1,), ()]: a = torch.ones(1, 3, 3, requires_grad=True) out = torch.addbmm(a.squeeze(0), a, a) - self.assertEqual(out.grad_fn._saved_batch1_argsize_0, 1) # int64_t - self.assertEqual(out.grad_fn._saved_batch1_argsize_1, 3) # int64_t + self.assertEqual(out.grad_fn._saved_batch1_sym_argsize_0, 1) # int64_t + self.assertEqual(out.grad_fn._saved_batch1_sym_argsize_1, 3) # int64_t a = torch.ones(1, 1, 3, 3, requires_grad=True) out = torch.nn.functional.unfold(a, 3) diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index 4544171f6ef..3632f725bb6 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -1190,7 +1190,6 @@ symbolic_tensor_failures = { xfail('nn.functional.dropout', ''), # Tensors of type TensorImpl do not have numel xfail('nn.functional.embedding_bag', ''), # aten._embedding_bag_forward_only.default - couldn't find symbolic meta fun... xfail('nn.functional.embedding', ''), # argument 'size' must be tuple of ints, but found element of type tor... - xfail('nn.functional.feature_alpha_dropout', 'with_train'), # Tensors of type TensorImpl do not have numel xfail('nn.functional.fractional_max_pool2d', ''), # argument 'size' must be tuple of ints, but found element of t... xfail('nn.functional.fractional_max_pool3d', ''), # argument 'size' must be tuple of ints, but found element of t... xfail('nn.functional.grid_sample', ''), # aten.grid_sampler_2d.default - couldn't find symbolic meta function/decompos... diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 91de86590f0..c0ea69edaab 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -220,8 +220,8 @@ - name: addbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor self: maybe_multiply(grad, beta.conj()) - batch1: maybe_multiply(grad.unsqueeze(0).expand({ batch1.size(0), batch1.size(1), batch2.size(2) }).bmm(batch2.transpose(1, 2).conj()), alpha.conj()) - batch2: maybe_multiply(batch1.transpose(1, 2).conj().bmm(grad.unsqueeze(0).expand({ batch1.size(0), batch1.size(1), batch2.size(2) })), alpha.conj()) + batch1: maybe_multiply(grad.unsqueeze(0).expand_symint({ batch1.sym_size(0), batch1.sym_size(1), batch2.sym_size(2) }).bmm(batch2.transpose(1, 2).conj()), alpha.conj()) + batch2: maybe_multiply(batch1.transpose(1, 2).conj().bmm(grad.unsqueeze(0).expand_symint({ batch1.sym_size(0), batch1.sym_size(1), batch2.sym_size(2) })), alpha.conj()) result: maybe_multiply(self_t, beta) + maybe_multiply(batch1_t.bmm(batch2_p).sum(0), alpha) + maybe_multiply(batch1_p.bmm(batch2_t).sum(0), alpha) - name: addcdiv(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor @@ -558,11 +558,11 @@ - name: div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor self: div_tensor_self_backward(grad, other, self.scalar_type(), rounding_mode) other: div_tensor_other_backward(grad, self, other, rounding_mode) - result: "rounding_mode.has_value() ? result.new_zeros(result.sizes()) : self_t / other_p - other_t * (self_p / other_p) / other_p" + result: "rounding_mode.has_value() ? result.new_zeros_symint(result.sym_sizes()) : self_t / other_p - other_t * (self_p / other_p) / other_p" - name: div.Scalar_mode(Tensor self, Scalar other, *, str? rounding_mode) -> Tensor self: div_tensor_self_backward(grad, at::lift_fresh(at::scalar_to_tensor(other)), self.scalar_type(), rounding_mode) - result: "rounding_mode.has_value() ? result.new_zeros(result.sizes()) : self_t / other" + result: "rounding_mode.has_value() ? result.new_zeros_symint(result.sym_sizes()) : self_t / other" - name: dot(Tensor self, Tensor tensor) -> Tensor self: grad * tensor.conj() @@ -828,7 +828,7 @@ result: at::_index_put_impl_(self_t, indices, values_t, accumulate, unsafe) - name: index_select(Tensor self, int dim, Tensor index) -> Tensor - self: index_select_backward(grad, self.sizes(), dim, index) + self: index_select_backward_symint(grad, self.sym_sizes(), dim, index) index: non_differentiable result: auto_linear @@ -845,7 +845,7 @@ self: non_differentiable - name: kthvalue(Tensor self, int k, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices) - self: value_selecting_reduction_backward(grad, dim, indices, self.sizes(), keepdim) + self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), keepdim) values: gather_with_keepdimed_indices(self_t, dim, indices, keepdim) - name: le_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) @@ -999,10 +999,10 @@ result: linalg_lu_solve_jvp(result, LU_p, pivots, LU_t, B_t, left, adjoint) - name: lu_unpack(Tensor LU_data, Tensor LU_pivots, bool unpack_data=True, bool unpack_pivots=True) -> (Tensor P, Tensor L, Tensor U) - LU_data: lu_unpack_backward(grad_L, grad_U, LU_data.size(-2), LU_data.size(-1)) + LU_data: lu_unpack_backward(grad_L, grad_U, LU_data.sym_size(-2), LU_data.sym_size(-1)) LU_pivots: non_differentiable - L: "LU_data_t.size(-2) >= LU_data_t.size(-1) ? LU_data_t.tril(-1) : LU_data_t.narrow(-1, 0, LU_data_t.size(-2)).tril(-1)" - U: "LU_data_t.size(-1) >= LU_data_t.size(-2) ? LU_data_t.triu() : LU_data_t.narrow(-2, 0, LU_data_t.size(-1)).triu()" + L: "LU_data_t.sym_size(-2) >= LU_data_t.sym_size(-1) ? LU_data_t.tril(-1) : LU_data_t.narrow_symint(-1, 0, LU_data_t.sym_size(-2)).tril(-1)" + U: "LU_data_t.sym_size(-1) >= LU_data_t.sym_size(-2) ? LU_data_t.triu() : LU_data_t.narrow_symint(-2, 0, LU_data_t.sym_size(-1)).triu()" output_differentiability: [False, True, True] - name: masked_fill.Scalar(Tensor self, Tensor mask, Scalar value) -> Tensor @@ -1018,7 +1018,7 @@ - name: masked_scatter(Tensor self, Tensor mask, Tensor source) -> Tensor self: grad.masked_fill(mask, 0) - source: masked_scatter_backward(grad, mask, source.sizes()) + source: masked_scatter_backward(grad, mask, source.sym_sizes()) mask: non_differentiable result: self_t.masked_scatter(mask, source_t) @@ -1032,7 +1032,7 @@ result: linalg_matrix_exp_differential(self_p, self_t, /*adjoint*/ false) - name: max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) - self: value_selecting_reduction_backward(grad, dim, indices, self.sizes(), keepdim) + self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), keepdim) values: gather_with_keepdimed_indices(self_t, dim, indices, keepdim) - name: max(Tensor self) -> Tensor @@ -1050,7 +1050,7 @@ result: other_t + (self_p > other_p).logical_or_(other_p.isnan()) * (self_t - other_t) - name: mean(Tensor self, *, ScalarType? dtype=None) -> Tensor - self: grad.expand(self.sizes()) / self.numel() + self: grad.expand_symint(self.sym_sizes()) / self.sym_numel() result: auto_linear - name: mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor @@ -1080,15 +1080,15 @@ # The backward implementation is correct in the sense that it returns the # subgradient on one side. - name: median.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) - self: value_selecting_reduction_backward(grad, dim, indices, self.sizes(), keepdim) + self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), keepdim) values: gather_with_keepdimed_indices(self_t, dim, indices, keepdim) - name: nanmedian.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) - self: value_selecting_reduction_backward(grad, dim, indices, self.sizes(), keepdim) + self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), keepdim) values: gather_with_keepdimed_indices(self_t, dim, indices, keepdim) - name: min.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) - self: value_selecting_reduction_backward(grad, dim, indices, self.sizes(), keepdim) + self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), keepdim) values: gather_with_keepdimed_indices(self_t, dim, indices, keepdim) - name: min(Tensor self) -> Tensor @@ -1119,7 +1119,7 @@ result: at::mm(self_t, mat2_p) + at::mm(self_p, mat2_t) - name: mode(Tensor self, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices) - self: value_selecting_reduction_backward(grad, dim, indices, self.sizes(), keepdim) + self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), keepdim) values: gather_with_keepdimed_indices(self_t, dim, indices, keepdim) - name: mul.Tensor(Tensor self, Tensor other) -> Tensor @@ -1232,16 +1232,16 @@ result: self_t.zero_() - name: normal.Tensor_float(Tensor mean, float std=1, *, Generator? generator=None) -> Tensor - mean: at::zeros(mean.sizes(), grad.options()) + mean: at::zeros_symint(mean.sym_sizes(), grad.options()) result: auto_element_wise - name: normal.float_Tensor(float mean, Tensor std, *, Generator? generator=None) -> Tensor - std: at::zeros(std.sizes(), grad.options()) + std: at::zeros_symint(std.sym_sizes(), grad.options()) result: auto_element_wise - name: normal.Tensor_Tensor(Tensor mean, Tensor std, *, Generator? generator=None) -> Tensor - mean: at::zeros(mean.sizes(), grad.options()) - std: at::zeros(std.sizes(), grad.options()) + mean: at::zeros_symint(mean.sym_sizes(), grad.options()) + std: at::zeros_symint(std.sym_sizes(), grad.options()) result: zeros_like(mean_t) - name: linalg_householder_product(Tensor input, Tensor tau) -> Tensor @@ -1463,12 +1463,12 @@ output_differentiability: [True, False, False, False] # LU is an auxiliary tensor not exposed to the user - name: sort(Tensor self, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices) - self: value_selecting_reduction_backward(grad, dim, indices, self.sizes(), true) + self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), true) output_differentiability: [True, False] values: gather_with_keepdimed_indices(self_t, dim, indices, true) - name: sort.stable(Tensor self, *, bool? stable, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices) - self: value_selecting_reduction_backward(grad, dim, indices, self.sizes(), true) + self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), true) output_differentiability: [True, False] values: gather_with_keepdimed_indices(self_t, dim, indices, true) @@ -1560,12 +1560,12 @@ # We never call _linalg_svd with compute_uv=False in an autograd context, so we don't even consider it here - name: _linalg_svd(Tensor A, bool full_matrices=False, bool compute_uv=True, *, str? driver=None) -> (Tensor U, Tensor S, Tensor Vh) - A: "svd_backward(full_matrices && grad_U.defined() ? grad_U.narrow(-1, 0, S.size(-1)) : grad_U, + A: "svd_backward(full_matrices && grad_U.defined() ? grad_U.narrow_symint(-1, 0, S.sym_size(-1)) : grad_U, grad_S, - full_matrices && grad_Vh.defined() ? grad_Vh.narrow(-2, 0, S.size(-1)) : grad_Vh, - full_matrices ? U.narrow(-1, 0, S.size(-1)) : U, + full_matrices && grad_Vh.defined() ? grad_Vh.narrow_symint(-2, 0, S.sym_size(-1)) : grad_Vh, + full_matrices ? U.narrow_symint(-1, 0, S.sym_size(-1)) : U, S, - full_matrices ? Vh.narrow(-2, 0, S.size(-1)) : Vh)" + full_matrices ? Vh.narrow_symint(-2, 0, S.sym_size(-1)) : Vh)" U, S, Vh: linalg_svd_jvp(A_t, U, S, Vh, full_matrices) - name: symeig(Tensor self, bool eigenvectors=False, bool upper=True) -> (Tensor eigenvalues, Tensor eigenvectors) @@ -1616,12 +1616,12 @@ result: auto_element_wise - name: topk(Tensor self, int k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices) - self: value_selecting_reduction_backward(grad, dim, indices, self.sizes(), true) + self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), true) output_differentiability: [True, False] values: gather(self_t, dim, indices) - name: trace(Tensor self) -> Tensor - self: trace_backward(grad, self.sizes()) + self: trace_backward_symint(grad, self.sym_sizes()) result: auto_linear - name: transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a) @@ -1670,10 +1670,10 @@ self: to_mkldnn_backward(grad, self) - name: unfold(Tensor(a) self, int dimension, int size, int step) -> Tensor(a) - self: unfold_backward(grad, self.sizes(), dimension, size, step) + self: unfold_backward_symint(grad, self.sym_sizes(), dimension, size, step) result: auto_linear -- name: unfold_backward(Tensor grad_in, int[] input_sizes, int dim, int size, int step) -> Tensor +- name: unfold_backward(Tensor grad_in, SymInt[] input_sizes, int dim, int size, int step) -> Tensor grad_in: grad.unfold(dim, size, step) result: auto_linear @@ -1772,7 +1772,7 @@ self: grad.to_dense().sparse_mask(mask).to_dense() mask: non_differentiable -- name: _sparse_coo_tensor_with_dims_and_tensors(int sparse_dim, int dense_dim, int[] size, Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor +- name: _sparse_coo_tensor_with_dims_and_tensors(SymInt sparse_dim, SymInt dense_dim, SymInt[] size, Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor values: sparse_constructor_values_backward(grad, indices) - name: _sparse_sum.dim(Tensor self, int[1] dim) -> Tensor @@ -1787,7 +1787,7 @@ - name: values(Tensor(a) self) -> Tensor(a) dispatch: Default: - self: at::_sparse_coo_tensor_unsafe(self.indices(), grad, self.sizes())._coalesced_(true) + self: at::_sparse_coo_tensor_unsafe_symint(self.indices(), grad, self.sym_sizes())._coalesced_(true) AutogradNestedTensor: self: at::_nested_view_from_buffer(grad.contiguous(), self._nested_tensor_size(), self._nested_tensor_strides(), self._nested_tensor_offsets()) @@ -1844,10 +1844,10 @@ - name: _embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1) -> (Tensor, Tensor, Tensor, Tensor) indices: non_differentiable offsets: non_differentiable - weight: _embedding_bag_backward(grad, indices, offsets, result1, result2, result3, weight.size(0), scale_grad_by_freq, mode, sparse, per_sample_weights, padding_idx) + weight: _embedding_bag_backward_symint(grad, indices, offsets, result1, result2, result3, weight.sym_size(0), scale_grad_by_freq, mode, sparse, per_sample_weights, padding_idx) per_sample_weights: _embedding_bag_per_sample_weights_backward(grad, weight, indices, offsets, result1, mode, padding_idx) -- name: _embedding_bag_dense_backward(Tensor grad, Tensor indices, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, int num_weights, bool scale_grad_by_freq, int mode, Tensor? per_sample_weights, int padding_idx=-1) -> Tensor +- name: _embedding_bag_dense_backward(Tensor grad, Tensor indices, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, SymInt num_weights, bool scale_grad_by_freq, int mode, Tensor? per_sample_weights, int padding_idx=-1) -> Tensor indices: non_differentiable offset2bag: non_differentiable bag_size: non_differentiable @@ -2147,7 +2147,7 @@ self: _adaptive_avg_pool2d_backward(grad, self) result: auto_linear -- name: _adaptive_avg_pool3d(Tensor self, int[3] output_size) -> Tensor +- name: _adaptive_avg_pool3d(Tensor self, SymInt[3] output_size) -> Tensor self: _adaptive_avg_pool3d_backward(grad, self) result: auto_linear @@ -2220,7 +2220,7 @@ # Note that the benchmark, deterministic, cudnn_enabled, and allow_tf32 flags are queried from the global context # by convolution_backward instead of being passed along from the forward pass. - name: _convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> Tensor - input, weight, bias: "grad.defined() ? convolution_backward(grad, input, weight, bias->sizes(), stride, padding, dilation, transposed, output_padding, groups, grad_input_mask) : std::tuple()" + input, weight, bias: "grad.defined() ? convolution_backward_symint(grad, input, weight, bias->sym_sizes(), stride, padding, dilation, transposed, output_padding, groups, grad_input_mask) : std::tuple()" result: _convolution_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, stride, padding, dilation, transposed, output_padding, groups, benchmark, deterministic, cudnn_enabled, allow_tf32) - name: convolution_backward(Tensor grad_output, Tensor input, Tensor weight, SymInt[]? bias_sizes, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor) @@ -2236,10 +2236,10 @@ grad_output, input, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, input, stride, padding, dilation, transposed, output_padding, groups, grad_input_mask) - name: slow_conv_transpose2d(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] output_padding=0, int[2] dilation=1) -> Tensor - self, weight, bias: "grad.defined() ? convolution_backward(grad, self, weight, bias->sizes(), stride, padding, dilation, true, output_padding, 1, grad_input_mask) : std::tuple()" + self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, true, output_padding, 1, grad_input_mask) : std::tuple()" - name: slow_conv_transpose3d(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias=None, int[3] stride=1, int[3] padding=0, int[3] output_padding=0, int[3] dilation=1) -> Tensor - self, weight, bias: "grad.defined() ? convolution_backward(grad, self, weight, bias->sizes(), stride, padding, dilation, true, output_padding, 1, grad_input_mask) : std::tuple()" + self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, true, output_padding, 1, grad_input_mask) : std::tuple()" - name: _slow_conv2d_forward(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias, int[2] stride, int[2] padding) -> Tensor self, weight, bias: "grad.defined() ? _slow_conv2d_backward(grad, self, weight, kernel_size, stride, padding, grad_input_mask) : std::tuple()" @@ -2248,19 +2248,19 @@ grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, {{1, 1}}, false, {{0, 0}}, 1, grad_input_mask) - name: _conv_depthwise2d(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias, int[2] stride, int[2] padding, int[2] dilation) -> Tensor - self, weight, bias: "grad.defined() ? convolution_backward(grad.contiguous(), self, weight, bias->sizes(), stride, padding, dilation, /*transposed=*/ false, /*output_padding=*/ {{0, 0}}, /*groups=*/ 1, grad_input_mask) : std::tuple()" + self, weight, bias: "grad.defined() ? convolution_backward_symint(grad.contiguous(), self, weight, bias->sym_sizes(), stride, padding, dilation, /*transposed=*/ false, /*output_padding=*/ {{0, 0}}, /*groups=*/ 1, grad_input_mask) : std::tuple()" - name: conv_depthwise3d(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias, int[3] stride, int[3] padding, int[3] dilation) -> Tensor - self, weight, bias: "grad.defined() ? convolution_backward(grad.contiguous(), self, weight, bias->sizes(), stride, padding, dilation, /*transposed=*/ false, /*output_padding=*/ {{0, 0, 0}}, /*groups=*/ 1, grad_input_mask) : std::tuple()" + self, weight, bias: "grad.defined() ? convolution_backward_symint(grad.contiguous(), self, weight, bias->sym_sizes(), stride, padding, dilation, /*transposed=*/ false, /*output_padding=*/ {{0, 0, 0}}, /*groups=*/ 1, grad_input_mask) : std::tuple()" - name: slow_conv3d_forward(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias, int[3] stride, int[3] padding) -> Tensor - self, weight, bias: "grad.defined() ? convolution_backward(grad, self, weight, bias->sizes(), stride, padding, /*dilation=*/ {{1, 1, 1}}, false, /*output_padding=*/ {{0, 0, 0}}, 1, grad_input_mask) : std::tuple()" + self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, /*dilation=*/ {{1, 1, 1}}, false, /*output_padding=*/ {{0, 0, 0}}, 1, grad_input_mask) : std::tuple()" - name: slow_conv_dilated2d(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1) -> Tensor - self, weight, bias: "grad.defined() ? convolution_backward(grad, self, weight, bias->sizes(), stride, padding, dilation, false, std::vector(padding.size(), 0), 1, grad_input_mask) : std::tuple()" + self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, false, std::vector(padding.size(), 0), 1, grad_input_mask) : std::tuple()" - name: slow_conv_dilated3d(Tensor self, Tensor weight, int[3] kernel_size, Tensor? bias=None, int[3] stride=1, int[3] padding=0, int[3] dilation=1) -> Tensor - self, weight, bias: "grad.defined() ? convolution_backward(grad, self, weight, bias->sizes(), stride, padding, dilation, false, std::vector(padding.size(), 0), 1, grad_input_mask) : std::tuple()" + self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, false, std::vector(padding.size(), 0), 1, grad_input_mask) : std::tuple()" - name: col2im(Tensor self, SymInt[2] output_size, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor self: im2col(grad, kernel_size, dilation, padding, stride) @@ -2276,7 +2276,7 @@ result: _adaptive_avg_pool2d_backward(grad_output_t, self_p) - name: _adaptive_avg_pool3d_backward(Tensor grad_output, Tensor self) -> Tensor - grad_output: _adaptive_avg_pool3d(grad, { grad_output.size(-3), grad_output.size(-2), grad_output.size(-1) }) + grad_output: _adaptive_avg_pool3d_symint(grad, { grad_output.sym_size(-3), grad_output.sym_size(-2), grad_output.sym_size(-1) }) self: zeros_like(self) result: _adaptive_avg_pool3d_backward(grad_output_t, self_p) @@ -2616,7 +2616,7 @@ - name: _nnpack_spatial_convolution(Tensor input, Tensor weight, Tensor? bias, int[2] padding, int[2] stride=1) -> Tensor # NNPACK does not support strided convolutions in the backwards path, which is the reason why we are using the closest available function that does here. - input, weight, bias: "grad.defined() ? convolution_backward(grad, input, weight, bias->sizes(), stride, padding, std::vector(padding.size(), 1), false, std::vector(padding.size(), 0), 1, grad_input_mask) : std::tuple()" + input, weight, bias: "grad.defined() ? convolution_backward_symint(grad, input, weight, bias->sym_sizes(), stride, padding, std::vector(padding.size(), 1), false, std::vector(padding.size(), 0), 1, grad_input_mask) : std::tuple()" #LSTM MPS - name: _lstm_mps(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor, Tensor, Tensor, Tensor) @@ -2648,13 +2648,13 @@ # miopen - name: miopen_convolution_transpose(Tensor self, Tensor weight, Tensor? bias, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor - self, weight, bias: "grad.defined() ? convolution_backward(grad, self, weight, bias->sizes(), stride, padding, dilation, true, output_padding, groups, grad_input_mask) : std::tuple()" + self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, true, output_padding, groups, grad_input_mask) : std::tuple()" - name: miopen_convolution(Tensor self, Tensor weight, Tensor? bias, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor - self, weight, bias: "grad.defined() ? convolution_backward(grad, self, weight, bias->sizes(), stride, padding, dilation, false, std::vector(padding.size(), 0), groups, grad_input_mask) : std::tuple()" + self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, false, std::vector(padding.size(), 0), groups, grad_input_mask) : std::tuple()" - name: miopen_depthwise_convolution(Tensor self, Tensor weight, Tensor? bias, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor - self, weight, bias: "grad.defined() ? convolution_backward(grad, self, weight, bias->sizes(), stride, padding, dilation, false, std::vector(padding.size(), 0), groups, grad_input_mask) : std::tuple()" + self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, false, std::vector(padding.size(), 0), groups, grad_input_mask) : std::tuple()" - name: miopen_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor) input, weight, bias: "grad.defined() ? (training ? miopen_batch_norm_backward(input, grad.contiguous(), weight, running_mean, running_var, result1, result2, epsilon) : native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, epsilon, grad_input_mask)) : std::tuple()" @@ -2674,7 +2674,7 @@ # mkldnn - name: mkldnn_convolution(Tensor self, Tensor weight, Tensor? bias, int[] padding, int[] stride, int[] dilation, int groups) -> Tensor - self, weight, bias: "grad.defined() ? convolution_backward(grad, self, weight, bias->sizes(), stride, padding, dilation, /*transposed=*/ false, /*output_padding=*/ std::vector(padding.size(), 0), groups, grad_input_mask) : std::tuple()" + self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, /*transposed=*/ false, /*output_padding=*/ std::vector(padding.size(), 0), groups, grad_input_mask) : std::tuple()" - name: mkldnn_linear(Tensor self, Tensor weight, Tensor? bias=None) -> Tensor self, weight, bias: mkldnn_linear_backward(self, grad, weight, grad_input_mask) @@ -2689,21 +2689,21 @@ self: mkldnn_adaptive_avg_pool2d_backward(grad, self) - name: _mkldnn_reshape(Tensor self, int[] shape) -> Tensor - self: grad.reshape(self.sizes()) + self: grad.reshape_symint(self.sym_sizes()) # Nested Tensor - name: _nested_tensor_from_tensor_list(Tensor[] list, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor list: "grad.defined()? at::unbind(grad) : std::vector(list.size())" - name: _nested_tensor_from_mask(Tensor t, Tensor mask, bool mask_check=True) -> Tensor - t: grad.to_padded_tensor(0, t.sizes()) + t: grad.to_padded_tensor_symint(0, t.sym_sizes()) mask: non_differentiable - name: _nested_from_padded(Tensor padded, Tensor cpu_nested_shape_example, bool fuse_transform_0213=False) -> Tensor padded: _nested_from_padded_backward(grad, padded, fuse_transform_0213) cpu_nested_shape_example: non_differentiable -- name: to_padded_tensor(Tensor self, float padding, int[]? output_size=None) -> Tensor +- name: to_padded_tensor(Tensor self, float padding, SymInt[]? output_size=None) -> Tensor self: at::_nested_from_padded(grad, self._nested_tensor_size()) padding: non_differentiable @@ -2714,15 +2714,15 @@ # fft - name: _fft_r2c(Tensor self, int[] dim, int normalization, bool onesided) -> Tensor - self: fft_r2c_backward(grad, dim, normalization, onesided, self.size(dim.back())) + self: fft_r2c_backward(grad, dim, normalization, onesided, self.sym_size(dim.back())) result: auto_linear - name: _fft_c2r(Tensor self, int[] dim, int normalization, int last_dim_size) -> Tensor self: fft_c2r_backward(grad, dim, normalization) result: auto_linear -- name: _fft_c2c(Tensor self, int[] dim, int normalization, bool forward) -> Tensor - self: _fft_c2c(grad, dim, normalization, !forward) +- name: _fft_c2c(Tensor self, SymInt[] dim, int normalization, bool forward) -> Tensor + self: _fft_c2c_symint(grad, dim, normalization, !forward) result: auto_linear - name: unbind.int(Tensor(a -> *) self, int dim=0) -> Tensor(a)[] @@ -2746,7 +2746,7 @@ # PackedSequence helpers - name: _pack_padded_sequence(Tensor input, Tensor lengths, bool batch_first) -> (Tensor, Tensor) - input: _pack_padded_sequence_backward(grad, input.sizes(), result1, batch_first) + input: _pack_padded_sequence_backward_symint(grad, input.sym_sizes(), result1, batch_first) # TH wrappers - name: eq.Scalar(Tensor self, Scalar other) -> Tensor @@ -2808,12 +2808,12 @@ - name: _test_autograd_multiple_dispatch.fullcoverage(Tensor self) -> Tensor dispatch: Default: - self: grad.expand(self.sizes()) + 1 + self: grad.expand_symint(self.sym_sizes()) + 1 result: auto_linear AutogradNestedTensor: self: grad.mul(grad) AutogradCUDA: - self: grad.expand(self.sizes()) * 2 + self: grad.expand_symint(self.sym_sizes()) * 2 - name: _test_autograd_multiple_dispatch.ntonly(Tensor self, bool b) -> Tensor dispatch: diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index bb74b56d12a..5d8b77ebcdb 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -196,6 +196,11 @@ def meta_pad2d(self, padding): return self.new_empty((nbatch, nplane, output_h, output_w)) +@register_meta(aten.bernoulli_.float, register_dispatcher=False) +def meta_bernoulli_(self, p=0.5, generator=None): + return self + + def dot_check(self, other): check( self.dim() == 1 and other.dim() == 1, diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index b1e43545857..0c93497bce0 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -1622,21 +1622,21 @@ Tensor std_mean_backward( Tensor masked_scatter_backward( const Tensor& grad, const Tensor& mask, - IntArrayRef sizes) { - int64_t numel = 1; + c10::SymIntArrayRef sizes) { + c10::SymInt numel = 1; for (auto size : sizes) { numel *= size; } auto mask_selected = grad.masked_select(mask); - auto diff_nelem = numel - mask_selected.numel(); + auto diff_nelem = numel - mask_selected.sym_numel(); if (diff_nelem > 0) { // because mask_selected returns a 1-d tensor with size of masked elements // that are 1, we need to fill out the rest with zeros then reshape back to // tensor2's size. - auto zeros_fillin = at::zeros({diff_nelem}, grad.options()); + auto zeros_fillin = at::zeros_symint({diff_nelem}, grad.options()); mask_selected = at::cat({mask_selected, zeros_fillin}, 0); } - return mask_selected.view(sizes); + return mask_selected.view_symint(sizes); } Tensor cholesky_jvp(const Tensor& dA, const Tensor& L, bool upper) { @@ -4371,10 +4371,10 @@ Tensor fft_c2r_backward( Tensor fft_r2c_backward( const Tensor& grad, - IntArrayRef dim, + at::IntArrayRef dim, int64_t normalization, bool onesided, - int64_t last_dim_size) { + c10::SymInt last_dim_size) { if (!onesided) { return at::real(at::_fft_c2c(grad, dim, normalization, /*forward=*/false)); } @@ -4389,16 +4389,17 @@ Tensor fft_r2c_backward( // (C2C ifft only take twosided inputs so we need to fill here) // 2. inverse C2C ifft // 3. discard the complex dim - auto half_sizes = grad.sizes(); - at::DimVector new_grad_shape(half_sizes.begin(), half_sizes.end()); + auto half_sizes = grad.sym_sizes(); + std::vector new_grad_shape(half_sizes.begin(), half_sizes.end()); const auto last_dim = at::maybe_wrap_dim(dim.back(), half_sizes.size()); new_grad_shape[last_dim] = last_dim_size; - const auto zero_length = last_dim_size - grad.size(dim.back()); + const auto zero_length = last_dim_size - grad.sym_size(dim.back()); auto complex_full_grad = - zero_length > 0 ? grad.new_zeros(new_grad_shape) : grad; + zero_length > 0 ? grad.new_zeros_symint(new_grad_shape) : grad; if (zero_length > 0) { - complex_full_grad.slice(last_dim, 0, half_sizes[last_dim]).copy_(grad); + complex_full_grad.slice_symint(last_dim, 0, half_sizes[last_dim]) + .copy_(grad); } return at::real( at::_fft_c2c(complex_full_grad, dim, normalization, /*forward=*/false)); @@ -5622,8 +5623,8 @@ Tensor solve_jvp( Tensor lu_unpack_backward( const Tensor& L_grad, const Tensor& U_grad, - const int64_t m, - const int64_t n) { + const c10::SymInt m, + const c10::SymInt n) { if (!L_grad.defined() && !U_grad.defined()) { return {}; } @@ -5631,16 +5632,16 @@ Tensor lu_unpack_backward( // Getters for the principal and complementary part of the matrices const auto get_L1 = [m, k](const Tensor& L) { - return m == k ? L.tril(-1) : L.narrow(-2, 0, k).tril(-1); + return m == k ? L.tril(-1) : L.narrow_symint(-2, 0, k).tril(-1); }; const auto get_L2 = [m, k](const Tensor& L) { - return L.narrow(-2, k, m - k); + return L.narrow_symint(-2, k, m - k); }; const auto get_U1 = [n, k](const Tensor& U) { - return n == k ? U.triu() : U.narrow(-1, 0, k).triu(); + return n == k ? U.triu() : U.narrow_symint(-1, 0, k).triu(); }; const auto get_U2 = [n, k](const Tensor& U) { - return U.narrow(-1, k, n - k); + return U.narrow_symint(-1, k, n - k); }; if (L_grad.defined()) { @@ -5657,20 +5658,22 @@ Tensor lu_unpack_backward( if (m >= n) { return L_grad.tril(-1); } else { - auto size = L_grad.sizes().vec(); + auto size = L_grad.sym_sizes().vec(); size.end()[-1] = n - m; return at::cat( - {L_grad.tril(-1), at::zeros(size, L_grad.options())}, /*dim=*/-1); + {L_grad.tril(-1), at::zeros_symint(size, L_grad.options())}, + /*dim=*/-1); } } } else { if (n >= m) { return U_grad.triu(); } else { - auto size = U_grad.sizes().vec(); + auto size = U_grad.sym_sizes().vec(); size.end()[-2] = m - n; return at::cat( - {U_grad.triu(), at::zeros(size, U_grad.options())}, /*dim=*/-2); + {U_grad.triu(), at::zeros_symint(size, U_grad.options())}, + /*dim=*/-2); } } } diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h index 8b83a0c22b1..562b015983f 100644 --- a/torch/csrc/autograd/FunctionsManual.h +++ b/torch/csrc/autograd/FunctionsManual.h @@ -356,7 +356,7 @@ Tensor std_mean_backward( at::Tensor masked_scatter_backward( const at::Tensor& grad, const at::Tensor& mask, - at::IntArrayRef sizes); + c10::SymIntArrayRef sizes); at::Tensor cholesky_backward( const at::Tensor& grad, bool upper, @@ -668,10 +668,10 @@ Tensor fft_backward( IntArrayRef output_sizes); Tensor fft_r2c_backward( const Tensor& grad, - IntArrayRef dim, + at::IntArrayRef dim, int64_t normalization, bool onesided, - int64_t last_dim_size); + c10::SymInt last_dim_size); Tensor fft_c2r_backward( const Tensor& grad, IntArrayRef dim, @@ -824,8 +824,8 @@ Tensor linalg_solve_jvp( Tensor lu_unpack_backward( const Tensor& L_grad, const Tensor& U_grad, - const int64_t m, - const int64_t n); + const c10::SymInt m, + const c10::SymInt n); Tensor linalg_det_backward( const Tensor& grad,