Merge pull request #52358 from Intel-tensorflow:gyshi/test_replace_add_with_biasadd

PiperOrigin-RevId: 403163981
Change-Id: If883ac2cba2f91224226eb1057198c00324b6561
This commit is contained in:
TensorFlower Gardener 2021-10-14 12:58:15 -07:00
commit 90eeb59514
3 changed files with 311 additions and 101 deletions

View File

@ -449,80 +449,204 @@ TEST_F(MklRemapperTest, FuseBatchNormWithRelu) {
TEST_F(MklRemapperTest, FuseMatMulWithBiasAddAndAdd) {
using ::tensorflow::ops::Placeholder;
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
for (const string& add_op : {"BiasAdd", "AddV2", "Add"}) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
auto input_shape = ops::Placeholder::Shape({4, 32});
auto input_shape_add = ops::Placeholder::Shape({4, 8});
auto filter_shape = ops::Placeholder::Shape({32, 8});
auto bias_shape = ops::Placeholder::Shape({8});
auto input_shape = ops::Placeholder::Shape({4, 32});
auto input_shape_add = ops::Placeholder::Shape({4, 8});
auto filter_shape = ops::Placeholder::Shape({32, 8});
auto bias_shape = ops::Placeholder::Shape({8});
auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape);
auto input_add =
Placeholder(s.WithOpName("input_add"), DT_FLOAT, input_shape_add);
auto filter = Placeholder(s.WithOpName("filter"), DT_FLOAT, filter_shape);
auto bias = Placeholder(s.WithOpName("bias"), DT_FLOAT, bias_shape);
auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape);
auto input_add =
Placeholder(s.WithOpName("input_add"), DT_FLOAT, input_shape_add);
auto filter = Placeholder(s.WithOpName("filter"), DT_FLOAT, filter_shape);
auto bias = Placeholder(s.WithOpName("bias"), DT_FLOAT, bias_shape);
auto matmul = ops::MatMul(s.WithOpName("matmul"), input, filter);
auto bias_add = ops::BiasAdd(s.WithOpName("bias_add"), matmul, bias);
auto matmul = ops::MatMul(s.WithOpName("matmul"), input, filter);
Output bias_add;
if (add_op == "BiasAdd")
bias_add = ops::BiasAdd(s.WithOpName("bias_add"), matmul, bias);
else if (add_op == "AddV2")
bias_add = ops::AddV2(s.WithOpName("bias_add"), matmul, bias);
else if (add_op == "Add")
bias_add = ops::Add(s.WithOpName("bias_add"), bias, matmul);
auto fetch = s.WithOpName("fetch");
auto add = ops::Add(s.WithOpName("add"), bias_add, input_add);
auto fetch = s.WithOpName("fetch");
auto add = ops::Add(s.WithOpName("add"), bias_add, input_add);
ops::Identity(fetch, add);
ops::Identity(fetch, add);
auto input_tensor = GenerateRandomTensor<DT_FLOAT>(
TensorShape(input_shape.shape_.dim_sizes()));
auto input_add_tensor = GenerateRandomTensor<DT_FLOAT>(
TensorShape(input_shape_add.shape_.dim_sizes()));
auto filter_tensor = GenerateRandomTensor<DT_FLOAT>(
TensorShape(filter_shape.shape_.dim_sizes()));
auto bias_tensor = GenerateRandomTensor<DT_FLOAT>(
TensorShape(bias_shape.shape_.dim_sizes()));
auto input_tensor = GenerateRandomTensor<DT_FLOAT>(
TensorShape(input_shape.shape_.dim_sizes()));
auto input_add_tensor = GenerateRandomTensor<DT_FLOAT>(
TensorShape(input_shape_add.shape_.dim_sizes()));
auto filter_tensor = GenerateRandomTensor<DT_FLOAT>(
TensorShape(filter_shape.shape_.dim_sizes()));
auto bias_tensor = GenerateRandomTensor<DT_FLOAT>(
TensorShape(bias_shape.shape_.dim_sizes()));
GrapplerItem item;
item.fetch = {"fetch"};
item.feed = {{"input", input_tensor},
{"filter", filter_tensor},
{"bias", bias_tensor},
{"input_add", input_add_tensor}};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
GrapplerItem item;
item.fetch = {"fetch"};
item.feed = {{"input", input_tensor},
{"filter", filter_tensor},
{"bias", bias_tensor},
{"input_add", input_add_tensor}};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
// Place all nodes on CPU.
for (int i = 0; i < item.graph.node_size(); ++i) {
item.graph.mutable_node(i)->set_device("/device:CPU:0");
// Place all nodes on CPU.
for (int i = 0; i < item.graph.node_size(); ++i) {
item.graph.mutable_node(i)->set_device("/device:CPU:0");
}
Remapper optimizer(RewriterConfig::AGGRESSIVE);
GraphDef output;
TF_CHECK_OK(optimizer.Optimize(nullptr, item, &output));
int found = 0;
for (const NodeDef& node : output.node()) {
auto fetch_node_name = "add";
if (node.name() == fetch_node_name) {
EXPECT_EQ("_FusedMatMul", node.op());
EXPECT_EQ("input", node.input(0));
EXPECT_EQ("filter", node.input(1));
EXPECT_EQ(2, node.attr().at("num_args").i());
EXPECT_EQ("bias", node.input(2));
EXPECT_EQ("input_add", node.input(3));
const auto fused_ops = node.attr().at("fused_ops").list().s();
EXPECT_EQ(2, fused_ops.size());
EXPECT_EQ("BiasAdd", fused_ops[0]);
EXPECT_EQ("Add", fused_ops[1]);
found++;
}
}
EXPECT_EQ(1, found);
auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
EXPECT_EQ(1, tensors_expected.size());
EXPECT_EQ(1, tensors.size());
test::ExpectClose(tensors_expected[0], tensors[0], 0, 1e-6);
}
}
Remapper optimizer(RewriterConfig::AGGRESSIVE);
GraphDef output;
TF_CHECK_OK(optimizer.Optimize(nullptr, item, &output));
class RelpaceAddWithBiasAddTest : public GrapplerTest {
public:
const string kAddOp = "Add";
const string kAddV2Op = "AddV2";
int found = 0;
for (const NodeDef& node : output.node()) {
auto fetch_node_name = "add";
if (node.name() == fetch_node_name) {
EXPECT_EQ("_FusedMatMul", node.op());
EXPECT_EQ("input", node.input(0));
EXPECT_EQ("filter", node.input(1));
protected:
template <DataType DTYPE>
void RelpaceAddWithBiasAddDepthwiseConv2D(const string& add_op) {
using ::tensorflow::ops::Placeholder;
EXPECT_EQ(2, node.attr().at("num_args").i());
EXPECT_EQ("bias", node.input(2));
EXPECT_EQ("input_add", node.input(3));
for (const string& activation : {"None", "Relu", "Relu6", "Elu"}) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
const auto fused_ops = node.attr().at("fused_ops").list().s();
EXPECT_EQ(2, fused_ops.size());
EXPECT_EQ("BiasAdd", fused_ops[0]);
EXPECT_EQ("Add", fused_ops[1]);
found++;
auto input_shape = Placeholder::Shape({8, 32, 32, 3});
auto filter_shape = Placeholder::Shape({1, 1, 3, 128});
auto bias_shape = Placeholder::Shape({128 * 3});
auto input = Placeholder(s.WithOpName("input"), DTYPE, input_shape);
auto filter = Placeholder(s.WithOpName("filter"), DTYPE, filter_shape);
auto bias = Placeholder(s.WithOpName("bias"), DTYPE, bias_shape);
std::vector<int> strides = {1, 1, 1, 1};
auto conv = ops::DepthwiseConv2dNative(s.WithOpName("depthwise_conv"),
input, filter, strides, "SAME");
Output bias_add;
if (add_op == kAddV2Op) {
bias_add = ops::AddV2(s.WithOpName(add_op), conv, bias);
} else {
bias_add = ops::Add(s.WithOpName(add_op), bias, conv);
}
ops::Identity fetch = [&]() -> ops::Identity {
auto activate = s.WithOpName("activation");
auto fetch = s.WithOpName("fetch");
if (activation == "Relu") {
return ops::Identity(fetch, ops::Relu(activate, bias_add));
} else if (activation == "Relu6") {
return ops::Identity(fetch, ops::Relu6(activate, bias_add));
} else if (activation == "Elu") {
return ops::Identity(fetch, ops::Elu(activate, bias_add));
}
return ops::Identity(fetch, bias_add);
}();
auto input_t = GenerateRandomTensor<DTYPE>({8, 32, 32, 3});
auto filter_t = GenerateRandomTensor<DTYPE>({1, 1, 3, 128});
auto bias_t = GenerateRandomTensor<DTYPE>({128 * 3});
GrapplerItem item;
item.fetch = {"fetch"};
item.feed = {{"input", input_t}, {"filter", filter_t}, {"bias", bias_t}};
TF_ASSERT_OK(s.ToGraphDef(&item.graph));
// Place all nodes on CPU.
for (int i = 0; i < item.graph.node_size(); ++i) {
item.graph.mutable_node(i)->set_device("/device:CPU:0");
}
Remapper optimizer(RewriterConfig::AGGRESSIVE);
GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
int found = 0;
for (const NodeDef& node : output.node()) {
if (node.name() == "activation") {
EXPECT_EQ(node.op(), "_FusedDepthwiseConv2dNative");
ASSERT_GE(node.input_size(), 3);
EXPECT_EQ(node.input(0), "input");
EXPECT_EQ(node.input(1), "filter");
EXPECT_EQ(node.attr().at("num_args").i(), 1);
EXPECT_EQ(node.input(2), "bias");
const auto fused_ops = node.attr().at("fused_ops").list().s();
ASSERT_EQ(fused_ops.size(), 2);
EXPECT_EQ(fused_ops[0], "BiasAdd");
EXPECT_EQ(fused_ops[1], activation);
found++;
} else if (node.name() == add_op) {
EXPECT_EQ(node.op(), "_FusedDepthwiseConv2dNative");
ASSERT_GE(node.input_size(), 3);
EXPECT_EQ(node.input(0), "input");
EXPECT_EQ(node.input(1), "filter");
EXPECT_EQ(node.attr().at("num_args").i(), 1);
EXPECT_EQ(node.input(2), "bias");
const auto fused_ops = node.attr().at("fused_ops").list().s();
ASSERT_EQ(fused_ops.size(), 1);
EXPECT_EQ(fused_ops[0], "BiasAdd");
found++;
}
}
EXPECT_EQ(found, 1);
auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
ASSERT_EQ(tensors_expected.size(), 1);
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
ASSERT_EQ(tensors.size(), 1);
if (DTYPE == DT_BFLOAT16)
test::ExpectClose(tensors[0], tensors_expected[0], 1e-2, 1e-2);
else
test::ExpectClose(tensors[0], tensors_expected[0], 1e-6);
}
}
EXPECT_EQ(1, found);
};
auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
EXPECT_EQ(1, tensors_expected.size());
EXPECT_EQ(1, tensors.size());
test::ExpectClose(tensors_expected[0], tensors[0], 0, 1e-6);
}
#define CREATE_REPLACEADDWITHBIASADD_TEST_1(ops, addop, dtype) \
TEST_F(RelpaceAddWithBiasAddTest, RelpaceAddWithBiasAdd##ops##_##addop) { \
RelpaceAddWithBiasAddDepthwiseConv2D<dtype>(#addop); \
}
CREATE_REPLACEADDWITHBIASADD_TEST_1(DepthConv2D, AddV2, DT_FLOAT);
CREATE_REPLACEADDWITHBIASADD_TEST_1(DepthConv2D, Add, DT_FLOAT);
class FusedMatMulBiasAddAndGeluTest : public GrapplerTest {
public:

View File

@ -143,23 +143,28 @@ struct TensorToHashBucket {
// Contraction node followed by a BiasAdd.
struct ContractionWithBiasAdd {
ContractionWithBiasAdd() = default;
ContractionWithBiasAdd(int contraction, int bias_add)
: contraction(contraction), bias_add(bias_add) {}
ContractionWithBiasAdd(int contraction, int bias_add, int bias_port)
: contraction(contraction), bias_add(bias_add), bias_port(bias_port) {}
int contraction = kMissingIndex;
int bias_add = kMissingIndex;
int bias_port = 1;
};
// Contraction node followed by a BiasAdd and Activation.
struct ContractionWithBiasAddAndActivation {
ContractionWithBiasAddAndActivation() = default;
ContractionWithBiasAddAndActivation(int contraction, int bias_add,
int activation)
: contraction(contraction), bias_add(bias_add), activation(activation) {}
int activation, int bias_port)
: contraction(contraction),
bias_add(bias_add),
activation(activation),
bias_port(bias_port) {}
int contraction = kMissingIndex;
int bias_add = kMissingIndex;
int activation = kMissingIndex;
int bias_port = 1;
};
// Contraction node followed by a Squeeze and BiasAdd.
@ -207,16 +212,18 @@ struct ContractionWithBatchNormAndActivation {
struct ContractionWithBiasAddAndAdd {
ContractionWithBiasAddAndAdd() = default;
ContractionWithBiasAddAndAdd(int contraction, int bias_add, int add,
int port_id)
int port_id, int bias_port)
: contraction(contraction),
bias_add(bias_add),
add(add),
port_id(port_id) {}
port_id(port_id),
bias_port(bias_port) {}
int contraction = kMissingIndex;
int bias_add = kMissingIndex;
int add = kMissingIndex;
int port_id = 0;
int bias_port = 1;
};
// Contraction node followed by a BiasAdd, Add and Relu.
@ -224,18 +231,21 @@ struct ContractionWithBiasAddAndAdd {
struct ContractionWithBiasAndAddActivation {
ContractionWithBiasAndAddActivation() = default;
ContractionWithBiasAndAddActivation(int contraction, int bias_add, int add,
int port_id, int activation)
int port_id, int activation,
int bias_port)
: contraction(contraction),
bias_add(bias_add),
add(add),
port_id(port_id),
activation(activation) {}
activation(activation),
bias_port(bias_port) {}
int contraction = kMissingIndex;
int bias_add = kMissingIndex;
int add = kMissingIndex;
int port_id = 0;
int activation = kMissingIndex;
int bias_port = 1;
};
bool IsInPreserveSet(const RemapperContext& ctx, const NodeDef* node) {
@ -436,6 +446,77 @@ inline bool HasAtMostOneDataFanoutAtPort0(
return absl::c_count_if(node_view.GetRegularFanout(0), predicate) <= 1;
}
bool IsConvOrMatMul(const NodeDef& node) {
return IsConv2D(node) || IsDepthwiseConv2dNative(node) || IsMatMul(node);
}
// Returns true if one input to Add is Conv2D or DepthwiseConv2dNative or
// MatMul, and the other input is semantically equivalent to BiasAdd.
bool IsBiasSemanticAdd(const RemapperContext& ctx,
const utils::MutableNodeView& node_view,
int& bias_port) {
if (!IsMKLEnabled()) return false;
const auto* node_def = node_view.node();
if (!IsAdd(*node_def) || node_view.NumRegularFanins() != 2) return false;
const auto& props = ctx.graph_properties.GetInputProperties(node_def->name());
if (props.size() < 2) return false;
const auto& regular_fanin_0 = node_view.GetRegularFanin(0);
const auto* node_view_0 = regular_fanin_0.node_view();
const auto* node_def_0 = node_view_0->node();
const auto& regular_fanin_1 = node_view.GetRegularFanin(1);
const auto* node_view_1 = regular_fanin_1.node_view();
const auto* node_def_1 = node_view_1->node();
auto is_channel_last_format = [](const NodeDef& node) -> bool {
if (node.attr().contains("data_format")) {
const string data_format = node.attr().at("data_format").s();
return (data_format == "NHWC");
}
return true;
};
if (!IsConvOrMatMul(*node_def_0) && !IsConvOrMatMul(*node_def_1))
return false;
if (!is_channel_last_format(*node_def_0) ||
!is_channel_last_format(*node_def_1))
return false;
const TensorShapeProto& prot0_shape = props[0].shape();
const TensorShapeProto& prot1_shape = props[1].shape();
if (prot0_shape.unknown_rank() || prot1_shape.unknown_rank() ||
prot0_shape.dim_size() < 1 || prot1_shape.dim_size() < 1 ||
!IsKnown(prot0_shape.dim(prot0_shape.dim_size() - 1)) ||
!IsKnown(prot1_shape.dim(prot1_shape.dim_size() - 1)))
return false;
// Helper function to check Add/AddV2 could be replaced with BiasAdd.
const auto is_supported_shape =
[](const TensorShapeProto& shape,
const TensorShapeProto& bcast_shape) -> bool {
if (shape.dim_size() < 2 || bcast_shape.dim_size() != 1) return false;
int channel_dim = shape.dim(shape.dim_size() - 1).size();
return (channel_dim == bcast_shape.dim(0).size());
};
if (ShapesSymbolicallyEqual(prot0_shape, prot1_shape) ||
!ShapesBroadcastable(prot0_shape, prot1_shape))
return false;
if (is_supported_shape(prot0_shape, prot1_shape)) {
bias_port = 1;
return true;
}
if (is_supported_shape(prot1_shape, prot0_shape)) {
bias_port = 0;
return true;
}
return false;
}
bool FindContractionWithBias(const RemapperContext& ctx, int node_index,
ContractionWithBiasAdd* matched,
bool check_device_compatible = true) {
@ -445,11 +526,13 @@ bool FindContractionWithBias(const RemapperContext& ctx, int node_index,
if (HasControlFaninOrFanout(*node_view)) return false;
const auto* node_def = node_view->node();
if (!IsBiasAdd(*node_def)) return false;
int bias_port = 1;
if (!IsBiasAdd(*node_def) && !IsBiasSemanticAdd(ctx, *node_view, bias_port))
return false;
// Input to the BiasAdd must be a Conv2D/3D or a MatMul.
if (node_view->NumRegularFanins() < 1) return false;
const auto& regular_fanin_0 = node_view->GetRegularFanin(0);
const auto& regular_fanin_0 = node_view->GetRegularFanin(1 - bias_port);
const auto* contraction_node_view = regular_fanin_0.node_view();
const auto* contraction_node_def = contraction_node_view->node();
@ -467,7 +550,7 @@ bool FindContractionWithBias(const RemapperContext& ctx, int node_index,
// Check that data type and data format are supported on assigned device.
const ContractionWithBiasAdd pattern{contraction_node_view->node_index(),
node_index};
node_index, bias_port};
if (check_device_compatible && !IsDeviceCompatible(ctx, pattern))
return false;
@ -504,7 +587,7 @@ bool FindContractionWithBiasAndActivation(
// Get the contraction node
const auto* contraction_node_view =
bias_add_node_view->GetRegularFanin(0).node_view();
bias_add_node_view->GetRegularFanin(1 - base.bias_port).node_view();
const auto* contraction_node_def = contraction_node_view->node();
// Currently, only matmul + bias + (tanh or Sigmoid) is enabled
@ -519,8 +602,8 @@ bool FindContractionWithBiasAndActivation(
return false;
// Check that data type and data format are supported on assigned device.
const ContractionWithBiasAddAndActivation pattern{base.contraction,
base.bias_add, node_index};
const ContractionWithBiasAddAndActivation pattern{
base.contraction, base.bias_add, node_index, base.bias_port};
if (!IsDeviceCompatible(ctx, pattern)) return false;
// We successfully found a {Conv2D, MatMul}+BiasAdd+Activation pattern.
@ -745,6 +828,7 @@ bool FindContractionWithBiasAddAndAdd(const RemapperContext& ctx,
matched->contraction = base.contraction;
matched->bias_add = base.bias_add;
matched->add = node_view.node_index();
matched->bias_port = base.bias_port;
return true;
}
@ -807,7 +891,8 @@ bool FindContractionWithBiasAndAddActivation(
// We successfully found a Conv2D+BiasAdd+AddN+activation pattern
// or Conv3D+BiasAdd+AddN+activation pattern
const ContractionWithBiasAndAddActivation pattern{
base.contraction, base.bias_add, base.add, base.port_id, node_index};
base.contraction, base.bias_add, base.add,
base.port_id, node_index, base.bias_port};
*matched = pattern;
return true;
@ -1538,10 +1623,9 @@ Status AddFusedContractionNode(RemapperContext* ctx,
NodeDef fused_op;
fused_op.set_name(bias_add.name());
fused_op.set_device(contraction.device());
fused_op.add_input(contraction.input(0)); // 0: input
fused_op.add_input(contraction.input(1)); // 1: filter
fused_op.add_input(bias_add.input(1)); // 2: bias
fused_op.add_input(contraction.input(0)); // 0: input
fused_op.add_input(contraction.input(1)); // 1: filter
fused_op.add_input(bias_add.input(matched.bias_port)); // 2: bias
if (IsConv2D(contraction)) {
fused_op.set_op(kFusedConv2D);
CopyConv2DAttributes(contraction, &fused_op);
@ -1557,7 +1641,6 @@ Status AddFusedContractionNode(RemapperContext* ctx,
}
SetFusedOpAttributes(&fused_op, {"BiasAdd"});
utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
Status status;
mutation->AddNode(std::move(fused_op), &status);
@ -1589,9 +1672,9 @@ Status AddFusedContractionNode(
NodeDef fused_op;
fused_op.set_name(activation.name());
fused_op.set_device(contraction.device());
fused_op.add_input(contraction.input(0)); // 0: input
fused_op.add_input(contraction.input(1)); // 1: filter
fused_op.add_input(bias_add.input(1)); // 2: bias
fused_op.add_input(contraction.input(0)); // 0: input
fused_op.add_input(contraction.input(1)); // 1: filter
fused_op.add_input(bias_add.input(matched.bias_port)); // 2: bias
if (IsConv2D(contraction)) {
fused_op.set_op(kFusedConv2D);
@ -1778,7 +1861,7 @@ Status AddFusedContractionNode(RemapperContext* ctx,
contraction.input(0)); // 0: input(conv) / a (matmul)
contraction_node.add_input(
contraction.input(1)); // 1: filter(conv) / b (matmul)
contraction_node.add_input(bias_add.input(1)); // 2: bias
contraction_node.add_input(bias_add.input(matched.bias_port)); // 2: bias
// Add OP has two inputs, one is conv+bias/matmul+bias pattern matched
// previously, the other input to add is fused here.
@ -1823,7 +1906,7 @@ Status AddFusedContractionNode(
fused_conv2d.add_input(contraction.input(0)); // 0: input
fused_conv2d.add_input(contraction.input(1)); // 1: filter
const NodeDef& bias_add = graph->node(matched.bias_add);
fused_conv2d.add_input(bias_add.input(1)); // 2: bias
fused_conv2d.add_input(bias_add.input(matched.bias_port)); // 2: bias
// Add OP has two inputs, one is conv+bias pattern matched previously,
// the other input to add is fused here.
@ -2266,24 +2349,19 @@ Status AddTensorToHashBucketNode(RemapperContext* ctx,
return Status::OK();
}
bool IsConv2DOrMatMul(const NodeDef& node) {
return IsConv2D(node) || IsMatMul(node);
}
bool IsContractionWithAdd(const RemapperContext& ctx, int node_index) {
const auto* node_view = ctx.graph_view.GetNode(node_index);
// Candidate for Conv2D + Add or Conv2D + BiasAdd + Add fusion.
// MatMul + Add or MatMul + BiasAdd + Add fusion.
// Candidate for Contraction + Add or Contraction + BiasAdd + Add fusion.
// Contraction candidate: MatMul, Conv2D, DepthwiseConv2dNative
auto is_supported_add_input = [](const auto* node_view) -> bool {
// Currently only support Conv2D and MatMul
if (IsConv2DOrMatMul(*node_view->node())) return true;
if (IsBiasAdd(*node_view->node())) {
if (IsConvOrMatMul(*node_view->node())) return true;
if (IsBiasAdd(*node_view->node()) || IsAdd(*node_view->node())) {
if (node_view->NumRegularFanins() < 2) return false;
const auto& bias_add_fanin_0 = node_view->GetRegularFanin(0);
const auto& bias_add_fanin_1 = node_view->GetRegularFanin(1);
return IsConv2DOrMatMul(*bias_add_fanin_0.node_view()->node()) ||
IsConv2DOrMatMul(*bias_add_fanin_1.node_view()->node());
return IsConvOrMatMul(*bias_add_fanin_0.node_view()->node()) ||
IsConvOrMatMul(*bias_add_fanin_1.node_view()->node());
}
return false;
};
@ -2341,7 +2419,8 @@ bool RequiresInferredShapes(const RemapperContext& ctx, int node_index) {
const auto* relu_fanin_0_node_view = relu_fanin_0.node_view();
const auto* relu_fanin_0_node_def = relu_fanin_0_node_view->node();
if (!IsBiasAdd(*relu_fanin_0_node_def)) return false;
if (!IsBiasAdd(*relu_fanin_0_node_def) && !IsAdd(*relu_fanin_0_node_def))
return false;
if (GetDataTypeFromAttr(*relu_fanin_0_node_def, "T") != DT_FLOAT)
return false;
@ -2349,11 +2428,9 @@ bool RequiresInferredShapes(const RemapperContext& ctx, int node_index) {
const auto& biasadd_fanin_0 = relu_fanin_0_node_view->GetRegularFanin(0);
const auto* biasadd_fanin_0_node_def = biasadd_fanin_0.node_view()->node();
if (!IsConv2D(*biasadd_fanin_0_node_def)) return false;
if (GetDataTypeFromAttr(*biasadd_fanin_0_node_def, "T") != DT_FLOAT)
return false;
return true;
};
@ -2412,7 +2489,6 @@ bool RequiresInferredShapes(const RemapperContext& ctx, int node_index) {
is_batch_norm_fusion_candidate() ||
is_batch_norm_grad_fusion_candidate();
}
} // namespace
Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,

View File

@ -107,11 +107,21 @@ class CostAnalysisTest(test.TestCase):
self.assertTrue(b"MatMul" in report)
self.assertTrue(b"ApplyAdam" in report)
self.assertTrue(b"Conv2D" in report)
self.assertTrue(b"Conv2DBackpropFilter" in report)
self.assertTrue(b"Softmax" in report)
for op_type in [b"MatMul", b"Conv2D", b"Conv2DBackpropFilter"]:
# When mkl is enabled, Conv2D and MatMul op followed by
# 1-dimension Add in this graph will be fused, but not
# in the mkl disabled case.
expected_matmul_count = 2
op_types = [b"MatMul", b"Conv2DBackpropFilter"]
if not test_util.IsMklEnabled():
self.assertTrue(b"Conv2D" in report)
expected_matmul_count = 3
op_types.append(b"Conv2D")
for op_type in op_types:
matcher = re.compile(
br"\s+" + op_type + br",\s*(\d+),\s*(\d+),\s*([\d\.eE+-]+)%,\s*" +
br"([\d\.eE+-]+)%,\s*(-?\d+),\s*(\d+),", re.MULTILINE)
@ -121,7 +131,7 @@ class CostAnalysisTest(test.TestCase):
# upper = int(m.group(5))
lower = int(m.group(6))
if op_type == b"MatMul":
self.assertEqual(3, op_count)
self.assertEqual(expected_matmul_count, op_count)
else:
self.assertEqual(1, op_count)
self.assertTrue(0 <= lower)