mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Closes https://github.com/caffe2/caffe2/pull/1260 Differential Revision: D5906739 Pulled By: Yangqing fbshipit-source-id: e482ba9ba60b5337d9165f28f7ec68d4518a0902
86 lines
3.0 KiB
Python
86 lines
3.0 KiB
Python
# Copyright (c) 2016-present, Facebook, Inc.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
##############################################################################
|
|
|
|
## @package hsm_util
|
|
# Module caffe2.python.hsm_util
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
from __future__ import unicode_literals
|
|
|
|
from caffe2.proto import hsm_pb2
|
|
|
|
'''
|
|
Hierarchical softmax utility methods that can be used to:
|
|
1) create TreeProto structure given list of word_ids or NodeProtos
|
|
2) create HierarchyProto structure using the user-inputted TreeProto
|
|
'''
|
|
|
|
|
|
def create_node_with_words(words, name='node'):
|
|
node = hsm_pb2.NodeProto()
|
|
node.name = name
|
|
for word in words:
|
|
node.word_ids.append(word)
|
|
return node
|
|
|
|
|
|
def create_node_with_nodes(nodes, name='node'):
|
|
node = hsm_pb2.NodeProto()
|
|
node.name = name
|
|
for child_node in nodes:
|
|
new_child_node = node.children.add()
|
|
new_child_node.MergeFrom(child_node)
|
|
return node
|
|
|
|
|
|
def create_hierarchy(tree_proto):
|
|
max_index = 0
|
|
|
|
def create_path(path, word):
|
|
path_proto = hsm_pb2.PathProto()
|
|
path_proto.word_id = word
|
|
for entry in path:
|
|
new_path_node = path_proto.path_nodes.add()
|
|
new_path_node.index = entry[0]
|
|
new_path_node.length = entry[1]
|
|
new_path_node.target = entry[2]
|
|
return path_proto
|
|
|
|
def recursive_path_builder(node_proto, path, hierarchy_proto, max_index):
|
|
node_proto.offset = max_index
|
|
path.append([max_index,
|
|
len(node_proto.word_ids) + len(node_proto.children), 0])
|
|
max_index += len(node_proto.word_ids) + len(node_proto.children)
|
|
if hierarchy_proto.size < max_index:
|
|
hierarchy_proto.size = max_index
|
|
for target, node in enumerate(node_proto.children):
|
|
path[-1][2] = target
|
|
max_index = recursive_path_builder(node, path, hierarchy_proto,
|
|
max_index)
|
|
for target, word in enumerate(node_proto.word_ids):
|
|
path[-1][2] = target + len(node_proto.children)
|
|
path_entry = create_path(path, word)
|
|
new_path_entry = hierarchy_proto.paths.add()
|
|
new_path_entry.MergeFrom(path_entry)
|
|
del path[-1]
|
|
return max_index
|
|
|
|
node = tree_proto.root_node
|
|
hierarchy_proto = hsm_pb2.HierarchyProto()
|
|
path = []
|
|
max_index = recursive_path_builder(node, path, hierarchy_proto, max_index)
|
|
return hierarchy_proto
|