mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
jacrev : Support chunked computation (#89376)
Ref: https://github.com/pytorch/functorch/issues/680 We introduce a kwarg `chunk_size` in `jacrev` to control whether the Jacobian computation should be chunked and if so then `chunk_size` will dictate the maximum size of the chunks used. We try two approaches, * Stacked Approach: Append the intermediate computation to a list and then stack those results. * Pre-allocation Approach: Pre-allocate a zeros tensor and copy chunked computation into it. For Memory Benchmark, see https://github.com/pytorch/pytorch/pull/89376#issuecomment-1348479098 Benchmark CPU : Performs better with more chunks/ smaller chunk_size. NOTE: There seems to be a lot of noise for shape `(64, 64)`. <details> ``` [----------------------------------------------- jacrev : device cpu : chunks 2 -----------------------------------------------] | with chunk_size and stacked | without chunk_size | with chunk_size and pre-allocated 1 threads: --------------------------------------------------------------------------------------------------------------------- (64, 64) : chunk_size 2080 | 76.2 | 50.9 | 80.1 (128, 128) : chunk_size 8256 | 1172.8 | 783.3 | 1225.5 (128, 144) : chunk_size 9288 | 1475.1 | 990.4 | 1548.3 (144, 144) : chunk_size 10440 | 1871.3 | 1254.4 | 1971.2 Times are in milliseconds (ms). [----------------------------------------------- jacrev : device cpu : chunks 3 ----------------------------------------------] | with chunk_size and stacked | without chunk_size | with chunk_size and pre-allocated 1 threads: -------------------------------------------------------------------------------------------------------------------- (64, 64) : chunk_size 1386 | 39.9 | 25.8 | 58.8 (128, 128) : chunk_size 5504 | 1182.6 | 782.2 | 1229.7 (128, 144) : chunk_size 6192 | 1483.6 | 995.4 | 1550.6 (144, 144) : chunk_size 6960 | 1879.1 | 1257.7 | 1960.5 Times are in milliseconds (ms). [----------------------------------------------- jacrev : device cpu : chunks 4 ----------------------------------------------] | with chunk_size and stacked | without chunk_size | with chunk_size and pre-allocated 1 threads: -------------------------------------------------------------------------------------------------------------------- (64, 64) : chunk_size 1040 | 41.7 | 50.6 | 29.1 (128, 128) : chunk_size 4128 | 1171.6 | 782.3 | 1226.7 (128, 144) : chunk_size 4644 | 1482.2 | 994.6 | 1550.9 (144, 144) : chunk_size 5220 | 1870.2 | 1254.5 | 1961.4 Times are in milliseconds (ms). [--------------------------------------------- jacrev : device cpu : chunks 100 ---------------------------------------------] | with chunk_size and stacked | without chunk_size | with chunk_size and pre-allocated 1 threads: ------------------------------------------------------------------------------------------------------------------- (64, 64) : chunk_size 41 | 46.8 | 50.5 | 46.4 (128, 128) : chunk_size 165 | 622.2 | 775.2 | 656.0 (128, 144) : chunk_size 185 | 803.9 | 987.3 | 866.9 (144, 144) : chunk_size 208 | 1021.1 | 1251.2 | 1088.2 Times are in milliseconds (ms). [--------------------------------------------- jacrev : device cpu : chunks 200 ---------------------------------------------] | with chunk_size and stacked | without chunk_size | with chunk_size and pre-allocated 1 threads: ------------------------------------------------------------------------------------------------------------------- (64, 64) : chunk_size 20 | 60.9 | 50.2 | 62.3 (128, 128) : chunk_size 82 | 583.1 | 779.4 | 634.3 (128, 144) : chunk_size 92 | 834.1 | 1005.8 | 472.3 (144, 144) : chunk_size 104 | 1053.6 | 1277.0 | 1033.9 Times are in milliseconds (ms). [--------------------------------------------- jacrev : device cpu : chunks 300 --------------------------------------------] | with chunk_size and stacked | without chunk_size | with chunk_size and pre-allocated 1 threads: ------------------------------------------------------------------------------------------------------------------ (64, 64) : chunk_size 13 | 77.7 | 50.4 | 79.6 (128, 128) : chunk_size 55 | 578.9 | 782.3 | 626.9 (128, 144) : chunk_size 61 | 718.2 | 1024.9 | 800.4 (144, 144) : chunk_size 69 | 919.7 | 1313.7 | 1023.0 Times are in milliseconds (ms). ``` </details> Benchmark CUDA: Performs better with less chunks/bigger chunk_size. <details> ``` [--------------------------------------------- jacrev : device cuda:1 : chunks 2 ----------------------------------------------] | with chunk_size and stacked | without chunk_size | with chunk_size and pre-allocated 1 threads: --------------------------------------------------------------------------------------------------------------------- (64, 64) : chunk_size 2080 | 1485.7 | 923.8 | 1632.3 (128, 128) : chunk_size 8256 | 25390.2 | 14103.2 | 33557.4 (128, 144) : chunk_size 9288 | 801.7 | 16854.1 | 42894.6 (144, 144) : chunk_size 10440 | 1003.5 | 21386.5 | 59648.5 Times are in microseconds (us). 3 / 3 : Shape (144, 144) : Device cuda:1 : chunks: 3 [--------------------------------------------- jacrev : device cuda:1 : chunks 3 ---------------------------------------------] | with chunk_size and stacked | without chunk_size | with chunk_size and pre-allocated 1 threads: -------------------------------------------------------------------------------------------------------------------- (64, 64) : chunk_size 1386 | 1474.5 | 924.5 | 1655.5 (128, 128) : chunk_size 5504 | 25368.9 | 10156.0 | 34022.1 (128, 144) : chunk_size 6192 | 25223.0 | 12933.7 | 56418.5 (144, 144) : chunk_size 6960 | 24729.3 | 16367.4 | 68744.7 Times are in microseconds (us). 3 / 3 : Shape (144, 144) : Device cuda:1 : chunks: 4 [--------------------------------------------- jacrev : device cuda:1 : chunks 4 ---------------------------------------------] | with chunk_size and stacked | without chunk_size | with chunk_size and pre-allocated 1 threads: -------------------------------------------------------------------------------------------------------------------- (64, 64) : chunk_size 1040 | 1489.2 | 924.4 | 1679.6 (128, 128) : chunk_size 4128 | 25370.4 | 8987.4 | 57201.3 (128, 144) : chunk_size 4644 | 32239.1 | 10136.2 | 72406.5 (144, 144) : chunk_size 5220 | 40994.3 | 12867.8 | 108653.4 Times are in microseconds (us). 3 / 3 : Shape (144, 144) : Device cuda:1 : chunks: 100 [------------------------------------------- jacrev : device cuda:1 : chunks 100 --------------------------------------------] | with chunk_size and stacked | without chunk_size | with chunk_size and pre-allocated 1 threads: ------------------------------------------------------------------------------------------------------------------- (64, 64) : chunk_size 41 | 21121.8 | 924.2 | 22753.5 (128, 128) : chunk_size 165 | 23679.7 | 14284.4 | 26758.2 (128, 144) : chunk_size 185 | 30082.3 | 18063.3 | 33553.5 (144, 144) : chunk_size 208 | 38175.6 | 22839.5 | 42030.0 Times are in microseconds (us). ``` </details> Benchmark Script <details> ```python import functorch import torch import itertools import time from torch.utils.benchmark import Timer from torch.utils.benchmark import Compare import sys import pickle from torch import profiler import math def prod(l): prod = 1 for el in l: prod *= el return prod def fn(x, y): return x + y, x.sum(0) shapes = ((64, 64), (128, 128), (128, 144), (144, 144)) for device in ('cpu', 'cuda:1'): if device == 'cuda:1': chunks = (2, 3, 4, 100,) else: chunks = (2, 3, 4, 100, 200, 300) for chunk in chunks: results = [] for shape in shapes: x = torch.zeros(*shape, dtype=torch.float, device=device) y = x.sum() chunk_size = (prod(shape) + prod(shape[1:])) // chunk jacrev_fn_chunked = functorch.jacrev(fn, (0, 1), chunk_size=chunk_size) jacrev_fn_chunked_pre = functorch.jacrev(fn, (0, 1), chunk_size=chunk_size, _preallocate_and_copy=True) jacrev_fn = functorch.jacrev(fn, (0, 1), chunk_size=None) tasks = [("jacrev_fn_chunked(x, y)", "with chunk_size and stacked"), ("jacrev_fn(x, y)", "without chunk_size"), ("jacrev_fn_chunked_pre(x, y)", "with chunk_size and pre-allocated"),] timers = [Timer(stmt=stmt, label=f"jacrev : device {device} : chunks {chunk}", sub_label=f"{(shape)} : chunk_size {chunk_size}", description=desc, globals=globals()) for stmt, desc in tasks] for i, timer in enumerate(timers): results.append( timer.blocked_autorange(min_run_time=2.) ) print(f"\r{i + 1} / {len(timers)} : Shape {shape} : Device {device} : chunks: {chunk}", end="") sys.stdout.flush() print() comparison = Compare(results) comparison.print() ``` </details> Pull Request resolved: https://github.com/pytorch/pytorch/pull/89376 Approved by: https://github.com/zou3519
This commit is contained in:
parent
e2dc60c6cb
commit
f02e93b584
|
|
@ -1872,6 +1872,41 @@ class TestJac(TestCase):
|
|||
out_val = out(x, y, z)
|
||||
self.assertEqual(out_val, expected_out)
|
||||
|
||||
@parametrize('_preallocate_and_copy', (True, False))
|
||||
def test_chunk_jacrev(self, device, _preallocate_and_copy):
|
||||
x = torch.randn(10, 2, device=device)
|
||||
y = torch.randn(1, 2, device=device)
|
||||
|
||||
def f(x, y):
|
||||
return (x.sin(), x + y), (x + 2, x.sum())
|
||||
|
||||
for chunk_size in (1, 2, 3, 4, 7, 10, 1000):
|
||||
expected = jacrev(f, argnums=(0, 1))(x, y)
|
||||
actual = jacrev(f, argnums=(0, 1),
|
||||
chunk_size=chunk_size,
|
||||
_preallocate_and_copy=_preallocate_and_copy)(x, y)
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
err_msg = "jacrev: `chunk_size` should be greater than 0."
|
||||
with self.assertRaisesRegex(ValueError, err_msg):
|
||||
jacrev(f, argnums=(0, ), chunk_size=0)(x, y)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, err_msg):
|
||||
jacrev(f, argnums=(0, ), chunk_size=-2)(x, y)
|
||||
|
||||
@parametrize('_preallocate_and_copy', (True, False))
|
||||
def test_chunk_jacrev_composition(self, device, _preallocate_and_copy):
|
||||
x = torch.randn(10, 2, device=device)
|
||||
chunk_size = 3
|
||||
|
||||
def f(x):
|
||||
return (x.sin(), x), (x + 2, x.sum())
|
||||
|
||||
expected = vmap(jacrev(jacrev(f)))(x)
|
||||
actual = vmap(jacrev(jacrev(f, chunk_size=chunk_size,
|
||||
_preallocate_and_copy=_preallocate_and_copy), chunk_size=chunk_size))(x)
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
|
||||
class TestHessian(TestCase):
|
||||
def _test_against_reference(self, f, inputs):
|
||||
|
|
|
|||
|
|
@ -336,7 +336,9 @@ def _safe_zero_index(x):
|
|||
return x[0]
|
||||
|
||||
|
||||
def jacrev(func: Callable, argnums: Union[int, Tuple[int]] = 0, *, has_aux=False):
|
||||
def jacrev(func: Callable, argnums: Union[int, Tuple[int]] = 0, *, has_aux=False,
|
||||
chunk_size: Optional[int] = None,
|
||||
_preallocate_and_copy=False):
|
||||
"""
|
||||
Computes the Jacobian of :attr:`func` with respect to the arg(s) at index
|
||||
:attr:`argnum` using reverse mode autodiff
|
||||
|
|
@ -352,6 +354,13 @@ def jacrev(func: Callable, argnums: Union[int, Tuple[int]] = 0, *, has_aux=False
|
|||
the function to be differentiated and the second element is
|
||||
auxiliary objects that will not be differentiated.
|
||||
Default: False.
|
||||
chunk_size (None or int): If None (default), use the maximum chunk size
|
||||
(equivalent to doing a single vmap over vjp to compute the jacobian).
|
||||
If not None, then compute the jacobian :attr:`chunk_size` rows at a time
|
||||
(equivalent to doing multiple vmap over vjp).
|
||||
Note that :attr:`chunk_size=1` is equivalent to computing the jacobian
|
||||
row-by-row with a for-loop. If you run into memory issues computing
|
||||
the jacobian, please try to specify a non-None chunk_size.
|
||||
|
||||
Returns:
|
||||
Returns a function that takes in the same inputs as :attr:`func` and
|
||||
|
|
@ -452,6 +461,9 @@ def jacrev(func: Callable, argnums: Union[int, Tuple[int]] = 0, *, has_aux=False
|
|||
outer one. This is because ``jacrev`` is a "function transform": its result
|
||||
should not depend on the result of a context manager outside of ``f``.
|
||||
"""
|
||||
if not (chunk_size is None or chunk_size > 0):
|
||||
raise ValueError("jacrev: `chunk_size` should be greater than 0.")
|
||||
|
||||
@wraps(func)
|
||||
def wrapper_fn(*args):
|
||||
vjp_out = _vjp_with_argnums(func, *args, argnums=argnums, has_aux=has_aux)
|
||||
|
|
@ -466,22 +478,73 @@ def jacrev(func: Callable, argnums: Union[int, Tuple[int]] = 0, *, has_aux=False
|
|||
# NB: vjp already checks that all outputs are tensors
|
||||
# Step 1: Construct grad_outputs by splitting the standard basis
|
||||
flat_output_numels = tuple(out.numel() for out in flat_output)
|
||||
flat_basis = _construct_standard_basis_for(flat_output, flat_output_numels)
|
||||
basis = tree_unflatten(flat_basis, output_spec)
|
||||
|
||||
results = vmap(vjp_fn)(basis)
|
||||
|
||||
primals = _slice_argnums(args, argnums)
|
||||
flat_primals, primals_spec = tree_flatten(primals)
|
||||
flat_results, results_spec = tree_flatten(results)
|
||||
|
||||
def compute_jacobian_stacked():
|
||||
# Helper function to compute chunked Jacobian
|
||||
# The intermediate chunked calculation are only
|
||||
# scoped at this function level.
|
||||
chunked_results = []
|
||||
for flat_basis_chunk in _chunked_standard_basis_for_(flat_output,
|
||||
flat_output_numels,
|
||||
chunk_size=chunk_size):
|
||||
basis = tree_unflatten(flat_basis_chunk, output_spec)
|
||||
chunked_result = vmap(vjp_fn)(basis)
|
||||
flat_results, _ = tree_flatten(chunked_result)
|
||||
chunked_results.append(flat_results)
|
||||
|
||||
if len(chunked_results) == 1:
|
||||
# Short-circuit if we used a single chunk
|
||||
return chunked_results[0]
|
||||
|
||||
# Concatenate chunks.
|
||||
flat_results = []
|
||||
# Iterate and concat the jacobians of different
|
||||
# inputs.
|
||||
for idx in range(len(flat_primals)):
|
||||
r = tuple(map(lambda r_: r_[idx], chunked_results))
|
||||
flat_results.append(torch.cat(r, 0))
|
||||
|
||||
return flat_results
|
||||
|
||||
def compute_jacobian_preallocate_and_copy():
|
||||
# Helper function to compute chunked Jacobian
|
||||
# The intermediate chunked calculation are only
|
||||
# scoped at this function level.
|
||||
out_vec_size = sum(flat_output_numels)
|
||||
|
||||
# Don't pre-allocate if we have a single chunk.
|
||||
if not (chunk_size is None or chunk_size >= out_vec_size):
|
||||
stacked_results = [primal.new_zeros(out_vec_size, *primal.shape) for primal in flat_primals]
|
||||
for idx, flat_basis_chunk in enumerate(_chunked_standard_basis_for_(flat_output,
|
||||
flat_output_numels,
|
||||
chunk_size=chunk_size)):
|
||||
basis = tree_unflatten(flat_basis_chunk, output_spec)
|
||||
chunked_result = vmap(vjp_fn)(basis)
|
||||
flat_results, _ = tree_flatten(chunked_result)
|
||||
if chunk_size is None or chunk_size >= out_vec_size:
|
||||
# Short-circuit if we have a single chunk.
|
||||
return flat_results
|
||||
|
||||
for r, sr in zip(flat_results, stacked_results):
|
||||
sr[idx * chunk_size: (idx + 1) * chunk_size].copy_(r)
|
||||
|
||||
return stacked_results
|
||||
|
||||
if _preallocate_and_copy:
|
||||
flat_jacobians_per_input = compute_jacobian_preallocate_and_copy()
|
||||
else:
|
||||
flat_jacobians_per_input = compute_jacobian_stacked()
|
||||
|
||||
# Step 2: The returned jacobian is one big tensor per input. In this step,
|
||||
# we split each Tensor by output.
|
||||
flat_results = [result.split(flat_output_numels, dim=0) for result in flat_results]
|
||||
flat_jacobians_per_input = [result.split(flat_output_numels, dim=0) for result in flat_jacobians_per_input]
|
||||
flat_input_flat_output = [
|
||||
tuple(split.view(out.shape + primal.shape)
|
||||
for split, out in zip(splits, flat_output))
|
||||
for splits, primal in zip(flat_results, flat_primals)
|
||||
for splits, primal in zip(flat_jacobians_per_input, flat_primals)
|
||||
]
|
||||
|
||||
# Step 3: Right now, `jacobian` is a List[List[Tensor]].
|
||||
|
|
@ -547,7 +610,7 @@ def jacrev(func: Callable, argnums: Union[int, Tuple[int]] = 0, *, has_aux=False
|
|||
# - one of shape [ 3] for the second output
|
||||
|
||||
|
||||
def _construct_standard_basis_for(tensors, tensor_numels):
|
||||
def _chunked_standard_basis_for_(tensors, tensor_numels, chunk_size=None):
|
||||
# This function:
|
||||
# - constructs a N=sum(tensor_numels) standard basis. i.e. an NxN identity matrix.
|
||||
# - Splits the identity matrix into chunks with each chunk size determined by `tensor_numels`.
|
||||
|
|
@ -565,17 +628,37 @@ def _construct_standard_basis_for(tensors, tensor_numels):
|
|||
#
|
||||
# See NOTE: [Computing jacobian with vmap and grad for multiple tensors]
|
||||
# for context behind this function.
|
||||
# NOTE: Argument `chunk_size` is used to generate chunked basis instead of
|
||||
# one huge basis matrix. `chunk_size` dictates the maximum size of the
|
||||
# basis matrix along dim=0.
|
||||
assert len(tensors) == len(tensor_numels)
|
||||
assert len(tensors) > 0
|
||||
assert chunk_size is None or chunk_size > 0
|
||||
total_numel = sum(tensor_numels)
|
||||
if chunk_size and chunk_size < total_numel:
|
||||
n_chunks = total_numel // chunk_size
|
||||
chunk_numels = [chunk_size] * n_chunks
|
||||
# remainder chunk
|
||||
chunk_numels.append(total_numel % chunk_size)
|
||||
else: # chunk_size is None or chunk_size >= total_numel
|
||||
chunk_size = total_numel
|
||||
chunk_numels = [total_numel]
|
||||
|
||||
diag_start_indices = (0, *torch.tensor(tensor_numels).cumsum(dim=0)[:-1].neg().unbind())
|
||||
chunks = tuple(tensor.new_zeros(total_numel, tensor_numel)
|
||||
for tensor, tensor_numel in zip(tensors, tensor_numels))
|
||||
for chunk, diag_start_idx in zip(chunks, diag_start_indices):
|
||||
chunk.diagonal(diag_start_idx).fill_(1)
|
||||
chunks = tuple(chunk.view(total_numel, *tensor.shape)
|
||||
for chunk, tensor in zip(chunks, tensors))
|
||||
return chunks
|
||||
|
||||
for chunk_idx, total_numel in enumerate(chunk_numels):
|
||||
chunks = tuple(tensor.new_zeros(total_numel, tensor_numel)
|
||||
for tensor, tensor_numel in zip(tensors, tensor_numels))
|
||||
|
||||
for chunk, diag_start_idx in zip(chunks, diag_start_indices):
|
||||
chunk.diagonal(diag_start_idx + chunk_idx * chunk_size).fill_(1)
|
||||
chunks = tuple(chunk.view(total_numel, *tensor.shape)
|
||||
for chunk, tensor in zip(chunks, tensors))
|
||||
yield chunks
|
||||
|
||||
def _construct_standard_basis_for(tensors, tensor_numels):
|
||||
for basis in _chunked_standard_basis_for_(tensors, tensor_numels, chunk_size=None):
|
||||
return basis
|
||||
|
||||
|
||||
def _validate_and_wrap_argnum(argnum, num_args):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user