mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Remove duplicated functions (#9601)
Summary: found by linter, duplication was likely introduced in previous code sync Pull Request resolved: https://github.com/pytorch/pytorch/pull/9601 Differential Revision: D8922379 Pulled By: bddppq fbshipit-source-id: 1f61bd7f539d823e62920615674a532ec0149623
This commit is contained in:
parent
adda789770
commit
a7afba7308
|
|
@ -641,35 +641,6 @@ class Caffe2Backend(Backend):
|
|||
cls._visit_and_substitute_raw_values(model.graph.node, raw_values_dict)
|
||||
|
||||
|
||||
@classmethod
|
||||
def _substitute_raw_value(cls, tp, raw_values_dict):
|
||||
if tp.HasField('raw_data') and tp.raw_data == bytes(b'__EXTERNAL'):
|
||||
if tp.name not in raw_values_dict:
|
||||
raise RuntimeError('TensorProto for value {} referenced raw data but it was not found!'.format(tp.name))
|
||||
else:
|
||||
tp.raw_data = raw_values_dict[tp.name]
|
||||
|
||||
@classmethod
|
||||
def _visit_and_substitute_raw_values(cls, nodes, raw_values_dict):
|
||||
for node in nodes:
|
||||
for attr in node.attribute:
|
||||
if attr.HasField('t'):
|
||||
cls._substitute_raw_value(attr.t, raw_values_dict)
|
||||
for t in attr.tensors:
|
||||
cls._substitute_raw_value(t, raw_values_dict)
|
||||
if attr.HasField('g'):
|
||||
cls._visit_and_substitute_raw_values(attr.g.node, raw_values_dict)
|
||||
for g in attr.graphs:
|
||||
cls._visit_and_substitute_raw_values(g.node, raw_values_dict)
|
||||
|
||||
@classmethod
|
||||
def _external_value_resolution_pass(cls, model, raw_values_dict):
|
||||
for init in model.graph.initializer:
|
||||
cls._substitute_raw_value(init, raw_values_dict)
|
||||
|
||||
cls._visit_and_substitute_raw_values(model.graph.node, raw_values_dict)
|
||||
|
||||
|
||||
@classmethod
|
||||
def _direct_initialize_parameters(cls, initializer, ws, device_option):
|
||||
for tp in initializer:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user