Scope the scalar cache in the context.

PiperOrigin-RevId: 168065417
This commit is contained in:
Alexandre Passos 2017-09-08 16:52:18 -07:00 committed by TensorFlower Gardener
parent 48deb206ba
commit 0753b0c790
2 changed files with 10 additions and 7 deletions

View File

@ -53,6 +53,7 @@ class _EagerContext(threading.local):
self.mode = _default_mode self.mode = _default_mode
self.scope_name = "" self.scope_name = ""
self.recording_summaries = False self.recording_summaries = False
self.scalar_cache = {}
# TODO(agarwal): rename to EagerContext / EagerRuntime ? # TODO(agarwal): rename to EagerContext / EagerRuntime ?
@ -157,6 +158,10 @@ class Context(object):
"""Returns True if current thread is in EAGER mode.""" """Returns True if current thread is in EAGER mode."""
return self._eager_context.mode == EAGER_MODE return self._eager_context.mode == EAGER_MODE
def scalar_cache(self):
"""Per-device cache for scalars."""
return self._eager_context.scalar_cache
@property @property
def scope_name(self): def scope_name(self):
"""Returns scope name for the current thread.""" """Returns scope name for the current thread."""

View File

@ -74,10 +74,6 @@ def _eager_fill(dims, value):
return result return result
# Rely on the GIL for thread-safety.
_scalar_cache = {}
def convert_to_eager_tensor(t, dtype=None): def convert_to_eager_tensor(t, dtype=None):
"""Converts the given `value` to an `EagerTensor`.""" """Converts the given `value` to an `EagerTensor`."""
if isinstance(ag_core.getval(t), ops.EagerTensor): if isinstance(ag_core.getval(t), ops.EagerTensor):
@ -88,13 +84,15 @@ def convert_to_eager_tensor(t, dtype=None):
# Use a scalar cache. This will put each scalar of each type only once on # Use a scalar cache. This will put each scalar of each type only once on
# each device. Scalars don't use much device memory but copying scalars can # each device. Scalars don't use much device memory but copying scalars can
# trigger memcpys which are slow. # trigger memcpys which are slow.
device = context.context().device_name ctx = context.context()
device = ctx.device_name
cache_key = device, t, dtype, type(t) cache_key = device, t, dtype, type(t)
tensor = _scalar_cache.get(cache_key, None) scalar_cache = ctx.scalar_cache()
tensor = scalar_cache.get(cache_key, None)
if tensor is not None: if tensor is not None:
return tensor return tensor
value = ops.EagerTensor(t, dtype=dtype) value = ops.EagerTensor(t, dtype=dtype)
_scalar_cache[cache_key] = value scalar_cache[cache_key] = value
return value return value
return ops.EagerTensor(t, dtype=dtype) return ops.EagerTensor(t, dtype=dtype)