Implement Batchnorm Inference by expanding them into smaller ops.

1. Add batch norm inference support in batchnorm_rewriter
2. Connect xla's batchnorm inference to tf's FusedBatchNorm

RELNOTES: n/a
PiperOrigin-RevId: 165655351
This commit is contained in:
A. Unique TensorFlower 2017-08-17 18:04:58 -07:00 committed by TensorFlower Gardener
parent f0da8bf56b
commit 7359fec792
25 changed files with 605 additions and 27 deletions

View File

@ -63,6 +63,39 @@ class FusedBatchNormTest(XLATestCase):
grad_offset = np.sum(grad_y, axis=(0, 1, 2))
return grad_x, grad_scale, grad_offset
def testInference(self):
x_shape = [2, 2, 6, 2]
scale_shape = [2]
x_val = np.random.random_sample(x_shape).astype(np.float32)
scale_val = np.random.random_sample(scale_shape).astype(np.float32)
offset_val = np.random.random_sample(scale_shape).astype(np.float32)
data_format = "NHWC"
with self.test_session() as sess, self.test_scope():
# To avoid constant folding
t_val = array_ops.placeholder(np.float32, shape=x_shape, name="x")
scale = array_ops.placeholder(np.float32, shape=[2], name="scale")
offset = array_ops.placeholder(np.float32, shape=[2], name="offset")
epsilon = 0.001
y_ref, mean_ref, var_ref = self._reference_training(
x_val, scale_val, offset_val, epsilon, data_format)
y, mean, variance = nn.fused_batch_norm(
t_val,
scale,
offset,
mean=mean_ref,
variance=var_ref,
epsilon=epsilon,
data_format=data_format,
is_training=False)
y_val, _, _ = sess.run(
[y, mean,
variance], {t_val: x_val,
scale: scale_val,
offset: offset_val})
self.assertAllClose(y_val, y_ref, atol=1e-3)
def _testLearning(self, use_gradient_checker):
x_shape = [2, 2, 6, 2]
scale_shape = [2]

View File

@ -39,28 +39,36 @@ class FusedBatchNormOp : public XlaOpKernel {
errors::InvalidArgument("Not supported format"));
feature_index_ = GetTensorFeatureDimIndex(/*num_dims=*/4, tensor_format);
}
// TODO(b/62843645): Implement BatchNormInference.
OP_REQUIRES(
ctx, is_training_,
errors::InvalidArgument("Fused batch normalization for inference is "
"not supported yet on XLA backend."));
}
void Compile(XlaOpKernelContext* ctx) override {
xla::ComputationDataHandle output = ctx->builder()->BatchNormTraining(
ctx->Input(0), ctx->Input(1), ctx->Input(2), epsilon_, feature_index_);
if (is_training_) {
xla::ComputationDataHandle output = ctx->builder()->BatchNormTraining(
ctx->Input(0), ctx->Input(1), ctx->Input(2), epsilon_,
feature_index_);
// In training mode, outputs the normalized value as well as the calculated
// mean and variance.
for (int i = 0; i < 3; i++) {
ctx->SetOutput(i, ctx->builder()->GetTupleElement(output, i));
// In training mode, outputs the normalized value as well as the
// calculated mean and variance.
for (int i = 0; i < 3; i++) {
ctx->SetOutput(i, ctx->builder()->GetTupleElement(output, i));
}
// Output 3 and 4 for "FusedBatchNorm" are currently marked as "reserved
// space 1 & 2". They are used to pass the per-batch mean and
// variance to the gradient. Here we maintain the same behavior by setting
// them to the mean and variance calculated by BatchNormTraining.
ctx->SetOutput(3, ctx->builder()->GetTupleElement(output, 1));
ctx->SetOutput(4, ctx->builder()->GetTupleElement(output, 2));
} else {
xla::ComputationDataHandle output = ctx->builder()->BatchNormInference(
ctx->Input(0), ctx->Input(1), ctx->Input(2), ctx->Input(3),
ctx->Input(4), epsilon_, feature_index_);
ctx->SetOutput(0, output);
// Directly send input to output as mean and variance in inference mode.
ctx->SetOutput(1, ctx->Input(3));
ctx->SetOutput(2, ctx->Input(4));
ctx->SetOutput(3, ctx->Input(3));
ctx->SetOutput(4, ctx->Input(4));
}
// Output 3 and 4 for "FusedBatchNorm" are currently marked as "reserved
// space 1 & 2". They are used to pass the per-batch mean and
// variance to the gradient. Here we maintain the same behavior by setting
// them to the mean and variance calculated by BatchNormTraining.
ctx->SetOutput(3, ctx->builder()->GetTupleElement(output, 1));
ctx->SetOutput(4, ctx->builder()->GetTupleElement(output, 2));
}
private:

View File

@ -1477,9 +1477,29 @@ ComputationDataHandle ComputationBuilder::BatchNormInference(
const ComputationDataHandle& operand, const ComputationDataHandle& scale,
const ComputationDataHandle& offset, const ComputationDataHandle& mean,
const ComputationDataHandle& variance, float epsilon, int64 feature_index) {
// TODO(b/62843645): Implement BatchNormInference.
NoteError(Unimplemented("BatchNormInference is not implemented yet."));
return ComputationDataHandle();
if (!first_error_.ok() || !PrepareComputation().ok()) {
return ComputationDataHandle();
}
BatchNormInferenceRequest request;
*request.mutable_operand() = operand;
*request.mutable_scale() = scale;
*request.mutable_offset() = offset;
*request.mutable_mean() = mean;
*request.mutable_variance() = variance;
request.set_epsilon(epsilon);
request.set_feature_index(feature_index);
OpRequest op_request;
*op_request.mutable_batch_norm_inference_request() = request;
*op_request.mutable_computation() = computation_.handle();
AddOpMetadata(&op_request);
OpResponse response;
VLOG(2) << "making BatchNormInference request";
Status s = client_->stub()->Op(&op_request, &response);
return ParseOpResponse(s, &response);
}
ComputationDataHandle ComputationBuilder::BatchNormGrad(

View File

@ -56,11 +56,14 @@ class BatchNormRewriterVisitor : public DfsHloVisitorWithDefault {
Status HandleBatchNormTraining(HloInstruction* batch_norm) override;
Status HandleBatchNormInference(HloInstruction* batch_norm) override;
Status HandleBatchNormGrad(HloInstruction* batch_norm) override;
// Runs the visitor on a computation.
static bool Run(HloComputation* computation, bool rewrite_training_op,
bool rewrite_grad_op, bool use_fusion);
bool rewrite_inference_op, bool rewrite_grad_op,
bool use_fusion);
// Returns whether any batch norm ops were rewritten.
const bool changed() const { return changed_; }
@ -70,9 +73,11 @@ class BatchNormRewriterVisitor : public DfsHloVisitorWithDefault {
private:
explicit BatchNormRewriterVisitor(HloComputation* computation,
bool rewrite_training_op,
bool rewrite_inference_op,
bool rewrite_grad_op, bool use_fusion)
: computation_(computation),
rewrite_training_op_(rewrite_training_op),
rewrite_inference_op_(rewrite_inference_op),
rewrite_grad_op_(rewrite_grad_op),
use_fusion_(use_fusion) {}
@ -94,6 +99,7 @@ class BatchNormRewriterVisitor : public DfsHloVisitorWithDefault {
HloComputation* computation_;
bool rewrite_training_op_;
bool rewrite_inference_op_;
bool rewrite_grad_op_;
bool use_fusion_;
@ -126,11 +132,14 @@ class BatchNormRewriterVisitor : public DfsHloVisitorWithDefault {
bool BatchNormRewriterVisitor::Run(HloComputation* computation,
bool rewrite_training_op,
bool rewrite_inference_op,
bool rewrite_grad_op, bool use_fusion) {
BatchNormRewriterVisitor visitor(computation,
/*rewrite_training_op=*/rewrite_training_op,
/*rewrite_grad_op=*/rewrite_grad_op,
/*use_fusion=*/use_fusion);
BatchNormRewriterVisitor visitor(
computation,
/*rewrite_training_op=*/rewrite_training_op,
/*rewrite_inference_op=*/rewrite_inference_op,
/*rewrite_grad_op=*/rewrite_grad_op,
/*use_fusion=*/use_fusion);
TF_CHECK_OK(computation->Accept(&visitor));
return visitor.changed_;
}
@ -268,6 +277,82 @@ Status BatchNormRewriterVisitor::HandleBatchNormTraining(
return Status::OK();
}
Status BatchNormRewriterVisitor::HandleBatchNormInference(
HloInstruction* batch_norm) {
if (!rewrite_inference_op_) {
return Status::OK();
}
// Expand batch norm inference into smaller HLO ops.
HloInstruction* operand = batch_norm->mutable_operand(0);
const Shape operand_shape = operand->shape();
int64 feature_index = batch_norm->feature_index();
HloInstruction* scale = batch_norm->mutable_operand(1);
HloInstruction* offset = batch_norm->mutable_operand(2);
HloInstruction* mean = batch_norm->mutable_operand(3);
HloInstruction* var = batch_norm->mutable_operand(4);
const Shape feature_shape = scale->shape();
auto epsilon = computation_->AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0(batch_norm->epsilon())));
std::vector<int64> dimensions_without_feature;
for (int64 i = 0; i < ShapeUtil::Rank(operand_shape); ++i) {
if (i != feature_index) {
dimensions_without_feature.push_back(i);
}
}
auto scale_broadcasted = computation_->AddInstruction(
HloInstruction::CreateBroadcast(operand_shape, scale, {feature_index}));
auto offset_broadcasted = computation_->AddInstruction(
HloInstruction::CreateBroadcast(operand_shape, offset, {feature_index}));
auto mean_broadcasted = computation_->AddInstruction(
HloInstruction::CreateBroadcast(operand_shape, mean, {feature_index}));
auto var_broadcasted = computation_->AddInstruction(
HloInstruction::CreateBroadcast(operand_shape, var, {feature_index}));
// Var[X] + epsilon.
auto var_add_epsilon =
computation_->AddInstruction(HloInstruction::CreateBinary(
operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon));
auto neg_half = computation_->AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0(-0.5f)));
// 1 / Sqrt[Var[X] + epsilon].
auto rsqrt_var_add_epsilon =
computation_->AddInstruction(HloInstruction::CreateBinary(
operand_shape, HloOpcode::kPower, var_add_epsilon, neg_half));
// X - E[X].
auto operand_minus_mean =
computation_->AddInstruction(HloInstruction::CreateBinary(
operand_shape, HloOpcode::kSubtract, operand, mean_broadcasted));
// (X - E[X]) / Sqrt[Var[X] + epsilon].
auto normalized = computation_->AddInstruction(
HloInstruction::CreateBinary(operand_shape, HloOpcode::kMultiply,
operand_minus_mean, rsqrt_var_add_epsilon));
// (X - E[X]) / Sqrt[Var[X] + epsilon] * scale.
auto scaled_normalized =
computation_->AddInstruction(HloInstruction::CreateBinary(
operand_shape, HloOpcode::kMultiply, normalized, scale_broadcasted));
// (X - E[X]) / Sqrt[Var[X] + epsilon] * scale + offset.
auto shifted_normalized = HloInstruction::CreateBinary(
operand_shape, HloOpcode::kAdd, scaled_normalized, offset_broadcasted);
TF_CHECK_OK(
ReplaceWithNewInstruction(batch_norm, std::move(shifted_normalized)));
return Status::OK();
}
Status BatchNormRewriterVisitor::HandleBatchNormGrad(
HloInstruction* batch_norm) {
// Use the following formulas to calculate gradients:
@ -457,7 +542,8 @@ StatusOr<bool> BatchNormRewriter::Run(HloModule* module) {
}
for (auto& comp : computations) {
if (BatchNormRewriterVisitor::Run(comp, rewrite_training_op_,
rewrite_grad_op_, use_fusion_)) {
rewrite_inference_op_, rewrite_grad_op_,
use_fusion_)) {
changed = true;
}
}

View File

@ -30,8 +30,10 @@ class BatchNormRewriter : public HloPassInterface {
public:
// When use_fusion is set, a multi-output fusion node is created.
BatchNormRewriter(bool rewrite_training_op = false,
bool rewrite_inference_op = false,
bool rewrite_grad_op = false, bool use_fusion = true)
: rewrite_training_op_(rewrite_training_op),
rewrite_inference_op_(rewrite_inference_op),
rewrite_grad_op_(rewrite_grad_op),
use_fusion_(use_fusion) {}
~BatchNormRewriter() = default;
@ -43,6 +45,7 @@ class BatchNormRewriter : public HloPassInterface {
private:
bool rewrite_training_op_;
bool rewrite_inference_op_;
bool rewrite_grad_op_;
bool use_fusion_;
};

View File

@ -64,6 +64,7 @@ TEST_F(BatchNormRewriterTest, BatchNormTraining) {
HloInstruction* root = computation->root_instruction();
EXPECT_EQ(root->opcode(), HloOpcode::kBatchNormTraining);
BatchNormRewriter rewriter(/*rewrite_training_op=*/true,
/*rewrite_inference_op=*/true,
/*rewrite_grad_op=*/true);
ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie());
root = computation->root_instruction();
@ -105,6 +106,7 @@ TEST_F(BatchNormRewriterTest, BatchNormGrad) {
HloInstruction* root = computation->root_instruction();
EXPECT_EQ(root->opcode(), HloOpcode::kBatchNormGrad);
BatchNormRewriter rewriter(/*rewrite_training_op=*/true,
/*rewrite_inference_op=*/true,
/*rewrite_grad_op=*/true);
ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie());
root = computation->root_instruction();

View File

@ -260,6 +260,7 @@ Status CpuCompiler::RunHloPasses(HloModule* module) {
pipeline.AddPass<HloPassFix<HloPassPipeline>>("simplification");
pass.AddPass<BatchNormRewriter>(
/*rewrite_training_op=*/true,
/*rewrite_inference_op=*/true,
/*rewrite_grad_op=*/true,
/*use_fusion=*/false);
pass.AddPass<AlgebraicSimplifier>(

View File

@ -228,6 +228,9 @@ class DfsHloVisitor {
virtual Status HandleBatchNormTraining(HloInstruction* batchNormTraining) = 0;
virtual Status HandleBatchNormInference(
HloInstruction* batchNormInference) = 0;
virtual Status HandleBatchNormGrad(HloInstruction* batchNormGrad) = 0;
// Invoked to inform the visitor that the traversal has completed, and that

View File

@ -54,6 +54,10 @@ class DfsHloVisitorWithDefault : public DfsHloVisitor {
return DefaultAction(hlo);
}
Status HandleBatchNormInference(HloInstruction* hlo) override {
return DefaultAction(hlo);
}
Status HandleBatchNormGrad(HloInstruction* hlo) override {
return DefaultAction(hlo);
}

View File

@ -135,6 +135,7 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module,
// instead.
pass.AddPass<BatchNormRewriter>(
/*rewrite_training_op=*/true,
/*rewrite_inference_op=*/true,
/*rewrite_grad_op=*/true,
/*use_fusion=*/false);
pass.AddPass<AlgebraicSimplifier>(

View File

@ -374,6 +374,12 @@ Status HloCostAnalysis::HandleBatchNormTraining(
return Status::OK();
}
Status HloCostAnalysis::HandleBatchNormInference(
HloInstruction* batchNormInference) {
// TODO(b/62294698): Implement cost analysis for batch-norm-inference.
return Status::OK();
}
Status HloCostAnalysis::HandleBatchNormGrad(HloInstruction* batchNormGrad) {
// TODO(b/62294698): Implement cost analysis for batch-norm-grad.
return Status::OK();

View File

@ -89,6 +89,7 @@ class HloCostAnalysis : public DfsHloVisitor {
tensorflow::gtl::ArraySlice<int64> dimensions,
HloComputation* function_handle) override;
Status HandleBatchNormTraining(HloInstruction* batchNormTraining) override;
Status HandleBatchNormInference(HloInstruction* batchNormInference) override;
Status HandleBatchNormGrad(HloInstruction* batchNormGrad) override;
Status HandleFusion(HloInstruction* fusion) override;
Status HandleCall(HloInstruction* call) override;

View File

@ -742,6 +742,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
case HloOpcode::kParameter:
return kOrange;
case HloOpcode::kBatchNormTraining:
case HloOpcode::kBatchNormInference:
case HloOpcode::kBatchNormGrad:
case HloOpcode::kReduce:
case HloOpcode::kSelectAndScatter:

View File

@ -406,6 +406,23 @@ HloInstruction::CreateBatchNormTraining(const Shape& shape,
return instruction;
}
/* static */ std::unique_ptr<HloInstruction>
HloInstruction::CreateBatchNormInference(
const Shape& shape, HloInstruction* operand, HloInstruction* scale,
HloInstruction* offset, HloInstruction* mean, HloInstruction* variance,
float epsilon, int64 feature_index) {
auto instruction =
WrapUnique(new HloInstruction(HloOpcode::kBatchNormInference, shape));
instruction->AppendOperand(operand);
instruction->AppendOperand(scale);
instruction->AppendOperand(offset);
instruction->AppendOperand(mean);
instruction->AppendOperand(variance);
instruction->epsilon_ = epsilon;
instruction->feature_index_ = feature_index;
return instruction;
}
/* static */ std::unique_ptr<HloInstruction>
HloInstruction::CreateBatchNormGrad(const Shape& shape, HloInstruction* operand,
HloInstruction* scale, HloInstruction* mean,
@ -1065,6 +1082,12 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
return CreateBatchNormTraining(shape, new_operands[0], new_operands[1],
new_operands[2], epsilon(),
feature_index());
case HloOpcode::kBatchNormInference:
CHECK_EQ(new_operands.size(), 5);
return CreateBatchNormInference(
shape, new_operands[0], new_operands[1], new_operands[2],
new_operands[3], new_operands[4], epsilon(), feature_index());
case HloOpcode::kInfeed:
CHECK_EQ(new_operands.size(), 0);
return CreateInfeed(shape, infeed_config());
@ -1355,6 +1378,7 @@ bool HloInstruction::IdenticalSlowPath(
ShapeUtil::Compatible(shape(), other.shape());
case HloOpcode::kBatchNormTraining:
case HloOpcode::kBatchNormInference:
case HloOpcode::kBatchNormGrad:
return feature_index() == other.feature_index() &&
epsilon() == other.epsilon();
@ -1952,6 +1976,8 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) {
return visitor->HandleAbs(this, operands_[0]);
case HloOpcode::kBatchNormTraining:
return visitor->HandleBatchNormTraining(this);
case HloOpcode::kBatchNormInference:
return visitor->HandleBatchNormInference(this);
case HloOpcode::kBatchNormGrad:
return visitor->HandleBatchNormGrad(this);
case HloOpcode::kSign:

View File

@ -224,6 +224,12 @@ class HloInstruction {
const Shape& shape, HloInstruction* operand, HloInstruction* scale,
HloInstruction* offset, float epsilon, int64 feature_index);
// Creates a batch-norm-inference instruction.
static std::unique_ptr<HloInstruction> CreateBatchNormInference(
const Shape& shape, HloInstruction* operand, HloInstruction* scale,
HloInstruction* offset, HloInstruction* mean, HloInstruction* variance,
float epsilon, int64 feature_index);
// Creates a batch-norm-grad instruction.
static std::unique_ptr<HloInstruction> CreateBatchNormGrad(
const Shape& shape, HloInstruction* operand, HloInstruction* scale,

View File

@ -33,6 +33,8 @@ string HloOpcodeString(HloOpcode opcode) {
return "add";
case HloOpcode::kBatchNormTraining:
return "batch-norm-training";
case HloOpcode::kBatchNormInference:
return "batch-norm-inference";
case HloOpcode::kBatchNormGrad:
return "batch-norm-grad";
case HloOpcode::kBitcast:

View File

@ -31,6 +31,7 @@ enum class HloOpcode {
kAbs,
kAdd,
kBatchNormTraining,
kBatchNormInference,
kBatchNormGrad,
kBitcast,
kBroadcast,

View File

@ -78,6 +78,7 @@ namespace xla {
// Expensive instructions.
case HloOpcode::kBatchNormTraining:
case HloOpcode::kBatchNormInference:
case HloOpcode::kBatchNormGrad:
case HloOpcode::kCall:
case HloOpcode::kConvolution:

View File

@ -1211,6 +1211,10 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) {
handle_status = computation->AddBatchNormTrainingInstruction(
arg->batch_norm_training_request());
break;
case OpRequest::kBatchNormInferenceRequest:
handle_status = computation->AddBatchNormInferenceInstruction(
arg->batch_norm_inference_request());
break;
case OpRequest::kBatchNormGradRequest:
handle_status = computation->AddBatchNormGradInstruction(
arg->batch_norm_grad_request());

View File

@ -885,6 +885,150 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
output_shape_for_mean_and_var});
}
/* static */ StatusOr<Shape> ShapeInference::InferBatchNormInferenceShape(
const Shape& operand_shape, const Shape& offset_shape,
const Shape& scale_shape, const Shape& mean_shape,
const Shape& variance_shape, int64 feature_index) {
TF_RETURN_IF_ERROR(
ExpectNotTupleOrOpaque(operand_shape, "operand of batch norm inference"));
TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
offset_shape, "offset input of batch norm inference"));
TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
scale_shape, "scale input of batch norm inference"));
TF_RET_CHECK(ShapeUtil::ValidateShape(operand_shape) ==
tensorflow::Status::OK());
TF_RET_CHECK(ShapeUtil::ValidateShape(offset_shape) ==
tensorflow::Status::OK());
TF_RET_CHECK(ShapeUtil::ValidateShape(scale_shape) ==
tensorflow::Status::OK());
TF_RET_CHECK(ShapeUtil::ValidateShape(mean_shape) ==
tensorflow::Status::OK());
TF_RET_CHECK(ShapeUtil::ValidateShape(variance_shape) ==
tensorflow::Status::OK());
if (feature_index >= ShapeUtil::Rank(operand_shape)) {
return InvalidArgument(
"Expected feature_index of batch-norm-inference to be "
"smaller than the rank of operand_shape; "
"got feature_index %lld, and rank %lld",
feature_index, ShapeUtil::Rank(operand_shape));
}
if (feature_index < 0) {
return InvalidArgument(
"Expected feature_index of batch-norm-inference to "
"be a non-negative number, got %lld",
feature_index);
}
if (ShapeUtil::Rank(operand_shape) < 1) {
return InvalidArgument(
"Expected the rank of operand to "
"batch-norm-inference to be at least 1; got %lld",
ShapeUtil::Rank(operand_shape));
}
if (ShapeUtil::Rank(offset_shape) != 1) {
return InvalidArgument(
"Offset input of batch-norm-inference must have"
" rank 1, but has rank %lld.",
ShapeUtil::Rank(offset_shape));
}
if (ShapeUtil::Rank(scale_shape) != 1) {
return InvalidArgument(
"Scale input of batch-norm-inference must have"
" rank 1, but has rank %lld.",
ShapeUtil::Rank(scale_shape));
}
if (!ShapeUtil::ElementIsFloating(operand_shape)) {
return InvalidArgument(
"The operand to batch-norm-inference must have a floating point "
"element type, but the shape is %s",
PrimitiveType_Name(operand_shape.element_type()).c_str());
}
if (!ShapeUtil::SameElementType(offset_shape, operand_shape)) {
return InvalidArgument(
"The inputs should have the same element type for "
"batch-norm-inference, "
"but the shape of offset factor is %s "
"and the shape of operand is %s",
PrimitiveType_Name(offset_shape.element_type()).c_str(),
PrimitiveType_Name(operand_shape.element_type()).c_str());
}
if (!ShapeUtil::SameElementType(scale_shape, operand_shape)) {
return InvalidArgument(
"The inputs should have the same element type for "
"batch-norm-inference, "
"but the shape of scale factor is %s "
"and the shape of operand is %s",
PrimitiveType_Name(scale_shape.element_type()).c_str(),
PrimitiveType_Name(operand_shape.element_type()).c_str());
}
if (!ShapeUtil::SameElementType(mean_shape, operand_shape)) {
return InvalidArgument(
"The inputs should have the same element type for "
"batch-norm-inference, "
"but the shape of mean is %s "
"and the shape of operand is %s",
PrimitiveType_Name(mean_shape.element_type()).c_str(),
PrimitiveType_Name(operand_shape.element_type()).c_str());
}
if (!ShapeUtil::SameElementType(variance_shape, operand_shape)) {
return InvalidArgument(
"The inputs should have the same element type for "
"batch-norm-inference, "
"but the shape of variance is %s "
"and the shape of operand is %s",
PrimitiveType_Name(mean_shape.element_type()).c_str(),
PrimitiveType_Name(variance_shape.element_type()).c_str());
}
const int64 feature_count = operand_shape.dimensions(feature_index);
Shape output_shape_for_mean_and_var =
ShapeUtil::MakeShape(operand_shape.element_type(), {feature_count});
if (ShapeUtil::GetDimension(offset_shape, 0) != feature_count) {
return InvalidArgument(
"The size of offset factor should be the same as feature count,"
"but the size of offset factor is %lld "
"and the feature count is %lld",
ShapeUtil::GetDimension(offset_shape, 0), feature_count);
}
if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) {
return InvalidArgument(
"The size of scale factor should be the same as feature count,"
"but the size of scale factor is %lld "
"and the feature count is %lld",
ShapeUtil::GetDimension(scale_shape, 0), feature_count);
}
if (ShapeUtil::GetDimension(mean_shape, 0) != feature_count) {
return InvalidArgument(
"The size of mean should be the same as feature count,"
"but the size of mean is %lld "
"and the feature count is %lld",
ShapeUtil::GetDimension(mean_shape, 0), feature_count);
}
if (ShapeUtil::GetDimension(variance_shape, 0) != feature_count) {
return InvalidArgument(
"The size of variance should be the same as feature count,"
"but the size of variance is %lld "
"and the feature count is %lld",
ShapeUtil::GetDimension(variance_shape, 0), feature_count);
}
return operand_shape;
}
/* static */ StatusOr<Shape> ShapeInference::InferBatchNormGradShape(
const Shape& operand_shape, const Shape& scale_shape,
const Shape& mean_shape, const Shape& var_shape,

View File

@ -71,6 +71,13 @@ class ShapeInference {
const Shape& scale_shape,
int64 feature_index);
// Infers the shape produced by InferBatchNormInference with the given
// operands.
static StatusOr<Shape> InferBatchNormInferenceShape(
const Shape& operand_shape, const Shape& offset_shape,
const Shape& scale_shape, const Shape& mean_shape,
const Shape& variance_shape, int64 feature_index);
// Infers the shape produced by InferBatchNormGrad with the given operands.
static StatusOr<Shape> InferBatchNormGradShape(const Shape& operand_shape,
const Shape& scale_shape,

View File

@ -507,6 +507,53 @@ UserComputation::AddBatchNormTrainingInstruction(
return handle;
}
StatusOr<ComputationDataHandle>
UserComputation::AddBatchNormInferenceInstruction(
const BatchNormInferenceRequest& batch_norm_inference_request) {
tensorflow::mutex_lock lock(mutex_);
TF_ASSIGN_OR_RETURN(const OperationRequest* operand,
LookUpRequest(batch_norm_inference_request.operand()));
TF_ASSIGN_OR_RETURN(const OperationRequest* scale,
LookUpRequest(batch_norm_inference_request.scale()));
TF_ASSIGN_OR_RETURN(const OperationRequest* offset,
LookUpRequest(batch_norm_inference_request.offset()));
TF_ASSIGN_OR_RETURN(const OperationRequest* mean,
LookUpRequest(batch_norm_inference_request.mean()));
TF_ASSIGN_OR_RETURN(const OperationRequest* variance,
LookUpRequest(batch_norm_inference_request.variance()));
ComputationDataHandle handle = CreateComputationDataHandle();
OperationRequest& request =
(*session_computation_.mutable_requests())[handle.handle()];
TF_ASSIGN_OR_RETURN(Shape inferred_shape,
ShapeInference::InferBatchNormInferenceShape(
operand->output_shape(), scale->output_shape(),
offset->output_shape(), mean->output_shape(),
variance->output_shape(),
batch_norm_inference_request.feature_index()));
*request.mutable_output_shape() = inferred_shape;
*request.mutable_output_handle() = handle;
*request.mutable_request()->mutable_batch_norm_inference_request() =
batch_norm_inference_request;
VLOG(1) << "AddBatchNormInferenceInstruction ("
<< GetVersionedHandleInternal() << "), data handle "
<< handle.handle() << ": "
<< batch_norm_inference_request.ShortDebugString();
return handle;
}
StatusOr<ComputationDataHandle> UserComputation::AddBatchNormGradInstruction(
const BatchNormGradRequest& batch_norm_grad_request) {
tensorflow::mutex_lock lock(mutex_);
@ -1678,6 +1725,25 @@ void ConstantVisitor(const SessionComputation& session_computation,
break;
}
case OpRequest::kBatchNormInferenceRequest: {
const BatchNormInferenceRequest& batch_norm_inference_request =
request.request().batch_norm_inference_request();
ConstantVisitor(session_computation,
batch_norm_inference_request.operand(), visited,
is_constant);
ConstantVisitor(session_computation, batch_norm_inference_request.scale(),
visited, is_constant);
ConstantVisitor(session_computation,
batch_norm_inference_request.offset(), visited,
is_constant);
ConstantVisitor(session_computation, batch_norm_inference_request.mean(),
visited, is_constant);
ConstantVisitor(session_computation,
batch_norm_inference_request.variance(), visited,
is_constant);
break;
}
case OpRequest::kBatchNormGradRequest: {
const BatchNormGradRequest& batch_norm_grad_request =
request.request().batch_norm_grad_request();
@ -2119,6 +2185,18 @@ static void ForEachOperand(
break;
}
case OpRequest::kBatchNormInferenceRequest: {
const BatchNormInferenceRequest& batch_norm_inference_request =
request.request().batch_norm_inference_request();
apply(batch_norm_inference_request.operand());
apply(batch_norm_inference_request.scale());
apply(batch_norm_inference_request.offset());
apply(batch_norm_inference_request.mean());
apply(batch_norm_inference_request.variance());
break;
}
case OpRequest::kBatchNormGradRequest: {
const BatchNormGradRequest& batch_norm_grad_request =
request.request().batch_norm_grad_request();
@ -2647,6 +2725,28 @@ void ComputationLowerer::Visit(
break;
}
case OpRequest::kBatchNormInferenceRequest: {
const BatchNormInferenceRequest& batch_norm_inference_request =
request.request().batch_norm_inference_request();
HloInstruction* operand =
lookup_instruction(batch_norm_inference_request.operand());
HloInstruction* scale =
lookup_instruction(batch_norm_inference_request.scale());
HloInstruction* offset =
lookup_instruction(batch_norm_inference_request.offset());
HloInstruction* mean =
lookup_instruction(batch_norm_inference_request.mean());
HloInstruction* variance =
lookup_instruction(batch_norm_inference_request.variance());
hlo_instruction =
add_instruction(HloInstruction::CreateBatchNormInference(
request.output_shape(), operand, scale, offset, mean, variance,
batch_norm_inference_request.epsilon(),
batch_norm_inference_request.feature_index()));
break;
}
case OpRequest::kBatchNormGradRequest: {
const BatchNormGradRequest& batch_norm_grad_request =
request.request().batch_norm_grad_request();

View File

@ -89,6 +89,10 @@ class UserComputation {
StatusOr<ComputationDataHandle> AddBatchNormTrainingInstruction(
const BatchNormTrainingRequest& batch_norm_training_request);
// Enqueues a batch norm inference instruction onto this user computation.
StatusOr<ComputationDataHandle> AddBatchNormInferenceInstruction(
const BatchNormInferenceRequest& batch_norm_inference_request);
// Enqueues a batch norm grad instruction onto this user computation.
StatusOr<ComputationDataHandle> AddBatchNormGradInstruction(
const BatchNormGradRequest& batch_norm_grad_request);

View File

@ -306,6 +306,109 @@ XLA_TEST_P(BatchNormTest, RandomizedTests) {
ErrorSpec(0.01, 1));
}
XLA_TEST_P(BatchNormTest, RandomizedInferencingTests) {
float epsilon = 0.001;
ComputationBuilder builder(client_, TestName());
const std::vector<int64>& bounds = GetParam().bounds;
Array4D<float> input_array(bounds[0], bounds[1], bounds[2], bounds[3]);
input_array.FillRandom(GetParam().random_value_var,
GetParam().random_value_mean);
const int64 feature_index = GetParam().feature_index;
const int64 num_elements_per_feature =
Product(bounds) / bounds[feature_index];
const int64 feature_bound = bounds[feature_index];
std::vector<float> offset(feature_bound, 1);
std::vector<float> scale(feature_bound, 2);
auto input_squared =
ReferenceUtil::MapArray4D(input_array, [](float a) { return a * a; });
std::vector<int64> reduce_dims;
for (int64 i = 0; i < static_cast<int64>(bounds.size()); ++i) {
if (i != feature_index) {
reduce_dims.push_back(i);
}
}
auto sum =
ReferenceUtil::Reduce4DTo1D(input_array, /*init=*/0.0f, reduce_dims,
[](float a, float b) { return a + b; });
auto sum_squared =
ReferenceUtil::Reduce4DTo1D(*input_squared, /*init=*/0.0f, reduce_dims,
[](float a, float b) { return a + b; });
std::vector<float> mean(feature_bound);
for (int64 i = 0; i < feature_bound; ++i) {
mean[i] = sum[i] / num_elements_per_feature;
}
std::vector<float> mean_square(feature_bound);
for (int64 i = 0; i < feature_bound; ++i) {
mean_square[i] = mean[i] * mean[i];
}
std::vector<float> square_mean(feature_bound);
for (int64 i = 0; i < feature_bound; ++i) {
square_mean[i] = sum_squared[i] / num_elements_per_feature;
}
std::vector<float> var(feature_bound);
for (int64 i = 0; i < feature_bound; ++i) {
var[i] = square_mean[i] - mean_square[i];
}
Array4D<float> mean4D =
*ReferenceUtil::Broadcast1DTo4D(mean, bounds, feature_index);
auto var4D = *ReferenceUtil::Broadcast1DTo4D(var, bounds, feature_index);
auto scale4D = *ReferenceUtil::Broadcast1DTo4D(scale, bounds, feature_index);
auto offset4D =
*ReferenceUtil::Broadcast1DTo4D(offset, bounds, feature_index);
auto normalized = *ReferenceUtil::BatchNorm4D(input_array, mean4D, var4D,
scale4D, offset4D, epsilon);
auto offset_literal = Literal::CreateR1<float>(offset);
auto scale_literal = Literal::CreateR1<float>(scale);
auto mean_literal = Literal::CreateR1<float>(mean);
auto var_literal = Literal::CreateR1<float>(var);
auto input_literal = Literal::CreateR4FromArray4D<float>(input_array);
auto input_activations =
builder.Parameter(0, input_literal->shape(), "input");
auto scale_activations =
builder.Parameter(1, scale_literal->shape(), "offset");
auto offset_activations =
builder.Parameter(2, offset_literal->shape(), "scale");
auto mean_activations = builder.Parameter(3, mean_literal->shape(), "mean");
auto variance_activations =
builder.Parameter(4, var_literal->shape(), "variance");
Array4D<float> expected = normalized;
std::unique_ptr<GlobalData> input_data =
client_->TransferToServer(*input_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> scale_data =
client_->TransferToServer(*scale_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> offset_data =
client_->TransferToServer(*offset_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> mean_data =
client_->TransferToServer(*mean_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> variance_data =
client_->TransferToServer(*var_literal).ConsumeValueOrDie();
builder.BatchNormInference(input_activations, scale_activations,
offset_activations, mean_activations,
variance_activations, epsilon, feature_index);
ComputeAndCompareR4<float>(
&builder, expected,
{input_data.get(), scale_data.get(), offset_data.get(), mean_data.get(),
variance_data.get()},
ErrorSpec(0.01, 1));
}
XLA_TEST_P(BatchNormTest, RandomizedGradTests) {
float epsilon = 0.001;
ComputationBuilder builder(client_, TestName());

View File

@ -491,6 +491,16 @@ message BatchNormTrainingRequest {
int64 feature_index = 5;
}
message BatchNormInferenceRequest {
ComputationDataHandle operand = 1;
ComputationDataHandle scale = 2;
ComputationDataHandle offset = 3;
ComputationDataHandle mean = 4;
ComputationDataHandle variance = 5;
float epsilon = 6;
int64 feature_index = 7;
}
message BatchNormGradRequest {
ComputationDataHandle operand = 1;
ComputationDataHandle scale = 2;
@ -813,7 +823,8 @@ message OpRequest {
OutfeedRequest outfeed_request = 32;
BatchNormTrainingRequest batch_norm_training_request = 35;
BatchNormGradRequest batch_norm_grad_request = 37;
// Next: 38
BatchNormInferenceRequest batch_norm_inference_request = 38;
// Next: 39
}
}