mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Merge pull request #52358 from Intel-tensorflow:gyshi/test_replace_add_with_biasadd
PiperOrigin-RevId: 403163981 Change-Id: If883ac2cba2f91224226eb1057198c00324b6561
This commit is contained in:
commit
90eeb59514
|
|
@ -449,80 +449,204 @@ TEST_F(MklRemapperTest, FuseBatchNormWithRelu) {
|
|||
TEST_F(MklRemapperTest, FuseMatMulWithBiasAddAndAdd) {
|
||||
using ::tensorflow::ops::Placeholder;
|
||||
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
for (const string& add_op : {"BiasAdd", "AddV2", "Add"}) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
|
||||
auto input_shape = ops::Placeholder::Shape({4, 32});
|
||||
auto input_shape_add = ops::Placeholder::Shape({4, 8});
|
||||
auto filter_shape = ops::Placeholder::Shape({32, 8});
|
||||
auto bias_shape = ops::Placeholder::Shape({8});
|
||||
auto input_shape = ops::Placeholder::Shape({4, 32});
|
||||
auto input_shape_add = ops::Placeholder::Shape({4, 8});
|
||||
auto filter_shape = ops::Placeholder::Shape({32, 8});
|
||||
auto bias_shape = ops::Placeholder::Shape({8});
|
||||
|
||||
auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape);
|
||||
auto input_add =
|
||||
Placeholder(s.WithOpName("input_add"), DT_FLOAT, input_shape_add);
|
||||
auto filter = Placeholder(s.WithOpName("filter"), DT_FLOAT, filter_shape);
|
||||
auto bias = Placeholder(s.WithOpName("bias"), DT_FLOAT, bias_shape);
|
||||
auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape);
|
||||
auto input_add =
|
||||
Placeholder(s.WithOpName("input_add"), DT_FLOAT, input_shape_add);
|
||||
auto filter = Placeholder(s.WithOpName("filter"), DT_FLOAT, filter_shape);
|
||||
auto bias = Placeholder(s.WithOpName("bias"), DT_FLOAT, bias_shape);
|
||||
|
||||
auto matmul = ops::MatMul(s.WithOpName("matmul"), input, filter);
|
||||
auto bias_add = ops::BiasAdd(s.WithOpName("bias_add"), matmul, bias);
|
||||
auto matmul = ops::MatMul(s.WithOpName("matmul"), input, filter);
|
||||
Output bias_add;
|
||||
if (add_op == "BiasAdd")
|
||||
bias_add = ops::BiasAdd(s.WithOpName("bias_add"), matmul, bias);
|
||||
else if (add_op == "AddV2")
|
||||
bias_add = ops::AddV2(s.WithOpName("bias_add"), matmul, bias);
|
||||
else if (add_op == "Add")
|
||||
bias_add = ops::Add(s.WithOpName("bias_add"), bias, matmul);
|
||||
|
||||
auto fetch = s.WithOpName("fetch");
|
||||
auto add = ops::Add(s.WithOpName("add"), bias_add, input_add);
|
||||
auto fetch = s.WithOpName("fetch");
|
||||
auto add = ops::Add(s.WithOpName("add"), bias_add, input_add);
|
||||
|
||||
ops::Identity(fetch, add);
|
||||
ops::Identity(fetch, add);
|
||||
|
||||
auto input_tensor = GenerateRandomTensor<DT_FLOAT>(
|
||||
TensorShape(input_shape.shape_.dim_sizes()));
|
||||
auto input_add_tensor = GenerateRandomTensor<DT_FLOAT>(
|
||||
TensorShape(input_shape_add.shape_.dim_sizes()));
|
||||
auto filter_tensor = GenerateRandomTensor<DT_FLOAT>(
|
||||
TensorShape(filter_shape.shape_.dim_sizes()));
|
||||
auto bias_tensor = GenerateRandomTensor<DT_FLOAT>(
|
||||
TensorShape(bias_shape.shape_.dim_sizes()));
|
||||
auto input_tensor = GenerateRandomTensor<DT_FLOAT>(
|
||||
TensorShape(input_shape.shape_.dim_sizes()));
|
||||
auto input_add_tensor = GenerateRandomTensor<DT_FLOAT>(
|
||||
TensorShape(input_shape_add.shape_.dim_sizes()));
|
||||
auto filter_tensor = GenerateRandomTensor<DT_FLOAT>(
|
||||
TensorShape(filter_shape.shape_.dim_sizes()));
|
||||
auto bias_tensor = GenerateRandomTensor<DT_FLOAT>(
|
||||
TensorShape(bias_shape.shape_.dim_sizes()));
|
||||
|
||||
GrapplerItem item;
|
||||
item.fetch = {"fetch"};
|
||||
item.feed = {{"input", input_tensor},
|
||||
{"filter", filter_tensor},
|
||||
{"bias", bias_tensor},
|
||||
{"input_add", input_add_tensor}};
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
GrapplerItem item;
|
||||
item.fetch = {"fetch"};
|
||||
item.feed = {{"input", input_tensor},
|
||||
{"filter", filter_tensor},
|
||||
{"bias", bias_tensor},
|
||||
{"input_add", input_add_tensor}};
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
|
||||
// Place all nodes on CPU.
|
||||
for (int i = 0; i < item.graph.node_size(); ++i) {
|
||||
item.graph.mutable_node(i)->set_device("/device:CPU:0");
|
||||
// Place all nodes on CPU.
|
||||
for (int i = 0; i < item.graph.node_size(); ++i) {
|
||||
item.graph.mutable_node(i)->set_device("/device:CPU:0");
|
||||
}
|
||||
|
||||
Remapper optimizer(RewriterConfig::AGGRESSIVE);
|
||||
GraphDef output;
|
||||
TF_CHECK_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
|
||||
int found = 0;
|
||||
for (const NodeDef& node : output.node()) {
|
||||
auto fetch_node_name = "add";
|
||||
if (node.name() == fetch_node_name) {
|
||||
EXPECT_EQ("_FusedMatMul", node.op());
|
||||
EXPECT_EQ("input", node.input(0));
|
||||
EXPECT_EQ("filter", node.input(1));
|
||||
EXPECT_EQ(2, node.attr().at("num_args").i());
|
||||
EXPECT_EQ("bias", node.input(2));
|
||||
EXPECT_EQ("input_add", node.input(3));
|
||||
|
||||
const auto fused_ops = node.attr().at("fused_ops").list().s();
|
||||
EXPECT_EQ(2, fused_ops.size());
|
||||
EXPECT_EQ("BiasAdd", fused_ops[0]);
|
||||
EXPECT_EQ("Add", fused_ops[1]);
|
||||
found++;
|
||||
}
|
||||
}
|
||||
EXPECT_EQ(1, found);
|
||||
|
||||
auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
|
||||
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
|
||||
EXPECT_EQ(1, tensors_expected.size());
|
||||
EXPECT_EQ(1, tensors.size());
|
||||
test::ExpectClose(tensors_expected[0], tensors[0], 0, 1e-6);
|
||||
}
|
||||
}
|
||||
|
||||
Remapper optimizer(RewriterConfig::AGGRESSIVE);
|
||||
GraphDef output;
|
||||
TF_CHECK_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
class RelpaceAddWithBiasAddTest : public GrapplerTest {
|
||||
public:
|
||||
const string kAddOp = "Add";
|
||||
const string kAddV2Op = "AddV2";
|
||||
|
||||
int found = 0;
|
||||
for (const NodeDef& node : output.node()) {
|
||||
auto fetch_node_name = "add";
|
||||
if (node.name() == fetch_node_name) {
|
||||
EXPECT_EQ("_FusedMatMul", node.op());
|
||||
EXPECT_EQ("input", node.input(0));
|
||||
EXPECT_EQ("filter", node.input(1));
|
||||
protected:
|
||||
template <DataType DTYPE>
|
||||
void RelpaceAddWithBiasAddDepthwiseConv2D(const string& add_op) {
|
||||
using ::tensorflow::ops::Placeholder;
|
||||
|
||||
EXPECT_EQ(2, node.attr().at("num_args").i());
|
||||
EXPECT_EQ("bias", node.input(2));
|
||||
EXPECT_EQ("input_add", node.input(3));
|
||||
for (const string& activation : {"None", "Relu", "Relu6", "Elu"}) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
|
||||
const auto fused_ops = node.attr().at("fused_ops").list().s();
|
||||
EXPECT_EQ(2, fused_ops.size());
|
||||
EXPECT_EQ("BiasAdd", fused_ops[0]);
|
||||
EXPECT_EQ("Add", fused_ops[1]);
|
||||
found++;
|
||||
auto input_shape = Placeholder::Shape({8, 32, 32, 3});
|
||||
auto filter_shape = Placeholder::Shape({1, 1, 3, 128});
|
||||
auto bias_shape = Placeholder::Shape({128 * 3});
|
||||
|
||||
auto input = Placeholder(s.WithOpName("input"), DTYPE, input_shape);
|
||||
auto filter = Placeholder(s.WithOpName("filter"), DTYPE, filter_shape);
|
||||
auto bias = Placeholder(s.WithOpName("bias"), DTYPE, bias_shape);
|
||||
|
||||
std::vector<int> strides = {1, 1, 1, 1};
|
||||
auto conv = ops::DepthwiseConv2dNative(s.WithOpName("depthwise_conv"),
|
||||
input, filter, strides, "SAME");
|
||||
|
||||
Output bias_add;
|
||||
if (add_op == kAddV2Op) {
|
||||
bias_add = ops::AddV2(s.WithOpName(add_op), conv, bias);
|
||||
} else {
|
||||
bias_add = ops::Add(s.WithOpName(add_op), bias, conv);
|
||||
}
|
||||
|
||||
ops::Identity fetch = [&]() -> ops::Identity {
|
||||
auto activate = s.WithOpName("activation");
|
||||
auto fetch = s.WithOpName("fetch");
|
||||
|
||||
if (activation == "Relu") {
|
||||
return ops::Identity(fetch, ops::Relu(activate, bias_add));
|
||||
} else if (activation == "Relu6") {
|
||||
return ops::Identity(fetch, ops::Relu6(activate, bias_add));
|
||||
} else if (activation == "Elu") {
|
||||
return ops::Identity(fetch, ops::Elu(activate, bias_add));
|
||||
}
|
||||
|
||||
return ops::Identity(fetch, bias_add);
|
||||
}();
|
||||
|
||||
auto input_t = GenerateRandomTensor<DTYPE>({8, 32, 32, 3});
|
||||
auto filter_t = GenerateRandomTensor<DTYPE>({1, 1, 3, 128});
|
||||
auto bias_t = GenerateRandomTensor<DTYPE>({128 * 3});
|
||||
|
||||
GrapplerItem item;
|
||||
item.fetch = {"fetch"};
|
||||
item.feed = {{"input", input_t}, {"filter", filter_t}, {"bias", bias_t}};
|
||||
TF_ASSERT_OK(s.ToGraphDef(&item.graph));
|
||||
|
||||
// Place all nodes on CPU.
|
||||
for (int i = 0; i < item.graph.node_size(); ++i) {
|
||||
item.graph.mutable_node(i)->set_device("/device:CPU:0");
|
||||
}
|
||||
|
||||
Remapper optimizer(RewriterConfig::AGGRESSIVE);
|
||||
GraphDef output;
|
||||
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
|
||||
int found = 0;
|
||||
for (const NodeDef& node : output.node()) {
|
||||
if (node.name() == "activation") {
|
||||
EXPECT_EQ(node.op(), "_FusedDepthwiseConv2dNative");
|
||||
ASSERT_GE(node.input_size(), 3);
|
||||
EXPECT_EQ(node.input(0), "input");
|
||||
EXPECT_EQ(node.input(1), "filter");
|
||||
EXPECT_EQ(node.attr().at("num_args").i(), 1);
|
||||
EXPECT_EQ(node.input(2), "bias");
|
||||
|
||||
const auto fused_ops = node.attr().at("fused_ops").list().s();
|
||||
ASSERT_EQ(fused_ops.size(), 2);
|
||||
EXPECT_EQ(fused_ops[0], "BiasAdd");
|
||||
EXPECT_EQ(fused_ops[1], activation);
|
||||
|
||||
found++;
|
||||
} else if (node.name() == add_op) {
|
||||
EXPECT_EQ(node.op(), "_FusedDepthwiseConv2dNative");
|
||||
ASSERT_GE(node.input_size(), 3);
|
||||
EXPECT_EQ(node.input(0), "input");
|
||||
EXPECT_EQ(node.input(1), "filter");
|
||||
EXPECT_EQ(node.attr().at("num_args").i(), 1);
|
||||
EXPECT_EQ(node.input(2), "bias");
|
||||
|
||||
const auto fused_ops = node.attr().at("fused_ops").list().s();
|
||||
ASSERT_EQ(fused_ops.size(), 1);
|
||||
EXPECT_EQ(fused_ops[0], "BiasAdd");
|
||||
found++;
|
||||
}
|
||||
}
|
||||
EXPECT_EQ(found, 1);
|
||||
|
||||
auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
|
||||
ASSERT_EQ(tensors_expected.size(), 1);
|
||||
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
|
||||
ASSERT_EQ(tensors.size(), 1);
|
||||
|
||||
if (DTYPE == DT_BFLOAT16)
|
||||
test::ExpectClose(tensors[0], tensors_expected[0], 1e-2, 1e-2);
|
||||
else
|
||||
test::ExpectClose(tensors[0], tensors_expected[0], 1e-6);
|
||||
}
|
||||
}
|
||||
EXPECT_EQ(1, found);
|
||||
};
|
||||
|
||||
auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
|
||||
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
|
||||
EXPECT_EQ(1, tensors_expected.size());
|
||||
EXPECT_EQ(1, tensors.size());
|
||||
test::ExpectClose(tensors_expected[0], tensors[0], 0, 1e-6);
|
||||
}
|
||||
#define CREATE_REPLACEADDWITHBIASADD_TEST_1(ops, addop, dtype) \
|
||||
TEST_F(RelpaceAddWithBiasAddTest, RelpaceAddWithBiasAdd##ops##_##addop) { \
|
||||
RelpaceAddWithBiasAddDepthwiseConv2D<dtype>(#addop); \
|
||||
}
|
||||
CREATE_REPLACEADDWITHBIASADD_TEST_1(DepthConv2D, AddV2, DT_FLOAT);
|
||||
CREATE_REPLACEADDWITHBIASADD_TEST_1(DepthConv2D, Add, DT_FLOAT);
|
||||
|
||||
class FusedMatMulBiasAddAndGeluTest : public GrapplerTest {
|
||||
public:
|
||||
|
|
|
|||
|
|
@ -143,23 +143,28 @@ struct TensorToHashBucket {
|
|||
// Contraction node followed by a BiasAdd.
|
||||
struct ContractionWithBiasAdd {
|
||||
ContractionWithBiasAdd() = default;
|
||||
ContractionWithBiasAdd(int contraction, int bias_add)
|
||||
: contraction(contraction), bias_add(bias_add) {}
|
||||
ContractionWithBiasAdd(int contraction, int bias_add, int bias_port)
|
||||
: contraction(contraction), bias_add(bias_add), bias_port(bias_port) {}
|
||||
|
||||
int contraction = kMissingIndex;
|
||||
int bias_add = kMissingIndex;
|
||||
int bias_port = 1;
|
||||
};
|
||||
|
||||
// Contraction node followed by a BiasAdd and Activation.
|
||||
struct ContractionWithBiasAddAndActivation {
|
||||
ContractionWithBiasAddAndActivation() = default;
|
||||
ContractionWithBiasAddAndActivation(int contraction, int bias_add,
|
||||
int activation)
|
||||
: contraction(contraction), bias_add(bias_add), activation(activation) {}
|
||||
int activation, int bias_port)
|
||||
: contraction(contraction),
|
||||
bias_add(bias_add),
|
||||
activation(activation),
|
||||
bias_port(bias_port) {}
|
||||
|
||||
int contraction = kMissingIndex;
|
||||
int bias_add = kMissingIndex;
|
||||
int activation = kMissingIndex;
|
||||
int bias_port = 1;
|
||||
};
|
||||
|
||||
// Contraction node followed by a Squeeze and BiasAdd.
|
||||
|
|
@ -207,16 +212,18 @@ struct ContractionWithBatchNormAndActivation {
|
|||
struct ContractionWithBiasAddAndAdd {
|
||||
ContractionWithBiasAddAndAdd() = default;
|
||||
ContractionWithBiasAddAndAdd(int contraction, int bias_add, int add,
|
||||
int port_id)
|
||||
int port_id, int bias_port)
|
||||
: contraction(contraction),
|
||||
bias_add(bias_add),
|
||||
add(add),
|
||||
port_id(port_id) {}
|
||||
port_id(port_id),
|
||||
bias_port(bias_port) {}
|
||||
|
||||
int contraction = kMissingIndex;
|
||||
int bias_add = kMissingIndex;
|
||||
int add = kMissingIndex;
|
||||
int port_id = 0;
|
||||
int bias_port = 1;
|
||||
};
|
||||
|
||||
// Contraction node followed by a BiasAdd, Add and Relu.
|
||||
|
|
@ -224,18 +231,21 @@ struct ContractionWithBiasAddAndAdd {
|
|||
struct ContractionWithBiasAndAddActivation {
|
||||
ContractionWithBiasAndAddActivation() = default;
|
||||
ContractionWithBiasAndAddActivation(int contraction, int bias_add, int add,
|
||||
int port_id, int activation)
|
||||
int port_id, int activation,
|
||||
int bias_port)
|
||||
: contraction(contraction),
|
||||
bias_add(bias_add),
|
||||
add(add),
|
||||
port_id(port_id),
|
||||
activation(activation) {}
|
||||
activation(activation),
|
||||
bias_port(bias_port) {}
|
||||
|
||||
int contraction = kMissingIndex;
|
||||
int bias_add = kMissingIndex;
|
||||
int add = kMissingIndex;
|
||||
int port_id = 0;
|
||||
int activation = kMissingIndex;
|
||||
int bias_port = 1;
|
||||
};
|
||||
|
||||
bool IsInPreserveSet(const RemapperContext& ctx, const NodeDef* node) {
|
||||
|
|
@ -436,6 +446,77 @@ inline bool HasAtMostOneDataFanoutAtPort0(
|
|||
return absl::c_count_if(node_view.GetRegularFanout(0), predicate) <= 1;
|
||||
}
|
||||
|
||||
bool IsConvOrMatMul(const NodeDef& node) {
|
||||
return IsConv2D(node) || IsDepthwiseConv2dNative(node) || IsMatMul(node);
|
||||
}
|
||||
|
||||
// Returns true if one input to Add is Conv2D or DepthwiseConv2dNative or
|
||||
// MatMul, and the other input is semantically equivalent to BiasAdd.
|
||||
bool IsBiasSemanticAdd(const RemapperContext& ctx,
|
||||
const utils::MutableNodeView& node_view,
|
||||
int& bias_port) {
|
||||
if (!IsMKLEnabled()) return false;
|
||||
|
||||
const auto* node_def = node_view.node();
|
||||
if (!IsAdd(*node_def) || node_view.NumRegularFanins() != 2) return false;
|
||||
|
||||
const auto& props = ctx.graph_properties.GetInputProperties(node_def->name());
|
||||
if (props.size() < 2) return false;
|
||||
|
||||
const auto& regular_fanin_0 = node_view.GetRegularFanin(0);
|
||||
const auto* node_view_0 = regular_fanin_0.node_view();
|
||||
const auto* node_def_0 = node_view_0->node();
|
||||
const auto& regular_fanin_1 = node_view.GetRegularFanin(1);
|
||||
const auto* node_view_1 = regular_fanin_1.node_view();
|
||||
const auto* node_def_1 = node_view_1->node();
|
||||
|
||||
auto is_channel_last_format = [](const NodeDef& node) -> bool {
|
||||
if (node.attr().contains("data_format")) {
|
||||
const string data_format = node.attr().at("data_format").s();
|
||||
return (data_format == "NHWC");
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
if (!IsConvOrMatMul(*node_def_0) && !IsConvOrMatMul(*node_def_1))
|
||||
return false;
|
||||
|
||||
if (!is_channel_last_format(*node_def_0) ||
|
||||
!is_channel_last_format(*node_def_1))
|
||||
return false;
|
||||
|
||||
const TensorShapeProto& prot0_shape = props[0].shape();
|
||||
const TensorShapeProto& prot1_shape = props[1].shape();
|
||||
if (prot0_shape.unknown_rank() || prot1_shape.unknown_rank() ||
|
||||
prot0_shape.dim_size() < 1 || prot1_shape.dim_size() < 1 ||
|
||||
!IsKnown(prot0_shape.dim(prot0_shape.dim_size() - 1)) ||
|
||||
!IsKnown(prot1_shape.dim(prot1_shape.dim_size() - 1)))
|
||||
return false;
|
||||
|
||||
// Helper function to check Add/AddV2 could be replaced with BiasAdd.
|
||||
const auto is_supported_shape =
|
||||
[](const TensorShapeProto& shape,
|
||||
const TensorShapeProto& bcast_shape) -> bool {
|
||||
if (shape.dim_size() < 2 || bcast_shape.dim_size() != 1) return false;
|
||||
int channel_dim = shape.dim(shape.dim_size() - 1).size();
|
||||
return (channel_dim == bcast_shape.dim(0).size());
|
||||
};
|
||||
|
||||
if (ShapesSymbolicallyEqual(prot0_shape, prot1_shape) ||
|
||||
!ShapesBroadcastable(prot0_shape, prot1_shape))
|
||||
return false;
|
||||
|
||||
if (is_supported_shape(prot0_shape, prot1_shape)) {
|
||||
bias_port = 1;
|
||||
return true;
|
||||
}
|
||||
if (is_supported_shape(prot1_shape, prot0_shape)) {
|
||||
bias_port = 0;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool FindContractionWithBias(const RemapperContext& ctx, int node_index,
|
||||
ContractionWithBiasAdd* matched,
|
||||
bool check_device_compatible = true) {
|
||||
|
|
@ -445,11 +526,13 @@ bool FindContractionWithBias(const RemapperContext& ctx, int node_index,
|
|||
if (HasControlFaninOrFanout(*node_view)) return false;
|
||||
|
||||
const auto* node_def = node_view->node();
|
||||
if (!IsBiasAdd(*node_def)) return false;
|
||||
int bias_port = 1;
|
||||
if (!IsBiasAdd(*node_def) && !IsBiasSemanticAdd(ctx, *node_view, bias_port))
|
||||
return false;
|
||||
|
||||
// Input to the BiasAdd must be a Conv2D/3D or a MatMul.
|
||||
if (node_view->NumRegularFanins() < 1) return false;
|
||||
const auto& regular_fanin_0 = node_view->GetRegularFanin(0);
|
||||
const auto& regular_fanin_0 = node_view->GetRegularFanin(1 - bias_port);
|
||||
const auto* contraction_node_view = regular_fanin_0.node_view();
|
||||
const auto* contraction_node_def = contraction_node_view->node();
|
||||
|
||||
|
|
@ -467,7 +550,7 @@ bool FindContractionWithBias(const RemapperContext& ctx, int node_index,
|
|||
|
||||
// Check that data type and data format are supported on assigned device.
|
||||
const ContractionWithBiasAdd pattern{contraction_node_view->node_index(),
|
||||
node_index};
|
||||
node_index, bias_port};
|
||||
if (check_device_compatible && !IsDeviceCompatible(ctx, pattern))
|
||||
return false;
|
||||
|
||||
|
|
@ -504,7 +587,7 @@ bool FindContractionWithBiasAndActivation(
|
|||
|
||||
// Get the contraction node
|
||||
const auto* contraction_node_view =
|
||||
bias_add_node_view->GetRegularFanin(0).node_view();
|
||||
bias_add_node_view->GetRegularFanin(1 - base.bias_port).node_view();
|
||||
const auto* contraction_node_def = contraction_node_view->node();
|
||||
|
||||
// Currently, only matmul + bias + (tanh or Sigmoid) is enabled
|
||||
|
|
@ -519,8 +602,8 @@ bool FindContractionWithBiasAndActivation(
|
|||
return false;
|
||||
|
||||
// Check that data type and data format are supported on assigned device.
|
||||
const ContractionWithBiasAddAndActivation pattern{base.contraction,
|
||||
base.bias_add, node_index};
|
||||
const ContractionWithBiasAddAndActivation pattern{
|
||||
base.contraction, base.bias_add, node_index, base.bias_port};
|
||||
if (!IsDeviceCompatible(ctx, pattern)) return false;
|
||||
|
||||
// We successfully found a {Conv2D, MatMul}+BiasAdd+Activation pattern.
|
||||
|
|
@ -745,6 +828,7 @@ bool FindContractionWithBiasAddAndAdd(const RemapperContext& ctx,
|
|||
matched->contraction = base.contraction;
|
||||
matched->bias_add = base.bias_add;
|
||||
matched->add = node_view.node_index();
|
||||
matched->bias_port = base.bias_port;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
|
@ -807,7 +891,8 @@ bool FindContractionWithBiasAndAddActivation(
|
|||
// We successfully found a Conv2D+BiasAdd+AddN+activation pattern
|
||||
// or Conv3D+BiasAdd+AddN+activation pattern
|
||||
const ContractionWithBiasAndAddActivation pattern{
|
||||
base.contraction, base.bias_add, base.add, base.port_id, node_index};
|
||||
base.contraction, base.bias_add, base.add,
|
||||
base.port_id, node_index, base.bias_port};
|
||||
*matched = pattern;
|
||||
|
||||
return true;
|
||||
|
|
@ -1538,10 +1623,9 @@ Status AddFusedContractionNode(RemapperContext* ctx,
|
|||
NodeDef fused_op;
|
||||
fused_op.set_name(bias_add.name());
|
||||
fused_op.set_device(contraction.device());
|
||||
fused_op.add_input(contraction.input(0)); // 0: input
|
||||
fused_op.add_input(contraction.input(1)); // 1: filter
|
||||
fused_op.add_input(bias_add.input(1)); // 2: bias
|
||||
|
||||
fused_op.add_input(contraction.input(0)); // 0: input
|
||||
fused_op.add_input(contraction.input(1)); // 1: filter
|
||||
fused_op.add_input(bias_add.input(matched.bias_port)); // 2: bias
|
||||
if (IsConv2D(contraction)) {
|
||||
fused_op.set_op(kFusedConv2D);
|
||||
CopyConv2DAttributes(contraction, &fused_op);
|
||||
|
|
@ -1557,7 +1641,6 @@ Status AddFusedContractionNode(RemapperContext* ctx,
|
|||
}
|
||||
|
||||
SetFusedOpAttributes(&fused_op, {"BiasAdd"});
|
||||
|
||||
utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
|
||||
Status status;
|
||||
mutation->AddNode(std::move(fused_op), &status);
|
||||
|
|
@ -1589,9 +1672,9 @@ Status AddFusedContractionNode(
|
|||
NodeDef fused_op;
|
||||
fused_op.set_name(activation.name());
|
||||
fused_op.set_device(contraction.device());
|
||||
fused_op.add_input(contraction.input(0)); // 0: input
|
||||
fused_op.add_input(contraction.input(1)); // 1: filter
|
||||
fused_op.add_input(bias_add.input(1)); // 2: bias
|
||||
fused_op.add_input(contraction.input(0)); // 0: input
|
||||
fused_op.add_input(contraction.input(1)); // 1: filter
|
||||
fused_op.add_input(bias_add.input(matched.bias_port)); // 2: bias
|
||||
|
||||
if (IsConv2D(contraction)) {
|
||||
fused_op.set_op(kFusedConv2D);
|
||||
|
|
@ -1778,7 +1861,7 @@ Status AddFusedContractionNode(RemapperContext* ctx,
|
|||
contraction.input(0)); // 0: input(conv) / a (matmul)
|
||||
contraction_node.add_input(
|
||||
contraction.input(1)); // 1: filter(conv) / b (matmul)
|
||||
contraction_node.add_input(bias_add.input(1)); // 2: bias
|
||||
contraction_node.add_input(bias_add.input(matched.bias_port)); // 2: bias
|
||||
|
||||
// Add OP has two inputs, one is conv+bias/matmul+bias pattern matched
|
||||
// previously, the other input to add is fused here.
|
||||
|
|
@ -1823,7 +1906,7 @@ Status AddFusedContractionNode(
|
|||
fused_conv2d.add_input(contraction.input(0)); // 0: input
|
||||
fused_conv2d.add_input(contraction.input(1)); // 1: filter
|
||||
const NodeDef& bias_add = graph->node(matched.bias_add);
|
||||
fused_conv2d.add_input(bias_add.input(1)); // 2: bias
|
||||
fused_conv2d.add_input(bias_add.input(matched.bias_port)); // 2: bias
|
||||
|
||||
// Add OP has two inputs, one is conv+bias pattern matched previously,
|
||||
// the other input to add is fused here.
|
||||
|
|
@ -2266,24 +2349,19 @@ Status AddTensorToHashBucketNode(RemapperContext* ctx,
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
bool IsConv2DOrMatMul(const NodeDef& node) {
|
||||
return IsConv2D(node) || IsMatMul(node);
|
||||
}
|
||||
|
||||
bool IsContractionWithAdd(const RemapperContext& ctx, int node_index) {
|
||||
const auto* node_view = ctx.graph_view.GetNode(node_index);
|
||||
|
||||
// Candidate for Conv2D + Add or Conv2D + BiasAdd + Add fusion.
|
||||
// MatMul + Add or MatMul + BiasAdd + Add fusion.
|
||||
// Candidate for Contraction + Add or Contraction + BiasAdd + Add fusion.
|
||||
// Contraction candidate: MatMul, Conv2D, DepthwiseConv2dNative
|
||||
auto is_supported_add_input = [](const auto* node_view) -> bool {
|
||||
// Currently only support Conv2D and MatMul
|
||||
if (IsConv2DOrMatMul(*node_view->node())) return true;
|
||||
if (IsBiasAdd(*node_view->node())) {
|
||||
if (IsConvOrMatMul(*node_view->node())) return true;
|
||||
if (IsBiasAdd(*node_view->node()) || IsAdd(*node_view->node())) {
|
||||
if (node_view->NumRegularFanins() < 2) return false;
|
||||
const auto& bias_add_fanin_0 = node_view->GetRegularFanin(0);
|
||||
const auto& bias_add_fanin_1 = node_view->GetRegularFanin(1);
|
||||
return IsConv2DOrMatMul(*bias_add_fanin_0.node_view()->node()) ||
|
||||
IsConv2DOrMatMul(*bias_add_fanin_1.node_view()->node());
|
||||
return IsConvOrMatMul(*bias_add_fanin_0.node_view()->node()) ||
|
||||
IsConvOrMatMul(*bias_add_fanin_1.node_view()->node());
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
|
@ -2341,7 +2419,8 @@ bool RequiresInferredShapes(const RemapperContext& ctx, int node_index) {
|
|||
const auto* relu_fanin_0_node_view = relu_fanin_0.node_view();
|
||||
const auto* relu_fanin_0_node_def = relu_fanin_0_node_view->node();
|
||||
|
||||
if (!IsBiasAdd(*relu_fanin_0_node_def)) return false;
|
||||
if (!IsBiasAdd(*relu_fanin_0_node_def) && !IsAdd(*relu_fanin_0_node_def))
|
||||
return false;
|
||||
if (GetDataTypeFromAttr(*relu_fanin_0_node_def, "T") != DT_FLOAT)
|
||||
return false;
|
||||
|
||||
|
|
@ -2349,11 +2428,9 @@ bool RequiresInferredShapes(const RemapperContext& ctx, int node_index) {
|
|||
|
||||
const auto& biasadd_fanin_0 = relu_fanin_0_node_view->GetRegularFanin(0);
|
||||
const auto* biasadd_fanin_0_node_def = biasadd_fanin_0.node_view()->node();
|
||||
|
||||
if (!IsConv2D(*biasadd_fanin_0_node_def)) return false;
|
||||
if (GetDataTypeFromAttr(*biasadd_fanin_0_node_def, "T") != DT_FLOAT)
|
||||
return false;
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
|
|
@ -2412,7 +2489,6 @@ bool RequiresInferredShapes(const RemapperContext& ctx, int node_index) {
|
|||
is_batch_norm_fusion_candidate() ||
|
||||
is_batch_norm_grad_fusion_candidate();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
|
|
|
|||
|
|
@ -107,11 +107,21 @@ class CostAnalysisTest(test.TestCase):
|
|||
|
||||
self.assertTrue(b"MatMul" in report)
|
||||
self.assertTrue(b"ApplyAdam" in report)
|
||||
self.assertTrue(b"Conv2D" in report)
|
||||
self.assertTrue(b"Conv2DBackpropFilter" in report)
|
||||
self.assertTrue(b"Softmax" in report)
|
||||
|
||||
for op_type in [b"MatMul", b"Conv2D", b"Conv2DBackpropFilter"]:
|
||||
# When mkl is enabled, Conv2D and MatMul op followed by
|
||||
# 1-dimension Add in this graph will be fused, but not
|
||||
# in the mkl disabled case.
|
||||
expected_matmul_count = 2
|
||||
op_types = [b"MatMul", b"Conv2DBackpropFilter"]
|
||||
|
||||
if not test_util.IsMklEnabled():
|
||||
self.assertTrue(b"Conv2D" in report)
|
||||
expected_matmul_count = 3
|
||||
op_types.append(b"Conv2D")
|
||||
|
||||
for op_type in op_types:
|
||||
matcher = re.compile(
|
||||
br"\s+" + op_type + br",\s*(\d+),\s*(\d+),\s*([\d\.eE+-]+)%,\s*" +
|
||||
br"([\d\.eE+-]+)%,\s*(-?\d+),\s*(\d+),", re.MULTILINE)
|
||||
|
|
@ -121,7 +131,7 @@ class CostAnalysisTest(test.TestCase):
|
|||
# upper = int(m.group(5))
|
||||
lower = int(m.group(6))
|
||||
if op_type == b"MatMul":
|
||||
self.assertEqual(3, op_count)
|
||||
self.assertEqual(expected_matmul_count, op_count)
|
||||
else:
|
||||
self.assertEqual(1, op_count)
|
||||
self.assertTrue(0 <= lower)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user