mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 00:20:20 +01:00
Add a call_logit_fn utility for logit_fn's, similar to Estimator's _call_model_fn.
PiperOrigin-RevId: 165649388
This commit is contained in:
parent
a92bd5d5cb
commit
9b9e5989d2
|
|
@ -119,6 +119,18 @@ py_test(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "logit_fns_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["python/learn/estimators/logit_fns_test.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":learn",
|
||||||
|
"//tensorflow/python:client_testlib",
|
||||||
|
"//tensorflow/python/estimator:model_fn",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
name = "estimators_test",
|
name = "estimators_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
|
|
|
||||||
|
|
@ -321,6 +321,7 @@ from tensorflow.contrib.learn.python.learn.estimators.linear import LinearClassi
|
||||||
from tensorflow.contrib.learn.python.learn.estimators.linear import LinearEstimator
|
from tensorflow.contrib.learn.python.learn.estimators.linear import LinearEstimator
|
||||||
from tensorflow.contrib.learn.python.learn.estimators.linear import LinearRegressor
|
from tensorflow.contrib.learn.python.learn.estimators.linear import LinearRegressor
|
||||||
from tensorflow.contrib.learn.python.learn.estimators.logistic_regressor import LogisticRegressor
|
from tensorflow.contrib.learn.python.learn.estimators.logistic_regressor import LogisticRegressor
|
||||||
|
from tensorflow.contrib.learn.python.learn.estimators.logit_fns import call_logit_fn
|
||||||
from tensorflow.contrib.learn.python.learn.estimators.logit_fns import dnn_logit_fn_builder
|
from tensorflow.contrib.learn.python.learn.estimators.logit_fns import dnn_logit_fn_builder
|
||||||
from tensorflow.contrib.learn.python.learn.estimators.logit_fns import linear_logit_fn_builder
|
from tensorflow.contrib.learn.python.learn.estimators.logit_fns import linear_logit_fn_builder
|
||||||
from tensorflow.contrib.learn.python.learn.estimators.metric_key import MetricKey
|
from tensorflow.contrib.learn.python.learn.estimators.metric_key import MetricKey
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,7 @@ should follow the following signature:
|
||||||
Args:
|
Args:
|
||||||
`features`: This is the first item returned from the `input_fn` passed to
|
`features`: This is the first item returned from the `input_fn` passed to
|
||||||
`train`, `evaluate`, and `predict`. This should be a single
|
`train`, `evaluate`, and `predict`. This should be a single
|
||||||
`Tensor` or `dict` of same.
|
`Tensor` or `dict` of same, and is the only required argument.
|
||||||
`mode`: Optional. Specifies if this training, evaluation or prediction. See
|
`mode`: Optional. Specifies if this training, evaluation or prediction. See
|
||||||
`ModeKeys`.
|
`ModeKeys`.
|
||||||
`params`: Optional `dict` of hyperparameters. Will receive what is passed to
|
`params`: Optional `dict` of hyperparameters. Will receive what is passed to
|
||||||
|
|
@ -39,10 +39,47 @@ from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python.estimator import util
|
||||||
from tensorflow.python.estimator.canned import dnn as dnn_core
|
from tensorflow.python.estimator.canned import dnn as dnn_core
|
||||||
from tensorflow.python.estimator.canned import linear as linear_core
|
from tensorflow.python.estimator.canned import linear as linear_core
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
dnn_logit_fn_builder = dnn_core._dnn_logit_fn_builder
|
dnn_logit_fn_builder = dnn_core._dnn_logit_fn_builder
|
||||||
linear_logit_fn_builder = linear_core._linear_logit_fn_builder
|
linear_logit_fn_builder = linear_core._linear_logit_fn_builder
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
|
|
||||||
|
def call_logit_fn(logit_fn, features, mode, params, config):
|
||||||
|
"""Calls logit_fn.
|
||||||
|
|
||||||
|
A utility function that calls the provided logit_fn with the relevant subset
|
||||||
|
of provided arguments. Similar to tf.estimator._call_model_fn().
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logit_fn: A logit_fn as defined above.
|
||||||
|
features: The features dict.
|
||||||
|
mode: TRAIN / EVAL / PREDICT ModeKeys.
|
||||||
|
params: The hyperparameter dict.
|
||||||
|
config: The configuration object.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A logit Tensor, the output of logit_fn.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if logit_fn does not return a Tensor.
|
||||||
|
"""
|
||||||
|
logit_fn_args = util.fn_args(logit_fn)
|
||||||
|
kwargs = {}
|
||||||
|
if 'mode' in logit_fn_args:
|
||||||
|
kwargs['mode'] = mode
|
||||||
|
if 'params' in logit_fn_args:
|
||||||
|
kwargs['params'] = params
|
||||||
|
if 'config' in logit_fn_args:
|
||||||
|
kwargs['config'] = config
|
||||||
|
logit_fn_results = logit_fn(features=features, **kwargs)
|
||||||
|
|
||||||
|
if not isinstance(logit_fn_results, ops.Tensor):
|
||||||
|
raise ValueError('model_fn should return a Tensor.')
|
||||||
|
|
||||||
|
return logit_fn_results
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,60 @@
|
||||||
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""logit_fn tests."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.contrib.learn.python.learn.estimators import logit_fns
|
||||||
|
from tensorflow.python.client import session
|
||||||
|
from tensorflow.python.estimator import model_fn
|
||||||
|
from tensorflow.python.framework import constant_op
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
class LogitFnTest(test.TestCase):
|
||||||
|
|
||||||
|
def test_simple_call_logit_fn(self):
|
||||||
|
def dummy_logit_fn(features, mode):
|
||||||
|
if mode == model_fn.ModeKeys.TRAIN:
|
||||||
|
return features['f1']
|
||||||
|
else:
|
||||||
|
return features['f2']
|
||||||
|
features = {
|
||||||
|
'f1': constant_op.constant([2., 3.]),
|
||||||
|
'f2': constant_op.constant([4., 5.])
|
||||||
|
}
|
||||||
|
logit_fn_result = logit_fns.call_logit_fn(
|
||||||
|
dummy_logit_fn, features, model_fn.ModeKeys.EVAL, 'fake_params',
|
||||||
|
'fake_config')
|
||||||
|
with session.Session():
|
||||||
|
self.assertAllClose([[4., 5.]], logit_fn_result.eval())
|
||||||
|
|
||||||
|
def test_should_return_tensor(self):
|
||||||
|
|
||||||
|
def invalid_logit_fn(features, params):
|
||||||
|
return {
|
||||||
|
'tensor1': features['f1'] * params['input_multiplier'],
|
||||||
|
'tensor2': features['f2'] * params['input_multiplier']
|
||||||
|
}
|
||||||
|
features = {
|
||||||
|
'f1': constant_op.constant([2., 3.]),
|
||||||
|
'f2': constant_op.constant([4., 5.])
|
||||||
|
}
|
||||||
|
params = {'learning_rate': 0.001, 'input_multiplier': 2.0}
|
||||||
|
with self.assertRaisesRegexp(ValueError, 'model_fn should return a Tensor'):
|
||||||
|
logit_fns.call_logit_fn(invalid_logit_fn, features, 'fake_mode', params,
|
||||||
|
'fake_config')
|
||||||
Loading…
Reference in New Issue
Block a user