Add a call_logit_fn utility for logit_fn's, similar to Estimator's _call_model_fn.

PiperOrigin-RevId: 165649388
This commit is contained in:
A. Unique TensorFlower 2017-08-17 17:03:26 -07:00 committed by TensorFlower Gardener
parent a92bd5d5cb
commit 9b9e5989d2
4 changed files with 111 additions and 1 deletions

View File

@ -119,6 +119,18 @@ py_test(
],
)
py_test(
name = "logit_fns_test",
size = "small",
srcs = ["python/learn/estimators/logit_fns_test.py"],
srcs_version = "PY2AND3",
deps = [
":learn",
"//tensorflow/python:client_testlib",
"//tensorflow/python/estimator:model_fn",
],
)
py_test(
name = "estimators_test",
size = "small",

View File

@ -321,6 +321,7 @@ from tensorflow.contrib.learn.python.learn.estimators.linear import LinearClassi
from tensorflow.contrib.learn.python.learn.estimators.linear import LinearEstimator
from tensorflow.contrib.learn.python.learn.estimators.linear import LinearRegressor
from tensorflow.contrib.learn.python.learn.estimators.logistic_regressor import LogisticRegressor
from tensorflow.contrib.learn.python.learn.estimators.logit_fns import call_logit_fn
from tensorflow.contrib.learn.python.learn.estimators.logit_fns import dnn_logit_fn_builder
from tensorflow.contrib.learn.python.learn.estimators.logit_fns import linear_logit_fn_builder
from tensorflow.contrib.learn.python.learn.estimators.metric_key import MetricKey

View File

@ -21,7 +21,7 @@ should follow the following signature:
Args:
`features`: This is the first item returned from the `input_fn` passed to
`train`, `evaluate`, and `predict`. This should be a single
`Tensor` or `dict` of same.
`Tensor` or `dict` of same, and is the only required argument.
`mode`: Optional. Specifies if this training, evaluation or prediction. See
`ModeKeys`.
`params`: Optional `dict` of hyperparameters. Will receive what is passed to
@ -39,10 +39,47 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.estimator import util
from tensorflow.python.estimator.canned import dnn as dnn_core
from tensorflow.python.estimator.canned import linear as linear_core
from tensorflow.python.framework import ops
# pylint: disable=protected-access
dnn_logit_fn_builder = dnn_core._dnn_logit_fn_builder
linear_logit_fn_builder = linear_core._linear_logit_fn_builder
# pylint: enable=protected-access
def call_logit_fn(logit_fn, features, mode, params, config):
"""Calls logit_fn.
A utility function that calls the provided logit_fn with the relevant subset
of provided arguments. Similar to tf.estimator._call_model_fn().
Args:
logit_fn: A logit_fn as defined above.
features: The features dict.
mode: TRAIN / EVAL / PREDICT ModeKeys.
params: The hyperparameter dict.
config: The configuration object.
Returns:
A logit Tensor, the output of logit_fn.
Raises:
ValueError: if logit_fn does not return a Tensor.
"""
logit_fn_args = util.fn_args(logit_fn)
kwargs = {}
if 'mode' in logit_fn_args:
kwargs['mode'] = mode
if 'params' in logit_fn_args:
kwargs['params'] = params
if 'config' in logit_fn_args:
kwargs['config'] = config
logit_fn_results = logit_fn(features=features, **kwargs)
if not isinstance(logit_fn_results, ops.Tensor):
raise ValueError('model_fn should return a Tensor.')
return logit_fn_results

View File

@ -0,0 +1,60 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""logit_fn tests."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.learn.python.learn.estimators import logit_fns
from tensorflow.python.client import session
from tensorflow.python.estimator import model_fn
from tensorflow.python.framework import constant_op
from tensorflow.python.platform import test
class LogitFnTest(test.TestCase):
def test_simple_call_logit_fn(self):
def dummy_logit_fn(features, mode):
if mode == model_fn.ModeKeys.TRAIN:
return features['f1']
else:
return features['f2']
features = {
'f1': constant_op.constant([2., 3.]),
'f2': constant_op.constant([4., 5.])
}
logit_fn_result = logit_fns.call_logit_fn(
dummy_logit_fn, features, model_fn.ModeKeys.EVAL, 'fake_params',
'fake_config')
with session.Session():
self.assertAllClose([[4., 5.]], logit_fn_result.eval())
def test_should_return_tensor(self):
def invalid_logit_fn(features, params):
return {
'tensor1': features['f1'] * params['input_multiplier'],
'tensor2': features['f2'] * params['input_multiplier']
}
features = {
'f1': constant_op.constant([2., 3.]),
'f2': constant_op.constant([4., 5.])
}
params = {'learning_rate': 0.001, 'input_multiplier': 2.0}
with self.assertRaisesRegexp(ValueError, 'model_fn should return a Tensor'):
logit_fns.call_logit_fn(invalid_logit_fn, features, 'fake_mode', params,
'fake_config')