Commit Graph

5 Commits

Author SHA1 Message Date
jjsjann123
8875453d8b skip primTorch nvfuser tests on rocm (#77468)
Fixes https://github.com/pytorch/pytorch/issues/77237
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77468
Approved by: https://github.com/davidberard98
2022-05-16 18:17:31 +00:00
Edward Z. Yang
0a14a4c280 Register prims as operators.
This makes prims look as if they were defined in native_functions.yaml
but they're still all written in Python.  You now need to give a full
schema string for your prims.  The returned prim object is now
torch.ops.prim overload (prims are not allowed to be overloaded,
so we return the overload, not the overload packet, for speed.)

Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77117

Approved by: https://github.com/mruberry, https://github.com/albanD
2022-05-11 16:38:14 +00:00
Kevin Stephano
752d496c91 Fix broadcast_in_dim support in NVFuser Frontend (#76790)
This PR primarily addresses augmenting the frontend to properly support `broadcast_in_dim`.  This required make a new version of the `define_tensor()` that takes in the `size` and `strides` of input tensors in order to properly determine broadcasts.

This PR also has a fix for the `python_example.py` that broke when a new argument was added to reductions to allow the user to specify an output Data Type.

`define_tensor()` Interface Example:

```
fusion2 = Fusion()

input1 = torch.ones(1, 1, 4, device='cuda')
input2 = torch.ones(2, 3, 4, device='cuda')

with FusionDefinition(fusion2) as fd :
    t0 = fd.define_tensor(sizes=input1.size(), strides=input1.stride())
    t1 = fd.define_tensor(sizes=input2.size(), strides=input2.stride())

    fd.add_input(t0)
    fd.add_input(t1)

    t0_b = fd.Ops.broadcast_in_dim(t0, [2, 3, 4], [0, 1, 2])
    print("Broadcast TensorView", t0_b)
    t2 = fd.Ops.add(t0_b, t1)

    fd.add_output(t2)
```
Print statement of defined broadcast tensor:

```
Broadcast TensorView T2_l[ sbS6{1}, sbS7{1}, iS8{i2} ] DataType: float Contiguity: ttt
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76790
Approved by: https://github.com/mruberry, https://github.com/jjsjann123
2022-05-10 18:13:22 +00:00
Mike Ruberry
f6bbecf8b5 Adds python ref consistency test, elementwise unary reference inputs, and formats test files
Per title.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76626
Approved by: https://github.com/ngimel
2022-05-01 22:42:46 +00:00
Mike Ruberry
fe1968dea0 [primTorch] Prototype nvFuser integration and test_prims.py
This adds prototype nvFuser integration for the following prims:

- broadcast_in_dim
- convert_element_type
- add
- div
- ge
- gt
- le
- lt
- mul

Adding it for additional prims supported by nvFuser's prototype Python frontend should be easy.

This also adds a new sugar to run operations using the ATen or nvFuser trace executors. For example:

```
def foo(a, b):
  return torch.add(a, b)

traced_foo = make_traced(foo)

a = torch.randn((1, 2, 3, 4, 5), device='cuda')
b = torch.randn((1, 2, 3, 4, 5), device='cuda')
result = traced_foo(a, b, executor='nvfuser')
```

Currently only operations with tensor inputs and one tensor output are supported, and the operation must be composed exclusively of reference or prim operations.

Finally, this adds a new test, test_prims.py, that just tests the broadcast_in_dim prim for now. In the future we'll likely have OpInfos for each prim, but we'll need a reference implementation of broadcast_in_dim to make that interesting.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76560
Approved by: https://github.com/ngimel
2022-04-29 02:02:25 +00:00