mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
Scope the scalar cache in the context.
PiperOrigin-RevId: 168065417
This commit is contained in:
parent
48deb206ba
commit
0753b0c790
|
|
@ -53,6 +53,7 @@ class _EagerContext(threading.local):
|
||||||
self.mode = _default_mode
|
self.mode = _default_mode
|
||||||
self.scope_name = ""
|
self.scope_name = ""
|
||||||
self.recording_summaries = False
|
self.recording_summaries = False
|
||||||
|
self.scalar_cache = {}
|
||||||
|
|
||||||
|
|
||||||
# TODO(agarwal): rename to EagerContext / EagerRuntime ?
|
# TODO(agarwal): rename to EagerContext / EagerRuntime ?
|
||||||
|
|
@ -157,6 +158,10 @@ class Context(object):
|
||||||
"""Returns True if current thread is in EAGER mode."""
|
"""Returns True if current thread is in EAGER mode."""
|
||||||
return self._eager_context.mode == EAGER_MODE
|
return self._eager_context.mode == EAGER_MODE
|
||||||
|
|
||||||
|
def scalar_cache(self):
|
||||||
|
"""Per-device cache for scalars."""
|
||||||
|
return self._eager_context.scalar_cache
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def scope_name(self):
|
def scope_name(self):
|
||||||
"""Returns scope name for the current thread."""
|
"""Returns scope name for the current thread."""
|
||||||
|
|
|
||||||
|
|
@ -74,10 +74,6 @@ def _eager_fill(dims, value):
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
# Rely on the GIL for thread-safety.
|
|
||||||
_scalar_cache = {}
|
|
||||||
|
|
||||||
|
|
||||||
def convert_to_eager_tensor(t, dtype=None):
|
def convert_to_eager_tensor(t, dtype=None):
|
||||||
"""Converts the given `value` to an `EagerTensor`."""
|
"""Converts the given `value` to an `EagerTensor`."""
|
||||||
if isinstance(ag_core.getval(t), ops.EagerTensor):
|
if isinstance(ag_core.getval(t), ops.EagerTensor):
|
||||||
|
|
@ -88,13 +84,15 @@ def convert_to_eager_tensor(t, dtype=None):
|
||||||
# Use a scalar cache. This will put each scalar of each type only once on
|
# Use a scalar cache. This will put each scalar of each type only once on
|
||||||
# each device. Scalars don't use much device memory but copying scalars can
|
# each device. Scalars don't use much device memory but copying scalars can
|
||||||
# trigger memcpys which are slow.
|
# trigger memcpys which are slow.
|
||||||
device = context.context().device_name
|
ctx = context.context()
|
||||||
|
device = ctx.device_name
|
||||||
cache_key = device, t, dtype, type(t)
|
cache_key = device, t, dtype, type(t)
|
||||||
tensor = _scalar_cache.get(cache_key, None)
|
scalar_cache = ctx.scalar_cache()
|
||||||
|
tensor = scalar_cache.get(cache_key, None)
|
||||||
if tensor is not None:
|
if tensor is not None:
|
||||||
return tensor
|
return tensor
|
||||||
value = ops.EagerTensor(t, dtype=dtype)
|
value = ops.EagerTensor(t, dtype=dtype)
|
||||||
_scalar_cache[cache_key] = value
|
scalar_cache[cache_key] = value
|
||||||
return value
|
return value
|
||||||
return ops.EagerTensor(t, dtype=dtype)
|
return ops.EagerTensor(t, dtype=dtype)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user