diff --git a/caffe2/operators/batch_sparse_to_dense_op.cc b/caffe2/operators/batch_sparse_to_dense_op.cc index 355a21ccf8f..e4b3642b35c 100644 --- a/caffe2/operators/batch_sparse_to_dense_op.cc +++ b/caffe2/operators/batch_sparse_to_dense_op.cc @@ -113,6 +113,32 @@ after running this operator. "2-D dense tensor, with 1st dim = len(lengths), 2nd dim = dense_last_dim" "in the arg list, the tensor is of the same data type as `values`." "Missing values are filled with default_value") + .TensorInferenceFunction([](const OperatorDef& def, + const vector& in) { + ArgumentHelper helper(def); + vector output_dims; + if (in.size() == 4) { + const auto& inference_dims = GetDimsVector(in[3]); + output_dims.insert(output_dims.end(), inference_dims.begin(), inference_dims.end()); + const int dense_last_dim = helper.GetSingleArgument("dense_last_dim", 0); + if(dense_last_dim > 0) { + CAFFE_ENFORCE( + output_dims.back() == dense_last_dim, + "The last dim of output_shape_inference should be consistent with dense_last_dim"); + } + } else { + const int dense_last_dim = helper.GetSingleArgument("dense_last_dim", 0); + CAFFE_ENFORCE( + dense_last_dim > 0, + "dense_last_dim must be set when output shape inference is unavailable"); + const auto& lens_dims = GetDimsVector(in[0]); + output_dims.insert(output_dims.end(), lens_dims[0]); + output_dims.insert(output_dims.end(), dense_last_dim); + } + vector out(1); + out[0] = CreateTensorShape(output_dims, in[2].data_type()); + return out; + }) .Arg( "dense_last_dim", "Optional, output dense last dimension. "