Fix forward test failure in multi-GPU tests

PiperOrigin-RevId: 381309867
Change-Id: I4a58cca3e84c6816d5a501a98ed3850940ca5f6d
This commit is contained in:
Isha Arkatkar 2021-06-24 12:16:18 -07:00 committed by TensorFlower Gardener
parent 88dc57fcc5
commit fa84a1e83c

View File

@ -200,6 +200,8 @@ class AggregatingVariable(variables_lib.Variable, core.Tensor):
# TODO(josh11b): Test saving & restoring.
def _gather_saveables_for_checkpoint(self):
if isinstance(self._v, CachingVariable):
return self._v._gather_saveables_for_checkpoint() # pylint:disable=protected-access
return {trackable.VARIABLE_VALUE_KEY: self._v}
def _map_resources(self, save_options):