mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
expanded weights without fast rules (#70140)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/70140 [Design Doc for Expanded Weights](https://gist.github.com/samdow/fa0a164fec7963f93ff45284989cfc55) <-- gives an overview of the design for Expanded Weights Introduces the ExpandedWeights mechanism and user-facing API without any custom implemented, faster rules. - User facing API is in `_stateless.py` (with documentation) - Testing is in test_expanded_weights - The rest is the implementation of the erroring fallback + the mechanism for being able to register faster per sample grad rules. Only linear is implemented here, but they are all implemented in #70141 Test Plan: Imported from OSS Reviewed By: mikaylagawarecki Differential Revision: D34350950 Pulled By: samdow fbshipit-source-id: 69c664b0bc3dff6951358d79d7e5d94882f7aef2
This commit is contained in:
parent
999cb73e93
commit
ae1620d3b6
387
test/test_expanded_weights.py
Normal file
387
test/test_expanded_weights.py
Normal file
|
|
@ -0,0 +1,387 @@
|
|||
# Owner(s): ["module: nn"]
|
||||
|
||||
from functools import partial
|
||||
from itertools import product
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.utils._per_sample_grad import call_for_per_sample_grads
|
||||
from torch.testing._internal.common_device_type import OpDTypes, instantiate_device_type_tests, ops
|
||||
from torch.testing._internal.common_nn import TestBase, module_tests, new_module_tests
|
||||
from torch.testing._internal.common_utils import TestCase, freeze_rng_state, make_tensor, run_tests
|
||||
from torch.testing._internal.common_methods_invocations import SampleInput, op_db
|
||||
from torch.nn.utils._expanded_weights import ExpandedWeight
|
||||
from torch.nn.utils._expanded_weights.expanded_weights_utils import forward_helper, set_grad_sample_if_exists, \
|
||||
unpack_expanded_weight_or_tensor, sum_over_all_but_batch_and_last_n, standard_kwargs
|
||||
|
||||
class TestContext:
|
||||
pass
|
||||
|
||||
class TestExpandedWeightHelperFunction(TestCase):
|
||||
def test_forward_helper(self, device):
|
||||
input = torch.randn(3, 4, device=device)
|
||||
weight = torch.randn(5, 4, device=device)
|
||||
bias = torch.randn(5, device=device)
|
||||
for (weight_batched, bias_batched) in product([True, False], [True, False]):
|
||||
maybe_batched_weight = ExpandedWeight(weight.clone().requires_grad_(), 3) if weight_batched else weight
|
||||
maybe_batched_bias = ExpandedWeight(bias.clone().requires_grad_(), 3) if bias_batched else bias
|
||||
args = (input, maybe_batched_weight, maybe_batched_bias)
|
||||
expanded_args, expanded_kwargs = standard_kwargs(('bias',), args)
|
||||
res = forward_helper(nn.functional.linear, expanded_args, expanded_kwargs)
|
||||
expected = nn.functional.linear(input, weight, bias)
|
||||
self.assertEqual(res, expected)
|
||||
|
||||
self.assertEqual(len(expanded_args), 2)
|
||||
assert expanded_args[0] is args[0] # avoids property checks in assertEquals
|
||||
assert expanded_args[1] is args[1] # avoids property checks in assertEquals
|
||||
self.assertEqual(len(expanded_kwargs), 1)
|
||||
assert expanded_kwargs['bias'] is args[2] # avoids property checks in assertEquals
|
||||
|
||||
def test_forward_helper_failure_args(self, device):
|
||||
weight = torch.randn(5, 4, device=device)
|
||||
bias = torch.randn(5, device=device)
|
||||
with self.assertRaisesRegex(RuntimeError, r"do not support inputs that are also ExpandedWeights."):
|
||||
input = ExpandedWeight(torch.randn(3, 4, requires_grad=True), 3)
|
||||
expanded_args, expanded_kwargs = standard_kwargs(('bias',), (input, weight, bias))
|
||||
forward_helper(nn.functional.linear, expanded_args, expanded_kwargs)
|
||||
with self.assertRaisesRegex(RuntimeError, r"requires a Tensor as the first input"):
|
||||
expanded_args, expanded_kwargs = standard_kwargs(('bias',), (3, weight, bias))
|
||||
forward_helper(nn.functional.linear, expanded_args, expanded_kwargs)
|
||||
with self.assertRaisesRegex(RuntimeError, r"requires a batch dimension but got an input of size 0"):
|
||||
expanded_args, expanded_kwargs = standard_kwargs(('bias',), (torch.tensor(3), weight, bias))
|
||||
forward_helper(nn.functional.linear, expanded_args, expanded_kwargs)
|
||||
with self.assertRaisesRegex(RuntimeError, r"0 is not a valid batch size for Expanded Weights"):
|
||||
expanded_args, expanded_kwargs = standard_kwargs(('bias',), (torch.randn(0, 1, 2), weight, bias))
|
||||
forward_helper(nn.functional.linear, expanded_args, expanded_kwargs)
|
||||
input = torch.randn(3, 4)
|
||||
for (weight_batched, bias_batched) in product([True, False], [True, False]):
|
||||
if not weight_batched and not bias_batched:
|
||||
continue
|
||||
maybe_batched_weight = ExpandedWeight(weight.clone().requires_grad_(), 4) if weight_batched else weight
|
||||
maybe_batched_bias = ExpandedWeight(bias.clone().requires_grad_(), 4) if bias_batched else bias
|
||||
with self.assertRaisesRegex(RuntimeError, r"Expected ExpandedWeights to have batch size matching input"):
|
||||
expanded_args, expanded_kwargs = standard_kwargs(('bias',), (input, maybe_batched_weight, maybe_batched_bias))
|
||||
forward_helper(nn.functional.linear, expanded_args, expanded_kwargs)
|
||||
|
||||
def test_set_grad_sample_if_exists(self, device):
|
||||
def test_fn(_):
|
||||
return True
|
||||
|
||||
orig_weight = torch.randn(4, device=device, requires_grad=True)
|
||||
expanded_weight = ExpandedWeight(orig_weight, 3)
|
||||
set_grad_sample_if_exists(expanded_weight, test_fn)
|
||||
self.assertTrue(hasattr(orig_weight, 'grad_sample'))
|
||||
self.assertTrue(orig_weight.grad_sample)
|
||||
|
||||
basic_tensor = torch.randn(4, device=device)
|
||||
set_grad_sample_if_exists(basic_tensor, test_fn)
|
||||
self.assertFalse(hasattr(basic_tensor, 'grad_sample'))
|
||||
|
||||
non_tensor = 3
|
||||
set_grad_sample_if_exists(non_tensor, test_fn)
|
||||
self.assertFalse(hasattr(non_tensor, 'grad_sample'))
|
||||
|
||||
def test_set_grad_sample_if_exists_failure(self, device):
|
||||
def test_fn(_):
|
||||
return True
|
||||
|
||||
grad_tensor = torch.randn(4, requires_grad=True, device=device)
|
||||
with self.assertRaisesRegex(RuntimeError, r"does not support a mixture of ExpandedWeight parameters and normal Parameters"):
|
||||
set_grad_sample_if_exists(grad_tensor, test_fn)
|
||||
|
||||
def test_unpack_expanded_weight_or_tensor(self, device):
|
||||
input = torch.randn(3, requires_grad=True, device=device)
|
||||
self.assertEqual(input, unpack_expanded_weight_or_tensor(ExpandedWeight(input, 3)))
|
||||
|
||||
input.requires_grad_(False)
|
||||
self.assertEqual(input, unpack_expanded_weight_or_tensor(input))
|
||||
self.assertTrue(unpack_expanded_weight_or_tensor(4) is None)
|
||||
|
||||
def test_unpack_expanded_weight_or_tensor_with_custom_function(self, device):
|
||||
input = torch.randn(3, requires_grad=True, device=device)
|
||||
self.assertTrue(unpack_expanded_weight_or_tensor(ExpandedWeight(input, 3), lambda x: x is input))
|
||||
|
||||
input.requires_grad_(False)
|
||||
self.assertTrue(unpack_expanded_weight_or_tensor(input, lambda x: x is input))
|
||||
self.assertTrue(unpack_expanded_weight_or_tensor(4, lambda x: x is input) is None)
|
||||
|
||||
def test_unpack_expanded_weight_or_tensor_failure(self, device):
|
||||
input = torch.randn(3, requires_grad=True, device=device)
|
||||
with self.assertRaisesRegex(RuntimeError, r"does not support a mixture of ExpandedWeight parameters and normal Parameters"):
|
||||
unpack_expanded_weight_or_tensor(input)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, r"does not support a mixture of ExpandedWeight parameters and normal Parameters"):
|
||||
unpack_expanded_weight_or_tensor(input, lambda x: x is input)
|
||||
|
||||
def test_sum_over_all_but_batch_and_last_n(self, device):
|
||||
input = torch.randn(1, 2, 3, 4, 5, device=device)
|
||||
res = sum_over_all_but_batch_and_last_n(input, 2)
|
||||
expected = input.sum((1, 2))
|
||||
self.assertEqual(res, expected)
|
||||
|
||||
res = sum_over_all_but_batch_and_last_n(input, 0)
|
||||
expected = input.sum((1, 2, 3, 4))
|
||||
self.assertEqual(res, expected)
|
||||
|
||||
res = sum_over_all_but_batch_and_last_n(input, 4)
|
||||
self.assertEqual(res, input)
|
||||
|
||||
class TestExpandedWeightFunctional(TestCase):
|
||||
@ops(filter(lambda op: op.supports_expanded_weight, op_db), dtypes=OpDTypes.supported, allowed_dtypes=(torch.double,))
|
||||
def test_expanded_weight_per_sample_grad(self, device, dtype, op):
|
||||
sample_inputs = op.sample_inputs(device, dtype, requires_grad=True)
|
||||
for sample_input in supported_inputs(op, sample_inputs):
|
||||
if op.name == "nn.functional.embedding": # embedding flips its argument order for autograd tests
|
||||
sample_input = SampleInput(sample_input.args[0], args=(sample_input.input,), kwargs=sample_input.kwargs)
|
||||
input = sample_input.input
|
||||
args = sample_input.args
|
||||
kwargs = sample_input.kwargs
|
||||
batch_size = input.shape[0] if len(input.shape) > 1 else 1
|
||||
|
||||
# get per sample grads with ExpandedWeights objects
|
||||
(ew_input, ew_args, ew_kwargs) = make_expanded_weight(sample_input, batch_size)
|
||||
diff_input_list = (ew_input,) + tuple(ew_args) + tuple(ew_kwargs.values())
|
||||
diff_input_list = [i for i in diff_input_list if is_diff_tensor(i)]
|
||||
diff_input_list = [i.orig_weight if isinstance(i, ExpandedWeight) else i for i in diff_input_list]
|
||||
if not diff_input_list:
|
||||
continue
|
||||
result = run_op(op, ew_input, *ew_args, **ew_kwargs)
|
||||
result.sum().backward() # grad doesn't work with ExpandedWeight because it calls __torch_function__
|
||||
expanded_weight_grad = tuple(i.grad_sample if hasattr(i, "grad_sample") else i.grad for i in diff_input_list)
|
||||
|
||||
# get per sample grads with for loop
|
||||
func = partial(run_op, op)
|
||||
per_sample_grad = for_loop_per_sample_grad(batch_size, input, func, *args, **kwargs)
|
||||
|
||||
# check equality
|
||||
self.assertEqual(len(per_sample_grad), len(expanded_weight_grad))
|
||||
for (result_grad, expected_grad) in zip(expanded_weight_grad, per_sample_grad):
|
||||
if result_grad is None:
|
||||
result_grad = torch.zeros_like(expected_grad)
|
||||
assert torch.allclose(result_grad, expected_grad), f"Got {result_grad}, expected {expected_grad}"
|
||||
|
||||
@ops(filter(lambda op: op.supports_expanded_weight, op_db), dtypes=OpDTypes.supported, allowed_dtypes=(torch.double,))
|
||||
def test_unsupported_expand_weights(self, device, dtype, op):
|
||||
sample_inputs = op.sample_inputs(device, dtype, requires_grad=True)
|
||||
unsupported_inputs = supported_inputs(op, sample_inputs, supported_inputs=False)
|
||||
for sample_input in unsupported_inputs:
|
||||
with self.assertRaisesRegex(RuntimeError, r"Expanded Weights"):
|
||||
if op.name == "nn.functional.embedding": # embedding flips its argument order for autograd tests
|
||||
sample_input = SampleInput(sample_input.args[0], args=(sample_input.input,), kwargs=sample_input.kwargs)
|
||||
input = sample_input.input
|
||||
|
||||
batch_size = input.shape[0] if len(input.shape) > 1 else 1
|
||||
|
||||
# get per sample grads with ExpandedWeights objects
|
||||
(ew_input, ew_args, ew_kwargs) = make_expanded_weight(sample_input, batch_size)
|
||||
result = run_op(op, ew_input, *ew_args, **ew_kwargs)
|
||||
diff_input_list = (ew_input,) + tuple(ew_args) + tuple(ew_kwargs.values())
|
||||
diff_input_list = [i for i in diff_input_list if is_diff_tensor(i)]
|
||||
diff_input_list = [i.orig_weight if isinstance(i, ExpandedWeight) else i for i in diff_input_list]
|
||||
result.sum().backward() # grad doesn't work with ExpandedWeight because it calls __torch_function__
|
||||
|
||||
@ops(filter(lambda op: op.supports_expanded_weight, op_db), dtypes=OpDTypes.supported)
|
||||
def test_expanded_weight_forward(self, device, dtype, op):
|
||||
sample_inputs = op.sample_inputs(device, dtype)
|
||||
for sample_input in supported_inputs(op, sample_inputs):
|
||||
batch_size = sample_input.input.shape[0] if len(sample_input.input.shape) > 1 else 1
|
||||
(ew_input, ew_args, ew_kwargs) = make_expanded_weight(sample_input, batch_size)
|
||||
expanded_weight_result = op(ew_input, *ew_args, **ew_kwargs)
|
||||
normal_result = op(sample_input.input, *sample_input.args, **sample_input.kwargs)
|
||||
self.assertEqual(expanded_weight_result, normal_result)
|
||||
|
||||
def test_expanded_weight_error(self, device):
|
||||
batch_size = 3
|
||||
sample_input = make_tensor((batch_size, 4), device, torch.float32, requires_grad=True)
|
||||
sample_weight = make_tensor((4), device, torch.float32, requires_grad=True)
|
||||
with self.assertRaisesRegex(RuntimeError, r"Expanded Weights encountered but cannot handle function"):
|
||||
torch.add(sample_input, ExpandedWeight(sample_weight, batch_size))
|
||||
|
||||
|
||||
class TestExpandedWeightModule(TestCase):
|
||||
def _do_test(self, module, input):
|
||||
batch_size = input.shape[0]
|
||||
with freeze_rng_state():
|
||||
# get per sample grads with ExpandedWeights context manager
|
||||
actual_res = call_for_per_sample_grads(module, batch_size, input).sum()
|
||||
actual_res.backward()
|
||||
actual_grads = []
|
||||
for param in module.parameters():
|
||||
actual_grads.append(param.grad_sample)
|
||||
del param.grad_sample
|
||||
|
||||
# get per sample grads with a for loop
|
||||
expected_res = torch.tensor(0.)
|
||||
expected_grads = []
|
||||
for i in range(batch_size):
|
||||
res = module(input[i].unsqueeze(0)).sum()
|
||||
expected_grads.append(torch.autograd.grad(res, module.parameters(), torch.ones_like(res)))
|
||||
expected_res += res
|
||||
expected_grads = tuple(torch.stack(grad) for grad in zip(*expected_grads))
|
||||
self.assertEqual(actual_res, expected_res)
|
||||
assert [torch.allclose(actual, expected) for (actual, expected) in zip(actual_grads, expected_grads)]
|
||||
|
||||
def _do_test_multi_input(self, module, input):
|
||||
class TestModule(nn.Module):
|
||||
def __init__(self, module):
|
||||
super().__init__()
|
||||
self.module = module
|
||||
|
||||
def forward(self, input):
|
||||
return self.module(input) + self.module(input)
|
||||
|
||||
batch_size = input.shape[0]
|
||||
with freeze_rng_state():
|
||||
# get per sample grads with ExpandedWeights context manager, calling .backward() twice
|
||||
test_module = TestModule(module)
|
||||
actual_res = call_for_per_sample_grads(test_module, batch_size, input).sum()
|
||||
actual_res.backward()
|
||||
actual_grads = []
|
||||
for param in module.parameters():
|
||||
actual_grads.append(param.grad_sample)
|
||||
del param.grad_sample
|
||||
|
||||
# get per sample grads with a for loop, running over the input twice
|
||||
expected_grads = []
|
||||
for i in range(batch_size):
|
||||
res = module(input[i].unsqueeze(0)).sum()
|
||||
expected_grads.append(torch.autograd.grad(res, module.parameters(), torch.ones_like(res)))
|
||||
expected_grads = tuple(torch.stack(grad) for grad in zip(*expected_grads))
|
||||
assert [torch.allclose(actual, 2 * expected) for (actual, expected) in zip(actual_grads, expected_grads)]
|
||||
|
||||
def test_per_sample_api_failing(self):
|
||||
module = nn.Linear(10, 10)
|
||||
input = torch.randn(64, 10)
|
||||
with self.assertRaisesRegex(RuntimeError, r"Module passed must be nn.Module"):
|
||||
call_for_per_sample_grads("fail", 64, input)
|
||||
with self.assertRaisesRegex(RuntimeError, r"Batch size passed must be an integer"):
|
||||
call_for_per_sample_grads(module, 6.4, input)
|
||||
with self.assertRaisesRegex(RuntimeError, r"Batch size must be positive"):
|
||||
call_for_per_sample_grads(module, -64, input)
|
||||
with self.assertRaisesRegex(RuntimeError, r"incorrect for multiple calls"):
|
||||
loss = call_for_per_sample_grads(module, 64, input).sum()
|
||||
loss.backward() # populate grad_sample fields
|
||||
call_for_per_sample_grads(module, 64, input)
|
||||
|
||||
class ContextManagerTests(TestBase):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def constructor_args(self):
|
||||
return self._get_arg('constructor_args', False)
|
||||
|
||||
def test_context_manager(self, test_case):
|
||||
module = self.constructor(*self.constructor_args)
|
||||
input = self._get_input()
|
||||
if len(input.shape) == 0 or input.shape[0] == 0:
|
||||
raise unittest.SkipTest("Can't get per sample gradients when no batch dim or batch dim is 0")
|
||||
if self.constructor == torch.nn.Linear and len(input.shape) == 1:
|
||||
raise unittest.SkipTest("Can't get per sample gradients for input of rank 1")
|
||||
test_case._do_test(module, input)
|
||||
|
||||
def test_context_manager_multiple_inputs(self, test_case):
|
||||
module = self.constructor(*self.constructor_args)
|
||||
input = self._get_input()
|
||||
if len(input.shape) == 0 or input.shape[0] == 0:
|
||||
raise unittest.SkipTest("Can't get per sample gradients when no batch dim or batch dim is 0")
|
||||
if self.constructor == torch.nn.Linear and len(input.shape) == 1:
|
||||
raise unittest.SkipTest("Can't get per sample gradients for input of rank 1")
|
||||
test_case._do_test_multi_input(module, input)
|
||||
|
||||
# TODO: Once all of these use ModuleInfo, replace with ModuleInfo tests
|
||||
# These currently use the legacy nn tests
|
||||
supported_modules = ['Linear']
|
||||
supported_tests = [t for t in module_tests + new_module_tests if 'module_name' in t and t['module_name'] in supported_modules]
|
||||
for test_param in supported_tests:
|
||||
if 'constructor' not in test_param:
|
||||
name = test_param.pop('module_name')
|
||||
test_param['constructor'] = getattr(nn, name)
|
||||
decorator = test_param.pop('decorator', None)
|
||||
test = ContextManagerTests(**test_param)
|
||||
test_name = test.get_name()
|
||||
if hasattr(TestExpandedWeightModule, test_name):
|
||||
raise RuntimeError('Found two tests with the same name: ' + test_name)
|
||||
test_name_multi_input = test.get_name() + "_multiple_inputs"
|
||||
if hasattr(TestExpandedWeightModule, test_name_multi_input):
|
||||
raise RuntimeError('Found two tests with the same name: ' + test_name)
|
||||
if decorator is not None:
|
||||
fn = decorator(fn)
|
||||
setattr(TestExpandedWeightModule, test_name, lambda self, test=test: test.test_context_manager(self))
|
||||
setattr(TestExpandedWeightModule, test_name_multi_input,
|
||||
lambda self, test=test: test.test_context_manager_multiple_inputs(self))
|
||||
|
||||
# ------------- HELPER FUNCTIONS -----------------
|
||||
|
||||
def run_op(op, input, *args, **kwargs):
|
||||
r"""
|
||||
OpInfo for Embedding switches the input and weight so autograd tests will only check the derivative
|
||||
of the weight, not the input, which can't be differentiable since its dtype is int. Calls op,
|
||||
using the special ordering that Embedding's OpInfo expects for that case.
|
||||
"""
|
||||
if op.name == "nn.functional.embedding":
|
||||
return op(args[0], input, **kwargs)
|
||||
else:
|
||||
return op(input, *args, **kwargs)
|
||||
|
||||
def make_expanded_weight(sample_input, batch_size):
|
||||
def expanded_weight_or_clone(arg):
|
||||
return ExpandedWeight(torch.clone(arg), batch_size) if is_diff_tensor(arg) else clone_if_tensor(arg)
|
||||
|
||||
ew_input = clone_if_tensor(sample_input.input)
|
||||
ew_args = tuple(expanded_weight_or_clone(arg) for arg in sample_input.args)
|
||||
ew_kwargs = {name: expanded_weight_or_clone(arg) for (name, arg) in sample_input.kwargs.items()}
|
||||
return ew_input, ew_args, ew_kwargs
|
||||
|
||||
def supported_inputs(op, sample_inputs, supported_inputs=True):
|
||||
r"""
|
||||
ExpandedWeights currently does not support some use cases when there's no batch dimension or
|
||||
operations that would cause inter-batch operations. Removes all of the cases it cannot deal with
|
||||
"""
|
||||
def filter_fn(input):
|
||||
if op.name == "nn.functional.linear":
|
||||
is_supported_input = len(input.input.shape) > 1 # input of rank 1 means no batch dim
|
||||
elif op.name == "nn.functional.layer_norm":
|
||||
normalized_shape = input.args[0]
|
||||
is_supported_input = input.input.shape != normalized_shape # would cause inter-batch operations
|
||||
elif op.name == "nn.functional.conv2d":
|
||||
# currently can't deal with padding computation on Python level
|
||||
is_supported_input = 'padding' not in input.kwargs or not isinstance(input.kwargs['padding'], str)
|
||||
elif op.name == "nn.functional.embedding":
|
||||
idx = input.args[0]
|
||||
is_supported_input = len(idx.shape) > 1 # there's no batch size
|
||||
else:
|
||||
is_supported_input = True
|
||||
is_supported_input = is_supported_input and input.input.shape[0] > 0 # 0 is not a valid batch size
|
||||
return is_supported_input if supported_inputs else not is_supported_input
|
||||
return [input for input in sample_inputs if filter_fn(input)]
|
||||
|
||||
def for_loop_per_sample_grad(batch_size, input, func, *args, **kwargs):
|
||||
# get per sample grads by getting derivative for each input in a for loop
|
||||
per_sample_grad = []
|
||||
for i in range(batch_size):
|
||||
per_sample_input = input[i]
|
||||
result = func(per_sample_input.unsqueeze(0), *args, **kwargs)
|
||||
diff_input_list = (per_sample_input,) + tuple(args) + tuple(kwargs.values())
|
||||
diff_input_list = [i for i in diff_input_list if isinstance(i, torch.Tensor) and i.requires_grad]
|
||||
per_sample_grad.append(torch.autograd.grad(result, diff_input_list, torch.ones_like(result), allow_unused=True))
|
||||
if len(per_sample_grad) == batch_size:
|
||||
per_sample_grad = tuple(torch.stack(grad) for grad in zip(*per_sample_grad))
|
||||
return per_sample_grad
|
||||
|
||||
def is_diff_tensor(t):
|
||||
return isinstance(t, ExpandedWeight) or (isinstance(t, torch.Tensor) and t.requires_grad)
|
||||
|
||||
def clone_if_tensor(t):
|
||||
if isinstance(t, torch.Tensor):
|
||||
res = torch.clone(t).detach()
|
||||
res.requires_grad_(t.requires_grad)
|
||||
return res
|
||||
else:
|
||||
return t
|
||||
|
||||
instantiate_device_type_tests(TestExpandedWeightHelperFunction, globals())
|
||||
instantiate_device_type_tests(TestExpandedWeightFunctional, globals())
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
4
torch/nn/utils/_expanded_weights/__init__.py
Normal file
4
torch/nn/utils/_expanded_weights/__init__.py
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
from .linear_expanded_weights import LinearPerSampleGrad
|
||||
from .expanded_weights_impl import ExpandedWeight
|
||||
|
||||
__all__ = ['ExpandedWeight']
|
||||
59
torch/nn/utils/_expanded_weights/expanded_weights_impl.py
Normal file
59
torch/nn/utils/_expanded_weights/expanded_weights_impl.py
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
from torch._C import _TensorBase
|
||||
import torch
|
||||
import functools
|
||||
|
||||
from typing import Callable, Dict, cast
|
||||
|
||||
HANDLED_FUNCTIONS: Dict[Callable, torch.autograd.Function] = {}
|
||||
|
||||
def implements_per_sample_grads(torch_function):
|
||||
@functools.wraps(torch_function)
|
||||
def decorator(autograd_func):
|
||||
HANDLED_FUNCTIONS[torch_function] = autograd_func
|
||||
return autograd_func
|
||||
return decorator
|
||||
|
||||
# ExpandedWeight represents a weight (parameter) Tensor that has an expanded
|
||||
# batch dimension. Operations on the ExpandedWeight Tensor act exactly like
|
||||
# those without an expanded batch dimension but a call to .backward() populates
|
||||
# the original (unexpanded) tensor with per-sample-gradients for in the grad_sample field
|
||||
#
|
||||
# ExpandedWeight has a fallback that always fails since we cannot know what the batch
|
||||
# dimension of the input tensor is and therefore cannot know if this is a valid call
|
||||
#
|
||||
# This is a __torch_function__ object but it could have also been a Tensor Extension
|
||||
# with a dispatch key.
|
||||
#
|
||||
# Needs to be a tensor subclass to allow reparamaterization
|
||||
class ExpandedWeight(torch.Tensor):
|
||||
def __init__(self, orig_weight, batch_size):
|
||||
self.batch_size = batch_size
|
||||
self.orig_weight = orig_weight
|
||||
|
||||
handled_functions = HANDLED_FUNCTIONS
|
||||
|
||||
def __new__(cls, orig_weight, _):
|
||||
if not isinstance(orig_weight, torch.Tensor):
|
||||
raise RuntimeError(f"Can only make Expanded Weights of Tensors, got {type(orig_weight).__name__}")
|
||||
if not orig_weight.requires_grad:
|
||||
raise RuntimeError("Can only build ExpandedWeights objects of tensors that require_grad")
|
||||
ret = torch.Tensor._make_subclass(cast(_TensorBase, cls), orig_weight, True)
|
||||
return ret
|
||||
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, _, args=(), kwargs=None):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
if func in cls.handled_functions:
|
||||
return cls.handled_functions[func].apply(tuple(kwargs.keys()), *(args + tuple(kwargs.values())))
|
||||
# We cannot use a fallback here because we do not know the batch dimension for any regular tensor inputs,
|
||||
# i.e. torch.add(torch.Tensor, ExpandedWeight)
|
||||
raise RuntimeError(f"Expanded Weights encountered but cannot handle function {func.__name__}")
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.orig_weight.dtype
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self.orig_weight.shape
|
||||
100
torch/nn/utils/_expanded_weights/expanded_weights_utils.py
Normal file
100
torch/nn/utils/_expanded_weights/expanded_weights_utils.py
Normal file
|
|
@ -0,0 +1,100 @@
|
|||
import torch
|
||||
from .expanded_weights_impl import ExpandedWeight
|
||||
|
||||
def standard_kwargs(kwarg_names, expanded_args):
|
||||
r'''Most `__torch_function__`s standardize the kwargs that they give, so this will separate
|
||||
the args and kwargs they pass. Functions that don't are linear and convND
|
||||
'''
|
||||
kwarg_values = expanded_args[len(expanded_args) - len(kwarg_names):]
|
||||
expanded_args_without_kwargs = expanded_args[:len(expanded_args) - len(kwarg_names)]
|
||||
expanded_kwargs = {name: value for (name, value) in zip(kwarg_names, kwarg_values)}
|
||||
return expanded_args_without_kwargs, expanded_kwargs
|
||||
|
||||
def forward_helper(func, expanded_args, expanded_kwargs):
|
||||
r'''Forward helper computes the forward pass for a function that has expanded weight(s)
|
||||
passed to it. It will run the forward pass where all ExpandedWeights are their original
|
||||
weight. It runs checks on the given arguments and detaches the outputs.
|
||||
|
||||
.. note:: First argument in :attr:`expanded_args` must be the input with the batch
|
||||
dimension as the first element of the shape
|
||||
|
||||
.. note:: :attr:`func` must return a Tensor or tuple of Tensors
|
||||
|
||||
Args:
|
||||
func: The function to be called
|
||||
ctx: The context from the autograd.Function object. Will be used to save
|
||||
computed state from the forward pass
|
||||
expanded_args: Arguments to be passed to :attr:`func`. Will include arguments
|
||||
that need to be unpacked because they are ExpandedWeights
|
||||
num_true_outs: The number of outputs seen by the user since some functions
|
||||
return auxillary data that is only used in the backward pass
|
||||
'''
|
||||
unexpanded_args, unexpanded_kwargs = _check_and_unexpand_args(func, expanded_args, expanded_kwargs)
|
||||
return func(*unexpanded_args, **unexpanded_kwargs)
|
||||
|
||||
def _check_and_unexpand_args(func, expanded_args, expanded_kwargs):
|
||||
# input must be the first argument passed
|
||||
input = expanded_args[0]
|
||||
if isinstance(input, ExpandedWeight):
|
||||
raise RuntimeError("Expanded Weights do not support inputs that are also ExpandedWeights. "
|
||||
f"Input must be a Tensor, got {type(input).__name__} in function {func.__name__}")
|
||||
if not isinstance(input, torch.Tensor):
|
||||
raise RuntimeError("Expanded Weights requires a Tensor as the first input to get the batch dimension, "
|
||||
f"got {type(input).__name__} in function {func.__name__}")
|
||||
if len(input.shape) == 0:
|
||||
raise RuntimeError(f"Expanded Weights requires a batch dimension but got an input of size 0 in function {func.__name__}")
|
||||
if input.shape[0] == 0:
|
||||
raise RuntimeError("0 is not a valid batch size for Expanded Weights but got input tensor of "
|
||||
f"{input} in function {func.__name__}")
|
||||
batch_size = input.shape[0]
|
||||
for arg in expanded_args + tuple(expanded_kwargs.values()):
|
||||
if isinstance(arg, ExpandedWeight) and arg.batch_size != batch_size:
|
||||
raise RuntimeError("Expected ExpandedWeights to have batch size matching input but got "
|
||||
f"input batch size of {batch_size} with ExpandedWeight of batch size {arg.batch_size}")
|
||||
|
||||
unexpanded_args = tuple(arg.orig_weight if isinstance(arg, ExpandedWeight) else arg for arg in expanded_args)
|
||||
unexpanded_kwargs = {name: arg.orig_weight if isinstance(arg, ExpandedWeight) else arg
|
||||
for (name, arg) in expanded_kwargs.items()}
|
||||
return unexpanded_args, unexpanded_kwargs
|
||||
|
||||
def set_grad_sample_if_exists(maybe_expanded_weight, per_sample_grad_fn):
|
||||
unpacked = unpack_expanded_weight_or_tensor(maybe_expanded_weight)
|
||||
if isinstance(maybe_expanded_weight, ExpandedWeight):
|
||||
if hasattr(unpacked, "grad_sample") and unpacked.grad_sample is not None:
|
||||
unpacked.grad_sample = unpacked.grad_sample + per_sample_grad_fn(unpacked)
|
||||
else:
|
||||
unpacked.grad_sample = per_sample_grad_fn(unpacked)
|
||||
|
||||
def unpack_expanded_weight_or_tensor(maybe_expanded_weight, func=lambda x: x):
|
||||
if isinstance(maybe_expanded_weight, ExpandedWeight):
|
||||
orig_weight = maybe_expanded_weight.orig_weight
|
||||
return func(orig_weight)
|
||||
elif isinstance(maybe_expanded_weight, torch.Tensor) and not maybe_expanded_weight.requires_grad:
|
||||
return func(maybe_expanded_weight)
|
||||
elif isinstance(maybe_expanded_weight, torch.Tensor):
|
||||
raise RuntimeError("ExpandedWeights currently does not support a mixture of ExpandedWeight parameters "
|
||||
"and normal Parameters. Please file and issue with pytorch/pytorch")
|
||||
|
||||
def sum_over_all_but_batch_and_last_n(
|
||||
tensor: torch.Tensor, n_dims: int
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Calculates the sum over all dimensions, except the first
|
||||
(batch dimension), and excluding the last n_dims.
|
||||
This function will ignore the first dimension and it will
|
||||
not aggregate over the last n_dims dimensions.
|
||||
Args:
|
||||
tensor: An input tensor of shape ``(B, ..., X[n_dims-1])``.
|
||||
n_dims: Number of dimensions to keep.
|
||||
Example:
|
||||
>>> tensor = torch.ones(1, 2, 3, 4, 5)
|
||||
>>> sum_over_all_but_batch_and_last_n(tensor, n_dims=2).shape
|
||||
torch.Size([1, 4, 5])
|
||||
Returns:
|
||||
A tensor of shape ``(B, ..., X[n_dims-1])``
|
||||
"""
|
||||
if tensor.dim() == n_dims + 1:
|
||||
return tensor
|
||||
else:
|
||||
dims = list(range(1, tensor.dim() - n_dims))
|
||||
return tensor.sum(dim=dims)
|
||||
38
torch/nn/utils/_expanded_weights/linear_expanded_weights.py
Normal file
38
torch/nn/utils/_expanded_weights/linear_expanded_weights.py
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
from .expanded_weights_impl import implements_per_sample_grads
|
||||
from .expanded_weights_utils import \
|
||||
forward_helper, set_grad_sample_if_exists, unpack_expanded_weight_or_tensor
|
||||
from typing import List, Optional
|
||||
|
||||
@implements_per_sample_grads(F.linear)
|
||||
class LinearPerSampleGrad(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, _, *expanded_args_and_kwargs):
|
||||
if len(expanded_args_and_kwargs[0].shape) <= 1:
|
||||
raise RuntimeError("Input does not have a batch dimension. Expanded Weights expected input "
|
||||
f"of at least rank 2, got of rank {len(expanded_args_and_kwargs[0].shape)}")
|
||||
expanded_kwargs = {'bias': expanded_args_and_kwargs[2] if len(expanded_args_and_kwargs) == 3 else None}
|
||||
expanded_args = expanded_args_and_kwargs[:2]
|
||||
output = forward_helper(F.linear, expanded_args, expanded_kwargs)
|
||||
ctx.args = expanded_args
|
||||
ctx.kwargs = expanded_kwargs
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, weight = ctx.args
|
||||
bias = ctx.kwargs['bias']
|
||||
results: List[Optional[torch.Tensor]] = []
|
||||
results.append(None) # for kwarg_names
|
||||
|
||||
if input.requires_grad:
|
||||
results.append(grad_output.matmul(unpack_expanded_weight_or_tensor(weight)))
|
||||
else:
|
||||
results.append(None)
|
||||
results.extend([None] * 2) # weight and bias don't compute batched gradients
|
||||
|
||||
# weight and bias get their grad_sample fields set directly if they exist
|
||||
set_grad_sample_if_exists(weight, lambda _: torch.einsum("n...i,n...j->nij", grad_output, input))
|
||||
set_grad_sample_if_exists(bias, lambda _: torch.einsum("n...k->nk", grad_output))
|
||||
return tuple(results)
|
||||
57
torch/nn/utils/_per_sample_grad.py
Normal file
57
torch/nn/utils/_per_sample_grad.py
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
import torch
|
||||
from torch.nn.utils._stateless import functional_call
|
||||
from torch.nn.utils._expanded_weights.expanded_weights_impl import ExpandedWeight
|
||||
|
||||
# dependency on `functional_call` means that this can't be exposed in utils
|
||||
# without creating circular dependency
|
||||
def call_for_per_sample_grads(module, batch_size, args, kwargs=None):
|
||||
r"""
|
||||
call_for_per_sample_grads(module, batch_size, args, kwargs=None) -> Tensor
|
||||
Invoked just like a forward pass, ``call_for_per_sample_grads`` will produce the same
|
||||
forward result. Then, when backward is invoked, the parameters of ``module``
|
||||
will have a ``grad_sample`` field populated with the per sample gradients
|
||||
instead of the regular gradients
|
||||
|
||||
Args:
|
||||
module: The ``nn.Module`` to get per sample gradients with respect to. All trainable
|
||||
parameters will compute per sample gradients, located in a ``grad_sample``
|
||||
field when ``backward`` is invoked
|
||||
batch_size: The batch size of the input. Typically the input's first dimension
|
||||
args: Tuple of positional args passed to ``module`` to perform the forward pass
|
||||
kwargs: Dict of named args passed to ``module`` to perform the forward pass. Default: None
|
||||
|
||||
Examples::
|
||||
>>> model = nn.Linear(4, 3)
|
||||
>>> batched_input = torch.randn(5, 4) # batch size of 5
|
||||
>>> res = call_for_per_sample_grads(model, batched_input.shape[0], batched_input).sum()
|
||||
>>> res.backward()
|
||||
>>> assert model.weight.shape == (3, 4)
|
||||
>>> assert model.weight.grad_sample.shape == (5, 3, 4)
|
||||
>>> assert model.weight.grad == None
|
||||
>>> assert model.bias.shape == (3,)
|
||||
>>> assert model.bias.grad_sample.shape == (5, 3)
|
||||
>>> assert model.bias.grad == None
|
||||
|
||||
Note::
|
||||
Does not work with any `nn.RNN`, including `nn.GRU` or `nn.LSTM`. Please use custom
|
||||
rewrites that wrap an `nn.Linear` module. See Opacus for an example
|
||||
"""
|
||||
def maybe_build_expanded_weight(og_tensor):
|
||||
if og_tensor.requires_grad:
|
||||
return ExpandedWeight(og_tensor, batch_size)
|
||||
else:
|
||||
return og_tensor
|
||||
|
||||
if not isinstance(module, torch.nn.Module):
|
||||
raise RuntimeError(f"Module passed must be nn.Module, got {type(module).__name__}")
|
||||
if not isinstance(batch_size, int):
|
||||
raise RuntimeError(f"Batch size passed must be an integer, got {type(batch_size).__name__}")
|
||||
if batch_size < 1:
|
||||
raise RuntimeError(f"Batch size must be positive, got {batch_size}")
|
||||
for weight in module.parameters():
|
||||
if hasattr(weight, "grad_sample") and weight.grad_sample is not None: # type: ignore[attr-defined]
|
||||
raise RuntimeError("Current Expanded Weights accumulates the gradients, which will be incorrect for multiple "
|
||||
f"calls without clearing gradients. Please clear out the grad_sample parameter of {weight} or "
|
||||
"post an issue to pytorch/pytorch to prioritize correct behavior")
|
||||
params = {name: maybe_build_expanded_weight(value) for (name, value) in module.named_parameters()}
|
||||
return functional_call(module, params, args, kwargs)
|
||||
|
|
@ -619,6 +619,8 @@ class OpInfo(object):
|
|||
test_conjugated_samples=True,
|
||||
test_neg_view=True,
|
||||
assert_jit_shape_analysis=False, # assert that jit shape analysis fully propagates shape
|
||||
# the following metadata relates to ExpandedWeights support and is checked in test_expanded_weights.py
|
||||
supports_expanded_weight=False,
|
||||
):
|
||||
|
||||
dtypes_args = (dtypes, dtypesIfCPU, dtypesIfCUDA, dtypesIfROCM)
|
||||
|
|
@ -778,6 +780,7 @@ class OpInfo(object):
|
|||
|
||||
self.test_conjugated_samples = test_conjugated_samples
|
||||
self.test_neg_view = test_neg_view
|
||||
self.supports_expanded_weight = supports_expanded_weight
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""Calls the function variant of the operator."""
|
||||
|
|
@ -11330,6 +11333,7 @@ op_db: List[OpInfo] = [
|
|||
supports_fwgrad_bwgrad=True,
|
||||
# See https://github.com/pytorch/pytorch/issues/66357
|
||||
check_batched_forward_grad=False,
|
||||
supports_expanded_weight=True,
|
||||
supports_out=False),
|
||||
OpInfo('nn.functional.bilinear',
|
||||
aten_name='bilinear',
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user