mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Add a call_logit_fn utility for logit_fn's, similar to Estimator's _call_model_fn.
PiperOrigin-RevId: 165649388
This commit is contained in:
parent
a92bd5d5cb
commit
9b9e5989d2
|
|
@ -119,6 +119,18 @@ py_test(
|
|||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "logit_fns_test",
|
||||
size = "small",
|
||||
srcs = ["python/learn/estimators/logit_fns_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":learn",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python/estimator:model_fn",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "estimators_test",
|
||||
size = "small",
|
||||
|
|
|
|||
|
|
@ -321,6 +321,7 @@ from tensorflow.contrib.learn.python.learn.estimators.linear import LinearClassi
|
|||
from tensorflow.contrib.learn.python.learn.estimators.linear import LinearEstimator
|
||||
from tensorflow.contrib.learn.python.learn.estimators.linear import LinearRegressor
|
||||
from tensorflow.contrib.learn.python.learn.estimators.logistic_regressor import LogisticRegressor
|
||||
from tensorflow.contrib.learn.python.learn.estimators.logit_fns import call_logit_fn
|
||||
from tensorflow.contrib.learn.python.learn.estimators.logit_fns import dnn_logit_fn_builder
|
||||
from tensorflow.contrib.learn.python.learn.estimators.logit_fns import linear_logit_fn_builder
|
||||
from tensorflow.contrib.learn.python.learn.estimators.metric_key import MetricKey
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ should follow the following signature:
|
|||
Args:
|
||||
`features`: This is the first item returned from the `input_fn` passed to
|
||||
`train`, `evaluate`, and `predict`. This should be a single
|
||||
`Tensor` or `dict` of same.
|
||||
`Tensor` or `dict` of same, and is the only required argument.
|
||||
`mode`: Optional. Specifies if this training, evaluation or prediction. See
|
||||
`ModeKeys`.
|
||||
`params`: Optional `dict` of hyperparameters. Will receive what is passed to
|
||||
|
|
@ -39,10 +39,47 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.estimator import util
|
||||
from tensorflow.python.estimator.canned import dnn as dnn_core
|
||||
from tensorflow.python.estimator.canned import linear as linear_core
|
||||
from tensorflow.python.framework import ops
|
||||
|
||||
# pylint: disable=protected-access
|
||||
dnn_logit_fn_builder = dnn_core._dnn_logit_fn_builder
|
||||
linear_logit_fn_builder = linear_core._linear_logit_fn_builder
|
||||
# pylint: enable=protected-access
|
||||
|
||||
|
||||
def call_logit_fn(logit_fn, features, mode, params, config):
|
||||
"""Calls logit_fn.
|
||||
|
||||
A utility function that calls the provided logit_fn with the relevant subset
|
||||
of provided arguments. Similar to tf.estimator._call_model_fn().
|
||||
|
||||
Args:
|
||||
logit_fn: A logit_fn as defined above.
|
||||
features: The features dict.
|
||||
mode: TRAIN / EVAL / PREDICT ModeKeys.
|
||||
params: The hyperparameter dict.
|
||||
config: The configuration object.
|
||||
|
||||
Returns:
|
||||
A logit Tensor, the output of logit_fn.
|
||||
|
||||
Raises:
|
||||
ValueError: if logit_fn does not return a Tensor.
|
||||
"""
|
||||
logit_fn_args = util.fn_args(logit_fn)
|
||||
kwargs = {}
|
||||
if 'mode' in logit_fn_args:
|
||||
kwargs['mode'] = mode
|
||||
if 'params' in logit_fn_args:
|
||||
kwargs['params'] = params
|
||||
if 'config' in logit_fn_args:
|
||||
kwargs['config'] = config
|
||||
logit_fn_results = logit_fn(features=features, **kwargs)
|
||||
|
||||
if not isinstance(logit_fn_results, ops.Tensor):
|
||||
raise ValueError('model_fn should return a Tensor.')
|
||||
|
||||
return logit_fn_results
|
||||
|
|
|
|||
|
|
@ -0,0 +1,60 @@
|
|||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""logit_fn tests."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.learn.python.learn.estimators import logit_fns
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.estimator import model_fn
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class LogitFnTest(test.TestCase):
|
||||
|
||||
def test_simple_call_logit_fn(self):
|
||||
def dummy_logit_fn(features, mode):
|
||||
if mode == model_fn.ModeKeys.TRAIN:
|
||||
return features['f1']
|
||||
else:
|
||||
return features['f2']
|
||||
features = {
|
||||
'f1': constant_op.constant([2., 3.]),
|
||||
'f2': constant_op.constant([4., 5.])
|
||||
}
|
||||
logit_fn_result = logit_fns.call_logit_fn(
|
||||
dummy_logit_fn, features, model_fn.ModeKeys.EVAL, 'fake_params',
|
||||
'fake_config')
|
||||
with session.Session():
|
||||
self.assertAllClose([[4., 5.]], logit_fn_result.eval())
|
||||
|
||||
def test_should_return_tensor(self):
|
||||
|
||||
def invalid_logit_fn(features, params):
|
||||
return {
|
||||
'tensor1': features['f1'] * params['input_multiplier'],
|
||||
'tensor2': features['f2'] * params['input_multiplier']
|
||||
}
|
||||
features = {
|
||||
'f1': constant_op.constant([2., 3.]),
|
||||
'f2': constant_op.constant([4., 5.])
|
||||
}
|
||||
params = {'learning_rate': 0.001, 'input_multiplier': 2.0}
|
||||
with self.assertRaisesRegexp(ValueError, 'model_fn should return a Tensor'):
|
||||
logit_fns.call_logit_fn(invalid_logit_fn, features, 'fake_mode', params,
|
||||
'fake_config')
|
||||
Loading…
Reference in New Issue
Block a user