Add a .with_cache() method to distributions.Transform objects (#36882)

Summary:
This resolves an issue observed by stefanwebb where the composition of multiple transforms is cached only if all components are cached.

This PR adds a new method `.with_cache()` so that e.g. you can compose a normalizing flow (that needs to be cached) with a `SigmoidTransform` (that wasn't already cached) by calling `.with_cache()` on the latter. This issue also comes up when composing non-cached constraint transforms as returned by `transform_to()` and `biject_to()`: after this PR you can call `transform_to(constraints.positive).with_cache()` to get a cached `ExpTransform`.

## Tested
- [x] added a unit test
Pull Request resolved: https://github.com/pytorch/pytorch/pull/36882

Differential Revision: D21155914

Pulled By: ezyang

fbshipit-source-id: 3c06e63785ca2503e08a5cd7532aff81882835e9
This commit is contained in:
Fritz Obermeyer 2020-04-21 10:47:32 -07:00 committed by Facebook GitHub Bot
parent 01100cb477
commit 00b7d84eb7
2 changed files with 64 additions and 7 deletions

View File

@ -4268,6 +4268,20 @@ class TestTransforms(TestCase):
self.assertTrue(identity_transform == identity_transform.inv)
self.assertFalse(identity_transform != identity_transform.inv)
def test_with_cache(self):
for transform in self.transforms:
if transform._cache_size == 0:
transform = transform.with_cache(1)
self.assertTrue(transform._cache_size == 1)
x = self._generate_data(transform).requires_grad_()
try:
y = transform(x)
except NotImplementedError:
continue
y2 = transform(x)
self.assertTrue(y2 is y)
def test_forward_inverse_cache(self):
for transform in self.transforms:
x = self._generate_data(transform).requires_grad_()

View File

@ -112,6 +112,13 @@ class Transform(object):
"""
raise NotImplementedError
def with_cache(self, cache_size=1):
if self._cache_size == cache_size:
return self
if type(self).__init__ is Transform.__init__:
return type(self)(cache_size=cache_size)
raise NotImplementedError("{}.with_cache is not implemented".format(type(self)))
def __eq__(self, other):
return self is other
@ -173,7 +180,7 @@ class _InverseTransform(Transform):
This class is private; please instead use the ``Transform.inv`` property.
"""
def __init__(self, transform):
super(_InverseTransform, self).__init__()
super(_InverseTransform, self).__init__(cache_size=transform._cache_size)
self._inv = transform
@constraints.dependent_property
@ -200,6 +207,9 @@ class _InverseTransform(Transform):
def inv(self):
return self._inv
def with_cache(self, cache_size=1):
return self.inv.with_cache(cache_size).inv
def __eq__(self, other):
if not isinstance(other, _InverseTransform):
return False
@ -219,9 +229,13 @@ class ComposeTransform(Transform):
Args:
parts (list of :class:`Transform`): A list of transforms to compose.
cache_size (int): Size of cache. If zero, no caching is done. If one,
the latest single value is cached. Only 0 and 1 are supported.
"""
def __init__(self, parts):
super(ComposeTransform, self).__init__()
def __init__(self, parts, cache_size=0):
if cache_size:
parts = [part.with_cache(cache_size) for part in parts]
super(ComposeTransform, self).__init__(cache_size=cache_size)
self.parts = parts
def __eq__(self, other):
@ -267,6 +281,11 @@ class ComposeTransform(Transform):
inv._inv = weakref.ref(self)
return inv
def with_cache(self, cache_size=1):
if self._cache_size == cache_size:
return self
return ComposeTransform(self.parts, cache_size=cache_size)
def __call__(self, x):
for part in self.parts:
x = part(x)
@ -331,6 +350,11 @@ class PowerTransform(Transform):
super(PowerTransform, self).__init__(cache_size=cache_size)
self.exponent, = broadcast_all(exponent)
def with_cache(self, cache_size=1):
if self._cache_size == cache_size:
return self
return PowerTransform(self.exponent, cache_size=cache_size)
def __eq__(self, other):
if not isinstance(other, PowerTransform):
return False
@ -453,6 +477,11 @@ class AffineTransform(Transform):
self.scale = scale
self.event_dim = event_dim
def with_cache(self, cache_size=1):
if self._cache_size == cache_size:
return self
return AffineTransform(self.loc, self.scale, self.event_dim, cache_size=cache_size)
def __eq__(self, other):
if not isinstance(other, AffineTransform):
return False
@ -606,9 +635,11 @@ class CatTransform(Transform):
t = CatTransform([t0, t0], dim=0, lengths=[20, 20])
y = t(x)
"""
def __init__(self, tseq, dim=0, lengths=None):
def __init__(self, tseq, dim=0, lengths=None, cache_size=0):
assert all(isinstance(t, Transform) for t in tseq)
super(CatTransform, self).__init__()
if cache_size:
tseq = [t.with_cache(cache_size) for t in tseq]
super(CatTransform, self).__init__(cache_size=cache_size)
self.transforms = list(tseq)
if lengths is None:
lengths = [1] * len(self.transforms)
@ -620,6 +651,11 @@ class CatTransform(Transform):
def length(self):
return sum(self.lengths)
def with_cache(self, cache_size=1):
if self._cache_size == cache_size:
return self
return CatTransform(self.tseq, self.dim, self.lengths, cache_size)
def _call(self, x):
assert -x.dim() <= self.dim < x.dim()
assert x.size(self.dim) == self.length
@ -682,12 +718,19 @@ class StackTransform(Transform):
t = StackTransform([ExpTransform(), identity_transform], dim=1)
y = t(x)
"""
def __init__(self, tseq, dim=0):
def __init__(self, tseq, dim=0, cache_size=0):
assert all(isinstance(t, Transform) for t in tseq)
super(StackTransform, self).__init__()
if cache_size:
tseq = [t.with_cache(cache_size) for t in tseq]
super(StackTransform, self).__init__(cache_size=cache_size)
self.transforms = list(tseq)
self.dim = dim
def with_cache(self, cache_size=1):
if self._cache_size == cache_size:
return self
return StackTransform(self.transforms, self.dim, cache_size)
def _slice(self, z):
return [z.select(self.dim, i) for i in range(z.size(self.dim))]