[3/N] Avoid copy in std::get (#141843)

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141843
Approved by: https://github.com/Skylion007
This commit is contained in:
cyy 2024-12-06 20:13:33 +00:00 committed by PyTorch MergeBot
parent add4a42ea2
commit 1fa27f6e82
22 changed files with 130 additions and 152 deletions

View File

@ -76,7 +76,7 @@ void BatchedTensorImpl::checkInvariants() const {
}
}
// The following are publically exposed as methods of Tensor
// The following are publicly exposed as methods of Tensor
IntArrayRef BatchedTensorImpl::strides_custom() const {
return strides_default();
@ -113,7 +113,7 @@ const char* BatchedTensorImpl::tensorimpl_type_name() const {
return "BatchedTensorImpl";
}
Tensor makeBatched(const Tensor& tensor, BatchDims bdims) {
Tensor makeBatched(Tensor tensor, BatchDims bdims) {
TORCH_INTERNAL_ASSERT(!isBatchedTensor(tensor));
auto tensor_dim = tensor.dim();
TORCH_CHECK(
@ -124,15 +124,15 @@ Tensor makeBatched(const Tensor& tensor, BatchDims bdims) {
std::all_of(bdims.begin(), bdims.end(),
[](const BatchDim& bdim) { return bdim.level() < kVmapNumLevels; }),
"We only support up to ", kVmapNumLevels, " nested vmaps");
return at::detail::make_tensor<BatchedTensorImpl>(tensor, std::move(bdims));
return at::detail::make_tensor<BatchedTensorImpl>(std::move(tensor), std::move(bdims));
}
Tensor addBatchDim(const Tensor& tensor, int64_t level, int64_t dim) {
Tensor addBatchDim(Tensor tensor, int64_t level, int64_t dim) {
const auto* batched = maybeGetBatchedImpl(tensor);
if (!batched) {
BatchDims bdims;
bdims.emplace_back(level, dim);
return at::detail::make_tensor<BatchedTensorImpl>(tensor, std::move(bdims));
return at::detail::make_tensor<BatchedTensorImpl>(std::move(tensor), std::move(bdims));
}
BatchDims new_bdims(batched->bdims().begin(), batched->bdims().end());
auto actual_bdim = batched->actualDim(dim, /*wrap_dim=*/true);

View File

@ -148,10 +148,10 @@ inline std::ostream& operator<<(std::ostream& out, const BatchDim& bdim) {
}
// Use this to construct a BatchedTensor from a regular Tensor
TORCH_API Tensor makeBatched(const Tensor& tensor, BatchDims bdims);
TORCH_API Tensor makeBatched(Tensor tensor, BatchDims bdims);
// Adds a batch dim to `tensor`, returning a BatchedTensor
TORCH_API Tensor addBatchDim(const Tensor& tensor, int64_t level, int64_t dim);
TORCH_API Tensor addBatchDim(Tensor tensor, int64_t level, int64_t dim);
// Checks if an inplace operation on self and other is "vmap compatible".
// See NOTE: [vmap-incompatible in-place operations] for the definition of this.

View File

@ -353,7 +353,7 @@ void expectOutOfPlaceMultiUnboxedCallingWorks(const KernelFunction& func) {
auto t1 = at::zeros({1});
auto t2 = at::zeros({1});
std::tuple<at::Tensor&, at::Tensor&> tup = func.call<
auto [t1_out, t2_out] = func.call<
std::tuple<at::Tensor&, at::Tensor&>, at::Scalar, at::Scalar, at::Tensor&, at::Tensor&
>(dummy, CPU_TEST_SET, s1, s2, t1, t2);
@ -361,11 +361,9 @@ void expectOutOfPlaceMultiUnboxedCallingWorks(const KernelFunction& func) {
EXPECT_EQ(t1.item().toFloat(), 1.0f);
EXPECT_EQ(t2.item().toFloat(), 2.0f);
auto t1_out = std::get<0>(tup);
EXPECT_EQ(t1_out.item().toFloat(), 1.0f);
EXPECT_TRUE(t1_out.is_same(t1));
auto t2_out = std::get<1>(tup);
EXPECT_EQ(t2_out.item().toFloat(), 2.0f);
EXPECT_TRUE(t2_out.is_same(t2));
}

View File

@ -218,47 +218,43 @@ static std::tuple<Tensor, std::optional<int64_t>> masked_select_backward_batch_r
static std::tuple<Tensor, std::optional<int64_t>> cdist_backward_batch_rule(
const Tensor& grad, std::optional<int64_t> grad_bdim,
const Tensor& x1, std::optional<int64_t> x1_bdim,
const Tensor& x2, std::optional<int64_t> x2_bdim,
Tensor x1, std::optional<int64_t> x1_bdim,
Tensor x2, std::optional<int64_t> x2_bdim,
const double p,
const Tensor& cdist, std::optional<int64_t> cdist_bdim) {
auto x1_ = x1;
if (cdist_bdim && !x1_bdim) {
// We need to make sure that x1 has batch dim if cdist has one
// otherwise, we get
// RuntimeError: Function CdistBackward0 returned an invalid gradient at index 1 - got [5]
// but expected shape compatible with [4, 5]
auto bs = cdist.size(*cdist_bdim);
x1_ = ensure_has_bdim(x1, false, bs);
x1_ = x1_.contiguous();
x1 = ensure_has_bdim(x1, false, bs).contiguous();
x1_bdim = 0;
}
// We need to apply the same preprocessing on x1 and x2 as in the forward pass
// _binary_pointwise_batch_rule
auto x12 = _binary_pointwise_helper(x1_, x1_bdim, x2, x2_bdim);
x1_ = std::move(std::get<0>(x12));
auto& x2_ = std::get<1>(x12);
std::tie(x1, x2)= _binary_pointwise_helper(x1, x1_bdim, x2, x2_bdim);
auto grad_ = moveBatchDimToFront(grad, grad_bdim);
if ((x1_bdim || x2_bdim) && !grad_bdim) {
// We need to make sure that grad has batch dim if x1 or x2 have one
// Probably, there is an assumption on the strides.
// Otherwise grad input contains thrash values, e.g. -7.0816e+29, 7.0816e+29
auto bs = get_bdim_size2(x1_, 0, x2_, 0);
auto bs = get_bdim_size2(x1, 0, x2, 0);
grad_ = ensure_has_bdim(grad_, grad_bdim.has_value(), bs);
grad_ = grad_.contiguous();
}
auto out = at::_cdist_backward(grad_, x1_, x2_, p, cdist);
auto out = at::_cdist_backward(grad_, x1, x2, p, cdist);
std::optional<int64_t> out_bdim = std::nullopt;
if (x1_bdim || x2_bdim) {
out_bdim = 0;
}
return std::make_tuple(out, out_bdim);
return std::make_tuple(std::move(out), out_bdim);
}
static void fill__Tensor_batch_rule(

View File

@ -42,6 +42,7 @@ static Tensor padRight(const Tensor& tensor, std::optional<int64_t> has_bdim, in
}
template<typename F, F Func>
static
std::tuple<Tensor, std::optional<int64_t>,Tensor, std::optional<int64_t>,Tensor, std::optional<int64_t>>
batch_norm_batch_rule(
const Tensor& input, std::optional<int64_t> input_bdim,
@ -70,10 +71,10 @@ batch_norm_batch_rule(
if (!input_bdim && !running_mean_bdim && !running_var_bdim) {
const auto dummy_weight = at::ones(input.size(1), input.options()); // cudnn and miopen require a weight
const auto dummy_bias = at::zeros(input.size(1), input.options()); // without this, get "strides() called on undefined Tensor" on cuda
const auto result = Func(input, dummy_weight, dummy_bias, running_mean_opt, running_var_opt, training, momentum, eps);
auto result = Func(input, dummy_weight, dummy_bias, running_mean_opt, running_var_opt, training, momentum, eps);
result0 = std::get<0>(result).transpose(0, 1); // [C, B, *]
mean = std::get<1>(result);
rstd = std::get<2>(result);
mean = std::move(std::get<1>(result));
rstd = std::move(std::get<2>(result));
} else {
bdim_size = get_bdim_size3(input, input_bdim, running_mean, running_mean_bdim, running_var, running_var_bdim);
auto input_ = moveBatchDimToFront(input, input_bdim);
@ -95,12 +96,12 @@ batch_norm_batch_rule(
const auto dummy_weight = at::ones(input_.size(1), input_.options()); // cudnn and miopen require a weight
const auto dummy_bias = at::zeros(input_.size(1), input_.options()); // without this, get "strides() called on undefined Tensor" on cuda
const auto result = Func(input_, dummy_weight, dummy_bias, running_mean_, running_var_, training, momentum, eps);
auto result = Func(input_, dummy_weight, dummy_bias, running_mean_, running_var_, training, momentum, eps);
result0 = std::get<0>(result).transpose(0, 1); // [(B0, C), B, *]
mean = std::move(std::get<1>(result));
rstd = std::move(std::get<2>(result));
result0 = reshape_dim_outof(0, bdim_size.value(), result0); // [B0, C, B, *]
mean = std::get<1>(result);
mean = reshape_dim_outof(0, bdim_size.value(), mean); // [B0, C]
rstd = std::get<2>(result);
rstd = reshape_dim_outof(0, bdim_size.value(), rstd); // [B0, C]
}
@ -124,6 +125,7 @@ batch_norm_batch_rule(
}
template<typename F, F Func>
static
std::tuple<at::Tensor, std::optional<int64_t>> batch_norm_backward_no_weight_bias_batch_rule(
const at::Tensor & grad_out, std::optional<int64_t> grad_out_bdim,
const at::Tensor & input, std::optional<int64_t> input_bdim,
@ -142,9 +144,9 @@ std::tuple<at::Tensor, std::optional<int64_t>> batch_norm_backward_no_weight_bia
TORCH_INTERNAL_ASSERT(!mean_bdim);
TORCH_INTERNAL_ASSERT(!rstd_bdim);
const auto dummy_weight = at::ones(input.size(1), input.options());
const auto result = Func(
auto result =Func(
grad_out, input, dummy_weight, running_mean_opt, running_var_opt, mean, rstd, training, eps, {true, false, false});
return std::make_tuple(std::get<0>(result), std::nullopt);
return {std::move(std::get<0>(result)), std::nullopt};
}
auto grad_out_ = moveBatchDimToFront(grad_out, grad_out_bdim);
@ -196,6 +198,7 @@ std::tuple<at::Tensor, std::optional<int64_t>> batch_norm_backward_no_weight_bia
}
template<typename F, F Func>
static
std::tuple<at::Tensor,at::Tensor,at::Tensor> batch_norm_backward_plumbing(
const at::Tensor & grad_out,
const at::Tensor & input,
@ -270,7 +273,7 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> batch_norm_backward_plumbing(
unwrapTensorAtLevel(grad_normalized_input.transpose(0, 1), cur_level); // [B0, B, C, *]
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
const auto results = batch_norm_backward_no_weight_bias_batch_rule<F, Func>(
auto results = batch_norm_backward_no_weight_bias_batch_rule<F, Func>(
grad_normalized_input_value, grad_normalized_input_bdim,
input_value, input_bdim,
running_mean_value, running_mean_bdim,
@ -278,7 +281,7 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> batch_norm_backward_plumbing(
save_mean_value, save_mean_bdim,
save_rstd_value, save_rstd_bdim,
training, eps);
grad_input = makeBatched(std::get<0>(results), std::get<1>(results), cur_level);
grad_input = makeBatched(std::move(std::get<0>(results)), std::get<1>(results), cur_level);
}
return std::make_tuple(grad_input, grad_weight, grad_bias);
}
@ -312,16 +315,13 @@ static std::tuple<Tensor,Tensor,Tensor> native_group_norm_plumbing(
const auto bdim_size = input_value.size(*input_bdim);
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
const auto result = at::native_group_norm(input_, std::nullopt, std::nullopt, N * bdim_size, C, HxW, group, eps);
result0 = makeBatched(reshape_dim_outof(0, bdim_size, std::get<0>(result)), 0, cur_level);
mean = makeBatched(reshape_dim_outof(0, bdim_size, std::get<1>(result)), 0, cur_level);
rstd = makeBatched(reshape_dim_outof(0, bdim_size, std::get<2>(result)), 0, cur_level);
std::tie(result0, mean, rstd) = at::native_group_norm(input_, std::nullopt, std::nullopt, N * bdim_size, C, HxW, group, eps);
result0 = makeBatched(reshape_dim_outof(0, bdim_size, result0), 0, cur_level);
mean = makeBatched(reshape_dim_outof(0, bdim_size, mean), 0, cur_level);
rstd = makeBatched(reshape_dim_outof(0, bdim_size, rstd), 0, cur_level);
} else {
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
const auto result = at::native_group_norm(input_value, std::nullopt, std::nullopt, N, C, HxW, group, eps);
result0 = std::get<0>(result);
mean = std::get<1>(result);
rstd = std::get<2>(result);
std::tie(result0, mean, rstd) = at::native_group_norm(input_value, std::nullopt, std::nullopt, N, C, HxW, group, eps);
}
if (weight.defined()) {
@ -334,10 +334,10 @@ static std::tuple<Tensor,Tensor,Tensor> native_group_norm_plumbing(
result0 = result0 + padded_bias;
}
return std::make_tuple(result0, mean, rstd);
return std::make_tuple(std::move(result0), std::move(mean), std::move(rstd));
}
static std::tuple<at::Tensor, std::optional<int64_t>> group_norm_backward_no_weight_bias_batch_rule(
static at::Tensor group_norm_backward_no_weight_bias_batch_rule(
const at::Tensor & grad_out, std::optional<int64_t> grad_out_bdim,
const at::Tensor & input, std::optional<int64_t> input_bdim,
const at::Tensor & mean, std::optional<int64_t> mean_bdim,
@ -359,15 +359,13 @@ static std::tuple<at::Tensor, std::optional<int64_t>> group_norm_backward_no_wei
mean_ = reshape_dim_into(0, 0, mean_); // [B0 * N, G]
rstd_ = reshape_dim_into(0, 0, rstd_); // [B0 * N, G]
const auto result = native_group_norm_backward(
auto result0 = std::get<0>(native_group_norm_backward(
grad_out_.contiguous(),
input_.contiguous(),
mean_.contiguous(),
rstd_.contiguous(),
std::nullopt, N * bdim_size, C, HxW, group, {true, false, false});
auto result0 = std::get<0>(result);
result0 = reshape_dim_outof(0, bdim_size, result0);
return std::make_tuple(result0, 0);
std::nullopt, N * bdim_size, C, HxW, group, {true, false, false}));
return reshape_dim_outof(0, bdim_size, result0);
}
static std::tuple<Tensor,Tensor,Tensor> native_group_norm_backward_plumbing(
@ -422,19 +420,19 @@ static std::tuple<Tensor,Tensor,Tensor> native_group_norm_backward_plumbing(
unwrapTensorAtLevel(grad_normalized_input, cur_level);
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
const auto res = group_norm_backward_no_weight_bias_batch_rule(
auto tensor = group_norm_backward_no_weight_bias_batch_rule(
grad_normalized_input_value, grad_normalized_input_bdim,
input_value, input_bdim,
mean_value, mean_bdim,
rstd_value, rstd_bdim,
N, C, HxW, group
);
grad_input = makeBatched(std::get<0>(res), std::get<1>(res), cur_level);
grad_input = makeBatched(std::move(tensor), 0, cur_level);
}
return std::make_tuple(grad_input, grad_weight, grad_bias);
}
C10_ALWAYS_INLINE bool has_same_shape(
static bool has_same_shape(
const Tensor& tensor, std::optional<int64_t> tensor_bdim,
c10::SymIntArrayRef normalized_shape) {
if (!tensor.defined()) {
@ -457,7 +455,7 @@ C10_ALWAYS_INLINE bool has_same_shape(
return true;
}
C10_ALWAYS_INLINE void check_same_shape(
static C10_ALWAYS_INLINE void check_same_shape(
const Tensor& tensor, std::optional<int64_t> tensor_bdim,
c10::SymIntArrayRef normalized_shape, const std::string& name) {
TORCH_CHECK(has_same_shape(tensor, tensor_bdim, normalized_shape),
@ -469,7 +467,7 @@ C10_ALWAYS_INLINE void check_same_shape(
}
// Ugh, hard to deduplicate
C10_ALWAYS_INLINE void _check_layer_norm_inputs(
static C10_ALWAYS_INLINE void _check_layer_norm_inputs(
SymIntArrayRef normalized_shape,
const Tensor& weight, std::optional<int64_t> weight_bdim,
const Tensor& bias, std::optional<int64_t> bias_bdim) {
@ -493,11 +491,9 @@ native_layer_norm_batch_rule(
double eps) {
auto input_ = moveBatchDimToFront(input, input_bdim);
if (!weight_bdim && !bias_bdim) {
const auto result = at::native_layer_norm_symint(input_, normalized_shape, weight_opt, bias_opt, eps);
const auto mean = std::get<1>(result);
const auto rstd = std::get<2>(result);
auto [result0, mean, rstd] = at::native_layer_norm_symint(input_, normalized_shape, weight_opt, bias_opt, eps);
const auto stats_bdim = compute_stat_bdim(input_bdim, mean);
return std::make_tuple(std::get<0>(result), 0, mean, stats_bdim, rstd, stats_bdim);
return std::make_tuple(std::move(result0), 0, std::move(mean), stats_bdim, std::move(rstd), stats_bdim);
}
// See [Note: hacky wrapper removal for optional tensor]
@ -509,9 +505,7 @@ native_layer_norm_batch_rule(
const auto input_logical_rank = rankWithoutBatchDim(input, input_bdim);
const auto result = at::native_layer_norm_symint(input_, normalized_shape, std::nullopt, std::nullopt, eps);
auto result0 = std::get<0>(result);
const auto mean = std::get<1>(result);
const auto rstd = std::get<2>(result);
auto [result0, mean, rstd] = result;
const auto stats_bdim = compute_stat_bdim(input_bdim, mean);
if (weight.defined()) {
@ -638,7 +632,7 @@ static std::tuple<at::Tensor,at::Tensor,at::Tensor> native_layer_norm_backward_p
unwrapTensorAtLevel(grad_normalized_input, cur_level);
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
const auto results = native_layer_norm_backward_no_weight_bias_batch_rule(
auto results = native_layer_norm_backward_no_weight_bias_batch_rule(
grad_normalized_input_value, grad_normalized_input_bdim,
input_value, input_bdim,
normalized_shape,

View File

@ -171,18 +171,18 @@ void BatchedTensorImpl::shallow_copy_from(const c10::intrusive_ptr<TensorImpl>&
TORCH_CHECK(false, "mutating directly with `.data` under vmap transform is not allowed.");
}
Tensor makeBatched(const Tensor& tensor, int64_t bdim, int64_t level) {
Tensor makeBatched(Tensor tensor, int64_t bdim, int64_t level) {
DispatchKeySet key_set = getKeysToPropagateToWrapper(tensor);
auto* batched = maybeGetBatchedImpl(tensor);
if (batched) {
auto batched_level = batched->level();
TORCH_INTERNAL_ASSERT(level > batched_level, " batched_level: ", batched_level, " level: ", level);
}
return at::detail::make_tensor<BatchedTensorImpl>(key_set, tensor, bdim, level);
return at::detail::make_tensor<BatchedTensorImpl>(key_set, std::move(tensor), bdim, level);
}
Tensor addBatchDim(const Tensor& tensor, int64_t dim, int64_t level) {
return makeBatched(tensor, dim, level);
Tensor addBatchDim(Tensor tensor, int64_t dim, int64_t level) {
return makeBatched(std::move(tensor), dim, level);
}
} // namespace at::functorch

View File

@ -144,10 +144,10 @@ inline std::bitset<kVmapNumLevels> createVmapLevelsBitset(int64_t level) {
}
// Use this to construct a BatchedTensor from a regular Tensor
TORCH_API Tensor makeBatched(const Tensor& tensor, int64_t dim, int64_t level);
TORCH_API Tensor makeBatched(Tensor tensor, int64_t dim, int64_t level);
// Adds a batch dim to `tensor`, returning a BatchedTensor
TORCH_API Tensor addBatchDim(const Tensor& tensor, int64_t dim, int64_t level);
TORCH_API Tensor addBatchDim(Tensor tensor, int64_t dim, int64_t level);
// Certain dispatch keys must be propagated to the BatchedTensor (or, in general,
// any wrapper Tensor subclasses). This is because there are methods on Tensor

View File

@ -22,20 +22,20 @@ void vmap_check_escaped(const std::optional<DynamicLayer> &layer, const char* wh
)
}
Tensor makeBatched(const Tensor& tensor, std::optional<int64_t> bdim, int64_t level) {
Tensor makeBatched(Tensor tensor, std::optional<int64_t> bdim, int64_t level) {
if (bdim.has_value()) {
TORCH_INTERNAL_ASSERT(*bdim >= 0);
TORCH_INTERNAL_ASSERT(*bdim < tensor.dim());
return makeBatched(tensor, bdim.value(), level);
return makeBatched(std::move(tensor), bdim.value(), level);
}
return tensor;
}
std::vector<Tensor> makeBatchedVector(const std::vector<Tensor>& tensors, std::optional<int64_t> bdim, int64_t level) {
std::vector<Tensor> makeBatchedVector(std::vector<Tensor> tensors, std::optional<int64_t> bdim, int64_t level) {
std::vector<Tensor> res;
res.reserve(tensors.size());
for (const auto & tensor : tensors) {
res.emplace_back(makeBatched(tensor, bdim, level));
for (auto & tensor : tensors) {
res.emplace_back(makeBatched(std::move(tensor), bdim, level));
}
return res;
}

View File

@ -29,7 +29,7 @@ namespace at::functorch {
void vmap_check_escaped(const std::optional<DynamicLayer> &layer, const char* what);
// Create a BatchedTensor given a tensor, bdim, and level
TORCH_API Tensor makeBatched(const Tensor& tensor, std::optional<int64_t> bdim, int64_t level);
TORCH_API Tensor makeBatched(Tensor tensor, std::optional<int64_t> bdim, int64_t level);
// Given a Tensor that may or may not be a BatchedTensor, unwrap it.
// If `tensor` is not a BatchedTensor, or is a BatchedTensor but the level
@ -38,7 +38,7 @@ TORCH_API Tensor makeBatched(const Tensor& tensor, std::optional<int64_t> bdim,
TORCH_API std::tuple<Tensor, std::optional<int64_t>> unwrapTensorAtLevel(const Tensor& tensor, int64_t level);
// Creates a vector of BatchedTensor
TORCH_API std::vector<Tensor> makeBatchedVector(const std::vector<Tensor>& tensors, std::optional<int64_t> bdim, int64_t level);
TORCH_API std::vector<Tensor> makeBatchedVector(std::vector<Tensor> tensors, std::optional<int64_t> bdim, int64_t level);
// Returns True if ANY tensor in tensors is batched at level
TORCH_API bool isBatchedAtLevel(ITensorListRef tensors, int64_t level);

View File

@ -55,9 +55,7 @@ void unpack_bcsr(
memset(dst + i * C, zero_points[i], C * sizeof(int8_t));
}
}
const std::vector<int8_t>& weight_values = std::get<0>(bcsr);
const std::vector<int32_t>& row_indices = std::get<1>(bcsr);
const std::vector<int32_t>& col_indices = std::get<2>(bcsr);
const auto& [weight_values, row_indices, col_indices] = bcsr;
int64_t rowBlocks = (R + RB - 1) / RB;
for (int64_t i = 0; i < rowBlocks; ++i) {
// For the current tile, rowBPtr starts from currentTileIdx
@ -316,4 +314,4 @@ PackedLinearWeightQnnp::PackedLinearWeightQnnp(
}
#endif // USE_PYTORCH_QNNPACK
} // namespace ao
} // namespace ao::sparse

View File

@ -25,9 +25,9 @@ __global__ void randperm_handle_duplicate_keys_kernel(T *keys, scalar_t *data, T
// do random permutation inside each island.
data += tid;
auto seeds = at::cuda::philox::unpack(philox_args);
auto [seed, offset] = at::cuda::philox::unpack(philox_args);
curandStatePhilox4_32_10_t state;
curand_init(std::get<0>(seeds), tid, std::get<1>(seeds), &state);
curand_init(seed, tid, offset, &state);
for (int i = island_size - 1; i > 0; i--) {
unsigned int r = curand(&state) % (i + 1);
if (i != r) {

View File

@ -241,7 +241,8 @@ std::tuple<Tensor, Tensor, Tensor> math_native_layer_norm(
auto outputs = at::native_batch_norm(
input_reshaped, /*weight=*/{}, /*bias=*/{}, /*running_mean=*/{},
/*running_var=*/{}, /*training=*/true, /*momentum=*/0, eps);
auto out = std::get<0>(outputs).view(input_shape);
auto& [out, mean, rstd] = outputs;
out = out.view(input_shape);
if (weight.defined() && bias.defined()) {
out = bias.addcmul(out, weight, 1);
} else if (weight.defined()) {
@ -249,8 +250,6 @@ std::tuple<Tensor, Tensor, Tensor> math_native_layer_norm(
} else if (bias.defined()) {
out = out.add(bias);
}
at::Tensor mean = std::get<1>(outputs);
at::Tensor rstd = std::get<2>(outputs);
std::vector<int64_t> stat_shape;
for (const auto idx : c10::irange(axis)) {
stat_shape.push_back(input_shape[idx]);
@ -260,7 +259,7 @@ std::tuple<Tensor, Tensor, Tensor> math_native_layer_norm(
}
mean = mean.view(stat_shape);
rstd = rstd.view(stat_shape);
return std::make_tuple(out, mean, rstd);
return outputs;
}
Tensor rms_norm_symint(

View File

@ -110,7 +110,7 @@ void inline col_indices_and_values_resize_(const Tensor& input, int64_t nnz) {
template <typename scalar_t>
void mkl_result_copy_(const Tensor& input, sparse_matrix_t mkl_desc) {
sparse_index_base_t indexing = SPARSE_INDEX_BASE_ZERO;
MKL_INT rows, cols;
MKL_INT rows = 0, cols = 0;
MKL_INT *rows_start = nullptr, *rows_end = nullptr, *columns = nullptr;
scalar_t* values = nullptr;
at::mkl::sparse::export_csr(
@ -194,7 +194,7 @@ void addmm_dense_result(
auto ldb = is_B_row_major ? B_strides[ndim - 2] : B_strides[ndim - 1];
auto columns_C = mkl_int_cast(C.size(-1), "columns_C");
matrix_descr descrA;
matrix_descr descrA{};
descrA.type = SPARSE_MATRIX_TYPE_GENERAL;
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
@ -511,7 +511,7 @@ void addmv_out_sparse_csr(
c10::MaybeOwned<Tensor> vec_ = prepare_dense_vector_for_mkl(vec);
sparse_operation_t opA = SPARSE_OPERATION_NON_TRANSPOSE;
matrix_descr descrA;
matrix_descr descrA{};
descrA.type = SPARSE_MATRIX_TYPE_GENERAL;
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
@ -651,7 +651,7 @@ void triangular_solve_out_sparse_csr(
c10::MaybeOwned<Tensor> B_ = prepare_dense_matrix_for_mkl(B, is_X_row_major);
sparse_operation_t opA = transpose ? SPARSE_OPERATION_TRANSPOSE : SPARSE_OPERATION_NON_TRANSPOSE;
matrix_descr descrA;
matrix_descr descrA{};
descrA.type = SPARSE_MATRIX_TYPE_TRIANGULAR;
descrA.mode = upper ? SPARSE_FILL_MODE_UPPER : SPARSE_FILL_MODE_LOWER;
descrA.diag = unitriangular ? SPARSE_DIAG_UNIT : SPARSE_DIAG_NON_UNIT;

View File

@ -487,13 +487,36 @@ static std::tuple<Tensor, Tensor, Tensor> mkldnn_rnn(
auto layer_cx = cx[index];
auto reverse = (direction > 0);
// bias won't be packed
auto outputs = at::mkldnn_rnn_layer(layer_input, layer_weights[0], layer_weights[1],
has_biases ? layer_weights[2] : at::zeros(layer_weights[0].sizes(), layer_weights[0].options().layout(at::Layout::Strided)),
has_biases ? layer_weights[3] : at::zeros(layer_weights[1].sizes(), layer_weights[1].options().layout(at::Layout::Strided)), layer_hx,
layer_cx, reverse, batch_sizes, mode, hidden_size, num_layers, has_biases, bidirectional, batch_first, train);
layer_output[direction] = std::move(std::get<0>(outputs));
layer_hy[index] = std::move(std::get<1>(outputs));
layer_cy[index] = std::move(std::get<2>(outputs));
std::tie(
layer_output[direction],
layer_hy[index],
layer_cy[index],
std::ignore) =
at::mkldnn_rnn_layer(
layer_input,
layer_weights[0],
layer_weights[1],
has_biases
? layer_weights[2]
: at::zeros(
layer_weights[0].sizes(),
layer_weights[0].options().layout(at::Layout::Strided)),
has_biases
? layer_weights[3]
: at::zeros(
layer_weights[1].sizes(),
layer_weights[1].options().layout(at::Layout::Strided)),
layer_hx,
layer_cx,
reverse,
batch_sizes,
mode,
hidden_size,
num_layers,
has_biases,
bidirectional,
batch_first,
train);
}
layer_input = num_directions == 1 ? layer_output[0]
: at::cat(layer_output, /*output_channels*/-1);

View File

@ -258,11 +258,7 @@ int64_t get_nnz(const Tensor& nestedtensor) {
TensorOptions().device(at::kCUDA).dtype(at::kInt));
Nnz_q = output_batch_size * max_seqlen_batch_q;
} else {
auto cumulative_and_max_q_and_nnz_q = cumulative_and_max_seq_len_nnz(q_t);
cumulative_sequence_length_q =
std::get<0>(cumulative_and_max_q_and_nnz_q);
max_seqlen_batch_q = std::get<1>(cumulative_and_max_q_and_nnz_q);
Nnz_q = std::get<2>(cumulative_and_max_q_and_nnz_q);
std::tie(cumulative_sequence_length_q, max_seqlen_batch_q, Nnz_q) = cumulative_and_max_seq_len_nnz(q_t);
}
int64_t max_seqlen_batch_kv = 0, Nnz_kv = 0;
@ -277,13 +273,10 @@ int64_t get_nnz(const Tensor& nestedtensor) {
TensorOptions().device(at::kCUDA).dtype(at::kInt));
Nnz_kv = output_batch_size * max_seqlen_batch_kv;
} else {
auto cumulative_and_max_kv_and_nnz_kv = k_batch_size_needs_broadcast
std::tie(cumulative_sequence_length_kv, max_seqlen_batch_kv, Nnz_kv) =
k_batch_size_needs_broadcast
? cumulative_and_max_seq_len_nnz(v_t)
: cumulative_and_max_seq_len_nnz(k_t);
cumulative_sequence_length_kv =
std::get<0>(cumulative_and_max_kv_and_nnz_kv);
max_seqlen_batch_kv = std::get<1>(cumulative_and_max_kv_and_nnz_kv);
Nnz_kv = std::get<2>(cumulative_and_max_kv_and_nnz_kv);
}
bool q_num_heads_needs_broadcast = q_num_heads != output_num_heads;
@ -369,14 +362,14 @@ int64_t get_nnz(const Tensor& nestedtensor) {
}
return std::make_tuple(
query_buffer_reshaped,
key_buffer_reshaped,
value_buffer_reshaped,
cumulative_sequence_length_q,
cumulative_sequence_length_kv,
std::move(query_buffer_reshaped),
std::move(key_buffer_reshaped),
std::move(value_buffer_reshaped),
std::move(cumulative_sequence_length_q),
std::move(cumulative_sequence_length_kv),
max_seqlen_batch_q,
max_seqlen_batch_kv,
output_shape);
std::move(output_shape));
}
} // namespace

View File

@ -441,21 +441,15 @@ TORCH_API int register_linear_params() {
[](SerializationType state)
-> c10::intrusive_ptr<
LinearPackedParamsBase> { // __setstate__
at::Tensor weight;
std::optional<at::Tensor> bias;
weight = std::move(std::get<0>(state));
bias = std::move(std::get<1>(state));
#ifdef USE_FBGEMM
if (at::globalContext().qEngine() == at::QEngine::FBGEMM ||
at::globalContext().qEngine() == at::QEngine::X86) {
const auto& weight = std::get<0>(state);
if (weight.scalar_type() == at::kQInt8) {
return PackedLinearWeight::prepack(
std::move(weight), std::move(bias));
return std::apply(PackedLinearWeight::prepack, std::move(state));
} else if (weight.scalar_type() == at::kFloat) {
// NB: fp16 weight is serialized as float
return PackedLinearWeightFp16::prepack(
std::move(weight), std::move(bias));
return std::apply(PackedLinearWeightFp16::prepack, std::move(state));
} else {
TORCH_CHECK(
false,
@ -467,22 +461,22 @@ TORCH_API int register_linear_params() {
#endif // USE_FBGEMM
#ifdef USE_PYTORCH_QNNPACK
if (at::globalContext().qEngine() == at::QEngine::QNNPACK) {
const auto& weight = std::get<0>(state);
TORCH_CHECK(
weight.scalar_type() == at::kQInt8,
"QNNPACK only supports INT8 bit width currently. Got ",
c10::toString(weight.scalar_type()));
return PackedLinearWeightsQnnp::prepack(
std::move(weight), std::move(bias));
return std::apply(PackedLinearWeightsQnnp::prepack, std::move(state));
}
#endif // USE_PYTORCH_QNNPACK
#if AT_MKLDNN_ENABLED()
if (at::globalContext().qEngine() == at::QEngine::ONEDNN) {
const auto& weight = std::get<0>(state);
TORCH_CHECK(
weight.scalar_type() == at::kQInt8,
"ONEDNN only supports INT8 bit width currently. Got ",
c10::toString(weight.scalar_type()));
return PackedLinearWeightsOnednn::prepack(
std::move(weight), std::move(bias));
return std::apply(PackedLinearWeightsOnednn::prepack, std::move(state));
}
#endif // #if AT_MKLDNN_ENABLED()
TORCH_CHECK(false, "Unknown qengine");

View File

@ -366,11 +366,8 @@ std::tuple<Tensor, Tensor> native_multi_head_attention_cpu(
}
#endif
// shape: 3 x [B, num_head, T, dim_per_head]
auto q_k_v = _transform_bias_rescale_qkv(qkv, qkv_bias, num_head);
auto [q, k, v] = _transform_bias_rescale_qkv(qkv, qkv_bias, num_head);
qkv = Tensor(); // Not used any more, allow free
auto& q = std::get<0>(q_k_v);
const auto& k = std::get<1>(q_k_v);
const auto& v = std::get<2>(q_k_v);
#ifndef NDEBUG
debug_assert_shape(__LINE__, q, {B, num_head, T, dim_per_head});
debug_assert_shape(__LINE__, k, {B, num_head, T, dim_per_head});

View File

@ -617,11 +617,8 @@ std::tuple<Tensor, Tensor> native_multi_head_attention_cuda(
}
#endif
// shape: 3 x [B, num_head, T, dim_per_head]
auto q_k_v = _transform_bias_rescale_qkv(qkv, qkv_bias, num_head);
auto [q, k, v] = _transform_bias_rescale_qkv(qkv, qkv_bias, num_head);
qkv = Tensor(); // Not used any more, allow free
auto& q = std::get<0>(q_k_v);
const auto& k = std::get<1>(q_k_v);
const auto& v = std::get<2>(q_k_v);
#ifndef NDEBUG
debug_assert_shape(__LINE__, q, {B, num_head, T, dim_per_head});
debug_assert_shape(__LINE__, k, {B, num_head, T, dim_per_head});
@ -1413,7 +1410,7 @@ __global__ void rand_uniform_kernel(
const int64_t head_id = blockIdx.y;
const int64_t query_idx = threadIdx.x;
const auto seeds = at::cuda::philox::unpack(rng_engine_inputs);
const auto [seed, offset] = at::cuda::philox::unpack(rng_engine_inputs);
const int dropout_seq_start = batch_id * (n_heads * n_queries * n_keys) +
head_id * (n_queries * n_keys);
@ -1421,9 +1418,9 @@ __global__ void rand_uniform_kernel(
curandStatePhilox4_32_10_t curand_state;
curand_init(
std::get<0>(seeds),
seed,
0,
std::get<1>(seeds) + dropout_seq_start + query_start_idx,
offset + dropout_seq_start + query_start_idx,
&curand_state);
for (int key_start_idx = 0; key_start_idx < n_keys; key_start_idx += 4) {

View File

@ -443,9 +443,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
cute::copy(smem_tiled_copy_KV, tdPsV, tdPrV_copy_view);
}
auto seeds = at::cuda::philox::unpack(params.philox_args);
unsigned long long seed = std::get<0>(seeds);
unsigned long long offset = std::get<1>(seeds);
const auto [seed, offset] = at::cuda::philox::unpack(params.philox_args);
pytorch_flash::Dropout dropout(seed, offset, params.p_dropout_in_uint8_t,
bidb, bidh, tidx, params.h);

View File

@ -1341,7 +1341,7 @@ struct AttentionBackwardKernel {
if (kApplyDropout) {
// See Note [Seed and Offset Device]
auto seeds = at::cuda::philox::unpack(p.rng_engine_inputs);
auto const [seed, offset] = at::cuda::philox::unpack(p.rng_engine_inputs);
// each element of the attention matrix P with shape
// (batch_sz, n_heads, n_queries, n_keys) is associated with a single
// offset in RNG sequence. we initialize the RNG state with offset that
@ -1351,9 +1351,9 @@ struct AttentionBackwardKernel {
// rather than once per iteration. each iteration takes a copy of the
// initialized RNG state and offsets it as needed.
curand_init(
std::get<0>(seeds),
seed,
0,
std::get<1>(seeds) + p.dropout_batch_head_rng_offset,
offset + p.dropout_batch_head_rng_offset,
&rng_state_init);
}

View File

@ -671,14 +671,13 @@ struct AttentionKernel {
curandStatePhilox4_32_10_t curand_state_init;
if (kSupportsDropout && p.use_dropout) {
const auto seeds = at::cuda::philox::unpack(p.rng_engine_inputs);
const auto [seed, offset] = at::cuda::philox::unpack(p.rng_engine_inputs);
if (p.rng_engine_inputs.captured_) {
// See Note [Seed and Offset Device]
// When we are in cuda graph capture mode the seed and offset are stored
// on device We pass in int64_t* seed, and int64_t* offset to act as
// scratch space for storing the rng state during the forward pass and
// saving for backwards.
auto [seed, offset] = seeds;
*p.seed = seed;
*p.extragraph_offset = offset;
}
@ -691,9 +690,9 @@ struct AttentionKernel {
// rather than once per iteration. each iteration takes a copy of the
// initialized RNG state and offsets it as needed.
curand_init(
std::get<0>(seeds),
seed,
0,
std::get<1>(seeds) + p.dropout_batch_head_rng_offset,
offset + p.dropout_batch_head_rng_offset,
&curand_state_init);
}

View File

@ -149,15 +149,7 @@ TORCH_LIBRARY(vulkan, m) {
},
// __setstate__
[](Conv2dOpContext::State state) {
return conv2d_clamp_prepack(
std::move(std::get<0>(state)),
std::move(std::get<1>(state)),
std::move(std::get<2>(state)),
std::move(std::get<3>(state)),
std::move(std::get<4>(state)),
std::get<5>(state),
std::get<6>(state),
std::get<7>(state));
return std::apply(conv2d_clamp_prepack, std::move(state));
});
}