diff --git a/caffe2/python/layers/concat.py b/caffe2/python/layers/concat.py index 69ed6467294..dc2fcf496e9 100644 --- a/caffe2/python/layers/concat.py +++ b/caffe2/python/layers/concat.py @@ -33,6 +33,13 @@ class Concat(ModelLayer): "Concat expects that limited dimensions of the input tensor" shapes.append(list(field_type.field_type().shape)) + if axis == 0: + self.output_schema = schema.from_blob_list( + input_record[0], + [model.net.NextScopedBlob(name + '_output')] + ) + return + concat_dim = 0 for shape in shapes: concat_dim += shape[axis - 1]