diff --git a/third_party/xla/xla/backends/gpu/runtime/BUILD b/third_party/xla/xla/backends/gpu/runtime/BUILD index 5d08e43aa7c..3f81398cb24 100644 --- a/third_party/xla/xla/backends/gpu/runtime/BUILD +++ b/third_party/xla/xla/backends/gpu/runtime/BUILD @@ -2902,6 +2902,7 @@ cc_library( hdrs = ["thunk_buffer_debug_pass.h"], deps = [ ":buffers_checksum_thunk", + ":buffers_nan_count_thunk", ":custom_call_thunk", ":sequential_thunk", ":thunk", @@ -2934,6 +2935,7 @@ xla_cc_test( srcs = ["thunk_buffer_debug_pass_test.cc"], deps = [ ":buffers_checksum_thunk", + ":buffers_nan_count_thunk", ":custom_call_thunk", ":sequential_thunk", ":thunk", @@ -3096,6 +3098,7 @@ xla_test( ":thunk", ":thunk_buffer_id", ":thunk_id", + "//xla:types", "//xla/service:buffer_assignment", "//xla/service:executable", "//xla/service/gpu:buffer_allocations", diff --git a/third_party/xla/xla/backends/gpu/runtime/buffers_nan_count_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/buffers_nan_count_thunk.cc index 30bc069b5ac..20feebfd07b 100644 --- a/third_party/xla/xla/backends/gpu/runtime/buffers_nan_count_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/buffers_nan_count_thunk.cc @@ -22,7 +22,6 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "xla/backends/gpu/runtime/thunk.h" -#include "xla/service/buffer_assignment.h" #include "xla/stream_executor/cuda/cuda_compute_capability.h" #include "xla/stream_executor/cuda/cuda_platform_id.h" #include "xla/stream_executor/device_memory.h" @@ -88,18 +87,21 @@ absl::Status BuffersDebugNanCountThunk::ExecuteOnStream( se::gpu::BufferDebugLog buffer_debug_log = se::gpu::BufferDebugLog::FromDeviceMemoryUnchecked(log_ptr); - for (const auto& [entry_id, buffer_slice_pair] : buffers_) { - BufferAllocation::Slice buffer = buffer_slice_pair.buffer; - PrimitiveType buffer_type = buffer_slice_pair.element_type; + for (const auto& [entry_id, buffer] : buffers_) { + PrimitiveType buffer_type = buffer.element_type(); se::DeviceMemoryBase device_buffer = params.buffer_allocations->GetDeviceAddress(buffer); if (buffer_type == PrimitiveType::F32) { + VLOG(1) << "F32 buffer detected with id: " << entry_id + << " and size: " << device_buffer.size(); se::DeviceMemory f32_buffer(device_buffer); TF_RETURN_IF_ERROR(kernel_f32_->Launch( thread_dim, se::BlockDim(1, 1, 1), params.stream, entry_id, f32_buffer, f32_buffer.size(), buffer_debug_log.GetDeviceHeader(), buffer_debug_log.GetDeviceEntries())); } else if (buffer_type == PrimitiveType::BF16) { + VLOG(1) << "BF16 buffer detected with id: " << entry_id + << " and size: " << device_buffer.size(); se::DeviceMemory bf16_buffer(device_buffer); TF_RETURN_IF_ERROR(kernel_bf16_->Launch( thread_dim, se::BlockDim(1, 1, 1), params.stream, entry_id, @@ -117,10 +119,9 @@ absl::Status BuffersDebugNanCountThunk::ExecuteOnStream( std::string BuffersDebugNanCountThunk::ToString(int indent) const { std::string result; absl::StrAppend(&result, ", buffers = ", buffers_.size()); - for (const auto& [buffer_id, buffer_slice_pair] : buffers_) { + for (const auto& [buffer_id, buffer] : buffers_) { absl::StrAppend(&result, "\n", std::string(indent + 2, ' '), - "buffer_id: ", buffer_id, - ", buffer: ", buffer_slice_pair.buffer.ToString()); + "buffer_id: ", buffer_id, ", buffer: ", buffer.ToString()); } return result; } diff --git a/third_party/xla/xla/backends/gpu/runtime/buffers_nan_count_thunk.h b/third_party/xla/xla/backends/gpu/runtime/buffers_nan_count_thunk.h index b9d9f0679b4..e4ee0f896a9 100644 --- a/third_party/xla/xla/backends/gpu/runtime/buffers_nan_count_thunk.h +++ b/third_party/xla/xla/backends/gpu/runtime/buffers_nan_count_thunk.h @@ -31,14 +31,9 @@ namespace xla::gpu { class BuffersDebugNanCountThunk : public Thunk { public: - struct BufferToCount { - BufferAllocation::Slice buffer; - PrimitiveType element_type; - }; - explicit BuffersDebugNanCountThunk( ThunkInfo info, BufferAllocation::Slice log_slice, - absl::flat_hash_map buffers) + absl::flat_hash_map buffers) : Thunk(Thunk::Kind::kBuffersDebugNanCount, std::move(info)), log_slice_(log_slice), buffers_(std::move(buffers)) {} @@ -53,6 +48,11 @@ class BuffersDebugNanCountThunk : public Thunk { return {}; } + const absl::flat_hash_map& + buffer_slices() const { + return buffers_; + } + private: // Loaded in Initialize. std::optional @@ -60,7 +60,7 @@ class BuffersDebugNanCountThunk : public Thunk { std::optional kernel_bf16_; BufferAllocation::Slice log_slice_; - absl::flat_hash_map buffers_; + absl::flat_hash_map buffers_; }; } // namespace xla::gpu diff --git a/third_party/xla/xla/backends/gpu/runtime/buffers_nan_count_thunk_test.cc b/third_party/xla/xla/backends/gpu/runtime/buffers_nan_count_thunk_test.cc index aef4777ab8a..3aceee5b652 100644 --- a/third_party/xla/xla/backends/gpu/runtime/buffers_nan_count_thunk_test.cc +++ b/third_party/xla/xla/backends/gpu/runtime/buffers_nan_count_thunk_test.cc @@ -40,6 +40,7 @@ limitations under the License. #include "xla/stream_executor/stream_executor_memory_allocator.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/tsl/platform/statusor.h" +#include "xla/types.h" namespace xla::gpu { namespace { @@ -85,12 +86,19 @@ TEST_F(BuffersDebugNanCountThunkTest, CalculatesNanCounts) { BufferAllocation alloc(/*index=*/0, /*size=*/kTotalDeviceMemoryBytes, /*color=*/0); + int64_t input_offset = kLogSize; BufferAllocation::Slice log_slice(&alloc, /*offset=*/0, kLogSize); + input_offset += kLogSize; + BufferAllocation::Slice inputs[2]; - for (int i = 0; i < 2; ++i) { - inputs[i] = BufferAllocation::Slice( - &alloc, /*offset=*/kLogSize + i * kInputSizeInBytes, kInputSizeInBytes); - } + int64_t input_size_bf16 = kInputElems * sizeof(Eigen::bfloat16); + inputs[0] = BufferAllocation::Slice(&alloc, input_offset, input_size_bf16, + PrimitiveType::BF16); + input_offset += input_size_bf16; + + inputs[1] = BufferAllocation::Slice( + &alloc, input_offset, kInputElems * sizeof(float), PrimitiveType::F32); + BufferAllocations allocations( {executor_->AllocateArray(kTotalDeviceMemoryBytes)}, executor_->device_ordinal(), allocator_.get()); @@ -102,13 +110,18 @@ TEST_F(BuffersDebugNanCountThunkTest, CalculatesNanCounts) { BufferDebugLog::CreateOnDevice( *stream_, se::DeviceMemory(log_mem))); // Fill inputs with some data - std::vector data(kInputElems, 0); - data[123] = std::numeric_limits::quiet_NaN(); - TF_ASSERT_OK(stream_->Memcpy(&inputs0_mem, data.data(), kInputSizeInBytes)); - data[123] = 0; - data[456] = std::numeric_limits::quiet_NaN(); - data[789] = std::numeric_limits::quiet_NaN(); - TF_ASSERT_OK(stream_->Memcpy(&inputs1_mem, data.data(), kInputSizeInBytes)); + { + std::vector data(kInputElems, Eigen::bfloat16(0)); + data[123] = std::numeric_limits::quiet_NaN(); + TF_ASSERT_OK(stream_->Memcpy(&inputs0_mem, data.data(), kInputSizeInBytes)); + } + { + std::vector data(kInputElems, 0); + data[456] = std::numeric_limits::quiet_NaN(); + data[789] = std::numeric_limits::quiet_NaN(); + TF_ASSERT_OK(stream_->Memcpy(&inputs1_mem, data.data(), kInputSizeInBytes)); + } + // Setup parameters for Initialize/Prepare/ExecuteOnStream Thunk::InitializeParams init_params; init_params.executor = executor_; @@ -121,10 +134,8 @@ TEST_F(BuffersDebugNanCountThunkTest, CalculatesNanCounts) { BuffersDebugNanCountThunk thunk( Thunk::ThunkInfo(), log_slice, - {{ThunkBufferId::Create(ThunkId(123), 4).value(), - {inputs[0], PrimitiveType::F32}}, - {ThunkBufferId::Create(ThunkId(456), 8).value(), - {inputs[1], PrimitiveType::F32}}}); + {{ThunkBufferId::Create(ThunkId(123), 4).value(), inputs[0]}, + {ThunkBufferId::Create(ThunkId(456), 8).value(), inputs[1]}}); TF_ASSERT_OK(thunk.Initialize(init_params)); TF_ASSERT_OK(thunk.Prepare(Thunk::PrepareParams{}, resource_requests)); TF_ASSERT_OK(thunk.ExecuteOnStream(execute_params)); diff --git a/third_party/xla/xla/backends/gpu/runtime/thunk_buffer_debug_pass.cc b/third_party/xla/xla/backends/gpu/runtime/thunk_buffer_debug_pass.cc index 52cd22d032e..48bddfd291c 100644 --- a/third_party/xla/xla/backends/gpu/runtime/thunk_buffer_debug_pass.cc +++ b/third_party/xla/xla/backends/gpu/runtime/thunk_buffer_debug_pass.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "xla/backends/gpu/runtime/buffers_checksum_thunk.h" +#include "xla/backends/gpu/runtime/buffers_nan_count_thunk.h" #include "xla/backends/gpu/runtime/custom_call_thunk.h" #include "xla/backends/gpu/runtime/sequential_thunk.h" #include "xla/backends/gpu/runtime/thunk.h" @@ -65,7 +66,7 @@ namespace { // If the thunk got wrapped, the data dependencies between the thunks will be // configured to ensure `predecessor_thunk` executes before the wrapped thunk // and `successor_thunk` executes after. -absl::StatusOr> WrapThunk( +absl::StatusOr> WrapWithChecksumThunk( std::unique_ptr thunk, BufferAllocation::Slice log_slice, const Thunk& predecessor_thunk, Thunk& successor_thunk) { const auto& thunk_buffers = thunk->buffer_uses(); @@ -127,6 +128,79 @@ absl::StatusOr> WrapThunk( return wrapped_thunk; } +absl::StatusOr> WrapWithNanCounterThunk( + std::unique_ptr thunk, BufferAllocation::Slice log_slice, + const Thunk& predecessor_thunk, Thunk& successor_thunk) { + const auto& thunk_buffers = thunk->buffer_uses(); + if (thunk_buffers.empty()) { + VLOG(1) << "No buffers in thunk " << thunk->thunk_info().thunk_id + << ", skipping"; + return thunk; + } + + absl::flat_hash_map buffers_to_check; + for (size_t buffer_idx = 0; buffer_idx < thunk_buffers.size(); ++buffer_idx) { + VLOG(1) << "Buffer " << buffer_idx << " in thunk " + << thunk->thunk_info().thunk_id; + const BufferUse& use = thunk_buffers[buffer_idx]; + const BufferAllocation::Slice& slice = use.slice(); + if (slice.allocation() == nullptr) { + VLOG(1) << "Buffer " << buffer_idx << " in thunk " + << thunk->thunk_info().thunk_id + << " has null allocation, skipping"; + continue; + } + auto buffer_id = + ThunkBufferId::Create(thunk->thunk_info().thunk_id, buffer_idx); + if (!buffer_id.ok()) { + LOG(WARNING) << "ThunkBufferId::Create failed: Skipping buffer " + << buffer_idx << " in thunk " << thunk->thunk_info().thunk_id + << ": " << buffer_id.status(); + continue; + } + if (slice.element_type() != PrimitiveType::F32 && + slice.element_type() != PrimitiveType::BF16) { + VLOG(1) << "Buffer " << buffer_idx << " in thunk " + << thunk->thunk_info().thunk_id + << " has unsupported element type " + << PrimitiveType_Name(slice.element_type()) << ", skipping"; + continue; + } + if (!use.HasDefinedContentsOnOutput()) { + VLOG(1) << "Buffer " << buffer_idx << " in thunk " + << thunk->thunk_info().thunk_id + << " has no defined contents on output, skipping"; + continue; + } + buffers_to_check.emplace(buffer_id.value(), use.slice()); + VLOG(1) << "Found buffer " << buffer_idx << " in thunk " + << thunk->thunk_info().thunk_id << " with element type " + << PrimitiveType_Name(slice.element_type()) << " and size " + << slice.size(); + } + + if (buffers_to_check.empty()) { + return thunk; + } + + VLOG(1) << "Wrapping thunk " << thunk->thunk_info().thunk_id + << " with nan counter thunk due to presence of buffers: " + << buffers_to_check.size(); + std::vector> thunk_and_checks; + Thunk* thunk_ptr = thunk.get(); + thunk_and_checks.push_back(std::move(thunk)); + auto buffer_debug_nan_counter_thunk = + std::make_unique(Thunk::ThunkInfo(), log_slice, + std::move(buffers_to_check)); + buffer_debug_nan_counter_thunk->add_control_predecessor(thunk_ptr); + thunk_and_checks.push_back(std::move(buffer_debug_nan_counter_thunk)); + auto wrapped_thunk = std::make_unique( + Thunk::ThunkInfo(), std::move(thunk_and_checks)); + wrapped_thunk->add_control_predecessor(&predecessor_thunk); + successor_thunk.add_control_predecessor(wrapped_thunk.get()); + return wrapped_thunk; +} + XLA_FFI_DEFINE_HANDLER_SYMBOL( kDebugLogInitHandler, [](se::Stream* absl_nonnull stream, xla::ffi::Buffer log_buffer) { @@ -207,10 +281,21 @@ absl::StatusOr ThunkBufferDebugPass::Run( ThunkSequence& thunks = root_thunk->thunks(); for (auto& thunk : thunks) { - TF_ASSIGN_OR_RETURN( - thunk, WrapThunk(std::move(thunk), log_slice, - /*predecessor_thunk=*/*buffer_debug_init_thunk.get(), - /*successor_thunk=*/*buffer_debug_dump_thunk.get())); + if (mode_ == Mode::kChecksum) { + VLOG(1) << "Wrapping with checksum thunk"; + TF_ASSIGN_OR_RETURN( + thunk, WrapWithChecksumThunk( + std::move(thunk), log_slice, + /*predecessor_thunk=*/*buffer_debug_init_thunk.get(), + /*successor_thunk=*/*buffer_debug_dump_thunk.get())); + } else if (mode_ == Mode::kNanCounter) { + VLOG(1) << "Wrapping with nan counter thunk"; + TF_ASSIGN_OR_RETURN( + thunk, WrapWithNanCounterThunk( + std::move(thunk), log_slice, + /*predecessor_thunk=*/*buffer_debug_init_thunk.get(), + /*successor_thunk=*/*buffer_debug_dump_thunk.get())); + } } thunks.reserve(thunks.size() + 2); diff --git a/third_party/xla/xla/backends/gpu/runtime/thunk_buffer_debug_pass.h b/third_party/xla/xla/backends/gpu/runtime/thunk_buffer_debug_pass.h index f219f5d3845..3b2219d89c5 100644 --- a/third_party/xla/xla/backends/gpu/runtime/thunk_buffer_debug_pass.h +++ b/third_party/xla/xla/backends/gpu/runtime/thunk_buffer_debug_pass.h @@ -30,7 +30,12 @@ namespace gpu { // Adds buffer debug tracing to thunks. class ThunkBufferDebugPass : public ThunkPassInterface { public: - ThunkBufferDebugPass() = default; + enum class Mode { + kChecksum, + kNanCounter, + }; + + explicit ThunkBufferDebugPass(Mode mode) : mode_(mode) {} absl::string_view name() const override { return "thunk-buffer-debug"; } @@ -39,6 +44,9 @@ class ThunkBufferDebugPass : public ThunkPassInterface { const HloModule* absl_nullable hlo_module, const se::DeviceDescription& device_info, ThunkPassBufferAllocator& allocator) override; + + private: + Mode mode_; }; } // namespace gpu diff --git a/third_party/xla/xla/backends/gpu/runtime/thunk_buffer_debug_pass_test.cc b/third_party/xla/xla/backends/gpu/runtime/thunk_buffer_debug_pass_test.cc index bd585130701..6bf72669756 100644 --- a/third_party/xla/xla/backends/gpu/runtime/thunk_buffer_debug_pass_test.cc +++ b/third_party/xla/xla/backends/gpu/runtime/thunk_buffer_debug_pass_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "xla/backends/gpu/runtime/buffers_checksum_thunk.h" +#include "xla/backends/gpu/runtime/buffers_nan_count_thunk.h" #include "xla/backends/gpu/runtime/custom_call_thunk.h" #include "xla/backends/gpu/runtime/sequential_thunk.h" #include "xla/backends/gpu/runtime/thunk.h" @@ -101,7 +102,7 @@ TEST(ThunkBufferDebugPassTest, IsNoOpWhenHloModuleIsNull) { auto root_thunk = std::make_unique(Thunk::ThunkInfo(), std::move(thunks)); - ThunkBufferDebugPass pass; + ThunkBufferDebugPass pass(ThunkBufferDebugPass::Mode::kChecksum); TF_ASSERT_OK_AND_ASSIGN( bool changed, pass.Run(root_thunk.get(), debug_options, /*hlo_module=*/nullptr, device_info, allocator)); @@ -152,7 +153,7 @@ TEST(ThunkBufferDebugPassTest, InsertsBuffersDebugChecksumThunks) { auto root_thunk = std::make_unique(Thunk::ThunkInfo(), std::move(thunks)); - ThunkBufferDebugPass pass; + ThunkBufferDebugPass pass(ThunkBufferDebugPass::Mode::kChecksum); TF_ASSERT_OK_AND_ASSIGN(bool changed, pass.Run(root_thunk.get(), debug_options, &hlo_module, device_info, allocator)); @@ -205,6 +206,91 @@ TEST(ThunkBufferDebugPassTest, InsertsBuffersDebugChecksumThunks) { Pair(ThunkBufferId::Create(kTestThunkId, 2).value(), slice_io))); } +TEST(ThunkBufferDebugPassTest, InsertsBuffersDebugNanCounterThunks) { + static constexpr ThunkId kTestThunkId = ThunkId(123); + DebugOptions debug_options; + debug_options.set_xla_gpu_experimental_enable_nan_counter_on_thunks(true); + se::DeviceDescription device_info; + FakeThunkPassBufferAllocator allocator; + // The callbacks created by ThunkBufferDebugPass require a HloModule with + // a non-null entry computation. + auto builder = HloComputation::Builder("entry"); + HloInstruction* root = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); + std::unique_ptr entry_computation = builder.Build(root); + HloModule hlo_module("test_module", HloModuleConfig()); + hlo_module.AddEntryComputation(std::move(entry_computation)); + // Create a fake thunk with a few different buffer uses. + BufferAllocation alloc(0, 1024, 0); + BufferAllocation::Slice slice_i(&alloc, 0, 1, PrimitiveType::F32); + BufferAllocation::Slice slice_o(&alloc, 1, 1, PrimitiveType::F32); + BufferAllocation::Slice slice_io(&alloc, 2, 1, PrimitiveType::F32); + BufferAllocation::Slice slice_scratch(&alloc, 3, 1, PrimitiveType::F32); + Thunk::ThunkInfo fake_thunk_info; + fake_thunk_info.thunk_id = ThunkId(kTestThunkId); + auto fake_thunk = std::make_unique( + fake_thunk_info, + Thunk::BufferUses{ + // Consume means the thunk can reuse the buffer for scratch space, so + // only check it on input. + BufferUse::Consume(slice_i), + // Write is undefined on input, but defined on output. + BufferUse::Write(slice_o), + // Unlike Consume, Read is supposed to preserve the contents of the + // buffer, so we check it on input *and* output. + BufferUse::Read(slice_io), + // Scratch buffers are not checked at all. + BufferUse::Scratch(slice_scratch), + }); + Thunk* fake_thunk_ptr = fake_thunk.get(); + std::vector> thunks; + thunks.push_back(std::move(fake_thunk)); + auto root_thunk = + std::make_unique(Thunk::ThunkInfo(), std::move(thunks)); + + ThunkBufferDebugPass pass(ThunkBufferDebugPass::Mode::kNanCounter); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + pass.Run(root_thunk.get(), debug_options, &hlo_module, + device_info, allocator)); + EXPECT_TRUE(changed); + + // Expected thunk structure after the pass: + // 1. CustomCallThunk (buffer debug log init) + // 2. SequentialThunk + // 1. FakeThunk + // 2. BuffersDebugNanCountThunk (nan counter output buffers) + // 3. CustomCallThunk (buffer debug log dump) + const std::vector>& new_thunks = root_thunk->thunks(); + EXPECT_THAT(new_thunks, SizeIs(3)); + EXPECT_EQ(new_thunks[0]->kind(), Thunk::Kind::kCustomCall); + EXPECT_EQ(new_thunks[1]->kind(), Thunk::Kind::kSequential); + EXPECT_EQ(new_thunks[2]->kind(), Thunk::Kind::kCustomCall); + + const CustomCallThunk& buffer_debug_init_thunk = + static_cast(*new_thunks[0]); + EXPECT_EQ(buffer_debug_init_thunk.target_name(), + "xla_gpu_buffer_debug_log_init"); + + const CustomCallThunk& buffer_debug_dump_thunk = + static_cast(*new_thunks[2]); + EXPECT_EQ(buffer_debug_dump_thunk.target_name(), + "xla_gpu_buffer_debug_log_dump"); + + const std::vector>& sub_thunks = + static_cast(*new_thunks[1]).thunks(); + EXPECT_THAT(sub_thunks, SizeIs(2)); + EXPECT_THAT(sub_thunks[0], Pointer(fake_thunk_ptr)); + EXPECT_EQ(sub_thunks[1]->kind(), Thunk::Kind::kBuffersDebugNanCount); + + const BuffersDebugNanCountThunk& buffer_debug_after_fake_thunk = + static_cast(*sub_thunks[1]); + EXPECT_THAT( + buffer_debug_after_fake_thunk.buffer_slices(), + UnorderedElementsAre( + Pair(ThunkBufferId::Create(kTestThunkId, 1).value(), slice_o), + Pair(ThunkBufferId::Create(kTestThunkId, 2).value(), slice_io))); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/debug_options_flags.cc b/third_party/xla/xla/debug_options_flags.cc index 35cf6a5d6dc..c4ab2c7252f 100644 --- a/third_party/xla/xla/debug_options_flags.cc +++ b/third_party/xla/xla/debug_options_flags.cc @@ -471,6 +471,8 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_cpu_collective_call_terminate_timeout_seconds(40); opts.set_xla_keep_shardings_after_spmd(false); + opts.set_xla_gpu_experimental_enable_checksum_tracing_on_thunks(false); + opts.set_xla_gpu_experimental_enable_nan_counter_on_thunks(false); return opts; } @@ -2654,6 +2656,13 @@ void MakeDebugOptionsFlags(std::vector* flag_list, debug_options->xla_gpu_experimental_enable_checksum_tracing_on_thunks(), "Enables an experimental feature to record checksums of selected thunk " "inputs/outputs.")); + flag_list->push_back(tsl::Flag( + "xla_gpu_experimental_enable_nan_counter_on_thunks", + bool_setter_for( + &DebugOptions::set_xla_gpu_experimental_enable_nan_counter_on_thunks), + debug_options->xla_gpu_experimental_enable_nan_counter_on_thunks(), + "Enables an experimental feature to record the number of nans in thunk " + "outputs.")); flag_list->push_back(tsl::Flag( "xla_gpu_experimental_thunk_buffer_debug_filter_by_thunk_id_ranges", setter_for_thunk_buffer_debug_filter_by_thunk_id, "(none)", diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index d520201d541..87232351aff 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -5866,6 +5866,9 @@ tf_proto_library( name = "buffer_assignment_proto", srcs = ["buffer_assignment.proto"], make_default_target_header_only = True, + protodeps = [ + "//xla:xla_data_proto", + ], ) cc_library( diff --git a/third_party/xla/xla/service/buffer_assignment.cc b/third_party/xla/xla/service/buffer_assignment.cc index 00bbf5f40a6..fef83711d76 100644 --- a/third_party/xla/xla/service/buffer_assignment.cc +++ b/third_party/xla/xla/service/buffer_assignment.cc @@ -245,6 +245,7 @@ BufferAllocation::Slice::ToProto() const { proto.set_offset(offset()); proto.set_size(size()); proto.set_buffer_allocation_index(allocation() == nullptr ? -1 : index()); + proto.set_element_type(element_type()); return proto; } @@ -259,13 +260,14 @@ absl::StatusOr BufferAllocation::Slice::FromProto( } const BufferAllocation& allocation = buffer_allocations[proto.buffer_allocation_index()]; - return BufferAllocation::Slice(&allocation, proto.offset(), proto.size()); + return BufferAllocation::Slice(&allocation, proto.offset(), proto.size(), + proto.element_type()); } BufferAllocation::Slice BufferAllocation::GetSlice( const HloValue& buffer) const { const OffsetSize os = FindOrDie(assigned_buffers_, &buffer); - return Slice(this, os.offset, os.size); + return Slice(this, os.offset, os.size, buffer.shape().element_type()); } absl::Status BufferAllocation::AddAssignment(const HloValue& buffer, @@ -331,6 +333,8 @@ BufferAllocationProto BufferAllocation::ToProto() const { proto_assigned->set_logical_buffer_id(buffer_offset_size.first->id()); proto_assigned->set_offset(buffer_offset_size.second.offset); proto_assigned->set_size(buffer_offset_size.second.size); + proto_assigned->set_element_type( + buffer_offset_size.first->shape().element_type()); } absl::c_sort(*proto.mutable_assigned(), [](const BufferAllocationProto::Assigned& assign1, diff --git a/third_party/xla/xla/service/buffer_assignment.h b/third_party/xla/xla/service/buffer_assignment.h index c33ca009589..0ba61f183b4 100644 --- a/third_party/xla/xla/service/buffer_assignment.h +++ b/third_party/xla/xla/service/buffer_assignment.h @@ -190,15 +190,22 @@ class BufferAllocation { class Slice { public: Slice() = default; - Slice(const BufferAllocation* allocation, int64_t offset, int64_t size) - : allocation_(allocation), offset_(offset), size_(size) {} + Slice(const BufferAllocation* allocation, int64_t offset, int64_t size, + PrimitiveType element_type = PrimitiveType::PRIMITIVE_TYPE_INVALID) + : allocation_(allocation), + offset_(offset), + size_(size), + element_type_(element_type) {} const BufferAllocation* allocation() const { return allocation_; } Index index() const { return allocation_->index(); } int64_t offset() const { return offset_; } int64_t size() const { return size_; } + PrimitiveType element_type() const { return element_type_; } bool operator==(const Slice& other) const { + // We don't compare element_type_ because it's not always set, and it's + // not relevant for the comparison here. return index() == other.index() && offset_ == other.offset_ && size_ == other.size_; } @@ -252,6 +259,7 @@ class BufferAllocation { const BufferAllocation* allocation_ = nullptr; int64_t offset_ = 0; int64_t size_ = 0; + PrimitiveType element_type_ = PrimitiveType::PRIMITIVE_TYPE_INVALID; }; // GetSlice returns the Slice of contiguous memory that holds the value diff --git a/third_party/xla/xla/service/buffer_assignment.proto b/third_party/xla/xla/service/buffer_assignment.proto index 6f6b2ac35ae..df6714b0faa 100644 --- a/third_party/xla/xla/service/buffer_assignment.proto +++ b/third_party/xla/xla/service/buffer_assignment.proto @@ -15,6 +15,8 @@ limitations under the License. syntax = "proto3"; +import "xla/xla_data.proto"; + package xla.buffer_assignment; // This defines the buffer isolation configuration, which is a debugging tool to @@ -108,4 +110,5 @@ message BufferAllocationSliceProto { int64 offset = 1; int64 size = 2; int64 buffer_allocation_index = 3; + xla.PrimitiveType element_type = 4; } diff --git a/third_party/xla/xla/service/gpu/gpu_executable.cc b/third_party/xla/xla/service/gpu/gpu_executable.cc index ca4c1f304db..d7476020579 100644 --- a/third_party/xla/xla/service/gpu/gpu_executable.cc +++ b/third_party/xla/xla/service/gpu/gpu_executable.cc @@ -179,7 +179,12 @@ static absl::Status RunThunkPasses(const DebugOptions& debug_options, ThunkPassBufferAllocator& allocator) { ThunkPassPipeline pipeline("thunk-passes"); if (debug_options.xla_gpu_experimental_enable_checksum_tracing_on_thunks()) { - pipeline.AddPass(std::make_unique()); + pipeline.AddPass(std::make_unique( + ThunkBufferDebugPass::Mode::kChecksum)); + } + if (debug_options.xla_gpu_experimental_enable_nan_counter_on_thunks()) { + pipeline.AddPass(std::make_unique( + ThunkBufferDebugPass::Mode::kNanCounter)); } if (debug_options.xla_gpu_experimental_enable_command_buffer_on_thunks()) { pipeline.AddPass(std::make_unique( diff --git a/third_party/xla/xla/service/hlo.proto b/third_party/xla/xla/service/hlo.proto index 6a8bf2aadec..6ea3da98fd7 100644 --- a/third_party/xla/xla/service/hlo.proto +++ b/third_party/xla/xla/service/hlo.proto @@ -667,6 +667,7 @@ message BufferAllocationProto { int64 logical_buffer_id = 1; int64 offset = 2; int64 size = 3; + xla.PrimitiveType element_type = 4; } int64 index = 1; diff --git a/third_party/xla/xla/xla.proto b/third_party/xla/xla/xla.proto index bb7c7c3784d..3d889f63a1b 100644 --- a/third_party/xla/xla/xla.proto +++ b/third_party/xla/xla/xla.proto @@ -348,7 +348,7 @@ message DebugOptions { // Limits the thunk buffer debug instrumentation to specific thunks. optional ThunkBufferDebugFilter - xla_gpu_experimental_thunk_buffer_debug_filter = 423; + xla_gpu_experimental_thunk_buffer_debug_filter = 424; // If true, every time an HLO module is run, we will dump an // HloUnoptimizedSnapshot (essentially, a serialized unoptimizedmodule plus @@ -657,6 +657,9 @@ message DebugOptions { optional bool xla_gpu_experimental_enable_heuristic_collective_combining = 366; + // If true, enable buffer nan counter on thunks. + optional bool xla_gpu_experimental_enable_nan_counter_on_thunks = 423; + // Enable NCCL symmetric buffers. optional bool xla_gpu_experimental_enable_nccl_symmetric_buffers = 406; @@ -1388,7 +1391,7 @@ message DebugOptions { // Note: when adding a new flag, please add it to one of the hardware-specific // or hardware-agnostic sections at the top of this proto message. - // Next id: 424 + // Next id: 425 // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend.