mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
[Autotuner] Add block level emitter backend for Triton fusion (2).
This change continues the work on the Triton block-level fusion emitter backend, which enables autotuning of tile configurations for custom Triton fusions in XLA. This backend implements the following core interfaces: - GetSupportedConfigs: Enumerates all supported combinations of tile sizes for the output tensors. The generated configs can be used during autotuning to explore different performance candidates. - GetDefaultConfig: Provides a default tile configuration for a given Triton fusion, used as a fallback when no tuning data is available. (Implemented in a previous PR-28515) - ApplyConfig: Applies a selected block-level fusion configuration to a Triton fusion instruction by updating its GpuBackendConfig. (will be added in the next PR) PiperOrigin-RevId: 784233964
This commit is contained in:
parent
e9b7af391b
commit
f926778e7f
|
|
@ -80,6 +80,7 @@ xla_test(
|
|||
"//xla/tsl/platform:statusor",
|
||||
"//xla/tsl/util/proto:proto_matchers",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||
|
||||
#include "xla/backends/gpu/autotuner/block_level_emitter.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
|
|
@ -76,6 +77,101 @@ constexpr int64_t GetTileSize(int64_t dim, int max_tile_size) {
|
|||
return 1LL << static_cast<int64_t>(std::ceil(std::log2(dim)));
|
||||
}
|
||||
|
||||
// Helper: resets all variable dimensions after 'index' to zero
|
||||
void ResetTrailingDimensions(const std::vector<int64_t>& input,
|
||||
std::vector<int64_t>& current, int64_t index) {
|
||||
int64_t dims = input.size();
|
||||
// Iterate over all dimensions after 'index'
|
||||
// Only reset dimensions that are variable (input[j] >= 0)
|
||||
for (int64_t j = index + 1; j < dims; ++j) {
|
||||
if (input[j] >= 0) {
|
||||
current[j] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Helper: tries to advance to the next valid combination.
|
||||
//
|
||||
// Returns:
|
||||
// - true: successfully advanced to the next combination (more combinations
|
||||
// available)
|
||||
// - false: no more combinations (all combinations have been generated)
|
||||
bool AdvanceToNextCombination(const std::vector<int64_t>& input,
|
||||
std::vector<int64_t>& current) {
|
||||
int64_t dims = input.size();
|
||||
// Iterate dimensions from right to left
|
||||
for (int64_t i = dims - 1; i >= 0; --i) {
|
||||
// Skip fixed dimensions (negative values in input)
|
||||
if (input[i] < 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// If the current dimension can still be incremented
|
||||
if (current[i] < input[i]) {
|
||||
current[i]++; // Increment this dimension
|
||||
ResetTrailingDimensions(input, current, i); // Reset all after it
|
||||
return true; // Not done yet, next combination ready
|
||||
}
|
||||
}
|
||||
// If we reach here, all dimensions are at max and no increment possible
|
||||
return false; // All combinations generated, done
|
||||
}
|
||||
|
||||
// Generates all multi-dimensional integer combinations for a given shape.
|
||||
//
|
||||
// For each dimension `i` in `input`:
|
||||
// - If input[i] >= 0 (variable dimension): the element at index `i` will
|
||||
// range from 0 up to `input[i]`, inclusive.
|
||||
// - If input[i] < 0 (fixed dimension): the element at index `i` will be
|
||||
// fixed to the value of `input[i]`.
|
||||
//
|
||||
// For example, given input = {2, MIN_INT, 3}, the function returns:
|
||||
// {
|
||||
// {0, MIN_INT, 0}, {0, MIN_INT, 1}, {0, MIN_INT, 2}, {0, MIN_INT, 3},
|
||||
// {1, MIN_INT, 0}, {1, MIN_INT, 1}, {1, MIN_INT, 2}, {1, MIN_INT, 3},
|
||||
// {2, MIN_INT, 0}, {2, MIN_INT, 1}, {2, MIN_INT, 2}, {2, MIN_INT, 3}
|
||||
// }
|
||||
//
|
||||
// Parameters:
|
||||
// - input: a vector of integers representing upper bounds (inclusive) for each
|
||||
// dimension. A negative value indicates that the dimension is fixed to
|
||||
// that value.
|
||||
//
|
||||
// Returns:
|
||||
// - A vector of integer vectors, where each inner vector is a unique
|
||||
// combination.
|
||||
//
|
||||
// Notes:
|
||||
// - The number of combinations is the product of all (input[i] + 1) where
|
||||
// input[i] >= 0.
|
||||
// - Each combination has the same length as `input`.
|
||||
// - For dimensions with input[i] < 0, that value is used directly in all
|
||||
// outputs.
|
||||
std::vector<std::vector<int64_t>> GenerateCombinations(
|
||||
const std::vector<int64_t>& input) {
|
||||
std::vector<std::vector<int64_t>> result;
|
||||
if (input.empty()) {
|
||||
return result;
|
||||
}
|
||||
|
||||
int64_t dims = input.size();
|
||||
std::vector<int64_t> current(dims);
|
||||
|
||||
// Initialize each dimension: 0 for variable, input[i] if fixed
|
||||
for (int64_t i = 0; i < dims; ++i) {
|
||||
current[i] = std::min(input[i], int64_t{0});
|
||||
}
|
||||
|
||||
// Loop until all combinations are generated
|
||||
do {
|
||||
// Add a copy of the current combination to the result
|
||||
result.push_back(current);
|
||||
// Attempt to increment to the next combination
|
||||
} while (AdvanceToNextCombination(input, current));
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// Recursively traverses a Shape object in depth-first order,
|
||||
// collecting the dimensions of all array shapes encountered.
|
||||
//
|
||||
|
|
@ -123,7 +219,67 @@ std::vector<absl::Span<const int64_t>> FlatListOfShapes(
|
|||
|
||||
absl::StatusOr<std::vector<std::unique_ptr<BackendConfig>>>
|
||||
BlockLevelEmitterBackend::GetSupportedConfigs(const HloInstruction& instr) {
|
||||
return absl::UnimplementedError("GetSupportedConfigs is not implemented yet");
|
||||
if (!IsSupported(instr)) {
|
||||
return std::vector<std::unique_ptr<BackendConfig>>();
|
||||
}
|
||||
|
||||
// This backend only supports array shapes (not tuples, etc.)
|
||||
if (!instr.shape().IsArray()) {
|
||||
return absl::InvalidArgumentError(
|
||||
"Only array shapes are supported in block-level emitter "
|
||||
"GetSupportedConfigs.");
|
||||
}
|
||||
|
||||
// Compute the base-2 logarithm (rounded down) of each dimension size.
|
||||
// This determines the range of tile sizes to explore in log2 space.
|
||||
std::vector<int64_t> log2_dims;
|
||||
for (const int64_t dim : instr.shape().dimensions()) {
|
||||
// Exclude zero-sized dimensions from tiling configuration.
|
||||
if (dim == 0) {
|
||||
// Use INT64_MIN as a sentinel to mark zero-sized dimensions.
|
||||
// These will be handled specially later.
|
||||
log2_dims.push_back(INT64_MIN);
|
||||
} else {
|
||||
// ceil(log2(dim))
|
||||
log2_dims.push_back(static_cast<int64_t>(std::ceil(std::log2(dim))));
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::unique_ptr<BackendConfig>> configs;
|
||||
|
||||
// Generate all possible combinations of tile sizes across dimensions,
|
||||
// by iterating over the space of log2(tile size) values.
|
||||
//
|
||||
// For example, if one dimension has log2 = 2 (i.e., dim=4),
|
||||
// this will generate tile sizes of 1, 2, and 4 for that dim.
|
||||
std::vector<std::vector<int64_t>> tile_log2_combinations =
|
||||
GenerateCombinations(log2_dims);
|
||||
|
||||
// For each valid tile size combination, construct a corresponding config.
|
||||
for (const std::vector<int64_t>& tile_log2_dims : tile_log2_combinations) {
|
||||
BlockLevelFusionConfig config;
|
||||
Tile* output_tile = config.add_output_tiles();
|
||||
|
||||
for (const int64_t log2_dim : tile_log2_dims) {
|
||||
if (log2_dim == INT64_MIN) {
|
||||
// Preserve 0-sized dimensions in the tile configuration.
|
||||
output_tile->add_sizes(0);
|
||||
} else {
|
||||
// Convert log2 size back to actual tile size (1 << log2).
|
||||
output_tile->add_sizes(1LL << log2_dim);
|
||||
}
|
||||
}
|
||||
|
||||
// Set default kernel execution parameters.
|
||||
config.set_num_warps(1); // Number of warps per block.
|
||||
config.set_num_ctas(1); // Number of thread blocks (CTAs).
|
||||
config.set_num_stages(1); // Number of pipeline stages.
|
||||
|
||||
// Store the config (as a polymorphic BackendConfig).
|
||||
configs.push_back(
|
||||
std::make_unique<BlockLevelFusionConfig>(std::move(config)));
|
||||
}
|
||||
return configs;
|
||||
}
|
||||
|
||||
absl::StatusOr<std::unique_ptr<BackendConfig>>
|
||||
|
|
|
|||
|
|
@ -16,10 +16,12 @@ limitations under the License.
|
|||
#include "xla/backends/gpu/autotuner/block_level_emitter.h"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/substitute.h"
|
||||
#include "xla/autotuning.pb.h"
|
||||
#include "xla/backends/autotuner/codegen_backend.h"
|
||||
#include "xla/hlo/ir/hlo_instruction.h"
|
||||
|
|
@ -263,5 +265,148 @@ ENTRY %main {
|
|||
)pb"));
|
||||
}
|
||||
|
||||
// Tests that `GetSupportedConfigs` returns a correct list of valid backend
|
||||
// configurations for a fusion instruction.
|
||||
// The fusion has output shape [64,1,16].
|
||||
// The backend should generate a full set of tile configurations for
|
||||
// different tile sizes for d0 and d2 while keeping the middle dimension d1
|
||||
// fixed at 1.
|
||||
TEST_F(TritonBlockLevelFusionEmitterBackendTest, GetSupportedConfigs) {
|
||||
// Build and verify an HLO module containing a fusion with a 3D transpose.
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(R"(
|
||||
HloModule m
|
||||
|
||||
%wrapped_transpose_computation {
|
||||
%param_0 = f32[16,1,64]{2,1,0} parameter(0)
|
||||
ROOT %transpose.3.1 = f32[64,1,16]{2,1,0} transpose(%param_0), dimensions={2,1,0}
|
||||
}
|
||||
|
||||
ENTRY %main {
|
||||
%p0 = f32[16,1,64]{2,1,0} parameter(0), metadata={op_name="a"}
|
||||
ROOT %wrapped_transpose = f32[64,1,16]{2,1,0} fusion(%p0), kind=kCustom,
|
||||
calls=%wrapped_transpose_computation,
|
||||
metadata={op_name="a"},
|
||||
backend_config={"fusion_backend_config":{"kind":"__triton"}}
|
||||
}
|
||||
)"));
|
||||
|
||||
// Call GetSupportedConfigs on the root instruction (the fusion op).
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
std::vector<std::unique_ptr<BackendConfig>> configs,
|
||||
backend_.GetSupportedConfigs(
|
||||
*(module->entry_computation()->root_instruction())));
|
||||
|
||||
// The backend should generate 35 combinations (7 x 5).
|
||||
// Expect 35 total configurations:
|
||||
// - 7 choices for d0 (output dim 0 = 64): 1, 2, 4, 8, 16, 32, 64
|
||||
// - 5 choices for d2 (output dim 2 = 16): 1, 2, 4, 8, 16
|
||||
// The middle dimension (d1 = 1) must always have tile size 1.
|
||||
ASSERT_EQ(configs.size(), 35);
|
||||
|
||||
int config_idx = 0;
|
||||
|
||||
// Iterate over all expected tile size combinations for d0 and d2.
|
||||
// (d1 is fixed at 1 as per the input shape [16,1,64]).
|
||||
for (int d0 : {1, 2, 4, 8, 16, 32, 64}) {
|
||||
for (int d2 : {1, 2, 4, 8, 16}) {
|
||||
ASSERT_EQ(configs[config_idx]->GetDescriptor(),
|
||||
BlockLevelFusionConfig::GetDescriptor())
|
||||
<< "Config is not a BlockLevelFusionConfig";
|
||||
const BlockLevelFusionConfig* block_level_fusion_config =
|
||||
dynamic_cast<const BlockLevelFusionConfig*>(
|
||||
configs[config_idx].get());
|
||||
ASSERT_NE(block_level_fusion_config, nullptr);
|
||||
|
||||
// Verify that the config matches the expected proto representation
|
||||
// based on the current d0 and d2 tile size values.
|
||||
// d1 is fixed at 1
|
||||
// Also verify default tuning parameters: 1 warp, 1 CTA, 1 stage.
|
||||
EXPECT_THAT(*block_level_fusion_config,
|
||||
EqualsProto(absl::Substitute(
|
||||
R"pb(
|
||||
output_tiles { sizes: $0 sizes: 1 sizes: $1 }
|
||||
num_warps: 1
|
||||
num_ctas: 1
|
||||
num_stages: 1
|
||||
)pb",
|
||||
d0, d2)));
|
||||
|
||||
++config_idx;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Tests that `GetSupportedConfigs` returns the correct subset of tile
|
||||
// configurations for fusion operations involving non-power-of-two tensor
|
||||
// dimensions, and that it correctly handles zero-sized dimensions.
|
||||
//
|
||||
// The fusion has output shape [10,0,8].
|
||||
// Tile size for the zero-sized dimension must be 0.
|
||||
TEST_F(TritonBlockLevelFusionEmitterBackendTest,
|
||||
GetSupportedConfigs_Zero_NonPow2Dim) {
|
||||
// Build and verify an HLO module containing a fusion with a 3D transpose.
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(R"(
|
||||
HloModule m
|
||||
|
||||
%wrapped_transpose_computation {
|
||||
%param_0 = f32[8,0,10]{2,1,0} parameter(0)
|
||||
ROOT %transpose.3.1 = f32[10,0,8]{2,1,0} transpose(%param_0), dimensions={2,1,0}
|
||||
}
|
||||
|
||||
ENTRY %main {
|
||||
%p0 = f32[8,0,10]{2,1,0} parameter(0), metadata={op_name="a"}
|
||||
ROOT %wrapped_transpose = f32[10,0,8]{2,1,0} fusion(%p0), kind=kCustom,
|
||||
calls=%wrapped_transpose_computation,
|
||||
metadata={op_name="a"},
|
||||
backend_config={"fusion_backend_config":{"kind":"__triton"}}
|
||||
}
|
||||
)"));
|
||||
|
||||
// Call GetSupportedConfigs on the root instruction (the fusion op).
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
std::vector<std::unique_ptr<BackendConfig>> configs,
|
||||
backend_.GetSupportedConfigs(
|
||||
*(module->entry_computation()->root_instruction())));
|
||||
|
||||
// Expect 20 total configurations:
|
||||
// - 5 choices for d0 (output dim 0 = 10): 1, 2, 4, 8, 16
|
||||
// - 4 choices for d2 (output dim 2 = 8): 1, 2, 4, 8
|
||||
// The middle dimension (d1 = 0) must always have tile size 0.
|
||||
ASSERT_EQ(configs.size(), 20);
|
||||
|
||||
int i = 0;
|
||||
|
||||
// Iterate over tile size combinations for dimensions 0 and 2.
|
||||
// Dimension 1 (middle) is zero-sized, so its tile size is fixed to 0.
|
||||
for (int d0 : {1, 2, 4, 8, 16}) {
|
||||
for (int d2 : {1, 2, 4, 8}) {
|
||||
ASSERT_EQ(configs[i]->GetDescriptor(),
|
||||
BlockLevelFusionConfig::GetDescriptor())
|
||||
<< "Config is not a BlockLevelFusionConfig";
|
||||
const BlockLevelFusionConfig* block_level_fusion_config =
|
||||
dynamic_cast<const BlockLevelFusionConfig*>(configs[i].get());
|
||||
ASSERT_NE(block_level_fusion_config, nullptr);
|
||||
|
||||
// Validate that tile shape matches expectations:
|
||||
// - d0: 10 → tile sizes {1, 2, 4, 8, 16}
|
||||
// - d1: 0 → must be tile size 0
|
||||
// - d2: 8 → tile sizes {1, 2, 4, 8}
|
||||
EXPECT_THAT(*block_level_fusion_config,
|
||||
EqualsProto(absl::Substitute(
|
||||
R"pb(
|
||||
output_tiles { sizes: $0 sizes: 0 sizes: $1 }
|
||||
num_warps: 1
|
||||
num_ctas: 1
|
||||
num_stages: 1
|
||||
)pb",
|
||||
d0, d2)));
|
||||
|
||||
++i;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user