Speed up FP precision lookup (#164044)

This commit simplifies the precision lookup and setting logic
by reducing the number of branches and using a custom hash
function. Fixes #161822. The issue described in #163709 still
persists. This is meant as a short term fix.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164044
Approved by: https://github.com/ngimel, https://github.com/eqy
This commit is contained in:
Lakshay Garg 2025-10-03 21:35:17 +00:00 committed by PyTorch MergeBot
parent 8d53d788fe
commit f006aee601
16 changed files with 170 additions and 132 deletions

View File

@ -40,41 +40,6 @@ namespace {
->conv
->rnn
*/
const std::map<std::string, std::vector<std::string>> _fp32_precisions = {
{"generic", {{"ieee", "tf32", "bf16", "none"}}},
{"mkldnn", {{"ieee", "tf32", "bf16", "none"}}},
{"cuda", {{"ieee", "tf32", "none"}}}};
// Check whether the backend and op are legal
void check_fp32_prec_backend_and_op(
const std::string& backend,
const std::string& op) {
static std::vector<std::string> backends = {"generic", "mkldnn", "cuda"};
static std::vector<std::string> operators = {"conv", "matmul", "rnn", "all"};
TORCH_CHECK(
std::find(backends.begin(), backends.end(), backend) != backends.end(),
"Invalid backend: ",
backend);
TORCH_CHECK(
std::find(operators.begin(), operators.end(), op) != operators.end(),
"Invalid operator: ",
op);
if (backend == "generic") {
TORCH_CHECK(op == "all", "Invalid operation for generic backend: ", op);
}
}
// Return whether the precision is supported by backends
bool validate_fp32_prec(
const std::string& backend,
const std::string& precision) {
auto iterp = _fp32_precisions.find(backend);
TORCH_CHECK(iterp != _fp32_precisions.end());
auto precisions = iterp->second;
bool valid = std::find(precisions.begin(), precisions.end(), precision) !=
precisions.end();
return valid;
}
C10_ALWAYS_INLINE void warn_deprecated_fp32_precision_api(){
TORCH_WARN_ONCE(
@ -86,6 +51,54 @@ void check_fp32_prec_backend_and_op(
}
} // namespace
Float32Backend str2backend(const std::string& name) {
if (name == "generic")
return Float32Backend::GENERIC;
else if (name == "cuda")
return Float32Backend::CUDA;
else if (name == "mkldnn")
return Float32Backend::MKLDNN;
TORCH_CHECK(false, "Unknown backend: ", name);
}
Float32Op str2op(const std::string& name) {
if (name == "all")
return Float32Op::ALL;
else if (name == "conv")
return Float32Op::CONV;
else if (name == "rnn")
return Float32Op::RNN;
else if (name == "matmul")
return Float32Op::MATMUL;
TORCH_CHECK(false, "Unknown op: ", name);
}
Float32Precision str2precision(const std::string& name) {
if (name == "none")
return Float32Precision::NONE;
else if (name == "ieee")
return Float32Precision::IEEE;
else if (name == "tf32")
return Float32Precision::TF32;
else if (name == "bf16")
return Float32Precision::BF16;
TORCH_CHECK(false, "Unknown precision: ", name);
}
std::string precision2str(Float32Precision prec) {
switch (prec) {
case Float32Precision::NONE:
return "none";
case Float32Precision::IEEE:
return "ieee";
case Float32Precision::TF32:
return "tf32";
case Float32Precision::BF16:
return "bf16";
}
TORCH_CHECK(false, "Invalid enum Float32Precision(", static_cast<int>(prec), ")");
}
Context::Context() = default;
// TODO: This could be bad juju if someone calls globalContext() in the
@ -179,10 +192,10 @@ void Context::setUserEnabledNNPACK(bool e) {
enabled_nnpack = e;
}
bool Context::allowTF32CuDNN(const std::string& op) const {
if (op.empty()){
bool allow_tf32_rnn = float32Precision("cuda", "rnn") == "tf32";
bool allow_tf32_conv = float32Precision("cuda", "conv") == "tf32";
bool Context::allowTF32CuDNN(std::optional<Float32Op> op) const {
if (!op.has_value()) {
bool allow_tf32_rnn = float32Precision(Float32Backend::CUDA, Float32Op::RNN) == Float32Precision::TF32;
bool allow_tf32_conv = float32Precision(Float32Backend::CUDA, Float32Op::CONV) == Float32Precision::TF32;
TORCH_CHECK(
allow_tf32_rnn == allow_tf32_conv && allow_tf32_rnn == allow_tf32_cudnn,
"PyTorch is checking whether allow_tf32 is enabled for cuDNN without a specific operator name,",
@ -191,15 +204,15 @@ bool Context::allowTF32CuDNN(const std::string& op) const {
"We suggest only using the new API to set the TF32 flag(s). See also: ",
"https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices");
} else {
return float32Precision("cuda", op) == "tf32";
return float32Precision(Float32Backend::CUDA, op.value()) == Float32Precision::TF32;
}
warn_deprecated_fp32_precision_api();
return allow_tf32_cudnn;
}
void Context::setAllowTF32CuDNN(bool b) {
setFloat32Precision("cuda", "rnn", b ? "tf32" : "none");
setFloat32Precision("cuda", "conv", b ? "tf32" : "none");
setFloat32Precision(Float32Backend::CUDA, Float32Op::RNN, b ? Float32Precision::TF32 : Float32Precision::NONE);
setFloat32Precision(Float32Backend::CUDA, Float32Op::CONV, b ? Float32Precision::TF32 : Float32Precision::NONE);
allow_tf32_cudnn = b;
warn_deprecated_fp32_precision_api();
}
@ -305,7 +318,7 @@ void Context::setImmediateMiopen(bool b) {
bool Context::allowTF32CuBLAS() const {
bool legacy_allow_tf32 = float32_matmul_precision != at::Float32MatmulPrecision::HIGHEST;
bool allow_tf32_new = float32Precision("cuda", "matmul") == "tf32";
bool allow_tf32_new = float32Precision(Float32Backend::CUDA, Float32Op::MATMUL) == Float32Precision::TF32;
TORCH_CHECK(
legacy_allow_tf32 == allow_tf32_new,
"PyTorch is checking whether allow_tf32_new is enabled for cuBlas matmul,",
@ -318,17 +331,17 @@ bool Context::allowTF32CuBLAS() const {
void Context::setAllowTF32CuBLAS(bool b) {
float32_matmul_precision = b ? at::Float32MatmulPrecision::HIGH : at::Float32MatmulPrecision::HIGHEST;
setFloat32Precision("cuda", "matmul", b ? "tf32" : "ieee");
setFloat32Precision(Float32Backend::CUDA, Float32Op::MATMUL, b ? Float32Precision::TF32 : Float32Precision::IEEE);
}
Float32MatmulPrecision Context::float32MatmulPrecision() const {
bool invalid = float32Precision("cuda", "matmul") == "tf32" &&
bool invalid = float32Precision(Float32Backend::CUDA, Float32Op::MATMUL) == Float32Precision::TF32 &&
float32_matmul_precision == at::Float32MatmulPrecision::HIGHEST;
invalid = invalid ||
(float32Precision("mkldnn", "matmul") == "bf16" &&
(float32Precision(Float32Backend::MKLDNN, Float32Op::MATMUL) == Float32Precision::BF16 &&
float32_matmul_precision != at::Float32MatmulPrecision::MEDIUM);
invalid = invalid ||
(float32Precision("mkldnn", "matmul") == "tf32" &&
(float32Precision(Float32Backend::MKLDNN, Float32Op::MATMUL) == Float32Precision::TF32 &&
float32_matmul_precision != at::Float32MatmulPrecision::HIGH);
TORCH_CHECK(
!invalid,
@ -340,15 +353,26 @@ Float32MatmulPrecision Context::float32MatmulPrecision() const {
return float32_matmul_precision;
}
std::string Context::float32Precision(const std::string& backend, const std::string& op) const {
check_fp32_prec_backend_and_op(backend, op);
auto precision = fp32_precision.find(backend)->second.find(op)->second;
if (precision == "none")
precision = fp32_precision.find(backend)->second.find("all")->second;
if (precision == "none")
precision = fp32_precision.find("generic")->second.find("all")->second;
bool valid_prec = validate_fp32_prec(backend, precision);
return valid_prec ? precision : "none";
Float32Precision Context::float32Precision(Float32Backend backend, Float32Op op) const {
std::pair<Float32Backend, Float32Op> key{backend, op};
auto it = fp32_precision.find(key);
TORCH_CHECK(it != fp32_precision.end(), "Invalid (backend, op) pair: (", backend, ", ", op, ")");
Float32Precision precision = it->second;
if (precision == Float32Precision::NONE) {
key.second = Float32Op::ALL;
precision = fp32_precision.find(key)->second;
}
if (precision == Float32Precision::NONE) {
key.first = Float32Backend::GENERIC;
precision = fp32_precision.find(key)->second;
}
// "cuda" does not support "bf16"
if (backend == Float32Backend::CUDA && precision == Float32Precision::BF16) {
return Float32Precision::NONE;
}
return precision;
}
void Context::setFloat32MatmulPrecision(const std::string &s) {
@ -357,18 +381,18 @@ void Context::setFloat32MatmulPrecision(const std::string &s) {
// TODO: consider if CuDNN field needs to also be set for potential future CuDNN ops like multi-headed attention
if (s_ == "highest") {
float32_matmul_precision = at::Float32MatmulPrecision::HIGHEST;
setFloat32Precision("cuda", "matmul", "ieee");
setFloat32Precision("mkldnn", "matmul", "ieee");
setFloat32Precision(Float32Backend::CUDA, Float32Op::MATMUL, Float32Precision::IEEE);
setFloat32Precision(Float32Backend::MKLDNN, Float32Op::MATMUL, Float32Precision::IEEE);
return true;
} else if (s_ == "high") {
float32_matmul_precision = at::Float32MatmulPrecision::HIGH;
setFloat32Precision("cuda", "matmul", "tf32");
setFloat32Precision("mkldnn", "matmul", "tf32");
setFloat32Precision(Float32Backend::CUDA, Float32Op::MATMUL, Float32Precision::TF32);
setFloat32Precision(Float32Backend::MKLDNN, Float32Op::MATMUL, Float32Precision::TF32);
return true;
} else if (s_ == "medium") {
float32_matmul_precision = at::Float32MatmulPrecision::MEDIUM;
setFloat32Precision("cuda", "matmul", "tf32");
setFloat32Precision("mkldnn", "matmul", "bf16");
setFloat32Precision(Float32Backend::CUDA, Float32Op::MATMUL, Float32Precision::TF32);
setFloat32Precision(Float32Backend::MKLDNN, Float32Op::MATMUL, Float32Precision::BF16);
return true;
}
return false;
@ -382,25 +406,16 @@ void Context::setFloat32MatmulPrecision(const std::string &s) {
"setFloat32MatmulPrecision call has no effect.");
}
void Context::setFloat32Precision(const std::string& backend, const std::string& op, const std::string& p) {
check_fp32_prec_backend_and_op(backend, op);
if (validate_fp32_prec(backend, p)) {
fp32_precision[backend][op] = p;
} else {
std::string msg;
auto iterp = _fp32_precisions.find(backend);
TORCH_CHECK(iterp != _fp32_precisions.end());
for (const auto& p : iterp->second) {
msg += p;
msg += " ";
}
TORCH_WARN(
"you have set wrong precision for backend:",
backend,
" setFloat32Precision call has no effect.",
"Please choose precision from: ",
msg);
}
void Context::setFloat32Precision(Float32Backend backend, Float32Op op, Float32Precision p) {
auto it = fp32_precision.find(std::make_pair(backend, op));
TORCH_CHECK(
it != fp32_precision.end(),
"Invalid (backend, op) pair: (", backend, ", ", op, ")");
TORCH_CHECK(
!(backend == Float32Backend::CUDA && p == Float32Precision::BF16),
"backend 'cuda' does not support precision 'bf16'");
it->second = p;
}
at::LinalgBackend Context::linalgPreferredBackend() const {

View File

@ -25,17 +25,27 @@
#include <c10/util/CallOnce.h>
#include <c10/util/Exception.h>
#include <c10/util/env.h>
#include <c10/util/hash.h>
#include <c10/util/irange.h>
#include <cstdint>
#include <map>
#include <mutex>
#include <unordered_map>
namespace at {
class Tensor;
enum class TORCH_API Float32MatmulPrecision { HIGHEST, HIGH, MEDIUM };
enum class TORCH_API Float32Backend { GENERIC, CUDA, MKLDNN };
enum class TORCH_API Float32Op { ALL, CONV, RNN, MATMUL };
enum class TORCH_API Float32Precision { NONE, IEEE, TF32, BF16 };
TORCH_API Float32Backend str2backend(const std::string& name);
TORCH_API Float32Op str2op(const std::string& name);
TORCH_API Float32Precision str2precision(const std::string& name);
TORCH_API std::string precision2str(Float32Precision prec);
class TORCH_API Context {
public:
@ -336,19 +346,17 @@ class TORCH_API Context {
void setFloat32MatmulPrecision(const std::string& s);
void setFloat32Precision(
const std::string& backend,
const std::string& op,
const std::string& s);
bool allowTF32CuDNN(const std::string& op = std::string()) const;
Float32Backend backend,
Float32Op op,
Float32Precision p);
bool allowTF32CuDNN(std::optional<Float32Op> op = std::nullopt) const;
void setAllowTF32CuDNN(bool);
bool allowTF32OneDNN() const;
void setAllowTF32OneDNN(bool);
bool allowTF32CuBLAS() const;
void setAllowTF32CuBLAS(bool);
Float32MatmulPrecision float32MatmulPrecision() const;
std::string float32Precision(
const std::string& backend,
const std::string& op) const;
Float32Precision float32Precision(Float32Backend backend, Float32Op op) const;
bool allowFP16ReductionCuBLAS() const;
void setAllowFP16ReductionCuBLAS(bool);
bool allowBF16ReductionCuBLAS() const;
@ -475,21 +483,20 @@ class TORCH_API Context {
bool enable_sparse_tensor_invariant_checks = false;
bool allow_fp16_reduction_cpu = false;
std::map<std::string, std::map<std::string, std::string>> fp32_precision = {
{"generic", {{"all", "none"}}},
{"mkldnn",
{{"matmul", "none"},
{"conv", "none"},
{"rnn", "none"},
{"all", "none"}}},
{"cuda",
{{"matmul",
float32_matmul_precision == at::Float32MatmulPrecision::HIGHEST
? "none"
: "tf32"},
{"conv", "tf32"},
{"rnn", "tf32"},
{"all", "none"}}},
using Key = std::pair<Float32Backend, Float32Op>;
std::unordered_map<Key, Float32Precision, c10::hash<Key>> fp32_precision = {
{{Float32Backend::GENERIC, Float32Op::ALL}, Float32Precision::NONE},
{{Float32Backend::MKLDNN, Float32Op::ALL}, Float32Precision::NONE},
{{Float32Backend::MKLDNN, Float32Op::CONV}, Float32Precision::NONE},
{{Float32Backend::MKLDNN, Float32Op::RNN}, Float32Precision::NONE},
{{Float32Backend::MKLDNN, Float32Op::MATMUL}, Float32Precision::NONE},
{{Float32Backend::CUDA, Float32Op::ALL}, Float32Precision::NONE},
{{Float32Backend::CUDA, Float32Op::CONV}, Float32Precision::TF32},
{{Float32Backend::CUDA, Float32Op::RNN}, Float32Precision::TF32},
{{Float32Backend::CUDA, Float32Op::MATMUL},
float32_matmul_precision == at::Float32MatmulPrecision::HIGHEST
? Float32Precision::NONE
: Float32Precision::TF32},
};
Allocator* prev_allocator_ptr_{nullptr};
@ -671,5 +678,4 @@ struct TORCH_API ROCmBackwardPassGuard {
~ROCmBackwardPassGuard();
static bool is_backward_pass();
};
} // namespace at

View File

@ -395,7 +395,7 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
computeType = CUBLAS_COMPUTE_64F;
scaleType = CUDA_R_64F;
} else if constexpr (std::is_same_v<Dtype, float>) {
if (at::globalContext().float32Precision("cuda", "matmul") == "tf32") {
if (at::globalContext().float32Precision(at::Float32Backend::CUDA, at::Float32Op::MATMUL) == at::Float32Precision::TF32) {
computeType = CUBLAS_COMPUTE_32F_FAST_TF32;
}
} else if constexpr (std::is_same_v<Dtype, c10::complex<double>>) {
@ -1559,7 +1559,7 @@ bool gemm_and_bias(
computeType = CUBLAS_COMPUTE_64F;
scaleType = CUDA_R_64F;
} else if constexpr (std::is_same_v<Dtype, float>) {
if (at::globalContext().float32Precision("cuda", "matmul") == "tf32") {
if (at::globalContext().float32Precision(at::Float32Backend::CUDA, at::Float32Op::MATMUL) == at::Float32Precision::TF32) {
computeType = CUBLAS_COMPUTE_32F_FAST_TF32;
}
} else if constexpr (std::is_same_v<Dtype, at::Half>) {

View File

@ -310,7 +310,7 @@ cublasHandle_t getCurrentCUDABlasHandle() {
// FP32 data type calculations based on the value of the allow_tf32 flag.
// To enable TF32, set the math mode of the handle to CUBLAS_TF32_TENSOR_OP_MATH.
if (!NoTF32Guard::should_disable_tf32() &&
at::globalContext().float32Precision("cuda", "matmul") == "tf32") {
at::globalContext().float32Precision(at::Float32Backend::CUDA, at::Float32Op::MATMUL) == at::Float32Precision::TF32) {
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TF32_TENSOR_OP_MATH));
} else {
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));

View File

@ -162,7 +162,7 @@ inline std::string ComputeTypeFor() {
// ROCBLAS and hipBLASLt.
template <>
inline std::string ComputeTypeFor<float>() {
if (at::globalContext().float32Precision("cuda", "matmul") != "tf32") {
if (at::globalContext().float32Precision(at::Float32Backend::CUDA, at::Float32Op::MATMUL) != at::Float32Precision::TF32) {
return "f32_r";
} else {
return "xf32_r";

View File

@ -506,7 +506,7 @@ class HipblasltGemmOp : public Callable<ParamsT> {
}
hipblasComputeType_t computeType = HIPBLAS_COMPUTE_32F;
if (at::globalContext().float32Precision("cuda", "matmul") == "tf32") {
if (at::globalContext().float32Precision(at::Float32Backend::CUDA, at::Float32Op::MATMUL) == at::Float32Precision::TF32) {
computeType = HIPBLAS_COMPUTE_32F_FAST_TF32;
}
HipBlasLtMatmulDescriptor matmul(computeType, HIP_R_32F);

View File

@ -141,7 +141,7 @@ class RocblasGemmOp : public Callable<GemmParams<T>> {
TuningStatus Call(const GemmParams<T>* params) override {
auto input_output_type = RocBlasDataTypeFor<T>();
if (at::globalContext().float32Precision("cuda", "matmul") == "tf32" && input_output_type == rocblas_datatype_f32_r)
if (at::globalContext().float32Precision(at::Float32Backend::CUDA, at::Float32Op::MATMUL) == at::Float32Precision::TF32 && input_output_type == rocblas_datatype_f32_r)
return FAIL; // no support for TF32 in rocBLAS
auto compute_type = RocBlasComputeTypeFor<T>();
auto h_a = DoCastForHalfOrBfloat16(params->alpha);
@ -209,7 +209,7 @@ class RocblasGemmStridedBatchedOp : public Callable<GemmStridedBatchedParams<T>>
TuningStatus Call(const GemmStridedBatchedParams<T>* params) override {
auto input_output_type = RocBlasDataTypeFor<T>();
if (at::globalContext().float32Precision("cuda", "matmul") == "tf32" && input_output_type == rocblas_datatype_f32_r)
if (at::globalContext().float32Precision(at::Float32Backend::CUDA, at::Float32Op::MATMUL) == at::Float32Precision::TF32 && input_output_type == rocblas_datatype_f32_r)
return FAIL; // no support for TF32 in rocBLAS
auto compute_type = RocBlasComputeTypeFor<T>();
auto h_a = DoCastForHalfOrBfloat16(params->alpha);

View File

@ -1174,7 +1174,7 @@ at::Tensor convolution(
bool deterministic = ctx.deterministicCuDNN() || ctx.deterministicAlgorithms();
return at::_convolution(input, weight, bias, stride, padding, dilation,
transposed, output_padding, groups,
ctx.benchmarkCuDNN(), deterministic, ctx.userEnabledCuDNN(), ctx.allowTF32CuDNN("conv"));
ctx.benchmarkCuDNN(), deterministic, ctx.userEnabledCuDNN(), ctx.allowTF32CuDNN(at::Float32Op::CONV));
}
at::Tensor convolution_overrideable(
@ -1319,7 +1319,7 @@ ConvBackend select_conv_backend(
params.benchmark = ctx.benchmarkCuDNN();
params.deterministic = ctx.deterministicCuDNN() || ctx.deterministicAlgorithms();
params.cudnn_enabled = ctx.userEnabledCuDNN();
params.allow_tf32 = ctx.allowTF32CuDNN("conv");
params.allow_tf32 = ctx.allowTF32CuDNN(at::Float32Op::CONV);
auto input = input_r;
auto weight = weight_r;
@ -1699,7 +1699,7 @@ at::Tensor _convolution(
c10::MaybeOwned<Tensor> bias_r_maybe_owned = at::borrow_from_optional_tensor(bias_r_opt);
const Tensor& bias_r = *bias_r_maybe_owned;
return at::_convolution(input_r, weight_r, bias_r, stride_, padding_, dilation_, transposed_, output_padding_, groups_, benchmark, deterministic, cudnn_enabled, at::globalContext().allowTF32CuDNN("conv"));
return at::_convolution(input_r, weight_r, bias_r, stride_, padding_, dilation_, transposed_, output_padding_, groups_, benchmark, deterministic, cudnn_enabled, at::globalContext().allowTF32CuDNN(at::Float32Op::CONV));
}
std::tuple<Tensor, Tensor, Tensor> convolution_backward_overrideable(
@ -1997,7 +1997,7 @@ std::tuple<Tensor, Tensor, Tensor> convolution_backward(
params.benchmark = ctx.benchmarkCuDNN();
params.deterministic = ctx.deterministicCuDNN() || ctx.deterministicAlgorithms();
params.cudnn_enabled = ctx.userEnabledCuDNN();
params.allow_tf32 = ctx.allowTF32CuDNN("conv");
params.allow_tf32 = ctx.allowTF32CuDNN(at::Float32Op::CONV);
// Validate inputs.
check_shape_backward(input, weight.sizes(), params);

View File

@ -169,7 +169,10 @@ std::string repro_from_args(const ConvolutionParams& params) {
ss << "If that doesn't trigger the error, please include your original repro script when reporting this issue.\n\n";
ss << "import torch\n";
ss << "torch.backends.cuda.matmul.allow_tf32 = "
<< pybool(at::globalContext().float32Precision("cuda", "matmul") == "tf32")
<< pybool(
at::globalContext().float32Precision(
at::Float32Backend::CUDA, at::Float32Op::MATMUL) ==
at::Float32Precision::TF32)
<< "\n";
ss << "torch.backends.cudnn.benchmark = "
<< pybool(at::globalContext().benchmarkCuDNN()) << "\n";
@ -726,7 +729,7 @@ Tensor cudnn_convolution_relu(
auto& ctx = at::globalContext();
bool benchmark = ctx.benchmarkCuDNN();
bool allow_tf32 = ctx.allowTF32CuDNN("conv");
bool allow_tf32 = ctx.allowTF32CuDNN(at::Float32Op::CONV);
auto _bias = bias_t.has_value()
? bias_t.value()
: at::zeros(
@ -784,7 +787,7 @@ Tensor cudnn_convolution_add_relu(
}
auto& ctx = at::globalContext();
bool allow_tf32 = ctx.allowTF32CuDNN("conv");
bool allow_tf32 = ctx.allowTF32CuDNN(at::Float32Op::CONV);
bool benchmark = ctx.benchmarkCuDNN();
auto _alpha = alpha.has_value() ? alpha.value().to<float>() : 1.0;
auto _bias = bias_t.has_value()

View File

@ -245,7 +245,7 @@ descriptor(cudnnHandle_t handle, DropoutDescriptor&& dropout_desc) const {
datatype,
input_datatype,
algo,
at::globalContext().allowTF32CuDNN("rnn"));
at::globalContext().allowTF32CuDNN(at::Float32Op::RNN));
#else
rnn_desc.set(
handle,
@ -261,7 +261,7 @@ descriptor(cudnnHandle_t handle, DropoutDescriptor&& dropout_desc) const {
datatype,
input_datatype,
algo,
at::globalContext().allowTF32CuDNN("rnn"));
at::globalContext().allowTF32CuDNN(at::Float32Op::RNN));
#endif
return rnn_desc;
}

View File

@ -156,12 +156,12 @@ static void check_shape_forward(const Tensor& input,
//
static bool mkldnn_conv_enabled_fpmath_mode_bf16(){
return at::globalContext().float32Precision("mkldnn", "conv") == "bf16" &&
return at::globalContext().float32Precision(at::Float32Backend::MKLDNN, at::Float32Op::CONV) == at::Float32Precision::BF16 &&
mkldnn_bf16_device_check();
}
static bool mkldnn_conv_enabled_fpmath_mode_tf32(){
return at::globalContext().float32Precision("mkldnn", "conv") == "tf32" &&
return at::globalContext().float32Precision(at::Float32Backend::MKLDNN, at::Float32Op::CONV) == at::Float32Precision::TF32 &&
cpuinfo_has_x86_amx_fp16();
}

View File

@ -69,12 +69,12 @@ mkldnn_scaled_mm(const Tensor& mat1, const Tensor& mat2,
namespace at::native {
static bool use_mkldnn_bf32_linear() {
return at::globalContext().float32Precision("mkldnn", "matmul") == "bf16" &&
return at::globalContext().float32Precision(at::Float32Backend::MKLDNN, at::Float32Op::MATMUL) == at::Float32Precision::BF16 &&
mkldnn_bf16_device_check();
}
static bool use_mkldnn_tf32_linear() {
return at::globalContext().float32Precision("mkldnn", "matmul") == "tf32" &&
return at::globalContext().float32Precision(at::Float32Backend::MKLDNN, at::Float32Op::MATMUL) == at::Float32Precision::TF32 &&
cpuinfo_has_x86_amx_fp16();
}

View File

@ -111,11 +111,11 @@ static bool use_mkldnn_fp16_matmul() {
}
static bool use_mkldnn_bf32_matmul() {
return use_mkldnn_bf16_matmul() && at::globalContext().float32Precision("mkldnn", "matmul") == "bf16";
return use_mkldnn_bf16_matmul() && at::globalContext().float32Precision(at::Float32Backend::MKLDNN, at::Float32Op::MATMUL) == at::Float32Precision::BF16;
}
static bool use_mkldnn_tf32_matmul() {
return cpuinfo_has_x86_amx_fp16() && at::globalContext().float32Precision("mkldnn", "matmul") == "tf32";
return cpuinfo_has_x86_amx_fp16() && at::globalContext().float32Precision(at::Float32Backend::MKLDNN, at::Float32Op::MATMUL) == at::Float32Precision::TF32;
}
// returns an ideep::tensor

View File

@ -2493,7 +2493,8 @@ Call this whenever a new thread is created in order to propagate values from
py_module.def(
"_get_fp32_precision_getter",
[](const std::string& backend, const std::string& op) {
return at::globalContext().float32Precision(backend, op);
return at::precision2str(at::globalContext().float32Precision(
at::str2backend(backend), at::str2op(op)));
});
py_module.def(
@ -2501,7 +2502,10 @@ Call this whenever a new thread is created in order to propagate values from
[](const std::string& backend,
const std::string& op,
const std::string& precision) {
at::globalContext().setFloat32Precision(backend, op, precision);
at::globalContext().setFloat32Precision(
at::str2backend(backend),
at::str2op(op),
at::str2precision(precision));
return precision;
});

View File

@ -634,7 +634,9 @@ struct GlobalStateGuard {
_torch_function_all_disabled = at::impl::torch_function_all_disabled();
_deterministic_algorithms = ctx.deterministicAlgorithms();
_deterministic_algorithms_warn_only = ctx.deterministicAlgorithmsWarnOnly();
_allow_tf32 = ctx.float32Precision("cuda", "matmul") == "tf32";
_allow_tf32 =
ctx.float32Precision(at::Float32Backend::CUDA, at::Float32Op::MATMUL) ==
at::Float32Precision::TF32;
_allow_fp16_reduce = ctx.allowFP16ReductionCuBLAS();
_allow_bf16_reduce = ctx.allowBF16ReductionCuBLAS();
_num_threads = at::get_num_threads();
@ -651,7 +653,10 @@ struct GlobalStateGuard {
_deterministic_algorithms == ctx.deterministicAlgorithms() &&
_deterministic_algorithms_warn_only ==
ctx.deterministicAlgorithmsWarnOnly() &&
_allow_tf32 == (ctx.float32Precision("cuda", "matmul") == "tf32") &&
_allow_tf32 ==
(ctx.float32Precision(
at::Float32Backend::CUDA, at::Float32Op::MATMUL) ==
at::Float32Precision::TF32) &&
_allow_fp16_reduce == ctx.allowFP16ReductionCuBLAS() &&
_allow_bf16_reduce == ctx.allowBF16ReductionCuBLAS() &&
_num_threads == at::get_num_threads()) &&
@ -672,7 +677,10 @@ struct GlobalStateGuard {
if (_deterministic_algorithms_warn_only !=
ctx.deterministicAlgorithmsWarnOnly())
os << "deterministic_algorithms_warn_only ";
if (_allow_tf32 != (ctx.float32Precision("cuda", "matmul") == "tf32"))
if (_allow_tf32 !=
(ctx.float32Precision(
at::Float32Backend::CUDA, at::Float32Op::MATMUL) ==
at::Float32Precision::TF32))
os << "allow_tf32 ";
if (_allow_fp16_reduce != ctx.allowFP16ReductionCuBLAS())
os << "allow_fp16_reduce ";

View File

@ -397,7 +397,9 @@ std::unique_ptr<KinetoObserverContext> ThreadLocalSubqueue::begin_op(
event->start_time_ = c10::getApproximateTime();
event->allow_tf32_cublas_ =
at::globalContext().float32Precision("cuda", "matmul") == "tf32";
at::globalContext().float32Precision(
at::Float32Backend::CUDA, at::Float32Op::MATMUL) ==
at::Float32Precision::TF32;
if (!config_.experimental_config.performance_events.empty()) {
const size_t n = config_.experimental_config.performance_events.size();
event->counters_ = std::make_unique<perf_counters_t>(n, 0);