mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
This reverts commit cc444e75d5.
Reverted https://github.com/pytorch/pytorch/pull/147878 on behalf of https://github.com/wdvr due to temporary reverting to revert an older one in the stack ([comment](https://github.com/pytorch/pytorch/pull/147878#issuecomment-2683515567))
This commit is contained in:
parent
2df9a8d72d
commit
05bc8fe62e
|
|
@ -1537,11 +1537,8 @@ void scaled_gemm(
|
||||||
// rowwise isn't supported using cublaslt or older hipblaslt
|
// rowwise isn't supported using cublaslt or older hipblaslt
|
||||||
TORCH_INTERNAL_ASSERT(use_rowwise == false, "rowwise scaled_gemm not supported with blaslt");
|
TORCH_INTERNAL_ASSERT(use_rowwise == false, "rowwise scaled_gemm not supported with blaslt");
|
||||||
#endif
|
#endif
|
||||||
// do not remove {}, this scope corresponds to the else condition in the USE_ROCM section above
|
|
||||||
{
|
|
||||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, mat1_scale_ptr);
|
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, mat1_scale_ptr);
|
||||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, mat2_scale_ptr);
|
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, mat2_scale_ptr);
|
||||||
}
|
|
||||||
if (result_scale_ptr != nullptr) {
|
if (result_scale_ptr != nullptr) {
|
||||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr);
|
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user