Commit Graph

4 Commits

Author SHA1 Message Date
Edward Z. Yang
48eb8d6aad Use TorchFunctionMode to implement PrimTorch tracing context
Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76735

Approved by: https://github.com/mruberry
2022-05-04 23:49:46 +00:00
Mike Ruberry
f6bbecf8b5 Adds python ref consistency test, elementwise unary reference inputs, and formats test files
Per title.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76626
Approved by: https://github.com/ngimel
2022-05-01 22:42:46 +00:00
Mike Ruberry
fe1968dea0 [primTorch] Prototype nvFuser integration and test_prims.py
This adds prototype nvFuser integration for the following prims:

- broadcast_in_dim
- convert_element_type
- add
- div
- ge
- gt
- le
- lt
- mul

Adding it for additional prims supported by nvFuser's prototype Python frontend should be easy.

This also adds a new sugar to run operations using the ATen or nvFuser trace executors. For example:

```
def foo(a, b):
  return torch.add(a, b)

traced_foo = make_traced(foo)

a = torch.randn((1, 2, 3, 4, 5), device='cuda')
b = torch.randn((1, 2, 3, 4, 5), device='cuda')
result = traced_foo(a, b, executor='nvfuser')
```

Currently only operations with tensor inputs and one tensor output are supported, and the operation must be composed exclusively of reference or prim operations.

Finally, this adds a new test, test_prims.py, that just tests the broadcast_in_dim prim for now. In the future we'll likely have OpInfos for each prim, but we'll need a reference implementation of broadcast_in_dim to make that interesting.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76560
Approved by: https://github.com/ngimel
2022-04-29 02:02:25 +00:00
Mike Ruberry
4048d4cdd2 [primTorch] Prototype tracer and elementwise unary reference opinfo class
Adds a prototype tracer with no caching support and the `ElementwiseUnaryPythonRefInfo` class. A reference for `floor` is added to test the latter, and the elementwise binary reference inputs are extended to also return noncontiguous inputs. The SampleInput transform operation has been updated to return an actual SampleInput instead of a tuple to facilitate uniform handling of (transformed) SampleInputs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76388
Approved by: https://github.com/ngimel
2022-04-27 14:40:21 +00:00