mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 00:19:58 +01:00
[xla:gpu] Fix our convert integer to pred in our Triton emitter
`arith.trunci` for i1 will simply take the last bit, but HLO expects convert to i1 to be value != 0. Emit this conversion a a compare not equal to 0 instead. This is already done correctly for floats. PiperOrigin-RevId: 824716165
This commit is contained in:
parent
ffc21f066a
commit
d299463d26
|
|
@ -236,6 +236,11 @@ Value Cast(EmitterLocOpBuilder& b, Value value, Type dst_element_ty) {
|
|||
}
|
||||
return b.create<ma::ExtSIOp>(dst_ty, value);
|
||||
}
|
||||
// int => bool is always value != 0.
|
||||
if (dst_element_ty.isInteger(1)) {
|
||||
return b.create<ma::CmpIOp>(ma::CmpIPredicate::ne, value,
|
||||
ZerosLike(b, value));
|
||||
}
|
||||
return b.create<ma::TruncIOp>(dst_ty, value);
|
||||
}
|
||||
// int => float
|
||||
|
|
|
|||
|
|
@ -259,8 +259,8 @@ ENTRY e {
|
|||
module_and_metadata.block_level_parameters,
|
||||
R"(
|
||||
CHECK: %[[LOAD:.*]] = xtile.extract {{.*}} -> tensor<16x16xi8>
|
||||
CHECK: %[[TRUNCI:.*]] = arith.trunci %[[LOAD]] : tensor<16x16xi8> to tensor<16x16xi1>
|
||||
CHECK: %{{.*}} = arith.andi %[[TRUNCI]], %{{.*}} : tensor<16x16xi1>
|
||||
CHECK: %[[CMPI:.*]] = arith.cmpi ne, %[[LOAD]], {{.*}} : tensor<16x16xi8>
|
||||
CHECK: %{{.*}} = arith.andi %[[CMPI]], %{{.*}} : tensor<16x16xi1>
|
||||
)"));
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -468,9 +468,10 @@ ENTRY e {
|
|||
)";
|
||||
TF_EXPECT_OK(CreateTritonIrAndFileCheckForDot(this, kHloText,
|
||||
"triton_gemm_computation", R"(
|
||||
CHECK: %[[CST:.*]] = arith.constant dense<0>
|
||||
CHECK: %[[LOAD:.*]] = tt.load %{{.*}} {{.*}} : !tt.ptr<tensor<16x16xi8>>
|
||||
CHECK: %[[TRUNCI:.*]] = arith.trunci %[[LOAD]] : tensor<16x16xi8> to tensor<16x16xi1>
|
||||
CHECK: %{{.*}} = arith.andi %[[TRUNCI]], %{{.*}} : tensor<16x16xi1>
|
||||
CHECK: %[[CMPI:.*]] = arith.cmpi ne, %[[LOAD]], %[[CST]] : tensor<16x16xi8>
|
||||
CHECK: %{{.*}} = arith.andi %[[CMPI]], %{{.*}} : tensor<16x16xi1>
|
||||
)"));
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -231,6 +231,32 @@ ENTRY entry {
|
|||
hlo_text, ErrorSpec{/*aabs=*/1e-4, /*arel=*/1e-6}));
|
||||
}
|
||||
|
||||
TEST_F(TritonEmitterTest, ConvertIntegerToPredIsEmittedCorrectly) {
|
||||
constexpr absl::string_view kHloText = R"(
|
||||
HloModule m
|
||||
|
||||
fused_convert {
|
||||
p0 = s32[3,2,2]{2,1,0} parameter(0)
|
||||
ROOT convert0 = pred[3,2,2]{2,1,0} convert(p0)
|
||||
}
|
||||
|
||||
ENTRY %main {
|
||||
p0 = s32[3,2,2]{2,1,0} parameter(0)
|
||||
ROOT input_convert_fusion = pred[3,2,2]{2,1,0} fusion(p0), kind=kCustom,
|
||||
calls=fused_convert,
|
||||
backend_config={"fusion_backend_config":{
|
||||
"kind":"__triton","block_level_fusion_config":{
|
||||
"num_warps":"1","output_tiles":[{"sizes":["1","2","2"]}],"num_ctas":1,
|
||||
"num_stages":1,"is_tma_allowed":false}}}
|
||||
}
|
||||
)";
|
||||
TF_EXPECT_OK(CreateTritonIrAndFileCheck(this, kHloText, "fused_convert", R"(
|
||||
CHECK: %[[CST:.*]] = arith.constant dense<0>
|
||||
CHECK: arith.cmpi ne, %{{.*}}, %[[CST]]
|
||||
)"));
|
||||
EXPECT_TRUE(RunAndCompareNoHloPasses(kHloText, kExactMatch));
|
||||
}
|
||||
|
||||
TEST_F(TritonEmitterTest, PredicateAddIsEmittedCorrectly) {
|
||||
constexpr absl::string_view kHloText = R"(
|
||||
HloModule m
|
||||
|
|
@ -2503,7 +2529,7 @@ ENTRY main {
|
|||
TF_EXPECT_OK(
|
||||
CreateTritonIrAndFileCheck(this, kHloText, "triton_computation", R"(
|
||||
CHECK: %[[I8_PARAM:.*]] = xtile.extract {{.*}} -> tensor<4xi8>
|
||||
CHECK: arith.trunci %[[I8_PARAM]] : tensor<4xi8> to tensor<4xi1>
|
||||
CHECK: arith.cmpi ne, %[[I8_PARAM]], {{.*}} : tensor<4xi8>
|
||||
)"));
|
||||
|
||||
EXPECT_TRUE(RunAndCompareNoHloPasses(kHloText, kExactMatch));
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user