[XLA:GPU] add NanCount thunk to thunk_buffer_debug_pass

We call the pass for f32 and bf16 output buffers.

PiperOrigin-RevId: 826808271
This commit is contained in:
Ilya Tikhonovskiy 2025-11-01 03:05:12 -07:00 committed by TensorFlower Gardener
parent 459ba30568
commit 4f3f2c9444
15 changed files with 274 additions and 44 deletions

View File

@ -2902,6 +2902,7 @@ cc_library(
hdrs = ["thunk_buffer_debug_pass.h"], hdrs = ["thunk_buffer_debug_pass.h"],
deps = [ deps = [
":buffers_checksum_thunk", ":buffers_checksum_thunk",
":buffers_nan_count_thunk",
":custom_call_thunk", ":custom_call_thunk",
":sequential_thunk", ":sequential_thunk",
":thunk", ":thunk",
@ -2934,6 +2935,7 @@ xla_cc_test(
srcs = ["thunk_buffer_debug_pass_test.cc"], srcs = ["thunk_buffer_debug_pass_test.cc"],
deps = [ deps = [
":buffers_checksum_thunk", ":buffers_checksum_thunk",
":buffers_nan_count_thunk",
":custom_call_thunk", ":custom_call_thunk",
":sequential_thunk", ":sequential_thunk",
":thunk", ":thunk",
@ -3096,6 +3098,7 @@ xla_test(
":thunk", ":thunk",
":thunk_buffer_id", ":thunk_buffer_id",
":thunk_id", ":thunk_id",
"//xla:types",
"//xla/service:buffer_assignment", "//xla/service:buffer_assignment",
"//xla/service:executable", "//xla/service:executable",
"//xla/service/gpu:buffer_allocations", "//xla/service/gpu:buffer_allocations",

View File

@ -22,7 +22,6 @@ limitations under the License.
#include "absl/status/status.h" #include "absl/status/status.h"
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "xla/backends/gpu/runtime/thunk.h" #include "xla/backends/gpu/runtime/thunk.h"
#include "xla/service/buffer_assignment.h"
#include "xla/stream_executor/cuda/cuda_compute_capability.h" #include "xla/stream_executor/cuda/cuda_compute_capability.h"
#include "xla/stream_executor/cuda/cuda_platform_id.h" #include "xla/stream_executor/cuda/cuda_platform_id.h"
#include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/device_memory.h"
@ -88,18 +87,21 @@ absl::Status BuffersDebugNanCountThunk::ExecuteOnStream(
se::gpu::BufferDebugLog buffer_debug_log = se::gpu::BufferDebugLog buffer_debug_log =
se::gpu::BufferDebugLog::FromDeviceMemoryUnchecked(log_ptr); se::gpu::BufferDebugLog::FromDeviceMemoryUnchecked(log_ptr);
for (const auto& [entry_id, buffer_slice_pair] : buffers_) { for (const auto& [entry_id, buffer] : buffers_) {
BufferAllocation::Slice buffer = buffer_slice_pair.buffer; PrimitiveType buffer_type = buffer.element_type();
PrimitiveType buffer_type = buffer_slice_pair.element_type;
se::DeviceMemoryBase device_buffer = se::DeviceMemoryBase device_buffer =
params.buffer_allocations->GetDeviceAddress(buffer); params.buffer_allocations->GetDeviceAddress(buffer);
if (buffer_type == PrimitiveType::F32) { if (buffer_type == PrimitiveType::F32) {
VLOG(1) << "F32 buffer detected with id: " << entry_id
<< " and size: " << device_buffer.size();
se::DeviceMemory<float> f32_buffer(device_buffer); se::DeviceMemory<float> f32_buffer(device_buffer);
TF_RETURN_IF_ERROR(kernel_f32_->Launch( TF_RETURN_IF_ERROR(kernel_f32_->Launch(
thread_dim, se::BlockDim(1, 1, 1), params.stream, entry_id, thread_dim, se::BlockDim(1, 1, 1), params.stream, entry_id,
f32_buffer, f32_buffer.size(), buffer_debug_log.GetDeviceHeader(), f32_buffer, f32_buffer.size(), buffer_debug_log.GetDeviceHeader(),
buffer_debug_log.GetDeviceEntries())); buffer_debug_log.GetDeviceEntries()));
} else if (buffer_type == PrimitiveType::BF16) { } else if (buffer_type == PrimitiveType::BF16) {
VLOG(1) << "BF16 buffer detected with id: " << entry_id
<< " and size: " << device_buffer.size();
se::DeviceMemory<Eigen::bfloat16> bf16_buffer(device_buffer); se::DeviceMemory<Eigen::bfloat16> bf16_buffer(device_buffer);
TF_RETURN_IF_ERROR(kernel_bf16_->Launch( TF_RETURN_IF_ERROR(kernel_bf16_->Launch(
thread_dim, se::BlockDim(1, 1, 1), params.stream, entry_id, thread_dim, se::BlockDim(1, 1, 1), params.stream, entry_id,
@ -117,10 +119,9 @@ absl::Status BuffersDebugNanCountThunk::ExecuteOnStream(
std::string BuffersDebugNanCountThunk::ToString(int indent) const { std::string BuffersDebugNanCountThunk::ToString(int indent) const {
std::string result; std::string result;
absl::StrAppend(&result, ", buffers = ", buffers_.size()); absl::StrAppend(&result, ", buffers = ", buffers_.size());
for (const auto& [buffer_id, buffer_slice_pair] : buffers_) { for (const auto& [buffer_id, buffer] : buffers_) {
absl::StrAppend(&result, "\n", std::string(indent + 2, ' '), absl::StrAppend(&result, "\n", std::string(indent + 2, ' '),
"buffer_id: ", buffer_id, "buffer_id: ", buffer_id, ", buffer: ", buffer.ToString());
", buffer: ", buffer_slice_pair.buffer.ToString());
} }
return result; return result;
} }

View File

@ -31,14 +31,9 @@ namespace xla::gpu {
class BuffersDebugNanCountThunk : public Thunk { class BuffersDebugNanCountThunk : public Thunk {
public: public:
struct BufferToCount {
BufferAllocation::Slice buffer;
PrimitiveType element_type;
};
explicit BuffersDebugNanCountThunk( explicit BuffersDebugNanCountThunk(
ThunkInfo info, BufferAllocation::Slice log_slice, ThunkInfo info, BufferAllocation::Slice log_slice,
absl::flat_hash_map<ThunkBufferId, BufferToCount> buffers) absl::flat_hash_map<ThunkBufferId, BufferAllocation::Slice> buffers)
: Thunk(Thunk::Kind::kBuffersDebugNanCount, std::move(info)), : Thunk(Thunk::Kind::kBuffersDebugNanCount, std::move(info)),
log_slice_(log_slice), log_slice_(log_slice),
buffers_(std::move(buffers)) {} buffers_(std::move(buffers)) {}
@ -53,6 +48,11 @@ class BuffersDebugNanCountThunk : public Thunk {
return {}; return {};
} }
const absl::flat_hash_map<ThunkBufferId, BufferAllocation::Slice>&
buffer_slices() const {
return buffers_;
}
private: private:
// Loaded in Initialize. // Loaded in Initialize.
std::optional<stream_executor::gpu::BufferDebugNanCountF32Kernel::KernelType> std::optional<stream_executor::gpu::BufferDebugNanCountF32Kernel::KernelType>
@ -60,7 +60,7 @@ class BuffersDebugNanCountThunk : public Thunk {
std::optional<stream_executor::gpu::BufferDebugNanCountBf16Kernel::KernelType> std::optional<stream_executor::gpu::BufferDebugNanCountBf16Kernel::KernelType>
kernel_bf16_; kernel_bf16_;
BufferAllocation::Slice log_slice_; BufferAllocation::Slice log_slice_;
absl::flat_hash_map<ThunkBufferId, BufferToCount> buffers_; absl::flat_hash_map<ThunkBufferId, BufferAllocation::Slice> buffers_;
}; };
} // namespace xla::gpu } // namespace xla::gpu

View File

@ -40,6 +40,7 @@ limitations under the License.
#include "xla/stream_executor/stream_executor_memory_allocator.h" #include "xla/stream_executor/stream_executor_memory_allocator.h"
#include "xla/tsl/lib/core/status_test_util.h" #include "xla/tsl/lib/core/status_test_util.h"
#include "xla/tsl/platform/statusor.h" #include "xla/tsl/platform/statusor.h"
#include "xla/types.h"
namespace xla::gpu { namespace xla::gpu {
namespace { namespace {
@ -85,12 +86,19 @@ TEST_F(BuffersDebugNanCountThunkTest, CalculatesNanCounts) {
BufferAllocation alloc(/*index=*/0, BufferAllocation alloc(/*index=*/0,
/*size=*/kTotalDeviceMemoryBytes, /*size=*/kTotalDeviceMemoryBytes,
/*color=*/0); /*color=*/0);
int64_t input_offset = kLogSize;
BufferAllocation::Slice log_slice(&alloc, /*offset=*/0, kLogSize); BufferAllocation::Slice log_slice(&alloc, /*offset=*/0, kLogSize);
input_offset += kLogSize;
BufferAllocation::Slice inputs[2]; BufferAllocation::Slice inputs[2];
for (int i = 0; i < 2; ++i) { int64_t input_size_bf16 = kInputElems * sizeof(Eigen::bfloat16);
inputs[i] = BufferAllocation::Slice( inputs[0] = BufferAllocation::Slice(&alloc, input_offset, input_size_bf16,
&alloc, /*offset=*/kLogSize + i * kInputSizeInBytes, kInputSizeInBytes); PrimitiveType::BF16);
} input_offset += input_size_bf16;
inputs[1] = BufferAllocation::Slice(
&alloc, input_offset, kInputElems * sizeof(float), PrimitiveType::F32);
BufferAllocations allocations( BufferAllocations allocations(
{executor_->AllocateArray<uint8_t>(kTotalDeviceMemoryBytes)}, {executor_->AllocateArray<uint8_t>(kTotalDeviceMemoryBytes)},
executor_->device_ordinal(), allocator_.get()); executor_->device_ordinal(), allocator_.get());
@ -102,13 +110,18 @@ TEST_F(BuffersDebugNanCountThunkTest, CalculatesNanCounts) {
BufferDebugLog::CreateOnDevice( BufferDebugLog::CreateOnDevice(
*stream_, se::DeviceMemory<uint8_t>(log_mem))); *stream_, se::DeviceMemory<uint8_t>(log_mem)));
// Fill inputs with some data // Fill inputs with some data
std::vector<float> data(kInputElems, 0); {
data[123] = std::numeric_limits<float>::quiet_NaN(); std::vector<Eigen::bfloat16> data(kInputElems, Eigen::bfloat16(0));
TF_ASSERT_OK(stream_->Memcpy(&inputs0_mem, data.data(), kInputSizeInBytes)); data[123] = std::numeric_limits<Eigen::bfloat16>::quiet_NaN();
data[123] = 0; TF_ASSERT_OK(stream_->Memcpy(&inputs0_mem, data.data(), kInputSizeInBytes));
data[456] = std::numeric_limits<float>::quiet_NaN(); }
data[789] = std::numeric_limits<float>::quiet_NaN(); {
TF_ASSERT_OK(stream_->Memcpy(&inputs1_mem, data.data(), kInputSizeInBytes)); std::vector<float> data(kInputElems, 0);
data[456] = std::numeric_limits<float>::quiet_NaN();
data[789] = std::numeric_limits<float>::quiet_NaN();
TF_ASSERT_OK(stream_->Memcpy(&inputs1_mem, data.data(), kInputSizeInBytes));
}
// Setup parameters for Initialize/Prepare/ExecuteOnStream // Setup parameters for Initialize/Prepare/ExecuteOnStream
Thunk::InitializeParams init_params; Thunk::InitializeParams init_params;
init_params.executor = executor_; init_params.executor = executor_;
@ -121,10 +134,8 @@ TEST_F(BuffersDebugNanCountThunkTest, CalculatesNanCounts) {
BuffersDebugNanCountThunk thunk( BuffersDebugNanCountThunk thunk(
Thunk::ThunkInfo(), log_slice, Thunk::ThunkInfo(), log_slice,
{{ThunkBufferId::Create(ThunkId(123), 4).value(), {{ThunkBufferId::Create(ThunkId(123), 4).value(), inputs[0]},
{inputs[0], PrimitiveType::F32}}, {ThunkBufferId::Create(ThunkId(456), 8).value(), inputs[1]}});
{ThunkBufferId::Create(ThunkId(456), 8).value(),
{inputs[1], PrimitiveType::F32}}});
TF_ASSERT_OK(thunk.Initialize(init_params)); TF_ASSERT_OK(thunk.Initialize(init_params));
TF_ASSERT_OK(thunk.Prepare(Thunk::PrepareParams{}, resource_requests)); TF_ASSERT_OK(thunk.Prepare(Thunk::PrepareParams{}, resource_requests));
TF_ASSERT_OK(thunk.ExecuteOnStream(execute_params)); TF_ASSERT_OK(thunk.ExecuteOnStream(execute_params));

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "absl/status/status.h" #include "absl/status/status.h"
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
#include "xla/backends/gpu/runtime/buffers_checksum_thunk.h" #include "xla/backends/gpu/runtime/buffers_checksum_thunk.h"
#include "xla/backends/gpu/runtime/buffers_nan_count_thunk.h"
#include "xla/backends/gpu/runtime/custom_call_thunk.h" #include "xla/backends/gpu/runtime/custom_call_thunk.h"
#include "xla/backends/gpu/runtime/sequential_thunk.h" #include "xla/backends/gpu/runtime/sequential_thunk.h"
#include "xla/backends/gpu/runtime/thunk.h" #include "xla/backends/gpu/runtime/thunk.h"
@ -65,7 +66,7 @@ namespace {
// If the thunk got wrapped, the data dependencies between the thunks will be // If the thunk got wrapped, the data dependencies between the thunks will be
// configured to ensure `predecessor_thunk` executes before the wrapped thunk // configured to ensure `predecessor_thunk` executes before the wrapped thunk
// and `successor_thunk` executes after. // and `successor_thunk` executes after.
absl::StatusOr<std::unique_ptr<Thunk>> WrapThunk( absl::StatusOr<std::unique_ptr<Thunk>> WrapWithChecksumThunk(
std::unique_ptr<Thunk> thunk, BufferAllocation::Slice log_slice, std::unique_ptr<Thunk> thunk, BufferAllocation::Slice log_slice,
const Thunk& predecessor_thunk, Thunk& successor_thunk) { const Thunk& predecessor_thunk, Thunk& successor_thunk) {
const auto& thunk_buffers = thunk->buffer_uses(); const auto& thunk_buffers = thunk->buffer_uses();
@ -127,6 +128,79 @@ absl::StatusOr<std::unique_ptr<Thunk>> WrapThunk(
return wrapped_thunk; return wrapped_thunk;
} }
absl::StatusOr<std::unique_ptr<Thunk>> WrapWithNanCounterThunk(
std::unique_ptr<Thunk> thunk, BufferAllocation::Slice log_slice,
const Thunk& predecessor_thunk, Thunk& successor_thunk) {
const auto& thunk_buffers = thunk->buffer_uses();
if (thunk_buffers.empty()) {
VLOG(1) << "No buffers in thunk " << thunk->thunk_info().thunk_id
<< ", skipping";
return thunk;
}
absl::flat_hash_map<ThunkBufferId, BufferAllocation::Slice> buffers_to_check;
for (size_t buffer_idx = 0; buffer_idx < thunk_buffers.size(); ++buffer_idx) {
VLOG(1) << "Buffer " << buffer_idx << " in thunk "
<< thunk->thunk_info().thunk_id;
const BufferUse& use = thunk_buffers[buffer_idx];
const BufferAllocation::Slice& slice = use.slice();
if (slice.allocation() == nullptr) {
VLOG(1) << "Buffer " << buffer_idx << " in thunk "
<< thunk->thunk_info().thunk_id
<< " has null allocation, skipping";
continue;
}
auto buffer_id =
ThunkBufferId::Create(thunk->thunk_info().thunk_id, buffer_idx);
if (!buffer_id.ok()) {
LOG(WARNING) << "ThunkBufferId::Create failed: Skipping buffer "
<< buffer_idx << " in thunk " << thunk->thunk_info().thunk_id
<< ": " << buffer_id.status();
continue;
}
if (slice.element_type() != PrimitiveType::F32 &&
slice.element_type() != PrimitiveType::BF16) {
VLOG(1) << "Buffer " << buffer_idx << " in thunk "
<< thunk->thunk_info().thunk_id
<< " has unsupported element type "
<< PrimitiveType_Name(slice.element_type()) << ", skipping";
continue;
}
if (!use.HasDefinedContentsOnOutput()) {
VLOG(1) << "Buffer " << buffer_idx << " in thunk "
<< thunk->thunk_info().thunk_id
<< " has no defined contents on output, skipping";
continue;
}
buffers_to_check.emplace(buffer_id.value(), use.slice());
VLOG(1) << "Found buffer " << buffer_idx << " in thunk "
<< thunk->thunk_info().thunk_id << " with element type "
<< PrimitiveType_Name(slice.element_type()) << " and size "
<< slice.size();
}
if (buffers_to_check.empty()) {
return thunk;
}
VLOG(1) << "Wrapping thunk " << thunk->thunk_info().thunk_id
<< " with nan counter thunk due to presence of buffers: "
<< buffers_to_check.size();
std::vector<std::unique_ptr<Thunk>> thunk_and_checks;
Thunk* thunk_ptr = thunk.get();
thunk_and_checks.push_back(std::move(thunk));
auto buffer_debug_nan_counter_thunk =
std::make_unique<BuffersDebugNanCountThunk>(Thunk::ThunkInfo(), log_slice,
std::move(buffers_to_check));
buffer_debug_nan_counter_thunk->add_control_predecessor(thunk_ptr);
thunk_and_checks.push_back(std::move(buffer_debug_nan_counter_thunk));
auto wrapped_thunk = std::make_unique<SequentialThunk>(
Thunk::ThunkInfo(), std::move(thunk_and_checks));
wrapped_thunk->add_control_predecessor(&predecessor_thunk);
successor_thunk.add_control_predecessor(wrapped_thunk.get());
return wrapped_thunk;
}
XLA_FFI_DEFINE_HANDLER_SYMBOL( XLA_FFI_DEFINE_HANDLER_SYMBOL(
kDebugLogInitHandler, kDebugLogInitHandler,
[](se::Stream* absl_nonnull stream, xla::ffi::Buffer<U8> log_buffer) { [](se::Stream* absl_nonnull stream, xla::ffi::Buffer<U8> log_buffer) {
@ -207,10 +281,21 @@ absl::StatusOr<bool> ThunkBufferDebugPass::Run(
ThunkSequence& thunks = root_thunk->thunks(); ThunkSequence& thunks = root_thunk->thunks();
for (auto& thunk : thunks) { for (auto& thunk : thunks) {
TF_ASSIGN_OR_RETURN( if (mode_ == Mode::kChecksum) {
thunk, WrapThunk(std::move(thunk), log_slice, VLOG(1) << "Wrapping with checksum thunk";
/*predecessor_thunk=*/*buffer_debug_init_thunk.get(), TF_ASSIGN_OR_RETURN(
/*successor_thunk=*/*buffer_debug_dump_thunk.get())); thunk, WrapWithChecksumThunk(
std::move(thunk), log_slice,
/*predecessor_thunk=*/*buffer_debug_init_thunk.get(),
/*successor_thunk=*/*buffer_debug_dump_thunk.get()));
} else if (mode_ == Mode::kNanCounter) {
VLOG(1) << "Wrapping with nan counter thunk";
TF_ASSIGN_OR_RETURN(
thunk, WrapWithNanCounterThunk(
std::move(thunk), log_slice,
/*predecessor_thunk=*/*buffer_debug_init_thunk.get(),
/*successor_thunk=*/*buffer_debug_dump_thunk.get()));
}
} }
thunks.reserve(thunks.size() + 2); thunks.reserve(thunks.size() + 2);

View File

@ -30,7 +30,12 @@ namespace gpu {
// Adds buffer debug tracing to thunks. // Adds buffer debug tracing to thunks.
class ThunkBufferDebugPass : public ThunkPassInterface { class ThunkBufferDebugPass : public ThunkPassInterface {
public: public:
ThunkBufferDebugPass() = default; enum class Mode {
kChecksum,
kNanCounter,
};
explicit ThunkBufferDebugPass(Mode mode) : mode_(mode) {}
absl::string_view name() const override { return "thunk-buffer-debug"; } absl::string_view name() const override { return "thunk-buffer-debug"; }
@ -39,6 +44,9 @@ class ThunkBufferDebugPass : public ThunkPassInterface {
const HloModule* absl_nullable hlo_module, const HloModule* absl_nullable hlo_module,
const se::DeviceDescription& device_info, const se::DeviceDescription& device_info,
ThunkPassBufferAllocator& allocator) override; ThunkPassBufferAllocator& allocator) override;
private:
Mode mode_;
}; };
} // namespace gpu } // namespace gpu

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "absl/status/status.h" #include "absl/status/status.h"
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
#include "xla/backends/gpu/runtime/buffers_checksum_thunk.h" #include "xla/backends/gpu/runtime/buffers_checksum_thunk.h"
#include "xla/backends/gpu/runtime/buffers_nan_count_thunk.h"
#include "xla/backends/gpu/runtime/custom_call_thunk.h" #include "xla/backends/gpu/runtime/custom_call_thunk.h"
#include "xla/backends/gpu/runtime/sequential_thunk.h" #include "xla/backends/gpu/runtime/sequential_thunk.h"
#include "xla/backends/gpu/runtime/thunk.h" #include "xla/backends/gpu/runtime/thunk.h"
@ -101,7 +102,7 @@ TEST(ThunkBufferDebugPassTest, IsNoOpWhenHloModuleIsNull) {
auto root_thunk = auto root_thunk =
std::make_unique<SequentialThunk>(Thunk::ThunkInfo(), std::move(thunks)); std::make_unique<SequentialThunk>(Thunk::ThunkInfo(), std::move(thunks));
ThunkBufferDebugPass pass; ThunkBufferDebugPass pass(ThunkBufferDebugPass::Mode::kChecksum);
TF_ASSERT_OK_AND_ASSIGN( TF_ASSERT_OK_AND_ASSIGN(
bool changed, pass.Run(root_thunk.get(), debug_options, bool changed, pass.Run(root_thunk.get(), debug_options,
/*hlo_module=*/nullptr, device_info, allocator)); /*hlo_module=*/nullptr, device_info, allocator));
@ -152,7 +153,7 @@ TEST(ThunkBufferDebugPassTest, InsertsBuffersDebugChecksumThunks) {
auto root_thunk = auto root_thunk =
std::make_unique<SequentialThunk>(Thunk::ThunkInfo(), std::move(thunks)); std::make_unique<SequentialThunk>(Thunk::ThunkInfo(), std::move(thunks));
ThunkBufferDebugPass pass; ThunkBufferDebugPass pass(ThunkBufferDebugPass::Mode::kChecksum);
TF_ASSERT_OK_AND_ASSIGN(bool changed, TF_ASSERT_OK_AND_ASSIGN(bool changed,
pass.Run(root_thunk.get(), debug_options, &hlo_module, pass.Run(root_thunk.get(), debug_options, &hlo_module,
device_info, allocator)); device_info, allocator));
@ -205,6 +206,91 @@ TEST(ThunkBufferDebugPassTest, InsertsBuffersDebugChecksumThunks) {
Pair(ThunkBufferId::Create(kTestThunkId, 2).value(), slice_io))); Pair(ThunkBufferId::Create(kTestThunkId, 2).value(), slice_io)));
} }
TEST(ThunkBufferDebugPassTest, InsertsBuffersDebugNanCounterThunks) {
static constexpr ThunkId kTestThunkId = ThunkId(123);
DebugOptions debug_options;
debug_options.set_xla_gpu_experimental_enable_nan_counter_on_thunks(true);
se::DeviceDescription device_info;
FakeThunkPassBufferAllocator allocator;
// The callbacks created by ThunkBufferDebugPass require a HloModule with
// a non-null entry computation.
auto builder = HloComputation::Builder("entry");
HloInstruction* root = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f)));
std::unique_ptr<HloComputation> entry_computation = builder.Build(root);
HloModule hlo_module("test_module", HloModuleConfig());
hlo_module.AddEntryComputation(std::move(entry_computation));
// Create a fake thunk with a few different buffer uses.
BufferAllocation alloc(0, 1024, 0);
BufferAllocation::Slice slice_i(&alloc, 0, 1, PrimitiveType::F32);
BufferAllocation::Slice slice_o(&alloc, 1, 1, PrimitiveType::F32);
BufferAllocation::Slice slice_io(&alloc, 2, 1, PrimitiveType::F32);
BufferAllocation::Slice slice_scratch(&alloc, 3, 1, PrimitiveType::F32);
Thunk::ThunkInfo fake_thunk_info;
fake_thunk_info.thunk_id = ThunkId(kTestThunkId);
auto fake_thunk = std::make_unique<FakeThunk>(
fake_thunk_info,
Thunk::BufferUses{
// Consume means the thunk can reuse the buffer for scratch space, so
// only check it on input.
BufferUse::Consume(slice_i),
// Write is undefined on input, but defined on output.
BufferUse::Write(slice_o),
// Unlike Consume, Read is supposed to preserve the contents of the
// buffer, so we check it on input *and* output.
BufferUse::Read(slice_io),
// Scratch buffers are not checked at all.
BufferUse::Scratch(slice_scratch),
});
Thunk* fake_thunk_ptr = fake_thunk.get();
std::vector<std::unique_ptr<Thunk>> thunks;
thunks.push_back(std::move(fake_thunk));
auto root_thunk =
std::make_unique<SequentialThunk>(Thunk::ThunkInfo(), std::move(thunks));
ThunkBufferDebugPass pass(ThunkBufferDebugPass::Mode::kNanCounter);
TF_ASSERT_OK_AND_ASSIGN(bool changed,
pass.Run(root_thunk.get(), debug_options, &hlo_module,
device_info, allocator));
EXPECT_TRUE(changed);
// Expected thunk structure after the pass:
// 1. CustomCallThunk (buffer debug log init)
// 2. SequentialThunk
// 1. FakeThunk
// 2. BuffersDebugNanCountThunk (nan counter output buffers)
// 3. CustomCallThunk (buffer debug log dump)
const std::vector<std::unique_ptr<Thunk>>& new_thunks = root_thunk->thunks();
EXPECT_THAT(new_thunks, SizeIs(3));
EXPECT_EQ(new_thunks[0]->kind(), Thunk::Kind::kCustomCall);
EXPECT_EQ(new_thunks[1]->kind(), Thunk::Kind::kSequential);
EXPECT_EQ(new_thunks[2]->kind(), Thunk::Kind::kCustomCall);
const CustomCallThunk& buffer_debug_init_thunk =
static_cast<const CustomCallThunk&>(*new_thunks[0]);
EXPECT_EQ(buffer_debug_init_thunk.target_name(),
"xla_gpu_buffer_debug_log_init");
const CustomCallThunk& buffer_debug_dump_thunk =
static_cast<const CustomCallThunk&>(*new_thunks[2]);
EXPECT_EQ(buffer_debug_dump_thunk.target_name(),
"xla_gpu_buffer_debug_log_dump");
const std::vector<std::unique_ptr<Thunk>>& sub_thunks =
static_cast<const SequentialThunk&>(*new_thunks[1]).thunks();
EXPECT_THAT(sub_thunks, SizeIs(2));
EXPECT_THAT(sub_thunks[0], Pointer(fake_thunk_ptr));
EXPECT_EQ(sub_thunks[1]->kind(), Thunk::Kind::kBuffersDebugNanCount);
const BuffersDebugNanCountThunk& buffer_debug_after_fake_thunk =
static_cast<const BuffersDebugNanCountThunk&>(*sub_thunks[1]);
EXPECT_THAT(
buffer_debug_after_fake_thunk.buffer_slices(),
UnorderedElementsAre(
Pair(ThunkBufferId::Create(kTestThunkId, 1).value(), slice_o),
Pair(ThunkBufferId::Create(kTestThunkId, 2).value(), slice_io)));
}
} // namespace } // namespace
} // namespace gpu } // namespace gpu
} // namespace xla } // namespace xla

View File

@ -471,6 +471,8 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
opts.set_xla_cpu_collective_call_terminate_timeout_seconds(40); opts.set_xla_cpu_collective_call_terminate_timeout_seconds(40);
opts.set_xla_keep_shardings_after_spmd(false); opts.set_xla_keep_shardings_after_spmd(false);
opts.set_xla_gpu_experimental_enable_checksum_tracing_on_thunks(false);
opts.set_xla_gpu_experimental_enable_nan_counter_on_thunks(false);
return opts; return opts;
} }
@ -2654,6 +2656,13 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
debug_options->xla_gpu_experimental_enable_checksum_tracing_on_thunks(), debug_options->xla_gpu_experimental_enable_checksum_tracing_on_thunks(),
"Enables an experimental feature to record checksums of selected thunk " "Enables an experimental feature to record checksums of selected thunk "
"inputs/outputs.")); "inputs/outputs."));
flag_list->push_back(tsl::Flag(
"xla_gpu_experimental_enable_nan_counter_on_thunks",
bool_setter_for(
&DebugOptions::set_xla_gpu_experimental_enable_nan_counter_on_thunks),
debug_options->xla_gpu_experimental_enable_nan_counter_on_thunks(),
"Enables an experimental feature to record the number of nans in thunk "
"outputs."));
flag_list->push_back(tsl::Flag( flag_list->push_back(tsl::Flag(
"xla_gpu_experimental_thunk_buffer_debug_filter_by_thunk_id_ranges", "xla_gpu_experimental_thunk_buffer_debug_filter_by_thunk_id_ranges",
setter_for_thunk_buffer_debug_filter_by_thunk_id, "(none)", setter_for_thunk_buffer_debug_filter_by_thunk_id, "(none)",

View File

@ -5866,6 +5866,9 @@ tf_proto_library(
name = "buffer_assignment_proto", name = "buffer_assignment_proto",
srcs = ["buffer_assignment.proto"], srcs = ["buffer_assignment.proto"],
make_default_target_header_only = True, make_default_target_header_only = True,
protodeps = [
"//xla:xla_data_proto",
],
) )
cc_library( cc_library(

View File

@ -245,6 +245,7 @@ BufferAllocation::Slice::ToProto() const {
proto.set_offset(offset()); proto.set_offset(offset());
proto.set_size(size()); proto.set_size(size());
proto.set_buffer_allocation_index(allocation() == nullptr ? -1 : index()); proto.set_buffer_allocation_index(allocation() == nullptr ? -1 : index());
proto.set_element_type(element_type());
return proto; return proto;
} }
@ -259,13 +260,14 @@ absl::StatusOr<BufferAllocation::Slice> BufferAllocation::Slice::FromProto(
} }
const BufferAllocation& allocation = const BufferAllocation& allocation =
buffer_allocations[proto.buffer_allocation_index()]; buffer_allocations[proto.buffer_allocation_index()];
return BufferAllocation::Slice(&allocation, proto.offset(), proto.size()); return BufferAllocation::Slice(&allocation, proto.offset(), proto.size(),
proto.element_type());
} }
BufferAllocation::Slice BufferAllocation::GetSlice( BufferAllocation::Slice BufferAllocation::GetSlice(
const HloValue& buffer) const { const HloValue& buffer) const {
const OffsetSize os = FindOrDie(assigned_buffers_, &buffer); const OffsetSize os = FindOrDie(assigned_buffers_, &buffer);
return Slice(this, os.offset, os.size); return Slice(this, os.offset, os.size, buffer.shape().element_type());
} }
absl::Status BufferAllocation::AddAssignment(const HloValue& buffer, absl::Status BufferAllocation::AddAssignment(const HloValue& buffer,
@ -331,6 +333,8 @@ BufferAllocationProto BufferAllocation::ToProto() const {
proto_assigned->set_logical_buffer_id(buffer_offset_size.first->id()); proto_assigned->set_logical_buffer_id(buffer_offset_size.first->id());
proto_assigned->set_offset(buffer_offset_size.second.offset); proto_assigned->set_offset(buffer_offset_size.second.offset);
proto_assigned->set_size(buffer_offset_size.second.size); proto_assigned->set_size(buffer_offset_size.second.size);
proto_assigned->set_element_type(
buffer_offset_size.first->shape().element_type());
} }
absl::c_sort(*proto.mutable_assigned(), absl::c_sort(*proto.mutable_assigned(),
[](const BufferAllocationProto::Assigned& assign1, [](const BufferAllocationProto::Assigned& assign1,

View File

@ -190,15 +190,22 @@ class BufferAllocation {
class Slice { class Slice {
public: public:
Slice() = default; Slice() = default;
Slice(const BufferAllocation* allocation, int64_t offset, int64_t size) Slice(const BufferAllocation* allocation, int64_t offset, int64_t size,
: allocation_(allocation), offset_(offset), size_(size) {} PrimitiveType element_type = PrimitiveType::PRIMITIVE_TYPE_INVALID)
: allocation_(allocation),
offset_(offset),
size_(size),
element_type_(element_type) {}
const BufferAllocation* allocation() const { return allocation_; } const BufferAllocation* allocation() const { return allocation_; }
Index index() const { return allocation_->index(); } Index index() const { return allocation_->index(); }
int64_t offset() const { return offset_; } int64_t offset() const { return offset_; }
int64_t size() const { return size_; } int64_t size() const { return size_; }
PrimitiveType element_type() const { return element_type_; }
bool operator==(const Slice& other) const { bool operator==(const Slice& other) const {
// We don't compare element_type_ because it's not always set, and it's
// not relevant for the comparison here.
return index() == other.index() && offset_ == other.offset_ && return index() == other.index() && offset_ == other.offset_ &&
size_ == other.size_; size_ == other.size_;
} }
@ -252,6 +259,7 @@ class BufferAllocation {
const BufferAllocation* allocation_ = nullptr; const BufferAllocation* allocation_ = nullptr;
int64_t offset_ = 0; int64_t offset_ = 0;
int64_t size_ = 0; int64_t size_ = 0;
PrimitiveType element_type_ = PrimitiveType::PRIMITIVE_TYPE_INVALID;
}; };
// GetSlice returns the Slice of contiguous memory that holds the value // GetSlice returns the Slice of contiguous memory that holds the value

View File

@ -15,6 +15,8 @@ limitations under the License.
syntax = "proto3"; syntax = "proto3";
import "xla/xla_data.proto";
package xla.buffer_assignment; package xla.buffer_assignment;
// This defines the buffer isolation configuration, which is a debugging tool to // This defines the buffer isolation configuration, which is a debugging tool to
@ -108,4 +110,5 @@ message BufferAllocationSliceProto {
int64 offset = 1; int64 offset = 1;
int64 size = 2; int64 size = 2;
int64 buffer_allocation_index = 3; int64 buffer_allocation_index = 3;
xla.PrimitiveType element_type = 4;
} }

View File

@ -179,7 +179,12 @@ static absl::Status RunThunkPasses(const DebugOptions& debug_options,
ThunkPassBufferAllocator& allocator) { ThunkPassBufferAllocator& allocator) {
ThunkPassPipeline pipeline("thunk-passes"); ThunkPassPipeline pipeline("thunk-passes");
if (debug_options.xla_gpu_experimental_enable_checksum_tracing_on_thunks()) { if (debug_options.xla_gpu_experimental_enable_checksum_tracing_on_thunks()) {
pipeline.AddPass(std::make_unique<ThunkBufferDebugPass>()); pipeline.AddPass(std::make_unique<ThunkBufferDebugPass>(
ThunkBufferDebugPass::Mode::kChecksum));
}
if (debug_options.xla_gpu_experimental_enable_nan_counter_on_thunks()) {
pipeline.AddPass(std::make_unique<ThunkBufferDebugPass>(
ThunkBufferDebugPass::Mode::kNanCounter));
} }
if (debug_options.xla_gpu_experimental_enable_command_buffer_on_thunks()) { if (debug_options.xla_gpu_experimental_enable_command_buffer_on_thunks()) {
pipeline.AddPass(std::make_unique<CommandBufferConversionPass>( pipeline.AddPass(std::make_unique<CommandBufferConversionPass>(

View File

@ -667,6 +667,7 @@ message BufferAllocationProto {
int64 logical_buffer_id = 1; int64 logical_buffer_id = 1;
int64 offset = 2; int64 offset = 2;
int64 size = 3; int64 size = 3;
xla.PrimitiveType element_type = 4;
} }
int64 index = 1; int64 index = 1;

View File

@ -348,7 +348,7 @@ message DebugOptions {
// Limits the thunk buffer debug instrumentation to specific thunks. // Limits the thunk buffer debug instrumentation to specific thunks.
optional ThunkBufferDebugFilter optional ThunkBufferDebugFilter
xla_gpu_experimental_thunk_buffer_debug_filter = 423; xla_gpu_experimental_thunk_buffer_debug_filter = 424;
// If true, every time an HLO module is run, we will dump an // If true, every time an HLO module is run, we will dump an
// HloUnoptimizedSnapshot (essentially, a serialized unoptimizedmodule plus // HloUnoptimizedSnapshot (essentially, a serialized unoptimizedmodule plus
@ -657,6 +657,9 @@ message DebugOptions {
optional bool xla_gpu_experimental_enable_heuristic_collective_combining = optional bool xla_gpu_experimental_enable_heuristic_collective_combining =
366; 366;
// If true, enable buffer nan counter on thunks.
optional bool xla_gpu_experimental_enable_nan_counter_on_thunks = 423;
// Enable NCCL symmetric buffers. // Enable NCCL symmetric buffers.
optional bool xla_gpu_experimental_enable_nccl_symmetric_buffers = 406; optional bool xla_gpu_experimental_enable_nccl_symmetric_buffers = 406;
@ -1388,7 +1391,7 @@ message DebugOptions {
// Note: when adding a new flag, please add it to one of the hardware-specific // Note: when adding a new flag, please add it to one of the hardware-specific
// or hardware-agnostic sections at the top of this proto message. // or hardware-agnostic sections at the top of this proto message.
// Next id: 424 // Next id: 425
// Extra options to pass to the compilation backend (e.g. LLVM); specific // Extra options to pass to the compilation backend (e.g. LLVM); specific
// interpretation of these values is left to the backend. // interpretation of these values is left to the backend.