mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 00:19:58 +01:00
[XLA:GPU] add NanCount thunk to thunk_buffer_debug_pass
We call the pass for f32 and bf16 output buffers. PiperOrigin-RevId: 826808271
This commit is contained in:
parent
459ba30568
commit
4f3f2c9444
|
|
@ -2902,6 +2902,7 @@ cc_library(
|
|||
hdrs = ["thunk_buffer_debug_pass.h"],
|
||||
deps = [
|
||||
":buffers_checksum_thunk",
|
||||
":buffers_nan_count_thunk",
|
||||
":custom_call_thunk",
|
||||
":sequential_thunk",
|
||||
":thunk",
|
||||
|
|
@ -2934,6 +2935,7 @@ xla_cc_test(
|
|||
srcs = ["thunk_buffer_debug_pass_test.cc"],
|
||||
deps = [
|
||||
":buffers_checksum_thunk",
|
||||
":buffers_nan_count_thunk",
|
||||
":custom_call_thunk",
|
||||
":sequential_thunk",
|
||||
":thunk",
|
||||
|
|
@ -3096,6 +3098,7 @@ xla_test(
|
|||
":thunk",
|
||||
":thunk_buffer_id",
|
||||
":thunk_id",
|
||||
"//xla:types",
|
||||
"//xla/service:buffer_assignment",
|
||||
"//xla/service:executable",
|
||||
"//xla/service/gpu:buffer_allocations",
|
||||
|
|
|
|||
|
|
@ -22,7 +22,6 @@ limitations under the License.
|
|||
#include "absl/status/status.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "xla/backends/gpu/runtime/thunk.h"
|
||||
#include "xla/service/buffer_assignment.h"
|
||||
#include "xla/stream_executor/cuda/cuda_compute_capability.h"
|
||||
#include "xla/stream_executor/cuda/cuda_platform_id.h"
|
||||
#include "xla/stream_executor/device_memory.h"
|
||||
|
|
@ -88,18 +87,21 @@ absl::Status BuffersDebugNanCountThunk::ExecuteOnStream(
|
|||
se::gpu::BufferDebugLog buffer_debug_log =
|
||||
se::gpu::BufferDebugLog::FromDeviceMemoryUnchecked(log_ptr);
|
||||
|
||||
for (const auto& [entry_id, buffer_slice_pair] : buffers_) {
|
||||
BufferAllocation::Slice buffer = buffer_slice_pair.buffer;
|
||||
PrimitiveType buffer_type = buffer_slice_pair.element_type;
|
||||
for (const auto& [entry_id, buffer] : buffers_) {
|
||||
PrimitiveType buffer_type = buffer.element_type();
|
||||
se::DeviceMemoryBase device_buffer =
|
||||
params.buffer_allocations->GetDeviceAddress(buffer);
|
||||
if (buffer_type == PrimitiveType::F32) {
|
||||
VLOG(1) << "F32 buffer detected with id: " << entry_id
|
||||
<< " and size: " << device_buffer.size();
|
||||
se::DeviceMemory<float> f32_buffer(device_buffer);
|
||||
TF_RETURN_IF_ERROR(kernel_f32_->Launch(
|
||||
thread_dim, se::BlockDim(1, 1, 1), params.stream, entry_id,
|
||||
f32_buffer, f32_buffer.size(), buffer_debug_log.GetDeviceHeader(),
|
||||
buffer_debug_log.GetDeviceEntries()));
|
||||
} else if (buffer_type == PrimitiveType::BF16) {
|
||||
VLOG(1) << "BF16 buffer detected with id: " << entry_id
|
||||
<< " and size: " << device_buffer.size();
|
||||
se::DeviceMemory<Eigen::bfloat16> bf16_buffer(device_buffer);
|
||||
TF_RETURN_IF_ERROR(kernel_bf16_->Launch(
|
||||
thread_dim, se::BlockDim(1, 1, 1), params.stream, entry_id,
|
||||
|
|
@ -117,10 +119,9 @@ absl::Status BuffersDebugNanCountThunk::ExecuteOnStream(
|
|||
std::string BuffersDebugNanCountThunk::ToString(int indent) const {
|
||||
std::string result;
|
||||
absl::StrAppend(&result, ", buffers = ", buffers_.size());
|
||||
for (const auto& [buffer_id, buffer_slice_pair] : buffers_) {
|
||||
for (const auto& [buffer_id, buffer] : buffers_) {
|
||||
absl::StrAppend(&result, "\n", std::string(indent + 2, ' '),
|
||||
"buffer_id: ", buffer_id,
|
||||
", buffer: ", buffer_slice_pair.buffer.ToString());
|
||||
"buffer_id: ", buffer_id, ", buffer: ", buffer.ToString());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -31,14 +31,9 @@ namespace xla::gpu {
|
|||
|
||||
class BuffersDebugNanCountThunk : public Thunk {
|
||||
public:
|
||||
struct BufferToCount {
|
||||
BufferAllocation::Slice buffer;
|
||||
PrimitiveType element_type;
|
||||
};
|
||||
|
||||
explicit BuffersDebugNanCountThunk(
|
||||
ThunkInfo info, BufferAllocation::Slice log_slice,
|
||||
absl::flat_hash_map<ThunkBufferId, BufferToCount> buffers)
|
||||
absl::flat_hash_map<ThunkBufferId, BufferAllocation::Slice> buffers)
|
||||
: Thunk(Thunk::Kind::kBuffersDebugNanCount, std::move(info)),
|
||||
log_slice_(log_slice),
|
||||
buffers_(std::move(buffers)) {}
|
||||
|
|
@ -53,6 +48,11 @@ class BuffersDebugNanCountThunk : public Thunk {
|
|||
return {};
|
||||
}
|
||||
|
||||
const absl::flat_hash_map<ThunkBufferId, BufferAllocation::Slice>&
|
||||
buffer_slices() const {
|
||||
return buffers_;
|
||||
}
|
||||
|
||||
private:
|
||||
// Loaded in Initialize.
|
||||
std::optional<stream_executor::gpu::BufferDebugNanCountF32Kernel::KernelType>
|
||||
|
|
@ -60,7 +60,7 @@ class BuffersDebugNanCountThunk : public Thunk {
|
|||
std::optional<stream_executor::gpu::BufferDebugNanCountBf16Kernel::KernelType>
|
||||
kernel_bf16_;
|
||||
BufferAllocation::Slice log_slice_;
|
||||
absl::flat_hash_map<ThunkBufferId, BufferToCount> buffers_;
|
||||
absl::flat_hash_map<ThunkBufferId, BufferAllocation::Slice> buffers_;
|
||||
};
|
||||
|
||||
} // namespace xla::gpu
|
||||
|
|
|
|||
|
|
@ -40,6 +40,7 @@ limitations under the License.
|
|||
#include "xla/stream_executor/stream_executor_memory_allocator.h"
|
||||
#include "xla/tsl/lib/core/status_test_util.h"
|
||||
#include "xla/tsl/platform/statusor.h"
|
||||
#include "xla/types.h"
|
||||
|
||||
namespace xla::gpu {
|
||||
namespace {
|
||||
|
|
@ -85,12 +86,19 @@ TEST_F(BuffersDebugNanCountThunkTest, CalculatesNanCounts) {
|
|||
BufferAllocation alloc(/*index=*/0,
|
||||
/*size=*/kTotalDeviceMemoryBytes,
|
||||
/*color=*/0);
|
||||
int64_t input_offset = kLogSize;
|
||||
BufferAllocation::Slice log_slice(&alloc, /*offset=*/0, kLogSize);
|
||||
input_offset += kLogSize;
|
||||
|
||||
BufferAllocation::Slice inputs[2];
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
inputs[i] = BufferAllocation::Slice(
|
||||
&alloc, /*offset=*/kLogSize + i * kInputSizeInBytes, kInputSizeInBytes);
|
||||
}
|
||||
int64_t input_size_bf16 = kInputElems * sizeof(Eigen::bfloat16);
|
||||
inputs[0] = BufferAllocation::Slice(&alloc, input_offset, input_size_bf16,
|
||||
PrimitiveType::BF16);
|
||||
input_offset += input_size_bf16;
|
||||
|
||||
inputs[1] = BufferAllocation::Slice(
|
||||
&alloc, input_offset, kInputElems * sizeof(float), PrimitiveType::F32);
|
||||
|
||||
BufferAllocations allocations(
|
||||
{executor_->AllocateArray<uint8_t>(kTotalDeviceMemoryBytes)},
|
||||
executor_->device_ordinal(), allocator_.get());
|
||||
|
|
@ -102,13 +110,18 @@ TEST_F(BuffersDebugNanCountThunkTest, CalculatesNanCounts) {
|
|||
BufferDebugLog::CreateOnDevice(
|
||||
*stream_, se::DeviceMemory<uint8_t>(log_mem)));
|
||||
// Fill inputs with some data
|
||||
std::vector<float> data(kInputElems, 0);
|
||||
data[123] = std::numeric_limits<float>::quiet_NaN();
|
||||
TF_ASSERT_OK(stream_->Memcpy(&inputs0_mem, data.data(), kInputSizeInBytes));
|
||||
data[123] = 0;
|
||||
data[456] = std::numeric_limits<float>::quiet_NaN();
|
||||
data[789] = std::numeric_limits<float>::quiet_NaN();
|
||||
TF_ASSERT_OK(stream_->Memcpy(&inputs1_mem, data.data(), kInputSizeInBytes));
|
||||
{
|
||||
std::vector<Eigen::bfloat16> data(kInputElems, Eigen::bfloat16(0));
|
||||
data[123] = std::numeric_limits<Eigen::bfloat16>::quiet_NaN();
|
||||
TF_ASSERT_OK(stream_->Memcpy(&inputs0_mem, data.data(), kInputSizeInBytes));
|
||||
}
|
||||
{
|
||||
std::vector<float> data(kInputElems, 0);
|
||||
data[456] = std::numeric_limits<float>::quiet_NaN();
|
||||
data[789] = std::numeric_limits<float>::quiet_NaN();
|
||||
TF_ASSERT_OK(stream_->Memcpy(&inputs1_mem, data.data(), kInputSizeInBytes));
|
||||
}
|
||||
|
||||
// Setup parameters for Initialize/Prepare/ExecuteOnStream
|
||||
Thunk::InitializeParams init_params;
|
||||
init_params.executor = executor_;
|
||||
|
|
@ -121,10 +134,8 @@ TEST_F(BuffersDebugNanCountThunkTest, CalculatesNanCounts) {
|
|||
|
||||
BuffersDebugNanCountThunk thunk(
|
||||
Thunk::ThunkInfo(), log_slice,
|
||||
{{ThunkBufferId::Create(ThunkId(123), 4).value(),
|
||||
{inputs[0], PrimitiveType::F32}},
|
||||
{ThunkBufferId::Create(ThunkId(456), 8).value(),
|
||||
{inputs[1], PrimitiveType::F32}}});
|
||||
{{ThunkBufferId::Create(ThunkId(123), 4).value(), inputs[0]},
|
||||
{ThunkBufferId::Create(ThunkId(456), 8).value(), inputs[1]}});
|
||||
TF_ASSERT_OK(thunk.Initialize(init_params));
|
||||
TF_ASSERT_OK(thunk.Prepare(Thunk::PrepareParams{}, resource_requests));
|
||||
TF_ASSERT_OK(thunk.ExecuteOnStream(execute_params));
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ limitations under the License.
|
|||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "xla/backends/gpu/runtime/buffers_checksum_thunk.h"
|
||||
#include "xla/backends/gpu/runtime/buffers_nan_count_thunk.h"
|
||||
#include "xla/backends/gpu/runtime/custom_call_thunk.h"
|
||||
#include "xla/backends/gpu/runtime/sequential_thunk.h"
|
||||
#include "xla/backends/gpu/runtime/thunk.h"
|
||||
|
|
@ -65,7 +66,7 @@ namespace {
|
|||
// If the thunk got wrapped, the data dependencies between the thunks will be
|
||||
// configured to ensure `predecessor_thunk` executes before the wrapped thunk
|
||||
// and `successor_thunk` executes after.
|
||||
absl::StatusOr<std::unique_ptr<Thunk>> WrapThunk(
|
||||
absl::StatusOr<std::unique_ptr<Thunk>> WrapWithChecksumThunk(
|
||||
std::unique_ptr<Thunk> thunk, BufferAllocation::Slice log_slice,
|
||||
const Thunk& predecessor_thunk, Thunk& successor_thunk) {
|
||||
const auto& thunk_buffers = thunk->buffer_uses();
|
||||
|
|
@ -127,6 +128,79 @@ absl::StatusOr<std::unique_ptr<Thunk>> WrapThunk(
|
|||
return wrapped_thunk;
|
||||
}
|
||||
|
||||
absl::StatusOr<std::unique_ptr<Thunk>> WrapWithNanCounterThunk(
|
||||
std::unique_ptr<Thunk> thunk, BufferAllocation::Slice log_slice,
|
||||
const Thunk& predecessor_thunk, Thunk& successor_thunk) {
|
||||
const auto& thunk_buffers = thunk->buffer_uses();
|
||||
if (thunk_buffers.empty()) {
|
||||
VLOG(1) << "No buffers in thunk " << thunk->thunk_info().thunk_id
|
||||
<< ", skipping";
|
||||
return thunk;
|
||||
}
|
||||
|
||||
absl::flat_hash_map<ThunkBufferId, BufferAllocation::Slice> buffers_to_check;
|
||||
for (size_t buffer_idx = 0; buffer_idx < thunk_buffers.size(); ++buffer_idx) {
|
||||
VLOG(1) << "Buffer " << buffer_idx << " in thunk "
|
||||
<< thunk->thunk_info().thunk_id;
|
||||
const BufferUse& use = thunk_buffers[buffer_idx];
|
||||
const BufferAllocation::Slice& slice = use.slice();
|
||||
if (slice.allocation() == nullptr) {
|
||||
VLOG(1) << "Buffer " << buffer_idx << " in thunk "
|
||||
<< thunk->thunk_info().thunk_id
|
||||
<< " has null allocation, skipping";
|
||||
continue;
|
||||
}
|
||||
auto buffer_id =
|
||||
ThunkBufferId::Create(thunk->thunk_info().thunk_id, buffer_idx);
|
||||
if (!buffer_id.ok()) {
|
||||
LOG(WARNING) << "ThunkBufferId::Create failed: Skipping buffer "
|
||||
<< buffer_idx << " in thunk " << thunk->thunk_info().thunk_id
|
||||
<< ": " << buffer_id.status();
|
||||
continue;
|
||||
}
|
||||
if (slice.element_type() != PrimitiveType::F32 &&
|
||||
slice.element_type() != PrimitiveType::BF16) {
|
||||
VLOG(1) << "Buffer " << buffer_idx << " in thunk "
|
||||
<< thunk->thunk_info().thunk_id
|
||||
<< " has unsupported element type "
|
||||
<< PrimitiveType_Name(slice.element_type()) << ", skipping";
|
||||
continue;
|
||||
}
|
||||
if (!use.HasDefinedContentsOnOutput()) {
|
||||
VLOG(1) << "Buffer " << buffer_idx << " in thunk "
|
||||
<< thunk->thunk_info().thunk_id
|
||||
<< " has no defined contents on output, skipping";
|
||||
continue;
|
||||
}
|
||||
buffers_to_check.emplace(buffer_id.value(), use.slice());
|
||||
VLOG(1) << "Found buffer " << buffer_idx << " in thunk "
|
||||
<< thunk->thunk_info().thunk_id << " with element type "
|
||||
<< PrimitiveType_Name(slice.element_type()) << " and size "
|
||||
<< slice.size();
|
||||
}
|
||||
|
||||
if (buffers_to_check.empty()) {
|
||||
return thunk;
|
||||
}
|
||||
|
||||
VLOG(1) << "Wrapping thunk " << thunk->thunk_info().thunk_id
|
||||
<< " with nan counter thunk due to presence of buffers: "
|
||||
<< buffers_to_check.size();
|
||||
std::vector<std::unique_ptr<Thunk>> thunk_and_checks;
|
||||
Thunk* thunk_ptr = thunk.get();
|
||||
thunk_and_checks.push_back(std::move(thunk));
|
||||
auto buffer_debug_nan_counter_thunk =
|
||||
std::make_unique<BuffersDebugNanCountThunk>(Thunk::ThunkInfo(), log_slice,
|
||||
std::move(buffers_to_check));
|
||||
buffer_debug_nan_counter_thunk->add_control_predecessor(thunk_ptr);
|
||||
thunk_and_checks.push_back(std::move(buffer_debug_nan_counter_thunk));
|
||||
auto wrapped_thunk = std::make_unique<SequentialThunk>(
|
||||
Thunk::ThunkInfo(), std::move(thunk_and_checks));
|
||||
wrapped_thunk->add_control_predecessor(&predecessor_thunk);
|
||||
successor_thunk.add_control_predecessor(wrapped_thunk.get());
|
||||
return wrapped_thunk;
|
||||
}
|
||||
|
||||
XLA_FFI_DEFINE_HANDLER_SYMBOL(
|
||||
kDebugLogInitHandler,
|
||||
[](se::Stream* absl_nonnull stream, xla::ffi::Buffer<U8> log_buffer) {
|
||||
|
|
@ -207,10 +281,21 @@ absl::StatusOr<bool> ThunkBufferDebugPass::Run(
|
|||
|
||||
ThunkSequence& thunks = root_thunk->thunks();
|
||||
for (auto& thunk : thunks) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
thunk, WrapThunk(std::move(thunk), log_slice,
|
||||
/*predecessor_thunk=*/*buffer_debug_init_thunk.get(),
|
||||
/*successor_thunk=*/*buffer_debug_dump_thunk.get()));
|
||||
if (mode_ == Mode::kChecksum) {
|
||||
VLOG(1) << "Wrapping with checksum thunk";
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
thunk, WrapWithChecksumThunk(
|
||||
std::move(thunk), log_slice,
|
||||
/*predecessor_thunk=*/*buffer_debug_init_thunk.get(),
|
||||
/*successor_thunk=*/*buffer_debug_dump_thunk.get()));
|
||||
} else if (mode_ == Mode::kNanCounter) {
|
||||
VLOG(1) << "Wrapping with nan counter thunk";
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
thunk, WrapWithNanCounterThunk(
|
||||
std::move(thunk), log_slice,
|
||||
/*predecessor_thunk=*/*buffer_debug_init_thunk.get(),
|
||||
/*successor_thunk=*/*buffer_debug_dump_thunk.get()));
|
||||
}
|
||||
}
|
||||
|
||||
thunks.reserve(thunks.size() + 2);
|
||||
|
|
|
|||
|
|
@ -30,7 +30,12 @@ namespace gpu {
|
|||
// Adds buffer debug tracing to thunks.
|
||||
class ThunkBufferDebugPass : public ThunkPassInterface {
|
||||
public:
|
||||
ThunkBufferDebugPass() = default;
|
||||
enum class Mode {
|
||||
kChecksum,
|
||||
kNanCounter,
|
||||
};
|
||||
|
||||
explicit ThunkBufferDebugPass(Mode mode) : mode_(mode) {}
|
||||
|
||||
absl::string_view name() const override { return "thunk-buffer-debug"; }
|
||||
|
||||
|
|
@ -39,6 +44,9 @@ class ThunkBufferDebugPass : public ThunkPassInterface {
|
|||
const HloModule* absl_nullable hlo_module,
|
||||
const se::DeviceDescription& device_info,
|
||||
ThunkPassBufferAllocator& allocator) override;
|
||||
|
||||
private:
|
||||
Mode mode_;
|
||||
};
|
||||
|
||||
} // namespace gpu
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ limitations under the License.
|
|||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "xla/backends/gpu/runtime/buffers_checksum_thunk.h"
|
||||
#include "xla/backends/gpu/runtime/buffers_nan_count_thunk.h"
|
||||
#include "xla/backends/gpu/runtime/custom_call_thunk.h"
|
||||
#include "xla/backends/gpu/runtime/sequential_thunk.h"
|
||||
#include "xla/backends/gpu/runtime/thunk.h"
|
||||
|
|
@ -101,7 +102,7 @@ TEST(ThunkBufferDebugPassTest, IsNoOpWhenHloModuleIsNull) {
|
|||
auto root_thunk =
|
||||
std::make_unique<SequentialThunk>(Thunk::ThunkInfo(), std::move(thunks));
|
||||
|
||||
ThunkBufferDebugPass pass;
|
||||
ThunkBufferDebugPass pass(ThunkBufferDebugPass::Mode::kChecksum);
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
bool changed, pass.Run(root_thunk.get(), debug_options,
|
||||
/*hlo_module=*/nullptr, device_info, allocator));
|
||||
|
|
@ -152,7 +153,7 @@ TEST(ThunkBufferDebugPassTest, InsertsBuffersDebugChecksumThunks) {
|
|||
auto root_thunk =
|
||||
std::make_unique<SequentialThunk>(Thunk::ThunkInfo(), std::move(thunks));
|
||||
|
||||
ThunkBufferDebugPass pass;
|
||||
ThunkBufferDebugPass pass(ThunkBufferDebugPass::Mode::kChecksum);
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
||||
pass.Run(root_thunk.get(), debug_options, &hlo_module,
|
||||
device_info, allocator));
|
||||
|
|
@ -205,6 +206,91 @@ TEST(ThunkBufferDebugPassTest, InsertsBuffersDebugChecksumThunks) {
|
|||
Pair(ThunkBufferId::Create(kTestThunkId, 2).value(), slice_io)));
|
||||
}
|
||||
|
||||
TEST(ThunkBufferDebugPassTest, InsertsBuffersDebugNanCounterThunks) {
|
||||
static constexpr ThunkId kTestThunkId = ThunkId(123);
|
||||
DebugOptions debug_options;
|
||||
debug_options.set_xla_gpu_experimental_enable_nan_counter_on_thunks(true);
|
||||
se::DeviceDescription device_info;
|
||||
FakeThunkPassBufferAllocator allocator;
|
||||
// The callbacks created by ThunkBufferDebugPass require a HloModule with
|
||||
// a non-null entry computation.
|
||||
auto builder = HloComputation::Builder("entry");
|
||||
HloInstruction* root = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f)));
|
||||
std::unique_ptr<HloComputation> entry_computation = builder.Build(root);
|
||||
HloModule hlo_module("test_module", HloModuleConfig());
|
||||
hlo_module.AddEntryComputation(std::move(entry_computation));
|
||||
// Create a fake thunk with a few different buffer uses.
|
||||
BufferAllocation alloc(0, 1024, 0);
|
||||
BufferAllocation::Slice slice_i(&alloc, 0, 1, PrimitiveType::F32);
|
||||
BufferAllocation::Slice slice_o(&alloc, 1, 1, PrimitiveType::F32);
|
||||
BufferAllocation::Slice slice_io(&alloc, 2, 1, PrimitiveType::F32);
|
||||
BufferAllocation::Slice slice_scratch(&alloc, 3, 1, PrimitiveType::F32);
|
||||
Thunk::ThunkInfo fake_thunk_info;
|
||||
fake_thunk_info.thunk_id = ThunkId(kTestThunkId);
|
||||
auto fake_thunk = std::make_unique<FakeThunk>(
|
||||
fake_thunk_info,
|
||||
Thunk::BufferUses{
|
||||
// Consume means the thunk can reuse the buffer for scratch space, so
|
||||
// only check it on input.
|
||||
BufferUse::Consume(slice_i),
|
||||
// Write is undefined on input, but defined on output.
|
||||
BufferUse::Write(slice_o),
|
||||
// Unlike Consume, Read is supposed to preserve the contents of the
|
||||
// buffer, so we check it on input *and* output.
|
||||
BufferUse::Read(slice_io),
|
||||
// Scratch buffers are not checked at all.
|
||||
BufferUse::Scratch(slice_scratch),
|
||||
});
|
||||
Thunk* fake_thunk_ptr = fake_thunk.get();
|
||||
std::vector<std::unique_ptr<Thunk>> thunks;
|
||||
thunks.push_back(std::move(fake_thunk));
|
||||
auto root_thunk =
|
||||
std::make_unique<SequentialThunk>(Thunk::ThunkInfo(), std::move(thunks));
|
||||
|
||||
ThunkBufferDebugPass pass(ThunkBufferDebugPass::Mode::kNanCounter);
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
||||
pass.Run(root_thunk.get(), debug_options, &hlo_module,
|
||||
device_info, allocator));
|
||||
EXPECT_TRUE(changed);
|
||||
|
||||
// Expected thunk structure after the pass:
|
||||
// 1. CustomCallThunk (buffer debug log init)
|
||||
// 2. SequentialThunk
|
||||
// 1. FakeThunk
|
||||
// 2. BuffersDebugNanCountThunk (nan counter output buffers)
|
||||
// 3. CustomCallThunk (buffer debug log dump)
|
||||
const std::vector<std::unique_ptr<Thunk>>& new_thunks = root_thunk->thunks();
|
||||
EXPECT_THAT(new_thunks, SizeIs(3));
|
||||
EXPECT_EQ(new_thunks[0]->kind(), Thunk::Kind::kCustomCall);
|
||||
EXPECT_EQ(new_thunks[1]->kind(), Thunk::Kind::kSequential);
|
||||
EXPECT_EQ(new_thunks[2]->kind(), Thunk::Kind::kCustomCall);
|
||||
|
||||
const CustomCallThunk& buffer_debug_init_thunk =
|
||||
static_cast<const CustomCallThunk&>(*new_thunks[0]);
|
||||
EXPECT_EQ(buffer_debug_init_thunk.target_name(),
|
||||
"xla_gpu_buffer_debug_log_init");
|
||||
|
||||
const CustomCallThunk& buffer_debug_dump_thunk =
|
||||
static_cast<const CustomCallThunk&>(*new_thunks[2]);
|
||||
EXPECT_EQ(buffer_debug_dump_thunk.target_name(),
|
||||
"xla_gpu_buffer_debug_log_dump");
|
||||
|
||||
const std::vector<std::unique_ptr<Thunk>>& sub_thunks =
|
||||
static_cast<const SequentialThunk&>(*new_thunks[1]).thunks();
|
||||
EXPECT_THAT(sub_thunks, SizeIs(2));
|
||||
EXPECT_THAT(sub_thunks[0], Pointer(fake_thunk_ptr));
|
||||
EXPECT_EQ(sub_thunks[1]->kind(), Thunk::Kind::kBuffersDebugNanCount);
|
||||
|
||||
const BuffersDebugNanCountThunk& buffer_debug_after_fake_thunk =
|
||||
static_cast<const BuffersDebugNanCountThunk&>(*sub_thunks[1]);
|
||||
EXPECT_THAT(
|
||||
buffer_debug_after_fake_thunk.buffer_slices(),
|
||||
UnorderedElementsAre(
|
||||
Pair(ThunkBufferId::Create(kTestThunkId, 1).value(), slice_o),
|
||||
Pair(ThunkBufferId::Create(kTestThunkId, 2).value(), slice_io)));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
|
|
|||
9
third_party/xla/xla/debug_options_flags.cc
vendored
9
third_party/xla/xla/debug_options_flags.cc
vendored
|
|
@ -471,6 +471,8 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
|
|||
opts.set_xla_cpu_collective_call_terminate_timeout_seconds(40);
|
||||
|
||||
opts.set_xla_keep_shardings_after_spmd(false);
|
||||
opts.set_xla_gpu_experimental_enable_checksum_tracing_on_thunks(false);
|
||||
opts.set_xla_gpu_experimental_enable_nan_counter_on_thunks(false);
|
||||
return opts;
|
||||
}
|
||||
|
||||
|
|
@ -2654,6 +2656,13 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
|
|||
debug_options->xla_gpu_experimental_enable_checksum_tracing_on_thunks(),
|
||||
"Enables an experimental feature to record checksums of selected thunk "
|
||||
"inputs/outputs."));
|
||||
flag_list->push_back(tsl::Flag(
|
||||
"xla_gpu_experimental_enable_nan_counter_on_thunks",
|
||||
bool_setter_for(
|
||||
&DebugOptions::set_xla_gpu_experimental_enable_nan_counter_on_thunks),
|
||||
debug_options->xla_gpu_experimental_enable_nan_counter_on_thunks(),
|
||||
"Enables an experimental feature to record the number of nans in thunk "
|
||||
"outputs."));
|
||||
flag_list->push_back(tsl::Flag(
|
||||
"xla_gpu_experimental_thunk_buffer_debug_filter_by_thunk_id_ranges",
|
||||
setter_for_thunk_buffer_debug_filter_by_thunk_id, "(none)",
|
||||
|
|
|
|||
3
third_party/xla/xla/service/BUILD
vendored
3
third_party/xla/xla/service/BUILD
vendored
|
|
@ -5866,6 +5866,9 @@ tf_proto_library(
|
|||
name = "buffer_assignment_proto",
|
||||
srcs = ["buffer_assignment.proto"],
|
||||
make_default_target_header_only = True,
|
||||
protodeps = [
|
||||
"//xla:xla_data_proto",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
|
|
|
|||
|
|
@ -245,6 +245,7 @@ BufferAllocation::Slice::ToProto() const {
|
|||
proto.set_offset(offset());
|
||||
proto.set_size(size());
|
||||
proto.set_buffer_allocation_index(allocation() == nullptr ? -1 : index());
|
||||
proto.set_element_type(element_type());
|
||||
return proto;
|
||||
}
|
||||
|
||||
|
|
@ -259,13 +260,14 @@ absl::StatusOr<BufferAllocation::Slice> BufferAllocation::Slice::FromProto(
|
|||
}
|
||||
const BufferAllocation& allocation =
|
||||
buffer_allocations[proto.buffer_allocation_index()];
|
||||
return BufferAllocation::Slice(&allocation, proto.offset(), proto.size());
|
||||
return BufferAllocation::Slice(&allocation, proto.offset(), proto.size(),
|
||||
proto.element_type());
|
||||
}
|
||||
|
||||
BufferAllocation::Slice BufferAllocation::GetSlice(
|
||||
const HloValue& buffer) const {
|
||||
const OffsetSize os = FindOrDie(assigned_buffers_, &buffer);
|
||||
return Slice(this, os.offset, os.size);
|
||||
return Slice(this, os.offset, os.size, buffer.shape().element_type());
|
||||
}
|
||||
|
||||
absl::Status BufferAllocation::AddAssignment(const HloValue& buffer,
|
||||
|
|
@ -331,6 +333,8 @@ BufferAllocationProto BufferAllocation::ToProto() const {
|
|||
proto_assigned->set_logical_buffer_id(buffer_offset_size.first->id());
|
||||
proto_assigned->set_offset(buffer_offset_size.second.offset);
|
||||
proto_assigned->set_size(buffer_offset_size.second.size);
|
||||
proto_assigned->set_element_type(
|
||||
buffer_offset_size.first->shape().element_type());
|
||||
}
|
||||
absl::c_sort(*proto.mutable_assigned(),
|
||||
[](const BufferAllocationProto::Assigned& assign1,
|
||||
|
|
|
|||
12
third_party/xla/xla/service/buffer_assignment.h
vendored
12
third_party/xla/xla/service/buffer_assignment.h
vendored
|
|
@ -190,15 +190,22 @@ class BufferAllocation {
|
|||
class Slice {
|
||||
public:
|
||||
Slice() = default;
|
||||
Slice(const BufferAllocation* allocation, int64_t offset, int64_t size)
|
||||
: allocation_(allocation), offset_(offset), size_(size) {}
|
||||
Slice(const BufferAllocation* allocation, int64_t offset, int64_t size,
|
||||
PrimitiveType element_type = PrimitiveType::PRIMITIVE_TYPE_INVALID)
|
||||
: allocation_(allocation),
|
||||
offset_(offset),
|
||||
size_(size),
|
||||
element_type_(element_type) {}
|
||||
|
||||
const BufferAllocation* allocation() const { return allocation_; }
|
||||
Index index() const { return allocation_->index(); }
|
||||
int64_t offset() const { return offset_; }
|
||||
int64_t size() const { return size_; }
|
||||
PrimitiveType element_type() const { return element_type_; }
|
||||
|
||||
bool operator==(const Slice& other) const {
|
||||
// We don't compare element_type_ because it's not always set, and it's
|
||||
// not relevant for the comparison here.
|
||||
return index() == other.index() && offset_ == other.offset_ &&
|
||||
size_ == other.size_;
|
||||
}
|
||||
|
|
@ -252,6 +259,7 @@ class BufferAllocation {
|
|||
const BufferAllocation* allocation_ = nullptr;
|
||||
int64_t offset_ = 0;
|
||||
int64_t size_ = 0;
|
||||
PrimitiveType element_type_ = PrimitiveType::PRIMITIVE_TYPE_INVALID;
|
||||
};
|
||||
|
||||
// GetSlice returns the Slice of contiguous memory that holds the value
|
||||
|
|
|
|||
|
|
@ -15,6 +15,8 @@ limitations under the License.
|
|||
|
||||
syntax = "proto3";
|
||||
|
||||
import "xla/xla_data.proto";
|
||||
|
||||
package xla.buffer_assignment;
|
||||
|
||||
// This defines the buffer isolation configuration, which is a debugging tool to
|
||||
|
|
@ -108,4 +110,5 @@ message BufferAllocationSliceProto {
|
|||
int64 offset = 1;
|
||||
int64 size = 2;
|
||||
int64 buffer_allocation_index = 3;
|
||||
xla.PrimitiveType element_type = 4;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -179,7 +179,12 @@ static absl::Status RunThunkPasses(const DebugOptions& debug_options,
|
|||
ThunkPassBufferAllocator& allocator) {
|
||||
ThunkPassPipeline pipeline("thunk-passes");
|
||||
if (debug_options.xla_gpu_experimental_enable_checksum_tracing_on_thunks()) {
|
||||
pipeline.AddPass(std::make_unique<ThunkBufferDebugPass>());
|
||||
pipeline.AddPass(std::make_unique<ThunkBufferDebugPass>(
|
||||
ThunkBufferDebugPass::Mode::kChecksum));
|
||||
}
|
||||
if (debug_options.xla_gpu_experimental_enable_nan_counter_on_thunks()) {
|
||||
pipeline.AddPass(std::make_unique<ThunkBufferDebugPass>(
|
||||
ThunkBufferDebugPass::Mode::kNanCounter));
|
||||
}
|
||||
if (debug_options.xla_gpu_experimental_enable_command_buffer_on_thunks()) {
|
||||
pipeline.AddPass(std::make_unique<CommandBufferConversionPass>(
|
||||
|
|
|
|||
1
third_party/xla/xla/service/hlo.proto
vendored
1
third_party/xla/xla/service/hlo.proto
vendored
|
|
@ -667,6 +667,7 @@ message BufferAllocationProto {
|
|||
int64 logical_buffer_id = 1;
|
||||
int64 offset = 2;
|
||||
int64 size = 3;
|
||||
xla.PrimitiveType element_type = 4;
|
||||
}
|
||||
|
||||
int64 index = 1;
|
||||
|
|
|
|||
7
third_party/xla/xla/xla.proto
vendored
7
third_party/xla/xla/xla.proto
vendored
|
|
@ -348,7 +348,7 @@ message DebugOptions {
|
|||
|
||||
// Limits the thunk buffer debug instrumentation to specific thunks.
|
||||
optional ThunkBufferDebugFilter
|
||||
xla_gpu_experimental_thunk_buffer_debug_filter = 423;
|
||||
xla_gpu_experimental_thunk_buffer_debug_filter = 424;
|
||||
|
||||
// If true, every time an HLO module is run, we will dump an
|
||||
// HloUnoptimizedSnapshot (essentially, a serialized unoptimizedmodule plus
|
||||
|
|
@ -657,6 +657,9 @@ message DebugOptions {
|
|||
optional bool xla_gpu_experimental_enable_heuristic_collective_combining =
|
||||
366;
|
||||
|
||||
// If true, enable buffer nan counter on thunks.
|
||||
optional bool xla_gpu_experimental_enable_nan_counter_on_thunks = 423;
|
||||
|
||||
// Enable NCCL symmetric buffers.
|
||||
optional bool xla_gpu_experimental_enable_nccl_symmetric_buffers = 406;
|
||||
|
||||
|
|
@ -1388,7 +1391,7 @@ message DebugOptions {
|
|||
// Note: when adding a new flag, please add it to one of the hardware-specific
|
||||
// or hardware-agnostic sections at the top of this proto message.
|
||||
|
||||
// Next id: 424
|
||||
// Next id: 425
|
||||
|
||||
// Extra options to pass to the compilation backend (e.g. LLVM); specific
|
||||
// interpretation of these values is left to the backend.
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user