Revert "refine fp32 precision api (#125888)"

This reverts commit 4c11b26158.

Reverted https://github.com/pytorch/pytorch/pull/125888 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it seems to cause some failures on ROCm ([comment](https://github.com/pytorch/pytorch/pull/125888#issuecomment-2869274791))
This commit is contained in:
PyTorch MergeBot 2025-05-11 00:35:46 +00:00
parent e4f22822cb
commit fdc387ec7c
22 changed files with 34 additions and 589 deletions

View File

@ -19,69 +19,9 @@
#if defined(__aarch64__) && !defined(C10_MOBILE)
#include <cpuinfo.h>
#endif
namespace at {
namespace {
/*
These const variables defined the fp32 precisions for different backend
We have "generic", "cuda", "mkldnn" backend now and we can choose fp32
prevision from "ieee", "tf32", "bf16" and "none". The "ieee" precision means
IEEE standard floating point format "tf32" and "bf16" means we are allowed to
use "tf32" or "bf16" as internal computation data types for fp32 computations.
And "none" means it is override-able by parent's node
generic->mkldnn->matmul
->conv
->rnn
->cuda ->matmul
->conv
->rnn
*/
const std::map<std::string, std::vector<std::string>> _fp32_precisions = {
{"generic", {{"ieee", "tf32", "bf16", "none"}}},
{"mkldnn", {{"ieee", "bf16", "none"}}},
{"cuda", {{"ieee", "tf32", "none"}}}};
// Check whether the backend and op are legal
void check_fp32_prec_backend_and_op(
const std::string& backend,
const std::string& op) {
static std::vector<std::string> backends = {"generic", "mkldnn", "cuda"};
static std::vector<std::string> operators = {"conv", "matmul", "rnn", "all"};
TORCH_CHECK(
std::find(backends.begin(), backends.end(), backend) != backends.end(),
"Invalid backend: ",
backend);
TORCH_CHECK(
std::find(operators.begin(), operators.end(), op) != operators.end(),
"Invalid operator: ",
op);
if (backend == "generic") {
TORCH_CHECK(op == "all", "Invalid operation for generic backend: ", op);
}
}
// Return whether the precision is supported by backends
bool validate_fp32_prec(
const std::string& backend,
const std::string& precision) {
auto iterp = _fp32_precisions.find(backend);
TORCH_CHECK(iterp != _fp32_precisions.end());
auto precisions = iterp->second;
bool valid = std::find(precisions.begin(), precisions.end(), precision) !=
precisions.end();
return valid;
}
C10_ALWAYS_INLINE void warn_deprecated_fp32_precision_api(){
TORCH_WARN_ONCE(
"This API is going to be deprecated, please see "
"https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices"
);
}
} // namespace
Context::Context() = default;
// TODO: This could be bad juju if someone calls globalContext() in the
@ -175,29 +115,12 @@ void Context::setUserEnabledNNPACK(bool e) {
enabled_nnpack = e;
}
bool Context::allowTF32CuDNN(const std::string& op) const {
if (op.size() == 0){
bool allow_tf32_rnn = float32Precision("cuda", "rnn") == "tf32";
bool allow_tf32_conv = float32Precision("cuda", "conv") == "tf32";
TORCH_CHECK(
allow_tf32_rnn == allow_tf32_conv && allow_tf32_rnn == allow_tf32_cudnn,
"PyTorch is checking whether allow_tf32 is enabled for cuDNN without a specific operator name,",
"but the current flag(s) indicate that cuDNN conv and cuDNN RNN have different TF32 flags.",
"This combination indicates that you have used a mix of the legacy and new APIs to set the TF32 flags. ",
"We suggest only using the new API to set the TF32 flag(s). See also: ",
"https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices");
} else {
return float32Precision("cuda", op) == "tf32";
}
warn_deprecated_fp32_precision_api();
bool Context::allowTF32CuDNN() const {
return allow_tf32_cudnn;
}
void Context::setAllowTF32CuDNN(bool b) {
setFloat32Precision("cuda", "rnn", b ? "tf32" : "none");
setFloat32Precision("cuda", "conv", b ? "tf32" : "none");
allow_tf32_cudnn = b;
warn_deprecated_fp32_precision_api();
}
void Context::setSDPPriorityOrder(const std::vector<int64_t>& order) {
@ -336,16 +259,7 @@ bool Context::allowTF32CuBLAS() const {
return false;
}
#endif
bool legacy_allow_tf32 = float32_matmul_precision != at::Float32MatmulPrecision::HIGHEST;
bool allow_tf32_new = float32Precision("cuda", "matmul") == "tf32";
TORCH_CHECK(
legacy_allow_tf32 == allow_tf32_new,
"PyTorch is checking whether allow_tf32_new is enabled for cuBlas matmul,",
"Current status indicate that you have used mix of the legacy and new APIs to set the TF32 status for cublas matmul. ",
"We suggest only using the new API to set the TF32 flag. See also: ",
"https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices");
warn_deprecated_fp32_precision_api();
return allow_tf32_new;
return float32_matmul_precision != at::Float32MatmulPrecision::HIGHEST;
}
void Context::setAllowTF32CuBLAS(bool b) {
@ -358,54 +272,27 @@ void Context::setAllowTF32CuBLAS(bool b) {
}
#endif
float32_matmul_precision = b ? at::Float32MatmulPrecision::HIGH : at::Float32MatmulPrecision::HIGHEST;
setFloat32Precision("cuda", "matmul", b ? "tf32" : "ieee");
}
Float32MatmulPrecision Context::float32MatmulPrecision() const {
bool invalid = float32Precision("cuda", "matmul") == "tf32" &&
float32_matmul_precision == at::Float32MatmulPrecision::HIGHEST;
invalid = invalid ||
(float32Precision("mkldnn", "matmul") == "bf16" &&
float32_matmul_precision != at::Float32MatmulPrecision::MEDIUM);
TORCH_CHECK(
!invalid,
"PyTorch is checking the matmul precision without a specific backend name,",
"Current status indicate that you have used mix of the legacy and new APIs to set the matmul precision. ",
"We suggest only using the new API for matmul precision. See also: ",
"https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices");
warn_deprecated_fp32_precision_api();
return float32_matmul_precision;
}
std::string Context::float32Precision(const std::string& backend, const std::string& op) const {
check_fp32_prec_backend_and_op(backend, op);
auto precision = fp32_precision.find(backend)->second.find(op)->second;
if (precision == "none")
precision = fp32_precision.find(backend)->second.find("all")->second;
if (precision == "none")
precision = fp32_precision.find("generic")->second.find("all")->second;
bool valid_prec = validate_fp32_prec(backend, precision);
return valid_prec ? precision : "none";
void Context::setFloat32MatmulPrecision(Float32MatmulPrecision p) {
float32_matmul_precision = p;
}
void Context::setFloat32MatmulPrecision(const std::string &s) {
auto match = [this](const std::string & s_) {
warn_deprecated_fp32_precision_api();
// TODO: consider if CuDNN field needs to also be set for potential future CuDNN ops like multi-headed attention
if (s_ == "highest") {
float32_matmul_precision = at::Float32MatmulPrecision::HIGHEST;
setFloat32Precision("cuda", "matmul", "ieee");
setFloat32Precision("mkldnn", "matmul", "ieee");
return true;
} else if (s_ == "high") {
float32_matmul_precision = at::Float32MatmulPrecision::HIGH;
setFloat32Precision("cuda", "matmul", "tf32");
setFloat32Precision("mkldnn", "matmul", "ieee");
return true;
} else if (s_ == "medium") {
float32_matmul_precision = at::Float32MatmulPrecision::MEDIUM;
setFloat32Precision("cuda", "matmul", "tf32");
setFloat32Precision("mkldnn", "matmul", "bf16");
return true;
}
return false;
@ -419,27 +306,6 @@ void Context::setFloat32MatmulPrecision(const std::string &s) {
"setFloat32MatmulPrecision call has no effect.");
}
void Context::setFloat32Precision(const std::string& backend, const std::string& op, const std::string& p) {
check_fp32_prec_backend_and_op(backend, op);
if (validate_fp32_prec(backend, p)) {
fp32_precision[backend][op] = p;
} else {
std::string msg;
auto iterp = _fp32_precisions.find(backend);
TORCH_CHECK(iterp != _fp32_precisions.end());
for (auto p : iterp->second) {
msg += p;
msg += " ";
}
TORCH_WARN(
"you have set wrong precision for backend:",
backend,
" setFloat32Precision call has no effect.",
"Please choose precision from: ",
msg);
}
}
at::LinalgBackend Context::linalgPreferredBackend() const {
return linalg_preferred_backend;
}

View File

@ -28,7 +28,6 @@
#include <c10/util/irange.h>
#include <cstdint>
#include <map>
#include <mutex>
namespace at {
@ -337,20 +336,14 @@ class TORCH_API Context {
void alertCuBLASConfigNotDeterministic() const;
void setFloat32MatmulPrecision(const std::string& s);
void setFloat32Precision(
const std::string& backend,
const std::string& op,
const std::string& s);
bool allowTF32CuDNN(const std::string& op = std::string()) const;
bool allowTF32CuDNN() const;
void setAllowTF32CuDNN(bool);
bool allowTF32OneDNN() const;
void setAllowTF32OneDNN(bool);
bool allowTF32CuBLAS() const;
void setAllowTF32CuBLAS(bool);
Float32MatmulPrecision float32MatmulPrecision() const;
std::string float32Precision(
const std::string& backend,
const std::string& op) const;
void setFloat32MatmulPrecision(Float32MatmulPrecision p);
bool allowFP16ReductionCuBLAS() const;
void setAllowFP16ReductionCuBLAS(bool);
bool allowBF16ReductionCuBLAS() const;
@ -476,23 +469,6 @@ class TORCH_API Context {
bool enable_sparse_tensor_invariant_checks = false;
bool allow_fp16_reduction_cpu = false;
std::map<std::string, std::map<std::string, std::string>> fp32_precision = {
{"generic", {{"all", "none"}}},
{"mkldnn",
{{"matmul", "none"},
{"conv", "none"},
{"rnn", "none"},
{"all", "none"}}},
{"cuda",
{{"matmul",
float32_matmul_precision == at::Float32MatmulPrecision::HIGHEST
? "none"
: "tf32"},
{"conv", "tf32"},
{"rnn", "tf32"},
{"all", "none"}}},
};
Allocator* prev_allocator_ptr_{nullptr};
};

View File

@ -403,7 +403,7 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
computeType = CUBLAS_COMPUTE_64F;
scaleType = CUDA_R_64F;
} else if constexpr (std::is_same_v<Dtype, float>) {
if (at::globalContext().float32Precision("cuda", "matmul") == "tf32") {
if (at::globalContext().allowTF32CuBLAS()) {
computeType = CUBLAS_COMPUTE_32F_FAST_TF32;
}
} else if constexpr (std::is_same_v<Dtype, c10::complex<double>>) {
@ -1574,7 +1574,7 @@ bool gemm_and_bias(
computeType = CUBLAS_COMPUTE_64F;
scaleType = CUDA_R_64F;
} else if constexpr (std::is_same_v<Dtype, float>) {
if (at::globalContext().float32Precision("cuda", "matmul") == "tf32") {
if (at::globalContext().allowTF32CuBLAS()) {
computeType = CUBLAS_COMPUTE_32F_FAST_TF32;
}
} else if constexpr (std::is_same_v<Dtype, at::Half>) {

View File

@ -218,8 +218,7 @@ cublasHandle_t getCurrentCUDABlasHandle() {
// On CUDA >= 11, and architecture >= Ampere, cuBLAS can use TF32 to speedup
// FP32 data type calculations based on the value of the allow_tf32 flag.
// To enable TF32, set the math mode of the handle to CUBLAS_TF32_TENSOR_OP_MATH.
if (!NoTF32Guard::should_disable_tf32() &&
at::globalContext().float32Precision("cuda", "matmul") == "tf32") {
if (!NoTF32Guard::should_disable_tf32() && at::globalContext().allowTF32CuBLAS()) {
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TF32_TENSOR_OP_MATH));
} else {
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));

View File

@ -1187,7 +1187,7 @@ at::Tensor convolution(
bool deterministic = ctx.deterministicCuDNN() || ctx.deterministicAlgorithms();
return at::_convolution(input, weight, bias, stride, padding, dilation,
transposed, output_padding, groups,
ctx.benchmarkCuDNN(), deterministic, ctx.userEnabledCuDNN(), ctx.allowTF32CuDNN("conv"));
ctx.benchmarkCuDNN(), deterministic, ctx.userEnabledCuDNN(), ctx.allowTF32CuDNN());
}
at::Tensor convolution_overrideable(
@ -1332,7 +1332,7 @@ ConvBackend select_conv_backend(
params.benchmark = ctx.benchmarkCuDNN();
params.deterministic = ctx.deterministicCuDNN() || ctx.deterministicAlgorithms();
params.cudnn_enabled = ctx.userEnabledCuDNN();
params.allow_tf32 = ctx.allowTF32CuDNN("conv");
params.allow_tf32 = ctx.allowTF32CuDNN();
auto input = input_r;
auto weight = weight_r;
@ -1719,7 +1719,7 @@ at::Tensor _convolution(
c10::MaybeOwned<Tensor> bias_r_maybe_owned = at::borrow_from_optional_tensor(bias_r_opt);
const Tensor& bias_r = *bias_r_maybe_owned;
return at::_convolution(input_r, weight_r, bias_r, stride_, padding_, dilation_, transposed_, output_padding_, groups_, benchmark, deterministic, cudnn_enabled, at::globalContext().allowTF32CuDNN("conv"));
return at::_convolution(input_r, weight_r, bias_r, stride_, padding_, dilation_, transposed_, output_padding_, groups_, benchmark, deterministic, cudnn_enabled, at::globalContext().allowTF32CuDNN());
}
std::tuple<Tensor, Tensor, Tensor> convolution_backward_overrideable(
@ -2017,7 +2017,7 @@ std::tuple<Tensor, Tensor, Tensor> convolution_backward(
params.benchmark = ctx.benchmarkCuDNN();
params.deterministic = ctx.deterministicCuDNN() || ctx.deterministicAlgorithms();
params.cudnn_enabled = ctx.userEnabledCuDNN();
params.allow_tf32 = ctx.allowTF32CuDNN("conv");
params.allow_tf32 = ctx.allowTF32CuDNN();
// Validate inputs.
check_shape_backward(input, weight.sizes(), params);

View File

@ -169,8 +169,7 @@ std::string repro_from_args(const ConvolutionParams& params) {
ss << "If that doesn't trigger the error, please include your original repro script when reporting this issue.\n\n";
ss << "import torch\n";
ss << "torch.backends.cuda.matmul.allow_tf32 = "
<< pybool(at::globalContext().float32Precision("cuda", "matmul") == "tf32")
<< "\n";
<< pybool(at::globalContext().allowTF32CuBLAS()) << "\n";
ss << "torch.backends.cudnn.benchmark = "
<< pybool(at::globalContext().benchmarkCuDNN()) << "\n";
ss << "torch.backends.cudnn.deterministic = " << pybool(params.deterministic)
@ -726,7 +725,7 @@ Tensor cudnn_convolution_relu(
auto& ctx = at::globalContext();
bool benchmark = ctx.benchmarkCuDNN();
bool allow_tf32 = ctx.allowTF32CuDNN("conv");
bool allow_tf32 = ctx.allowTF32CuDNN();
auto _bias = bias_t.has_value()
? bias_t.value()
: at::zeros(
@ -784,7 +783,7 @@ Tensor cudnn_convolution_add_relu(
}
auto& ctx = at::globalContext();
bool allow_tf32 = ctx.allowTF32CuDNN("conv");
bool allow_tf32 = ctx.allowTF32CuDNN();
bool benchmark = ctx.benchmarkCuDNN();
auto _alpha = alpha.has_value() ? alpha.value().to<float>() : 1.0;
auto _bias = bias_t.has_value()

View File

@ -245,7 +245,7 @@ descriptor(cudnnHandle_t handle, DropoutDescriptor&& dropout_desc) const {
datatype,
input_datatype,
algo,
at::globalContext().allowTF32CuDNN("rnn"));
at::globalContext().allowTF32CuDNN());
#else
rnn_desc.set(
handle,
@ -261,7 +261,7 @@ descriptor(cudnnHandle_t handle, DropoutDescriptor&& dropout_desc) const {
datatype,
input_datatype,
algo,
at::globalContext().allowTF32CuDNN("rnn"));
at::globalContext().allowTF32CuDNN());
#endif
return rnn_desc;
}

View File

@ -104,7 +104,7 @@ static bool use_mkldnn_fp16_matmul() {
}
static bool use_mkldnn_bf32_matmul() {
return use_mkldnn_bf16_matmul() && at::globalContext().float32Precision("mkldnn", "matmul") == "bf16";
return use_mkldnn_bf16_matmul() && at::globalContext().float32MatmulPrecision() == at::Float32MatmulPrecision::MEDIUM;
}

View File

@ -133,44 +133,6 @@ To toggle the TF32 flags off in C++, you can do
at::globalContext().setAllowTF32CuBLAS(false);
at::globalContext().setAllowTF32CuDNN(false);
After Pytorch 2.7, we provide a new sets of APIs to control the TF32 behavior in a more fine-grained way.
We can set float32 precision per backend and per operators. We can also override the global setting for a specific operator.
.. code:: python
torch.backends.fp32_precision = "ieee"
torch.backends.cuda.matmul.fp32_precision = "ieee"
torch.backends.cudnn.fp32_precision = "ieee"
torch.backends.cudnn.conv.fp32_precision = "tf32"
torch.backends.cudnn.rnn.fp32_precision = "tf32"
The fp32_precision can be set to `ieee` or `tf32` for `cuda/cudnn`.
`ieee` fp32_precision indicate that we will use `FP32` as internal computation precision.
`tf32` fp32_precision indicate that we will allow to use `TF32` as internal computation precision.
We can override a generic setting for a specific operator if the fp32_precision is set to `ieee`.
.. code:: python
torch.backends.cudnn.fp32_precision = "tf32"
torch.backends.cudnn.conv.fp32_precision = "ieee"
torch.backends.cudnn.rnn.fp32_precision = "ieee"
We can also override a generic setting for a specific backend if the fp32_precision is set to `ieee`.
.. code:: python
torch.backends.fp32_precision = "tf32"
torch.backends.cudnn.fp32_precision = "ieee"
torch.backends.cudnn.conv.fp32_precision = "ieee"
torch.backends.cudnn.rnn.fp32_precision = "ieee"
For above 2 cases, both `torch.backends.cudnn.conv.fp32_precision` and `torch.backends.cudnn.rnn.fp32_precision`
is overrided to `ieee`.
Old settings are still supported. But we suggest to use the new settings for better control. And we do not support
to use mix of old and new settings.
For more information about TF32, see:
- `TensorFloat-32`_

View File

@ -1,102 +0,0 @@
.. meta::
:description: A guide to torch.backends.mkldnn, a PyTorch backend to run MKLDNN operations
:keywords: optimize PyTorch, MKLDNN
.. _mkldnn_backend:
MKLDNN backend
---------------------------------------------------
MKLDNN is an open-source cross-platform performance library of basic building blocks
for deep learning applications.
.. code:: python
# The flag below controls whether enable MKLDNN backend in Pytorch.
torch.backends.mkldnn.enabled = True
Users can disable MKLDNN backend by:
.. code:: python
torch.backends.mkldnn.enabled = False
.. _bf16_on_mkldnn:
Bfloat16 (BF16) on MKLDNN backend
---------------------------------------------------
Starting in PyTorch 2.4, there is a set of APIs to control the internal computation precision
for `float32` operators.
.. code:: python
# The flag below controls the internal computation precision for mkldnn matmul. Default ieee is float32.
torch.backends.mkldnn.matmul.fp32_precision = "ieee"
# The flag below controls the internal computation precision for mkldnn conv. Default ieee is float32.
torch.backends.mkldnn.conv.fp32_precision = "ieee"
# The flag below controls the internal computation precision for mkldnn rnn. Default ieee is float32.
torch.backends.mkldnn.rnn.fp32_precision = "ieee"
Note that besides matmuls and convolutions themselves, functions and nn modules that internally uses
matmuls or convolutions are also affected. These include :class:`torch.nn.Linear`, :class:`torch.nn._ConvNd`, :func:`torch.cdist`,
:func:`torch.tensordot`, :func:`torch.nn.functional.affine_grid` and :func:`torch.nn.functional.grid_sample`,
:class:`torch.nn.AdaptiveLogSoftmaxWithLoss`, :class:`torch.nn.GRU` and :class:`torch.nn.LSTM`.
To get an idea of the precision and speed, see the example code and benchmark data (on SPR) below:
.. code:: python
torch.manual_seed(0)
a_full = torch.randn(10240, 10240, dtype=torch.double)
b_full = torch.randn(10240, 10240, dtype=torch.double)
ab_full = a_full @ b_full
mean = ab_full.abs().mean() # 80.7451
a = a_full.float()
b = b_full.float()
# Do matmul at BF16 mode.
torch.backends.mkldnn.matmul.fp32_precision = 'bf16'
ab_bf16 = a @ b # expected speedup with BF16 dot-product acceleration
error = (ab_bf16 - ab_full).abs().max() # 1.3704
relative_error = error / mean # 0.0170
print(error, relative_error)
# Do matmul FP32 mode.
torch.backends.mkldnn.matmul.fp32_precision = 'ieee'
ab_fp32 = a @ b
error = (ab_fp32 - ab_full).abs().max() # 0.0003
relative_error = error / mean # 0.00000317
print(error, relative_error)
From the above example, we can see that with BF16, the speed is ~7x faster on SPR, and that
relative error compared to double precision is approximately 2 orders of magnitude larger.
If full FP32 precision is needed, users can disable BF16 by:
.. code:: python
torch.backends.mkldnn.matmul.fp32_precision = 'ieee'
torch.backends.mkldnn.conv.fp32_precision = 'ieee'
torch.backends.mkldnn.rnn.fp32_precision = 'ieee'
To toggle the BF16 flags off in C++, you can do
.. code:: C++
at::globalContext().setFloat32Precision("ieee", "mkldnn", "matmul");
at::globalContext().setFloat32Precision("ieee", "mkldnn", "conv");
at::globalContext().setFloat32Precision("ieee", "mkldnn", "rnn");
We can override a generic setting for a specific operator or backend if the fp32_precision is set to `ieee`.
.. code:: python
torch.backends.fp32_precision = "bf16"
torch.backends.mkldnn.fp32_precision = "ieee"
torch.backends.mkldnn.matmul.fp32_precision = "ieee"
For such case, both `torch.backends.mkldnn.fp32_precision` and `torch.backends.mkldnn.matmul.fp32_precision`
is overrided to bf16.

View File

@ -63,7 +63,6 @@ from torch.testing._internal.common_utils import (
load_tests,
MI300_ARCH,
parametrize,
recover_orig_fp32_precision,
run_tests,
serialTest,
setBlasBackendsToDefaultFinally,
@ -829,55 +828,6 @@ class TestCuda(TestCase):
):
self.assertTrue(torch.backends.cudnn.allow_tf32)
@recover_orig_fp32_precision
def test_fp32_precision_with_tf32(self):
with torch.backends.cudnn.flags(
enabled=None,
benchmark=None,
benchmark_limit=None,
deterministic=None,
allow_tf32=True,
fp32_precision="none",
):
self.assertEqual(torch.backends.cudnn.conv.fp32_precision, "tf32")
self.assertEqual(torch.backends.cudnn.rnn.fp32_precision, "tf32")
with torch.backends.cudnn.flags(
enabled=None,
benchmark=None,
benchmark_limit=None,
deterministic=None,
allow_tf32=False,
fp32_precision="none",
):
self.assertEqual(torch.backends.cudnn.conv.fp32_precision, "none")
self.assertEqual(torch.backends.cudnn.rnn.fp32_precision, "none")
@recover_orig_fp32_precision
def test_fp32_precision_with_float32_matmul_precision(self):
torch.set_float32_matmul_precision("highest")
self.assertEqual(torch.backends.cuda.matmul.fp32_precision, "ieee")
torch.set_float32_matmul_precision("high")
self.assertEqual(torch.backends.cuda.matmul.fp32_precision, "tf32")
torch.set_float32_matmul_precision("medium")
self.assertEqual(torch.backends.cuda.matmul.fp32_precision, "tf32")
@recover_orig_fp32_precision
def test_invalid_status_for_legacy_api(self):
torch.backends.cudnn.conv.fp32_precision = "none"
torch.backends.cudnn.rnn.fp32_precision = "tf32"
with self.assertRaisesRegex(RuntimeError, "mix of the legacy and new APIs"):
print(torch.backends.cudnn.allow_tf32)
torch.set_float32_matmul_precision("highest")
torch.backends.cuda.matmul.fp32_precision = "tf32"
with self.assertRaisesRegex(RuntimeError, "mix of the legacy and new APIs"):
print(torch.get_float32_matmul_precision())
if not TEST_WITH_ROCM:
with self.assertRaisesRegex(RuntimeError, "mix of the legacy and new APIs"):
print(torch.backends.cuda.matmul.allow_tf32)
def test_type_conversions(self):
x = torch.randn(5, 5)
self.assertIsInstance(x.float(), torch.FloatTensor)

View File

@ -22,7 +22,7 @@ import torch.backends.mkldnn
from torch.utils import mkldnn as mkldnn_utils
from torch.testing._internal.common_utils import TestCase, \
run_tests, TemporaryFileName, gradcheck, gradgradcheck, IS_WINDOWS, \
skipIfTorchDynamo, xfailIfTorchDynamo, recover_orig_fp32_precision
skipIfTorchDynamo, xfailIfTorchDynamo
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,
dtypes,
@ -1629,53 +1629,6 @@ class TestMkldnn(TestCase):
with self.assertRaises(ValueError):
torch.mkldnn_max_pool2d(x, kernel_size=3, stride=0)
@recover_orig_fp32_precision
def test_mlkdnn_get_set(self):
# get/set mkldnn ops
with torch.backends.mkldnn.flags(enabled=None, fp32_precision="bf16"):
self.assertEqual(torch.backends.mkldnn.fp32_precision, "bf16")
with torch.backends.mkldnn.flags(enabled=None, fp32_precision="none"):
self.assertEqual(torch.backends.mkldnn.fp32_precision, "none")
# get/set matmul
torch.backends.mkldnn.matmul.fp32_precision = "bf16"
self.assertEqual(torch.backends.mkldnn.matmul.fp32_precision, "bf16")
torch.backends.mkldnn.matmul.fp32_precision = "none"
self.assertEqual(torch.backends.mkldnn.matmul.fp32_precision, "none")
# get/set conv
torch.backends.mkldnn.conv.fp32_precision = "bf16"
self.assertEqual(torch.backends.mkldnn.conv.fp32_precision, "bf16")
torch.backends.mkldnn.conv.fp32_precision = "none"
self.assertEqual(torch.backends.mkldnn.conv.fp32_precision, "none")
# get/set rnn
torch.backends.mkldnn.rnn.fp32_precision = "bf16"
self.assertEqual(torch.backends.mkldnn.rnn.fp32_precision, "bf16")
torch.backends.mkldnn.rnn.fp32_precision = "none"
self.assertEqual(torch.backends.mkldnn.rnn.fp32_precision, "none")
@recover_orig_fp32_precision
def test_generic_precision(self):
with torch.backends.flags(fp32_precision="none"):
self.assertEqual(torch.backends.fp32_precision, "none")
with torch.backends.flags(fp32_precision="tf32"):
self.assertEqual(torch.backends.fp32_precision, "tf32")
@recover_orig_fp32_precision
def test_default_use_parent(self):
torch.backends.mkldnn.matmul.fp32_precision = "none"
with torch.backends.mkldnn.flags(enabled=None, fp32_precision="bf16"):
self.assertEqual(torch.backends.mkldnn.matmul.fp32_precision, "bf16")
with torch.backends.mkldnn.flags(enabled=None, fp32_precision="none"):
with torch.backends.flags(fp32_precision="bf16"):
self.assertEqual(torch.backends.mkldnn.matmul.fp32_precision, "bf16")
with torch.backends.flags(fp32_precision="tf32"):
# when parent is a not supported precision, use default
self.assertEqual(torch.backends.mkldnn.matmul.fp32_precision, "none")
@recover_orig_fp32_precision
def test_invalid(self):
# use default if user set a not supported precision
torch.backends.mkldnn.matmul.fp32_precision = "tf32"
self.assertEqual(torch.backends.mkldnn.matmul.fp32_precision, "none")
instantiate_device_type_tests(TestMkldnn, globals(), only_for=('cpu',))

View File

@ -1316,8 +1316,6 @@ def _disabled_torch_dispatch_impl(
) -> Any: ... # THPModule_disable_dispatch_function
def _get_linalg_preferred_backend() -> torch._C._LinalgBackend: ...
def _set_linalg_preferred_backend(arg: torch._C._LinalgBackend): ...
def _get_fp32_precision_getter(backend: str, op: str) -> str: ...
def _set_fp32_precision_setter(backend: str, op: str, value: str) -> str: ...
class _LinalgBackend:
Default: _LinalgBackend

View File

@ -251,9 +251,7 @@ def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]:
cuda_rng_state = None
if torch.cuda.is_available():
cuda_rng_state = torch.cuda.get_rng_state()
cuda_matmul_fp32_prec = torch._C._get_fp32_precision_getter(
"cuda", "matmul"
)
allow_tf32 = torch._C._get_cublas_allow_tf32()
prior_fwd_from_src = torch.fx.graph_module._forward_from_src
torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result
cleanup = setup_compile_debug()
@ -285,9 +283,7 @@ def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]:
torch._C._unset_default_mobile_cpu_allocator()
if cuda_rng_state is not None:
torch.cuda.set_rng_state(cuda_rng_state)
torch._C._set_fp32_precision_setter(
"cuda", "matmul", cuda_matmul_fp32_prec
)
torch._C._set_cublas_allow_tf32(allow_tf32)
torch.fx.graph_module._forward_from_src = prior_fwd_from_src
assert guards.check(), (
f"Global {guards.reason()}state changed while dynamo tracing, please report a bug"

View File

@ -1,10 +1,7 @@
# mypy: allow-untyped-defs
import sys
import types
from contextlib import contextmanager
import torch
# The idea for this parameter is that we forbid bare assignment
# to torch.backends.<cudnn|mkldnn>.enabled and friends when running our
@ -60,70 +57,6 @@ class PropModule(types.ModuleType):
return self.m.__getattribute__(attr)
class _FP32Precision:
def __init__(self, backend, op):
self.backend = backend
self.op = op
def __setattr__(self, name, value):
if name == "fp32_precision":
torch._C._set_fp32_precision_setter(self.backend, self.op, value)
elif name in ("backend", "op"):
super().__setattr__(name, value)
else:
raise AttributeError("Unknown attribute " + name)
def __getattr__(self, name):
if name == "fp32_precision":
return torch._C._get_fp32_precision_getter(self.backend, self.op)
else:
raise AttributeError("Unknown attribute " + name)
def set_flags(_fp32_precision="none"):
orig_flags = (torch._C._get_fp32_precision_getter("generic", "all"),)
if _fp32_precision is not None:
torch._C._set_fp32_precision_setter("generic", "all", _fp32_precision)
return orig_flags
@contextmanager
def flags(fp32_precision="none"):
with __allow_nonbracketed_mutation():
orig_flags = set_flags(fp32_precision)
try:
yield
finally:
with __allow_nonbracketed_mutation():
set_flags(*orig_flags)
def _get_fp32_precision_getter(backend, op):
def inner():
return torch._C._get_fp32_precision_getter(backend, op)
return inner
def _set_fp32_precision_setter(backend, op):
def inner(precision):
return torch._C._set_fp32_precision_setter(backend, op, precision)
return inner
class GenericModule(PropModule):
def __init__(self, m, name):
super().__init__(m, name)
fp32_precision = ContextProp(
_get_fp32_precision_getter("generic", "all"),
_set_fp32_precision_setter("generic", "all"),
)
sys.modules[__name__] = GenericModule(sys.modules[__name__], __name__)
from torch.backends import (
cpu as cpu,
cuda as cuda,

View File

@ -135,8 +135,6 @@ class cuBLASModule:
return torch._C._get_cublas_allow_bf16_reduced_precision_reduction()
elif name == "allow_fp16_accumulation":
return torch._C._get_cublas_allow_fp16_accumulation()
elif name == "fp32_precision":
return torch._C._get_fp32_precision_getter("cuda", "matmul")
raise AttributeError("Unknown attribute " + name)
def __setattr__(self, name, value):
@ -148,8 +146,6 @@ class cuBLASModule:
return torch._C._set_cublas_allow_bf16_reduced_precision_reduction(value)
elif name == "allow_fp16_accumulation":
return torch._C._set_cublas_allow_fp16_accumulation(value)
elif name == "fp32_precision":
return torch._C._set_fp32_precision_setter("cuda", "matmul", value)
raise AttributeError("Unknown attribute " + name)

View File

@ -6,14 +6,7 @@ from contextlib import contextmanager
from typing import Optional
import torch
from torch.backends import (
__allow_nonbracketed_mutation,
_FP32Precision,
_get_fp32_precision_getter,
_set_fp32_precision_setter,
ContextProp,
PropModule,
)
from torch.backends import __allow_nonbracketed_mutation, ContextProp, PropModule
try:
@ -135,7 +128,6 @@ def set_flags(
_benchmark_limit=None,
_deterministic=None,
_allow_tf32=None,
_fp32_precision="none",
):
orig_flags = (
torch._C._get_cudnn_enabled(),
@ -143,7 +135,6 @@ def set_flags(
None if not is_available() else torch._C._cuda_get_cudnn_benchmark_limit(),
torch._C._get_cudnn_deterministic(),
torch._C._get_cudnn_allow_tf32(),
torch._C._get_fp32_precision_getter("cuda", "all"),
)
if _enabled is not None:
torch._C._set_cudnn_enabled(_enabled)
@ -155,8 +146,6 @@ def set_flags(
torch._C._set_cudnn_deterministic(_deterministic)
if _allow_tf32 is not None:
torch._C._set_cudnn_allow_tf32(_allow_tf32)
if _fp32_precision is not None:
torch._C._set_fp32_precision_setter("cuda", "all", _fp32_precision)
return orig_flags
@ -167,16 +156,10 @@ def flags(
benchmark_limit=10,
deterministic=False,
allow_tf32=True,
fp32_precision="none",
):
with __allow_nonbracketed_mutation():
orig_flags = set_flags(
enabled,
benchmark,
benchmark_limit,
deterministic,
allow_tf32,
fp32_precision,
enabled, benchmark, benchmark_limit, deterministic, allow_tf32
)
try:
yield
@ -211,12 +194,6 @@ class CudnnModule(PropModule):
allow_tf32 = ContextProp(
torch._C._get_cudnn_allow_tf32, torch._C._set_cudnn_allow_tf32
)
conv = _FP32Precision("cuda", "conv")
rnn = _FP32Precision("cuda", "rnn")
fp32_precision = ContextProp(
_get_fp32_precision_getter("cuda", "all"),
_set_fp32_precision_setter("cuda", "all"),
)
# This is the sys.modules replacement trick, see

View File

@ -4,14 +4,7 @@ from contextlib import contextmanager
from typing import TYPE_CHECKING
import torch
from torch.backends import (
__allow_nonbracketed_mutation,
_FP32Precision,
_get_fp32_precision_getter,
_set_fp32_precision_setter,
ContextProp,
PropModule,
)
from torch.backends import __allow_nonbracketed_mutation, ContextProp, PropModule
def is_available():
@ -71,14 +64,11 @@ class verbose:
return False
def set_flags(
_enabled=None, _deterministic=None, _allow_tf32=None, _fp32_precision="none"
):
def set_flags(_enabled=None, _deterministic=None, _allow_tf32=None):
orig_flags = (
torch._C._get_mkldnn_enabled(),
torch._C._get_mkldnn_deterministic(),
torch._C._get_onednn_allow_tf32(),
torch._C._get_fp32_precision_getter("mkldnn", "all"),
)
if _enabled is not None:
torch._C._set_mkldnn_enabled(_enabled)
@ -86,15 +76,13 @@ def set_flags(
torch._C._set_mkldnn_deterministic(_deterministic)
if _allow_tf32 is not None:
torch._C._set_onednn_allow_tf32(_allow_tf32)
if _fp32_precision is not None:
torch._C._set_fp32_precision_setter("mkldnn", "all", _fp32_precision)
return orig_flags
@contextmanager
def flags(enabled=False, deterministic=False, allow_tf32=True, fp32_precision="none"):
def flags(enabled=False, deterministic=False, allow_tf32=True):
with __allow_nonbracketed_mutation():
orig_flags = set_flags(enabled, deterministic, allow_tf32, fp32_precision)
orig_flags = set_flags(enabled, deterministic, allow_tf32)
try:
yield
finally:
@ -116,13 +104,6 @@ class MkldnnModule(PropModule):
allow_tf32 = ContextProp(
torch._C._get_onednn_allow_tf32, torch._C._set_onednn_allow_tf32
)
matmul = _FP32Precision("mkldnn", "matmul")
conv = _FP32Precision("mkldnn", "conv")
rnn = _FP32Precision("mkldnn", "rnn")
fp32_precision = ContextProp(
_get_fp32_precision_getter("mkldnn", "all"),
_set_fp32_precision_setter("generic", "all"),
)
if TYPE_CHECKING:

View File

@ -667,12 +667,10 @@ static PyObject* THPModule_setAllowTF32CuDNN(PyObject* _unused, PyObject* arg) {
}
static PyObject* THPModule_allowTF32CuDNN(PyObject* _unused, PyObject* noargs) {
HANDLE_TH_ERRORS
if (at::globalContext().allowTF32CuDNN())
Py_RETURN_TRUE;
else
Py_RETURN_FALSE;
END_HANDLE_TH_ERRORS
}
static PyObject* THPModule_setFloat32MatmulPrecision(
@ -693,7 +691,6 @@ static PyObject* THPModule_setFloat32MatmulPrecision(
static PyObject* THPModule_float32MatmulPrecision(
PyObject* _unused,
PyObject* noargs) {
HANDLE_TH_ERRORS
std::string s = "highest";
auto p = at::globalContext().float32MatmulPrecision();
if (p == at::Float32MatmulPrecision::HIGH) {
@ -702,7 +699,6 @@ static PyObject* THPModule_float32MatmulPrecision(
s = "medium";
}
return THPUtils_packString(s);
END_HANDLE_TH_ERRORS
}
static PyObject* THPModule_setSDPPriorityOrder(
PyObject* _unused,
@ -1117,12 +1113,10 @@ static PyObject* THPModule_setAllowTF32CuBLAS(
static PyObject* THPModule_allowTF32CuBLAS(
PyObject* _unused,
PyObject* noargs) {
HANDLE_TH_ERRORS
if (at::globalContext().allowTF32CuBLAS()) {
Py_RETURN_TRUE;
}
Py_RETURN_FALSE;
END_HANDLE_TH_ERRORS
}
static PyObject* THPModule_setAllowFP16ReductionCuBLAS(
@ -2293,18 +2287,6 @@ Call this whenever a new thread is created in order to propagate values from
at::DataPtr(reinterpret_cast<void*>(data_ptr), device));
});
py_module.def(
"_get_fp32_precision_getter", [](std::string backend, std::string op) {
return at::globalContext().float32Precision(backend, op);
});
py_module.def(
"_set_fp32_precision_setter",
[](std::string backend, std::string op, std::string precision) {
at::globalContext().setFloat32Precision(backend, op, precision);
return precision;
});
py_module.def(
"_stash_obj_in_tls", [](const std::string& key, py::handle arg) {
at::impl::ThreadLocalPythonObjects::get_state().set(

View File

@ -590,7 +590,7 @@ struct GlobalStateGuard {
_torch_function_all_disabled = at::impl::torch_function_all_disabled();
_deterministic_algorithms = ctx.deterministicAlgorithms();
_deterministic_algorithms_warn_only = ctx.deterministicAlgorithmsWarnOnly();
_allow_tf32 = ctx.float32Precision("cuda", "matmul") == "tf32";
_allow_tf32 = ctx.allowTF32CuBLAS();
_allow_fp16_reduce = ctx.allowFP16ReductionCuBLAS();
_allow_bf16_reduce = ctx.allowBF16ReductionCuBLAS();
_num_threads = at::get_num_threads();
@ -607,7 +607,7 @@ struct GlobalStateGuard {
_deterministic_algorithms == ctx.deterministicAlgorithms() &&
_deterministic_algorithms_warn_only ==
ctx.deterministicAlgorithmsWarnOnly() &&
_allow_tf32 == (ctx.float32Precision("cuda", "matmul") == "tf32") &&
_allow_tf32 == ctx.allowTF32CuBLAS() &&
_allow_fp16_reduce == ctx.allowFP16ReductionCuBLAS() &&
_allow_bf16_reduce == ctx.allowBF16ReductionCuBLAS() &&
_num_threads == at::get_num_threads()) &&
@ -628,7 +628,7 @@ struct GlobalStateGuard {
if (_deterministic_algorithms_warn_only !=
ctx.deterministicAlgorithmsWarnOnly())
os << "deterministic_algorithms_warn_only ";
if (_allow_tf32 != (ctx.float32Precision("cuda", "matmul") == "tf32"))
if (_allow_tf32 != ctx.allowTF32CuBLAS())
os << "allow_tf32 ";
if (_allow_fp16_reduce != ctx.allowFP16ReductionCuBLAS())
os << "allow_fp16_reduce ";

View File

@ -396,8 +396,7 @@ std::unique_ptr<KinetoObserverContext> ThreadLocalSubqueue::begin_op(
}
event->start_time_ = c10::getApproximateTime();
event->allow_tf32_cublas_ =
at::globalContext().float32Precision("cuda", "matmul") == "tf32";
event->allow_tf32_cublas_ = at::globalContext().allowTF32CuBLAS();
if (!config_.experimental_config.performance_events.empty()) {
const size_t n = config_.experimental_config.performance_events.size();
event->counters_ = std::make_unique<perf_counters_t>(n, 0);

View File

@ -5635,25 +5635,5 @@ def scoped_load_inline(func):
return cpp_extension.load_inline(*args, **kwargs)
return func(*args, load_inline=load_inline, **kwargs)
return wrapper
def recover_orig_fp32_precision(fn):
@contextlib.contextmanager
def recover():
old_mkldnn_conv_p = torch.backends.mkldnn.conv.fp32_precision # type: ignore[attr-defined]
old_mkldnn_rnn_p = torch.backends.mkldnn.rnn.fp32_precision # type: ignore[attr-defined]
old_mkldnn_matmul_p = torch.backends.mkldnn.matmul.fp32_precision # type: ignore[attr-defined]
old_cudnn_conv_p = torch.backends.cudnn.conv.fp32_precision # type: ignore[attr-defined]
old_cudnn_rnn_p = torch.backends.cudnn.rnn.fp32_precision # type: ignore[attr-defined]
old_cuda_matmul_p = torch.backends.cuda.matmul.fp32_precision
try:
yield
finally:
torch.backends.mkldnn.conv.fp32_precision = old_mkldnn_conv_p # type: ignore[attr-defined]
torch.backends.mkldnn.rnn.fp32_precision = old_mkldnn_rnn_p # type: ignore[attr-defined]
torch.backends.mkldnn.matmul.fp32_precision = old_mkldnn_matmul_p # type: ignore[attr-defined]
torch.backends.cudnn.conv.fp32_precision = old_cudnn_conv_p # type: ignore[attr-defined]
torch.backends.cudnn.rnn.fp32_precision = old_cudnn_rnn_p # type: ignore[attr-defined]
torch.backends.cuda.matmul.fp32_precision = old_cuda_matmul_p
return recover()(fn)