mirror of
https://github.com/zebrajr/opencv.git
synced 2025-12-06 12:19:50 +01:00
Update Mask-RCNN networks generator
This commit is contained in:
parent
cae2992af1
commit
1a27ff4518
|
|
@ -38,6 +38,8 @@ aspect_ratios = [float(ar) for ar in grid_anchor_generator['aspect_ratios']]
|
|||
width_stride = float(grid_anchor_generator['width_stride'][0])
|
||||
height_stride = float(grid_anchor_generator['height_stride'][0])
|
||||
features_stride = float(config['feature_extractor'][0]['first_stage_features_stride'][0])
|
||||
first_stage_nms_iou_threshold = float(config['first_stage_nms_iou_threshold'][0])
|
||||
first_stage_max_proposals = int(config['first_stage_max_proposals'][0])
|
||||
|
||||
print('Number of classes: %d' % num_classes)
|
||||
print('Scales: %s' % str(scales))
|
||||
|
|
@ -53,7 +55,8 @@ graph_def = parseTextGraph(args.output)
|
|||
removeIdentity(graph_def)
|
||||
|
||||
def to_remove(name, op):
|
||||
return name.startswith(scopesToIgnore) or not name.startswith(scopesToKeep)
|
||||
return name.startswith(scopesToIgnore) or not name.startswith(scopesToKeep) or \
|
||||
(name.startswith('CropAndResize') and op != 'CropAndResize')
|
||||
|
||||
removeUnusedNodesAndAttrs(to_remove, graph_def)
|
||||
|
||||
|
|
@ -123,20 +126,22 @@ detectionOut.input.append('proposals')
|
|||
detectionOut.addAttr('num_classes', 2)
|
||||
detectionOut.addAttr('share_location', True)
|
||||
detectionOut.addAttr('background_label_id', 0)
|
||||
detectionOut.addAttr('nms_threshold', 0.7)
|
||||
detectionOut.addAttr('nms_threshold', first_stage_nms_iou_threshold)
|
||||
detectionOut.addAttr('top_k', 6000)
|
||||
detectionOut.addAttr('code_type', "CENTER_SIZE")
|
||||
detectionOut.addAttr('keep_top_k', 100)
|
||||
detectionOut.addAttr('keep_top_k', first_stage_max_proposals)
|
||||
detectionOut.addAttr('clip', True)
|
||||
|
||||
graph_def.node.extend([detectionOut])
|
||||
|
||||
# Save as text.
|
||||
cropAndResizeNodesNames = []
|
||||
for node in reversed(topNodes):
|
||||
if node.op != 'CropAndResize':
|
||||
graph_def.node.extend([node])
|
||||
topNodes.pop()
|
||||
else:
|
||||
cropAndResizeNodesNames.append(node.name)
|
||||
if numCropAndResize == 1:
|
||||
break
|
||||
else:
|
||||
|
|
@ -166,11 +171,15 @@ for i in reversed(range(len(graph_def.node))):
|
|||
|
||||
if graph_def.node[i].name in ['SecondStageBoxPredictor/Flatten/flatten/Shape',
|
||||
'SecondStageBoxPredictor/Flatten/flatten/strided_slice',
|
||||
'SecondStageBoxPredictor/Flatten/flatten/Reshape/shape']:
|
||||
'SecondStageBoxPredictor/Flatten/flatten/Reshape/shape',
|
||||
'SecondStageBoxPredictor/Flatten_1/flatten/Shape',
|
||||
'SecondStageBoxPredictor/Flatten_1/flatten/strided_slice',
|
||||
'SecondStageBoxPredictor/Flatten_1/flatten/Reshape/shape']:
|
||||
del graph_def.node[i]
|
||||
|
||||
for node in graph_def.node:
|
||||
if node.name == 'SecondStageBoxPredictor/Flatten/flatten/Reshape':
|
||||
if node.name == 'SecondStageBoxPredictor/Flatten/flatten/Reshape' or \
|
||||
node.name == 'SecondStageBoxPredictor/Flatten_1/flatten/Reshape':
|
||||
node.op = 'Flatten'
|
||||
node.input.pop()
|
||||
|
||||
|
|
@ -178,6 +187,12 @@ for node in graph_def.node:
|
|||
'SecondStageBoxPredictor/BoxEncodingPredictor/MatMul']:
|
||||
node.addAttr('loc_pred_transposed', True)
|
||||
|
||||
if node.name.startswith('MaxPool2D'):
|
||||
assert(node.op == 'MaxPool')
|
||||
assert(len(cropAndResizeNodesNames) == 2)
|
||||
node.input = [cropAndResizeNodesNames[0]]
|
||||
del cropAndResizeNodesNames[0]
|
||||
|
||||
################################################################################
|
||||
### Postprocessing
|
||||
################################################################################
|
||||
|
|
@ -223,6 +238,11 @@ graph_def.node.extend([detectionOut])
|
|||
for node in reversed(topNodes):
|
||||
graph_def.node.extend([node])
|
||||
|
||||
if node.name.startswith('MaxPool2D'):
|
||||
assert(node.op == 'MaxPool')
|
||||
assert(len(cropAndResizeNodesNames) == 1)
|
||||
node.input = [cropAndResizeNodesNames[0]]
|
||||
|
||||
for i in reversed(range(len(graph_def.node))):
|
||||
if graph_def.node[i].op == 'CropAndResize':
|
||||
graph_def.node[i].input.insert(1, 'detection_out_final')
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user