mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Follows #137407 Pull Request resolved: https://github.com/pytorch/pytorch/pull/137459 Approved by: https://github.com/Skylion007
291 lines
10 KiB
C++
291 lines
10 KiB
C++
#include <ATen/native/nested/NestedTensorMath.h>
|
|
|
|
#include <ATen/ATen.h>
|
|
#include <ATen/AccumulateType.h>
|
|
#include <ATen/NamedTensorUtils.h>
|
|
#include <ATen/WrapDimUtils.h>
|
|
#include <ATen/core/op_registration/op_registration.h>
|
|
#include <ATen/native/layer_norm.h>
|
|
#include <ATen/NestedTensorImpl.h>
|
|
#include <c10/core/DispatchKey.h>
|
|
#include <ATen/native/nested/NestedTensorUtils.h>
|
|
#include <c10/core/DeviceType.h>
|
|
|
|
#include <utility>
|
|
|
|
namespace at::native {
|
|
|
|
// See Note [nested tensor matmul] in NestedTensorMath.cpp
|
|
std::tuple<Tensor, Tensor> matmul_backward_nested(
|
|
const Tensor& grad,
|
|
const Tensor& self,
|
|
const Tensor& other,
|
|
std::array<bool, 2> grad_input_mask) {
|
|
if (!grad.defined()) {
|
|
return std::make_tuple(Tensor(), Tensor());
|
|
}
|
|
Tensor grad_self, grad_other;
|
|
if (grad_input_mask[0]) {
|
|
grad_self = at::matmul(grad, other.transpose(-1, -2));
|
|
}
|
|
if (grad_input_mask[1]) {
|
|
grad_other = at::matmul(self.transpose(-1, -2), grad);
|
|
}
|
|
return std::make_tuple(grad_self, grad_other);
|
|
}
|
|
|
|
std::tuple<Tensor, Tensor, Tensor> nested_linear_backward(
|
|
const Tensor& input,
|
|
const Tensor& grad_output,
|
|
const Tensor& weight,
|
|
std::array<bool, 3> output_mask) {
|
|
if (!grad_output.defined()) {
|
|
return std::tuple<Tensor, Tensor, Tensor>{Tensor(), Tensor(), Tensor()};
|
|
}
|
|
Tensor grad_input, grad_weight, grad_bias;
|
|
auto grad_output_contiguous = grad_output.contiguous();
|
|
auto* nt_grad_output = get_nested_tensor_impl(grad_output_contiguous);
|
|
auto* nt_input = get_nested_tensor_impl(input);
|
|
TORCH_INTERNAL_ASSERT(nt_grad_output != nullptr);
|
|
TORCH_INTERNAL_ASSERT(nt_input != nullptr);
|
|
TORCH_INTERNAL_ASSERT(nested_tensor_impl_is_contiguous(nt_grad_output));
|
|
auto grad_output_buffer = nt_grad_output->get_buffer();
|
|
auto input_buffer = nt_input->get_buffer();
|
|
|
|
auto reshaped_grad = grad_output_buffer.reshape({-1, weight.size(0)});
|
|
|
|
if (output_mask[0]) {
|
|
auto grad_input_buffer = at::mm(reshaped_grad, weight).view({-1});
|
|
auto grad_input_nt_size = nt_input->get_nested_sizes().clone();
|
|
grad_input = wrap_buffer(grad_input_buffer, grad_input_nt_size);
|
|
}
|
|
if (output_mask[1]) {
|
|
grad_weight =
|
|
at::mm(reshaped_grad.t(), input_buffer.reshape({-1, weight.size(1)}));
|
|
}
|
|
if (output_mask[2]) {
|
|
grad_bias = reshaped_grad.sum(0);
|
|
}
|
|
return std::tuple<Tensor, Tensor, Tensor>{grad_input, grad_weight, grad_bias};
|
|
}
|
|
|
|
Tensor nested_softmax_backward(
|
|
const Tensor& grad,
|
|
const Tensor& output,
|
|
int64_t dim,
|
|
ScalarType input_dtype) {
|
|
TORCH_INTERNAL_ASSERT(grad.is_nested(), "Should be nested grad")
|
|
TORCH_INTERNAL_ASSERT(output.is_nested(), "Should be nested output")
|
|
|
|
auto output_ptr = get_nested_tensor_impl(output);
|
|
auto grad_ptr = get_nested_tensor_impl(grad);
|
|
int64_t ntensors = output_ptr->size(0);
|
|
if (ntensors == 0) {
|
|
return grad.clone();
|
|
}
|
|
int64_t positive_dim = at::maybe_wrap_dim(dim, output_ptr->dim());
|
|
|
|
// Get the info about the output
|
|
const Tensor &output_buffer = output_ptr->get_buffer(),
|
|
&output_sizemat = output_ptr->get_nested_sizes();
|
|
|
|
// Get the info about the grad
|
|
const Tensor &grad_sizemat = grad_ptr->get_nested_sizes();
|
|
|
|
TORCH_INTERNAL_ASSERT(output_sizemat.equal(grad_sizemat));
|
|
Tensor grad_output =
|
|
wrap_buffer(at::empty_like(output_buffer), output_sizemat.clone());
|
|
|
|
// Unbind nt into individual tensor slices for calculating the derivative
|
|
std::vector<Tensor> grad_output_unbind{grad_output.unbind()},
|
|
grad_unbind{grad.unbind()}, output_unbind{output.unbind()};
|
|
|
|
for(const auto i: c10::irange(ntensors)) {
|
|
at::_softmax_backward_data_out(
|
|
grad_output_unbind[i],
|
|
grad_unbind[i],
|
|
output_unbind[i],
|
|
positive_dim - 1,
|
|
input_dtype);
|
|
}
|
|
return grad_output;
|
|
|
|
}
|
|
|
|
// Rudimentary sum backward assuming the conditions in #82387
|
|
Tensor _nested_sum_backward_cpu(
|
|
const Tensor& grad,
|
|
const Tensor& nested_self,
|
|
OptionalIntArrayRef opt_dims,
|
|
bool keepdim) {
|
|
auto nt_self = get_nested_tensor_impl(nested_self);
|
|
auto nt_grad = get_nested_tensor_impl(grad);
|
|
const Tensor& grad_buffer = nt_grad->get_buffer();
|
|
const Tensor& self_buffer = nt_self->get_buffer();
|
|
auto grad_sizes = nt_grad->get_nested_sizes();
|
|
auto self_sizes = nt_self->get_nested_sizes();
|
|
int64_t ntensors = nt_self->size(0);
|
|
const Tensor& self_grad_buffer = self_buffer.new_empty(self_buffer.sizes());
|
|
|
|
auto num_segments = at::prod(grad_sizes, -1);
|
|
auto segment_lengths = self_sizes.select(1, -1);
|
|
|
|
// This logic assumes for now that
|
|
// (1) all the gradient nested tensors are contiguous
|
|
// (2) the gradient nested tensors are stored contiguously in the buffer
|
|
AT_DISPATCH_ALL_TYPES_AND2(
|
|
ScalarType::Half, ScalarType::BFloat16, self_grad_buffer.scalar_type(), "nested_sum_dim_cpu", [&]() {
|
|
auto* self_grad_data = self_grad_buffer.data_ptr<scalar_t>();
|
|
const auto* output_grad_data = grad_buffer.const_data_ptr<scalar_t>();
|
|
int64_t out_idx = 0, in_idx = 0;
|
|
for (const auto i : c10::irange(ntensors)) {
|
|
int64_t segments = num_segments[i].item<int64_t>();
|
|
int64_t segment_length = segment_lengths[i].item<int64_t>();
|
|
for (auto j = 0; j < segments; j++) {
|
|
scalar_t output_grad = output_grad_data[out_idx];
|
|
for (auto k = 0; k < segment_length; k++) {
|
|
self_grad_data[in_idx] = output_grad;
|
|
in_idx += 1;
|
|
}
|
|
out_idx += 1;
|
|
}
|
|
}
|
|
});
|
|
|
|
return wrap_buffer(self_grad_buffer, self_sizes);
|
|
|
|
}
|
|
|
|
|
|
Tensor _nested_select_backward_symint(
|
|
const Tensor& grad,
|
|
const Tensor& nested_self,
|
|
int64_t dim,
|
|
// NOLINTNEXTLINE(performance-unnecessary-value-param)
|
|
c10::SymInt index) {
|
|
auto nt_self = get_nested_tensor_impl(nested_self);
|
|
const Tensor& self_buffer = nt_self->get_buffer();
|
|
const auto self_sizes = nt_self->get_nested_sizes();
|
|
const Tensor& self_grad_buffer = self_buffer.new_zeros(self_buffer.sizes());
|
|
|
|
auto nt_grad = wrap_buffer(self_grad_buffer, self_sizes);
|
|
nt_grad.select_symint(dim, std::move(index)).copy_(grad);
|
|
|
|
return nt_grad;
|
|
}
|
|
|
|
Tensor gelu_backwards_nested(const Tensor& grad, const Tensor& self, c10::string_view approximate){
|
|
auto partial_gelu_backward = [approximate](auto && PH1, auto && PH2) { return at::gelu_backward(std::forward<decltype(PH1)>(PH1), std::forward<decltype(PH2)>(PH2), approximate); };
|
|
return map_nt_binary(grad, self, partial_gelu_backward);
|
|
}
|
|
|
|
// Naming convention for relu
|
|
Tensor threshold_backwards_nested(const Tensor& grad_output, const Tensor& input, const Scalar& threshold){
|
|
auto partial_relu_backward = [threshold](auto && PH1, auto && PH2) { return at::threshold_backward(std::forward<decltype(PH1)>(PH1), std::forward<decltype(PH2)>(PH2), threshold); };
|
|
return map_nt_binary(grad_output, input, partial_relu_backward);
|
|
}
|
|
|
|
// Tensor grad_output, Tensor self, *, Tensor(a!) grad_input) -> Tensor(a!)
|
|
Tensor silu_backward_nested(const Tensor& grad_output, const Tensor& self){
|
|
auto partial_silu_backward = [](auto && PH1, auto && PH2) { return at::silu_backward(std::forward<decltype(PH1)>(PH1), std::forward<decltype(PH2)>(PH2)); };
|
|
return map_nt_binary(grad_output, self, partial_silu_backward);
|
|
}
|
|
|
|
std::tuple<Tensor, Tensor, Tensor> layer_norm_backward_nested(
|
|
const Tensor& grad,
|
|
const Tensor& input,
|
|
IntArrayRef normalized_shape,
|
|
const Tensor& mean,
|
|
const Tensor& rstd,
|
|
const std::optional<Tensor>& weight_opt /* optional */,
|
|
const std::optional<Tensor>& bias_opt /*{ optional */,
|
|
std::array<bool, 3> grad_input_mask) {
|
|
// For NestedTensors weight and bias are non nested.
|
|
auto* nt_impl_grad = get_nested_tensor_impl(grad);
|
|
auto* nt_impl_input = get_nested_tensor_impl(input);
|
|
const auto& weight = *weight_opt;
|
|
const auto& bias = *bias_opt;
|
|
const auto& sizes = nt_impl_input->get_nested_sizes();
|
|
auto M_N = _check_nested_layer_norm_inputs(
|
|
*nt_impl_input, normalized_shape, weight, bias);
|
|
auto M = M_N.first;
|
|
auto N = M_N.second;
|
|
|
|
auto gamma = weight.expect_contiguous();
|
|
auto beta = bias.expect_contiguous();
|
|
|
|
Tensor dInput;
|
|
Tensor dgamma;
|
|
Tensor dbeta;
|
|
auto input_buffer = nt_impl_input->get_buffer();
|
|
auto grad_buffer = nt_impl_grad->get_buffer();
|
|
// NOLINTNEXTLINE(bugprone-branch-clone)
|
|
if (grad_input_mask[0]) {
|
|
dInput = at::native::empty_like(
|
|
input_buffer,
|
|
std::nullopt /* dtype */,
|
|
std::nullopt /* layout */,
|
|
std::nullopt /* device */,
|
|
std::nullopt /* pin_memory */,
|
|
at::MemoryFormat::Contiguous);
|
|
} else {
|
|
dInput = at::native::zeros_like(
|
|
input_buffer,
|
|
std::nullopt /* dtype */,
|
|
std::nullopt /* layout */,
|
|
std::nullopt /* device */,
|
|
std::nullopt /* pin_memory */,
|
|
at::MemoryFormat::Contiguous);
|
|
}
|
|
if (grad_input_mask[1]) {
|
|
dgamma = M > 0 ? at::native::empty_like(
|
|
*gamma,
|
|
std::nullopt /* dtype */,
|
|
std::nullopt /* layout */,
|
|
std::nullopt /* device */,
|
|
std::nullopt /* pin_memory */,
|
|
at::MemoryFormat::Contiguous)
|
|
: at::native::zeros_like(
|
|
*gamma,
|
|
std::nullopt /* dtype */,
|
|
std::nullopt /* layout */,
|
|
std::nullopt /* device */,
|
|
std::nullopt /* pin_memory */,
|
|
at::MemoryFormat::Contiguous);
|
|
}
|
|
if (grad_input_mask[2]) {
|
|
dbeta = M > 0 ? at::native::empty_like(
|
|
*beta,
|
|
std::nullopt /* dtype */,
|
|
std::nullopt /* layout */,
|
|
std::nullopt /* device */,
|
|
std::nullopt /* pin_memory */,
|
|
at::MemoryFormat::Contiguous)
|
|
: at::native::zeros_like(
|
|
*beta,
|
|
std::nullopt /* dtype */,
|
|
std::nullopt /* layout */,
|
|
std::nullopt /* device */,
|
|
std::nullopt /* pin_memory */,
|
|
at::MemoryFormat::Contiguous);
|
|
}
|
|
if (M > 0) {
|
|
LayerNormBackwardKernel(
|
|
input_buffer.is_cuda() ? kCUDA : kCPU,
|
|
grad_buffer,
|
|
input_buffer,
|
|
mean,
|
|
rstd,
|
|
*gamma,
|
|
M,
|
|
N,
|
|
&dInput,
|
|
&dgamma,
|
|
&dbeta);
|
|
}
|
|
return std::make_tuple(
|
|
wrap_buffer(dInput, sizes), std::move(dgamma), std::move(dbeta));
|
|
}
|
|
|
|
} // namespace at::native
|