mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
Implement Batchnorm Inference by expanding them into smaller ops.
1. Add batch norm inference support in batchnorm_rewriter 2. Connect xla's batchnorm inference to tf's FusedBatchNorm RELNOTES: n/a PiperOrigin-RevId: 165655351
This commit is contained in:
parent
f0da8bf56b
commit
7359fec792
|
|
@ -63,6 +63,39 @@ class FusedBatchNormTest(XLATestCase):
|
|||
grad_offset = np.sum(grad_y, axis=(0, 1, 2))
|
||||
return grad_x, grad_scale, grad_offset
|
||||
|
||||
def testInference(self):
|
||||
x_shape = [2, 2, 6, 2]
|
||||
scale_shape = [2]
|
||||
x_val = np.random.random_sample(x_shape).astype(np.float32)
|
||||
scale_val = np.random.random_sample(scale_shape).astype(np.float32)
|
||||
|
||||
offset_val = np.random.random_sample(scale_shape).astype(np.float32)
|
||||
data_format = "NHWC"
|
||||
with self.test_session() as sess, self.test_scope():
|
||||
# To avoid constant folding
|
||||
t_val = array_ops.placeholder(np.float32, shape=x_shape, name="x")
|
||||
scale = array_ops.placeholder(np.float32, shape=[2], name="scale")
|
||||
offset = array_ops.placeholder(np.float32, shape=[2], name="offset")
|
||||
epsilon = 0.001
|
||||
y_ref, mean_ref, var_ref = self._reference_training(
|
||||
x_val, scale_val, offset_val, epsilon, data_format)
|
||||
y, mean, variance = nn.fused_batch_norm(
|
||||
t_val,
|
||||
scale,
|
||||
offset,
|
||||
mean=mean_ref,
|
||||
variance=var_ref,
|
||||
epsilon=epsilon,
|
||||
data_format=data_format,
|
||||
is_training=False)
|
||||
|
||||
y_val, _, _ = sess.run(
|
||||
[y, mean,
|
||||
variance], {t_val: x_val,
|
||||
scale: scale_val,
|
||||
offset: offset_val})
|
||||
self.assertAllClose(y_val, y_ref, atol=1e-3)
|
||||
|
||||
def _testLearning(self, use_gradient_checker):
|
||||
x_shape = [2, 2, 6, 2]
|
||||
scale_shape = [2]
|
||||
|
|
|
|||
|
|
@ -39,28 +39,36 @@ class FusedBatchNormOp : public XlaOpKernel {
|
|||
errors::InvalidArgument("Not supported format"));
|
||||
feature_index_ = GetTensorFeatureDimIndex(/*num_dims=*/4, tensor_format);
|
||||
}
|
||||
// TODO(b/62843645): Implement BatchNormInference.
|
||||
OP_REQUIRES(
|
||||
ctx, is_training_,
|
||||
errors::InvalidArgument("Fused batch normalization for inference is "
|
||||
"not supported yet on XLA backend."));
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
xla::ComputationDataHandle output = ctx->builder()->BatchNormTraining(
|
||||
ctx->Input(0), ctx->Input(1), ctx->Input(2), epsilon_, feature_index_);
|
||||
if (is_training_) {
|
||||
xla::ComputationDataHandle output = ctx->builder()->BatchNormTraining(
|
||||
ctx->Input(0), ctx->Input(1), ctx->Input(2), epsilon_,
|
||||
feature_index_);
|
||||
|
||||
// In training mode, outputs the normalized value as well as the calculated
|
||||
// mean and variance.
|
||||
for (int i = 0; i < 3; i++) {
|
||||
ctx->SetOutput(i, ctx->builder()->GetTupleElement(output, i));
|
||||
// In training mode, outputs the normalized value as well as the
|
||||
// calculated mean and variance.
|
||||
for (int i = 0; i < 3; i++) {
|
||||
ctx->SetOutput(i, ctx->builder()->GetTupleElement(output, i));
|
||||
}
|
||||
// Output 3 and 4 for "FusedBatchNorm" are currently marked as "reserved
|
||||
// space 1 & 2". They are used to pass the per-batch mean and
|
||||
// variance to the gradient. Here we maintain the same behavior by setting
|
||||
// them to the mean and variance calculated by BatchNormTraining.
|
||||
ctx->SetOutput(3, ctx->builder()->GetTupleElement(output, 1));
|
||||
ctx->SetOutput(4, ctx->builder()->GetTupleElement(output, 2));
|
||||
} else {
|
||||
xla::ComputationDataHandle output = ctx->builder()->BatchNormInference(
|
||||
ctx->Input(0), ctx->Input(1), ctx->Input(2), ctx->Input(3),
|
||||
ctx->Input(4), epsilon_, feature_index_);
|
||||
ctx->SetOutput(0, output);
|
||||
// Directly send input to output as mean and variance in inference mode.
|
||||
ctx->SetOutput(1, ctx->Input(3));
|
||||
ctx->SetOutput(2, ctx->Input(4));
|
||||
ctx->SetOutput(3, ctx->Input(3));
|
||||
ctx->SetOutput(4, ctx->Input(4));
|
||||
}
|
||||
// Output 3 and 4 for "FusedBatchNorm" are currently marked as "reserved
|
||||
// space 1 & 2". They are used to pass the per-batch mean and
|
||||
// variance to the gradient. Here we maintain the same behavior by setting
|
||||
// them to the mean and variance calculated by BatchNormTraining.
|
||||
ctx->SetOutput(3, ctx->builder()->GetTupleElement(output, 1));
|
||||
ctx->SetOutput(4, ctx->builder()->GetTupleElement(output, 2));
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
|
|||
|
|
@ -1477,9 +1477,29 @@ ComputationDataHandle ComputationBuilder::BatchNormInference(
|
|||
const ComputationDataHandle& operand, const ComputationDataHandle& scale,
|
||||
const ComputationDataHandle& offset, const ComputationDataHandle& mean,
|
||||
const ComputationDataHandle& variance, float epsilon, int64 feature_index) {
|
||||
// TODO(b/62843645): Implement BatchNormInference.
|
||||
NoteError(Unimplemented("BatchNormInference is not implemented yet."));
|
||||
return ComputationDataHandle();
|
||||
if (!first_error_.ok() || !PrepareComputation().ok()) {
|
||||
return ComputationDataHandle();
|
||||
}
|
||||
BatchNormInferenceRequest request;
|
||||
*request.mutable_operand() = operand;
|
||||
*request.mutable_scale() = scale;
|
||||
*request.mutable_offset() = offset;
|
||||
*request.mutable_mean() = mean;
|
||||
*request.mutable_variance() = variance;
|
||||
request.set_epsilon(epsilon);
|
||||
request.set_feature_index(feature_index);
|
||||
|
||||
OpRequest op_request;
|
||||
*op_request.mutable_batch_norm_inference_request() = request;
|
||||
*op_request.mutable_computation() = computation_.handle();
|
||||
AddOpMetadata(&op_request);
|
||||
|
||||
OpResponse response;
|
||||
|
||||
VLOG(2) << "making BatchNormInference request";
|
||||
|
||||
Status s = client_->stub()->Op(&op_request, &response);
|
||||
return ParseOpResponse(s, &response);
|
||||
}
|
||||
|
||||
ComputationDataHandle ComputationBuilder::BatchNormGrad(
|
||||
|
|
|
|||
|
|
@ -56,11 +56,14 @@ class BatchNormRewriterVisitor : public DfsHloVisitorWithDefault {
|
|||
|
||||
Status HandleBatchNormTraining(HloInstruction* batch_norm) override;
|
||||
|
||||
Status HandleBatchNormInference(HloInstruction* batch_norm) override;
|
||||
|
||||
Status HandleBatchNormGrad(HloInstruction* batch_norm) override;
|
||||
|
||||
// Runs the visitor on a computation.
|
||||
static bool Run(HloComputation* computation, bool rewrite_training_op,
|
||||
bool rewrite_grad_op, bool use_fusion);
|
||||
bool rewrite_inference_op, bool rewrite_grad_op,
|
||||
bool use_fusion);
|
||||
|
||||
// Returns whether any batch norm ops were rewritten.
|
||||
const bool changed() const { return changed_; }
|
||||
|
|
@ -70,9 +73,11 @@ class BatchNormRewriterVisitor : public DfsHloVisitorWithDefault {
|
|||
private:
|
||||
explicit BatchNormRewriterVisitor(HloComputation* computation,
|
||||
bool rewrite_training_op,
|
||||
bool rewrite_inference_op,
|
||||
bool rewrite_grad_op, bool use_fusion)
|
||||
: computation_(computation),
|
||||
rewrite_training_op_(rewrite_training_op),
|
||||
rewrite_inference_op_(rewrite_inference_op),
|
||||
rewrite_grad_op_(rewrite_grad_op),
|
||||
use_fusion_(use_fusion) {}
|
||||
|
||||
|
|
@ -94,6 +99,7 @@ class BatchNormRewriterVisitor : public DfsHloVisitorWithDefault {
|
|||
HloComputation* computation_;
|
||||
|
||||
bool rewrite_training_op_;
|
||||
bool rewrite_inference_op_;
|
||||
bool rewrite_grad_op_;
|
||||
bool use_fusion_;
|
||||
|
||||
|
|
@ -126,11 +132,14 @@ class BatchNormRewriterVisitor : public DfsHloVisitorWithDefault {
|
|||
|
||||
bool BatchNormRewriterVisitor::Run(HloComputation* computation,
|
||||
bool rewrite_training_op,
|
||||
bool rewrite_inference_op,
|
||||
bool rewrite_grad_op, bool use_fusion) {
|
||||
BatchNormRewriterVisitor visitor(computation,
|
||||
/*rewrite_training_op=*/rewrite_training_op,
|
||||
/*rewrite_grad_op=*/rewrite_grad_op,
|
||||
/*use_fusion=*/use_fusion);
|
||||
BatchNormRewriterVisitor visitor(
|
||||
computation,
|
||||
/*rewrite_training_op=*/rewrite_training_op,
|
||||
/*rewrite_inference_op=*/rewrite_inference_op,
|
||||
/*rewrite_grad_op=*/rewrite_grad_op,
|
||||
/*use_fusion=*/use_fusion);
|
||||
TF_CHECK_OK(computation->Accept(&visitor));
|
||||
return visitor.changed_;
|
||||
}
|
||||
|
|
@ -268,6 +277,82 @@ Status BatchNormRewriterVisitor::HandleBatchNormTraining(
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status BatchNormRewriterVisitor::HandleBatchNormInference(
|
||||
HloInstruction* batch_norm) {
|
||||
if (!rewrite_inference_op_) {
|
||||
return Status::OK();
|
||||
}
|
||||
// Expand batch norm inference into smaller HLO ops.
|
||||
HloInstruction* operand = batch_norm->mutable_operand(0);
|
||||
const Shape operand_shape = operand->shape();
|
||||
int64 feature_index = batch_norm->feature_index();
|
||||
|
||||
HloInstruction* scale = batch_norm->mutable_operand(1);
|
||||
HloInstruction* offset = batch_norm->mutable_operand(2);
|
||||
HloInstruction* mean = batch_norm->mutable_operand(3);
|
||||
HloInstruction* var = batch_norm->mutable_operand(4);
|
||||
const Shape feature_shape = scale->shape();
|
||||
|
||||
auto epsilon = computation_->AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0(batch_norm->epsilon())));
|
||||
|
||||
std::vector<int64> dimensions_without_feature;
|
||||
|
||||
for (int64 i = 0; i < ShapeUtil::Rank(operand_shape); ++i) {
|
||||
if (i != feature_index) {
|
||||
dimensions_without_feature.push_back(i);
|
||||
}
|
||||
}
|
||||
|
||||
auto scale_broadcasted = computation_->AddInstruction(
|
||||
HloInstruction::CreateBroadcast(operand_shape, scale, {feature_index}));
|
||||
|
||||
auto offset_broadcasted = computation_->AddInstruction(
|
||||
HloInstruction::CreateBroadcast(operand_shape, offset, {feature_index}));
|
||||
|
||||
auto mean_broadcasted = computation_->AddInstruction(
|
||||
HloInstruction::CreateBroadcast(operand_shape, mean, {feature_index}));
|
||||
|
||||
auto var_broadcasted = computation_->AddInstruction(
|
||||
HloInstruction::CreateBroadcast(operand_shape, var, {feature_index}));
|
||||
|
||||
// Var[X] + epsilon.
|
||||
auto var_add_epsilon =
|
||||
computation_->AddInstruction(HloInstruction::CreateBinary(
|
||||
operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon));
|
||||
|
||||
auto neg_half = computation_->AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0(-0.5f)));
|
||||
|
||||
// 1 / Sqrt[Var[X] + epsilon].
|
||||
auto rsqrt_var_add_epsilon =
|
||||
computation_->AddInstruction(HloInstruction::CreateBinary(
|
||||
operand_shape, HloOpcode::kPower, var_add_epsilon, neg_half));
|
||||
|
||||
// X - E[X].
|
||||
auto operand_minus_mean =
|
||||
computation_->AddInstruction(HloInstruction::CreateBinary(
|
||||
operand_shape, HloOpcode::kSubtract, operand, mean_broadcasted));
|
||||
|
||||
// (X - E[X]) / Sqrt[Var[X] + epsilon].
|
||||
auto normalized = computation_->AddInstruction(
|
||||
HloInstruction::CreateBinary(operand_shape, HloOpcode::kMultiply,
|
||||
operand_minus_mean, rsqrt_var_add_epsilon));
|
||||
|
||||
// (X - E[X]) / Sqrt[Var[X] + epsilon] * scale.
|
||||
auto scaled_normalized =
|
||||
computation_->AddInstruction(HloInstruction::CreateBinary(
|
||||
operand_shape, HloOpcode::kMultiply, normalized, scale_broadcasted));
|
||||
|
||||
// (X - E[X]) / Sqrt[Var[X] + epsilon] * scale + offset.
|
||||
auto shifted_normalized = HloInstruction::CreateBinary(
|
||||
operand_shape, HloOpcode::kAdd, scaled_normalized, offset_broadcasted);
|
||||
|
||||
TF_CHECK_OK(
|
||||
ReplaceWithNewInstruction(batch_norm, std::move(shifted_normalized)));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status BatchNormRewriterVisitor::HandleBatchNormGrad(
|
||||
HloInstruction* batch_norm) {
|
||||
// Use the following formulas to calculate gradients:
|
||||
|
|
@ -457,7 +542,8 @@ StatusOr<bool> BatchNormRewriter::Run(HloModule* module) {
|
|||
}
|
||||
for (auto& comp : computations) {
|
||||
if (BatchNormRewriterVisitor::Run(comp, rewrite_training_op_,
|
||||
rewrite_grad_op_, use_fusion_)) {
|
||||
rewrite_inference_op_, rewrite_grad_op_,
|
||||
use_fusion_)) {
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -30,8 +30,10 @@ class BatchNormRewriter : public HloPassInterface {
|
|||
public:
|
||||
// When use_fusion is set, a multi-output fusion node is created.
|
||||
BatchNormRewriter(bool rewrite_training_op = false,
|
||||
bool rewrite_inference_op = false,
|
||||
bool rewrite_grad_op = false, bool use_fusion = true)
|
||||
: rewrite_training_op_(rewrite_training_op),
|
||||
rewrite_inference_op_(rewrite_inference_op),
|
||||
rewrite_grad_op_(rewrite_grad_op),
|
||||
use_fusion_(use_fusion) {}
|
||||
~BatchNormRewriter() = default;
|
||||
|
|
@ -43,6 +45,7 @@ class BatchNormRewriter : public HloPassInterface {
|
|||
|
||||
private:
|
||||
bool rewrite_training_op_;
|
||||
bool rewrite_inference_op_;
|
||||
bool rewrite_grad_op_;
|
||||
bool use_fusion_;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -64,6 +64,7 @@ TEST_F(BatchNormRewriterTest, BatchNormTraining) {
|
|||
HloInstruction* root = computation->root_instruction();
|
||||
EXPECT_EQ(root->opcode(), HloOpcode::kBatchNormTraining);
|
||||
BatchNormRewriter rewriter(/*rewrite_training_op=*/true,
|
||||
/*rewrite_inference_op=*/true,
|
||||
/*rewrite_grad_op=*/true);
|
||||
ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie());
|
||||
root = computation->root_instruction();
|
||||
|
|
@ -105,6 +106,7 @@ TEST_F(BatchNormRewriterTest, BatchNormGrad) {
|
|||
HloInstruction* root = computation->root_instruction();
|
||||
EXPECT_EQ(root->opcode(), HloOpcode::kBatchNormGrad);
|
||||
BatchNormRewriter rewriter(/*rewrite_training_op=*/true,
|
||||
/*rewrite_inference_op=*/true,
|
||||
/*rewrite_grad_op=*/true);
|
||||
ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie());
|
||||
root = computation->root_instruction();
|
||||
|
|
|
|||
|
|
@ -260,6 +260,7 @@ Status CpuCompiler::RunHloPasses(HloModule* module) {
|
|||
pipeline.AddPass<HloPassFix<HloPassPipeline>>("simplification");
|
||||
pass.AddPass<BatchNormRewriter>(
|
||||
/*rewrite_training_op=*/true,
|
||||
/*rewrite_inference_op=*/true,
|
||||
/*rewrite_grad_op=*/true,
|
||||
/*use_fusion=*/false);
|
||||
pass.AddPass<AlgebraicSimplifier>(
|
||||
|
|
|
|||
|
|
@ -228,6 +228,9 @@ class DfsHloVisitor {
|
|||
|
||||
virtual Status HandleBatchNormTraining(HloInstruction* batchNormTraining) = 0;
|
||||
|
||||
virtual Status HandleBatchNormInference(
|
||||
HloInstruction* batchNormInference) = 0;
|
||||
|
||||
virtual Status HandleBatchNormGrad(HloInstruction* batchNormGrad) = 0;
|
||||
|
||||
// Invoked to inform the visitor that the traversal has completed, and that
|
||||
|
|
|
|||
|
|
@ -54,6 +54,10 @@ class DfsHloVisitorWithDefault : public DfsHloVisitor {
|
|||
return DefaultAction(hlo);
|
||||
}
|
||||
|
||||
Status HandleBatchNormInference(HloInstruction* hlo) override {
|
||||
return DefaultAction(hlo);
|
||||
}
|
||||
|
||||
Status HandleBatchNormGrad(HloInstruction* hlo) override {
|
||||
return DefaultAction(hlo);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -135,6 +135,7 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module,
|
|||
// instead.
|
||||
pass.AddPass<BatchNormRewriter>(
|
||||
/*rewrite_training_op=*/true,
|
||||
/*rewrite_inference_op=*/true,
|
||||
/*rewrite_grad_op=*/true,
|
||||
/*use_fusion=*/false);
|
||||
pass.AddPass<AlgebraicSimplifier>(
|
||||
|
|
|
|||
|
|
@ -374,6 +374,12 @@ Status HloCostAnalysis::HandleBatchNormTraining(
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HloCostAnalysis::HandleBatchNormInference(
|
||||
HloInstruction* batchNormInference) {
|
||||
// TODO(b/62294698): Implement cost analysis for batch-norm-inference.
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HloCostAnalysis::HandleBatchNormGrad(HloInstruction* batchNormGrad) {
|
||||
// TODO(b/62294698): Implement cost analysis for batch-norm-grad.
|
||||
return Status::OK();
|
||||
|
|
|
|||
|
|
@ -89,6 +89,7 @@ class HloCostAnalysis : public DfsHloVisitor {
|
|||
tensorflow::gtl::ArraySlice<int64> dimensions,
|
||||
HloComputation* function_handle) override;
|
||||
Status HandleBatchNormTraining(HloInstruction* batchNormTraining) override;
|
||||
Status HandleBatchNormInference(HloInstruction* batchNormInference) override;
|
||||
Status HandleBatchNormGrad(HloInstruction* batchNormGrad) override;
|
||||
Status HandleFusion(HloInstruction* fusion) override;
|
||||
Status HandleCall(HloInstruction* call) override;
|
||||
|
|
|
|||
|
|
@ -742,6 +742,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
|
|||
case HloOpcode::kParameter:
|
||||
return kOrange;
|
||||
case HloOpcode::kBatchNormTraining:
|
||||
case HloOpcode::kBatchNormInference:
|
||||
case HloOpcode::kBatchNormGrad:
|
||||
case HloOpcode::kReduce:
|
||||
case HloOpcode::kSelectAndScatter:
|
||||
|
|
|
|||
|
|
@ -406,6 +406,23 @@ HloInstruction::CreateBatchNormTraining(const Shape& shape,
|
|||
return instruction;
|
||||
}
|
||||
|
||||
/* static */ std::unique_ptr<HloInstruction>
|
||||
HloInstruction::CreateBatchNormInference(
|
||||
const Shape& shape, HloInstruction* operand, HloInstruction* scale,
|
||||
HloInstruction* offset, HloInstruction* mean, HloInstruction* variance,
|
||||
float epsilon, int64 feature_index) {
|
||||
auto instruction =
|
||||
WrapUnique(new HloInstruction(HloOpcode::kBatchNormInference, shape));
|
||||
instruction->AppendOperand(operand);
|
||||
instruction->AppendOperand(scale);
|
||||
instruction->AppendOperand(offset);
|
||||
instruction->AppendOperand(mean);
|
||||
instruction->AppendOperand(variance);
|
||||
instruction->epsilon_ = epsilon;
|
||||
instruction->feature_index_ = feature_index;
|
||||
return instruction;
|
||||
}
|
||||
|
||||
/* static */ std::unique_ptr<HloInstruction>
|
||||
HloInstruction::CreateBatchNormGrad(const Shape& shape, HloInstruction* operand,
|
||||
HloInstruction* scale, HloInstruction* mean,
|
||||
|
|
@ -1065,6 +1082,12 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
|
|||
return CreateBatchNormTraining(shape, new_operands[0], new_operands[1],
|
||||
new_operands[2], epsilon(),
|
||||
feature_index());
|
||||
|
||||
case HloOpcode::kBatchNormInference:
|
||||
CHECK_EQ(new_operands.size(), 5);
|
||||
return CreateBatchNormInference(
|
||||
shape, new_operands[0], new_operands[1], new_operands[2],
|
||||
new_operands[3], new_operands[4], epsilon(), feature_index());
|
||||
case HloOpcode::kInfeed:
|
||||
CHECK_EQ(new_operands.size(), 0);
|
||||
return CreateInfeed(shape, infeed_config());
|
||||
|
|
@ -1355,6 +1378,7 @@ bool HloInstruction::IdenticalSlowPath(
|
|||
ShapeUtil::Compatible(shape(), other.shape());
|
||||
|
||||
case HloOpcode::kBatchNormTraining:
|
||||
case HloOpcode::kBatchNormInference:
|
||||
case HloOpcode::kBatchNormGrad:
|
||||
return feature_index() == other.feature_index() &&
|
||||
epsilon() == other.epsilon();
|
||||
|
|
@ -1952,6 +1976,8 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) {
|
|||
return visitor->HandleAbs(this, operands_[0]);
|
||||
case HloOpcode::kBatchNormTraining:
|
||||
return visitor->HandleBatchNormTraining(this);
|
||||
case HloOpcode::kBatchNormInference:
|
||||
return visitor->HandleBatchNormInference(this);
|
||||
case HloOpcode::kBatchNormGrad:
|
||||
return visitor->HandleBatchNormGrad(this);
|
||||
case HloOpcode::kSign:
|
||||
|
|
|
|||
|
|
@ -224,6 +224,12 @@ class HloInstruction {
|
|||
const Shape& shape, HloInstruction* operand, HloInstruction* scale,
|
||||
HloInstruction* offset, float epsilon, int64 feature_index);
|
||||
|
||||
// Creates a batch-norm-inference instruction.
|
||||
static std::unique_ptr<HloInstruction> CreateBatchNormInference(
|
||||
const Shape& shape, HloInstruction* operand, HloInstruction* scale,
|
||||
HloInstruction* offset, HloInstruction* mean, HloInstruction* variance,
|
||||
float epsilon, int64 feature_index);
|
||||
|
||||
// Creates a batch-norm-grad instruction.
|
||||
static std::unique_ptr<HloInstruction> CreateBatchNormGrad(
|
||||
const Shape& shape, HloInstruction* operand, HloInstruction* scale,
|
||||
|
|
|
|||
|
|
@ -33,6 +33,8 @@ string HloOpcodeString(HloOpcode opcode) {
|
|||
return "add";
|
||||
case HloOpcode::kBatchNormTraining:
|
||||
return "batch-norm-training";
|
||||
case HloOpcode::kBatchNormInference:
|
||||
return "batch-norm-inference";
|
||||
case HloOpcode::kBatchNormGrad:
|
||||
return "batch-norm-grad";
|
||||
case HloOpcode::kBitcast:
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ enum class HloOpcode {
|
|||
kAbs,
|
||||
kAdd,
|
||||
kBatchNormTraining,
|
||||
kBatchNormInference,
|
||||
kBatchNormGrad,
|
||||
kBitcast,
|
||||
kBroadcast,
|
||||
|
|
|
|||
|
|
@ -78,6 +78,7 @@ namespace xla {
|
|||
|
||||
// Expensive instructions.
|
||||
case HloOpcode::kBatchNormTraining:
|
||||
case HloOpcode::kBatchNormInference:
|
||||
case HloOpcode::kBatchNormGrad:
|
||||
case HloOpcode::kCall:
|
||||
case HloOpcode::kConvolution:
|
||||
|
|
|
|||
|
|
@ -1211,6 +1211,10 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) {
|
|||
handle_status = computation->AddBatchNormTrainingInstruction(
|
||||
arg->batch_norm_training_request());
|
||||
break;
|
||||
case OpRequest::kBatchNormInferenceRequest:
|
||||
handle_status = computation->AddBatchNormInferenceInstruction(
|
||||
arg->batch_norm_inference_request());
|
||||
break;
|
||||
case OpRequest::kBatchNormGradRequest:
|
||||
handle_status = computation->AddBatchNormGradInstruction(
|
||||
arg->batch_norm_grad_request());
|
||||
|
|
|
|||
|
|
@ -885,6 +885,150 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
|
|||
output_shape_for_mean_and_var});
|
||||
}
|
||||
|
||||
/* static */ StatusOr<Shape> ShapeInference::InferBatchNormInferenceShape(
|
||||
const Shape& operand_shape, const Shape& offset_shape,
|
||||
const Shape& scale_shape, const Shape& mean_shape,
|
||||
const Shape& variance_shape, int64 feature_index) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
ExpectNotTupleOrOpaque(operand_shape, "operand of batch norm inference"));
|
||||
TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
|
||||
offset_shape, "offset input of batch norm inference"));
|
||||
TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
|
||||
scale_shape, "scale input of batch norm inference"));
|
||||
|
||||
TF_RET_CHECK(ShapeUtil::ValidateShape(operand_shape) ==
|
||||
tensorflow::Status::OK());
|
||||
TF_RET_CHECK(ShapeUtil::ValidateShape(offset_shape) ==
|
||||
tensorflow::Status::OK());
|
||||
TF_RET_CHECK(ShapeUtil::ValidateShape(scale_shape) ==
|
||||
tensorflow::Status::OK());
|
||||
TF_RET_CHECK(ShapeUtil::ValidateShape(mean_shape) ==
|
||||
tensorflow::Status::OK());
|
||||
TF_RET_CHECK(ShapeUtil::ValidateShape(variance_shape) ==
|
||||
tensorflow::Status::OK());
|
||||
|
||||
if (feature_index >= ShapeUtil::Rank(operand_shape)) {
|
||||
return InvalidArgument(
|
||||
"Expected feature_index of batch-norm-inference to be "
|
||||
"smaller than the rank of operand_shape; "
|
||||
"got feature_index %lld, and rank %lld",
|
||||
feature_index, ShapeUtil::Rank(operand_shape));
|
||||
}
|
||||
|
||||
if (feature_index < 0) {
|
||||
return InvalidArgument(
|
||||
"Expected feature_index of batch-norm-inference to "
|
||||
"be a non-negative number, got %lld",
|
||||
feature_index);
|
||||
}
|
||||
|
||||
if (ShapeUtil::Rank(operand_shape) < 1) {
|
||||
return InvalidArgument(
|
||||
"Expected the rank of operand to "
|
||||
"batch-norm-inference to be at least 1; got %lld",
|
||||
ShapeUtil::Rank(operand_shape));
|
||||
}
|
||||
|
||||
if (ShapeUtil::Rank(offset_shape) != 1) {
|
||||
return InvalidArgument(
|
||||
"Offset input of batch-norm-inference must have"
|
||||
" rank 1, but has rank %lld.",
|
||||
ShapeUtil::Rank(offset_shape));
|
||||
}
|
||||
|
||||
if (ShapeUtil::Rank(scale_shape) != 1) {
|
||||
return InvalidArgument(
|
||||
"Scale input of batch-norm-inference must have"
|
||||
" rank 1, but has rank %lld.",
|
||||
ShapeUtil::Rank(scale_shape));
|
||||
}
|
||||
|
||||
if (!ShapeUtil::ElementIsFloating(operand_shape)) {
|
||||
return InvalidArgument(
|
||||
"The operand to batch-norm-inference must have a floating point "
|
||||
"element type, but the shape is %s",
|
||||
PrimitiveType_Name(operand_shape.element_type()).c_str());
|
||||
}
|
||||
|
||||
if (!ShapeUtil::SameElementType(offset_shape, operand_shape)) {
|
||||
return InvalidArgument(
|
||||
"The inputs should have the same element type for "
|
||||
"batch-norm-inference, "
|
||||
"but the shape of offset factor is %s "
|
||||
"and the shape of operand is %s",
|
||||
PrimitiveType_Name(offset_shape.element_type()).c_str(),
|
||||
PrimitiveType_Name(operand_shape.element_type()).c_str());
|
||||
}
|
||||
|
||||
if (!ShapeUtil::SameElementType(scale_shape, operand_shape)) {
|
||||
return InvalidArgument(
|
||||
"The inputs should have the same element type for "
|
||||
"batch-norm-inference, "
|
||||
"but the shape of scale factor is %s "
|
||||
"and the shape of operand is %s",
|
||||
PrimitiveType_Name(scale_shape.element_type()).c_str(),
|
||||
PrimitiveType_Name(operand_shape.element_type()).c_str());
|
||||
}
|
||||
|
||||
if (!ShapeUtil::SameElementType(mean_shape, operand_shape)) {
|
||||
return InvalidArgument(
|
||||
"The inputs should have the same element type for "
|
||||
"batch-norm-inference, "
|
||||
"but the shape of mean is %s "
|
||||
"and the shape of operand is %s",
|
||||
PrimitiveType_Name(mean_shape.element_type()).c_str(),
|
||||
PrimitiveType_Name(operand_shape.element_type()).c_str());
|
||||
}
|
||||
|
||||
if (!ShapeUtil::SameElementType(variance_shape, operand_shape)) {
|
||||
return InvalidArgument(
|
||||
"The inputs should have the same element type for "
|
||||
"batch-norm-inference, "
|
||||
"but the shape of variance is %s "
|
||||
"and the shape of operand is %s",
|
||||
PrimitiveType_Name(mean_shape.element_type()).c_str(),
|
||||
PrimitiveType_Name(variance_shape.element_type()).c_str());
|
||||
}
|
||||
|
||||
const int64 feature_count = operand_shape.dimensions(feature_index);
|
||||
Shape output_shape_for_mean_and_var =
|
||||
ShapeUtil::MakeShape(operand_shape.element_type(), {feature_count});
|
||||
|
||||
if (ShapeUtil::GetDimension(offset_shape, 0) != feature_count) {
|
||||
return InvalidArgument(
|
||||
"The size of offset factor should be the same as feature count,"
|
||||
"but the size of offset factor is %lld "
|
||||
"and the feature count is %lld",
|
||||
ShapeUtil::GetDimension(offset_shape, 0), feature_count);
|
||||
}
|
||||
|
||||
if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) {
|
||||
return InvalidArgument(
|
||||
"The size of scale factor should be the same as feature count,"
|
||||
"but the size of scale factor is %lld "
|
||||
"and the feature count is %lld",
|
||||
ShapeUtil::GetDimension(scale_shape, 0), feature_count);
|
||||
}
|
||||
|
||||
if (ShapeUtil::GetDimension(mean_shape, 0) != feature_count) {
|
||||
return InvalidArgument(
|
||||
"The size of mean should be the same as feature count,"
|
||||
"but the size of mean is %lld "
|
||||
"and the feature count is %lld",
|
||||
ShapeUtil::GetDimension(mean_shape, 0), feature_count);
|
||||
}
|
||||
|
||||
if (ShapeUtil::GetDimension(variance_shape, 0) != feature_count) {
|
||||
return InvalidArgument(
|
||||
"The size of variance should be the same as feature count,"
|
||||
"but the size of variance is %lld "
|
||||
"and the feature count is %lld",
|
||||
ShapeUtil::GetDimension(variance_shape, 0), feature_count);
|
||||
}
|
||||
|
||||
return operand_shape;
|
||||
}
|
||||
|
||||
/* static */ StatusOr<Shape> ShapeInference::InferBatchNormGradShape(
|
||||
const Shape& operand_shape, const Shape& scale_shape,
|
||||
const Shape& mean_shape, const Shape& var_shape,
|
||||
|
|
|
|||
|
|
@ -71,6 +71,13 @@ class ShapeInference {
|
|||
const Shape& scale_shape,
|
||||
int64 feature_index);
|
||||
|
||||
// Infers the shape produced by InferBatchNormInference with the given
|
||||
// operands.
|
||||
static StatusOr<Shape> InferBatchNormInferenceShape(
|
||||
const Shape& operand_shape, const Shape& offset_shape,
|
||||
const Shape& scale_shape, const Shape& mean_shape,
|
||||
const Shape& variance_shape, int64 feature_index);
|
||||
|
||||
// Infers the shape produced by InferBatchNormGrad with the given operands.
|
||||
static StatusOr<Shape> InferBatchNormGradShape(const Shape& operand_shape,
|
||||
const Shape& scale_shape,
|
||||
|
|
|
|||
|
|
@ -507,6 +507,53 @@ UserComputation::AddBatchNormTrainingInstruction(
|
|||
return handle;
|
||||
}
|
||||
|
||||
StatusOr<ComputationDataHandle>
|
||||
UserComputation::AddBatchNormInferenceInstruction(
|
||||
const BatchNormInferenceRequest& batch_norm_inference_request) {
|
||||
tensorflow::mutex_lock lock(mutex_);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(const OperationRequest* operand,
|
||||
LookUpRequest(batch_norm_inference_request.operand()));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(const OperationRequest* scale,
|
||||
LookUpRequest(batch_norm_inference_request.scale()));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(const OperationRequest* offset,
|
||||
LookUpRequest(batch_norm_inference_request.offset()));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(const OperationRequest* mean,
|
||||
LookUpRequest(batch_norm_inference_request.mean()));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(const OperationRequest* variance,
|
||||
LookUpRequest(batch_norm_inference_request.variance()));
|
||||
|
||||
ComputationDataHandle handle = CreateComputationDataHandle();
|
||||
|
||||
OperationRequest& request =
|
||||
(*session_computation_.mutable_requests())[handle.handle()];
|
||||
|
||||
TF_ASSIGN_OR_RETURN(Shape inferred_shape,
|
||||
ShapeInference::InferBatchNormInferenceShape(
|
||||
operand->output_shape(), scale->output_shape(),
|
||||
offset->output_shape(), mean->output_shape(),
|
||||
variance->output_shape(),
|
||||
batch_norm_inference_request.feature_index()));
|
||||
|
||||
*request.mutable_output_shape() = inferred_shape;
|
||||
|
||||
*request.mutable_output_handle() = handle;
|
||||
|
||||
*request.mutable_request()->mutable_batch_norm_inference_request() =
|
||||
batch_norm_inference_request;
|
||||
|
||||
VLOG(1) << "AddBatchNormInferenceInstruction ("
|
||||
<< GetVersionedHandleInternal() << "), data handle "
|
||||
<< handle.handle() << ": "
|
||||
<< batch_norm_inference_request.ShortDebugString();
|
||||
|
||||
return handle;
|
||||
}
|
||||
|
||||
StatusOr<ComputationDataHandle> UserComputation::AddBatchNormGradInstruction(
|
||||
const BatchNormGradRequest& batch_norm_grad_request) {
|
||||
tensorflow::mutex_lock lock(mutex_);
|
||||
|
|
@ -1678,6 +1725,25 @@ void ConstantVisitor(const SessionComputation& session_computation,
|
|||
break;
|
||||
}
|
||||
|
||||
case OpRequest::kBatchNormInferenceRequest: {
|
||||
const BatchNormInferenceRequest& batch_norm_inference_request =
|
||||
request.request().batch_norm_inference_request();
|
||||
ConstantVisitor(session_computation,
|
||||
batch_norm_inference_request.operand(), visited,
|
||||
is_constant);
|
||||
ConstantVisitor(session_computation, batch_norm_inference_request.scale(),
|
||||
visited, is_constant);
|
||||
ConstantVisitor(session_computation,
|
||||
batch_norm_inference_request.offset(), visited,
|
||||
is_constant);
|
||||
ConstantVisitor(session_computation, batch_norm_inference_request.mean(),
|
||||
visited, is_constant);
|
||||
ConstantVisitor(session_computation,
|
||||
batch_norm_inference_request.variance(), visited,
|
||||
is_constant);
|
||||
break;
|
||||
}
|
||||
|
||||
case OpRequest::kBatchNormGradRequest: {
|
||||
const BatchNormGradRequest& batch_norm_grad_request =
|
||||
request.request().batch_norm_grad_request();
|
||||
|
|
@ -2119,6 +2185,18 @@ static void ForEachOperand(
|
|||
break;
|
||||
}
|
||||
|
||||
case OpRequest::kBatchNormInferenceRequest: {
|
||||
const BatchNormInferenceRequest& batch_norm_inference_request =
|
||||
request.request().batch_norm_inference_request();
|
||||
|
||||
apply(batch_norm_inference_request.operand());
|
||||
apply(batch_norm_inference_request.scale());
|
||||
apply(batch_norm_inference_request.offset());
|
||||
apply(batch_norm_inference_request.mean());
|
||||
apply(batch_norm_inference_request.variance());
|
||||
break;
|
||||
}
|
||||
|
||||
case OpRequest::kBatchNormGradRequest: {
|
||||
const BatchNormGradRequest& batch_norm_grad_request =
|
||||
request.request().batch_norm_grad_request();
|
||||
|
|
@ -2647,6 +2725,28 @@ void ComputationLowerer::Visit(
|
|||
break;
|
||||
}
|
||||
|
||||
case OpRequest::kBatchNormInferenceRequest: {
|
||||
const BatchNormInferenceRequest& batch_norm_inference_request =
|
||||
request.request().batch_norm_inference_request();
|
||||
HloInstruction* operand =
|
||||
lookup_instruction(batch_norm_inference_request.operand());
|
||||
HloInstruction* scale =
|
||||
lookup_instruction(batch_norm_inference_request.scale());
|
||||
HloInstruction* offset =
|
||||
lookup_instruction(batch_norm_inference_request.offset());
|
||||
HloInstruction* mean =
|
||||
lookup_instruction(batch_norm_inference_request.mean());
|
||||
HloInstruction* variance =
|
||||
lookup_instruction(batch_norm_inference_request.variance());
|
||||
|
||||
hlo_instruction =
|
||||
add_instruction(HloInstruction::CreateBatchNormInference(
|
||||
request.output_shape(), operand, scale, offset, mean, variance,
|
||||
batch_norm_inference_request.epsilon(),
|
||||
batch_norm_inference_request.feature_index()));
|
||||
break;
|
||||
}
|
||||
|
||||
case OpRequest::kBatchNormGradRequest: {
|
||||
const BatchNormGradRequest& batch_norm_grad_request =
|
||||
request.request().batch_norm_grad_request();
|
||||
|
|
|
|||
|
|
@ -89,6 +89,10 @@ class UserComputation {
|
|||
StatusOr<ComputationDataHandle> AddBatchNormTrainingInstruction(
|
||||
const BatchNormTrainingRequest& batch_norm_training_request);
|
||||
|
||||
// Enqueues a batch norm inference instruction onto this user computation.
|
||||
StatusOr<ComputationDataHandle> AddBatchNormInferenceInstruction(
|
||||
const BatchNormInferenceRequest& batch_norm_inference_request);
|
||||
|
||||
// Enqueues a batch norm grad instruction onto this user computation.
|
||||
StatusOr<ComputationDataHandle> AddBatchNormGradInstruction(
|
||||
const BatchNormGradRequest& batch_norm_grad_request);
|
||||
|
|
|
|||
|
|
@ -306,6 +306,109 @@ XLA_TEST_P(BatchNormTest, RandomizedTests) {
|
|||
ErrorSpec(0.01, 1));
|
||||
}
|
||||
|
||||
XLA_TEST_P(BatchNormTest, RandomizedInferencingTests) {
|
||||
float epsilon = 0.001;
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
const std::vector<int64>& bounds = GetParam().bounds;
|
||||
Array4D<float> input_array(bounds[0], bounds[1], bounds[2], bounds[3]);
|
||||
input_array.FillRandom(GetParam().random_value_var,
|
||||
GetParam().random_value_mean);
|
||||
|
||||
const int64 feature_index = GetParam().feature_index;
|
||||
const int64 num_elements_per_feature =
|
||||
Product(bounds) / bounds[feature_index];
|
||||
const int64 feature_bound = bounds[feature_index];
|
||||
std::vector<float> offset(feature_bound, 1);
|
||||
std::vector<float> scale(feature_bound, 2);
|
||||
|
||||
auto input_squared =
|
||||
ReferenceUtil::MapArray4D(input_array, [](float a) { return a * a; });
|
||||
std::vector<int64> reduce_dims;
|
||||
for (int64 i = 0; i < static_cast<int64>(bounds.size()); ++i) {
|
||||
if (i != feature_index) {
|
||||
reduce_dims.push_back(i);
|
||||
}
|
||||
}
|
||||
|
||||
auto sum =
|
||||
ReferenceUtil::Reduce4DTo1D(input_array, /*init=*/0.0f, reduce_dims,
|
||||
[](float a, float b) { return a + b; });
|
||||
|
||||
auto sum_squared =
|
||||
ReferenceUtil::Reduce4DTo1D(*input_squared, /*init=*/0.0f, reduce_dims,
|
||||
[](float a, float b) { return a + b; });
|
||||
|
||||
std::vector<float> mean(feature_bound);
|
||||
|
||||
for (int64 i = 0; i < feature_bound; ++i) {
|
||||
mean[i] = sum[i] / num_elements_per_feature;
|
||||
}
|
||||
|
||||
std::vector<float> mean_square(feature_bound);
|
||||
for (int64 i = 0; i < feature_bound; ++i) {
|
||||
mean_square[i] = mean[i] * mean[i];
|
||||
}
|
||||
|
||||
std::vector<float> square_mean(feature_bound);
|
||||
for (int64 i = 0; i < feature_bound; ++i) {
|
||||
square_mean[i] = sum_squared[i] / num_elements_per_feature;
|
||||
}
|
||||
|
||||
std::vector<float> var(feature_bound);
|
||||
for (int64 i = 0; i < feature_bound; ++i) {
|
||||
var[i] = square_mean[i] - mean_square[i];
|
||||
}
|
||||
|
||||
Array4D<float> mean4D =
|
||||
*ReferenceUtil::Broadcast1DTo4D(mean, bounds, feature_index);
|
||||
auto var4D = *ReferenceUtil::Broadcast1DTo4D(var, bounds, feature_index);
|
||||
auto scale4D = *ReferenceUtil::Broadcast1DTo4D(scale, bounds, feature_index);
|
||||
auto offset4D =
|
||||
*ReferenceUtil::Broadcast1DTo4D(offset, bounds, feature_index);
|
||||
|
||||
auto normalized = *ReferenceUtil::BatchNorm4D(input_array, mean4D, var4D,
|
||||
scale4D, offset4D, epsilon);
|
||||
|
||||
auto offset_literal = Literal::CreateR1<float>(offset);
|
||||
auto scale_literal = Literal::CreateR1<float>(scale);
|
||||
auto mean_literal = Literal::CreateR1<float>(mean);
|
||||
auto var_literal = Literal::CreateR1<float>(var);
|
||||
auto input_literal = Literal::CreateR4FromArray4D<float>(input_array);
|
||||
|
||||
auto input_activations =
|
||||
builder.Parameter(0, input_literal->shape(), "input");
|
||||
auto scale_activations =
|
||||
builder.Parameter(1, scale_literal->shape(), "offset");
|
||||
auto offset_activations =
|
||||
builder.Parameter(2, offset_literal->shape(), "scale");
|
||||
auto mean_activations = builder.Parameter(3, mean_literal->shape(), "mean");
|
||||
auto variance_activations =
|
||||
builder.Parameter(4, var_literal->shape(), "variance");
|
||||
|
||||
Array4D<float> expected = normalized;
|
||||
|
||||
std::unique_ptr<GlobalData> input_data =
|
||||
client_->TransferToServer(*input_literal).ConsumeValueOrDie();
|
||||
std::unique_ptr<GlobalData> scale_data =
|
||||
client_->TransferToServer(*scale_literal).ConsumeValueOrDie();
|
||||
std::unique_ptr<GlobalData> offset_data =
|
||||
client_->TransferToServer(*offset_literal).ConsumeValueOrDie();
|
||||
std::unique_ptr<GlobalData> mean_data =
|
||||
client_->TransferToServer(*mean_literal).ConsumeValueOrDie();
|
||||
std::unique_ptr<GlobalData> variance_data =
|
||||
client_->TransferToServer(*var_literal).ConsumeValueOrDie();
|
||||
|
||||
builder.BatchNormInference(input_activations, scale_activations,
|
||||
offset_activations, mean_activations,
|
||||
variance_activations, epsilon, feature_index);
|
||||
|
||||
ComputeAndCompareR4<float>(
|
||||
&builder, expected,
|
||||
{input_data.get(), scale_data.get(), offset_data.get(), mean_data.get(),
|
||||
variance_data.get()},
|
||||
ErrorSpec(0.01, 1));
|
||||
}
|
||||
|
||||
XLA_TEST_P(BatchNormTest, RandomizedGradTests) {
|
||||
float epsilon = 0.001;
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
|
|
|
|||
|
|
@ -491,6 +491,16 @@ message BatchNormTrainingRequest {
|
|||
int64 feature_index = 5;
|
||||
}
|
||||
|
||||
message BatchNormInferenceRequest {
|
||||
ComputationDataHandle operand = 1;
|
||||
ComputationDataHandle scale = 2;
|
||||
ComputationDataHandle offset = 3;
|
||||
ComputationDataHandle mean = 4;
|
||||
ComputationDataHandle variance = 5;
|
||||
float epsilon = 6;
|
||||
int64 feature_index = 7;
|
||||
}
|
||||
|
||||
message BatchNormGradRequest {
|
||||
ComputationDataHandle operand = 1;
|
||||
ComputationDataHandle scale = 2;
|
||||
|
|
@ -813,7 +823,8 @@ message OpRequest {
|
|||
OutfeedRequest outfeed_request = 32;
|
||||
BatchNormTrainingRequest batch_norm_training_request = 35;
|
||||
BatchNormGradRequest batch_norm_grad_request = 37;
|
||||
// Next: 38
|
||||
BatchNormInferenceRequest batch_norm_inference_request = 38;
|
||||
// Next: 39
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user