Fix TensorForest's saveable object names so loading a savedmodel works.

PiperOrigin-RevId: 163332598
This commit is contained in:
A. Unique TensorFlower 2017-07-27 06:24:46 -07:00 committed by TensorFlower Gardener
parent cda80a7850
commit 722f6f3611
2 changed files with 4 additions and 4 deletions

View File

@ -103,7 +103,7 @@ def tree_variable(params, tree_config, stats_handle, name, container=None):
""" """
with ops.name_scope(name, "TreeVariable") as name: with ops.name_scope(name, "TreeVariable") as name:
resource_handle = gen_model_ops.decision_tree_resource_handle_op( resource_handle = gen_model_ops.decision_tree_resource_handle_op(
container, name, name=name) container, shared_name=name, name=name)
create_op = gen_model_ops.create_tree_variable( create_op = gen_model_ops.create_tree_variable(
resource_handle, resource_handle,
@ -113,7 +113,7 @@ def tree_variable(params, tree_config, stats_handle, name, container=None):
# Adds the variable to the savable list. # Adds the variable to the savable list.
saveable = TreeVariableSavable(params, resource_handle, stats_handle, saveable = TreeVariableSavable(params, resource_handle, stats_handle,
create_op, create_op,
"tree_checkpoint_{0}".format(name)) resource_handle.name)
ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
resources.register_resource(resource_handle, create_op, is_initialized_op) resources.register_resource(resource_handle, create_op, is_initialized_op)
return resource_handle return resource_handle

View File

@ -99,7 +99,7 @@ def fertile_stats_variable(params, stats_config, name,
""" """
with ops.name_scope(name, "FertileStatsVariable") as name: with ops.name_scope(name, "FertileStatsVariable") as name:
resource_handle = gen_stats_ops.fertile_stats_resource_handle_op( resource_handle = gen_stats_ops.fertile_stats_resource_handle_op(
container, name, name=name) container, shared_name=name, name=name)
create_op = gen_stats_ops.create_fertile_stats_variable( create_op = gen_stats_ops.create_fertile_stats_variable(
resource_handle, stats_config, resource_handle, stats_config,
@ -108,7 +108,7 @@ def fertile_stats_variable(params, stats_config, name,
resource_handle) resource_handle)
# Adds the variable to the savable list. # Adds the variable to the savable list.
saveable = FertileStatsVariableSavable(params, resource_handle, create_op, saveable = FertileStatsVariableSavable(params, resource_handle, create_op,
"stats_checkpoint_{0}".format(name)) resource_handle.name)
ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
resources.register_resource(resource_handle, create_op, is_initialized_op) resources.register_resource(resource_handle, create_op, is_initialized_op)
return resource_handle return resource_handle