mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Speed up FP precision lookup (#164044)
This commit simplifies the precision lookup and setting logic by reducing the number of branches and using a custom hash function. Fixes #161822. The issue described in #163709 still persists. This is meant as a short term fix. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164044 Approved by: https://github.com/ngimel, https://github.com/eqy
This commit is contained in:
parent
8d53d788fe
commit
f006aee601
|
|
@ -40,41 +40,6 @@ namespace {
|
|||
->conv
|
||||
->rnn
|
||||
*/
|
||||
const std::map<std::string, std::vector<std::string>> _fp32_precisions = {
|
||||
{"generic", {{"ieee", "tf32", "bf16", "none"}}},
|
||||
{"mkldnn", {{"ieee", "tf32", "bf16", "none"}}},
|
||||
{"cuda", {{"ieee", "tf32", "none"}}}};
|
||||
|
||||
// Check whether the backend and op are legal
|
||||
void check_fp32_prec_backend_and_op(
|
||||
const std::string& backend,
|
||||
const std::string& op) {
|
||||
static std::vector<std::string> backends = {"generic", "mkldnn", "cuda"};
|
||||
static std::vector<std::string> operators = {"conv", "matmul", "rnn", "all"};
|
||||
TORCH_CHECK(
|
||||
std::find(backends.begin(), backends.end(), backend) != backends.end(),
|
||||
"Invalid backend: ",
|
||||
backend);
|
||||
TORCH_CHECK(
|
||||
std::find(operators.begin(), operators.end(), op) != operators.end(),
|
||||
"Invalid operator: ",
|
||||
op);
|
||||
if (backend == "generic") {
|
||||
TORCH_CHECK(op == "all", "Invalid operation for generic backend: ", op);
|
||||
}
|
||||
}
|
||||
|
||||
// Return whether the precision is supported by backends
|
||||
bool validate_fp32_prec(
|
||||
const std::string& backend,
|
||||
const std::string& precision) {
|
||||
auto iterp = _fp32_precisions.find(backend);
|
||||
TORCH_CHECK(iterp != _fp32_precisions.end());
|
||||
auto precisions = iterp->second;
|
||||
bool valid = std::find(precisions.begin(), precisions.end(), precision) !=
|
||||
precisions.end();
|
||||
return valid;
|
||||
}
|
||||
|
||||
C10_ALWAYS_INLINE void warn_deprecated_fp32_precision_api(){
|
||||
TORCH_WARN_ONCE(
|
||||
|
|
@ -86,6 +51,54 @@ void check_fp32_prec_backend_and_op(
|
|||
}
|
||||
} // namespace
|
||||
|
||||
Float32Backend str2backend(const std::string& name) {
|
||||
if (name == "generic")
|
||||
return Float32Backend::GENERIC;
|
||||
else if (name == "cuda")
|
||||
return Float32Backend::CUDA;
|
||||
else if (name == "mkldnn")
|
||||
return Float32Backend::MKLDNN;
|
||||
TORCH_CHECK(false, "Unknown backend: ", name);
|
||||
}
|
||||
|
||||
Float32Op str2op(const std::string& name) {
|
||||
if (name == "all")
|
||||
return Float32Op::ALL;
|
||||
else if (name == "conv")
|
||||
return Float32Op::CONV;
|
||||
else if (name == "rnn")
|
||||
return Float32Op::RNN;
|
||||
else if (name == "matmul")
|
||||
return Float32Op::MATMUL;
|
||||
TORCH_CHECK(false, "Unknown op: ", name);
|
||||
}
|
||||
|
||||
Float32Precision str2precision(const std::string& name) {
|
||||
if (name == "none")
|
||||
return Float32Precision::NONE;
|
||||
else if (name == "ieee")
|
||||
return Float32Precision::IEEE;
|
||||
else if (name == "tf32")
|
||||
return Float32Precision::TF32;
|
||||
else if (name == "bf16")
|
||||
return Float32Precision::BF16;
|
||||
TORCH_CHECK(false, "Unknown precision: ", name);
|
||||
}
|
||||
|
||||
std::string precision2str(Float32Precision prec) {
|
||||
switch (prec) {
|
||||
case Float32Precision::NONE:
|
||||
return "none";
|
||||
case Float32Precision::IEEE:
|
||||
return "ieee";
|
||||
case Float32Precision::TF32:
|
||||
return "tf32";
|
||||
case Float32Precision::BF16:
|
||||
return "bf16";
|
||||
}
|
||||
TORCH_CHECK(false, "Invalid enum Float32Precision(", static_cast<int>(prec), ")");
|
||||
}
|
||||
|
||||
Context::Context() = default;
|
||||
|
||||
// TODO: This could be bad juju if someone calls globalContext() in the
|
||||
|
|
@ -179,10 +192,10 @@ void Context::setUserEnabledNNPACK(bool e) {
|
|||
enabled_nnpack = e;
|
||||
}
|
||||
|
||||
bool Context::allowTF32CuDNN(const std::string& op) const {
|
||||
if (op.empty()){
|
||||
bool allow_tf32_rnn = float32Precision("cuda", "rnn") == "tf32";
|
||||
bool allow_tf32_conv = float32Precision("cuda", "conv") == "tf32";
|
||||
bool Context::allowTF32CuDNN(std::optional<Float32Op> op) const {
|
||||
if (!op.has_value()) {
|
||||
bool allow_tf32_rnn = float32Precision(Float32Backend::CUDA, Float32Op::RNN) == Float32Precision::TF32;
|
||||
bool allow_tf32_conv = float32Precision(Float32Backend::CUDA, Float32Op::CONV) == Float32Precision::TF32;
|
||||
TORCH_CHECK(
|
||||
allow_tf32_rnn == allow_tf32_conv && allow_tf32_rnn == allow_tf32_cudnn,
|
||||
"PyTorch is checking whether allow_tf32 is enabled for cuDNN without a specific operator name,",
|
||||
|
|
@ -191,15 +204,15 @@ bool Context::allowTF32CuDNN(const std::string& op) const {
|
|||
"We suggest only using the new API to set the TF32 flag(s). See also: ",
|
||||
"https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices");
|
||||
} else {
|
||||
return float32Precision("cuda", op) == "tf32";
|
||||
return float32Precision(Float32Backend::CUDA, op.value()) == Float32Precision::TF32;
|
||||
}
|
||||
warn_deprecated_fp32_precision_api();
|
||||
return allow_tf32_cudnn;
|
||||
}
|
||||
|
||||
void Context::setAllowTF32CuDNN(bool b) {
|
||||
setFloat32Precision("cuda", "rnn", b ? "tf32" : "none");
|
||||
setFloat32Precision("cuda", "conv", b ? "tf32" : "none");
|
||||
setFloat32Precision(Float32Backend::CUDA, Float32Op::RNN, b ? Float32Precision::TF32 : Float32Precision::NONE);
|
||||
setFloat32Precision(Float32Backend::CUDA, Float32Op::CONV, b ? Float32Precision::TF32 : Float32Precision::NONE);
|
||||
allow_tf32_cudnn = b;
|
||||
warn_deprecated_fp32_precision_api();
|
||||
}
|
||||
|
|
@ -305,7 +318,7 @@ void Context::setImmediateMiopen(bool b) {
|
|||
|
||||
bool Context::allowTF32CuBLAS() const {
|
||||
bool legacy_allow_tf32 = float32_matmul_precision != at::Float32MatmulPrecision::HIGHEST;
|
||||
bool allow_tf32_new = float32Precision("cuda", "matmul") == "tf32";
|
||||
bool allow_tf32_new = float32Precision(Float32Backend::CUDA, Float32Op::MATMUL) == Float32Precision::TF32;
|
||||
TORCH_CHECK(
|
||||
legacy_allow_tf32 == allow_tf32_new,
|
||||
"PyTorch is checking whether allow_tf32_new is enabled for cuBlas matmul,",
|
||||
|
|
@ -318,17 +331,17 @@ bool Context::allowTF32CuBLAS() const {
|
|||
|
||||
void Context::setAllowTF32CuBLAS(bool b) {
|
||||
float32_matmul_precision = b ? at::Float32MatmulPrecision::HIGH : at::Float32MatmulPrecision::HIGHEST;
|
||||
setFloat32Precision("cuda", "matmul", b ? "tf32" : "ieee");
|
||||
setFloat32Precision(Float32Backend::CUDA, Float32Op::MATMUL, b ? Float32Precision::TF32 : Float32Precision::IEEE);
|
||||
}
|
||||
|
||||
Float32MatmulPrecision Context::float32MatmulPrecision() const {
|
||||
bool invalid = float32Precision("cuda", "matmul") == "tf32" &&
|
||||
bool invalid = float32Precision(Float32Backend::CUDA, Float32Op::MATMUL) == Float32Precision::TF32 &&
|
||||
float32_matmul_precision == at::Float32MatmulPrecision::HIGHEST;
|
||||
invalid = invalid ||
|
||||
(float32Precision("mkldnn", "matmul") == "bf16" &&
|
||||
(float32Precision(Float32Backend::MKLDNN, Float32Op::MATMUL) == Float32Precision::BF16 &&
|
||||
float32_matmul_precision != at::Float32MatmulPrecision::MEDIUM);
|
||||
invalid = invalid ||
|
||||
(float32Precision("mkldnn", "matmul") == "tf32" &&
|
||||
(float32Precision(Float32Backend::MKLDNN, Float32Op::MATMUL) == Float32Precision::TF32 &&
|
||||
float32_matmul_precision != at::Float32MatmulPrecision::HIGH);
|
||||
TORCH_CHECK(
|
||||
!invalid,
|
||||
|
|
@ -340,15 +353,26 @@ Float32MatmulPrecision Context::float32MatmulPrecision() const {
|
|||
return float32_matmul_precision;
|
||||
}
|
||||
|
||||
std::string Context::float32Precision(const std::string& backend, const std::string& op) const {
|
||||
check_fp32_prec_backend_and_op(backend, op);
|
||||
auto precision = fp32_precision.find(backend)->second.find(op)->second;
|
||||
if (precision == "none")
|
||||
precision = fp32_precision.find(backend)->second.find("all")->second;
|
||||
if (precision == "none")
|
||||
precision = fp32_precision.find("generic")->second.find("all")->second;
|
||||
bool valid_prec = validate_fp32_prec(backend, precision);
|
||||
return valid_prec ? precision : "none";
|
||||
Float32Precision Context::float32Precision(Float32Backend backend, Float32Op op) const {
|
||||
std::pair<Float32Backend, Float32Op> key{backend, op};
|
||||
auto it = fp32_precision.find(key);
|
||||
TORCH_CHECK(it != fp32_precision.end(), "Invalid (backend, op) pair: (", backend, ", ", op, ")");
|
||||
|
||||
Float32Precision precision = it->second;
|
||||
if (precision == Float32Precision::NONE) {
|
||||
key.second = Float32Op::ALL;
|
||||
precision = fp32_precision.find(key)->second;
|
||||
}
|
||||
if (precision == Float32Precision::NONE) {
|
||||
key.first = Float32Backend::GENERIC;
|
||||
precision = fp32_precision.find(key)->second;
|
||||
}
|
||||
|
||||
// "cuda" does not support "bf16"
|
||||
if (backend == Float32Backend::CUDA && precision == Float32Precision::BF16) {
|
||||
return Float32Precision::NONE;
|
||||
}
|
||||
return precision;
|
||||
}
|
||||
|
||||
void Context::setFloat32MatmulPrecision(const std::string &s) {
|
||||
|
|
@ -357,18 +381,18 @@ void Context::setFloat32MatmulPrecision(const std::string &s) {
|
|||
// TODO: consider if CuDNN field needs to also be set for potential future CuDNN ops like multi-headed attention
|
||||
if (s_ == "highest") {
|
||||
float32_matmul_precision = at::Float32MatmulPrecision::HIGHEST;
|
||||
setFloat32Precision("cuda", "matmul", "ieee");
|
||||
setFloat32Precision("mkldnn", "matmul", "ieee");
|
||||
setFloat32Precision(Float32Backend::CUDA, Float32Op::MATMUL, Float32Precision::IEEE);
|
||||
setFloat32Precision(Float32Backend::MKLDNN, Float32Op::MATMUL, Float32Precision::IEEE);
|
||||
return true;
|
||||
} else if (s_ == "high") {
|
||||
float32_matmul_precision = at::Float32MatmulPrecision::HIGH;
|
||||
setFloat32Precision("cuda", "matmul", "tf32");
|
||||
setFloat32Precision("mkldnn", "matmul", "tf32");
|
||||
setFloat32Precision(Float32Backend::CUDA, Float32Op::MATMUL, Float32Precision::TF32);
|
||||
setFloat32Precision(Float32Backend::MKLDNN, Float32Op::MATMUL, Float32Precision::TF32);
|
||||
return true;
|
||||
} else if (s_ == "medium") {
|
||||
float32_matmul_precision = at::Float32MatmulPrecision::MEDIUM;
|
||||
setFloat32Precision("cuda", "matmul", "tf32");
|
||||
setFloat32Precision("mkldnn", "matmul", "bf16");
|
||||
setFloat32Precision(Float32Backend::CUDA, Float32Op::MATMUL, Float32Precision::TF32);
|
||||
setFloat32Precision(Float32Backend::MKLDNN, Float32Op::MATMUL, Float32Precision::BF16);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
|
|
@ -382,25 +406,16 @@ void Context::setFloat32MatmulPrecision(const std::string &s) {
|
|||
"setFloat32MatmulPrecision call has no effect.");
|
||||
}
|
||||
|
||||
void Context::setFloat32Precision(const std::string& backend, const std::string& op, const std::string& p) {
|
||||
check_fp32_prec_backend_and_op(backend, op);
|
||||
if (validate_fp32_prec(backend, p)) {
|
||||
fp32_precision[backend][op] = p;
|
||||
} else {
|
||||
std::string msg;
|
||||
auto iterp = _fp32_precisions.find(backend);
|
||||
TORCH_CHECK(iterp != _fp32_precisions.end());
|
||||
for (const auto& p : iterp->second) {
|
||||
msg += p;
|
||||
msg += " ";
|
||||
}
|
||||
TORCH_WARN(
|
||||
"you have set wrong precision for backend:",
|
||||
backend,
|
||||
" setFloat32Precision call has no effect.",
|
||||
"Please choose precision from: ",
|
||||
msg);
|
||||
}
|
||||
void Context::setFloat32Precision(Float32Backend backend, Float32Op op, Float32Precision p) {
|
||||
auto it = fp32_precision.find(std::make_pair(backend, op));
|
||||
TORCH_CHECK(
|
||||
it != fp32_precision.end(),
|
||||
"Invalid (backend, op) pair: (", backend, ", ", op, ")");
|
||||
TORCH_CHECK(
|
||||
!(backend == Float32Backend::CUDA && p == Float32Precision::BF16),
|
||||
"backend 'cuda' does not support precision 'bf16'");
|
||||
|
||||
it->second = p;
|
||||
}
|
||||
|
||||
at::LinalgBackend Context::linalgPreferredBackend() const {
|
||||
|
|
|
|||
|
|
@ -25,17 +25,27 @@
|
|||
#include <c10/util/CallOnce.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/env.h>
|
||||
#include <c10/util/hash.h>
|
||||
#include <c10/util/irange.h>
|
||||
|
||||
#include <cstdint>
|
||||
#include <map>
|
||||
#include <mutex>
|
||||
#include <unordered_map>
|
||||
|
||||
namespace at {
|
||||
|
||||
class Tensor;
|
||||
|
||||
enum class TORCH_API Float32MatmulPrecision { HIGHEST, HIGH, MEDIUM };
|
||||
enum class TORCH_API Float32Backend { GENERIC, CUDA, MKLDNN };
|
||||
enum class TORCH_API Float32Op { ALL, CONV, RNN, MATMUL };
|
||||
enum class TORCH_API Float32Precision { NONE, IEEE, TF32, BF16 };
|
||||
|
||||
TORCH_API Float32Backend str2backend(const std::string& name);
|
||||
TORCH_API Float32Op str2op(const std::string& name);
|
||||
TORCH_API Float32Precision str2precision(const std::string& name);
|
||||
TORCH_API std::string precision2str(Float32Precision prec);
|
||||
|
||||
class TORCH_API Context {
|
||||
public:
|
||||
|
|
@ -336,19 +346,17 @@ class TORCH_API Context {
|
|||
|
||||
void setFloat32MatmulPrecision(const std::string& s);
|
||||
void setFloat32Precision(
|
||||
const std::string& backend,
|
||||
const std::string& op,
|
||||
const std::string& s);
|
||||
bool allowTF32CuDNN(const std::string& op = std::string()) const;
|
||||
Float32Backend backend,
|
||||
Float32Op op,
|
||||
Float32Precision p);
|
||||
bool allowTF32CuDNN(std::optional<Float32Op> op = std::nullopt) const;
|
||||
void setAllowTF32CuDNN(bool);
|
||||
bool allowTF32OneDNN() const;
|
||||
void setAllowTF32OneDNN(bool);
|
||||
bool allowTF32CuBLAS() const;
|
||||
void setAllowTF32CuBLAS(bool);
|
||||
Float32MatmulPrecision float32MatmulPrecision() const;
|
||||
std::string float32Precision(
|
||||
const std::string& backend,
|
||||
const std::string& op) const;
|
||||
Float32Precision float32Precision(Float32Backend backend, Float32Op op) const;
|
||||
bool allowFP16ReductionCuBLAS() const;
|
||||
void setAllowFP16ReductionCuBLAS(bool);
|
||||
bool allowBF16ReductionCuBLAS() const;
|
||||
|
|
@ -475,21 +483,20 @@ class TORCH_API Context {
|
|||
bool enable_sparse_tensor_invariant_checks = false;
|
||||
bool allow_fp16_reduction_cpu = false;
|
||||
|
||||
std::map<std::string, std::map<std::string, std::string>> fp32_precision = {
|
||||
{"generic", {{"all", "none"}}},
|
||||
{"mkldnn",
|
||||
{{"matmul", "none"},
|
||||
{"conv", "none"},
|
||||
{"rnn", "none"},
|
||||
{"all", "none"}}},
|
||||
{"cuda",
|
||||
{{"matmul",
|
||||
float32_matmul_precision == at::Float32MatmulPrecision::HIGHEST
|
||||
? "none"
|
||||
: "tf32"},
|
||||
{"conv", "tf32"},
|
||||
{"rnn", "tf32"},
|
||||
{"all", "none"}}},
|
||||
using Key = std::pair<Float32Backend, Float32Op>;
|
||||
std::unordered_map<Key, Float32Precision, c10::hash<Key>> fp32_precision = {
|
||||
{{Float32Backend::GENERIC, Float32Op::ALL}, Float32Precision::NONE},
|
||||
{{Float32Backend::MKLDNN, Float32Op::ALL}, Float32Precision::NONE},
|
||||
{{Float32Backend::MKLDNN, Float32Op::CONV}, Float32Precision::NONE},
|
||||
{{Float32Backend::MKLDNN, Float32Op::RNN}, Float32Precision::NONE},
|
||||
{{Float32Backend::MKLDNN, Float32Op::MATMUL}, Float32Precision::NONE},
|
||||
{{Float32Backend::CUDA, Float32Op::ALL}, Float32Precision::NONE},
|
||||
{{Float32Backend::CUDA, Float32Op::CONV}, Float32Precision::TF32},
|
||||
{{Float32Backend::CUDA, Float32Op::RNN}, Float32Precision::TF32},
|
||||
{{Float32Backend::CUDA, Float32Op::MATMUL},
|
||||
float32_matmul_precision == at::Float32MatmulPrecision::HIGHEST
|
||||
? Float32Precision::NONE
|
||||
: Float32Precision::TF32},
|
||||
};
|
||||
|
||||
Allocator* prev_allocator_ptr_{nullptr};
|
||||
|
|
@ -671,5 +678,4 @@ struct TORCH_API ROCmBackwardPassGuard {
|
|||
~ROCmBackwardPassGuard();
|
||||
static bool is_backward_pass();
|
||||
};
|
||||
|
||||
} // namespace at
|
||||
|
|
|
|||
|
|
@ -395,7 +395,7 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
|
|||
computeType = CUBLAS_COMPUTE_64F;
|
||||
scaleType = CUDA_R_64F;
|
||||
} else if constexpr (std::is_same_v<Dtype, float>) {
|
||||
if (at::globalContext().float32Precision("cuda", "matmul") == "tf32") {
|
||||
if (at::globalContext().float32Precision(at::Float32Backend::CUDA, at::Float32Op::MATMUL) == at::Float32Precision::TF32) {
|
||||
computeType = CUBLAS_COMPUTE_32F_FAST_TF32;
|
||||
}
|
||||
} else if constexpr (std::is_same_v<Dtype, c10::complex<double>>) {
|
||||
|
|
@ -1559,7 +1559,7 @@ bool gemm_and_bias(
|
|||
computeType = CUBLAS_COMPUTE_64F;
|
||||
scaleType = CUDA_R_64F;
|
||||
} else if constexpr (std::is_same_v<Dtype, float>) {
|
||||
if (at::globalContext().float32Precision("cuda", "matmul") == "tf32") {
|
||||
if (at::globalContext().float32Precision(at::Float32Backend::CUDA, at::Float32Op::MATMUL) == at::Float32Precision::TF32) {
|
||||
computeType = CUBLAS_COMPUTE_32F_FAST_TF32;
|
||||
}
|
||||
} else if constexpr (std::is_same_v<Dtype, at::Half>) {
|
||||
|
|
|
|||
|
|
@ -310,7 +310,7 @@ cublasHandle_t getCurrentCUDABlasHandle() {
|
|||
// FP32 data type calculations based on the value of the allow_tf32 flag.
|
||||
// To enable TF32, set the math mode of the handle to CUBLAS_TF32_TENSOR_OP_MATH.
|
||||
if (!NoTF32Guard::should_disable_tf32() &&
|
||||
at::globalContext().float32Precision("cuda", "matmul") == "tf32") {
|
||||
at::globalContext().float32Precision(at::Float32Backend::CUDA, at::Float32Op::MATMUL) == at::Float32Precision::TF32) {
|
||||
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TF32_TENSOR_OP_MATH));
|
||||
} else {
|
||||
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
|
||||
|
|
|
|||
|
|
@ -162,7 +162,7 @@ inline std::string ComputeTypeFor() {
|
|||
// ROCBLAS and hipBLASLt.
|
||||
template <>
|
||||
inline std::string ComputeTypeFor<float>() {
|
||||
if (at::globalContext().float32Precision("cuda", "matmul") != "tf32") {
|
||||
if (at::globalContext().float32Precision(at::Float32Backend::CUDA, at::Float32Op::MATMUL) != at::Float32Precision::TF32) {
|
||||
return "f32_r";
|
||||
} else {
|
||||
return "xf32_r";
|
||||
|
|
|
|||
|
|
@ -506,7 +506,7 @@ class HipblasltGemmOp : public Callable<ParamsT> {
|
|||
}
|
||||
|
||||
hipblasComputeType_t computeType = HIPBLAS_COMPUTE_32F;
|
||||
if (at::globalContext().float32Precision("cuda", "matmul") == "tf32") {
|
||||
if (at::globalContext().float32Precision(at::Float32Backend::CUDA, at::Float32Op::MATMUL) == at::Float32Precision::TF32) {
|
||||
computeType = HIPBLAS_COMPUTE_32F_FAST_TF32;
|
||||
}
|
||||
HipBlasLtMatmulDescriptor matmul(computeType, HIP_R_32F);
|
||||
|
|
|
|||
|
|
@ -141,7 +141,7 @@ class RocblasGemmOp : public Callable<GemmParams<T>> {
|
|||
|
||||
TuningStatus Call(const GemmParams<T>* params) override {
|
||||
auto input_output_type = RocBlasDataTypeFor<T>();
|
||||
if (at::globalContext().float32Precision("cuda", "matmul") == "tf32" && input_output_type == rocblas_datatype_f32_r)
|
||||
if (at::globalContext().float32Precision(at::Float32Backend::CUDA, at::Float32Op::MATMUL) == at::Float32Precision::TF32 && input_output_type == rocblas_datatype_f32_r)
|
||||
return FAIL; // no support for TF32 in rocBLAS
|
||||
auto compute_type = RocBlasComputeTypeFor<T>();
|
||||
auto h_a = DoCastForHalfOrBfloat16(params->alpha);
|
||||
|
|
@ -209,7 +209,7 @@ class RocblasGemmStridedBatchedOp : public Callable<GemmStridedBatchedParams<T>>
|
|||
|
||||
TuningStatus Call(const GemmStridedBatchedParams<T>* params) override {
|
||||
auto input_output_type = RocBlasDataTypeFor<T>();
|
||||
if (at::globalContext().float32Precision("cuda", "matmul") == "tf32" && input_output_type == rocblas_datatype_f32_r)
|
||||
if (at::globalContext().float32Precision(at::Float32Backend::CUDA, at::Float32Op::MATMUL) == at::Float32Precision::TF32 && input_output_type == rocblas_datatype_f32_r)
|
||||
return FAIL; // no support for TF32 in rocBLAS
|
||||
auto compute_type = RocBlasComputeTypeFor<T>();
|
||||
auto h_a = DoCastForHalfOrBfloat16(params->alpha);
|
||||
|
|
|
|||
|
|
@ -1174,7 +1174,7 @@ at::Tensor convolution(
|
|||
bool deterministic = ctx.deterministicCuDNN() || ctx.deterministicAlgorithms();
|
||||
return at::_convolution(input, weight, bias, stride, padding, dilation,
|
||||
transposed, output_padding, groups,
|
||||
ctx.benchmarkCuDNN(), deterministic, ctx.userEnabledCuDNN(), ctx.allowTF32CuDNN("conv"));
|
||||
ctx.benchmarkCuDNN(), deterministic, ctx.userEnabledCuDNN(), ctx.allowTF32CuDNN(at::Float32Op::CONV));
|
||||
}
|
||||
|
||||
at::Tensor convolution_overrideable(
|
||||
|
|
@ -1319,7 +1319,7 @@ ConvBackend select_conv_backend(
|
|||
params.benchmark = ctx.benchmarkCuDNN();
|
||||
params.deterministic = ctx.deterministicCuDNN() || ctx.deterministicAlgorithms();
|
||||
params.cudnn_enabled = ctx.userEnabledCuDNN();
|
||||
params.allow_tf32 = ctx.allowTF32CuDNN("conv");
|
||||
params.allow_tf32 = ctx.allowTF32CuDNN(at::Float32Op::CONV);
|
||||
|
||||
auto input = input_r;
|
||||
auto weight = weight_r;
|
||||
|
|
@ -1699,7 +1699,7 @@ at::Tensor _convolution(
|
|||
c10::MaybeOwned<Tensor> bias_r_maybe_owned = at::borrow_from_optional_tensor(bias_r_opt);
|
||||
const Tensor& bias_r = *bias_r_maybe_owned;
|
||||
|
||||
return at::_convolution(input_r, weight_r, bias_r, stride_, padding_, dilation_, transposed_, output_padding_, groups_, benchmark, deterministic, cudnn_enabled, at::globalContext().allowTF32CuDNN("conv"));
|
||||
return at::_convolution(input_r, weight_r, bias_r, stride_, padding_, dilation_, transposed_, output_padding_, groups_, benchmark, deterministic, cudnn_enabled, at::globalContext().allowTF32CuDNN(at::Float32Op::CONV));
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor, Tensor> convolution_backward_overrideable(
|
||||
|
|
@ -1997,7 +1997,7 @@ std::tuple<Tensor, Tensor, Tensor> convolution_backward(
|
|||
params.benchmark = ctx.benchmarkCuDNN();
|
||||
params.deterministic = ctx.deterministicCuDNN() || ctx.deterministicAlgorithms();
|
||||
params.cudnn_enabled = ctx.userEnabledCuDNN();
|
||||
params.allow_tf32 = ctx.allowTF32CuDNN("conv");
|
||||
params.allow_tf32 = ctx.allowTF32CuDNN(at::Float32Op::CONV);
|
||||
|
||||
// Validate inputs.
|
||||
check_shape_backward(input, weight.sizes(), params);
|
||||
|
|
|
|||
|
|
@ -169,7 +169,10 @@ std::string repro_from_args(const ConvolutionParams& params) {
|
|||
ss << "If that doesn't trigger the error, please include your original repro script when reporting this issue.\n\n";
|
||||
ss << "import torch\n";
|
||||
ss << "torch.backends.cuda.matmul.allow_tf32 = "
|
||||
<< pybool(at::globalContext().float32Precision("cuda", "matmul") == "tf32")
|
||||
<< pybool(
|
||||
at::globalContext().float32Precision(
|
||||
at::Float32Backend::CUDA, at::Float32Op::MATMUL) ==
|
||||
at::Float32Precision::TF32)
|
||||
<< "\n";
|
||||
ss << "torch.backends.cudnn.benchmark = "
|
||||
<< pybool(at::globalContext().benchmarkCuDNN()) << "\n";
|
||||
|
|
@ -726,7 +729,7 @@ Tensor cudnn_convolution_relu(
|
|||
|
||||
auto& ctx = at::globalContext();
|
||||
bool benchmark = ctx.benchmarkCuDNN();
|
||||
bool allow_tf32 = ctx.allowTF32CuDNN("conv");
|
||||
bool allow_tf32 = ctx.allowTF32CuDNN(at::Float32Op::CONV);
|
||||
auto _bias = bias_t.has_value()
|
||||
? bias_t.value()
|
||||
: at::zeros(
|
||||
|
|
@ -784,7 +787,7 @@ Tensor cudnn_convolution_add_relu(
|
|||
}
|
||||
|
||||
auto& ctx = at::globalContext();
|
||||
bool allow_tf32 = ctx.allowTF32CuDNN("conv");
|
||||
bool allow_tf32 = ctx.allowTF32CuDNN(at::Float32Op::CONV);
|
||||
bool benchmark = ctx.benchmarkCuDNN();
|
||||
auto _alpha = alpha.has_value() ? alpha.value().to<float>() : 1.0;
|
||||
auto _bias = bias_t.has_value()
|
||||
|
|
|
|||
|
|
@ -245,7 +245,7 @@ descriptor(cudnnHandle_t handle, DropoutDescriptor&& dropout_desc) const {
|
|||
datatype,
|
||||
input_datatype,
|
||||
algo,
|
||||
at::globalContext().allowTF32CuDNN("rnn"));
|
||||
at::globalContext().allowTF32CuDNN(at::Float32Op::RNN));
|
||||
#else
|
||||
rnn_desc.set(
|
||||
handle,
|
||||
|
|
@ -261,7 +261,7 @@ descriptor(cudnnHandle_t handle, DropoutDescriptor&& dropout_desc) const {
|
|||
datatype,
|
||||
input_datatype,
|
||||
algo,
|
||||
at::globalContext().allowTF32CuDNN("rnn"));
|
||||
at::globalContext().allowTF32CuDNN(at::Float32Op::RNN));
|
||||
#endif
|
||||
return rnn_desc;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -156,12 +156,12 @@ static void check_shape_forward(const Tensor& input,
|
|||
//
|
||||
|
||||
static bool mkldnn_conv_enabled_fpmath_mode_bf16(){
|
||||
return at::globalContext().float32Precision("mkldnn", "conv") == "bf16" &&
|
||||
return at::globalContext().float32Precision(at::Float32Backend::MKLDNN, at::Float32Op::CONV) == at::Float32Precision::BF16 &&
|
||||
mkldnn_bf16_device_check();
|
||||
}
|
||||
|
||||
static bool mkldnn_conv_enabled_fpmath_mode_tf32(){
|
||||
return at::globalContext().float32Precision("mkldnn", "conv") == "tf32" &&
|
||||
return at::globalContext().float32Precision(at::Float32Backend::MKLDNN, at::Float32Op::CONV) == at::Float32Precision::TF32 &&
|
||||
cpuinfo_has_x86_amx_fp16();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -69,12 +69,12 @@ mkldnn_scaled_mm(const Tensor& mat1, const Tensor& mat2,
|
|||
namespace at::native {
|
||||
|
||||
static bool use_mkldnn_bf32_linear() {
|
||||
return at::globalContext().float32Precision("mkldnn", "matmul") == "bf16" &&
|
||||
return at::globalContext().float32Precision(at::Float32Backend::MKLDNN, at::Float32Op::MATMUL) == at::Float32Precision::BF16 &&
|
||||
mkldnn_bf16_device_check();
|
||||
}
|
||||
|
||||
static bool use_mkldnn_tf32_linear() {
|
||||
return at::globalContext().float32Precision("mkldnn", "matmul") == "tf32" &&
|
||||
return at::globalContext().float32Precision(at::Float32Backend::MKLDNN, at::Float32Op::MATMUL) == at::Float32Precision::TF32 &&
|
||||
cpuinfo_has_x86_amx_fp16();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -111,11 +111,11 @@ static bool use_mkldnn_fp16_matmul() {
|
|||
}
|
||||
|
||||
static bool use_mkldnn_bf32_matmul() {
|
||||
return use_mkldnn_bf16_matmul() && at::globalContext().float32Precision("mkldnn", "matmul") == "bf16";
|
||||
return use_mkldnn_bf16_matmul() && at::globalContext().float32Precision(at::Float32Backend::MKLDNN, at::Float32Op::MATMUL) == at::Float32Precision::BF16;
|
||||
}
|
||||
|
||||
static bool use_mkldnn_tf32_matmul() {
|
||||
return cpuinfo_has_x86_amx_fp16() && at::globalContext().float32Precision("mkldnn", "matmul") == "tf32";
|
||||
return cpuinfo_has_x86_amx_fp16() && at::globalContext().float32Precision(at::Float32Backend::MKLDNN, at::Float32Op::MATMUL) == at::Float32Precision::TF32;
|
||||
}
|
||||
|
||||
// returns an ideep::tensor
|
||||
|
|
|
|||
|
|
@ -2493,7 +2493,8 @@ Call this whenever a new thread is created in order to propagate values from
|
|||
py_module.def(
|
||||
"_get_fp32_precision_getter",
|
||||
[](const std::string& backend, const std::string& op) {
|
||||
return at::globalContext().float32Precision(backend, op);
|
||||
return at::precision2str(at::globalContext().float32Precision(
|
||||
at::str2backend(backend), at::str2op(op)));
|
||||
});
|
||||
|
||||
py_module.def(
|
||||
|
|
@ -2501,7 +2502,10 @@ Call this whenever a new thread is created in order to propagate values from
|
|||
[](const std::string& backend,
|
||||
const std::string& op,
|
||||
const std::string& precision) {
|
||||
at::globalContext().setFloat32Precision(backend, op, precision);
|
||||
at::globalContext().setFloat32Precision(
|
||||
at::str2backend(backend),
|
||||
at::str2op(op),
|
||||
at::str2precision(precision));
|
||||
return precision;
|
||||
});
|
||||
|
||||
|
|
|
|||
|
|
@ -634,7 +634,9 @@ struct GlobalStateGuard {
|
|||
_torch_function_all_disabled = at::impl::torch_function_all_disabled();
|
||||
_deterministic_algorithms = ctx.deterministicAlgorithms();
|
||||
_deterministic_algorithms_warn_only = ctx.deterministicAlgorithmsWarnOnly();
|
||||
_allow_tf32 = ctx.float32Precision("cuda", "matmul") == "tf32";
|
||||
_allow_tf32 =
|
||||
ctx.float32Precision(at::Float32Backend::CUDA, at::Float32Op::MATMUL) ==
|
||||
at::Float32Precision::TF32;
|
||||
_allow_fp16_reduce = ctx.allowFP16ReductionCuBLAS();
|
||||
_allow_bf16_reduce = ctx.allowBF16ReductionCuBLAS();
|
||||
_num_threads = at::get_num_threads();
|
||||
|
|
@ -651,7 +653,10 @@ struct GlobalStateGuard {
|
|||
_deterministic_algorithms == ctx.deterministicAlgorithms() &&
|
||||
_deterministic_algorithms_warn_only ==
|
||||
ctx.deterministicAlgorithmsWarnOnly() &&
|
||||
_allow_tf32 == (ctx.float32Precision("cuda", "matmul") == "tf32") &&
|
||||
_allow_tf32 ==
|
||||
(ctx.float32Precision(
|
||||
at::Float32Backend::CUDA, at::Float32Op::MATMUL) ==
|
||||
at::Float32Precision::TF32) &&
|
||||
_allow_fp16_reduce == ctx.allowFP16ReductionCuBLAS() &&
|
||||
_allow_bf16_reduce == ctx.allowBF16ReductionCuBLAS() &&
|
||||
_num_threads == at::get_num_threads()) &&
|
||||
|
|
@ -672,7 +677,10 @@ struct GlobalStateGuard {
|
|||
if (_deterministic_algorithms_warn_only !=
|
||||
ctx.deterministicAlgorithmsWarnOnly())
|
||||
os << "deterministic_algorithms_warn_only ";
|
||||
if (_allow_tf32 != (ctx.float32Precision("cuda", "matmul") == "tf32"))
|
||||
if (_allow_tf32 !=
|
||||
(ctx.float32Precision(
|
||||
at::Float32Backend::CUDA, at::Float32Op::MATMUL) ==
|
||||
at::Float32Precision::TF32))
|
||||
os << "allow_tf32 ";
|
||||
if (_allow_fp16_reduce != ctx.allowFP16ReductionCuBLAS())
|
||||
os << "allow_fp16_reduce ";
|
||||
|
|
|
|||
|
|
@ -397,7 +397,9 @@ std::unique_ptr<KinetoObserverContext> ThreadLocalSubqueue::begin_op(
|
|||
|
||||
event->start_time_ = c10::getApproximateTime();
|
||||
event->allow_tf32_cublas_ =
|
||||
at::globalContext().float32Precision("cuda", "matmul") == "tf32";
|
||||
at::globalContext().float32Precision(
|
||||
at::Float32Backend::CUDA, at::Float32Op::MATMUL) ==
|
||||
at::Float32Precision::TF32;
|
||||
if (!config_.experimental_config.performance_events.empty()) {
|
||||
const size_t n = config_.experimental_config.performance_events.size();
|
||||
event->counters_ = std::make_unique<perf_counters_t>(n, 0);
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user