This makes prims look as if they were defined in native_functions.yaml
but they're still all written in Python. You now need to give a full
schema string for your prims. The returned prim object is now
torch.ops.prim overload (prims are not allowed to be overloaded,
so we return the overload, not the overload packet, for speed.)
Signed-off-by: Edward Z. Yang <ezyangfb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77117
Approved by: https://github.com/mruberry, https://github.com/albanD
This PR primarily addresses augmenting the frontend to properly support `broadcast_in_dim`. This required make a new version of the `define_tensor()` that takes in the `size` and `strides` of input tensors in order to properly determine broadcasts.
This PR also has a fix for the `python_example.py` that broke when a new argument was added to reductions to allow the user to specify an output Data Type.
`define_tensor()` Interface Example:
```
fusion2 = Fusion()
input1 = torch.ones(1, 1, 4, device='cuda')
input2 = torch.ones(2, 3, 4, device='cuda')
with FusionDefinition(fusion2) as fd :
t0 = fd.define_tensor(sizes=input1.size(), strides=input1.stride())
t1 = fd.define_tensor(sizes=input2.size(), strides=input2.stride())
fd.add_input(t0)
fd.add_input(t1)
t0_b = fd.Ops.broadcast_in_dim(t0, [2, 3, 4], [0, 1, 2])
print("Broadcast TensorView", t0_b)
t2 = fd.Ops.add(t0_b, t1)
fd.add_output(t2)
```
Print statement of defined broadcast tensor:
```
Broadcast TensorView T2_l[ sbS6{1}, sbS7{1}, iS8{i2} ] DataType: float Contiguity: ttt
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76790
Approved by: https://github.com/mruberry, https://github.com/jjsjann123
This adds prototype nvFuser integration for the following prims:
- broadcast_in_dim
- convert_element_type
- add
- div
- ge
- gt
- le
- lt
- mul
Adding it for additional prims supported by nvFuser's prototype Python frontend should be easy.
This also adds a new sugar to run operations using the ATen or nvFuser trace executors. For example:
```
def foo(a, b):
return torch.add(a, b)
traced_foo = make_traced(foo)
a = torch.randn((1, 2, 3, 4, 5), device='cuda')
b = torch.randn((1, 2, 3, 4, 5), device='cuda')
result = traced_foo(a, b, executor='nvfuser')
```
Currently only operations with tensor inputs and one tensor output are supported, and the operation must be composed exclusively of reference or prim operations.
Finally, this adds a new test, test_prims.py, that just tests the broadcast_in_dim prim for now. In the future we'll likely have OpInfos for each prim, but we'll need a reference implementation of broadcast_in_dim to make that interesting.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76560
Approved by: https://github.com/ngimel