Commit Graph

58 Commits

Author SHA1 Message Date
Horace He
e675dbadc4 Ported gelu decomp to ref (#78697)
Ugh... these are actually so painful to write without operator overloading lol.

Decided to just utilize operator overloading, and xfail the ref tests for now.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78697
Approved by: https://github.com/mruberry
2022-06-06 22:30:20 +00:00
Edward Z. Yang
587efdb5fa Replace TensorMeta with FakeTensor
Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78836

Approved by: https://github.com/albanD, https://github.com/mruberry
2022-06-05 11:51:27 +00:00
Ivan Yashchuk
df748b60f7 Allow pytrees as output for make_traced and nvfuser executor (#78802)
This PR lifts the restriction that the output of a function traced with `make_traced` and executed with nvFuser must be a single tensor. Now it's possible to return a "pytree", a tensor's nested data structure (see https://github.com/pytorch/pytorch/blob/master/torch/utils/_pytree.py).

I added a test with a function that returns a tuple of two objects where one of the objects is a dictionary with a tensor value.

```py
def fn(a, b):
    d = {}
    d["c"] = torch.add(a, b)
    return (d, torch.add(a, d["c"]))
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78802
Approved by: https://github.com/mruberry
2022-06-04 08:41:18 +00:00
Edward Z. Yang
6b273444c4 Add logit ref; allow non-refs to be called in refs.
Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77816

Approved by: https://github.com/mruberry
2022-05-21 02:35:14 +00:00
Edward Z. Yang
4a11678368 Change TensorMeta to use tensor subclass.
This means it can be fed through traditional PyTorch C++ code
(although currently it does not work, as the __torch_dispatch__
implementation is stubbed to always throw an error.)

Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76759

Approved by: https://github.com/mruberry
2022-05-04 23:49:47 +00:00
Edward Z. Yang
48eb8d6aad Use TorchFunctionMode to implement PrimTorch tracing context
Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76735

Approved by: https://github.com/mruberry
2022-05-04 23:49:46 +00:00
Mike Ruberry
fe1968dea0 [primTorch] Prototype nvFuser integration and test_prims.py
This adds prototype nvFuser integration for the following prims:

- broadcast_in_dim
- convert_element_type
- add
- div
- ge
- gt
- le
- lt
- mul

Adding it for additional prims supported by nvFuser's prototype Python frontend should be easy.

This also adds a new sugar to run operations using the ATen or nvFuser trace executors. For example:

```
def foo(a, b):
  return torch.add(a, b)

traced_foo = make_traced(foo)

a = torch.randn((1, 2, 3, 4, 5), device='cuda')
b = torch.randn((1, 2, 3, 4, 5), device='cuda')
result = traced_foo(a, b, executor='nvfuser')
```

Currently only operations with tensor inputs and one tensor output are supported, and the operation must be composed exclusively of reference or prim operations.

Finally, this adds a new test, test_prims.py, that just tests the broadcast_in_dim prim for now. In the future we'll likely have OpInfos for each prim, but we'll need a reference implementation of broadcast_in_dim to make that interesting.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76560
Approved by: https://github.com/ngimel
2022-04-29 02:02:25 +00:00
Mike Ruberry
4048d4cdd2 [primTorch] Prototype tracer and elementwise unary reference opinfo class
Adds a prototype tracer with no caching support and the `ElementwiseUnaryPythonRefInfo` class. A reference for `floor` is added to test the latter, and the elementwise binary reference inputs are extended to also return noncontiguous inputs. The SampleInput transform operation has been updated to return an actual SampleInput instead of a tuple to facilitate uniform handling of (transformed) SampleInputs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76388
Approved by: https://github.com/ngimel
2022-04-27 14:40:21 +00:00