jacrev : Support chunked computation (#89376)

Ref: https://github.com/pytorch/functorch/issues/680

We introduce a kwarg `chunk_size` in `jacrev` to control whether the Jacobian computation should be chunked and if so then `chunk_size` will dictate the maximum size of the chunks used.

We try two approaches,
* Stacked Approach: Append the intermediate computation to a list and then stack those results.
* Pre-allocation Approach: Pre-allocate a zeros tensor and copy chunked computation into it.

For Memory Benchmark, see https://github.com/pytorch/pytorch/pull/89376#issuecomment-1348479098

Benchmark CPU : Performs better with more chunks/ smaller chunk_size.

NOTE: There seems to be a lot of noise for shape `(64, 64)`.

<details>

```
[----------------------------------------------- jacrev : device cpu : chunks 2 -----------------------------------------------]
                                     |  with chunk_size and stacked  |  without chunk_size  |  with chunk_size and pre-allocated
1 threads: ---------------------------------------------------------------------------------------------------------------------
      (64, 64) : chunk_size 2080     |               76.2            |          50.9        |                  80.1
      (128, 128) : chunk_size 8256   |             1172.8            |         783.3        |                1225.5
      (128, 144) : chunk_size 9288   |             1475.1            |         990.4        |                1548.3
      (144, 144) : chunk_size 10440  |             1871.3            |        1254.4        |                1971.2

Times are in milliseconds (ms).

[----------------------------------------------- jacrev : device cpu : chunks 3 ----------------------------------------------]
                                    |  with chunk_size and stacked  |  without chunk_size  |  with chunk_size and pre-allocated
1 threads: --------------------------------------------------------------------------------------------------------------------
      (64, 64) : chunk_size 1386    |               39.9            |          25.8        |                  58.8
      (128, 128) : chunk_size 5504  |             1182.6            |         782.2        |                1229.7
      (128, 144) : chunk_size 6192  |             1483.6            |         995.4        |                1550.6
      (144, 144) : chunk_size 6960  |             1879.1            |        1257.7        |                1960.5

Times are in milliseconds (ms).

[----------------------------------------------- jacrev : device cpu : chunks 4 ----------------------------------------------]
                                    |  with chunk_size and stacked  |  without chunk_size  |  with chunk_size and pre-allocated
1 threads: --------------------------------------------------------------------------------------------------------------------
      (64, 64) : chunk_size 1040    |               41.7            |          50.6        |                  29.1
      (128, 128) : chunk_size 4128  |             1171.6            |         782.3        |                1226.7
      (128, 144) : chunk_size 4644  |             1482.2            |         994.6        |                1550.9
      (144, 144) : chunk_size 5220  |             1870.2            |        1254.5        |                1961.4

Times are in milliseconds (ms).

[--------------------------------------------- jacrev : device cpu : chunks 100 ---------------------------------------------]
                                   |  with chunk_size and stacked  |  without chunk_size  |  with chunk_size and pre-allocated
1 threads: -------------------------------------------------------------------------------------------------------------------
      (64, 64) : chunk_size 41     |               46.8            |          50.5        |                  46.4
      (128, 128) : chunk_size 165  |              622.2            |         775.2        |                 656.0
      (128, 144) : chunk_size 185  |              803.9            |         987.3        |                 866.9
      (144, 144) : chunk_size 208  |             1021.1            |        1251.2        |                1088.2

Times are in milliseconds (ms).

[--------------------------------------------- jacrev : device cpu : chunks 200 ---------------------------------------------]
                                   |  with chunk_size and stacked  |  without chunk_size  |  with chunk_size and pre-allocated
1 threads: -------------------------------------------------------------------------------------------------------------------
      (64, 64) : chunk_size 20     |               60.9            |          50.2        |                  62.3
      (128, 128) : chunk_size 82   |              583.1            |         779.4        |                 634.3
      (128, 144) : chunk_size 92   |              834.1            |        1005.8        |                 472.3
      (144, 144) : chunk_size 104  |             1053.6            |        1277.0        |                1033.9

Times are in milliseconds (ms).

[--------------------------------------------- jacrev : device cpu : chunks 300 --------------------------------------------]
                                  |  with chunk_size and stacked  |  without chunk_size  |  with chunk_size and pre-allocated
1 threads: ------------------------------------------------------------------------------------------------------------------
      (64, 64) : chunk_size 13    |              77.7             |          50.4        |                  79.6
      (128, 128) : chunk_size 55  |             578.9             |         782.3        |                 626.9
      (128, 144) : chunk_size 61  |             718.2             |        1024.9        |                 800.4
      (144, 144) : chunk_size 69  |             919.7             |        1313.7        |                1023.0

Times are in milliseconds (ms).
```

</details>

Benchmark CUDA: Performs better with less chunks/bigger chunk_size.

<details>

```
[--------------------------------------------- jacrev : device cuda:1 : chunks 2 ----------------------------------------------]
                                     |  with chunk_size and stacked  |  without chunk_size  |  with chunk_size and pre-allocated
1 threads: ---------------------------------------------------------------------------------------------------------------------
      (64, 64) : chunk_size 2080     |             1485.7            |         923.8        |                1632.3
      (128, 128) : chunk_size 8256   |            25390.2            |       14103.2        |               33557.4
      (128, 144) : chunk_size 9288   |              801.7            |       16854.1        |               42894.6
      (144, 144) : chunk_size 10440  |             1003.5            |       21386.5        |               59648.5

Times are in microseconds (us).

3 / 3 : Shape (144, 144) : Device cuda:1 : chunks: 3
[--------------------------------------------- jacrev : device cuda:1 : chunks 3 ---------------------------------------------]
                                    |  with chunk_size and stacked  |  without chunk_size  |  with chunk_size and pre-allocated
1 threads: --------------------------------------------------------------------------------------------------------------------
      (64, 64) : chunk_size 1386    |             1474.5            |         924.5        |                1655.5
      (128, 128) : chunk_size 5504  |            25368.9            |       10156.0        |               34022.1
      (128, 144) : chunk_size 6192  |            25223.0            |       12933.7        |               56418.5
      (144, 144) : chunk_size 6960  |            24729.3            |       16367.4        |               68744.7

Times are in microseconds (us).

3 / 3 : Shape (144, 144) : Device cuda:1 : chunks: 4
[--------------------------------------------- jacrev : device cuda:1 : chunks 4 ---------------------------------------------]
                                    |  with chunk_size and stacked  |  without chunk_size  |  with chunk_size and pre-allocated
1 threads: --------------------------------------------------------------------------------------------------------------------
      (64, 64) : chunk_size 1040    |             1489.2            |         924.4        |                 1679.6
      (128, 128) : chunk_size 4128  |            25370.4            |        8987.4        |                57201.3
      (128, 144) : chunk_size 4644  |            32239.1            |       10136.2        |                72406.5
      (144, 144) : chunk_size 5220  |            40994.3            |       12867.8        |               108653.4

Times are in microseconds (us).

3 / 3 : Shape (144, 144) : Device cuda:1 : chunks: 100
[------------------------------------------- jacrev : device cuda:1 : chunks 100 --------------------------------------------]
                                   |  with chunk_size and stacked  |  without chunk_size  |  with chunk_size and pre-allocated
1 threads: -------------------------------------------------------------------------------------------------------------------
      (64, 64) : chunk_size 41     |            21121.8            |         924.2        |               22753.5
      (128, 128) : chunk_size 165  |            23679.7            |       14284.4        |               26758.2
      (128, 144) : chunk_size 185  |            30082.3            |       18063.3        |               33553.5
      (144, 144) : chunk_size 208  |            38175.6            |       22839.5        |               42030.0

Times are in microseconds (us).
```

</details>

Benchmark Script

<details>

```python
import functorch
import torch
import itertools
import time
from torch.utils.benchmark import Timer
from torch.utils.benchmark import Compare
import sys
import pickle
from torch import profiler

import math

def prod(l):
    prod = 1
    for el in l:
        prod *= el

    return prod

def fn(x, y):
    return x + y, x.sum(0)

shapes = ((64, 64), (128, 128), (128, 144), (144, 144))

for device in ('cpu', 'cuda:1'):
    if device == 'cuda:1':
        chunks = (2, 3, 4, 100,)
    else:
        chunks = (2, 3, 4, 100, 200, 300)
    for chunk in chunks:
        results = []
        for shape in shapes:
            x = torch.zeros(*shape, dtype=torch.float, device=device)
            y = x.sum()
            chunk_size = (prod(shape) + prod(shape[1:])) // chunk
            jacrev_fn_chunked = functorch.jacrev(fn, (0, 1), chunk_size=chunk_size)
            jacrev_fn_chunked_pre = functorch.jacrev(fn, (0, 1), chunk_size=chunk_size, _preallocate_and_copy=True)
            jacrev_fn = functorch.jacrev(fn, (0, 1), chunk_size=None)

            tasks = [("jacrev_fn_chunked(x, y)", "with chunk_size and stacked"),
                     ("jacrev_fn(x, y)", "without chunk_size"),
                     ("jacrev_fn_chunked_pre(x, y)", "with chunk_size and pre-allocated"),]
            timers = [Timer(stmt=stmt, label=f"jacrev : device {device} : chunks {chunk}", sub_label=f"{(shape)} : chunk_size {chunk_size}", description=desc, globals=globals()) for stmt, desc in tasks]

            for i, timer in enumerate(timers):
                results.append(
                    timer.blocked_autorange(min_run_time=2.)
                )
                print(f"\r{i + 1} / {len(timers)} : Shape {shape} : Device {device} : chunks: {chunk}", end="")
                sys.stdout.flush()

        print()
        comparison = Compare(results)
        comparison.print()
```

</details>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89376
Approved by: https://github.com/zou3519
This commit is contained in:
Kshiteej K 2022-12-19 20:04:18 +00:00 committed by PyTorch MergeBot
parent e2dc60c6cb
commit f02e93b584
2 changed files with 134 additions and 16 deletions

View File

@ -1872,6 +1872,41 @@ class TestJac(TestCase):
out_val = out(x, y, z)
self.assertEqual(out_val, expected_out)
@parametrize('_preallocate_and_copy', (True, False))
def test_chunk_jacrev(self, device, _preallocate_and_copy):
x = torch.randn(10, 2, device=device)
y = torch.randn(1, 2, device=device)
def f(x, y):
return (x.sin(), x + y), (x + 2, x.sum())
for chunk_size in (1, 2, 3, 4, 7, 10, 1000):
expected = jacrev(f, argnums=(0, 1))(x, y)
actual = jacrev(f, argnums=(0, 1),
chunk_size=chunk_size,
_preallocate_and_copy=_preallocate_and_copy)(x, y)
self.assertEqual(actual, expected)
err_msg = "jacrev: `chunk_size` should be greater than 0."
with self.assertRaisesRegex(ValueError, err_msg):
jacrev(f, argnums=(0, ), chunk_size=0)(x, y)
with self.assertRaisesRegex(ValueError, err_msg):
jacrev(f, argnums=(0, ), chunk_size=-2)(x, y)
@parametrize('_preallocate_and_copy', (True, False))
def test_chunk_jacrev_composition(self, device, _preallocate_and_copy):
x = torch.randn(10, 2, device=device)
chunk_size = 3
def f(x):
return (x.sin(), x), (x + 2, x.sum())
expected = vmap(jacrev(jacrev(f)))(x)
actual = vmap(jacrev(jacrev(f, chunk_size=chunk_size,
_preallocate_and_copy=_preallocate_and_copy), chunk_size=chunk_size))(x)
self.assertEqual(actual, expected)
class TestHessian(TestCase):
def _test_against_reference(self, f, inputs):

View File

@ -336,7 +336,9 @@ def _safe_zero_index(x):
return x[0]
def jacrev(func: Callable, argnums: Union[int, Tuple[int]] = 0, *, has_aux=False):
def jacrev(func: Callable, argnums: Union[int, Tuple[int]] = 0, *, has_aux=False,
chunk_size: Optional[int] = None,
_preallocate_and_copy=False):
"""
Computes the Jacobian of :attr:`func` with respect to the arg(s) at index
:attr:`argnum` using reverse mode autodiff
@ -352,6 +354,13 @@ def jacrev(func: Callable, argnums: Union[int, Tuple[int]] = 0, *, has_aux=False
the function to be differentiated and the second element is
auxiliary objects that will not be differentiated.
Default: False.
chunk_size (None or int): If None (default), use the maximum chunk size
(equivalent to doing a single vmap over vjp to compute the jacobian).
If not None, then compute the jacobian :attr:`chunk_size` rows at a time
(equivalent to doing multiple vmap over vjp).
Note that :attr:`chunk_size=1` is equivalent to computing the jacobian
row-by-row with a for-loop. If you run into memory issues computing
the jacobian, please try to specify a non-None chunk_size.
Returns:
Returns a function that takes in the same inputs as :attr:`func` and
@ -452,6 +461,9 @@ def jacrev(func: Callable, argnums: Union[int, Tuple[int]] = 0, *, has_aux=False
outer one. This is because ``jacrev`` is a "function transform": its result
should not depend on the result of a context manager outside of ``f``.
"""
if not (chunk_size is None or chunk_size > 0):
raise ValueError("jacrev: `chunk_size` should be greater than 0.")
@wraps(func)
def wrapper_fn(*args):
vjp_out = _vjp_with_argnums(func, *args, argnums=argnums, has_aux=has_aux)
@ -466,22 +478,73 @@ def jacrev(func: Callable, argnums: Union[int, Tuple[int]] = 0, *, has_aux=False
# NB: vjp already checks that all outputs are tensors
# Step 1: Construct grad_outputs by splitting the standard basis
flat_output_numels = tuple(out.numel() for out in flat_output)
flat_basis = _construct_standard_basis_for(flat_output, flat_output_numels)
basis = tree_unflatten(flat_basis, output_spec)
results = vmap(vjp_fn)(basis)
primals = _slice_argnums(args, argnums)
flat_primals, primals_spec = tree_flatten(primals)
flat_results, results_spec = tree_flatten(results)
def compute_jacobian_stacked():
# Helper function to compute chunked Jacobian
# The intermediate chunked calculation are only
# scoped at this function level.
chunked_results = []
for flat_basis_chunk in _chunked_standard_basis_for_(flat_output,
flat_output_numels,
chunk_size=chunk_size):
basis = tree_unflatten(flat_basis_chunk, output_spec)
chunked_result = vmap(vjp_fn)(basis)
flat_results, _ = tree_flatten(chunked_result)
chunked_results.append(flat_results)
if len(chunked_results) == 1:
# Short-circuit if we used a single chunk
return chunked_results[0]
# Concatenate chunks.
flat_results = []
# Iterate and concat the jacobians of different
# inputs.
for idx in range(len(flat_primals)):
r = tuple(map(lambda r_: r_[idx], chunked_results))
flat_results.append(torch.cat(r, 0))
return flat_results
def compute_jacobian_preallocate_and_copy():
# Helper function to compute chunked Jacobian
# The intermediate chunked calculation are only
# scoped at this function level.
out_vec_size = sum(flat_output_numels)
# Don't pre-allocate if we have a single chunk.
if not (chunk_size is None or chunk_size >= out_vec_size):
stacked_results = [primal.new_zeros(out_vec_size, *primal.shape) for primal in flat_primals]
for idx, flat_basis_chunk in enumerate(_chunked_standard_basis_for_(flat_output,
flat_output_numels,
chunk_size=chunk_size)):
basis = tree_unflatten(flat_basis_chunk, output_spec)
chunked_result = vmap(vjp_fn)(basis)
flat_results, _ = tree_flatten(chunked_result)
if chunk_size is None or chunk_size >= out_vec_size:
# Short-circuit if we have a single chunk.
return flat_results
for r, sr in zip(flat_results, stacked_results):
sr[idx * chunk_size: (idx + 1) * chunk_size].copy_(r)
return stacked_results
if _preallocate_and_copy:
flat_jacobians_per_input = compute_jacobian_preallocate_and_copy()
else:
flat_jacobians_per_input = compute_jacobian_stacked()
# Step 2: The returned jacobian is one big tensor per input. In this step,
# we split each Tensor by output.
flat_results = [result.split(flat_output_numels, dim=0) for result in flat_results]
flat_jacobians_per_input = [result.split(flat_output_numels, dim=0) for result in flat_jacobians_per_input]
flat_input_flat_output = [
tuple(split.view(out.shape + primal.shape)
for split, out in zip(splits, flat_output))
for splits, primal in zip(flat_results, flat_primals)
for splits, primal in zip(flat_jacobians_per_input, flat_primals)
]
# Step 3: Right now, `jacobian` is a List[List[Tensor]].
@ -547,7 +610,7 @@ def jacrev(func: Callable, argnums: Union[int, Tuple[int]] = 0, *, has_aux=False
# - one of shape [ 3] for the second output
def _construct_standard_basis_for(tensors, tensor_numels):
def _chunked_standard_basis_for_(tensors, tensor_numels, chunk_size=None):
# This function:
# - constructs a N=sum(tensor_numels) standard basis. i.e. an NxN identity matrix.
# - Splits the identity matrix into chunks with each chunk size determined by `tensor_numels`.
@ -565,17 +628,37 @@ def _construct_standard_basis_for(tensors, tensor_numels):
#
# See NOTE: [Computing jacobian with vmap and grad for multiple tensors]
# for context behind this function.
# NOTE: Argument `chunk_size` is used to generate chunked basis instead of
# one huge basis matrix. `chunk_size` dictates the maximum size of the
# basis matrix along dim=0.
assert len(tensors) == len(tensor_numels)
assert len(tensors) > 0
assert chunk_size is None or chunk_size > 0
total_numel = sum(tensor_numels)
if chunk_size and chunk_size < total_numel:
n_chunks = total_numel // chunk_size
chunk_numels = [chunk_size] * n_chunks
# remainder chunk
chunk_numels.append(total_numel % chunk_size)
else: # chunk_size is None or chunk_size >= total_numel
chunk_size = total_numel
chunk_numels = [total_numel]
diag_start_indices = (0, *torch.tensor(tensor_numels).cumsum(dim=0)[:-1].neg().unbind())
chunks = tuple(tensor.new_zeros(total_numel, tensor_numel)
for tensor, tensor_numel in zip(tensors, tensor_numels))
for chunk, diag_start_idx in zip(chunks, diag_start_indices):
chunk.diagonal(diag_start_idx).fill_(1)
chunks = tuple(chunk.view(total_numel, *tensor.shape)
for chunk, tensor in zip(chunks, tensors))
return chunks
for chunk_idx, total_numel in enumerate(chunk_numels):
chunks = tuple(tensor.new_zeros(total_numel, tensor_numel)
for tensor, tensor_numel in zip(tensors, tensor_numels))
for chunk, diag_start_idx in zip(chunks, diag_start_indices):
chunk.diagonal(diag_start_idx + chunk_idx * chunk_size).fill_(1)
chunks = tuple(chunk.view(total_numel, *tensor.shape)
for chunk, tensor in zip(chunks, tensors))
yield chunks
def _construct_standard_basis_for(tensors, tensor_numels):
for basis in _chunked_standard_basis_for_(tensors, tensor_numels, chunk_size=None):
return basis
def _validate_and_wrap_argnum(argnum, num_args):