mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "torch._scaled_mm with MXFP8 (#147548)"
This reverts commit 12b9674cb6.
Reverted https://github.com/pytorch/pytorch/pull/147548 on behalf of https://github.com/wdvr due to failing internal build - similar to previous, see below ([comment](https://github.com/pytorch/pytorch/pull/147548#issuecomment-2684134336))
This commit is contained in:
parent
4216478250
commit
a84db75e1b
|
|
@ -14,7 +14,6 @@
|
||||||
#include <c10/macros/Export.h>
|
#include <c10/macros/Export.h>
|
||||||
#include <c10/util/env.h>
|
#include <c10/util/env.h>
|
||||||
#include <c10/util/irange.h>
|
#include <c10/util/irange.h>
|
||||||
#include <c10/core/ScalarType.h>
|
|
||||||
|
|
||||||
#ifdef USE_ROCM
|
#ifdef USE_ROCM
|
||||||
#include <hipblaslt/hipblaslt-ext.hpp>
|
#include <hipblaslt/hipblaslt-ext.hpp>
|
||||||
|
|
@ -1504,12 +1503,10 @@ void scaled_gemm(
|
||||||
const void* mat1_scale_ptr,
|
const void* mat1_scale_ptr,
|
||||||
int64_t mat1_ld,
|
int64_t mat1_ld,
|
||||||
ScalarType mat1_dtype,
|
ScalarType mat1_dtype,
|
||||||
ScalarType mat1_scale_dtype,
|
|
||||||
const void* mat2_ptr,
|
const void* mat2_ptr,
|
||||||
const void* mat2_scale_ptr,
|
const void* mat2_scale_ptr,
|
||||||
int64_t mat2_ld,
|
int64_t mat2_ld,
|
||||||
ScalarType mat2_dtype,
|
ScalarType mat2_dtype,
|
||||||
ScalarType mat2_scale_dtype,
|
|
||||||
const void* bias_ptr,
|
const void* bias_ptr,
|
||||||
ScalarType bias_dtype,
|
ScalarType bias_dtype,
|
||||||
void* result_ptr,
|
void* result_ptr,
|
||||||
|
|
@ -1537,8 +1534,10 @@ void scaled_gemm(
|
||||||
// rowwise isn't supported using cublaslt or older hipblaslt
|
// rowwise isn't supported using cublaslt or older hipblaslt
|
||||||
TORCH_INTERNAL_ASSERT(use_rowwise == false, "rowwise scaled_gemm not supported with blaslt");
|
TORCH_INTERNAL_ASSERT(use_rowwise == false, "rowwise scaled_gemm not supported with blaslt");
|
||||||
#endif
|
#endif
|
||||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, mat1_scale_ptr);
|
{
|
||||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, mat2_scale_ptr);
|
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, mat1_scale_ptr);
|
||||||
|
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, mat2_scale_ptr);
|
||||||
|
}
|
||||||
if (result_scale_ptr != nullptr) {
|
if (result_scale_ptr != nullptr) {
|
||||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr);
|
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr);
|
||||||
}
|
}
|
||||||
|
|
@ -1561,15 +1560,6 @@ void scaled_gemm(
|
||||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, ScalarTypeToCudaDataType(bias_dtype));
|
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, ScalarTypeToCudaDataType(bias_dtype));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (mat1_scale_dtype == kFloat8_e8m0fnu && mat2_scale_dtype == kFloat8_e8m0fnu) {
|
|
||||||
#if CUDA_VERSION >= 12080
|
|
||||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0);
|
|
||||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0);
|
|
||||||
#else
|
|
||||||
TORCH_CHECK(false, "scaled_gemm with `torch.float8_e8m0fnu` scales is only supported for CUDA 12.8 and above");
|
|
||||||
#endif // CUDA_VERSION >= 12080
|
|
||||||
}
|
|
||||||
|
|
||||||
auto stream = c10::cuda::getCurrentCUDAStream();
|
auto stream = c10::cuda::getCurrentCUDAStream();
|
||||||
size_t workspaceSize = 0;
|
size_t workspaceSize = 0;
|
||||||
auto workspace_ptr = _getWorkspace(workspaceSize);
|
auto workspace_ptr = _getWorkspace(workspaceSize);
|
||||||
|
|
|
||||||
|
|
@ -130,12 +130,10 @@ void scaled_gemm(
|
||||||
const void* mat1_scale_ptr,
|
const void* mat1_scale_ptr,
|
||||||
int64_t mat1_ld,
|
int64_t mat1_ld,
|
||||||
ScalarType mat1_dtype,
|
ScalarType mat1_dtype,
|
||||||
ScalarType mat1_scale_dtype,
|
|
||||||
const void* mat2_ptr,
|
const void* mat2_ptr,
|
||||||
const void* mat2_scale_ptr,
|
const void* mat2_scale_ptr,
|
||||||
int64_t mat2_ld,
|
int64_t mat2_ld,
|
||||||
ScalarType mat2_dtype,
|
ScalarType mat2_dtype,
|
||||||
ScalarType mat2_scale_dtype,
|
|
||||||
const void* bias_ptr,
|
const void* bias_ptr,
|
||||||
ScalarType bias_dtype,
|
ScalarType bias_dtype,
|
||||||
void* result_ptr,
|
void* result_ptr,
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,6 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <c10/core/ScalarType.h>
|
|
||||||
|
|
||||||
#include <ATen/cuda/tunable/TunableOp.h>
|
#include <ATen/cuda/tunable/TunableOp.h>
|
||||||
#include <ATen/cuda/CUDABlas.h>
|
#include <ATen/cuda/CUDABlas.h>
|
||||||
|
|
@ -425,12 +424,10 @@ struct ScaledGemmParams : OpParams {
|
||||||
const void* a_scale_ptr{};
|
const void* a_scale_ptr{};
|
||||||
int64_t lda{};
|
int64_t lda{};
|
||||||
ScalarType a_dtype{};
|
ScalarType a_dtype{};
|
||||||
ScalarType a_scale_dtype{};
|
|
||||||
const void* b{};
|
const void* b{};
|
||||||
const void* b_scale_ptr{};
|
const void* b_scale_ptr{};
|
||||||
int64_t ldb{};
|
int64_t ldb{};
|
||||||
ScalarType b_dtype{};
|
ScalarType b_dtype{};
|
||||||
ScalarType b_scale_dtype{};
|
|
||||||
const void* bias_ptr{};
|
const void* bias_ptr{};
|
||||||
ScalarType bias_dtype{};
|
ScalarType bias_dtype{};
|
||||||
void* c{};
|
void* c{};
|
||||||
|
|
|
||||||
|
|
@ -95,12 +95,10 @@ class DefaultScaledGemmOp : public Callable<ScaledGemmParams<T>> {
|
||||||
params->a_scale_ptr,
|
params->a_scale_ptr,
|
||||||
params->lda,
|
params->lda,
|
||||||
params->a_dtype,
|
params->a_dtype,
|
||||||
params->a_scale_dtype,
|
|
||||||
params->b,
|
params->b,
|
||||||
params->b_scale_ptr,
|
params->b_scale_ptr,
|
||||||
params->ldb,
|
params->ldb,
|
||||||
params->b_dtype,
|
params->b_dtype,
|
||||||
params->b_scale_dtype,
|
|
||||||
params->bias_ptr,
|
params->bias_ptr,
|
||||||
params->bias_dtype,
|
params->bias_dtype,
|
||||||
params->c,
|
params->c,
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,4 @@
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <c10/util/typeid.h>
|
|
||||||
#include <c10/util/Exception.h>
|
#include <c10/util/Exception.h>
|
||||||
#include <c10/core/Scalar.h>
|
#include <c10/core/Scalar.h>
|
||||||
#include <c10/core/ScalarType.h>
|
#include <c10/core/ScalarType.h>
|
||||||
|
|
@ -96,33 +95,11 @@ c10::MaybeOwned<Tensor> inline prepare_matrix_for_cublas(const Tensor& tensor, b
|
||||||
}
|
}
|
||||||
|
|
||||||
struct cublasCommonArgs {
|
struct cublasCommonArgs {
|
||||||
cublasCommonArgs(
|
cublasCommonArgs(const Tensor& mat1, const Tensor& mat2, Tensor& c) {
|
||||||
const Tensor& mat1,
|
|
||||||
const Tensor& mat2,
|
|
||||||
Tensor& c,
|
|
||||||
const std::optional<Tensor>& scale_a = c10::nullopt,
|
|
||||||
const std::optional<Tensor>& scale_b = c10::nullopt,
|
|
||||||
const std::optional<Tensor>& scale_result = c10::nullopt) {
|
|
||||||
bool transpose_result = false, transpose_mat1 = false, transpose_mat2 = false;
|
bool transpose_result = false, transpose_mat1 = false, transpose_mat2 = false;
|
||||||
result = prepare_matrix_for_cublas(c, transpose_result);
|
result = prepare_matrix_for_cublas(c, transpose_result);
|
||||||
mata = prepare_matrix_for_cublas(transpose_result ? mat2 : mat1, transpose_mat1, transpose_result);
|
mata = prepare_matrix_for_cublas(transpose_result ? mat2 : mat1, transpose_mat1, transpose_result);
|
||||||
matb = prepare_matrix_for_cublas(transpose_result ? mat1 : mat2, transpose_mat2, transpose_result);
|
matb = prepare_matrix_for_cublas(transpose_result ? mat1 : mat2, transpose_mat2, transpose_result);
|
||||||
|
|
||||||
// Handle scale tensors if provided
|
|
||||||
if (scale_a && scale_b) {
|
|
||||||
// By default since we return in row-major we run the gemm
|
|
||||||
// as B.T @ A.T, check transpose_result to determine if we flip the scales
|
|
||||||
scale_mata_ptr = transpose_result ? scale_b->data_ptr() : scale_a->data_ptr();
|
|
||||||
scale_mata_dtype = transpose_result ? scale_b->scalar_type() : scale_a->scalar_type();
|
|
||||||
scale_matb_ptr = transpose_result ? scale_a->data_ptr() : scale_b->data_ptr();
|
|
||||||
scale_matb_dtype = transpose_result ? scale_a->scalar_type() : scale_b->scalar_type();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (scale_result) {
|
|
||||||
scale_result_ptr = scale_result->data_ptr();
|
|
||||||
scale_result_dtype = scale_result->scalar_type();
|
|
||||||
}
|
|
||||||
|
|
||||||
auto mat1_sizes = mat1.sizes();
|
auto mat1_sizes = mat1.sizes();
|
||||||
auto mat2_sizes = mat2.sizes();
|
auto mat2_sizes = mat2.sizes();
|
||||||
if (transpose_result) {
|
if (transpose_result) {
|
||||||
|
|
@ -138,23 +115,13 @@ struct cublasCommonArgs {
|
||||||
lda = mata->stride((transpose_mat1 == transpose_result) ? 1 : 0);
|
lda = mata->stride((transpose_mat1 == transpose_result) ? 1 : 0);
|
||||||
ldb = matb->stride((transpose_mat2 == transpose_result) ? 1 : 0);
|
ldb = matb->stride((transpose_mat2 == transpose_result) ? 1 : 0);
|
||||||
result_ld = result->stride(transpose_result ? 0 : 1);
|
result_ld = result->stride(transpose_result ? 0 : 1);
|
||||||
transa = transpose_mat1 ? mata->is_conj() ? 'c' : 't' : 'n';
|
transa = transpose_mat1 ? mata->is_conj() ? 'c' : 't' : 'n';
|
||||||
transb = transpose_mat2 ? matb->is_conj() ? 'c' : 't' : 'n';
|
transb = transpose_mat2 ? matb->is_conj() ? 'c' : 't' : 'n';
|
||||||
}
|
}
|
||||||
|
|
||||||
// Matrix members
|
|
||||||
char transa, transb;
|
char transa, transb;
|
||||||
int64_t m, n, k;
|
int64_t m, n, k;
|
||||||
int64_t lda, ldb, result_ld;
|
int64_t lda, ldb, result_ld;
|
||||||
c10::MaybeOwned<Tensor> mata, matb, result;
|
c10::MaybeOwned<Tensor> mata, matb, result;
|
||||||
|
|
||||||
// Scale members
|
|
||||||
void* scale_mata_ptr = nullptr;
|
|
||||||
void* scale_matb_ptr = nullptr;
|
|
||||||
void* scale_result_ptr = nullptr;
|
|
||||||
std::optional<c10::ScalarType> scale_mata_dtype;
|
|
||||||
std::optional<c10::ScalarType> scale_matb_dtype;
|
|
||||||
std::optional<c10::ScalarType> scale_result_dtype;
|
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
|
@ -936,10 +903,9 @@ static bool _scaled_mm_is_fnuz() {
|
||||||
|
|
||||||
namespace{
|
namespace{
|
||||||
|
|
||||||
enum class ScalingType : std::uint8_t {
|
enum class ScalingType {
|
||||||
TensorWise,
|
TensorWise,
|
||||||
RowWise,
|
RowWise,
|
||||||
BlockWise,
|
|
||||||
Error
|
Error
|
||||||
};
|
};
|
||||||
/*
|
/*
|
||||||
|
|
@ -947,13 +913,10 @@ enum class ScalingType : std::uint8_t {
|
||||||
* ---------------------------
|
* ---------------------------
|
||||||
* Conditions and corresponding Scaling Types:
|
* Conditions and corresponding Scaling Types:
|
||||||
*
|
*
|
||||||
* - If scale tensors are Float8_e8m0fnu:
|
|
||||||
* - Returns BlockWise (with additional size checks).
|
|
||||||
*
|
|
||||||
* - If scale_a.numel() == 1 && scale_b.numel() == 1:
|
* - If scale_a.numel() == 1 && scale_b.numel() == 1:
|
||||||
* - Returns TensorWise.
|
* - Returns TensorWise.
|
||||||
*
|
*
|
||||||
* - Else if scale_a.dim() == 2 && scale_a.size(0) == dim_m && scale_b.size(0) == dim_n:
|
* - Else if scale_a.dim() == 1 && scale_a.size(0) == dim_m && scale_b.size(0) == dim_n:
|
||||||
* - Returns RowWise.
|
* - Returns RowWise.
|
||||||
*
|
*
|
||||||
* - Otherwise:
|
* - Otherwise:
|
||||||
|
|
@ -966,40 +929,7 @@ ScalingType get_scaling_type(
|
||||||
const at::Tensor& scale_a,
|
const at::Tensor& scale_a,
|
||||||
const at::Tensor& scale_b,
|
const at::Tensor& scale_b,
|
||||||
int64_t dim_m,
|
int64_t dim_m,
|
||||||
int64_t dim_k,
|
|
||||||
int64_t dim_n) {
|
int64_t dim_n) {
|
||||||
// Check for BlockWise scaling (FP8_E8M0 types)
|
|
||||||
if (scale_a.scalar_type() == scale_b.scalar_type() &&
|
|
||||||
scale_a.scalar_type() == at::kFloat8_e8m0fnu) {
|
|
||||||
constexpr int64_t BLOCK_SIZE_K = 32;
|
|
||||||
constexpr int64_t BLOCK_SIZE_MN = 128;
|
|
||||||
|
|
||||||
auto ceil_div = [](auto a, auto b) { return (a + b - 1) / b; };
|
|
||||||
auto num_k_blocks = ceil_div(dim_k, BLOCK_SIZE_K);
|
|
||||||
auto padded_num_k_blocks = ceil_div(num_k_blocks, 4) * 4;
|
|
||||||
|
|
||||||
// TODO: We might want to enforce some structure on the shapes of the scale
|
|
||||||
// tensors
|
|
||||||
|
|
||||||
// Check expected sizes for block-wise scaling
|
|
||||||
auto expected_a_size =
|
|
||||||
BLOCK_SIZE_MN * ceil_div(dim_m, BLOCK_SIZE_MN) * padded_num_k_blocks;
|
|
||||||
auto expected_b_size =
|
|
||||||
BLOCK_SIZE_MN * ceil_div(dim_n, BLOCK_SIZE_MN) * padded_num_k_blocks;
|
|
||||||
|
|
||||||
TORCH_CHECK(scale_a.numel() == expected_a_size,
|
|
||||||
"For BlockWise scaling: Expected scale_a size to be ",
|
|
||||||
expected_a_size, " but got ", scale_a.numel());
|
|
||||||
TORCH_CHECK(scale_b.numel() == expected_b_size,
|
|
||||||
"For BlockWise scaling: Expected scale_b size to be ",
|
|
||||||
expected_b_size, " but got ", scale_b.numel());
|
|
||||||
|
|
||||||
TORCH_CHECK(
|
|
||||||
scale_a.is_contiguous() && scale_b.is_contiguous(),
|
|
||||||
"For BlockWise scaling: Both scale_a and scale_b must be contiguous");
|
|
||||||
|
|
||||||
return ScalingType::BlockWise;
|
|
||||||
}
|
|
||||||
// Both Per-Tensor and Row-wise scaling expect fp32 tensors
|
// Both Per-Tensor and Row-wise scaling expect fp32 tensors
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
scale_a.scalar_type() == kFloat && scale_b.scalar_type() == kFloat,
|
scale_a.scalar_type() == kFloat && scale_b.scalar_type() == kFloat,
|
||||||
|
|
@ -1097,7 +1027,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
|
||||||
mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")");
|
mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")");
|
||||||
|
|
||||||
// Check what type of scaling we are doing based on inputs
|
// Check what type of scaling we are doing based on inputs
|
||||||
ScalingType scaling_choice = get_scaling_type(scale_a, scale_b, mat1.size(0), mat1.size(1), mat2.size(1));
|
ScalingType scaling_choice = get_scaling_type(scale_a, scale_b, mat1.size(0), mat2.size(1));
|
||||||
TORCH_INTERNAL_ASSERT(scaling_choice != ScalingType::Error, "Scaling type not supported");
|
TORCH_INTERNAL_ASSERT(scaling_choice != ScalingType::Error, "Scaling type not supported");
|
||||||
|
|
||||||
TORCH_CHECK(!scale_result || (scale_result->numel() == 1 && scale_result->scalar_type() == kFloat),
|
TORCH_CHECK(!scale_result || (scale_result->numel() == 1 && scale_result->scalar_type() == kFloat),
|
||||||
|
|
@ -1190,7 +1120,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
cublasCommonArgs args(mat1, mat2, out, scale_a, scale_b, scale_result);
|
cublasCommonArgs args(mat1, mat2, out);
|
||||||
const auto out_dtype_ = args.result->scalar_type();
|
const auto out_dtype_ = args.result->scalar_type();
|
||||||
TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt");
|
TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt");
|
||||||
|
|
||||||
|
|
@ -1300,7 +1230,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
#endif
|
#endif
|
||||||
{
|
{
|
||||||
at::cuda::blas::scaled_gemm(
|
at::cuda::blas::scaled_gemm(
|
||||||
args.transa,
|
args.transa,
|
||||||
args.transb,
|
args.transb,
|
||||||
|
|
@ -1308,19 +1238,17 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
|
||||||
args.n,
|
args.n,
|
||||||
args.k,
|
args.k,
|
||||||
args.mata->data_ptr(),
|
args.mata->data_ptr(),
|
||||||
args.scale_mata_ptr,
|
scale_a.data_ptr(),
|
||||||
args.lda,
|
args.lda,
|
||||||
args.mata->scalar_type(),
|
args.mata->scalar_type(),
|
||||||
args.scale_mata_dtype.value(),
|
|
||||||
args.matb->data_ptr(),
|
args.matb->data_ptr(),
|
||||||
args.scale_matb_ptr,
|
scale_b.data_ptr(),
|
||||||
args.ldb,
|
args.ldb,
|
||||||
args.matb->scalar_type(),
|
args.matb->scalar_type(),
|
||||||
args.scale_matb_dtype.value(),
|
|
||||||
bias ? bias->data_ptr(): nullptr,
|
bias ? bias->data_ptr(): nullptr,
|
||||||
bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_,
|
bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_,
|
||||||
args.result->data_ptr(),
|
args.result->data_ptr(),
|
||||||
args.scale_result_ptr,
|
scale_result ? scale_result->data_ptr() : nullptr,
|
||||||
args.result_ld,
|
args.result_ld,
|
||||||
out_dtype_,
|
out_dtype_,
|
||||||
use_fast_accum,
|
use_fast_accum,
|
||||||
|
|
|
||||||
|
|
@ -19,8 +19,7 @@ from torch.testing._internal.common_cuda import (
|
||||||
SM53OrLater,
|
SM53OrLater,
|
||||||
SM89OrLater,
|
SM89OrLater,
|
||||||
_get_torch_cuda_version,
|
_get_torch_cuda_version,
|
||||||
PLATFORM_SUPPORTS_FP8,
|
PLATFORM_SUPPORTS_FP8
|
||||||
PLATFORM_SUPPORTS_MX_GEMM
|
|
||||||
)
|
)
|
||||||
from torch.testing._internal.common_device_type import (
|
from torch.testing._internal.common_device_type import (
|
||||||
dtypes,
|
dtypes,
|
||||||
|
|
@ -251,7 +250,6 @@ class TestMatmulCuda(TestCase):
|
||||||
|
|
||||||
|
|
||||||
f8_msg = "FP8 is only supported on H100+, SM 8.9 and MI300+ devices"
|
f8_msg = "FP8 is only supported on H100+, SM 8.9 and MI300+ devices"
|
||||||
mx_skip_msg = "MX gemm is only supported on CUDA capability 10.0+"
|
|
||||||
|
|
||||||
if torch.version.hip and 'gfx94' in torch.cuda.get_device_properties(0).gcnArchName:
|
if torch.version.hip and 'gfx94' in torch.cuda.get_device_properties(0).gcnArchName:
|
||||||
e4m3_type = torch.float8_e4m3fnuz
|
e4m3_type = torch.float8_e4m3fnuz
|
||||||
|
|
@ -368,79 +366,6 @@ def to_fp8_saturated(
|
||||||
|
|
||||||
return x.to(fp8_dtype)
|
return x.to(fp8_dtype)
|
||||||
|
|
||||||
# copied from https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/mx/to_blocked.py
|
|
||||||
def ceil_div(a, b):
|
|
||||||
return (a + b - 1) // b
|
|
||||||
|
|
||||||
def to_blocked(input_matrix) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Rearrange a large matrix by breaking it into blocks and applying the rearrangement pattern.
|
|
||||||
|
|
||||||
See:
|
|
||||||
https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_matrix: Input tensor of shape (H, W)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Rearranged tensor of shape (32*ceil_div(H,128), 16*ceil_div(W,4))
|
|
||||||
"""
|
|
||||||
rows, cols = input_matrix.shape
|
|
||||||
n_row_blocks = ceil_div(rows, 128)
|
|
||||||
n_col_blocks = ceil_div(cols, 4)
|
|
||||||
|
|
||||||
# Calculate the padded shape
|
|
||||||
padded_rows = n_row_blocks * 128
|
|
||||||
padded_cols = n_col_blocks * 4
|
|
||||||
|
|
||||||
padded = input_matrix
|
|
||||||
# Ideally we would use torch.nn.pad but it doesn't support float8_e8m0fnu for now
|
|
||||||
if (rows, cols) != (padded_rows, padded_cols):
|
|
||||||
padded = torch.zeros((padded_rows, padded_cols), device=input_matrix.device, dtype=input_matrix.dtype)
|
|
||||||
padded[:rows, :cols] = input_matrix
|
|
||||||
|
|
||||||
# Rearrange the blocks
|
|
||||||
blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3)
|
|
||||||
rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16)
|
|
||||||
|
|
||||||
return rearranged.flatten()
|
|
||||||
|
|
||||||
def compute_error(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
||||||
"""Computes the error between two tensors in dB.
|
|
||||||
|
|
||||||
For more details see:
|
|
||||||
https://en.wikipedia.org/wiki/Signal-to-noise_ratio
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x: The original tensor.
|
|
||||||
y: The tensor to compare to the original tensor.
|
|
||||||
"""
|
|
||||||
Ps = torch.norm(x)
|
|
||||||
Pn = torch.norm(x - y)
|
|
||||||
return 20 * torch.log10(Ps / Pn)
|
|
||||||
|
|
||||||
# largest power of 2 representable in `torch.float8_e4m3fn`
|
|
||||||
F8E4M3_LARGEST_POW2 = 8
|
|
||||||
# max value of `torch.float8_e4m3fn` (448)
|
|
||||||
F8E4M3_MAX_VAL = torch.finfo(torch.float8_e4m3fn).max
|
|
||||||
# exponent bias of `torch.float8_e8m0fnu`
|
|
||||||
F8E8M0_EXP_BIAS = 127
|
|
||||||
|
|
||||||
def data_to_mx_scale(x, block_size):
|
|
||||||
# simple implementation of https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
|
|
||||||
# section 6.3, not all edge cases (such as NaN) are handled/tested
|
|
||||||
orig_shape = x.shape
|
|
||||||
x = x.reshape(-1, block_size)
|
|
||||||
max_abs = torch.amax(torch.abs(x), 1)
|
|
||||||
largest_p2_lt_max_abs = torch.floor(torch.log2(max_abs))
|
|
||||||
scale_e8m0_unbiased = largest_p2_lt_max_abs - F8E4M3_LARGEST_POW2
|
|
||||||
scale_e8m0_unbiased = torch.clamp(scale_e8m0_unbiased, -1 * F8E8M0_EXP_BIAS, F8E8M0_EXP_BIAS)
|
|
||||||
scale_e8m0_biased = scale_e8m0_unbiased + F8E8M0_EXP_BIAS
|
|
||||||
scale_e8m0_biased = scale_e8m0_biased.to(torch.uint8)
|
|
||||||
scale_e8m0_biased = scale_e8m0_biased.view(torch.float8_e8m0fnu)
|
|
||||||
return scale_e8m0_biased.reshape(orig_shape[0], -1)
|
|
||||||
|
|
||||||
|
|
||||||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not found")
|
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not found")
|
||||||
class TestFP8MatmulCuda(TestCase):
|
class TestFP8MatmulCuda(TestCase):
|
||||||
|
|
||||||
|
|
@ -843,287 +768,6 @@ class TestFP8MatmulCuda(TestCase):
|
||||||
self.assertEqual(out_dtype, out_fp8.dtype)
|
self.assertEqual(out_dtype, out_fp8.dtype)
|
||||||
self.assertEqual(out_fp32, out_fp8.to(torch.float))
|
self.assertEqual(out_fp32, out_fp8.to(torch.float))
|
||||||
|
|
||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_MX_GEMM, mx_skip_msg)
|
|
||||||
@parametrize("test_case_name", [
|
|
||||||
"a_eye_b_eye",
|
|
||||||
"a_ones_b_ones",
|
|
||||||
"a_ones_modified_b_ones",
|
|
||||||
"a_ones_b_ones_modified",
|
|
||||||
"a_scale_modified_b_ones",
|
|
||||||
"a_ones_b_scale_modified",
|
|
||||||
"data_random_scales_one",
|
|
||||||
"data_random_scales_from_data",
|
|
||||||
])
|
|
||||||
@parametrize("fast_accum", [False, True])
|
|
||||||
@parametrize("mkn", [
|
|
||||||
# Nice shapes
|
|
||||||
(128, 128, 128),
|
|
||||||
(256, 256, 256),
|
|
||||||
(128, 256, 512),
|
|
||||||
(256, 512, 128),
|
|
||||||
(512, 128, 256),
|
|
||||||
|
|
||||||
# Non block multiples
|
|
||||||
(65, 96, 112),
|
|
||||||
(197, 224, 272),
|
|
||||||
# K not multiple of 32
|
|
||||||
(197, 240, 272),
|
|
||||||
|
|
||||||
# Very unbalanced
|
|
||||||
(1023, 64, 48),
|
|
||||||
(31, 1024, 64),
|
|
||||||
(45, 96, 1024),
|
|
||||||
|
|
||||||
# Mixed large and small
|
|
||||||
(2, 1024, 128),
|
|
||||||
(127, 96, 1024),
|
|
||||||
(1025, 128, 96)
|
|
||||||
], name_fn=lambda mkn: f"{mkn[0]}_{mkn[1]}_{mkn[2]}")
|
|
||||||
def test_blockwise_mxfp8_numerics(self, test_case_name, fast_accum, mkn) -> None:
|
|
||||||
# inspiration: https://github.com/pytorch/ao/pull/1625
|
|
||||||
|
|
||||||
device = "cuda"
|
|
||||||
M, K, N = mkn
|
|
||||||
BLOCK_SIZE = 32
|
|
||||||
require_exact_match = True
|
|
||||||
|
|
||||||
def ceil_div(a, b):
|
|
||||||
return (a + b - 1) // b
|
|
||||||
|
|
||||||
if test_case_name == "a_eye_b_eye":
|
|
||||||
if not ((M == K) and (M == N)):
|
|
||||||
return unittest.skip("this test is only defined for M == K == N, skipping")
|
|
||||||
A_ref = torch.eye(M, device=device, dtype=torch.bfloat16)
|
|
||||||
B_ref = torch.eye(M, device=device, dtype=torch.bfloat16)
|
|
||||||
|
|
||||||
A = A_ref.to(torch.float8_e4m3fn)
|
|
||||||
B = B_ref.to(torch.float8_e4m3fn)
|
|
||||||
|
|
||||||
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
|
|
||||||
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
|
|
||||||
# convert to swizzled format
|
|
||||||
A_scale = to_blocked(A_scale)
|
|
||||||
B_scale = to_blocked(B_scale)
|
|
||||||
|
|
||||||
elif test_case_name == "a_ones_b_ones":
|
|
||||||
A_ref = torch.ones(M, K, device=device, dtype=torch.bfloat16)
|
|
||||||
B_ref = torch.ones(N, K, device=device, dtype=torch.bfloat16)
|
|
||||||
|
|
||||||
A = A_ref.to(torch.float8_e4m3fn)
|
|
||||||
B = B_ref.to(torch.float8_e4m3fn)
|
|
||||||
|
|
||||||
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
|
|
||||||
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
|
|
||||||
# convert to swizzled format
|
|
||||||
A_scale = to_blocked(A_scale)
|
|
||||||
B_scale = to_blocked(B_scale)
|
|
||||||
|
|
||||||
elif test_case_name == "a_ones_modified_b_ones":
|
|
||||||
A_ref = torch.ones(M, K, device=device, dtype=torch.bfloat16)
|
|
||||||
B_ref = torch.ones(N, K, device=device, dtype=torch.bfloat16)
|
|
||||||
|
|
||||||
A = A_ref.to(torch.float8_e4m3fn)
|
|
||||||
B = B_ref.to(torch.float8_e4m3fn)
|
|
||||||
|
|
||||||
A_ref[1][0:BLOCK_SIZE] = 2
|
|
||||||
A[1][0:BLOCK_SIZE] = 2
|
|
||||||
|
|
||||||
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
|
|
||||||
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
|
|
||||||
# convert to swizzled format
|
|
||||||
A_scale = to_blocked(A_scale)
|
|
||||||
B_scale = to_blocked(B_scale)
|
|
||||||
|
|
||||||
elif test_case_name == "a_ones_b_ones_modified":
|
|
||||||
A_ref = torch.ones(M, K, device=device, dtype=torch.bfloat16)
|
|
||||||
B_ref = torch.ones(N, K, device=device, dtype=torch.bfloat16)
|
|
||||||
|
|
||||||
A = A_ref.to(torch.float8_e4m3fn)
|
|
||||||
B = B_ref.to(torch.float8_e4m3fn)
|
|
||||||
|
|
||||||
B_ref[1][0:BLOCK_SIZE] = 2
|
|
||||||
B[1][0:BLOCK_SIZE] = 2
|
|
||||||
|
|
||||||
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
|
|
||||||
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
|
|
||||||
# convert to swizzled format
|
|
||||||
A_scale = to_blocked(A_scale)
|
|
||||||
B_scale = to_blocked(B_scale)
|
|
||||||
|
|
||||||
elif test_case_name == "a_scale_modified_b_ones":
|
|
||||||
A_ref = torch.ones(M, K, device=device, dtype=torch.bfloat16)
|
|
||||||
B_ref = torch.ones(N, K, device=device, dtype=torch.bfloat16)
|
|
||||||
|
|
||||||
A = A_ref.to(torch.float8_e4m3fn)
|
|
||||||
B = B_ref.to(torch.float8_e4m3fn)
|
|
||||||
|
|
||||||
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
|
|
||||||
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
|
|
||||||
|
|
||||||
A_ref[1][0:BLOCK_SIZE] = 4
|
|
||||||
A[1][0:BLOCK_SIZE] = 2
|
|
||||||
A_scale[1][0] = 2
|
|
||||||
|
|
||||||
# convert to swizzled format
|
|
||||||
A_scale = to_blocked(A_scale)
|
|
||||||
B_scale = to_blocked(B_scale)
|
|
||||||
|
|
||||||
elif test_case_name == "a_ones_b_scale_modified":
|
|
||||||
A_ref = torch.ones(M, K, device=device, dtype=torch.bfloat16)
|
|
||||||
B_ref = torch.ones(N, K, device=device, dtype=torch.bfloat16)
|
|
||||||
|
|
||||||
A = A_ref.to(torch.float8_e4m3fn)
|
|
||||||
B = B_ref.to(torch.float8_e4m3fn)
|
|
||||||
|
|
||||||
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
|
|
||||||
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
|
|
||||||
|
|
||||||
B_ref[1][0:BLOCK_SIZE] = 4
|
|
||||||
B[1][0:BLOCK_SIZE] = 2
|
|
||||||
B_scale[1][0] = 2
|
|
||||||
|
|
||||||
# convert to swizzled format
|
|
||||||
A_scale = to_blocked(A_scale)
|
|
||||||
B_scale = to_blocked(B_scale)
|
|
||||||
|
|
||||||
elif test_case_name == "data_random_scales_one":
|
|
||||||
require_exact_match = False
|
|
||||||
# scales all-ones, element data random while being exactly representable in float8_e4m3fn
|
|
||||||
|
|
||||||
# generate integers in [0, 255] and interpret as float8_e4m3fn
|
|
||||||
A_ref = torch.randint(0, 255, (M, K), device=device, dtype=torch.uint8).view(torch.float8_e4m3fn).to(torch.bfloat16)
|
|
||||||
B_ref = torch.randint(0, 255, (N, K), device=device, dtype=torch.uint8).view(torch.float8_e4m3fn).to(torch.bfloat16)
|
|
||||||
# modification: don't allow NaN values
|
|
||||||
A_ref[torch.isnan(A_ref)] = 0
|
|
||||||
B_ref[torch.isnan(B_ref)] = 0
|
|
||||||
|
|
||||||
A = A_ref.to(torch.float8_e4m3fn)
|
|
||||||
B = B_ref.to(torch.float8_e4m3fn)
|
|
||||||
|
|
||||||
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
|
|
||||||
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
|
|
||||||
|
|
||||||
# convert to swizzled format
|
|
||||||
A_scale = to_blocked(A_scale)
|
|
||||||
B_scale = to_blocked(B_scale)
|
|
||||||
|
|
||||||
elif test_case_name == "data_random_scales_from_data":
|
|
||||||
if not K % BLOCK_SIZE == 0:
|
|
||||||
return unittest.skip(f"this test is only defined for K a multiple of {BLOCK_SIZE}, skipping")
|
|
||||||
require_exact_match = False
|
|
||||||
# random data, scales from data
|
|
||||||
A_ref = torch.randn((M, K), device=device, dtype=torch.bfloat16) * 1000
|
|
||||||
B_ref = torch.randn((N, K), device=device, dtype=torch.bfloat16) * 1000
|
|
||||||
|
|
||||||
# Calculate scales based on the inputs
|
|
||||||
A_scale = data_to_mx_scale(A_ref, BLOCK_SIZE)
|
|
||||||
B_scale = data_to_mx_scale(B_ref, BLOCK_SIZE)
|
|
||||||
|
|
||||||
max_val = F8E4M3_MAX_VAL
|
|
||||||
min_val = -1 * max_val
|
|
||||||
|
|
||||||
A = (A_ref.reshape(-1, BLOCK_SIZE) / A_scale.reshape(M * ceil_div(K, BLOCK_SIZE), 1).float()).reshape(M, K)
|
|
||||||
A = A.clamp(min=min_val, max=max_val).to(torch.float8_e4m3fn)
|
|
||||||
B = (B_ref.reshape(-1, BLOCK_SIZE) / B_scale.reshape(N * ceil_div(K, BLOCK_SIZE), 1).float()).reshape(N, K)
|
|
||||||
B = B.clamp(min=min_val, max=max_val).to(torch.float8_e4m3fn)
|
|
||||||
|
|
||||||
# convert to swizzled format
|
|
||||||
A_scale = to_blocked(A_scale)
|
|
||||||
B_scale = to_blocked(B_scale)
|
|
||||||
|
|
||||||
C_ref = A_ref @ B_ref.t()
|
|
||||||
|
|
||||||
C = torch._scaled_mm(
|
|
||||||
A,
|
|
||||||
B.t(),
|
|
||||||
A_scale,
|
|
||||||
B_scale,
|
|
||||||
out_dtype=torch.bfloat16,
|
|
||||||
use_fast_accum=fast_accum,
|
|
||||||
)
|
|
||||||
|
|
||||||
if require_exact_match:
|
|
||||||
torch.testing.assert_close(C, C_ref, atol=0, rtol=0)
|
|
||||||
else:
|
|
||||||
sqnr = compute_error(C_ref, C)
|
|
||||||
assert sqnr.item() > 22.0
|
|
||||||
|
|
||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
|
|
||||||
@skipIfRocm()
|
|
||||||
def test_blockwise_mxfloat8_error_messages(self, device) -> None:
|
|
||||||
M, K, N = (1024, 512, 2048)
|
|
||||||
BLOCK_SIZE_K = 32
|
|
||||||
BLOCK_SIZE_MN = 128
|
|
||||||
fill_value = 0.5
|
|
||||||
|
|
||||||
x = torch.full((M, K), fill_value, device=device)
|
|
||||||
y = torch.full((N, K), fill_value, device=device)
|
|
||||||
|
|
||||||
x_fp8 = x.to(e4m3_type)
|
|
||||||
y_fp8 = y.to(e4m3_type).t()
|
|
||||||
|
|
||||||
def ceil_div(a, b):
|
|
||||||
return (a + b - 1) // b
|
|
||||||
|
|
||||||
num_k_blocks = ceil_div(K, BLOCK_SIZE_K)
|
|
||||||
padded_num_k_blocks = ceil_div(num_k_blocks, 4) * 4
|
|
||||||
expected_a_size = BLOCK_SIZE_MN * ceil_div(M, BLOCK_SIZE_MN) * padded_num_k_blocks
|
|
||||||
expected_b_size = BLOCK_SIZE_MN * ceil_div(N, BLOCK_SIZE_MN) * padded_num_k_blocks
|
|
||||||
|
|
||||||
|
|
||||||
# Test wrong scale tensor size for scale_a with correct dtype
|
|
||||||
with self.assertRaisesRegex(
|
|
||||||
RuntimeError,
|
|
||||||
re.escape(
|
|
||||||
f"For BlockWise scaling: Expected scale_a size to be {expected_a_size} "
|
|
||||||
f"but got {expected_a_size - 1}"
|
|
||||||
),
|
|
||||||
):
|
|
||||||
incorrect_size_a = torch.ones(expected_a_size - 1, device=device, dtype=torch.float8_e8m0fnu)
|
|
||||||
correct_size_b = torch.ones(expected_b_size, device=device, dtype=torch.float8_e8m0fnu)
|
|
||||||
torch._scaled_mm(
|
|
||||||
x_fp8,
|
|
||||||
y_fp8,
|
|
||||||
scale_a=incorrect_size_a,
|
|
||||||
scale_b=correct_size_b,
|
|
||||||
out_dtype=torch.bfloat16,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Test wrong scale tensor size for scale_b with correct dtype
|
|
||||||
with self.assertRaisesRegex(
|
|
||||||
RuntimeError,
|
|
||||||
re.escape(
|
|
||||||
f"For BlockWise scaling: Expected scale_b size to be {expected_b_size} "
|
|
||||||
f"but got {expected_b_size + 1}"
|
|
||||||
),
|
|
||||||
):
|
|
||||||
correct_size_a = torch.ones(expected_a_size, device=device, dtype=torch.float8_e8m0fnu)
|
|
||||||
incorrect_size_b = torch.ones(expected_b_size + 1, device=device, dtype=torch.float8_e8m0fnu)
|
|
||||||
torch._scaled_mm(
|
|
||||||
x_fp8,
|
|
||||||
y_fp8,
|
|
||||||
scale_a=correct_size_a,
|
|
||||||
scale_b=incorrect_size_b,
|
|
||||||
out_dtype=torch.bfloat16,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Test non-contiguous scale tensors with correct dtype
|
|
||||||
with self.assertRaisesRegex(
|
|
||||||
RuntimeError,
|
|
||||||
re.escape(
|
|
||||||
"For BlockWise scaling: Both scale_a and scale_b must be contiguous"
|
|
||||||
),
|
|
||||||
):
|
|
||||||
non_contiguous_a = torch.ones(expected_a_size * 2, device=device, dtype=torch.float8_e8m0fnu)[::2]
|
|
||||||
contiguous_b = torch.ones(expected_b_size, device=device, dtype=torch.float8_e8m0fnu)
|
|
||||||
torch._scaled_mm(
|
|
||||||
x_fp8,
|
|
||||||
y_fp8,
|
|
||||||
scale_a=non_contiguous_a,
|
|
||||||
scale_b=contiguous_b,
|
|
||||||
out_dtype=torch.bfloat16,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
|
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
|
||||||
@unittest.skipIf(IS_WINDOWS, "Windows doesn't support CUTLASS extensions")
|
@unittest.skipIf(IS_WINDOWS, "Windows doesn't support CUTLASS extensions")
|
||||||
|
|
|
||||||
|
|
@ -32,7 +32,6 @@ SM75OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_devic
|
||||||
SM80OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 0))
|
SM80OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 0))
|
||||||
SM89OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9))
|
SM89OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9))
|
||||||
SM90OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0))
|
SM90OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0))
|
||||||
SM100OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (10, 0))
|
|
||||||
|
|
||||||
IS_THOR = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability()[0] == 10
|
IS_THOR = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability()[0] == 10
|
||||||
and torch.cuda.get_device_capability()[1] > 0)
|
and torch.cuda.get_device_capability()[1] > 0)
|
||||||
|
|
@ -102,7 +101,6 @@ def evaluate_platform_supports_fp8():
|
||||||
|
|
||||||
PLATFORM_SUPPORTS_FP8: bool = LazyVal(lambda: evaluate_platform_supports_fp8())
|
PLATFORM_SUPPORTS_FP8: bool = LazyVal(lambda: evaluate_platform_supports_fp8())
|
||||||
|
|
||||||
PLATFORM_SUPPORTS_MX_GEMM: bool = LazyVal(lambda: TEST_CUDA and SM100OrLater)
|
|
||||||
|
|
||||||
if TEST_NUMBA:
|
if TEST_NUMBA:
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user