diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 88f5618a8e4..2028362b0ae 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -588,6 +588,7 @@ tf_xla_py_strict_test( "//tensorflow/python/eager:def_function", "//tensorflow/python/framework:config", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:test_lib", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:math_ops", "//tensorflow/python/ops:nn_ops", diff --git a/tensorflow/compiler/tests/tensor_float_32_test.py b/tensorflow/compiler/tests/tensor_float_32_test.py index f02b69948f4..68016280c7b 100644 --- a/tensorflow/compiler/tests/tensor_float_32_test.py +++ b/tensorflow/compiler/tests/tensor_float_32_test.py @@ -18,13 +18,14 @@ from tensorflow.compiler.tests import xla_test from tensorflow.python.eager import def_function from tensorflow.python.framework import config from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.platform import googletest -class TensorFloat32ConvTest(xla_test.XLATestCase): +class TensorFloat32Test(xla_test.XLATestCase): def tearDown(self): super().tearDown() @@ -51,6 +52,13 @@ class TensorFloat32ConvTest(xla_test.XLATestCase): # operand_precision is not in HLO if it's the default value. self.assertNotIn('operand_precision', hlo_text) + # On Ampere GPUs and above, which support TF32, test TF32 is used by + # asserting outputs are not close to FP64 result + if test_util.is_gpu_available(min_cuda_compute_capability=(8, 0)): + out = compiled_fn(*inputs) + f64_out = compiled_fn(*[math_ops.cast(x, 'float64') for x in inputs]) + self.assertNotAllClose(out, f64_out, rtol=1e-5, atol=1e-5) + def test_matmul(self): x = array_ops.fill((1024, 1024), 1 + 2**-12) y = array_ops.fill((1024, 1024), 1.0) @@ -70,8 +78,8 @@ class TensorFloat32ConvTest(xla_test.XLATestCase): self._test_fn(batch_matmul, [x, y]) def test_conv2d(self): - x = array_ops.fill((2, 20, 20, 32), 1 + 2**-12) - y = array_ops.fill((3, 3, 32, 32), 1.0) + x = array_ops.fill((16, 40, 40, 64), 1 + 2**-12) + y = array_ops.fill((3, 3, 64, 64), 1.0) def conv2d(x, y): return nn_ops.conv2d(x, y, [1, 1, 1, 1], padding='SAME') @@ -79,23 +87,23 @@ class TensorFloat32ConvTest(xla_test.XLATestCase): self._test_fn(conv2d, [x, y]) def test_conv2d_backprop_input(self): - y = array_ops.fill((3, 3, 32, 32), 1 + 2**-12) - out_backprop = array_ops.fill((2, 20, 20, 32), 1.0) + y = array_ops.fill((3, 3, 64, 64), 1 + 2**-12) + out_backprop = array_ops.fill((16, 40, 40, 64), 1.0) def conv2d_backprop_input(y, out_backprop): return nn_ops.conv2d_backprop_input( - (2, 20, 20, 32), y, out_backprop, [1, 1, 1, 1], padding='SAME' + (16, 40, 40, 64), y, out_backprop, [1, 1, 1, 1], padding='SAME' ) self._test_fn(conv2d_backprop_input, [y, out_backprop]) def test_conv2d_backprop_filter(self): - x = array_ops.fill((2, 20, 20, 32), 1 + 2**-12) - out_backprop = array_ops.fill((2, 20, 20, 32), 1.0) + x = array_ops.fill((16, 40, 40, 64), 1 + 2**-12) + out_backprop = array_ops.fill((16, 40, 40, 64), 1.0) def conv2d_backprop_filter(x, out_backprop): return nn_ops.conv2d_backprop_filter( - x, (3, 3, 32, 32), out_backprop, [1, 1, 1, 1], padding='SAME' + x, (3, 3, 64, 64), out_backprop, [1, 1, 1, 1], padding='SAME' ) self._test_fn(conv2d_backprop_filter, [x, out_backprop]) @@ -103,4 +111,7 @@ class TensorFloat32ConvTest(xla_test.XLATestCase): if __name__ == '__main__': ops.enable_eager_execution() + # Enable determinism, since otherwise the autotuner may nondeterministically + # chooses either a TF32 or non-TF32 algorithm + config.enable_op_determinism() googletest.main() diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index bc7ba74322a..d5c708e4b76 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -345,8 +345,6 @@ xla_test( "//tensorflow/compiler/xla/tests:test_macros_header", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/tsl/lib/core:status_test_util", - "//tensorflow/tsl/platform:tensor_float_32_hdr_lib", - "//tensorflow/tsl/platform:tensor_float_32_utils", "//tensorflow/tsl/platform:test", ], ) @@ -558,7 +556,6 @@ xla_test( "//tensorflow/compiler/xla/tests:test_macros_header", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/tsl/lib/core:status_test_util", - "//tensorflow/tsl/platform:tensor_float_32_hdr_lib", "//tensorflow/tsl/platform:test", ], ) @@ -611,8 +608,6 @@ xla_test( "//tensorflow/compiler/xla/tests:test_macros_header", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/tsl/lib/core:status_test_util", - "//tensorflow/tsl/platform:tensor_float_32_hdr_lib", - "//tensorflow/tsl/platform:tensor_float_32_utils", "//tensorflow/tsl/platform:test", ], ) diff --git a/tensorflow/compiler/xla/client/lib/qr_test.cc b/tensorflow/compiler/xla/client/lib/qr_test.cc index 485d3d4a9e6..ba4a8e9a883 100644 --- a/tensorflow/compiler/xla/client/lib/qr_test.cc +++ b/tensorflow/compiler/xla/client/lib/qr_test.cc @@ -29,16 +29,12 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/tsl/lib/core/status_test_util.h" -#include "tensorflow/tsl/platform/tensor_float_32_utils.h" namespace { using QrTest = xla::ClientLibraryTestBase; XLA_TEST_F(QrTest, Simple) { - // Test fails with TensorFloat-32 enabled - tsl::enable_tensor_float_32_execution(false); - xla::Array2D data({ {4, 6, 8, 10}, {6, 45, 54, 63}, @@ -79,8 +75,6 @@ XLA_TEST_F(QrTest, Simple) { } XLA_TEST_F(QrTest, ZeroDiagonal) { - // Test fails with TensorFloat-32 enabled - tsl::enable_tensor_float_32_execution(false); xla::XlaBuilder builder(TestName()); xla::Array2D a_vals({ @@ -106,8 +100,6 @@ XLA_TEST_F(QrTest, ZeroDiagonal) { } XLA_TEST_F(QrTest, SimpleBatched) { - // Test fails with TensorFloat-32 enabled - tsl::enable_tensor_float_32_execution(false); xla::XlaBuilder builder(TestName()); xla::Array3D a_vals({ @@ -136,8 +128,6 @@ XLA_TEST_F(QrTest, SimpleBatched) { } XLA_TEST_F(QrTest, SubnormalComplex) { - tsl::enable_tensor_float_32_execution(false); - // Verifies that we don't get NaNs in the case that the norm of a complex // number would be denormal but its imaginary value is not exactly 0. xla::Array2D a_vals({ diff --git a/tensorflow/compiler/xla/client/lib/self_adjoint_eig_test.cc b/tensorflow/compiler/xla/client/lib/self_adjoint_eig_test.cc index 6e48ae35cfa..95ff6c0781a 100644 --- a/tensorflow/compiler/xla/client/lib/self_adjoint_eig_test.cc +++ b/tensorflow/compiler/xla/client/lib/self_adjoint_eig_test.cc @@ -35,9 +35,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/tsl/lib/core/status_test_util.h" -#if GOOGLE_CUDA -#include "tensorflow/tsl/platform/tensor_float_32_utils.h" -#endif namespace xla { @@ -77,18 +74,8 @@ class SelfAdjointEigTest : public ClientLibraryTestBase { {4, 5, 10, 11}, {3, 9, 11, 17}, }; - -#if GOOGLE_CUDA - tf32_init_state_ = tsl::tensor_float_32_execution_enabled(); - tsl::enable_tensor_float_32_execution(false); -#endif - } - void TearDown() override { - ClientLibraryTestBase::TearDown(); -#if GOOGLE_CUDA - tsl::enable_tensor_float_32_execution(tf32_init_state_); -#endif } + void TearDown() override { ClientLibraryTestBase::TearDown(); } Array3D GetUnitMatrix3D(const Array3D& matrix) { Array3D result(matrix.n1(), matrix.n2(), matrix.n3(), 0.0); @@ -123,7 +110,6 @@ class SelfAdjointEigTest : public ClientLibraryTestBase { Array2D matrix2d_8x8_; Array2D low_rank_4x4_; Array2D wrong_type_4x4_; - bool tf32_init_state_; }; XlaOp GetAverageAbsoluteError(XlaOp m1, XlaOp m2, XlaBuilder* builder) { diff --git a/tensorflow/compiler/xla/client/lib/svd_test.cc b/tensorflow/compiler/xla/client/lib/svd_test.cc index 034771d2fb6..d3be1f907bc 100644 --- a/tensorflow/compiler/xla/client/lib/svd_test.cc +++ b/tensorflow/compiler/xla/client/lib/svd_test.cc @@ -35,7 +35,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/tsl/lib/core/status_test_util.h" -#include "tensorflow/tsl/platform/tensor_float_32_utils.h" namespace xla { @@ -57,9 +56,6 @@ class SVDTest : public ClientLibraryTestBase { {12, 48, 6, 62, 3}, }, }; - - // Test fails with TensorFloat-32 enabled - tsl::enable_tensor_float_32_execution(false); } void TearDown() override { ClientLibraryTestBase::TearDown(); } diff --git a/tensorflow/compiler/xla/hlo/ir/hlo_instruction.cc b/tensorflow/compiler/xla/hlo/ir/hlo_instruction.cc index 28a9550b1ae..8753c0c7bdb 100644 --- a/tensorflow/compiler/xla/hlo/ir/hlo_instruction.cc +++ b/tensorflow/compiler/xla/hlo/ir/hlo_instruction.cc @@ -4602,6 +4602,9 @@ PrecisionConfig* HloInstruction::mutable_precision_config() { if (auto* dot = DynCast(this)) { return dot->mutable_precision_config(); } + if (auto* custom_call = DynCast(this)) { + return custom_call->mutable_precision_config(); + } LOG(FATAL) << "Unimplemented method."; } diff --git a/tensorflow/compiler/xla/service/gpu/conv_layout_normalization.cc b/tensorflow/compiler/xla/service/gpu/conv_layout_normalization.cc index 1fea4134e4b..90b3ef13447 100644 --- a/tensorflow/compiler/xla/service/gpu/conv_layout_normalization.cc +++ b/tensorflow/compiler/xla/service/gpu/conv_layout_normalization.cc @@ -146,6 +146,7 @@ StatusOr UpdateLayoutForCudnnConvolution( normalized_conv->set_feature_group_count(hlo->feature_group_count()); normalized_conv->set_raw_backend_config_string( hlo->raw_backend_config_string()); + *normalized_conv->mutable_precision_config() = hlo->precision_config(); normalized_conv->parent()->parent()->SetAndUniquifyInstrName(normalized_conv, hlo->name()); diff --git a/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc index d6936053cf1..8ca8042d77a 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc @@ -292,6 +292,7 @@ StatusOr> DoGemmAutotune( const DebugOptions& debug_options = gemm->GetModule()->config().debug_options(); AutotuneConfig autotune_config = GetConfig(debug_options); + const bool deterministic_ops = debug_options.xla_gpu_deterministic_ops(); TF_ASSIGN_OR_RETURN(GemmConfig config, GemmConfig::For(gemm)); // Don't run autotuning concurrently on the same GPU. @@ -395,9 +396,10 @@ StatusOr> DoGemmAutotune( // should always return true, and the actual // success-ness is returned in // ProfileResult::is_valid. - TF_RETURN_IF_ERROR(RunGemm( - config, lhs_buffer, rhs_buffer, output_buffer, - stream, algorithm, &profile_result)); + TF_RETURN_IF_ERROR( + RunGemm(config, lhs_buffer, rhs_buffer, + output_buffer, deterministic_ops, + stream, algorithm, &profile_result)); return std::move(profile_result); })); diff --git a/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc b/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc index 025ba88f23b..bd9a905ad8b 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc @@ -623,12 +623,17 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { instr, m::Convert(GemmOrCublasLtMatmul(&existing_gemm).WithOneUser())) && existing_gemm->operands().size() == 2) { + TF_ASSIGN_OR_RETURN(GemmBackendConfig gemm_backend_config, + existing_gemm->backend_config()); + // check if type combination is supported here TF_ASSIGN_OR_RETURN( bool types_are_supported, IsLegacyCublasMatmul(*existing_gemm) - ? TypesAreSupportedByLegacyCublas(*existing_gemm, instr) - : TypesAreSupportedByCublasLt(*existing_gemm, instr)); + ? TypesAreSupportedByLegacyCublas(*existing_gemm, + gemm_backend_config, instr) + : TypesAreSupportedByCublasLt(*existing_gemm, + gemm_backend_config, instr)); if (types_are_supported) { return FuseMatrixConvert(existing_gemm, instr); } @@ -1392,7 +1397,8 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { } StatusOr TypesAreSupportedByLegacyCublas( - const HloInstruction &instr, const HloInstruction *bias = nullptr) const { + const HloInstruction &instr, const GemmBackendConfig &gemm_backend_config, + const HloInstruction *bias = nullptr) const { // Figure out the Atype/Btype. const PrimitiveType a_dtype = instr.operand(0)->shape().element_type(); const PrimitiveType b_dtype = instr.operand(1)->shape().element_type(); @@ -1478,7 +1484,8 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { } StatusOr TypesAreSupportedByCublasLt( - const HloInstruction &instr, const HloInstruction *bias = nullptr) const { + const HloInstruction &instr, const GemmBackendConfig &backend_config, + const HloInstruction *bias = nullptr) const { // Figure out the Atype/Btype. const PrimitiveType a_dtype = instr.operand(0)->shape().element_type(); const PrimitiveType b_dtype = instr.operand(1)->shape().element_type(); @@ -1494,10 +1501,12 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { // Figure out the computeType and scaleType. TF_ASSIGN_OR_RETURN(const se::blas::DataType output_dtype, AsBlasDataType(output_type)); - TF_ASSIGN_OR_RETURN(const se::blas::ComputationType compute_type, - GetBlasComputationType( - a_dtype, output_type, - stream_executor::blas::kDefaultComputePrecision)); + int max_precision = *absl::c_max_element( + backend_config.precision_config().operand_precision()); + TF_ASSIGN_OR_RETURN( + const se::blas::ComputationType compute_type, + GetBlasComputationType(a_dtype, instr.shape().element_type(), + max_precision)); se::blas::DataType scale_type = cublas_lt::GetScaleType(output_dtype, compute_type); @@ -1641,8 +1650,9 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { const HloInstruction *rhs = instr.operand(1); const Shape &output_shape = instr.shape(); - TF_ASSIGN_OR_RETURN(bool types_are_supported_by_cublas_lt, - TypesAreSupportedByCublasLt(instr)); + TF_ASSIGN_OR_RETURN( + bool types_are_supported_by_cublas_lt, + TypesAreSupportedByCublasLt(instr, gemm_backend_config)); if (!types_are_supported_by_cublas_lt) { return false; } diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc index b5d35bf1ca2..4218dad8b47 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc @@ -29,19 +29,22 @@ namespace gpu { GemmThunk::GemmThunk(ThunkInfo thunk_info, GemmConfig config, const BufferAllocation::Slice& lhs_buffer, const BufferAllocation::Slice& rhs_buffer, - const BufferAllocation::Slice& output_buffer) + const BufferAllocation::Slice& output_buffer, + bool deterministic) : Thunk(Kind::kGemm, thunk_info), config_(std::move(config)), lhs_buffer_(lhs_buffer), rhs_buffer_(rhs_buffer), - output_buffer_(output_buffer) {} + output_buffer_(output_buffer), + deterministic_(deterministic) {} Status GemmThunk::ExecuteOnStream(const ExecuteParams& params) { VLOG(3) << "Running GEMM thunk"; const BufferAllocations& allocs = *params.buffer_allocations; return RunGemm(config_, allocs.GetDeviceAddress(lhs_buffer_), allocs.GetDeviceAddress(rhs_buffer_), - allocs.GetDeviceAddress(output_buffer_), params.stream); + allocs.GetDeviceAddress(output_buffer_), deterministic_, + params.stream); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h index 500a02559be..bf0bc40d349 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h @@ -32,7 +32,7 @@ class GemmThunk : public Thunk { GemmThunk(ThunkInfo thunk_info, GemmConfig config, const BufferAllocation::Slice& lhs_buffer, const BufferAllocation::Slice& rhs_buffer, - const BufferAllocation::Slice& output_buffer); + const BufferAllocation::Slice& output_buffer, bool deterministic); GemmThunk(const GemmThunk&) = delete; GemmThunk& operator=(const GemmThunk&) = delete; @@ -44,6 +44,8 @@ class GemmThunk : public Thunk { const BufferAllocation::Slice lhs_buffer_; const BufferAllocation::Slice rhs_buffer_; const BufferAllocation::Slice output_buffer_; + // Whether to run deterministically. + const bool deterministic_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc index 3df71da9cff..f2bc98370a1 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc @@ -112,7 +112,7 @@ StatusOr> ScratchAllocator::AllocateBytes( StatusOr> GetAlgorithms( const GpuConvConfig& config, se::Stream* stream, bool use_cudnn_frontend, - bool use_fallback, bool deterministic_ops) { + bool use_fallback, const se::NumericOptions& numeric_options) { TF_ASSIGN_OR_RETURN(se::dnn::ConvolutionKind kind, GetDNNConvKindFromCudnnConvKind(config.kind)); @@ -146,8 +146,7 @@ StatusOr> GetAlgorithms( /* leakyrelu_alpha = */ 0.0, stream, config.input_descriptor, config.filter_descriptor, config.bias_descriptor, config.output_descriptor, config.conv_desc, use_fallback, - config.fusion->mode, se::NumericOptions{deterministic_ops}, - &runners)); + config.fusion->mode, numeric_options, &runners)); for (auto& runner : runners) { TF_ASSIGN_OR_RETURN( auto runner_cache, @@ -172,8 +171,7 @@ StatusOr> GetAlgorithms( /* filter_data = */ DeviceMemoryBase(nullptr), config.output_descriptor, /* output_data = */ DeviceMemoryBase(nullptr), config.conv_desc, - use_fallback, nullptr, se::NumericOptions{deterministic_ops}, - &runners)); + use_fallback, nullptr, numeric_options, &runners)); for (auto& runner : runners) { TF_ASSIGN_OR_RETURN( auto runner_cache, @@ -194,7 +192,7 @@ GetMIOpenAlgorithms(const HloCustomCallInstruction* instr, se::DeviceMemoryBase result_buffer, se::StreamExecutor* stream_exec, ScratchAllocator* scratch_allocator, se::Stream* stream, - bool deterministic_ops) { + const se::NumericOptions& numeric_options) { TF_ASSIGN_OR_RETURN(GpuConvConfig config, GetGpuConvConfig(instr)); TF_ASSIGN_OR_RETURN(se::dnn::ConvolutionKind kind, @@ -213,7 +211,7 @@ GetMIOpenAlgorithms(const HloCustomCallInstruction* instr, params.config->filter_descriptor, params.filter_buf, params.config->output_descriptor, params.output_buf, params.config->conv_desc, /* use_fallback = */ false, scratch_allocator, - se::NumericOptions{deterministic_ops}, &runners)); + numeric_options, &runners)); return runners; } @@ -828,6 +826,15 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda( const bool cudnn_frontend_enabled = debug_options.xla_gpu_enable_cudnn_frontend(); const bool deterministic_ops = debug_options.xla_gpu_deterministic_ops(); + bool allow_tf32 = true; + // TODO(b/284371623): Properly set allow_tf32 even if instr==nullptr, which is + // the case when running an AOT compiled executable with runtime autotuning. + if (instr) { + allow_tf32 = absl::c_all_of( + instr->precision_config().operand_precision(), + [](int precision) { return precision <= PrecisionConfig::HIGH; }); + } + const se::NumericOptions numeric_options{deterministic_ops, allow_tf32}; // Use the first algorithm that's supported as reference. There isn't a // particular reason to use it, as any algorithm suffices. It doesn't make @@ -838,7 +845,7 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda( std::vector runners, GetAlgorithms(runtime_arguments.gpu_conv_config, stream, cudnn_frontend_enabled, - /* use_fallback = */ false, deterministic_ops)); + /* use_fallback = */ false, numeric_options)); std::vector profile_results; for (auto& runner_cache : runners) { @@ -862,7 +869,7 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda( std::vector fallback_runners, GetAlgorithms(runtime_arguments.gpu_conv_config, stream, cudnn_frontend_enabled, - /* use_fallback = */ true, deterministic_ops)); + /* use_fallback = */ true, numeric_options)); for (auto& runner_cache : fallback_runners) { TF_ASSIGN_OR_RETURN( @@ -961,6 +968,10 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheRocm( const DebugOptions& debug_options = instr->GetModule()->config().debug_options(); const bool deterministic_ops = debug_options.xla_gpu_deterministic_ops(); + const bool allow_tf32 = absl::c_all_of( + instr->precision_config().operand_precision(), + [](int precision) { return precision <= PrecisionConfig::HIGH; }); + const se::NumericOptions numeric_options{deterministic_ops, allow_tf32}; se::StreamExecutor* stream_exec = std::get(config_).stream_exec; const auto device_ordinal = stream_exec->device_ordinal(); @@ -998,7 +1009,7 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheRocm( std::vector> runners, GetMIOpenAlgorithms(instr, absl::MakeSpan(operand_buffers), result_buffer, stream_exec, &scratch_allocator, stream, - deterministic_ops)); + numeric_options)); std::vector profile_results; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter.cc b/tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter.cc index cf9aec202da..bee7667cdeb 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter.cc @@ -560,6 +560,7 @@ HloInstruction* CreateGpuConv(absl::string_view call_target, const Shape& shape, const Window& window, const ConvolutionDimensionNumbers& dnums, int64_t feature_group_count, + const PrecisionConfig& precision_config, const OpMetadata& metadata) { HloComputation* computation = lhs->parent(); @@ -579,6 +580,7 @@ HloInstruction* CreateGpuConv(absl::string_view call_target, const Shape& shape, custom_call->set_window(window); custom_call->set_convolution_dimension_numbers(dnums); custom_call->set_feature_group_count(feature_group_count); + *custom_call->mutable_precision_config() = precision_config; custom_call->set_metadata(metadata); // Give the customcall a user-friendly name. @@ -665,14 +667,16 @@ static StatusOr CreateCustomCallHelper(HloInstruction* conv) { auto& [window, dnums, rhs] = *m; return CreateGpuConv(kCudnnConvBackwardInputCallTarget, conv->shape(), conv->mutable_operand(0), rhs, window, dnums, - conv->feature_group_count(), conv->metadata()); + conv->feature_group_count(), conv->precision_config(), + conv->metadata()); } if (ConvolutionMatch m = MatchBackwardFilter(conv)) { auto& [window, dnums, lhs] = *m; return CreateGpuConv(kCudnnConvBackwardFilterCallTarget, conv->shape(), lhs, conv->mutable_operand(1), window, dnums, - conv->batch_group_count(), conv->metadata()); + conv->batch_group_count(), conv->precision_config(), + conv->metadata()); } // If all else fails, try a forward convolution. @@ -684,7 +688,8 @@ static StatusOr CreateCustomCallHelper(HloInstruction* conv) { return CreateGpuConv(kCudnnConvForwardCallTarget, conv->shape(), conv->mutable_operand(0), conv->mutable_operand(1), conv->window(), conv->convolution_dimension_numbers(), - conv->feature_group_count(), conv->metadata()); + conv->feature_group_count(), conv->precision_config(), + conv->metadata()); } return nullptr; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 720f682ca2b..f06591cdc0e 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -1143,10 +1143,12 @@ Status IrEmitterUnnested::EmitGemmThunk(mlir::Operation* op) { TF_ASSIGN_OR_RETURN(auto a, GetAllocationSlice(gemm.getA())); TF_ASSIGN_OR_RETURN(auto b, GetAllocationSlice(gemm.getB())); TF_ASSIGN_OR_RETURN(auto c, GetAllocationSlice(gemm.getC())); + bool deterministic_ops = + hlo_module_config_.debug_options().xla_gpu_deterministic_ops(); TF_ASSIGN_OR_RETURN(GemmConfig config, GemmConfig::For(gemm)); - auto thunk = - std::make_unique(GetThunkInfo(op), std::move(config), a, b, c); + auto thunk = std::make_unique(GetThunkInfo(op), std::move(config), + a, b, c, deterministic_ops); AddThunkToThunkSequence(std::move(thunk)); return OkStatus(); diff --git a/tensorflow/compiler/xla/service/gpu/matmul_utils.cc b/tensorflow/compiler/xla/service/gpu/matmul_utils.cc index f09262bf9f7..99d31cf1f5e 100644 --- a/tensorflow/compiler/xla/service/gpu/matmul_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/matmul_utils.cc @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/stream_executor/blas.h" #include "tensorflow/compiler/xla/stream_executor/device_memory.h" +#include "tensorflow/compiler/xla/stream_executor/numeric_options.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -551,6 +552,7 @@ Status DoGemmWithAlgorithm(int64_t batch_size, int64_t m, int64_t n, int64_t k, Scale beta, se::Stream* stream, se::blas::AlgorithmType algorithm, se::blas::ComputePrecision compute_precision, + const se::NumericOptions& numeric_options, se::blas::ProfileResult* profile_result) { CHECK(output.transpose == se::blas::Transpose::kNoTranspose); PrimitiveType lhs_type = primitive_util::NativeToPrimitiveType(); @@ -566,13 +568,13 @@ Status DoGemmWithAlgorithm(int64_t batch_size, int64_t m, int64_t n, int64_t k, lhs.leading_dim_stride, lhs.batch_stride, rhs.cast(), rhs.leading_dim_stride, rhs.batch_stride, beta, &output_data, output.leading_dim_stride, output.batch_stride, batch_size, - computation_type, algorithm, compute_precision, profile_result); + computation_type, algorithm, numeric_options, profile_result); } else { return stream->ThenBlasGemmWithAlgorithm( lhs.transpose, rhs.transpose, m, n, k, alpha, lhs.cast(), lhs.leading_dim_stride, rhs.cast(), rhs.leading_dim_stride, beta, &output_data, output.leading_dim_stride, computation_type, algorithm, - compute_precision, profile_result); + numeric_options, profile_result); } } @@ -583,6 +585,7 @@ Status DoGemm(int64_t batch_size, int64_t m, int64_t n, int64_t k, se::Stream* stream, std::optional algorithm, se::blas::ComputePrecision compute_precision, + const se::NumericOptions& numeric_options, se::blas::ProfileResult* profile_result) { CHECK(output.transpose == se::blas::Transpose::kNoTranspose); se::DeviceMemory output_data(output.data); @@ -591,7 +594,7 @@ Status DoGemm(int64_t batch_size, int64_t m, int64_t n, int64_t k, if (algorithm) { return DoGemmWithAlgorithm( batch_size, m, n, k, lhs, rhs, output, alpha, beta, stream, *algorithm, - compute_precision, profile_result); + compute_precision, numeric_options, profile_result); } #endif @@ -601,20 +604,21 @@ Status DoGemm(int64_t batch_size, int64_t m, int64_t n, int64_t k, lhs.leading_dim_stride, lhs.batch_stride, rhs.cast(), rhs.leading_dim_stride, rhs.batch_stride, beta, &output_data, output.leading_dim_stride, output.batch_stride, batch_size, - compute_precision); + numeric_options); } return stream->ThenBlasGemm( lhs.transpose, rhs.transpose, m, n, k, alpha, lhs.cast(), lhs.leading_dim_stride, rhs.cast(), rhs.leading_dim_stride, beta, - &output_data, output.leading_dim_stride, compute_precision); + &output_data, output.leading_dim_stride, numeric_options); } } // namespace Status RunGemm(const GemmConfig& config, se::DeviceMemoryBase lhs_buffer, se::DeviceMemoryBase rhs_buffer, - se::DeviceMemoryBase output_buffer, se::Stream* stream, + se::DeviceMemoryBase output_buffer, bool deterministic_ops, + se::Stream* stream, std::optional algorithm, se::blas::ProfileResult* profile_result) { VLOG(2) << "Executing a GemmThunk"; @@ -635,6 +639,9 @@ Status RunGemm(const GemmConfig& config, se::DeviceMemoryBase lhs_buffer, MatrixDescriptor rhs = GetMatrixDesc(rhs_layout, rhs_buffer); MatrixDescriptor output = GetMatrixDesc(output_layout, output_buffer); int64_t batch_size = output_layout.batch_size; + se::NumericOptions numeric_options{ + deterministic_ops, + /*allow_tf32=*/config.compute_precision <= 1}; if (!algorithm) algorithm = config.algorithm; @@ -651,7 +658,7 @@ Status RunGemm(const GemmConfig& config, se::DeviceMemoryBase lhs_buffer, batch_size, m, n, k, lhs, rhs, output, \ static_cast(config.alpha.real()), \ static_cast(config.beta), stream, algorithm, \ - config.compute_precision, profile_result); \ + config.compute_precision, numeric_options, profile_result); \ } #define TYPED_GEMM_COMPLEX(SCALENTYPE, ATYPE, BTYPE, CTYPE) \ @@ -664,7 +671,7 @@ Status RunGemm(const GemmConfig& config, se::DeviceMemoryBase lhs_buffer, batch_size, m, n, k, lhs, rhs, output, \ static_cast(config.alpha), \ static_cast(config.beta), stream, algorithm, \ - config.compute_precision, profile_result); \ + config.compute_precision, numeric_options, profile_result); \ } if (output_layout.dtype == S32) { @@ -673,7 +680,7 @@ Status RunGemm(const GemmConfig& config, se::DeviceMemoryBase lhs_buffer, batch_size, m, n, k, lhs, rhs, output, static_cast(config.alpha.real()), static_cast(config.beta), stream, *algorithm, - se::blas::kDefaultComputePrecision, profile_result); + se::blas::kDefaultComputePrecision, numeric_options, profile_result); } TYPED_GEMM(F32, BF16, BF16, BF16) diff --git a/tensorflow/compiler/xla/service/gpu/matmul_utils.h b/tensorflow/compiler/xla/service/gpu/matmul_utils.h index 6fb7d252a25..3afe513fc74 100644 --- a/tensorflow/compiler/xla/service/gpu/matmul_utils.h +++ b/tensorflow/compiler/xla/service/gpu/matmul_utils.h @@ -139,7 +139,8 @@ se::blas::DataType GetScaleType(se::blas::DataType c_type, // If `algorithm` is provided, it overrides the one specified in `config`. Status RunGemm(const GemmConfig& config, se::DeviceMemoryBase lhs_buffer, se::DeviceMemoryBase rhs_buffer, - se::DeviceMemoryBase output_buffer, se::Stream* stream, + se::DeviceMemoryBase output_buffer, bool deterministic_ops, + se::Stream* stream, std::optional algorithm = std::nullopt, se::blas::ProfileResult* profile_result = nullptr); diff --git a/tensorflow/compiler/xla/service/gpu/runtime/gemm.cc b/tensorflow/compiler/xla/service/gpu/runtime/gemm.cc index 268149ba534..0cf4c4dfa38 100644 --- a/tensorflow/compiler/xla/service/gpu/runtime/gemm.cc +++ b/tensorflow/compiler/xla/service/gpu/runtime/gemm.cc @@ -53,6 +53,7 @@ Status DoRuntimeAutotuning(se::Stream* stream, GemmConfig& config, VLOG(3) << "Running GEMM runtime autotuning"; std::vector algorithms; stream->parent()->GetBlasGemmAlgorithms(stream, &algorithms); + const bool deterministic_ops = debug_options->xla_gpu_deterministic_ops(); // Set autotune_level to 3 to disable correctness checking, which avoids // memory allocation during runtime. @@ -84,8 +85,8 @@ Status DoRuntimeAutotuning(se::Stream* stream, GemmConfig& config, // we pass a non-null ProfileResult, DoGemmWithAlgorithm should // always return true, and the actual success-ness is returned in // ProfileResult::is_valid. - TF_RETURN_IF_ERROR(RunGemm(config, lhs, rhs, out, stream, algorithm, - &profile_result)); + TF_RETURN_IF_ERROR(RunGemm(config, lhs, rhs, out, deterministic_ops, + stream, algorithm, &profile_result)); return std::move(profile_result); })); @@ -110,6 +111,7 @@ static absl::Status GemmImpl(const ServiceExecutableRunOptions* run_options, se::DeviceMemoryBase lhs_data = GetDeviceAddress(lhs); se::DeviceMemoryBase rhs_data = GetDeviceAddress(rhs); se::DeviceMemoryBase output_data = GetDeviceAddress(out); + const bool deterministic_ops = debug_options->xla_gpu_deterministic_ops(); VLOG(3) << "Running GEMM"; se::Stream* stream = run_options->stream(); @@ -143,7 +145,8 @@ static absl::Status GemmImpl(const ServiceExecutableRunOptions* run_options, #endif } - return RunGemm(*gemm_config, lhs_data, rhs_data, output_data, stream); + return RunGemm(*gemm_config, lhs_data, rhs_data, output_data, + deterministic_ops, stream); } XLA_RUNTIME_DEFINE_CUSTOM_CALL( diff --git a/tensorflow/compiler/xla/service/gpu/tests/BUILD b/tensorflow/compiler/xla/service/gpu/tests/BUILD index 7fc54c3e76d..0ef2bdd0744 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/gpu/tests/BUILD @@ -220,7 +220,6 @@ xla_cc_test( "//tensorflow/compiler/xla/tests:filecheck", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/tsl/lib/core:status_test_util", - "//tensorflow/tsl/platform:tensor_float_32_hdr_lib", "//tensorflow/tsl/platform:test", "//tensorflow/tsl/platform:test_main", "@com_google_absl//absl/strings", @@ -892,6 +891,25 @@ xla_cc_test( ]), ) +xla_test( + name = "tensor_float_32_global_var_test", + srcs = ["tensor_float_32_global_var_test.cc"], + backend_tags = {"gpu": [ + "requires-gpu-nvidia", + "requires-gpu-sm80-only", + ]}, + backends = [ + "gpu", + ], + deps = [ + ":gpu_codegen_test", + "//tensorflow/compiler/xla:error_spec", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/tsl/platform:tensor_float_32_utils", + "//tensorflow/tsl/platform:test_main", + ], +) + xla_cc_test( name = "gpu_fused_mha_test", srcs = ["gpu_fused_mha_test.cc"], diff --git a/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc index 6afb1d2ee96..23ff37f4ee7 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc @@ -34,7 +34,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/xla.pb.h" #include "tensorflow/tsl/lib/core/status_test_util.h" -#include "tensorflow/tsl/platform/tensor_float_32_utils.h" #include "tensorflow/tsl/platform/test.h" namespace xla { @@ -52,16 +51,6 @@ class GemmRewriteTest : public GpuCodegenTest { ->GetDeviceDescription() .cuda_compute_capability(); } - void SetUp() override { - tf32_state_ = tsl::tensor_float_32_execution_enabled(); - tsl::enable_tensor_float_32_execution(false); - } - void TearDown() override { - tsl::enable_tensor_float_32_execution(tf32_state_); - } - - private: - bool tf32_state_; }; TEST_F(GemmRewriteTest, CheckCustomCallTarget) { @@ -1125,7 +1114,7 @@ HloModule test ENTRY test { Arg_0.1 = f16[4,3]{1,0} parameter(0) Arg_1.2 = f16[3,6]{1,0} parameter(1) - ROOT dot.3 = f32[4,6]{1,0} dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT dot.3 = f32[4,6]{1,0} dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest, highest} } )"; @@ -6308,8 +6297,6 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF8TF32E5M2) { } )"; - bool tf32_state_ = tsl::tensor_float_32_execution_enabled(); - tsl::enable_tensor_float_32_execution(true); CheckFp8IfOnHopper(hlo_text); RunAndFilecheckHloRewrite(hlo_text, @@ -6318,7 +6305,6 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF8TF32E5M2) { R"( ; CHECK: custom_call_target="__cublas$lt$matmul$f8", )"); - tsl::enable_tensor_float_32_execution(tf32_state_); } INSTANTIATE_TEST_SUITE_P(Fp8CublasTestsBothLegacyAndLt, diff --git a/tensorflow/compiler/xla/service/gpu/tests/tensor_float_32_global_var_test.cc b/tensorflow/compiler/xla/service/gpu/tests/tensor_float_32_global_var_test.cc new file mode 100644 index 00000000000..b56571db89d --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/tensor_float_32_global_var_test.cc @@ -0,0 +1,67 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/error_spec.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/tsl/platform/tensor_float_32_utils.h" + +namespace xla { +namespace gpu { +namespace { + +// Test that setting the TensorFloat-32 global variable to false causes +// TensorFloat-32 not to be used, even when the operand precision is set to the +// default. +// TODO(b/280130359): Have XLA ignore the TensorFloat-32 global variable +class TensorFloat32GlobalVarTest : public HloTestBase { + protected: + TensorFloat32GlobalVarTest() { + // The error tolerances are small enough so that the use of TF32 will cause + // the error to be greater than the tolerances. + error_spec_ = ErrorSpec{1e-4, 1e-4}; + } +}; + +TEST_F(TensorFloat32GlobalVarTest, Dot) { + tsl::enable_tensor_float_32_execution(false); + const char* hlo_text = R"( +HloModule TestModule + +ENTRY %dot_computation (x: f32[1024,1024], source: f32[1024,1024]) -> f32[1024,1024] { + %x = f32[1024,1024] parameter(0) + %y = f32[1024,1024] parameter(1) + ROOT %result = f32[1024,1024] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={default, default} +} +)"; + EXPECT_TRUE(RunAndCompare(hlo_text, error_spec_)); +} + +TEST_F(TensorFloat32GlobalVarTest, Convolution) { + tsl::enable_tensor_float_32_execution(false); + const char* hlo_text = R"( +HloModule TestModule + +ENTRY %conv_computation (x: f32[16,40,40,64], source: f32[3,3,64,64]) -> f32[16,40,40,64] { + %x = f32[16,40,40,64] parameter(0) + %y = f32[3,3,64,64] parameter(1) + ROOT %result = f32[16,40,40,64] convolution(x, y), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, operand_precision={default, default} +} +)"; + EXPECT_TRUE(RunAndCompare(hlo_text, error_spec_)); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/stream_executor/blas.h b/tensorflow/compiler/xla/stream_executor/blas.h index 9774a3f6d72..7fd5e05edaa 100644 --- a/tensorflow/compiler/xla/stream_executor/blas.h +++ b/tensorflow/compiler/xla/stream_executor/blas.h @@ -41,11 +41,15 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_BLAS_H_ #include +#include +#include +#include #include #include "absl/types/span.h" #include "tensorflow/compiler/xla/stream_executor/data_type.h" #include "tensorflow/compiler/xla/stream_executor/device_memory.h" +#include "tensorflow/compiler/xla/stream_executor/numeric_options.h" #include "tensorflow/compiler/xla/stream_executor/platform/port.h" #include "tensorflow/tsl/platform/statusor.h" #include "tensorflow/tsl/protobuf/dnn.pb.h" @@ -314,7 +318,7 @@ class BlasSupport { const DeviceMemoryBase &a, int lda, const DeviceMemoryBase &b, int ldb, const void *beta, DeviceMemoryBase *c, int ldc, - ComputePrecision precision) = 0; + const NumericOptions &numeric_options) = 0; // Gets a list of supported algorithms for DoBlasGemmWithAlgorithm. virtual bool GetBlasGemmAlgorithms( @@ -338,7 +342,7 @@ class BlasSupport { const DeviceMemoryBase &b, DataType type_b, int ldb, const void *beta, DeviceMemoryBase *c, DataType type_c, int ldc, ComputationType computation_type, AlgorithmType algorithm, - blas::ComputePrecision precision, + const NumericOptions &numeric_options, ProfileResult *output_profile_result) = 0; virtual tsl::Status DoBlasGemmStridedBatchedWithAlgorithm( @@ -348,7 +352,7 @@ class BlasSupport { const DeviceMemoryBase &b, DataType type_b, int ldb, int64_t stride_b, const void *beta, DeviceMemoryBase *c, DataType type_c, int ldc, int64_t stride_c, int batch_count, ComputationType computation_type, - AlgorithmType algorithm, blas::ComputePrecision precision, + AlgorithmType algorithm, const NumericOptions &numeric_options, ProfileResult *output_profile_result) = 0; // Computes a batch of matrix-matrix product with general matrices. @@ -362,6 +366,7 @@ class BlasSupport { DeviceMemorySlice b, int ldb, float beta, DeviceMemorySlice c, int ldc, int batch_count, + const NumericOptions &numeric_options, ScratchAllocator *scratch_allocator) = 0; virtual bool DoBlasGemmBatched(Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64_t n, @@ -371,6 +376,7 @@ class BlasSupport { float beta, DeviceMemorySlice c, int ldc, int batch_count, + const NumericOptions &numeric_options, ScratchAllocator *scratch_allocator) = 0; virtual bool DoBlasGemmBatched(Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64_t n, @@ -379,6 +385,7 @@ class BlasSupport { DeviceMemorySlice b, int ldb, float beta, DeviceMemorySlice c, int ldc, int batch_count, + const NumericOptions &numeric_options, ScratchAllocator *scratch_allocator) = 0; virtual bool DoBlasGemmBatched(Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64_t n, @@ -387,6 +394,7 @@ class BlasSupport { DeviceMemorySlice b, int ldb, double beta, DeviceMemorySlice c, int ldc, int batch_count, + const NumericOptions &numeric_options, ScratchAllocator *scratch_allocator) = 0; virtual bool DoBlasGemmBatched( Stream *stream, blas::Transpose transa, blas::Transpose transb, @@ -394,14 +402,16 @@ class BlasSupport { DeviceMemorySlice> a, int lda, DeviceMemorySlice> b, int ldb, std::complex beta, DeviceMemorySlice> c, - int ldc, int batch_count, ScratchAllocator *scratch_allocator) = 0; + int ldc, int batch_count, const NumericOptions &numeric_options, + ScratchAllocator *scratch_allocator) = 0; virtual bool DoBlasGemmBatched( Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64_t n, uint64 k, std::complex alpha, DeviceMemorySlice> a, int lda, DeviceMemorySlice> b, int ldb, std::complex beta, DeviceMemorySlice> c, - int ldc, int batch_count, ScratchAllocator *scratch_allocator) = 0; + int ldc, int batch_count, const NumericOptions &numeric_options, + ScratchAllocator *scratch_allocator) = 0; // Batched gemm with strides instead of pointer arrays. virtual tsl::Status DoBlasGemmStridedBatched( @@ -410,7 +420,7 @@ class BlasSupport { const DeviceMemoryBase &a, int lda, int64_t stride_a, const DeviceMemoryBase &b, int ldb, int64_t stride_b, const void *beta, DeviceMemoryBase *c, int ldc, int64_t stride_c, int batch_count, - blas::ComputePrecision precision) = 0; + const NumericOptions &numeric_options) = 0; // Solves a triangular matrix equation. // @@ -560,7 +570,7 @@ class BlasSupport { uint64_t m, uint64 n, uint64 k, blas::DataType dtype, const void *alpha, \ const DeviceMemoryBase &a, int lda, const DeviceMemoryBase &b, int ldb, \ const void *beta, DeviceMemoryBase *c, int ldc, \ - blas::ComputePrecision precision) override; \ + const NumericOptions &numeric_options) override; \ bool GetBlasGemmAlgorithms(Stream *stream, \ std::vector *out_algorithms) \ override; \ @@ -571,7 +581,7 @@ class BlasSupport { const DeviceMemoryBase &b, blas::DataType type_b, int ldb, \ const void *beta, DeviceMemoryBase *c, blas::DataType type_c, int ldc, \ blas::ComputationType computation_type, blas::AlgorithmType algorithm, \ - blas::ComputePrecision precision, \ + const NumericOptions &numeric_options, \ blas::ProfileResult *output_profile_result) override; \ bool DoBlasGemmBatched( \ Stream *stream, blas::Transpose transa, blas::Transpose transb, \ @@ -579,6 +589,7 @@ class BlasSupport { DeviceMemorySlice a, int lda, \ DeviceMemorySlice b, int ldb, float beta, \ DeviceMemorySlice c, int ldc, int batch_count, \ + const NumericOptions &numeric_options, \ ScratchAllocator *scratch_allocator) override; \ bool DoBlasGemmBatched( \ Stream *stream, blas::Transpose transa, blas::Transpose transb, \ @@ -586,40 +597,45 @@ class BlasSupport { DeviceMemorySlice a, int lda, \ DeviceMemorySlice b, int ldb, float beta, \ DeviceMemorySlice c, int ldc, int batch_count, \ + const NumericOptions &numeric_options, \ ScratchAllocator *scratch_allocator) override; \ bool DoBlasGemmBatched( \ Stream *stream, blas::Transpose transa, blas::Transpose transb, \ uint64_t m, uint64 n, uint64 k, float alpha, DeviceMemorySlice a, \ int lda, DeviceMemorySlice b, int ldb, float beta, \ DeviceMemorySlice c, int ldc, int batch_count, \ + const NumericOptions &numeric_options, \ ScratchAllocator *scratch_allocator) override; \ bool DoBlasGemmBatched( \ Stream *stream, blas::Transpose transa, blas::Transpose transb, \ uint64_t m, uint64 n, uint64 k, double alpha, \ DeviceMemorySlice a, int lda, DeviceMemorySlice b, \ int ldb, double beta, DeviceMemorySlice c, int ldc, \ - int batch_count, ScratchAllocator *scratch_allocator) override; \ + int batch_count, const NumericOptions &numeric_options, \ + ScratchAllocator *scratch_allocator) override; \ bool DoBlasGemmBatched( \ Stream *stream, blas::Transpose transa, blas::Transpose transb, \ uint64_t m, uint64 n, uint64 k, std::complex alpha, \ DeviceMemorySlice> a, int lda, \ DeviceMemorySlice> b, int ldb, \ std::complex beta, DeviceMemorySlice> c, \ - int ldc, int batch_count, ScratchAllocator *scratch_allocator) override; \ + int ldc, int batch_count, const NumericOptions &numeric_options, \ + ScratchAllocator *scratch_allocator) override; \ bool DoBlasGemmBatched( \ Stream *stream, blas::Transpose transa, blas::Transpose transb, \ uint64_t m, uint64 n, uint64 k, std::complex alpha, \ DeviceMemorySlice> a, int lda, \ DeviceMemorySlice> b, int ldb, \ std::complex beta, DeviceMemorySlice> c, \ - int ldc, int batch_count, ScratchAllocator *scratch_allocator) override; \ + int ldc, int batch_count, const NumericOptions &numeric_options, \ + ScratchAllocator *scratch_allocator) override; \ tsl::Status DoBlasGemmStridedBatched( \ Stream *stream, blas::Transpose transa, blas::Transpose transb, \ uint64_t m, uint64 n, uint64 k, blas::DataType dtype, const void *alpha, \ const DeviceMemoryBase &a, int lda, int64_t stride_a, \ const DeviceMemoryBase &b, int ldb, int64_t stride_b, const void *beta, \ DeviceMemoryBase *c, int ldc, int64_t stride_c, int batch_count, \ - blas::ComputePrecision precision) override; \ + const NumericOptions &numeric_options) override; \ tsl::Status DoBlasGemmStridedBatchedWithAlgorithm( \ Stream *stream, blas::Transpose transa, blas::Transpose transb, \ uint64_t m, uint64 n, uint64 k, const void *alpha, \ @@ -628,7 +644,7 @@ class BlasSupport { int ldb, int64_t stride_b, const void *beta, DeviceMemoryBase *c, \ blas::DataType type_c, int ldc, int64_t stride_c, int batch_count, \ blas::ComputationType computation_type, blas::AlgorithmType algorithm, \ - blas::ComputePrecision precision, \ + const NumericOptions &numeric_options, \ blas::ProfileResult *output_profile_result) override; \ bool DoBlasTrsm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ blas::Transpose transa, blas::Diagonal diag, uint64_t m, \ diff --git a/tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc b/tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc index 238a6d634d5..3344fca4094 100644 --- a/tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc +++ b/tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc @@ -36,6 +36,7 @@ limitations under the License. #include "tensorflow/compiler/xla/stream_executor/gpu/gpu_stream.h" #include "tensorflow/compiler/xla/stream_executor/gpu/gpu_timer.h" #include "tensorflow/compiler/xla/stream_executor/gpu/gpu_types.h" +#include "tensorflow/compiler/xla/stream_executor/numeric_options.h" #include "tensorflow/compiler/xla/stream_executor/platform/initialize.h" #include "tensorflow/compiler/xla/stream_executor/platform/logging.h" #include "tensorflow/compiler/xla/stream_executor/platform/port.h" @@ -588,7 +589,7 @@ tsl::Status CUDABlas::DoBlasGemm(Stream *stream, blas::Transpose transa, const void *alpha, const DeviceMemoryBase &a, int lda, const DeviceMemoryBase &b, int ldb, const void *beta, DeviceMemoryBase *c, int ldc, - blas::ComputePrecision precision) { + const NumericOptions &numeric_options) { cublasMath_t math_type = CUBLAS_DEFAULT_MATH; #if CUDA_VERSION < 11000 @@ -598,17 +599,7 @@ tsl::Status CUDABlas::DoBlasGemm(Stream *stream, blas::Transpose transa, #else if (dtype == blas::DataType::kFloat) { math_type = CUBLAS_TF32_TENSOR_OP_MATH; - if (stream->GetCudaComputeCapability().IsAtLeast( - CudaComputeCapability::AMPERE)) { - // TODO(reedwm): Remove or make this VLOG(1) once TensorFloat-32 is more - // well tested. - if (tsl::tensor_float_32_execution_enabled()) { - LOG_FIRST_N(INFO, 1) << "TensorFloat-32 will be used for the matrix " - "multiplication. This will only be logged " - "once."; - } - } - if (precision > blas::kDefaultComputePrecision) { + if (numeric_options.allow_tf32) { math_type = CUBLAS_DEFAULT_MATH; } } @@ -722,7 +713,7 @@ static bool UsesTensorOps(blas::AlgorithmType algo) { static tsl::StatusOr GetMathTypeForGemmEx( Stream *stream, blas::AlgorithmType algorithm, blas::DataType type_a, - blas::DataType type_b, blas::ComputePrecision precision) { + blas::DataType type_b, const NumericOptions &numeric_options) { if (type_a != type_b) { return tsl::errors::Internal("Types of inputs mismatch"); } @@ -753,10 +744,6 @@ static tsl::StatusOr GetMathTypeForGemmEx( "Algorithm ", algorithm, " uses tensor ops, but tensor ops are not available in sm", cc.major, "X devices for float input types."); - } else if (!tsl::tensor_float_32_execution_enabled()) { - return tsl::errors::Internal( - "Algorithm ", algorithm, - " uses tensor ops, but tensor ops are disabled for fp32 inputs"); } math_type = CUBLAS_TF32_TENSOR_OP_MATH; #endif @@ -770,7 +757,7 @@ static tsl::StatusOr GetMathTypeForGemmEx( " uses tensor ops which are not supported for input"); } } - if (precision > blas::kDefaultComputePrecision) { + if (!numeric_options.allow_tf32) { math_type = CUBLAS_DEFAULT_MATH; } @@ -814,11 +801,11 @@ tsl::Status CUDABlas::DoBlasGemmWithAlgorithm( blas::DataType type_a, int lda, const DeviceMemoryBase &b, blas::DataType type_b, int ldb, const void *beta, DeviceMemoryBase *c, blas::DataType type_c, int ldc, blas::ComputationType computation_type, - blas::AlgorithmType algorithm, blas::ComputePrecision precision, + blas::AlgorithmType algorithm, const NumericOptions &numeric_options, blas::ProfileResult *output_profile_result) { TF_ASSIGN_OR_RETURN( cublasMath_t math_type, - GetMathTypeForGemmEx(stream, algorithm, type_a, type_b, precision)); + GetMathTypeForGemmEx(stream, algorithm, type_a, type_b, numeric_options)); TF_ASSIGN_OR_RETURN(auto timer, StartGpuTimerForProfile( stream, parent_, output_profile_result)); @@ -846,11 +833,11 @@ tsl::Status CUDABlas::DoBlasGemmStridedBatchedWithAlgorithm( blas::DataType type_b, int ldb, int64_t stride_b, const void *beta, DeviceMemoryBase *c, blas::DataType type_c, int ldc, int64_t stride_c, int batch_count, blas::ComputationType computation_type, - blas::AlgorithmType algorithm, blas::ComputePrecision precision, + blas::AlgorithmType algorithm, const NumericOptions &numeric_options, blas::ProfileResult *output_profile_result) { TF_ASSIGN_OR_RETURN( cublasMath_t math_type, - GetMathTypeForGemmEx(stream, algorithm, type_a, type_b, precision)); + GetMathTypeForGemmEx(stream, algorithm, type_a, type_b, numeric_options)); TF_ASSIGN_OR_RETURN(auto timer, StartGpuTimerForProfile( stream, parent_, output_profile_result)); @@ -991,6 +978,7 @@ tsl::Status CUDABlas::DoBlasGemmBatchedInternal( const DeviceMemorySlice &a_ptrs_to_wrappers, int lda, const DeviceMemorySlice &b_ptrs_to_wrappers, int ldb, Scalar beta, const DeviceMemorySlice &c_ptrs_to_wrappers, int ldc, int batch_count, + const NumericOptions &numeric_options, ScratchAllocator *scratch_allocator) { std::vector a_raw_ptrs, b_raw_ptrs, c_raw_ptrs; for (int i = 0; i < batch_count; ++i) { @@ -1072,12 +1060,14 @@ tsl::Status CUDABlas::DoBlasGemmBatchedInternal( algo = CUBLAS_GEMM_DFALT_TENSOR_OP; #if CUBLAS_VER_MAJOR >= 11 } else if (data_type == CUDA_R_32F) { - // DoBlassInternalImpl will switch math_type back to CUBLAS_DEFAULT_MATH - // if TensorFloat-32 is disabled. - math_type = CUBLAS_TF32_TENSOR_OP_MATH; - algo = tsl::tensor_float_32_execution_enabled() - ? CUBLAS_GEMM_DFALT_TENSOR_OP - : CUBLAS_GEMM_DFALT; + if (numeric_options.allow_tf32 && + tsl::tensor_float_32_execution_enabled()) { + math_type = CUBLAS_TENSOR_OP_MATH; + algo = CUBLAS_GEMM_DFALT_TENSOR_OP; + } else { + math_type = CUBLAS_DEFAULT_MATH; + algo = CUBLAS_GEMM_DFALT; + } #endif } else { math_type = CUBLAS_DEFAULT_MATH; @@ -1119,8 +1109,7 @@ tsl::Status CUDABlas::DoBlasGemmBatchedInternal( DeviceMemory *c_matrix = c_ptrs_to_wrappers[b]; TF_RETURN_IF_ERROR(DoBlasGemm( stream, transa, transb, m, n, k, blas::ToDataType::value, &alpha, - a_matrix, lda, b_matrix, ldb, &beta, c_matrix, ldc, - blas::kDefaultComputePrecision)); + a_matrix, lda, b_matrix, ldb, &beta, c_matrix, ldc, numeric_options)); } return ::tsl::OkStatus(); } @@ -1131,12 +1120,14 @@ bool CUDABlas::DoBlasGemmBatched( uint64_t n, uint64 k, float alpha, DeviceMemorySlice a_array, int lda, DeviceMemorySlice b_array, int ldb, float beta, DeviceMemorySlice c_array, int ldc, int batch_count, + const NumericOptions &numeric_options, ScratchAllocator *scratch_allocator) { // Note: The func passed here (cublasSgemmBatched) is not actually called, // due to special handling of fp16 inside DoBlasGemmBatchedInternal. tsl::Status status = DoBlasGemmBatchedInternal( cublasSgemmBatched, stream, transa, transb, m, n, k, alpha, a_array, lda, - b_array, ldb, beta, c_array, ldc, batch_count, scratch_allocator); + b_array, ldb, beta, c_array, ldc, batch_count, numeric_options, + scratch_allocator); if (!status.ok()) { LOG(ERROR) << status; } @@ -1149,12 +1140,14 @@ bool CUDABlas::DoBlasGemmBatched( DeviceMemorySlice a_array, int lda, DeviceMemorySlice b_array, int ldb, float beta, DeviceMemorySlice c_array, int ldc, int batch_count, + const NumericOptions &numeric_options, ScratchAllocator *scratch_allocator) { // Note: The func passed here (cublasSgemmBatched) is not actually called, // due to special handling of bf16 inside DoBlasGemmBatchedInternal. tsl::Status status = DoBlasGemmBatchedInternal( cublasSgemmBatched, stream, transa, transb, m, n, k, alpha, a_array, lda, - b_array, ldb, beta, c_array, ldc, batch_count, scratch_allocator); + b_array, ldb, beta, c_array, ldc, batch_count, numeric_options, + scratch_allocator); if (!status.ok()) { LOG(ERROR) << status; } @@ -1168,10 +1161,12 @@ bool CUDABlas::DoBlasGemmBatched(Stream *stream, blas::Transpose transa, DeviceMemorySlice b_array, int ldb, float beta, DeviceMemorySlice c_array, int ldc, int batch_count, + const NumericOptions &numeric_options, ScratchAllocator *scratch_allocator) { tsl::Status status = DoBlasGemmBatchedInternal( cublasSgemmBatched, stream, transa, transb, m, n, k, alpha, a_array, lda, - b_array, ldb, beta, c_array, ldc, batch_count, scratch_allocator); + b_array, ldb, beta, c_array, ldc, batch_count, numeric_options, + scratch_allocator); if (!status.ok()) { LOG(ERROR) << status; } @@ -1185,10 +1180,12 @@ bool CUDABlas::DoBlasGemmBatched(Stream *stream, blas::Transpose transa, DeviceMemorySlice b_array, int ldb, double beta, DeviceMemorySlice c_array, int ldc, int batch_count, + const NumericOptions &numeric_options, ScratchAllocator *scratch_allocator) { tsl::Status status = DoBlasGemmBatchedInternal( cublasDgemmBatched, stream, transa, transb, m, n, k, alpha, a_array, lda, - b_array, ldb, beta, c_array, ldc, batch_count, scratch_allocator); + b_array, ldb, beta, c_array, ldc, batch_count, numeric_options, + scratch_allocator); if (!status.ok()) { LOG(ERROR) << status; } @@ -1201,10 +1198,12 @@ bool CUDABlas::DoBlasGemmBatched( DeviceMemorySlice> a_array, int lda, DeviceMemorySlice> b_array, int ldb, std::complex beta, DeviceMemorySlice> c_array, - int ldc, int batch_count, ScratchAllocator *scratch_allocator) { + int ldc, int batch_count, const NumericOptions &numeric_options, + ScratchAllocator *scratch_allocator) { tsl::Status status = DoBlasGemmBatchedInternal( cublasCgemmBatched, stream, transa, transb, m, n, k, alpha, a_array, lda, - b_array, ldb, beta, c_array, ldc, batch_count, scratch_allocator); + b_array, ldb, beta, c_array, ldc, batch_count, numeric_options, + scratch_allocator); if (!status.ok()) { LOG(ERROR) << status; } @@ -1217,10 +1216,12 @@ bool CUDABlas::DoBlasGemmBatched( DeviceMemorySlice> a_array, int lda, DeviceMemorySlice> b_array, int ldb, std::complex beta, DeviceMemorySlice> c_array, - int ldc, int batch_count, ScratchAllocator *scratch_allocator) { + int ldc, int batch_count, const NumericOptions &numeric_options, + ScratchAllocator *scratch_allocator) { tsl::Status status = DoBlasGemmBatchedInternal( cublasZgemmBatched, stream, transa, transb, m, n, k, alpha, a_array, lda, - b_array, ldb, beta, c_array, ldc, batch_count, scratch_allocator); + b_array, ldb, beta, c_array, ldc, batch_count, numeric_options, + scratch_allocator); if (!status.ok()) { LOG(ERROR) << status; } @@ -1233,19 +1234,16 @@ tsl::Status CUDABlas::DoBlasGemmStridedBatched( const DeviceMemoryBase &a, int lda, int64_t stride_a, const DeviceMemoryBase &b, int ldb, int64_t stride_b, const void *beta, DeviceMemoryBase *c, int ldc, int64_t stride_c, int batch_count, - blas::ComputePrecision precision) { + const NumericOptions &numeric_options) { cublasMath_t math_type = CUBLAS_DEFAULT_MATH; #if CUDA_VERSION < 11000 if (dtype == dnn::kHalf) { math_type = CUBLAS_TENSOR_OP_MATH; } #else - if (dtype == dnn::kFloat) { + if (dtype == dnn::kFloat && numeric_options.allow_tf32) { math_type = CUBLAS_TF32_TENSOR_OP_MATH; } - if (precision > blas::kDefaultComputePrecision) { - math_type = CUBLAS_DEFAULT_MATH; - } #endif switch (dtype) { diff --git a/tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.h b/tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.h index 70aaa58aaa1..cd33cb953a4 100644 --- a/tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.h +++ b/tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.h @@ -109,6 +109,7 @@ class CUDABlas : public blas::BlasSupport { const DeviceMemorySlice &a_array, int lda, const DeviceMemorySlice &b_array, int ldb, Scalar beta, const DeviceMemorySlice &c_array, int ldc, int batch_count, + const NumericOptions &numeric_options, ScratchAllocator *scratch_allocator); // Guards the cuBLAS handle for this device. diff --git a/tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc b/tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc index e86fefe7a49..1cf08b27749 100644 --- a/tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc @@ -954,7 +954,7 @@ static bool TensorOpMathAvailable( } static bool IsTensorMathEnabled(CudaComputeCapability cuda_compute_capability, - dnn::DataType input_type) { + dnn::DataType input_type, bool allow_tf32) { if (!TensorOpMathAvailable(cuda_compute_capability)) { return false; } @@ -962,7 +962,7 @@ static bool IsTensorMathEnabled(CudaComputeCapability cuda_compute_capability, #if CUDNN_VERSION < 8000 return false; #else - if (!tsl::tensor_float_32_execution_enabled()) { + if (!allow_tf32 || !tsl::tensor_float_32_execution_enabled()) { return false; } #endif @@ -970,8 +970,10 @@ static bool IsTensorMathEnabled(CudaComputeCapability cuda_compute_capability, return true; } -static bool IsTensorMathEnabled(Stream* stream, dnn::DataType input_type) { - return IsTensorMathEnabled(stream->GetCudaComputeCapability(), input_type); +static bool IsTensorMathEnabled(Stream* stream, dnn::DataType input_type, + bool allow_tf32) { + return IsTensorMathEnabled(stream->GetCudaComputeCapability(), input_type, + allow_tf32); } // Turns a PoolingDescriptor structure into a cudnn pooling descriptor handle @@ -1314,8 +1316,9 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor { int cell_size, int batch_size, cudnnRNNInputMode_t input_mode, cudnnDirectionMode_t direction_mode, cudnnRNNMode_t rnn_mode, cudnnDataType_t data_type, cudnnDataType_t compute_type, - const dnn::AlgorithmConfig& algorithm_config, float dropout, - uint64_t seed, ScratchAllocator* state_allocator, bool use_padded_io) { + const dnn::AlgorithmConfig& algorithm_config, + const NumericOptions& numeric_options, float dropout, uint64_t seed, + ScratchAllocator* state_allocator, bool use_padded_io) { TF_ASSIGN_OR_RETURN( CudnnDropoutDescriptor dropout_desc, CudnnDropoutDescriptor::Create(cudnn, dropout, seed, state_allocator)); @@ -1337,7 +1340,8 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor { // TODO(csigg): Minimal support cuDNN version is 7.3, clean up. bool allow_tensor_ops = data_type == CUDNN_DATA_HALF; if (data_type == CUDNN_DATA_FLOAT) - allow_tensor_ops = tsl::tensor_float_32_execution_enabled(); + allow_tensor_ops = numeric_options.allow_tf32 && + tsl::tensor_float_32_execution_enabled(); bool use_tensor_ops = algorithm_config.algorithm().has_value() ? algorithm_config.algorithm()->tensor_ops_enabled() @@ -2471,8 +2475,8 @@ CudnnSupport::createRnnDescriptor( int batch_size, dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode, dnn::DataType data_type, const dnn::AlgorithmConfig& algorithm_config, - float dropout, uint64_t seed, ScratchAllocator* state_allocator, - bool use_padded_io) { + const NumericOptions& numeric_options, float dropout, uint64_t seed, + ScratchAllocator* state_allocator, bool use_padded_io) { // Setting up a cudnnRNNDescriptor requires a cuDNN handle, but because it's // not enqueueing anything into a stream, we pass in the null stream. auto cudnn = cudnn_->GetHandle(parent_, /*stream=*/nullptr); @@ -2483,7 +2487,8 @@ CudnnSupport::createRnnDescriptor( ToCudnnRnnInputMode(input_mode), ToCudnnRnnDirectionMode(direction_mode), ToCudnnRnnMode(rnn_mode), ToCudnnDataType(data_type), GetRnnComputeType(data_type), - algorithm_config, dropout, seed, state_allocator, use_padded_io)); + algorithm_config, numeric_options, dropout, seed, state_allocator, + use_padded_io)); return std::unique_ptr( new CudnnRnnDescriptor(std::move(rnn_desc))); } @@ -3087,19 +3092,16 @@ AllocateCudnnConvolutionBackwardFilterWorkspace( return scratch_allocator->AllocateBytes(size_in_bytes); } -tsl::StatusOr UseTensorOps(Stream* stream, dnn::DataType type, - std::optional desc) { - bool use_tensor_ops; +bool UseTensorOps(dnn::DataType input_type, + std::optional desc) { if (desc.has_value()) { - use_tensor_ops = desc->tensor_ops_enabled(); - if (use_tensor_ops && !IsTensorMathEnabled(stream, type)) { - return tsl::Status(absl::StatusCode::kInvalidArgument, - "Algo requests disabled tensor op evaluation."); - } + return desc->tensor_ops_enabled(); } else { - use_tensor_ops = IsTensorMathEnabled(stream, type); + // It's unknown whether the user wants to use TensorFloat-32, which is used + // with tensor ops when the inputs are FP32. For safety, assume the user + // does not want TensorFloat-32 on FP32 inputs. + return input_type != dnn::DataType::kFloat; } - return use_tensor_ops; } cudnnDataType_t GetRnnComputeType(dnn::DataType data_type); @@ -3118,9 +3120,7 @@ tsl::StatusOr GetCudnnConvolutionForwardAlgorithm( CudnnConvolutionDescriptor conv( convolution_descriptor, ToCudnnDataType(GetConvAccumulatorType(element_type))); - bool use_tensor_ops; - TF_ASSIGN_OR_RETURN(use_tensor_ops, - UseTensorOps(stream, element_type, algo_desc)); + bool use_tensor_ops = UseTensorOps(element_type, algo_desc); conv.set_use_tensor_op_math(use_tensor_ops); if (!algo_desc.has_value()) { @@ -3159,8 +3159,7 @@ tsl::StatusOr GetCudnnConvolutionForwardAlgorithm( "Returned status: ", scratch_or.status().ToString())); } - TF_ASSIGN_OR_RETURN(use_tensor_ops, - UseTensorOps(stream, element_type, algo_desc)); + use_tensor_ops = UseTensorOps(element_type, algo_desc); conv.set_use_tensor_op_math(use_tensor_ops); TF_ASSIGN_OR_RETURN(*scratch, AllocateCudnnConvolutionForwardWorkspace( stream, cudnn, input_nd, filter, conv, @@ -3180,9 +3179,7 @@ tsl::StatusOr GetCudnnConvolutionBackwardDataAlgorithm( CudnnConvolutionDescriptor conv( convolution_descriptor, ToCudnnDataType(GetConvAccumulatorType(element_type))); - bool use_tensor_ops; - TF_ASSIGN_OR_RETURN(use_tensor_ops, - UseTensorOps(stream, element_type, algo_desc)); + bool use_tensor_ops = UseTensorOps(element_type, algo_desc); conv.set_use_tensor_op_math(use_tensor_ops); if (!algo_desc.has_value()) { @@ -3220,8 +3217,7 @@ tsl::StatusOr GetCudnnConvolutionBackwardDataAlgorithm( "while a secondary algorithm is not provided."); } - TF_ASSIGN_OR_RETURN(use_tensor_ops, - UseTensorOps(stream, element_type, algo_desc)); + use_tensor_ops = UseTensorOps(element_type, algo_desc); conv.set_use_tensor_op_math(use_tensor_ops); TF_ASSIGN_OR_RETURN(*scratch, AllocateCudnnConvolutionBackwardDataWorkspace( stream, cudnn, input_nd, filter, conv, @@ -3241,9 +3237,7 @@ tsl::StatusOr GetCudnnConvolutionBackwardFilterAlgorithm( CudnnConvolutionDescriptor conv( convolution_descriptor, ToCudnnDataType(GetConvAccumulatorType(element_type))); - bool use_tensor_ops; - TF_ASSIGN_OR_RETURN(use_tensor_ops, - UseTensorOps(stream, element_type, algo_desc)); + bool use_tensor_ops = UseTensorOps(element_type, algo_desc); conv.set_use_tensor_op_math(use_tensor_ops); if (!algo_desc.has_value()) { @@ -3284,8 +3278,7 @@ tsl::StatusOr GetCudnnConvolutionBackwardFilterAlgorithm( scratch_or.status().ToString())); } - TF_ASSIGN_OR_RETURN(use_tensor_ops, - UseTensorOps(stream, element_type, algo_desc)); + use_tensor_ops = UseTensorOps(element_type, algo_desc); conv.set_use_tensor_op_math(use_tensor_ops); TF_ASSIGN_OR_RETURN(*scratch, AllocateCudnnConvolutionBackwardFilterWorkspace( stream, cudnn, input_nd, filter, conv, @@ -4913,9 +4906,6 @@ class CudnnLegacyConvRunner : public dnn::ConvRunner { DeviceMemoryBase output_data) const override { auto algo = MakeAlgorithmDesc(); - // Check that the current stream supports tensor ops if they're requested. - TF_RETURN_IF_ERROR(UseTensorOps(stream, input_type_, algo).status()); - if (static_cast(parent_) != stream->parent()->implementation()) { return tsl::errors::Internal( @@ -5104,8 +5094,7 @@ tsl::Status CudnnSupport::DoConvolve( auto accumulator_type = GetConvAccumulatorType(element_type); CudnnConvolutionDescriptor conv(convolution_descriptor, ToCudnnDataType(accumulator_type)); - TF_ASSIGN_OR_RETURN(bool use_tensor_ops, - UseTensorOps(stream, element_type, algorithm_desc)); + bool use_tensor_ops = UseTensorOps(element_type, algorithm_desc); conv.set_use_tensor_op_math(use_tensor_ops); TF_ASSIGN_OR_RETURN( @@ -5424,13 +5413,18 @@ tsl::Status CreateOpRunners( std::vector>>* out_runners, bool need_side_input, const NumericOptions& numeric_options) { cudnn_frontend::EngineConfigList filtered_configs; + const bool disable_winograd = !CudnnEnvVar::IsEnabled(); + const bool disable_nondeterminism = RequireCudnnDeterminism(numeric_options); + const bool disable_tensor_core = + !IsTensorMathEnabled(stream, input_type, numeric_options.allow_tf32); auto generic_filter_fn = [=](cudnnBackendDescriptor_t engine_config) -> bool { - return GenericEngineFilter( - engine_config, - /*disable_winograd*/ !CudnnEnvVar::IsEnabled(), - /*disable_nondeterminism*/ RequireCudnnDeterminism(numeric_options), - /*disable_tensor_core*/ !IsTensorMathEnabled(stream, input_type)); + return GenericEngineFilter(engine_config, disable_winograd, + disable_nondeterminism, disable_tensor_core); }; + VLOG(4) << "Filtering engine configs with disable_winograd=" + << disable_winograd + << ", disable_nondeterminism=" << disable_nondeterminism + << ", disable_tensor_core=" << disable_tensor_core; std::array heur_mode = {use_fallback ? "heuristics_fallback" : "heuristics_mode_b"}; @@ -6132,8 +6126,8 @@ bool CudnnSupport::GetConvolveAlgorithms( std::vector* out_algorithms) { PreloadCudnnSubLibs(PreloadCudnnType::ConvFwd); - bool tensor_op_math_available = - IsTensorMathEnabled(cuda_compute_capability, input_type); + bool tensor_op_math_available = IsTensorMathEnabled( + cuda_compute_capability, input_type, numeric_options.allow_tf32); out_algorithms->clear(); std::vector algo_types; @@ -6353,8 +6347,8 @@ bool CudnnSupport::GetConvolveBackwardDataAlgorithms( std::vector* out_algorithms) { PreloadCudnnSubLibs(PreloadCudnnType::ConvBwdData); - bool tensor_op_math_available = - IsTensorMathEnabled(cuda_compute_capability, input_type); + bool tensor_op_math_available = IsTensorMathEnabled( + cuda_compute_capability, input_type, numeric_options.allow_tf32); out_algorithms->clear(); std::vector algo_types = { @@ -6389,8 +6383,8 @@ bool CudnnSupport::GetConvolveBackwardFilterAlgorithms( std::vector* out_algorithms) { PreloadCudnnSubLibs(PreloadCudnnType::ConvBwdFilter); - bool tensor_op_math_available = - IsTensorMathEnabled(cuda_compute_capability, input_type); + bool tensor_op_math_available = IsTensorMathEnabled( + cuda_compute_capability, input_type, numeric_options.allow_tf32); out_algorithms->clear(); std::vector algo_types = { @@ -7090,8 +7084,7 @@ bool CudnnSupport::DoMatMul(Stream* stream, if (!stream ->ThenBlasGemm(blas::Transpose::kNoTranspose, blas::Transpose::kNoTranspose, m, n, k, weights, m, - input_data, k, output_data, m, - blas::kDefaultComputePrecision) + input_data, k, output_data, m, NumericOptions{}) .ok()) { return false; } @@ -7174,7 +7167,7 @@ bool CudnnSupport::DoMatMul(Stream* stream, stream->ThenBlasGemmBatched(blas::Transpose::kNoTranspose, blas::Transpose::kNoTranspose, m, n, k, alpha, toPtrs(a), lda, toPtrs(b), ldb, beta, toPtrs(c), - ldc, batch_count); + ldc, batch_count, NumericOptions{}); } return stream->ok(); @@ -7316,7 +7309,7 @@ tsl::Status CudnnSupport::DoPoolForward( const dnn::BatchDescriptor& output_dimensions, DeviceMemoryBase output_data, ScratchAllocator* workspace_allocator) { return DoPoolForward(element_type, stream, pooling_dimensions, - NumericOptions{false}, input_dimensions, input_data, + NumericOptions{}, input_dimensions, input_data, output_dimensions, output_data, workspace_allocator); } @@ -7396,7 +7389,7 @@ tsl::Status CudnnSupport::DoPoolBackward( DeviceMemoryBase input_diff_data, DeviceMemoryBase output_diff_data, ScratchAllocator* workspace_allocator) { return DoPoolBackward(element_type, stream, pooling_dimensions, - NumericOptions{false}, input_dimensions, input_data, + NumericOptions{}, input_dimensions, input_data, output_dimensions, output_data, input_diff_data, output_diff_data, workspace_allocator); } diff --git a/tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.h b/tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.h index 9005c17ae8d..2b703e1ca4e 100644 --- a/tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.h +++ b/tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.h @@ -62,8 +62,8 @@ class CudnnSupport : public dnn::DnnSupport { int batch_size, dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode, dnn::DataType data_type, const dnn::AlgorithmConfig& algorithm_config, - float dropout, uint64_t seed, ScratchAllocator* state_allocator, - bool use_padded_io) override; + const NumericOptions& numeric_options, float dropout, uint64_t seed, + ScratchAllocator* state_allocator, bool use_padded_io) override; tsl::StatusOr> createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size, diff --git a/tensorflow/compiler/xla/stream_executor/dnn.h b/tensorflow/compiler/xla/stream_executor/dnn.h index e4edfedbe71..138720d3fe7 100644 --- a/tensorflow/compiler/xla/stream_executor/dnn.h +++ b/tensorflow/compiler/xla/stream_executor/dnn.h @@ -2292,8 +2292,9 @@ class DnnSupport { dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode, dnn::DataType data_type, const dnn::AlgorithmConfig& algorithm_config, - float dropout, uint64_t seed, - ScratchAllocator* state_allocator, bool use_padded_io) { + const NumericOptions& numeric_options, float dropout, + uint64_t seed, ScratchAllocator* state_allocator, + bool use_padded_io) { return tsl::Status(absl::StatusCode::kUnimplemented, "createRnnDescriptor is unimplemented"); } diff --git a/tensorflow/compiler/xla/stream_executor/numeric_options.h b/tensorflow/compiler/xla/stream_executor/numeric_options.h index 2f76f0364d6..891ed68f06e 100644 --- a/tensorflow/compiler/xla/stream_executor/numeric_options.h +++ b/tensorflow/compiler/xla/stream_executor/numeric_options.h @@ -21,10 +21,15 @@ namespace stream_executor { // Options that specify the numeric behavior of operations like matrix // multiplications and convolutions struct NumericOptions { - explicit NumericOptions(bool require_determinism) - : require_determinism(require_determinism) {} + NumericOptions(bool require_determinism, bool allow_tf32) + : require_determinism(require_determinism), allow_tf32(allow_tf32) {} + NumericOptions() : require_determinism(false), allow_tf32(true) {} + + // If true, the op must be deterministic bool require_determinism; + // If true, float32 inputs can be rounded to TensorFloat-32 precision + bool allow_tf32; }; } // namespace stream_executor diff --git a/tensorflow/compiler/xla/stream_executor/rocm/rocm_blas.cc b/tensorflow/compiler/xla/stream_executor/rocm/rocm_blas.cc index 9949199258a..6e0b24a4fe1 100644 --- a/tensorflow/compiler/xla/stream_executor/rocm/rocm_blas.cc +++ b/tensorflow/compiler/xla/stream_executor/rocm/rocm_blas.cc @@ -406,7 +406,7 @@ tsl::Status ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa, const void *alpha, const DeviceMemoryBase &a, int lda, const DeviceMemoryBase &b, int ldb, const void *beta, DeviceMemoryBase *c, int ldc, - blas::ComputePrecision precision) { + const NumericOptions &numeric_options) { blas_log("DoBlasGemm"); VLOG(1) << absl::StreamFormat( "doing rocBLAS GEMM: at=%d bt=%d m=%u n=%u " @@ -529,7 +529,7 @@ tsl::Status ROCMBlas::DoBlasGemmWithAlgorithm( blas::DataType type_a, int lda, const DeviceMemoryBase &b, blas::DataType type_b, int ldb, const void *beta, DeviceMemoryBase *c, blas::DataType type_c, int ldc, blas::ComputationType computation_type, - blas::AlgorithmType algorithm, blas::ComputePrecision precision, + blas::AlgorithmType algorithm, const NumericOptions &numeric_options, blas::ProfileResult *output_profile_result) { // ROCM TODO: properly implement the interface return tsl::errors::Internal("DoBlasGemmWithAlgorithm ", @@ -543,7 +543,7 @@ tsl::Status ROCMBlas::DoBlasGemmStridedBatchedWithAlgorithm( blas::DataType type_b, int ldb, int64_t stride_b, const void *beta, DeviceMemoryBase *c, blas::DataType type_c, int ldc, int64_t stride_c, int batch_count, blas::ComputationType computation_type, - blas::AlgorithmType algorithm, blas::ComputePrecision precision, + blas::AlgorithmType algorithm, const NumericOptions &numeric_options, blas::ProfileResult *output_profile_result) { // ROCM TODO: properly implement the interface return tsl::errors::Internal("DoBlasGemmStridedBatchedWithAlgorithm ", @@ -842,6 +842,7 @@ bool ROCMBlas::DoBlasGemmBatched(Stream *stream, blas::Transpose transa, DeviceMemorySlice b, int ldb, float beta, DeviceMemorySlice c, int ldc, int batch_count, + const NumericOptions &numeric_options, ScratchAllocator *scratch_allocator) { blas_log("DoBlasGemmBatched"); const Eigen::half alpha_half(alpha); @@ -864,6 +865,7 @@ bool ROCMBlas::DoBlasGemmBatched( DeviceMemorySlice a_array, int lda, DeviceMemorySlice b_array, int ldb, float beta, DeviceMemorySlice c_array, int ldc, int batch_count, + const NumericOptions &numeric_options, ScratchAllocator *scratch_allocator) { blas_log("DoBlasGemmBatched"); const Eigen::bfloat16 alpha_bf16(alpha); @@ -886,6 +888,7 @@ bool ROCMBlas::DoBlasGemmBatched(Stream *stream, blas::Transpose transa, DeviceMemorySlice b_array, int ldb, float beta, DeviceMemorySlice c_array, int ldc, int batch_count, + const NumericOptions &numeric_options, ScratchAllocator *scratch_allocator) { blas_log("DoBlasGemmBatched"); tsl::Status status = DoBlasGemmBatchedInternal( @@ -905,6 +908,7 @@ bool ROCMBlas::DoBlasGemmBatched(Stream *stream, blas::Transpose transa, DeviceMemorySlice b_array, int ldb, double beta, DeviceMemorySlice c_array, int ldc, int batch_count, + const NumericOptions &numeric_options, ScratchAllocator *scratch_allocator) { blas_log("DoBlasGemmBatched"); tsl::Status status = DoBlasGemmBatchedInternal( @@ -923,7 +927,8 @@ bool ROCMBlas::DoBlasGemmBatched( DeviceMemorySlice> a_array, int lda, DeviceMemorySlice> b_array, int ldb, std::complex beta, DeviceMemorySlice> c_array, - int ldc, int batch_count, ScratchAllocator *scratch_allocator) { + int ldc, int batch_count, const NumericOptions &numeric_options, + ScratchAllocator *scratch_allocator) { blas_log("DoBlasGemmBatched"); tsl::Status status = DoBlasGemmBatchedInternal( wrap::rocblas_cgemm_strided_batched, stream, transa, transb, m, n, k, @@ -941,7 +946,8 @@ bool ROCMBlas::DoBlasGemmBatched( DeviceMemorySlice> a_array, int lda, DeviceMemorySlice> b_array, int ldb, std::complex beta, DeviceMemorySlice> c_array, - int ldc, int batch_count, ScratchAllocator *scratch_allocator) { + int ldc, int batch_count, const NumericOptions &numeric_options, + ScratchAllocator *scratch_allocator) { blas_log("DoBlasGemmBatched"); tsl::Status status = DoBlasGemmBatchedInternal( wrap::rocblas_zgemm_strided_batched, stream, transa, transb, m, n, k, @@ -1071,7 +1077,7 @@ tsl::Status ROCMBlas::DoBlasGemmStridedBatched( const DeviceMemoryBase &a, int lda, int64_t stride_a, const DeviceMemoryBase &b, int ldb, int64_t stride_b, const void *beta, DeviceMemoryBase *c, int ldc, int64_t stride_c, int batch_count, - blas::ComputePrecision precision) { + const NumericOptions &numeric_options) { VLOG(1) << absl::StreamFormat( "doing rocBLAS SGEMM Strided Batched: at=%d bt=%d m=%u n=%u " "k=%llu alpha=%p a=%p lda=%d b=%p ldb=%d beta=%p " diff --git a/tensorflow/compiler/xla/stream_executor/rocm/rocm_dnn.cc b/tensorflow/compiler/xla/stream_executor/rocm/rocm_dnn.cc index a31d37a3517..df92df03c71 100644 --- a/tensorflow/compiler/xla/stream_executor/rocm/rocm_dnn.cc +++ b/tensorflow/compiler/xla/stream_executor/rocm/rocm_dnn.cc @@ -2691,8 +2691,8 @@ MIOpenSupport::createRnnDescriptor( int batch_size, dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode, dnn::DataType data_type, const dnn::AlgorithmConfig& algorithm_config, - float dropout, uint64_t seed, ScratchAllocator* state_allocator, - bool use_padded_io) { + const NumericOptions& numeric_options, float dropout, uint64_t seed, + ScratchAllocator* state_allocator, bool use_padded_io) { // ROCM TODO: batch_size is used in dynamic persistent RNN algorithm and is // not supported by MIOpen now. if (use_padded_io) { @@ -3946,8 +3946,7 @@ bool MIOpenSupport::DoMatMul(Stream* stream, if (!stream ->ThenBlasGemm(blas::Transpose::kNoTranspose, blas::Transpose::kNoTranspose, m, n, k, weights, m, - input_data, k, output_data, m, - blas::kDefaultComputePrecision) + input_data, k, output_data, m, NumericOptions{}) .ok()) { return false; } @@ -4030,7 +4029,7 @@ bool MIOpenSupport::DoMatMul(Stream* stream, stream->ThenBlasGemmBatched(blas::Transpose::kNoTranspose, blas::Transpose::kNoTranspose, m, n, k, alpha, toPtrs(a), lda, toPtrs(b), ldb, beta, toPtrs(c), - ldc, batch_count); + ldc, batch_count, NumericOptions{}); } return stream->ok(); diff --git a/tensorflow/compiler/xla/stream_executor/rocm/rocm_dnn.h b/tensorflow/compiler/xla/stream_executor/rocm/rocm_dnn.h index b0714cad4e5..9abed5d085e 100644 --- a/tensorflow/compiler/xla/stream_executor/rocm/rocm_dnn.h +++ b/tensorflow/compiler/xla/stream_executor/rocm/rocm_dnn.h @@ -85,8 +85,8 @@ class MIOpenSupport : public dnn::DnnSupport { int batch_size, dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode, dnn::DataType data_type, const dnn::AlgorithmConfig& algorithm_config, - float dropout, uint64_t seed, ScratchAllocator* state_allocator, - bool use_padded_io) override; + const NumericOptions& numeric_options, float dropout, uint64_t seed, + ScratchAllocator* state_allocator, bool use_padded_io) override; tsl::StatusOr> createRnnSequenceTensorDescriptor(int seq_length, int batch_size, diff --git a/tensorflow/compiler/xla/stream_executor/stream.cc b/tensorflow/compiler/xla/stream_executor/stream.cc index 28dffedffba..107126d0980 100644 --- a/tensorflow/compiler/xla/stream_executor/stream.cc +++ b/tensorflow/compiler/xla/stream_executor/stream.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "third_party/eigen3/Eigen/Core" #include "tensorflow/compiler/xla/stream_executor/blas.h" +#include "tensorflow/compiler/xla/stream_executor/numeric_options.h" #include "tensorflow/compiler/xla/stream_executor/platform.h" #include "tensorflow/compiler/xla/stream_executor/platform/logging.h" #include "tensorflow/compiler/xla/stream_executor/platform/port.h" @@ -1588,9 +1589,11 @@ Stream &Stream::ThenBlasGemmBatched( blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, uint64_t k, float alpha, DeviceMemorySlice a, int lda, DeviceMemorySlice b, int ldb, float beta, - DeviceMemorySlice c, int ldc, int batch_count) { + DeviceMemorySlice c, int ldc, int batch_count, + const NumericOptions &numeric_options) { return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, + numeric_options, /*scratch_allocator=*/nullptr); } @@ -1599,6 +1602,7 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch( uint64_t k, float alpha, DeviceMemorySlice a, int lda, DeviceMemorySlice b, int ldb, float beta, DeviceMemorySlice c, int ldc, int batch_count, + const NumericOptions &numeric_options, ScratchAllocator *scratch_allocator) { VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), @@ -1607,11 +1611,12 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch( ThenBlasImpl, int, DeviceMemorySlice, int, float, - DeviceMemorySlice, int, int, ScratchAllocator *> + DeviceMemorySlice, int, int, const NumericOptions &, + ScratchAllocator *> impl; return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, - scratch_allocator); + numeric_options, scratch_allocator); } Stream &Stream::ThenBlasGemmBatchedWithScratch( @@ -1619,6 +1624,7 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch( uint64_t k, float alpha, DeviceMemorySlice a, int lda, DeviceMemorySlice b, int ldb, float beta, DeviceMemorySlice c, int ldc, int batch_count, + const NumericOptions &numeric_options, ScratchAllocator *scratch_allocator) { VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), @@ -1627,22 +1633,22 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch( ThenBlasImpl, int, DeviceMemorySlice, int, float, - DeviceMemorySlice, int, int, ScratchAllocator *> + DeviceMemorySlice, int, int, + const NumericOptions &, ScratchAllocator *> impl; return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, - scratch_allocator); + numeric_options, scratch_allocator); } -Stream &Stream::ThenBlasGemmBatched(blas::Transpose transa, - blas::Transpose transb, uint64_t m, - uint64 n, uint64_t k, float alpha, - DeviceMemorySlice a, int lda, - DeviceMemorySlice b, int ldb, - float beta, DeviceMemorySlice c, - int ldc, int batch_count) { +Stream &Stream::ThenBlasGemmBatched( + blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, + uint64_t k, float alpha, DeviceMemorySlice a, int lda, + DeviceMemorySlice b, int ldb, float beta, DeviceMemorySlice c, + int ldc, int batch_count, const NumericOptions &numeric_options) { return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, + numeric_options, /*scratch_allocator=*/nullptr); } @@ -1650,7 +1656,8 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch( blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, uint64_t k, float alpha, DeviceMemorySlice a, int lda, DeviceMemorySlice b, int ldb, float beta, DeviceMemorySlice c, - int ldc, int batch_count, ScratchAllocator *scratch_allocator) { + int ldc, int batch_count, const NumericOptions &numeric_options, + ScratchAllocator *scratch_allocator) { VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count)); @@ -1658,11 +1665,11 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch( ThenBlasImpl, int, DeviceMemorySlice, int, float, DeviceMemorySlice, int, int, - ScratchAllocator *> + const NumericOptions &, ScratchAllocator *> impl; return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, - scratch_allocator); + numeric_options, scratch_allocator); } Stream &Stream::ThenBlasGemmBatched(blas::Transpose transa, @@ -1671,9 +1678,11 @@ Stream &Stream::ThenBlasGemmBatched(blas::Transpose transa, DeviceMemorySlice a, int lda, DeviceMemorySlice b, int ldb, double beta, DeviceMemorySlice c, - int ldc, int batch_count) { + int ldc, int batch_count, + const NumericOptions &numeric_options) { return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, + numeric_options, /*scratch_allocator=*/nullptr); } @@ -1682,6 +1691,7 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch( uint64_t k, double alpha, DeviceMemorySlice a, int lda, DeviceMemorySlice b, int ldb, double beta, DeviceMemorySlice c, int ldc, int batch_count, + const NumericOptions &numeric_options, ScratchAllocator *scratch_allocator) { VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), @@ -1690,11 +1700,12 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch( ThenBlasImpl, int, DeviceMemorySlice, int, double, - DeviceMemorySlice, int, int, ScratchAllocator *> + DeviceMemorySlice, int, int, const NumericOptions &, + ScratchAllocator *> impl; return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, - scratch_allocator); + numeric_options, scratch_allocator); } Stream &Stream::ThenBlasGemmBatched( @@ -1702,9 +1713,11 @@ Stream &Stream::ThenBlasGemmBatched( uint64_t k, std::complex alpha, DeviceMemorySlice> a, int lda, DeviceMemorySlice> b, int ldb, std::complex beta, - DeviceMemorySlice> c, int ldc, int batch_count) { + DeviceMemorySlice> c, int ldc, int batch_count, + const NumericOptions &numeric_options) { return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, + numeric_options, /*scratch_allocator=*/nullptr); } @@ -1714,6 +1727,7 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch( DeviceMemorySlice> a, int lda, DeviceMemorySlice> b, int ldb, std::complex beta, DeviceMemorySlice> c, int ldc, int batch_count, + const NumericOptions &numeric_options, ScratchAllocator *scratch_allocator) { VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), @@ -1723,11 +1737,11 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch( std::complex, DeviceMemorySlice>, int, DeviceMemorySlice>, int, std::complex, DeviceMemorySlice>, int, int, - ScratchAllocator *> + const NumericOptions &, ScratchAllocator *> impl; return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, - scratch_allocator); + numeric_options, scratch_allocator); } Stream &Stream::ThenBlasGemmBatched( @@ -1736,9 +1750,10 @@ Stream &Stream::ThenBlasGemmBatched( DeviceMemorySlice> a, int lda, DeviceMemorySlice> b, int ldb, std::complex beta, DeviceMemorySlice> c, - int ldc, int batch_count) { + int ldc, int batch_count, const NumericOptions &numeric_options) { return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, + numeric_options, /*scratch_allocator=*/nullptr); } @@ -1748,7 +1763,8 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch( DeviceMemorySlice> a, int lda, DeviceMemorySlice> b, int ldb, std::complex beta, DeviceMemorySlice> c, - int ldc, int batch_count, ScratchAllocator *scratch_allocator) { + int ldc, int batch_count, const NumericOptions &numeric_options, + ScratchAllocator *scratch_allocator) { VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count)); @@ -1757,11 +1773,11 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch( std::complex, DeviceMemorySlice>, int, DeviceMemorySlice>, int, std::complex, DeviceMemorySlice>, - int, int, ScratchAllocator *> + int, int, const NumericOptions &, ScratchAllocator *> impl; return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, - scratch_allocator); + numeric_options, scratch_allocator); } Stream &Stream::ThenSetRngSeed(const uint8 *seed, uint64_t seed_bytes) { diff --git a/tensorflow/compiler/xla/stream_executor/stream.h b/tensorflow/compiler/xla/stream_executor/stream.h index 8b0591087a6..79a75f1431f 100644 --- a/tensorflow/compiler/xla/stream_executor/stream.h +++ b/tensorflow/compiler/xla/stream_executor/stream.h @@ -28,6 +28,7 @@ limitations under the License. #include #include #include +#include #include "absl/base/thread_annotations.h" #include "absl/functional/any_invocable.h" @@ -893,14 +894,15 @@ class Stream { const DeviceMemory &a, int lda, const DeviceMemory &b, int ldb, DeviceMemory *c, int ldc, - blas::ComputePrecision precision) { + const NumericOptions &numeric_options) { InputType alpha{1.0}; InputType beta{0.0}; return ThenBlasGemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, - ldc, precision); + ldc, numeric_options); } - // TODO(parkers): Update all callers to pass kDefaultComputePrecision. + // TODO(reedwm): Update all callers (if there are any) to pass correct + // NumericOptions. template tsl::Status ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, uint64 k, @@ -908,7 +910,7 @@ class Stream { const DeviceMemory &b, int ldb, DeviceMemory *c, int ldc) { return ThenBlasGemm(transa, transb, m, n, k, a, lda, b, ldb, c, ldc, - blas::kDefaultComputePrecision); + NumericOptions{}); } template @@ -917,7 +919,7 @@ class Stream { const DeviceMemory &a, int lda, const DeviceMemory &b, int ldb, ConstantType beta, DeviceMemory *c, - int ldc, blas::ComputePrecision precision) { + int ldc, const NumericOptions &numeric_options) { static_assert( detail::is_any_of, @@ -948,10 +950,10 @@ class Stream { return blas->DoBlasGemm(this, transa, transb, m, n, k, blas::ToDataType::value, alpha_ptr, a, - lda, b, ldb, beta_ptr, c, ldc, precision); + lda, b, ldb, beta_ptr, c, ldc, numeric_options); } - // TODO(parkers): Update all callers to pass kDefaultComputePrecision. + // TODO(reedwm): Update all callers to pass correct NumericOptions. template tsl::Status ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, uint64 k, ConstantType alpha, @@ -960,7 +962,7 @@ class Stream { ConstantType beta, DeviceMemory *c, int ldc) { return ThenBlasGemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, - ldc, blas::kDefaultComputePrecision); + ldc, NumericOptions{}); } template @@ -973,10 +975,9 @@ class Stream { blas::ProfileResult *output_profile_result) { OutputType alpha{1}; OutputType beta{0}; - return ThenBlasGemmWithAlgorithm(transa, transb, m, n, k, alpha, a, lda, b, - ldb, beta, c, ldc, computation_type, - algorithm, blas::kDefaultComputePrecision, - output_profile_result); + return ThenBlasGemmWithAlgorithm( + transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + computation_type, algorithm, NumericOptions{}, output_profile_result); } template @@ -986,7 +987,7 @@ class Stream { const DeviceMemory &b, int ldb, ConstantType beta, DeviceMemory *c, int ldc, blas::ComputationType computation_type, blas::AlgorithmType algorithm, - blas::ComputePrecision precision, + const NumericOptions &numeric_options, blas::ProfileResult *output_profile_result) { TF_RETURN_IF_ERROR( CheckTypesForExtendedBlas( @@ -1010,7 +1011,7 @@ class Stream { blas::ToDataType::value, lda, b, blas::ToDataType::value, ldb, beta_ptr, c, blas::ToDataType::value, ldc, computation_type, algorithm, - precision, output_profile_result); + numeric_options, output_profile_result); if (output_profile_result) { // The error is recorded in the profile. return ::tsl::OkStatus(); @@ -1025,7 +1026,7 @@ class Stream { int64_t stride_a, const DeviceMemory &b, int ldb, int64_t stride_b, ConstantType beta, DeviceMemory *c, int ldc, int64_t stride_c, int batch_count, blas::ComputationType computation_type, - blas::AlgorithmType algorithm, blas::ComputePrecision precision, + blas::AlgorithmType algorithm, const NumericOptions &numeric_options, blas::ProfileResult *output_profile_result) { TF_RETURN_IF_ERROR( CheckTypesForExtendedBlas( @@ -1047,7 +1048,7 @@ class Stream { blas::ToDataType::value, lda, stride_a, b, blas::ToDataType::value, ldb, stride_b, beta_ptr, c, blas::ToDataType::value, ldc, stride_c, batch_count, - computation_type, algorithm, precision, output_profile_result); + computation_type, algorithm, numeric_options, output_profile_result); if (output_profile_result) { // The error is recorded in the profile. return ::tsl::OkStatus(); @@ -1064,19 +1065,22 @@ class Stream { DeviceMemorySlice a, int lda, DeviceMemorySlice b, int ldb, float beta, DeviceMemorySlice c, - int ldc, int batch_count); + int ldc, int batch_count, + const NumericOptions &numeric_options); Stream &ThenBlasGemmBatched(blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, uint64 k, float alpha, DeviceMemorySlice a, int lda, DeviceMemorySlice b, int ldb, float beta, DeviceMemorySlice c, int ldc, - int batch_count); + int batch_count, + const NumericOptions &numeric_options); Stream &ThenBlasGemmBatched(blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, uint64 k, double alpha, DeviceMemorySlice a, int lda, DeviceMemorySlice b, int ldb, double beta, DeviceMemorySlice c, int ldc, - int batch_count); + int batch_count, + const NumericOptions &numeric_options); Stream &ThenBlasGemmBatched(blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, uint64_t k, std::complex alpha, @@ -1084,27 +1088,28 @@ class Stream { DeviceMemorySlice> b, int ldb, std::complex beta, DeviceMemorySlice> c, int ldc, - int batch_count); - Stream &ThenBlasGemmBatched(blas::Transpose transa, blas::Transpose transb, - uint64_t m, uint64 n, uint64_t k, - std::complex alpha, - DeviceMemorySlice> a, - int lda, - DeviceMemorySlice> b, - int ldb, std::complex beta, - DeviceMemorySlice> c, - int ldc, int batch_count); + int batch_count, + const NumericOptions &numeric_options); + Stream &ThenBlasGemmBatched( + blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, + uint64_t k, std::complex alpha, + DeviceMemorySlice> a, int lda, + DeviceMemorySlice> b, int ldb, + std::complex beta, DeviceMemorySlice> c, + int ldc, int batch_count, const NumericOptions &numeric_options); Stream &ThenBlasGemmBatchedWithScratch( blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, uint64_t k, float alpha, DeviceMemorySlice a, int lda, DeviceMemorySlice b, int ldb, float beta, DeviceMemorySlice c, int ldc, int batch_count, + const NumericOptions &numeric_options, ScratchAllocator *scratch_allocator); Stream &ThenBlasGemmBatchedWithScratch( blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, uint64_t k, float alpha, DeviceMemorySlice a, int lda, DeviceMemorySlice b, int ldb, float beta, DeviceMemorySlice c, int ldc, int batch_count, + const NumericOptions &numeric_options, ScratchAllocator *scratch_allocator); Stream &ThenBlasGemmBatchedWithScratch(blas::Transpose transa, blas::Transpose transb, uint64_t m, @@ -1113,12 +1118,14 @@ class Stream { DeviceMemorySlice b, int ldb, float beta, DeviceMemorySlice c, int ldc, int batch_count, + const NumericOptions &numeric_options, ScratchAllocator *scratch_allocator); Stream &ThenBlasGemmBatchedWithScratch( blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, uint64_t k, double alpha, DeviceMemorySlice a, int lda, DeviceMemorySlice b, int ldb, double beta, DeviceMemorySlice c, int ldc, int batch_count, + const NumericOptions &numeric_options, ScratchAllocator *scratch_allocator); Stream &ThenBlasGemmBatchedWithScratch( blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, @@ -1126,14 +1133,16 @@ class Stream { DeviceMemorySlice> a, int lda, DeviceMemorySlice> b, int ldb, std::complex beta, DeviceMemorySlice> c, - int ldc, int batch_count, ScratchAllocator *scratch_allocator); + int ldc, int batch_count, const NumericOptions &numeric_options, + ScratchAllocator *scratch_allocator); Stream &ThenBlasGemmBatchedWithScratch( blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, uint64_t k, std::complex alpha, DeviceMemorySlice> a, int lda, DeviceMemorySlice> b, int ldb, std::complex beta, DeviceMemorySlice> c, - int ldc, int batch_count, ScratchAllocator *scratch_allocator); + int ldc, int batch_count, const NumericOptions &numeric_options, + ScratchAllocator *scratch_allocator); template tsl::Status ThenBlasGemmStridedBatched( @@ -1141,7 +1150,8 @@ class Stream { uint64_t k, ConstantType alpha, const DeviceMemory &a, int lda, int64_t stride_a, const DeviceMemory &b, int ldb, int64_t stride_b, ConstantType beta, DeviceMemory *c, int ldc, - int64_t stride_c, int batch_count, blas::ComputePrecision precision) { + int64_t stride_c, int batch_count, + const NumericOptions &numeric_options) { static_assert( detail::is_any_of, @@ -1168,7 +1178,7 @@ class Stream { return blas->DoBlasGemmStridedBatched( this, transa, transb, m, n, k, blas::ToDataType::value, alpha_ptr, a, lda, stride_a, b, ldb, stride_b, beta_ptr, c, ldc, - stride_c, batch_count, precision); + stride_c, batch_count, numeric_options); } // See BlasSupport::DoBlasTrsm. diff --git a/tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.cc b/tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.cc index b619bb5016a..6b092898eae 100644 --- a/tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.cc +++ b/tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.cc @@ -377,8 +377,8 @@ StreamExecutor::createRnnDescriptor( int batch_size, dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode, dnn::DataType data_type, const dnn::AlgorithmConfig& algorithm_config, - float dropout, uint64_t seed, ScratchAllocator* state_allocator, - bool use_padded_io) { + const NumericOptions& numeric_options, float dropout, uint64_t seed, + ScratchAllocator* state_allocator, bool use_padded_io) { dnn::DnnSupport* dnn_support = AsDnn(); if (!dnn_support) { return tsl::Status(absl::StatusCode::kUnknown, @@ -386,8 +386,8 @@ StreamExecutor::createRnnDescriptor( } return dnn_support->createRnnDescriptor( num_layers, hidden_size, input_size, cell_size, batch_size, input_mode, - direction_mode, rnn_mode, data_type, algorithm_config, dropout, seed, - state_allocator, use_padded_io); + direction_mode, rnn_mode, data_type, algorithm_config, numeric_options, + dropout, seed, state_allocator, use_padded_io); } tsl::StatusOr> diff --git a/tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.h b/tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.h index f40e943a20c..4b60b306137 100644 --- a/tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.h +++ b/tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.h @@ -20,6 +20,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -432,8 +433,8 @@ class StreamExecutor { int batch_size, dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode, dnn::DataType data_type, const dnn::AlgorithmConfig& algorithm_config, - float dropout, uint64_t seed, ScratchAllocator* state_allocator, - bool use_padded_io); + const NumericOptions& numeric_options, float dropout, uint64_t seed, + ScratchAllocator* state_allocator, bool use_padded_io); // Create a RNN sequence descriptor that specifies either the input or output // sequence. The caller retains the ownership of the returned descriptor. diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index e7c0c0108e1..e6b8cf435b9 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -1215,8 +1215,6 @@ xla_test( "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:padding", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/tsl/platform:tensor_float_32_hdr_lib", - "//tensorflow/tsl/platform:tensor_float_32_utils", "//tensorflow/tsl/platform:test", ], ) @@ -2770,8 +2768,6 @@ xla_test( "//tensorflow/compiler/xla/client/lib:arithmetic", "//tensorflow/compiler/xla/client/lib:matrix", "//tensorflow/tsl/lib/core:status_test_util", - "//tensorflow/tsl/platform:tensor_float_32_hdr_lib", - "//tensorflow/tsl/platform:tensor_float_32_utils", ], ) diff --git a/tensorflow/compiler/xla/tests/cholesky_test.cc b/tensorflow/compiler/xla/tests/cholesky_test.cc index bd421b8eeeb..4dd9e8dac9c 100644 --- a/tensorflow/compiler/xla/tests/cholesky_test.cc +++ b/tensorflow/compiler/xla/tests/cholesky_test.cc @@ -30,7 +30,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/tsl/lib/core/status_test_util.h" -#include "tensorflow/tsl/platform/tensor_float_32_utils.h" namespace xla { namespace { @@ -220,8 +219,6 @@ class RandomCholeskyTest public ::testing::WithParamInterface {}; XLA_TEST_P(RandomCholeskyTest, Real) { - // Test fails with TensorFloat-32 enabled - tsl::enable_tensor_float_32_execution(false); XlaBuilder builder(TestName()); auto test_params = GetParam(); @@ -259,8 +256,6 @@ XLA_TEST_P(RandomCholeskyTest, Real) { } XLA_TEST_P(RandomCholeskyTest, Complex) { - // Test fails with TensorFloat-32 enabled - tsl::enable_tensor_float_32_execution(false); XlaBuilder builder(TestName()); auto test_params = GetParam(); diff --git a/tensorflow/compiler/xla/tests/convolution_variants_test.cc b/tensorflow/compiler/xla/tests/convolution_variants_test.cc index 4a574979cfb..7f5756b954f 100644 --- a/tensorflow/compiler/xla/tests/convolution_variants_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_variants_test.cc @@ -34,7 +34,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/tsl/platform/tensor_float_32_utils.h" #include "tensorflow/tsl/platform/test.h" namespace xla { @@ -50,16 +49,16 @@ class ConvolutionVariantsTest : public ClientLibraryTestBase { ErrorSpec error_spec_ = ErrorSpec(1e-4, 1e-2); #endif - void SetUp() override { - init_tf32_status_ = tsl::tensor_float_32_execution_enabled(); - tsl::enable_tensor_float_32_execution(false); + XlaOp ConvWithHighestPrecision(const XlaOp lhs, const XlaOp rhs, + absl::Span window_strides, + Padding padding) { + PrecisionConfig precision_config; + // Set the 2 operands to have the HIGHEST precision. + precision_config.add_operand_precision(PrecisionConfig::HIGHEST); + precision_config.add_operand_precision(PrecisionConfig::HIGHEST); + return Conv(lhs, rhs, window_strides, padding, /*feature_group_count=*/1, + /*batch_group_count=*/1, &precision_config); } - void TearDown() override { - tsl::enable_tensor_float_32_execution(init_tf32_status_); - } - - private: - bool init_tf32_status_; }; XLA_TEST_F(ConvolutionVariantsTest, Minimal) { @@ -626,7 +625,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter16x16x1x1Input16x16x1x1) { auto input = ConstantR4FromArray4D(&builder, input_array); auto filter = ConstantR4FromArray4D(&builder, filter_array); - Conv(input, filter, {1, 1}, Padding::kValid); + ConvWithHighestPrecision(input, filter, {1, 1}, Padding::kValid); Array4D expected(16, 16, 1, 1); for (int i0 = 0; i0 < 16; ++i0) { diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 521a27b4b28..43908caa8f7 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -530,8 +530,9 @@ cc_library( hdrs = ["numeric_options_utils.h"], deps = [ "//tensorflow/compiler/xla/stream_executor:numeric_options", + "//tensorflow/core/platform:tensor_float_32_hdr_lib", "//tensorflow/core/util:determinism_for_kernels", - ], + ] + if_static(["//tensorflow/core/platform:tensor_float_32_utils"]), ) tf_cuda_library( @@ -1531,6 +1532,7 @@ tf_kernel_library( deps = [ ":cast_op", ":gpu_utils", + ":numeric_options_utils", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", @@ -3477,6 +3479,7 @@ tf_kernel_library( ":conv_ops_gpu_hdrs", ":gpu_utils", ":conv_ops", + ":numeric_options_utils", "//tensorflow/core/protobuf:autotuning_proto_cc", "//tensorflow/core/util/autotune_maps:conv_parameters", "//tensorflow/core/util/proto:proto_utils", diff --git a/tensorflow/core/kernels/conv_grad_filter_ops_3d.cc b/tensorflow/core/kernels/conv_grad_filter_ops_3d.cc index acfb5954203..0cb52b8419d 100644 --- a/tensorflow/core/kernels/conv_grad_filter_ops_3d.cc +++ b/tensorflow/core/kernels/conv_grad_filter_ops_3d.cc @@ -45,6 +45,7 @@ limitations under the License. #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include "tensorflow/core/kernels/cast_op.h" +#include "tensorflow/core/kernels/numeric_options_utils.h" #include "tensorflow/core/platform/stream_executor.h" using stream_executor::dnn::DimIndex; #include "tensorflow/core/protobuf/autotuning.pb.h" @@ -730,11 +731,10 @@ void LaunchConvBackpropFilterOpImpl( auto c_ptr = AsDeviceMemory(filter_backprop->template flat().data(), filter_backprop->template flat().size()); - OP_REQUIRES_OK( - context, stream->ThenBlasGemm(se::blas::Transpose::kNoTranspose, - se::blas::Transpose::kTranspose, n, m, k, - a_ptr, n, b_ptr, m, &c_ptr, n, - se::blas::kDefaultComputePrecision)); + OP_REQUIRES_OK(context, stream->ThenBlasGemm( + se::blas::Transpose::kNoTranspose, + se::blas::Transpose::kTranspose, n, m, k, a_ptr, + n, b_ptr, m, &c_ptr, n, GetNumericOptions())); return; } else if (!is_grouped_convolution && dims.filter_size(0) == dims.input_size(0) && @@ -753,11 +753,10 @@ void LaunchConvBackpropFilterOpImpl( auto c_ptr = AsDeviceMemory(filter_backprop->template flat().data(), filter_backprop->template flat().size()); - OP_REQUIRES_OK( - context, stream->ThenBlasGemm(se::blas::Transpose::kNoTranspose, - se::blas::Transpose::kTranspose, n, m, k, - b_ptr, n, a_ptr, m, &c_ptr, n, - se::blas::kDefaultComputePrecision)); + OP_REQUIRES_OK(context, stream->ThenBlasGemm( + se::blas::Transpose::kNoTranspose, + se::blas::Transpose::kTranspose, n, m, k, b_ptr, + n, a_ptr, m, &c_ptr, n, GetNumericOptions())); return; } diff --git a/tensorflow/core/kernels/conv_grad_filter_ops_launcher.cc b/tensorflow/core/kernels/conv_grad_filter_ops_launcher.cc index e04f9949978..1738c9413c9 100644 --- a/tensorflow/core/kernels/conv_grad_filter_ops_launcher.cc +++ b/tensorflow/core/kernels/conv_grad_filter_ops_launcher.cc @@ -48,6 +48,7 @@ limitations under the License. #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include "tensorflow/core/kernels/cast_op.h" #include "tensorflow/core/kernels/conv_ops_gpu.h" +#include "tensorflow/core/kernels/numeric_options_utils.h" #include "tensorflow/core/platform/stream_executor.h" #include "tensorflow/core/protobuf/autotuning.pb.h" #include "tensorflow/core/util/autotune_maps/conv_parameters.h" @@ -260,11 +261,10 @@ void LaunchConv2DBackpropFilterOpImpl( auto c_ptr = AsDeviceMemory(filter_backprop->template flat().data(), filter_backprop->template flat().size()); - OP_REQUIRES_OK( - ctx, stream->ThenBlasGemm(se::blas::Transpose::kNoTranspose, - se::blas::Transpose::kTranspose, n, m, k, - a_ptr, n, b_ptr, m, &c_ptr, n, - se::blas::kDefaultComputePrecision)); + OP_REQUIRES_OK(ctx, stream->ThenBlasGemm(se::blas::Transpose::kNoTranspose, + se::blas::Transpose::kTranspose, n, + m, k, a_ptr, n, b_ptr, m, &c_ptr, + n, GetNumericOptions())); return; } else if (dims.spatial_dims[0].filter_size == dims.spatial_dims[0].input_size && @@ -286,11 +286,10 @@ void LaunchConv2DBackpropFilterOpImpl( auto c_ptr = AsDeviceMemory(filter_backprop->template flat().data(), filter_backprop->template flat().size()); - OP_REQUIRES_OK( - ctx, stream->ThenBlasGemm(se::blas::Transpose::kNoTranspose, - se::blas::Transpose::kTranspose, n, m, k, - b_ptr, n, a_ptr, m, &c_ptr, n, - se::blas::kDefaultComputePrecision)); + OP_REQUIRES_OK(ctx, stream->ThenBlasGemm(se::blas::Transpose::kNoTranspose, + se::blas::Transpose::kTranspose, n, + m, k, b_ptr, n, a_ptr, m, &c_ptr, + n, GetNumericOptions())); return; } diff --git a/tensorflow/core/kernels/conv_grad_input_ops.cc b/tensorflow/core/kernels/conv_grad_input_ops.cc index 85d4c73153f..a222fd3e89c 100644 --- a/tensorflow/core/kernels/conv_grad_input_ops.cc +++ b/tensorflow/core/kernels/conv_grad_input_ops.cc @@ -23,6 +23,7 @@ limitations under the License. #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include "tensorflow/core/kernels/cast_op.h" +#include "tensorflow/core/kernels/numeric_options_utils.h" #include "tensorflow/core/protobuf/autotuning.pb.h" #include "tensorflow/core/util/autotune_maps/conv_parameters.h" #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM @@ -156,9 +157,9 @@ void LaunchConv2DBackpropInputOpGpuImpl( auto transpose = se::blas::Transpose::kTranspose; auto no_transpose = se::blas::Transpose::kNoTranspose; - OP_REQUIRES_OK(ctx, stream->ThenBlasGemm( - transpose, no_transpose, n, m, k, b_ptr, k, a_ptr, - k, &c_ptr, n, se::blas::kDefaultComputePrecision)); + OP_REQUIRES_OK( + ctx, stream->ThenBlasGemm(transpose, no_transpose, n, m, k, b_ptr, k, + a_ptr, k, &c_ptr, n, GetNumericOptions())); return; } else if (dims.spatial_dims[0].filter_size == dims.spatial_dims[0].input_size && @@ -183,9 +184,9 @@ void LaunchConv2DBackpropInputOpGpuImpl( auto transpose = se::blas::Transpose::kTranspose; auto no_transpose = se::blas::Transpose::kNoTranspose; - OP_REQUIRES_OK(ctx, stream->ThenBlasGemm( - transpose, no_transpose, n, m, k, b_ptr, k, a_ptr, - k, &c_ptr, n, se::blas::kDefaultComputePrecision)); + OP_REQUIRES_OK( + ctx, stream->ThenBlasGemm(transpose, no_transpose, n, m, k, b_ptr, k, + a_ptr, k, &c_ptr, n, GetNumericOptions())); return; } diff --git a/tensorflow/core/kernels/conv_grad_input_ops_3d.cc b/tensorflow/core/kernels/conv_grad_input_ops_3d.cc index d04217251e2..3e35e82bb9a 100644 --- a/tensorflow/core/kernels/conv_grad_input_ops_3d.cc +++ b/tensorflow/core/kernels/conv_grad_input_ops_3d.cc @@ -45,6 +45,7 @@ limitations under the License. #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include "tensorflow/core/kernels/cast_op.h" +#include "tensorflow/core/kernels/numeric_options_utils.h" #include "tensorflow/core/platform/stream_executor.h" using stream_executor::dnn::DimIndex; #include "tensorflow/core/protobuf/autotuning.pb.h" @@ -732,10 +733,9 @@ void LaunchConvBackpropInputOpImpl( auto transpose = se::blas::Transpose::kTranspose; auto no_transpose = se::blas::Transpose::kNoTranspose; - OP_REQUIRES_OK( - context, - stream->ThenBlasGemm(transpose, no_transpose, n, m, k, b_ptr, k, a_ptr, - k, &c_ptr, n, se::blas::kDefaultComputePrecision)); + OP_REQUIRES_OK(context, stream->ThenBlasGemm(transpose, no_transpose, n, m, + k, b_ptr, k, a_ptr, k, &c_ptr, + n, GetNumericOptions())); return; } else if (!is_grouped_convolution && dims.filter_size(0) == dims.input_size(0) && @@ -757,10 +757,9 @@ void LaunchConvBackpropInputOpImpl( auto transpose = se::blas::Transpose::kTranspose; auto no_transpose = se::blas::Transpose::kNoTranspose; - OP_REQUIRES_OK( - context, - stream->ThenBlasGemm(transpose, no_transpose, n, m, k, b_ptr, k, a_ptr, - k, &c_ptr, n, se::blas::kDefaultComputePrecision)); + OP_REQUIRES_OK(context, stream->ThenBlasGemm(transpose, no_transpose, n, m, + k, b_ptr, k, a_ptr, k, &c_ptr, + n, GetNumericOptions())); return; } diff --git a/tensorflow/core/kernels/conv_ops_3d.cc b/tensorflow/core/kernels/conv_ops_3d.cc index 60eb3afe596..2aed137404a 100644 --- a/tensorflow/core/kernels/conv_ops_3d.cc +++ b/tensorflow/core/kernels/conv_ops_3d.cc @@ -37,6 +37,7 @@ limitations under the License. #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include "tensorflow/core/kernels/cast_op.h" +#include "tensorflow/core/kernels/numeric_options_utils.h" #include "tensorflow/core/platform/stream_executor.h" #include "tensorflow/core/protobuf/autotuning.pb.h" #include "tensorflow/core/util/autotune_maps/conv_parameters.h" @@ -289,8 +290,7 @@ void LaunchConvOpImpl(OpKernelContext* ctx, bool cudnn_use_autotune, auto no_transpose = se::blas::Transpose::kNoTranspose; OP_REQUIRES_OK( ctx, stream->ThenBlasGemm(no_transpose, no_transpose, n, m, k, b_ptr, n, - a_ptr, k, &c_ptr, n, - se::blas::kDefaultComputePrecision)); + a_ptr, k, &c_ptr, n, GetNumericOptions())); return; } else if (!is_grouped_convolution && filter_planes == in_planes && filter_rows == in_rows && filter_cols == in_cols && @@ -311,8 +311,7 @@ void LaunchConvOpImpl(OpKernelContext* ctx, bool cudnn_use_autotune, auto no_transpose = se::blas::Transpose::kNoTranspose; OP_REQUIRES_OK( ctx, stream->ThenBlasGemm(no_transpose, no_transpose, n, m, k, b_ptr, n, - a_ptr, k, &c_ptr, n, - se::blas::kDefaultComputePrecision)); + a_ptr, k, &c_ptr, n, GetNumericOptions())); return; } diff --git a/tensorflow/core/kernels/conv_ops_impl.h b/tensorflow/core/kernels/conv_ops_impl.h index fafa617bd88..8d22ff01a97 100644 --- a/tensorflow/core/kernels/conv_ops_impl.h +++ b/tensorflow/core/kernels/conv_ops_impl.h @@ -64,6 +64,7 @@ limitations under the License. #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include "tensorflow/core/kernels/cast_op.h" #include "tensorflow/core/kernels/conv_ops_gpu.h" +#include "tensorflow/core/kernels/numeric_options_utils.h" #include "tensorflow/core/platform/stream_executor.h" #include "tensorflow/core/util/autotune_maps/conv_autotune_maps.h" #include "tensorflow/core/util/autotune_maps/conv_parameters.h" @@ -487,8 +488,7 @@ void LaunchConv2DOpImpl(OpKernelContext* ctx, bool use_cudnn, auto no_transpose = se::blas::Transpose::kNoTranspose; OP_REQUIRES_OK( ctx, stream->ThenBlasGemm(no_transpose, no_transpose, n, m, k, b_ptr, n, - a_ptr, k, &c_ptr, n, - se::blas::kDefaultComputePrecision)); + a_ptr, k, &c_ptr, n, GetNumericOptions())); return; } else if (patch_rows == in_rows && patch_cols == in_cols && !is_grouped_convolution && row_dilation == 1 && @@ -510,8 +510,7 @@ void LaunchConv2DOpImpl(OpKernelContext* ctx, bool use_cudnn, auto no_transpose = se::blas::Transpose::kNoTranspose; OP_REQUIRES_OK( ctx, stream->ThenBlasGemm(no_transpose, no_transpose, n, m, k, b_ptr, n, - a_ptr, k, &c_ptr, n, - se::blas::kDefaultComputePrecision)); + a_ptr, k, &c_ptr, n, GetNumericOptions())); return; } diff --git a/tensorflow/core/kernels/cudnn_rnn_ops.cc b/tensorflow/core/kernels/cudnn_rnn_ops.cc index 73b9bd0f828..2cd10ca8917 100644 --- a/tensorflow/core/kernels/cudnn_rnn_ops.cc +++ b/tensorflow/core/kernels/cudnn_rnn_ops.cc @@ -51,6 +51,7 @@ limitations under the License. #include "tensorflow/core/util/use_cudnn.h" #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#include "tensorflow/core/kernels/numeric_options_utils.h" #include "tensorflow/core/platform/stream_executor.h" #include "tensorflow/core/util/stream_executor_util.h" #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM @@ -1296,7 +1297,8 @@ class CudnnRNNKernelCommon : public OpKernel { auto rnn_desc_s = stream->parent()->createRnnDescriptor( num_layers, h_num_units, input_size, /*cell_size=*/c_num_units, /*batch_size=*/0, input_mode, rnn_direction_mode(), rnn_mode(), - ToDataType::value, algo_config, dropout(), seed(), + ToDataType::value, algo_config, GetNumericOptions(), dropout(), + seed(), /* state_allocator=*/nullptr, /*use_padded_io=*/false); if (!rnn_desc_s.ok()) { return FromExecutorStatus(rnn_desc_s); @@ -1321,8 +1323,8 @@ class CudnnRNNKernelCommon : public OpKernel { model_shapes.num_layers, model_shapes.num_units, model_shapes.input_size, model_shapes.cell_num_units, model_shapes.batch_size, input_mode, rnn_direction_mode(), rnn_mode(), - data_type, algo_config, dropout(), seed(), dropout_state_allocator, - use_padded_io); + data_type, algo_config, GetNumericOptions(), dropout(), seed(), + dropout_state_allocator, use_padded_io); TF_RETURN_IF_ERROR(rnn_desc_s.status()); *rnn_desc = std::move(rnn_desc_s).value(); diff --git a/tensorflow/core/kernels/matmul_op_impl.h b/tensorflow/core/kernels/matmul_op_impl.h index 2c5993383ed..9425a2d3b1b 100644 --- a/tensorflow/core/kernels/matmul_op_impl.h +++ b/tensorflow/core/kernels/matmul_op_impl.h @@ -46,6 +46,7 @@ limitations under the License. #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include "tensorflow/core/kernels/gpu_utils.h" +#include "tensorflow/core/kernels/numeric_options_utils.h" #include "tensorflow/core/platform/stream_executor.h" #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM #if GOOGLE_CUDA @@ -525,7 +526,7 @@ struct LaunchBatchMatMul { static_cast(1.0), b_ptrs, adj_y || trans_y ? k : n, a_ptrs, adj_x || trans_x ? m : k, static_cast(0.0), c_ptrs, n, batch_size, - &scratch_allocator) + GetNumericOptions(), &scratch_allocator) .ok(); if (!blas_launch_status) { context->SetStatus(errors::Internal( @@ -619,12 +620,11 @@ struct LaunchBatchMatMul { } } - OP_REQUIRES_OK(context, - stream->ThenBlasGemm( - blas_transpose_b, blas_transpose_a, n, m, k, - *(b_ptrs[0]), adj_y || trans_y ? k : n, *(a_ptrs[0]), - adj_x || trans_x ? m : k, c_ptrs[0], n, - se::blas::kDefaultComputePrecision)); + OP_REQUIRES_OK(context, stream->ThenBlasGemm( + blas_transpose_b, blas_transpose_a, n, m, k, + *(b_ptrs[0]), adj_y || trans_y ? k : n, + *(a_ptrs[0]), adj_x || trans_x ? m : k, + c_ptrs[0], n, GetNumericOptions())); } else if (use_strided_batched) { OP_REQUIRES_OK( context, stream->ThenBlasGemmStridedBatched( @@ -633,7 +633,7 @@ struct LaunchBatchMatMul { adj_y || trans_y ? k : n, b_stride, *a_ptrs[0], adj_x || trans_x ? m : k, a_stride, static_cast(0.0), c_ptrs[0], n, c_stride, - batch_size, se::blas::kDefaultComputePrecision)); + batch_size, GetNumericOptions())); } else { BlasScratchAllocator scratch_allocator(context); bool blas_launch_status = @@ -643,7 +643,7 @@ struct LaunchBatchMatMul { static_cast(1.0), b_ptrs, adj_y || trans_y ? k : n, a_ptrs, adj_x || trans_x ? m : k, static_cast(0.0), c_ptrs, n, batch_size, - &scratch_allocator) + GetNumericOptions(), &scratch_allocator) .ok(); if (!blas_launch_status) { context->SetStatus(errors::Internal( diff --git a/tensorflow/core/kernels/numeric_options_utils.h b/tensorflow/core/kernels/numeric_options_utils.h index beb2847c3d5..8cfe8e9870d 100644 --- a/tensorflow/core/kernels/numeric_options_utils.h +++ b/tensorflow/core/kernels/numeric_options_utils.h @@ -17,12 +17,15 @@ limitations under the License. #define TENSORFLOW_CORE_KERNELS_NUMERIC_OPTIONS_UTILS_H_ #include "tensorflow/compiler/xla/stream_executor/numeric_options.h" +#include "tensorflow/tsl/platform/tensor_float_32_utils.h" #include "tensorflow/tsl/util/determinism.h" namespace tensorflow { inline stream_executor::NumericOptions GetNumericOptions() { - return stream_executor::NumericOptions{tsl::OpDeterminismRequired()}; + return stream_executor::NumericOptions{ + /*require_determinism=*/tsl::OpDeterminismRequired(), + /*allow_tf32=*/tsl::tensor_float_32_execution_enabled()}; } } // namespace tensorflow diff --git a/tensorflow/core/kernels/rnn/BUILD b/tensorflow/core/kernels/rnn/BUILD index 24f8eee0570..2c28f54e761 100644 --- a/tensorflow/core/kernels/rnn/BUILD +++ b/tensorflow/core/kernels/rnn/BUILD @@ -34,6 +34,7 @@ tf_gpu_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core/kernels:eigen_helpers", + "//tensorflow/core/kernels:numeric_options_utils", "//tensorflow/core/platform:stream_executor", "//tensorflow/tsl/framework/contraction:eigen_contraction_kernel", "//third_party/eigen3", diff --git a/tensorflow/core/kernels/rnn/blas_gemm.cc b/tensorflow/core/kernels/rnn/blas_gemm.cc index 8966f2244c4..b83de9f7520 100644 --- a/tensorflow/core/kernels/rnn/blas_gemm.cc +++ b/tensorflow/core/kernels/rnn/blas_gemm.cc @@ -16,6 +16,7 @@ limitations under the License. #define EIGEN_USE_THREADS #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#include "tensorflow/core/kernels/numeric_options_utils.h" #include "tensorflow/core/platform/stream_executor.h" #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM @@ -53,7 +54,7 @@ void TensorCuBlasGemm::operator()(OpKernelContext* ctx, bool transa, ctx, ctx->op_device_context()->stream()->ThenBlasGemm( trans[transa], trans[transb], m, n, k, static_cast(alpha), a_ptr, lda, b_ptr, ldb, static_cast(beta), &c_ptr, ldc, - se::blas::kDefaultComputePrecision)); + GetNumericOptions())); #else ctx->SetStatus(errors::InvalidArgument("CuBlasGemm needs CUDA.")); #endif