[XLA:GPU] Test that DotDecomposer canonicalizes batch dims.

Existing tests do not cover this.

PiperOrigin-RevId: 826006833
This commit is contained in:
Thomas Joerg 2025-10-30 06:35:22 -07:00 committed by TensorFlower Gardener
parent 0a1309a2e5
commit fbaeea227b

View File

@ -239,6 +239,30 @@ TEST_F(DotDecomposerTest, AddRhsNonContractingDimIfZero) {
op::Shape("f32[64,0]")))); op::Shape("f32[64,0]"))));
} }
TEST_F(DotDecomposerTest, CanonicalizeBatchDims) {
absl::string_view module_string = R"(
ENTRY main {
p0 = f32[64,4,32,8] parameter(0)
p1 = f32[128,4,8,32] parameter(1)
ROOT dot = f32[32,8,64,128] dot(p0, p1), lhs_batch_dims={2,3},
lhs_contracting_dims={1},
rhs_batch_dims={3,2},
rhs_contracting_dims={1}
})";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_string));
TF_ASSERT_OK_AND_ASSIGN(bool canonicalized,
DotDecomposer().Run(module.get()));
EXPECT_TRUE(canonicalized);
EXPECT_THAT(module->entry_computation()->root_instruction(),
op::Reshape(AllOf(op::Dot(op::Reshape(), op::Reshape(),
/*lhs_contracting_dim=*/3,
/*rhs_contracting_dim=*/2),
op::Shape("f32[32,8,64,128]"))));
}
template <typename Arg0, typename Arg1, typename Arg2> template <typename Arg0, typename Arg1, typename Arg2>
auto SparseDotMatcher(Arg0&& arg0, Arg1&& arg1, Arg2&& arg2) { auto SparseDotMatcher(Arg0&& arg0, Arg1&& arg1, Arg2&& arg2) {
return match::Op() return match::Op()