[xla:gpu] Fix our convert integer to pred in our Triton emitter

`arith.trunci` for i1 will simply take the last bit, but HLO expects convert to i1 to be value != 0. Emit this conversion a a compare not equal to 0 instead. This is already done correctly for floats.

PiperOrigin-RevId: 824716165
This commit is contained in:
Tori Baker 2025-10-27 16:03:05 -07:00 committed by TensorFlower Gardener
parent ffc21f066a
commit d299463d26
4 changed files with 37 additions and 5 deletions

View File

@ -236,6 +236,11 @@ Value Cast(EmitterLocOpBuilder& b, Value value, Type dst_element_ty) {
}
return b.create<ma::ExtSIOp>(dst_ty, value);
}
// int => bool is always value != 0.
if (dst_element_ty.isInteger(1)) {
return b.create<ma::CmpIOp>(ma::CmpIPredicate::ne, value,
ZerosLike(b, value));
}
return b.create<ma::TruncIOp>(dst_ty, value);
}
// int => float

View File

@ -259,8 +259,8 @@ ENTRY e {
module_and_metadata.block_level_parameters,
R"(
CHECK: %[[LOAD:.*]] = xtile.extract {{.*}} -> tensor<16x16xi8>
CHECK: %[[TRUNCI:.*]] = arith.trunci %[[LOAD]] : tensor<16x16xi8> to tensor<16x16xi1>
CHECK: %{{.*}} = arith.andi %[[TRUNCI]], %{{.*}} : tensor<16x16xi1>
CHECK: %[[CMPI:.*]] = arith.cmpi ne, %[[LOAD]], {{.*}} : tensor<16x16xi8>
CHECK: %{{.*}} = arith.andi %[[CMPI]], %{{.*}} : tensor<16x16xi1>
)"));
}

View File

@ -468,9 +468,10 @@ ENTRY e {
)";
TF_EXPECT_OK(CreateTritonIrAndFileCheckForDot(this, kHloText,
"triton_gemm_computation", R"(
CHECK: %[[CST:.*]] = arith.constant dense<0>
CHECK: %[[LOAD:.*]] = tt.load %{{.*}} {{.*}} : !tt.ptr<tensor<16x16xi8>>
CHECK: %[[TRUNCI:.*]] = arith.trunci %[[LOAD]] : tensor<16x16xi8> to tensor<16x16xi1>
CHECK: %{{.*}} = arith.andi %[[TRUNCI]], %{{.*}} : tensor<16x16xi1>
CHECK: %[[CMPI:.*]] = arith.cmpi ne, %[[LOAD]], %[[CST]] : tensor<16x16xi8>
CHECK: %{{.*}} = arith.andi %[[CMPI]], %{{.*}} : tensor<16x16xi1>
)"));
}

View File

@ -231,6 +231,32 @@ ENTRY entry {
hlo_text, ErrorSpec{/*aabs=*/1e-4, /*arel=*/1e-6}));
}
TEST_F(TritonEmitterTest, ConvertIntegerToPredIsEmittedCorrectly) {
constexpr absl::string_view kHloText = R"(
HloModule m
fused_convert {
p0 = s32[3,2,2]{2,1,0} parameter(0)
ROOT convert0 = pred[3,2,2]{2,1,0} convert(p0)
}
ENTRY %main {
p0 = s32[3,2,2]{2,1,0} parameter(0)
ROOT input_convert_fusion = pred[3,2,2]{2,1,0} fusion(p0), kind=kCustom,
calls=fused_convert,
backend_config={"fusion_backend_config":{
"kind":"__triton","block_level_fusion_config":{
"num_warps":"1","output_tiles":[{"sizes":["1","2","2"]}],"num_ctas":1,
"num_stages":1,"is_tma_allowed":false}}}
}
)";
TF_EXPECT_OK(CreateTritonIrAndFileCheck(this, kHloText, "fused_convert", R"(
CHECK: %[[CST:.*]] = arith.constant dense<0>
CHECK: arith.cmpi ne, %{{.*}}, %[[CST]]
)"));
EXPECT_TRUE(RunAndCompareNoHloPasses(kHloText, kExactMatch));
}
TEST_F(TritonEmitterTest, PredicateAddIsEmittedCorrectly) {
constexpr absl::string_view kHloText = R"(
HloModule m
@ -2503,7 +2529,7 @@ ENTRY main {
TF_EXPECT_OK(
CreateTritonIrAndFileCheck(this, kHloText, "triton_computation", R"(
CHECK: %[[I8_PARAM:.*]] = xtile.extract {{.*}} -> tensor<4xi8>
CHECK: arith.trunci %[[I8_PARAM]] : tensor<4xi8> to tensor<4xi1>
CHECK: arith.cmpi ne, %[[I8_PARAM]], {{.*}} : tensor<4xi8>
)"));
EXPECT_TRUE(RunAndCompareNoHloPasses(kHloText, kExactMatch));