[Autotuner] Add block level emitter backend for Triton fusion (2).

This change continues the work on the Triton block-level fusion emitter backend, which enables autotuning of tile configurations for custom Triton fusions in XLA.

This backend implements the following core interfaces:

- GetSupportedConfigs: Enumerates all supported combinations of tile sizes for the output tensors. The generated configs can be used during autotuning to explore different performance candidates.

- GetDefaultConfig: Provides a default tile configuration for a given Triton fusion, used as a fallback when no tuning data is available. (Implemented in a previous PR-28515)

- ApplyConfig: Applies a selected block-level fusion configuration to a Triton fusion instruction by updating its GpuBackendConfig. (will be added in the next PR)

PiperOrigin-RevId: 784233964
This commit is contained in:
Alex Pivovarov 2025-07-17 10:35:00 -07:00 committed by TensorFlower Gardener
parent e9b7af391b
commit f926778e7f
3 changed files with 303 additions and 1 deletions

View File

@ -80,6 +80,7 @@ xla_test(
"//xla/tsl/platform:statusor",
"//xla/tsl/util/proto:proto_matchers",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_googletest//:gtest_main",
],
)

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "xla/backends/gpu/autotuner/block_level_emitter.h"
#include <algorithm>
#include <cmath>
#include <cstdint>
#include <memory>
@ -76,6 +77,101 @@ constexpr int64_t GetTileSize(int64_t dim, int max_tile_size) {
return 1LL << static_cast<int64_t>(std::ceil(std::log2(dim)));
}
// Helper: resets all variable dimensions after 'index' to zero
void ResetTrailingDimensions(const std::vector<int64_t>& input,
std::vector<int64_t>& current, int64_t index) {
int64_t dims = input.size();
// Iterate over all dimensions after 'index'
// Only reset dimensions that are variable (input[j] >= 0)
for (int64_t j = index + 1; j < dims; ++j) {
if (input[j] >= 0) {
current[j] = 0;
}
}
}
// Helper: tries to advance to the next valid combination.
//
// Returns:
// - true: successfully advanced to the next combination (more combinations
// available)
// - false: no more combinations (all combinations have been generated)
bool AdvanceToNextCombination(const std::vector<int64_t>& input,
std::vector<int64_t>& current) {
int64_t dims = input.size();
// Iterate dimensions from right to left
for (int64_t i = dims - 1; i >= 0; --i) {
// Skip fixed dimensions (negative values in input)
if (input[i] < 0) {
continue;
}
// If the current dimension can still be incremented
if (current[i] < input[i]) {
current[i]++; // Increment this dimension
ResetTrailingDimensions(input, current, i); // Reset all after it
return true; // Not done yet, next combination ready
}
}
// If we reach here, all dimensions are at max and no increment possible
return false; // All combinations generated, done
}
// Generates all multi-dimensional integer combinations for a given shape.
//
// For each dimension `i` in `input`:
// - If input[i] >= 0 (variable dimension): the element at index `i` will
// range from 0 up to `input[i]`, inclusive.
// - If input[i] < 0 (fixed dimension): the element at index `i` will be
// fixed to the value of `input[i]`.
//
// For example, given input = {2, MIN_INT, 3}, the function returns:
// {
// {0, MIN_INT, 0}, {0, MIN_INT, 1}, {0, MIN_INT, 2}, {0, MIN_INT, 3},
// {1, MIN_INT, 0}, {1, MIN_INT, 1}, {1, MIN_INT, 2}, {1, MIN_INT, 3},
// {2, MIN_INT, 0}, {2, MIN_INT, 1}, {2, MIN_INT, 2}, {2, MIN_INT, 3}
// }
//
// Parameters:
// - input: a vector of integers representing upper bounds (inclusive) for each
// dimension. A negative value indicates that the dimension is fixed to
// that value.
//
// Returns:
// - A vector of integer vectors, where each inner vector is a unique
// combination.
//
// Notes:
// - The number of combinations is the product of all (input[i] + 1) where
// input[i] >= 0.
// - Each combination has the same length as `input`.
// - For dimensions with input[i] < 0, that value is used directly in all
// outputs.
std::vector<std::vector<int64_t>> GenerateCombinations(
const std::vector<int64_t>& input) {
std::vector<std::vector<int64_t>> result;
if (input.empty()) {
return result;
}
int64_t dims = input.size();
std::vector<int64_t> current(dims);
// Initialize each dimension: 0 for variable, input[i] if fixed
for (int64_t i = 0; i < dims; ++i) {
current[i] = std::min(input[i], int64_t{0});
}
// Loop until all combinations are generated
do {
// Add a copy of the current combination to the result
result.push_back(current);
// Attempt to increment to the next combination
} while (AdvanceToNextCombination(input, current));
return result;
}
// Recursively traverses a Shape object in depth-first order,
// collecting the dimensions of all array shapes encountered.
//
@ -123,7 +219,67 @@ std::vector<absl::Span<const int64_t>> FlatListOfShapes(
absl::StatusOr<std::vector<std::unique_ptr<BackendConfig>>>
BlockLevelEmitterBackend::GetSupportedConfigs(const HloInstruction& instr) {
return absl::UnimplementedError("GetSupportedConfigs is not implemented yet");
if (!IsSupported(instr)) {
return std::vector<std::unique_ptr<BackendConfig>>();
}
// This backend only supports array shapes (not tuples, etc.)
if (!instr.shape().IsArray()) {
return absl::InvalidArgumentError(
"Only array shapes are supported in block-level emitter "
"GetSupportedConfigs.");
}
// Compute the base-2 logarithm (rounded down) of each dimension size.
// This determines the range of tile sizes to explore in log2 space.
std::vector<int64_t> log2_dims;
for (const int64_t dim : instr.shape().dimensions()) {
// Exclude zero-sized dimensions from tiling configuration.
if (dim == 0) {
// Use INT64_MIN as a sentinel to mark zero-sized dimensions.
// These will be handled specially later.
log2_dims.push_back(INT64_MIN);
} else {
// ceil(log2(dim))
log2_dims.push_back(static_cast<int64_t>(std::ceil(std::log2(dim))));
}
}
std::vector<std::unique_ptr<BackendConfig>> configs;
// Generate all possible combinations of tile sizes across dimensions,
// by iterating over the space of log2(tile size) values.
//
// For example, if one dimension has log2 = 2 (i.e., dim=4),
// this will generate tile sizes of 1, 2, and 4 for that dim.
std::vector<std::vector<int64_t>> tile_log2_combinations =
GenerateCombinations(log2_dims);
// For each valid tile size combination, construct a corresponding config.
for (const std::vector<int64_t>& tile_log2_dims : tile_log2_combinations) {
BlockLevelFusionConfig config;
Tile* output_tile = config.add_output_tiles();
for (const int64_t log2_dim : tile_log2_dims) {
if (log2_dim == INT64_MIN) {
// Preserve 0-sized dimensions in the tile configuration.
output_tile->add_sizes(0);
} else {
// Convert log2 size back to actual tile size (1 << log2).
output_tile->add_sizes(1LL << log2_dim);
}
}
// Set default kernel execution parameters.
config.set_num_warps(1); // Number of warps per block.
config.set_num_ctas(1); // Number of thread blocks (CTAs).
config.set_num_stages(1); // Number of pipeline stages.
// Store the config (as a polymorphic BackendConfig).
configs.push_back(
std::make_unique<BlockLevelFusionConfig>(std::move(config)));
}
return configs;
}
absl::StatusOr<std::unique_ptr<BackendConfig>>

View File

@ -16,10 +16,12 @@ limitations under the License.
#include "xla/backends/gpu/autotuner/block_level_emitter.h"
#include <memory>
#include <vector>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/status/statusor.h"
#include "absl/strings/substitute.h"
#include "xla/autotuning.pb.h"
#include "xla/backends/autotuner/codegen_backend.h"
#include "xla/hlo/ir/hlo_instruction.h"
@ -263,5 +265,148 @@ ENTRY %main {
)pb"));
}
// Tests that `GetSupportedConfigs` returns a correct list of valid backend
// configurations for a fusion instruction.
// The fusion has output shape [64,1,16].
// The backend should generate a full set of tile configurations for
// different tile sizes for d0 and d2 while keeping the middle dimension d1
// fixed at 1.
TEST_F(TritonBlockLevelFusionEmitterBackendTest, GetSupportedConfigs) {
// Build and verify an HLO module containing a fusion with a 3D transpose.
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(R"(
HloModule m
%wrapped_transpose_computation {
%param_0 = f32[16,1,64]{2,1,0} parameter(0)
ROOT %transpose.3.1 = f32[64,1,16]{2,1,0} transpose(%param_0), dimensions={2,1,0}
}
ENTRY %main {
%p0 = f32[16,1,64]{2,1,0} parameter(0), metadata={op_name="a"}
ROOT %wrapped_transpose = f32[64,1,16]{2,1,0} fusion(%p0), kind=kCustom,
calls=%wrapped_transpose_computation,
metadata={op_name="a"},
backend_config={"fusion_backend_config":{"kind":"__triton"}}
}
)"));
// Call GetSupportedConfigs on the root instruction (the fusion op).
TF_ASSERT_OK_AND_ASSIGN(
std::vector<std::unique_ptr<BackendConfig>> configs,
backend_.GetSupportedConfigs(
*(module->entry_computation()->root_instruction())));
// The backend should generate 35 combinations (7 x 5).
// Expect 35 total configurations:
// - 7 choices for d0 (output dim 0 = 64): 1, 2, 4, 8, 16, 32, 64
// - 5 choices for d2 (output dim 2 = 16): 1, 2, 4, 8, 16
// The middle dimension (d1 = 1) must always have tile size 1.
ASSERT_EQ(configs.size(), 35);
int config_idx = 0;
// Iterate over all expected tile size combinations for d0 and d2.
// (d1 is fixed at 1 as per the input shape [16,1,64]).
for (int d0 : {1, 2, 4, 8, 16, 32, 64}) {
for (int d2 : {1, 2, 4, 8, 16}) {
ASSERT_EQ(configs[config_idx]->GetDescriptor(),
BlockLevelFusionConfig::GetDescriptor())
<< "Config is not a BlockLevelFusionConfig";
const BlockLevelFusionConfig* block_level_fusion_config =
dynamic_cast<const BlockLevelFusionConfig*>(
configs[config_idx].get());
ASSERT_NE(block_level_fusion_config, nullptr);
// Verify that the config matches the expected proto representation
// based on the current d0 and d2 tile size values.
// d1 is fixed at 1
// Also verify default tuning parameters: 1 warp, 1 CTA, 1 stage.
EXPECT_THAT(*block_level_fusion_config,
EqualsProto(absl::Substitute(
R"pb(
output_tiles { sizes: $0 sizes: 1 sizes: $1 }
num_warps: 1
num_ctas: 1
num_stages: 1
)pb",
d0, d2)));
++config_idx;
}
}
}
// Tests that `GetSupportedConfigs` returns the correct subset of tile
// configurations for fusion operations involving non-power-of-two tensor
// dimensions, and that it correctly handles zero-sized dimensions.
//
// The fusion has output shape [10,0,8].
// Tile size for the zero-sized dimension must be 0.
TEST_F(TritonBlockLevelFusionEmitterBackendTest,
GetSupportedConfigs_Zero_NonPow2Dim) {
// Build and verify an HLO module containing a fusion with a 3D transpose.
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(R"(
HloModule m
%wrapped_transpose_computation {
%param_0 = f32[8,0,10]{2,1,0} parameter(0)
ROOT %transpose.3.1 = f32[10,0,8]{2,1,0} transpose(%param_0), dimensions={2,1,0}
}
ENTRY %main {
%p0 = f32[8,0,10]{2,1,0} parameter(0), metadata={op_name="a"}
ROOT %wrapped_transpose = f32[10,0,8]{2,1,0} fusion(%p0), kind=kCustom,
calls=%wrapped_transpose_computation,
metadata={op_name="a"},
backend_config={"fusion_backend_config":{"kind":"__triton"}}
}
)"));
// Call GetSupportedConfigs on the root instruction (the fusion op).
TF_ASSERT_OK_AND_ASSIGN(
std::vector<std::unique_ptr<BackendConfig>> configs,
backend_.GetSupportedConfigs(
*(module->entry_computation()->root_instruction())));
// Expect 20 total configurations:
// - 5 choices for d0 (output dim 0 = 10): 1, 2, 4, 8, 16
// - 4 choices for d2 (output dim 2 = 8): 1, 2, 4, 8
// The middle dimension (d1 = 0) must always have tile size 0.
ASSERT_EQ(configs.size(), 20);
int i = 0;
// Iterate over tile size combinations for dimensions 0 and 2.
// Dimension 1 (middle) is zero-sized, so its tile size is fixed to 0.
for (int d0 : {1, 2, 4, 8, 16}) {
for (int d2 : {1, 2, 4, 8}) {
ASSERT_EQ(configs[i]->GetDescriptor(),
BlockLevelFusionConfig::GetDescriptor())
<< "Config is not a BlockLevelFusionConfig";
const BlockLevelFusionConfig* block_level_fusion_config =
dynamic_cast<const BlockLevelFusionConfig*>(configs[i].get());
ASSERT_NE(block_level_fusion_config, nullptr);
// Validate that tile shape matches expectations:
// - d0: 10 → tile sizes {1, 2, 4, 8, 16}
// - d1: 0 → must be tile size 0
// - d2: 8 → tile sizes {1, 2, 4, 8}
EXPECT_THAT(*block_level_fusion_config,
EqualsProto(absl::Substitute(
R"pb(
output_tiles { sizes: $0 sizes: 0 sizes: $1 }
num_warps: 1
num_ctas: 1
num_stages: 1
)pb",
d0, d2)));
++i;
}
}
}
} // namespace gpu
} // namespace xla