mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ROCm][TunableOp] hipblaslt tf32 support (#145946)
TF32 is supported by hipblaslt. Support added by #143549. This PR expands integration to the TunableOp feature. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145946 Approved by: https://github.com/pruthvistony, https://github.com/echen4096, https://github.com/yoyoyocmu Co-authored-by: Nichols A. Romero <nick.romero@amd.com>
This commit is contained in:
parent
ab45aaca97
commit
0c8ec26d3b
|
|
@ -498,7 +498,11 @@ class HipblasltGemmOp : public Callable<ParamsT> {
|
|||
mat_c, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_c, sizeof(stride_c)));
|
||||
}
|
||||
|
||||
HipBlasLtMatmulDescriptor matmul(HIPBLAS_COMPUTE_32F, HIP_R_32F);
|
||||
hipblasComputeType_t computeType = HIPBLAS_COMPUTE_32F;
|
||||
if (at::globalContext().allowTF32CuBLAS()) {
|
||||
computeType = HIPBLAS_COMPUTE_32F_FAST_TF32;
|
||||
}
|
||||
HipBlasLtMatmulDescriptor matmul(computeType, HIP_R_32F);
|
||||
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_TRANSA, opa);
|
||||
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_TRANSB, opb);
|
||||
|
||||
|
|
@ -611,6 +615,11 @@ auto GetHipBlasLtTypeStringAndOps() {
|
|||
auto in_out_datatype = HipDataTypeFor<CT>();
|
||||
std::vector<hipblasLtMatmulHeuristicResult_t> heuristic_result;
|
||||
|
||||
hipblasComputeType_t computeType = HIPBLAS_COMPUTE_32F;
|
||||
if (at::globalContext().allowTF32CuBLAS()) {
|
||||
computeType = HIPBLAS_COMPUTE_32F_FAST_TF32;
|
||||
}
|
||||
|
||||
hipblasLtHandle_t handle;
|
||||
TORCH_HIPBLASLT_CHECK(hipblasLtCreate(&handle));
|
||||
TORCH_HIPBLASLT_CHECK(hipblaslt_ext::getAllAlgos(handle,
|
||||
|
|
@ -621,7 +630,7 @@ auto GetHipBlasLtTypeStringAndOps() {
|
|||
b_datatype,
|
||||
in_out_datatype,
|
||||
in_out_datatype,
|
||||
HIPBLAS_COMPUTE_32F,
|
||||
computeType,
|
||||
heuristic_result));
|
||||
TORCH_HIPBLASLT_CHECK(hipblasLtDestroy(handle));
|
||||
|
||||
|
|
|
|||
|
|
@ -141,6 +141,8 @@ class RocblasGemmOp : public Callable<GemmParams<T>> {
|
|||
|
||||
TuningStatus Call(const GemmParams<T>* params) override {
|
||||
auto input_output_type = RocBlasDataTypeFor<T>();
|
||||
if (at::globalContext().allowTF32CuBLAS() && input_output_type == rocblas_datatype_f32_r)
|
||||
return FAIL; // no support for TF32 in rocBLAS
|
||||
auto compute_type = RocBlasComputeTypeFor<T>();
|
||||
auto h_a = DoCastForHalfOrBfloat16(params->alpha);
|
||||
auto h_b = DoCastForHalfOrBfloat16(params->beta);
|
||||
|
|
@ -207,6 +209,8 @@ class RocblasGemmStridedBatchedOp : public Callable<GemmStridedBatchedParams<T>>
|
|||
|
||||
TuningStatus Call(const GemmStridedBatchedParams<T>* params) override {
|
||||
auto input_output_type = RocBlasDataTypeFor<T>();
|
||||
if (at::globalContext().allowTF32CuBLAS() && input_output_type == rocblas_datatype_f32_r)
|
||||
return FAIL; // no support for TF32 in rocBLAS
|
||||
auto compute_type = RocBlasComputeTypeFor<T>();
|
||||
auto h_a = DoCastForHalfOrBfloat16(params->alpha);
|
||||
auto h_b = DoCastForHalfOrBfloat16(params->beta);
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user