mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Add a .with_cache() method to distributions.Transform objects (#36882)
Summary: This resolves an issue observed by stefanwebb where the composition of multiple transforms is cached only if all components are cached. This PR adds a new method `.with_cache()` so that e.g. you can compose a normalizing flow (that needs to be cached) with a `SigmoidTransform` (that wasn't already cached) by calling `.with_cache()` on the latter. This issue also comes up when composing non-cached constraint transforms as returned by `transform_to()` and `biject_to()`: after this PR you can call `transform_to(constraints.positive).with_cache()` to get a cached `ExpTransform`. ## Tested - [x] added a unit test Pull Request resolved: https://github.com/pytorch/pytorch/pull/36882 Differential Revision: D21155914 Pulled By: ezyang fbshipit-source-id: 3c06e63785ca2503e08a5cd7532aff81882835e9
This commit is contained in:
parent
01100cb477
commit
00b7d84eb7
|
|
@ -4268,6 +4268,20 @@ class TestTransforms(TestCase):
|
|||
self.assertTrue(identity_transform == identity_transform.inv)
|
||||
self.assertFalse(identity_transform != identity_transform.inv)
|
||||
|
||||
def test_with_cache(self):
|
||||
for transform in self.transforms:
|
||||
if transform._cache_size == 0:
|
||||
transform = transform.with_cache(1)
|
||||
self.assertTrue(transform._cache_size == 1)
|
||||
|
||||
x = self._generate_data(transform).requires_grad_()
|
||||
try:
|
||||
y = transform(x)
|
||||
except NotImplementedError:
|
||||
continue
|
||||
y2 = transform(x)
|
||||
self.assertTrue(y2 is y)
|
||||
|
||||
def test_forward_inverse_cache(self):
|
||||
for transform in self.transforms:
|
||||
x = self._generate_data(transform).requires_grad_()
|
||||
|
|
|
|||
|
|
@ -112,6 +112,13 @@ class Transform(object):
|
|||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def with_cache(self, cache_size=1):
|
||||
if self._cache_size == cache_size:
|
||||
return self
|
||||
if type(self).__init__ is Transform.__init__:
|
||||
return type(self)(cache_size=cache_size)
|
||||
raise NotImplementedError("{}.with_cache is not implemented".format(type(self)))
|
||||
|
||||
def __eq__(self, other):
|
||||
return self is other
|
||||
|
||||
|
|
@ -173,7 +180,7 @@ class _InverseTransform(Transform):
|
|||
This class is private; please instead use the ``Transform.inv`` property.
|
||||
"""
|
||||
def __init__(self, transform):
|
||||
super(_InverseTransform, self).__init__()
|
||||
super(_InverseTransform, self).__init__(cache_size=transform._cache_size)
|
||||
self._inv = transform
|
||||
|
||||
@constraints.dependent_property
|
||||
|
|
@ -200,6 +207,9 @@ class _InverseTransform(Transform):
|
|||
def inv(self):
|
||||
return self._inv
|
||||
|
||||
def with_cache(self, cache_size=1):
|
||||
return self.inv.with_cache(cache_size).inv
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, _InverseTransform):
|
||||
return False
|
||||
|
|
@ -219,9 +229,13 @@ class ComposeTransform(Transform):
|
|||
|
||||
Args:
|
||||
parts (list of :class:`Transform`): A list of transforms to compose.
|
||||
cache_size (int): Size of cache. If zero, no caching is done. If one,
|
||||
the latest single value is cached. Only 0 and 1 are supported.
|
||||
"""
|
||||
def __init__(self, parts):
|
||||
super(ComposeTransform, self).__init__()
|
||||
def __init__(self, parts, cache_size=0):
|
||||
if cache_size:
|
||||
parts = [part.with_cache(cache_size) for part in parts]
|
||||
super(ComposeTransform, self).__init__(cache_size=cache_size)
|
||||
self.parts = parts
|
||||
|
||||
def __eq__(self, other):
|
||||
|
|
@ -267,6 +281,11 @@ class ComposeTransform(Transform):
|
|||
inv._inv = weakref.ref(self)
|
||||
return inv
|
||||
|
||||
def with_cache(self, cache_size=1):
|
||||
if self._cache_size == cache_size:
|
||||
return self
|
||||
return ComposeTransform(self.parts, cache_size=cache_size)
|
||||
|
||||
def __call__(self, x):
|
||||
for part in self.parts:
|
||||
x = part(x)
|
||||
|
|
@ -331,6 +350,11 @@ class PowerTransform(Transform):
|
|||
super(PowerTransform, self).__init__(cache_size=cache_size)
|
||||
self.exponent, = broadcast_all(exponent)
|
||||
|
||||
def with_cache(self, cache_size=1):
|
||||
if self._cache_size == cache_size:
|
||||
return self
|
||||
return PowerTransform(self.exponent, cache_size=cache_size)
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, PowerTransform):
|
||||
return False
|
||||
|
|
@ -453,6 +477,11 @@ class AffineTransform(Transform):
|
|||
self.scale = scale
|
||||
self.event_dim = event_dim
|
||||
|
||||
def with_cache(self, cache_size=1):
|
||||
if self._cache_size == cache_size:
|
||||
return self
|
||||
return AffineTransform(self.loc, self.scale, self.event_dim, cache_size=cache_size)
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, AffineTransform):
|
||||
return False
|
||||
|
|
@ -606,9 +635,11 @@ class CatTransform(Transform):
|
|||
t = CatTransform([t0, t0], dim=0, lengths=[20, 20])
|
||||
y = t(x)
|
||||
"""
|
||||
def __init__(self, tseq, dim=0, lengths=None):
|
||||
def __init__(self, tseq, dim=0, lengths=None, cache_size=0):
|
||||
assert all(isinstance(t, Transform) for t in tseq)
|
||||
super(CatTransform, self).__init__()
|
||||
if cache_size:
|
||||
tseq = [t.with_cache(cache_size) for t in tseq]
|
||||
super(CatTransform, self).__init__(cache_size=cache_size)
|
||||
self.transforms = list(tseq)
|
||||
if lengths is None:
|
||||
lengths = [1] * len(self.transforms)
|
||||
|
|
@ -620,6 +651,11 @@ class CatTransform(Transform):
|
|||
def length(self):
|
||||
return sum(self.lengths)
|
||||
|
||||
def with_cache(self, cache_size=1):
|
||||
if self._cache_size == cache_size:
|
||||
return self
|
||||
return CatTransform(self.tseq, self.dim, self.lengths, cache_size)
|
||||
|
||||
def _call(self, x):
|
||||
assert -x.dim() <= self.dim < x.dim()
|
||||
assert x.size(self.dim) == self.length
|
||||
|
|
@ -682,12 +718,19 @@ class StackTransform(Transform):
|
|||
t = StackTransform([ExpTransform(), identity_transform], dim=1)
|
||||
y = t(x)
|
||||
"""
|
||||
def __init__(self, tseq, dim=0):
|
||||
def __init__(self, tseq, dim=0, cache_size=0):
|
||||
assert all(isinstance(t, Transform) for t in tseq)
|
||||
super(StackTransform, self).__init__()
|
||||
if cache_size:
|
||||
tseq = [t.with_cache(cache_size) for t in tseq]
|
||||
super(StackTransform, self).__init__(cache_size=cache_size)
|
||||
self.transforms = list(tseq)
|
||||
self.dim = dim
|
||||
|
||||
def with_cache(self, cache_size=1):
|
||||
if self._cache_size == cache_size:
|
||||
return self
|
||||
return StackTransform(self.transforms, self.dim, cache_size)
|
||||
|
||||
def _slice(self, z):
|
||||
return [z.select(self.dim, i) for i in range(z.size(self.dim))]
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user