This PR lifts the restriction that the output of a function traced with `make_traced` and executed with nvFuser must be a single tensor. Now it's possible to return a "pytree", a tensor's nested data structure (see https://github.com/pytorch/pytorch/blob/master/torch/utils/_pytree.py).
I added a test with a function that returns a tuple of two objects where one of the objects is a dictionary with a tensor value.
```py
def fn(a, b):
d = {}
d["c"] = torch.add(a, b)
return (d, torch.add(a, d["c"]))
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78802
Approved by: https://github.com/mruberry
This means it can be fed through traditional PyTorch C++ code
(although currently it does not work, as the __torch_dispatch__
implementation is stubbed to always throw an error.)
Signed-off-by: Edward Z. Yang <ezyangfb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76759
Approved by: https://github.com/mruberry
This adds prototype nvFuser integration for the following prims:
- broadcast_in_dim
- convert_element_type
- add
- div
- ge
- gt
- le
- lt
- mul
Adding it for additional prims supported by nvFuser's prototype Python frontend should be easy.
This also adds a new sugar to run operations using the ATen or nvFuser trace executors. For example:
```
def foo(a, b):
return torch.add(a, b)
traced_foo = make_traced(foo)
a = torch.randn((1, 2, 3, 4, 5), device='cuda')
b = torch.randn((1, 2, 3, 4, 5), device='cuda')
result = traced_foo(a, b, executor='nvfuser')
```
Currently only operations with tensor inputs and one tensor output are supported, and the operation must be composed exclusively of reference or prim operations.
Finally, this adds a new test, test_prims.py, that just tests the broadcast_in_dim prim for now. In the future we'll likely have OpInfos for each prim, but we'll need a reference implementation of broadcast_in_dim to make that interesting.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76560
Approved by: https://github.com/ngimel
Adds a prototype tracer with no caching support and the `ElementwiseUnaryPythonRefInfo` class. A reference for `floor` is added to test the latter, and the elementwise binary reference inputs are extended to also return noncontiguous inputs. The SampleInput transform operation has been updated to return an actual SampleInput instead of a tuple to facilitate uniform handling of (transformed) SampleInputs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76388
Approved by: https://github.com/ngimel