mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 00:19:58 +01:00
[XLA:GPU] Test that DotDecomposer canonicalizes batch dims.
Existing tests do not cover this. PiperOrigin-RevId: 826006833
This commit is contained in:
parent
0a1309a2e5
commit
fbaeea227b
|
|
@ -239,6 +239,30 @@ TEST_F(DotDecomposerTest, AddRhsNonContractingDimIfZero) {
|
|||
op::Shape("f32[64,0]"))));
|
||||
}
|
||||
|
||||
TEST_F(DotDecomposerTest, CanonicalizeBatchDims) {
|
||||
absl::string_view module_string = R"(
|
||||
ENTRY main {
|
||||
p0 = f32[64,4,32,8] parameter(0)
|
||||
p1 = f32[128,4,8,32] parameter(1)
|
||||
ROOT dot = f32[32,8,64,128] dot(p0, p1), lhs_batch_dims={2,3},
|
||||
lhs_contracting_dims={1},
|
||||
rhs_batch_dims={3,2},
|
||||
rhs_contracting_dims={1}
|
||||
})";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_string));
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool canonicalized,
|
||||
DotDecomposer().Run(module.get()));
|
||||
EXPECT_TRUE(canonicalized);
|
||||
|
||||
EXPECT_THAT(module->entry_computation()->root_instruction(),
|
||||
op::Reshape(AllOf(op::Dot(op::Reshape(), op::Reshape(),
|
||||
/*lhs_contracting_dim=*/3,
|
||||
/*rhs_contracting_dim=*/2),
|
||||
op::Shape("f32[32,8,64,128]"))));
|
||||
}
|
||||
|
||||
template <typename Arg0, typename Arg1, typename Arg2>
|
||||
auto SparseDotMatcher(Arg0&& arg0, Arg1&& arg1, Arg2&& arg2) {
|
||||
return match::Op()
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user