mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
[xla:gpu] Fix our convert integer to pred in our Triton emitter
`arith.trunci` for i1 will simply take the last bit, but HLO expects convert to i1 to be value != 0. Emit this conversion a a compare not equal to 0 instead. This is already done correctly for floats. PiperOrigin-RevId: 824716165
This commit is contained in:
parent
ffc21f066a
commit
d299463d26
|
|
@ -236,6 +236,11 @@ Value Cast(EmitterLocOpBuilder& b, Value value, Type dst_element_ty) {
|
||||||
}
|
}
|
||||||
return b.create<ma::ExtSIOp>(dst_ty, value);
|
return b.create<ma::ExtSIOp>(dst_ty, value);
|
||||||
}
|
}
|
||||||
|
// int => bool is always value != 0.
|
||||||
|
if (dst_element_ty.isInteger(1)) {
|
||||||
|
return b.create<ma::CmpIOp>(ma::CmpIPredicate::ne, value,
|
||||||
|
ZerosLike(b, value));
|
||||||
|
}
|
||||||
return b.create<ma::TruncIOp>(dst_ty, value);
|
return b.create<ma::TruncIOp>(dst_ty, value);
|
||||||
}
|
}
|
||||||
// int => float
|
// int => float
|
||||||
|
|
|
||||||
|
|
@ -259,8 +259,8 @@ ENTRY e {
|
||||||
module_and_metadata.block_level_parameters,
|
module_and_metadata.block_level_parameters,
|
||||||
R"(
|
R"(
|
||||||
CHECK: %[[LOAD:.*]] = xtile.extract {{.*}} -> tensor<16x16xi8>
|
CHECK: %[[LOAD:.*]] = xtile.extract {{.*}} -> tensor<16x16xi8>
|
||||||
CHECK: %[[TRUNCI:.*]] = arith.trunci %[[LOAD]] : tensor<16x16xi8> to tensor<16x16xi1>
|
CHECK: %[[CMPI:.*]] = arith.cmpi ne, %[[LOAD]], {{.*}} : tensor<16x16xi8>
|
||||||
CHECK: %{{.*}} = arith.andi %[[TRUNCI]], %{{.*}} : tensor<16x16xi1>
|
CHECK: %{{.*}} = arith.andi %[[CMPI]], %{{.*}} : tensor<16x16xi1>
|
||||||
)"));
|
)"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -468,9 +468,10 @@ ENTRY e {
|
||||||
)";
|
)";
|
||||||
TF_EXPECT_OK(CreateTritonIrAndFileCheckForDot(this, kHloText,
|
TF_EXPECT_OK(CreateTritonIrAndFileCheckForDot(this, kHloText,
|
||||||
"triton_gemm_computation", R"(
|
"triton_gemm_computation", R"(
|
||||||
|
CHECK: %[[CST:.*]] = arith.constant dense<0>
|
||||||
CHECK: %[[LOAD:.*]] = tt.load %{{.*}} {{.*}} : !tt.ptr<tensor<16x16xi8>>
|
CHECK: %[[LOAD:.*]] = tt.load %{{.*}} {{.*}} : !tt.ptr<tensor<16x16xi8>>
|
||||||
CHECK: %[[TRUNCI:.*]] = arith.trunci %[[LOAD]] : tensor<16x16xi8> to tensor<16x16xi1>
|
CHECK: %[[CMPI:.*]] = arith.cmpi ne, %[[LOAD]], %[[CST]] : tensor<16x16xi8>
|
||||||
CHECK: %{{.*}} = arith.andi %[[TRUNCI]], %{{.*}} : tensor<16x16xi1>
|
CHECK: %{{.*}} = arith.andi %[[CMPI]], %{{.*}} : tensor<16x16xi1>
|
||||||
)"));
|
)"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -231,6 +231,32 @@ ENTRY entry {
|
||||||
hlo_text, ErrorSpec{/*aabs=*/1e-4, /*arel=*/1e-6}));
|
hlo_text, ErrorSpec{/*aabs=*/1e-4, /*arel=*/1e-6}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(TritonEmitterTest, ConvertIntegerToPredIsEmittedCorrectly) {
|
||||||
|
constexpr absl::string_view kHloText = R"(
|
||||||
|
HloModule m
|
||||||
|
|
||||||
|
fused_convert {
|
||||||
|
p0 = s32[3,2,2]{2,1,0} parameter(0)
|
||||||
|
ROOT convert0 = pred[3,2,2]{2,1,0} convert(p0)
|
||||||
|
}
|
||||||
|
|
||||||
|
ENTRY %main {
|
||||||
|
p0 = s32[3,2,2]{2,1,0} parameter(0)
|
||||||
|
ROOT input_convert_fusion = pred[3,2,2]{2,1,0} fusion(p0), kind=kCustom,
|
||||||
|
calls=fused_convert,
|
||||||
|
backend_config={"fusion_backend_config":{
|
||||||
|
"kind":"__triton","block_level_fusion_config":{
|
||||||
|
"num_warps":"1","output_tiles":[{"sizes":["1","2","2"]}],"num_ctas":1,
|
||||||
|
"num_stages":1,"is_tma_allowed":false}}}
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
TF_EXPECT_OK(CreateTritonIrAndFileCheck(this, kHloText, "fused_convert", R"(
|
||||||
|
CHECK: %[[CST:.*]] = arith.constant dense<0>
|
||||||
|
CHECK: arith.cmpi ne, %{{.*}}, %[[CST]]
|
||||||
|
)"));
|
||||||
|
EXPECT_TRUE(RunAndCompareNoHloPasses(kHloText, kExactMatch));
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(TritonEmitterTest, PredicateAddIsEmittedCorrectly) {
|
TEST_F(TritonEmitterTest, PredicateAddIsEmittedCorrectly) {
|
||||||
constexpr absl::string_view kHloText = R"(
|
constexpr absl::string_view kHloText = R"(
|
||||||
HloModule m
|
HloModule m
|
||||||
|
|
@ -2503,7 +2529,7 @@ ENTRY main {
|
||||||
TF_EXPECT_OK(
|
TF_EXPECT_OK(
|
||||||
CreateTritonIrAndFileCheck(this, kHloText, "triton_computation", R"(
|
CreateTritonIrAndFileCheck(this, kHloText, "triton_computation", R"(
|
||||||
CHECK: %[[I8_PARAM:.*]] = xtile.extract {{.*}} -> tensor<4xi8>
|
CHECK: %[[I8_PARAM:.*]] = xtile.extract {{.*}} -> tensor<4xi8>
|
||||||
CHECK: arith.trunci %[[I8_PARAM]] : tensor<4xi8> to tensor<4xi1>
|
CHECK: arith.cmpi ne, %[[I8_PARAM]], {{.*}} : tensor<4xi8>
|
||||||
)"));
|
)"));
|
||||||
|
|
||||||
EXPECT_TRUE(RunAndCompareNoHloPasses(kHloText, kExactMatch));
|
EXPECT_TRUE(RunAndCompareNoHloPasses(kHloText, kExactMatch));
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user