Add training test for multi classes (n>2) linear classifier.

PiperOrigin-RevId: 157896002
This commit is contained in:
Jianwei Xie 2017-06-02 17:45:41 -07:00 committed by TensorFlower Gardener
parent 675d36be0d
commit 1c70fb6869

View File

@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import os
import shutil
import tempfile
@ -648,7 +649,7 @@ class LinearRegressorTrainingTest(test.TestCase):
if self._model_dir:
shutil.rmtree(self._model_dir)
def _mockOptimizer(self, expected_loss=None):
def _mock_optimizer(self, expected_loss=None):
expected_var_names = [
'%s/part_0:0' % _AGE_WEIGHT_NAME,
'%s/part_0:0' % _BIAS_NAME
@ -680,7 +681,7 @@ class LinearRegressorTrainingTest(test.TestCase):
mock_optimizer.__deepcopy__ = lambda _: mock_optimizer
return mock_optimizer
def _assertCheckpoint(
def _assert_checkpoint(
self, expected_global_step, expected_age_weight=None, expected_bias=None):
shapes = {
name: shape for (name, shape) in
@ -717,7 +718,7 @@ class LinearRegressorTrainingTest(test.TestCase):
num_steps = 10
linear_regressor.train(
input_fn=lambda: ({'age': ((age,),)}, ((label,),)), steps=num_steps)
self._assertCheckpoint(num_steps)
self._assert_checkpoint(num_steps)
def testTrainWithOneDimLabel(self):
label_dimension = 1
@ -736,7 +737,7 @@ class LinearRegressorTrainingTest(test.TestCase):
batch_size=batch_size, num_epochs=None,
shuffle=True)
est.train(train_input_fn, steps=200)
self._assertCheckpoint(200)
self._assert_checkpoint(200)
def testTrainWithOneDimWeight(self):
label_dimension = 1
@ -757,14 +758,14 @@ class LinearRegressorTrainingTest(test.TestCase):
batch_size=batch_size, num_epochs=None,
shuffle=True)
est.train(train_input_fn, steps=200)
self._assertCheckpoint(200)
self._assert_checkpoint(200)
def testFromScratch(self):
# Create LinearRegressor.
label = 5.
age = 17
# loss = (logits - label)^2 = (0 - 5.)^2 = 25.
mock_optimizer = self._mockOptimizer(expected_loss=25.)
mock_optimizer = self._mock_optimizer(expected_loss=25.)
linear_regressor = linear.LinearRegressor(
feature_columns=(feature_column_lib.numeric_column('age'),),
model_dir=self._model_dir, optimizer=mock_optimizer)
@ -775,7 +776,7 @@ class LinearRegressorTrainingTest(test.TestCase):
linear_regressor.train(
input_fn=lambda: ({'age': ((age,),)}, ((label,),)), steps=num_steps)
self.assertEqual(1, mock_optimizer.minimize.call_count)
self._assertCheckpoint(
self._assert_checkpoint(
expected_global_step=num_steps,
expected_age_weight=0.,
expected_bias=0.)
@ -795,7 +796,7 @@ class LinearRegressorTrainingTest(test.TestCase):
# logits = age * age_weight + bias = 17 * 10. + 5. = 175
# loss = (logits - label)^2 = (175 - 5)^2 = 28900
mock_optimizer = self._mockOptimizer(expected_loss=28900.)
mock_optimizer = self._mock_optimizer(expected_loss=28900.)
linear_regressor = linear.LinearRegressor(
feature_columns=(feature_column_lib.numeric_column('age'),),
model_dir=self._model_dir, optimizer=mock_optimizer)
@ -806,7 +807,7 @@ class LinearRegressorTrainingTest(test.TestCase):
linear_regressor.train(
input_fn=lambda: ({'age': ((17,),)}, ((5.,),)), steps=num_steps)
self.assertEqual(1, mock_optimizer.minimize.call_count)
self._assertCheckpoint(
self._assert_checkpoint(
expected_global_step=initial_global_step + num_steps,
expected_age_weight=age_weight,
expected_bias=bias)
@ -828,7 +829,7 @@ class LinearRegressorTrainingTest(test.TestCase):
# logits[0] = 17 * 10. + 5. = 175
# logits[1] = 15 * 10. + 5. = 155
# loss = sum(logits - label)^2 = (175 - 5)^2 + (155 - 3)^2 = 52004
mock_optimizer = self._mockOptimizer(expected_loss=52004.)
mock_optimizer = self._mock_optimizer(expected_loss=52004.)
linear_regressor = linear.LinearRegressor(
feature_columns=(feature_column_lib.numeric_column('age'),),
model_dir=self._model_dir, optimizer=mock_optimizer)
@ -840,13 +841,18 @@ class LinearRegressorTrainingTest(test.TestCase):
input_fn=lambda: ({'age': ((17,), (15,))}, ((5.,), (3.,))),
steps=num_steps)
self.assertEqual(1, mock_optimizer.minimize.call_count)
self._assertCheckpoint(
self._assert_checkpoint(
expected_global_step=initial_global_step + num_steps,
expected_age_weight=age_weight,
expected_bias=bias)
class LinearClassiferWithBinaryClassesTrainingTest(test.TestCase):
class _BaseLinearClassiferTrainingTest(object):
def __init__(self, n_classes):
self._n_classes = n_classes
self._logits_dimensions = (
self._n_classes if self._n_classes > 2 else 1)
def setUp(self):
self._model_dir = tempfile.mkdtemp()
@ -855,7 +861,7 @@ class LinearClassiferWithBinaryClassesTrainingTest(test.TestCase):
if self._model_dir:
shutil.rmtree(self._model_dir)
def _mockOptimizer(self, expected_loss=None):
def _mock_optimizer(self, expected_loss=None):
expected_var_names = [
'%s/part_0:0' % _AGE_WEIGHT_NAME,
'%s/part_0:0' % _BIAS_NAME
@ -887,8 +893,10 @@ class LinearClassiferWithBinaryClassesTrainingTest(test.TestCase):
mock_optimizer.__deepcopy__ = lambda _: mock_optimizer
return mock_optimizer
def _assertCheckpoint(
def _assert_checkpoint(
self, expected_global_step, expected_age_weight=None, expected_bias=None):
logits_dimension = self._logits_dimensions
shapes = {
name: shape for (name, shape) in
checkpoint_utils.list_variables(self._model_dir)
@ -900,20 +908,20 @@ class LinearClassiferWithBinaryClassesTrainingTest(test.TestCase):
checkpoint_utils.load_variable(
self._model_dir, ops.GraphKeys.GLOBAL_STEP))
self.assertEqual([1, 1], shapes[_AGE_WEIGHT_NAME])
self.assertEqual([1, logits_dimension], shapes[_AGE_WEIGHT_NAME])
if expected_age_weight is not None:
self.assertEqual(
self.assertAllEqual(
expected_age_weight,
checkpoint_utils.load_variable(self._model_dir, _AGE_WEIGHT_NAME))
self.assertEqual([1], shapes[_BIAS_NAME])
self.assertEqual([logits_dimension], shapes[_BIAS_NAME])
if expected_bias is not None:
self.assertEqual(
self.assertAllEqual(
expected_bias,
checkpoint_utils.load_variable(self._model_dir, _BIAS_NAME))
def testFromScratchWithDefaultOptimizer(self):
n_classes = 2
n_classes = self._n_classes
label = 0
age = 17
est = linear.LinearClassifier(
@ -925,10 +933,10 @@ class LinearClassiferWithBinaryClassesTrainingTest(test.TestCase):
num_steps = 10
est.train(
input_fn=lambda: ({'age': ((age,),)}, ((label,),)), steps=num_steps)
self._assertCheckpoint(num_steps)
self._assert_checkpoint(num_steps)
def testTrainWithTwoDimsLabel(self):
n_classes = 2
n_classes = self._n_classes
batch_size = 20
est = linear.LinearClassifier(
@ -947,10 +955,10 @@ class LinearClassiferWithBinaryClassesTrainingTest(test.TestCase):
num_epochs=None,
shuffle=True)
est.train(train_input_fn, steps=200)
self._assertCheckpoint(200)
self._assert_checkpoint(200)
def testTrainWithOneDimLabel(self):
n_classes = 2
n_classes = self._n_classes
batch_size = 20
est = linear.LinearClassifier(
@ -967,10 +975,10 @@ class LinearClassiferWithBinaryClassesTrainingTest(test.TestCase):
num_epochs=None,
shuffle=True)
est.train(train_input_fn, steps=200)
self._assertCheckpoint(200)
self._assert_checkpoint(200)
def testTrainWithTwoDimsWeight(self):
n_classes = 2
n_classes = self._n_classes
batch_size = 20
est = linear.LinearClassifier(
@ -988,10 +996,10 @@ class LinearClassiferWithBinaryClassesTrainingTest(test.TestCase):
batch_size=batch_size, num_epochs=None,
shuffle=True)
est.train(train_input_fn, steps=200)
self._assertCheckpoint(200)
self._assert_checkpoint(200)
def testTrainWithOneDimWeight(self):
n_classes = 2
n_classes = self._n_classes
batch_size = 20
est = linear.LinearClassifier(
@ -1007,16 +1015,24 @@ class LinearClassiferWithBinaryClassesTrainingTest(test.TestCase):
batch_size=batch_size, num_epochs=None,
shuffle=True)
est.train(train_input_fn, steps=200)
self._assertCheckpoint(200)
self._assert_checkpoint(200)
def testFromScratch(self):
n_classes = 2
n_classes = self._n_classes
label = 1
age = 17
# loss = sigmoid_cross_entropy(logits, label) where logits = 0 (weights are
# For binary classifer:
# loss = sigmoid_cross_entropy(logits, label) where logits=0 (weights are
# all zero initially) and label = 1 so,
# loss = 1 * -log ( sigmoid(logits) ) = 0.69315
mock_optimizer = self._mockOptimizer(expected_loss=0.69315)
# For multi class classifer:
# loss = cross_entropy(logits, label) where logits are all 0s (weights are
# all zero initially) and label = 1 so,
# loss = 1 * -log ( 1.0 / n_classes )
# For this particular test case, as logits are same, the formular
# 1 * -log ( 1.0 / n_classes ) covers both binary and multi class cases.
mock_optimizer = self._mock_optimizer(
expected_loss=-1 * math.log(1.0/n_classes))
est = linear.LinearClassifier(
feature_columns=(feature_column_lib.numeric_column('age'),),
@ -1030,31 +1046,49 @@ class LinearClassiferWithBinaryClassesTrainingTest(test.TestCase):
est.train(
input_fn=lambda: ({'age': ((age,),)}, ((label,),)), steps=num_steps)
self.assertEqual(1, mock_optimizer.minimize.call_count)
self._assertCheckpoint(
self._assert_checkpoint(
expected_global_step=num_steps,
expected_age_weight=0.,
expected_bias=0.)
expected_age_weight=[[0.]] if n_classes == 2 else [[0.] * n_classes],
expected_bias=[0.] if n_classes == 2 else [.0] * n_classes)
def testFromCheckpoint(self):
# Create initial checkpoint.
n_classes = 2
n_classes = self._n_classes
label = 1
age = 17
age_weight = 2.0
bias = -35.0
# For binary case, the expected weight has shape (1,1). For multi class
# case, the shape is (1, n_classes). In order to test the weights, set
# weights as 2.0 * range(n_classes).
age_weight = [[2.0]] if n_classes == 2 else (
np.reshape(2.0 * np.array(list(range(n_classes)), dtype=np.float32),
(1, n_classes)))
bias = [-35.0] if n_classes == 2 else [-35.0] * n_classes
initial_global_step = 100
with ops.Graph().as_default():
variables.Variable([[age_weight]], name=_AGE_WEIGHT_NAME)
variables.Variable([bias], name=_BIAS_NAME)
variables.Variable(age_weight, name=_AGE_WEIGHT_NAME)
variables.Variable(bias, name=_BIAS_NAME)
variables.Variable(
initial_global_step, name=ops.GraphKeys.GLOBAL_STEP,
dtype=dtypes.int64)
_save_variables_to_ckpt(self._model_dir)
# For binary classifer:
# logits = age * age_weight + bias = 17 * 2. - 35. = -1.
# loss = sigmoid_cross_entropy(logits, label)
# so, loss = 1 * -log ( sigmoid(-1) ) = 1.3133
mock_optimizer = self._mockOptimizer(expected_loss=1.3133)
# For multi class classifer:
# loss = cross_entropy(logits, label)
# where logits = 17 * age_weight + bias and label = 1
# so, loss = 1 * -log ( soft_max(logits)[1] )
if n_classes == 2:
expected_loss = 1.3133
else:
logits = age_weight * age + bias
logits_exp = np.exp(logits)
softmax = logits_exp / logits_exp.sum()
expected_loss = -1 * math.log(softmax[0, label])
mock_optimizer = self._mock_optimizer(expected_loss=expected_loss)
est = linear.LinearClassifier(
feature_columns=(feature_column_lib.numeric_column('age'),),
@ -1068,34 +1102,55 @@ class LinearClassiferWithBinaryClassesTrainingTest(test.TestCase):
est.train(
input_fn=lambda: ({'age': ((age,),)}, ((label,),)), steps=num_steps)
self.assertEqual(1, mock_optimizer.minimize.call_count)
self._assertCheckpoint(
self._assert_checkpoint(
expected_global_step=initial_global_step + num_steps,
expected_age_weight=age_weight,
expected_bias=bias)
def testFromCheckpointMultiBatch(self):
# Create initial checkpoint.
n_classes = 2
n_classes = self._n_classes
label = [1, 0]
age = [17, 18.5]
age_weight = 2.0
bias = -35.0
# For binary case, the expected weight has shape (1,1). For multi class
# case, the shape is (1, n_classes). In order to test the weights, set
# weights as 2.0 * range(n_classes).
age_weight = [[2.0]] if n_classes == 2 else (
np.reshape(2.0 * np.array(list(range(n_classes)), dtype=np.float32),
(1, n_classes)))
bias = [-35.0] if n_classes == 2 else [-35.0] * n_classes
initial_global_step = 100
with ops.Graph().as_default():
variables.Variable([[age_weight]], name=_AGE_WEIGHT_NAME)
variables.Variable([bias], name=_BIAS_NAME)
variables.Variable(age_weight, name=_AGE_WEIGHT_NAME)
variables.Variable(bias, name=_BIAS_NAME)
variables.Variable(
initial_global_step, name=ops.GraphKeys.GLOBAL_STEP,
dtype=dtypes.int64)
_save_variables_to_ckpt(self._model_dir)
# For binary classifer:
# logits = age * age_weight + bias
# logits[0] = 17 * 2. - 35. = -1.
# logits[1] = 18.5 * 2. - 35. = 2.
# loss = sigmoid_cross_entropy(logits, label)
# so, loss[0] = 1 * -log ( sigmoid(-1) ) = 1.3133
# loss[1] = (1 - 0) * -log ( 1- sigmoid(2) ) = 2.1269
mock_optimizer = self._mockOptimizer(expected_loss=1.3133 + 2.1269)
# For multi class classifer:
# loss = cross_entropy(logits, label)
# where logits = [17, 18.5] * age_weight + bias and label = [1, 0]
# so, loss = 1 * -log ( soft_max(logits)[label] )
if n_classes == 2:
expected_loss = (1.3133 + 2.1269)
else:
logits = age_weight * np.reshape(age, (2, 1)) + bias
logits_exp = np.exp(logits)
softmax_row_0 = logits_exp[0] / logits_exp[0].sum()
softmax_row_1 = logits_exp[1] / logits_exp[1].sum()
expected_loss_0 = -1 * math.log(softmax_row_0[label[0]])
expected_loss_1 = -1 * math.log(softmax_row_1[label[1]])
expected_loss = expected_loss_0 + expected_loss_1
mock_optimizer = self._mock_optimizer(expected_loss=expected_loss)
est = linear.LinearClassifier(
feature_columns=(feature_column_lib.numeric_column('age'),),
@ -1110,10 +1165,27 @@ class LinearClassiferWithBinaryClassesTrainingTest(test.TestCase):
input_fn=lambda: ({'age': (age)}, (label)),
steps=num_steps)
self.assertEqual(1, mock_optimizer.minimize.call_count)
self._assertCheckpoint(
self._assert_checkpoint(
expected_global_step=initial_global_step + num_steps,
expected_age_weight=age_weight,
expected_bias=bias)
class LinearClassiferWithBinaryClassesTrainingTest(
_BaseLinearClassiferTrainingTest, test.TestCase):
def __init__(self, methodName='runTest'):
test.TestCase.__init__(self, methodName)
_BaseLinearClassiferTrainingTest.__init__(self, n_classes=2)
class LinearClassiferWithMultiClassesTrainingTest(
_BaseLinearClassiferTrainingTest, test.TestCase):
def __init__(self, methodName='runTest'):
test.TestCase.__init__(self, methodName)
_BaseLinearClassiferTrainingTest.__init__(self, n_classes=4)
if __name__ == '__main__':
test.main()