mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 00:20:20 +01:00
Expands integration tests in dnn_test.
PiperOrigin-RevId: 157476608
This commit is contained in:
parent
21461213dd
commit
6c3b15915d
|
|
@ -119,6 +119,7 @@ py_test(
|
||||||
":metric_keys",
|
":metric_keys",
|
||||||
":model_fn",
|
":model_fn",
|
||||||
":numpy_io",
|
":numpy_io",
|
||||||
|
":pandas_io",
|
||||||
":prediction_keys",
|
":prediction_keys",
|
||||||
"//tensorflow/core:protos_all_py",
|
"//tensorflow/core:protos_all_py",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
|
|
@ -126,9 +127,11 @@ py_test(
|
||||||
"//tensorflow/python:client",
|
"//tensorflow/python:client",
|
||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
"//tensorflow/python:constant_op",
|
"//tensorflow/python:constant_op",
|
||||||
|
"//tensorflow/python:data_flow_ops",
|
||||||
"//tensorflow/python:dtypes",
|
"//tensorflow/python:dtypes",
|
||||||
"//tensorflow/python:framework_ops",
|
"//tensorflow/python:framework_ops",
|
||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
|
"//tensorflow/python:parsing_ops",
|
||||||
"//tensorflow/python:platform",
|
"//tensorflow/python:platform",
|
||||||
"//tensorflow/python:state_ops",
|
"//tensorflow/python:state_ops",
|
||||||
"//tensorflow/python:summary",
|
"//tensorflow/python:summary",
|
||||||
|
|
|
||||||
|
|
@ -25,6 +25,8 @@ import tempfile
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import six
|
import six
|
||||||
|
|
||||||
|
from tensorflow.core.example import example_pb2
|
||||||
|
from tensorflow.core.example import feature_pb2
|
||||||
from tensorflow.core.framework import summary_pb2
|
from tensorflow.core.framework import summary_pb2
|
||||||
from tensorflow.python.client import session as tf_session
|
from tensorflow.python.client import session as tf_session
|
||||||
from tensorflow.python.estimator import model_fn
|
from tensorflow.python.estimator import model_fn
|
||||||
|
|
@ -34,25 +36,40 @@ from tensorflow.python.estimator.canned import metric_keys
|
||||||
from tensorflow.python.estimator.canned import prediction_keys
|
from tensorflow.python.estimator.canned import prediction_keys
|
||||||
from tensorflow.python.estimator.export import export
|
from tensorflow.python.estimator.export import export
|
||||||
from tensorflow.python.estimator.inputs import numpy_io
|
from tensorflow.python.estimator.inputs import numpy_io
|
||||||
|
from tensorflow.python.estimator.inputs import pandas_io
|
||||||
from tensorflow.python.feature_column import feature_column
|
from tensorflow.python.feature_column import feature_column
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import check_ops
|
from tensorflow.python.ops import check_ops
|
||||||
|
from tensorflow.python.ops import data_flow_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
|
from tensorflow.python.ops import parsing_ops
|
||||||
from tensorflow.python.ops import state_ops
|
from tensorflow.python.ops import state_ops
|
||||||
from tensorflow.python.ops import variables as variables_lib
|
from tensorflow.python.ops import variables as variables_lib
|
||||||
from tensorflow.python.platform import gfile
|
from tensorflow.python.platform import gfile
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
from tensorflow.python.summary import summary as summary_lib
|
from tensorflow.python.summary import summary as summary_lib
|
||||||
from tensorflow.python.training import checkpoint_utils
|
from tensorflow.python.training import checkpoint_utils
|
||||||
|
from tensorflow.python.training import input as input_lib
|
||||||
from tensorflow.python.training import monitored_session
|
from tensorflow.python.training import monitored_session
|
||||||
from tensorflow.python.training import optimizer
|
from tensorflow.python.training import optimizer
|
||||||
|
from tensorflow.python.training import queue_runner
|
||||||
from tensorflow.python.training import saver
|
from tensorflow.python.training import saver
|
||||||
from tensorflow.python.training import session_run_hook
|
from tensorflow.python.training import session_run_hook
|
||||||
from tensorflow.python.training import training_util
|
from tensorflow.python.training import training_util
|
||||||
|
|
||||||
|
try:
|
||||||
|
# pylint: disable=g-import-not-at-top
|
||||||
|
import pandas as pd
|
||||||
|
HAS_PANDAS = True
|
||||||
|
except IOError:
|
||||||
|
# Pandas writes a temporary file during import. If it fails, don't use pandas.
|
||||||
|
HAS_PANDAS = False
|
||||||
|
except ImportError:
|
||||||
|
HAS_PANDAS = False
|
||||||
|
|
||||||
# Names of variables created by model.
|
# Names of variables created by model.
|
||||||
_LEARNING_RATE_NAME = 'dnn/regression_head/dnn/learning_rate'
|
_LEARNING_RATE_NAME = 'dnn/regression_head/dnn/learning_rate'
|
||||||
_HIDDEN_WEIGHTS_NAME_PATTERN = 'dnn/hiddenlayer_%d/kernel'
|
_HIDDEN_WEIGHTS_NAME_PATTERN = 'dnn/hiddenlayer_%d/kernel'
|
||||||
|
|
@ -503,6 +520,22 @@ class DNNRegressorPredictTest(test.TestCase):
|
||||||
}, next(dnn_regressor.predict(input_fn=input_fn)))
|
}, next(dnn_regressor.predict(input_fn=input_fn)))
|
||||||
|
|
||||||
|
|
||||||
|
def _queue_parsed_features(feature_map):
|
||||||
|
tensors_to_enqueue = []
|
||||||
|
keys = []
|
||||||
|
for key, tensor in six.iteritems(feature_map):
|
||||||
|
keys.append(key)
|
||||||
|
tensors_to_enqueue.append(tensor)
|
||||||
|
queue_dtypes = [x.dtype for x in tensors_to_enqueue]
|
||||||
|
input_queue = data_flow_ops.FIFOQueue(capacity=100, dtypes=queue_dtypes)
|
||||||
|
queue_runner.add_queue_runner(
|
||||||
|
queue_runner.QueueRunner(
|
||||||
|
input_queue,
|
||||||
|
[input_queue.enqueue(tensors_to_enqueue)]))
|
||||||
|
dequeued_tensors = input_queue.dequeue()
|
||||||
|
return {keys[i]: dequeued_tensors[i] for i in range(len(dequeued_tensors))}
|
||||||
|
|
||||||
|
|
||||||
class DNNRegressorIntegrationTest(test.TestCase):
|
class DNNRegressorIntegrationTest(test.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
|
@ -512,44 +545,27 @@ class DNNRegressorIntegrationTest(test.TestCase):
|
||||||
if self._model_dir:
|
if self._model_dir:
|
||||||
shutil.rmtree(self._model_dir)
|
shutil.rmtree(self._model_dir)
|
||||||
|
|
||||||
def test_complete_flow(self):
|
def _test_complete_flow(
|
||||||
label_dimension = 2
|
self, train_input_fn, eval_input_fn, predict_input_fn, input_dimension,
|
||||||
batch_size = 10
|
label_dimension, batch_size):
|
||||||
feature_columns = [feature_column.numeric_column('x', shape=(2,))]
|
feature_columns = [
|
||||||
|
feature_column.numeric_column('x', shape=(input_dimension,))]
|
||||||
est = dnn.DNNRegressor(
|
est = dnn.DNNRegressor(
|
||||||
hidden_units=(2, 2),
|
hidden_units=(2, 2),
|
||||||
feature_columns=feature_columns,
|
feature_columns=feature_columns,
|
||||||
label_dimension=label_dimension,
|
label_dimension=label_dimension,
|
||||||
model_dir=self._model_dir)
|
model_dir=self._model_dir)
|
||||||
data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32)
|
|
||||||
data = data.reshape(batch_size, label_dimension)
|
|
||||||
|
|
||||||
# TRAIN
|
# TRAIN
|
||||||
# learn y = x
|
num_steps = 10
|
||||||
train_input_fn = numpy_io.numpy_input_fn(
|
|
||||||
x={'x': data},
|
|
||||||
y=data,
|
|
||||||
batch_size=batch_size,
|
|
||||||
num_epochs=None,
|
|
||||||
shuffle=True)
|
|
||||||
num_steps = 200
|
|
||||||
est.train(train_input_fn, steps=num_steps)
|
est.train(train_input_fn, steps=num_steps)
|
||||||
|
|
||||||
# EVALUTE
|
# EVALUTE
|
||||||
eval_input_fn = numpy_io.numpy_input_fn(
|
|
||||||
x={'x': data},
|
|
||||||
y=data,
|
|
||||||
batch_size=batch_size,
|
|
||||||
shuffle=False)
|
|
||||||
scores = est.evaluate(eval_input_fn)
|
scores = est.evaluate(eval_input_fn)
|
||||||
self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP])
|
self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP])
|
||||||
self.assertIn('loss', six.iterkeys(scores))
|
self.assertIn('loss', six.iterkeys(scores))
|
||||||
|
|
||||||
# PREDICT
|
# PREDICT
|
||||||
predict_input_fn = numpy_io.numpy_input_fn(
|
|
||||||
x={'x': data},
|
|
||||||
batch_size=batch_size,
|
|
||||||
shuffle=False)
|
|
||||||
predictions = np.array([
|
predictions = np.array([
|
||||||
x[prediction_keys.PredictionKeys.PREDICTIONS]
|
x[prediction_keys.PredictionKeys.PREDICTIONS]
|
||||||
for x in est.predict(predict_input_fn)
|
for x in est.predict(predict_input_fn)
|
||||||
|
|
@ -564,6 +580,120 @@ class DNNRegressorIntegrationTest(test.TestCase):
|
||||||
serving_input_receiver_fn)
|
serving_input_receiver_fn)
|
||||||
self.assertTrue(gfile.Exists(export_dir))
|
self.assertTrue(gfile.Exists(export_dir))
|
||||||
|
|
||||||
|
def test_numpy_input_fn(self):
|
||||||
|
"""Tests complete flow with numpy_input_fn."""
|
||||||
|
label_dimension = 2
|
||||||
|
batch_size = 10
|
||||||
|
data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32)
|
||||||
|
data = data.reshape(batch_size, label_dimension)
|
||||||
|
# learn y = x
|
||||||
|
train_input_fn = numpy_io.numpy_input_fn(
|
||||||
|
x={'x': data},
|
||||||
|
y=data,
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_epochs=None,
|
||||||
|
shuffle=True)
|
||||||
|
eval_input_fn = numpy_io.numpy_input_fn(
|
||||||
|
x={'x': data},
|
||||||
|
y=data,
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=False)
|
||||||
|
predict_input_fn = numpy_io.numpy_input_fn(
|
||||||
|
x={'x': data},
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=False)
|
||||||
|
|
||||||
|
self._test_complete_flow(
|
||||||
|
train_input_fn=train_input_fn,
|
||||||
|
eval_input_fn=eval_input_fn,
|
||||||
|
predict_input_fn=predict_input_fn,
|
||||||
|
input_dimension=label_dimension,
|
||||||
|
label_dimension=label_dimension,
|
||||||
|
batch_size=batch_size)
|
||||||
|
|
||||||
|
def test_pandas_input_fn(self):
|
||||||
|
"""Tests complete flow with pandas_input_fn."""
|
||||||
|
if not HAS_PANDAS:
|
||||||
|
return
|
||||||
|
label_dimension = 1
|
||||||
|
batch_size = 10
|
||||||
|
data = np.linspace(0., 2., batch_size, dtype=np.float32)
|
||||||
|
x = pd.DataFrame({'x': data})
|
||||||
|
y = pd.Series(data)
|
||||||
|
train_input_fn = pandas_io.pandas_input_fn(
|
||||||
|
x=x,
|
||||||
|
y=y,
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_epochs=None,
|
||||||
|
shuffle=True)
|
||||||
|
eval_input_fn = pandas_io.pandas_input_fn(
|
||||||
|
x=x,
|
||||||
|
y=y,
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=False)
|
||||||
|
predict_input_fn = pandas_io.pandas_input_fn(
|
||||||
|
x=x,
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=False)
|
||||||
|
|
||||||
|
self._test_complete_flow(
|
||||||
|
train_input_fn=train_input_fn,
|
||||||
|
eval_input_fn=eval_input_fn,
|
||||||
|
predict_input_fn=predict_input_fn,
|
||||||
|
input_dimension=label_dimension,
|
||||||
|
label_dimension=label_dimension,
|
||||||
|
batch_size=batch_size)
|
||||||
|
|
||||||
|
def test_input_fn_from_parse_example(self):
|
||||||
|
"""Tests complete flow with input_fn constructed from parse_example."""
|
||||||
|
label_dimension = 2
|
||||||
|
batch_size = 10
|
||||||
|
data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32)
|
||||||
|
data = data.reshape(batch_size, label_dimension)
|
||||||
|
|
||||||
|
serialized_examples = []
|
||||||
|
for datum in data:
|
||||||
|
example = example_pb2.Example(features=feature_pb2.Features(
|
||||||
|
feature={
|
||||||
|
'x': feature_pb2.Feature(
|
||||||
|
float_list=feature_pb2.FloatList(value=datum)),
|
||||||
|
'y': feature_pb2.Feature(
|
||||||
|
float_list=feature_pb2.FloatList(value=datum)),
|
||||||
|
}))
|
||||||
|
serialized_examples.append(example.SerializeToString())
|
||||||
|
|
||||||
|
feature_spec = {
|
||||||
|
'x': parsing_ops.FixedLenFeature([label_dimension], dtypes.float32),
|
||||||
|
'y': parsing_ops.FixedLenFeature([label_dimension], dtypes.float32),
|
||||||
|
}
|
||||||
|
def _train_input_fn():
|
||||||
|
feature_map = parsing_ops.parse_example(serialized_examples, feature_spec)
|
||||||
|
features = _queue_parsed_features(feature_map)
|
||||||
|
labels = features.pop('y')
|
||||||
|
return features, labels
|
||||||
|
def _eval_input_fn():
|
||||||
|
feature_map = parsing_ops.parse_example(
|
||||||
|
input_lib.limit_epochs(serialized_examples, num_epochs=1),
|
||||||
|
feature_spec)
|
||||||
|
features = _queue_parsed_features(feature_map)
|
||||||
|
labels = features.pop('y')
|
||||||
|
return features, labels
|
||||||
|
def _predict_input_fn():
|
||||||
|
feature_map = parsing_ops.parse_example(
|
||||||
|
input_lib.limit_epochs(serialized_examples, num_epochs=1),
|
||||||
|
feature_spec)
|
||||||
|
features = _queue_parsed_features(feature_map)
|
||||||
|
features.pop('y')
|
||||||
|
return features, None
|
||||||
|
|
||||||
|
self._test_complete_flow(
|
||||||
|
train_input_fn=_train_input_fn,
|
||||||
|
eval_input_fn=_eval_input_fn,
|
||||||
|
predict_input_fn=_predict_input_fn,
|
||||||
|
input_dimension=label_dimension,
|
||||||
|
label_dimension=label_dimension,
|
||||||
|
batch_size=batch_size)
|
||||||
|
|
||||||
|
|
||||||
def _full_var_name(var_name):
|
def _full_var_name(var_name):
|
||||||
return '%s/part_0:0' % var_name
|
return '%s/part_0:0' % var_name
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user