[ROCm][TunableOp] hipblaslt tf32 support (#145946)

TF32 is supported by hipblaslt. Support added by #143549.  This PR expands integration to the TunableOp feature.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145946
Approved by: https://github.com/pruthvistony, https://github.com/echen4096, https://github.com/yoyoyocmu

Co-authored-by: Nichols A. Romero <nick.romero@amd.com>
This commit is contained in:
Jeff Daily 2025-03-12 21:17:08 +00:00 committed by PyTorch MergeBot
parent ab45aaca97
commit 0c8ec26d3b
2 changed files with 15 additions and 2 deletions

View File

@ -498,7 +498,11 @@ class HipblasltGemmOp : public Callable<ParamsT> {
mat_c, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_c, sizeof(stride_c)));
}
HipBlasLtMatmulDescriptor matmul(HIPBLAS_COMPUTE_32F, HIP_R_32F);
hipblasComputeType_t computeType = HIPBLAS_COMPUTE_32F;
if (at::globalContext().allowTF32CuBLAS()) {
computeType = HIPBLAS_COMPUTE_32F_FAST_TF32;
}
HipBlasLtMatmulDescriptor matmul(computeType, HIP_R_32F);
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_TRANSA, opa);
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_TRANSB, opb);
@ -611,6 +615,11 @@ auto GetHipBlasLtTypeStringAndOps() {
auto in_out_datatype = HipDataTypeFor<CT>();
std::vector<hipblasLtMatmulHeuristicResult_t> heuristic_result;
hipblasComputeType_t computeType = HIPBLAS_COMPUTE_32F;
if (at::globalContext().allowTF32CuBLAS()) {
computeType = HIPBLAS_COMPUTE_32F_FAST_TF32;
}
hipblasLtHandle_t handle;
TORCH_HIPBLASLT_CHECK(hipblasLtCreate(&handle));
TORCH_HIPBLASLT_CHECK(hipblaslt_ext::getAllAlgos(handle,
@ -621,7 +630,7 @@ auto GetHipBlasLtTypeStringAndOps() {
b_datatype,
in_out_datatype,
in_out_datatype,
HIPBLAS_COMPUTE_32F,
computeType,
heuristic_result));
TORCH_HIPBLASLT_CHECK(hipblasLtDestroy(handle));

View File

@ -141,6 +141,8 @@ class RocblasGemmOp : public Callable<GemmParams<T>> {
TuningStatus Call(const GemmParams<T>* params) override {
auto input_output_type = RocBlasDataTypeFor<T>();
if (at::globalContext().allowTF32CuBLAS() && input_output_type == rocblas_datatype_f32_r)
return FAIL; // no support for TF32 in rocBLAS
auto compute_type = RocBlasComputeTypeFor<T>();
auto h_a = DoCastForHalfOrBfloat16(params->alpha);
auto h_b = DoCastForHalfOrBfloat16(params->beta);
@ -207,6 +209,8 @@ class RocblasGemmStridedBatchedOp : public Callable<GemmStridedBatchedParams<T>>
TuningStatus Call(const GemmStridedBatchedParams<T>* params) override {
auto input_output_type = RocBlasDataTypeFor<T>();
if (at::globalContext().allowTF32CuBLAS() && input_output_type == rocblas_datatype_f32_r)
return FAIL; // no support for TF32 in rocBLAS
auto compute_type = RocBlasComputeTypeFor<T>();
auto h_a = DoCastForHalfOrBfloat16(params->alpha);
auto h_b = DoCastForHalfOrBfloat16(params->beta);