mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Concat axis=0
Summary: Previously, the code below would go out of bound. Reviewed By: xianjiec Differential Revision: D4968037 fbshipit-source-id: 3760e2cddc919c45d85ac644ac3fabf72dbaf666
This commit is contained in:
parent
1040b5f91c
commit
ffc6bad116
|
|
@ -33,6 +33,13 @@ class Concat(ModelLayer):
|
||||||
"Concat expects that limited dimensions of the input tensor"
|
"Concat expects that limited dimensions of the input tensor"
|
||||||
shapes.append(list(field_type.field_type().shape))
|
shapes.append(list(field_type.field_type().shape))
|
||||||
|
|
||||||
|
if axis == 0:
|
||||||
|
self.output_schema = schema.from_blob_list(
|
||||||
|
input_record[0],
|
||||||
|
[model.net.NextScopedBlob(name + '_output')]
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
concat_dim = 0
|
concat_dim = 0
|
||||||
for shape in shapes:
|
for shape in shapes:
|
||||||
concat_dim += shape[axis - 1]
|
concat_dim += shape[axis - 1]
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user