mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
[XLA:GPU] Test that DotDecomposer canonicalizes batch dims.
Existing tests do not cover this. PiperOrigin-RevId: 826006833
This commit is contained in:
parent
0a1309a2e5
commit
fbaeea227b
|
|
@ -239,6 +239,30 @@ TEST_F(DotDecomposerTest, AddRhsNonContractingDimIfZero) {
|
||||||
op::Shape("f32[64,0]"))));
|
op::Shape("f32[64,0]"))));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(DotDecomposerTest, CanonicalizeBatchDims) {
|
||||||
|
absl::string_view module_string = R"(
|
||||||
|
ENTRY main {
|
||||||
|
p0 = f32[64,4,32,8] parameter(0)
|
||||||
|
p1 = f32[128,4,8,32] parameter(1)
|
||||||
|
ROOT dot = f32[32,8,64,128] dot(p0, p1), lhs_batch_dims={2,3},
|
||||||
|
lhs_contracting_dims={1},
|
||||||
|
rhs_batch_dims={3,2},
|
||||||
|
rhs_contracting_dims={1}
|
||||||
|
})";
|
||||||
|
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||||
|
ParseAndReturnVerifiedModule(module_string));
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(bool canonicalized,
|
||||||
|
DotDecomposer().Run(module.get()));
|
||||||
|
EXPECT_TRUE(canonicalized);
|
||||||
|
|
||||||
|
EXPECT_THAT(module->entry_computation()->root_instruction(),
|
||||||
|
op::Reshape(AllOf(op::Dot(op::Reshape(), op::Reshape(),
|
||||||
|
/*lhs_contracting_dim=*/3,
|
||||||
|
/*rhs_contracting_dim=*/2),
|
||||||
|
op::Shape("f32[32,8,64,128]"))));
|
||||||
|
}
|
||||||
|
|
||||||
template <typename Arg0, typename Arg1, typename Arg2>
|
template <typename Arg0, typename Arg1, typename Arg2>
|
||||||
auto SparseDotMatcher(Arg0&& arg0, Arg1&& arg1, Arg2&& arg2) {
|
auto SparseDotMatcher(Arg0&& arg0, Arg1&& arg1, Arg2&& arg2) {
|
||||||
return match::Op()
|
return match::Op()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user