pytorch/aten/src/ATen/native/nested/NestedTensorBackward.cpp

291 lines
10 KiB
C++

#include <ATen/native/nested/NestedTensorMath.h>
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/NamedTensorUtils.h>
#include <ATen/WrapDimUtils.h>
#include <ATen/core/op_registration/op_registration.h>
#include <ATen/native/layer_norm.h>
#include <ATen/NestedTensorImpl.h>
#include <c10/core/DispatchKey.h>
#include <ATen/native/nested/NestedTensorUtils.h>
#include <c10/core/DeviceType.h>
#include <utility>
namespace at::native {
// See Note [nested tensor matmul] in NestedTensorMath.cpp
std::tuple<Tensor, Tensor> matmul_backward_nested(
const Tensor& grad,
const Tensor& self,
const Tensor& other,
std::array<bool, 2> grad_input_mask) {
if (!grad.defined()) {
return std::make_tuple(Tensor(), Tensor());
}
Tensor grad_self, grad_other;
if (grad_input_mask[0]) {
grad_self = at::matmul(grad, other.transpose(-1, -2));
}
if (grad_input_mask[1]) {
grad_other = at::matmul(self.transpose(-1, -2), grad);
}
return std::make_tuple(grad_self, grad_other);
}
std::tuple<Tensor, Tensor, Tensor> nested_linear_backward(
const Tensor& input,
const Tensor& grad_output,
const Tensor& weight,
std::array<bool, 3> output_mask) {
if (!grad_output.defined()) {
return std::tuple<Tensor, Tensor, Tensor>{Tensor(), Tensor(), Tensor()};
}
Tensor grad_input, grad_weight, grad_bias;
auto grad_output_contiguous = grad_output.contiguous();
auto* nt_grad_output = get_nested_tensor_impl(grad_output_contiguous);
auto* nt_input = get_nested_tensor_impl(input);
TORCH_INTERNAL_ASSERT(nt_grad_output != nullptr);
TORCH_INTERNAL_ASSERT(nt_input != nullptr);
TORCH_INTERNAL_ASSERT(nested_tensor_impl_is_contiguous(nt_grad_output));
auto grad_output_buffer = nt_grad_output->get_buffer();
auto input_buffer = nt_input->get_buffer();
auto reshaped_grad = grad_output_buffer.reshape({-1, weight.size(0)});
if (output_mask[0]) {
auto grad_input_buffer = at::mm(reshaped_grad, weight).view({-1});
auto grad_input_nt_size = nt_input->get_nested_sizes().clone();
grad_input = wrap_buffer(grad_input_buffer, grad_input_nt_size);
}
if (output_mask[1]) {
grad_weight =
at::mm(reshaped_grad.t(), input_buffer.reshape({-1, weight.size(1)}));
}
if (output_mask[2]) {
grad_bias = reshaped_grad.sum(0);
}
return std::tuple<Tensor, Tensor, Tensor>{grad_input, grad_weight, grad_bias};
}
Tensor nested_softmax_backward(
const Tensor& grad,
const Tensor& output,
int64_t dim,
ScalarType input_dtype) {
TORCH_INTERNAL_ASSERT(grad.is_nested(), "Should be nested grad")
TORCH_INTERNAL_ASSERT(output.is_nested(), "Should be nested output")
auto output_ptr = get_nested_tensor_impl(output);
auto grad_ptr = get_nested_tensor_impl(grad);
int64_t ntensors = output_ptr->size(0);
if (ntensors == 0) {
return grad.clone();
}
int64_t positive_dim = at::maybe_wrap_dim(dim, output_ptr->dim());
// Get the info about the output
const Tensor &output_buffer = output_ptr->get_buffer(),
&output_sizemat = output_ptr->get_nested_sizes();
// Get the info about the grad
const Tensor &grad_sizemat = grad_ptr->get_nested_sizes();
TORCH_INTERNAL_ASSERT(output_sizemat.equal(grad_sizemat));
Tensor grad_output =
wrap_buffer(at::empty_like(output_buffer), output_sizemat.clone());
// Unbind nt into individual tensor slices for calculating the derivative
std::vector<Tensor> grad_output_unbind{grad_output.unbind()},
grad_unbind{grad.unbind()}, output_unbind{output.unbind()};
for(const auto i: c10::irange(ntensors)) {
at::_softmax_backward_data_out(
grad_output_unbind[i],
grad_unbind[i],
output_unbind[i],
positive_dim - 1,
input_dtype);
}
return grad_output;
}
// Rudimentary sum backward assuming the conditions in #82387
Tensor _nested_sum_backward_cpu(
const Tensor& grad,
const Tensor& nested_self,
OptionalIntArrayRef opt_dims,
bool keepdim) {
auto nt_self = get_nested_tensor_impl(nested_self);
auto nt_grad = get_nested_tensor_impl(grad);
const Tensor& grad_buffer = nt_grad->get_buffer();
const Tensor& self_buffer = nt_self->get_buffer();
auto grad_sizes = nt_grad->get_nested_sizes();
auto self_sizes = nt_self->get_nested_sizes();
int64_t ntensors = nt_self->size(0);
const Tensor& self_grad_buffer = self_buffer.new_empty(self_buffer.sizes());
auto num_segments = at::prod(grad_sizes, -1);
auto segment_lengths = self_sizes.select(1, -1);
// This logic assumes for now that
// (1) all the gradient nested tensors are contiguous
// (2) the gradient nested tensors are stored contiguously in the buffer
AT_DISPATCH_ALL_TYPES_AND2(
ScalarType::Half, ScalarType::BFloat16, self_grad_buffer.scalar_type(), "nested_sum_dim_cpu", [&]() {
auto* self_grad_data = self_grad_buffer.data_ptr<scalar_t>();
const auto* output_grad_data = grad_buffer.const_data_ptr<scalar_t>();
int64_t out_idx = 0, in_idx = 0;
for (const auto i : c10::irange(ntensors)) {
int64_t segments = num_segments[i].item<int64_t>();
int64_t segment_length = segment_lengths[i].item<int64_t>();
for (auto j = 0; j < segments; j++) {
scalar_t output_grad = output_grad_data[out_idx];
for (auto k = 0; k < segment_length; k++) {
self_grad_data[in_idx] = output_grad;
in_idx += 1;
}
out_idx += 1;
}
}
});
return wrap_buffer(self_grad_buffer, self_sizes);
}
Tensor _nested_select_backward_symint(
const Tensor& grad,
const Tensor& nested_self,
int64_t dim,
// NOLINTNEXTLINE(performance-unnecessary-value-param)
c10::SymInt index) {
auto nt_self = get_nested_tensor_impl(nested_self);
const Tensor& self_buffer = nt_self->get_buffer();
const auto self_sizes = nt_self->get_nested_sizes();
const Tensor& self_grad_buffer = self_buffer.new_zeros(self_buffer.sizes());
auto nt_grad = wrap_buffer(self_grad_buffer, self_sizes);
nt_grad.select_symint(dim, std::move(index)).copy_(grad);
return nt_grad;
}
Tensor gelu_backwards_nested(const Tensor& grad, const Tensor& self, c10::string_view approximate){
auto partial_gelu_backward = [approximate](auto && PH1, auto && PH2) { return at::gelu_backward(std::forward<decltype(PH1)>(PH1), std::forward<decltype(PH2)>(PH2), approximate); };
return map_nt_binary(grad, self, partial_gelu_backward);
}
// Naming convention for relu
Tensor threshold_backwards_nested(const Tensor& grad_output, const Tensor& input, const Scalar& threshold){
auto partial_relu_backward = [threshold](auto && PH1, auto && PH2) { return at::threshold_backward(std::forward<decltype(PH1)>(PH1), std::forward<decltype(PH2)>(PH2), threshold); };
return map_nt_binary(grad_output, input, partial_relu_backward);
}
// Tensor grad_output, Tensor self, *, Tensor(a!) grad_input) -> Tensor(a!)
Tensor silu_backward_nested(const Tensor& grad_output, const Tensor& self){
auto partial_silu_backward = [](auto && PH1, auto && PH2) { return at::silu_backward(std::forward<decltype(PH1)>(PH1), std::forward<decltype(PH2)>(PH2)); };
return map_nt_binary(grad_output, self, partial_silu_backward);
}
std::tuple<Tensor, Tensor, Tensor> layer_norm_backward_nested(
const Tensor& grad,
const Tensor& input,
IntArrayRef normalized_shape,
const Tensor& mean,
const Tensor& rstd,
const std::optional<Tensor>& weight_opt /* optional */,
const std::optional<Tensor>& bias_opt /*{ optional */,
std::array<bool, 3> grad_input_mask) {
// For NestedTensors weight and bias are non nested.
auto* nt_impl_grad = get_nested_tensor_impl(grad);
auto* nt_impl_input = get_nested_tensor_impl(input);
const auto& weight = *weight_opt;
const auto& bias = *bias_opt;
const auto& sizes = nt_impl_input->get_nested_sizes();
auto M_N = _check_nested_layer_norm_inputs(
*nt_impl_input, normalized_shape, weight, bias);
auto M = M_N.first;
auto N = M_N.second;
auto gamma = weight.expect_contiguous();
auto beta = bias.expect_contiguous();
Tensor dInput;
Tensor dgamma;
Tensor dbeta;
auto input_buffer = nt_impl_input->get_buffer();
auto grad_buffer = nt_impl_grad->get_buffer();
// NOLINTNEXTLINE(bugprone-branch-clone)
if (grad_input_mask[0]) {
dInput = at::native::empty_like(
input_buffer,
std::nullopt /* dtype */,
std::nullopt /* layout */,
std::nullopt /* device */,
std::nullopt /* pin_memory */,
at::MemoryFormat::Contiguous);
} else {
dInput = at::native::zeros_like(
input_buffer,
std::nullopt /* dtype */,
std::nullopt /* layout */,
std::nullopt /* device */,
std::nullopt /* pin_memory */,
at::MemoryFormat::Contiguous);
}
if (grad_input_mask[1]) {
dgamma = M > 0 ? at::native::empty_like(
*gamma,
std::nullopt /* dtype */,
std::nullopt /* layout */,
std::nullopt /* device */,
std::nullopt /* pin_memory */,
at::MemoryFormat::Contiguous)
: at::native::zeros_like(
*gamma,
std::nullopt /* dtype */,
std::nullopt /* layout */,
std::nullopt /* device */,
std::nullopt /* pin_memory */,
at::MemoryFormat::Contiguous);
}
if (grad_input_mask[2]) {
dbeta = M > 0 ? at::native::empty_like(
*beta,
std::nullopt /* dtype */,
std::nullopt /* layout */,
std::nullopt /* device */,
std::nullopt /* pin_memory */,
at::MemoryFormat::Contiguous)
: at::native::zeros_like(
*beta,
std::nullopt /* dtype */,
std::nullopt /* layout */,
std::nullopt /* device */,
std::nullopt /* pin_memory */,
at::MemoryFormat::Contiguous);
}
if (M > 0) {
LayerNormBackwardKernel(
input_buffer.is_cuda() ? kCUDA : kCPU,
grad_buffer,
input_buffer,
mean,
rstd,
*gamma,
M,
N,
&dInput,
&dgamma,
&dbeta);
}
return std::make_tuple(
wrap_buffer(dInput, sizes), std::move(dgamma), std::move(dbeta));
}
} // namespace at::native