[1/2] Split cublasCommonArgs into its own file (#166313)

Summary:

* Factor out `cublasCommonArgs` struct
* Necessary for factoring out scaled mm routines

Test Plan:

```
pytest -svv test/test_matmul_cuda.py
pytest -svv test/test_scaled_matmul_cuda.py
```

Reviewers:

Subscribers:

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlayton@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166313
Approved by: https://github.com/eqy, https://github.com/Skylion007
This commit is contained in:
Simon Layton 2025-10-27 09:38:18 -07:00 committed by PyTorch MergeBot
parent a4a0378e6b
commit acd936cc1a
2 changed files with 172 additions and 160 deletions

View File

@ -19,6 +19,7 @@
#include <ATen/native/Resize.h>
#include <c10/util/MaybeOwned.h>
#include <ATen/native/GroupedMMUtils.h>
#include <ATen/native/cuda/cuBlasCommonArgs.h>
#include <ATen/native/cuda/RowwiseScaledMM.h>
#include <ATen/native/cuda/ScaledGroupMM.h>
#include <ATen/native/cuda/GroupMM.h>
@ -57,169 +58,9 @@
namespace at::native {
namespace {
// TODO: https://github.com/pytorch/pytorch/pull/59380#pullrequestreview-725310492
c10::MaybeOwned<Tensor> inline resolve_conj_if_indicated(const Tensor& tensor, bool resolve_conj) {
if (resolve_conj && tensor.is_conj()) {
return c10::MaybeOwned<Tensor>::owned(tensor.resolve_conj());
} else {
return c10::MaybeOwned<Tensor>::borrowed(tensor);
}
}
c10::MaybeOwned<Tensor> inline prepare_matrix_for_cublas(const Tensor& tensor, bool& transpose_tensor, bool transpose_result) {
if (tensor.is_non_overlapping_and_dense()) { // common case
transpose_tensor = tensor.is_contiguous();
return resolve_conj_if_indicated(tensor, transpose_result ? transpose_tensor : !transpose_tensor);
}
IntArrayRef tensor_strides = tensor.strides();
IntArrayRef tensor_sizes = tensor.sizes();
if ((tensor_strides[0] == 1) && (tensor_strides[1] >= std::max<int64_t>(1, tensor_sizes[0]))) {
transpose_tensor = false;
return resolve_conj_if_indicated(tensor, !transpose_result);
} else if ((tensor_strides[1] == 1) && (tensor_strides[0] >= std::max<int64_t>(1, tensor_sizes[1]))) {
transpose_tensor = true;
return resolve_conj_if_indicated(tensor, transpose_result);
} else {
transpose_tensor = true;
return c10::MaybeOwned<Tensor>::owned(tensor.clone(at::MemoryFormat::Contiguous));
}
}
c10::MaybeOwned<Tensor> inline prepare_matrix_for_cublas(const Tensor& tensor, bool& transpose_tensor) {
if (tensor.is_non_overlapping_and_dense()) { // common case
transpose_tensor = tensor.is_contiguous();
return resolve_conj_if_indicated(tensor, true);
}
IntArrayRef tensor_strides = tensor.strides();
IntArrayRef tensor_sizes = tensor.sizes();
if ((tensor_strides[0] == 1) && (tensor_strides[1] >= std::max<int64_t>(1, tensor_sizes[0]))) {
transpose_tensor = false;
return resolve_conj_if_indicated(tensor, true);
} else if ((tensor_strides[1] == 1) && (tensor_strides[0] >= std::max<int64_t>(1, tensor_sizes[1]))) {
transpose_tensor = true;
return resolve_conj_if_indicated(tensor, true);
} else {
transpose_tensor = true;
return c10::MaybeOwned<Tensor>::owned(tensor.clone(at::MemoryFormat::Contiguous));
}
}
using at::blas::ScalingType;
using at::blas::SwizzleType;
/**
* @brief Prepares matrices for CUBLAS operation
*
* This constructor prepares tensors for CUBLAS
* The main difference is that PyTorch uses row-major as the default and
* CUBLAS expects column-major.
*
* @details
* To enable row-major output while using CUBLAS,
* we use the mathematical identity that (A × B)^T = B^T × A^T.
*
* Transpose in this context refers to Cublas's(Fortran) definition of transpose (row-major)
* T = row-major, N = col-major
*
* Example:
* For matrices A (M×K)(row-major) and B (K×N)(row-major):
* - Standard multiplication: A × B = (M×K) × (K×N) = M×N result (row-major)
* - Using our transpose trick: (B^T × A^T) = (N×K)(T) × (K×M)(T) = N×M(N)
* - However, since the output form cublas is column-major this is
* - equivalent to an output of size MxN row-major as expected
*
* The transpose flags are derived from the layouts of the passed in tensors
*
* If the operands are in packed float4 format, `k`, `lda` and `ldb` are adjusted
* to their unpacked values to match what cuBLAS expects.
*
* @param mat1 First input matrix
* @param mat2 Second input matrix
* @param c Output matrix (result)
* @param scale_a Optional scaling factor for first matrix
* @param scale_b Optional scaling factor for second matrix
* @param scale_result Optional scaling factor for result
*/
struct cublasCommonArgs {
cublasCommonArgs(
const Tensor& mat1,
const Tensor& mat2,
Tensor& c,
const std::optional<Tensor>& scale_a = std::nullopt,
const std::optional<Tensor>& scale_b = std::nullopt,
const std::optional<Tensor>& scale_result = std::nullopt,
const std::optional<ScalingType>& scaling_choice_a = std::nullopt,
const std::optional<ScalingType>& scaling_choice_b = std::nullopt) {
bool transpose_result = false, transpose_a = false, transpose_b = false;
result = prepare_matrix_for_cublas(c, transpose_result);
mata = prepare_matrix_for_cublas(transpose_result ? mat2 : mat1, transpose_a, transpose_result);
matb = prepare_matrix_for_cublas(transpose_result ? mat1 : mat2, transpose_b, transpose_result);
// Handle scale tensors if provided
if (scale_a && scale_b) {
// By default since we return in row-major we run the gemm
// as B.T @ A.T, check transpose_result to determine if we flip the scales
scale_mata_ptr = transpose_result ? scale_b->data_ptr() : scale_a->data_ptr();
scale_mata_dtype = transpose_result ? scale_b->scalar_type() : scale_a->scalar_type();
scaling_mata_type = transpose_result ? scaling_choice_b : scaling_choice_a;
scale_matb_ptr = transpose_result ? scale_a->data_ptr() : scale_b->data_ptr();
scale_matb_dtype = transpose_result ? scale_a->scalar_type() : scale_b->scalar_type();
scaling_matb_type = transpose_result ? scaling_choice_a : scaling_choice_b;
}
if (scale_result) {
scale_result_ptr = scale_result->data_ptr();
scale_result_dtype = scale_result->scalar_type();
}
// Update transpose flags
if (transpose_result) {
transpose_a = !transpose_a;
transpose_b = !transpose_b;
}
auto sizes_a = mata->sizes();
auto sizes_b = matb->sizes();
m = sizes_a[transpose_result ? 1 : 0];
k = sizes_a[transpose_result ? 0 : 1];
n = sizes_b[transpose_result ? 0 : 1];
lda = mata->stride((transpose_a == transpose_result) ? 1 : 0);
ldb = matb->stride((transpose_b == transpose_result) ? 1 : 0);
result_ld = result->stride(transpose_result ? 0 : 1);
transa = transpose_a ? mata->is_conj() ? 'c' : 't' : 'n';
transb = transpose_b ? matb->is_conj() ? 'c' : 't' : 'n';
// cuBLAS expects unpacked values of `k`, `lda` and `ldb`, adjust for 4x2 packing
// if the gemm operands are in packed float4
if (mat1.dtype() == at::kFloat4_e2m1fn_x2 && mat2.dtype() == at::kFloat4_e2m1fn_x2) {
k = k * 2;
lda = lda * 2;
ldb = ldb * 2;
}
}
// Matrix members
char transa, transb;
int64_t m, n, k;
int64_t lda, ldb, result_ld;
c10::MaybeOwned<Tensor> mata, matb, result;
// Scale members
void* scale_mata_ptr = nullptr;
void* scale_matb_ptr = nullptr;
void* scale_result_ptr = nullptr;
std::optional<c10::ScalarType> scale_mata_dtype;
std::optional<ScalingType> scaling_mata_type;
std::optional<c10::ScalarType> scale_matb_dtype;
std::optional<ScalingType> scaling_matb_type;
std::optional<c10::ScalarType> scale_result_dtype;
};
} // namespace
c10::MaybeOwned<Tensor> prepare_batch_matrix_for_cublas(const Tensor& tensor, bool& transpose_tensor, int64_t& ld_tensor, bool transpose_result, int64_t m, int64_t n) {
IntArrayRef tensor_strides = tensor.strides();
c10::MaybeOwned<Tensor> tensor_;

View File

@ -0,0 +1,171 @@
#pragma once
#include <ATen/core/Tensor.h>
namespace at::native {
using at::blas::ScalingType;
using at::blas::SwizzleType;
namespace {
// TODO: https://github.com/pytorch/pytorch/pull/59380#pullrequestreview-725310492
c10::MaybeOwned<Tensor> inline resolve_conj_if_indicated(const Tensor& tensor, bool resolve_conj) {
if (resolve_conj && tensor.is_conj()) {
return c10::MaybeOwned<Tensor>::owned(tensor.resolve_conj());
} else {
return c10::MaybeOwned<Tensor>::borrowed(tensor);
}
}
c10::MaybeOwned<Tensor> inline prepare_matrix_for_cublas(const Tensor& tensor, bool& transpose_tensor, bool transpose_result) {
if (tensor.is_non_overlapping_and_dense()) { // common case
transpose_tensor = tensor.is_contiguous();
return resolve_conj_if_indicated(tensor, transpose_result ? transpose_tensor : !transpose_tensor);
}
IntArrayRef tensor_strides = tensor.strides();
IntArrayRef tensor_sizes = tensor.sizes();
if ((tensor_strides[0] == 1) && (tensor_strides[1] >= std::max<int64_t>(1, tensor_sizes[0]))) {
transpose_tensor = false;
return resolve_conj_if_indicated(tensor, !transpose_result);
} else if ((tensor_strides[1] == 1) && (tensor_strides[0] >= std::max<int64_t>(1, tensor_sizes[1]))) {
transpose_tensor = true;
return resolve_conj_if_indicated(tensor, transpose_result);
} else {
transpose_tensor = true;
return c10::MaybeOwned<Tensor>::owned(tensor.clone(at::MemoryFormat::Contiguous));
}
}
c10::MaybeOwned<Tensor> inline prepare_matrix_for_cublas(const Tensor& tensor, bool& transpose_tensor) {
if (tensor.is_non_overlapping_and_dense()) { // common case
transpose_tensor = tensor.is_contiguous();
return resolve_conj_if_indicated(tensor, true);
}
IntArrayRef tensor_strides = tensor.strides();
IntArrayRef tensor_sizes = tensor.sizes();
if ((tensor_strides[0] == 1) && (tensor_strides[1] >= std::max<int64_t>(1, tensor_sizes[0]))) {
transpose_tensor = false;
return resolve_conj_if_indicated(tensor, true);
} else if ((tensor_strides[1] == 1) && (tensor_strides[0] >= std::max<int64_t>(1, tensor_sizes[1]))) {
transpose_tensor = true;
return resolve_conj_if_indicated(tensor, true);
} else {
transpose_tensor = true;
return c10::MaybeOwned<Tensor>::owned(tensor.clone(at::MemoryFormat::Contiguous));
}
}
} // namespace
/**
* @brief Prepares matrices for CUBLAS operation
*
* This constructor prepares tensors for CUBLAS
* The main difference is that PyTorch uses row-major as the default and
* CUBLAS expects column-major.
*
* @details
* To enable row-major output while using CUBLAS,
* we use the mathematical identity that (A × B)^T = B^T × A^T.
*
* Transpose in this context refers to Cublas's(Fortran) definition of transpose (row-major)
* T = row-major, N = col-major
*
* Example:
* For matrices A (M×K)(row-major) and B (K×N)(row-major):
* - Standard multiplication: A × B = (M×K) × (K×N) = M×N result (row-major)
* - Using our transpose trick: (B^T × A^T) = (N×K)(T) × (K×M)(T) = N×M(N)
* - However, since the output form cublas is column-major this is
* - equivalent to an output of size MxN row-major as expected
*
* The transpose flags are derived from the layouts of the passed in tensors
*
* If the operands are in packed float4 format, `k`, `lda` and `ldb` are adjusted
* to their unpacked values to match what cuBLAS expects.
*
* @param mat1 First input matrix
* @param mat2 Second input matrix
* @param c Output matrix (result)
* @param scale_a Optional scaling factor for first matrix
* @param scale_b Optional scaling factor for second matrix
* @param scale_result Optional scaling factor for result
*/
struct cublasCommonArgs {
cublasCommonArgs(
const Tensor& mat1,
const Tensor& mat2,
Tensor& c,
const std::optional<Tensor>& scale_a = std::nullopt,
const std::optional<Tensor>& scale_b = std::nullopt,
const std::optional<Tensor>& scale_result = std::nullopt,
const std::optional<ScalingType>& scaling_choice_a = std::nullopt,
const std::optional<ScalingType>& scaling_choice_b = std::nullopt) {
bool transpose_result = false, transpose_a = false, transpose_b = false;
result = prepare_matrix_for_cublas(c, transpose_result);
mata = prepare_matrix_for_cublas(transpose_result ? mat2 : mat1, transpose_a, transpose_result);
matb = prepare_matrix_for_cublas(transpose_result ? mat1 : mat2, transpose_b, transpose_result);
// Handle scale tensors if provided
if (scale_a && scale_b) {
// By default since we return in row-major we run the gemm
// as B.T @ A.T, check transpose_result to determine if we flip the scales
scale_mata_ptr = transpose_result ? scale_b->data_ptr() : scale_a->data_ptr();
scale_mata_dtype = transpose_result ? scale_b->scalar_type() : scale_a->scalar_type();
scaling_mata_type = transpose_result ? scaling_choice_b : scaling_choice_a;
scale_matb_ptr = transpose_result ? scale_a->data_ptr() : scale_b->data_ptr();
scale_matb_dtype = transpose_result ? scale_a->scalar_type() : scale_b->scalar_type();
scaling_matb_type = transpose_result ? scaling_choice_a : scaling_choice_b;
}
if (scale_result) {
scale_result_ptr = scale_result->data_ptr();
scale_result_dtype = scale_result->scalar_type();
}
// Update transpose flags
if (transpose_result) {
transpose_a = !transpose_a;
transpose_b = !transpose_b;
}
auto sizes_a = mata->sizes();
auto sizes_b = matb->sizes();
m = sizes_a[transpose_result ? 1 : 0];
k = sizes_a[transpose_result ? 0 : 1];
n = sizes_b[transpose_result ? 0 : 1];
lda = mata->stride((transpose_a == transpose_result) ? 1 : 0);
ldb = matb->stride((transpose_b == transpose_result) ? 1 : 0);
result_ld = result->stride(transpose_result ? 0 : 1);
transa = transpose_a ? mata->is_conj() ? 'c' : 't' : 'n';
transb = transpose_b ? matb->is_conj() ? 'c' : 't' : 'n';
// cuBLAS expects unpacked values of `k`, `lda` and `ldb`, adjust for 4x2 packing
// if the gemm operands are in packed float4
if (mat1.dtype() == at::kFloat4_e2m1fn_x2 && mat2.dtype() == at::kFloat4_e2m1fn_x2) {
k = k * 2;
lda = lda * 2;
ldb = ldb * 2;
}
}
// Matrix members
char transa, transb;
int64_t m, n, k;
int64_t lda, ldb, result_ld;
c10::MaybeOwned<Tensor> mata, matb, result;
// Scale members
void* scale_mata_ptr = nullptr;
void* scale_matb_ptr = nullptr;
void* scale_result_ptr = nullptr;
std::optional<c10::ScalarType> scale_mata_dtype;
std::optional<ScalingType> scaling_mata_type;
std::optional<c10::ScalarType> scale_matb_dtype;
std::optional<ScalingType> scaling_matb_type;
std::optional<c10::ScalarType> scale_result_dtype;
};
} // namespace at::native