From b70a86f269129ffe9206b9d8a2decd7d22dc68f1 Mon Sep 17 00:00:00 2001 From: Ran Chen Date: Fri, 10 Jul 2020 12:45:04 -0700 Subject: [PATCH] Add a test for known issues with saved model PiperOrigin-RevId: 320656309 Change-Id: I6f749752b9d29a3cb474d0e0930e8cbec293ce2d --- .../python/distribute/integration_test/BUILD | 20 + .../integration_test/saved_model_test.py | 484 ++++++++++++++++++ 2 files changed, 504 insertions(+) create mode 100644 tensorflow/python/distribute/integration_test/BUILD create mode 100644 tensorflow/python/distribute/integration_test/saved_model_test.py diff --git a/tensorflow/python/distribute/integration_test/BUILD b/tensorflow/python/distribute/integration_test/BUILD new file mode 100644 index 00000000000..d997e64be05 --- /dev/null +++ b/tensorflow/python/distribute/integration_test/BUILD @@ -0,0 +1,20 @@ +load("//tensorflow/core/platform/default:distribute.bzl", "distribute_py_test") + +package( + default_visibility = ["//tensorflow:internal"], + licenses = ["notice"], # Apache 2.0 +) + +distribute_py_test( + name = "saved_model_test", + srcs = ["saved_model_test.py"], + deps = [ + "//tensorflow:tensorflow_py", + "//tensorflow/python/distribute:combinations", + "//tensorflow/python/distribute:multi_process_runner", + "//tensorflow/python/distribute:multi_worker_test_base", + "//tensorflow/python/distribute:strategy_combinations", + "//tensorflow/python/eager:test", + "@absl_py//absl/testing:parameterized", + ], +) diff --git a/tensorflow/python/distribute/integration_test/saved_model_test.py b/tensorflow/python/distribute/integration_test/saved_model_test.py new file mode 100644 index 00000000000..60de590bb48 --- /dev/null +++ b/tensorflow/python/distribute/integration_test/saved_model_test.py @@ -0,0 +1,484 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""This file contains unit tests of past, existing and potential saving issues with tf.distribute. + +The tests are written as minimum reproductions in the hope to demonstrate the +expected behavior in a straightforward fashion. + +Test cases ending with _broken are known issues. Assertions in such tests +describes the current incorrect behavior. If you fix something, you should +expect some test cases to fail and please update them. + +This file is not intended to provide exhaustive test coverage. Exhaustive tests +using Keras models are in keras*_test.py +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized + +import tensorflow as tf +import tensorflow.compat.v1 as tf1 + +from tensorflow.python.distribute import combinations +from tensorflow.python.distribute import strategy_combinations +from tensorflow.python.eager import test + + +@combinations.generate( + combinations.combine( + strategy=[ + strategy_combinations.mirrored_strategy_with_gpu_and_cpu, + strategy_combinations.tpu_strategy, + ], + mode="eager", + )) +class SaveAndLoadForServingTest(test.TestCase, parameterized.TestCase): + # These test cases cover the case when a model is trained under + # tf.distribute.Strategy and used for serving later. Serving usually only uses + # one device and this is simulated by loading the model under no strategy + # using tf.saved_model.load. Note that tf.keras.models.load_model doesn't use + # the saved functions so they're not fit for this test. + # + # When saving, it's expected that the model is saved as if there's no + # tf.distribute.Strategy. The saved tf.function should be an inference + # function on a single device, and the distributed variables are saved as + # single variables. + # + # Curently references to components of a distributed variable are mapped to + # the single variable that is saved. This means that if the saved tf.functions + # access components of a distributed variable, for example if it triggers + # variable aggregation, the outputs are likely incorrect. + # + # Note that distributed variables have different behavior in the replica + # context and the cross-replica context. Saving happens in the cross replica + # context or the default startegy's replica context. + + def test_read_sync_on_read_variable_broken(self, strategy): + # synchronizaiton=ON_READ variables are typically used in Keras metrics and + # batch norm layers. + # + # This is broken now since the saved variable already has the aggregated + # value, but the saved tf.function is traced under the cross-replica context + # and contains the aggregation. + # + # Impacts: + # - MirroredStrategy, TPUStrategy + # - aggregation=NONE: error when saving. + # - aggregation=SUM: incorrect results. + # - aggregation=MEAN: slight computation overhead. + # - aggregation=ONLY_FIRST_REPLICA: none. + # - MultiWorkerMirroredStrategy: + # - aggregation=NONE: error when saving + # - aggregation=MEAN, SUM: error or hanging when using the loaded model. + # - aggregation=ONLY_FIRST_REPLICA: none. + # Note that batch norm uses aggregation=MEAN. + + class Model(tf.Module): + + def __init__(self): + self.v = tf.Variable( + 0., + synchronization=tf.VariableSynchronization.ON_READ, + aggregation=tf.VariableAggregation.SUM) + + @tf.function(input_signature=[]) + def __call__(self): + return self.v.read_value() + + export_dir = self.get_temp_dir() + with strategy.scope(): + m = Model() + # Note that each component is assigned with the value divided by the + # number of replicas. + m.v.assign(1.) + self.assertAllEqual( + self.evaluate(strategy.experimental_local_results(m.v)), [0.5, 0.5]) + tf.saved_model.save(m, export_dir) + + loaded = tf.saved_model.load(export_dir) + # The variable already has the aggregated value. + self.assertEqual(self.evaluate(loaded.v.read_value()), 1.) + # TODO(b/159752793): reading the variable aggregates the values again. + # got 2., want 1. + self.assertEqual(self.evaluate(loaded()), 2.) + + def test_read_mirrored_variable(self, strategy): + # synchronizaiton=ON_WRITE is the default variable created under + # tf.distribute.Strategy.scope(). Most model parameters are this kind of + # variable. Reading a synchronization=ON_WRITE simply reads the primary + # component, so it works as intended. + + class Model(tf.Module): + + def __init__(self): + self.v = tf.Variable( + 0., synchronization=tf.VariableSynchronization.ON_WRITE) + + @tf.function(input_signature=[]) + def __call__(self): + return self.v.read_value() + + export_dir = self.get_temp_dir() + with strategy.scope(): + m = Model() + m.v.assign(1.) + tf.saved_model.save(m, export_dir) + + loaded = tf.saved_model.load(export_dir) + self.assertEqual(self.evaluate(loaded()), 1.) + + def test_update_sync_on_read_variable_broken(self, strategy): + # It's rare to update aggregation=ON_READ variables in serving, but it's + # possible that the SavedModel contains both serving and training graphs, + # and the training may contain metrics layers. + # + # This is now partially broken since assign_add() and assign_sub() are not + # allowed in the cross-replica context if aggregation=SUM, which blocks + # saving the model. + + class Model(tf.Module): + + def __init__(self): + self.v = tf.Variable( + 0., + synchronization=tf.VariableSynchronization.ON_READ, + aggregation=tf.VariableAggregation.SUM) + + @tf.function(input_signature=[]) + def update(self): + self.v.assign_add(1.) + return self.v.read_value() + + export_dir = self.get_temp_dir() + with strategy.scope(): + m = Model() + # got error, want no error. + with self.assertRaisesRegex(ValueError, + "SyncOnReadVariable does not support"): + tf.saved_model.save(m, export_dir) + + # TODO(b/159752793): Uncomment after fix. + # loaded = tf.saved_model.load(export_dir) + # loaded.update() + # self.assertEqual(self.evaluate(loaded.v), 1.) + + def test_update_mirrored_variable_broken(self, strategy): + # It's very rare to update aggregation=ON_WRITE variables in the forward + # path, and this test case is mainly for completeness. + # + # The saved tf.function updates each components of the distributed variable, + # which effectively updates the variable in the saved model N times where N + # equals the number of local replicas during training. + + class Model(tf.Module): + + def __init__(self): + self.v = tf.Variable( + 0., + synchronization=tf.VariableSynchronization.ON_WRITE, + aggregation=tf.VariableAggregation.MEAN) + + @tf.function(input_signature=[]) + def update(self): + self.v.assign_add(1.) + + export_dir = self.get_temp_dir() + with strategy.scope(): + m = Model() + tf.saved_model.save(m, export_dir) + + loaded = tf.saved_model.load(export_dir) + self.assertEqual(self.evaluate(loaded.v), 0.) + loaded.update() + # TODO(b/159752793): Change after fix. + # got 2., want 1. + self.assertEqual(self.evaluate(loaded.v), 2.) + + def test_training_only_device(self, strategy): + # tf.distribute APIs may enter device scopes, but the saved model should not + # contain device annotations, since devices during training may not be + # available when the saved model is used. + # + # Models trained with MultiWorkerMirroredStrategy is affected the most, + # since under MultiWorkerMirroredStrategy the device annotations contain job + # and task, which causes error if that job or task is not available even + # with soft placement enabled. + + class Model(tf.Module): + + @tf.function(input_signature=[]) + def __call__(self): + return tf.identity(1.) + + export_dir = self.get_temp_dir() + with strategy.scope(), tf.device("GPU:0"): + m = Model() + tf.saved_model.save(m, export_dir) + + export_dir = self.get_temp_dir() + loaded = tf.saved_model.load(export_dir) + graph = loaded.signatures["serving_default"].graph + for op in graph.get_operations(): + self.assertEqual(op.device, "") + self.assertEqual(loaded().numpy(), 1.) + + def test_model_with_loaded_layer_broken(self, strategy): + # If a model contains a layer loaded from SavedModel, including tf.hub + # layers, and if the model is created under tf.distribute.Strategy, it + # cannot be saved again. The saving won't error but the saved model cannot + # be used. + # + # The reason is that if a saved model is loaded under + # tf.distribute.Strategy, the tf.functions are wrapped by + # saved_model._WrapperFunction, which generates an assertion node in the + # cross-replica context. + + class Layer(tf.Module): + + def __init__(self): + self.v = tf.Variable(1.) + + @tf.function(input_signature=[]) + def __call__(self): + return self.v.read_value() + + class Model(tf.Module): + + def __init__(self, layer): + self.layer = layer + + @tf.function(input_signature=[]) + def __call__(self): + return self.layer() + + layer_export_dir = self.get_temp_dir() + tf.saved_model.save(Layer(), layer_export_dir) + + with strategy.scope(): + m = Model(tf.saved_model.load(layer_export_dir)) + export_dir = self.get_temp_dir() + # It happens to work if we save the model outside of strategy.scope(), + # because DistributedVariable.handle and _WrapperFunction behaved + # differently under the cross-replica context and the default strategy's + # replica context. + tf.saved_model.save(m, export_dir) + + loaded = tf.saved_model.load(export_dir) + # got error, want [1., 1.] + if isinstance(strategy, tf.distribute.MirroredStrategy): + with self.assertRaisesRegex( + tf.errors.InvalidArgumentError, + "from the cross-replica context in an in-replica context"): + strategy.run(loaded) + else: + with self.assertRaisesRegex(tf.errors.InvalidArgumentError, + "No registered 'Placeholder'"): + strategy.run(loaded) + # TODO(b/160646235): Uncomment after fix. + # self.assertAllEqual( + # self.evaluate( + # strategy.experimental_local_results(strategy.run(loaded)), + # [1., 1.])) + + +@combinations.generate( + combinations.combine( + strategy=[strategy_combinations.mirrored_strategy_with_gpu_and_cpu], + mode="eager", + )) +class SaveAndLoadForTrainingTest(test.TestCase, parameterized.TestCase): + # These test cases cover the case when the user loads a model and continues to + # train it. The model could originally be trained with or without + # tf.distribute.Strategy. + # + # tf.distribute does not distinguish whether the model is saved for inference + # or for training. The implications are that all issues with serving are + # issues with training as well, possibly with different symptoms. + # + # Note that for Keras models, loading them with tf.keras.models.load_model() + # can workaround most issues since Keras loader restructs the layers with + # saved configs if possible, in which case the saved graph is not used. + + def test_read_sync_on_read_variable_broken(self, strategy): + # Reading a synchronizaiton=ON_READ in the replica context should only read + # the local value, however with a loaded model, reading in the replica + # context triggers aggregation as well. While one may argue the behavior is + # desirable, note that aggregation can cause hanging if the originall model + # is trained with MultiWorkerMirroredStrategy. + + class Model(tf.Module): + + def __init__(self): + self.v = tf.Variable( + 0., + synchronization=tf.VariableSynchronization.ON_READ, + aggregation=tf.VariableAggregation.SUM) + + @tf.function(input_signature=[]) + def __call__(self): + return self.v.read_value() + + export_dir = self.get_temp_dir() + with strategy.scope(): + m = Model() + m.v.assign(1.) + self.assertAllEqual( + self.evaluate(strategy.experimental_local_results(m.v)), [0.5, 0.5]) + tf.saved_model.save(m, export_dir) + + with strategy.scope(): + loaded = tf.saved_model.load(export_dir) + # After loading, reading in the replica context is the same as reading in + # the cross-replica context. + # TODO(b/159752793): change after fix. + self.assertAllEqual( + self.evaluate( + strategy.experimental_local_results(strategy.run(loaded))), + [1., 1.]) + self.assertEqual(self.evaluate(loaded.v.read_value()), 1.) + + def test_update_sync_on_read_variable_broken(self, strategy): + # Can't even save. + + class Model(tf.Module): + + def __init__(self): + self.v = tf.Variable( + 0., + synchronization=tf.VariableSynchronization.ON_READ, + aggregation=tf.VariableAggregation.SUM) + + @tf.function(input_signature=[tf.TensorSpec(shape=[1], dtype=tf.float32)]) + def update(self, value): + self.v.assign_add(value) + + export_dir = self.get_temp_dir() + with strategy.scope(): + m = Model() + # got error, want no error. + with self.assertRaisesRegex(ValueError, + "SyncOnReadVariable does not support"): + tf.saved_model.save(m, export_dir) + + # TODO(b/159752793): Complete the test after the saving issue is fixed. + + def test_read_mirrored_variable(self, strategy): + + class Model(tf.Module): + + def __init__(self): + self.v = tf.Variable( + 0., synchronization=tf.VariableSynchronization.ON_WRITE) + + @tf.function(input_signature=[]) + def __call__(self): + return self.v.read_value() + + export_dir = self.get_temp_dir() + with strategy.scope(): + m = Model() + m.v.assign(1.) + tf.saved_model.save(m, export_dir) + + with strategy.scope(): + loaded = tf.saved_model.load(export_dir) + self.assertAllEqual( + self.evaluate( + strategy.experimental_local_results(strategy.run(loaded))), + [1., 1.]) + + def test_update_mirrored_variable_broken(self, strategy): + # This is also uncommon since most model parameters should be updated by + # optimizer, and this test case is for completeness. + # + # It's broken the saved model may not contain the aggregation logic. Even if + # it does, it's wrong since all inputs to the aggregation are the same + # variable. + + class Model(tf.Module): + + def __init__(self): + self.v = tf.Variable( + 0., + synchronization=tf.VariableSynchronization.ON_WRITE, + aggregation=tf.VariableAggregation.MEAN) + + @tf.function(input_signature=[tf.TensorSpec(shape=[1], dtype=tf.float32)]) + def update(self, value): + self.v.assign_add(value[0]) + + export_dir = self.get_temp_dir() + with strategy.scope(): + m = Model() + tf.saved_model.save(m, export_dir) + + with strategy.scope(): + loaded = tf.saved_model.load(export_dir) + value = strategy.experimental_distribute_dataset( + tf.data.Dataset.from_tensor_slices([1., 2.]).batch(2)) + strategy.run(loaded.update, args=(next(iter(value)),)) + # TODO(b/159752793): Change after fix. + # got [2., 4.], want [1.5, 1.5]. + self.assertAllEqual( + self.evaluate(strategy.experimental_local_results(loaded.v)), [2., 4.]) + + # TODO(crccw): add a test case that trains a saved model with optimizer. + + def test_model_with_loaded_v1_layer_broken(self, strategy): + # If a model contains a layer loaded from SavedModel, and if the model is + # loaded under tf.distribute.Strategy, it can't be saved and loaded again + # under tf.distribute.Strategy. + # + # Although the error is the same models with TF2 SavedModel, the cause is + # different. TF1 models loaded in API contain an intializer, which is + # invoked upon loading. Since loading is in the cross-replica context, that + # fails. + # + # Note that these saved model can still be loaded and used without + # tf.distribute.Strategy. + # + # Some tf.hub layers are converted from TF1, by loading TF1 saved model in + # TF2 then saved in TF2. This issue disables them to work with + # tf.distribute.Strategy. + v1_export_dir = self.get_temp_dir() + + with tf1.Graph().as_default(), tf1.Session() as sess: + v = tf1.Variable(1., use_resource=True) + sess.run(v.initializer) + builder = tf1.saved_model.Builder(v1_export_dir) + builder.add_meta_graph_and_variables( + sess, + tags=[tf1.saved_model.tag_constants.TRAINING], + main_op=tf1.tables_initializer(), + strip_default_attrs=True) + builder.save() + + v1_loaded = tf.saved_model.load(v1_export_dir) + v2_export_dir = self.get_temp_dir() + tf.saved_model.save(v1_loaded, v2_export_dir) + with strategy.scope(): + # TODO(b/150009657): remove after fix. + # got error, want no error. + with self.assertRaisesRegex( + tf.errors.InvalidArgumentError, + "from the cross-replica context in an in-replica context"): + tf.saved_model.load(v2_export_dir) + + +if __name__ == "__main__": + combinations.main()