mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
Add regression variance over individual trees to TensorForest inference.
PiperOrigin-RevId: 163695881
This commit is contained in:
parent
15e928d51e
commit
1560c55d2d
|
|
@ -44,6 +44,7 @@ from tensorflow.python.training import session_run_hook
|
|||
KEYS_NAME = 'keys'
|
||||
LOSS_NAME = 'rf_training_loss'
|
||||
TREE_PATHS_PREDICTION_KEY = 'tree_paths'
|
||||
VARIANCE_PREDICTION_KEY = 'regression_variance'
|
||||
|
||||
EPSILON = 0.000001
|
||||
|
||||
|
|
@ -195,7 +196,8 @@ def get_model_fn(params,
|
|||
graph_builder = graph_builder_class(params,
|
||||
device_assigner=dev_assn)
|
||||
|
||||
logits, tree_paths = graph_builder.inference_graph(features)
|
||||
logits, tree_paths, regression_variance = graph_builder.inference_graph(
|
||||
features)
|
||||
|
||||
summary.scalar('average_tree_size', graph_builder.average_size())
|
||||
# For binary classification problems, convert probabilities to logits.
|
||||
|
|
@ -265,6 +267,9 @@ def get_model_fn(params,
|
|||
if params.inference_tree_paths:
|
||||
model_ops.predictions[TREE_PATHS_PREDICTION_KEY] = tree_paths
|
||||
|
||||
if params.regression:
|
||||
model_ops.predictions[VARIANCE_PREDICTION_KEY] = regression_variance
|
||||
|
||||
return model_ops
|
||||
|
||||
return _model_fn
|
||||
|
|
|
|||
|
|
@ -474,7 +474,8 @@ class RandomForestGraphs(object):
|
|||
**inference_args: Keyword arguments to pass through to each tree.
|
||||
|
||||
Returns:
|
||||
A tuple of (probabilities, tree_paths).
|
||||
A tuple of (probabilities, tree_paths, variance), where variance
|
||||
is the variance over all the trees for regression problems only.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If trying to use feature bagging with sparse
|
||||
|
|
@ -501,12 +502,20 @@ class RandomForestGraphs(object):
|
|||
probabilities.append(probs)
|
||||
paths.append(path)
|
||||
with ops.device(self.variables.device_dummies[0].device):
|
||||
all_predict = array_ops.stack(probabilities)
|
||||
return math_ops.div(
|
||||
math_ops.reduce_sum(all_predict, 0),
|
||||
# shape of all_predict should be [batch_size, num_trees, num_outputs]
|
||||
all_predict = array_ops.stack(probabilities, axis=1)
|
||||
average_values = math_ops.div(
|
||||
math_ops.reduce_sum(all_predict, 1),
|
||||
self.params.num_trees,
|
||||
name='probabilities'), array_ops.stack(
|
||||
paths, axis=1)
|
||||
name='probabilities')
|
||||
tree_paths = array_ops.stack(paths, axis=1)
|
||||
regression_variance = None
|
||||
if self.params.regression:
|
||||
expected_squares = math_ops.div(
|
||||
math_ops.reduce_sum(all_predict * all_predict, 1),
|
||||
self.params.num_trees)
|
||||
regression_variance = expected_squares - average_values * average_values
|
||||
return average_values, tree_paths, regression_variance
|
||||
|
||||
def average_size(self):
|
||||
"""Constructs a TF graph for evaluating the average size of a forest.
|
||||
|
|
|
|||
|
|
@ -105,9 +105,10 @@ class TensorForestTest(test_util.TensorFlowTestCase):
|
|||
split_after_samples=25).fill()
|
||||
|
||||
graph_builder = tensor_forest.RandomForestGraphs(params)
|
||||
probs, paths = graph_builder.inference_graph(input_data)
|
||||
probs, paths, var = graph_builder.inference_graph(input_data)
|
||||
self.assertTrue(isinstance(probs, ops.Tensor))
|
||||
self.assertTrue(isinstance(paths, ops.Tensor))
|
||||
self.assertIsNone(var)
|
||||
|
||||
def testTrainingConstructionClassificationSparse(self):
|
||||
input_data = sparse_tensor.SparseTensor(
|
||||
|
|
@ -144,12 +145,14 @@ class TensorForestTest(test_util.TensorFlowTestCase):
|
|||
num_features=10,
|
||||
num_trees=10,
|
||||
max_nodes=1000,
|
||||
regression=True,
|
||||
split_after_samples=25).fill()
|
||||
|
||||
graph_builder = tensor_forest.RandomForestGraphs(params)
|
||||
probs, paths = graph_builder.inference_graph(input_data)
|
||||
probs, paths, var = graph_builder.inference_graph(input_data)
|
||||
self.assertTrue(isinstance(probs, ops.Tensor))
|
||||
self.assertTrue(isinstance(paths, ops.Tensor))
|
||||
self.assertTrue(isinstance(var, ops.Tensor))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user