mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
Use PrecisionConfig to determine whether TF32 is used on NVIDIA GPUs.
If the Precision for any input is HIGHEST or if the TF32 global variable is set to false, then TF32 will not be used. Otherwise TF32 will be used. In the future (hopefully soon), XLA will no longer use the TF32 global variable and only use the PrecisionConfig. This change is the first step to removing usage of the TF32 global variable. Before, the PrecisionConfig was used for dots in some cases and never for convolutions. Now the PrecisionConfig is always used on NVIDIA GPUs, for both dots and convolutions. Previously in cases where the PrecisionConfig was used, XLA was inconsistent whether HIGH was enough to disable TF32, or if HIGHEST was required. Which PrecisionConfig was required to disable TF32 depended on the codepath taken. Now HIGHEST is always required to disable TF32. Previously, a ComputePrecision (a typedef-ed int) was passed to some StreamExecutor methods to indicate precision, with 0 being interpreted as TF32 and nonzero being interpreted as FP32. Such functions now take NumericOptions instead, which has an allow_tf32 field added. TensorFlow OpKernels are updated to pass the correct NumericOptions to StreamExecutor. The "TensorFloat-32 will be used for the matrix multiplication" log message is now removed. It was not particularly useful and there was a TODO to remove it. PiperOrigin-RevId: 540448127
This commit is contained in:
parent
704dddce91
commit
14ea9d18c3
|
|
@ -588,6 +588,7 @@ tf_xla_py_strict_test(
|
|||
"//tensorflow/python/eager:def_function",
|
||||
"//tensorflow/python/framework:config",
|
||||
"//tensorflow/python/framework:ops",
|
||||
"//tensorflow/python/framework:test_lib",
|
||||
"//tensorflow/python/ops:array_ops",
|
||||
"//tensorflow/python/ops:math_ops",
|
||||
"//tensorflow/python/ops:nn_ops",
|
||||
|
|
|
|||
|
|
@ -18,13 +18,14 @@ from tensorflow.compiler.tests import xla_test
|
|||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import config
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
class TensorFloat32ConvTest(xla_test.XLATestCase):
|
||||
class TensorFloat32Test(xla_test.XLATestCase):
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
|
|
@ -51,6 +52,13 @@ class TensorFloat32ConvTest(xla_test.XLATestCase):
|
|||
# operand_precision is not in HLO if it's the default value.
|
||||
self.assertNotIn('operand_precision', hlo_text)
|
||||
|
||||
# On Ampere GPUs and above, which support TF32, test TF32 is used by
|
||||
# asserting outputs are not close to FP64 result
|
||||
if test_util.is_gpu_available(min_cuda_compute_capability=(8, 0)):
|
||||
out = compiled_fn(*inputs)
|
||||
f64_out = compiled_fn(*[math_ops.cast(x, 'float64') for x in inputs])
|
||||
self.assertNotAllClose(out, f64_out, rtol=1e-5, atol=1e-5)
|
||||
|
||||
def test_matmul(self):
|
||||
x = array_ops.fill((1024, 1024), 1 + 2**-12)
|
||||
y = array_ops.fill((1024, 1024), 1.0)
|
||||
|
|
@ -70,8 +78,8 @@ class TensorFloat32ConvTest(xla_test.XLATestCase):
|
|||
self._test_fn(batch_matmul, [x, y])
|
||||
|
||||
def test_conv2d(self):
|
||||
x = array_ops.fill((2, 20, 20, 32), 1 + 2**-12)
|
||||
y = array_ops.fill((3, 3, 32, 32), 1.0)
|
||||
x = array_ops.fill((16, 40, 40, 64), 1 + 2**-12)
|
||||
y = array_ops.fill((3, 3, 64, 64), 1.0)
|
||||
|
||||
def conv2d(x, y):
|
||||
return nn_ops.conv2d(x, y, [1, 1, 1, 1], padding='SAME')
|
||||
|
|
@ -79,23 +87,23 @@ class TensorFloat32ConvTest(xla_test.XLATestCase):
|
|||
self._test_fn(conv2d, [x, y])
|
||||
|
||||
def test_conv2d_backprop_input(self):
|
||||
y = array_ops.fill((3, 3, 32, 32), 1 + 2**-12)
|
||||
out_backprop = array_ops.fill((2, 20, 20, 32), 1.0)
|
||||
y = array_ops.fill((3, 3, 64, 64), 1 + 2**-12)
|
||||
out_backprop = array_ops.fill((16, 40, 40, 64), 1.0)
|
||||
|
||||
def conv2d_backprop_input(y, out_backprop):
|
||||
return nn_ops.conv2d_backprop_input(
|
||||
(2, 20, 20, 32), y, out_backprop, [1, 1, 1, 1], padding='SAME'
|
||||
(16, 40, 40, 64), y, out_backprop, [1, 1, 1, 1], padding='SAME'
|
||||
)
|
||||
|
||||
self._test_fn(conv2d_backprop_input, [y, out_backprop])
|
||||
|
||||
def test_conv2d_backprop_filter(self):
|
||||
x = array_ops.fill((2, 20, 20, 32), 1 + 2**-12)
|
||||
out_backprop = array_ops.fill((2, 20, 20, 32), 1.0)
|
||||
x = array_ops.fill((16, 40, 40, 64), 1 + 2**-12)
|
||||
out_backprop = array_ops.fill((16, 40, 40, 64), 1.0)
|
||||
|
||||
def conv2d_backprop_filter(x, out_backprop):
|
||||
return nn_ops.conv2d_backprop_filter(
|
||||
x, (3, 3, 32, 32), out_backprop, [1, 1, 1, 1], padding='SAME'
|
||||
x, (3, 3, 64, 64), out_backprop, [1, 1, 1, 1], padding='SAME'
|
||||
)
|
||||
|
||||
self._test_fn(conv2d_backprop_filter, [x, out_backprop])
|
||||
|
|
@ -103,4 +111,7 @@ class TensorFloat32ConvTest(xla_test.XLATestCase):
|
|||
|
||||
if __name__ == '__main__':
|
||||
ops.enable_eager_execution()
|
||||
# Enable determinism, since otherwise the autotuner may nondeterministically
|
||||
# chooses either a TF32 or non-TF32 algorithm
|
||||
config.enable_op_determinism()
|
||||
googletest.main()
|
||||
|
|
|
|||
|
|
@ -345,8 +345,6 @@ xla_test(
|
|||
"//tensorflow/compiler/xla/tests:test_macros_header",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/tsl/lib/core:status_test_util",
|
||||
"//tensorflow/tsl/platform:tensor_float_32_hdr_lib",
|
||||
"//tensorflow/tsl/platform:tensor_float_32_utils",
|
||||
"//tensorflow/tsl/platform:test",
|
||||
],
|
||||
)
|
||||
|
|
@ -558,7 +556,6 @@ xla_test(
|
|||
"//tensorflow/compiler/xla/tests:test_macros_header",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/tsl/lib/core:status_test_util",
|
||||
"//tensorflow/tsl/platform:tensor_float_32_hdr_lib",
|
||||
"//tensorflow/tsl/platform:test",
|
||||
],
|
||||
)
|
||||
|
|
@ -611,8 +608,6 @@ xla_test(
|
|||
"//tensorflow/compiler/xla/tests:test_macros_header",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/tsl/lib/core:status_test_util",
|
||||
"//tensorflow/tsl/platform:tensor_float_32_hdr_lib",
|
||||
"//tensorflow/tsl/platform:tensor_float_32_utils",
|
||||
"//tensorflow/tsl/platform:test",
|
||||
],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -29,16 +29,12 @@ limitations under the License.
|
|||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/tsl/lib/core/status_test_util.h"
|
||||
#include "tensorflow/tsl/platform/tensor_float_32_utils.h"
|
||||
|
||||
namespace {
|
||||
|
||||
using QrTest = xla::ClientLibraryTestBase;
|
||||
|
||||
XLA_TEST_F(QrTest, Simple) {
|
||||
// Test fails with TensorFloat-32 enabled
|
||||
tsl::enable_tensor_float_32_execution(false);
|
||||
|
||||
xla::Array2D<float> data({
|
||||
{4, 6, 8, 10},
|
||||
{6, 45, 54, 63},
|
||||
|
|
@ -79,8 +75,6 @@ XLA_TEST_F(QrTest, Simple) {
|
|||
}
|
||||
|
||||
XLA_TEST_F(QrTest, ZeroDiagonal) {
|
||||
// Test fails with TensorFloat-32 enabled
|
||||
tsl::enable_tensor_float_32_execution(false);
|
||||
xla::XlaBuilder builder(TestName());
|
||||
|
||||
xla::Array2D<float> a_vals({
|
||||
|
|
@ -106,8 +100,6 @@ XLA_TEST_F(QrTest, ZeroDiagonal) {
|
|||
}
|
||||
|
||||
XLA_TEST_F(QrTest, SimpleBatched) {
|
||||
// Test fails with TensorFloat-32 enabled
|
||||
tsl::enable_tensor_float_32_execution(false);
|
||||
xla::XlaBuilder builder(TestName());
|
||||
|
||||
xla::Array3D<float> a_vals({
|
||||
|
|
@ -136,8 +128,6 @@ XLA_TEST_F(QrTest, SimpleBatched) {
|
|||
}
|
||||
|
||||
XLA_TEST_F(QrTest, SubnormalComplex) {
|
||||
tsl::enable_tensor_float_32_execution(false);
|
||||
|
||||
// Verifies that we don't get NaNs in the case that the norm of a complex
|
||||
// number would be denormal but its imaginary value is not exactly 0.
|
||||
xla::Array2D<xla::complex64> a_vals({
|
||||
|
|
|
|||
|
|
@ -35,9 +35,6 @@ limitations under the License.
|
|||
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/tsl/lib/core/status_test_util.h"
|
||||
#if GOOGLE_CUDA
|
||||
#include "tensorflow/tsl/platform/tensor_float_32_utils.h"
|
||||
#endif
|
||||
|
||||
namespace xla {
|
||||
|
||||
|
|
@ -77,18 +74,8 @@ class SelfAdjointEigTest : public ClientLibraryTestBase {
|
|||
{4, 5, 10, 11},
|
||||
{3, 9, 11, 17},
|
||||
};
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
tf32_init_state_ = tsl::tensor_float_32_execution_enabled();
|
||||
tsl::enable_tensor_float_32_execution(false);
|
||||
#endif
|
||||
}
|
||||
void TearDown() override {
|
||||
ClientLibraryTestBase::TearDown();
|
||||
#if GOOGLE_CUDA
|
||||
tsl::enable_tensor_float_32_execution(tf32_init_state_);
|
||||
#endif
|
||||
}
|
||||
void TearDown() override { ClientLibraryTestBase::TearDown(); }
|
||||
|
||||
Array3D<float> GetUnitMatrix3D(const Array3D<float>& matrix) {
|
||||
Array3D<float> result(matrix.n1(), matrix.n2(), matrix.n3(), 0.0);
|
||||
|
|
@ -123,7 +110,6 @@ class SelfAdjointEigTest : public ClientLibraryTestBase {
|
|||
Array2D<float> matrix2d_8x8_;
|
||||
Array2D<float> low_rank_4x4_;
|
||||
Array2D<int> wrong_type_4x4_;
|
||||
bool tf32_init_state_;
|
||||
};
|
||||
|
||||
XlaOp GetAverageAbsoluteError(XlaOp m1, XlaOp m2, XlaBuilder* builder) {
|
||||
|
|
|
|||
|
|
@ -35,7 +35,6 @@ limitations under the License.
|
|||
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/tsl/lib/core/status_test_util.h"
|
||||
#include "tensorflow/tsl/platform/tensor_float_32_utils.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
|
|
@ -57,9 +56,6 @@ class SVDTest : public ClientLibraryTestBase {
|
|||
{12, 48, 6, 62, 3},
|
||||
},
|
||||
};
|
||||
|
||||
// Test fails with TensorFloat-32 enabled
|
||||
tsl::enable_tensor_float_32_execution(false);
|
||||
}
|
||||
void TearDown() override { ClientLibraryTestBase::TearDown(); }
|
||||
|
||||
|
|
|
|||
|
|
@ -4602,6 +4602,9 @@ PrecisionConfig* HloInstruction::mutable_precision_config() {
|
|||
if (auto* dot = DynCast<HloDotInstruction>(this)) {
|
||||
return dot->mutable_precision_config();
|
||||
}
|
||||
if (auto* custom_call = DynCast<HloCustomCallInstruction>(this)) {
|
||||
return custom_call->mutable_precision_config();
|
||||
}
|
||||
LOG(FATAL) << "Unimplemented method.";
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -146,6 +146,7 @@ StatusOr<HloInstruction*> UpdateLayoutForCudnnConvolution(
|
|||
normalized_conv->set_feature_group_count(hlo->feature_group_count());
|
||||
normalized_conv->set_raw_backend_config_string(
|
||||
hlo->raw_backend_config_string());
|
||||
*normalized_conv->mutable_precision_config() = hlo->precision_config();
|
||||
normalized_conv->parent()->parent()->SetAndUniquifyInstrName(normalized_conv,
|
||||
hlo->name());
|
||||
|
||||
|
|
|
|||
|
|
@ -292,6 +292,7 @@ StatusOr<std::optional<se::blas::AlgorithmType>> DoGemmAutotune(
|
|||
const DebugOptions& debug_options =
|
||||
gemm->GetModule()->config().debug_options();
|
||||
AutotuneConfig autotune_config = GetConfig(debug_options);
|
||||
const bool deterministic_ops = debug_options.xla_gpu_deterministic_ops();
|
||||
|
||||
TF_ASSIGN_OR_RETURN(GemmConfig config, GemmConfig::For(gemm));
|
||||
// Don't run autotuning concurrently on the same GPU.
|
||||
|
|
@ -395,9 +396,10 @@ StatusOr<std::optional<se::blas::AlgorithmType>> DoGemmAutotune(
|
|||
// should always return true, and the actual
|
||||
// success-ness is returned in
|
||||
// ProfileResult::is_valid.
|
||||
TF_RETURN_IF_ERROR(RunGemm(
|
||||
config, lhs_buffer, rhs_buffer, output_buffer,
|
||||
stream, algorithm, &profile_result));
|
||||
TF_RETURN_IF_ERROR(
|
||||
RunGemm(config, lhs_buffer, rhs_buffer,
|
||||
output_buffer, deterministic_ops,
|
||||
stream, algorithm, &profile_result));
|
||||
return std::move(profile_result);
|
||||
}));
|
||||
|
||||
|
|
|
|||
|
|
@ -623,12 +623,17 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
|
|||
instr,
|
||||
m::Convert(GemmOrCublasLtMatmul(&existing_gemm).WithOneUser())) &&
|
||||
existing_gemm->operands().size() == 2) {
|
||||
TF_ASSIGN_OR_RETURN(GemmBackendConfig gemm_backend_config,
|
||||
existing_gemm->backend_config<GemmBackendConfig>());
|
||||
|
||||
// check if type combination is supported here
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
bool types_are_supported,
|
||||
IsLegacyCublasMatmul(*existing_gemm)
|
||||
? TypesAreSupportedByLegacyCublas(*existing_gemm, instr)
|
||||
: TypesAreSupportedByCublasLt(*existing_gemm, instr));
|
||||
? TypesAreSupportedByLegacyCublas(*existing_gemm,
|
||||
gemm_backend_config, instr)
|
||||
: TypesAreSupportedByCublasLt(*existing_gemm,
|
||||
gemm_backend_config, instr));
|
||||
if (types_are_supported) {
|
||||
return FuseMatrixConvert(existing_gemm, instr);
|
||||
}
|
||||
|
|
@ -1392,7 +1397,8 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
|
|||
}
|
||||
|
||||
StatusOr<bool> TypesAreSupportedByLegacyCublas(
|
||||
const HloInstruction &instr, const HloInstruction *bias = nullptr) const {
|
||||
const HloInstruction &instr, const GemmBackendConfig &gemm_backend_config,
|
||||
const HloInstruction *bias = nullptr) const {
|
||||
// Figure out the Atype/Btype.
|
||||
const PrimitiveType a_dtype = instr.operand(0)->shape().element_type();
|
||||
const PrimitiveType b_dtype = instr.operand(1)->shape().element_type();
|
||||
|
|
@ -1478,7 +1484,8 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
|
|||
}
|
||||
|
||||
StatusOr<bool> TypesAreSupportedByCublasLt(
|
||||
const HloInstruction &instr, const HloInstruction *bias = nullptr) const {
|
||||
const HloInstruction &instr, const GemmBackendConfig &backend_config,
|
||||
const HloInstruction *bias = nullptr) const {
|
||||
// Figure out the Atype/Btype.
|
||||
const PrimitiveType a_dtype = instr.operand(0)->shape().element_type();
|
||||
const PrimitiveType b_dtype = instr.operand(1)->shape().element_type();
|
||||
|
|
@ -1494,10 +1501,12 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
|
|||
// Figure out the computeType and scaleType.
|
||||
TF_ASSIGN_OR_RETURN(const se::blas::DataType output_dtype,
|
||||
AsBlasDataType(output_type));
|
||||
TF_ASSIGN_OR_RETURN(const se::blas::ComputationType compute_type,
|
||||
GetBlasComputationType(
|
||||
a_dtype, output_type,
|
||||
stream_executor::blas::kDefaultComputePrecision));
|
||||
int max_precision = *absl::c_max_element(
|
||||
backend_config.precision_config().operand_precision());
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
const se::blas::ComputationType compute_type,
|
||||
GetBlasComputationType(a_dtype, instr.shape().element_type(),
|
||||
max_precision));
|
||||
se::blas::DataType scale_type =
|
||||
cublas_lt::GetScaleType(output_dtype, compute_type);
|
||||
|
||||
|
|
@ -1641,8 +1650,9 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
|
|||
const HloInstruction *rhs = instr.operand(1);
|
||||
const Shape &output_shape = instr.shape();
|
||||
|
||||
TF_ASSIGN_OR_RETURN(bool types_are_supported_by_cublas_lt,
|
||||
TypesAreSupportedByCublasLt(instr));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
bool types_are_supported_by_cublas_lt,
|
||||
TypesAreSupportedByCublasLt(instr, gemm_backend_config));
|
||||
if (!types_are_supported_by_cublas_lt) {
|
||||
return false;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -29,19 +29,22 @@ namespace gpu {
|
|||
GemmThunk::GemmThunk(ThunkInfo thunk_info, GemmConfig config,
|
||||
const BufferAllocation::Slice& lhs_buffer,
|
||||
const BufferAllocation::Slice& rhs_buffer,
|
||||
const BufferAllocation::Slice& output_buffer)
|
||||
const BufferAllocation::Slice& output_buffer,
|
||||
bool deterministic)
|
||||
: Thunk(Kind::kGemm, thunk_info),
|
||||
config_(std::move(config)),
|
||||
lhs_buffer_(lhs_buffer),
|
||||
rhs_buffer_(rhs_buffer),
|
||||
output_buffer_(output_buffer) {}
|
||||
output_buffer_(output_buffer),
|
||||
deterministic_(deterministic) {}
|
||||
|
||||
Status GemmThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
VLOG(3) << "Running GEMM thunk";
|
||||
const BufferAllocations& allocs = *params.buffer_allocations;
|
||||
return RunGemm(config_, allocs.GetDeviceAddress(lhs_buffer_),
|
||||
allocs.GetDeviceAddress(rhs_buffer_),
|
||||
allocs.GetDeviceAddress(output_buffer_), params.stream);
|
||||
allocs.GetDeviceAddress(output_buffer_), deterministic_,
|
||||
params.stream);
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ class GemmThunk : public Thunk {
|
|||
GemmThunk(ThunkInfo thunk_info, GemmConfig config,
|
||||
const BufferAllocation::Slice& lhs_buffer,
|
||||
const BufferAllocation::Slice& rhs_buffer,
|
||||
const BufferAllocation::Slice& output_buffer);
|
||||
const BufferAllocation::Slice& output_buffer, bool deterministic);
|
||||
|
||||
GemmThunk(const GemmThunk&) = delete;
|
||||
GemmThunk& operator=(const GemmThunk&) = delete;
|
||||
|
|
@ -44,6 +44,8 @@ class GemmThunk : public Thunk {
|
|||
const BufferAllocation::Slice lhs_buffer_;
|
||||
const BufferAllocation::Slice rhs_buffer_;
|
||||
const BufferAllocation::Slice output_buffer_;
|
||||
// Whether to run deterministically.
|
||||
const bool deterministic_;
|
||||
};
|
||||
|
||||
} // namespace gpu
|
||||
|
|
|
|||
|
|
@ -112,7 +112,7 @@ StatusOr<se::DeviceMemory<uint8_t>> ScratchAllocator::AllocateBytes(
|
|||
|
||||
StatusOr<std::vector<MaybeFusedConvRunner>> GetAlgorithms(
|
||||
const GpuConvConfig& config, se::Stream* stream, bool use_cudnn_frontend,
|
||||
bool use_fallback, bool deterministic_ops) {
|
||||
bool use_fallback, const se::NumericOptions& numeric_options) {
|
||||
TF_ASSIGN_OR_RETURN(se::dnn::ConvolutionKind kind,
|
||||
GetDNNConvKindFromCudnnConvKind(config.kind));
|
||||
|
||||
|
|
@ -146,8 +146,7 @@ StatusOr<std::vector<MaybeFusedConvRunner>> GetAlgorithms(
|
|||
/* leakyrelu_alpha = */ 0.0, stream, config.input_descriptor,
|
||||
config.filter_descriptor, config.bias_descriptor,
|
||||
config.output_descriptor, config.conv_desc, use_fallback,
|
||||
config.fusion->mode, se::NumericOptions{deterministic_ops},
|
||||
&runners));
|
||||
config.fusion->mode, numeric_options, &runners));
|
||||
for (auto& runner : runners) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto runner_cache,
|
||||
|
|
@ -172,8 +171,7 @@ StatusOr<std::vector<MaybeFusedConvRunner>> GetAlgorithms(
|
|||
/* filter_data = */ DeviceMemoryBase(nullptr),
|
||||
config.output_descriptor,
|
||||
/* output_data = */ DeviceMemoryBase(nullptr), config.conv_desc,
|
||||
use_fallback, nullptr, se::NumericOptions{deterministic_ops},
|
||||
&runners));
|
||||
use_fallback, nullptr, numeric_options, &runners));
|
||||
for (auto& runner : runners) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto runner_cache,
|
||||
|
|
@ -194,7 +192,7 @@ GetMIOpenAlgorithms(const HloCustomCallInstruction* instr,
|
|||
se::DeviceMemoryBase result_buffer,
|
||||
se::StreamExecutor* stream_exec,
|
||||
ScratchAllocator* scratch_allocator, se::Stream* stream,
|
||||
bool deterministic_ops) {
|
||||
const se::NumericOptions& numeric_options) {
|
||||
TF_ASSIGN_OR_RETURN(GpuConvConfig config, GetGpuConvConfig(instr));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(se::dnn::ConvolutionKind kind,
|
||||
|
|
@ -213,7 +211,7 @@ GetMIOpenAlgorithms(const HloCustomCallInstruction* instr,
|
|||
params.config->filter_descriptor, params.filter_buf,
|
||||
params.config->output_descriptor, params.output_buf,
|
||||
params.config->conv_desc, /* use_fallback = */ false, scratch_allocator,
|
||||
se::NumericOptions{deterministic_ops}, &runners));
|
||||
numeric_options, &runners));
|
||||
|
||||
return runners;
|
||||
}
|
||||
|
|
@ -828,6 +826,15 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda(
|
|||
const bool cudnn_frontend_enabled =
|
||||
debug_options.xla_gpu_enable_cudnn_frontend();
|
||||
const bool deterministic_ops = debug_options.xla_gpu_deterministic_ops();
|
||||
bool allow_tf32 = true;
|
||||
// TODO(b/284371623): Properly set allow_tf32 even if instr==nullptr, which is
|
||||
// the case when running an AOT compiled executable with runtime autotuning.
|
||||
if (instr) {
|
||||
allow_tf32 = absl::c_all_of(
|
||||
instr->precision_config().operand_precision(),
|
||||
[](int precision) { return precision <= PrecisionConfig::HIGH; });
|
||||
}
|
||||
const se::NumericOptions numeric_options{deterministic_ops, allow_tf32};
|
||||
|
||||
// Use the first algorithm that's supported as reference. There isn't a
|
||||
// particular reason to use it, as any algorithm suffices. It doesn't make
|
||||
|
|
@ -838,7 +845,7 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda(
|
|||
std::vector<MaybeFusedConvRunner> runners,
|
||||
GetAlgorithms(runtime_arguments.gpu_conv_config, stream,
|
||||
cudnn_frontend_enabled,
|
||||
/* use_fallback = */ false, deterministic_ops));
|
||||
/* use_fallback = */ false, numeric_options));
|
||||
|
||||
std::vector<AutotuneResult> profile_results;
|
||||
for (auto& runner_cache : runners) {
|
||||
|
|
@ -862,7 +869,7 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda(
|
|||
std::vector<MaybeFusedConvRunner> fallback_runners,
|
||||
GetAlgorithms(runtime_arguments.gpu_conv_config, stream,
|
||||
cudnn_frontend_enabled,
|
||||
/* use_fallback = */ true, deterministic_ops));
|
||||
/* use_fallback = */ true, numeric_options));
|
||||
|
||||
for (auto& runner_cache : fallback_runners) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
|
|
@ -961,6 +968,10 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheRocm(
|
|||
const DebugOptions& debug_options =
|
||||
instr->GetModule()->config().debug_options();
|
||||
const bool deterministic_ops = debug_options.xla_gpu_deterministic_ops();
|
||||
const bool allow_tf32 = absl::c_all_of(
|
||||
instr->precision_config().operand_precision(),
|
||||
[](int precision) { return precision <= PrecisionConfig::HIGH; });
|
||||
const se::NumericOptions numeric_options{deterministic_ops, allow_tf32};
|
||||
|
||||
se::StreamExecutor* stream_exec = std::get<DeviceConfig>(config_).stream_exec;
|
||||
const auto device_ordinal = stream_exec->device_ordinal();
|
||||
|
|
@ -998,7 +1009,7 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheRocm(
|
|||
std::vector<std::unique_ptr<const se::dnn::ConvRunner>> runners,
|
||||
GetMIOpenAlgorithms(instr, absl::MakeSpan(operand_buffers), result_buffer,
|
||||
stream_exec, &scratch_allocator, stream,
|
||||
deterministic_ops));
|
||||
numeric_options));
|
||||
|
||||
std::vector<AutotuneResult> profile_results;
|
||||
|
||||
|
|
|
|||
|
|
@ -560,6 +560,7 @@ HloInstruction* CreateGpuConv(absl::string_view call_target, const Shape& shape,
|
|||
const Window& window,
|
||||
const ConvolutionDimensionNumbers& dnums,
|
||||
int64_t feature_group_count,
|
||||
const PrecisionConfig& precision_config,
|
||||
const OpMetadata& metadata) {
|
||||
HloComputation* computation = lhs->parent();
|
||||
|
||||
|
|
@ -579,6 +580,7 @@ HloInstruction* CreateGpuConv(absl::string_view call_target, const Shape& shape,
|
|||
custom_call->set_window(window);
|
||||
custom_call->set_convolution_dimension_numbers(dnums);
|
||||
custom_call->set_feature_group_count(feature_group_count);
|
||||
*custom_call->mutable_precision_config() = precision_config;
|
||||
custom_call->set_metadata(metadata);
|
||||
|
||||
// Give the customcall a user-friendly name.
|
||||
|
|
@ -665,14 +667,16 @@ static StatusOr<HloInstruction*> CreateCustomCallHelper(HloInstruction* conv) {
|
|||
auto& [window, dnums, rhs] = *m;
|
||||
return CreateGpuConv(kCudnnConvBackwardInputCallTarget, conv->shape(),
|
||||
conv->mutable_operand(0), rhs, window, dnums,
|
||||
conv->feature_group_count(), conv->metadata());
|
||||
conv->feature_group_count(), conv->precision_config(),
|
||||
conv->metadata());
|
||||
}
|
||||
|
||||
if (ConvolutionMatch m = MatchBackwardFilter(conv)) {
|
||||
auto& [window, dnums, lhs] = *m;
|
||||
return CreateGpuConv(kCudnnConvBackwardFilterCallTarget, conv->shape(), lhs,
|
||||
conv->mutable_operand(1), window, dnums,
|
||||
conv->batch_group_count(), conv->metadata());
|
||||
conv->batch_group_count(), conv->precision_config(),
|
||||
conv->metadata());
|
||||
}
|
||||
|
||||
// If all else fails, try a forward convolution.
|
||||
|
|
@ -684,7 +688,8 @@ static StatusOr<HloInstruction*> CreateCustomCallHelper(HloInstruction* conv) {
|
|||
return CreateGpuConv(kCudnnConvForwardCallTarget, conv->shape(),
|
||||
conv->mutable_operand(0), conv->mutable_operand(1),
|
||||
conv->window(), conv->convolution_dimension_numbers(),
|
||||
conv->feature_group_count(), conv->metadata());
|
||||
conv->feature_group_count(), conv->precision_config(),
|
||||
conv->metadata());
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
|
|
|
|||
|
|
@ -1143,10 +1143,12 @@ Status IrEmitterUnnested::EmitGemmThunk(mlir::Operation* op) {
|
|||
TF_ASSIGN_OR_RETURN(auto a, GetAllocationSlice(gemm.getA()));
|
||||
TF_ASSIGN_OR_RETURN(auto b, GetAllocationSlice(gemm.getB()));
|
||||
TF_ASSIGN_OR_RETURN(auto c, GetAllocationSlice(gemm.getC()));
|
||||
bool deterministic_ops =
|
||||
hlo_module_config_.debug_options().xla_gpu_deterministic_ops();
|
||||
|
||||
TF_ASSIGN_OR_RETURN(GemmConfig config, GemmConfig::For(gemm));
|
||||
auto thunk =
|
||||
std::make_unique<GemmThunk>(GetThunkInfo(op), std::move(config), a, b, c);
|
||||
auto thunk = std::make_unique<GemmThunk>(GetThunkInfo(op), std::move(config),
|
||||
a, b, c, deterministic_ops);
|
||||
|
||||
AddThunkToThunkSequence(std::move(thunk));
|
||||
return OkStatus();
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@ limitations under the License.
|
|||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/stream_executor/blas.h"
|
||||
#include "tensorflow/compiler/xla/stream_executor/device_memory.h"
|
||||
#include "tensorflow/compiler/xla/stream_executor/numeric_options.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
|
@ -551,6 +552,7 @@ Status DoGemmWithAlgorithm(int64_t batch_size, int64_t m, int64_t n, int64_t k,
|
|||
Scale beta, se::Stream* stream,
|
||||
se::blas::AlgorithmType algorithm,
|
||||
se::blas::ComputePrecision compute_precision,
|
||||
const se::NumericOptions& numeric_options,
|
||||
se::blas::ProfileResult* profile_result) {
|
||||
CHECK(output.transpose == se::blas::Transpose::kNoTranspose);
|
||||
PrimitiveType lhs_type = primitive_util::NativeToPrimitiveType<Input>();
|
||||
|
|
@ -566,13 +568,13 @@ Status DoGemmWithAlgorithm(int64_t batch_size, int64_t m, int64_t n, int64_t k,
|
|||
lhs.leading_dim_stride, lhs.batch_stride, rhs.cast<Input>(),
|
||||
rhs.leading_dim_stride, rhs.batch_stride, beta, &output_data,
|
||||
output.leading_dim_stride, output.batch_stride, batch_size,
|
||||
computation_type, algorithm, compute_precision, profile_result);
|
||||
computation_type, algorithm, numeric_options, profile_result);
|
||||
} else {
|
||||
return stream->ThenBlasGemmWithAlgorithm(
|
||||
lhs.transpose, rhs.transpose, m, n, k, alpha, lhs.cast<Input>(),
|
||||
lhs.leading_dim_stride, rhs.cast<Input>(), rhs.leading_dim_stride, beta,
|
||||
&output_data, output.leading_dim_stride, computation_type, algorithm,
|
||||
compute_precision, profile_result);
|
||||
numeric_options, profile_result);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -583,6 +585,7 @@ Status DoGemm(int64_t batch_size, int64_t m, int64_t n, int64_t k,
|
|||
se::Stream* stream,
|
||||
std::optional<se::blas::AlgorithmType> algorithm,
|
||||
se::blas::ComputePrecision compute_precision,
|
||||
const se::NumericOptions& numeric_options,
|
||||
se::blas::ProfileResult* profile_result) {
|
||||
CHECK(output.transpose == se::blas::Transpose::kNoTranspose);
|
||||
se::DeviceMemory<Output> output_data(output.data);
|
||||
|
|
@ -591,7 +594,7 @@ Status DoGemm(int64_t batch_size, int64_t m, int64_t n, int64_t k,
|
|||
if (algorithm) {
|
||||
return DoGemmWithAlgorithm<Scale, Input, Output>(
|
||||
batch_size, m, n, k, lhs, rhs, output, alpha, beta, stream, *algorithm,
|
||||
compute_precision, profile_result);
|
||||
compute_precision, numeric_options, profile_result);
|
||||
}
|
||||
#endif
|
||||
|
||||
|
|
@ -601,20 +604,21 @@ Status DoGemm(int64_t batch_size, int64_t m, int64_t n, int64_t k,
|
|||
lhs.leading_dim_stride, lhs.batch_stride, rhs.cast<Input>(),
|
||||
rhs.leading_dim_stride, rhs.batch_stride, beta, &output_data,
|
||||
output.leading_dim_stride, output.batch_stride, batch_size,
|
||||
compute_precision);
|
||||
numeric_options);
|
||||
}
|
||||
|
||||
return stream->ThenBlasGemm(
|
||||
lhs.transpose, rhs.transpose, m, n, k, alpha, lhs.cast<Input>(),
|
||||
lhs.leading_dim_stride, rhs.cast<Input>(), rhs.leading_dim_stride, beta,
|
||||
&output_data, output.leading_dim_stride, compute_precision);
|
||||
&output_data, output.leading_dim_stride, numeric_options);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status RunGemm(const GemmConfig& config, se::DeviceMemoryBase lhs_buffer,
|
||||
se::DeviceMemoryBase rhs_buffer,
|
||||
se::DeviceMemoryBase output_buffer, se::Stream* stream,
|
||||
se::DeviceMemoryBase output_buffer, bool deterministic_ops,
|
||||
se::Stream* stream,
|
||||
std::optional<se::blas::AlgorithmType> algorithm,
|
||||
se::blas::ProfileResult* profile_result) {
|
||||
VLOG(2) << "Executing a GemmThunk";
|
||||
|
|
@ -635,6 +639,9 @@ Status RunGemm(const GemmConfig& config, se::DeviceMemoryBase lhs_buffer,
|
|||
MatrixDescriptor rhs = GetMatrixDesc(rhs_layout, rhs_buffer);
|
||||
MatrixDescriptor output = GetMatrixDesc(output_layout, output_buffer);
|
||||
int64_t batch_size = output_layout.batch_size;
|
||||
se::NumericOptions numeric_options{
|
||||
deterministic_ops,
|
||||
/*allow_tf32=*/config.compute_precision <= 1};
|
||||
|
||||
if (!algorithm) algorithm = config.algorithm;
|
||||
|
||||
|
|
@ -651,7 +658,7 @@ Status RunGemm(const GemmConfig& config, se::DeviceMemoryBase lhs_buffer,
|
|||
batch_size, m, n, k, lhs, rhs, output, \
|
||||
static_cast<NativeScaleType>(config.alpha.real()), \
|
||||
static_cast<NativeScaleType>(config.beta), stream, algorithm, \
|
||||
config.compute_precision, profile_result); \
|
||||
config.compute_precision, numeric_options, profile_result); \
|
||||
}
|
||||
|
||||
#define TYPED_GEMM_COMPLEX(SCALENTYPE, ATYPE, BTYPE, CTYPE) \
|
||||
|
|
@ -664,7 +671,7 @@ Status RunGemm(const GemmConfig& config, se::DeviceMemoryBase lhs_buffer,
|
|||
batch_size, m, n, k, lhs, rhs, output, \
|
||||
static_cast<NativeScaleType>(config.alpha), \
|
||||
static_cast<NativeScaleType>(config.beta), stream, algorithm, \
|
||||
config.compute_precision, profile_result); \
|
||||
config.compute_precision, numeric_options, profile_result); \
|
||||
}
|
||||
|
||||
if (output_layout.dtype == S32) {
|
||||
|
|
@ -673,7 +680,7 @@ Status RunGemm(const GemmConfig& config, se::DeviceMemoryBase lhs_buffer,
|
|||
batch_size, m, n, k, lhs, rhs, output,
|
||||
static_cast<int32_t>(config.alpha.real()),
|
||||
static_cast<int32_t>(config.beta), stream, *algorithm,
|
||||
se::blas::kDefaultComputePrecision, profile_result);
|
||||
se::blas::kDefaultComputePrecision, numeric_options, profile_result);
|
||||
}
|
||||
|
||||
TYPED_GEMM(F32, BF16, BF16, BF16)
|
||||
|
|
|
|||
|
|
@ -139,7 +139,8 @@ se::blas::DataType GetScaleType(se::blas::DataType c_type,
|
|||
// If `algorithm` is provided, it overrides the one specified in `config`.
|
||||
Status RunGemm(const GemmConfig& config, se::DeviceMemoryBase lhs_buffer,
|
||||
se::DeviceMemoryBase rhs_buffer,
|
||||
se::DeviceMemoryBase output_buffer, se::Stream* stream,
|
||||
se::DeviceMemoryBase output_buffer, bool deterministic_ops,
|
||||
se::Stream* stream,
|
||||
std::optional<se::blas::AlgorithmType> algorithm = std::nullopt,
|
||||
se::blas::ProfileResult* profile_result = nullptr);
|
||||
|
||||
|
|
|
|||
|
|
@ -53,6 +53,7 @@ Status DoRuntimeAutotuning(se::Stream* stream, GemmConfig& config,
|
|||
VLOG(3) << "Running GEMM runtime autotuning";
|
||||
std::vector<se::blas::AlgorithmType> algorithms;
|
||||
stream->parent()->GetBlasGemmAlgorithms(stream, &algorithms);
|
||||
const bool deterministic_ops = debug_options->xla_gpu_deterministic_ops();
|
||||
|
||||
// Set autotune_level to 3 to disable correctness checking, which avoids
|
||||
// memory allocation during runtime.
|
||||
|
|
@ -84,8 +85,8 @@ Status DoRuntimeAutotuning(se::Stream* stream, GemmConfig& config,
|
|||
// we pass a non-null ProfileResult, DoGemmWithAlgorithm should
|
||||
// always return true, and the actual success-ness is returned in
|
||||
// ProfileResult::is_valid.
|
||||
TF_RETURN_IF_ERROR(RunGemm(config, lhs, rhs, out, stream, algorithm,
|
||||
&profile_result));
|
||||
TF_RETURN_IF_ERROR(RunGemm(config, lhs, rhs, out, deterministic_ops,
|
||||
stream, algorithm, &profile_result));
|
||||
return std::move(profile_result);
|
||||
}));
|
||||
|
||||
|
|
@ -110,6 +111,7 @@ static absl::Status GemmImpl(const ServiceExecutableRunOptions* run_options,
|
|||
se::DeviceMemoryBase lhs_data = GetDeviceAddress(lhs);
|
||||
se::DeviceMemoryBase rhs_data = GetDeviceAddress(rhs);
|
||||
se::DeviceMemoryBase output_data = GetDeviceAddress(out);
|
||||
const bool deterministic_ops = debug_options->xla_gpu_deterministic_ops();
|
||||
|
||||
VLOG(3) << "Running GEMM";
|
||||
se::Stream* stream = run_options->stream();
|
||||
|
|
@ -143,7 +145,8 @@ static absl::Status GemmImpl(const ServiceExecutableRunOptions* run_options,
|
|||
#endif
|
||||
}
|
||||
|
||||
return RunGemm(*gemm_config, lhs_data, rhs_data, output_data, stream);
|
||||
return RunGemm(*gemm_config, lhs_data, rhs_data, output_data,
|
||||
deterministic_ops, stream);
|
||||
}
|
||||
|
||||
XLA_RUNTIME_DEFINE_CUSTOM_CALL(
|
||||
|
|
|
|||
|
|
@ -220,7 +220,6 @@ xla_cc_test(
|
|||
"//tensorflow/compiler/xla/tests:filecheck",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/tsl/lib/core:status_test_util",
|
||||
"//tensorflow/tsl/platform:tensor_float_32_hdr_lib",
|
||||
"//tensorflow/tsl/platform:test",
|
||||
"//tensorflow/tsl/platform:test_main",
|
||||
"@com_google_absl//absl/strings",
|
||||
|
|
@ -892,6 +891,25 @@ xla_cc_test(
|
|||
]),
|
||||
)
|
||||
|
||||
xla_test(
|
||||
name = "tensor_float_32_global_var_test",
|
||||
srcs = ["tensor_float_32_global_var_test.cc"],
|
||||
backend_tags = {"gpu": [
|
||||
"requires-gpu-nvidia",
|
||||
"requires-gpu-sm80-only",
|
||||
]},
|
||||
backends = [
|
||||
"gpu",
|
||||
],
|
||||
deps = [
|
||||
":gpu_codegen_test",
|
||||
"//tensorflow/compiler/xla:error_spec",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/tsl/platform:tensor_float_32_utils",
|
||||
"//tensorflow/tsl/platform:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
xla_cc_test(
|
||||
name = "gpu_fused_mha_test",
|
||||
srcs = ["gpu_fused_mha_test.cc"],
|
||||
|
|
|
|||
|
|
@ -34,7 +34,6 @@ limitations under the License.
|
|||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/xla.pb.h"
|
||||
#include "tensorflow/tsl/lib/core/status_test_util.h"
|
||||
#include "tensorflow/tsl/platform/tensor_float_32_utils.h"
|
||||
#include "tensorflow/tsl/platform/test.h"
|
||||
|
||||
namespace xla {
|
||||
|
|
@ -52,16 +51,6 @@ class GemmRewriteTest : public GpuCodegenTest {
|
|||
->GetDeviceDescription()
|
||||
.cuda_compute_capability();
|
||||
}
|
||||
void SetUp() override {
|
||||
tf32_state_ = tsl::tensor_float_32_execution_enabled();
|
||||
tsl::enable_tensor_float_32_execution(false);
|
||||
}
|
||||
void TearDown() override {
|
||||
tsl::enable_tensor_float_32_execution(tf32_state_);
|
||||
}
|
||||
|
||||
private:
|
||||
bool tf32_state_;
|
||||
};
|
||||
|
||||
TEST_F(GemmRewriteTest, CheckCustomCallTarget) {
|
||||
|
|
@ -1125,7 +1114,7 @@ HloModule test
|
|||
ENTRY test {
|
||||
Arg_0.1 = f16[4,3]{1,0} parameter(0)
|
||||
Arg_1.2 = f16[3,6]{1,0} parameter(1)
|
||||
ROOT dot.3 = f32[4,6]{1,0} dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
|
||||
ROOT dot.3 = f32[4,6]{1,0} dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest, highest}
|
||||
}
|
||||
)";
|
||||
|
||||
|
|
@ -6308,8 +6297,6 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF8TF32E5M2) {
|
|||
}
|
||||
|
||||
)";
|
||||
bool tf32_state_ = tsl::tensor_float_32_execution_enabled();
|
||||
tsl::enable_tensor_float_32_execution(true);
|
||||
|
||||
CheckFp8IfOnHopper(hlo_text);
|
||||
RunAndFilecheckHloRewrite(hlo_text,
|
||||
|
|
@ -6318,7 +6305,6 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF8TF32E5M2) {
|
|||
R"(
|
||||
; CHECK: custom_call_target="__cublas$lt$matmul$f8",
|
||||
)");
|
||||
tsl::enable_tensor_float_32_execution(tf32_state_);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(Fp8CublasTestsBothLegacyAndLt,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,67 @@
|
|||
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/error_spec.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/tsl/platform/tensor_float_32_utils.h"
|
||||
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
namespace {
|
||||
|
||||
// Test that setting the TensorFloat-32 global variable to false causes
|
||||
// TensorFloat-32 not to be used, even when the operand precision is set to the
|
||||
// default.
|
||||
// TODO(b/280130359): Have XLA ignore the TensorFloat-32 global variable
|
||||
class TensorFloat32GlobalVarTest : public HloTestBase {
|
||||
protected:
|
||||
TensorFloat32GlobalVarTest() {
|
||||
// The error tolerances are small enough so that the use of TF32 will cause
|
||||
// the error to be greater than the tolerances.
|
||||
error_spec_ = ErrorSpec{1e-4, 1e-4};
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(TensorFloat32GlobalVarTest, Dot) {
|
||||
tsl::enable_tensor_float_32_execution(false);
|
||||
const char* hlo_text = R"(
|
||||
HloModule TestModule
|
||||
|
||||
ENTRY %dot_computation (x: f32[1024,1024], source: f32[1024,1024]) -> f32[1024,1024] {
|
||||
%x = f32[1024,1024] parameter(0)
|
||||
%y = f32[1024,1024] parameter(1)
|
||||
ROOT %result = f32[1024,1024] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={default, default}
|
||||
}
|
||||
)";
|
||||
EXPECT_TRUE(RunAndCompare(hlo_text, error_spec_));
|
||||
}
|
||||
|
||||
TEST_F(TensorFloat32GlobalVarTest, Convolution) {
|
||||
tsl::enable_tensor_float_32_execution(false);
|
||||
const char* hlo_text = R"(
|
||||
HloModule TestModule
|
||||
|
||||
ENTRY %conv_computation (x: f32[16,40,40,64], source: f32[3,3,64,64]) -> f32[16,40,40,64] {
|
||||
%x = f32[16,40,40,64] parameter(0)
|
||||
%y = f32[3,3,64,64] parameter(1)
|
||||
ROOT %result = f32[16,40,40,64] convolution(x, y), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, operand_precision={default, default}
|
||||
}
|
||||
)";
|
||||
EXPECT_TRUE(RunAndCompare(hlo_text, error_spec_));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
|
@ -41,11 +41,15 @@ limitations under the License.
|
|||
#define TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_BLAS_H_
|
||||
|
||||
#include <complex>
|
||||
#include <limits>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/xla/stream_executor/data_type.h"
|
||||
#include "tensorflow/compiler/xla/stream_executor/device_memory.h"
|
||||
#include "tensorflow/compiler/xla/stream_executor/numeric_options.h"
|
||||
#include "tensorflow/compiler/xla/stream_executor/platform/port.h"
|
||||
#include "tensorflow/tsl/platform/statusor.h"
|
||||
#include "tensorflow/tsl/protobuf/dnn.pb.h"
|
||||
|
|
@ -314,7 +318,7 @@ class BlasSupport {
|
|||
const DeviceMemoryBase &a, int lda,
|
||||
const DeviceMemoryBase &b, int ldb,
|
||||
const void *beta, DeviceMemoryBase *c, int ldc,
|
||||
ComputePrecision precision) = 0;
|
||||
const NumericOptions &numeric_options) = 0;
|
||||
|
||||
// Gets a list of supported algorithms for DoBlasGemmWithAlgorithm.
|
||||
virtual bool GetBlasGemmAlgorithms(
|
||||
|
|
@ -338,7 +342,7 @@ class BlasSupport {
|
|||
const DeviceMemoryBase &b, DataType type_b, int ldb, const void *beta,
|
||||
DeviceMemoryBase *c, DataType type_c, int ldc,
|
||||
ComputationType computation_type, AlgorithmType algorithm,
|
||||
blas::ComputePrecision precision,
|
||||
const NumericOptions &numeric_options,
|
||||
ProfileResult *output_profile_result) = 0;
|
||||
|
||||
virtual tsl::Status DoBlasGemmStridedBatchedWithAlgorithm(
|
||||
|
|
@ -348,7 +352,7 @@ class BlasSupport {
|
|||
const DeviceMemoryBase &b, DataType type_b, int ldb, int64_t stride_b,
|
||||
const void *beta, DeviceMemoryBase *c, DataType type_c, int ldc,
|
||||
int64_t stride_c, int batch_count, ComputationType computation_type,
|
||||
AlgorithmType algorithm, blas::ComputePrecision precision,
|
||||
AlgorithmType algorithm, const NumericOptions &numeric_options,
|
||||
ProfileResult *output_profile_result) = 0;
|
||||
|
||||
// Computes a batch of matrix-matrix product with general matrices.
|
||||
|
|
@ -362,6 +366,7 @@ class BlasSupport {
|
|||
DeviceMemorySlice<Eigen::half> b, int ldb,
|
||||
float beta, DeviceMemorySlice<Eigen::half> c,
|
||||
int ldc, int batch_count,
|
||||
const NumericOptions &numeric_options,
|
||||
ScratchAllocator *scratch_allocator) = 0;
|
||||
virtual bool DoBlasGemmBatched(Stream *stream, blas::Transpose transa,
|
||||
blas::Transpose transb, uint64_t m, uint64_t n,
|
||||
|
|
@ -371,6 +376,7 @@ class BlasSupport {
|
|||
float beta,
|
||||
DeviceMemorySlice<Eigen::bfloat16> c, int ldc,
|
||||
int batch_count,
|
||||
const NumericOptions &numeric_options,
|
||||
ScratchAllocator *scratch_allocator) = 0;
|
||||
virtual bool DoBlasGemmBatched(Stream *stream, blas::Transpose transa,
|
||||
blas::Transpose transb, uint64_t m, uint64_t n,
|
||||
|
|
@ -379,6 +385,7 @@ class BlasSupport {
|
|||
DeviceMemorySlice<float> b, int ldb,
|
||||
float beta, DeviceMemorySlice<float> c,
|
||||
int ldc, int batch_count,
|
||||
const NumericOptions &numeric_options,
|
||||
ScratchAllocator *scratch_allocator) = 0;
|
||||
virtual bool DoBlasGemmBatched(Stream *stream, blas::Transpose transa,
|
||||
blas::Transpose transb, uint64_t m, uint64_t n,
|
||||
|
|
@ -387,6 +394,7 @@ class BlasSupport {
|
|||
DeviceMemorySlice<double> b, int ldb,
|
||||
double beta, DeviceMemorySlice<double> c,
|
||||
int ldc, int batch_count,
|
||||
const NumericOptions &numeric_options,
|
||||
ScratchAllocator *scratch_allocator) = 0;
|
||||
virtual bool DoBlasGemmBatched(
|
||||
Stream *stream, blas::Transpose transa, blas::Transpose transb,
|
||||
|
|
@ -394,14 +402,16 @@ class BlasSupport {
|
|||
DeviceMemorySlice<std::complex<float>> a, int lda,
|
||||
DeviceMemorySlice<std::complex<float>> b, int ldb,
|
||||
std::complex<float> beta, DeviceMemorySlice<std::complex<float>> c,
|
||||
int ldc, int batch_count, ScratchAllocator *scratch_allocator) = 0;
|
||||
int ldc, int batch_count, const NumericOptions &numeric_options,
|
||||
ScratchAllocator *scratch_allocator) = 0;
|
||||
virtual bool DoBlasGemmBatched(
|
||||
Stream *stream, blas::Transpose transa, blas::Transpose transb,
|
||||
uint64_t m, uint64_t n, uint64 k, std::complex<double> alpha,
|
||||
DeviceMemorySlice<std::complex<double>> a, int lda,
|
||||
DeviceMemorySlice<std::complex<double>> b, int ldb,
|
||||
std::complex<double> beta, DeviceMemorySlice<std::complex<double>> c,
|
||||
int ldc, int batch_count, ScratchAllocator *scratch_allocator) = 0;
|
||||
int ldc, int batch_count, const NumericOptions &numeric_options,
|
||||
ScratchAllocator *scratch_allocator) = 0;
|
||||
|
||||
// Batched gemm with strides instead of pointer arrays.
|
||||
virtual tsl::Status DoBlasGemmStridedBatched(
|
||||
|
|
@ -410,7 +420,7 @@ class BlasSupport {
|
|||
const DeviceMemoryBase &a, int lda, int64_t stride_a,
|
||||
const DeviceMemoryBase &b, int ldb, int64_t stride_b, const void *beta,
|
||||
DeviceMemoryBase *c, int ldc, int64_t stride_c, int batch_count,
|
||||
blas::ComputePrecision precision) = 0;
|
||||
const NumericOptions &numeric_options) = 0;
|
||||
|
||||
// Solves a triangular matrix equation.
|
||||
//
|
||||
|
|
@ -560,7 +570,7 @@ class BlasSupport {
|
|||
uint64_t m, uint64 n, uint64 k, blas::DataType dtype, const void *alpha, \
|
||||
const DeviceMemoryBase &a, int lda, const DeviceMemoryBase &b, int ldb, \
|
||||
const void *beta, DeviceMemoryBase *c, int ldc, \
|
||||
blas::ComputePrecision precision) override; \
|
||||
const NumericOptions &numeric_options) override; \
|
||||
bool GetBlasGemmAlgorithms(Stream *stream, \
|
||||
std::vector<blas::AlgorithmType> *out_algorithms) \
|
||||
override; \
|
||||
|
|
@ -571,7 +581,7 @@ class BlasSupport {
|
|||
const DeviceMemoryBase &b, blas::DataType type_b, int ldb, \
|
||||
const void *beta, DeviceMemoryBase *c, blas::DataType type_c, int ldc, \
|
||||
blas::ComputationType computation_type, blas::AlgorithmType algorithm, \
|
||||
blas::ComputePrecision precision, \
|
||||
const NumericOptions &numeric_options, \
|
||||
blas::ProfileResult *output_profile_result) override; \
|
||||
bool DoBlasGemmBatched( \
|
||||
Stream *stream, blas::Transpose transa, blas::Transpose transb, \
|
||||
|
|
@ -579,6 +589,7 @@ class BlasSupport {
|
|||
DeviceMemorySlice<Eigen::half> a, int lda, \
|
||||
DeviceMemorySlice<Eigen::half> b, int ldb, float beta, \
|
||||
DeviceMemorySlice<Eigen::half> c, int ldc, int batch_count, \
|
||||
const NumericOptions &numeric_options, \
|
||||
ScratchAllocator *scratch_allocator) override; \
|
||||
bool DoBlasGemmBatched( \
|
||||
Stream *stream, blas::Transpose transa, blas::Transpose transb, \
|
||||
|
|
@ -586,40 +597,45 @@ class BlasSupport {
|
|||
DeviceMemorySlice<Eigen::bfloat16> a, int lda, \
|
||||
DeviceMemorySlice<Eigen::bfloat16> b, int ldb, float beta, \
|
||||
DeviceMemorySlice<Eigen::bfloat16> c, int ldc, int batch_count, \
|
||||
const NumericOptions &numeric_options, \
|
||||
ScratchAllocator *scratch_allocator) override; \
|
||||
bool DoBlasGemmBatched( \
|
||||
Stream *stream, blas::Transpose transa, blas::Transpose transb, \
|
||||
uint64_t m, uint64 n, uint64 k, float alpha, DeviceMemorySlice<float> a, \
|
||||
int lda, DeviceMemorySlice<float> b, int ldb, float beta, \
|
||||
DeviceMemorySlice<float> c, int ldc, int batch_count, \
|
||||
const NumericOptions &numeric_options, \
|
||||
ScratchAllocator *scratch_allocator) override; \
|
||||
bool DoBlasGemmBatched( \
|
||||
Stream *stream, blas::Transpose transa, blas::Transpose transb, \
|
||||
uint64_t m, uint64 n, uint64 k, double alpha, \
|
||||
DeviceMemorySlice<double> a, int lda, DeviceMemorySlice<double> b, \
|
||||
int ldb, double beta, DeviceMemorySlice<double> c, int ldc, \
|
||||
int batch_count, ScratchAllocator *scratch_allocator) override; \
|
||||
int batch_count, const NumericOptions &numeric_options, \
|
||||
ScratchAllocator *scratch_allocator) override; \
|
||||
bool DoBlasGemmBatched( \
|
||||
Stream *stream, blas::Transpose transa, blas::Transpose transb, \
|
||||
uint64_t m, uint64 n, uint64 k, std::complex<float> alpha, \
|
||||
DeviceMemorySlice<std::complex<float>> a, int lda, \
|
||||
DeviceMemorySlice<std::complex<float>> b, int ldb, \
|
||||
std::complex<float> beta, DeviceMemorySlice<std::complex<float>> c, \
|
||||
int ldc, int batch_count, ScratchAllocator *scratch_allocator) override; \
|
||||
int ldc, int batch_count, const NumericOptions &numeric_options, \
|
||||
ScratchAllocator *scratch_allocator) override; \
|
||||
bool DoBlasGemmBatched( \
|
||||
Stream *stream, blas::Transpose transa, blas::Transpose transb, \
|
||||
uint64_t m, uint64 n, uint64 k, std::complex<double> alpha, \
|
||||
DeviceMemorySlice<std::complex<double>> a, int lda, \
|
||||
DeviceMemorySlice<std::complex<double>> b, int ldb, \
|
||||
std::complex<double> beta, DeviceMemorySlice<std::complex<double>> c, \
|
||||
int ldc, int batch_count, ScratchAllocator *scratch_allocator) override; \
|
||||
int ldc, int batch_count, const NumericOptions &numeric_options, \
|
||||
ScratchAllocator *scratch_allocator) override; \
|
||||
tsl::Status DoBlasGemmStridedBatched( \
|
||||
Stream *stream, blas::Transpose transa, blas::Transpose transb, \
|
||||
uint64_t m, uint64 n, uint64 k, blas::DataType dtype, const void *alpha, \
|
||||
const DeviceMemoryBase &a, int lda, int64_t stride_a, \
|
||||
const DeviceMemoryBase &b, int ldb, int64_t stride_b, const void *beta, \
|
||||
DeviceMemoryBase *c, int ldc, int64_t stride_c, int batch_count, \
|
||||
blas::ComputePrecision precision) override; \
|
||||
const NumericOptions &numeric_options) override; \
|
||||
tsl::Status DoBlasGemmStridedBatchedWithAlgorithm( \
|
||||
Stream *stream, blas::Transpose transa, blas::Transpose transb, \
|
||||
uint64_t m, uint64 n, uint64 k, const void *alpha, \
|
||||
|
|
@ -628,7 +644,7 @@ class BlasSupport {
|
|||
int ldb, int64_t stride_b, const void *beta, DeviceMemoryBase *c, \
|
||||
blas::DataType type_c, int ldc, int64_t stride_c, int batch_count, \
|
||||
blas::ComputationType computation_type, blas::AlgorithmType algorithm, \
|
||||
blas::ComputePrecision precision, \
|
||||
const NumericOptions &numeric_options, \
|
||||
blas::ProfileResult *output_profile_result) override; \
|
||||
bool DoBlasTrsm(Stream *stream, blas::Side side, blas::UpperLower uplo, \
|
||||
blas::Transpose transa, blas::Diagonal diag, uint64_t m, \
|
||||
|
|
|
|||
|
|
@ -36,6 +36,7 @@ limitations under the License.
|
|||
#include "tensorflow/compiler/xla/stream_executor/gpu/gpu_stream.h"
|
||||
#include "tensorflow/compiler/xla/stream_executor/gpu/gpu_timer.h"
|
||||
#include "tensorflow/compiler/xla/stream_executor/gpu/gpu_types.h"
|
||||
#include "tensorflow/compiler/xla/stream_executor/numeric_options.h"
|
||||
#include "tensorflow/compiler/xla/stream_executor/platform/initialize.h"
|
||||
#include "tensorflow/compiler/xla/stream_executor/platform/logging.h"
|
||||
#include "tensorflow/compiler/xla/stream_executor/platform/port.h"
|
||||
|
|
@ -588,7 +589,7 @@ tsl::Status CUDABlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
|
|||
const void *alpha, const DeviceMemoryBase &a,
|
||||
int lda, const DeviceMemoryBase &b, int ldb,
|
||||
const void *beta, DeviceMemoryBase *c, int ldc,
|
||||
blas::ComputePrecision precision) {
|
||||
const NumericOptions &numeric_options) {
|
||||
cublasMath_t math_type = CUBLAS_DEFAULT_MATH;
|
||||
|
||||
#if CUDA_VERSION < 11000
|
||||
|
|
@ -598,17 +599,7 @@ tsl::Status CUDABlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
|
|||
#else
|
||||
if (dtype == blas::DataType::kFloat) {
|
||||
math_type = CUBLAS_TF32_TENSOR_OP_MATH;
|
||||
if (stream->GetCudaComputeCapability().IsAtLeast(
|
||||
CudaComputeCapability::AMPERE)) {
|
||||
// TODO(reedwm): Remove or make this VLOG(1) once TensorFloat-32 is more
|
||||
// well tested.
|
||||
if (tsl::tensor_float_32_execution_enabled()) {
|
||||
LOG_FIRST_N(INFO, 1) << "TensorFloat-32 will be used for the matrix "
|
||||
"multiplication. This will only be logged "
|
||||
"once.";
|
||||
}
|
||||
}
|
||||
if (precision > blas::kDefaultComputePrecision) {
|
||||
if (numeric_options.allow_tf32) {
|
||||
math_type = CUBLAS_DEFAULT_MATH;
|
||||
}
|
||||
}
|
||||
|
|
@ -722,7 +713,7 @@ static bool UsesTensorOps(blas::AlgorithmType algo) {
|
|||
|
||||
static tsl::StatusOr<cublasMath_t> GetMathTypeForGemmEx(
|
||||
Stream *stream, blas::AlgorithmType algorithm, blas::DataType type_a,
|
||||
blas::DataType type_b, blas::ComputePrecision precision) {
|
||||
blas::DataType type_b, const NumericOptions &numeric_options) {
|
||||
if (type_a != type_b) {
|
||||
return tsl::errors::Internal("Types of inputs mismatch");
|
||||
}
|
||||
|
|
@ -753,10 +744,6 @@ static tsl::StatusOr<cublasMath_t> GetMathTypeForGemmEx(
|
|||
"Algorithm ", algorithm,
|
||||
" uses tensor ops, but tensor ops are not available in sm",
|
||||
cc.major, "X devices for float input types.");
|
||||
} else if (!tsl::tensor_float_32_execution_enabled()) {
|
||||
return tsl::errors::Internal(
|
||||
"Algorithm ", algorithm,
|
||||
" uses tensor ops, but tensor ops are disabled for fp32 inputs");
|
||||
}
|
||||
math_type = CUBLAS_TF32_TENSOR_OP_MATH;
|
||||
#endif
|
||||
|
|
@ -770,7 +757,7 @@ static tsl::StatusOr<cublasMath_t> GetMathTypeForGemmEx(
|
|||
" uses tensor ops which are not supported for input");
|
||||
}
|
||||
}
|
||||
if (precision > blas::kDefaultComputePrecision) {
|
||||
if (!numeric_options.allow_tf32) {
|
||||
math_type = CUBLAS_DEFAULT_MATH;
|
||||
}
|
||||
|
||||
|
|
@ -814,11 +801,11 @@ tsl::Status CUDABlas::DoBlasGemmWithAlgorithm(
|
|||
blas::DataType type_a, int lda, const DeviceMemoryBase &b,
|
||||
blas::DataType type_b, int ldb, const void *beta, DeviceMemoryBase *c,
|
||||
blas::DataType type_c, int ldc, blas::ComputationType computation_type,
|
||||
blas::AlgorithmType algorithm, blas::ComputePrecision precision,
|
||||
blas::AlgorithmType algorithm, const NumericOptions &numeric_options,
|
||||
blas::ProfileResult *output_profile_result) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
cublasMath_t math_type,
|
||||
GetMathTypeForGemmEx(stream, algorithm, type_a, type_b, precision));
|
||||
GetMathTypeForGemmEx(stream, algorithm, type_a, type_b, numeric_options));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto timer, StartGpuTimerForProfile(
|
||||
stream, parent_, output_profile_result));
|
||||
|
|
@ -846,11 +833,11 @@ tsl::Status CUDABlas::DoBlasGemmStridedBatchedWithAlgorithm(
|
|||
blas::DataType type_b, int ldb, int64_t stride_b, const void *beta,
|
||||
DeviceMemoryBase *c, blas::DataType type_c, int ldc, int64_t stride_c,
|
||||
int batch_count, blas::ComputationType computation_type,
|
||||
blas::AlgorithmType algorithm, blas::ComputePrecision precision,
|
||||
blas::AlgorithmType algorithm, const NumericOptions &numeric_options,
|
||||
blas::ProfileResult *output_profile_result) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
cublasMath_t math_type,
|
||||
GetMathTypeForGemmEx(stream, algorithm, type_a, type_b, precision));
|
||||
GetMathTypeForGemmEx(stream, algorithm, type_a, type_b, numeric_options));
|
||||
TF_ASSIGN_OR_RETURN(auto timer, StartGpuTimerForProfile(
|
||||
stream, parent_, output_profile_result));
|
||||
|
||||
|
|
@ -991,6 +978,7 @@ tsl::Status CUDABlas::DoBlasGemmBatchedInternal(
|
|||
const DeviceMemorySlice<T> &a_ptrs_to_wrappers, int lda,
|
||||
const DeviceMemorySlice<T> &b_ptrs_to_wrappers, int ldb, Scalar beta,
|
||||
const DeviceMemorySlice<T> &c_ptrs_to_wrappers, int ldc, int batch_count,
|
||||
const NumericOptions &numeric_options,
|
||||
ScratchAllocator *scratch_allocator) {
|
||||
std::vector<T *> a_raw_ptrs, b_raw_ptrs, c_raw_ptrs;
|
||||
for (int i = 0; i < batch_count; ++i) {
|
||||
|
|
@ -1072,12 +1060,14 @@ tsl::Status CUDABlas::DoBlasGemmBatchedInternal(
|
|||
algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
|
||||
#if CUBLAS_VER_MAJOR >= 11
|
||||
} else if (data_type == CUDA_R_32F) {
|
||||
// DoBlassInternalImpl will switch math_type back to CUBLAS_DEFAULT_MATH
|
||||
// if TensorFloat-32 is disabled.
|
||||
math_type = CUBLAS_TF32_TENSOR_OP_MATH;
|
||||
algo = tsl::tensor_float_32_execution_enabled()
|
||||
? CUBLAS_GEMM_DFALT_TENSOR_OP
|
||||
: CUBLAS_GEMM_DFALT;
|
||||
if (numeric_options.allow_tf32 &&
|
||||
tsl::tensor_float_32_execution_enabled()) {
|
||||
math_type = CUBLAS_TENSOR_OP_MATH;
|
||||
algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
|
||||
} else {
|
||||
math_type = CUBLAS_DEFAULT_MATH;
|
||||
algo = CUBLAS_GEMM_DFALT;
|
||||
}
|
||||
#endif
|
||||
} else {
|
||||
math_type = CUBLAS_DEFAULT_MATH;
|
||||
|
|
@ -1119,8 +1109,7 @@ tsl::Status CUDABlas::DoBlasGemmBatchedInternal(
|
|||
DeviceMemory<T> *c_matrix = c_ptrs_to_wrappers[b];
|
||||
TF_RETURN_IF_ERROR(DoBlasGemm(
|
||||
stream, transa, transb, m, n, k, blas::ToDataType<T>::value, &alpha,
|
||||
a_matrix, lda, b_matrix, ldb, &beta, c_matrix, ldc,
|
||||
blas::kDefaultComputePrecision));
|
||||
a_matrix, lda, b_matrix, ldb, &beta, c_matrix, ldc, numeric_options));
|
||||
}
|
||||
return ::tsl::OkStatus();
|
||||
}
|
||||
|
|
@ -1131,12 +1120,14 @@ bool CUDABlas::DoBlasGemmBatched(
|
|||
uint64_t n, uint64 k, float alpha, DeviceMemorySlice<Eigen::half> a_array,
|
||||
int lda, DeviceMemorySlice<Eigen::half> b_array, int ldb, float beta,
|
||||
DeviceMemorySlice<Eigen::half> c_array, int ldc, int batch_count,
|
||||
const NumericOptions &numeric_options,
|
||||
ScratchAllocator *scratch_allocator) {
|
||||
// Note: The func passed here (cublasSgemmBatched) is not actually called,
|
||||
// due to special handling of fp16 inside DoBlasGemmBatchedInternal.
|
||||
tsl::Status status = DoBlasGemmBatchedInternal(
|
||||
cublasSgemmBatched, stream, transa, transb, m, n, k, alpha, a_array, lda,
|
||||
b_array, ldb, beta, c_array, ldc, batch_count, scratch_allocator);
|
||||
b_array, ldb, beta, c_array, ldc, batch_count, numeric_options,
|
||||
scratch_allocator);
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << status;
|
||||
}
|
||||
|
|
@ -1149,12 +1140,14 @@ bool CUDABlas::DoBlasGemmBatched(
|
|||
DeviceMemorySlice<Eigen::bfloat16> a_array, int lda,
|
||||
DeviceMemorySlice<Eigen::bfloat16> b_array, int ldb, float beta,
|
||||
DeviceMemorySlice<Eigen::bfloat16> c_array, int ldc, int batch_count,
|
||||
const NumericOptions &numeric_options,
|
||||
ScratchAllocator *scratch_allocator) {
|
||||
// Note: The func passed here (cublasSgemmBatched) is not actually called,
|
||||
// due to special handling of bf16 inside DoBlasGemmBatchedInternal.
|
||||
tsl::Status status = DoBlasGemmBatchedInternal(
|
||||
cublasSgemmBatched, stream, transa, transb, m, n, k, alpha, a_array, lda,
|
||||
b_array, ldb, beta, c_array, ldc, batch_count, scratch_allocator);
|
||||
b_array, ldb, beta, c_array, ldc, batch_count, numeric_options,
|
||||
scratch_allocator);
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << status;
|
||||
}
|
||||
|
|
@ -1168,10 +1161,12 @@ bool CUDABlas::DoBlasGemmBatched(Stream *stream, blas::Transpose transa,
|
|||
DeviceMemorySlice<float> b_array, int ldb,
|
||||
float beta, DeviceMemorySlice<float> c_array,
|
||||
int ldc, int batch_count,
|
||||
const NumericOptions &numeric_options,
|
||||
ScratchAllocator *scratch_allocator) {
|
||||
tsl::Status status = DoBlasGemmBatchedInternal(
|
||||
cublasSgemmBatched, stream, transa, transb, m, n, k, alpha, a_array, lda,
|
||||
b_array, ldb, beta, c_array, ldc, batch_count, scratch_allocator);
|
||||
b_array, ldb, beta, c_array, ldc, batch_count, numeric_options,
|
||||
scratch_allocator);
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << status;
|
||||
}
|
||||
|
|
@ -1185,10 +1180,12 @@ bool CUDABlas::DoBlasGemmBatched(Stream *stream, blas::Transpose transa,
|
|||
DeviceMemorySlice<double> b_array, int ldb,
|
||||
double beta, DeviceMemorySlice<double> c_array,
|
||||
int ldc, int batch_count,
|
||||
const NumericOptions &numeric_options,
|
||||
ScratchAllocator *scratch_allocator) {
|
||||
tsl::Status status = DoBlasGemmBatchedInternal(
|
||||
cublasDgemmBatched, stream, transa, transb, m, n, k, alpha, a_array, lda,
|
||||
b_array, ldb, beta, c_array, ldc, batch_count, scratch_allocator);
|
||||
b_array, ldb, beta, c_array, ldc, batch_count, numeric_options,
|
||||
scratch_allocator);
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << status;
|
||||
}
|
||||
|
|
@ -1201,10 +1198,12 @@ bool CUDABlas::DoBlasGemmBatched(
|
|||
DeviceMemorySlice<std::complex<float>> a_array, int lda,
|
||||
DeviceMemorySlice<std::complex<float>> b_array, int ldb,
|
||||
std::complex<float> beta, DeviceMemorySlice<std::complex<float>> c_array,
|
||||
int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
|
||||
int ldc, int batch_count, const NumericOptions &numeric_options,
|
||||
ScratchAllocator *scratch_allocator) {
|
||||
tsl::Status status = DoBlasGemmBatchedInternal(
|
||||
cublasCgemmBatched, stream, transa, transb, m, n, k, alpha, a_array, lda,
|
||||
b_array, ldb, beta, c_array, ldc, batch_count, scratch_allocator);
|
||||
b_array, ldb, beta, c_array, ldc, batch_count, numeric_options,
|
||||
scratch_allocator);
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << status;
|
||||
}
|
||||
|
|
@ -1217,10 +1216,12 @@ bool CUDABlas::DoBlasGemmBatched(
|
|||
DeviceMemorySlice<std::complex<double>> a_array, int lda,
|
||||
DeviceMemorySlice<std::complex<double>> b_array, int ldb,
|
||||
std::complex<double> beta, DeviceMemorySlice<std::complex<double>> c_array,
|
||||
int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
|
||||
int ldc, int batch_count, const NumericOptions &numeric_options,
|
||||
ScratchAllocator *scratch_allocator) {
|
||||
tsl::Status status = DoBlasGemmBatchedInternal(
|
||||
cublasZgemmBatched, stream, transa, transb, m, n, k, alpha, a_array, lda,
|
||||
b_array, ldb, beta, c_array, ldc, batch_count, scratch_allocator);
|
||||
b_array, ldb, beta, c_array, ldc, batch_count, numeric_options,
|
||||
scratch_allocator);
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << status;
|
||||
}
|
||||
|
|
@ -1233,19 +1234,16 @@ tsl::Status CUDABlas::DoBlasGemmStridedBatched(
|
|||
const DeviceMemoryBase &a, int lda, int64_t stride_a,
|
||||
const DeviceMemoryBase &b, int ldb, int64_t stride_b, const void *beta,
|
||||
DeviceMemoryBase *c, int ldc, int64_t stride_c, int batch_count,
|
||||
blas::ComputePrecision precision) {
|
||||
const NumericOptions &numeric_options) {
|
||||
cublasMath_t math_type = CUBLAS_DEFAULT_MATH;
|
||||
#if CUDA_VERSION < 11000
|
||||
if (dtype == dnn::kHalf) {
|
||||
math_type = CUBLAS_TENSOR_OP_MATH;
|
||||
}
|
||||
#else
|
||||
if (dtype == dnn::kFloat) {
|
||||
if (dtype == dnn::kFloat && numeric_options.allow_tf32) {
|
||||
math_type = CUBLAS_TF32_TENSOR_OP_MATH;
|
||||
}
|
||||
if (precision > blas::kDefaultComputePrecision) {
|
||||
math_type = CUBLAS_DEFAULT_MATH;
|
||||
}
|
||||
#endif
|
||||
|
||||
switch (dtype) {
|
||||
|
|
|
|||
|
|
@ -109,6 +109,7 @@ class CUDABlas : public blas::BlasSupport {
|
|||
const DeviceMemorySlice<T> &a_array, int lda,
|
||||
const DeviceMemorySlice<T> &b_array, int ldb, Scalar beta,
|
||||
const DeviceMemorySlice<T> &c_array, int ldc, int batch_count,
|
||||
const NumericOptions &numeric_options,
|
||||
ScratchAllocator *scratch_allocator);
|
||||
|
||||
// Guards the cuBLAS handle for this device.
|
||||
|
|
|
|||
|
|
@ -954,7 +954,7 @@ static bool TensorOpMathAvailable(
|
|||
}
|
||||
|
||||
static bool IsTensorMathEnabled(CudaComputeCapability cuda_compute_capability,
|
||||
dnn::DataType input_type) {
|
||||
dnn::DataType input_type, bool allow_tf32) {
|
||||
if (!TensorOpMathAvailable(cuda_compute_capability)) {
|
||||
return false;
|
||||
}
|
||||
|
|
@ -962,7 +962,7 @@ static bool IsTensorMathEnabled(CudaComputeCapability cuda_compute_capability,
|
|||
#if CUDNN_VERSION < 8000
|
||||
return false;
|
||||
#else
|
||||
if (!tsl::tensor_float_32_execution_enabled()) {
|
||||
if (!allow_tf32 || !tsl::tensor_float_32_execution_enabled()) {
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
|
|
@ -970,8 +970,10 @@ static bool IsTensorMathEnabled(CudaComputeCapability cuda_compute_capability,
|
|||
return true;
|
||||
}
|
||||
|
||||
static bool IsTensorMathEnabled(Stream* stream, dnn::DataType input_type) {
|
||||
return IsTensorMathEnabled(stream->GetCudaComputeCapability(), input_type);
|
||||
static bool IsTensorMathEnabled(Stream* stream, dnn::DataType input_type,
|
||||
bool allow_tf32) {
|
||||
return IsTensorMathEnabled(stream->GetCudaComputeCapability(), input_type,
|
||||
allow_tf32);
|
||||
}
|
||||
|
||||
// Turns a PoolingDescriptor structure into a cudnn pooling descriptor handle
|
||||
|
|
@ -1314,8 +1316,9 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor {
|
|||
int cell_size, int batch_size, cudnnRNNInputMode_t input_mode,
|
||||
cudnnDirectionMode_t direction_mode, cudnnRNNMode_t rnn_mode,
|
||||
cudnnDataType_t data_type, cudnnDataType_t compute_type,
|
||||
const dnn::AlgorithmConfig& algorithm_config, float dropout,
|
||||
uint64_t seed, ScratchAllocator* state_allocator, bool use_padded_io) {
|
||||
const dnn::AlgorithmConfig& algorithm_config,
|
||||
const NumericOptions& numeric_options, float dropout, uint64_t seed,
|
||||
ScratchAllocator* state_allocator, bool use_padded_io) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
CudnnDropoutDescriptor dropout_desc,
|
||||
CudnnDropoutDescriptor::Create(cudnn, dropout, seed, state_allocator));
|
||||
|
|
@ -1337,7 +1340,8 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor {
|
|||
// TODO(csigg): Minimal support cuDNN version is 7.3, clean up.
|
||||
bool allow_tensor_ops = data_type == CUDNN_DATA_HALF;
|
||||
if (data_type == CUDNN_DATA_FLOAT)
|
||||
allow_tensor_ops = tsl::tensor_float_32_execution_enabled();
|
||||
allow_tensor_ops = numeric_options.allow_tf32 &&
|
||||
tsl::tensor_float_32_execution_enabled();
|
||||
bool use_tensor_ops =
|
||||
algorithm_config.algorithm().has_value()
|
||||
? algorithm_config.algorithm()->tensor_ops_enabled()
|
||||
|
|
@ -2471,8 +2475,8 @@ CudnnSupport::createRnnDescriptor(
|
|||
int batch_size, dnn::RnnInputMode input_mode,
|
||||
dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode,
|
||||
dnn::DataType data_type, const dnn::AlgorithmConfig& algorithm_config,
|
||||
float dropout, uint64_t seed, ScratchAllocator* state_allocator,
|
||||
bool use_padded_io) {
|
||||
const NumericOptions& numeric_options, float dropout, uint64_t seed,
|
||||
ScratchAllocator* state_allocator, bool use_padded_io) {
|
||||
// Setting up a cudnnRNNDescriptor requires a cuDNN handle, but because it's
|
||||
// not enqueueing anything into a stream, we pass in the null stream.
|
||||
auto cudnn = cudnn_->GetHandle(parent_, /*stream=*/nullptr);
|
||||
|
|
@ -2483,7 +2487,8 @@ CudnnSupport::createRnnDescriptor(
|
|||
ToCudnnRnnInputMode(input_mode),
|
||||
ToCudnnRnnDirectionMode(direction_mode), ToCudnnRnnMode(rnn_mode),
|
||||
ToCudnnDataType(data_type), GetRnnComputeType(data_type),
|
||||
algorithm_config, dropout, seed, state_allocator, use_padded_io));
|
||||
algorithm_config, numeric_options, dropout, seed, state_allocator,
|
||||
use_padded_io));
|
||||
return std::unique_ptr<dnn::RnnDescriptor>(
|
||||
new CudnnRnnDescriptor(std::move(rnn_desc)));
|
||||
}
|
||||
|
|
@ -3087,19 +3092,16 @@ AllocateCudnnConvolutionBackwardFilterWorkspace(
|
|||
return scratch_allocator->AllocateBytes(size_in_bytes);
|
||||
}
|
||||
|
||||
tsl::StatusOr<bool> UseTensorOps(Stream* stream, dnn::DataType type,
|
||||
std::optional<dnn::AlgorithmDesc> desc) {
|
||||
bool use_tensor_ops;
|
||||
bool UseTensorOps(dnn::DataType input_type,
|
||||
std::optional<dnn::AlgorithmDesc> desc) {
|
||||
if (desc.has_value()) {
|
||||
use_tensor_ops = desc->tensor_ops_enabled();
|
||||
if (use_tensor_ops && !IsTensorMathEnabled(stream, type)) {
|
||||
return tsl::Status(absl::StatusCode::kInvalidArgument,
|
||||
"Algo requests disabled tensor op evaluation.");
|
||||
}
|
||||
return desc->tensor_ops_enabled();
|
||||
} else {
|
||||
use_tensor_ops = IsTensorMathEnabled(stream, type);
|
||||
// It's unknown whether the user wants to use TensorFloat-32, which is used
|
||||
// with tensor ops when the inputs are FP32. For safety, assume the user
|
||||
// does not want TensorFloat-32 on FP32 inputs.
|
||||
return input_type != dnn::DataType::kFloat;
|
||||
}
|
||||
return use_tensor_ops;
|
||||
}
|
||||
|
||||
cudnnDataType_t GetRnnComputeType(dnn::DataType data_type);
|
||||
|
|
@ -3118,9 +3120,7 @@ tsl::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionForwardAlgorithm(
|
|||
CudnnConvolutionDescriptor conv(
|
||||
convolution_descriptor,
|
||||
ToCudnnDataType(GetConvAccumulatorType(element_type)));
|
||||
bool use_tensor_ops;
|
||||
TF_ASSIGN_OR_RETURN(use_tensor_ops,
|
||||
UseTensorOps(stream, element_type, algo_desc));
|
||||
bool use_tensor_ops = UseTensorOps(element_type, algo_desc);
|
||||
conv.set_use_tensor_op_math(use_tensor_ops);
|
||||
|
||||
if (!algo_desc.has_value()) {
|
||||
|
|
@ -3159,8 +3159,7 @@ tsl::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionForwardAlgorithm(
|
|||
"Returned status: ", scratch_or.status().ToString()));
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(use_tensor_ops,
|
||||
UseTensorOps(stream, element_type, algo_desc));
|
||||
use_tensor_ops = UseTensorOps(element_type, algo_desc);
|
||||
conv.set_use_tensor_op_math(use_tensor_ops);
|
||||
TF_ASSIGN_OR_RETURN(*scratch, AllocateCudnnConvolutionForwardWorkspace(
|
||||
stream, cudnn, input_nd, filter, conv,
|
||||
|
|
@ -3180,9 +3179,7 @@ tsl::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionBackwardDataAlgorithm(
|
|||
CudnnConvolutionDescriptor conv(
|
||||
convolution_descriptor,
|
||||
ToCudnnDataType(GetConvAccumulatorType(element_type)));
|
||||
bool use_tensor_ops;
|
||||
TF_ASSIGN_OR_RETURN(use_tensor_ops,
|
||||
UseTensorOps(stream, element_type, algo_desc));
|
||||
bool use_tensor_ops = UseTensorOps(element_type, algo_desc);
|
||||
conv.set_use_tensor_op_math(use_tensor_ops);
|
||||
|
||||
if (!algo_desc.has_value()) {
|
||||
|
|
@ -3220,8 +3217,7 @@ tsl::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionBackwardDataAlgorithm(
|
|||
"while a secondary algorithm is not provided.");
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(use_tensor_ops,
|
||||
UseTensorOps(stream, element_type, algo_desc));
|
||||
use_tensor_ops = UseTensorOps(element_type, algo_desc);
|
||||
conv.set_use_tensor_op_math(use_tensor_ops);
|
||||
TF_ASSIGN_OR_RETURN(*scratch, AllocateCudnnConvolutionBackwardDataWorkspace(
|
||||
stream, cudnn, input_nd, filter, conv,
|
||||
|
|
@ -3241,9 +3237,7 @@ tsl::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionBackwardFilterAlgorithm(
|
|||
CudnnConvolutionDescriptor conv(
|
||||
convolution_descriptor,
|
||||
ToCudnnDataType(GetConvAccumulatorType(element_type)));
|
||||
bool use_tensor_ops;
|
||||
TF_ASSIGN_OR_RETURN(use_tensor_ops,
|
||||
UseTensorOps(stream, element_type, algo_desc));
|
||||
bool use_tensor_ops = UseTensorOps(element_type, algo_desc);
|
||||
conv.set_use_tensor_op_math(use_tensor_ops);
|
||||
|
||||
if (!algo_desc.has_value()) {
|
||||
|
|
@ -3284,8 +3278,7 @@ tsl::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionBackwardFilterAlgorithm(
|
|||
scratch_or.status().ToString()));
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(use_tensor_ops,
|
||||
UseTensorOps(stream, element_type, algo_desc));
|
||||
use_tensor_ops = UseTensorOps(element_type, algo_desc);
|
||||
conv.set_use_tensor_op_math(use_tensor_ops);
|
||||
TF_ASSIGN_OR_RETURN(*scratch, AllocateCudnnConvolutionBackwardFilterWorkspace(
|
||||
stream, cudnn, input_nd, filter, conv,
|
||||
|
|
@ -4913,9 +4906,6 @@ class CudnnLegacyConvRunner : public dnn::ConvRunner {
|
|||
DeviceMemoryBase output_data) const override {
|
||||
auto algo = MakeAlgorithmDesc();
|
||||
|
||||
// Check that the current stream supports tensor ops if they're requested.
|
||||
TF_RETURN_IF_ERROR(UseTensorOps(stream, input_type_, algo).status());
|
||||
|
||||
if (static_cast<internal::StreamExecutorInterface*>(parent_) !=
|
||||
stream->parent()->implementation()) {
|
||||
return tsl::errors::Internal(
|
||||
|
|
@ -5104,8 +5094,7 @@ tsl::Status CudnnSupport::DoConvolve(
|
|||
auto accumulator_type = GetConvAccumulatorType(element_type);
|
||||
CudnnConvolutionDescriptor conv(convolution_descriptor,
|
||||
ToCudnnDataType(accumulator_type));
|
||||
TF_ASSIGN_OR_RETURN(bool use_tensor_ops,
|
||||
UseTensorOps(stream, element_type, algorithm_desc));
|
||||
bool use_tensor_ops = UseTensorOps(element_type, algorithm_desc);
|
||||
conv.set_use_tensor_op_math(use_tensor_ops);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
|
|
@ -5424,13 +5413,18 @@ tsl::Status CreateOpRunners(
|
|||
std::vector<std::unique_ptr<const dnn::OpRunner<Sig>>>* out_runners,
|
||||
bool need_side_input, const NumericOptions& numeric_options) {
|
||||
cudnn_frontend::EngineConfigList filtered_configs;
|
||||
const bool disable_winograd = !CudnnEnvVar<WinogradNonfused>::IsEnabled();
|
||||
const bool disable_nondeterminism = RequireCudnnDeterminism(numeric_options);
|
||||
const bool disable_tensor_core =
|
||||
!IsTensorMathEnabled(stream, input_type, numeric_options.allow_tf32);
|
||||
auto generic_filter_fn = [=](cudnnBackendDescriptor_t engine_config) -> bool {
|
||||
return GenericEngineFilter(
|
||||
engine_config,
|
||||
/*disable_winograd*/ !CudnnEnvVar<WinogradNonfused>::IsEnabled(),
|
||||
/*disable_nondeterminism*/ RequireCudnnDeterminism(numeric_options),
|
||||
/*disable_tensor_core*/ !IsTensorMathEnabled(stream, input_type));
|
||||
return GenericEngineFilter(engine_config, disable_winograd,
|
||||
disable_nondeterminism, disable_tensor_core);
|
||||
};
|
||||
VLOG(4) << "Filtering engine configs with disable_winograd="
|
||||
<< disable_winograd
|
||||
<< ", disable_nondeterminism=" << disable_nondeterminism
|
||||
<< ", disable_tensor_core=" << disable_tensor_core;
|
||||
|
||||
std::array<std::string, 1> heur_mode = {use_fallback ? "heuristics_fallback"
|
||||
: "heuristics_mode_b"};
|
||||
|
|
@ -6132,8 +6126,8 @@ bool CudnnSupport::GetConvolveAlgorithms(
|
|||
std::vector<dnn::AlgorithmDesc>* out_algorithms) {
|
||||
PreloadCudnnSubLibs(PreloadCudnnType::ConvFwd);
|
||||
|
||||
bool tensor_op_math_available =
|
||||
IsTensorMathEnabled(cuda_compute_capability, input_type);
|
||||
bool tensor_op_math_available = IsTensorMathEnabled(
|
||||
cuda_compute_capability, input_type, numeric_options.allow_tf32);
|
||||
out_algorithms->clear();
|
||||
|
||||
std::vector<dnn::AlgorithmDesc::Index> algo_types;
|
||||
|
|
@ -6353,8 +6347,8 @@ bool CudnnSupport::GetConvolveBackwardDataAlgorithms(
|
|||
std::vector<dnn::AlgorithmDesc>* out_algorithms) {
|
||||
PreloadCudnnSubLibs(PreloadCudnnType::ConvBwdData);
|
||||
|
||||
bool tensor_op_math_available =
|
||||
IsTensorMathEnabled(cuda_compute_capability, input_type);
|
||||
bool tensor_op_math_available = IsTensorMathEnabled(
|
||||
cuda_compute_capability, input_type, numeric_options.allow_tf32);
|
||||
out_algorithms->clear();
|
||||
|
||||
std::vector<dnn::AlgorithmDesc::Index> algo_types = {
|
||||
|
|
@ -6389,8 +6383,8 @@ bool CudnnSupport::GetConvolveBackwardFilterAlgorithms(
|
|||
std::vector<dnn::AlgorithmDesc>* out_algorithms) {
|
||||
PreloadCudnnSubLibs(PreloadCudnnType::ConvBwdFilter);
|
||||
|
||||
bool tensor_op_math_available =
|
||||
IsTensorMathEnabled(cuda_compute_capability, input_type);
|
||||
bool tensor_op_math_available = IsTensorMathEnabled(
|
||||
cuda_compute_capability, input_type, numeric_options.allow_tf32);
|
||||
out_algorithms->clear();
|
||||
|
||||
std::vector<dnn::AlgorithmDesc::Index> algo_types = {
|
||||
|
|
@ -7090,8 +7084,7 @@ bool CudnnSupport::DoMatMul(Stream* stream,
|
|||
if (!stream
|
||||
->ThenBlasGemm(blas::Transpose::kNoTranspose,
|
||||
blas::Transpose::kNoTranspose, m, n, k, weights, m,
|
||||
input_data, k, output_data, m,
|
||||
blas::kDefaultComputePrecision)
|
||||
input_data, k, output_data, m, NumericOptions{})
|
||||
.ok()) {
|
||||
return false;
|
||||
}
|
||||
|
|
@ -7174,7 +7167,7 @@ bool CudnnSupport::DoMatMul(Stream* stream,
|
|||
stream->ThenBlasGemmBatched(blas::Transpose::kNoTranspose,
|
||||
blas::Transpose::kNoTranspose, m, n, k, alpha,
|
||||
toPtrs(a), lda, toPtrs(b), ldb, beta, toPtrs(c),
|
||||
ldc, batch_count);
|
||||
ldc, batch_count, NumericOptions{});
|
||||
}
|
||||
|
||||
return stream->ok();
|
||||
|
|
@ -7316,7 +7309,7 @@ tsl::Status CudnnSupport::DoPoolForward(
|
|||
const dnn::BatchDescriptor& output_dimensions, DeviceMemoryBase output_data,
|
||||
ScratchAllocator* workspace_allocator) {
|
||||
return DoPoolForward(element_type, stream, pooling_dimensions,
|
||||
NumericOptions{false}, input_dimensions, input_data,
|
||||
NumericOptions{}, input_dimensions, input_data,
|
||||
output_dimensions, output_data, workspace_allocator);
|
||||
}
|
||||
|
||||
|
|
@ -7396,7 +7389,7 @@ tsl::Status CudnnSupport::DoPoolBackward(
|
|||
DeviceMemoryBase input_diff_data, DeviceMemoryBase output_diff_data,
|
||||
ScratchAllocator* workspace_allocator) {
|
||||
return DoPoolBackward(element_type, stream, pooling_dimensions,
|
||||
NumericOptions{false}, input_dimensions, input_data,
|
||||
NumericOptions{}, input_dimensions, input_data,
|
||||
output_dimensions, output_data, input_diff_data,
|
||||
output_diff_data, workspace_allocator);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -62,8 +62,8 @@ class CudnnSupport : public dnn::DnnSupport {
|
|||
int batch_size, dnn::RnnInputMode input_mode,
|
||||
dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode,
|
||||
dnn::DataType data_type, const dnn::AlgorithmConfig& algorithm_config,
|
||||
float dropout, uint64_t seed, ScratchAllocator* state_allocator,
|
||||
bool use_padded_io) override;
|
||||
const NumericOptions& numeric_options, float dropout, uint64_t seed,
|
||||
ScratchAllocator* state_allocator, bool use_padded_io) override;
|
||||
|
||||
tsl::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
|
||||
createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size,
|
||||
|
|
|
|||
|
|
@ -2292,8 +2292,9 @@ class DnnSupport {
|
|||
dnn::RnnDirectionMode direction_mode,
|
||||
dnn::RnnMode rnn_mode, dnn::DataType data_type,
|
||||
const dnn::AlgorithmConfig& algorithm_config,
|
||||
float dropout, uint64_t seed,
|
||||
ScratchAllocator* state_allocator, bool use_padded_io) {
|
||||
const NumericOptions& numeric_options, float dropout,
|
||||
uint64_t seed, ScratchAllocator* state_allocator,
|
||||
bool use_padded_io) {
|
||||
return tsl::Status(absl::StatusCode::kUnimplemented,
|
||||
"createRnnDescriptor is unimplemented");
|
||||
}
|
||||
|
|
|
|||
|
|
@ -21,10 +21,15 @@ namespace stream_executor {
|
|||
// Options that specify the numeric behavior of operations like matrix
|
||||
// multiplications and convolutions
|
||||
struct NumericOptions {
|
||||
explicit NumericOptions(bool require_determinism)
|
||||
: require_determinism(require_determinism) {}
|
||||
NumericOptions(bool require_determinism, bool allow_tf32)
|
||||
: require_determinism(require_determinism), allow_tf32(allow_tf32) {}
|
||||
|
||||
NumericOptions() : require_determinism(false), allow_tf32(true) {}
|
||||
|
||||
// If true, the op must be deterministic
|
||||
bool require_determinism;
|
||||
// If true, float32 inputs can be rounded to TensorFloat-32 precision
|
||||
bool allow_tf32;
|
||||
};
|
||||
|
||||
} // namespace stream_executor
|
||||
|
|
|
|||
|
|
@ -406,7 +406,7 @@ tsl::Status ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
|
|||
const void *alpha, const DeviceMemoryBase &a,
|
||||
int lda, const DeviceMemoryBase &b, int ldb,
|
||||
const void *beta, DeviceMemoryBase *c, int ldc,
|
||||
blas::ComputePrecision precision) {
|
||||
const NumericOptions &numeric_options) {
|
||||
blas_log("DoBlasGemm");
|
||||
VLOG(1) << absl::StreamFormat(
|
||||
"doing rocBLAS GEMM: at=%d bt=%d m=%u n=%u "
|
||||
|
|
@ -529,7 +529,7 @@ tsl::Status ROCMBlas::DoBlasGemmWithAlgorithm(
|
|||
blas::DataType type_a, int lda, const DeviceMemoryBase &b,
|
||||
blas::DataType type_b, int ldb, const void *beta, DeviceMemoryBase *c,
|
||||
blas::DataType type_c, int ldc, blas::ComputationType computation_type,
|
||||
blas::AlgorithmType algorithm, blas::ComputePrecision precision,
|
||||
blas::AlgorithmType algorithm, const NumericOptions &numeric_options,
|
||||
blas::ProfileResult *output_profile_result) {
|
||||
// ROCM TODO: properly implement the interface
|
||||
return tsl::errors::Internal("DoBlasGemmWithAlgorithm ",
|
||||
|
|
@ -543,7 +543,7 @@ tsl::Status ROCMBlas::DoBlasGemmStridedBatchedWithAlgorithm(
|
|||
blas::DataType type_b, int ldb, int64_t stride_b, const void *beta,
|
||||
DeviceMemoryBase *c, blas::DataType type_c, int ldc, int64_t stride_c,
|
||||
int batch_count, blas::ComputationType computation_type,
|
||||
blas::AlgorithmType algorithm, blas::ComputePrecision precision,
|
||||
blas::AlgorithmType algorithm, const NumericOptions &numeric_options,
|
||||
blas::ProfileResult *output_profile_result) {
|
||||
// ROCM TODO: properly implement the interface
|
||||
return tsl::errors::Internal("DoBlasGemmStridedBatchedWithAlgorithm ",
|
||||
|
|
@ -842,6 +842,7 @@ bool ROCMBlas::DoBlasGemmBatched(Stream *stream, blas::Transpose transa,
|
|||
DeviceMemorySlice<Eigen::half> b, int ldb,
|
||||
float beta, DeviceMemorySlice<Eigen::half> c,
|
||||
int ldc, int batch_count,
|
||||
const NumericOptions &numeric_options,
|
||||
ScratchAllocator *scratch_allocator) {
|
||||
blas_log("DoBlasGemmBatched");
|
||||
const Eigen::half alpha_half(alpha);
|
||||
|
|
@ -864,6 +865,7 @@ bool ROCMBlas::DoBlasGemmBatched(
|
|||
DeviceMemorySlice<Eigen::bfloat16> a_array, int lda,
|
||||
DeviceMemorySlice<Eigen::bfloat16> b_array, int ldb, float beta,
|
||||
DeviceMemorySlice<Eigen::bfloat16> c_array, int ldc, int batch_count,
|
||||
const NumericOptions &numeric_options,
|
||||
ScratchAllocator *scratch_allocator) {
|
||||
blas_log("DoBlasGemmBatched");
|
||||
const Eigen::bfloat16 alpha_bf16(alpha);
|
||||
|
|
@ -886,6 +888,7 @@ bool ROCMBlas::DoBlasGemmBatched(Stream *stream, blas::Transpose transa,
|
|||
DeviceMemorySlice<float> b_array, int ldb,
|
||||
float beta, DeviceMemorySlice<float> c_array,
|
||||
int ldc, int batch_count,
|
||||
const NumericOptions &numeric_options,
|
||||
ScratchAllocator *scratch_allocator) {
|
||||
blas_log("DoBlasGemmBatched");
|
||||
tsl::Status status = DoBlasGemmBatchedInternal(
|
||||
|
|
@ -905,6 +908,7 @@ bool ROCMBlas::DoBlasGemmBatched(Stream *stream, blas::Transpose transa,
|
|||
DeviceMemorySlice<double> b_array, int ldb,
|
||||
double beta, DeviceMemorySlice<double> c_array,
|
||||
int ldc, int batch_count,
|
||||
const NumericOptions &numeric_options,
|
||||
ScratchAllocator *scratch_allocator) {
|
||||
blas_log("DoBlasGemmBatched");
|
||||
tsl::Status status = DoBlasGemmBatchedInternal(
|
||||
|
|
@ -923,7 +927,8 @@ bool ROCMBlas::DoBlasGemmBatched(
|
|||
DeviceMemorySlice<std::complex<float>> a_array, int lda,
|
||||
DeviceMemorySlice<std::complex<float>> b_array, int ldb,
|
||||
std::complex<float> beta, DeviceMemorySlice<std::complex<float>> c_array,
|
||||
int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
|
||||
int ldc, int batch_count, const NumericOptions &numeric_options,
|
||||
ScratchAllocator *scratch_allocator) {
|
||||
blas_log("DoBlasGemmBatched");
|
||||
tsl::Status status = DoBlasGemmBatchedInternal(
|
||||
wrap::rocblas_cgemm_strided_batched, stream, transa, transb, m, n, k,
|
||||
|
|
@ -941,7 +946,8 @@ bool ROCMBlas::DoBlasGemmBatched(
|
|||
DeviceMemorySlice<std::complex<double>> a_array, int lda,
|
||||
DeviceMemorySlice<std::complex<double>> b_array, int ldb,
|
||||
std::complex<double> beta, DeviceMemorySlice<std::complex<double>> c_array,
|
||||
int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
|
||||
int ldc, int batch_count, const NumericOptions &numeric_options,
|
||||
ScratchAllocator *scratch_allocator) {
|
||||
blas_log("DoBlasGemmBatched");
|
||||
tsl::Status status = DoBlasGemmBatchedInternal(
|
||||
wrap::rocblas_zgemm_strided_batched, stream, transa, transb, m, n, k,
|
||||
|
|
@ -1071,7 +1077,7 @@ tsl::Status ROCMBlas::DoBlasGemmStridedBatched(
|
|||
const DeviceMemoryBase &a, int lda, int64_t stride_a,
|
||||
const DeviceMemoryBase &b, int ldb, int64_t stride_b, const void *beta,
|
||||
DeviceMemoryBase *c, int ldc, int64_t stride_c, int batch_count,
|
||||
blas::ComputePrecision precision) {
|
||||
const NumericOptions &numeric_options) {
|
||||
VLOG(1) << absl::StreamFormat(
|
||||
"doing rocBLAS SGEMM Strided Batched<float>: at=%d bt=%d m=%u n=%u "
|
||||
"k=%llu alpha=%p a=%p lda=%d b=%p ldb=%d beta=%p "
|
||||
|
|
|
|||
|
|
@ -2691,8 +2691,8 @@ MIOpenSupport::createRnnDescriptor(
|
|||
int batch_size, dnn::RnnInputMode input_mode,
|
||||
dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode,
|
||||
dnn::DataType data_type, const dnn::AlgorithmConfig& algorithm_config,
|
||||
float dropout, uint64_t seed, ScratchAllocator* state_allocator,
|
||||
bool use_padded_io) {
|
||||
const NumericOptions& numeric_options, float dropout, uint64_t seed,
|
||||
ScratchAllocator* state_allocator, bool use_padded_io) {
|
||||
// ROCM TODO: batch_size is used in dynamic persistent RNN algorithm and is
|
||||
// not supported by MIOpen now.
|
||||
if (use_padded_io) {
|
||||
|
|
@ -3946,8 +3946,7 @@ bool MIOpenSupport::DoMatMul(Stream* stream,
|
|||
if (!stream
|
||||
->ThenBlasGemm(blas::Transpose::kNoTranspose,
|
||||
blas::Transpose::kNoTranspose, m, n, k, weights, m,
|
||||
input_data, k, output_data, m,
|
||||
blas::kDefaultComputePrecision)
|
||||
input_data, k, output_data, m, NumericOptions{})
|
||||
.ok()) {
|
||||
return false;
|
||||
}
|
||||
|
|
@ -4030,7 +4029,7 @@ bool MIOpenSupport::DoMatMul(Stream* stream,
|
|||
stream->ThenBlasGemmBatched(blas::Transpose::kNoTranspose,
|
||||
blas::Transpose::kNoTranspose, m, n, k, alpha,
|
||||
toPtrs(a), lda, toPtrs(b), ldb, beta, toPtrs(c),
|
||||
ldc, batch_count);
|
||||
ldc, batch_count, NumericOptions{});
|
||||
}
|
||||
|
||||
return stream->ok();
|
||||
|
|
|
|||
|
|
@ -85,8 +85,8 @@ class MIOpenSupport : public dnn::DnnSupport {
|
|||
int batch_size, dnn::RnnInputMode input_mode,
|
||||
dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode,
|
||||
dnn::DataType data_type, const dnn::AlgorithmConfig& algorithm_config,
|
||||
float dropout, uint64_t seed, ScratchAllocator* state_allocator,
|
||||
bool use_padded_io) override;
|
||||
const NumericOptions& numeric_options, float dropout, uint64_t seed,
|
||||
ScratchAllocator* state_allocator, bool use_padded_io) override;
|
||||
|
||||
tsl::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
|
||||
createRnnSequenceTensorDescriptor(int seq_length, int batch_size,
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ limitations under the License.
|
|||
#include "absl/strings/str_cat.h"
|
||||
#include "third_party/eigen3/Eigen/Core"
|
||||
#include "tensorflow/compiler/xla/stream_executor/blas.h"
|
||||
#include "tensorflow/compiler/xla/stream_executor/numeric_options.h"
|
||||
#include "tensorflow/compiler/xla/stream_executor/platform.h"
|
||||
#include "tensorflow/compiler/xla/stream_executor/platform/logging.h"
|
||||
#include "tensorflow/compiler/xla/stream_executor/platform/port.h"
|
||||
|
|
@ -1588,9 +1589,11 @@ Stream &Stream::ThenBlasGemmBatched(
|
|||
blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
|
||||
uint64_t k, float alpha, DeviceMemorySlice<Eigen::half> a, int lda,
|
||||
DeviceMemorySlice<Eigen::half> b, int ldb, float beta,
|
||||
DeviceMemorySlice<Eigen::half> c, int ldc, int batch_count) {
|
||||
DeviceMemorySlice<Eigen::half> c, int ldc, int batch_count,
|
||||
const NumericOptions &numeric_options) {
|
||||
return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
|
||||
b, ldb, beta, c, ldc, batch_count,
|
||||
numeric_options,
|
||||
/*scratch_allocator=*/nullptr);
|
||||
}
|
||||
|
||||
|
|
@ -1599,6 +1602,7 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch(
|
|||
uint64_t k, float alpha, DeviceMemorySlice<Eigen::half> a, int lda,
|
||||
DeviceMemorySlice<Eigen::half> b, int ldb, float beta,
|
||||
DeviceMemorySlice<Eigen::half> c, int ldc, int batch_count,
|
||||
const NumericOptions &numeric_options,
|
||||
ScratchAllocator *scratch_allocator) {
|
||||
VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
|
||||
PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
|
||||
|
|
@ -1607,11 +1611,12 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch(
|
|||
ThenBlasImpl<blas::Transpose, blas::Transpose, uint64_t, uint64_t, uint64,
|
||||
float, DeviceMemorySlice<Eigen::half>, int,
|
||||
DeviceMemorySlice<Eigen::half>, int, float,
|
||||
DeviceMemorySlice<Eigen::half>, int, int, ScratchAllocator *>
|
||||
DeviceMemorySlice<Eigen::half>, int, int, const NumericOptions &,
|
||||
ScratchAllocator *>
|
||||
impl;
|
||||
return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
|
||||
k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
|
||||
scratch_allocator);
|
||||
numeric_options, scratch_allocator);
|
||||
}
|
||||
|
||||
Stream &Stream::ThenBlasGemmBatchedWithScratch(
|
||||
|
|
@ -1619,6 +1624,7 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch(
|
|||
uint64_t k, float alpha, DeviceMemorySlice<Eigen::bfloat16> a, int lda,
|
||||
DeviceMemorySlice<Eigen::bfloat16> b, int ldb, float beta,
|
||||
DeviceMemorySlice<Eigen::bfloat16> c, int ldc, int batch_count,
|
||||
const NumericOptions &numeric_options,
|
||||
ScratchAllocator *scratch_allocator) {
|
||||
VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
|
||||
PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
|
||||
|
|
@ -1627,22 +1633,22 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch(
|
|||
ThenBlasImpl<blas::Transpose, blas::Transpose, uint64_t, uint64_t, uint64,
|
||||
float, DeviceMemorySlice<Eigen::bfloat16>, int,
|
||||
DeviceMemorySlice<Eigen::bfloat16>, int, float,
|
||||
DeviceMemorySlice<Eigen::bfloat16>, int, int, ScratchAllocator *>
|
||||
DeviceMemorySlice<Eigen::bfloat16>, int, int,
|
||||
const NumericOptions &, ScratchAllocator *>
|
||||
impl;
|
||||
return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
|
||||
k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
|
||||
scratch_allocator);
|
||||
numeric_options, scratch_allocator);
|
||||
}
|
||||
|
||||
Stream &Stream::ThenBlasGemmBatched(blas::Transpose transa,
|
||||
blas::Transpose transb, uint64_t m,
|
||||
uint64 n, uint64_t k, float alpha,
|
||||
DeviceMemorySlice<float> a, int lda,
|
||||
DeviceMemorySlice<float> b, int ldb,
|
||||
float beta, DeviceMemorySlice<float> c,
|
||||
int ldc, int batch_count) {
|
||||
Stream &Stream::ThenBlasGemmBatched(
|
||||
blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
|
||||
uint64_t k, float alpha, DeviceMemorySlice<float> a, int lda,
|
||||
DeviceMemorySlice<float> b, int ldb, float beta, DeviceMemorySlice<float> c,
|
||||
int ldc, int batch_count, const NumericOptions &numeric_options) {
|
||||
return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
|
||||
b, ldb, beta, c, ldc, batch_count,
|
||||
numeric_options,
|
||||
/*scratch_allocator=*/nullptr);
|
||||
}
|
||||
|
||||
|
|
@ -1650,7 +1656,8 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch(
|
|||
blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
|
||||
uint64_t k, float alpha, DeviceMemorySlice<float> a, int lda,
|
||||
DeviceMemorySlice<float> b, int ldb, float beta, DeviceMemorySlice<float> c,
|
||||
int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
|
||||
int ldc, int batch_count, const NumericOptions &numeric_options,
|
||||
ScratchAllocator *scratch_allocator) {
|
||||
VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
|
||||
PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
|
||||
PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
|
||||
|
|
@ -1658,11 +1665,11 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch(
|
|||
ThenBlasImpl<blas::Transpose, blas::Transpose, uint64_t, uint64_t, uint64,
|
||||
float, DeviceMemorySlice<float>, int, DeviceMemorySlice<float>,
|
||||
int, float, DeviceMemorySlice<float>, int, int,
|
||||
ScratchAllocator *>
|
||||
const NumericOptions &, ScratchAllocator *>
|
||||
impl;
|
||||
return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
|
||||
k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
|
||||
scratch_allocator);
|
||||
numeric_options, scratch_allocator);
|
||||
}
|
||||
|
||||
Stream &Stream::ThenBlasGemmBatched(blas::Transpose transa,
|
||||
|
|
@ -1671,9 +1678,11 @@ Stream &Stream::ThenBlasGemmBatched(blas::Transpose transa,
|
|||
DeviceMemorySlice<double> a, int lda,
|
||||
DeviceMemorySlice<double> b, int ldb,
|
||||
double beta, DeviceMemorySlice<double> c,
|
||||
int ldc, int batch_count) {
|
||||
int ldc, int batch_count,
|
||||
const NumericOptions &numeric_options) {
|
||||
return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
|
||||
b, ldb, beta, c, ldc, batch_count,
|
||||
numeric_options,
|
||||
/*scratch_allocator=*/nullptr);
|
||||
}
|
||||
|
||||
|
|
@ -1682,6 +1691,7 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch(
|
|||
uint64_t k, double alpha, DeviceMemorySlice<double> a, int lda,
|
||||
DeviceMemorySlice<double> b, int ldb, double beta,
|
||||
DeviceMemorySlice<double> c, int ldc, int batch_count,
|
||||
const NumericOptions &numeric_options,
|
||||
ScratchAllocator *scratch_allocator) {
|
||||
VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
|
||||
PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
|
||||
|
|
@ -1690,11 +1700,12 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch(
|
|||
ThenBlasImpl<blas::Transpose, blas::Transpose, uint64_t, uint64_t, uint64,
|
||||
double, DeviceMemorySlice<double>, int,
|
||||
DeviceMemorySlice<double>, int, double,
|
||||
DeviceMemorySlice<double>, int, int, ScratchAllocator *>
|
||||
DeviceMemorySlice<double>, int, int, const NumericOptions &,
|
||||
ScratchAllocator *>
|
||||
impl;
|
||||
return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
|
||||
k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
|
||||
scratch_allocator);
|
||||
numeric_options, scratch_allocator);
|
||||
}
|
||||
|
||||
Stream &Stream::ThenBlasGemmBatched(
|
||||
|
|
@ -1702,9 +1713,11 @@ Stream &Stream::ThenBlasGemmBatched(
|
|||
uint64_t k, std::complex<float> alpha,
|
||||
DeviceMemorySlice<std::complex<float>> a, int lda,
|
||||
DeviceMemorySlice<std::complex<float>> b, int ldb, std::complex<float> beta,
|
||||
DeviceMemorySlice<std::complex<float>> c, int ldc, int batch_count) {
|
||||
DeviceMemorySlice<std::complex<float>> c, int ldc, int batch_count,
|
||||
const NumericOptions &numeric_options) {
|
||||
return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
|
||||
b, ldb, beta, c, ldc, batch_count,
|
||||
numeric_options,
|
||||
/*scratch_allocator=*/nullptr);
|
||||
}
|
||||
|
||||
|
|
@ -1714,6 +1727,7 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch(
|
|||
DeviceMemorySlice<std::complex<float>> a, int lda,
|
||||
DeviceMemorySlice<std::complex<float>> b, int ldb, std::complex<float> beta,
|
||||
DeviceMemorySlice<std::complex<float>> c, int ldc, int batch_count,
|
||||
const NumericOptions &numeric_options,
|
||||
ScratchAllocator *scratch_allocator) {
|
||||
VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
|
||||
PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
|
||||
|
|
@ -1723,11 +1737,11 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch(
|
|||
std::complex<float>, DeviceMemorySlice<std::complex<float>>, int,
|
||||
DeviceMemorySlice<std::complex<float>>, int, std::complex<float>,
|
||||
DeviceMemorySlice<std::complex<float>>, int, int,
|
||||
ScratchAllocator *>
|
||||
const NumericOptions &, ScratchAllocator *>
|
||||
impl;
|
||||
return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
|
||||
k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
|
||||
scratch_allocator);
|
||||
numeric_options, scratch_allocator);
|
||||
}
|
||||
|
||||
Stream &Stream::ThenBlasGemmBatched(
|
||||
|
|
@ -1736,9 +1750,10 @@ Stream &Stream::ThenBlasGemmBatched(
|
|||
DeviceMemorySlice<std::complex<double>> a, int lda,
|
||||
DeviceMemorySlice<std::complex<double>> b, int ldb,
|
||||
std::complex<double> beta, DeviceMemorySlice<std::complex<double>> c,
|
||||
int ldc, int batch_count) {
|
||||
int ldc, int batch_count, const NumericOptions &numeric_options) {
|
||||
return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
|
||||
b, ldb, beta, c, ldc, batch_count,
|
||||
numeric_options,
|
||||
/*scratch_allocator=*/nullptr);
|
||||
}
|
||||
|
||||
|
|
@ -1748,7 +1763,8 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch(
|
|||
DeviceMemorySlice<std::complex<double>> a, int lda,
|
||||
DeviceMemorySlice<std::complex<double>> b, int ldb,
|
||||
std::complex<double> beta, DeviceMemorySlice<std::complex<double>> c,
|
||||
int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
|
||||
int ldc, int batch_count, const NumericOptions &numeric_options,
|
||||
ScratchAllocator *scratch_allocator) {
|
||||
VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
|
||||
PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
|
||||
PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
|
||||
|
|
@ -1757,11 +1773,11 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch(
|
|||
std::complex<double>, DeviceMemorySlice<std::complex<double>>,
|
||||
int, DeviceMemorySlice<std::complex<double>>, int,
|
||||
std::complex<double>, DeviceMemorySlice<std::complex<double>>,
|
||||
int, int, ScratchAllocator *>
|
||||
int, int, const NumericOptions &, ScratchAllocator *>
|
||||
impl;
|
||||
return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
|
||||
k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
|
||||
scratch_allocator);
|
||||
numeric_options, scratch_allocator);
|
||||
}
|
||||
|
||||
Stream &Stream::ThenSetRngSeed(const uint8 *seed, uint64_t seed_bytes) {
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ limitations under the License.
|
|||
#include <optional>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/base/thread_annotations.h"
|
||||
#include "absl/functional/any_invocable.h"
|
||||
|
|
@ -893,14 +894,15 @@ class Stream {
|
|||
const DeviceMemory<InputType> &a, int lda,
|
||||
const DeviceMemory<InputType> &b, int ldb,
|
||||
DeviceMemory<OutputType> *c, int ldc,
|
||||
blas::ComputePrecision precision) {
|
||||
const NumericOptions &numeric_options) {
|
||||
InputType alpha{1.0};
|
||||
InputType beta{0.0};
|
||||
return ThenBlasGemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c,
|
||||
ldc, precision);
|
||||
ldc, numeric_options);
|
||||
}
|
||||
|
||||
// TODO(parkers): Update all callers to pass kDefaultComputePrecision.
|
||||
// TODO(reedwm): Update all callers (if there are any) to pass correct
|
||||
// NumericOptions.
|
||||
template <typename InputType, typename OutputType>
|
||||
tsl::Status ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
|
||||
uint64_t m, uint64 n, uint64 k,
|
||||
|
|
@ -908,7 +910,7 @@ class Stream {
|
|||
const DeviceMemory<InputType> &b, int ldb,
|
||||
DeviceMemory<OutputType> *c, int ldc) {
|
||||
return ThenBlasGemm(transa, transb, m, n, k, a, lda, b, ldb, c, ldc,
|
||||
blas::kDefaultComputePrecision);
|
||||
NumericOptions{});
|
||||
}
|
||||
|
||||
template <typename InputType, typename OutputType, typename ConstantType>
|
||||
|
|
@ -917,7 +919,7 @@ class Stream {
|
|||
const DeviceMemory<InputType> &a, int lda,
|
||||
const DeviceMemory<InputType> &b, int ldb,
|
||||
ConstantType beta, DeviceMemory<OutputType> *c,
|
||||
int ldc, blas::ComputePrecision precision) {
|
||||
int ldc, const NumericOptions &numeric_options) {
|
||||
static_assert(
|
||||
detail::is_any_of<InputType, int8_t, Eigen::half, Eigen::bfloat16,
|
||||
float, double, std::complex<float>,
|
||||
|
|
@ -948,10 +950,10 @@ class Stream {
|
|||
|
||||
return blas->DoBlasGemm(this, transa, transb, m, n, k,
|
||||
blas::ToDataType<InputType>::value, alpha_ptr, a,
|
||||
lda, b, ldb, beta_ptr, c, ldc, precision);
|
||||
lda, b, ldb, beta_ptr, c, ldc, numeric_options);
|
||||
}
|
||||
|
||||
// TODO(parkers): Update all callers to pass kDefaultComputePrecision.
|
||||
// TODO(reedwm): Update all callers to pass correct NumericOptions.
|
||||
template <typename InputType, typename OutputType, typename ConstantType>
|
||||
tsl::Status ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
|
||||
uint64_t m, uint64 n, uint64 k, ConstantType alpha,
|
||||
|
|
@ -960,7 +962,7 @@ class Stream {
|
|||
ConstantType beta, DeviceMemory<OutputType> *c,
|
||||
int ldc) {
|
||||
return ThenBlasGemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c,
|
||||
ldc, blas::kDefaultComputePrecision);
|
||||
ldc, NumericOptions{});
|
||||
}
|
||||
|
||||
template <typename InputType, typename OutputType>
|
||||
|
|
@ -973,10 +975,9 @@ class Stream {
|
|||
blas::ProfileResult *output_profile_result) {
|
||||
OutputType alpha{1};
|
||||
OutputType beta{0};
|
||||
return ThenBlasGemmWithAlgorithm(transa, transb, m, n, k, alpha, a, lda, b,
|
||||
ldb, beta, c, ldc, computation_type,
|
||||
algorithm, blas::kDefaultComputePrecision,
|
||||
output_profile_result);
|
||||
return ThenBlasGemmWithAlgorithm(
|
||||
transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
|
||||
computation_type, algorithm, NumericOptions{}, output_profile_result);
|
||||
}
|
||||
|
||||
template <typename InputType, typename OutputType, typename ConstantType>
|
||||
|
|
@ -986,7 +987,7 @@ class Stream {
|
|||
const DeviceMemory<InputType> &b, int ldb, ConstantType beta,
|
||||
DeviceMemory<OutputType> *c, int ldc,
|
||||
blas::ComputationType computation_type, blas::AlgorithmType algorithm,
|
||||
blas::ComputePrecision precision,
|
||||
const NumericOptions &numeric_options,
|
||||
blas::ProfileResult *output_profile_result) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
CheckTypesForExtendedBlas<InputType, OutputType, ConstantType>(
|
||||
|
|
@ -1010,7 +1011,7 @@ class Stream {
|
|||
blas::ToDataType<InputType>::value, lda, b,
|
||||
blas::ToDataType<InputType>::value, ldb, beta_ptr, c,
|
||||
blas::ToDataType<OutputType>::value, ldc, computation_type, algorithm,
|
||||
precision, output_profile_result);
|
||||
numeric_options, output_profile_result);
|
||||
if (output_profile_result) {
|
||||
// The error is recorded in the profile.
|
||||
return ::tsl::OkStatus();
|
||||
|
|
@ -1025,7 +1026,7 @@ class Stream {
|
|||
int64_t stride_a, const DeviceMemory<InputType> &b, int ldb,
|
||||
int64_t stride_b, ConstantType beta, DeviceMemory<OutputType> *c, int ldc,
|
||||
int64_t stride_c, int batch_count, blas::ComputationType computation_type,
|
||||
blas::AlgorithmType algorithm, blas::ComputePrecision precision,
|
||||
blas::AlgorithmType algorithm, const NumericOptions &numeric_options,
|
||||
blas::ProfileResult *output_profile_result) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
CheckTypesForExtendedBlas<InputType, OutputType, ConstantType>(
|
||||
|
|
@ -1047,7 +1048,7 @@ class Stream {
|
|||
blas::ToDataType<InputType>::value, lda, stride_a, b,
|
||||
blas::ToDataType<InputType>::value, ldb, stride_b, beta_ptr, c,
|
||||
blas::ToDataType<OutputType>::value, ldc, stride_c, batch_count,
|
||||
computation_type, algorithm, precision, output_profile_result);
|
||||
computation_type, algorithm, numeric_options, output_profile_result);
|
||||
if (output_profile_result) {
|
||||
// The error is recorded in the profile.
|
||||
return ::tsl::OkStatus();
|
||||
|
|
@ -1064,19 +1065,22 @@ class Stream {
|
|||
DeviceMemorySlice<Eigen::half> a, int lda,
|
||||
DeviceMemorySlice<Eigen::half> b, int ldb,
|
||||
float beta, DeviceMemorySlice<Eigen::half> c,
|
||||
int ldc, int batch_count);
|
||||
int ldc, int batch_count,
|
||||
const NumericOptions &numeric_options);
|
||||
Stream &ThenBlasGemmBatched(blas::Transpose transa, blas::Transpose transb,
|
||||
uint64_t m, uint64 n, uint64 k, float alpha,
|
||||
DeviceMemorySlice<float> a, int lda,
|
||||
DeviceMemorySlice<float> b, int ldb, float beta,
|
||||
DeviceMemorySlice<float> c, int ldc,
|
||||
int batch_count);
|
||||
int batch_count,
|
||||
const NumericOptions &numeric_options);
|
||||
Stream &ThenBlasGemmBatched(blas::Transpose transa, blas::Transpose transb,
|
||||
uint64_t m, uint64 n, uint64 k, double alpha,
|
||||
DeviceMemorySlice<double> a, int lda,
|
||||
DeviceMemorySlice<double> b, int ldb, double beta,
|
||||
DeviceMemorySlice<double> c, int ldc,
|
||||
int batch_count);
|
||||
int batch_count,
|
||||
const NumericOptions &numeric_options);
|
||||
Stream &ThenBlasGemmBatched(blas::Transpose transa, blas::Transpose transb,
|
||||
uint64_t m, uint64 n, uint64_t k,
|
||||
std::complex<float> alpha,
|
||||
|
|
@ -1084,27 +1088,28 @@ class Stream {
|
|||
DeviceMemorySlice<std::complex<float>> b, int ldb,
|
||||
std::complex<float> beta,
|
||||
DeviceMemorySlice<std::complex<float>> c, int ldc,
|
||||
int batch_count);
|
||||
Stream &ThenBlasGemmBatched(blas::Transpose transa, blas::Transpose transb,
|
||||
uint64_t m, uint64 n, uint64_t k,
|
||||
std::complex<double> alpha,
|
||||
DeviceMemorySlice<std::complex<double>> a,
|
||||
int lda,
|
||||
DeviceMemorySlice<std::complex<double>> b,
|
||||
int ldb, std::complex<double> beta,
|
||||
DeviceMemorySlice<std::complex<double>> c,
|
||||
int ldc, int batch_count);
|
||||
int batch_count,
|
||||
const NumericOptions &numeric_options);
|
||||
Stream &ThenBlasGemmBatched(
|
||||
blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
|
||||
uint64_t k, std::complex<double> alpha,
|
||||
DeviceMemorySlice<std::complex<double>> a, int lda,
|
||||
DeviceMemorySlice<std::complex<double>> b, int ldb,
|
||||
std::complex<double> beta, DeviceMemorySlice<std::complex<double>> c,
|
||||
int ldc, int batch_count, const NumericOptions &numeric_options);
|
||||
Stream &ThenBlasGemmBatchedWithScratch(
|
||||
blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
|
||||
uint64_t k, float alpha, DeviceMemorySlice<Eigen::half> a, int lda,
|
||||
DeviceMemorySlice<Eigen::half> b, int ldb, float beta,
|
||||
DeviceMemorySlice<Eigen::half> c, int ldc, int batch_count,
|
||||
const NumericOptions &numeric_options,
|
||||
ScratchAllocator *scratch_allocator);
|
||||
Stream &ThenBlasGemmBatchedWithScratch(
|
||||
blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
|
||||
uint64_t k, float alpha, DeviceMemorySlice<Eigen::bfloat16> a, int lda,
|
||||
DeviceMemorySlice<Eigen::bfloat16> b, int ldb, float beta,
|
||||
DeviceMemorySlice<Eigen::bfloat16> c, int ldc, int batch_count,
|
||||
const NumericOptions &numeric_options,
|
||||
ScratchAllocator *scratch_allocator);
|
||||
Stream &ThenBlasGemmBatchedWithScratch(blas::Transpose transa,
|
||||
blas::Transpose transb, uint64_t m,
|
||||
|
|
@ -1113,12 +1118,14 @@ class Stream {
|
|||
DeviceMemorySlice<float> b, int ldb,
|
||||
float beta, DeviceMemorySlice<float> c,
|
||||
int ldc, int batch_count,
|
||||
const NumericOptions &numeric_options,
|
||||
ScratchAllocator *scratch_allocator);
|
||||
Stream &ThenBlasGemmBatchedWithScratch(
|
||||
blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
|
||||
uint64_t k, double alpha, DeviceMemorySlice<double> a, int lda,
|
||||
DeviceMemorySlice<double> b, int ldb, double beta,
|
||||
DeviceMemorySlice<double> c, int ldc, int batch_count,
|
||||
const NumericOptions &numeric_options,
|
||||
ScratchAllocator *scratch_allocator);
|
||||
Stream &ThenBlasGemmBatchedWithScratch(
|
||||
blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
|
||||
|
|
@ -1126,14 +1133,16 @@ class Stream {
|
|||
DeviceMemorySlice<std::complex<float>> a, int lda,
|
||||
DeviceMemorySlice<std::complex<float>> b, int ldb,
|
||||
std::complex<float> beta, DeviceMemorySlice<std::complex<float>> c,
|
||||
int ldc, int batch_count, ScratchAllocator *scratch_allocator);
|
||||
int ldc, int batch_count, const NumericOptions &numeric_options,
|
||||
ScratchAllocator *scratch_allocator);
|
||||
Stream &ThenBlasGemmBatchedWithScratch(
|
||||
blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
|
||||
uint64_t k, std::complex<double> alpha,
|
||||
DeviceMemorySlice<std::complex<double>> a, int lda,
|
||||
DeviceMemorySlice<std::complex<double>> b, int ldb,
|
||||
std::complex<double> beta, DeviceMemorySlice<std::complex<double>> c,
|
||||
int ldc, int batch_count, ScratchAllocator *scratch_allocator);
|
||||
int ldc, int batch_count, const NumericOptions &numeric_options,
|
||||
ScratchAllocator *scratch_allocator);
|
||||
|
||||
template <typename InputType, typename OutputType, typename ConstantType>
|
||||
tsl::Status ThenBlasGemmStridedBatched(
|
||||
|
|
@ -1141,7 +1150,8 @@ class Stream {
|
|||
uint64_t k, ConstantType alpha, const DeviceMemory<InputType> &a, int lda,
|
||||
int64_t stride_a, const DeviceMemory<InputType> &b, int ldb,
|
||||
int64_t stride_b, ConstantType beta, DeviceMemory<OutputType> *c, int ldc,
|
||||
int64_t stride_c, int batch_count, blas::ComputePrecision precision) {
|
||||
int64_t stride_c, int batch_count,
|
||||
const NumericOptions &numeric_options) {
|
||||
static_assert(
|
||||
detail::is_any_of<InputType, int8_t, float, Eigen::half,
|
||||
Eigen::bfloat16, double, std::complex<float>,
|
||||
|
|
@ -1168,7 +1178,7 @@ class Stream {
|
|||
return blas->DoBlasGemmStridedBatched(
|
||||
this, transa, transb, m, n, k, blas::ToDataType<InputType>::value,
|
||||
alpha_ptr, a, lda, stride_a, b, ldb, stride_b, beta_ptr, c, ldc,
|
||||
stride_c, batch_count, precision);
|
||||
stride_c, batch_count, numeric_options);
|
||||
}
|
||||
|
||||
// See BlasSupport::DoBlasTrsm.
|
||||
|
|
|
|||
|
|
@ -377,8 +377,8 @@ StreamExecutor::createRnnDescriptor(
|
|||
int batch_size, dnn::RnnInputMode input_mode,
|
||||
dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode,
|
||||
dnn::DataType data_type, const dnn::AlgorithmConfig& algorithm_config,
|
||||
float dropout, uint64_t seed, ScratchAllocator* state_allocator,
|
||||
bool use_padded_io) {
|
||||
const NumericOptions& numeric_options, float dropout, uint64_t seed,
|
||||
ScratchAllocator* state_allocator, bool use_padded_io) {
|
||||
dnn::DnnSupport* dnn_support = AsDnn();
|
||||
if (!dnn_support) {
|
||||
return tsl::Status(absl::StatusCode::kUnknown,
|
||||
|
|
@ -386,8 +386,8 @@ StreamExecutor::createRnnDescriptor(
|
|||
}
|
||||
return dnn_support->createRnnDescriptor(
|
||||
num_layers, hidden_size, input_size, cell_size, batch_size, input_mode,
|
||||
direction_mode, rnn_mode, data_type, algorithm_config, dropout, seed,
|
||||
state_allocator, use_padded_io);
|
||||
direction_mode, rnn_mode, data_type, algorithm_config, numeric_options,
|
||||
dropout, seed, state_allocator, use_padded_io);
|
||||
}
|
||||
|
||||
tsl::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ limitations under the License.
|
|||
#include <cstdint>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
|
@ -432,8 +433,8 @@ class StreamExecutor {
|
|||
int batch_size, dnn::RnnInputMode input_mode,
|
||||
dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode,
|
||||
dnn::DataType data_type, const dnn::AlgorithmConfig& algorithm_config,
|
||||
float dropout, uint64_t seed, ScratchAllocator* state_allocator,
|
||||
bool use_padded_io);
|
||||
const NumericOptions& numeric_options, float dropout, uint64_t seed,
|
||||
ScratchAllocator* state_allocator, bool use_padded_io);
|
||||
|
||||
// Create a RNN sequence descriptor that specifies either the input or output
|
||||
// sequence. The caller retains the ownership of the returned descriptor.
|
||||
|
|
|
|||
|
|
@ -1215,8 +1215,6 @@ xla_test(
|
|||
"//tensorflow/compiler/xla/client:local_client",
|
||||
"//tensorflow/compiler/xla/client:padding",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/tsl/platform:tensor_float_32_hdr_lib",
|
||||
"//tensorflow/tsl/platform:tensor_float_32_utils",
|
||||
"//tensorflow/tsl/platform:test",
|
||||
],
|
||||
)
|
||||
|
|
@ -2770,8 +2768,6 @@ xla_test(
|
|||
"//tensorflow/compiler/xla/client/lib:arithmetic",
|
||||
"//tensorflow/compiler/xla/client/lib:matrix",
|
||||
"//tensorflow/tsl/lib/core:status_test_util",
|
||||
"//tensorflow/tsl/platform:tensor_float_32_hdr_lib",
|
||||
"//tensorflow/tsl/platform:tensor_float_32_utils",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -30,7 +30,6 @@ limitations under the License.
|
|||
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/tsl/lib/core/status_test_util.h"
|
||||
#include "tensorflow/tsl/platform/tensor_float_32_utils.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
|
@ -220,8 +219,6 @@ class RandomCholeskyTest
|
|||
public ::testing::WithParamInterface<CholeskyTestCase> {};
|
||||
|
||||
XLA_TEST_P(RandomCholeskyTest, Real) {
|
||||
// Test fails with TensorFloat-32 enabled
|
||||
tsl::enable_tensor_float_32_execution(false);
|
||||
XlaBuilder builder(TestName());
|
||||
|
||||
auto test_params = GetParam();
|
||||
|
|
@ -259,8 +256,6 @@ XLA_TEST_P(RandomCholeskyTest, Real) {
|
|||
}
|
||||
|
||||
XLA_TEST_P(RandomCholeskyTest, Complex) {
|
||||
// Test fails with TensorFloat-32 enabled
|
||||
tsl::enable_tensor_float_32_execution(false);
|
||||
XlaBuilder builder(TestName());
|
||||
|
||||
auto test_params = GetParam();
|
||||
|
|
|
|||
|
|
@ -34,7 +34,6 @@ limitations under the License.
|
|||
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
||||
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/tsl/platform/tensor_float_32_utils.h"
|
||||
#include "tensorflow/tsl/platform/test.h"
|
||||
|
||||
namespace xla {
|
||||
|
|
@ -50,16 +49,16 @@ class ConvolutionVariantsTest : public ClientLibraryTestBase {
|
|||
ErrorSpec error_spec_ = ErrorSpec(1e-4, 1e-2);
|
||||
#endif
|
||||
|
||||
void SetUp() override {
|
||||
init_tf32_status_ = tsl::tensor_float_32_execution_enabled();
|
||||
tsl::enable_tensor_float_32_execution(false);
|
||||
XlaOp ConvWithHighestPrecision(const XlaOp lhs, const XlaOp rhs,
|
||||
absl::Span<const int64_t> window_strides,
|
||||
Padding padding) {
|
||||
PrecisionConfig precision_config;
|
||||
// Set the 2 operands to have the HIGHEST precision.
|
||||
precision_config.add_operand_precision(PrecisionConfig::HIGHEST);
|
||||
precision_config.add_operand_precision(PrecisionConfig::HIGHEST);
|
||||
return Conv(lhs, rhs, window_strides, padding, /*feature_group_count=*/1,
|
||||
/*batch_group_count=*/1, &precision_config);
|
||||
}
|
||||
void TearDown() override {
|
||||
tsl::enable_tensor_float_32_execution(init_tf32_status_);
|
||||
}
|
||||
|
||||
private:
|
||||
bool init_tf32_status_;
|
||||
};
|
||||
|
||||
XLA_TEST_F(ConvolutionVariantsTest, Minimal) {
|
||||
|
|
@ -626,7 +625,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter16x16x1x1Input16x16x1x1) {
|
|||
|
||||
auto input = ConstantR4FromArray4D<float>(&builder, input_array);
|
||||
auto filter = ConstantR4FromArray4D<float>(&builder, filter_array);
|
||||
Conv(input, filter, {1, 1}, Padding::kValid);
|
||||
ConvWithHighestPrecision(input, filter, {1, 1}, Padding::kValid);
|
||||
|
||||
Array4D<float> expected(16, 16, 1, 1);
|
||||
for (int i0 = 0; i0 < 16; ++i0) {
|
||||
|
|
|
|||
|
|
@ -530,8 +530,9 @@ cc_library(
|
|||
hdrs = ["numeric_options_utils.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla/stream_executor:numeric_options",
|
||||
"//tensorflow/core/platform:tensor_float_32_hdr_lib",
|
||||
"//tensorflow/core/util:determinism_for_kernels",
|
||||
],
|
||||
] + if_static(["//tensorflow/core/platform:tensor_float_32_utils"]),
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
|
|
@ -1531,6 +1532,7 @@ tf_kernel_library(
|
|||
deps = [
|
||||
":cast_op",
|
||||
":gpu_utils",
|
||||
":numeric_options_utils",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
|
|
@ -3477,6 +3479,7 @@ tf_kernel_library(
|
|||
":conv_ops_gpu_hdrs",
|
||||
":gpu_utils",
|
||||
":conv_ops",
|
||||
":numeric_options_utils",
|
||||
"//tensorflow/core/protobuf:autotuning_proto_cc",
|
||||
"//tensorflow/core/util/autotune_maps:conv_parameters",
|
||||
"//tensorflow/core/util/proto:proto_utils",
|
||||
|
|
|
|||
|
|
@ -45,6 +45,7 @@ limitations under the License.
|
|||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#include "tensorflow/core/kernels/cast_op.h"
|
||||
#include "tensorflow/core/kernels/numeric_options_utils.h"
|
||||
#include "tensorflow/core/platform/stream_executor.h"
|
||||
using stream_executor::dnn::DimIndex;
|
||||
#include "tensorflow/core/protobuf/autotuning.pb.h"
|
||||
|
|
@ -730,11 +731,10 @@ void LaunchConvBackpropFilterOpImpl(
|
|||
auto c_ptr = AsDeviceMemory(filter_backprop->template flat<T>().data(),
|
||||
filter_backprop->template flat<T>().size());
|
||||
|
||||
OP_REQUIRES_OK(
|
||||
context, stream->ThenBlasGemm(se::blas::Transpose::kNoTranspose,
|
||||
se::blas::Transpose::kTranspose, n, m, k,
|
||||
a_ptr, n, b_ptr, m, &c_ptr, n,
|
||||
se::blas::kDefaultComputePrecision));
|
||||
OP_REQUIRES_OK(context, stream->ThenBlasGemm(
|
||||
se::blas::Transpose::kNoTranspose,
|
||||
se::blas::Transpose::kTranspose, n, m, k, a_ptr,
|
||||
n, b_ptr, m, &c_ptr, n, GetNumericOptions()));
|
||||
return;
|
||||
} else if (!is_grouped_convolution &&
|
||||
dims.filter_size(0) == dims.input_size(0) &&
|
||||
|
|
@ -753,11 +753,10 @@ void LaunchConvBackpropFilterOpImpl(
|
|||
auto c_ptr = AsDeviceMemory(filter_backprop->template flat<T>().data(),
|
||||
filter_backprop->template flat<T>().size());
|
||||
|
||||
OP_REQUIRES_OK(
|
||||
context, stream->ThenBlasGemm(se::blas::Transpose::kNoTranspose,
|
||||
se::blas::Transpose::kTranspose, n, m, k,
|
||||
b_ptr, n, a_ptr, m, &c_ptr, n,
|
||||
se::blas::kDefaultComputePrecision));
|
||||
OP_REQUIRES_OK(context, stream->ThenBlasGemm(
|
||||
se::blas::Transpose::kNoTranspose,
|
||||
se::blas::Transpose::kTranspose, n, m, k, b_ptr,
|
||||
n, a_ptr, m, &c_ptr, n, GetNumericOptions()));
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -48,6 +48,7 @@ limitations under the License.
|
|||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#include "tensorflow/core/kernels/cast_op.h"
|
||||
#include "tensorflow/core/kernels/conv_ops_gpu.h"
|
||||
#include "tensorflow/core/kernels/numeric_options_utils.h"
|
||||
#include "tensorflow/core/platform/stream_executor.h"
|
||||
#include "tensorflow/core/protobuf/autotuning.pb.h"
|
||||
#include "tensorflow/core/util/autotune_maps/conv_parameters.h"
|
||||
|
|
@ -260,11 +261,10 @@ void LaunchConv2DBackpropFilterOpImpl(
|
|||
auto c_ptr = AsDeviceMemory(filter_backprop->template flat<T>().data(),
|
||||
filter_backprop->template flat<T>().size());
|
||||
|
||||
OP_REQUIRES_OK(
|
||||
ctx, stream->ThenBlasGemm(se::blas::Transpose::kNoTranspose,
|
||||
se::blas::Transpose::kTranspose, n, m, k,
|
||||
a_ptr, n, b_ptr, m, &c_ptr, n,
|
||||
se::blas::kDefaultComputePrecision));
|
||||
OP_REQUIRES_OK(ctx, stream->ThenBlasGemm(se::blas::Transpose::kNoTranspose,
|
||||
se::blas::Transpose::kTranspose, n,
|
||||
m, k, a_ptr, n, b_ptr, m, &c_ptr,
|
||||
n, GetNumericOptions()));
|
||||
return;
|
||||
} else if (dims.spatial_dims[0].filter_size ==
|
||||
dims.spatial_dims[0].input_size &&
|
||||
|
|
@ -286,11 +286,10 @@ void LaunchConv2DBackpropFilterOpImpl(
|
|||
auto c_ptr = AsDeviceMemory(filter_backprop->template flat<T>().data(),
|
||||
filter_backprop->template flat<T>().size());
|
||||
|
||||
OP_REQUIRES_OK(
|
||||
ctx, stream->ThenBlasGemm(se::blas::Transpose::kNoTranspose,
|
||||
se::blas::Transpose::kTranspose, n, m, k,
|
||||
b_ptr, n, a_ptr, m, &c_ptr, n,
|
||||
se::blas::kDefaultComputePrecision));
|
||||
OP_REQUIRES_OK(ctx, stream->ThenBlasGemm(se::blas::Transpose::kNoTranspose,
|
||||
se::blas::Transpose::kTranspose, n,
|
||||
m, k, b_ptr, n, a_ptr, m, &c_ptr,
|
||||
n, GetNumericOptions()));
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ limitations under the License.
|
|||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#include "tensorflow/core/kernels/cast_op.h"
|
||||
#include "tensorflow/core/kernels/numeric_options_utils.h"
|
||||
#include "tensorflow/core/protobuf/autotuning.pb.h"
|
||||
#include "tensorflow/core/util/autotune_maps/conv_parameters.h"
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
|
@ -156,9 +157,9 @@ void LaunchConv2DBackpropInputOpGpuImpl(
|
|||
auto transpose = se::blas::Transpose::kTranspose;
|
||||
auto no_transpose = se::blas::Transpose::kNoTranspose;
|
||||
|
||||
OP_REQUIRES_OK(ctx, stream->ThenBlasGemm(
|
||||
transpose, no_transpose, n, m, k, b_ptr, k, a_ptr,
|
||||
k, &c_ptr, n, se::blas::kDefaultComputePrecision));
|
||||
OP_REQUIRES_OK(
|
||||
ctx, stream->ThenBlasGemm(transpose, no_transpose, n, m, k, b_ptr, k,
|
||||
a_ptr, k, &c_ptr, n, GetNumericOptions()));
|
||||
return;
|
||||
} else if (dims.spatial_dims[0].filter_size ==
|
||||
dims.spatial_dims[0].input_size &&
|
||||
|
|
@ -183,9 +184,9 @@ void LaunchConv2DBackpropInputOpGpuImpl(
|
|||
auto transpose = se::blas::Transpose::kTranspose;
|
||||
auto no_transpose = se::blas::Transpose::kNoTranspose;
|
||||
|
||||
OP_REQUIRES_OK(ctx, stream->ThenBlasGemm(
|
||||
transpose, no_transpose, n, m, k, b_ptr, k, a_ptr,
|
||||
k, &c_ptr, n, se::blas::kDefaultComputePrecision));
|
||||
OP_REQUIRES_OK(
|
||||
ctx, stream->ThenBlasGemm(transpose, no_transpose, n, m, k, b_ptr, k,
|
||||
a_ptr, k, &c_ptr, n, GetNumericOptions()));
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -45,6 +45,7 @@ limitations under the License.
|
|||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#include "tensorflow/core/kernels/cast_op.h"
|
||||
#include "tensorflow/core/kernels/numeric_options_utils.h"
|
||||
#include "tensorflow/core/platform/stream_executor.h"
|
||||
using stream_executor::dnn::DimIndex;
|
||||
#include "tensorflow/core/protobuf/autotuning.pb.h"
|
||||
|
|
@ -732,10 +733,9 @@ void LaunchConvBackpropInputOpImpl(
|
|||
auto transpose = se::blas::Transpose::kTranspose;
|
||||
auto no_transpose = se::blas::Transpose::kNoTranspose;
|
||||
|
||||
OP_REQUIRES_OK(
|
||||
context,
|
||||
stream->ThenBlasGemm(transpose, no_transpose, n, m, k, b_ptr, k, a_ptr,
|
||||
k, &c_ptr, n, se::blas::kDefaultComputePrecision));
|
||||
OP_REQUIRES_OK(context, stream->ThenBlasGemm(transpose, no_transpose, n, m,
|
||||
k, b_ptr, k, a_ptr, k, &c_ptr,
|
||||
n, GetNumericOptions()));
|
||||
return;
|
||||
} else if (!is_grouped_convolution &&
|
||||
dims.filter_size(0) == dims.input_size(0) &&
|
||||
|
|
@ -757,10 +757,9 @@ void LaunchConvBackpropInputOpImpl(
|
|||
auto transpose = se::blas::Transpose::kTranspose;
|
||||
auto no_transpose = se::blas::Transpose::kNoTranspose;
|
||||
|
||||
OP_REQUIRES_OK(
|
||||
context,
|
||||
stream->ThenBlasGemm(transpose, no_transpose, n, m, k, b_ptr, k, a_ptr,
|
||||
k, &c_ptr, n, se::blas::kDefaultComputePrecision));
|
||||
OP_REQUIRES_OK(context, stream->ThenBlasGemm(transpose, no_transpose, n, m,
|
||||
k, b_ptr, k, a_ptr, k, &c_ptr,
|
||||
n, GetNumericOptions()));
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ limitations under the License.
|
|||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#include "tensorflow/core/kernels/cast_op.h"
|
||||
#include "tensorflow/core/kernels/numeric_options_utils.h"
|
||||
#include "tensorflow/core/platform/stream_executor.h"
|
||||
#include "tensorflow/core/protobuf/autotuning.pb.h"
|
||||
#include "tensorflow/core/util/autotune_maps/conv_parameters.h"
|
||||
|
|
@ -289,8 +290,7 @@ void LaunchConvOpImpl(OpKernelContext* ctx, bool cudnn_use_autotune,
|
|||
auto no_transpose = se::blas::Transpose::kNoTranspose;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, stream->ThenBlasGemm(no_transpose, no_transpose, n, m, k, b_ptr, n,
|
||||
a_ptr, k, &c_ptr, n,
|
||||
se::blas::kDefaultComputePrecision));
|
||||
a_ptr, k, &c_ptr, n, GetNumericOptions()));
|
||||
return;
|
||||
} else if (!is_grouped_convolution && filter_planes == in_planes &&
|
||||
filter_rows == in_rows && filter_cols == in_cols &&
|
||||
|
|
@ -311,8 +311,7 @@ void LaunchConvOpImpl(OpKernelContext* ctx, bool cudnn_use_autotune,
|
|||
auto no_transpose = se::blas::Transpose::kNoTranspose;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, stream->ThenBlasGemm(no_transpose, no_transpose, n, m, k, b_ptr, n,
|
||||
a_ptr, k, &c_ptr, n,
|
||||
se::blas::kDefaultComputePrecision));
|
||||
a_ptr, k, &c_ptr, n, GetNumericOptions()));
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -64,6 +64,7 @@ limitations under the License.
|
|||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#include "tensorflow/core/kernels/cast_op.h"
|
||||
#include "tensorflow/core/kernels/conv_ops_gpu.h"
|
||||
#include "tensorflow/core/kernels/numeric_options_utils.h"
|
||||
#include "tensorflow/core/platform/stream_executor.h"
|
||||
#include "tensorflow/core/util/autotune_maps/conv_autotune_maps.h"
|
||||
#include "tensorflow/core/util/autotune_maps/conv_parameters.h"
|
||||
|
|
@ -487,8 +488,7 @@ void LaunchConv2DOpImpl(OpKernelContext* ctx, bool use_cudnn,
|
|||
auto no_transpose = se::blas::Transpose::kNoTranspose;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, stream->ThenBlasGemm(no_transpose, no_transpose, n, m, k, b_ptr, n,
|
||||
a_ptr, k, &c_ptr, n,
|
||||
se::blas::kDefaultComputePrecision));
|
||||
a_ptr, k, &c_ptr, n, GetNumericOptions()));
|
||||
return;
|
||||
} else if (patch_rows == in_rows && patch_cols == in_cols &&
|
||||
!is_grouped_convolution && row_dilation == 1 &&
|
||||
|
|
@ -510,8 +510,7 @@ void LaunchConv2DOpImpl(OpKernelContext* ctx, bool use_cudnn,
|
|||
auto no_transpose = se::blas::Transpose::kNoTranspose;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, stream->ThenBlasGemm(no_transpose, no_transpose, n, m, k, b_ptr, n,
|
||||
a_ptr, k, &c_ptr, n,
|
||||
se::blas::kDefaultComputePrecision));
|
||||
a_ptr, k, &c_ptr, n, GetNumericOptions()));
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -51,6 +51,7 @@ limitations under the License.
|
|||
#include "tensorflow/core/util/use_cudnn.h"
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#include "tensorflow/core/kernels/numeric_options_utils.h"
|
||||
#include "tensorflow/core/platform/stream_executor.h"
|
||||
#include "tensorflow/core/util/stream_executor_util.h"
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
|
@ -1296,7 +1297,8 @@ class CudnnRNNKernelCommon : public OpKernel {
|
|||
auto rnn_desc_s = stream->parent()->createRnnDescriptor(
|
||||
num_layers, h_num_units, input_size, /*cell_size=*/c_num_units,
|
||||
/*batch_size=*/0, input_mode, rnn_direction_mode(), rnn_mode(),
|
||||
ToDataType<T>::value, algo_config, dropout(), seed(),
|
||||
ToDataType<T>::value, algo_config, GetNumericOptions(), dropout(),
|
||||
seed(),
|
||||
/* state_allocator=*/nullptr, /*use_padded_io=*/false);
|
||||
if (!rnn_desc_s.ok()) {
|
||||
return FromExecutorStatus(rnn_desc_s);
|
||||
|
|
@ -1321,8 +1323,8 @@ class CudnnRNNKernelCommon : public OpKernel {
|
|||
model_shapes.num_layers, model_shapes.num_units,
|
||||
model_shapes.input_size, model_shapes.cell_num_units,
|
||||
model_shapes.batch_size, input_mode, rnn_direction_mode(), rnn_mode(),
|
||||
data_type, algo_config, dropout(), seed(), dropout_state_allocator,
|
||||
use_padded_io);
|
||||
data_type, algo_config, GetNumericOptions(), dropout(), seed(),
|
||||
dropout_state_allocator, use_padded_io);
|
||||
TF_RETURN_IF_ERROR(rnn_desc_s.status());
|
||||
|
||||
*rnn_desc = std::move(rnn_desc_s).value();
|
||||
|
|
|
|||
|
|
@ -46,6 +46,7 @@ limitations under the License.
|
|||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#include "tensorflow/core/kernels/gpu_utils.h"
|
||||
#include "tensorflow/core/kernels/numeric_options_utils.h"
|
||||
#include "tensorflow/core/platform/stream_executor.h"
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#if GOOGLE_CUDA
|
||||
|
|
@ -525,7 +526,7 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
|
|||
static_cast<Coefficient>(1.0), b_ptrs,
|
||||
adj_y || trans_y ? k : n, a_ptrs, adj_x || trans_x ? m : k,
|
||||
static_cast<Coefficient>(0.0), c_ptrs, n, batch_size,
|
||||
&scratch_allocator)
|
||||
GetNumericOptions(), &scratch_allocator)
|
||||
.ok();
|
||||
if (!blas_launch_status) {
|
||||
context->SetStatus(errors::Internal(
|
||||
|
|
@ -619,12 +620,11 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
|
|||
}
|
||||
}
|
||||
|
||||
OP_REQUIRES_OK(context,
|
||||
stream->ThenBlasGemm(
|
||||
blas_transpose_b, blas_transpose_a, n, m, k,
|
||||
*(b_ptrs[0]), adj_y || trans_y ? k : n, *(a_ptrs[0]),
|
||||
adj_x || trans_x ? m : k, c_ptrs[0], n,
|
||||
se::blas::kDefaultComputePrecision));
|
||||
OP_REQUIRES_OK(context, stream->ThenBlasGemm(
|
||||
blas_transpose_b, blas_transpose_a, n, m, k,
|
||||
*(b_ptrs[0]), adj_y || trans_y ? k : n,
|
||||
*(a_ptrs[0]), adj_x || trans_x ? m : k,
|
||||
c_ptrs[0], n, GetNumericOptions()));
|
||||
} else if (use_strided_batched) {
|
||||
OP_REQUIRES_OK(
|
||||
context, stream->ThenBlasGemmStridedBatched(
|
||||
|
|
@ -633,7 +633,7 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
|
|||
adj_y || trans_y ? k : n, b_stride, *a_ptrs[0],
|
||||
adj_x || trans_x ? m : k, a_stride,
|
||||
static_cast<Coefficient>(0.0), c_ptrs[0], n, c_stride,
|
||||
batch_size, se::blas::kDefaultComputePrecision));
|
||||
batch_size, GetNumericOptions()));
|
||||
} else {
|
||||
BlasScratchAllocator scratch_allocator(context);
|
||||
bool blas_launch_status =
|
||||
|
|
@ -643,7 +643,7 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
|
|||
static_cast<Coefficient>(1.0), b_ptrs,
|
||||
adj_y || trans_y ? k : n, a_ptrs, adj_x || trans_x ? m : k,
|
||||
static_cast<Coefficient>(0.0), c_ptrs, n, batch_size,
|
||||
&scratch_allocator)
|
||||
GetNumericOptions(), &scratch_allocator)
|
||||
.ok();
|
||||
if (!blas_launch_status) {
|
||||
context->SetStatus(errors::Internal(
|
||||
|
|
|
|||
|
|
@ -17,12 +17,15 @@ limitations under the License.
|
|||
#define TENSORFLOW_CORE_KERNELS_NUMERIC_OPTIONS_UTILS_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/stream_executor/numeric_options.h"
|
||||
#include "tensorflow/tsl/platform/tensor_float_32_utils.h"
|
||||
#include "tensorflow/tsl/util/determinism.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
inline stream_executor::NumericOptions GetNumericOptions() {
|
||||
return stream_executor::NumericOptions{tsl::OpDeterminismRequired()};
|
||||
return stream_executor::NumericOptions{
|
||||
/*require_determinism=*/tsl::OpDeterminismRequired(),
|
||||
/*allow_tf32=*/tsl::tensor_float_32_execution_enabled()};
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@ tf_gpu_library(
|
|||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/kernels:eigen_helpers",
|
||||
"//tensorflow/core/kernels:numeric_options_utils",
|
||||
"//tensorflow/core/platform:stream_executor",
|
||||
"//tensorflow/tsl/framework/contraction:eigen_contraction_kernel",
|
||||
"//third_party/eigen3",
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||
#define EIGEN_USE_THREADS
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#include "tensorflow/core/kernels/numeric_options_utils.h"
|
||||
#include "tensorflow/core/platform/stream_executor.h"
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
|
|
@ -53,7 +54,7 @@ void TensorCuBlasGemm<T>::operator()(OpKernelContext* ctx, bool transa,
|
|||
ctx, ctx->op_device_context()->stream()->ThenBlasGemm(
|
||||
trans[transa], trans[transb], m, n, k, static_cast<T>(alpha),
|
||||
a_ptr, lda, b_ptr, ldb, static_cast<T>(beta), &c_ptr, ldc,
|
||||
se::blas::kDefaultComputePrecision));
|
||||
GetNumericOptions()));
|
||||
#else
|
||||
ctx->SetStatus(errors::InvalidArgument("CuBlasGemm needs CUDA."));
|
||||
#endif
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user