mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 00:20:20 +01:00
[XLA] Fix a bug in cloning outfeeds, carried the wrong shape.
PiperOrigin-RevId: 163265592
This commit is contained in:
parent
1bad826d6f
commit
08790e73d1
|
|
@ -912,7 +912,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
|
||||||
return CreateInfeed(shape, infeed_config());
|
return CreateInfeed(shape, infeed_config());
|
||||||
case HloOpcode::kOutfeed:
|
case HloOpcode::kOutfeed:
|
||||||
CHECK_EQ(new_operands.size(), 1);
|
CHECK_EQ(new_operands.size(), 1);
|
||||||
return CreateOutfeed(shape, new_operands[0], outfeed_config());
|
return CreateOutfeed(outfeed_shape_, new_operands[0], outfeed_config());
|
||||||
case HloOpcode::kBatchNormGrad:
|
case HloOpcode::kBatchNormGrad:
|
||||||
CHECK_EQ(new_operands.size(), 5);
|
CHECK_EQ(new_operands.size(), 5);
|
||||||
return CreateBatchNormGrad(shape, new_operands[0], new_operands[1],
|
return CreateBatchNormGrad(shape, new_operands[0], new_operands[1],
|
||||||
|
|
|
||||||
|
|
@ -638,6 +638,27 @@ TEST_F(HloInstructionTest, PreserveMetadataInFusionAndClone) {
|
||||||
metadata, fusion->fused_expression_root()->operand(0)->metadata()));
|
metadata, fusion->fused_expression_root()->operand(0)->metadata()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(HloInstructionTest, PreserveOutfeedShapeThroughClone) {
|
||||||
|
HloComputation::Builder builder(TestName());
|
||||||
|
auto constant = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateConstant(Literal::CreateR2<float>({
|
||||||
|
{1, 2},
|
||||||
|
{3, 4},
|
||||||
|
})));
|
||||||
|
auto shape10 = ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0});
|
||||||
|
auto shape01 = ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {0, 1});
|
||||||
|
auto outfeed10 = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateOutfeed(shape10, constant, ""));
|
||||||
|
auto outfeed01 = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateOutfeed(shape01, constant, ""));
|
||||||
|
|
||||||
|
auto clone01 = builder.AddInstruction(outfeed01->Clone());
|
||||||
|
auto clone10 = builder.AddInstruction(outfeed10->Clone());
|
||||||
|
|
||||||
|
EXPECT_TRUE(ShapeUtil::Equal(clone01->outfeed_shape(), shape01));
|
||||||
|
EXPECT_TRUE(ShapeUtil::Equal(clone10->outfeed_shape(), shape10));
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(HloInstructionTest, FusionOpWithCalledComputations) {
|
TEST_F(HloInstructionTest, FusionOpWithCalledComputations) {
|
||||||
HloComputation::Builder builder(TestName());
|
HloComputation::Builder builder(TestName());
|
||||||
// Create a fusion instruction containing a single unary operation.
|
// Create a fusion instruction containing a single unary operation.
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user