[3/N] Avoid copy in std::get (#141843)

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141843
Approved by: https://github.com/Skylion007
This commit is contained in:
cyy 2024-12-06 20:13:33 +00:00 committed by PyTorch MergeBot
parent add4a42ea2
commit 1fa27f6e82
22 changed files with 130 additions and 152 deletions

View File

@ -76,7 +76,7 @@ void BatchedTensorImpl::checkInvariants() const {
} }
} }
// The following are publically exposed as methods of Tensor // The following are publicly exposed as methods of Tensor
IntArrayRef BatchedTensorImpl::strides_custom() const { IntArrayRef BatchedTensorImpl::strides_custom() const {
return strides_default(); return strides_default();
@ -113,7 +113,7 @@ const char* BatchedTensorImpl::tensorimpl_type_name() const {
return "BatchedTensorImpl"; return "BatchedTensorImpl";
} }
Tensor makeBatched(const Tensor& tensor, BatchDims bdims) { Tensor makeBatched(Tensor tensor, BatchDims bdims) {
TORCH_INTERNAL_ASSERT(!isBatchedTensor(tensor)); TORCH_INTERNAL_ASSERT(!isBatchedTensor(tensor));
auto tensor_dim = tensor.dim(); auto tensor_dim = tensor.dim();
TORCH_CHECK( TORCH_CHECK(
@ -124,15 +124,15 @@ Tensor makeBatched(const Tensor& tensor, BatchDims bdims) {
std::all_of(bdims.begin(), bdims.end(), std::all_of(bdims.begin(), bdims.end(),
[](const BatchDim& bdim) { return bdim.level() < kVmapNumLevels; }), [](const BatchDim& bdim) { return bdim.level() < kVmapNumLevels; }),
"We only support up to ", kVmapNumLevels, " nested vmaps"); "We only support up to ", kVmapNumLevels, " nested vmaps");
return at::detail::make_tensor<BatchedTensorImpl>(tensor, std::move(bdims)); return at::detail::make_tensor<BatchedTensorImpl>(std::move(tensor), std::move(bdims));
} }
Tensor addBatchDim(const Tensor& tensor, int64_t level, int64_t dim) { Tensor addBatchDim(Tensor tensor, int64_t level, int64_t dim) {
const auto* batched = maybeGetBatchedImpl(tensor); const auto* batched = maybeGetBatchedImpl(tensor);
if (!batched) { if (!batched) {
BatchDims bdims; BatchDims bdims;
bdims.emplace_back(level, dim); bdims.emplace_back(level, dim);
return at::detail::make_tensor<BatchedTensorImpl>(tensor, std::move(bdims)); return at::detail::make_tensor<BatchedTensorImpl>(std::move(tensor), std::move(bdims));
} }
BatchDims new_bdims(batched->bdims().begin(), batched->bdims().end()); BatchDims new_bdims(batched->bdims().begin(), batched->bdims().end());
auto actual_bdim = batched->actualDim(dim, /*wrap_dim=*/true); auto actual_bdim = batched->actualDim(dim, /*wrap_dim=*/true);

View File

@ -148,10 +148,10 @@ inline std::ostream& operator<<(std::ostream& out, const BatchDim& bdim) {
} }
// Use this to construct a BatchedTensor from a regular Tensor // Use this to construct a BatchedTensor from a regular Tensor
TORCH_API Tensor makeBatched(const Tensor& tensor, BatchDims bdims); TORCH_API Tensor makeBatched(Tensor tensor, BatchDims bdims);
// Adds a batch dim to `tensor`, returning a BatchedTensor // Adds a batch dim to `tensor`, returning a BatchedTensor
TORCH_API Tensor addBatchDim(const Tensor& tensor, int64_t level, int64_t dim); TORCH_API Tensor addBatchDim(Tensor tensor, int64_t level, int64_t dim);
// Checks if an inplace operation on self and other is "vmap compatible". // Checks if an inplace operation on self and other is "vmap compatible".
// See NOTE: [vmap-incompatible in-place operations] for the definition of this. // See NOTE: [vmap-incompatible in-place operations] for the definition of this.

View File

@ -353,7 +353,7 @@ void expectOutOfPlaceMultiUnboxedCallingWorks(const KernelFunction& func) {
auto t1 = at::zeros({1}); auto t1 = at::zeros({1});
auto t2 = at::zeros({1}); auto t2 = at::zeros({1});
std::tuple<at::Tensor&, at::Tensor&> tup = func.call< auto [t1_out, t2_out] = func.call<
std::tuple<at::Tensor&, at::Tensor&>, at::Scalar, at::Scalar, at::Tensor&, at::Tensor& std::tuple<at::Tensor&, at::Tensor&>, at::Scalar, at::Scalar, at::Tensor&, at::Tensor&
>(dummy, CPU_TEST_SET, s1, s2, t1, t2); >(dummy, CPU_TEST_SET, s1, s2, t1, t2);
@ -361,11 +361,9 @@ void expectOutOfPlaceMultiUnboxedCallingWorks(const KernelFunction& func) {
EXPECT_EQ(t1.item().toFloat(), 1.0f); EXPECT_EQ(t1.item().toFloat(), 1.0f);
EXPECT_EQ(t2.item().toFloat(), 2.0f); EXPECT_EQ(t2.item().toFloat(), 2.0f);
auto t1_out = std::get<0>(tup);
EXPECT_EQ(t1_out.item().toFloat(), 1.0f); EXPECT_EQ(t1_out.item().toFloat(), 1.0f);
EXPECT_TRUE(t1_out.is_same(t1)); EXPECT_TRUE(t1_out.is_same(t1));
auto t2_out = std::get<1>(tup);
EXPECT_EQ(t2_out.item().toFloat(), 2.0f); EXPECT_EQ(t2_out.item().toFloat(), 2.0f);
EXPECT_TRUE(t2_out.is_same(t2)); EXPECT_TRUE(t2_out.is_same(t2));
} }

View File

@ -218,47 +218,43 @@ static std::tuple<Tensor, std::optional<int64_t>> masked_select_backward_batch_r
static std::tuple<Tensor, std::optional<int64_t>> cdist_backward_batch_rule( static std::tuple<Tensor, std::optional<int64_t>> cdist_backward_batch_rule(
const Tensor& grad, std::optional<int64_t> grad_bdim, const Tensor& grad, std::optional<int64_t> grad_bdim,
const Tensor& x1, std::optional<int64_t> x1_bdim, Tensor x1, std::optional<int64_t> x1_bdim,
const Tensor& x2, std::optional<int64_t> x2_bdim, Tensor x2, std::optional<int64_t> x2_bdim,
const double p, const double p,
const Tensor& cdist, std::optional<int64_t> cdist_bdim) { const Tensor& cdist, std::optional<int64_t> cdist_bdim) {
auto x1_ = x1;
if (cdist_bdim && !x1_bdim) { if (cdist_bdim && !x1_bdim) {
// We need to make sure that x1 has batch dim if cdist has one // We need to make sure that x1 has batch dim if cdist has one
// otherwise, we get // otherwise, we get
// RuntimeError: Function CdistBackward0 returned an invalid gradient at index 1 - got [5] // RuntimeError: Function CdistBackward0 returned an invalid gradient at index 1 - got [5]
// but expected shape compatible with [4, 5] // but expected shape compatible with [4, 5]
auto bs = cdist.size(*cdist_bdim); auto bs = cdist.size(*cdist_bdim);
x1_ = ensure_has_bdim(x1, false, bs); x1 = ensure_has_bdim(x1, false, bs).contiguous();
x1_ = x1_.contiguous();
x1_bdim = 0; x1_bdim = 0;
} }
// We need to apply the same preprocessing on x1 and x2 as in the forward pass // We need to apply the same preprocessing on x1 and x2 as in the forward pass
// _binary_pointwise_batch_rule // _binary_pointwise_batch_rule
auto x12 = _binary_pointwise_helper(x1_, x1_bdim, x2, x2_bdim); std::tie(x1, x2)= _binary_pointwise_helper(x1, x1_bdim, x2, x2_bdim);
x1_ = std::move(std::get<0>(x12));
auto& x2_ = std::get<1>(x12);
auto grad_ = moveBatchDimToFront(grad, grad_bdim); auto grad_ = moveBatchDimToFront(grad, grad_bdim);
if ((x1_bdim || x2_bdim) && !grad_bdim) { if ((x1_bdim || x2_bdim) && !grad_bdim) {
// We need to make sure that grad has batch dim if x1 or x2 have one // We need to make sure that grad has batch dim if x1 or x2 have one
// Probably, there is an assumption on the strides. // Probably, there is an assumption on the strides.
// Otherwise grad input contains thrash values, e.g. -7.0816e+29, 7.0816e+29 // Otherwise grad input contains thrash values, e.g. -7.0816e+29, 7.0816e+29
auto bs = get_bdim_size2(x1_, 0, x2_, 0); auto bs = get_bdim_size2(x1, 0, x2, 0);
grad_ = ensure_has_bdim(grad_, grad_bdim.has_value(), bs); grad_ = ensure_has_bdim(grad_, grad_bdim.has_value(), bs);
grad_ = grad_.contiguous(); grad_ = grad_.contiguous();
} }
auto out = at::_cdist_backward(grad_, x1_, x2_, p, cdist); auto out = at::_cdist_backward(grad_, x1, x2, p, cdist);
std::optional<int64_t> out_bdim = std::nullopt; std::optional<int64_t> out_bdim = std::nullopt;
if (x1_bdim || x2_bdim) { if (x1_bdim || x2_bdim) {
out_bdim = 0; out_bdim = 0;
} }
return std::make_tuple(out, out_bdim); return std::make_tuple(std::move(out), out_bdim);
} }
static void fill__Tensor_batch_rule( static void fill__Tensor_batch_rule(

View File

@ -42,6 +42,7 @@ static Tensor padRight(const Tensor& tensor, std::optional<int64_t> has_bdim, in
} }
template<typename F, F Func> template<typename F, F Func>
static
std::tuple<Tensor, std::optional<int64_t>,Tensor, std::optional<int64_t>,Tensor, std::optional<int64_t>> std::tuple<Tensor, std::optional<int64_t>,Tensor, std::optional<int64_t>,Tensor, std::optional<int64_t>>
batch_norm_batch_rule( batch_norm_batch_rule(
const Tensor& input, std::optional<int64_t> input_bdim, const Tensor& input, std::optional<int64_t> input_bdim,
@ -70,10 +71,10 @@ batch_norm_batch_rule(
if (!input_bdim && !running_mean_bdim && !running_var_bdim) { if (!input_bdim && !running_mean_bdim && !running_var_bdim) {
const auto dummy_weight = at::ones(input.size(1), input.options()); // cudnn and miopen require a weight const auto dummy_weight = at::ones(input.size(1), input.options()); // cudnn and miopen require a weight
const auto dummy_bias = at::zeros(input.size(1), input.options()); // without this, get "strides() called on undefined Tensor" on cuda const auto dummy_bias = at::zeros(input.size(1), input.options()); // without this, get "strides() called on undefined Tensor" on cuda
const auto result = Func(input, dummy_weight, dummy_bias, running_mean_opt, running_var_opt, training, momentum, eps); auto result = Func(input, dummy_weight, dummy_bias, running_mean_opt, running_var_opt, training, momentum, eps);
result0 = std::get<0>(result).transpose(0, 1); // [C, B, *] result0 = std::get<0>(result).transpose(0, 1); // [C, B, *]
mean = std::get<1>(result); mean = std::move(std::get<1>(result));
rstd = std::get<2>(result); rstd = std::move(std::get<2>(result));
} else { } else {
bdim_size = get_bdim_size3(input, input_bdim, running_mean, running_mean_bdim, running_var, running_var_bdim); bdim_size = get_bdim_size3(input, input_bdim, running_mean, running_mean_bdim, running_var, running_var_bdim);
auto input_ = moveBatchDimToFront(input, input_bdim); auto input_ = moveBatchDimToFront(input, input_bdim);
@ -95,12 +96,12 @@ batch_norm_batch_rule(
const auto dummy_weight = at::ones(input_.size(1), input_.options()); // cudnn and miopen require a weight const auto dummy_weight = at::ones(input_.size(1), input_.options()); // cudnn and miopen require a weight
const auto dummy_bias = at::zeros(input_.size(1), input_.options()); // without this, get "strides() called on undefined Tensor" on cuda const auto dummy_bias = at::zeros(input_.size(1), input_.options()); // without this, get "strides() called on undefined Tensor" on cuda
const auto result = Func(input_, dummy_weight, dummy_bias, running_mean_, running_var_, training, momentum, eps); auto result = Func(input_, dummy_weight, dummy_bias, running_mean_, running_var_, training, momentum, eps);
result0 = std::get<0>(result).transpose(0, 1); // [(B0, C), B, *] result0 = std::get<0>(result).transpose(0, 1); // [(B0, C), B, *]
mean = std::move(std::get<1>(result));
rstd = std::move(std::get<2>(result));
result0 = reshape_dim_outof(0, bdim_size.value(), result0); // [B0, C, B, *] result0 = reshape_dim_outof(0, bdim_size.value(), result0); // [B0, C, B, *]
mean = std::get<1>(result);
mean = reshape_dim_outof(0, bdim_size.value(), mean); // [B0, C] mean = reshape_dim_outof(0, bdim_size.value(), mean); // [B0, C]
rstd = std::get<2>(result);
rstd = reshape_dim_outof(0, bdim_size.value(), rstd); // [B0, C] rstd = reshape_dim_outof(0, bdim_size.value(), rstd); // [B0, C]
} }
@ -124,6 +125,7 @@ batch_norm_batch_rule(
} }
template<typename F, F Func> template<typename F, F Func>
static
std::tuple<at::Tensor, std::optional<int64_t>> batch_norm_backward_no_weight_bias_batch_rule( std::tuple<at::Tensor, std::optional<int64_t>> batch_norm_backward_no_weight_bias_batch_rule(
const at::Tensor & grad_out, std::optional<int64_t> grad_out_bdim, const at::Tensor & grad_out, std::optional<int64_t> grad_out_bdim,
const at::Tensor & input, std::optional<int64_t> input_bdim, const at::Tensor & input, std::optional<int64_t> input_bdim,
@ -142,9 +144,9 @@ std::tuple<at::Tensor, std::optional<int64_t>> batch_norm_backward_no_weight_bia
TORCH_INTERNAL_ASSERT(!mean_bdim); TORCH_INTERNAL_ASSERT(!mean_bdim);
TORCH_INTERNAL_ASSERT(!rstd_bdim); TORCH_INTERNAL_ASSERT(!rstd_bdim);
const auto dummy_weight = at::ones(input.size(1), input.options()); const auto dummy_weight = at::ones(input.size(1), input.options());
const auto result = Func( auto result =Func(
grad_out, input, dummy_weight, running_mean_opt, running_var_opt, mean, rstd, training, eps, {true, false, false}); grad_out, input, dummy_weight, running_mean_opt, running_var_opt, mean, rstd, training, eps, {true, false, false});
return std::make_tuple(std::get<0>(result), std::nullopt); return {std::move(std::get<0>(result)), std::nullopt};
} }
auto grad_out_ = moveBatchDimToFront(grad_out, grad_out_bdim); auto grad_out_ = moveBatchDimToFront(grad_out, grad_out_bdim);
@ -196,6 +198,7 @@ std::tuple<at::Tensor, std::optional<int64_t>> batch_norm_backward_no_weight_bia
} }
template<typename F, F Func> template<typename F, F Func>
static
std::tuple<at::Tensor,at::Tensor,at::Tensor> batch_norm_backward_plumbing( std::tuple<at::Tensor,at::Tensor,at::Tensor> batch_norm_backward_plumbing(
const at::Tensor & grad_out, const at::Tensor & grad_out,
const at::Tensor & input, const at::Tensor & input,
@ -270,7 +273,7 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> batch_norm_backward_plumbing(
unwrapTensorAtLevel(grad_normalized_input.transpose(0, 1), cur_level); // [B0, B, C, *] unwrapTensorAtLevel(grad_normalized_input.transpose(0, 1), cur_level); // [B0, B, C, *]
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
const auto results = batch_norm_backward_no_weight_bias_batch_rule<F, Func>( auto results = batch_norm_backward_no_weight_bias_batch_rule<F, Func>(
grad_normalized_input_value, grad_normalized_input_bdim, grad_normalized_input_value, grad_normalized_input_bdim,
input_value, input_bdim, input_value, input_bdim,
running_mean_value, running_mean_bdim, running_mean_value, running_mean_bdim,
@ -278,7 +281,7 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> batch_norm_backward_plumbing(
save_mean_value, save_mean_bdim, save_mean_value, save_mean_bdim,
save_rstd_value, save_rstd_bdim, save_rstd_value, save_rstd_bdim,
training, eps); training, eps);
grad_input = makeBatched(std::get<0>(results), std::get<1>(results), cur_level); grad_input = makeBatched(std::move(std::get<0>(results)), std::get<1>(results), cur_level);
} }
return std::make_tuple(grad_input, grad_weight, grad_bias); return std::make_tuple(grad_input, grad_weight, grad_bias);
} }
@ -312,16 +315,13 @@ static std::tuple<Tensor,Tensor,Tensor> native_group_norm_plumbing(
const auto bdim_size = input_value.size(*input_bdim); const auto bdim_size = input_value.size(*input_bdim);
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
const auto result = at::native_group_norm(input_, std::nullopt, std::nullopt, N * bdim_size, C, HxW, group, eps); std::tie(result0, mean, rstd) = at::native_group_norm(input_, std::nullopt, std::nullopt, N * bdim_size, C, HxW, group, eps);
result0 = makeBatched(reshape_dim_outof(0, bdim_size, std::get<0>(result)), 0, cur_level); result0 = makeBatched(reshape_dim_outof(0, bdim_size, result0), 0, cur_level);
mean = makeBatched(reshape_dim_outof(0, bdim_size, std::get<1>(result)), 0, cur_level); mean = makeBatched(reshape_dim_outof(0, bdim_size, mean), 0, cur_level);
rstd = makeBatched(reshape_dim_outof(0, bdim_size, std::get<2>(result)), 0, cur_level); rstd = makeBatched(reshape_dim_outof(0, bdim_size, rstd), 0, cur_level);
} else { } else {
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
const auto result = at::native_group_norm(input_value, std::nullopt, std::nullopt, N, C, HxW, group, eps); std::tie(result0, mean, rstd) = at::native_group_norm(input_value, std::nullopt, std::nullopt, N, C, HxW, group, eps);
result0 = std::get<0>(result);
mean = std::get<1>(result);
rstd = std::get<2>(result);
} }
if (weight.defined()) { if (weight.defined()) {
@ -334,10 +334,10 @@ static std::tuple<Tensor,Tensor,Tensor> native_group_norm_plumbing(
result0 = result0 + padded_bias; result0 = result0 + padded_bias;
} }
return std::make_tuple(result0, mean, rstd); return std::make_tuple(std::move(result0), std::move(mean), std::move(rstd));
} }
static std::tuple<at::Tensor, std::optional<int64_t>> group_norm_backward_no_weight_bias_batch_rule( static at::Tensor group_norm_backward_no_weight_bias_batch_rule(
const at::Tensor & grad_out, std::optional<int64_t> grad_out_bdim, const at::Tensor & grad_out, std::optional<int64_t> grad_out_bdim,
const at::Tensor & input, std::optional<int64_t> input_bdim, const at::Tensor & input, std::optional<int64_t> input_bdim,
const at::Tensor & mean, std::optional<int64_t> mean_bdim, const at::Tensor & mean, std::optional<int64_t> mean_bdim,
@ -359,15 +359,13 @@ static std::tuple<at::Tensor, std::optional<int64_t>> group_norm_backward_no_wei
mean_ = reshape_dim_into(0, 0, mean_); // [B0 * N, G] mean_ = reshape_dim_into(0, 0, mean_); // [B0 * N, G]
rstd_ = reshape_dim_into(0, 0, rstd_); // [B0 * N, G] rstd_ = reshape_dim_into(0, 0, rstd_); // [B0 * N, G]
const auto result = native_group_norm_backward( auto result0 = std::get<0>(native_group_norm_backward(
grad_out_.contiguous(), grad_out_.contiguous(),
input_.contiguous(), input_.contiguous(),
mean_.contiguous(), mean_.contiguous(),
rstd_.contiguous(), rstd_.contiguous(),
std::nullopt, N * bdim_size, C, HxW, group, {true, false, false}); std::nullopt, N * bdim_size, C, HxW, group, {true, false, false}));
auto result0 = std::get<0>(result); return reshape_dim_outof(0, bdim_size, result0);
result0 = reshape_dim_outof(0, bdim_size, result0);
return std::make_tuple(result0, 0);
} }
static std::tuple<Tensor,Tensor,Tensor> native_group_norm_backward_plumbing( static std::tuple<Tensor,Tensor,Tensor> native_group_norm_backward_plumbing(
@ -422,19 +420,19 @@ static std::tuple<Tensor,Tensor,Tensor> native_group_norm_backward_plumbing(
unwrapTensorAtLevel(grad_normalized_input, cur_level); unwrapTensorAtLevel(grad_normalized_input, cur_level);
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
const auto res = group_norm_backward_no_weight_bias_batch_rule( auto tensor = group_norm_backward_no_weight_bias_batch_rule(
grad_normalized_input_value, grad_normalized_input_bdim, grad_normalized_input_value, grad_normalized_input_bdim,
input_value, input_bdim, input_value, input_bdim,
mean_value, mean_bdim, mean_value, mean_bdim,
rstd_value, rstd_bdim, rstd_value, rstd_bdim,
N, C, HxW, group N, C, HxW, group
); );
grad_input = makeBatched(std::get<0>(res), std::get<1>(res), cur_level); grad_input = makeBatched(std::move(tensor), 0, cur_level);
} }
return std::make_tuple(grad_input, grad_weight, grad_bias); return std::make_tuple(grad_input, grad_weight, grad_bias);
} }
C10_ALWAYS_INLINE bool has_same_shape( static bool has_same_shape(
const Tensor& tensor, std::optional<int64_t> tensor_bdim, const Tensor& tensor, std::optional<int64_t> tensor_bdim,
c10::SymIntArrayRef normalized_shape) { c10::SymIntArrayRef normalized_shape) {
if (!tensor.defined()) { if (!tensor.defined()) {
@ -457,7 +455,7 @@ C10_ALWAYS_INLINE bool has_same_shape(
return true; return true;
} }
C10_ALWAYS_INLINE void check_same_shape( static C10_ALWAYS_INLINE void check_same_shape(
const Tensor& tensor, std::optional<int64_t> tensor_bdim, const Tensor& tensor, std::optional<int64_t> tensor_bdim,
c10::SymIntArrayRef normalized_shape, const std::string& name) { c10::SymIntArrayRef normalized_shape, const std::string& name) {
TORCH_CHECK(has_same_shape(tensor, tensor_bdim, normalized_shape), TORCH_CHECK(has_same_shape(tensor, tensor_bdim, normalized_shape),
@ -469,7 +467,7 @@ C10_ALWAYS_INLINE void check_same_shape(
} }
// Ugh, hard to deduplicate // Ugh, hard to deduplicate
C10_ALWAYS_INLINE void _check_layer_norm_inputs( static C10_ALWAYS_INLINE void _check_layer_norm_inputs(
SymIntArrayRef normalized_shape, SymIntArrayRef normalized_shape,
const Tensor& weight, std::optional<int64_t> weight_bdim, const Tensor& weight, std::optional<int64_t> weight_bdim,
const Tensor& bias, std::optional<int64_t> bias_bdim) { const Tensor& bias, std::optional<int64_t> bias_bdim) {
@ -493,11 +491,9 @@ native_layer_norm_batch_rule(
double eps) { double eps) {
auto input_ = moveBatchDimToFront(input, input_bdim); auto input_ = moveBatchDimToFront(input, input_bdim);
if (!weight_bdim && !bias_bdim) { if (!weight_bdim && !bias_bdim) {
const auto result = at::native_layer_norm_symint(input_, normalized_shape, weight_opt, bias_opt, eps); auto [result0, mean, rstd] = at::native_layer_norm_symint(input_, normalized_shape, weight_opt, bias_opt, eps);
const auto mean = std::get<1>(result);
const auto rstd = std::get<2>(result);
const auto stats_bdim = compute_stat_bdim(input_bdim, mean); const auto stats_bdim = compute_stat_bdim(input_bdim, mean);
return std::make_tuple(std::get<0>(result), 0, mean, stats_bdim, rstd, stats_bdim); return std::make_tuple(std::move(result0), 0, std::move(mean), stats_bdim, std::move(rstd), stats_bdim);
} }
// See [Note: hacky wrapper removal for optional tensor] // See [Note: hacky wrapper removal for optional tensor]
@ -509,9 +505,7 @@ native_layer_norm_batch_rule(
const auto input_logical_rank = rankWithoutBatchDim(input, input_bdim); const auto input_logical_rank = rankWithoutBatchDim(input, input_bdim);
const auto result = at::native_layer_norm_symint(input_, normalized_shape, std::nullopt, std::nullopt, eps); const auto result = at::native_layer_norm_symint(input_, normalized_shape, std::nullopt, std::nullopt, eps);
auto result0 = std::get<0>(result); auto [result0, mean, rstd] = result;
const auto mean = std::get<1>(result);
const auto rstd = std::get<2>(result);
const auto stats_bdim = compute_stat_bdim(input_bdim, mean); const auto stats_bdim = compute_stat_bdim(input_bdim, mean);
if (weight.defined()) { if (weight.defined()) {
@ -638,7 +632,7 @@ static std::tuple<at::Tensor,at::Tensor,at::Tensor> native_layer_norm_backward_p
unwrapTensorAtLevel(grad_normalized_input, cur_level); unwrapTensorAtLevel(grad_normalized_input, cur_level);
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
const auto results = native_layer_norm_backward_no_weight_bias_batch_rule( auto results = native_layer_norm_backward_no_weight_bias_batch_rule(
grad_normalized_input_value, grad_normalized_input_bdim, grad_normalized_input_value, grad_normalized_input_bdim,
input_value, input_bdim, input_value, input_bdim,
normalized_shape, normalized_shape,

View File

@ -171,18 +171,18 @@ void BatchedTensorImpl::shallow_copy_from(const c10::intrusive_ptr<TensorImpl>&
TORCH_CHECK(false, "mutating directly with `.data` under vmap transform is not allowed."); TORCH_CHECK(false, "mutating directly with `.data` under vmap transform is not allowed.");
} }
Tensor makeBatched(const Tensor& tensor, int64_t bdim, int64_t level) { Tensor makeBatched(Tensor tensor, int64_t bdim, int64_t level) {
DispatchKeySet key_set = getKeysToPropagateToWrapper(tensor); DispatchKeySet key_set = getKeysToPropagateToWrapper(tensor);
auto* batched = maybeGetBatchedImpl(tensor); auto* batched = maybeGetBatchedImpl(tensor);
if (batched) { if (batched) {
auto batched_level = batched->level(); auto batched_level = batched->level();
TORCH_INTERNAL_ASSERT(level > batched_level, " batched_level: ", batched_level, " level: ", level); TORCH_INTERNAL_ASSERT(level > batched_level, " batched_level: ", batched_level, " level: ", level);
} }
return at::detail::make_tensor<BatchedTensorImpl>(key_set, tensor, bdim, level); return at::detail::make_tensor<BatchedTensorImpl>(key_set, std::move(tensor), bdim, level);
} }
Tensor addBatchDim(const Tensor& tensor, int64_t dim, int64_t level) { Tensor addBatchDim(Tensor tensor, int64_t dim, int64_t level) {
return makeBatched(tensor, dim, level); return makeBatched(std::move(tensor), dim, level);
} }
} // namespace at::functorch } // namespace at::functorch

View File

@ -144,10 +144,10 @@ inline std::bitset<kVmapNumLevels> createVmapLevelsBitset(int64_t level) {
} }
// Use this to construct a BatchedTensor from a regular Tensor // Use this to construct a BatchedTensor from a regular Tensor
TORCH_API Tensor makeBatched(const Tensor& tensor, int64_t dim, int64_t level); TORCH_API Tensor makeBatched(Tensor tensor, int64_t dim, int64_t level);
// Adds a batch dim to `tensor`, returning a BatchedTensor // Adds a batch dim to `tensor`, returning a BatchedTensor
TORCH_API Tensor addBatchDim(const Tensor& tensor, int64_t dim, int64_t level); TORCH_API Tensor addBatchDim(Tensor tensor, int64_t dim, int64_t level);
// Certain dispatch keys must be propagated to the BatchedTensor (or, in general, // Certain dispatch keys must be propagated to the BatchedTensor (or, in general,
// any wrapper Tensor subclasses). This is because there are methods on Tensor // any wrapper Tensor subclasses). This is because there are methods on Tensor

View File

@ -22,20 +22,20 @@ void vmap_check_escaped(const std::optional<DynamicLayer> &layer, const char* wh
) )
} }
Tensor makeBatched(const Tensor& tensor, std::optional<int64_t> bdim, int64_t level) { Tensor makeBatched(Tensor tensor, std::optional<int64_t> bdim, int64_t level) {
if (bdim.has_value()) { if (bdim.has_value()) {
TORCH_INTERNAL_ASSERT(*bdim >= 0); TORCH_INTERNAL_ASSERT(*bdim >= 0);
TORCH_INTERNAL_ASSERT(*bdim < tensor.dim()); TORCH_INTERNAL_ASSERT(*bdim < tensor.dim());
return makeBatched(tensor, bdim.value(), level); return makeBatched(std::move(tensor), bdim.value(), level);
} }
return tensor; return tensor;
} }
std::vector<Tensor> makeBatchedVector(const std::vector<Tensor>& tensors, std::optional<int64_t> bdim, int64_t level) { std::vector<Tensor> makeBatchedVector(std::vector<Tensor> tensors, std::optional<int64_t> bdim, int64_t level) {
std::vector<Tensor> res; std::vector<Tensor> res;
res.reserve(tensors.size()); res.reserve(tensors.size());
for (const auto & tensor : tensors) { for (auto & tensor : tensors) {
res.emplace_back(makeBatched(tensor, bdim, level)); res.emplace_back(makeBatched(std::move(tensor), bdim, level));
} }
return res; return res;
} }

View File

@ -29,7 +29,7 @@ namespace at::functorch {
void vmap_check_escaped(const std::optional<DynamicLayer> &layer, const char* what); void vmap_check_escaped(const std::optional<DynamicLayer> &layer, const char* what);
// Create a BatchedTensor given a tensor, bdim, and level // Create a BatchedTensor given a tensor, bdim, and level
TORCH_API Tensor makeBatched(const Tensor& tensor, std::optional<int64_t> bdim, int64_t level); TORCH_API Tensor makeBatched(Tensor tensor, std::optional<int64_t> bdim, int64_t level);
// Given a Tensor that may or may not be a BatchedTensor, unwrap it. // Given a Tensor that may or may not be a BatchedTensor, unwrap it.
// If `tensor` is not a BatchedTensor, or is a BatchedTensor but the level // If `tensor` is not a BatchedTensor, or is a BatchedTensor but the level
@ -38,7 +38,7 @@ TORCH_API Tensor makeBatched(const Tensor& tensor, std::optional<int64_t> bdim,
TORCH_API std::tuple<Tensor, std::optional<int64_t>> unwrapTensorAtLevel(const Tensor& tensor, int64_t level); TORCH_API std::tuple<Tensor, std::optional<int64_t>> unwrapTensorAtLevel(const Tensor& tensor, int64_t level);
// Creates a vector of BatchedTensor // Creates a vector of BatchedTensor
TORCH_API std::vector<Tensor> makeBatchedVector(const std::vector<Tensor>& tensors, std::optional<int64_t> bdim, int64_t level); TORCH_API std::vector<Tensor> makeBatchedVector(std::vector<Tensor> tensors, std::optional<int64_t> bdim, int64_t level);
// Returns True if ANY tensor in tensors is batched at level // Returns True if ANY tensor in tensors is batched at level
TORCH_API bool isBatchedAtLevel(ITensorListRef tensors, int64_t level); TORCH_API bool isBatchedAtLevel(ITensorListRef tensors, int64_t level);

View File

@ -55,9 +55,7 @@ void unpack_bcsr(
memset(dst + i * C, zero_points[i], C * sizeof(int8_t)); memset(dst + i * C, zero_points[i], C * sizeof(int8_t));
} }
} }
const std::vector<int8_t>& weight_values = std::get<0>(bcsr); const auto& [weight_values, row_indices, col_indices] = bcsr;
const std::vector<int32_t>& row_indices = std::get<1>(bcsr);
const std::vector<int32_t>& col_indices = std::get<2>(bcsr);
int64_t rowBlocks = (R + RB - 1) / RB; int64_t rowBlocks = (R + RB - 1) / RB;
for (int64_t i = 0; i < rowBlocks; ++i) { for (int64_t i = 0; i < rowBlocks; ++i) {
// For the current tile, rowBPtr starts from currentTileIdx // For the current tile, rowBPtr starts from currentTileIdx
@ -316,4 +314,4 @@ PackedLinearWeightQnnp::PackedLinearWeightQnnp(
} }
#endif // USE_PYTORCH_QNNPACK #endif // USE_PYTORCH_QNNPACK
} // namespace ao } // namespace ao::sparse

View File

@ -25,9 +25,9 @@ __global__ void randperm_handle_duplicate_keys_kernel(T *keys, scalar_t *data, T
// do random permutation inside each island. // do random permutation inside each island.
data += tid; data += tid;
auto seeds = at::cuda::philox::unpack(philox_args); auto [seed, offset] = at::cuda::philox::unpack(philox_args);
curandStatePhilox4_32_10_t state; curandStatePhilox4_32_10_t state;
curand_init(std::get<0>(seeds), tid, std::get<1>(seeds), &state); curand_init(seed, tid, offset, &state);
for (int i = island_size - 1; i > 0; i--) { for (int i = island_size - 1; i > 0; i--) {
unsigned int r = curand(&state) % (i + 1); unsigned int r = curand(&state) % (i + 1);
if (i != r) { if (i != r) {

View File

@ -241,7 +241,8 @@ std::tuple<Tensor, Tensor, Tensor> math_native_layer_norm(
auto outputs = at::native_batch_norm( auto outputs = at::native_batch_norm(
input_reshaped, /*weight=*/{}, /*bias=*/{}, /*running_mean=*/{}, input_reshaped, /*weight=*/{}, /*bias=*/{}, /*running_mean=*/{},
/*running_var=*/{}, /*training=*/true, /*momentum=*/0, eps); /*running_var=*/{}, /*training=*/true, /*momentum=*/0, eps);
auto out = std::get<0>(outputs).view(input_shape); auto& [out, mean, rstd] = outputs;
out = out.view(input_shape);
if (weight.defined() && bias.defined()) { if (weight.defined() && bias.defined()) {
out = bias.addcmul(out, weight, 1); out = bias.addcmul(out, weight, 1);
} else if (weight.defined()) { } else if (weight.defined()) {
@ -249,8 +250,6 @@ std::tuple<Tensor, Tensor, Tensor> math_native_layer_norm(
} else if (bias.defined()) { } else if (bias.defined()) {
out = out.add(bias); out = out.add(bias);
} }
at::Tensor mean = std::get<1>(outputs);
at::Tensor rstd = std::get<2>(outputs);
std::vector<int64_t> stat_shape; std::vector<int64_t> stat_shape;
for (const auto idx : c10::irange(axis)) { for (const auto idx : c10::irange(axis)) {
stat_shape.push_back(input_shape[idx]); stat_shape.push_back(input_shape[idx]);
@ -260,7 +259,7 @@ std::tuple<Tensor, Tensor, Tensor> math_native_layer_norm(
} }
mean = mean.view(stat_shape); mean = mean.view(stat_shape);
rstd = rstd.view(stat_shape); rstd = rstd.view(stat_shape);
return std::make_tuple(out, mean, rstd); return outputs;
} }
Tensor rms_norm_symint( Tensor rms_norm_symint(

View File

@ -110,7 +110,7 @@ void inline col_indices_and_values_resize_(const Tensor& input, int64_t nnz) {
template <typename scalar_t> template <typename scalar_t>
void mkl_result_copy_(const Tensor& input, sparse_matrix_t mkl_desc) { void mkl_result_copy_(const Tensor& input, sparse_matrix_t mkl_desc) {
sparse_index_base_t indexing = SPARSE_INDEX_BASE_ZERO; sparse_index_base_t indexing = SPARSE_INDEX_BASE_ZERO;
MKL_INT rows, cols; MKL_INT rows = 0, cols = 0;
MKL_INT *rows_start = nullptr, *rows_end = nullptr, *columns = nullptr; MKL_INT *rows_start = nullptr, *rows_end = nullptr, *columns = nullptr;
scalar_t* values = nullptr; scalar_t* values = nullptr;
at::mkl::sparse::export_csr( at::mkl::sparse::export_csr(
@ -194,7 +194,7 @@ void addmm_dense_result(
auto ldb = is_B_row_major ? B_strides[ndim - 2] : B_strides[ndim - 1]; auto ldb = is_B_row_major ? B_strides[ndim - 2] : B_strides[ndim - 1];
auto columns_C = mkl_int_cast(C.size(-1), "columns_C"); auto columns_C = mkl_int_cast(C.size(-1), "columns_C");
matrix_descr descrA; matrix_descr descrA{};
descrA.type = SPARSE_MATRIX_TYPE_GENERAL; descrA.type = SPARSE_MATRIX_TYPE_GENERAL;
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES( AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
@ -511,7 +511,7 @@ void addmv_out_sparse_csr(
c10::MaybeOwned<Tensor> vec_ = prepare_dense_vector_for_mkl(vec); c10::MaybeOwned<Tensor> vec_ = prepare_dense_vector_for_mkl(vec);
sparse_operation_t opA = SPARSE_OPERATION_NON_TRANSPOSE; sparse_operation_t opA = SPARSE_OPERATION_NON_TRANSPOSE;
matrix_descr descrA; matrix_descr descrA{};
descrA.type = SPARSE_MATRIX_TYPE_GENERAL; descrA.type = SPARSE_MATRIX_TYPE_GENERAL;
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES( AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
@ -651,7 +651,7 @@ void triangular_solve_out_sparse_csr(
c10::MaybeOwned<Tensor> B_ = prepare_dense_matrix_for_mkl(B, is_X_row_major); c10::MaybeOwned<Tensor> B_ = prepare_dense_matrix_for_mkl(B, is_X_row_major);
sparse_operation_t opA = transpose ? SPARSE_OPERATION_TRANSPOSE : SPARSE_OPERATION_NON_TRANSPOSE; sparse_operation_t opA = transpose ? SPARSE_OPERATION_TRANSPOSE : SPARSE_OPERATION_NON_TRANSPOSE;
matrix_descr descrA; matrix_descr descrA{};
descrA.type = SPARSE_MATRIX_TYPE_TRIANGULAR; descrA.type = SPARSE_MATRIX_TYPE_TRIANGULAR;
descrA.mode = upper ? SPARSE_FILL_MODE_UPPER : SPARSE_FILL_MODE_LOWER; descrA.mode = upper ? SPARSE_FILL_MODE_UPPER : SPARSE_FILL_MODE_LOWER;
descrA.diag = unitriangular ? SPARSE_DIAG_UNIT : SPARSE_DIAG_NON_UNIT; descrA.diag = unitriangular ? SPARSE_DIAG_UNIT : SPARSE_DIAG_NON_UNIT;

View File

@ -487,13 +487,36 @@ static std::tuple<Tensor, Tensor, Tensor> mkldnn_rnn(
auto layer_cx = cx[index]; auto layer_cx = cx[index];
auto reverse = (direction > 0); auto reverse = (direction > 0);
// bias won't be packed // bias won't be packed
auto outputs = at::mkldnn_rnn_layer(layer_input, layer_weights[0], layer_weights[1], std::tie(
has_biases ? layer_weights[2] : at::zeros(layer_weights[0].sizes(), layer_weights[0].options().layout(at::Layout::Strided)), layer_output[direction],
has_biases ? layer_weights[3] : at::zeros(layer_weights[1].sizes(), layer_weights[1].options().layout(at::Layout::Strided)), layer_hx, layer_hy[index],
layer_cx, reverse, batch_sizes, mode, hidden_size, num_layers, has_biases, bidirectional, batch_first, train); layer_cy[index],
layer_output[direction] = std::move(std::get<0>(outputs)); std::ignore) =
layer_hy[index] = std::move(std::get<1>(outputs)); at::mkldnn_rnn_layer(
layer_cy[index] = std::move(std::get<2>(outputs)); layer_input,
layer_weights[0],
layer_weights[1],
has_biases
? layer_weights[2]
: at::zeros(
layer_weights[0].sizes(),
layer_weights[0].options().layout(at::Layout::Strided)),
has_biases
? layer_weights[3]
: at::zeros(
layer_weights[1].sizes(),
layer_weights[1].options().layout(at::Layout::Strided)),
layer_hx,
layer_cx,
reverse,
batch_sizes,
mode,
hidden_size,
num_layers,
has_biases,
bidirectional,
batch_first,
train);
} }
layer_input = num_directions == 1 ? layer_output[0] layer_input = num_directions == 1 ? layer_output[0]
: at::cat(layer_output, /*output_channels*/-1); : at::cat(layer_output, /*output_channels*/-1);

View File

@ -258,11 +258,7 @@ int64_t get_nnz(const Tensor& nestedtensor) {
TensorOptions().device(at::kCUDA).dtype(at::kInt)); TensorOptions().device(at::kCUDA).dtype(at::kInt));
Nnz_q = output_batch_size * max_seqlen_batch_q; Nnz_q = output_batch_size * max_seqlen_batch_q;
} else { } else {
auto cumulative_and_max_q_and_nnz_q = cumulative_and_max_seq_len_nnz(q_t); std::tie(cumulative_sequence_length_q, max_seqlen_batch_q, Nnz_q) = cumulative_and_max_seq_len_nnz(q_t);
cumulative_sequence_length_q =
std::get<0>(cumulative_and_max_q_and_nnz_q);
max_seqlen_batch_q = std::get<1>(cumulative_and_max_q_and_nnz_q);
Nnz_q = std::get<2>(cumulative_and_max_q_and_nnz_q);
} }
int64_t max_seqlen_batch_kv = 0, Nnz_kv = 0; int64_t max_seqlen_batch_kv = 0, Nnz_kv = 0;
@ -277,13 +273,10 @@ int64_t get_nnz(const Tensor& nestedtensor) {
TensorOptions().device(at::kCUDA).dtype(at::kInt)); TensorOptions().device(at::kCUDA).dtype(at::kInt));
Nnz_kv = output_batch_size * max_seqlen_batch_kv; Nnz_kv = output_batch_size * max_seqlen_batch_kv;
} else { } else {
auto cumulative_and_max_kv_and_nnz_kv = k_batch_size_needs_broadcast std::tie(cumulative_sequence_length_kv, max_seqlen_batch_kv, Nnz_kv) =
k_batch_size_needs_broadcast
? cumulative_and_max_seq_len_nnz(v_t) ? cumulative_and_max_seq_len_nnz(v_t)
: cumulative_and_max_seq_len_nnz(k_t); : cumulative_and_max_seq_len_nnz(k_t);
cumulative_sequence_length_kv =
std::get<0>(cumulative_and_max_kv_and_nnz_kv);
max_seqlen_batch_kv = std::get<1>(cumulative_and_max_kv_and_nnz_kv);
Nnz_kv = std::get<2>(cumulative_and_max_kv_and_nnz_kv);
} }
bool q_num_heads_needs_broadcast = q_num_heads != output_num_heads; bool q_num_heads_needs_broadcast = q_num_heads != output_num_heads;
@ -369,14 +362,14 @@ int64_t get_nnz(const Tensor& nestedtensor) {
} }
return std::make_tuple( return std::make_tuple(
query_buffer_reshaped, std::move(query_buffer_reshaped),
key_buffer_reshaped, std::move(key_buffer_reshaped),
value_buffer_reshaped, std::move(value_buffer_reshaped),
cumulative_sequence_length_q, std::move(cumulative_sequence_length_q),
cumulative_sequence_length_kv, std::move(cumulative_sequence_length_kv),
max_seqlen_batch_q, max_seqlen_batch_q,
max_seqlen_batch_kv, max_seqlen_batch_kv,
output_shape); std::move(output_shape));
} }
} // namespace } // namespace

View File

@ -441,21 +441,15 @@ TORCH_API int register_linear_params() {
[](SerializationType state) [](SerializationType state)
-> c10::intrusive_ptr< -> c10::intrusive_ptr<
LinearPackedParamsBase> { // __setstate__ LinearPackedParamsBase> { // __setstate__
at::Tensor weight;
std::optional<at::Tensor> bias;
weight = std::move(std::get<0>(state));
bias = std::move(std::get<1>(state));
#ifdef USE_FBGEMM #ifdef USE_FBGEMM
if (at::globalContext().qEngine() == at::QEngine::FBGEMM || if (at::globalContext().qEngine() == at::QEngine::FBGEMM ||
at::globalContext().qEngine() == at::QEngine::X86) { at::globalContext().qEngine() == at::QEngine::X86) {
const auto& weight = std::get<0>(state);
if (weight.scalar_type() == at::kQInt8) { if (weight.scalar_type() == at::kQInt8) {
return PackedLinearWeight::prepack( return std::apply(PackedLinearWeight::prepack, std::move(state));
std::move(weight), std::move(bias));
} else if (weight.scalar_type() == at::kFloat) { } else if (weight.scalar_type() == at::kFloat) {
// NB: fp16 weight is serialized as float // NB: fp16 weight is serialized as float
return PackedLinearWeightFp16::prepack( return std::apply(PackedLinearWeightFp16::prepack, std::move(state));
std::move(weight), std::move(bias));
} else { } else {
TORCH_CHECK( TORCH_CHECK(
false, false,
@ -467,22 +461,22 @@ TORCH_API int register_linear_params() {
#endif // USE_FBGEMM #endif // USE_FBGEMM
#ifdef USE_PYTORCH_QNNPACK #ifdef USE_PYTORCH_QNNPACK
if (at::globalContext().qEngine() == at::QEngine::QNNPACK) { if (at::globalContext().qEngine() == at::QEngine::QNNPACK) {
const auto& weight = std::get<0>(state);
TORCH_CHECK( TORCH_CHECK(
weight.scalar_type() == at::kQInt8, weight.scalar_type() == at::kQInt8,
"QNNPACK only supports INT8 bit width currently. Got ", "QNNPACK only supports INT8 bit width currently. Got ",
c10::toString(weight.scalar_type())); c10::toString(weight.scalar_type()));
return PackedLinearWeightsQnnp::prepack( return std::apply(PackedLinearWeightsQnnp::prepack, std::move(state));
std::move(weight), std::move(bias));
} }
#endif // USE_PYTORCH_QNNPACK #endif // USE_PYTORCH_QNNPACK
#if AT_MKLDNN_ENABLED() #if AT_MKLDNN_ENABLED()
if (at::globalContext().qEngine() == at::QEngine::ONEDNN) { if (at::globalContext().qEngine() == at::QEngine::ONEDNN) {
const auto& weight = std::get<0>(state);
TORCH_CHECK( TORCH_CHECK(
weight.scalar_type() == at::kQInt8, weight.scalar_type() == at::kQInt8,
"ONEDNN only supports INT8 bit width currently. Got ", "ONEDNN only supports INT8 bit width currently. Got ",
c10::toString(weight.scalar_type())); c10::toString(weight.scalar_type()));
return PackedLinearWeightsOnednn::prepack( return std::apply(PackedLinearWeightsOnednn::prepack, std::move(state));
std::move(weight), std::move(bias));
} }
#endif // #if AT_MKLDNN_ENABLED() #endif // #if AT_MKLDNN_ENABLED()
TORCH_CHECK(false, "Unknown qengine"); TORCH_CHECK(false, "Unknown qengine");

View File

@ -366,11 +366,8 @@ std::tuple<Tensor, Tensor> native_multi_head_attention_cpu(
} }
#endif #endif
// shape: 3 x [B, num_head, T, dim_per_head] // shape: 3 x [B, num_head, T, dim_per_head]
auto q_k_v = _transform_bias_rescale_qkv(qkv, qkv_bias, num_head); auto [q, k, v] = _transform_bias_rescale_qkv(qkv, qkv_bias, num_head);
qkv = Tensor(); // Not used any more, allow free qkv = Tensor(); // Not used any more, allow free
auto& q = std::get<0>(q_k_v);
const auto& k = std::get<1>(q_k_v);
const auto& v = std::get<2>(q_k_v);
#ifndef NDEBUG #ifndef NDEBUG
debug_assert_shape(__LINE__, q, {B, num_head, T, dim_per_head}); debug_assert_shape(__LINE__, q, {B, num_head, T, dim_per_head});
debug_assert_shape(__LINE__, k, {B, num_head, T, dim_per_head}); debug_assert_shape(__LINE__, k, {B, num_head, T, dim_per_head});

View File

@ -617,11 +617,8 @@ std::tuple<Tensor, Tensor> native_multi_head_attention_cuda(
} }
#endif #endif
// shape: 3 x [B, num_head, T, dim_per_head] // shape: 3 x [B, num_head, T, dim_per_head]
auto q_k_v = _transform_bias_rescale_qkv(qkv, qkv_bias, num_head); auto [q, k, v] = _transform_bias_rescale_qkv(qkv, qkv_bias, num_head);
qkv = Tensor(); // Not used any more, allow free qkv = Tensor(); // Not used any more, allow free
auto& q = std::get<0>(q_k_v);
const auto& k = std::get<1>(q_k_v);
const auto& v = std::get<2>(q_k_v);
#ifndef NDEBUG #ifndef NDEBUG
debug_assert_shape(__LINE__, q, {B, num_head, T, dim_per_head}); debug_assert_shape(__LINE__, q, {B, num_head, T, dim_per_head});
debug_assert_shape(__LINE__, k, {B, num_head, T, dim_per_head}); debug_assert_shape(__LINE__, k, {B, num_head, T, dim_per_head});
@ -1413,7 +1410,7 @@ __global__ void rand_uniform_kernel(
const int64_t head_id = blockIdx.y; const int64_t head_id = blockIdx.y;
const int64_t query_idx = threadIdx.x; const int64_t query_idx = threadIdx.x;
const auto seeds = at::cuda::philox::unpack(rng_engine_inputs); const auto [seed, offset] = at::cuda::philox::unpack(rng_engine_inputs);
const int dropout_seq_start = batch_id * (n_heads * n_queries * n_keys) + const int dropout_seq_start = batch_id * (n_heads * n_queries * n_keys) +
head_id * (n_queries * n_keys); head_id * (n_queries * n_keys);
@ -1421,9 +1418,9 @@ __global__ void rand_uniform_kernel(
curandStatePhilox4_32_10_t curand_state; curandStatePhilox4_32_10_t curand_state;
curand_init( curand_init(
std::get<0>(seeds), seed,
0, 0,
std::get<1>(seeds) + dropout_seq_start + query_start_idx, offset + dropout_seq_start + query_start_idx,
&curand_state); &curand_state);
for (int key_start_idx = 0; key_start_idx < n_keys; key_start_idx += 4) { for (int key_start_idx = 0; key_start_idx < n_keys; key_start_idx += 4) {

View File

@ -443,9 +443,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
cute::copy(smem_tiled_copy_KV, tdPsV, tdPrV_copy_view); cute::copy(smem_tiled_copy_KV, tdPsV, tdPrV_copy_view);
} }
auto seeds = at::cuda::philox::unpack(params.philox_args); const auto [seed, offset] = at::cuda::philox::unpack(params.philox_args);
unsigned long long seed = std::get<0>(seeds);
unsigned long long offset = std::get<1>(seeds);
pytorch_flash::Dropout dropout(seed, offset, params.p_dropout_in_uint8_t, pytorch_flash::Dropout dropout(seed, offset, params.p_dropout_in_uint8_t,
bidb, bidh, tidx, params.h); bidb, bidh, tidx, params.h);

View File

@ -1341,7 +1341,7 @@ struct AttentionBackwardKernel {
if (kApplyDropout) { if (kApplyDropout) {
// See Note [Seed and Offset Device] // See Note [Seed and Offset Device]
auto seeds = at::cuda::philox::unpack(p.rng_engine_inputs); auto const [seed, offset] = at::cuda::philox::unpack(p.rng_engine_inputs);
// each element of the attention matrix P with shape // each element of the attention matrix P with shape
// (batch_sz, n_heads, n_queries, n_keys) is associated with a single // (batch_sz, n_heads, n_queries, n_keys) is associated with a single
// offset in RNG sequence. we initialize the RNG state with offset that // offset in RNG sequence. we initialize the RNG state with offset that
@ -1351,9 +1351,9 @@ struct AttentionBackwardKernel {
// rather than once per iteration. each iteration takes a copy of the // rather than once per iteration. each iteration takes a copy of the
// initialized RNG state and offsets it as needed. // initialized RNG state and offsets it as needed.
curand_init( curand_init(
std::get<0>(seeds), seed,
0, 0,
std::get<1>(seeds) + p.dropout_batch_head_rng_offset, offset + p.dropout_batch_head_rng_offset,
&rng_state_init); &rng_state_init);
} }

View File

@ -671,14 +671,13 @@ struct AttentionKernel {
curandStatePhilox4_32_10_t curand_state_init; curandStatePhilox4_32_10_t curand_state_init;
if (kSupportsDropout && p.use_dropout) { if (kSupportsDropout && p.use_dropout) {
const auto seeds = at::cuda::philox::unpack(p.rng_engine_inputs); const auto [seed, offset] = at::cuda::philox::unpack(p.rng_engine_inputs);
if (p.rng_engine_inputs.captured_) { if (p.rng_engine_inputs.captured_) {
// See Note [Seed and Offset Device] // See Note [Seed and Offset Device]
// When we are in cuda graph capture mode the seed and offset are stored // When we are in cuda graph capture mode the seed and offset are stored
// on device We pass in int64_t* seed, and int64_t* offset to act as // on device We pass in int64_t* seed, and int64_t* offset to act as
// scratch space for storing the rng state during the forward pass and // scratch space for storing the rng state during the forward pass and
// saving for backwards. // saving for backwards.
auto [seed, offset] = seeds;
*p.seed = seed; *p.seed = seed;
*p.extragraph_offset = offset; *p.extragraph_offset = offset;
} }
@ -691,9 +690,9 @@ struct AttentionKernel {
// rather than once per iteration. each iteration takes a copy of the // rather than once per iteration. each iteration takes a copy of the
// initialized RNG state and offsets it as needed. // initialized RNG state and offsets it as needed.
curand_init( curand_init(
std::get<0>(seeds), seed,
0, 0,
std::get<1>(seeds) + p.dropout_batch_head_rng_offset, offset + p.dropout_batch_head_rng_offset,
&curand_state_init); &curand_state_init);
} }

View File

@ -149,15 +149,7 @@ TORCH_LIBRARY(vulkan, m) {
}, },
// __setstate__ // __setstate__
[](Conv2dOpContext::State state) { [](Conv2dOpContext::State state) {
return conv2d_clamp_prepack( return std::apply(conv2d_clamp_prepack, std::move(state));
std::move(std::get<0>(state)),
std::move(std::get<1>(state)),
std::move(std::get<2>(state)),
std::move(std::get<3>(state)),
std::move(std::get<4>(state)),
std::get<5>(state),
std::get<6>(state),
std::get<7>(state));
}); });
} }