This PR adds nvfuser-specific primitive - `var_mean`.
Interpretation `torch.var_mean` -> `torch.ops.nvprims.var_mean` is handled by `TorchRefsNvfuserCapabilityMode` context manager.
I moved some helper code from `_prims/__init__.py` to `_prims_common`. Correctness is tested with OpInfo tests (see `PythonRefInfo("ops.nvprims.var_mean"`).
Layer norm reference now uses `torch.var_mean` instead of `torch._refs.var_mean` to allow interception. Here's a simple comparison of performance with this PR and master (on 3080ti):
```py
import torch
from torch._prims.context import TorchRefsNvfuserCapabilityMode
from torch.fx.experimental.proxy_tensor import make_fx
from torch._prims.executor import execute
def func(a):
return torch.native_layer_norm(a, (1024,), None, None, 1e-6)
a = torch.randn(10, 512, 1024, dtype=torch.float16, device="cuda")
with TorchRefsNvfuserCapabilityMode():
gm = make_fx(func)(a)
for _ in range(10):
execute(gm, a, executor="strictly_nvfuser");
```
run with `PYTORCH_NVFUSER_DUMP=dump_eff_bandwidth python script.py`
```py
# WITH THIS PR
# kernel1 run in 0.032768 ms, achieved: 641.25 GB/s
# kernel1 run in 0.033792 ms, achieved: 621.818 GB/s
# kernel1 run in 0.032768 ms, achieved: 641.25 GB/s
# kernel1 run in 0.032608 ms, achieved: 644.396 GB/s
# kernel1 run in 0.031744 ms, achieved: 661.935 GB/s
# kernel1 run in 0.031744 ms, achieved: 661.935 GB/s
# kernel1 run in 0.032768 ms, achieved: 641.25 GB/s
# kernel1 run in 0.03072 ms, achieved: 684 GB/s
# kernel1 run in 0.031744 ms, achieved: 661.935 GB/s
# kernel1 run in 0.031744 ms, achieved: 661.935 GB/s
# ON MASTER
# kernel1 run in 0.05632 ms, achieved: 373.091 GB/s
# kernel1 run in 0.044032 ms, achieved: 477.209 GB/s
# kernel1 run in 0.044032 ms, achieved: 477.209 GB/s
# kernel1 run in 0.044032 ms, achieved: 477.209 GB/s
# kernel1 run in 0.043808 ms, achieved: 479.649 GB/s
# kernel1 run in 0.043008 ms, achieved: 488.571 GB/s
# kernel1 run in 0.044032 ms, achieved: 477.209 GB/s
# kernel1 run in 0.043008 ms, achieved: 488.571 GB/s
# kernel1 run in 0.043008 ms, achieved: 488.571 GB/s
# kernel1 run in 0.043008 ms, achieved: 488.571 GB/s
```
So this PR gives about 35% improvement in performance using nvfuser executor with this specific normalized shape.
Also this PR fixes https://github.com/pytorch/pytorch/issues/83506 (see the change in `torch/csrc/jit/python/pybind_utils.cpp`).
Ref. https://github.com/pytorch/pytorch/issues/80187
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83508
Approved by: https://github.com/ngimel
Conditional decomposing aten::_to_copy to nvprim::convert_element_type to allow fusion with type casting, which is introduced during type promotion phase at torch decomposition.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83782
Approved by: https://github.com/ngimel
This PR adds nvfuser-specific primitive - `var_mean`.
Interpretation `torch.var_mean` -> `torch.ops.nvprims.var_mean` is handled by `TorchRefsNvfuserCapabilityMode` context manager.
I moved some helper code from `_prims/__init__.py` to `_prims_common`. Correctness is tested with OpInfo tests (see `PythonRefInfo("ops.nvprims.var_mean"`).
Layer norm reference now uses `torch.var_mean` instead of `torch._refs.var_mean` to allow interception. Here's a simple comparison of performance with this PR and master (on 3080ti):
```py
import torch
from torch._prims.context import TorchRefsNvfuserCapabilityMode
from torch.fx.experimental.proxy_tensor import make_fx
from torch._prims.executor import execute
def func(a):
return torch.native_layer_norm(a, (1024,), None, None, 1e-6)
a = torch.randn(10, 512, 1024, dtype=torch.float16, device="cuda")
with TorchRefsNvfuserCapabilityMode():
gm = make_fx(func)(a)
for _ in range(10):
execute(gm, a, executor="strictly_nvfuser");
```
run with `PYTORCH_NVFUSER_DUMP=dump_eff_bandwidth python script.py`
```py
# WITH THIS PR
# kernel1 run in 0.032768 ms, achieved: 641.25 GB/s
# kernel1 run in 0.033792 ms, achieved: 621.818 GB/s
# kernel1 run in 0.032768 ms, achieved: 641.25 GB/s
# kernel1 run in 0.032608 ms, achieved: 644.396 GB/s
# kernel1 run in 0.031744 ms, achieved: 661.935 GB/s
# kernel1 run in 0.031744 ms, achieved: 661.935 GB/s
# kernel1 run in 0.032768 ms, achieved: 641.25 GB/s
# kernel1 run in 0.03072 ms, achieved: 684 GB/s
# kernel1 run in 0.031744 ms, achieved: 661.935 GB/s
# kernel1 run in 0.031744 ms, achieved: 661.935 GB/s
# ON MASTER
# kernel1 run in 0.05632 ms, achieved: 373.091 GB/s
# kernel1 run in 0.044032 ms, achieved: 477.209 GB/s
# kernel1 run in 0.044032 ms, achieved: 477.209 GB/s
# kernel1 run in 0.044032 ms, achieved: 477.209 GB/s
# kernel1 run in 0.043808 ms, achieved: 479.649 GB/s
# kernel1 run in 0.043008 ms, achieved: 488.571 GB/s
# kernel1 run in 0.044032 ms, achieved: 477.209 GB/s
# kernel1 run in 0.043008 ms, achieved: 488.571 GB/s
# kernel1 run in 0.043008 ms, achieved: 488.571 GB/s
# kernel1 run in 0.043008 ms, achieved: 488.571 GB/s
```
So this PR gives about 35% improvement in performance using nvfuser executor with this specific normalized shape.
Also this PR fixes https://github.com/pytorch/pytorch/issues/83506 (see the change in `torch/csrc/jit/python/pybind_utils.cpp`).
Ref. https://github.com/pytorch/pytorch/issues/80187
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83508
Approved by: https://github.com/ngimel
This is a new version of #15648 based on the latest master branch.
Unlike the previous PR where I fixed a lot of the doctests in addition to integrating xdoctest, I'm going to reduce the scope here. I'm simply going to integrate xdoctest, and then I'm going to mark all of the failing tests as "SKIP". This will let xdoctest run on the dashboards, provide some value, and still let the dashboards pass. I'll leave fixing the doctests themselves to another PR.
In my initial commit, I do the bare minimum to get something running with failing dashboards. The few tests that I marked as skip are causing segfaults. Running xdoctest results in 293 failed, 201 passed tests. The next commits will be to disable those tests. (unfortunately I don't have a tool that will insert the `#xdoctest: +SKIP` directive over every failing test, so I'm going to do this mostly manually.)
Fixes https://github.com/pytorch/pytorch/issues/71105
@ezyang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82797
Approved by: https://github.com/ezyang
New namespace `torch.ops.nvprims` is meant for specific to the nvFuser set of primitives. All `impl_nvfuser` attributes are removed from `torch.ops.prims` functions.
`NvfuserPrimsMode()` context manager can be used for automatic rewrite of `torch.ops.prims` calls to `torch.ops.nvprims` when possible.
The previous way to test whether a prim would be executable with nvFuser was to test `impl_nvfuser is not None`, now all functions in the `torch.ops.nvprims` namespace are supposed to have the `impl_nvfuser` attribute and hence all are executable by nvFuser.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82155
Approved by: https://github.com/jjsjann123, https://github.com/ngimel
Adds a new context manager `TorchRefsNvfuserCapabilityMode` for conditional rewrite of `torch.*` calls to `torch._refs.*` based on whether the decomposition consisting of prims supports nvFuser execution or not.
A new optional argument for `TorchRefsMode` is added - `should_fallback_fn`, a callable that returns whether the original `torch.foo` or the replacement `torch._refs.foo` should be used.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81764
Approved by: https://github.com/ezyang
This makes symbolic tracing tests for logsigmoid and xlogy start working again.
While I'm at it, add pin_memory and layout kwargs to empty; but they
don't actually do anything and raise an error if they are non standard.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82332
Approved by: https://github.com/eellison
This ref does more things than `torch.norm`, and it fixes a few bugs
that `torch.norm` has. This implementation and the `torch.norm`
implementation come to terms in the next PR of this stack
We put this PR before, as otherwise `test_decomp.py` was failing.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81765
Approved by: https://github.com/ngimel
This PR lifts the restriction that the output of a function traced with `make_traced` and executed with nvFuser must be a single tensor. Now it's possible to return a "pytree", a tensor's nested data structure (see https://github.com/pytorch/pytorch/blob/master/torch/utils/_pytree.py).
I added a test with a function that returns a tuple of two objects where one of the objects is a dictionary with a tensor value.
```py
def fn(a, b):
d = {}
d["c"] = torch.add(a, b)
return (d, torch.add(a, d["c"]))
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78802
Approved by: https://github.com/mruberry
This means it can be fed through traditional PyTorch C++ code
(although currently it does not work, as the __torch_dispatch__
implementation is stubbed to always throw an error.)
Signed-off-by: Edward Z. Yang <ezyangfb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76759
Approved by: https://github.com/mruberry
This adds prototype nvFuser integration for the following prims:
- broadcast_in_dim
- convert_element_type
- add
- div
- ge
- gt
- le
- lt
- mul
Adding it for additional prims supported by nvFuser's prototype Python frontend should be easy.
This also adds a new sugar to run operations using the ATen or nvFuser trace executors. For example:
```
def foo(a, b):
return torch.add(a, b)
traced_foo = make_traced(foo)
a = torch.randn((1, 2, 3, 4, 5), device='cuda')
b = torch.randn((1, 2, 3, 4, 5), device='cuda')
result = traced_foo(a, b, executor='nvfuser')
```
Currently only operations with tensor inputs and one tensor output are supported, and the operation must be composed exclusively of reference or prim operations.
Finally, this adds a new test, test_prims.py, that just tests the broadcast_in_dim prim for now. In the future we'll likely have OpInfos for each prim, but we'll need a reference implementation of broadcast_in_dim to make that interesting.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76560
Approved by: https://github.com/ngimel
Adds a prototype tracer with no caching support and the `ElementwiseUnaryPythonRefInfo` class. A reference for `floor` is added to test the latter, and the elementwise binary reference inputs are extended to also return noncontiguous inputs. The SampleInput transform operation has been updated to return an actual SampleInput instead of a tuple to facilitate uniform handling of (transformed) SampleInputs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76388
Approved by: https://github.com/ngimel