pytorch/torch/csrc/jit/script
Zachary DeVito 1a2ec10bd4 Support enough of closures to write autograd functions (#15411)
Summary:
This PR adds enough of the infra for supporting closures (inner script functions) in order to allow us to expression symbolic gradients using them. We do not actually ever run graphs that contain these closures. The symbolic_script infrastructure just extracts them out of the original forward graph and turns them into discrete forward/backward pairs. This cuts down on the type annotations necessary to write forward/backward pairs and aligns closely with the "differentiator" function approach to expression reverse-mode AD.

Example:

This code:
```
import torch

r = torch.jit.CompilationUnit(
'''
def mul_forward(self, other):
    def backward(grad_output):
        grad_self = (grad_output * other).sum_to_size(self.size())
        grad_other = (grad_output * self).sum_to_size(other.size())
        return grad_self, grad_other
    return self * other, backward
''')

print(r.module.code)
```

Will produce this graph (pretty printed for clarity):

```
def mul_forward(self,
    self: Tensor,
    other: Tensor) -> Tuple[Tensor, Tuple[None, Tuple[Tensor, Tensor]]]:
  backward = (self.__lambda, (other, self))
  return (torch.mul(self, other), backward)

def __lambda(self,
    context: Tuple[Tensor, Tensor],
    grad_output: Tensor) -> Tuple[Tensor, Tensor]:
  other, self, = context
  grad_self = torch.sum_to_size(torch.mul(grad_output, other), torch.size(self))
  grad_other = torch.sum_to_size(torch.mul(grad_output, self), torch.size(other))
  return (grad_self, grad_other)
```

symbolic_script will then do some modifications to remove the unsuppored prim::Function node, yielding:

```
def mul_forward(self,
    self: Tensor,
    other: Tensor) -> Tuple[Tensor, Tuple[None, Tuple[Tensor, Tensor]]]:
  return (torch.mul(self, other), (other, self))

def backward(self,
    context: Tuple[Tensor, Tensor],
    grad_output: Tensor) -> Tuple[Tensor, Tensor]:
  other, self, = context
  grad_self = torch.sum_to_size(torch.mul(grad_output, other), torch.size(self))
  grad_other = torch.sum_to_size(torch.mul(grad_output, self), torch.size(other))
  return (grad_self, grad_other)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15411

Differential Revision: D13523340

Pulled By: zdevito

fbshipit-source-id: 4d4a269460e595b16802c00ec55ae00e3e682d49
2018-12-20 14:39:11 -08:00
..
builtin_functions.cpp Canonicalize all includes in PyTorch. (#14849) 2018-12-08 19:38:30 -08:00
builtin_functions.h Canonicalize all includes in PyTorch. (#14849) 2018-12-08 19:38:30 -08:00
compiler.cpp Support enough of closures to write autograd functions (#15411) 2018-12-20 14:39:11 -08:00
compiler.h Split up compiler.cpp (#15355) 2018-12-18 19:43:35 -08:00
error_report.h Canonicalize all includes in PyTorch. (#14849) 2018-12-08 19:38:30 -08:00
init.cpp Split up compiler.cpp (#15355) 2018-12-18 19:43:35 -08:00
init.h Canonicalize all includes in PyTorch. (#14849) 2018-12-08 19:38:30 -08:00
jit_exception.h First step at adding exceptions (#12789) 2018-10-30 20:25:50 -07:00
lexer.cpp Enable performance-unnecessary-value-param in .clang-tidy (#15026) 2018-12-13 16:15:35 -08:00
lexer.h Enable performance-unnecessary-value-param in .clang-tidy (#15026) 2018-12-13 16:15:35 -08:00
module.cpp Split up compiler.cpp (#15355) 2018-12-18 19:43:35 -08:00
module.h Fix Module::copy_into 2018-12-19 17:09:59 -08:00
parse_string_literal.h Canonicalize all includes in PyTorch. (#14849) 2018-12-08 19:38:30 -08:00
parser.cpp Support enough of closures to write autograd functions (#15411) 2018-12-20 14:39:11 -08:00
parser.h Create parser.cpp (#15238) 2018-12-14 19:31:36 -08:00
python_tree_views.cpp Method returns a single argument (#15289) 2018-12-18 10:44:09 -08:00
python_tree_views.h Canonicalize all includes in PyTorch. (#14849) 2018-12-08 19:38:30 -08:00
schema_matching.cpp Split up compiler.cpp (#15355) 2018-12-18 19:43:35 -08:00
schema_matching.h Split up compiler.cpp (#15355) 2018-12-18 19:43:35 -08:00
sugared_value.cpp Split up compiler.cpp (#15355) 2018-12-18 19:43:35 -08:00
sugared_value.h Split up compiler.cpp (#15355) 2018-12-18 19:43:35 -08:00
tree_views.h Support enough of closures to write autograd functions (#15411) 2018-12-20 14:39:11 -08:00
tree.h Enable all clang-tidy performance checks (#15198) 2018-12-14 13:32:47 -08:00
type_parser.cpp Split up compiler.cpp (#15355) 2018-12-18 19:43:35 -08:00
type_parser.h Split up compiler.cpp (#15355) 2018-12-18 19:43:35 -08:00