Commit Graph

33 Commits

Author SHA1 Message Date
Ivan Yashchuk
3aae6ff1e1 Add nvprims.var_mean (#83508)
This PR adds nvfuser-specific primitive - `var_mean`.
Interpretation `torch.var_mean` -> `torch.ops.nvprims.var_mean` is handled by `TorchRefsNvfuserCapabilityMode` context manager.

I moved some helper code from `_prims/__init__.py` to `_prims_common`. Correctness is tested with OpInfo tests (see `PythonRefInfo("ops.nvprims.var_mean"`).

Layer norm reference now uses `torch.var_mean` instead of `torch._refs.var_mean` to allow interception. Here's a simple comparison of performance with this PR and master (on 3080ti):
```py
import torch
from torch._prims.context import TorchRefsNvfuserCapabilityMode
from torch.fx.experimental.proxy_tensor import make_fx
from torch._prims.executor import execute

def func(a):
    return torch.native_layer_norm(a, (1024,), None, None, 1e-6)

a = torch.randn(10, 512, 1024, dtype=torch.float16, device="cuda")

with TorchRefsNvfuserCapabilityMode():
    gm = make_fx(func)(a)

for _ in range(10):
    execute(gm, a, executor="strictly_nvfuser");
```
run with `PYTORCH_NVFUSER_DUMP=dump_eff_bandwidth python script.py`
```py
# WITH THIS PR
# kernel1 run in 0.032768 ms, achieved: 641.25 GB/s
# kernel1 run in 0.033792 ms, achieved: 621.818 GB/s
# kernel1 run in 0.032768 ms, achieved: 641.25 GB/s
# kernel1 run in 0.032608 ms, achieved: 644.396 GB/s
# kernel1 run in 0.031744 ms, achieved: 661.935 GB/s
# kernel1 run in 0.031744 ms, achieved: 661.935 GB/s
# kernel1 run in 0.032768 ms, achieved: 641.25 GB/s
# kernel1 run in 0.03072 ms, achieved: 684 GB/s
# kernel1 run in 0.031744 ms, achieved: 661.935 GB/s
# kernel1 run in 0.031744 ms, achieved: 661.935 GB/s

# ON MASTER
# kernel1 run in 0.05632 ms, achieved: 373.091 GB/s
# kernel1 run in 0.044032 ms, achieved: 477.209 GB/s
# kernel1 run in 0.044032 ms, achieved: 477.209 GB/s
# kernel1 run in 0.044032 ms, achieved: 477.209 GB/s
# kernel1 run in 0.043808 ms, achieved: 479.649 GB/s
# kernel1 run in 0.043008 ms, achieved: 488.571 GB/s
# kernel1 run in 0.044032 ms, achieved: 477.209 GB/s
# kernel1 run in 0.043008 ms, achieved: 488.571 GB/s
# kernel1 run in 0.043008 ms, achieved: 488.571 GB/s
# kernel1 run in 0.043008 ms, achieved: 488.571 GB/s
```
So this PR gives about 35% improvement in performance using nvfuser executor with this specific normalized shape.

Also this PR fixes https://github.com/pytorch/pytorch/issues/83506 (see the change in `torch/csrc/jit/python/pybind_utils.cpp`).

Ref. https://github.com/pytorch/pytorch/issues/80187

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83508
Approved by: https://github.com/ngimel
2022-08-28 18:45:25 +00:00
PyTorch MergeBot
b159a5230f Revert "Add nvprims.var_mean (#83508)"
This reverts commit 7e7694b661.

Reverted https://github.com/pytorch/pytorch/pull/83508 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally
2022-08-28 11:30:27 +00:00
jjsjann123
b078d242c4 Nvfuser to copy decomp to prim (#83782)
Conditional decomposing aten::_to_copy to nvprim::convert_element_type to allow fusion with type casting, which is introduced during type promotion phase at torch decomposition.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83782
Approved by: https://github.com/ngimel
2022-08-28 04:26:36 +00:00
Ivan Yashchuk
7e7694b661 Add nvprims.var_mean (#83508)
This PR adds nvfuser-specific primitive - `var_mean`.
Interpretation `torch.var_mean` -> `torch.ops.nvprims.var_mean` is handled by `TorchRefsNvfuserCapabilityMode` context manager.

I moved some helper code from `_prims/__init__.py` to `_prims_common`. Correctness is tested with OpInfo tests (see `PythonRefInfo("ops.nvprims.var_mean"`).

Layer norm reference now uses `torch.var_mean` instead of `torch._refs.var_mean` to allow interception. Here's a simple comparison of performance with this PR and master (on 3080ti):
```py
import torch
from torch._prims.context import TorchRefsNvfuserCapabilityMode
from torch.fx.experimental.proxy_tensor import make_fx
from torch._prims.executor import execute

def func(a):
    return torch.native_layer_norm(a, (1024,), None, None, 1e-6)

a = torch.randn(10, 512, 1024, dtype=torch.float16, device="cuda")

with TorchRefsNvfuserCapabilityMode():
    gm = make_fx(func)(a)

for _ in range(10):
    execute(gm, a, executor="strictly_nvfuser");
```
run with `PYTORCH_NVFUSER_DUMP=dump_eff_bandwidth python script.py`
```py
# WITH THIS PR
# kernel1 run in 0.032768 ms, achieved: 641.25 GB/s
# kernel1 run in 0.033792 ms, achieved: 621.818 GB/s
# kernel1 run in 0.032768 ms, achieved: 641.25 GB/s
# kernel1 run in 0.032608 ms, achieved: 644.396 GB/s
# kernel1 run in 0.031744 ms, achieved: 661.935 GB/s
# kernel1 run in 0.031744 ms, achieved: 661.935 GB/s
# kernel1 run in 0.032768 ms, achieved: 641.25 GB/s
# kernel1 run in 0.03072 ms, achieved: 684 GB/s
# kernel1 run in 0.031744 ms, achieved: 661.935 GB/s
# kernel1 run in 0.031744 ms, achieved: 661.935 GB/s

# ON MASTER
# kernel1 run in 0.05632 ms, achieved: 373.091 GB/s
# kernel1 run in 0.044032 ms, achieved: 477.209 GB/s
# kernel1 run in 0.044032 ms, achieved: 477.209 GB/s
# kernel1 run in 0.044032 ms, achieved: 477.209 GB/s
# kernel1 run in 0.043808 ms, achieved: 479.649 GB/s
# kernel1 run in 0.043008 ms, achieved: 488.571 GB/s
# kernel1 run in 0.044032 ms, achieved: 477.209 GB/s
# kernel1 run in 0.043008 ms, achieved: 488.571 GB/s
# kernel1 run in 0.043008 ms, achieved: 488.571 GB/s
# kernel1 run in 0.043008 ms, achieved: 488.571 GB/s
```
So this PR gives about 35% improvement in performance using nvfuser executor with this specific normalized shape.

Also this PR fixes https://github.com/pytorch/pytorch/issues/83506 (see the change in `torch/csrc/jit/python/pybind_utils.cpp`).

Ref. https://github.com/pytorch/pytorch/issues/80187

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83508
Approved by: https://github.com/ngimel
2022-08-27 09:05:20 +00:00
joncrall
b136f3f310 More doctest refinements. (#83317)
Follow up to #82797

Now that the doctests themselves are in a better state, we should be able to enable xdoctest on the CI so they stay that way.

@ezyang @vadimkantorov
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83317
Approved by: https://github.com/ezyang
2022-08-22 20:07:26 +00:00
Ivan Yashchuk
9f03444f70 Add torch.ops.aten -> torch._refs mapping to TorchRefsMode using decomposition_table (#82657)
### Description
This PR adds the possibility to convert `torch.ops.aten` calls to `torch._refs` and consequently prims under TorchRefsMode.

### Testing
New test, `test_aten_overload_to_prims`, in `test/test_prims.py`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/82657
Approved by: https://github.com/jjsjann123, https://github.com/ezyang
2022-08-17 14:46:06 +00:00
Fabio Rocha
2a096e940d [primTorch] support for a few magic methods (#83524)
Added support for mapping __rsub__, __rtruediv__,
__rfloordiv__, __floordiv__, __pow__,
and __rpow__ in TorchRefsMode.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83524
Approved by: https://github.com/ngimel
2022-08-17 09:48:15 +00:00
joncrall
4618371da5 Integrate xdoctest - Rebased (#82797)
This is a new version of #15648 based on the latest master branch.

Unlike the previous PR where I fixed a lot of the doctests in addition to integrating xdoctest, I'm going to reduce the scope here. I'm simply going to integrate xdoctest, and then I'm going to mark all of the failing tests as "SKIP". This will let xdoctest run on the dashboards, provide some value, and still let the dashboards pass. I'll leave fixing the doctests themselves to another PR.

In my initial commit, I do the bare minimum to get something running with failing dashboards. The few tests that I marked as skip are causing segfaults. Running xdoctest results in 293 failed, 201 passed tests. The next commits will be to disable those tests. (unfortunately I don't have a tool that will insert the `#xdoctest: +SKIP` directive over every failing test, so I'm going to do this mostly manually.)

Fixes https://github.com/pytorch/pytorch/issues/71105

@ezyang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82797
Approved by: https://github.com/ezyang
2022-08-12 02:08:01 +00:00
Ivan Yashchuk
ec67c6abbe Add torch.ops.nvprims namespace for nvFuser-specific prims (#82155)
New namespace `torch.ops.nvprims` is meant for specific to the nvFuser set of primitives. All `impl_nvfuser` attributes are removed from `torch.ops.prims` functions.

`NvfuserPrimsMode()` context manager can be used for automatic rewrite of `torch.ops.prims` calls to `torch.ops.nvprims` when possible.

The previous way to test whether a prim would be executable with nvFuser was to test `impl_nvfuser is not None`, now all functions in the `torch.ops.nvprims` namespace are supposed to have the `impl_nvfuser` attribute and hence all are executable by nvFuser.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/82155
Approved by: https://github.com/jjsjann123, https://github.com/ngimel
2022-08-04 16:51:56 +00:00
Ivan Yashchuk
900e93d351 Add context manager for conditional rewrites of torch.* to torch._refs.* calls (#81764)
Adds a new context manager `TorchRefsNvfuserCapabilityMode` for conditional rewrite of `torch.*` calls to `torch._refs.*` based on whether the decomposition consisting of prims supports nvFuser execution or not.

A new optional argument for `TorchRefsMode` is added - `should_fallback_fn`, a callable that returns whether the original `torch.foo` or the replacement `torch._refs.foo` should be used.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/81764
Approved by: https://github.com/ezyang
2022-08-02 11:02:10 +00:00
Edward Z. Yang
98b9dfa129 Add decompositions for zero_, fill_, new_full, new_zeros, new_ones (#82332)
This makes symbolic tracing tests for logsigmoid and xlogy start working again.

While I'm at it, add pin_memory and layout kwargs to empty; but they
don't actually do anything and raise an error if they are non standard.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82332
Approved by: https://github.com/eellison
2022-07-28 04:02:02 +00:00
lezcano
11fe277b62 [PrimTorch] Add reference for torch.norm (#81765)
This ref does more things than `torch.norm`, and it fixes a few bugs
that `torch.norm` has. This implementation and the `torch.norm`
implementation come to terms in the next PR of this stack

We put this PR before, as otherwise `test_decomp.py` was failing.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81765
Approved by: https://github.com/ngimel
2022-07-25 19:57:21 +00:00
Huy Do
12cb26509a Apply ufmt to torch internal (#81643)
This is a big bang PR, merge conflicts are probably expected and will be addressed at merge.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81643
Approved by: https://github.com/ezyang
2022-07-22 02:19:50 +00:00
Horace He
a5fb41e3d3 Revert "Revert "Refactored prim utils into _prims_utils folder (#81746)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81746
Approved by: https://github.com/anijain2305, https://github.com/Krovatkin
2022-07-20 23:43:57 +00:00
PyTorch MergeBot
e43a02c314 Revert "Refactored prim utils into _prims_utils folder (#81088)"
This reverts commit 80231d0a72.

Reverted https://github.com/pytorch/pytorch/pull/81088 on behalf of https://github.com/jeanschmidt due to breaking internal tests
2022-07-19 19:56:41 +00:00
Horace He
80231d0a72 Refactored prim utils into _prims_utils folder (#81088)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81088
Approved by: https://github.com/ngimel
2022-07-19 03:55:51 +00:00
Peter Bell
bf36d8b987 [primTorch] Implement one-dimensional fft transforms (#80570)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80570
Approved by: https://github.com/mruberry
2022-07-15 15:13:43 +00:00
Ivan Yashchuk
0d5bc54114 Fix interpretation torch -> torch._refs in case of nested torch calls under TorchRefsMode (#80135)
torch calls inside `TorchRefsMode.__torch_function__` dispatch should be interpreted as refs calls under `TorchRefsMode`. Fixes https://github.com/pytorch/pytorch/issues/80079.

In addition, this PR enables two more tests for the nvFuser executor.

For example here's the FX trace of `torch._refs.nn.functional.layer_norm` before the proposed change (note the mix of `aten` and `prims`):
```py
opcode         name                    target                      args                              kwargs
-------------  ----------------------  --------------------------  --------------------------------  -----------------
placeholder    a_1                     a_1                         ()                                {}
call_function  convert_element_type    prims.convert_element_type  (a_1, torch.float32)              {}
call_function  var                     prims.var                   (convert_element_type, [0, 1])    {'correction': 0}
call_function  broadcast_in_dim        prims.broadcast_in_dim      (var, [1, 1], [])                 {}
call_function  convert_element_type_1  prims.convert_element_type  (a_1, torch.float32)              {}
call_function  sum_1                   prims.sum                   (convert_element_type_1, [0, 1])  {}
call_function  broadcast_in_dim_1      prims.broadcast_in_dim      (sum_1, [1, 1], [])               {}
call_function  div                     prims.div                   (broadcast_in_dim_1, 9.0)         {}
call_function  add                     aten.add                    (broadcast_in_dim, 1e-05)         {}
call_function  rsqrt                   aten.rsqrt                  (add,)                            {}
call_function  sub                     aten.sub                    (a_1, div)                        {}
call_function  mul                     aten.mul                    (sub, rsqrt)                      {}
call_function  convert_element_type_2  prims.convert_element_type  (mul, torch.float32)              {}
output         output                  output                      (convert_element_type_2,)         {}
```
And with this PR:
```py
opcode         name                    target                      args                              kwargs
-------------  ----------------------  --------------------------  --------------------------------  -----------------
placeholder    a_1                     a_1                         ()                                {}
call_function  convert_element_type    prims.convert_element_type  (a_1, torch.float32)              {}
call_function  var                     prims.var                   (convert_element_type, [0, 1])    {'correction': 0}
call_function  broadcast_in_dim        prims.broadcast_in_dim      (var, [1, 1], [])                 {}
call_function  convert_element_type_1  prims.convert_element_type  (a_1, torch.float32)              {}
call_function  sum_1                   prims.sum                   (convert_element_type_1, [0, 1])  {}
call_function  broadcast_in_dim_1      prims.broadcast_in_dim      (sum_1, [1, 1], [])               {}
call_function  div                     prims.div                   (broadcast_in_dim_1, 9.0)         {}
call_function  add                     prims.add                   (broadcast_in_dim, 1e-05)         {}
call_function  rsqrt                   prims.rsqrt                 (add,)                            {}
call_function  broadcast_in_dim_2      prims.broadcast_in_dim      (div, [3, 3], [0, 1])             {}
call_function  sub                     prims.sub                   (a_1, broadcast_in_dim_2)         {}
call_function  broadcast_in_dim_3      prims.broadcast_in_dim      (rsqrt, [3, 3], [0, 1])           {}
call_function  mul                     prims.mul                   (sub, broadcast_in_dim_3)         {}
call_function  convert_element_type_2  prims.convert_element_type  (mul, torch.float32)              {}
output         output                  output                      (convert_element_type_2,)         {}
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80135
Approved by: https://github.com/ngimel
2022-06-25 03:55:04 +00:00
Horace He
e89676f76c fix logical_not reland issues
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79900

Approved by: https://github.com/ngimel
2022-06-21 03:41:18 +00:00
Nikita Shulga
f5eb05f107 Revert "Reland #2 of "Added {logical_not, trace} refs, moved logical ops to use method overloads""
This reverts commit f3665dd237.

Reverted https://github.com/pytorch/pytorch/pull/79819 on behalf of https://github.com/malfet due to land raced with softshrink refs
2022-06-20 14:22:15 -07:00
Horace He
f3665dd237 Reland #2 of "Added {logical_not, trace} refs, moved logical ops to use method overloads"
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79819

Approved by: https://github.com/mruberry
2022-06-20 19:50:43 +00:00
PyTorch MergeBot
fefff54cad Revert "Revert "Revert "Added {logical_not, trace} refs, moved logical ops to use method overloads"""
This reverts commit a2d2981e8e.

Reverted https://github.com/pytorch/pytorch/pull/79224 on behalf of https://github.com/suo due to broke lots of things a2d2981e8e
2022-06-10 04:40:43 +00:00
Horace He
a2d2981e8e Revert "Revert "Added {logical_not, trace} refs, moved logical ops to use method overloads""
This reverts commit d67309aefb.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79224

Approved by: https://github.com/mruberry
2022-06-10 03:07:14 +00:00
PyTorch MergeBot
d67309aefb Revert "Added {logical_not, trace} refs, moved logical ops to use method overloads"
This reverts commit 64b6bd8c1e.

Reverted https://github.com/pytorch/pytorch/pull/79000 on behalf of https://github.com/malfet due to Introduces test failure, see https://hud.pytorch.org/pr/79000
2022-06-09 13:11:23 +00:00
Horace He
64b6bd8c1e Added {logical_not, trace} refs, moved logical ops to use method overloads
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79000

Approved by: https://github.com/ezyang
2022-06-09 07:16:36 +00:00
Horace He
e675dbadc4 Ported gelu decomp to ref (#78697)
Ugh... these are actually so painful to write without operator overloading lol.

Decided to just utilize operator overloading, and xfail the ref tests for now.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78697
Approved by: https://github.com/mruberry
2022-06-06 22:30:20 +00:00
Edward Z. Yang
587efdb5fa Replace TensorMeta with FakeTensor
Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78836

Approved by: https://github.com/albanD, https://github.com/mruberry
2022-06-05 11:51:27 +00:00
Ivan Yashchuk
df748b60f7 Allow pytrees as output for make_traced and nvfuser executor (#78802)
This PR lifts the restriction that the output of a function traced with `make_traced` and executed with nvFuser must be a single tensor. Now it's possible to return a "pytree", a tensor's nested data structure (see https://github.com/pytorch/pytorch/blob/master/torch/utils/_pytree.py).

I added a test with a function that returns a tuple of two objects where one of the objects is a dictionary with a tensor value.

```py
def fn(a, b):
    d = {}
    d["c"] = torch.add(a, b)
    return (d, torch.add(a, d["c"]))
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78802
Approved by: https://github.com/mruberry
2022-06-04 08:41:18 +00:00
Edward Z. Yang
6b273444c4 Add logit ref; allow non-refs to be called in refs.
Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77816

Approved by: https://github.com/mruberry
2022-05-21 02:35:14 +00:00
Edward Z. Yang
4a11678368 Change TensorMeta to use tensor subclass.
This means it can be fed through traditional PyTorch C++ code
(although currently it does not work, as the __torch_dispatch__
implementation is stubbed to always throw an error.)

Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76759

Approved by: https://github.com/mruberry
2022-05-04 23:49:47 +00:00
Edward Z. Yang
48eb8d6aad Use TorchFunctionMode to implement PrimTorch tracing context
Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76735

Approved by: https://github.com/mruberry
2022-05-04 23:49:46 +00:00
Mike Ruberry
fe1968dea0 [primTorch] Prototype nvFuser integration and test_prims.py
This adds prototype nvFuser integration for the following prims:

- broadcast_in_dim
- convert_element_type
- add
- div
- ge
- gt
- le
- lt
- mul

Adding it for additional prims supported by nvFuser's prototype Python frontend should be easy.

This also adds a new sugar to run operations using the ATen or nvFuser trace executors. For example:

```
def foo(a, b):
  return torch.add(a, b)

traced_foo = make_traced(foo)

a = torch.randn((1, 2, 3, 4, 5), device='cuda')
b = torch.randn((1, 2, 3, 4, 5), device='cuda')
result = traced_foo(a, b, executor='nvfuser')
```

Currently only operations with tensor inputs and one tensor output are supported, and the operation must be composed exclusively of reference or prim operations.

Finally, this adds a new test, test_prims.py, that just tests the broadcast_in_dim prim for now. In the future we'll likely have OpInfos for each prim, but we'll need a reference implementation of broadcast_in_dim to make that interesting.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76560
Approved by: https://github.com/ngimel
2022-04-29 02:02:25 +00:00
Mike Ruberry
4048d4cdd2 [primTorch] Prototype tracer and elementwise unary reference opinfo class
Adds a prototype tracer with no caching support and the `ElementwiseUnaryPythonRefInfo` class. A reference for `floor` is added to test the latter, and the elementwise binary reference inputs are extended to also return noncontiguous inputs. The SampleInput transform operation has been updated to return an actual SampleInput instead of a tuple to facilitate uniform handling of (transformed) SampleInputs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76388
Approved by: https://github.com/ngimel
2022-04-27 14:40:21 +00:00