mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Concat axis=0
Summary: Previously, the code below would go out of bound. Reviewed By: xianjiec Differential Revision: D4968037 fbshipit-source-id: 3760e2cddc919c45d85ac644ac3fabf72dbaf666
This commit is contained in:
parent
1040b5f91c
commit
ffc6bad116
|
|
@ -33,6 +33,13 @@ class Concat(ModelLayer):
|
|||
"Concat expects that limited dimensions of the input tensor"
|
||||
shapes.append(list(field_type.field_type().shape))
|
||||
|
||||
if axis == 0:
|
||||
self.output_schema = schema.from_blob_list(
|
||||
input_record[0],
|
||||
[model.net.NextScopedBlob(name + '_output')]
|
||||
)
|
||||
return
|
||||
|
||||
concat_dim = 0
|
||||
for shape in shapes:
|
||||
concat_dim += shape[axis - 1]
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user