add tensor and cost inference functions (#17684)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17684

Adding tensor and cost inference functions to more int8 operators.

Reviewed By: yinghai

Differential Revision: D14174746

fbshipit-source-id: dfad975fa75899565c8fb61f1b7747a9206ebd22
This commit is contained in:
Jongsoo Park 2019-03-06 23:26:27 -08:00 committed by Facebook Github Bot
parent 3dba1285ab
commit 39423fbdd4
11 changed files with 144 additions and 113 deletions

View File

@ -106,7 +106,6 @@ Split a tensor into a list of tensors, given a lengths input, along the specifie
The `input` will be split into `K` parts. Each part of length The `input` will be split into `K` parts. Each part of length
`sum(lengths[i*k:i*k+k))`)DOC"); `sum(lengths[i*k:i*k+k))`)DOC");
namespace {
OpSchema::Cost CostInferenceForConcat( OpSchema::Cost CostInferenceForConcat(
const OperatorDef& def, const OperatorDef& def,
const vector<TensorShape>& in) { const vector<TensorShape>& in) {
@ -143,6 +142,7 @@ OpSchema::Cost CostInferenceForConcat(
return cost; return cost;
} }
namespace {
std::pair<std::vector<DeviceOption>, std::vector<DeviceOption>> std::pair<std::vector<DeviceOption>, std::vector<DeviceOption>>
concatOpDevInfer(const OperatorDef& def) { concatOpDevInfer(const OperatorDef& def) {
auto op_device = auto op_device =
@ -157,22 +157,9 @@ concatOpDevInfer(const OperatorDef& def) {
} }
} // namespace } // namespace
REGISTER_CPU_OPERATOR(Concat, ConcatOp<CPUContext>); vector<TensorShape> TensorInferenceForConcat(
OPERATOR_SCHEMA(Concat) const OperatorDef& def,
.NumInputs(1, INT_MAX) const vector<TensorShape>& in) {
.NumOutputs(2)
.Arg("axis", "*(type: int; default: -1)* Axis to concatenate on.")
.Arg(
"order",
"*(type: string; default='NCHW')* Order of blob dimensions. Concats on the C dimension.")
.Arg(
"add_axis",
"*(type: int)* Pass non-zero integer to add the axis specified in `axis` to all input tensors.")
.TensorInferenceFunction(OpSchema::NeedsAllInputShapes([](const OperatorDef&
def,
const vector<
TensorShape>&
in) {
ArgumentHelper helper(def); ArgumentHelper helper(def);
const int axis = helper.HasArgument("axis") const int axis = helper.HasArgument("axis")
? helper.GetSingleArgument<int>("axis", -1) ? helper.GetSingleArgument<int>("axis", -1)
@ -181,8 +168,7 @@ OPERATOR_SCHEMA(Concat)
bool add_axis = helper.GetSingleArgument<int>("add_axis", 0) != 0; bool add_axis = helper.GetSingleArgument<int>("add_axis", 0) != 0;
int adj_size = in[0].dims_size() + (add_axis ? 1 : 0); int adj_size = in[0].dims_size() + (add_axis ? 1 : 0);
const int canonical_axis = canonical_axis_index_(axis, adj_size); const int canonical_axis = canonical_axis_index_(axis, adj_size);
CAFFE_ENFORCE_LT( CAFFE_ENFORCE_LT(canonical_axis, adj_size, "Axis not in input ndim range.");
canonical_axis, adj_size, "Axis not in input ndim range.");
CAFFE_ENFORCE_GT(in.size(), 0); CAFFE_ENFORCE_GT(in.size(), 0);
vector<int> split_shape(1, in.size()); vector<int> split_shape(1, in.size());
vector<int> out_shape(in[0].dims().begin(), in[0].dims().end()); vector<int> out_shape(in[0].dims().begin(), in[0].dims().end());
@ -238,13 +224,26 @@ OPERATOR_SCHEMA(Concat)
} }
} }
if (def.output_size() == 1) { if (def.output_size() == 1) {
return vector<TensorShape>{ return vector<TensorShape>{CreateTensorShape(out_shape, in[0].data_type())};
CreateTensorShape(out_shape, in[0].data_type())};
} }
return vector<TensorShape>{ return vector<TensorShape>{
CreateTensorShape(out_shape, in[0].data_type()), CreateTensorShape(out_shape, in[0].data_type()),
CreateTensorShape(split_shape, TensorProto::INT32)}; CreateTensorShape(split_shape, TensorProto::INT32)};
})) }
REGISTER_CPU_OPERATOR(Concat, ConcatOp<CPUContext>);
OPERATOR_SCHEMA(Concat)
.NumInputs(1, INT_MAX)
.NumOutputs(2)
.Arg("axis", "*(type: int; default: -1)* Axis to concatenate on.")
.Arg(
"order",
"*(type: string; default='NCHW')* Order of blob dimensions. Concats on the C dimension.")
.Arg(
"add_axis",
"*(type: int)* Pass non-zero integer to add the axis specified in `axis` to all input tensors.")
.TensorInferenceFunction(
OpSchema::NeedsAllInputShapes(TensorInferenceForConcat))
.CostInferenceFunction(CostInferenceForConcat) .CostInferenceFunction(CostInferenceForConcat)
.DeviceInferenceFunction(concatOpDevInfer) .DeviceInferenceFunction(concatOpDevInfer)
.SetDoc(R"DOC( .SetDoc(R"DOC(

View File

@ -335,6 +335,14 @@ bool ConcatOp<Context>::RunOnDevice() {
return true; return true;
} }
OpSchema::Cost CostInferenceForConcat(
const OperatorDef& def,
const std::vector<TensorShape>& in);
std::vector<TensorShape> TensorInferenceForConcat(
const OperatorDef& def,
const std::vector<TensorShape>& in);
} // namespace caffe2 } // namespace caffe2
#endif // CAFFE2_OPERATORS_CONCAT_SPLIT_OP_H_ #endif // CAFFE2_OPERATORS_CONCAT_SPLIT_OP_H_

View File

@ -2,17 +2,6 @@
namespace caffe2 { namespace caffe2 {
namespace {
OpSchema::Cost CostInferenceForSum(
const OperatorDef& def,
const vector<TensorShape>& in) {
struct OpSchema::Cost cost = PointwiseCostInference<1>(def, in);
cost.flops *= (in.size() - 1);
cost.params_bytes = 0;
return cost;
}
} // namespace
REGISTER_CPU_OPERATOR(Sum, SumOp<CPUContext>); REGISTER_CPU_OPERATOR(Sum, SumOp<CPUContext>);
OPERATOR_SCHEMA(Sum) OPERATOR_SCHEMA(Sum)

View File

@ -7,27 +7,7 @@ REGISTER_CPU_OPERATOR(Flatten, FlattenOp<CPUContext>);
OPERATOR_SCHEMA(Flatten) OPERATOR_SCHEMA(Flatten)
.NumInputs(1) .NumInputs(1)
.NumOutputs(1) .NumOutputs(1)
.TensorInferenceFunction([](const OperatorDef& def, .TensorInferenceFunction(TensorInferenceForFlatten)
const vector<TensorShape>& in) {
ArgumentHelper helper(def);
const int axis = helper.GetSingleArgument<int>("axis", 1);
vector<TensorShape> out(1);
int64_t outer = 1;
int64_t inner = 1;
std::size_t index = 0;
for (auto d : in[0].dims()) {
if (index < axis) {
outer *= d;
} else {
inner *= d;
}
++index;
}
out[0].set_data_type(in[0].data_type());
out[0].add_dims(outer);
out[0].add_dims(inner);
return out;
})
.SetDoc(R"DOC( .SetDoc(R"DOC(
Flattens the input tensor into a 2D matrix. If input tensor has shape Flattens the input tensor into a 2D matrix. If input tensor has shape
$(d_0, d_1, ..., d_n)$ then the output will have shape $(d_0, d_1, ..., d_n)$ then the output will have shape

View File

@ -33,6 +33,29 @@ class FlattenOp : public Operator<Context> {
int axis_; int axis_;
}; };
inline std::vector<TensorShape> TensorInferenceForFlatten(
const OperatorDef& def,
const std::vector<TensorShape>& in) {
ArgumentHelper helper(def);
const int axis = helper.GetSingleArgument<int>("axis", 1);
std::vector<TensorShape> out(1);
int64_t outer = 1;
int64_t inner = 1;
std::size_t index = 0;
for (auto d : in[0].dims()) {
if (index < axis) {
outer *= d;
} else {
inner *= d;
}
++index;
}
out[0].set_data_type(in[0].data_type());
out[0].add_dims(outer);
out[0].add_dims(inner);
return out;
}
} // namespace caffe2 } // namespace caffe2
#endif // CAFFE2_OPERATORS_FLATTEN_OP_H_ #endif // CAFFE2_OPERATORS_FLATTEN_OP_H_

View File

@ -1,6 +1,8 @@
#include "caffe2/operators/quantized/int8_add_op.h"
#include <climits> #include <climits>
#include "caffe2/operators/quantized/int8_add_op.h" #include "caffe2/operators/utility_ops.h"
namespace caffe2 { namespace caffe2 {
@ -55,6 +57,8 @@ OPERATOR_SCHEMA(Int8Sum)
.NumInputs(1, std::numeric_limits<int>::max()) .NumInputs(1, std::numeric_limits<int>::max())
.NumOutputs(1) .NumOutputs(1)
.AllowInplace({{0, 0}, {1, 0}}) .AllowInplace({{0, 0}, {1, 0}})
.CostInferenceFunction(CostInferenceForSum)
.IdenticalTypeAndShapeOfInput(0)
.Arg("Y_scale", "Output tensor quantization scale") .Arg("Y_scale", "Output tensor quantization scale")
.Arg("Y_zero_point", "Output tensor quantization offset"); .Arg("Y_zero_point", "Output tensor quantization offset");
@ -62,6 +66,8 @@ OPERATOR_SCHEMA(Int8SumRelu)
.NumInputs(1, std::numeric_limits<int>::max()) .NumInputs(1, std::numeric_limits<int>::max())
.NumOutputs(1) .NumOutputs(1)
.AllowInplace({{0, 0}, {1, 0}}) .AllowInplace({{0, 0}, {1, 0}})
.CostInferenceFunction(CostInferenceForSum)
.IdenticalTypeAndShapeOfInput(0)
.Arg("Y_scale", "Output tensor quantization scale") .Arg("Y_scale", "Output tensor quantization scale")
.Arg("Y_zero_point", "Output tensor quantization offset"); .Arg("Y_zero_point", "Output tensor quantization offset");

View File

@ -1,5 +1,7 @@
#include "caffe2/operators/quantized/int8_concat_op.h" #include "caffe2/operators/quantized/int8_concat_op.h"
#include "caffe2/operators/concat_split_op.h"
namespace caffe2 { namespace caffe2 {
REGISTER_CPU_OPERATOR(Int8Concat, int8::Int8ConcatOp); REGISTER_CPU_OPERATOR(Int8Concat, int8::Int8ConcatOp);
@ -14,6 +16,9 @@ OPERATOR_SCHEMA(Int8Concat)
"add_axis", "add_axis",
"Pass 1 to add the axis specified in arg 'axis' to all " "Pass 1 to add the axis specified in arg 'axis' to all "
"input tensors") "input tensors")
.TensorInferenceFunction(
OpSchema::NeedsAllInputShapes(TensorInferenceForConcat))
.CostInferenceFunction(CostInferenceForConcat)
.SetDoc("Concatenate a list of tensors into a single tensor") .SetDoc("Concatenate a list of tensors into a single tensor")
.Output(0, "concat_result", "Concatenated tensor") .Output(0, "concat_result", "Concatenated tensor")
.Output(1, "split_info", "The dimensions of the inputs.") .Output(1, "split_info", "The dimensions of the inputs.")

View File

@ -1,12 +1,19 @@
#include "caffe2/operators/quantized/int8_fc_op.h" #include "caffe2/operators/quantized/int8_fc_op.h"
#include <functional>
#include "caffe2/operators/fc_inference.h"
namespace caffe2 { namespace caffe2 {
REGISTER_CPU_OPERATOR(Int8FC, int8::Int8FCOp); REGISTER_CPU_OPERATOR(Int8FC, int8::Int8FCOp);
using namespace std::placeholders;
OPERATOR_SCHEMA(Int8FC) OPERATOR_SCHEMA(Int8FC)
.NumInputs(3) .NumInputs(3)
.NumOutputs(1) .NumOutputs(1)
.TensorInferenceFunction(std::bind(FCShapeInference, _1, _2, false))
.CostInferenceFunction(std::bind(CostInferenceForFC, _1, _2, false))
.SetDoc(R"DOC( .SetDoc(R"DOC(
Computes the result of passing an input vector X into a fully Computes the result of passing an input vector X into a fully
connected layer with 2D weight matrix W and 1D bias vector b. That is, connected layer with 2D weight matrix W and 1D bias vector b. That is,

View File

@ -1,5 +1,7 @@
#include "caffe2/operators/quantized/int8_flatten_op.h" #include "caffe2/operators/quantized/int8_flatten_op.h"
#include "caffe2/operators/flatten_op.h"
namespace caffe2 { namespace caffe2 {
REGISTER_CPU_OPERATOR(Int8Flatten, int8::Int8FlattenOp); REGISTER_CPU_OPERATOR(Int8Flatten, int8::Int8FlattenOp);
@ -7,6 +9,7 @@ REGISTER_CPU_OPERATOR(Int8Flatten, int8::Int8FlattenOp);
OPERATOR_SCHEMA(Int8Flatten) OPERATOR_SCHEMA(Int8Flatten)
.NumInputs(1) .NumInputs(1)
.NumOutputs(1) .NumOutputs(1)
.TensorInferenceFunction(TensorInferenceForFlatten)
.SetDoc(R"DOC( .SetDoc(R"DOC(
Flattens the input tensor into a 2D matrix. If input tensor has shape Flattens the input tensor into a 2D matrix. If input tensor has shape
(d_0, d_1, ... d_n) then the output will have shape (d_0, d_1, ... d_n) then the output will have shape

View File

@ -12,7 +12,8 @@ OPERATOR_SCHEMA(Int8GivenTensorFill)
.SetDoc(R"DOC( .SetDoc(R"DOC(
Creates quantized tensor of type char(byte) with scale and zero point info. Creates quantized tensor of type char(byte) with scale and zero point info.
)DOC") )DOC")
.Output(0, "Tensor", "An Int8TensorCPU with scale and zero point info"); .Output(0, "Tensor", "An Int8TensorCPU with scale and zero point info")
.TensorInferenceFunction(FillerTensorInference<>);
OPERATOR_SCHEMA(Int8GivenIntTensorFill) OPERATOR_SCHEMA(Int8GivenIntTensorFill)
.NumInputs(0) .NumInputs(0)
@ -24,7 +25,8 @@ OPERATOR_SCHEMA(Int8GivenIntTensorFill)
.SetDoc(R"DOC( .SetDoc(R"DOC(
Creates quantized tensor of type int32 with scale and zero point info. Creates quantized tensor of type int32 with scale and zero point info.
)DOC") )DOC")
.Output(0, "Tensor", "An Int8TensorCPU with scale and zero point info"); .Output(0, "Tensor", "An Int8TensorCPU with scale and zero point info")
.TensorInferenceFunction(FillerTensorInference<>);
REGISTER_CPU_OPERATOR(Int8GivenTensorFill, int8::Int8GivenTensorFillOp); REGISTER_CPU_OPERATOR(Int8GivenTensorFill, int8::Int8GivenTensorFillOp);
REGISTER_CPU_OPERATOR(Int8GivenIntTensorFill, int8::Int8GivenIntTensorFillOp); REGISTER_CPU_OPERATOR(Int8GivenIntTensorFill, int8::Int8GivenIntTensorFillOp);

View File

@ -317,6 +317,15 @@ class SumOp : public Operator<Context> {
} }
}; };
inline OpSchema::Cost CostInferenceForSum(
const OperatorDef& def,
const std::vector<TensorShape>& in) {
struct OpSchema::Cost cost = PointwiseCostInference<1>(def, in);
cost.flops *= (in.size() - 1);
cost.params_bytes = 0;
return cost;
}
// WeightedSumOp computes the weighted sum of several tensors. The input should // WeightedSumOp computes the weighted sum of several tensors. The input should
// be in the form X_0, weight_0, X_1, weight_1, ... where X_i all have the same // be in the form X_0, weight_0, X_1, weight_1, ... where X_i all have the same
// shape, and weight_i are size 1 tensors that specifies the weight of each // shape, and weight_i are size 1 tensors that specifies the weight of each