[XLA:GPU] add NanCount thunk to thunk_buffer_debug_pass

We call the pass for f32 and bf16 output buffers.

PiperOrigin-RevId: 826808271
This commit is contained in:
Ilya Tikhonovskiy 2025-11-01 03:05:12 -07:00 committed by TensorFlower Gardener
parent 459ba30568
commit 4f3f2c9444
15 changed files with 274 additions and 44 deletions

View File

@ -2902,6 +2902,7 @@ cc_library(
hdrs = ["thunk_buffer_debug_pass.h"],
deps = [
":buffers_checksum_thunk",
":buffers_nan_count_thunk",
":custom_call_thunk",
":sequential_thunk",
":thunk",
@ -2934,6 +2935,7 @@ xla_cc_test(
srcs = ["thunk_buffer_debug_pass_test.cc"],
deps = [
":buffers_checksum_thunk",
":buffers_nan_count_thunk",
":custom_call_thunk",
":sequential_thunk",
":thunk",
@ -3096,6 +3098,7 @@ xla_test(
":thunk",
":thunk_buffer_id",
":thunk_id",
"//xla:types",
"//xla/service:buffer_assignment",
"//xla/service:executable",
"//xla/service/gpu:buffer_allocations",

View File

@ -22,7 +22,6 @@ limitations under the License.
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "xla/backends/gpu/runtime/thunk.h"
#include "xla/service/buffer_assignment.h"
#include "xla/stream_executor/cuda/cuda_compute_capability.h"
#include "xla/stream_executor/cuda/cuda_platform_id.h"
#include "xla/stream_executor/device_memory.h"
@ -88,18 +87,21 @@ absl::Status BuffersDebugNanCountThunk::ExecuteOnStream(
se::gpu::BufferDebugLog buffer_debug_log =
se::gpu::BufferDebugLog::FromDeviceMemoryUnchecked(log_ptr);
for (const auto& [entry_id, buffer_slice_pair] : buffers_) {
BufferAllocation::Slice buffer = buffer_slice_pair.buffer;
PrimitiveType buffer_type = buffer_slice_pair.element_type;
for (const auto& [entry_id, buffer] : buffers_) {
PrimitiveType buffer_type = buffer.element_type();
se::DeviceMemoryBase device_buffer =
params.buffer_allocations->GetDeviceAddress(buffer);
if (buffer_type == PrimitiveType::F32) {
VLOG(1) << "F32 buffer detected with id: " << entry_id
<< " and size: " << device_buffer.size();
se::DeviceMemory<float> f32_buffer(device_buffer);
TF_RETURN_IF_ERROR(kernel_f32_->Launch(
thread_dim, se::BlockDim(1, 1, 1), params.stream, entry_id,
f32_buffer, f32_buffer.size(), buffer_debug_log.GetDeviceHeader(),
buffer_debug_log.GetDeviceEntries()));
} else if (buffer_type == PrimitiveType::BF16) {
VLOG(1) << "BF16 buffer detected with id: " << entry_id
<< " and size: " << device_buffer.size();
se::DeviceMemory<Eigen::bfloat16> bf16_buffer(device_buffer);
TF_RETURN_IF_ERROR(kernel_bf16_->Launch(
thread_dim, se::BlockDim(1, 1, 1), params.stream, entry_id,
@ -117,10 +119,9 @@ absl::Status BuffersDebugNanCountThunk::ExecuteOnStream(
std::string BuffersDebugNanCountThunk::ToString(int indent) const {
std::string result;
absl::StrAppend(&result, ", buffers = ", buffers_.size());
for (const auto& [buffer_id, buffer_slice_pair] : buffers_) {
for (const auto& [buffer_id, buffer] : buffers_) {
absl::StrAppend(&result, "\n", std::string(indent + 2, ' '),
"buffer_id: ", buffer_id,
", buffer: ", buffer_slice_pair.buffer.ToString());
"buffer_id: ", buffer_id, ", buffer: ", buffer.ToString());
}
return result;
}

View File

@ -31,14 +31,9 @@ namespace xla::gpu {
class BuffersDebugNanCountThunk : public Thunk {
public:
struct BufferToCount {
BufferAllocation::Slice buffer;
PrimitiveType element_type;
};
explicit BuffersDebugNanCountThunk(
ThunkInfo info, BufferAllocation::Slice log_slice,
absl::flat_hash_map<ThunkBufferId, BufferToCount> buffers)
absl::flat_hash_map<ThunkBufferId, BufferAllocation::Slice> buffers)
: Thunk(Thunk::Kind::kBuffersDebugNanCount, std::move(info)),
log_slice_(log_slice),
buffers_(std::move(buffers)) {}
@ -53,6 +48,11 @@ class BuffersDebugNanCountThunk : public Thunk {
return {};
}
const absl::flat_hash_map<ThunkBufferId, BufferAllocation::Slice>&
buffer_slices() const {
return buffers_;
}
private:
// Loaded in Initialize.
std::optional<stream_executor::gpu::BufferDebugNanCountF32Kernel::KernelType>
@ -60,7 +60,7 @@ class BuffersDebugNanCountThunk : public Thunk {
std::optional<stream_executor::gpu::BufferDebugNanCountBf16Kernel::KernelType>
kernel_bf16_;
BufferAllocation::Slice log_slice_;
absl::flat_hash_map<ThunkBufferId, BufferToCount> buffers_;
absl::flat_hash_map<ThunkBufferId, BufferAllocation::Slice> buffers_;
};
} // namespace xla::gpu

View File

@ -40,6 +40,7 @@ limitations under the License.
#include "xla/stream_executor/stream_executor_memory_allocator.h"
#include "xla/tsl/lib/core/status_test_util.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/types.h"
namespace xla::gpu {
namespace {
@ -85,12 +86,19 @@ TEST_F(BuffersDebugNanCountThunkTest, CalculatesNanCounts) {
BufferAllocation alloc(/*index=*/0,
/*size=*/kTotalDeviceMemoryBytes,
/*color=*/0);
int64_t input_offset = kLogSize;
BufferAllocation::Slice log_slice(&alloc, /*offset=*/0, kLogSize);
input_offset += kLogSize;
BufferAllocation::Slice inputs[2];
for (int i = 0; i < 2; ++i) {
inputs[i] = BufferAllocation::Slice(
&alloc, /*offset=*/kLogSize + i * kInputSizeInBytes, kInputSizeInBytes);
}
int64_t input_size_bf16 = kInputElems * sizeof(Eigen::bfloat16);
inputs[0] = BufferAllocation::Slice(&alloc, input_offset, input_size_bf16,
PrimitiveType::BF16);
input_offset += input_size_bf16;
inputs[1] = BufferAllocation::Slice(
&alloc, input_offset, kInputElems * sizeof(float), PrimitiveType::F32);
BufferAllocations allocations(
{executor_->AllocateArray<uint8_t>(kTotalDeviceMemoryBytes)},
executor_->device_ordinal(), allocator_.get());
@ -102,13 +110,18 @@ TEST_F(BuffersDebugNanCountThunkTest, CalculatesNanCounts) {
BufferDebugLog::CreateOnDevice(
*stream_, se::DeviceMemory<uint8_t>(log_mem)));
// Fill inputs with some data
std::vector<float> data(kInputElems, 0);
data[123] = std::numeric_limits<float>::quiet_NaN();
TF_ASSERT_OK(stream_->Memcpy(&inputs0_mem, data.data(), kInputSizeInBytes));
data[123] = 0;
data[456] = std::numeric_limits<float>::quiet_NaN();
data[789] = std::numeric_limits<float>::quiet_NaN();
TF_ASSERT_OK(stream_->Memcpy(&inputs1_mem, data.data(), kInputSizeInBytes));
{
std::vector<Eigen::bfloat16> data(kInputElems, Eigen::bfloat16(0));
data[123] = std::numeric_limits<Eigen::bfloat16>::quiet_NaN();
TF_ASSERT_OK(stream_->Memcpy(&inputs0_mem, data.data(), kInputSizeInBytes));
}
{
std::vector<float> data(kInputElems, 0);
data[456] = std::numeric_limits<float>::quiet_NaN();
data[789] = std::numeric_limits<float>::quiet_NaN();
TF_ASSERT_OK(stream_->Memcpy(&inputs1_mem, data.data(), kInputSizeInBytes));
}
// Setup parameters for Initialize/Prepare/ExecuteOnStream
Thunk::InitializeParams init_params;
init_params.executor = executor_;
@ -121,10 +134,8 @@ TEST_F(BuffersDebugNanCountThunkTest, CalculatesNanCounts) {
BuffersDebugNanCountThunk thunk(
Thunk::ThunkInfo(), log_slice,
{{ThunkBufferId::Create(ThunkId(123), 4).value(),
{inputs[0], PrimitiveType::F32}},
{ThunkBufferId::Create(ThunkId(456), 8).value(),
{inputs[1], PrimitiveType::F32}}});
{{ThunkBufferId::Create(ThunkId(123), 4).value(), inputs[0]},
{ThunkBufferId::Create(ThunkId(456), 8).value(), inputs[1]}});
TF_ASSERT_OK(thunk.Initialize(init_params));
TF_ASSERT_OK(thunk.Prepare(Thunk::PrepareParams{}, resource_requests));
TF_ASSERT_OK(thunk.ExecuteOnStream(execute_params));

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "xla/backends/gpu/runtime/buffers_checksum_thunk.h"
#include "xla/backends/gpu/runtime/buffers_nan_count_thunk.h"
#include "xla/backends/gpu/runtime/custom_call_thunk.h"
#include "xla/backends/gpu/runtime/sequential_thunk.h"
#include "xla/backends/gpu/runtime/thunk.h"
@ -65,7 +66,7 @@ namespace {
// If the thunk got wrapped, the data dependencies between the thunks will be
// configured to ensure `predecessor_thunk` executes before the wrapped thunk
// and `successor_thunk` executes after.
absl::StatusOr<std::unique_ptr<Thunk>> WrapThunk(
absl::StatusOr<std::unique_ptr<Thunk>> WrapWithChecksumThunk(
std::unique_ptr<Thunk> thunk, BufferAllocation::Slice log_slice,
const Thunk& predecessor_thunk, Thunk& successor_thunk) {
const auto& thunk_buffers = thunk->buffer_uses();
@ -127,6 +128,79 @@ absl::StatusOr<std::unique_ptr<Thunk>> WrapThunk(
return wrapped_thunk;
}
absl::StatusOr<std::unique_ptr<Thunk>> WrapWithNanCounterThunk(
std::unique_ptr<Thunk> thunk, BufferAllocation::Slice log_slice,
const Thunk& predecessor_thunk, Thunk& successor_thunk) {
const auto& thunk_buffers = thunk->buffer_uses();
if (thunk_buffers.empty()) {
VLOG(1) << "No buffers in thunk " << thunk->thunk_info().thunk_id
<< ", skipping";
return thunk;
}
absl::flat_hash_map<ThunkBufferId, BufferAllocation::Slice> buffers_to_check;
for (size_t buffer_idx = 0; buffer_idx < thunk_buffers.size(); ++buffer_idx) {
VLOG(1) << "Buffer " << buffer_idx << " in thunk "
<< thunk->thunk_info().thunk_id;
const BufferUse& use = thunk_buffers[buffer_idx];
const BufferAllocation::Slice& slice = use.slice();
if (slice.allocation() == nullptr) {
VLOG(1) << "Buffer " << buffer_idx << " in thunk "
<< thunk->thunk_info().thunk_id
<< " has null allocation, skipping";
continue;
}
auto buffer_id =
ThunkBufferId::Create(thunk->thunk_info().thunk_id, buffer_idx);
if (!buffer_id.ok()) {
LOG(WARNING) << "ThunkBufferId::Create failed: Skipping buffer "
<< buffer_idx << " in thunk " << thunk->thunk_info().thunk_id
<< ": " << buffer_id.status();
continue;
}
if (slice.element_type() != PrimitiveType::F32 &&
slice.element_type() != PrimitiveType::BF16) {
VLOG(1) << "Buffer " << buffer_idx << " in thunk "
<< thunk->thunk_info().thunk_id
<< " has unsupported element type "
<< PrimitiveType_Name(slice.element_type()) << ", skipping";
continue;
}
if (!use.HasDefinedContentsOnOutput()) {
VLOG(1) << "Buffer " << buffer_idx << " in thunk "
<< thunk->thunk_info().thunk_id
<< " has no defined contents on output, skipping";
continue;
}
buffers_to_check.emplace(buffer_id.value(), use.slice());
VLOG(1) << "Found buffer " << buffer_idx << " in thunk "
<< thunk->thunk_info().thunk_id << " with element type "
<< PrimitiveType_Name(slice.element_type()) << " and size "
<< slice.size();
}
if (buffers_to_check.empty()) {
return thunk;
}
VLOG(1) << "Wrapping thunk " << thunk->thunk_info().thunk_id
<< " with nan counter thunk due to presence of buffers: "
<< buffers_to_check.size();
std::vector<std::unique_ptr<Thunk>> thunk_and_checks;
Thunk* thunk_ptr = thunk.get();
thunk_and_checks.push_back(std::move(thunk));
auto buffer_debug_nan_counter_thunk =
std::make_unique<BuffersDebugNanCountThunk>(Thunk::ThunkInfo(), log_slice,
std::move(buffers_to_check));
buffer_debug_nan_counter_thunk->add_control_predecessor(thunk_ptr);
thunk_and_checks.push_back(std::move(buffer_debug_nan_counter_thunk));
auto wrapped_thunk = std::make_unique<SequentialThunk>(
Thunk::ThunkInfo(), std::move(thunk_and_checks));
wrapped_thunk->add_control_predecessor(&predecessor_thunk);
successor_thunk.add_control_predecessor(wrapped_thunk.get());
return wrapped_thunk;
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(
kDebugLogInitHandler,
[](se::Stream* absl_nonnull stream, xla::ffi::Buffer<U8> log_buffer) {
@ -207,10 +281,21 @@ absl::StatusOr<bool> ThunkBufferDebugPass::Run(
ThunkSequence& thunks = root_thunk->thunks();
for (auto& thunk : thunks) {
TF_ASSIGN_OR_RETURN(
thunk, WrapThunk(std::move(thunk), log_slice,
/*predecessor_thunk=*/*buffer_debug_init_thunk.get(),
/*successor_thunk=*/*buffer_debug_dump_thunk.get()));
if (mode_ == Mode::kChecksum) {
VLOG(1) << "Wrapping with checksum thunk";
TF_ASSIGN_OR_RETURN(
thunk, WrapWithChecksumThunk(
std::move(thunk), log_slice,
/*predecessor_thunk=*/*buffer_debug_init_thunk.get(),
/*successor_thunk=*/*buffer_debug_dump_thunk.get()));
} else if (mode_ == Mode::kNanCounter) {
VLOG(1) << "Wrapping with nan counter thunk";
TF_ASSIGN_OR_RETURN(
thunk, WrapWithNanCounterThunk(
std::move(thunk), log_slice,
/*predecessor_thunk=*/*buffer_debug_init_thunk.get(),
/*successor_thunk=*/*buffer_debug_dump_thunk.get()));
}
}
thunks.reserve(thunks.size() + 2);

View File

@ -30,7 +30,12 @@ namespace gpu {
// Adds buffer debug tracing to thunks.
class ThunkBufferDebugPass : public ThunkPassInterface {
public:
ThunkBufferDebugPass() = default;
enum class Mode {
kChecksum,
kNanCounter,
};
explicit ThunkBufferDebugPass(Mode mode) : mode_(mode) {}
absl::string_view name() const override { return "thunk-buffer-debug"; }
@ -39,6 +44,9 @@ class ThunkBufferDebugPass : public ThunkPassInterface {
const HloModule* absl_nullable hlo_module,
const se::DeviceDescription& device_info,
ThunkPassBufferAllocator& allocator) override;
private:
Mode mode_;
};
} // namespace gpu

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "xla/backends/gpu/runtime/buffers_checksum_thunk.h"
#include "xla/backends/gpu/runtime/buffers_nan_count_thunk.h"
#include "xla/backends/gpu/runtime/custom_call_thunk.h"
#include "xla/backends/gpu/runtime/sequential_thunk.h"
#include "xla/backends/gpu/runtime/thunk.h"
@ -101,7 +102,7 @@ TEST(ThunkBufferDebugPassTest, IsNoOpWhenHloModuleIsNull) {
auto root_thunk =
std::make_unique<SequentialThunk>(Thunk::ThunkInfo(), std::move(thunks));
ThunkBufferDebugPass pass;
ThunkBufferDebugPass pass(ThunkBufferDebugPass::Mode::kChecksum);
TF_ASSERT_OK_AND_ASSIGN(
bool changed, pass.Run(root_thunk.get(), debug_options,
/*hlo_module=*/nullptr, device_info, allocator));
@ -152,7 +153,7 @@ TEST(ThunkBufferDebugPassTest, InsertsBuffersDebugChecksumThunks) {
auto root_thunk =
std::make_unique<SequentialThunk>(Thunk::ThunkInfo(), std::move(thunks));
ThunkBufferDebugPass pass;
ThunkBufferDebugPass pass(ThunkBufferDebugPass::Mode::kChecksum);
TF_ASSERT_OK_AND_ASSIGN(bool changed,
pass.Run(root_thunk.get(), debug_options, &hlo_module,
device_info, allocator));
@ -205,6 +206,91 @@ TEST(ThunkBufferDebugPassTest, InsertsBuffersDebugChecksumThunks) {
Pair(ThunkBufferId::Create(kTestThunkId, 2).value(), slice_io)));
}
TEST(ThunkBufferDebugPassTest, InsertsBuffersDebugNanCounterThunks) {
static constexpr ThunkId kTestThunkId = ThunkId(123);
DebugOptions debug_options;
debug_options.set_xla_gpu_experimental_enable_nan_counter_on_thunks(true);
se::DeviceDescription device_info;
FakeThunkPassBufferAllocator allocator;
// The callbacks created by ThunkBufferDebugPass require a HloModule with
// a non-null entry computation.
auto builder = HloComputation::Builder("entry");
HloInstruction* root = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f)));
std::unique_ptr<HloComputation> entry_computation = builder.Build(root);
HloModule hlo_module("test_module", HloModuleConfig());
hlo_module.AddEntryComputation(std::move(entry_computation));
// Create a fake thunk with a few different buffer uses.
BufferAllocation alloc(0, 1024, 0);
BufferAllocation::Slice slice_i(&alloc, 0, 1, PrimitiveType::F32);
BufferAllocation::Slice slice_o(&alloc, 1, 1, PrimitiveType::F32);
BufferAllocation::Slice slice_io(&alloc, 2, 1, PrimitiveType::F32);
BufferAllocation::Slice slice_scratch(&alloc, 3, 1, PrimitiveType::F32);
Thunk::ThunkInfo fake_thunk_info;
fake_thunk_info.thunk_id = ThunkId(kTestThunkId);
auto fake_thunk = std::make_unique<FakeThunk>(
fake_thunk_info,
Thunk::BufferUses{
// Consume means the thunk can reuse the buffer for scratch space, so
// only check it on input.
BufferUse::Consume(slice_i),
// Write is undefined on input, but defined on output.
BufferUse::Write(slice_o),
// Unlike Consume, Read is supposed to preserve the contents of the
// buffer, so we check it on input *and* output.
BufferUse::Read(slice_io),
// Scratch buffers are not checked at all.
BufferUse::Scratch(slice_scratch),
});
Thunk* fake_thunk_ptr = fake_thunk.get();
std::vector<std::unique_ptr<Thunk>> thunks;
thunks.push_back(std::move(fake_thunk));
auto root_thunk =
std::make_unique<SequentialThunk>(Thunk::ThunkInfo(), std::move(thunks));
ThunkBufferDebugPass pass(ThunkBufferDebugPass::Mode::kNanCounter);
TF_ASSERT_OK_AND_ASSIGN(bool changed,
pass.Run(root_thunk.get(), debug_options, &hlo_module,
device_info, allocator));
EXPECT_TRUE(changed);
// Expected thunk structure after the pass:
// 1. CustomCallThunk (buffer debug log init)
// 2. SequentialThunk
// 1. FakeThunk
// 2. BuffersDebugNanCountThunk (nan counter output buffers)
// 3. CustomCallThunk (buffer debug log dump)
const std::vector<std::unique_ptr<Thunk>>& new_thunks = root_thunk->thunks();
EXPECT_THAT(new_thunks, SizeIs(3));
EXPECT_EQ(new_thunks[0]->kind(), Thunk::Kind::kCustomCall);
EXPECT_EQ(new_thunks[1]->kind(), Thunk::Kind::kSequential);
EXPECT_EQ(new_thunks[2]->kind(), Thunk::Kind::kCustomCall);
const CustomCallThunk& buffer_debug_init_thunk =
static_cast<const CustomCallThunk&>(*new_thunks[0]);
EXPECT_EQ(buffer_debug_init_thunk.target_name(),
"xla_gpu_buffer_debug_log_init");
const CustomCallThunk& buffer_debug_dump_thunk =
static_cast<const CustomCallThunk&>(*new_thunks[2]);
EXPECT_EQ(buffer_debug_dump_thunk.target_name(),
"xla_gpu_buffer_debug_log_dump");
const std::vector<std::unique_ptr<Thunk>>& sub_thunks =
static_cast<const SequentialThunk&>(*new_thunks[1]).thunks();
EXPECT_THAT(sub_thunks, SizeIs(2));
EXPECT_THAT(sub_thunks[0], Pointer(fake_thunk_ptr));
EXPECT_EQ(sub_thunks[1]->kind(), Thunk::Kind::kBuffersDebugNanCount);
const BuffersDebugNanCountThunk& buffer_debug_after_fake_thunk =
static_cast<const BuffersDebugNanCountThunk&>(*sub_thunks[1]);
EXPECT_THAT(
buffer_debug_after_fake_thunk.buffer_slices(),
UnorderedElementsAre(
Pair(ThunkBufferId::Create(kTestThunkId, 1).value(), slice_o),
Pair(ThunkBufferId::Create(kTestThunkId, 2).value(), slice_io)));
}
} // namespace
} // namespace gpu
} // namespace xla

View File

@ -471,6 +471,8 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
opts.set_xla_cpu_collective_call_terminate_timeout_seconds(40);
opts.set_xla_keep_shardings_after_spmd(false);
opts.set_xla_gpu_experimental_enable_checksum_tracing_on_thunks(false);
opts.set_xla_gpu_experimental_enable_nan_counter_on_thunks(false);
return opts;
}
@ -2654,6 +2656,13 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
debug_options->xla_gpu_experimental_enable_checksum_tracing_on_thunks(),
"Enables an experimental feature to record checksums of selected thunk "
"inputs/outputs."));
flag_list->push_back(tsl::Flag(
"xla_gpu_experimental_enable_nan_counter_on_thunks",
bool_setter_for(
&DebugOptions::set_xla_gpu_experimental_enable_nan_counter_on_thunks),
debug_options->xla_gpu_experimental_enable_nan_counter_on_thunks(),
"Enables an experimental feature to record the number of nans in thunk "
"outputs."));
flag_list->push_back(tsl::Flag(
"xla_gpu_experimental_thunk_buffer_debug_filter_by_thunk_id_ranges",
setter_for_thunk_buffer_debug_filter_by_thunk_id, "(none)",

View File

@ -5866,6 +5866,9 @@ tf_proto_library(
name = "buffer_assignment_proto",
srcs = ["buffer_assignment.proto"],
make_default_target_header_only = True,
protodeps = [
"//xla:xla_data_proto",
],
)
cc_library(

View File

@ -245,6 +245,7 @@ BufferAllocation::Slice::ToProto() const {
proto.set_offset(offset());
proto.set_size(size());
proto.set_buffer_allocation_index(allocation() == nullptr ? -1 : index());
proto.set_element_type(element_type());
return proto;
}
@ -259,13 +260,14 @@ absl::StatusOr<BufferAllocation::Slice> BufferAllocation::Slice::FromProto(
}
const BufferAllocation& allocation =
buffer_allocations[proto.buffer_allocation_index()];
return BufferAllocation::Slice(&allocation, proto.offset(), proto.size());
return BufferAllocation::Slice(&allocation, proto.offset(), proto.size(),
proto.element_type());
}
BufferAllocation::Slice BufferAllocation::GetSlice(
const HloValue& buffer) const {
const OffsetSize os = FindOrDie(assigned_buffers_, &buffer);
return Slice(this, os.offset, os.size);
return Slice(this, os.offset, os.size, buffer.shape().element_type());
}
absl::Status BufferAllocation::AddAssignment(const HloValue& buffer,
@ -331,6 +333,8 @@ BufferAllocationProto BufferAllocation::ToProto() const {
proto_assigned->set_logical_buffer_id(buffer_offset_size.first->id());
proto_assigned->set_offset(buffer_offset_size.second.offset);
proto_assigned->set_size(buffer_offset_size.second.size);
proto_assigned->set_element_type(
buffer_offset_size.first->shape().element_type());
}
absl::c_sort(*proto.mutable_assigned(),
[](const BufferAllocationProto::Assigned& assign1,

View File

@ -190,15 +190,22 @@ class BufferAllocation {
class Slice {
public:
Slice() = default;
Slice(const BufferAllocation* allocation, int64_t offset, int64_t size)
: allocation_(allocation), offset_(offset), size_(size) {}
Slice(const BufferAllocation* allocation, int64_t offset, int64_t size,
PrimitiveType element_type = PrimitiveType::PRIMITIVE_TYPE_INVALID)
: allocation_(allocation),
offset_(offset),
size_(size),
element_type_(element_type) {}
const BufferAllocation* allocation() const { return allocation_; }
Index index() const { return allocation_->index(); }
int64_t offset() const { return offset_; }
int64_t size() const { return size_; }
PrimitiveType element_type() const { return element_type_; }
bool operator==(const Slice& other) const {
// We don't compare element_type_ because it's not always set, and it's
// not relevant for the comparison here.
return index() == other.index() && offset_ == other.offset_ &&
size_ == other.size_;
}
@ -252,6 +259,7 @@ class BufferAllocation {
const BufferAllocation* allocation_ = nullptr;
int64_t offset_ = 0;
int64_t size_ = 0;
PrimitiveType element_type_ = PrimitiveType::PRIMITIVE_TYPE_INVALID;
};
// GetSlice returns the Slice of contiguous memory that holds the value

View File

@ -15,6 +15,8 @@ limitations under the License.
syntax = "proto3";
import "xla/xla_data.proto";
package xla.buffer_assignment;
// This defines the buffer isolation configuration, which is a debugging tool to
@ -108,4 +110,5 @@ message BufferAllocationSliceProto {
int64 offset = 1;
int64 size = 2;
int64 buffer_allocation_index = 3;
xla.PrimitiveType element_type = 4;
}

View File

@ -179,7 +179,12 @@ static absl::Status RunThunkPasses(const DebugOptions& debug_options,
ThunkPassBufferAllocator& allocator) {
ThunkPassPipeline pipeline("thunk-passes");
if (debug_options.xla_gpu_experimental_enable_checksum_tracing_on_thunks()) {
pipeline.AddPass(std::make_unique<ThunkBufferDebugPass>());
pipeline.AddPass(std::make_unique<ThunkBufferDebugPass>(
ThunkBufferDebugPass::Mode::kChecksum));
}
if (debug_options.xla_gpu_experimental_enable_nan_counter_on_thunks()) {
pipeline.AddPass(std::make_unique<ThunkBufferDebugPass>(
ThunkBufferDebugPass::Mode::kNanCounter));
}
if (debug_options.xla_gpu_experimental_enable_command_buffer_on_thunks()) {
pipeline.AddPass(std::make_unique<CommandBufferConversionPass>(

View File

@ -667,6 +667,7 @@ message BufferAllocationProto {
int64 logical_buffer_id = 1;
int64 offset = 2;
int64 size = 3;
xla.PrimitiveType element_type = 4;
}
int64 index = 1;

View File

@ -348,7 +348,7 @@ message DebugOptions {
// Limits the thunk buffer debug instrumentation to specific thunks.
optional ThunkBufferDebugFilter
xla_gpu_experimental_thunk_buffer_debug_filter = 423;
xla_gpu_experimental_thunk_buffer_debug_filter = 424;
// If true, every time an HLO module is run, we will dump an
// HloUnoptimizedSnapshot (essentially, a serialized unoptimizedmodule plus
@ -657,6 +657,9 @@ message DebugOptions {
optional bool xla_gpu_experimental_enable_heuristic_collective_combining =
366;
// If true, enable buffer nan counter on thunks.
optional bool xla_gpu_experimental_enable_nan_counter_on_thunks = 423;
// Enable NCCL symmetric buffers.
optional bool xla_gpu_experimental_enable_nccl_symmetric_buffers = 406;
@ -1388,7 +1391,7 @@ message DebugOptions {
// Note: when adding a new flag, please add it to one of the hardware-specific
// or hardware-agnostic sections at the top of this proto message.
// Next id: 424
// Next id: 425
// Extra options to pass to the compilation backend (e.g. LLVM); specific
// interpretation of these values is left to the backend.