mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Expands integration tests in dnn_test.
PiperOrigin-RevId: 157476608
This commit is contained in:
parent
21461213dd
commit
6c3b15915d
|
|
@ -119,6 +119,7 @@ py_test(
|
|||
":metric_keys",
|
||||
":model_fn",
|
||||
":numpy_io",
|
||||
":pandas_io",
|
||||
":prediction_keys",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
|
|
@ -126,9 +127,11 @@ py_test(
|
|||
"//tensorflow/python:client",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:data_flow_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:parsing_ops",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:state_ops",
|
||||
"//tensorflow/python:summary",
|
||||
|
|
|
|||
|
|
@ -25,6 +25,8 @@ import tempfile
|
|||
import numpy as np
|
||||
import six
|
||||
|
||||
from tensorflow.core.example import example_pb2
|
||||
from tensorflow.core.example import feature_pb2
|
||||
from tensorflow.core.framework import summary_pb2
|
||||
from tensorflow.python.client import session as tf_session
|
||||
from tensorflow.python.estimator import model_fn
|
||||
|
|
@ -34,25 +36,40 @@ from tensorflow.python.estimator.canned import metric_keys
|
|||
from tensorflow.python.estimator.canned import prediction_keys
|
||||
from tensorflow.python.estimator.export import export
|
||||
from tensorflow.python.estimator.inputs import numpy_io
|
||||
from tensorflow.python.estimator.inputs import pandas_io
|
||||
from tensorflow.python.feature_column import feature_column
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
from tensorflow.python.ops import data_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import parsing_ops
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import variables as variables_lib
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.summary import summary as summary_lib
|
||||
from tensorflow.python.training import checkpoint_utils
|
||||
from tensorflow.python.training import input as input_lib
|
||||
from tensorflow.python.training import monitored_session
|
||||
from tensorflow.python.training import optimizer
|
||||
from tensorflow.python.training import queue_runner
|
||||
from tensorflow.python.training import saver
|
||||
from tensorflow.python.training import session_run_hook
|
||||
from tensorflow.python.training import training_util
|
||||
|
||||
try:
|
||||
# pylint: disable=g-import-not-at-top
|
||||
import pandas as pd
|
||||
HAS_PANDAS = True
|
||||
except IOError:
|
||||
# Pandas writes a temporary file during import. If it fails, don't use pandas.
|
||||
HAS_PANDAS = False
|
||||
except ImportError:
|
||||
HAS_PANDAS = False
|
||||
|
||||
# Names of variables created by model.
|
||||
_LEARNING_RATE_NAME = 'dnn/regression_head/dnn/learning_rate'
|
||||
_HIDDEN_WEIGHTS_NAME_PATTERN = 'dnn/hiddenlayer_%d/kernel'
|
||||
|
|
@ -503,6 +520,22 @@ class DNNRegressorPredictTest(test.TestCase):
|
|||
}, next(dnn_regressor.predict(input_fn=input_fn)))
|
||||
|
||||
|
||||
def _queue_parsed_features(feature_map):
|
||||
tensors_to_enqueue = []
|
||||
keys = []
|
||||
for key, tensor in six.iteritems(feature_map):
|
||||
keys.append(key)
|
||||
tensors_to_enqueue.append(tensor)
|
||||
queue_dtypes = [x.dtype for x in tensors_to_enqueue]
|
||||
input_queue = data_flow_ops.FIFOQueue(capacity=100, dtypes=queue_dtypes)
|
||||
queue_runner.add_queue_runner(
|
||||
queue_runner.QueueRunner(
|
||||
input_queue,
|
||||
[input_queue.enqueue(tensors_to_enqueue)]))
|
||||
dequeued_tensors = input_queue.dequeue()
|
||||
return {keys[i]: dequeued_tensors[i] for i in range(len(dequeued_tensors))}
|
||||
|
||||
|
||||
class DNNRegressorIntegrationTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
|
|
@ -512,44 +545,27 @@ class DNNRegressorIntegrationTest(test.TestCase):
|
|||
if self._model_dir:
|
||||
shutil.rmtree(self._model_dir)
|
||||
|
||||
def test_complete_flow(self):
|
||||
label_dimension = 2
|
||||
batch_size = 10
|
||||
feature_columns = [feature_column.numeric_column('x', shape=(2,))]
|
||||
def _test_complete_flow(
|
||||
self, train_input_fn, eval_input_fn, predict_input_fn, input_dimension,
|
||||
label_dimension, batch_size):
|
||||
feature_columns = [
|
||||
feature_column.numeric_column('x', shape=(input_dimension,))]
|
||||
est = dnn.DNNRegressor(
|
||||
hidden_units=(2, 2),
|
||||
feature_columns=feature_columns,
|
||||
label_dimension=label_dimension,
|
||||
model_dir=self._model_dir)
|
||||
data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32)
|
||||
data = data.reshape(batch_size, label_dimension)
|
||||
|
||||
# TRAIN
|
||||
# learn y = x
|
||||
train_input_fn = numpy_io.numpy_input_fn(
|
||||
x={'x': data},
|
||||
y=data,
|
||||
batch_size=batch_size,
|
||||
num_epochs=None,
|
||||
shuffle=True)
|
||||
num_steps = 200
|
||||
num_steps = 10
|
||||
est.train(train_input_fn, steps=num_steps)
|
||||
|
||||
# EVALUTE
|
||||
eval_input_fn = numpy_io.numpy_input_fn(
|
||||
x={'x': data},
|
||||
y=data,
|
||||
batch_size=batch_size,
|
||||
shuffle=False)
|
||||
scores = est.evaluate(eval_input_fn)
|
||||
self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP])
|
||||
self.assertIn('loss', six.iterkeys(scores))
|
||||
|
||||
# PREDICT
|
||||
predict_input_fn = numpy_io.numpy_input_fn(
|
||||
x={'x': data},
|
||||
batch_size=batch_size,
|
||||
shuffle=False)
|
||||
predictions = np.array([
|
||||
x[prediction_keys.PredictionKeys.PREDICTIONS]
|
||||
for x in est.predict(predict_input_fn)
|
||||
|
|
@ -564,6 +580,120 @@ class DNNRegressorIntegrationTest(test.TestCase):
|
|||
serving_input_receiver_fn)
|
||||
self.assertTrue(gfile.Exists(export_dir))
|
||||
|
||||
def test_numpy_input_fn(self):
|
||||
"""Tests complete flow with numpy_input_fn."""
|
||||
label_dimension = 2
|
||||
batch_size = 10
|
||||
data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32)
|
||||
data = data.reshape(batch_size, label_dimension)
|
||||
# learn y = x
|
||||
train_input_fn = numpy_io.numpy_input_fn(
|
||||
x={'x': data},
|
||||
y=data,
|
||||
batch_size=batch_size,
|
||||
num_epochs=None,
|
||||
shuffle=True)
|
||||
eval_input_fn = numpy_io.numpy_input_fn(
|
||||
x={'x': data},
|
||||
y=data,
|
||||
batch_size=batch_size,
|
||||
shuffle=False)
|
||||
predict_input_fn = numpy_io.numpy_input_fn(
|
||||
x={'x': data},
|
||||
batch_size=batch_size,
|
||||
shuffle=False)
|
||||
|
||||
self._test_complete_flow(
|
||||
train_input_fn=train_input_fn,
|
||||
eval_input_fn=eval_input_fn,
|
||||
predict_input_fn=predict_input_fn,
|
||||
input_dimension=label_dimension,
|
||||
label_dimension=label_dimension,
|
||||
batch_size=batch_size)
|
||||
|
||||
def test_pandas_input_fn(self):
|
||||
"""Tests complete flow with pandas_input_fn."""
|
||||
if not HAS_PANDAS:
|
||||
return
|
||||
label_dimension = 1
|
||||
batch_size = 10
|
||||
data = np.linspace(0., 2., batch_size, dtype=np.float32)
|
||||
x = pd.DataFrame({'x': data})
|
||||
y = pd.Series(data)
|
||||
train_input_fn = pandas_io.pandas_input_fn(
|
||||
x=x,
|
||||
y=y,
|
||||
batch_size=batch_size,
|
||||
num_epochs=None,
|
||||
shuffle=True)
|
||||
eval_input_fn = pandas_io.pandas_input_fn(
|
||||
x=x,
|
||||
y=y,
|
||||
batch_size=batch_size,
|
||||
shuffle=False)
|
||||
predict_input_fn = pandas_io.pandas_input_fn(
|
||||
x=x,
|
||||
batch_size=batch_size,
|
||||
shuffle=False)
|
||||
|
||||
self._test_complete_flow(
|
||||
train_input_fn=train_input_fn,
|
||||
eval_input_fn=eval_input_fn,
|
||||
predict_input_fn=predict_input_fn,
|
||||
input_dimension=label_dimension,
|
||||
label_dimension=label_dimension,
|
||||
batch_size=batch_size)
|
||||
|
||||
def test_input_fn_from_parse_example(self):
|
||||
"""Tests complete flow with input_fn constructed from parse_example."""
|
||||
label_dimension = 2
|
||||
batch_size = 10
|
||||
data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32)
|
||||
data = data.reshape(batch_size, label_dimension)
|
||||
|
||||
serialized_examples = []
|
||||
for datum in data:
|
||||
example = example_pb2.Example(features=feature_pb2.Features(
|
||||
feature={
|
||||
'x': feature_pb2.Feature(
|
||||
float_list=feature_pb2.FloatList(value=datum)),
|
||||
'y': feature_pb2.Feature(
|
||||
float_list=feature_pb2.FloatList(value=datum)),
|
||||
}))
|
||||
serialized_examples.append(example.SerializeToString())
|
||||
|
||||
feature_spec = {
|
||||
'x': parsing_ops.FixedLenFeature([label_dimension], dtypes.float32),
|
||||
'y': parsing_ops.FixedLenFeature([label_dimension], dtypes.float32),
|
||||
}
|
||||
def _train_input_fn():
|
||||
feature_map = parsing_ops.parse_example(serialized_examples, feature_spec)
|
||||
features = _queue_parsed_features(feature_map)
|
||||
labels = features.pop('y')
|
||||
return features, labels
|
||||
def _eval_input_fn():
|
||||
feature_map = parsing_ops.parse_example(
|
||||
input_lib.limit_epochs(serialized_examples, num_epochs=1),
|
||||
feature_spec)
|
||||
features = _queue_parsed_features(feature_map)
|
||||
labels = features.pop('y')
|
||||
return features, labels
|
||||
def _predict_input_fn():
|
||||
feature_map = parsing_ops.parse_example(
|
||||
input_lib.limit_epochs(serialized_examples, num_epochs=1),
|
||||
feature_spec)
|
||||
features = _queue_parsed_features(feature_map)
|
||||
features.pop('y')
|
||||
return features, None
|
||||
|
||||
self._test_complete_flow(
|
||||
train_input_fn=_train_input_fn,
|
||||
eval_input_fn=_eval_input_fn,
|
||||
predict_input_fn=_predict_input_fn,
|
||||
input_dimension=label_dimension,
|
||||
label_dimension=label_dimension,
|
||||
batch_size=batch_size)
|
||||
|
||||
|
||||
def _full_var_name(var_name):
|
||||
return '%s/part_0:0' % var_name
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user