mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 00:20:20 +01:00
Fix TensorForest's saveable object names so loading a savedmodel works.
PiperOrigin-RevId: 163332598
This commit is contained in:
parent
cda80a7850
commit
722f6f3611
|
|
@ -103,7 +103,7 @@ def tree_variable(params, tree_config, stats_handle, name, container=None):
|
||||||
"""
|
"""
|
||||||
with ops.name_scope(name, "TreeVariable") as name:
|
with ops.name_scope(name, "TreeVariable") as name:
|
||||||
resource_handle = gen_model_ops.decision_tree_resource_handle_op(
|
resource_handle = gen_model_ops.decision_tree_resource_handle_op(
|
||||||
container, name, name=name)
|
container, shared_name=name, name=name)
|
||||||
|
|
||||||
create_op = gen_model_ops.create_tree_variable(
|
create_op = gen_model_ops.create_tree_variable(
|
||||||
resource_handle,
|
resource_handle,
|
||||||
|
|
@ -113,7 +113,7 @@ def tree_variable(params, tree_config, stats_handle, name, container=None):
|
||||||
# Adds the variable to the savable list.
|
# Adds the variable to the savable list.
|
||||||
saveable = TreeVariableSavable(params, resource_handle, stats_handle,
|
saveable = TreeVariableSavable(params, resource_handle, stats_handle,
|
||||||
create_op,
|
create_op,
|
||||||
"tree_checkpoint_{0}".format(name))
|
resource_handle.name)
|
||||||
ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
|
ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
|
||||||
resources.register_resource(resource_handle, create_op, is_initialized_op)
|
resources.register_resource(resource_handle, create_op, is_initialized_op)
|
||||||
return resource_handle
|
return resource_handle
|
||||||
|
|
|
||||||
|
|
@ -99,7 +99,7 @@ def fertile_stats_variable(params, stats_config, name,
|
||||||
"""
|
"""
|
||||||
with ops.name_scope(name, "FertileStatsVariable") as name:
|
with ops.name_scope(name, "FertileStatsVariable") as name:
|
||||||
resource_handle = gen_stats_ops.fertile_stats_resource_handle_op(
|
resource_handle = gen_stats_ops.fertile_stats_resource_handle_op(
|
||||||
container, name, name=name)
|
container, shared_name=name, name=name)
|
||||||
|
|
||||||
create_op = gen_stats_ops.create_fertile_stats_variable(
|
create_op = gen_stats_ops.create_fertile_stats_variable(
|
||||||
resource_handle, stats_config,
|
resource_handle, stats_config,
|
||||||
|
|
@ -108,7 +108,7 @@ def fertile_stats_variable(params, stats_config, name,
|
||||||
resource_handle)
|
resource_handle)
|
||||||
# Adds the variable to the savable list.
|
# Adds the variable to the savable list.
|
||||||
saveable = FertileStatsVariableSavable(params, resource_handle, create_op,
|
saveable = FertileStatsVariableSavable(params, resource_handle, create_op,
|
||||||
"stats_checkpoint_{0}".format(name))
|
resource_handle.name)
|
||||||
ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
|
ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
|
||||||
resources.register_resource(resource_handle, create_op, is_initialized_op)
|
resources.register_resource(resource_handle, create_op, is_initialized_op)
|
||||||
return resource_handle
|
return resource_handle
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user