#include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include // Helper functions for autogenerated code // These used to be inlined into the codegened Functions.cpp namespace torch { namespace autograd { namespace generated { namespace details { using at::Tensor; using at::Scalar; using at::IntArrayRef; using at::TensorList; bool isDefined(const c10::optional& t) { return t.has_value() && t->defined(); } bool isFwGradDefined(const c10::optional& t) { return t.has_value() && t->defined() && t->fw_grad(/*level */ 0).defined(); } Tensor toLegacyTensor(const c10::optional& t) { return t.has_value() ? *t : Tensor(); } Tensor toLegacyFwGrad(const c10::optional& t) { return (t.has_value() && t->defined()) ? t->fw_grad(/*level */ 0) : Tensor(); } Tensor toLegacyPrimal(const c10::optional& t) { return (t.has_value() && t->defined()) ? t->_fw_primal(/*level */ 0) : Tensor(); } void copy_range(variable_list& out, IndexRange range, const Tensor & t) { AT_ASSERT(range.second <= out.size()); AT_ASSERTM(range.second - range.first == 1, "inconsistent range for Tensor output"); out[range.first] = t; } void copy_range(variable_list& out, IndexRange range, at::ArrayRef t) { AT_ASSERT(range.second <= out.size()); AT_ASSERTM(range.second - range.first == t.size(), "inconsistent range for TensorList output"); std::copy(t.begin(), t.end(), out.begin() + range.first); } Tensor copysign_tensor_self_backward(const Tensor & grad, const Tensor & self, const Tensor & result) { auto ratio = result / self; ratio.masked_fill_(self == 0, 0); return grad * ratio; } Tensor not_implemented(const char* name) { throw std::runtime_error( std::string("the derivative for '") + name + "' is not implemented"); } Tensor maybe_multiply(const Tensor & t, const Scalar & s) { bool is_one = false; if (s.isFloatingPoint()) { is_one = s.toDouble() == 1; } else if(s.isIntegral(true)) { is_one = s.toLong() == 1; } if (is_one) { return t; } else { return t * s; } } int64_t _safe_size(IntArrayRef sizes, IntArrayRef dim) { int64_t size = 1; if (sizes.size() == 0) { return 1; } for (auto d : dim) { d = at::maybe_wrap_dim(d, sizes.size()); size *= sizes[d]; } return size; } static Tensor wrapped_scalar_tensor(Scalar scalar) { auto tensor = scalar_to_tensor(scalar); tensor.unsafeGetTensorImpl()->set_wrapped_number(true); return tensor; } Tensor handle_r_to_c(ScalarType self_st, Tensor gradient_result) { if (!at::isComplexType(self_st) && gradient_result.is_complex()) { // R -> C return at::real(gradient_result); } return gradient_result; } Tensor handle_r_to_c(Tensor self, Tensor gradient_result) { if (!self.is_complex() && gradient_result.is_complex()) { // R -> C return at::real(gradient_result); } return gradient_result; } Tensor restore_reduced_dims(const Tensor &output, IntArrayRef dims, bool keepdim) { if (keepdim) { return output; } int64_t total_dims = output.dim() + dims.size(); std::vector target_shape(total_dims, 0); for (int64_t i : dims) { if (i < 0) { i = total_dims + i; } target_shape[i] = 1; } int64_t j = 0; for (int64_t i : output.sizes()) { while (target_shape[j] > 0) j++; target_shape[j++] = i; } return output.reshape(target_shape); } Tensor scale_grad_by_count(const Tensor &grad, const Tensor &mask, IntArrayRef dims) { return (grad / mask.sum(dims, true)) * mask; } std::tuple _euclidean_dist_backward(const Tensor & grad, const Tensor & x1, const Tensor & x2, const Tensor & res) { if (!grad.defined()) { return std::tuple(Tensor(), Tensor()); } // handle case at 0 where we return a subgradient containing 0 Tensor ratio = grad / res; ratio.masked_fill_(res == 0, 0); return std::tuple{ x1 * ratio.sum(-1, true) - ratio.matmul(x2), x2 * ratio.sum(-2, false).unsqueeze(-1) - ratio.transpose(-2, -1).matmul(x1)}; } Tensor norm_backward(const Tensor& grad, const Tensor& self, const optional & p_, const Tensor& norm) { return norm_backward(grad, self, p_, norm, {}, true); } Tensor norm_backward(Tensor grad, const Tensor& self, const optional & p_, Tensor norm, IntArrayRef dim, bool keepdim) { size_t ndim = self.sizes().size(); double p = p_.value_or(2.0).toDouble(); Tensor self_scaled; Tensor scale_v; if (!keepdim && self.dim() != 0) { grad = unsqueeze_multiple(grad, dim, ndim); norm = unsqueeze_multiple(norm, dim, ndim); } if (p == 0.0) { return at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT); } else if (p == 1.0) { return self.sgn() * grad; } else if (p == 2.0) { self_scaled = self; scale_v = grad / norm; } else if (std::isinf(p)) { Tensor is_eq_max = (self.abs() == norm).logical_or_(self.isnan().logical_and_(norm.isnan())).type_as(self); self_scaled = self.sign() * is_eq_max; Tensor nb_max = is_eq_max.count_nonzero(dim); if (self.dim() != 0) { nb_max = unsqueeze_multiple(nb_max, dim, ndim); } scale_v = grad / nb_max; } else if (p < 2.0) { self_scaled = self.sgn() * self.abs().pow(p - 1); scale_v = grad / norm.pow(p - 1); } else { self_scaled = self * self.abs().pow(p - 2); scale_v = grad / norm.pow(p - 1); } // handle case at 0 where we return a subgradient containing 0 scale_v.masked_fill_(norm == 0, 0); return self_scaled * scale_v; } Tensor pow_backward(Tensor grad, const Tensor & self, const Scalar & exponent) { if (exponent.equal(0.0)) { return at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT); } else { auto grad_lambda = [&](auto exp) { return grad * (exp * self.pow(exp - 1)).conj(); }; Tensor out = (exponent.isComplex()) ? grad_lambda(exponent.toComplexDouble()) : grad_lambda(exponent.toDouble()); return handle_r_to_c(self, out); } } Tensor pow_backward_self(Tensor grad, const Tensor & self, const Tensor & exponent) { auto out = at::where(exponent == 0.0, at::zeros({}, grad.options()), grad * (exponent * self.pow(exponent - 1)).conj()); return handle_r_to_c(self, out); } // Caveats: // We define d(a^b)/db at a = 0 and b < 0 to be -inf. This is due to // d(a^b)/db -> -inf for a fixed b as a -> +0 // Currently, tensorflow defines d(a^b)/db = nan for a = 0 and b < 0. // // We define d(a^b)/db = 0 for a = 0 and b = 0 by continuity as // d(a^b)/db = 0 for a > 0 and b -> +0. // Currently, tensorflow agrees with us. Tensor pow_backward_exponent(Tensor grad, const Tensor& self, const Tensor& exponent, Tensor result) { Tensor cond; if (exponent.is_complex()) { auto is_real_exp = at::logical_and(at::imag(exponent) == 0, at::real(exponent) >= 0); cond = at::logical_and(self == 0, is_real_exp); } else { cond = at::logical_and(self == 0, exponent >= 0); } auto out = grad * at::where(cond, at::zeros({}, grad.options()), (result * self.log()).conj()); return handle_r_to_c(exponent, out); } Tensor pow_backward_exponent(Tensor grad, const Scalar & base, const Tensor& exponent, Tensor result) { auto grad_lambda = [](Tensor a, Scalar b) { return (a * b.log()).conj(); }; if (base.equal(0.0)) { auto cond = [](auto exp) { if (exp.is_complex()) { return at::logical_and(at::imag(exp) == 0, at::real(exp) >= 0); } else { return exp >=0; } }; auto out = grad * at::where(cond(exponent), at::zeros({}, grad.options()), grad_lambda(result, base)); return handle_r_to_c(exponent, out); } else { auto out = grad * grad_lambda(result, base); return handle_r_to_c(exponent, out); } } Tensor angle_backward(Tensor grad, const Tensor& self) { if (self.is_complex()) { return at::where(self == 0.0, at::zeros({}, self.options()), grad * self / self.abs().pow(2) * Scalar(c10::complex{0.0, 1.0})); } else { return at::zeros_like(self, at::MemoryFormat::Preserve); } } Tensor mvlgamma_backward(Tensor grad, const Tensor & self, int64_t p) { Tensor args = at::arange(-p / 2. + 0.5, 0.5, 0.5, self.options()); args = args.add(self.unsqueeze(-1)); return grad * args.digamma_().sum(-1); } Tensor sgn_backward(Tensor result, Tensor grad, Tensor self) { if (self.is_complex()) { auto abs = at::abs(self); // C -> C // https://arxiv.org/pdf/1701.00392.pdf Section 4.20 return at::where(abs == 0.0, at::zeros({}, grad.options()), (grad/abs - (at::real(grad/self) * result))); } else { return at::zeros_like(self, at::MemoryFormat::Preserve); } } Tensor mul_tensor_backward(Tensor grad, Tensor other, ScalarType self_st) { auto out = grad * other.conj(); return handle_r_to_c(self_st, out); } Tensor div_tensor_self_backward(Tensor grad, Tensor other, ScalarType self_st) { auto result = grad / other.conj(); return handle_r_to_c(self_st, result); } Tensor div_tensor_other_backward(Tensor grad, Tensor self, Tensor other) { auto result = -grad * ((self / other) / other).conj(); return handle_r_to_c(other, result); } Tensor permute_backwards(const Tensor & grad, IntArrayRef fwd_dims) { // invert the permutation auto ndims = fwd_dims.size(); std::vector dims(ndims); for (size_t i = 0; i < ndims; i++) { dims[at::maybe_wrap_dim(fwd_dims[i], ndims)] = i; } return grad.permute(dims); } Tensor rad2deg_backward(const Tensor& grad) { constexpr double M_180_PI = 57.295779513082320876798154814105170332405472466564; return at::mul(grad, wrapped_scalar_tensor(Scalar(M_180_PI))); } Tensor deg2rad_backward(const Tensor& grad) { constexpr double M_PI_180 = 0.017453292519943295769236907684886127134428718885417; return at::mul(grad, wrapped_scalar_tensor(Scalar(M_PI_180))); } Tensor unsqueeze_multiple(const Tensor & t, IntArrayRef dim, size_t n_dims) { auto dims_to_unsqueeze = at::dim_list_to_bitset(dim, n_dims); Tensor res = t; for (size_t i = 0; i < n_dims; i++){ if (dims_to_unsqueeze[i]) { res = res.unsqueeze(i); } } return res; } Tensor sum_backward(const Tensor & grad, IntArrayRef sizes, IntArrayRef dims, bool keepdim) { if (!keepdim && sizes.size() > 0) { if (dims.size()==1) { return grad.unsqueeze(dims[0]).expand(sizes); } else { Tensor res = unsqueeze_multiple(grad, dims, sizes.size()); return res.expand(sizes); } } else { return grad.expand(sizes); } } Tensor nansum_backward(const Tensor & grad, const Tensor & self, IntArrayRef dims, bool keepdim) { auto sizes = self.sizes(); if (!keepdim && sizes.size() > 0) { if (dims.size()==1) { return grad.unsqueeze(dims[0]).expand(sizes) * self.isnan().logical_not(); } else { Tensor res = unsqueeze_multiple(grad, dims, sizes.size()); return res.expand(sizes) * self.isnan().logical_not(); } } else { return grad.expand(sizes) * self.isnan().logical_not(); } } std::vector reverse_list(const IntArrayRef list) { auto result = std::vector(); result.reserve(list.size()); for (auto iter = list.rbegin(); iter != list.rend(); iter++) { result.push_back(*iter); } return result; } Tensor reverse_dim(const Tensor& t, int64_t dim) { Tensor index = at::arange(t.size(dim) - 1, -1, -1, t.options().dtype(at::kLong)); return t.index_select(dim, index); } Tensor prod_safe_zeros_backward(const Tensor &grad, const Tensor& inp, int64_t dim) { if (inp.size(dim) == 1) { return grad; } auto ones_size = inp.sizes().vec(); ones_size[dim] = 1; Tensor ones = at::ones(ones_size, grad.options()); Tensor exclusive_normal_nocp = at::cat({ones, inp.narrow(dim, 0, inp.size(dim) - 1)}, dim); Tensor exclusive_normal = exclusive_normal_nocp.cumprod(dim); Tensor narrow_reverse = reverse_dim(inp.narrow(dim, 1, inp.size(dim) - 1), dim); Tensor exclusive_reverse_nocp = at::cat({ones, narrow_reverse}, dim); Tensor exclusive_reverse = reverse_dim(exclusive_reverse_nocp.cumprod(dim), dim); return grad * (exclusive_normal * exclusive_reverse); } // note that the gradient for prod is equivalent to: // cumprod(exclusive, normal) * cumprod(exclusive, reverse), e.g.: // input: [ a, b, c] // cumprod(exclusive, normal): [1 , a, a * b] // cumprod(exclusive, reverse): [b * c, c, 1] // product: [b * c, a * c, a * b] // and this is safe under input with 0s. Tensor prod_backward(const Tensor& grad, const Tensor& input, const Tensor& result) { if (input.dim() == 0) { return grad; } Tensor zero_idx = (input == 0).nonzero(); if (zero_idx.numel() == 0) { return (grad * result) / input; } else if (zero_idx.size(0) > 1) { return at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); } else { return prod_safe_zeros_backward(grad, input.contiguous().view(-1), 0).view_as(input); } } Tensor prod_backward(Tensor grad, const Tensor& input, Tensor result, int64_t dim, bool keepdim) { if (input.dim() == 0) { return grad; } dim = at::maybe_wrap_dim(dim, input.sizes().size()); if (!keepdim && input.dim() != 1) { grad = grad.unsqueeze(dim); result = result.unsqueeze(dim); } Tensor zero_mask = (input == 0); Tensor slice_zero_count = zero_mask.sum(dim, true); int64_t total_zeros = slice_zero_count.sum().item(); if (total_zeros == 0) { return (grad * result) / input; } else { return prod_safe_zeros_backward(grad, input, dim); } } Tensor solve_backward_self(const Tensor & grad, const Tensor & self, const Tensor & A) { return at::linalg_solve(A.conj().transpose(-2, -1), grad); } Tensor solve_backward_A(const Tensor & grad, const Tensor & self, const Tensor & A, const Tensor & solution) { Tensor grad_self = solve_backward_self(grad, self, A); if (self.ndimension() == 2 && A.ndimension() == 2) { return -at::mm(grad_self, solution.conj().transpose(-2, -1)); } // if self was unsqueezed from (..., M) to (..., M, 1) auto batched_rhs_shape = IntArrayRef(A.sizes().data(), A.dim()-1); // A.shape[:-1] bool is_rhs_broadcasted = self.dim() == 1 || (A.dim()-1 == self.dim() && self.sizes().equals(batched_rhs_shape)); if (is_rhs_broadcasted) { return -at::matmul(grad_self.unsqueeze(-1), solution.unsqueeze(-1).conj().transpose(-2, -1)); } return -at::matmul(grad_self, solution.conj().transpose(-2, -1)); } Tensor cumsum_backward(const Tensor & x, int64_t dim) { // Need to check numel to see if there are no values (such as shape [0,2], and dim to see if x is a scalar. if (x.dim() == 0 || x.numel() == 0) { return x; } auto ret = at::cumsum(-x, dim); auto ret_sum = ret.narrow(dim, ret.size(dim) - 1, 1).clone(at::MemoryFormat::Preserve); ret -= ret_sum.expand(ret.sizes()); ret += x; return ret; } Tensor logsumexp_backward(Tensor grad, const Tensor & self, Tensor result, IntArrayRef dim, bool keepdim) { if (!keepdim && self.dim() != 0) { grad = unsqueeze_multiple(grad, dim, self.sizes().size()); result = unsqueeze_multiple(result, dim, self.sizes().size()); } return grad * (self - result).exp(); } Tensor logcumsumexp_backward(Tensor grad, const Tensor & self, Tensor result, int64_t dim) { if (grad.dim() == 0 || grad.numel() == 0) { return grad; } // Reference: https://github.com/tensorflow/tensorflow/blob/ // 2a5910906a0e0f3dbc186ff9db6386d81a63448c/tensorflow/python/ops/math_grad.py#L1832-L1863 return AT_DISPATCH_FLOATING_TYPES( at::typeMetaToScalarType(grad.dtype()), "logcumsumexp_backward", [grad, self, result, dim]() { auto grad_min = at::empty_like(grad); grad_min.fill_(std::numeric_limits::lowest()); auto log_grad_positive = at::where(grad > 0, grad.log(), grad_min); auto log_grad_negative = at::where(grad < 0, (-grad).log(), grad_min); auto reverse_logcumsumexp = [dim](auto x) { return at::flip(at::logcumsumexp(at::flip(x, {dim}), dim), {dim}); }; auto output_pos = (reverse_logcumsumexp(log_grad_positive - result) + self).exp(); auto output_neg = (reverse_logcumsumexp(log_grad_negative - result) + self).exp(); return output_pos - output_neg; }); } Tensor unbind_backward(const variable_list& grads, int64_t dim) { IntArrayRef sizes; at::TensorOptions o; for (auto v : grads) { if (v.defined()) { sizes = v.sizes(); o = static_cast(v).options(); break; } } auto grads_tensors = fmap(grads, [&](const Variable& v) { return ( v.defined() ? static_cast(v) : at::zeros({}, o).expand(sizes)); }); return at::stack(grads_tensors, dim); } Tensor unsqueeze_to(const Tensor & self, IntArrayRef sizes) { auto result = self; int64_t nDims = sizes.size(); for (int64_t dim = 0; dim < nDims; dim++) { if (sizes[dim] == 1) { result = result.unsqueeze(dim); } } return result; } Tensor unsqueeze_to(const Tensor & self, int64_t dim, IntArrayRef sizes) { dim = at::maybe_wrap_dim(dim, sizes.size()); // in NumPy it's not an error to unsqueeze a scalar, but we still need to avoided // unsqueezing in the backward. if (sizes.size() > 0 && sizes[dim] == 1) { return self.unsqueeze(dim); } return self; } std::vector cat_tensors_backward(const Tensor & grad, const std::vector> &sizes, int64_t dim) { std::vector grad_inputs(sizes.size()); if (!grad.defined()) { return grad_inputs; } dim = at::legacy_cat_wrap_dim(dim, sizes); int64_t accumulate = 0; for (size_t i = 0; i < sizes.size(); ++i) { auto& shape = sizes[i]; // If input was empty tensor, gradInput should be empty tensor. if (shape == std::vector({0})) { grad_inputs[i] = at::zeros({0}, grad.options()); continue; } auto size = shape[dim]; accumulate += size; grad_inputs[i] = grad.narrow(dim, accumulate - size, size); } return grad_inputs; } Tensor clamp_backward(const Tensor & grad, const Tensor &self, const optional & min, const optional & max) { // clamp: gradients not defined on min and max, so we return the subgradient 1 for these cases. if (max && min) { return grad * ((self >= *min) * (self <= *max)).type_as(grad); } else if (min) { return grad * (self >= *min).type_as(grad); } else if (max) { return grad * (self <= *max).type_as(grad); } else { return grad; } } // This function is used by load_derivatives.py to replace tensor.strides() // calls that appear in derivative formulas. If the tensor has requires_grad // set, this function returns its strides or throws an error if the tensor // is sparse. If requires_grad is not set, an empty array is returned since // there will be no backward pass. // // This function only supports the case where `input` is the tensor whose // single derivative is being calculated. // // This function does not support `self` derivatives for inplace functions. // // Args: // input Tensor to call .strides() on // input_name Name of `input` tensor, from derivative formula at::IntArrayRef strides_or_error(const Tensor & input, c10::string_view const & input_name) { // TODO: Ideally, this function would never be called if requires_grad is // not set. Once codegen is updated to avoid the call, we can remove this // check. if (input.requires_grad()) { TORCH_CHECK( !input.is_sparse(), "The backward pass for this operation requires the '", input_name, "' tensor to be strided, but a sparse tensor was given instead. ", "Please either use a strided tensor or set requires_grad=False for '", input_name, "'"); return input.strides(); } else { return IntArrayRef({}); } } Tensor mm_mat1_backward(const Tensor & grad, const Tensor & mat2, at::IntArrayRef mat1_sizes, at::IntArrayRef mat1_strides, const Scalar & alpha) { // if input was column-major, return grad as column-order for efficiency if (mat1_strides[0] == 1 && mat1_strides[1] == mat1_sizes[0]) { return maybe_multiply(mat2.conj().mm(grad.t()).t(), alpha); } else { return maybe_multiply(grad.mm(mat2.t().conj()), alpha); } } Tensor mm_mat2_backward(const Tensor & grad, const Tensor & mat1, IntArrayRef sizes, IntArrayRef strides, const Scalar & alpha) { // if input was column-major, return grad as column-order for efficiency if (strides[0] == 1 && strides[1] == sizes[0]) { if (mat1.is_sparse()) { // Since mm(dense, sparse) doesn't exist, // pass a transposed output matrix to the underlying "addmm" // function directly. int64_t out_rows = mat1.size(1); int64_t out_cols = grad.size(1); Tensor t = at::zeros({}, grad.options()).expand({out_rows, out_cols}, true); Tensor r = at::empty({out_cols, out_rows}, grad.options()).t(); at::addmm_out(r, t, mat1.t(), grad, alpha, 1); return r; } return maybe_multiply(grad.t().mm(mat1.conj()).t(), alpha); } else { return maybe_multiply(mat1.t().conj().mm(grad), alpha); } } Tensor _sparse_addmm_sparse_backward(const Tensor& grad, const Tensor& sparse_, const Tensor& dense, const Scalar& alpha) { AT_ASSERT(sparse_.is_sparse()); auto sparse = sparse_.coalesce(); Tensor grad_sparse = maybe_multiply(grad.mm(dense.t()), alpha); return grad_sparse.sparse_mask(sparse); } // This function return a new SparseTensor with values from Tensor `input` filtered by indices of `mask` // and values are ignored. `input` and `mask` are sparse matrices, a sparse tensor with sparse_dim=2 and dense_dim=2, // and they must have the same shape. // Note that the `output` must have the same `indices` as the `mask` so we are using just a clone. // However, to get `values` we have to use specific helper function for CPU/CUDA and use the `mask` data to filter `values` // That's why we created this `_sparse_matrix_mask_helper` function. Tensor _sparse_matrix_mask(const Tensor& input, const Tensor& mask){ Tensor output = at::empty_like(mask); Tensor mask_indices = mask._indices().clone(); Tensor r_values; if (mask._nnz() == 0) { r_values = at::zeros_like(mask._values()); } else { r_values = _sparse_matrix_mask_helper(input, mask_indices.contiguous()); } at::sparse::get_sparse_impl(output)->set_indices_and_values_unsafe(mask_indices, r_values); return output; } Tensor sparse_sparse_matmul_backward( const Tensor& grad, const Tensor& a, const Tensor& b, int64_t grad_order) { /* To implement the backward algorithm for sparse matrix-matrix matmul (SPMM) we can start from the following definition for dense tensors: c = a @ b then a_grad = c_grad @ b^T b_grad = a^T @ c_grad So for sparse matrices we can use the following definition: if grad_order == 0: a_grad = sparse_matrix_mask(c_grad @ b^T, mask=a) else: b_grad = sparse_matrix_mask(a^T @ c_grad, mask=b) */ TORCH_CHECK( grad_order == 0 || grad_order == 1, ": grad_order not in [0, 1] at sparse_sparse_matmul_backward function"); if (grad_order == 0) { auto a_grad = _sparse_sparse_matmul(grad, b.t()); return _sparse_matrix_mask(a_grad.coalesce(), a.coalesce()); } auto b_grad = _sparse_sparse_matmul(a.t(), grad); return _sparse_matrix_mask(b_grad.coalesce(), b.coalesce()); } Tensor renorm_backward(const Tensor & grad, const Tensor & self, Scalar p, int64_t dim, Scalar maxnorm) { auto transposed_sizes = self.transpose(dim, 0).sizes().vec(); auto flatten = [&](const Tensor & t) { return t.transpose(dim, 0).contiguous().view({t.size(dim), -1}); }; auto unflatten = [&](const Tensor & t) { return t.contiguous().view(transposed_sizes).transpose(dim, 0); }; // renorm computes the norm over all dimensions except `dim`, which is why // we need the flatten and unflatten business. TODO: simplify this when we // add support for norm over multiple dimensions. auto self_flat = flatten(self); auto grad_flat = flatten(grad); auto norm_flat = self_flat.norm(p, 1, true); auto grad_output = (self_flat * grad_flat).sum(1, true); auto nb = norm_backward(grad_output, self_flat, p, norm_flat, 1, true); auto invnorm = (norm_flat + 1e-7).reciprocal(); auto grad_norm = unflatten(maxnorm * invnorm * (grad_flat - invnorm * nb)); auto norm = unflatten(norm_flat.expand_as(self_flat)); // TODO: remove the detach once comparison ops no longer require grad auto mask = Variable(norm < maxnorm).detach(); return at::where(mask, grad, grad_norm); } Tensor repeat_backward(Tensor grad, IntArrayRef repeats, IntArrayRef input_shape) { auto find_iter = std::find(repeats.cbegin(), repeats.cend(), 0); if (find_iter != repeats.cend()) { return at::zeros(input_shape, grad.options()); } const auto input_dims = input_shape.size(); int64_t num_unsqueezed = grad.dim() - input_dims; for (int64_t i = 0; i < num_unsqueezed; ++i) { grad = grad.sum(0, false); } at::DimVector grad_size, sum_dims; for (size_t dim = 0; dim < input_dims; ++dim) { int64_t repeat = repeats[dim + num_unsqueezed]; // Reshape gradient (repeat > 1) // Index: [..., dim , ...] [..., dim , dim+1 , ...] // Shape: From [..., dimsize, ...] to [..., repeat, dimsize/repeat, ...] // The gradient tensor at 'dim' is reshaped to 'repeat' times of input tensor. // Then, sum up gradients over repeated tensors along 'dim', and reduce shape // from 'repeat * dimsize/repeat' to 'dimsize/repeat' ('input_dimsize'). // Example: // Size(3, 2) Size(6, 2) // [[v1_0, v1_1], // [v1_2, v1_3], // [[v0, v1], repeat(2, 1) [v1_4, v1_5], // [v2, v3], -------------> [v2_0, v2_1], // [v4, v5]] [v2_2, v2_3], // [v2_4, v2_5]] // // input grad (3, 2) reshape (2, 3, 2) output grad (6, 2) // [[[g1_0, g1_1], [[g1_0, g1_1], // [g1_2, g1_3], [g1_2, g1_3], // [[g1_0+g2_0, g1_1+g2_1], [g1_4, g1_5]], [g1_4, g1_5], // [g1_0+g2_0, g1_1+g2_1], [g2_0, g2_1], // [g1_0+g2_0, g1_1+g2_1]] [[g2_0, g2_1], [g2_2, g2_3], // [g2_2, g2_3], [g2_4, g2_5]] // [g2_4, g2_5]]] // If gradient tensor is reshaped to [..., dimsize/repeat, repeat, ...] and then // sum over 'dim+1'. The gradient for input is not correctly aligned with input. // Example: // input grad (3, 2) reshape (3, 2, 2) output grad (6, 2) // [[[g1_0, g1_1], // [g1_2, g1_3]], [[g1_0, g1_1], // [g1_2, g1_3], // [[g1_0+g1_2, g1_1+g1_3], [[g1_4, g1_5], [g1_4, g1_5], // [g1_4+g2_0, g1_5+g2_1], [g2_0, g2_1]], [g2_0, g2_1], // [g2_2+g2_4, g2_3+g2_5]] [g2_2, g2_3], // [[g2_2, g2_3], [g2_4, g2_5]] // [g2_4, g2_5]]] if (repeat != 1) { grad_size.push_back(repeat); sum_dims.push_back(grad_size.size() - 1); } // Don't need to reshape gradient into (repeat, input_shape[dim]) (repeat == 1) grad_size.push_back(input_shape[dim]); } // One-time Reshape & Sum // Reshape gradient to grad_size: // 1. If repeat equals to 1, append input size at that dimension, // 2. If repeat is larger than 1, append both repeat and input size at that dimension. // Sum over all "repeat" dimensions from sum_dims: // Example: // Input Size (2, 3, 4, 5) // repeat [4, 1, 9, 3] // output/grad Size (8, 3, 36, 15) // grad_size [4, 2, 3, 9, 4, 3, 5] // sum_dims [0, 3, 5] // When repeat 1 time over all original dimensions, the empty sum_dims will reduce // the whole grad tensor into a scalar rather than keeping original dimensions. if (!sum_dims.empty()) { grad = grad.reshape(grad_size); grad = grad.sum(sum_dims); } return grad; } // p1m == 1 - p Tensor _fused_dropout_backward(Tensor grad, Tensor mask, double p1m) { if (grad.requires_grad()) { // Use autograd-friendly backward if double backward is required return grad * (mask.type_as(grad) * (1. / p1m)); } else { return at::_masked_scale(grad, mask, 1. / p1m); } } Tensor evenly_distribute_backward(Tensor grad, const Tensor & input, const Tensor & value) { if (input.is_cuda()) { auto mask = (input == value).logical_or_(input.isnan().logical_and_(value.isnan())); return mask * (grad / mask.sum()); } else { auto mask = value.isnan().item() ? input.isnan() : input == value; return grad.new_zeros(input.sizes(), input.options()).masked_fill_(mask, grad / mask.sum()); } } Tensor var_backward(const Tensor & grad, const Tensor & self, bool unbiased) { return (2.0 / (self.numel() - unbiased)) * grad * (self - self.mean()); } Tensor var_backward(Tensor grad, const Tensor & self, IntArrayRef dim, bool unbiased, bool keepdim) { if (self.dim() == 0) { return var_backward(grad, self, unbiased); } if (!keepdim && self.dim() > 1) { grad = unsqueeze_multiple(grad, dim, self.sizes().size()); } return (2.0 / (_safe_size(self.sizes(), dim) - unbiased)) * grad * (self - self.mean(dim, true)); } Tensor std_backward(const Tensor & result, const Tensor & grad, const Tensor & self, bool unbiased) { return var_backward((grad / (result * 2)).masked_fill_(result == 0, 0), self, unbiased); } Tensor std_backward(const Tensor & result, Tensor grad, const Tensor & self, IntArrayRef dim, bool unbiased, bool keepdim) { return var_backward((grad / (result * 2)).masked_fill_(result == 0, 0), self, dim, unbiased, keepdim); } Tensor mean_backward(Tensor grad, const IntArrayRef sizes, IntArrayRef dim, bool keepdim) { return sum_backward(grad, sizes, dim, keepdim) / _safe_size(sizes, dim); } Tensor mean_backward(Tensor grad, const IntArrayRef sizes, int numel) { return grad.expand(sizes) / numel; } Tensor var_std_mean_backward(const variable_list& grads, const Tensor & self, const Tensor & r1, const Tensor & r2, IntArrayRef dim, bool unbiased, bool keepdim, bool is_std) { Tensor grad; if (grads[0].defined()) { grad = is_std ? std_backward(r1, grads[0], self, dim, unbiased, keepdim) : var_backward(grads[0], self, dim, unbiased, keepdim); } if (grads[1].defined()) { Tensor mean_grad = mean_backward(grads[1], self.sizes(), dim, keepdim); grad = grads[0].defined() ? grad + mean_grad : mean_grad; } return grad; } Tensor var_std_mean_backward(const variable_list& grads, const Tensor & self, const Tensor & r1, const Tensor & r2, bool unbiased, bool is_std) { Tensor grad; if (grads[0].defined()) { grad = is_std ? std_backward(r1, grads[0], self, unbiased) : var_backward(grads[0], self, unbiased); } if (grads[1].defined()) { Tensor mean_grad = mean_backward(grads[1], self.sizes(), self.numel()); grad = grads[0].defined() ? grad + mean_grad : mean_grad; } return grad; } Tensor masked_scatter_backward(const Tensor & grad, const Tensor & mask, IntArrayRef sizes) { int64_t numel = 1; for (auto size : sizes) { numel *= size; } auto mask_selected = grad.masked_select(mask); auto diff_nelem = numel - mask_selected.numel(); if (diff_nelem > 0) { // because mask_selected returns a 1-d tensor with size of masked elements that are 1, // we need to fill out the rest with zeros then reshape back to tensor2's size. auto zeros_fillin = at::zeros({diff_nelem}, grad.options()); mask_selected = at::cat({mask_selected, zeros_fillin}, 0); } return mask_selected.view(sizes); } Tensor cholesky_backward(Tensor grad, bool upper, Tensor L) { // cf. Iain Murray (2016); arXiv 1602.07527 // This gradient is symmetric, and not triangular. // Cholesky additionally assumes that the input is symmetric, which is a subspace of // R^{n x n}, and hence the derivative is not well-defined for off-diagonal // elements. We resolve this by taking the gradient of the functionally independent // elements of the matrix (i.e., the lower triangular portion of the input) and then // reflect it on the upper triangular portion, thereby symmetrizing the gradient of // the cholesky operation. The motivation behind this choice is that symmetric gradient // leads to stable gradient updates, and retains symmetry of the updated matrix if it // were updated by a gradient based algorithm. if (upper) { L = L.transpose(-1, -2).conj(); grad = grad.transpose(-1, -2).conj(); } auto L_inverse = std::get<0>(at::triangular_solve(at::eye(L.size(-1), L.options()), L, /*upper=*/false)); auto phi = at::matmul(L.transpose(-1, -2).conj(), grad); phi.tril_().diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).mul_(0.5); auto grad_input = at::matmul(at::matmul(L_inverse.transpose(-1, -2).conj(), phi), L_inverse); return grad_input.add(grad_input.transpose(-1, -2).conj()).mul_(0.5); // Symmetrizing the gradient } Tensor cholesky_inverse_backward(Tensor grad, Tensor L, bool upper, Tensor inverse) { Tensor grad_L; if (grad.defined()) { Tensor common_term = grad + grad.transpose(-2, -1); common_term = at::matmul(inverse, at::matmul(common_term, inverse)); if (upper) { grad_L = -at::matmul(L, common_term); } else { grad_L = -at::matmul(common_term, L); } } else { grad_L = at::zeros({1}, L.options()).expand_as(L); } return grad_L; } Tensor split_with_sizes_backward(const std::vector &grads, IntArrayRef split_sizes, int64_t dim, IntArrayRef sizes, const at::TensorOptions &options) { dim = at::maybe_wrap_dim(dim, sizes.size()); // it's possible some of the grads are not defined (represents tensors of all 0s). // Since at::cat can't handle those, let's define them std::vector grads_all_defined(grads.size()); for (size_t j = 0; j < grads.size(); ++j) { if (grads[j].defined()) { grads_all_defined[j] = grads[j]; } else { auto length = split_sizes[j]; auto grad_size = sizes.vec(); grad_size[dim] = length; grads_all_defined[j] = at::zeros(grad_size, options); } } auto ret = at::cat(grads_all_defined, dim); return ret; } Tensor split_backward(const std::vector &grads, int64_t split_size, int64_t dim, IntArrayRef sizes, const at::TensorOptions &options) { dim = at::maybe_wrap_dim(dim, sizes.size()); int64_t dim_size = sizes[dim]; int64_t num_splits = grads.size(); std::vector split_sizes(num_splits, split_size); split_sizes[num_splits - 1] = split_size - (split_size * num_splits - dim_size); return split_with_sizes_backward(grads, split_sizes, dim, sizes, options); } Tensor max_pool_double_backward(const Tensor & grad, const Tensor & indices, int dim) { AT_ASSERT(indices.dim() >= dim); auto size = indices.sizes().slice(0, indices.dim() - dim).vec(); size.push_back(-1); auto indices_view = indices.view(size); const auto memory_format = indices.suggest_memory_format(); return grad.contiguous(memory_format).view(size).gather(-1, indices_view).view(indices.sizes()); } Tensor glu_double_backward(const Tensor & grad, const Tensor & grad_output, const Tensor & input, int64_t dim) { auto& gO = grad_output; auto input_size = input.size(dim) / 2; auto first_half = input.narrow(dim, 0, input_size); auto second_half = input.narrow(dim, input_size, input_size); auto sig_second_half = second_half.sigmoid(); auto one_sub_sig_second_half = 1 - sig_second_half; auto sig_one_sub_sig = sig_second_half * one_sub_sig_second_half; auto ggI_first_half = grad.narrow(dim, 0, input_size); auto ggI_second_half = grad.narrow(dim, input_size, input_size); auto ggI_second_half_times_first_half = ggI_second_half * first_half; auto gI_first_half = ggI_second_half * gO * sig_one_sub_sig; auto second_order_sh = sig_one_sub_sig * one_sub_sig_second_half - sig_second_half * sig_one_sub_sig; auto gI_second_half = ggI_second_half_times_first_half * gO * second_order_sh + ggI_first_half * gO * sig_one_sub_sig; return at::cat({gI_first_half, gI_second_half}, dim); } Tensor glu_double_backward_grad_output(const Tensor & grad, const Tensor & input, int64_t dim) { if (dim < 0) dim += input.dim(); auto sizes = input.sizes().vec(); sizes[dim] /= 2; auto tmp = grad * glu_backward(at::ones(sizes, input.options()), input, dim); return tmp.narrow(dim, 0, sizes[dim]) + tmp.narrow(dim, sizes[dim], sizes[dim]); } Tensor infinitely_differentiable_silu_backward( const Tensor& grad_output, const Tensor& input) { const Tensor sigmoid = input.sigmoid(); return grad_output * sigmoid * (1.0 + input * (1.0 - sigmoid)); } Tensor infinitely_differentiable_logit_backward( const Tensor& grad, const Tensor& self, c10::optional eps) { if (eps) { const double lo = eps.value(); const double hi = 1.0 - lo; return at::where( at::logical_and(self >= lo, self <= hi), grad / (self * (1.0 - self)), at::zeros({}, self.options())); } else { return at::where( at::logical_and(self >= 0.0, self <= 1.0), grad / (self * (1.0 - self)), at::empty({}, self.options()) .fill_(std::numeric_limits::quiet_NaN())); } } Tensor kl_div_double_backward_grad_output(const Tensor & grad, const Tensor & input, const Tensor & target, int64_t reduction, bool log_target) { auto result = kl_div_backward(grad, input, target, at::Reduction::None, log_target); if (reduction == at::Reduction::Mean) { return result.mean(); } else if (reduction == at::Reduction::Sum) { return result.sum(); } return result; } // Compute derivatives for targets. Tensor kl_div_target_backward(Tensor grad_output, Tensor self, Tensor target, int64_t reduction, bool log_target) { Tensor grad_target; if (!log_target) { grad_target = grad_output.mul(target.log().add_(1).sub_(self)).masked_fill_(target == 0, 0.); } else { grad_target = grad_output.mul(target.add(1).sub_(self).mul_(target.exp())); } if (reduction == at::Reduction::Mean) { grad_target.div_(target.numel()); } return grad_target; } Tensor binary_cross_entropy_with_logits_target_backward(const Tensor& grad_output, const Tensor& self, const Tensor& target, const c10::optional& weight, const c10::optional& pos_weight, int64_t reduction) { Tensor grad_target; if (isDefined(pos_weight)) { grad_target = (1. - self.sigmoid()).log_().sub_(pos_weight->mul(self.sigmoid().log_())).mul_(grad_output); } else { grad_target = self.mul(-grad_output); } if (isDefined(weight)) { grad_target.mul_(*weight); } if (reduction == at::Reduction::Mean) { grad_target.div_(target.numel()); } return grad_target; } Tensor log_sigmoid_double_backward(const Tensor & grad, const Tensor & input) { auto z = input.sigmoid(); return grad * (z - 1) * z; } Tensor softmax_double_backward(const Tensor & grad, const Tensor & grad_output, int dim, const Tensor & output) { auto gO = grad_output; auto ggI = grad; auto ggI_output = ggI * output; auto ggI_out_sum = ggI_output.sum(dim, true); auto ggI_out_sum_output = ggI_out_sum * output; auto gO_out_sum = (gO * output).sum(dim, true); // gI calculation auto gI_t0 = ggI_output * (gO - gO_out_sum); auto gI_t1 = output * ((ggI_output * gO).sum(dim, true).sub_(gO_out_sum * ggI_out_sum)); auto gI_t2 = ggI_out_sum_output * gO; auto gI_t3 = ggI_out_sum_output * gO_out_sum; return gI_t0 - gI_t1 - gI_t2 + gI_t3; } Tensor log_softmax_double_backward(const Tensor & grad, const Tensor & grad_output, int dim, const Tensor & output) { auto z = output.exp(); return z * grad_output.sum(dim, true) * ((grad * z).sum(dim, true) - grad); } // NOTE: [How to write vmap-compatible backward formulas] // // See NOTE: [vmap-incompatible in-place operations] for what it means for an // in-place operation to be incompatible with vmap. // // If an in-place operation used in a backward formula is vmap-incompatible, // then as developers we have the following options: // // - If the in-place operation directly followed the creation of a tensor with // a factory function like at::zeros(...), we should replace the factory with a // corresponding grad.new_zeros(...) call. The grad.new_zeros(...) call // propagates the batch dims to the resulting tensor. // For example: // Before: at::zeros(input.sizes(), grad.options()).copy_(grad) // After: grad.new_zeros(input.sizes()).copy_(grad) // // - If the in-place operation followed some sequence of operations, if the // we want to be able to vmap over the backward formula as-is (this is // usually the case for simple (<15loc) backward formulas), then use // inplaceIsVmapCompatible to guard the operation. For example: // c = a * b // Before: c.mul_(grad) // After: c = at::inplaceIsVmapCompatible(c, grad) ? c.mul_(grad) : c * grad // // - If we don't want to vmap directly over the backward formula (e.g., if the // backward formula is too complicated or has a lot of vmap-incompatible // operations, then register the backward formula as an operator and eventually // write a batching rule for it. Tensor binary_cross_entropy_double_backward(const Tensor & grad_output, const Tensor & grad, const Tensor & input, const Tensor & target, const c10::optional& weight, int64_t reduction) { auto eps = 1e-12; auto inp_pl_eps = input + eps; auto one_m_inp_pl_eps = 1 - input + eps; // gradient wrt input auto gI = (input * input - 2 * input * target + target) / (inp_pl_eps.pow(2) * one_m_inp_pl_eps.pow(2)); if (at::inplaceIsVmapCompatible(gI, grad)) { gI *= (grad * grad_output); } else { gI = gI * (grad * grad_output); } if (isDefined(weight)) { gI *= *weight; } if (reduction == at::Reduction::Mean) { return gI / input.numel(); } else if (reduction == at::Reduction::Sum) { return gI.sum(); } return gI; } Tensor binary_cross_entropy_double_backward_grad_output(const Tensor & grad, const Tensor & input, const Tensor & target, const c10::optional& weight, int64_t reduction) { auto eps = 1e-12; // gradient wrt grad_output auto ggO = (input - target) / ((input + eps) * (1 - input + eps)); if (at::inplaceIsVmapCompatible(ggO, grad)) { ggO *= grad; } else { ggO = ggO * grad; } if (isDefined(weight)) { ggO *= *weight; } if (reduction == at::Reduction::Mean) { return ggO / input.numel(); } else if (reduction == at::Reduction::Sum) { return ggO.sum(); } return ggO; } Tensor l1_loss_double_backward_grad_output(const Tensor & grad, const Tensor & input, const Tensor & target, int64_t reduction) { auto output = l1_loss_backward(grad, input, target, at::Reduction::None); if (reduction == at::Reduction::Mean) { return output.mean(); } else if (reduction == at::Reduction::Sum) { return output.sum(); } return output; } Tensor smooth_l1_loss_double_backward(const Tensor & grad, const Tensor & input, const Tensor & target, int64_t reduction, double beta) { // special case to protect against a divide-by-zero. if (beta == 0) { return at::zeros(grad.sizes(), grad.options()); } auto d = (input - target).abs(); auto grad_input = grad * (d < beta).type_as(grad) / beta; if (reduction == at::Reduction::Mean) { grad_input /= input.numel(); } return grad_input; } Tensor smooth_l1_loss_double_backward_grad_output(const Tensor & grad, const Tensor & grad_output, const Tensor & input, const Tensor & target, int64_t reduction, double beta) { if (reduction == at::Reduction::None) { return smooth_l1_loss_backward(grad, input, target, reduction, beta); } auto r = smooth_l1_loss_backward(ones_like(grad_output), input, target, reduction, beta); return (r * grad).sum(); } Tensor mse_loss_double_backward(const Tensor & grad, const Tensor & input, int64_t reduction) { auto grad_input = 2 * grad; if (reduction == at::Reduction::Mean) { grad_input /= input.numel(); } return grad_input; } Tensor mse_loss_double_backward_grad_output(const Tensor & grad, const Tensor & grad_output, const Tensor & input, const Tensor & target, int64_t reduction) { if (reduction == at::Reduction::None) { return mse_loss_backward(grad, input, target, reduction); } auto r = mse_loss_backward(ones_like(grad_output), input, target, reduction); return (r * grad).sum(); } Tensor soft_margin_loss_double_backward(const Tensor & grad, const Tensor & input, const Tensor & target, int64_t reduction) { auto z = (input * -target).exp(); auto zplus1 = z + 1; auto grad_input = grad * (target * target) * z / (zplus1 * zplus1); if (reduction == at::Reduction::Mean) { grad_input /= input.numel(); } return grad_input; } Tensor soft_margin_loss_double_backward_grad_output(const Tensor & grad, const Tensor & grad_output, const Tensor & input, const Tensor & target, int64_t reduction) { if (reduction == at::Reduction::None) { return soft_margin_loss_backward(grad, input, target, reduction); } auto r = soft_margin_loss_backward(ones_like(grad_output), input, target, reduction); return (r * grad).sum(); } Tensor softplus_double_backward(const Tensor & grad, const Tensor & input, Scalar beta, Scalar threshold) { auto x = (input * beta); return sigmoid_backward(grad, x.sigmoid()) * (x < threshold).type_as(grad) * beta; } // NOTE [ as_strided Backward and layout-aware/agnostic autograd ] // // `storage_offset` is ignored for simplicity in this note. If you just want the // full algorithm without explanation, scroll down to bottom of this note. // // Implementing the backward of as_strided is tricky because you have to deal // with mappings that map one memory location to multiple indices, i.e., the // output tensor has multiple indices pointing to **overlapping** memory // addresses. This can happen in all in all sorts of weird cases. For example, // // x = torch.randn(15) // x.as_strided([3, 3], [1, 0]) # "expand" case // x.as_strided([3, 3], [2, 1]) # "size too large" case // x.as_strided([3, 2], [3, 6]) # res[2, 0] points to 2*3 + 0*6 = 6 // # res[0, 1] points to 0*3 + 1*6 = 6 // // Here is the general strategy we apply in implementing as_strided backward: // 0. ??? (optimization step. we will talk about this later) // 1. Create some underlying flattened tensor as if it is the base tensor // representing the contiguous memory storage for both input and output. // 2. Use the output geometry to scatter (or index_add) the gradients into // this storage tensor. // 3. ??? (fix for input tensor with overlapping memory. we will talk about // this later) // 4. Return the as_strided view of the storage tensor using input geometry. // // In step (2), if the output tensor does't have overlapping memory, we can // safely scatter (`storage.as_strided(output_geometry).copy_(grad)`); // otherwise, we must use `index_add` as gradients at different indices may need // to be summed to a single location. // // For example, in this case: // // x = torch.randn(3) // y = x.as_strided([3, 3], [1, 0]) # "expand" case // # size [ 3, 3] // # stride [ 1, 0] // y.backward() # step (1): contiguous storagte tensor `s` of size 3, which // is large enough to be used as underlying storage // for `x` and `y`. // s = [ 0, 0, 0] // # step (2): since `y` has overlapping memory, index_add grad // into `s` basing on `y`'s geometry, i.e., // s[i * y.stride(0) + j * y.stride(1)] += gy[i, j]. // s = [ 3, 3, 3] // # step (4): as_strided view `s` using `x`'s geometry // s = [ 3, 3, 3] // grad_input = s.as_strided(x.size(), x.stride()) // = s.as_strided([3], [1]) // = [ 3, 3, 3] // // This is exactly what we would get if using `expand`. However, here the input // tensor doesn't have overlapping memory. If it does, we must add an extra step // before (4). Considering this case: // // t = torch.randn(3) // x = t.expand(3, 3) # input with overlapping memory // # size [3, 3] // # stride [0, 1] // y = x.as_strided([1], [1]) # contiguous output // # size [1] // # stride [1] // y.backward() # step (1): contiguous storage tensor `s` of size 3, which // is large enough to be used as underlying storage // for `x` and `y`. // s = [ 0, 0, 0] // # step (2): scatter grad into `s` basing on `y`'s geometry // s = [ 1, 0, 0] // # step (4): as_strided view `s` using `x`'s geometry // s = [ 1, 0, 0] // grad_input = s.as_strided([3, 3], [0, 1]) // = s.as_strided([3, 3], [0, 1]) // = [[ 1, 0, 0], // [ 1, 0, 0], // [ 1, 0, 0]] // Is this result correct? // // `x.as_strided([1], [1])` call is obviously equivalent with // `x[(0,) * x.dim()].view(1)` for any `x`. But autograd through the second // gives gradient `[ [ 1, 0, 0], [ 0, 0, 0], [ 0, 0, 0]]`. For this specific // case, indexing `x` at any index in first column is also equivalent, and // yields a gradient of shape `[3 x 3]` containing eight 0's and one 1. There is // an `x.size(1)`-times difference between these gradients computed from other // PyTorch ops and the gradient we got from as_strided. // // You might conclude that the gradients from as_strided is wrong. However, // let's first see why they are actually reasonable. Consider the pointwise // perturbations by `delta` anywhere in the first column of `x`. It will lead to // a `delta` change in the same memory location, and then `y` will change by // `delta`. So one can say the gradient should be exactly 1 at the first column, // as given by our above procedure. // // In the above computation of numerical gradients, they only match the // analytical results because strides and memory locations are considered in the // forward pass, i.e., this op (including both forward and backward) is // layout-aware. // // However, in PyTorch, most (probably all) other ops (forward and backward) are // layout-agnostic. E.g., // // t = torch.randn(1) // x = t.expand(2) // y = x.sum() // y.backward() // // Layout-agnostic autograd (as it is currently in PyTorch) will give you // // gy = 1 // gx = [ 1, 1] # SumBackward: torch.ones_like(x) // gt = [ 2] # ExpandBackward: gx.sum() // // Note that `gx = [ 1, 1]`. However, if you perturb any value in `x` by `delta` // (the other will also change by `delta`), `y` will change by `2 * delta`. So // the gradients, if strides are taken into consideration, should be 2. // // Layout-aware autograd should give you // // gy = 1 // gx = [ 2, 2] # Because the backward considers the fact that the input `x` // # is already expanded. // gt = [ 2] # Layout-aware backward of expand is just a slicing because // # the previous backward should have already taken care of // # strides and made sure that gradients are the same along the // # expanded dimension. // // As shown above, these two types are not compatible. Therefore, we must either // make as_strided layout-agnostic, or make all other ops layout-aware. // // It is difficult to support layout-aware autograd (at least in the current // codebase structure), because it would mean // 1. storing tensor geometries of every input tensor for backward // 2. depending on input geometry, the gradient computed from backward change // 3. ideally enforcing gradient of T to always have same strides as T // (although these two methods only differ when it comes to overlapping memory) // // Therefore, we must formulate `as_strided` in a layout-agnostic way, i.e., // giving the same output regardless of the input layout. We consider // `input.stride()` as a separate independent fixed argument `input_stride`. // Then, `as_strided(input, size, stride)` can be thought of as: // 1. "Scatter" each value of `input` into a "storage" using storage location // computed from the value's index in `input`, `input.size()` and // `input_stride`, but if N values end up in the same location, the value // is average of those N values (they will be the same value anyways). // // Formal description: // Denote the set of all input indices that pointing to the same storage // location `storage[n]` as `S(n)`, i.e., // // S(n) = { index : == n, index is valid given input.size() }, // // where `` is the dot product between `x` and `y`. // // Then, the process is: // // storage[n] = Avg { S(n) } // // Note that all values in `S(n)` are the same (they point to the same // memory location anyways, so this step doesn't change anything, but // effectively avoids having the denpendency on the layout of `input`. // I.e., the result holds fixed regardless of the layout of `input`, as // long as `input_stride` is fixed. // // NOTE: for forward pass, we can equivalently simply selet any one of // `S(n)` as `storage[n]`. However, cosnidering this as an average // operation makes backward easier (so all values in set // `{ grad_input[i] : i in S(n) }` are the same, and it can use the // same geometry as input). // 2. As usual, return the as_strided view of `storage` using required output // `size` and `stride`. // // To backward through this layout-agnostic version, we simply add the following // step: // .... (scatter gradients into the storage tensor using output geometry) // 3. For all storage location n, `storage[n] /= |S(n)|`. // .... (return as_strided view of the storage tensor using input geometry) // // Finally, we note that these general operations are expensive, so we apply the // following optimizations: // Add step (0): For all output dimension `d` with output stride 0, sum the // gradients along dimension `d` (don't keepdim), and remove // dimension `d` from output size and stride. // (An optimization for "expand" cases so we may avoid step (3)) // Only apply step (3) when input tensor has overlapping memory. // // FULL ALGORITHM: // 0. For all output dimension `d` with output stride 0, sum the gradients // along dimension `d` (don't keepdim), and remove dimension `d` from // output size and stride. // 1. Create some underlying flattened tensor as if it is the base tensor // representing the contiguous memory storage for both input and output. // 2. Use the output geometry to scatter (or index_add) the gradients into // this storage tensor `storage`. // 3. If input tensor has overlapping memory, // For all storage location `i`, `storage[i] /= N(i)`, where `N(i)` is the // number of indices in input geometry pointing to the same storage // location `i` (i.e., `|S(i)|` in equations above). // 4. Return the as_strided view of the storage tensor using input geometry. // // See NOTE [ Detecting Memory Overlap Within A Strided Tensor ] on how to // roughly detech overlapping memory. // NOTE [ Detecting Memory Overlap Within A Strided Tensor ] // // Checking memory overlap within a strided tensor is the special case of // detecting memory overlap of two strided tensors, where the two tensors start // at the same memory address. The later is HARD (see #8212). // // But even this special case isn't simple. This note describes a check for a // even more constrained simple case where we can be certain that there is no // overlap. // // The checking algorithm can be described as: // 0. Return [ pass check ] if any dimension has size 0 // 1. Ignore all dimensions that have size 1 // 2. If no remaining dimensions, return [ pass check ] // 3. Sort the remaining dimensions according to the strides decreasingly // 4. Check that for each dimension k, // // stride[k] > \sum_{ i > k } (size[i] - 1) * stride[i] // // That is equivalent to, after reordering the dimensions so strides are // in decreasing order, checking that stride of each dimension is larger // than the maximum memory offset in a slice at that dimension. // // Obviously this check passes for contiguous tensors ( the dimensions will be // already sorted with LHS = stride[0] = \prod size[i] being exactly 1 larger // than RHS ). Similarly, the check passes for tensors contiguous in all but // the last dimension, and LHS = stride[0] = stride[-1] * \prod size[i] being // exactly stride[-1] larger than RHS. (*) // // We will show that these view operations, including all our view operations // *except for* general as_strided and unfold, also preserve this invariant: // // alias: Obviously preserves // // expand: All changed dimensions are removed in step (1) // // view: Consider the input dimensions as grouped into consecutive // dimension "blocks", where dimensions are contiguous in each one. // one. view only works when the output dimensions can also be // grouped into the same consecutive blocks of same ordering. // // NB: this means that the number of elements and stride of the // last dimension in each block is the same in input and // output. (**) // // Notation: // Consider a single such block B, // ... B_prev[-1]], [ B[0], ..., B[i], ..., B[k] = B[-1] ], [ B_next[0], ... // start--^^^^ ^^^^^^^^^^^^--end // Each B[i] denotes a dimension index such that B[i] = B[0] + i. // // We first show that in a tensor (i.e., input) satisfies the // invariant, after sorting, the dimensions within each block // still remain consecutive. (***) // // After removing dimensions of size 1, the dimensions within a // block is already sorted by strides in descending order. So // sorting all dimensions will not change the relative ordering // among them. // // Assume that some block B is not consecutive after sorting, // i.e., there exists a dimension d between B[0] and B[-1] in // sorted order. // // By (*), we know that // stride[B[0]] // = \sum_{i > 0} (size[B[i]] - 1) * stride[B[i]] + stride[B[-1]] // < \sum_{i > 0} (size[B[i]] - 1) * stride[B[i]] + stride[d] // <= \sum_{i > 0} (size[B[i]] - 1) * stride[B[i]] + (size[d] - 1) * stride[d] // <= \sum{j > B[0]} (size[j] - 1) * stride[j], // // where the first < comes from sorting and // the second <= comes from the fact that dimension d // exists after step (1) and // thus must have size greater // than 1 // the third <= comes from the fact that each term in // the sum is non-negative // // Then we have a countradiction as the invariant must not be // satisfied at B[0]. So the original proposition is true. // // Now that we established the above claim (***), we consider the // view operation as first sorting the dimensions (i.e., blocks), // apply the original view (since it only cares dimensions being // consecutive and contiguous withtin each block), and then undo // the sort. // // Consider a single block B in the output, // ... ], [ B[0], ..., B[i], ..., B[k] = B[-1] ], [ ... // start--^^^^ ^^^^^^^^^^^^--end // // By (*), we know that for all i // stride[i] = stride[B[-1]] + // \sum_{j=i+1}^{k} (size[B[j]] - 1) * stride[B[j]] // // Then the invariant is obviously satisfied at every dimension // in this block if it is satisfied at dimnesion B[-1]. It only // remains to show that it is satisfied at the last dimension in // each block. // // Since the same blocks are present in both input and output // with the same ordering, we will abuse the notation in the // following statements. // // By (*), we know that the following holds for both input and // output, for any block B: // \sum_{i > B[-1]} (size[i] - 1) * stride[i] // = \sum_{block B' after B} \prod_{j in B'} size[B[j]] * stride[B'[-1]] // = \sum_{block B' after B} numel(B') * stride[B'[-1]]. // ^^^^^^^^^^^^^^^^^^^^^^^|^^^^^^^^^^^^^^^^^^^^^^^^^^ // By (**), we know that, this quantity in the above equation // remains the same in input and output. So both // \sum_{i > B[-1]} (size[i] - 1) * stride[i] // and // stride[B[-1]] // are the same in input and output. // // These two quantities are exactly the LHS and RHS of the // invariant inequality. Since by assumption the invariant is // satisfied in input at B[-1], it is also satisfied in output at // B[-1]. This concludes the proof. // // squeeze: Special case of view // // unsqueeze: Special case of view // // slice: Consider slicing dimension i with step = k >= 1. // // Let stride' and size' be the output strides and sizes. We have // // stride'[i] = k * stride[i] // size'[i] <= floor(size[i] / k) // // If size'[i] = 1, invariant is obviously satisfied as we are // just removing a dimension (afte step (1)). // // Assume size'[i] > 1. // // By assumption, the invariant is satisfied at every dimension // in input. // // For any dimension j, if stride[j] > stride[i], we have // stride'[j] = stride[j] // > (size[i] - 1) * stride[i] // = (size[i] / k * k - 1) * k * stride[i] / k // = (size[i] / k - 1 / k) * stride'[i] // >= (size'[i] - 1 / k) * stride'[i] // >= stride'[i]. // // If stride[j] < stride[i], we have // stride'[j] = stride[j] < stride[i] <= stride'[i]. // // So the sorting order remains unchanged after slice. // // Since // (size'[i] - 1) * stride'[i] // = (floor(size[i] / k) - 1) * k * stride[i] // <= (size[i] / k - 1) * k * stride[i] // = (size[i] - k) * stride[i] // <= (size[i] - 1) * * stride[i], // the term from this dimension i in the invariant inequality at // other dimensions can only decrease after slice. So the // invariant is preserved. // // narrow: Special case of slice // // select: narrow + squeeze // // permute: Sorting makes permutation of dimensions irrelevant // // transpose: Sorting makes swapping dimensions irrelevant // // diagonal: Effectively merging two dimensions i and j into a new // dimension k s.t. // stride'[k] = stride[i] + stride[j] // size'[k] <= min(size[i], size[j]), // where stride and size are on the input, and stride' and size' // are on the output. // // Assuming that size[i] > 1 and size[j] > 1. If any has size 1, // then this is unsqueeze on that dimension. // // WLOG, say stride[i] >= stride[j]. // // Each dimension d in input with stride[d] > stride[j] has // stride'[d] = stride[d] // > (size[i] - 1) * stride[i] + (size[j] - 1) * stride[j] // >= stride[i] + stride[j] // = stride[k]. // So, considering the sorted dimensions, this is effectively // removing i, and replacing j with k. // // For dimensions d with stride[i] < stride[d] < stride[j], the // term from dimension i is removed in the invariant inequality. // For dimensions d with stride[d] > stride[j], we have // (size'[k] - 1) * stride'[k] // <= (min(size[i], size[j]) - 1) * (stride[i] + stride[j]) // <= (size[i] - 1) * stride[i] + (size[j] - 1) * stride[j], // so the term from i and j in the invariant can only decrease. // // So this is generally relaxing the constraint, and thus it // preserves it. // This implements steps (2)~(4) of the algorithm in // NOTE [ Detecting Memory Overlap Within A Strided Tensor ] // Helper for as_strided_backward static inline bool _maybe_overlapping_memory(IntArrayRef sizes, IntArrayRef strides) { if (sizes.size() > 0) { std::vector argsort(sizes.size()); std::iota(argsort.begin(), argsort.end(), 0); std::sort(argsort.begin(), argsort.end(), [&](std::size_t i, std::size_t j){ return strides[i] < strides[j]; }); int64_t max_index_in_slice = 0; for (auto i : argsort) { auto stride_ = strides[i]; if (stride_ <= max_index_in_slice) { return true; } max_index_in_slice += stride_ * (sizes[i] - 1); } } return false; } // Returns the minimum storage size needed to contain a tensor of sizes, strides, and storage_offset // Helper for as_strided_backward static inline int64_t _min_storage_size(IntArrayRef sizes, IntArrayRef strides, int64_t storage_offset) { int64_t storage_size = storage_offset + 1; int64_t dim = sizes.size(); for (int64_t i = 0; i < dim; i++) { auto size_i = sizes[i]; if (size_i == 0) { return storage_offset; } storage_size += (size_i - 1) * strides[i]; } return storage_size; } // See NOTE [ as_strided Backward and layout-aware/agnostic autograd ] for explanation Tensor as_strided_backward(Tensor grad, TensorGeometry input_geometry, IntArrayRef sizes, IntArrayRef strides, optional storage_offset_) { // For output geometry, // check for size 0 dimensions, // skip size 1 dimensions, // reduce grad on expanded dims (stride=0, size>1) // Step (0) for the algorithm in NOTE [ as_strided Backward and layout-aware/agnostic autograd ] // Step (0)~(1) for the algorithm in NOTE [ Detecting Memory Overlap Within A Strided Tensor ] // on output geometry auto storage_offset = storage_offset_.value_or(input_geometry.storage_offset()); auto odim = grad.dim(); std::vector out_sizes_, out_strides_; out_sizes_.reserve(odim); out_strides_.reserve(odim); for (int64_t i = odim - 1; i >= 0; i--) { auto size_i = sizes[i]; auto stride_i = strides[i]; if (size_i == 0) { return at::zeros(input_geometry.sizes(), grad.options()); } else if (size_i == 1) { grad = grad.squeeze(i); } else if (stride_i == 0) { grad = grad.sum(i, false); } else { out_sizes_.insert(out_sizes_.begin(), size_i); out_strides_.insert(out_strides_.begin(), stride_i); } } // Step (2)~(4) for the algorithm in NOTE [ Detecting Memory Overlap Within A Strided Tensor ] // on output geometry auto out_maybe_overlap = _maybe_overlapping_memory(out_sizes_, out_strides_); // For input geometry, // check for size 0 dimensions, // skip size 1 dimensions, // Step (0)~(1) for the algorithm in NOTE [ Detecting Memory Overlap Within A Strided Tensor ] // on input geometry auto idim = input_geometry.dim(); IntArrayRef inp_sizes = input_geometry.sizes(), inp_strides = input_geometry.strides(); std::vector inp_sizes_, inp_strides_; inp_sizes_.reserve(idim); inp_strides_.reserve(idim); for (int64_t i = idim - 1; i >= 0; i--) { auto size_i = inp_sizes[i]; auto stride_i = inp_strides[i]; if (size_i == 0) { return at::zeros(input_geometry.sizes(), grad.options()); } else if (size_i != 1) { inp_sizes_.insert(inp_sizes_.begin(), size_i); inp_strides_.insert(inp_strides_.begin(), stride_i); } } // Step (1)~(4) for the algorithm in NOTE [ Detecting Memory Overlap Within A Strided Tensor ] // on input geometry auto inp_maybe_overlap = _maybe_overlapping_memory(inp_sizes_, inp_strides_); // Rest of this function implements // Step (1)~(4) for the algorithm in NOTE [ as_strided Backward and layout-aware/agnostic autograd ] // TODO: Raise if not all output values are visible in input geometry. // Technically speaking, if you treat those values as constants, not // raising is fine, and mathematically correct. However, these values // really are contained in some base tensor, and by treating them as // constants we are ignoring this tight dependency. Therefore, it is // more sensible to raise here. // Step (1): create underlying tensor as "storage" auto shared_offset = std::min(input_geometry.storage_offset(), storage_offset); auto inp_effective_offset = input_geometry.storage_offset() - shared_offset; auto out_effective_offset = storage_offset - shared_offset; auto base_size = std::max( _min_storage_size(inp_sizes_, inp_strides_, inp_effective_offset), _min_storage_size(out_sizes_, out_strides_, out_effective_offset) ); auto storage = grad.new_zeros({base_size}); // prepare indices tensor if we will do index_add_ later c10::optional flatten_full_indices; if (inp_maybe_overlap || out_maybe_overlap) { flatten_full_indices = at::arange(0, base_size, grad.options().dtype(at::kLong)); } // Step (2): use output geometry to scatter gradients into storage if (out_maybe_overlap) { auto out_indices = flatten_full_indices->as_strided(out_sizes_, out_strides_, out_effective_offset); storage.index_add_(0, out_indices.reshape(-1), grad.reshape(-1)); } else { // assume that new tensors have 0 storage offset storage.as_strided(out_sizes_, out_strides_, out_effective_offset).copy_(grad); } // Step (3): if input tensor has overlapping memory, divide scattered gradient // at storage[i] by the number of times i shows up in input geometry if (inp_maybe_overlap) { auto count = at::zeros_like(storage, LEGACY_CONTIGUOUS_MEMORY_FORMAT); auto inp_indices = flatten_full_indices->as_strided(inp_sizes_, inp_strides_, inp_effective_offset).reshape(-1); count.index_add_(0, inp_indices, at::ones({1}, grad.options()).expand_as(inp_indices)); storage.div_(count); // this will give nan outside visible range } // Step (4): return as_strided view of the storage tensor with input geometry return storage.as_strided(inp_sizes, inp_strides, inp_effective_offset); } std::tuple atan2_backward(const Tensor& grad, const Tensor& self, const Tensor& other, std::array output_mask) { if (!grad.defined()) { return std::tuple{Tensor(), Tensor()}; } auto recip = (self * self + other * other).reciprocal(); return std::tuple{ output_mask[0] ? grad * other * recip : Tensor(), output_mask[1] ? grad * -self * recip : Tensor() }; } // TODO: Seriously consider writing the derivative formulas for // each output separately; there is not all that much sharing // of computation going on here. std::tuple prelu_double_backward( const Tensor & grad_grad_input, const Tensor & grad_grad_weight, const Tensor & grad_out, const Tensor & input_, const Tensor & weight_) { if (!(grad_grad_input.defined() || grad_grad_weight.defined() || grad_out.defined())) { return std::tuple(Tensor(), Tensor(), Tensor()); } auto input = input_.contiguous(); auto weight = weight_.contiguous(); // Zero-fill undefined grads (TODO: do this more efficiently) auto ggI = grad_grad_input.defined() ? grad_grad_input.contiguous() : at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); auto ggW = grad_grad_weight.defined() ? grad_grad_weight.contiguous() : at::zeros_like(weight, LEGACY_CONTIGUOUS_MEMORY_FORMAT); auto gO = grad_out.defined() ? grad_out.contiguous() : at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); auto positive_mask = (input > 0).type_as(ggI); auto nonpositive_mask = (input <= 0).type_as(ggW); // Explanation: Let input be i, weight be w, grad_output be gO. // f(i, w) = i if i > 0 // = w * i if i <= 0 // gI = df/di * gO = gO if i > 0 gW = df/dw * gO = 0 if i > 0 // = gO * w if i <= 0 = gO * i if i <= 0 // The rest is taking derivatives of these wrt i, w, gO and summing/expanding properly. if (weight.numel() == 1) { // from PReLU.forward: num_parameters == 0 is used indicate that a // single weight is shared among all input channels. // this is a little tricky because PReLU currently doesn't take a shape so the weight may be // 1-d when the input is a scalar (and there isn't a good Parameter API for that anyway until Variable // and tensor are merged). So, use weight and ggW as 0-dim in this case. bool scalar_input_1d_weight = (positive_mask.dim() == 0 && weight.dim() == 1); auto weight_maybe_squeeze = scalar_input_1d_weight ? weight.squeeze() : weight; auto ggW_maybe_squeeze = scalar_input_1d_weight ? ggW.squeeze() : ggW; auto mask = positive_mask + nonpositive_mask * weight_maybe_squeeze.expand_as(input); auto ggO = ggI * mask + ggW_maybe_squeeze.expand_as(gO) * (nonpositive_mask * input); return std::tuple( ggO, ggW_maybe_squeeze.expand_as(gO) * gO * nonpositive_mask, (ggI * gO * nonpositive_mask).sum().expand_as(weight) ); } else { // Expand ggW to match size of ggI; a simple expand doesn't work because // ggW is the size of the input channel (dim==1 unless there is only 1 dimension). For example, // let ggI be size (3,4,5,6,7) and ggW be size (4). Then we unsqueeze ggW to be size (4,1,1,1) // so the expand succeeds. auto dims_to_unsqueeze = std::max(input.dim() - 2, 0); auto ggW_expanded = ggW; for (int64_t i = 0; i < dims_to_unsqueeze; i++) { ggW_expanded = ggW_expanded.unsqueeze(1); } ggW_expanded = ggW_expanded.expand_as(ggI); auto gI = ggW_expanded * gO * nonpositive_mask; auto gW = ggI * gO * nonpositive_mask; if (input.dim() > 1) { gW = gW.sum(0); } while (gW.dim() > 1) { gW = gW.sum(1); } Tensor ggO; if (gO.requires_grad()) { // expand weight as input as in ggW/ggI above auto weight_expanded = weight; for (int64_t i = 0; i < dims_to_unsqueeze; i++) { weight_expanded = weight_expanded.unsqueeze(1); } weight_expanded = weight_expanded.expand_as(input); auto mask = positive_mask + nonpositive_mask * weight_expanded; ggO = ggI * mask + ggW_expanded * nonpositive_mask * input; } return std::tuple{ggO, gI, gW}; } } // https://j-towns.github.io/papers/svd-derivative.pdf // // This makes no assumption on the signs of sigma. Tensor svd_backward(const std::vector &grads, const Tensor& self, bool some, bool compute_uv, const Tensor& raw_u, const Tensor& sigma, const Tensor& raw_v) { TORCH_CHECK(compute_uv, "svd_backward: Setting compute_uv to false in torch.svd doesn't compute singular matrices, ", "and hence we cannot compute backward. Please use torch.svd(compute_uv=True)"); auto m = self.size(-2); auto n = self.size(-1); auto k = sigma.size(-1); auto gsigma = grads[1]; auto u = raw_u; // Currently torch.svd for complex dtypes returns the conjugate of V, // while the backward formula is derived with just V (without the conjugation) // therefore here we need to conjugate the V output of SVD and grads[2]. // Once https://github.com/pytorch/pytorch/issues/45821 is resolved // extra .conj(), that are marked below in the code, shall be removed. auto v = raw_v.conj(); // TODO: remove .conj() auto gu = grads[0]; auto gv = grads[2].conj(); // TODO: remove .conj() if (!some) { // We ignore the free subspace here because possible base vectors cancel // each other, e.g., both -v and +v are valid base for a dimension. // Don't assume behavior of any particular implementation of svd. u = raw_u.narrow(-1, 0, k); v = raw_v.narrow(-1, 0, k).conj(); // TODO: remove .conj() if (gu.defined()) { gu = gu.narrow(-1, 0, k); } if (gv.defined()) { gv = gv.narrow(-1, 0, k); } } auto vh = v.conj().transpose(-2, -1); Tensor sigma_term; if (gsigma.defined()) { gsigma = gsigma.to(self.dtype()); // computes u @ diag(gsigma) @ vh sigma_term = at::matmul(u * gsigma.unsqueeze(-2), vh); } else { sigma_term = at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT); } // in case that there are no gu and gv, we can avoid the series of kernel // calls below if (!gv.defined() && !gu.defined()) { return sigma_term; } auto uh = u.conj().transpose(-2, -1); auto im = at::eye(m, self.options()); auto in = at::eye(n, self.options()); auto sigma_mat = sigma.diag_embed(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).to(self.dtype()); auto sigma_mat_inv = sigma.pow(-1).diag_embed(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).to(self.dtype()); auto sigma_sq = sigma.pow(2); auto F = sigma_sq.unsqueeze(-2) - sigma_sq.unsqueeze(-1); // The following two lines invert values of F, and fills the diagonal with 0s. // Notice that F currently has 0s on diagonal. So we fill diagonal with +inf // first to prevent nan from appearing in backward of this function. F.diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).fill_(INFINITY); F = F.pow(-1); Tensor u_term, v_term; if (gu.defined()) { auto guh = gu.conj().transpose(-2, -1); u_term = at::matmul(u, at::matmul(F.mul(at::matmul(uh, gu) - at::matmul(guh, u)), sigma_mat)); if (m > k) { u_term = u_term + at::matmul(im - at::matmul(u, uh), at::matmul(gu, sigma_mat_inv)); } u_term = at::matmul(u_term, vh); } else { u_term = at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT); } if (gv.defined()) { auto gvh = gv.conj().transpose(-2, -1); v_term = at::matmul(sigma_mat, at::matmul(F.mul(at::matmul(vh, gv) - at::matmul(gvh, v)), vh)); if (n > k) { v_term = v_term + at::matmul(sigma_mat_inv, at::matmul(gvh, in - at::matmul(v, vh))); } v_term = at::matmul(u, v_term); } else { v_term = at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT); } // for complex-valued input there is an additional term // https://giggleliu.github.io/2019/04/02/einsumbp.html // https://arxiv.org/abs/1909.02659 if (self.is_complex() && gu.defined()) { // computes L = Identity.mul(uh @ gu) Tensor L = at::matmul(uh, gu).diagonal(0, -2, -1).diag_embed(0, -2, -1); L = L - L.conj().transpose(-2, -1); Tensor imag_term = 0.5 * at::matmul(at::matmul(at::matmul(u, L), sigma_mat_inv), vh); return u_term + sigma_term + v_term + imag_term; } return u_term + sigma_term + v_term; } // "An extended collection of matrix derivative results for forward and reverse mode algorithmic differentiation" // https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf Tensor eig_backward(const std::vector &grads, const Tensor& self, bool eigenvectors, const Tensor& lambda, const Tensor& v) { // This gradient only works for real eigenvalues at the moment. TORCH_CHECK(eigenvectors, "eig_backward: Setting eigenvectors to false in torch.eig doesn't compute eigenvectors ", "and hence we cannot compute backward. Please use torch.eig(eigenvectors=True)"); auto zeros = at::zeros({1}, lambda.options()); TORCH_CHECK( at::allclose(lambda.slice(/*dim=*/-1, /*start=*/1, /*end=*/2), zeros), "eig_backward: Backward calculation does not support complex eigenvalues at the moment."); auto glambda = grads[0]; auto gv = grads[1]; auto vt = v.transpose(-2, -1); Tensor result; // contribution from the eigenvectors if (gv.defined()) { auto rlambda = lambda.slice(/*dim=*/-1, /*start=*/0, /*end=*/1); auto hm = rlambda.transpose(-2,-1) - rlambda; hm.diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).fill_(INFINITY); hm.pow_(-1.0); auto gvortho = gv - at::sum(gv * v, /*dim=*/-2, /*keepdim=*/true) * v; auto B = hm * at::matmul(vt, gvortho); auto A = at::matmul(B, vt); std::tie(result, std::ignore) = at::solve(A, vt); } // contribution from eigenvalues if (glambda.defined()) { auto grlambda = glambda.slice(/*dim=*/-1, /*start=*/0, /*end=*/1) * vt; auto A = at::matmul(v, grlambda); auto vvt = at::matmul(v, vt); if (result.defined()) { Tensor result1; std::tie(result1, std::ignore) = at::solve(A, vvt); result = result.add(result1); } else { std::tie(result, std::ignore) = at::solve(A, vvt); } } return result; } // http://eprints.maths.ox.ac.uk/1079/1/NA-08-01.pdf Tensor symeig_backward(const std::vector &grads, const Tensor& self, bool eigenvectors, bool upper, const Tensor& lambda, const Tensor& v) { // This gradient is symmetric, and not triangular. // symeig operates only on symmetric inputs, which is a subspace of // R^{n x n}, and hence the derivative is not well-defined for off-diagonal // elements. We resolve this by taking the gradient of the functionally independent // elements of the matrix (i.e., the lower triangular portion of the input) and then // reflect it on the upper triangular portion, thereby symmetrizing the gradient of // the symeig operation. The motivation behind this choice is that symmetric gradient // leads to stable gradient updates, and retains symmetry of the updated matrix if it // were updated by a gradient based algorithm. TORCH_CHECK(eigenvectors, "symeig_backward: Setting eigenvectors to false in torch.symeig doesn't compute eigenvectors ", "and hence we cannot compute backward. Please use torch.symeig(eigenvectors=True)"); auto glambda = grads[0]; auto gv = grads[1]; auto vh = v.conj().transpose(-2, -1); Tensor result; if (gv.defined()) { Tensor F = lambda.unsqueeze(-2) - lambda.unsqueeze(-1); F.diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).fill_(INFINITY); F.pow_(-1); result = at::matmul(v, at::matmul(F * at::matmul(vh, gv), vh)); } else { result = at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT); } if (glambda.defined()) { glambda = glambda.to(self.dtype()); // computes v @ diag(glambda) @ vh Tensor glambda_term = at::matmul(v * glambda.unsqueeze(-2), vh); if (at::inplaceIsVmapCompatible(result, glambda_term)) { result.add_(glambda_term); } else { result = result + glambda_term; } } return result.add(result.conj().transpose(-2, -1)).mul_(0.5); } Tensor linalg_qr_backward(const std::vector &grads, const Tensor& self, std::string mode, const Tensor& q, const Tensor& r){ bool compute_q, reduced; std::tie(compute_q, reduced) = at::native::_parse_qr_mode(mode); TORCH_CHECK(compute_q, "linalg_qr_backward: cannot compute backward if mode='r'. " "Please use torch.linalg.qr(..., mode='reduced')"); auto square_deep_case_backward = [](const Tensor& grad_Q, const Tensor& grad_R, const Tensor& A, const Tensor& Q, const Tensor& R) -> Tensor { // For square and deep (tall) case we refer: // Matthias Seeger, Asmus Hetzel, Zhenwen Dai, Eric Meissner, Neil D. Lawrence (2018). Auto-Differentiating Linear Algebra. // https://arxiv.org/abs/1710.08717 Section 4.3 LQ Decomposition (Note that LQ decomposition is the transpose of QR decomposition) // Hai-Jun Liao, Jin-Guo Liu, Lei Wang, Tao Xiang (2019). Differentiable Programming Tensor Networks. // https://arxiv.org/abs/1903.09650 Section 3. QR factorization // For derivations of complex-valued input case, see https://giggleliu.github.io/2019/04/02/einsumbp.html // Compute R grad_R^H Tensor R_term; if (grad_R.defined()) { R_term = at::matmul(R, grad_R.conj().transpose(-2, -1)); } else { // R is ... x N x N, grad_R is ... x N x N and grad_R.T is ... x N x N R_term = at::zeros_like(R, LEGACY_CONTIGUOUS_MEMORY_FORMAT); } // Compute grad_Q^H Q Tensor Q_term; if (grad_Q.defined()) { Q_term = at::matmul(grad_Q.conj().transpose(-2, -1), Q); } else { // Q is ... x M x N, Q.T is ... x N x M and grad_Q is ... x M x N Q_term = at::zeros_like(R, LEGACY_CONTIGUOUS_MEMORY_FORMAT); } Tensor M = R_term - Q_term; // Compute M = (tril(M) + tril(M).conj().transpose(-2, -1)) * 0.5 Identity Tensor M_tril = at::tril(M); M = M_tril + M_tril.conj().transpose(-2, -1); M.diagonal(0, -2, -1).mul_(0.5); Tensor rhs_term; if (grad_Q.defined()) { rhs_term = grad_Q + at::matmul(Q, M); } else { rhs_term = at::matmul(Q, M); } // We want to compute: (rhs_term @ R^{-H}) // Note that (rhs_term @ R^{-H}) = (R^{-1} @ rhs_solve_1^H)^H // Since R is upper triangular, we can do this using // triangular_solve(rhs_term^H, R)^H Tensor grad_A; std::tie(grad_A, std::ignore) = at::triangular_solve( rhs_term.conj().transpose(-2, -1), R, /*upper=*/true, /*transpose=*/false, /*unitriangular=*/false); return grad_A.conj().transpose(-2, -1); }; auto m = self.size(-2); auto n = self.size(-1); TORCH_CHECK( ((m <= n && (!reduced)) || reduced), "The derivative is not implemented when nrows > ncols and complete QR. "); auto grad_Q = grads[0]; auto grad_R = grads[1]; if (m >= n) { return square_deep_case_backward(grad_Q, grad_R, self, q, r); } else { // For wide (m < n) input matrices A, partition A = [X|Y] and R = [U|V] // X and U are square full rank matrices. We will partition grads, // grad_R = [grad_U | grad_V] and grad_A = [grad_X | grad_Y]. // To obtain grad_X we reuse the gradient formula from the square case. // Formulae: grad_X = square_case_grad(grad_Q_prime, grad_U, Q, U), // where grad_Q_prime = grad_Q + Y @ grad_V^H // and grad_Y = Q @ grad_V. // Then concatenate grads to get grad_A = [grad_X | grad_Y]. auto Y = self.narrow(-1, m, n - m); auto U = r.narrow(-1, 0, m); Tensor grad_Y, grad_X, grad_V, grad_Q_prime; if (grad_R.defined()) { grad_V = grad_R.narrow(-1, m, n - m); // reuse grad_R to store grad_U grad_R = grad_R.narrow(-1, 0, m); // grad_Q_prime starts with the value of Y @ grad_V^H grad_Q_prime = at::matmul(Y, grad_V.conj().transpose(-2, -1)); } else { // when grad_R is not defined then grad_V and grad_Q_prime // get initialized with zeros grad_V = at::zeros_like(Y, LEGACY_CONTIGUOUS_MEMORY_FORMAT); grad_Q_prime = at::zeros_like(q, LEGACY_CONTIGUOUS_MEMORY_FORMAT); } if (grad_Q.defined()) { // add the grad_Q term into grad_Q_prime when defined o/w is 0 grad_Q_prime = grad_Q_prime + grad_Q; } // Calculate grad_X using the helper. Grad_R contains the grad_U value grad_X = square_deep_case_backward(grad_Q_prime, grad_R, self, q, U); grad_Y = at::matmul(q, grad_V); // Concatenate grad_X and grad_Y to get grad_A. return at::cat({grad_X, grad_Y}, -1); } } // Invertible case is derived from Jacobi's formula, and also can be found at: // http://eprints.maths.ox.ac.uk/1079/1/NA-08-01.pdf Tensor det_backward(const Tensor & grad, const Tensor& self, const Tensor& det) { auto singular_case_backward = [&](const Tensor& grad, const Tensor& self, const Tensor& det) -> Tensor { Tensor u, sigma, v; std::tie(u, sigma, v) = self.svd(); auto gsigma = prod_backward(grad.unsqueeze(-1), sigma, det.unsqueeze(-1)); return svd_backward({{}, gsigma, {}}, self, true, true, u, sigma, v); }; auto nonsingular_case_backward = [&](const Tensor& grad, const Tensor& self, const Tensor& det) -> Tensor { return unsqueeze_multiple(grad * det, {-1, -2}, self.dim()) * self.inverse().transpose(-2, -1); }; if (self.dim() == 2) { if (det.item() == 0) { return singular_case_backward(grad, self, det); } else { return nonsingular_case_backward(grad, self, det); } } else { auto nonzero_det_indices = at::native::toListOfOptionalTensors(at::where(det)); c10::optional first_nonzero_det_index = nonzero_det_indices[0]; if (first_nonzero_det_index->size(0) == det.numel()) { // all determinants are nonzero (non-singular) return nonsingular_case_backward(grad, self, det); } auto zero_det_indices = at::native::toListOfOptionalTensors(at::where(det == 0)); c10::optional first_zero_det_index = zero_det_indices[0]; if (first_zero_det_index->size(0) == det.numel()) { // all determinants are zero (singular) return singular_case_backward(grad, self, det); } Tensor grad_det = at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT); // invertible case grad_det.index_put_(/*indices=*/nonzero_det_indices, /*value=*/nonsingular_case_backward(grad.index(nonzero_det_indices), self.index(nonzero_det_indices), det.index(nonzero_det_indices))); // non-invertible case, uses SVD grad_det.index_put_(/*indices=*/zero_det_indices, /*value=*/singular_case_backward(grad.index(zero_det_indices), self.index(zero_det_indices), det.index(zero_det_indices))); return grad_det; } } Tensor logdet_backward(const Tensor & grad, const Tensor& self, const Tensor& logdet) { auto singular_case_backward = [&](const Tensor& grad, const Tensor& self) -> Tensor { Tensor u, sigma, v; std::tie(u, sigma, v) = self.svd(); // logdet = \sum log(sigma) auto gsigma = grad.unsqueeze(-1).div(sigma); return svd_backward({{}, gsigma, {}}, self, true, true, u, sigma, v); }; auto nonsingular_case_backward = [&](const Tensor& grad, const Tensor& self) -> Tensor { return unsqueeze_multiple(grad, {-1, -2}, self.dim()) * self.inverse().transpose(-2, -1); }; if (self.dim() == 2) { if (logdet.item() != -INFINITY) { return nonsingular_case_backward(grad, self); } else { return singular_case_backward(grad, self); } } else { auto finite_logdet_indices = at::native::toListOfOptionalTensors(at::where(logdet != -INFINITY)); c10::optional first_finite_logdet_index = finite_logdet_indices[0]; if (first_finite_logdet_index->size(0) == logdet.numel()) { // all log determinants are finite (non-singular) return nonsingular_case_backward(grad, self); } auto neginf_logdet_indices = at::native::toListOfOptionalTensors(at::where(logdet == -INFINITY)); c10::optional first_neginf_logdet_index = neginf_logdet_indices[0]; if (first_neginf_logdet_index->size(0) == logdet.numel()) { // all log determinants are -inf (singular) return singular_case_backward(grad, self); } Tensor grad_logdet = at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT); // invertible case grad_logdet.index_put_(/*indices=*/finite_logdet_indices, /*value=*/nonsingular_case_backward(grad.index(finite_logdet_indices), self.index(finite_logdet_indices))); // non-invertible case, uses SVD grad_logdet.index_put_(/*indices=*/neginf_logdet_indices, /*value=*/singular_case_backward(grad.index(neginf_logdet_indices), self.index(neginf_logdet_indices))); return grad_logdet; } } Tensor slogdet_backward(const Tensor& grad_logabsdet, const Tensor& self, const Tensor& signdet, const Tensor& logabsdet) { auto singular_case_backward = [&](const Tensor& grad_logabsdet, const Tensor& self) -> Tensor { Tensor u, sigma, v; std::tie(u, sigma, v) = self.svd(); // sigma has all non-negative entries (also with at least one zero entry) // so logabsdet = \sum log(abs(sigma)) // but det = 0, so backward logabsdet = \sum log(sigma) auto gsigma = grad_logabsdet.unsqueeze(-1).div(sigma); return svd_backward({{}, gsigma, {}}, self, true, true, u, sigma, v); }; auto nonsingular_case_backward = [&](const Tensor& grad_logabsdet, const Tensor& self) -> Tensor { return unsqueeze_multiple(grad_logabsdet, {-1, -2}, self.dim()) * self.inverse().transpose(-2, -1); }; if (self.dim() == 2) { if (signdet.item() == 0) { return singular_case_backward(grad_logabsdet, self); } else { return nonsingular_case_backward(grad_logabsdet, self); } } else { auto nonzero_signdet_indices = at::native::toListOfOptionalTensors(at::where(signdet)); c10::optional first_nonzero_signdet_index = nonzero_signdet_indices[0]; if (first_nonzero_signdet_index->size(0) == logabsdet.numel()) { // all log determinants are finite (non-singular) return nonsingular_case_backward(grad_logabsdet, self); } auto zero_signdet_indices = at::native::toListOfOptionalTensors(at::where(signdet == 0)); c10::optional first_zero_signdet_index = zero_signdet_indices[0]; if (first_zero_signdet_index->size(0) == logabsdet.numel()) { // all log determinants are -inf (singular) return singular_case_backward(grad_logabsdet, self); } Tensor grad_slogdet = at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT); // invertible case grad_slogdet.index_put_(/*indices=*/nonzero_signdet_indices, /*value=*/nonsingular_case_backward(grad_logabsdet.index(nonzero_signdet_indices), self.index(nonzero_signdet_indices))); // non-invertible case, uses SVD grad_slogdet.index_put_(/*indices=*/zero_signdet_indices, /*value=*/singular_case_backward(grad_logabsdet.index(zero_signdet_indices), self.index(zero_signdet_indices))); return grad_slogdet; } } // Reference: // https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf // Sec. 2.3.1 Matrix inverse product std::tuple triangular_solve_backward( const Tensor & grad_x, const Tensor & grad_m, const Tensor & b, const Tensor & a, const Tensor & x, const bool upper, const bool transpose, const bool unitriangular, std::array output_mask) { Tensor grad_b, grad_a; if (grad_x.defined() || grad_m.defined()) { if (grad_x.defined()) { grad_b = std::get<0>(grad_x.triangular_solve(a.conj(), upper, !transpose, unitriangular)); if (output_mask[1]) { grad_a = transpose ? -x.conj().matmul(grad_b.transpose(-1, -2)) : -grad_b.matmul(x.transpose(-1, -2).conj()); if (upper) { grad_a = grad_a.triu((int) unitriangular); } else { grad_a = grad_a.tril(-((int) unitriangular)); } } } if (!grad_a.defined()) { grad_a = at::zeros({1}, a.options()).expand_as(a); } if (!grad_b.defined()) { grad_b = at::zeros({1}, b.options()).expand_as(b); } if (output_mask[1] && grad_m.defined()) { grad_a = grad_a.add(grad_m); } } return std::tuple{grad_b, grad_a}; } std::tuple cholesky_solve_backward( const Tensor& grad_x, const Tensor& self, const Tensor& input2, const Tensor& result, const bool upper) { Tensor grad_self, grad_input2; if (grad_x.defined()) { grad_self = grad_x.cholesky_solve(input2, /*upper=*/upper); Tensor common_term = at::matmul(grad_self, result.conj().transpose(-2, -1)); common_term = common_term + common_term.conj().transpose(-2, -1); if (upper) { grad_input2 = -at::matmul(input2, common_term); } else { grad_input2 = -at::matmul(common_term, input2); } } return std::tuple{grad_self, grad_input2}; } Tensor fft_c2r_backward(const Tensor& grad, IntArrayRef dim, int64_t normalization) { // Forward is C2R (onesided) // Think of onesided C2R irfft as // 1. fill the other half by conjugate symmetry // 2. inverse C2C ifft // 3. discard the complex dimension // So backward is // 1. R2C rfft (essentially add dummy complex dimension, and dft) // 2. accumulate gradient by conjugate symmetry // since rfft results follow conjugate symmetry, we only need to // double some entries from onesided rfft results, i.e., the ones with // their reflected indices also landing out of the onesided range. So // consider the index of last dim: // i. idx = 0. // Reflected to (N - 0) % N = 0. Not doubled. // ii 0 < idx < floor(N/2) (last). // N > N - idx > ceil(N/2) // Reflected to () // iii. idx = floor(N/2) = N/2 (last) when N even. // Reflected to (N - N/2) % N = N/2. Not doubled. // iv. idx = floor(N/2) = (N-1)/2 (last) when N odd. // Reflected to (N - (N-1)/2) % N = (N+1)/2. Doubled. // Therefore, needs to double // idx = 1, 2, ..., N/2 - 1 when N even // idx = 1, 2, ..., (N-1)/2 when N odd // that is // idx = 1, 2, ..., N - (floor(N/2) + 1) // = 1, 2, ..., N - onesided_length auto gI = at::_fft_r2c(grad, dim, normalization, /*onesided=*/true); auto double_length = grad.size(dim.back()) - gI.size(dim.back()); if (double_length > 0) { // also covers case when signal size is zero gI.narrow(dim.back(), 1, double_length).mul_(2); } return gI; } Tensor fft_r2c_backward(const Tensor& grad, IntArrayRef dim, int64_t normalization, bool onesided, int64_t last_dim_size) { if (!onesided) { return at::real(at::_fft_c2c(grad, dim, normalization, /*forward=*/false)); } // Forward is R2C (onesided) // Think of onesided R2C rfft as // 1. view as complex numbers (fill complex dim with zeros) // 2. C2C fft // 3. discard half of results // So backward is // 1. fill the other half with zeros (with `zero_grad_shape` below) // (C2C ifft only take twosided inputs so we need to fill here) // 2. inverse C2C ifft // 3. discard the complex dim auto half_sizes = grad.sizes(); at::DimVector new_grad_shape(half_sizes.begin(), half_sizes.end()); const auto last_dim = at::maybe_wrap_dim(dim.back(), half_sizes.size()); new_grad_shape[last_dim] = last_dim_size; const auto zero_length = last_dim_size - grad.size(dim.back()); auto complex_full_grad = zero_length > 0 ? at::zeros(new_grad_shape, grad.options()) : grad; if (zero_length > 0) { complex_full_grad.slice(last_dim, 0, half_sizes[last_dim]).copy_(grad); } return at::real(at::_fft_c2c(complex_full_grad, dim, normalization, /*forward=*/false)); } // Helper for batchnorm_double_backward Tensor sum_exclude_dim1(const Tensor& to_sum, bool keepdim=true) { auto r = to_sum.sum(0, keepdim); int64_t start_point_exclusive = keepdim ? 1 : 0; for (int64_t dim = r.dim() - 1; dim > start_point_exclusive; dim--) { r = r.sum(dim, keepdim); } return r; } // Helper for batchnorm_double_backward // similar to expand_as below, but doesn't do the expand_as; operates as if // reductions were done with keepdim=True Tensor unsqueeze_dim1(const Tensor& src, const Tensor& target) { auto src_expanded = src; while (src_expanded.sizes().size() < target.sizes().size() - 1) { src_expanded = src_expanded.unsqueeze(1); } if (src_expanded.sizes().size() == target.sizes().size() - 1) { src_expanded = src_expanded.unsqueeze(0); } return src_expanded; } // Helper for batchnorm_double_backward // because gamma/ggG/ggB are 1-dimensional and represent dim==1, we can't // do a straight expansion because it won't follow the broadcasting rules. Tensor expand_as_dim1(const Tensor& src, const Tensor& target) { auto src_expanded = src; while (src_expanded.sizes().size() < target.sizes().size() - 1) { src_expanded = src_expanded.unsqueeze(1); } return src_expanded.expand_as(target); } std::tuple batchnorm_double_backward( const Tensor & input, const c10::optional & gamma, const Tensor & ggI, const Tensor & ggG, const Tensor & ggB, const Tensor & gO, const c10::optional & running_mean, const c10::optional & running_var, bool training, double eps, const c10::optional & save_mean, const c10::optional & save_invstd, std::array output_mask) { bool affine = isDefined(gamma); // TODO: Do we have a ScalarOrTensor type? Would such a thing exist? Tensor gamma_expanded; Tensor ggG_expanded, ggB_expanded; if (affine) { gamma_expanded = expand_as_dim1(*gamma, input); if (ggG.defined()) { ggG_expanded = expand_as_dim1(ggG, input); } if (ggB.defined()) { ggB_expanded = expand_as_dim1(ggB, input); } } else { gamma_expanded = at::ones({}, input.options()); } // define some terms we will reuse auto M = input.size(0); for (auto s : input.sizes().slice(2)) { M *= s; } // for half inputs, save_mean, save_invstd are float (ideally, we would cast // everything else, but not now) auto mu = unsqueeze_dim1(training ? toLegacyTensor(save_mean).to(input.scalar_type()) : toLegacyTensor(running_mean), input); auto input_sub_mu = input - mu; auto sigma2_eps_neg_1_2 = unsqueeze_dim1( training ? toLegacyTensor(save_invstd).to(input.scalar_type()) : toLegacyTensor(running_var).add(Scalar(eps)).pow(-0.5), input); auto sigma2_eps_neg_1 = sigma2_eps_neg_1_2.pow(2); auto sigma2_eps_neg_3_2 = sigma2_eps_neg_1_2.pow(3); // calculate gI auto input_mu_sigma2_neg_3_2 = input_sub_mu * sigma2_eps_neg_3_2; auto gOinmu_sum = sum_exclude_dim1(gO * input_sub_mu); auto gO_sum = sum_exclude_dim1(gO); Tensor gI; if (ggI.defined() && training) { auto ggI_sum = sum_exclude_dim1(ggI); auto ggIinmu_sum = sum_exclude_dim1(ggI * input_sub_mu); auto all_sub = ((ggI_sum * gO_sum).div_(M)).sub_(sum_exclude_dim1(gO * ggI)).add_( (sigma2_eps_neg_1 * gOinmu_sum * ggIinmu_sum).mul_(3. / M)); auto gI_0t = (input_mu_sigma2_neg_3_2 * all_sub).div_(M); auto gI_1t = (ggIinmu_sum * sigma2_eps_neg_3_2).div_(M) * (gO_sum.div(M) - gO); auto gI_2t = (gOinmu_sum * sigma2_eps_neg_3_2).div_(M) * (ggI_sum.div(M) - ggI); gI = gamma_expanded * (gI_0t.add_(gI_1t).add_(gI_2t)); } // add contribution of gamma term to gI Tensor gI_G_term; if (affine && ggG.defined()) { if (training) { auto t0 = gO * sigma2_eps_neg_1_2; auto t1 = (sigma2_eps_neg_1_2 * gO_sum).div_(-M); auto t2 = (input_mu_sigma2_neg_3_2 * sum_exclude_dim1(gO * input_sub_mu)).div_(-M); gI_G_term = ggG_expanded * (t0.add_(t1).add_(t2)); gI = gI.defined() ? gI.add_(gI_G_term) : gI_G_term; } else { gI_G_term = ggG_expanded * sigma2_eps_neg_1_2 * gO; gI = gI.defined() ? gI.add_(gI_G_term) : gI_G_term; } } // this is the first backward's grad_input auto first_back_grad_input = [&](const Tensor& gO, const Tensor& gamma) -> Tensor { auto h0 = (gamma * sigma2_eps_neg_1_2).div_(M); auto h1 = (M * gO).sub_(sum_exclude_dim1(gO)).sub_( input_sub_mu.mul(sigma2_eps_neg_1) * sum_exclude_dim1(gO * input_sub_mu)); return h0 * h1; }; // calculate gG Tensor gG; if (affine && ggI.defined()) { if (training) { // gG is just the first backwards with the gamma term removed (then shaped properly) gG = ggI * first_back_grad_input(gO, at::ones({}, sigma2_eps_neg_1_2.options())); gG = sum_exclude_dim1(gG, false); } else { gG = sum_exclude_dim1(ggI * gO * sigma2_eps_neg_1_2, false); } } // calculate ggO Tensor ggO; // contribution of input term if (ggI.defined()) { if (training) { ggO = first_back_grad_input(ggI, gamma_expanded); } else { ggO = ggI * sigma2_eps_neg_1_2 * gamma_expanded; } } if (ggG.defined()) { auto ggO_G_term = ggG_expanded * input_sub_mu * sigma2_eps_neg_1_2; ggO = ggO.defined() ? ggO.add_(ggO_G_term) : ggO_G_term; } if (ggB.defined()) { auto ggO_B_term = ggB_expanded; ggO = ggO.defined() ? ggO.add_(ggO_B_term) : ggO_B_term; } if (output_mask[1] && !gG.defined()) { AT_ASSERTM(affine, "gamma should always be defined when it requires grad"); } return std::tuple{gI, gG, ggO}; } std::tuple infinitely_differentiable_native_layer_norm_backward( const Tensor& dY, const Tensor& dmean, const Tensor& drstd, const Tensor& X, const Tensor& mean, const Tensor& rstd, const c10::optional& gamma, IntArrayRef normalized_shape, double eps, std::array grad_input_mask) { const int normalized_ndim = normalized_shape.size(); const auto input_shape = X.sizes(); const auto input_ndim = X.dim(); const int axis = input_ndim - normalized_ndim; const int64_t M = at::prod_intlist(input_shape.cbegin(), input_shape.cbegin() + axis); const int64_t N = at::prod_intlist(input_shape.cbegin() + axis, input_shape.cend()); Tensor dX; Tensor dgamma; Tensor dbeta; const Tensor X_tensor = X.reshape({M, N}); const Tensor mean_tensor = mean.reshape({M, 1}); const Tensor rstd_tensor = rstd.reshape({M, 1}); const double s = 1.0 / static_cast(N); Tensor dY_tensor; if (dY.defined()) { dY_tensor = dY.reshape({M, N}); } if (grad_input_mask[0]) { Tensor gamma_tensor; if (isDefined(gamma)) { gamma_tensor = gamma->reshape({1, N}); } Tensor rstd_cube = rstd_tensor * rstd_tensor * rstd_tensor; Tensor var; Tensor dvar; if (drstd.defined()) { var = ((rstd_tensor * rstd_tensor).reciprocal_() - eps).clamp_min(0); dvar = -0.5 * rstd_cube * drstd.view({M, 1}); } Tensor ds; Tensor db; if (dY.defined()) { ds = (isDefined(gamma) ? dY_tensor * X_tensor * gamma_tensor : dY_tensor * X_tensor) .sum(1) .unsqueeze_(-1); db = (isDefined(gamma) ? dY_tensor * gamma_tensor : dY_tensor) .sum(1) .unsqueeze_(-1); const Tensor& a = rstd_tensor; const Tensor b = (db * mean_tensor - ds) * rstd_cube * s; const Tensor c = -b * mean_tensor - db * rstd_tensor * s; if (isDefined(gamma)) { dX = a * dY_tensor * gamma_tensor + b * X_tensor + c; } else { dX = a * dY_tensor + b * X_tensor + c; } if (dmean.defined() && drstd.defined()) { dX += var_std_mean_backward( {dvar, dmean.view({M, 1})}, X_tensor, var, mean_tensor, {1}, false, true, false); } dX = dX.reshape_as(X); } else if (dmean.defined() && drstd.defined()) { dX = var_std_mean_backward( {dvar, dmean.view({M, 1})}, X_tensor, var, mean_tensor, {1}, false, true, false) .reshape_as(X); } } if (grad_input_mask[1] && dY.defined()) { dgamma = (dY_tensor * (X_tensor - mean_tensor) * rstd_tensor) .sum(0) .reshape_as(toLegacyTensor(gamma)); } if (grad_input_mask[2] && dY.defined()) { dbeta = dY_tensor.sum(0).reshape_as(toLegacyTensor(gamma)); } return std::make_tuple(dX, dgamma, dbeta); } std::tuple infinitely_differentiable_native_group_norm_backward( const Tensor& dY, const Tensor& dmean, const Tensor& drstd, const Tensor& X, const Tensor& mean, const Tensor& rstd, const c10::optional& gamma, int64_t N, int64_t C, int64_t HxW, int64_t group, double eps, std::array grad_input_mask) { const int64_t G = group; const int64_t D = C / G; const double s = 1.0 / static_cast(D * HxW); Tensor dX; Tensor dgamma; Tensor dbeta; const Tensor X_tensor = X.reshape({N, G, D, HxW}); const Tensor mean_tensor = mean.reshape({N, G, 1, 1}); const Tensor rstd_tensor = rstd.reshape({N, G, 1, 1}); Tensor dY_tensor; Tensor ds; Tensor db; if (dY.defined()) { dY_tensor = dY.reshape({N, G, D, HxW}); ds = (dY_tensor * X_tensor).sum(3).unsqueeze_(-1); db = dY_tensor.sum(3).unsqueeze_(-1); } if (grad_input_mask[0]) { Tensor gamma_tensor; if (isDefined(gamma)) { gamma_tensor = gamma->reshape({1, G, D, 1}); } const Tensor var = ((rstd_tensor * rstd_tensor).reciprocal_() - eps).clamp_min(0); const Tensor rstd_cube = rstd_tensor * rstd_tensor * rstd_tensor; Tensor dvar; if (drstd.defined()) { dvar = -0.5 * rstd_cube * drstd.view({N, G, 1, 1}); } if (dY.defined()) { const Tensor a = isDefined(gamma) ? rstd_tensor * gamma_tensor : rstd_tensor; Tensor b = (isDefined(gamma) ? (ds * gamma_tensor).sum(2) : ds.sum(2)) .unsqueeze_(-2); Tensor c = (isDefined(gamma) ? (db * gamma_tensor).sum(2) : db.sum(2)) .unsqueeze_(-2); b = (c * mean_tensor - b) * rstd_cube * s; c = -b * mean_tensor - c * rstd_tensor * s; dX = a * dY_tensor + b * X_tensor + c; if (dmean.defined() && drstd.defined()) { dX += var_std_mean_backward( {dvar, dmean.view({N, G, 1, 1})}, X_tensor, var, mean_tensor, {2, 3}, false, true, false); } dX = dX.reshape_as(X); } else if (dmean.defined() && drstd.defined()) { dX = var_std_mean_backward( {dvar, dmean.view({N, G, 1, 1})}, X_tensor, var, mean_tensor, {2, 3}, false, true, false) .reshape_as(X); } } if (grad_input_mask[1] && dY.defined()) { dgamma = ((ds - db * mean_tensor) * rstd_tensor).sum(0).reshape_as(toLegacyTensor(gamma)); } if (grad_input_mask[2] && dY.defined()) { dbeta = db.sum(0).reshape_as(toLegacyTensor(gamma)); } return std::make_tuple(dX, dgamma, dbeta); } std::tuple _trilinear_backward(const Tensor& grad_out, const Tensor& i1, const Tensor& i2, const Tensor& i3, IntArrayRef expand1, IntArrayRef expand2, IntArrayRef expand3, IntArrayRef sumdim, int64_t unroll_dim, std::array grad_mask) { Tensor grad_i1, grad_i2, grad_i3; if (grad_out.defined()) { if (grad_mask[0]) grad_i1 = at::_trilinear(grad_out, i2, i3, sumdim, expand2, expand3, expand1); if (grad_mask[1]) grad_i2 = at::_trilinear(i1, grad_out, i3, expand1, sumdim, expand3, expand2); if (grad_mask[2]) grad_i3 = at::_trilinear(i1, i2, grad_out, expand1, expand2, sumdim, expand3); } return std::tuple(grad_i1, grad_i2, grad_i3); } Tensor log1p_backward(const Tensor& grad, const Tensor& self) { if (self.is_sparse()) { AT_ERROR( "log1p of a sparse tensor is made to be non-differentiable since ", "local gradient of zero is 1 / (0 + 1) = 1 and it makes the tensor dense. ", "Use a different mathematical operation which preserves sparsity of gradients, ", "or report a bug if you think this is an error."); } return grad / (self + 1).conj(); } Tensor sparse_constructor_values_backward(const Tensor& sparse_grad_out, const Tensor& indices, IntArrayRef values_shape) { // TODO: improve this backward by writing a kernel (maybe) auto dense_grad = sparse_grad_out.is_sparse() ? sparse_grad_out.to_dense() : sparse_grad_out; auto full_size = sparse_grad_out.sizes(); auto flattened_grad_shape = values_shape.vec(); flattened_grad_shape[0] = at::prod_intlist(full_size.slice(0, indices.size(0))); auto flattened_dense_grad = dense_grad.view(flattened_grad_shape); auto flattened_indices = at::sparse::flatten_indices(indices, full_size); return flattened_dense_grad.index_select(0, flattened_indices); } // Because the backward of pad(input, pads) is just pad(grad_output, [-p for p in pads]) Tensor constant_pad_nd_backward(const Tensor& grad, IntArrayRef pad) { auto negated_pad = pad.vec(); std::transform(negated_pad.cbegin(), negated_pad.cend(), negated_pad.begin(), std::negate()); return at::constant_pad_nd(grad, negated_pad, 0); } Tensor embedding_dense_double_backward(const Tensor & grad, const Tensor & indices, int64_t padding_idx) { // since first backward takes care of scaling by frequency, // we don't need to worry about it here. auto gg_weight = grad.index_select(0, indices.reshape(-1)); // reshape gradient as per the shape of indices auto size = indices.sizes().vec(); size.push_back(-1); if (padding_idx >= 0) { gg_weight.masked_fill_((indices == padding_idx).reshape({-1, 1}), 0); } return gg_weight.view(size); } Tensor index_backward(Tensor zeros_like_self, const torch::List>& indices, const Tensor& grad) { return at::_index_put_impl_(zeros_like_self, indices, grad, true, true); } Tensor _cudnn_ctc_loss_backward(const Tensor& grad_out, const Tensor& loss, const Tensor& raw_grad, bool zero_infinity) { if (zero_infinity) { return at::where( loss.unsqueeze(0).unsqueeze(2) == 0, at::zeros({0}, raw_grad.options()), raw_grad * grad_out.unsqueeze(0).unsqueeze(2)); } else { return raw_grad * grad_out.unsqueeze(0).unsqueeze(2); } } bool any_variable_defined(variable_list& variables) { for (auto variable : variables) { if (variable.defined()) { return true; } } return false; } } // namespace details } // namespace generated } // namespace autograd } // namespace torch