mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[3/N] Avoid copy in std::get (#141843)
Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/141843 Approved by: https://github.com/Skylion007
This commit is contained in:
parent
add4a42ea2
commit
1fa27f6e82
|
|
@ -76,7 +76,7 @@ void BatchedTensorImpl::checkInvariants() const {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// The following are publically exposed as methods of Tensor
|
// The following are publicly exposed as methods of Tensor
|
||||||
|
|
||||||
IntArrayRef BatchedTensorImpl::strides_custom() const {
|
IntArrayRef BatchedTensorImpl::strides_custom() const {
|
||||||
return strides_default();
|
return strides_default();
|
||||||
|
|
@ -113,7 +113,7 @@ const char* BatchedTensorImpl::tensorimpl_type_name() const {
|
||||||
return "BatchedTensorImpl";
|
return "BatchedTensorImpl";
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor makeBatched(const Tensor& tensor, BatchDims bdims) {
|
Tensor makeBatched(Tensor tensor, BatchDims bdims) {
|
||||||
TORCH_INTERNAL_ASSERT(!isBatchedTensor(tensor));
|
TORCH_INTERNAL_ASSERT(!isBatchedTensor(tensor));
|
||||||
auto tensor_dim = tensor.dim();
|
auto tensor_dim = tensor.dim();
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
|
|
@ -124,15 +124,15 @@ Tensor makeBatched(const Tensor& tensor, BatchDims bdims) {
|
||||||
std::all_of(bdims.begin(), bdims.end(),
|
std::all_of(bdims.begin(), bdims.end(),
|
||||||
[](const BatchDim& bdim) { return bdim.level() < kVmapNumLevels; }),
|
[](const BatchDim& bdim) { return bdim.level() < kVmapNumLevels; }),
|
||||||
"We only support up to ", kVmapNumLevels, " nested vmaps");
|
"We only support up to ", kVmapNumLevels, " nested vmaps");
|
||||||
return at::detail::make_tensor<BatchedTensorImpl>(tensor, std::move(bdims));
|
return at::detail::make_tensor<BatchedTensorImpl>(std::move(tensor), std::move(bdims));
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor addBatchDim(const Tensor& tensor, int64_t level, int64_t dim) {
|
Tensor addBatchDim(Tensor tensor, int64_t level, int64_t dim) {
|
||||||
const auto* batched = maybeGetBatchedImpl(tensor);
|
const auto* batched = maybeGetBatchedImpl(tensor);
|
||||||
if (!batched) {
|
if (!batched) {
|
||||||
BatchDims bdims;
|
BatchDims bdims;
|
||||||
bdims.emplace_back(level, dim);
|
bdims.emplace_back(level, dim);
|
||||||
return at::detail::make_tensor<BatchedTensorImpl>(tensor, std::move(bdims));
|
return at::detail::make_tensor<BatchedTensorImpl>(std::move(tensor), std::move(bdims));
|
||||||
}
|
}
|
||||||
BatchDims new_bdims(batched->bdims().begin(), batched->bdims().end());
|
BatchDims new_bdims(batched->bdims().begin(), batched->bdims().end());
|
||||||
auto actual_bdim = batched->actualDim(dim, /*wrap_dim=*/true);
|
auto actual_bdim = batched->actualDim(dim, /*wrap_dim=*/true);
|
||||||
|
|
|
||||||
|
|
@ -148,10 +148,10 @@ inline std::ostream& operator<<(std::ostream& out, const BatchDim& bdim) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use this to construct a BatchedTensor from a regular Tensor
|
// Use this to construct a BatchedTensor from a regular Tensor
|
||||||
TORCH_API Tensor makeBatched(const Tensor& tensor, BatchDims bdims);
|
TORCH_API Tensor makeBatched(Tensor tensor, BatchDims bdims);
|
||||||
|
|
||||||
// Adds a batch dim to `tensor`, returning a BatchedTensor
|
// Adds a batch dim to `tensor`, returning a BatchedTensor
|
||||||
TORCH_API Tensor addBatchDim(const Tensor& tensor, int64_t level, int64_t dim);
|
TORCH_API Tensor addBatchDim(Tensor tensor, int64_t level, int64_t dim);
|
||||||
|
|
||||||
// Checks if an inplace operation on self and other is "vmap compatible".
|
// Checks if an inplace operation on self and other is "vmap compatible".
|
||||||
// See NOTE: [vmap-incompatible in-place operations] for the definition of this.
|
// See NOTE: [vmap-incompatible in-place operations] for the definition of this.
|
||||||
|
|
|
||||||
|
|
@ -353,7 +353,7 @@ void expectOutOfPlaceMultiUnboxedCallingWorks(const KernelFunction& func) {
|
||||||
auto t1 = at::zeros({1});
|
auto t1 = at::zeros({1});
|
||||||
auto t2 = at::zeros({1});
|
auto t2 = at::zeros({1});
|
||||||
|
|
||||||
std::tuple<at::Tensor&, at::Tensor&> tup = func.call<
|
auto [t1_out, t2_out] = func.call<
|
||||||
std::tuple<at::Tensor&, at::Tensor&>, at::Scalar, at::Scalar, at::Tensor&, at::Tensor&
|
std::tuple<at::Tensor&, at::Tensor&>, at::Scalar, at::Scalar, at::Tensor&, at::Tensor&
|
||||||
>(dummy, CPU_TEST_SET, s1, s2, t1, t2);
|
>(dummy, CPU_TEST_SET, s1, s2, t1, t2);
|
||||||
|
|
||||||
|
|
@ -361,11 +361,9 @@ void expectOutOfPlaceMultiUnboxedCallingWorks(const KernelFunction& func) {
|
||||||
EXPECT_EQ(t1.item().toFloat(), 1.0f);
|
EXPECT_EQ(t1.item().toFloat(), 1.0f);
|
||||||
EXPECT_EQ(t2.item().toFloat(), 2.0f);
|
EXPECT_EQ(t2.item().toFloat(), 2.0f);
|
||||||
|
|
||||||
auto t1_out = std::get<0>(tup);
|
|
||||||
EXPECT_EQ(t1_out.item().toFloat(), 1.0f);
|
EXPECT_EQ(t1_out.item().toFloat(), 1.0f);
|
||||||
EXPECT_TRUE(t1_out.is_same(t1));
|
EXPECT_TRUE(t1_out.is_same(t1));
|
||||||
|
|
||||||
auto t2_out = std::get<1>(tup);
|
|
||||||
EXPECT_EQ(t2_out.item().toFloat(), 2.0f);
|
EXPECT_EQ(t2_out.item().toFloat(), 2.0f);
|
||||||
EXPECT_TRUE(t2_out.is_same(t2));
|
EXPECT_TRUE(t2_out.is_same(t2));
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -218,47 +218,43 @@ static std::tuple<Tensor, std::optional<int64_t>> masked_select_backward_batch_r
|
||||||
|
|
||||||
static std::tuple<Tensor, std::optional<int64_t>> cdist_backward_batch_rule(
|
static std::tuple<Tensor, std::optional<int64_t>> cdist_backward_batch_rule(
|
||||||
const Tensor& grad, std::optional<int64_t> grad_bdim,
|
const Tensor& grad, std::optional<int64_t> grad_bdim,
|
||||||
const Tensor& x1, std::optional<int64_t> x1_bdim,
|
Tensor x1, std::optional<int64_t> x1_bdim,
|
||||||
const Tensor& x2, std::optional<int64_t> x2_bdim,
|
Tensor x2, std::optional<int64_t> x2_bdim,
|
||||||
const double p,
|
const double p,
|
||||||
const Tensor& cdist, std::optional<int64_t> cdist_bdim) {
|
const Tensor& cdist, std::optional<int64_t> cdist_bdim) {
|
||||||
|
|
||||||
auto x1_ = x1;
|
|
||||||
if (cdist_bdim && !x1_bdim) {
|
if (cdist_bdim && !x1_bdim) {
|
||||||
// We need to make sure that x1 has batch dim if cdist has one
|
// We need to make sure that x1 has batch dim if cdist has one
|
||||||
// otherwise, we get
|
// otherwise, we get
|
||||||
// RuntimeError: Function CdistBackward0 returned an invalid gradient at index 1 - got [5]
|
// RuntimeError: Function CdistBackward0 returned an invalid gradient at index 1 - got [5]
|
||||||
// but expected shape compatible with [4, 5]
|
// but expected shape compatible with [4, 5]
|
||||||
auto bs = cdist.size(*cdist_bdim);
|
auto bs = cdist.size(*cdist_bdim);
|
||||||
x1_ = ensure_has_bdim(x1, false, bs);
|
x1 = ensure_has_bdim(x1, false, bs).contiguous();
|
||||||
x1_ = x1_.contiguous();
|
|
||||||
x1_bdim = 0;
|
x1_bdim = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
// We need to apply the same preprocessing on x1 and x2 as in the forward pass
|
// We need to apply the same preprocessing on x1 and x2 as in the forward pass
|
||||||
// _binary_pointwise_batch_rule
|
// _binary_pointwise_batch_rule
|
||||||
auto x12 = _binary_pointwise_helper(x1_, x1_bdim, x2, x2_bdim);
|
std::tie(x1, x2)= _binary_pointwise_helper(x1, x1_bdim, x2, x2_bdim);
|
||||||
x1_ = std::move(std::get<0>(x12));
|
|
||||||
auto& x2_ = std::get<1>(x12);
|
|
||||||
|
|
||||||
auto grad_ = moveBatchDimToFront(grad, grad_bdim);
|
auto grad_ = moveBatchDimToFront(grad, grad_bdim);
|
||||||
if ((x1_bdim || x2_bdim) && !grad_bdim) {
|
if ((x1_bdim || x2_bdim) && !grad_bdim) {
|
||||||
// We need to make sure that grad has batch dim if x1 or x2 have one
|
// We need to make sure that grad has batch dim if x1 or x2 have one
|
||||||
// Probably, there is an assumption on the strides.
|
// Probably, there is an assumption on the strides.
|
||||||
// Otherwise grad input contains thrash values, e.g. -7.0816e+29, 7.0816e+29
|
// Otherwise grad input contains thrash values, e.g. -7.0816e+29, 7.0816e+29
|
||||||
auto bs = get_bdim_size2(x1_, 0, x2_, 0);
|
auto bs = get_bdim_size2(x1, 0, x2, 0);
|
||||||
grad_ = ensure_has_bdim(grad_, grad_bdim.has_value(), bs);
|
grad_ = ensure_has_bdim(grad_, grad_bdim.has_value(), bs);
|
||||||
grad_ = grad_.contiguous();
|
grad_ = grad_.contiguous();
|
||||||
}
|
}
|
||||||
|
|
||||||
auto out = at::_cdist_backward(grad_, x1_, x2_, p, cdist);
|
auto out = at::_cdist_backward(grad_, x1, x2, p, cdist);
|
||||||
|
|
||||||
std::optional<int64_t> out_bdim = std::nullopt;
|
std::optional<int64_t> out_bdim = std::nullopt;
|
||||||
if (x1_bdim || x2_bdim) {
|
if (x1_bdim || x2_bdim) {
|
||||||
out_bdim = 0;
|
out_bdim = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
return std::make_tuple(out, out_bdim);
|
return std::make_tuple(std::move(out), out_bdim);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void fill__Tensor_batch_rule(
|
static void fill__Tensor_batch_rule(
|
||||||
|
|
|
||||||
|
|
@ -42,6 +42,7 @@ static Tensor padRight(const Tensor& tensor, std::optional<int64_t> has_bdim, in
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename F, F Func>
|
template<typename F, F Func>
|
||||||
|
static
|
||||||
std::tuple<Tensor, std::optional<int64_t>,Tensor, std::optional<int64_t>,Tensor, std::optional<int64_t>>
|
std::tuple<Tensor, std::optional<int64_t>,Tensor, std::optional<int64_t>,Tensor, std::optional<int64_t>>
|
||||||
batch_norm_batch_rule(
|
batch_norm_batch_rule(
|
||||||
const Tensor& input, std::optional<int64_t> input_bdim,
|
const Tensor& input, std::optional<int64_t> input_bdim,
|
||||||
|
|
@ -70,10 +71,10 @@ batch_norm_batch_rule(
|
||||||
if (!input_bdim && !running_mean_bdim && !running_var_bdim) {
|
if (!input_bdim && !running_mean_bdim && !running_var_bdim) {
|
||||||
const auto dummy_weight = at::ones(input.size(1), input.options()); // cudnn and miopen require a weight
|
const auto dummy_weight = at::ones(input.size(1), input.options()); // cudnn and miopen require a weight
|
||||||
const auto dummy_bias = at::zeros(input.size(1), input.options()); // without this, get "strides() called on undefined Tensor" on cuda
|
const auto dummy_bias = at::zeros(input.size(1), input.options()); // without this, get "strides() called on undefined Tensor" on cuda
|
||||||
const auto result = Func(input, dummy_weight, dummy_bias, running_mean_opt, running_var_opt, training, momentum, eps);
|
auto result = Func(input, dummy_weight, dummy_bias, running_mean_opt, running_var_opt, training, momentum, eps);
|
||||||
result0 = std::get<0>(result).transpose(0, 1); // [C, B, *]
|
result0 = std::get<0>(result).transpose(0, 1); // [C, B, *]
|
||||||
mean = std::get<1>(result);
|
mean = std::move(std::get<1>(result));
|
||||||
rstd = std::get<2>(result);
|
rstd = std::move(std::get<2>(result));
|
||||||
} else {
|
} else {
|
||||||
bdim_size = get_bdim_size3(input, input_bdim, running_mean, running_mean_bdim, running_var, running_var_bdim);
|
bdim_size = get_bdim_size3(input, input_bdim, running_mean, running_mean_bdim, running_var, running_var_bdim);
|
||||||
auto input_ = moveBatchDimToFront(input, input_bdim);
|
auto input_ = moveBatchDimToFront(input, input_bdim);
|
||||||
|
|
@ -95,12 +96,12 @@ batch_norm_batch_rule(
|
||||||
|
|
||||||
const auto dummy_weight = at::ones(input_.size(1), input_.options()); // cudnn and miopen require a weight
|
const auto dummy_weight = at::ones(input_.size(1), input_.options()); // cudnn and miopen require a weight
|
||||||
const auto dummy_bias = at::zeros(input_.size(1), input_.options()); // without this, get "strides() called on undefined Tensor" on cuda
|
const auto dummy_bias = at::zeros(input_.size(1), input_.options()); // without this, get "strides() called on undefined Tensor" on cuda
|
||||||
const auto result = Func(input_, dummy_weight, dummy_bias, running_mean_, running_var_, training, momentum, eps);
|
auto result = Func(input_, dummy_weight, dummy_bias, running_mean_, running_var_, training, momentum, eps);
|
||||||
result0 = std::get<0>(result).transpose(0, 1); // [(B0, C), B, *]
|
result0 = std::get<0>(result).transpose(0, 1); // [(B0, C), B, *]
|
||||||
|
mean = std::move(std::get<1>(result));
|
||||||
|
rstd = std::move(std::get<2>(result));
|
||||||
result0 = reshape_dim_outof(0, bdim_size.value(), result0); // [B0, C, B, *]
|
result0 = reshape_dim_outof(0, bdim_size.value(), result0); // [B0, C, B, *]
|
||||||
mean = std::get<1>(result);
|
|
||||||
mean = reshape_dim_outof(0, bdim_size.value(), mean); // [B0, C]
|
mean = reshape_dim_outof(0, bdim_size.value(), mean); // [B0, C]
|
||||||
rstd = std::get<2>(result);
|
|
||||||
rstd = reshape_dim_outof(0, bdim_size.value(), rstd); // [B0, C]
|
rstd = reshape_dim_outof(0, bdim_size.value(), rstd); // [B0, C]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -124,6 +125,7 @@ batch_norm_batch_rule(
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename F, F Func>
|
template<typename F, F Func>
|
||||||
|
static
|
||||||
std::tuple<at::Tensor, std::optional<int64_t>> batch_norm_backward_no_weight_bias_batch_rule(
|
std::tuple<at::Tensor, std::optional<int64_t>> batch_norm_backward_no_weight_bias_batch_rule(
|
||||||
const at::Tensor & grad_out, std::optional<int64_t> grad_out_bdim,
|
const at::Tensor & grad_out, std::optional<int64_t> grad_out_bdim,
|
||||||
const at::Tensor & input, std::optional<int64_t> input_bdim,
|
const at::Tensor & input, std::optional<int64_t> input_bdim,
|
||||||
|
|
@ -142,9 +144,9 @@ std::tuple<at::Tensor, std::optional<int64_t>> batch_norm_backward_no_weight_bia
|
||||||
TORCH_INTERNAL_ASSERT(!mean_bdim);
|
TORCH_INTERNAL_ASSERT(!mean_bdim);
|
||||||
TORCH_INTERNAL_ASSERT(!rstd_bdim);
|
TORCH_INTERNAL_ASSERT(!rstd_bdim);
|
||||||
const auto dummy_weight = at::ones(input.size(1), input.options());
|
const auto dummy_weight = at::ones(input.size(1), input.options());
|
||||||
const auto result = Func(
|
auto result =Func(
|
||||||
grad_out, input, dummy_weight, running_mean_opt, running_var_opt, mean, rstd, training, eps, {true, false, false});
|
grad_out, input, dummy_weight, running_mean_opt, running_var_opt, mean, rstd, training, eps, {true, false, false});
|
||||||
return std::make_tuple(std::get<0>(result), std::nullopt);
|
return {std::move(std::get<0>(result)), std::nullopt};
|
||||||
}
|
}
|
||||||
|
|
||||||
auto grad_out_ = moveBatchDimToFront(grad_out, grad_out_bdim);
|
auto grad_out_ = moveBatchDimToFront(grad_out, grad_out_bdim);
|
||||||
|
|
@ -196,6 +198,7 @@ std::tuple<at::Tensor, std::optional<int64_t>> batch_norm_backward_no_weight_bia
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename F, F Func>
|
template<typename F, F Func>
|
||||||
|
static
|
||||||
std::tuple<at::Tensor,at::Tensor,at::Tensor> batch_norm_backward_plumbing(
|
std::tuple<at::Tensor,at::Tensor,at::Tensor> batch_norm_backward_plumbing(
|
||||||
const at::Tensor & grad_out,
|
const at::Tensor & grad_out,
|
||||||
const at::Tensor & input,
|
const at::Tensor & input,
|
||||||
|
|
@ -270,7 +273,7 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> batch_norm_backward_plumbing(
|
||||||
unwrapTensorAtLevel(grad_normalized_input.transpose(0, 1), cur_level); // [B0, B, C, *]
|
unwrapTensorAtLevel(grad_normalized_input.transpose(0, 1), cur_level); // [B0, B, C, *]
|
||||||
|
|
||||||
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
|
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
|
||||||
const auto results = batch_norm_backward_no_weight_bias_batch_rule<F, Func>(
|
auto results = batch_norm_backward_no_weight_bias_batch_rule<F, Func>(
|
||||||
grad_normalized_input_value, grad_normalized_input_bdim,
|
grad_normalized_input_value, grad_normalized_input_bdim,
|
||||||
input_value, input_bdim,
|
input_value, input_bdim,
|
||||||
running_mean_value, running_mean_bdim,
|
running_mean_value, running_mean_bdim,
|
||||||
|
|
@ -278,7 +281,7 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> batch_norm_backward_plumbing(
|
||||||
save_mean_value, save_mean_bdim,
|
save_mean_value, save_mean_bdim,
|
||||||
save_rstd_value, save_rstd_bdim,
|
save_rstd_value, save_rstd_bdim,
|
||||||
training, eps);
|
training, eps);
|
||||||
grad_input = makeBatched(std::get<0>(results), std::get<1>(results), cur_level);
|
grad_input = makeBatched(std::move(std::get<0>(results)), std::get<1>(results), cur_level);
|
||||||
}
|
}
|
||||||
return std::make_tuple(grad_input, grad_weight, grad_bias);
|
return std::make_tuple(grad_input, grad_weight, grad_bias);
|
||||||
}
|
}
|
||||||
|
|
@ -312,16 +315,13 @@ static std::tuple<Tensor,Tensor,Tensor> native_group_norm_plumbing(
|
||||||
const auto bdim_size = input_value.size(*input_bdim);
|
const auto bdim_size = input_value.size(*input_bdim);
|
||||||
|
|
||||||
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
|
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
|
||||||
const auto result = at::native_group_norm(input_, std::nullopt, std::nullopt, N * bdim_size, C, HxW, group, eps);
|
std::tie(result0, mean, rstd) = at::native_group_norm(input_, std::nullopt, std::nullopt, N * bdim_size, C, HxW, group, eps);
|
||||||
result0 = makeBatched(reshape_dim_outof(0, bdim_size, std::get<0>(result)), 0, cur_level);
|
result0 = makeBatched(reshape_dim_outof(0, bdim_size, result0), 0, cur_level);
|
||||||
mean = makeBatched(reshape_dim_outof(0, bdim_size, std::get<1>(result)), 0, cur_level);
|
mean = makeBatched(reshape_dim_outof(0, bdim_size, mean), 0, cur_level);
|
||||||
rstd = makeBatched(reshape_dim_outof(0, bdim_size, std::get<2>(result)), 0, cur_level);
|
rstd = makeBatched(reshape_dim_outof(0, bdim_size, rstd), 0, cur_level);
|
||||||
} else {
|
} else {
|
||||||
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
|
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
|
||||||
const auto result = at::native_group_norm(input_value, std::nullopt, std::nullopt, N, C, HxW, group, eps);
|
std::tie(result0, mean, rstd) = at::native_group_norm(input_value, std::nullopt, std::nullopt, N, C, HxW, group, eps);
|
||||||
result0 = std::get<0>(result);
|
|
||||||
mean = std::get<1>(result);
|
|
||||||
rstd = std::get<2>(result);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (weight.defined()) {
|
if (weight.defined()) {
|
||||||
|
|
@ -334,10 +334,10 @@ static std::tuple<Tensor,Tensor,Tensor> native_group_norm_plumbing(
|
||||||
result0 = result0 + padded_bias;
|
result0 = result0 + padded_bias;
|
||||||
}
|
}
|
||||||
|
|
||||||
return std::make_tuple(result0, mean, rstd);
|
return std::make_tuple(std::move(result0), std::move(mean), std::move(rstd));
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::tuple<at::Tensor, std::optional<int64_t>> group_norm_backward_no_weight_bias_batch_rule(
|
static at::Tensor group_norm_backward_no_weight_bias_batch_rule(
|
||||||
const at::Tensor & grad_out, std::optional<int64_t> grad_out_bdim,
|
const at::Tensor & grad_out, std::optional<int64_t> grad_out_bdim,
|
||||||
const at::Tensor & input, std::optional<int64_t> input_bdim,
|
const at::Tensor & input, std::optional<int64_t> input_bdim,
|
||||||
const at::Tensor & mean, std::optional<int64_t> mean_bdim,
|
const at::Tensor & mean, std::optional<int64_t> mean_bdim,
|
||||||
|
|
@ -359,15 +359,13 @@ static std::tuple<at::Tensor, std::optional<int64_t>> group_norm_backward_no_wei
|
||||||
mean_ = reshape_dim_into(0, 0, mean_); // [B0 * N, G]
|
mean_ = reshape_dim_into(0, 0, mean_); // [B0 * N, G]
|
||||||
rstd_ = reshape_dim_into(0, 0, rstd_); // [B0 * N, G]
|
rstd_ = reshape_dim_into(0, 0, rstd_); // [B0 * N, G]
|
||||||
|
|
||||||
const auto result = native_group_norm_backward(
|
auto result0 = std::get<0>(native_group_norm_backward(
|
||||||
grad_out_.contiguous(),
|
grad_out_.contiguous(),
|
||||||
input_.contiguous(),
|
input_.contiguous(),
|
||||||
mean_.contiguous(),
|
mean_.contiguous(),
|
||||||
rstd_.contiguous(),
|
rstd_.contiguous(),
|
||||||
std::nullopt, N * bdim_size, C, HxW, group, {true, false, false});
|
std::nullopt, N * bdim_size, C, HxW, group, {true, false, false}));
|
||||||
auto result0 = std::get<0>(result);
|
return reshape_dim_outof(0, bdim_size, result0);
|
||||||
result0 = reshape_dim_outof(0, bdim_size, result0);
|
|
||||||
return std::make_tuple(result0, 0);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::tuple<Tensor,Tensor,Tensor> native_group_norm_backward_plumbing(
|
static std::tuple<Tensor,Tensor,Tensor> native_group_norm_backward_plumbing(
|
||||||
|
|
@ -422,19 +420,19 @@ static std::tuple<Tensor,Tensor,Tensor> native_group_norm_backward_plumbing(
|
||||||
unwrapTensorAtLevel(grad_normalized_input, cur_level);
|
unwrapTensorAtLevel(grad_normalized_input, cur_level);
|
||||||
|
|
||||||
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
|
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
|
||||||
const auto res = group_norm_backward_no_weight_bias_batch_rule(
|
auto tensor = group_norm_backward_no_weight_bias_batch_rule(
|
||||||
grad_normalized_input_value, grad_normalized_input_bdim,
|
grad_normalized_input_value, grad_normalized_input_bdim,
|
||||||
input_value, input_bdim,
|
input_value, input_bdim,
|
||||||
mean_value, mean_bdim,
|
mean_value, mean_bdim,
|
||||||
rstd_value, rstd_bdim,
|
rstd_value, rstd_bdim,
|
||||||
N, C, HxW, group
|
N, C, HxW, group
|
||||||
);
|
);
|
||||||
grad_input = makeBatched(std::get<0>(res), std::get<1>(res), cur_level);
|
grad_input = makeBatched(std::move(tensor), 0, cur_level);
|
||||||
}
|
}
|
||||||
return std::make_tuple(grad_input, grad_weight, grad_bias);
|
return std::make_tuple(grad_input, grad_weight, grad_bias);
|
||||||
}
|
}
|
||||||
|
|
||||||
C10_ALWAYS_INLINE bool has_same_shape(
|
static bool has_same_shape(
|
||||||
const Tensor& tensor, std::optional<int64_t> tensor_bdim,
|
const Tensor& tensor, std::optional<int64_t> tensor_bdim,
|
||||||
c10::SymIntArrayRef normalized_shape) {
|
c10::SymIntArrayRef normalized_shape) {
|
||||||
if (!tensor.defined()) {
|
if (!tensor.defined()) {
|
||||||
|
|
@ -457,7 +455,7 @@ C10_ALWAYS_INLINE bool has_same_shape(
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
C10_ALWAYS_INLINE void check_same_shape(
|
static C10_ALWAYS_INLINE void check_same_shape(
|
||||||
const Tensor& tensor, std::optional<int64_t> tensor_bdim,
|
const Tensor& tensor, std::optional<int64_t> tensor_bdim,
|
||||||
c10::SymIntArrayRef normalized_shape, const std::string& name) {
|
c10::SymIntArrayRef normalized_shape, const std::string& name) {
|
||||||
TORCH_CHECK(has_same_shape(tensor, tensor_bdim, normalized_shape),
|
TORCH_CHECK(has_same_shape(tensor, tensor_bdim, normalized_shape),
|
||||||
|
|
@ -469,7 +467,7 @@ C10_ALWAYS_INLINE void check_same_shape(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ugh, hard to deduplicate
|
// Ugh, hard to deduplicate
|
||||||
C10_ALWAYS_INLINE void _check_layer_norm_inputs(
|
static C10_ALWAYS_INLINE void _check_layer_norm_inputs(
|
||||||
SymIntArrayRef normalized_shape,
|
SymIntArrayRef normalized_shape,
|
||||||
const Tensor& weight, std::optional<int64_t> weight_bdim,
|
const Tensor& weight, std::optional<int64_t> weight_bdim,
|
||||||
const Tensor& bias, std::optional<int64_t> bias_bdim) {
|
const Tensor& bias, std::optional<int64_t> bias_bdim) {
|
||||||
|
|
@ -493,11 +491,9 @@ native_layer_norm_batch_rule(
|
||||||
double eps) {
|
double eps) {
|
||||||
auto input_ = moveBatchDimToFront(input, input_bdim);
|
auto input_ = moveBatchDimToFront(input, input_bdim);
|
||||||
if (!weight_bdim && !bias_bdim) {
|
if (!weight_bdim && !bias_bdim) {
|
||||||
const auto result = at::native_layer_norm_symint(input_, normalized_shape, weight_opt, bias_opt, eps);
|
auto [result0, mean, rstd] = at::native_layer_norm_symint(input_, normalized_shape, weight_opt, bias_opt, eps);
|
||||||
const auto mean = std::get<1>(result);
|
|
||||||
const auto rstd = std::get<2>(result);
|
|
||||||
const auto stats_bdim = compute_stat_bdim(input_bdim, mean);
|
const auto stats_bdim = compute_stat_bdim(input_bdim, mean);
|
||||||
return std::make_tuple(std::get<0>(result), 0, mean, stats_bdim, rstd, stats_bdim);
|
return std::make_tuple(std::move(result0), 0, std::move(mean), stats_bdim, std::move(rstd), stats_bdim);
|
||||||
}
|
}
|
||||||
|
|
||||||
// See [Note: hacky wrapper removal for optional tensor]
|
// See [Note: hacky wrapper removal for optional tensor]
|
||||||
|
|
@ -509,9 +505,7 @@ native_layer_norm_batch_rule(
|
||||||
|
|
||||||
const auto input_logical_rank = rankWithoutBatchDim(input, input_bdim);
|
const auto input_logical_rank = rankWithoutBatchDim(input, input_bdim);
|
||||||
const auto result = at::native_layer_norm_symint(input_, normalized_shape, std::nullopt, std::nullopt, eps);
|
const auto result = at::native_layer_norm_symint(input_, normalized_shape, std::nullopt, std::nullopt, eps);
|
||||||
auto result0 = std::get<0>(result);
|
auto [result0, mean, rstd] = result;
|
||||||
const auto mean = std::get<1>(result);
|
|
||||||
const auto rstd = std::get<2>(result);
|
|
||||||
const auto stats_bdim = compute_stat_bdim(input_bdim, mean);
|
const auto stats_bdim = compute_stat_bdim(input_bdim, mean);
|
||||||
|
|
||||||
if (weight.defined()) {
|
if (weight.defined()) {
|
||||||
|
|
@ -638,7 +632,7 @@ static std::tuple<at::Tensor,at::Tensor,at::Tensor> native_layer_norm_backward_p
|
||||||
unwrapTensorAtLevel(grad_normalized_input, cur_level);
|
unwrapTensorAtLevel(grad_normalized_input, cur_level);
|
||||||
|
|
||||||
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
|
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
|
||||||
const auto results = native_layer_norm_backward_no_weight_bias_batch_rule(
|
auto results = native_layer_norm_backward_no_weight_bias_batch_rule(
|
||||||
grad_normalized_input_value, grad_normalized_input_bdim,
|
grad_normalized_input_value, grad_normalized_input_bdim,
|
||||||
input_value, input_bdim,
|
input_value, input_bdim,
|
||||||
normalized_shape,
|
normalized_shape,
|
||||||
|
|
|
||||||
|
|
@ -171,18 +171,18 @@ void BatchedTensorImpl::shallow_copy_from(const c10::intrusive_ptr<TensorImpl>&
|
||||||
TORCH_CHECK(false, "mutating directly with `.data` under vmap transform is not allowed.");
|
TORCH_CHECK(false, "mutating directly with `.data` under vmap transform is not allowed.");
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor makeBatched(const Tensor& tensor, int64_t bdim, int64_t level) {
|
Tensor makeBatched(Tensor tensor, int64_t bdim, int64_t level) {
|
||||||
DispatchKeySet key_set = getKeysToPropagateToWrapper(tensor);
|
DispatchKeySet key_set = getKeysToPropagateToWrapper(tensor);
|
||||||
auto* batched = maybeGetBatchedImpl(tensor);
|
auto* batched = maybeGetBatchedImpl(tensor);
|
||||||
if (batched) {
|
if (batched) {
|
||||||
auto batched_level = batched->level();
|
auto batched_level = batched->level();
|
||||||
TORCH_INTERNAL_ASSERT(level > batched_level, " batched_level: ", batched_level, " level: ", level);
|
TORCH_INTERNAL_ASSERT(level > batched_level, " batched_level: ", batched_level, " level: ", level);
|
||||||
}
|
}
|
||||||
return at::detail::make_tensor<BatchedTensorImpl>(key_set, tensor, bdim, level);
|
return at::detail::make_tensor<BatchedTensorImpl>(key_set, std::move(tensor), bdim, level);
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor addBatchDim(const Tensor& tensor, int64_t dim, int64_t level) {
|
Tensor addBatchDim(Tensor tensor, int64_t dim, int64_t level) {
|
||||||
return makeBatched(tensor, dim, level);
|
return makeBatched(std::move(tensor), dim, level);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace at::functorch
|
} // namespace at::functorch
|
||||||
|
|
|
||||||
|
|
@ -144,10 +144,10 @@ inline std::bitset<kVmapNumLevels> createVmapLevelsBitset(int64_t level) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use this to construct a BatchedTensor from a regular Tensor
|
// Use this to construct a BatchedTensor from a regular Tensor
|
||||||
TORCH_API Tensor makeBatched(const Tensor& tensor, int64_t dim, int64_t level);
|
TORCH_API Tensor makeBatched(Tensor tensor, int64_t dim, int64_t level);
|
||||||
|
|
||||||
// Adds a batch dim to `tensor`, returning a BatchedTensor
|
// Adds a batch dim to `tensor`, returning a BatchedTensor
|
||||||
TORCH_API Tensor addBatchDim(const Tensor& tensor, int64_t dim, int64_t level);
|
TORCH_API Tensor addBatchDim(Tensor tensor, int64_t dim, int64_t level);
|
||||||
|
|
||||||
// Certain dispatch keys must be propagated to the BatchedTensor (or, in general,
|
// Certain dispatch keys must be propagated to the BatchedTensor (or, in general,
|
||||||
// any wrapper Tensor subclasses). This is because there are methods on Tensor
|
// any wrapper Tensor subclasses). This is because there are methods on Tensor
|
||||||
|
|
|
||||||
|
|
@ -22,20 +22,20 @@ void vmap_check_escaped(const std::optional<DynamicLayer> &layer, const char* wh
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor makeBatched(const Tensor& tensor, std::optional<int64_t> bdim, int64_t level) {
|
Tensor makeBatched(Tensor tensor, std::optional<int64_t> bdim, int64_t level) {
|
||||||
if (bdim.has_value()) {
|
if (bdim.has_value()) {
|
||||||
TORCH_INTERNAL_ASSERT(*bdim >= 0);
|
TORCH_INTERNAL_ASSERT(*bdim >= 0);
|
||||||
TORCH_INTERNAL_ASSERT(*bdim < tensor.dim());
|
TORCH_INTERNAL_ASSERT(*bdim < tensor.dim());
|
||||||
return makeBatched(tensor, bdim.value(), level);
|
return makeBatched(std::move(tensor), bdim.value(), level);
|
||||||
}
|
}
|
||||||
return tensor;
|
return tensor;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<Tensor> makeBatchedVector(const std::vector<Tensor>& tensors, std::optional<int64_t> bdim, int64_t level) {
|
std::vector<Tensor> makeBatchedVector(std::vector<Tensor> tensors, std::optional<int64_t> bdim, int64_t level) {
|
||||||
std::vector<Tensor> res;
|
std::vector<Tensor> res;
|
||||||
res.reserve(tensors.size());
|
res.reserve(tensors.size());
|
||||||
for (const auto & tensor : tensors) {
|
for (auto & tensor : tensors) {
|
||||||
res.emplace_back(makeBatched(tensor, bdim, level));
|
res.emplace_back(makeBatched(std::move(tensor), bdim, level));
|
||||||
}
|
}
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -29,7 +29,7 @@ namespace at::functorch {
|
||||||
void vmap_check_escaped(const std::optional<DynamicLayer> &layer, const char* what);
|
void vmap_check_escaped(const std::optional<DynamicLayer> &layer, const char* what);
|
||||||
|
|
||||||
// Create a BatchedTensor given a tensor, bdim, and level
|
// Create a BatchedTensor given a tensor, bdim, and level
|
||||||
TORCH_API Tensor makeBatched(const Tensor& tensor, std::optional<int64_t> bdim, int64_t level);
|
TORCH_API Tensor makeBatched(Tensor tensor, std::optional<int64_t> bdim, int64_t level);
|
||||||
|
|
||||||
// Given a Tensor that may or may not be a BatchedTensor, unwrap it.
|
// Given a Tensor that may or may not be a BatchedTensor, unwrap it.
|
||||||
// If `tensor` is not a BatchedTensor, or is a BatchedTensor but the level
|
// If `tensor` is not a BatchedTensor, or is a BatchedTensor but the level
|
||||||
|
|
@ -38,7 +38,7 @@ TORCH_API Tensor makeBatched(const Tensor& tensor, std::optional<int64_t> bdim,
|
||||||
TORCH_API std::tuple<Tensor, std::optional<int64_t>> unwrapTensorAtLevel(const Tensor& tensor, int64_t level);
|
TORCH_API std::tuple<Tensor, std::optional<int64_t>> unwrapTensorAtLevel(const Tensor& tensor, int64_t level);
|
||||||
|
|
||||||
// Creates a vector of BatchedTensor
|
// Creates a vector of BatchedTensor
|
||||||
TORCH_API std::vector<Tensor> makeBatchedVector(const std::vector<Tensor>& tensors, std::optional<int64_t> bdim, int64_t level);
|
TORCH_API std::vector<Tensor> makeBatchedVector(std::vector<Tensor> tensors, std::optional<int64_t> bdim, int64_t level);
|
||||||
|
|
||||||
// Returns True if ANY tensor in tensors is batched at level
|
// Returns True if ANY tensor in tensors is batched at level
|
||||||
TORCH_API bool isBatchedAtLevel(ITensorListRef tensors, int64_t level);
|
TORCH_API bool isBatchedAtLevel(ITensorListRef tensors, int64_t level);
|
||||||
|
|
|
||||||
|
|
@ -55,9 +55,7 @@ void unpack_bcsr(
|
||||||
memset(dst + i * C, zero_points[i], C * sizeof(int8_t));
|
memset(dst + i * C, zero_points[i], C * sizeof(int8_t));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
const std::vector<int8_t>& weight_values = std::get<0>(bcsr);
|
const auto& [weight_values, row_indices, col_indices] = bcsr;
|
||||||
const std::vector<int32_t>& row_indices = std::get<1>(bcsr);
|
|
||||||
const std::vector<int32_t>& col_indices = std::get<2>(bcsr);
|
|
||||||
int64_t rowBlocks = (R + RB - 1) / RB;
|
int64_t rowBlocks = (R + RB - 1) / RB;
|
||||||
for (int64_t i = 0; i < rowBlocks; ++i) {
|
for (int64_t i = 0; i < rowBlocks; ++i) {
|
||||||
// For the current tile, rowBPtr starts from currentTileIdx
|
// For the current tile, rowBPtr starts from currentTileIdx
|
||||||
|
|
@ -316,4 +314,4 @@ PackedLinearWeightQnnp::PackedLinearWeightQnnp(
|
||||||
}
|
}
|
||||||
#endif // USE_PYTORCH_QNNPACK
|
#endif // USE_PYTORCH_QNNPACK
|
||||||
|
|
||||||
} // namespace ao
|
} // namespace ao::sparse
|
||||||
|
|
|
||||||
|
|
@ -25,9 +25,9 @@ __global__ void randperm_handle_duplicate_keys_kernel(T *keys, scalar_t *data, T
|
||||||
|
|
||||||
// do random permutation inside each island.
|
// do random permutation inside each island.
|
||||||
data += tid;
|
data += tid;
|
||||||
auto seeds = at::cuda::philox::unpack(philox_args);
|
auto [seed, offset] = at::cuda::philox::unpack(philox_args);
|
||||||
curandStatePhilox4_32_10_t state;
|
curandStatePhilox4_32_10_t state;
|
||||||
curand_init(std::get<0>(seeds), tid, std::get<1>(seeds), &state);
|
curand_init(seed, tid, offset, &state);
|
||||||
for (int i = island_size - 1; i > 0; i--) {
|
for (int i = island_size - 1; i > 0; i--) {
|
||||||
unsigned int r = curand(&state) % (i + 1);
|
unsigned int r = curand(&state) % (i + 1);
|
||||||
if (i != r) {
|
if (i != r) {
|
||||||
|
|
|
||||||
|
|
@ -241,7 +241,8 @@ std::tuple<Tensor, Tensor, Tensor> math_native_layer_norm(
|
||||||
auto outputs = at::native_batch_norm(
|
auto outputs = at::native_batch_norm(
|
||||||
input_reshaped, /*weight=*/{}, /*bias=*/{}, /*running_mean=*/{},
|
input_reshaped, /*weight=*/{}, /*bias=*/{}, /*running_mean=*/{},
|
||||||
/*running_var=*/{}, /*training=*/true, /*momentum=*/0, eps);
|
/*running_var=*/{}, /*training=*/true, /*momentum=*/0, eps);
|
||||||
auto out = std::get<0>(outputs).view(input_shape);
|
auto& [out, mean, rstd] = outputs;
|
||||||
|
out = out.view(input_shape);
|
||||||
if (weight.defined() && bias.defined()) {
|
if (weight.defined() && bias.defined()) {
|
||||||
out = bias.addcmul(out, weight, 1);
|
out = bias.addcmul(out, weight, 1);
|
||||||
} else if (weight.defined()) {
|
} else if (weight.defined()) {
|
||||||
|
|
@ -249,8 +250,6 @@ std::tuple<Tensor, Tensor, Tensor> math_native_layer_norm(
|
||||||
} else if (bias.defined()) {
|
} else if (bias.defined()) {
|
||||||
out = out.add(bias);
|
out = out.add(bias);
|
||||||
}
|
}
|
||||||
at::Tensor mean = std::get<1>(outputs);
|
|
||||||
at::Tensor rstd = std::get<2>(outputs);
|
|
||||||
std::vector<int64_t> stat_shape;
|
std::vector<int64_t> stat_shape;
|
||||||
for (const auto idx : c10::irange(axis)) {
|
for (const auto idx : c10::irange(axis)) {
|
||||||
stat_shape.push_back(input_shape[idx]);
|
stat_shape.push_back(input_shape[idx]);
|
||||||
|
|
@ -260,7 +259,7 @@ std::tuple<Tensor, Tensor, Tensor> math_native_layer_norm(
|
||||||
}
|
}
|
||||||
mean = mean.view(stat_shape);
|
mean = mean.view(stat_shape);
|
||||||
rstd = rstd.view(stat_shape);
|
rstd = rstd.view(stat_shape);
|
||||||
return std::make_tuple(out, mean, rstd);
|
return outputs;
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor rms_norm_symint(
|
Tensor rms_norm_symint(
|
||||||
|
|
|
||||||
|
|
@ -110,7 +110,7 @@ void inline col_indices_and_values_resize_(const Tensor& input, int64_t nnz) {
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
void mkl_result_copy_(const Tensor& input, sparse_matrix_t mkl_desc) {
|
void mkl_result_copy_(const Tensor& input, sparse_matrix_t mkl_desc) {
|
||||||
sparse_index_base_t indexing = SPARSE_INDEX_BASE_ZERO;
|
sparse_index_base_t indexing = SPARSE_INDEX_BASE_ZERO;
|
||||||
MKL_INT rows, cols;
|
MKL_INT rows = 0, cols = 0;
|
||||||
MKL_INT *rows_start = nullptr, *rows_end = nullptr, *columns = nullptr;
|
MKL_INT *rows_start = nullptr, *rows_end = nullptr, *columns = nullptr;
|
||||||
scalar_t* values = nullptr;
|
scalar_t* values = nullptr;
|
||||||
at::mkl::sparse::export_csr(
|
at::mkl::sparse::export_csr(
|
||||||
|
|
@ -194,7 +194,7 @@ void addmm_dense_result(
|
||||||
auto ldb = is_B_row_major ? B_strides[ndim - 2] : B_strides[ndim - 1];
|
auto ldb = is_B_row_major ? B_strides[ndim - 2] : B_strides[ndim - 1];
|
||||||
auto columns_C = mkl_int_cast(C.size(-1), "columns_C");
|
auto columns_C = mkl_int_cast(C.size(-1), "columns_C");
|
||||||
|
|
||||||
matrix_descr descrA;
|
matrix_descr descrA{};
|
||||||
descrA.type = SPARSE_MATRIX_TYPE_GENERAL;
|
descrA.type = SPARSE_MATRIX_TYPE_GENERAL;
|
||||||
|
|
||||||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
|
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
|
||||||
|
|
@ -511,7 +511,7 @@ void addmv_out_sparse_csr(
|
||||||
c10::MaybeOwned<Tensor> vec_ = prepare_dense_vector_for_mkl(vec);
|
c10::MaybeOwned<Tensor> vec_ = prepare_dense_vector_for_mkl(vec);
|
||||||
|
|
||||||
sparse_operation_t opA = SPARSE_OPERATION_NON_TRANSPOSE;
|
sparse_operation_t opA = SPARSE_OPERATION_NON_TRANSPOSE;
|
||||||
matrix_descr descrA;
|
matrix_descr descrA{};
|
||||||
descrA.type = SPARSE_MATRIX_TYPE_GENERAL;
|
descrA.type = SPARSE_MATRIX_TYPE_GENERAL;
|
||||||
|
|
||||||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
|
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
|
||||||
|
|
@ -651,7 +651,7 @@ void triangular_solve_out_sparse_csr(
|
||||||
c10::MaybeOwned<Tensor> B_ = prepare_dense_matrix_for_mkl(B, is_X_row_major);
|
c10::MaybeOwned<Tensor> B_ = prepare_dense_matrix_for_mkl(B, is_X_row_major);
|
||||||
|
|
||||||
sparse_operation_t opA = transpose ? SPARSE_OPERATION_TRANSPOSE : SPARSE_OPERATION_NON_TRANSPOSE;
|
sparse_operation_t opA = transpose ? SPARSE_OPERATION_TRANSPOSE : SPARSE_OPERATION_NON_TRANSPOSE;
|
||||||
matrix_descr descrA;
|
matrix_descr descrA{};
|
||||||
descrA.type = SPARSE_MATRIX_TYPE_TRIANGULAR;
|
descrA.type = SPARSE_MATRIX_TYPE_TRIANGULAR;
|
||||||
descrA.mode = upper ? SPARSE_FILL_MODE_UPPER : SPARSE_FILL_MODE_LOWER;
|
descrA.mode = upper ? SPARSE_FILL_MODE_UPPER : SPARSE_FILL_MODE_LOWER;
|
||||||
descrA.diag = unitriangular ? SPARSE_DIAG_UNIT : SPARSE_DIAG_NON_UNIT;
|
descrA.diag = unitriangular ? SPARSE_DIAG_UNIT : SPARSE_DIAG_NON_UNIT;
|
||||||
|
|
|
||||||
|
|
@ -487,13 +487,36 @@ static std::tuple<Tensor, Tensor, Tensor> mkldnn_rnn(
|
||||||
auto layer_cx = cx[index];
|
auto layer_cx = cx[index];
|
||||||
auto reverse = (direction > 0);
|
auto reverse = (direction > 0);
|
||||||
// bias won't be packed
|
// bias won't be packed
|
||||||
auto outputs = at::mkldnn_rnn_layer(layer_input, layer_weights[0], layer_weights[1],
|
std::tie(
|
||||||
has_biases ? layer_weights[2] : at::zeros(layer_weights[0].sizes(), layer_weights[0].options().layout(at::Layout::Strided)),
|
layer_output[direction],
|
||||||
has_biases ? layer_weights[3] : at::zeros(layer_weights[1].sizes(), layer_weights[1].options().layout(at::Layout::Strided)), layer_hx,
|
layer_hy[index],
|
||||||
layer_cx, reverse, batch_sizes, mode, hidden_size, num_layers, has_biases, bidirectional, batch_first, train);
|
layer_cy[index],
|
||||||
layer_output[direction] = std::move(std::get<0>(outputs));
|
std::ignore) =
|
||||||
layer_hy[index] = std::move(std::get<1>(outputs));
|
at::mkldnn_rnn_layer(
|
||||||
layer_cy[index] = std::move(std::get<2>(outputs));
|
layer_input,
|
||||||
|
layer_weights[0],
|
||||||
|
layer_weights[1],
|
||||||
|
has_biases
|
||||||
|
? layer_weights[2]
|
||||||
|
: at::zeros(
|
||||||
|
layer_weights[0].sizes(),
|
||||||
|
layer_weights[0].options().layout(at::Layout::Strided)),
|
||||||
|
has_biases
|
||||||
|
? layer_weights[3]
|
||||||
|
: at::zeros(
|
||||||
|
layer_weights[1].sizes(),
|
||||||
|
layer_weights[1].options().layout(at::Layout::Strided)),
|
||||||
|
layer_hx,
|
||||||
|
layer_cx,
|
||||||
|
reverse,
|
||||||
|
batch_sizes,
|
||||||
|
mode,
|
||||||
|
hidden_size,
|
||||||
|
num_layers,
|
||||||
|
has_biases,
|
||||||
|
bidirectional,
|
||||||
|
batch_first,
|
||||||
|
train);
|
||||||
}
|
}
|
||||||
layer_input = num_directions == 1 ? layer_output[0]
|
layer_input = num_directions == 1 ? layer_output[0]
|
||||||
: at::cat(layer_output, /*output_channels*/-1);
|
: at::cat(layer_output, /*output_channels*/-1);
|
||||||
|
|
|
||||||
|
|
@ -258,11 +258,7 @@ int64_t get_nnz(const Tensor& nestedtensor) {
|
||||||
TensorOptions().device(at::kCUDA).dtype(at::kInt));
|
TensorOptions().device(at::kCUDA).dtype(at::kInt));
|
||||||
Nnz_q = output_batch_size * max_seqlen_batch_q;
|
Nnz_q = output_batch_size * max_seqlen_batch_q;
|
||||||
} else {
|
} else {
|
||||||
auto cumulative_and_max_q_and_nnz_q = cumulative_and_max_seq_len_nnz(q_t);
|
std::tie(cumulative_sequence_length_q, max_seqlen_batch_q, Nnz_q) = cumulative_and_max_seq_len_nnz(q_t);
|
||||||
cumulative_sequence_length_q =
|
|
||||||
std::get<0>(cumulative_and_max_q_and_nnz_q);
|
|
||||||
max_seqlen_batch_q = std::get<1>(cumulative_and_max_q_and_nnz_q);
|
|
||||||
Nnz_q = std::get<2>(cumulative_and_max_q_and_nnz_q);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t max_seqlen_batch_kv = 0, Nnz_kv = 0;
|
int64_t max_seqlen_batch_kv = 0, Nnz_kv = 0;
|
||||||
|
|
@ -277,13 +273,10 @@ int64_t get_nnz(const Tensor& nestedtensor) {
|
||||||
TensorOptions().device(at::kCUDA).dtype(at::kInt));
|
TensorOptions().device(at::kCUDA).dtype(at::kInt));
|
||||||
Nnz_kv = output_batch_size * max_seqlen_batch_kv;
|
Nnz_kv = output_batch_size * max_seqlen_batch_kv;
|
||||||
} else {
|
} else {
|
||||||
auto cumulative_and_max_kv_and_nnz_kv = k_batch_size_needs_broadcast
|
std::tie(cumulative_sequence_length_kv, max_seqlen_batch_kv, Nnz_kv) =
|
||||||
|
k_batch_size_needs_broadcast
|
||||||
? cumulative_and_max_seq_len_nnz(v_t)
|
? cumulative_and_max_seq_len_nnz(v_t)
|
||||||
: cumulative_and_max_seq_len_nnz(k_t);
|
: cumulative_and_max_seq_len_nnz(k_t);
|
||||||
cumulative_sequence_length_kv =
|
|
||||||
std::get<0>(cumulative_and_max_kv_and_nnz_kv);
|
|
||||||
max_seqlen_batch_kv = std::get<1>(cumulative_and_max_kv_and_nnz_kv);
|
|
||||||
Nnz_kv = std::get<2>(cumulative_and_max_kv_and_nnz_kv);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool q_num_heads_needs_broadcast = q_num_heads != output_num_heads;
|
bool q_num_heads_needs_broadcast = q_num_heads != output_num_heads;
|
||||||
|
|
@ -369,14 +362,14 @@ int64_t get_nnz(const Tensor& nestedtensor) {
|
||||||
}
|
}
|
||||||
|
|
||||||
return std::make_tuple(
|
return std::make_tuple(
|
||||||
query_buffer_reshaped,
|
std::move(query_buffer_reshaped),
|
||||||
key_buffer_reshaped,
|
std::move(key_buffer_reshaped),
|
||||||
value_buffer_reshaped,
|
std::move(value_buffer_reshaped),
|
||||||
cumulative_sequence_length_q,
|
std::move(cumulative_sequence_length_q),
|
||||||
cumulative_sequence_length_kv,
|
std::move(cumulative_sequence_length_kv),
|
||||||
max_seqlen_batch_q,
|
max_seqlen_batch_q,
|
||||||
max_seqlen_batch_kv,
|
max_seqlen_batch_kv,
|
||||||
output_shape);
|
std::move(output_shape));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
|
||||||
|
|
@ -441,21 +441,15 @@ TORCH_API int register_linear_params() {
|
||||||
[](SerializationType state)
|
[](SerializationType state)
|
||||||
-> c10::intrusive_ptr<
|
-> c10::intrusive_ptr<
|
||||||
LinearPackedParamsBase> { // __setstate__
|
LinearPackedParamsBase> { // __setstate__
|
||||||
at::Tensor weight;
|
|
||||||
std::optional<at::Tensor> bias;
|
|
||||||
weight = std::move(std::get<0>(state));
|
|
||||||
bias = std::move(std::get<1>(state));
|
|
||||||
|
|
||||||
#ifdef USE_FBGEMM
|
#ifdef USE_FBGEMM
|
||||||
if (at::globalContext().qEngine() == at::QEngine::FBGEMM ||
|
if (at::globalContext().qEngine() == at::QEngine::FBGEMM ||
|
||||||
at::globalContext().qEngine() == at::QEngine::X86) {
|
at::globalContext().qEngine() == at::QEngine::X86) {
|
||||||
|
const auto& weight = std::get<0>(state);
|
||||||
if (weight.scalar_type() == at::kQInt8) {
|
if (weight.scalar_type() == at::kQInt8) {
|
||||||
return PackedLinearWeight::prepack(
|
return std::apply(PackedLinearWeight::prepack, std::move(state));
|
||||||
std::move(weight), std::move(bias));
|
|
||||||
} else if (weight.scalar_type() == at::kFloat) {
|
} else if (weight.scalar_type() == at::kFloat) {
|
||||||
// NB: fp16 weight is serialized as float
|
// NB: fp16 weight is serialized as float
|
||||||
return PackedLinearWeightFp16::prepack(
|
return std::apply(PackedLinearWeightFp16::prepack, std::move(state));
|
||||||
std::move(weight), std::move(bias));
|
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
false,
|
false,
|
||||||
|
|
@ -467,22 +461,22 @@ TORCH_API int register_linear_params() {
|
||||||
#endif // USE_FBGEMM
|
#endif // USE_FBGEMM
|
||||||
#ifdef USE_PYTORCH_QNNPACK
|
#ifdef USE_PYTORCH_QNNPACK
|
||||||
if (at::globalContext().qEngine() == at::QEngine::QNNPACK) {
|
if (at::globalContext().qEngine() == at::QEngine::QNNPACK) {
|
||||||
|
const auto& weight = std::get<0>(state);
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
weight.scalar_type() == at::kQInt8,
|
weight.scalar_type() == at::kQInt8,
|
||||||
"QNNPACK only supports INT8 bit width currently. Got ",
|
"QNNPACK only supports INT8 bit width currently. Got ",
|
||||||
c10::toString(weight.scalar_type()));
|
c10::toString(weight.scalar_type()));
|
||||||
return PackedLinearWeightsQnnp::prepack(
|
return std::apply(PackedLinearWeightsQnnp::prepack, std::move(state));
|
||||||
std::move(weight), std::move(bias));
|
|
||||||
}
|
}
|
||||||
#endif // USE_PYTORCH_QNNPACK
|
#endif // USE_PYTORCH_QNNPACK
|
||||||
#if AT_MKLDNN_ENABLED()
|
#if AT_MKLDNN_ENABLED()
|
||||||
if (at::globalContext().qEngine() == at::QEngine::ONEDNN) {
|
if (at::globalContext().qEngine() == at::QEngine::ONEDNN) {
|
||||||
|
const auto& weight = std::get<0>(state);
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
weight.scalar_type() == at::kQInt8,
|
weight.scalar_type() == at::kQInt8,
|
||||||
"ONEDNN only supports INT8 bit width currently. Got ",
|
"ONEDNN only supports INT8 bit width currently. Got ",
|
||||||
c10::toString(weight.scalar_type()));
|
c10::toString(weight.scalar_type()));
|
||||||
return PackedLinearWeightsOnednn::prepack(
|
return std::apply(PackedLinearWeightsOnednn::prepack, std::move(state));
|
||||||
std::move(weight), std::move(bias));
|
|
||||||
}
|
}
|
||||||
#endif // #if AT_MKLDNN_ENABLED()
|
#endif // #if AT_MKLDNN_ENABLED()
|
||||||
TORCH_CHECK(false, "Unknown qengine");
|
TORCH_CHECK(false, "Unknown qengine");
|
||||||
|
|
|
||||||
|
|
@ -366,11 +366,8 @@ std::tuple<Tensor, Tensor> native_multi_head_attention_cpu(
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
// shape: 3 x [B, num_head, T, dim_per_head]
|
// shape: 3 x [B, num_head, T, dim_per_head]
|
||||||
auto q_k_v = _transform_bias_rescale_qkv(qkv, qkv_bias, num_head);
|
auto [q, k, v] = _transform_bias_rescale_qkv(qkv, qkv_bias, num_head);
|
||||||
qkv = Tensor(); // Not used any more, allow free
|
qkv = Tensor(); // Not used any more, allow free
|
||||||
auto& q = std::get<0>(q_k_v);
|
|
||||||
const auto& k = std::get<1>(q_k_v);
|
|
||||||
const auto& v = std::get<2>(q_k_v);
|
|
||||||
#ifndef NDEBUG
|
#ifndef NDEBUG
|
||||||
debug_assert_shape(__LINE__, q, {B, num_head, T, dim_per_head});
|
debug_assert_shape(__LINE__, q, {B, num_head, T, dim_per_head});
|
||||||
debug_assert_shape(__LINE__, k, {B, num_head, T, dim_per_head});
|
debug_assert_shape(__LINE__, k, {B, num_head, T, dim_per_head});
|
||||||
|
|
|
||||||
|
|
@ -617,11 +617,8 @@ std::tuple<Tensor, Tensor> native_multi_head_attention_cuda(
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
// shape: 3 x [B, num_head, T, dim_per_head]
|
// shape: 3 x [B, num_head, T, dim_per_head]
|
||||||
auto q_k_v = _transform_bias_rescale_qkv(qkv, qkv_bias, num_head);
|
auto [q, k, v] = _transform_bias_rescale_qkv(qkv, qkv_bias, num_head);
|
||||||
qkv = Tensor(); // Not used any more, allow free
|
qkv = Tensor(); // Not used any more, allow free
|
||||||
auto& q = std::get<0>(q_k_v);
|
|
||||||
const auto& k = std::get<1>(q_k_v);
|
|
||||||
const auto& v = std::get<2>(q_k_v);
|
|
||||||
#ifndef NDEBUG
|
#ifndef NDEBUG
|
||||||
debug_assert_shape(__LINE__, q, {B, num_head, T, dim_per_head});
|
debug_assert_shape(__LINE__, q, {B, num_head, T, dim_per_head});
|
||||||
debug_assert_shape(__LINE__, k, {B, num_head, T, dim_per_head});
|
debug_assert_shape(__LINE__, k, {B, num_head, T, dim_per_head});
|
||||||
|
|
@ -1413,7 +1410,7 @@ __global__ void rand_uniform_kernel(
|
||||||
const int64_t head_id = blockIdx.y;
|
const int64_t head_id = blockIdx.y;
|
||||||
const int64_t query_idx = threadIdx.x;
|
const int64_t query_idx = threadIdx.x;
|
||||||
|
|
||||||
const auto seeds = at::cuda::philox::unpack(rng_engine_inputs);
|
const auto [seed, offset] = at::cuda::philox::unpack(rng_engine_inputs);
|
||||||
|
|
||||||
const int dropout_seq_start = batch_id * (n_heads * n_queries * n_keys) +
|
const int dropout_seq_start = batch_id * (n_heads * n_queries * n_keys) +
|
||||||
head_id * (n_queries * n_keys);
|
head_id * (n_queries * n_keys);
|
||||||
|
|
@ -1421,9 +1418,9 @@ __global__ void rand_uniform_kernel(
|
||||||
|
|
||||||
curandStatePhilox4_32_10_t curand_state;
|
curandStatePhilox4_32_10_t curand_state;
|
||||||
curand_init(
|
curand_init(
|
||||||
std::get<0>(seeds),
|
seed,
|
||||||
0,
|
0,
|
||||||
std::get<1>(seeds) + dropout_seq_start + query_start_idx,
|
offset + dropout_seq_start + query_start_idx,
|
||||||
&curand_state);
|
&curand_state);
|
||||||
|
|
||||||
for (int key_start_idx = 0; key_start_idx < n_keys; key_start_idx += 4) {
|
for (int key_start_idx = 0; key_start_idx < n_keys; key_start_idx += 4) {
|
||||||
|
|
|
||||||
|
|
@ -443,9 +443,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
|
||||||
cute::copy(smem_tiled_copy_KV, tdPsV, tdPrV_copy_view);
|
cute::copy(smem_tiled_copy_KV, tdPsV, tdPrV_copy_view);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto seeds = at::cuda::philox::unpack(params.philox_args);
|
const auto [seed, offset] = at::cuda::philox::unpack(params.philox_args);
|
||||||
unsigned long long seed = std::get<0>(seeds);
|
|
||||||
unsigned long long offset = std::get<1>(seeds);
|
|
||||||
pytorch_flash::Dropout dropout(seed, offset, params.p_dropout_in_uint8_t,
|
pytorch_flash::Dropout dropout(seed, offset, params.p_dropout_in_uint8_t,
|
||||||
bidb, bidh, tidx, params.h);
|
bidb, bidh, tidx, params.h);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1341,7 +1341,7 @@ struct AttentionBackwardKernel {
|
||||||
|
|
||||||
if (kApplyDropout) {
|
if (kApplyDropout) {
|
||||||
// See Note [Seed and Offset Device]
|
// See Note [Seed and Offset Device]
|
||||||
auto seeds = at::cuda::philox::unpack(p.rng_engine_inputs);
|
auto const [seed, offset] = at::cuda::philox::unpack(p.rng_engine_inputs);
|
||||||
// each element of the attention matrix P with shape
|
// each element of the attention matrix P with shape
|
||||||
// (batch_sz, n_heads, n_queries, n_keys) is associated with a single
|
// (batch_sz, n_heads, n_queries, n_keys) is associated with a single
|
||||||
// offset in RNG sequence. we initialize the RNG state with offset that
|
// offset in RNG sequence. we initialize the RNG state with offset that
|
||||||
|
|
@ -1351,9 +1351,9 @@ struct AttentionBackwardKernel {
|
||||||
// rather than once per iteration. each iteration takes a copy of the
|
// rather than once per iteration. each iteration takes a copy of the
|
||||||
// initialized RNG state and offsets it as needed.
|
// initialized RNG state and offsets it as needed.
|
||||||
curand_init(
|
curand_init(
|
||||||
std::get<0>(seeds),
|
seed,
|
||||||
0,
|
0,
|
||||||
std::get<1>(seeds) + p.dropout_batch_head_rng_offset,
|
offset + p.dropout_batch_head_rng_offset,
|
||||||
&rng_state_init);
|
&rng_state_init);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -671,14 +671,13 @@ struct AttentionKernel {
|
||||||
|
|
||||||
curandStatePhilox4_32_10_t curand_state_init;
|
curandStatePhilox4_32_10_t curand_state_init;
|
||||||
if (kSupportsDropout && p.use_dropout) {
|
if (kSupportsDropout && p.use_dropout) {
|
||||||
const auto seeds = at::cuda::philox::unpack(p.rng_engine_inputs);
|
const auto [seed, offset] = at::cuda::philox::unpack(p.rng_engine_inputs);
|
||||||
if (p.rng_engine_inputs.captured_) {
|
if (p.rng_engine_inputs.captured_) {
|
||||||
// See Note [Seed and Offset Device]
|
// See Note [Seed and Offset Device]
|
||||||
// When we are in cuda graph capture mode the seed and offset are stored
|
// When we are in cuda graph capture mode the seed and offset are stored
|
||||||
// on device We pass in int64_t* seed, and int64_t* offset to act as
|
// on device We pass in int64_t* seed, and int64_t* offset to act as
|
||||||
// scratch space for storing the rng state during the forward pass and
|
// scratch space for storing the rng state during the forward pass and
|
||||||
// saving for backwards.
|
// saving for backwards.
|
||||||
auto [seed, offset] = seeds;
|
|
||||||
*p.seed = seed;
|
*p.seed = seed;
|
||||||
*p.extragraph_offset = offset;
|
*p.extragraph_offset = offset;
|
||||||
}
|
}
|
||||||
|
|
@ -691,9 +690,9 @@ struct AttentionKernel {
|
||||||
// rather than once per iteration. each iteration takes a copy of the
|
// rather than once per iteration. each iteration takes a copy of the
|
||||||
// initialized RNG state and offsets it as needed.
|
// initialized RNG state and offsets it as needed.
|
||||||
curand_init(
|
curand_init(
|
||||||
std::get<0>(seeds),
|
seed,
|
||||||
0,
|
0,
|
||||||
std::get<1>(seeds) + p.dropout_batch_head_rng_offset,
|
offset + p.dropout_batch_head_rng_offset,
|
||||||
&curand_state_init);
|
&curand_state_init);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -149,15 +149,7 @@ TORCH_LIBRARY(vulkan, m) {
|
||||||
},
|
},
|
||||||
// __setstate__
|
// __setstate__
|
||||||
[](Conv2dOpContext::State state) {
|
[](Conv2dOpContext::State state) {
|
||||||
return conv2d_clamp_prepack(
|
return std::apply(conv2d_clamp_prepack, std::move(state));
|
||||||
std::move(std::get<0>(state)),
|
|
||||||
std::move(std::get<1>(state)),
|
|
||||||
std::move(std::get<2>(state)),
|
|
||||||
std::move(std::get<3>(state)),
|
|
||||||
std::move(std::get<4>(state)),
|
|
||||||
std::get<5>(state),
|
|
||||||
std::get<6>(state),
|
|
||||||
std::get<7>(state));
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user