[ROCm] [TunableOp] Enable logging of BLAS parameters (#147034)

This PR supports a logging feature that is being requested.
```
PYTORCH_TUNABLEOP_BLAS_LOG=1
```
Enables the logging of BLAS parameters with either offline of online (in-situ) tuning.

The BLAS parameters are written to the CSV file.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147034
Approved by: https://github.com/jeffdaily
This commit is contained in:
Nichols A. Romero 2025-03-07 09:32:56 +00:00 committed by PyTorch MergeBot
parent 243b47e2ec
commit bb84a23c22
5 changed files with 256 additions and 5 deletions

View File

@ -45,6 +45,201 @@ inline char BlasOpToString(BlasOp op) {
return 'N';
}
template <typename T>
inline const char* BLASTypeName(T v) {
return "unknown";
}
template <>
inline const char* BLASTypeName(float v) {
return "f32_r";
}
template <>
inline const char* BLASTypeName(double v) {
return "f64_r";
}
template <>
inline const char* BLASTypeName(BFloat16 v) {
return "bf16_r";
}
template <>
inline const char* BLASTypeName(Half v) {
return "f16_r";
}
//https://github.com/ROCm/hipBLASLt/blob/develop/library/src/include/auxiliary.hpp#L175
template <>
inline const char* BLASTypeName(Float8_e4m3fn v) {
return "f8_r";
}
template <>
inline const char* BLASTypeName(Float8_e5m2 v) {
return "bf8_r";
}
template <>
inline const char* BLASTypeName(Float8_e4m3fnuz v) {
return "f8_fnuz_r";
}
template <>
inline const char* BLASTypeName(Float8_e5m2fnuz v) {
return "bf8_fnuz_r";
}
template <>
inline const char* BLASTypeName(c10::complex<double> v) {
return "f64_r";
}
template <>
inline const char* BLASTypeName(c10::complex<float> v) {
return "f32_r";
}
inline std::string ScalarTypeToBLASType(c10::ScalarType scalar_type) {
std::string BLASType;
switch (scalar_type) {
case c10::ScalarType::Float:{
BLASType = "f32_r";
break;
}
case c10::ScalarType::Double:{
BLASType = "f64_r";
break;
}
case c10::ScalarType::BFloat16:{
BLASType = "bf16_r";
break;
}
case c10::ScalarType::Half: {
BLASType = "f16_r";
break;
}
case c10::ScalarType::Float8_e4m3fn: {
BLASType = "f8_r";
break;
}
case c10::ScalarType::Float8_e5m2: {
BLASType = "bf8_r";
break;
}
case c10::ScalarType::Float8_e4m3fnuz: {
BLASType = "f8_fnuz_r";
break;
}
case c10::ScalarType::Float8_e5m2fnuz: {
BLASType = "bf8_fnuz_r";
break;
}
case c10::ScalarType::ComplexFloat:{
BLASType = "f32_c";
break;
}
case c10::ScalarType::ComplexDouble:{
BLASType = "f64_c";
break;
}
default:
BLASType = "unknown";
}
return BLASType;
}
// Similar to Compute Type in GemmRocblas.h
template <typename T>
inline std::string ComputeTypeFor() {
return "Unknown ComputeType";
}
// This is a union of the compute types for
// ROCBLAS and hipBLASLt.
template <>
inline std::string ComputeTypeFor<float>() {
if (!at::globalContext().allowTF32CuBLAS()) {
return "f32_r";
} else {
return "xf32_r";
}
}
template <>
inline std::string ComputeTypeFor<double>() {
return "f64_r";
}
template <>
inline std::string ComputeTypeFor<Half>() {
return "f32_r";
}
template <>
inline std::string ComputeTypeFor<BFloat16>() {
return "f32_r";
}
template <>
inline std::string ComputeTypeFor<c10::complex<float>>() {
return "f32_c";
}
template <>
inline std::string ComputeTypeFor<c10::complex<double>>() {
return "f64_c";
}
template <>
inline std::string ComputeTypeFor<Float8_e4m3fn>() {
return "f32_r";
}
template <>
inline std::string ComputeTypeFor<Float8_e5m2>() {
return "f32_r";
}
template <>
inline std::string ComputeTypeFor<Float8_e4m3fnuz>() {
return "f32_r";
}
template <>
inline std::string ComputeTypeFor<Float8_e5m2fnuz>() {
return "f32_r";
}
// Convert opmath_type<T> to string
template <typename T>
inline std::string to_string_opmath(const at::opmath_type<T>& value) {
if constexpr (std::is_same_v<at::opmath_type<T>, c10::complex<float>> ||
std::is_same_v<at::opmath_type<T>, c10::complex<double>>) {
return fmt::format("({:.4f}, {:.4f})", value.real(), value.imag());
} else {
return fmt::format("{:.4f}", value);
}
}
// convert activation epilogue to string
inline std::string to_string_epilogue(const at::cuda::blas::GEMMAndBiasActivationEpilogue& value) {
switch (value) {
case at::cuda::blas::GEMMAndBiasActivationEpilogue::None:
return std::string("None");
break;
case at::cuda::blas::GEMMAndBiasActivationEpilogue::RELU:
return std::string("RELU");
break;
case cuda::blas::GEMMAndBiasActivationEpilogue::GELU:
return std::string("GELU");
break;
default:
return std::string("unknown");
}
}
namespace detail {
static bool NumericalCheck(ScalarType dtype, void* c, void* other_c, int64_t size) {
@ -87,6 +282,15 @@ template <typename T>
struct GemmParams : OpParams {
GemmParams() = default;
std::string BLASSignature() const override {
std::string alpha_str = to_string_opmath<T>(alpha);
std::string beta_str = to_string_opmath<T>(beta);
return fmt::sprintf("- { function: matmul, M: %ld, N: %ld, K: %ld, lda: %ld, ldb: %ld, ldc: %ld, ldd: %ld, stride_a: 0, stride_b: 0, stride_c: 0, stride_d: 0, "
"alpha: %s, beta: %s, transA: %c, transB: %c, batch_count: 1, a_type: %s, b_type: %s, c_type: %s, d_type: %s, scale_type: %s, bias_type: %s, compute_type: %s }",
m, n, k, lda, ldb, ldc, ldc, alpha_str, beta_str, transa, transb,
BLASTypeName<T>(T{}), BLASTypeName<T>(T{}), BLASTypeName<T>(T{}), BLASTypeName<T>(T{}), ComputeTypeFor<T>(), ComputeTypeFor<T>(), ComputeTypeFor<T>());
}
std::string Signature() const override {
return fmt::sprintf("%c%c_%ld_%ld_%ld_ld_%ld_%ld_%ld", transa, transb, m, n, k, lda, ldb, ldc);
}
@ -172,6 +376,15 @@ private:
template <typename T>
struct GemmAndBiasParams : OpParams {
std::string BLASSignature() const override {
std::string alpha_str = to_string_opmath<T>(alpha);
std::string activation_str = to_string_epilogue(activation);
return fmt::sprintf("- { function: matmul, M: %ld, N: %ld, K: %ld, lda: %ld, ldb: %ld, ldc: %ld, ldd: %ld, stride_a: 0, stride_b: 0, stride_c: 0, stride_d: 0, "
"alpha: %s, transA: %c, transB: %c, batch_count: 1, a_type: %s, b_type: %s, c_type: %s, d_type: %s, activation: %s, bias_type: %s, scale_type: %s, compute_type: %s }",
m, n, k, lda, ldb, ldc, ldc, alpha_str, transa, transb,
BLASTypeName<T>(T{}), BLASTypeName<T>(T{}), BLASTypeName<T>(T{}), BLASTypeName<T>(T{}), activation_str, BLASTypeName<T>(T{}), ComputeTypeFor<T>(), ComputeTypeFor<T>(), ComputeTypeFor<T>());
}
std::string Signature() const override {
return fmt::sprintf("%c%c_%ld_%ld_%ld_ld_%ld_%ld_%ld", transa, transb, m, n, k, lda, ldb, ldc);
}
@ -258,6 +471,15 @@ private:
template <typename T>
struct GemmStridedBatchedParams : OpParams {
std::string BLASSignature() const override {
std::string alpha_str = to_string_opmath<T>(alpha);
std::string beta_str = to_string_opmath<T>(beta);
return fmt::sprintf("- { function: matmul, M: %ld, N: %ld, K: %ld, lda: %ld, ldb: %ld, ldc: %ld, ldd: %ld, stride_a: %ld, stride_b: %ld, stride_c: %ld, stride_d: %ld, "
"alpha: %s, beta: %s, transA: %c, transB: %c, batch_count: %ld, a_type: %s, b_type: %s, c_type: %s, d_type: %s, scale_type: %s, compute_type: %s }",
m, n, k, lda, ldb, ldc, ldc, stride_a, stride_b, stride_c, stride_c, alpha_str, beta_str, transa, transb, batch,
BLASTypeName<T>(T{}), BLASTypeName<T>(T{}), BLASTypeName<T>(T{}), BLASTypeName<T>(T{}), ComputeTypeFor<T>(), ComputeTypeFor<T>());
}
std::string Signature() const override {
return fmt::sprintf("%c%c_%ld_%ld_%ld_B_%ld_ld_%ld_%ld_%ld", transa, transb, m, n, k, batch, lda, ldb, ldc);
}
@ -351,6 +573,15 @@ template <typename T>
struct ScaledGemmParams : OpParams {
ScaledGemmParams() = default;
std::string BLASSignature() const override {
// Excluding use_fast_accum and use_rowise booleans for now
return fmt::sprintf("- { function: matmul, M: %ld, N: %ld, K: %ld, lda: %ld, ldb: %ld, ldc: %ld, ldd: %ld, stride_a: 0, stride_b: 0, stride_c: 0, stride_d: 0, "
"transA: %c, transB: %c, batch_count: 1, scaleA: f32_r, scaleB: f32_r, a_type: %s, b_type: %s, c_type: %s, d_type: %s, bias_type: %s, scale_type: %s, compute_type: %s }",
m, n, k, lda, ldb, ldc, ldc, transa, transb,
ScalarTypeToBLASType(a_dtype), ScalarTypeToBLASType(b_dtype), ScalarTypeToBLASType(c_dtype), ScalarTypeToBLASType(c_dtype), ScalarTypeToBLASType(bias_dtype),
ComputeTypeFor<T>(), ComputeTypeFor<T>());
}
std::string Signature() const override {
return fmt::sprintf("%c%c_%ld_%ld_%ld_ld_%ld_%ld_%ld_rw_%d", transa, transb, m, n, k, lda, ldb, ldc, use_rowwise);
}

View File

@ -154,6 +154,7 @@ programmatically since the settings become fixed. Use the C++ or Python APIs ins
| PYTORCH_TUNABLEOP_MAX_WARMUP_ITERATIONS | Default is 0, meaning it is not used. |
| PYTORCH_TUNABLEOP_ICACHE_FLUSH_ENABLED | Default is 1. Set to 0 to disable. |
| PYTORCH_TUNABLEOP_ROTATING_BUFFER_SIZE | Default (or < 0) is to query L2 cache size. Set to 0 to disable. Otherwise, set to the number of MiB to use for the pool of operator parameters. For example, setting this to the size of your device's memory cache will guarantee that every tuning iteration will use a cold cache. |
| PYTORCH_TUNABLEOP_BLAS_LOG | Default is 0. Set to 1 to enable. Write BLAS paramters to tuning CSV file. |
### Python Interface
All python APIs exist in the `torch.cuda.tunable` module.

View File

@ -46,7 +46,13 @@ TuningContext* getTuningContext() {
}
std::ostream& operator<<(std::ostream& stream, const ResultEntry& entry) {
return stream << entry.key_ << "," << entry.time_;
static const bool blaslog = c10::utils::get_env("PYTORCH_TUNABLEOP_BLAS_LOG") == "1";
if (!blaslog) {
return stream << entry.key_ << "," << entry.time_;
}
else {
return stream << entry.key_ << "," << entry.time_ << ",BLAS_PARAMS: " << entry.blas_sig_;
}
}
// TuningResultsManager
@ -107,7 +113,8 @@ void TuningResultsManager::Add(const std::string& op_signature, const std::strin
AddImpl(op_signature, params_signature, std::move(best), it->second);
}
void TuningResultsManager::RecordUntuned( std::ofstream& untuned_file, const std::string& op_signature, const std::string& params_signature) {
void TuningResultsManager::RecordUntuned( std::ofstream& untuned_file, const std::string& op_signature,
const std::string& params_signature, const std::string& blas_signature) {
std::scoped_lock l{lock_};
if (!untuned_file.good()) {
TORCH_WARN_ONCE("failed to open file for writing; untuned gemm will not be saved");
@ -127,7 +134,13 @@ void TuningResultsManager::RecordUntuned( std::ofstream& untuned_file, const std
}
if (isNew) {
untuned_file << op_signature << "," << params_signature << std::endl;
static const bool blaslog = c10::utils::get_env("PYTORCH_TUNABLEOP_BLAS_LOG") == "1";
if (!blaslog) {
untuned_file << op_signature << "," << params_signature << std::endl;
}
else {
untuned_file << op_signature << "," << params_signature << ",BLAS_PARAMS: " << blas_signature << std::endl;
}
TUNABLE_LOG3("Untuned,", op_signature, ",", params_signature);
}
}

View File

@ -40,6 +40,7 @@ enum TORCH_CUDA_CPP_API TuningStatus {
class TORCH_CUDA_CPP_API ResultEntry {
public:
explicit ResultEntry(std::string key, double time) : key_(std::move(key)), time_(time) {}
explicit ResultEntry(std::string key, double time, const std::string& blas_sig ) : key_(std::move(key)), time_(time), blas_sig_(blas_sig) {}
bool operator==(const ResultEntry& other) { return key_ == other.key_; }
bool operator!=(const ResultEntry& other) { return key_ != other.key_; }
operator std::string () { return key_; }
@ -52,6 +53,7 @@ class TORCH_CUDA_CPP_API ResultEntry {
private:
std::string key_;
double time_;
std::string blas_sig_;
};
typedef std::unordered_map<std::string, ResultEntry> KernelMap;
@ -99,7 +101,8 @@ class TORCH_CUDA_CPP_API TuningResultsManager {
size_t GetSize();
void RecordUntuned( std::ofstream& untuned_file, const std::string& op_signature, const std::string& params_signature);
void RecordUntuned( std::ofstream& untuned_file, const std::string& op_signature,
const std::string& params_signature, const std::string& blas_signature);
private:
std::mutex lock_;
ResultsMap results_;

View File

@ -118,6 +118,7 @@ class TunableOp {
auto& mgr = ctx->GetTuningResultsManager();
auto op_sig = Signature();
auto params_sig = params->Signature();
auto blas_sig = params->BLASSignature();
result = mgr.Lookup(op_sig, params_sig);
// If there is not previous tuning result been found, we do the tuning iff tuning is enabled
if (result == ResultEntry::Null()) {
@ -127,7 +128,7 @@ class TunableOp {
}
else if (ctx->IsRecordUntunedEnabled()) {
// or record the gemm into file
mgr.RecordUntuned(ctx->GetUntunedFile(), op_sig, params_sig);
mgr.RecordUntuned(ctx->GetUntunedFile(), op_sig, params_sig, blas_sig);
}
}
}
@ -224,6 +225,7 @@ class TunableOp {
TuningContext* ctx = getTuningContext();
auto op_sig = Signature();
auto params_sig = params->Signature();
auto blas_sig = params->BLASSignature();
TUNABLE_LOG2("finding fastest for ", op_sig, '(', params_sig, ')', " out of ", op_names_.size(), " candidates");
auto min_duration_ms = std::numeric_limits<double>::infinity();
std::string id_name = "Default";
@ -422,6 +424,7 @@ class TunableOp {
struct OpParams {
virtual ~OpParams() = default;
virtual std::string Signature() const = 0;
virtual std::string BLASSignature() const = 0;
};
} // namespace at::cuda::tunable