mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/61123 This applies the design pattern of removing explicit arguments when they coincide with the default arguments. This simplifies argument patterns that dispatch kernels receive and make it easier for us to maintain BC (as addition of a new default argument isn't immediately BC-breaking for dispatch implementors). There is an important extra API which I haven't implemented here yet, which is to take an incomplete sequence of arguments and fill out their defaults (in case the user did want normalization). I plan on adding that in a future PR. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Test Plan: Imported from OSS Reviewed By: saketh-are Differential Revision: D29853616 Pulled By: ezyang fbshipit-source-id: 71c672cb3a7d4d01f838a1c7fcdb75a8ce7d058e
258 lines
8.8 KiB
Python
258 lines
8.8 KiB
Python
import torch
|
|
from torch.testing._internal.common_utils import TestCase, run_tests
|
|
from torch.utils._pytree import tree_map
|
|
|
|
from typing import Iterator, List
|
|
import logging
|
|
import contextlib
|
|
|
|
# TODO: move this into library proper
|
|
@contextlib.contextmanager
|
|
def no_dispatch() -> Iterator[None]:
|
|
guard = torch._C._DisableTorchDispatch()
|
|
try:
|
|
yield
|
|
finally:
|
|
del guard
|
|
|
|
|
|
# How the chain of calls works for LoggingTensor:
|
|
# 1. Call torch.sin
|
|
# 2. Attempt __torch_function__. In LoggingTensor torch function is disabled so we bypass it entirely
|
|
# 3. Enter dispatcher, wind your way through Autograd
|
|
# 4. Hit Python dispatch key, call __torch_dispatch__
|
|
|
|
# TODO: TensorBase should work
|
|
class LoggingTensor(torch.Tensor):
|
|
elem: torch.Tensor
|
|
|
|
__slots__ = ['elem']
|
|
|
|
@staticmethod
|
|
def __new__(cls, elem, *args, **kwargs):
|
|
# The wrapping tensor (LoggingTensor) is just a meta tensor, so it
|
|
# doesn't hold any memory (meta tensor is generally the preferred type
|
|
# of tensor you want to make a subclass from)...
|
|
r = torch.Tensor._make_subclass(cls, elem.to('meta'), elem.requires_grad)
|
|
# ...the real tensor is held as an element on the tensor.
|
|
r.elem = elem
|
|
return r
|
|
|
|
def __repr__(self):
|
|
return f"LoggingTensor({self.elem})"
|
|
|
|
@classmethod
|
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
|
def unwrap(e):
|
|
return e.elem if isinstance(e, LoggingTensor) else e
|
|
|
|
def wrap(e):
|
|
return LoggingTensor(e) if isinstance(e, torch.Tensor) else e
|
|
|
|
# TODO: handle kwargs
|
|
assert not kwargs
|
|
rs = tree_map(wrap, func(*tree_map(unwrap, args)))
|
|
logging.getLogger("LoggingTensor").info(f"{func.__module__}.{func.__name__}", args, rs)
|
|
return rs
|
|
|
|
# https://stackoverflow.com/questions/36408496/python-logging-handler-to-append-to-list
|
|
class LoggingTensorHandler(logging.Handler):
|
|
log_list: List[str]
|
|
next_shortid: int
|
|
|
|
def __init__(self, log_list: List[str]) -> None:
|
|
logging.Handler.__init__(self)
|
|
self.log_list = log_list
|
|
self.next_shortid = 0
|
|
|
|
# WARNING: not deterministic over multiple threads, this matters for
|
|
# autograd
|
|
def _shortid(self, o: object) -> int:
|
|
if not hasattr(o, '_shortid'):
|
|
o._shortid = self.next_shortid
|
|
self.next_shortid += 1
|
|
return o._shortid
|
|
|
|
def _fmt(self, a: object) -> str:
|
|
return f'${self._shortid(a)}' if isinstance(a, LoggingTensor) else repr(a)
|
|
|
|
def emit(self, record):
|
|
fmt_args = "(" + ", ".join(self._fmt(a) for a in record.args[0]) + ")"
|
|
fmt_rets = ", ".join(self._fmt(a) for a in record.args[1]) \
|
|
if isinstance(record.args[1], (list, tuple)) else self._fmt(record.args[1])
|
|
self.log_list.append(f'{fmt_rets} = {record.msg}{fmt_args}')
|
|
|
|
def log_input(name: str, var: object):
|
|
logging.getLogger("LoggingTensor").info("input", (name,), (var,))
|
|
|
|
@contextlib.contextmanager
|
|
def capture_logs() -> Iterator[List[str]]:
|
|
logger = logging.getLogger("LoggingTensor")
|
|
log_list = []
|
|
handler = LoggingTensorHandler(log_list)
|
|
logger.addHandler(handler)
|
|
logger.setLevel(logging.INFO)
|
|
try:
|
|
yield log_list
|
|
finally:
|
|
logger.removeHandler(handler)
|
|
|
|
class TestPythonDispatch(TestCase):
|
|
def test_basic(self) -> None:
|
|
with capture_logs() as logs:
|
|
x = LoggingTensor(torch.tensor([3.0], requires_grad=True))
|
|
log_input("x", x)
|
|
y = x * x
|
|
saved_x = y.grad_fn._saved_self
|
|
grad_y = LoggingTensor(torch.tensor([1.0]))
|
|
log_input("grad_y", grad_y)
|
|
g, = torch.autograd.grad((y,), (x,), (grad_y,))
|
|
|
|
self.assertEqual(g.elem, torch.tensor([6.0]))
|
|
with torch.no_grad():
|
|
self.assertEqual(saved_x, x)
|
|
self.assertEqual(saved_x._version, x._version)
|
|
x.add_(2)
|
|
self.assertEqual(saved_x, x)
|
|
# TODO: figure out why broken
|
|
# self.assertEqual(saved_x._version, x._version)
|
|
self.assertExpectedInline('\n'.join(logs), '''\
|
|
$0 = input('x')
|
|
$1 = torch._ops.aten.mul($0, $0)
|
|
$2 = input('grad_y')
|
|
$3 = torch._ops.aten.mul($2, $0)
|
|
$4 = torch._ops.aten.mul($2, $0)
|
|
$5 = torch._ops.aten.add($4, $3)''')
|
|
|
|
def test_out(self) -> None:
|
|
with capture_logs() as logs:
|
|
x = LoggingTensor(torch.ones(1))
|
|
y = LoggingTensor(torch.zeros(1))
|
|
log_input("x", x)
|
|
log_input("y", y)
|
|
torch.abs(x, out=y)
|
|
|
|
self.assertEqual(y.elem, torch.ones(1))
|
|
# TODO: arguably this shouldn't pass and we should complain
|
|
# that out isn't a kwarg
|
|
self.assertExpectedInline('\n'.join(logs), '''\
|
|
$0 = input('x')
|
|
$1 = input('y')
|
|
$2 = torch._ops.aten.abs($0, $1)''')
|
|
|
|
def test_list_ret(self) -> None:
|
|
# test all sequence types are permissible returns
|
|
for list_type in (list, tuple):
|
|
class A(torch._C._TensorBase):
|
|
@staticmethod
|
|
def __new__(cls, elem):
|
|
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
|
|
|
|
@classmethod
|
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
|
if func == torch.ops.aten.split:
|
|
with no_dispatch():
|
|
return list_type(torch.split(*args))
|
|
else:
|
|
raise AssertionError(f"unrecognized func: {func}")
|
|
|
|
self.assertEqual(
|
|
torch.split(A(torch.tensor([0, 1])), 2),
|
|
torch.split(torch.tensor([0, 1]), 2)
|
|
)
|
|
|
|
def test_invalid_ret(self) -> None:
|
|
# test invalid return gets reasonable error message
|
|
class A(torch._C._TensorBase):
|
|
@staticmethod
|
|
def __new__(cls, elem):
|
|
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
|
|
|
|
@classmethod
|
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
|
return "arf"
|
|
|
|
# Wobbles depending on NDEBUG mode of pybind11
|
|
self.assertRaisesRegexp(
|
|
RuntimeError, "Unable to cast", lambda: A(torch.zeros(1)).neg(),
|
|
)
|
|
self.assertExpectedRaisesInline(
|
|
RuntimeError, lambda: A(torch.zeros(1)).detach(),
|
|
"""detach returned invalid type str, expected Tensor"""
|
|
)
|
|
|
|
def test_metadata_change_not_allowed(self) -> None:
|
|
x = LoggingTensor(torch.ones(1))
|
|
y = x.data
|
|
self.assertIsInstance(y, LoggingTensor)
|
|
self.assertRaises(RuntimeError, lambda: y.resize_(4))
|
|
|
|
def test_version(self) -> None:
|
|
x = LoggingTensor(torch.ones(1))
|
|
prev_vc = x._version
|
|
x.detach().add_(2)
|
|
cur_vc = x._version
|
|
self.assertNotEqual(prev_vc, cur_vc)
|
|
x.data.add_(2)
|
|
self.assertEqual(cur_vc, x._version)
|
|
|
|
def test_format(self) -> None:
|
|
x = LoggingTensor(torch.ones(1))
|
|
s1 = str(x)
|
|
s2 = repr(x)
|
|
s3 = f"{x}"
|
|
self.assertExpectedInline(s1, """LoggingTensor(tensor([1.]))""")
|
|
self.assertEqual(s1, s2)
|
|
self.assertEqual(s1, s3)
|
|
|
|
def test_custom_autograd(self) -> None:
|
|
escape = [None]
|
|
|
|
class Square(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, x):
|
|
y = x ** 2
|
|
ctx.save_for_backward(x)
|
|
return y
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
assert isinstance(grad_output, LoggingTensor)
|
|
x, = ctx.saved_tensors
|
|
assert isinstance(x, LoggingTensor)
|
|
escape[0] = x
|
|
return grad_output * 2 * x
|
|
|
|
with capture_logs() as logs:
|
|
x = LoggingTensor(torch.ones(1, requires_grad=True))
|
|
log_input("x", x)
|
|
x.grad = LoggingTensor(torch.zeros(1))
|
|
log_input("x.grad", x.grad)
|
|
y = Square.apply(x)
|
|
grad_output = LoggingTensor(torch.ones(1))
|
|
log_input("grad_output", grad_output)
|
|
y.backward(grad_output)
|
|
|
|
with torch.no_grad():
|
|
self.assertEqual(escape[0], x)
|
|
self.assertEqual(escape[0]._version, x._version)
|
|
# TODO: figure out why x.requires_grad = False doesn't
|
|
# trigger an error for LoggingTensor
|
|
x.add_(2)
|
|
self.assertEqual(escape[0], x)
|
|
# TODO: figure out why this is broken
|
|
# self.assertEqual(escape[0]._version, x._version)
|
|
|
|
self.assertExpectedInline('\n'.join(logs), '''\
|
|
$0 = input('x')
|
|
$1 = input('x.grad')
|
|
$2 = torch._ops.aten.pow($0, 2)
|
|
$3 = input('grad_output')
|
|
$4 = torch._ops.aten.mul($3, tensor(2))
|
|
$5 = torch._ops.aten.mul($4, $0)
|
|
$6 = torch._ops.aten.add_($1, $5)''')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|