import contextlib import itertools import logging from typing import Iterator, List import torch from torch.testing._internal.common_utils import TestCase, run_tests from torch.utils._pytree import tree_map # TODO: move this into library proper @contextlib.contextmanager def no_dispatch() -> Iterator[None]: guard = torch._C._DisableTorchDispatch() try: yield finally: del guard # How the chain of calls works for LoggingTensor: # 1. Call torch.sin # 2. Attempt __torch_function__. In LoggingTensor torch function is disabled so we bypass it entirely # 3. Enter dispatcher, wind your way through Autograd # 4. Hit Python dispatch key, call __torch_dispatch__ # TODO: TensorBase should work class LoggingTensor(torch.Tensor): elem: torch.Tensor __slots__ = ["elem"] @staticmethod def __new__(cls, elem, *args, **kwargs): # The wrapping tensor (LoggingTensor) is just a meta tensor, so it # doesn't hold any memory (meta tensor is generally the preferred type # of tensor you want to make a subclass from)... r = torch.Tensor._make_subclass(cls, elem.to("meta"), elem.requires_grad) # ...the real tensor is held as an element on the tensor. r.elem = elem return r def __repr__(self): return f"LoggingTensor({self.elem})" @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): def unwrap(e): return e.elem if isinstance(e, LoggingTensor) else e def wrap(e): return LoggingTensor(e) if isinstance(e, torch.Tensor) else e rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))) logging.getLogger("LoggingTensor").info( f"{func.__module__}.{func.__name__}", args, kwargs, rs ) return rs # https://stackoverflow.com/questions/36408496/python-logging-handler-to-append-to-list class LoggingTensorHandler(logging.Handler): log_list: List[str] next_shortid: int def __init__(self, log_list: List[str]) -> None: logging.Handler.__init__(self) self.log_list = log_list self.next_shortid = 0 # WARNING: not deterministic over multiple threads, this matters for # autograd def _shortid(self, o: object) -> int: if not hasattr(o, "_shortid"): o._shortid = self.next_shortid self.next_shortid += 1 return o._shortid def _fmt(self, a: object) -> str: return f"${self._shortid(a)}" if isinstance(a, LoggingTensor) else repr(a) def emit(self, record): fmt_args = ", ".join( itertools.chain( (self._fmt(a) for a in record.args[0]), (f"{k}={self._fmt(v)}" for k, v in record.args[1].items()), ) ) fmt_rets = ( ", ".join(self._fmt(a) for a in record.args[2]) if isinstance(record.args[2], (list, tuple)) else self._fmt(record.args[2]) ) self.log_list.append(f"{fmt_rets} = {record.msg}({fmt_args})") def log_input(name: str, var: object): logging.getLogger("LoggingTensor").info("input", (name,), {}, (var,)) @contextlib.contextmanager def capture_logs() -> Iterator[List[str]]: logger = logging.getLogger("LoggingTensor") log_list = [] handler = LoggingTensorHandler(log_list) logger.addHandler(handler) logger.setLevel(logging.INFO) try: yield log_list finally: logger.removeHandler(handler) class TestPythonDispatch(TestCase): def test_basic(self) -> None: with capture_logs() as logs: x = LoggingTensor(torch.tensor([3.0], requires_grad=True)) log_input("x", x) y = x * x saved_x = y.grad_fn._saved_self grad_y = LoggingTensor(torch.tensor([1.0])) log_input("grad_y", grad_y) (g,) = torch.autograd.grad((y,), (x,), (grad_y,)) self.assertEqual(g.elem, torch.tensor([6.0])) with torch.no_grad(): self.assertEqual(saved_x, x) self.assertEqual(saved_x._version, x._version) x.add_(2) self.assertEqual(saved_x, x) # TODO: figure out why broken # self.assertEqual(saved_x._version, x._version) self.assertExpectedInline( "\n".join(logs), """\ $0 = input('x') $1 = torch._ops.aten.mul($0, $0) $2 = input('grad_y') $3 = torch._ops.aten.mul($2, $0) $4 = torch._ops.aten.mul($2, $0) $5 = torch._ops.aten.add($4, $3)""", ) def test_out(self) -> None: with capture_logs() as logs: x = LoggingTensor(torch.ones(1)) y = LoggingTensor(torch.zeros(1)) log_input("x", x) log_input("y", y) torch.abs(x, out=y) self.assertEqual(y.elem, torch.ones(1)) # TODO: arguably this shouldn't pass and we should complain # that out isn't a kwarg self.assertExpectedInline( "\n".join(logs), """\ $0 = input('x') $1 = input('y') $2 = torch._ops.aten.abs($0, out=$1)""", ) def test_kwarg_only(self) -> None: with capture_logs() as logs: x = LoggingTensor(torch.ones(1)) y = LoggingTensor(torch.ones(1, 1)) z = LoggingTensor(torch.ones(1)) log_input("x", x) log_input("y", y) log_input("z", z) torch.addmv(x, y, z) torch.addmv(x, y, z, beta=1) torch.addmv(x, y, z, beta=2) torch.addmv(x, y, z, alpha=2) torch.addmv(x, y, z, beta=2, alpha=2) # The expectation is that beta/alpha don't show up when they're # defaulted. This is even if the user explicitly specified it. self.assertExpectedInline( "\n".join(logs), """\ $0 = input('x') $1 = input('y') $2 = input('z') $3 = torch._ops.aten.addmv($0, $1, $2) $4 = torch._ops.aten.addmv($0, $1, $2) $5 = torch._ops.aten.addmv($0, $1, $2, beta=2) $6 = torch._ops.aten.addmv($0, $1, $2, alpha=2) $7 = torch._ops.aten.addmv($0, $1, $2, beta=2, alpha=2)""", ) def test_kwarg_only_and_positional_default(self) -> None: with capture_logs() as logs: x = LoggingTensor(torch.ones(1)) y = LoggingTensor(torch.ones(1)) log_input("x", x) log_input("y", y) torch.ops.aten.kl_div(x, y) torch.ops.aten.kl_div(x, y, 2) torch.ops.aten.kl_div(x, y, log_target=True) torch.ops.aten.kl_div(x, y, 2, log_target=True) # What we are testing here is that we omit reduction # if it is defaulted, even if a kwarg is set self.assertExpectedInline( "\n".join(logs), """\ $0 = input('x') $1 = input('y') $2 = torch._ops.aten.kl_div($0, $1) $3 = torch._ops.aten.kl_div($0, $1, 2) $4 = torch._ops.aten.kl_div($0, $1, log_target=True) $5 = torch._ops.aten.kl_div($0, $1, 2, log_target=True)""", ) def test_list_ret(self) -> None: # test all sequence types are permissible returns for list_type in (list, tuple): class A(torch._C._TensorBase): @staticmethod def __new__(cls, elem): return torch.Tensor._make_subclass(cls, elem, elem.requires_grad) @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): if func == torch.ops.aten.split: with no_dispatch(): return list_type(torch.split(*args)) else: raise AssertionError(f"unrecognized func: {func}") self.assertEqual( torch.split(A(torch.tensor([0, 1])), 2), torch.split(torch.tensor([0, 1]), 2), ) def test_invalid_ret(self) -> None: # test invalid return gets reasonable error message class A(torch._C._TensorBase): @staticmethod def __new__(cls, elem): return torch.Tensor._make_subclass(cls, elem, elem.requires_grad) @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): return "arf" # Wobbles depending on NDEBUG mode of pybind11 self.assertRaisesRegexp( RuntimeError, "Unable to cast", lambda: A(torch.zeros(1)).neg(), ) self.assertExpectedRaisesInline( RuntimeError, lambda: A(torch.zeros(1)).detach(), """detach returned invalid type str, expected Tensor""", ) def test_metadata_change_not_allowed(self) -> None: x = LoggingTensor(torch.ones(1)) y = x.data self.assertIsInstance(y, LoggingTensor) self.assertRaises(RuntimeError, lambda: y.resize_(4)) def test_version(self) -> None: x = LoggingTensor(torch.ones(1)) prev_vc = x._version x.detach().add_(2) cur_vc = x._version self.assertNotEqual(prev_vc, cur_vc) x.data.add_(2) self.assertEqual(cur_vc, x._version) def test_format(self) -> None: x = LoggingTensor(torch.ones(1)) s1 = str(x) s2 = repr(x) s3 = f"{x}" self.assertExpectedInline(s1, """LoggingTensor(tensor([1.]))""") self.assertEqual(s1, s2) self.assertEqual(s1, s3) def test_custom_autograd(self) -> None: escape = [None] class Square(torch.autograd.Function): @staticmethod def forward(ctx, x): y = x ** 2 ctx.save_for_backward(x) return y @staticmethod def backward(ctx, grad_output): assert isinstance(grad_output, LoggingTensor) (x,) = ctx.saved_tensors assert isinstance(x, LoggingTensor) escape[0] = x return grad_output * 2 * x with capture_logs() as logs: x = LoggingTensor(torch.ones(1, requires_grad=True)) log_input("x", x) x.grad = LoggingTensor(torch.zeros(1)) log_input("x.grad", x.grad) y = Square.apply(x) grad_output = LoggingTensor(torch.ones(1)) log_input("grad_output", grad_output) y.backward(grad_output) with torch.no_grad(): self.assertEqual(escape[0], x) self.assertEqual(escape[0]._version, x._version) # TODO: figure out why x.requires_grad = False doesn't # trigger an error for LoggingTensor x.add_(2) self.assertEqual(escape[0], x) # TODO: figure out why this is broken # self.assertEqual(escape[0]._version, x._version) self.assertExpectedInline( "\n".join(logs), """\ $0 = input('x') $1 = input('x.grad') $2 = torch._ops.aten.pow($0, 2) $3 = input('grad_output') $4 = torch._ops.aten.mul($3, tensor(2)) $5 = torch._ops.aten.mul($4, $0) $6 = torch._ops.aten.add_($1, $5)""", ) if __name__ == "__main__": run_tests()