[XLA:GPU] Add check_thunk_result_consistency tool for verifying checksum consistency

When implementing this it turned out that the log is currently missing some information needed to reliably distinguish input/output checksums and different thunk executions. This adds the needed fields to the proto, but emitting them in the log will be a separate change.

With the extra data missing, the tool assumes all checksums refer to outputs, and each thunk execution is going to give the same results each time. The tests include the extra data, so once that's implement it should(TM) just work.

PiperOrigin-RevId: 825040798
This commit is contained in:
Marcin Radomski 2025-10-28 07:59:00 -07:00 committed by TensorFlower Gardener
parent 542ffe0410
commit 7334d07917
8 changed files with 670 additions and 6 deletions

View File

@ -1,6 +1,6 @@
load("@local_config_cuda//cuda:build_defs.bzl", "cuda_library")
load("@rules_cc//cc:cc_library.bzl", "cc_library")
load("//xla:xla.default.bzl", "xla_cc_test", "xla_internal")
load("//xla:xla.default.bzl", "xla_cc_test", "xla_internal", "xla_py_proto_library")
load("//xla/tests:build_defs.bzl", "xla_test")
load("//xla/tsl:tsl.bzl", "internal_visibility", "nvtx_headers")
load("//xla/tsl:tsl.default.bzl", "get_compatible_with_portable")
@ -2445,6 +2445,11 @@ tf_proto_library(
],
)
xla_py_proto_library(
name = "thunk_proto_py",
deps = [":thunk_proto"],
)
tf_proto_library(
name = "dynamic_slice_thunk_proto",
srcs = [
@ -3093,6 +3098,11 @@ tf_proto_library(
srcs = ["buffer_debug_log.proto"],
)
xla_py_proto_library(
name = "buffer_debug_log_proto_py",
deps = [":buffer_debug_log_proto"],
)
cc_library(
name = "thunk_buffer_id",
hdrs = ["thunk_buffer_id.h"],

View File

@ -26,8 +26,20 @@ message BufferDebugLogEntryProto {
// Thunk::buffer_uses().
uint64 buffer_idx = 2;
// The value of the buffer.
uint32 value = 3;
// The checksum of the buffer.
uint32 checksum = 3;
// If true, the entry refers to a thunk input buffer, and the checksum is
// calculated based on the buffer value before the thunk execution.
//
// If false, it refers to thunk output, and the checksum is calculated based
// on the buffer value after the thunk execution.
bool is_input_buffer = 4;
// ID of the thunk execution that produced this entry. Entries with the same
// (thunk_id, execution_id) describe buffers used by a single execution of a
// thunk.
uint32 execution_id = 5;
}
// A dump of a `BufferDebugLog` contents.

View File

@ -103,7 +103,7 @@ absl::StatusOr<xla::gpu::BufferDebugLogProto> BufferDebugLog::ReadProto(
buffer_debug_log_proto.add_entries();
entry_proto->set_thunk_id(entry.entry_id.thunk_id().value());
entry_proto->set_buffer_idx(entry.entry_id.buffer_idx());
entry_proto->set_value(entry.value);
entry_proto->set_checksum(entry.value);
}
return buffer_debug_log_proto;

View File

@ -144,8 +144,8 @@ TEST_F(BufferDebugLogTest, ReadAsProto) {
device_log.ReadProto(*stream_));
EXPECT_THAT(log_proto, EqualsProto(R"pb(
entries { thunk_id: 123 buffer_idx: 4 value: 12341234 }
entries { thunk_id: 567 buffer_idx: 8 value: 56785678 }
entries { thunk_id: 123 buffer_idx: 4 checksum: 12341234 }
entries { thunk_id: 567 buffer_idx: 8 checksum: 56785678 }
)pb"));
}

View File

@ -0,0 +1,47 @@
# Tools and utilities for analyzing the BufferDebugLogProto dumps.
load("//xla:py_strict.bzl", "py_strict_test")
load("//xla:pytype.bzl", "pytype_strict_binary", "pytype_strict_library")
package(
# copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
default_visibility = ["//xla:internal"],
licenses = ["notice"],
)
pytype_strict_library(
name = "checksum_mismatch_report",
srcs = ["checksum_mismatch_report.py"],
deps = [
"//xla/backends/gpu/runtime:buffer_debug_log_proto_py",
"//xla/backends/gpu/runtime:thunk_proto_py",
],
)
py_strict_test(
name = "checksum_mismatch_report_test",
srcs = ["checksum_mismatch_report_test.py"],
deps = [
":checksum_mismatch_report",
"//xla/backends/gpu/runtime:buffer_debug_log_proto_py",
"//xla/backends/gpu/runtime:thunk_proto_py",
"@absl_py//absl/testing:absltest",
"@com_google_protobuf//:protobuf_python",
],
)
pytype_strict_binary(
name = "check_thunk_output_consistency",
srcs = [
"check_thunk_output_consistency.py",
],
main = "check_thunk_output_consistency.py",
deps = [
":checksum_mismatch_report",
"//xla/backends/gpu/runtime:buffer_debug_log_proto_py",
"//xla/backends/gpu/runtime:thunk_proto_py",
"@absl_py//absl:app",
"@absl_py//absl/flags",
"@com_google_protobuf//:protobuf_python",
],
)

View File

@ -0,0 +1,117 @@
# Copyright 2025 The OpenXLA Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A tool to analyze buffer debug logs.
To generate the log files, run the HLO with
--xla_gpu_experimental_enable_checksum_tracing_on_thunks.
"""
from collections.abc import Sequence
from absl import app
from absl import flags
from google.protobuf import message
from google.protobuf import text_format
from xla.backends.gpu.runtime import buffer_debug_log_pb2
from xla.backends.gpu.runtime import thunk_pb2
from xla.tools.buffer_debug_log import checksum_mismatch_report
def parse_binary_or_text_proto(data: bytes, proto_type):
"""Parses a binary or text proto."""
try:
return proto_type.FromString(data)
except message.DecodeError:
pass
return text_format.Parse(data, proto_type())
_METADATA_FILE = flags.DEFINE_string(
"metadata-file", None, "Path to the thunk metadata proto file."
)
def _print_formatted_report(
report: checksum_mismatch_report.ChecksumMismatchReport,
):
"""Prints a ChecksumMismatchReport to stdout in a human-readable format."""
if not report.mismatches:
print("\N{WHITE HEAVY CHECK MARK} All results are perfectly consistent.")
return
print(
"\N{OCTAGONAL SIGN} Different outputs detected among identical"
" thunk executions:"
)
for thunk_id, mismatches_by_inputs in report.mismatches.items():
if not mismatches_by_inputs:
continue
def describe_thunk(thunk_id: checksum_mismatch_report.ThunkId):
result = f"In outputs of thunk {thunk_id}"
metadata = " (metadata missing)"
if report.thunk_metadata:
thunk_metadata = report.thunk_metadata.get(thunk_id)
if thunk_metadata:
metadata = f" (kind: {thunk_metadata.thunk_kind}, profile_annotation:"
metadata += f" {thunk_metadata.profile_annotation})"
return result + metadata
print(describe_thunk(thunk_id))
for _, mismatches_by_buffer_idx in sorted(mismatches_by_inputs.items()):
for buffer_idx, checksums in mismatches_by_buffer_idx.items():
print(f" buffer {buffer_idx}: checksums={checksums}")
def main(argv: Sequence[str]) -> None:
if len(argv) < 2:
raise app.UsageError(
"Usage: buffer-debug.py [--metadata-file METADATA_PROTO_PATH]"
" LOG_PROTO_PATHS..."
)
log_protos = {}
for module_id, arg in enumerate(argv[1:]):
try:
with open(arg, "rb") as f:
log_protos[module_id] = parse_binary_or_text_proto(
f.read(), buffer_debug_log_pb2.BufferDebugLogProto
)
except Exception as e:
e.add_note(f"when reading {arg}")
raise
if _METADATA_FILE.value:
try:
with open(_METADATA_FILE.value, "rb") as f:
metadata_proto = parse_binary_or_text_proto(
f.read(), thunk_pb2.ThunkMetadataListProto
)
except Exception as e:
e.add_note(f"when reading {_METADATA_FILE.value}")
raise
else:
metadata_proto = thunk_pb2.ThunkMetadataListProto()
report = checksum_mismatch_report.ChecksumMismatchReport.from_protos(
log_protos, metadata_proto
)
_print_formatted_report(report)
if __name__ == "__main__":
app.run(main)

View File

@ -0,0 +1,250 @@
# Copyright 2025 The OpenXLA Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for detecting checksum mismatches in buffer debug logs.
The log is generated by running with
--xla_gpu_experimental_enable_checksum_tracing_on_thunks.
"""
import collections
import dataclasses
import itertools
from typing import Callable, Iterable, NewType, Optional, Self, TypeVar
from xla.backends.gpu.runtime import buffer_debug_log_pb2
from xla.backends.gpu.runtime import thunk_pb2
ModuleExecutionId = NewType("ModuleExecutionId", int)
ThunkId = NewType("ThunkId", int)
BufferIdx = NewType("BufferIdx", int)
Checksum = NewType("Checksum", int)
@dataclasses.dataclass(frozen=True)
class BufferChecksums:
"""A set of buffer checksums with order-independent hashing."""
checksums: dict[BufferIdx, Checksum]
def __hash__(self):
return hash(tuple(sorted(self.checksums.items())))
@dataclasses.dataclass(frozen=True)
class ThunkMetadata:
"""Thunk metadata, read from ThunkMetadataListProto.
Stored in a separate type to enable type checking.
"""
thunk_id: ThunkId
thunk_kind: str
profile_annotation: Optional[str]
@dataclasses.dataclass(frozen=True)
class ThunkExecution:
"""The details of a single execution of a thunk."""
# An ID of the HLO module execution that produced this thunk execution.
module_execution_id: int
# An ID of the thunk execution within a HLO module execution. If a thunk
# executes in a loop, there will create multiple entries with same thunk_id
# but different execution IDs.
thunk_execution_id: int
# The ID of the thunk that was executed. Details about the thunk can be found
# in ThunkMetadata.
thunk_id: ThunkId
# Checksums of buffers with defined contents before thunk execution.
# These are used to identify repeats of the same computation that are expected
# to produce the same results.
input_checksums: BufferChecksums
# Checksums of buffers with defined contents after thunk execution.
# These are the values we want to verify are consistent across executions.
output_checksums: BufferChecksums
@dataclasses.dataclass(frozen=True)
class ChecksumMismatchReport:
"""A report of checksum mismatches for a thunk."""
thunk_metadata: dict[ThunkId, ThunkMetadata]
# Thunks for which different executions produced different results. The value
# is a input checksums => output checksum sets dict containing the info about
# inconsistent outptus, and the checksums of inputs that caused them.
mismatches: dict[
ThunkId, dict[BufferChecksums, dict[BufferIdx, set[Checksum]]]
]
@classmethod
def from_protos(
cls,
log_protos: dict[
ModuleExecutionId, buffer_debug_log_pb2.BufferDebugLogProto
],
metadata_proto: thunk_pb2.ThunkMetadataListProto,
) -> Self:
"""Creates a ChecksumMismatchReport from protobufs.
Args:
log_protos: A dict of BufferDebugLogProto keyed by module execution ID.
metadata_proto: A ThunkMetadataListProto.
Preconditions:
- All log protos must refer to the same HLO module.
- metadata proto must describe the same HLO module as the log protos or be
an empty proto.
"""
metadata = _parse_metadata(metadata_proto)
executions = itertools.chain.from_iterable(
_parse_log(module_execution_id, log_proto)
for module_execution_id, log_proto in log_protos.items()
)
mismatches = _find_inconsistent_thunks(executions)
return cls(metadata, mismatches)
K = TypeVar("K")
T = TypeVar("T")
def group_by(
values: Iterable[T], key_getter: Callable[[T], K]
) -> dict[K, list[T]]:
"""Groups a sequence by a key function."""
result = collections.defaultdict(list)
for item in values:
result[key_getter(item)].append(item)
return result
def _parse_metadata(
metadata_proto: thunk_pb2.ThunkMetadataListProto,
) -> dict[ThunkId, ThunkMetadata]:
"""Parses a ThunkMetadataListProto into a dict of ThunkMetadata."""
metadata_by_thunk_id: dict[ThunkId, ThunkMetadata] = {}
for metadata in metadata_proto.thunk_metadata:
thunk_id = ThunkId(metadata.thunk_info.thunk_id)
metadata_by_thunk_id[thunk_id] = ThunkMetadata(
thunk_id=thunk_id,
thunk_kind=metadata.thunk_kind,
profile_annotation=metadata.thunk_info.profile_annotation,
)
return metadata_by_thunk_id
def _parse_log(
module_execution: int,
log_proto: buffer_debug_log_pb2.BufferDebugLogProto,
) -> list[ThunkExecution]:
"""Parses a BufferDebugLogProto and ThunkMetadataListProto into a list of ThunkExecutions."""
entries_by_execution = group_by(
log_proto.entries, lambda entry: (entry.thunk_id, entry.execution_id)
)
executions = [
ThunkExecution(
module_execution_id=module_execution,
thunk_execution_id=execution_id,
thunk_id=thunk_id,
input_checksums=BufferChecksums({
entry.buffer_idx: entry.checksum
for entry in entries
if entry.is_input_buffer
}),
output_checksums=BufferChecksums({
entry.buffer_idx: entry.checksum
for entry in entries
if not entry.is_input_buffer
}),
)
for (thunk_id, execution_id), entries in entries_by_execution.items()
]
return executions
def _find_inconsistent_output_checksums(
executions: list[ThunkExecution],
) -> dict[BufferIdx, set[Checksum]]:
"""Finds mismatches in output checksums for a list of identical executions.
Args:
executions: A list of executions of the same thunk on the same input
arguments.
Returns:
A dict of buffers whose contents were not consistent across executions with
the same inputs, based on the checksum value. The value is a set of
checksums observed for that buffer.
"""
checksums_by_buffer_idx: dict[BufferIdx, set[Checksum]] = (
collections.defaultdict(set)
)
for execution in executions:
for buffer_idx, checksum in execution.output_checksums.checksums.items():
checksums_by_buffer_idx[buffer_idx].add(checksum)
return {
buffer_idx: checksums
for buffer_idx, checksums in checksums_by_buffer_idx.items()
if len(checksums) > 1
}
def _find_inconsistent_thunks(
executions: Iterable[ThunkExecution],
) -> dict[ThunkId, dict[BufferChecksums, dict[BufferIdx, set[Checksum]]]]:
"""Finds thunks with inconsistent output checksums across identical executions.
Args:
executions: A arbitrary list of thunk executions.
Returns:
A dict of thunks whose outputs were inconsistent across identical
executions.
The value is a dict keyed by the set of input checksums, with values
identifying the output buffers with inconsistent checksums, along with the
set of observed checksums for each.
"""
executions_by_thunk_id: dict[ThunkId, list[ThunkExecution]] = group_by(
executions,
lambda e: e.thunk_id,
)
mismatches: dict[
ThunkId, dict[BufferChecksums, dict[BufferIdx, set[Checksum]]]
] = {}
for thunk_id, executions in executions_by_thunk_id.items():
executions_by_inputs: dict[BufferChecksums, list[ThunkExecution]] = (
group_by(executions, lambda e: e.input_checksums)
)
mismatches_by_inputs: dict[
BufferChecksums, dict[BufferIdx, set[Checksum]]
] = {}
for input_checksums, executions in executions_by_inputs.items():
m = _find_inconsistent_output_checksums(executions)
if m:
mismatches_by_inputs[input_checksums] = m
if mismatches_by_inputs:
mismatches[thunk_id] = mismatches_by_inputs
return mismatches

View File

@ -0,0 +1,228 @@
# Copyright 2025 The OpenXLA Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from absl.testing import absltest
from google.protobuf import text_format
from xla.backends.gpu.runtime import buffer_debug_log_pb2
from xla.backends.gpu.runtime import thunk_pb2
from xla.tools.buffer_debug_log import checksum_mismatch_report
class ChecksumMismatchReportTest(absltest.TestCase):
def test_from_protos_loads_metadata(self):
test_log = ""
test_metadata = """
thunk_metadata {
thunk_info {
thunk_id: 100
profile_annotation: "thunk1"
}
thunk_kind: "kGemm"
}
thunk_metadata {
thunk_info {
thunk_id: 101
profile_annotation: "thunk2"
}
thunk_kind: "kConv"
}
"""
log_proto = text_format.Parse(
test_log, buffer_debug_log_pb2.BufferDebugLogProto()
)
metadata_proto = text_format.Parse(
test_metadata,
thunk_pb2.ThunkMetadataListProto(),
)
report = checksum_mismatch_report.ChecksumMismatchReport.from_protos(
{0: log_proto}, metadata_proto
)
self.assertEqual(
report.thunk_metadata,
{
100: checksum_mismatch_report.ThunkMetadata(
thunk_id=100,
thunk_kind="kGemm",
profile_annotation="thunk1",
),
101: checksum_mismatch_report.ThunkMetadata(
thunk_id=101,
thunk_kind="kConv",
profile_annotation="thunk2",
),
},
)
def test_from_protos_finds_mismatches_in_single_proto(self):
test_log = """
entries {
thunk_id: 100
execution_id: 10
buffer_idx: 0
is_input_buffer: true
checksum: 11111111
}
entries {
thunk_id: 100
execution_id: 10
buffer_idx: 1
is_input_buffer: false
checksum: 22222222
}
entries {
thunk_id: 100
execution_id: 11
buffer_idx: 0
is_input_buffer: true
checksum: 11111111
}
entries {
thunk_id: 100
execution_id: 11
buffer_idx: 1
is_input_buffer: false
checksum: 33333333
}
"""
test_metadata = ""
log_proto = text_format.Parse(
test_log, buffer_debug_log_pb2.BufferDebugLogProto()
)
metadata_proto = text_format.Parse(
test_metadata,
thunk_pb2.ThunkMetadataListProto(),
)
report = checksum_mismatch_report.ChecksumMismatchReport.from_protos(
{0: log_proto}, metadata_proto
)
self.assertEqual(
report.mismatches,
{
# thunk ID
100: {
# input checksums
checksum_mismatch_report.BufferChecksums({0: 11111111}): {
# output buffer index => checksums
1: {22222222, 33333333},
},
},
},
)
def test_from_protos_finds_mismatches_in_multiple_protos(self):
test_log_template = """
entries {{
thunk_id: 100
execution_id: 10
buffer_idx: 0
is_input_buffer: true
checksum: 11111111
}}
entries {{
thunk_id: 100
execution_id: 10
buffer_idx: 1
is_input_buffer: false
checksum: {output_checksum}
}}
"""
test_logs = [
test_log_template.format(output_checksum=checksum)
for checksum in [22222222, 33333333]
]
test_metadata = ""
log_protos = {
module_id: text_format.Parse(
test_log, buffer_debug_log_pb2.BufferDebugLogProto()
)
for module_id, test_log in enumerate(test_logs)
}
metadata_proto = text_format.Parse(
test_metadata,
thunk_pb2.ThunkMetadataListProto(),
)
report = checksum_mismatch_report.ChecksumMismatchReport.from_protos(
log_protos, metadata_proto
)
self.assertEqual(
report.mismatches,
{
# thunk ID
100: {
# input checksums
checksum_mismatch_report.BufferChecksums({0: 11111111}): {
# output buffer index => checksums
1: {22222222, 33333333},
},
},
},
)
def test_from_protos_does_not_include_consistent_executions(self):
test_log = """
entries {
thunk_id: 100
execution_id: 10
buffer_idx: 0
is_input_buffer: true
checksum: 11111111
}
entries {
thunk_id: 100
execution_id: 10
buffer_idx: 1
is_input_buffer: false
checksum: 22222222
}
entries {
thunk_id: 100
execution_id: 11
buffer_idx: 0
is_input_buffer: true
checksum: 11111111
}
entries {
thunk_id: 100
execution_id: 11
buffer_idx: 1
is_input_buffer: false
checksum: 22222222
}
"""
test_metadata = ""
log_proto = text_format.Parse(
test_log, buffer_debug_log_pb2.BufferDebugLogProto()
)
metadata_proto = text_format.Parse(
test_metadata,
thunk_pb2.ThunkMetadataListProto(),
)
report = checksum_mismatch_report.ChecksumMismatchReport.from_protos(
{0: log_proto}, metadata_proto
)
self.assertEmpty(report.mismatches)
if __name__ == "__main__":
absltest.main()