mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 00:20:20 +01:00
Update and Move DNNLinearCombinedRegressor to estimator/canned.
PiperOrigin-RevId: 157744087
This commit is contained in:
parent
d29bbeca3d
commit
15a740ebbb
373
tensorflow/python/estimator/canned/dnn_linear_combined.py
Normal file
373
tensorflow/python/estimator/canned/dnn_linear_combined.py
Normal file
|
|
@ -0,0 +1,373 @@
|
||||||
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""TensorFlow estimators for Linear and DNN joined training models."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
import six
|
||||||
|
|
||||||
|
from tensorflow.python.estimator import estimator
|
||||||
|
from tensorflow.python.estimator import model_fn
|
||||||
|
from tensorflow.python.estimator.canned import head as head_lib
|
||||||
|
from tensorflow.python.estimator.canned import optimizers
|
||||||
|
from tensorflow.python.feature_column import feature_column as feature_column_lib
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.layers import core as core_layers
|
||||||
|
from tensorflow.python.ops import control_flow_ops
|
||||||
|
from tensorflow.python.ops import init_ops
|
||||||
|
from tensorflow.python.ops import nn
|
||||||
|
from tensorflow.python.ops import partitioned_variables
|
||||||
|
from tensorflow.python.ops import state_ops
|
||||||
|
from tensorflow.python.ops import variable_scope
|
||||||
|
from tensorflow.python.summary import summary
|
||||||
|
from tensorflow.python.training import sync_replicas_optimizer
|
||||||
|
from tensorflow.python.training import training_util
|
||||||
|
|
||||||
|
# The default learning rates are a historical artifact of the initial
|
||||||
|
# implementation, but seem a reasonable choice.
|
||||||
|
_DNN_LEARNING_RATE = 0.05
|
||||||
|
_LINEAR_LEARNING_RATE = 0.2
|
||||||
|
|
||||||
|
|
||||||
|
def _check_no_sync_replicas_optimizer(optimizer):
|
||||||
|
if isinstance(optimizer, sync_replicas_optimizer.SyncReplicasOptimizer):
|
||||||
|
raise ValueError(
|
||||||
|
'SyncReplicasOptimizer does not support multi optimizers case. '
|
||||||
|
'Therefore, it is not supported in DNNLinearCombined model. '
|
||||||
|
'If you want to use this optimizer, please use either DNN or Linear '
|
||||||
|
'model.')
|
||||||
|
|
||||||
|
|
||||||
|
def _linear_learning_rate(num_linear_feature_columns):
|
||||||
|
"""Returns the default learning rate of the linear model.
|
||||||
|
|
||||||
|
The calculation is a historical artifact of this initial implementation, but
|
||||||
|
has proven a reasonable choice.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_linear_feature_columns: The number of feature columns of the linear
|
||||||
|
model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A float.
|
||||||
|
"""
|
||||||
|
default_learning_rate = 1. / math.sqrt(num_linear_feature_columns)
|
||||||
|
return min(_LINEAR_LEARNING_RATE, default_learning_rate)
|
||||||
|
|
||||||
|
|
||||||
|
def _add_layer_summary(value, tag):
|
||||||
|
summary.scalar('%s/fraction_of_zero_values' % tag, nn.zero_fraction(value))
|
||||||
|
summary.histogram('%s/activation' % tag, value)
|
||||||
|
|
||||||
|
|
||||||
|
def _dnn_linear_combined_model_fn(
|
||||||
|
features, labels, mode, head,
|
||||||
|
linear_feature_columns=None, linear_optimizer='Ftrl',
|
||||||
|
dnn_feature_columns=None, dnn_optimizer='Adagrad', dnn_hidden_units=None,
|
||||||
|
dnn_activation_fn=nn.relu, dnn_dropout=None,
|
||||||
|
input_layer_partitioner=None, config=None):
|
||||||
|
"""Deep Neural Net and Linear combined model_fn.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
features: `Tensor` or dict of `Tensor` (depends on data passed to `fit`).
|
||||||
|
labels: `Tensor` of shape [batch_size, 1] or [batch_size] labels of dtype
|
||||||
|
`int32` or `int64` in the range `[0, n_classes)`.
|
||||||
|
mode: Defines whether this is training, evaluation or prediction.
|
||||||
|
See `ModeKeys`.
|
||||||
|
head: A `Head` instance.
|
||||||
|
linear_feature_columns: An iterable containing all the feature columns used
|
||||||
|
by the Linear model.
|
||||||
|
linear_optimizer: string, `Optimizer` object, or callable that defines the
|
||||||
|
optimizer to use for training the Linear model. Defaults to the Ftrl
|
||||||
|
optimizer.
|
||||||
|
dnn_feature_columns: An iterable containing all the feature columns used by
|
||||||
|
the DNN model.
|
||||||
|
dnn_optimizer: string, `Optimizer` object, or callable that defines the
|
||||||
|
optimizer to use for training the DNN model. Defaults to the Adagrad
|
||||||
|
optimizer.
|
||||||
|
dnn_hidden_units: List of hidden units per DNN layer.
|
||||||
|
dnn_activation_fn: Activation function applied to each DNN layer. If `None`,
|
||||||
|
will use `tf.nn.relu`.
|
||||||
|
dnn_dropout: When not `None`, the probability we will drop out a given DNN
|
||||||
|
coordinate.
|
||||||
|
input_layer_partitioner: Partitioner for input layer.
|
||||||
|
config: `RunConfig` object to configure the runtime settings.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`ModelFnOps`
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If both `linear_feature_columns` and `dnn_features_columns`
|
||||||
|
are empty at the same time, or `input_layer_partitioner` is missing.
|
||||||
|
"""
|
||||||
|
if not linear_feature_columns and not dnn_feature_columns:
|
||||||
|
raise ValueError(
|
||||||
|
'Either linear_feature_columns or dnn_feature_columns must be defined.')
|
||||||
|
num_ps_replicas = config.num_ps_replicas if config else 0
|
||||||
|
input_layer_partitioner = input_layer_partitioner or (
|
||||||
|
partitioned_variables.min_max_variable_partitioner(
|
||||||
|
max_partitions=num_ps_replicas,
|
||||||
|
min_slice_size=64 << 20))
|
||||||
|
|
||||||
|
linear_optimizer = optimizers.get_optimizer_instance(
|
||||||
|
linear_optimizer,
|
||||||
|
learning_rate=_linear_learning_rate(len(linear_feature_columns)))
|
||||||
|
_check_no_sync_replicas_optimizer(linear_optimizer)
|
||||||
|
|
||||||
|
dnn_optimizer = optimizers.get_optimizer_instance(
|
||||||
|
dnn_optimizer,
|
||||||
|
learning_rate=_DNN_LEARNING_RATE)
|
||||||
|
_check_no_sync_replicas_optimizer(dnn_optimizer)
|
||||||
|
|
||||||
|
# Build DNN Logits.
|
||||||
|
dnn_parent_scope = 'dnn'
|
||||||
|
|
||||||
|
if not dnn_feature_columns:
|
||||||
|
dnn_logits = None
|
||||||
|
else:
|
||||||
|
if not dnn_hidden_units:
|
||||||
|
raise ValueError(
|
||||||
|
'dnn_hidden_units must be defined when dnn_feature_columns is '
|
||||||
|
'specified.')
|
||||||
|
dnn_partitioner = (
|
||||||
|
partitioned_variables.min_max_variable_partitioner(
|
||||||
|
max_partitions=num_ps_replicas))
|
||||||
|
with variable_scope.variable_scope(
|
||||||
|
dnn_parent_scope,
|
||||||
|
values=tuple(six.itervalues(features)),
|
||||||
|
partitioner=dnn_partitioner):
|
||||||
|
with variable_scope.variable_scope('input',
|
||||||
|
partitioner=input_layer_partitioner):
|
||||||
|
net = feature_column_lib.input_layer(
|
||||||
|
features=features,
|
||||||
|
feature_columns=dnn_feature_columns)
|
||||||
|
|
||||||
|
for layer_id, num_hidden_units in enumerate(dnn_hidden_units):
|
||||||
|
with variable_scope.variable_scope(
|
||||||
|
'hiddenlayer_%d' % layer_id,
|
||||||
|
values=(net,)) as dnn_hidden_layer_scope:
|
||||||
|
net = core_layers.dense(
|
||||||
|
net,
|
||||||
|
units=num_hidden_units,
|
||||||
|
activation=dnn_activation_fn,
|
||||||
|
kernel_initializer=init_ops.glorot_uniform_initializer(),
|
||||||
|
name=dnn_hidden_layer_scope)
|
||||||
|
if dnn_dropout is not None and mode == model_fn.ModeKeys.TRAIN:
|
||||||
|
net = core_layers.dropout(net, rate=dnn_dropout, training=True)
|
||||||
|
_add_layer_summary(net, dnn_hidden_layer_scope.name)
|
||||||
|
|
||||||
|
with variable_scope.variable_scope(
|
||||||
|
'logits',
|
||||||
|
values=(net,)) as dnn_logits_scope:
|
||||||
|
logits = core_layers.dense(
|
||||||
|
net,
|
||||||
|
units=head.logits_dimension,
|
||||||
|
activation=None,
|
||||||
|
kernel_initializer=init_ops.glorot_uniform_initializer(),
|
||||||
|
name=dnn_logits_scope)
|
||||||
|
_add_layer_summary(dnn_logits, dnn_logits_scope.name)
|
||||||
|
|
||||||
|
linear_parent_scope = 'linear'
|
||||||
|
|
||||||
|
if not linear_feature_columns:
|
||||||
|
linear_logits = None
|
||||||
|
else:
|
||||||
|
with variable_scope.variable_scope(
|
||||||
|
linear_parent_scope,
|
||||||
|
values=tuple(six.itervalues(features)),
|
||||||
|
partitioner=input_layer_partitioner) as scope:
|
||||||
|
linear_logits = feature_column_lib.linear_model(
|
||||||
|
features=features,
|
||||||
|
feature_columns=linear_feature_columns,
|
||||||
|
units=head.logits_dimension)
|
||||||
|
_add_layer_summary(linear_logits, scope.name)
|
||||||
|
|
||||||
|
# Combine logits and build full model.
|
||||||
|
if dnn_logits is not None and linear_logits is not None:
|
||||||
|
logits = dnn_logits + linear_logits
|
||||||
|
elif dnn_logits is not None:
|
||||||
|
logits = dnn_logits
|
||||||
|
else:
|
||||||
|
logits = linear_logits
|
||||||
|
|
||||||
|
def _train_op_fn(loss):
|
||||||
|
"""Returns the op to optimize the loss."""
|
||||||
|
train_ops = []
|
||||||
|
global_step = training_util.get_global_step()
|
||||||
|
if dnn_logits is not None:
|
||||||
|
train_ops.append(
|
||||||
|
dnn_optimizer.minimize(
|
||||||
|
loss,
|
||||||
|
var_list=ops.get_collection(
|
||||||
|
ops.GraphKeys.TRAINABLE_VARIABLES,
|
||||||
|
scope=dnn_parent_scope)))
|
||||||
|
if linear_logits is not None:
|
||||||
|
train_ops.append(
|
||||||
|
linear_optimizer.minimize(
|
||||||
|
loss,
|
||||||
|
var_list=ops.get_collection(
|
||||||
|
ops.GraphKeys.TRAINABLE_VARIABLES,
|
||||||
|
scope=linear_parent_scope)))
|
||||||
|
|
||||||
|
train_op = control_flow_ops.group(*train_ops)
|
||||||
|
with ops.control_dependencies([train_op]):
|
||||||
|
with ops.colocate_with(global_step):
|
||||||
|
return state_ops.assign_add(global_step, 1)
|
||||||
|
|
||||||
|
return head.create_estimator_spec(
|
||||||
|
features=features,
|
||||||
|
mode=mode,
|
||||||
|
labels=labels,
|
||||||
|
train_op_fn=_train_op_fn,
|
||||||
|
logits=logits)
|
||||||
|
|
||||||
|
|
||||||
|
class DNNLinearCombinedRegressor(estimator.Estimator):
|
||||||
|
"""An estimator for TensorFlow Linear and DNN joined models for regresssion.
|
||||||
|
|
||||||
|
Note: This estimator is also known as wide-n-deep.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
numeric_feature = numeric_column(...)
|
||||||
|
sparse_column_a = categorical_column_with_hash_bucket(...)
|
||||||
|
sparse_column_b = categorical_column_with_hash_bucket(...)
|
||||||
|
|
||||||
|
sparse_feature_a_x_sparse_feature_b = crossed_column(...)
|
||||||
|
sparse_feature_a_emb = embedding_column(sparse_id_column=sparse_feature_a,
|
||||||
|
...)
|
||||||
|
sparse_feature_b_emb = embedding_column(sparse_id_column=sparse_feature_b,
|
||||||
|
...)
|
||||||
|
|
||||||
|
estimator = DNNLinearCombinedRegressor(
|
||||||
|
# wide settings
|
||||||
|
linear_feature_columns=[sparse_feature_a_x_sparse_feature_b],
|
||||||
|
linear_optimizer=tf.train.FtrlOptimizer(...),
|
||||||
|
# deep settings
|
||||||
|
dnn_feature_columns=[
|
||||||
|
sparse_feature_a_emb, sparse_feature_b_emb, numeric_feature],
|
||||||
|
dnn_hidden_units=[1000, 500, 100],
|
||||||
|
dnn_optimizer=tf.train.ProximalAdagradOptimizer(...))
|
||||||
|
|
||||||
|
# To apply L1 and L2 regularization, you can set optimizers as follows:
|
||||||
|
tf.train.ProximalAdagradOptimizer(
|
||||||
|
learning_rate=0.1,
|
||||||
|
l1_regularization_strength=0.001,
|
||||||
|
l2_regularization_strength=0.001)
|
||||||
|
# It is same for FtrlOptimizer.
|
||||||
|
|
||||||
|
# Input builders
|
||||||
|
def input_fn_train: # returns x, y
|
||||||
|
pass
|
||||||
|
estimator.train(input_fn=input_fn_train, steps=100)
|
||||||
|
|
||||||
|
def input_fn_eval: # returns x, y
|
||||||
|
pass
|
||||||
|
metrics = estimator.evaluate(input_fn=input_fn_eval, steps=10)
|
||||||
|
def input_fn_predict: # returns x, None
|
||||||
|
pass
|
||||||
|
predictions = estimator.predict(input_fn=input_fn_predict)
|
||||||
|
```
|
||||||
|
|
||||||
|
Input of `train` and `evaluate` should have following features,
|
||||||
|
otherwise there will be a `KeyError`:
|
||||||
|
|
||||||
|
* for each `column` in `dnn_feature_columns` + `linear_feature_columns`:
|
||||||
|
- if `column` is a `_CategoricalColumn`, a feature with `key=column.name`
|
||||||
|
whose `value` is a `SparseTensor`.
|
||||||
|
- if `column` is a `_WeightedCategoricalColumn`, two features: the first
|
||||||
|
with `key` the id column name, the second with `key` the weight column
|
||||||
|
name. Both features' `value` must be a `SparseTensor`.
|
||||||
|
- if `column` is a `_DenseColumn`, a feature with `key=column.name`
|
||||||
|
whose `value` is a `Tensor`.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
model_dir=None,
|
||||||
|
linear_feature_columns=None,
|
||||||
|
linear_optimizer=None,
|
||||||
|
dnn_feature_columns=None,
|
||||||
|
dnn_optimizer=None,
|
||||||
|
dnn_hidden_units=None,
|
||||||
|
dnn_activation_fn=nn.relu,
|
||||||
|
dnn_dropout=None,
|
||||||
|
label_dimension=1,
|
||||||
|
input_layer_partitioner=None,
|
||||||
|
config=None):
|
||||||
|
"""Initializes a DNNLinearCombinedRegressor instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_dir: Directory to save model parameters, graph and etc. This can
|
||||||
|
also be used to load checkpoints from the directory into a estimator
|
||||||
|
to continue training a previously saved model.
|
||||||
|
linear_feature_columns: An iterable containing all the feature columns
|
||||||
|
used by linear part of the model. All items in the set must be
|
||||||
|
instances of classes derived from `FeatureColumn`.
|
||||||
|
linear_optimizer: An instance of `tf.Optimizer` used to apply gradients to
|
||||||
|
the linear part of the model. If `None`, will use a FTRL optimizer.
|
||||||
|
dnn_feature_columns: An iterable containing all the feature columns used
|
||||||
|
by deep part of the model. All items in the set must be instances of
|
||||||
|
classes derived from `FeatureColumn`.
|
||||||
|
dnn_optimizer: An instance of `tf.Optimizer` used to apply gradients to
|
||||||
|
the deep part of the model. If `None`, will use an Adagrad optimizer.
|
||||||
|
dnn_hidden_units: List of hidden units per layer. All layers are fully
|
||||||
|
connected.
|
||||||
|
dnn_activation_fn: Activation function applied to each layer. If None,
|
||||||
|
will use `tf.nn.relu`.
|
||||||
|
dnn_dropout: When not None, the probability we will drop out
|
||||||
|
a given coordinate.
|
||||||
|
label_dimension: Number of regression targets per example. This is the
|
||||||
|
size of the last dimension of the labels and logits `Tensor` objects
|
||||||
|
(typically, these have shape `[batch_size, label_dimension]`).
|
||||||
|
input_layer_partitioner: Partitioner for input layer. Defaults to
|
||||||
|
`min_max_variable_partitioner` with `min_slice_size` 64 << 20.
|
||||||
|
config: RunConfig object to configure the runtime settings.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If both linear_feature_columns and dnn_features_columns are
|
||||||
|
empty at the same time.
|
||||||
|
"""
|
||||||
|
linear_feature_columns = linear_feature_columns or []
|
||||||
|
dnn_feature_columns = dnn_feature_columns or []
|
||||||
|
self._feature_columns = linear_feature_columns + dnn_feature_columns
|
||||||
|
if not self._feature_columns:
|
||||||
|
raise ValueError('Either linear_feature_columns or dnn_feature_columns '
|
||||||
|
'must be defined.')
|
||||||
|
|
||||||
|
def _model_fn(features, labels, mode, config):
|
||||||
|
return _dnn_linear_combined_model_fn(
|
||||||
|
features=features,
|
||||||
|
labels=labels,
|
||||||
|
mode=mode,
|
||||||
|
head=head_lib._regression_head_with_mean_squared_error_loss( # pylint: disable=protected-access
|
||||||
|
label_dimension=label_dimension),
|
||||||
|
linear_feature_columns=linear_feature_columns,
|
||||||
|
linear_optimizer=linear_optimizer,
|
||||||
|
dnn_feature_columns=dnn_feature_columns,
|
||||||
|
dnn_optimizer=dnn_optimizer,
|
||||||
|
dnn_hidden_units=dnn_hidden_units,
|
||||||
|
dnn_activation_fn=dnn_activation_fn,
|
||||||
|
dnn_dropout=dnn_dropout,
|
||||||
|
input_layer_partitioner=input_layer_partitioner,
|
||||||
|
config=config)
|
||||||
|
|
||||||
|
super(DNNLinearCombinedRegressor, self).__init__(
|
||||||
|
model_fn=_model_fn, model_dir=model_dir, config=config)
|
||||||
Loading…
Reference in New Issue
Block a user