* Added decomposition testing + display infra
* add a couple more decompositions
* changed some stuff
* made some changes
* Added decomposition testing + display infra
* add a couple more decompositions
* fix some decompositions
* changed some stuff
* updated generation
* fix test failures
* removed extraneous files
* fixed test failures
* fixed tests
* updated
* fixed tests again
Two main things happened:
- I removed {wrap_key, PythonTensor, pythonkey_trace} from being public
APIs
- I moved all compilation related things to the functorch.compile
namespace. This includes nnc_jit which is now in
functorch.compile.nnc_jit
Concerns:
- nnc_jit was in the functorch namespace for a long time. Should we
leave it there? Are there stakeholders to notify?
Summary: Recomputation fwd in the bwd pass can improve the performance
of pointwise operators, where it helps us in reduce memory bandwidth
pressure at the expense of more computation. This PR adds a new
partitioning function to enable this type of recomputation.
* Support buffers in compiled_module
* Don't compute gradients for inputs that don't require grad
* Add a unit test for batchnorm
* Fix eager compilation tests that change requires_grad
* Create new args for tests without recompilation
* Enable some eager fusion opinfo tests that now work (because we stopped asking for unimplemented derivatives)
Summary: The existing code assumed a single output; this generalizes to tuple
outputs
Test Plan: Compile a simple test program with multiple outputs and check that
outputs/grads are the same as eager.
```
def foo(a, b):
return a + b, a * b
```
* handled some cases of index.Tensor
* fixed merge errors
* Added batching rules for index, both cases are batched
* fix some issues
* handled some cases of index.Tensor
* fixed merge errors
* Added batching rules for index, both cases are batched
* fix some issues
* fix tests
* handled some cases of index.Tensor
* fixed merge errors
* fixed tests
Benchmark:
https://gist.github.com/zou3519/f7691a94f8570b27cccc8e16fc8ed13b
It doesn't look like this adds a lot of overhead. If the overhead
becomes a problem we can move this into C++.
The cache works by specializing on shape/stride/dtype/device of the
input tensor and any concrete values.
NB: the concrete value cache means that if an integer arg to the
function changes, we will recompile. In the future, when we add
static_argnums, we should change the cache to "not specialize on
specific integers".
Test Plan:
- run tests
Updates make_functional to use the new improved variants. The new
variants are superior in every way so we're replacing the previous
variants with this.
If someone wants the older variants, they can be found at:
- make_functional_with_buffers_deprecated_v1
- make_functional_deprecated_v1