Commit Graph

24 Commits

Author SHA1 Message Date
Kurt Mohler
ee83c646bb Replace _prims_common.check with torch._check* (#103240)
This relands most of the changes from #102219 which were backed out by #103128. However, instead of removing `_prims_common.check`, it adds a warning and a comment mentioning that it will be removed in the future and `torch._check*` should be used instead. As mentioned in https://github.com/pytorch/pytorch/pull/103128#pullrequestreview-1466414415, `_prims_common.check` cannot yet be removed because of some internal usage

Part of #72948

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103240
Approved by: https://github.com/albanD
2023-06-21 00:46:17 +00:00
Ivan Zaitsev
821493715c Back out "Remove check from _prims_common, replace with torch._check* (#102219)", Back out "Forwatd fix for D46427687" (#103128)
Test Plan: revertitparrot

Reviewed By: malfet

Differential Revision: D46506433

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103128
Approved by: https://github.com/malfet
2023-06-07 01:41:41 +00:00
Kurt Mohler
a84bb2709a Remove check from _prims_common, replace with torch._check* (#102219)
Part of #72948

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102219
Approved by: https://github.com/lezcano, https://github.com/albanD
2023-06-03 02:23:21 +00:00
PyTorch MergeBot
a7efa0ce35 Revert "Remove check from _prims_common, replace with torch._check* (#102219)"
This reverts commit fb79d43649.

Reverted https://github.com/pytorch/pytorch/pull/102219 on behalf of https://github.com/malfet due to Broke lint, see https://github.com/pytorch/pytorch/actions/runs/5158949959/jobs/9293466925 ([comment](https://github.com/pytorch/pytorch/pull/102219#issuecomment-1574245414))
2023-06-02 20:00:48 +00:00
Kurt Mohler
fb79d43649 Remove check from _prims_common, replace with torch._check* (#102219)
Part of #72948

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102219
Approved by: https://github.com/lezcano, https://github.com/albanD
2023-06-02 19:13:45 +00:00
Richard Zou
fc31b3a106 Allow existing "Python RAII guards" to be used as context managers (#102579)
This PR adds a `py_context_manager_DEPRECATED` that converts a C++ RAII
guard to an object that may be either used as Python context manager or
as a "Python RAII guard".

We don't convert all of them to Python context manager only due to BC
reasons; people in OSS and internally actually rely on these APIs and I
don't want to break them. We are justified in breaking BC if we wanted
to, but it seemed like too much work for not a lot of gain.

The API is postfixed with "DEPRECATED" to indicate that people should
really use `py_context_manager` (converts C++ RAII guard to Python
context manager) instead.

Test Plan:
- this PR converts all PyTorch usages of _AutoDispatchBelowAutograd to
context manager. I can do the rest in follow-ups.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102579
Approved by: https://github.com/bdhirsh, https://github.com/albanD
2023-05-31 19:55:38 +00:00
Nikita Karetnikov
1e591a8b64 [pt2] add meta function for solve_triangular (#100829)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100829
Approved by: https://github.com/ezyang
2023-05-08 13:48:15 +00:00
Aaron Gokaslan
8fce9a09cd [BE]: pyupgrade Python to 3.8 - imports and object inheritance only (#94308)
Apply parts of pyupgrade to torch (starting with the safest changes).
This PR only does two things: removes the need to inherit from object and removes unused future imports.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94308
Approved by: https://github.com/ezyang, https://github.com/albanD
2023-02-07 21:10:56 +00:00
Sherlock Huang
438f12d91a Rewrite some decomps to allow producing aten ops (#93099)
This introduces a new stop to the decomposition train.
Before reaching prims.view_of, it will stop at aten.alias. Export path wants to get off the train at aten ops.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/93099
Approved by: https://github.com/ngimel
2023-01-31 17:46:20 +00:00
Peter Bell
f0b592dae7 Make masked_fill reference traceable (#90972)
As the comment states, `item()` cannot be used since you can't trace through a
scalar.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90972
Approved by: https://github.com/ngimel
2023-01-18 10:54:42 +00:00
Brian Hirsh
c47bdd7522 *_scatter ops should preserve input stride/storage_offset (#91029)
It turns out that we *do* need to update *_scatter ops to return the exact same strides as their inputs. I added a test to `test/test_functionalization.py`, which now trips thanks to Ed's functionalization stride debugging check. It only actually ends up tripping silent correctness if you try to .backward() on that function.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91029
Approved by: https://github.com/ezyang
2022-12-22 19:41:53 +00:00
lezcano
6741443c7c Simplify maybe_resize_out (#88116)
The previous behaviour would call `resize_` on 0-sized elements even
when their size was correct. This would make some test fail, as resize_
may be an in-place operation and it's not supported by some subsystems

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88116
Approved by: https://github.com/mruberry
2022-11-18 14:59:45 +00:00
Fabio Rocha
493ded249e [primTorch] decomposition for bucketize (#86366)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86366
Approved by: https://github.com/mruberry
2022-10-12 12:25:42 +00:00
Edward Z. Yang
2a332afbf4 Add SymFloat, support SymInt to SymFloat conversion (#84284)
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84284
Approved by: https://github.com/albanD
2022-09-03 01:30:32 +00:00
Ivan Yashchuk
3aae6ff1e1 Add nvprims.var_mean (#83508)
This PR adds nvfuser-specific primitive - `var_mean`.
Interpretation `torch.var_mean` -> `torch.ops.nvprims.var_mean` is handled by `TorchRefsNvfuserCapabilityMode` context manager.

I moved some helper code from `_prims/__init__.py` to `_prims_common`. Correctness is tested with OpInfo tests (see `PythonRefInfo("ops.nvprims.var_mean"`).

Layer norm reference now uses `torch.var_mean` instead of `torch._refs.var_mean` to allow interception. Here's a simple comparison of performance with this PR and master (on 3080ti):
```py
import torch
from torch._prims.context import TorchRefsNvfuserCapabilityMode
from torch.fx.experimental.proxy_tensor import make_fx
from torch._prims.executor import execute

def func(a):
    return torch.native_layer_norm(a, (1024,), None, None, 1e-6)

a = torch.randn(10, 512, 1024, dtype=torch.float16, device="cuda")

with TorchRefsNvfuserCapabilityMode():
    gm = make_fx(func)(a)

for _ in range(10):
    execute(gm, a, executor="strictly_nvfuser");
```
run with `PYTORCH_NVFUSER_DUMP=dump_eff_bandwidth python script.py`
```py
# WITH THIS PR
# kernel1 run in 0.032768 ms, achieved: 641.25 GB/s
# kernel1 run in 0.033792 ms, achieved: 621.818 GB/s
# kernel1 run in 0.032768 ms, achieved: 641.25 GB/s
# kernel1 run in 0.032608 ms, achieved: 644.396 GB/s
# kernel1 run in 0.031744 ms, achieved: 661.935 GB/s
# kernel1 run in 0.031744 ms, achieved: 661.935 GB/s
# kernel1 run in 0.032768 ms, achieved: 641.25 GB/s
# kernel1 run in 0.03072 ms, achieved: 684 GB/s
# kernel1 run in 0.031744 ms, achieved: 661.935 GB/s
# kernel1 run in 0.031744 ms, achieved: 661.935 GB/s

# ON MASTER
# kernel1 run in 0.05632 ms, achieved: 373.091 GB/s
# kernel1 run in 0.044032 ms, achieved: 477.209 GB/s
# kernel1 run in 0.044032 ms, achieved: 477.209 GB/s
# kernel1 run in 0.044032 ms, achieved: 477.209 GB/s
# kernel1 run in 0.043808 ms, achieved: 479.649 GB/s
# kernel1 run in 0.043008 ms, achieved: 488.571 GB/s
# kernel1 run in 0.044032 ms, achieved: 477.209 GB/s
# kernel1 run in 0.043008 ms, achieved: 488.571 GB/s
# kernel1 run in 0.043008 ms, achieved: 488.571 GB/s
# kernel1 run in 0.043008 ms, achieved: 488.571 GB/s
```
So this PR gives about 35% improvement in performance using nvfuser executor with this specific normalized shape.

Also this PR fixes https://github.com/pytorch/pytorch/issues/83506 (see the change in `torch/csrc/jit/python/pybind_utils.cpp`).

Ref. https://github.com/pytorch/pytorch/issues/80187

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83508
Approved by: https://github.com/ngimel
2022-08-28 18:45:25 +00:00
PyTorch MergeBot
b159a5230f Revert "Add nvprims.var_mean (#83508)"
This reverts commit 7e7694b661.

Reverted https://github.com/pytorch/pytorch/pull/83508 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally
2022-08-28 11:30:27 +00:00
Ivan Yashchuk
7e7694b661 Add nvprims.var_mean (#83508)
This PR adds nvfuser-specific primitive - `var_mean`.
Interpretation `torch.var_mean` -> `torch.ops.nvprims.var_mean` is handled by `TorchRefsNvfuserCapabilityMode` context manager.

I moved some helper code from `_prims/__init__.py` to `_prims_common`. Correctness is tested with OpInfo tests (see `PythonRefInfo("ops.nvprims.var_mean"`).

Layer norm reference now uses `torch.var_mean` instead of `torch._refs.var_mean` to allow interception. Here's a simple comparison of performance with this PR and master (on 3080ti):
```py
import torch
from torch._prims.context import TorchRefsNvfuserCapabilityMode
from torch.fx.experimental.proxy_tensor import make_fx
from torch._prims.executor import execute

def func(a):
    return torch.native_layer_norm(a, (1024,), None, None, 1e-6)

a = torch.randn(10, 512, 1024, dtype=torch.float16, device="cuda")

with TorchRefsNvfuserCapabilityMode():
    gm = make_fx(func)(a)

for _ in range(10):
    execute(gm, a, executor="strictly_nvfuser");
```
run with `PYTORCH_NVFUSER_DUMP=dump_eff_bandwidth python script.py`
```py
# WITH THIS PR
# kernel1 run in 0.032768 ms, achieved: 641.25 GB/s
# kernel1 run in 0.033792 ms, achieved: 621.818 GB/s
# kernel1 run in 0.032768 ms, achieved: 641.25 GB/s
# kernel1 run in 0.032608 ms, achieved: 644.396 GB/s
# kernel1 run in 0.031744 ms, achieved: 661.935 GB/s
# kernel1 run in 0.031744 ms, achieved: 661.935 GB/s
# kernel1 run in 0.032768 ms, achieved: 641.25 GB/s
# kernel1 run in 0.03072 ms, achieved: 684 GB/s
# kernel1 run in 0.031744 ms, achieved: 661.935 GB/s
# kernel1 run in 0.031744 ms, achieved: 661.935 GB/s

# ON MASTER
# kernel1 run in 0.05632 ms, achieved: 373.091 GB/s
# kernel1 run in 0.044032 ms, achieved: 477.209 GB/s
# kernel1 run in 0.044032 ms, achieved: 477.209 GB/s
# kernel1 run in 0.044032 ms, achieved: 477.209 GB/s
# kernel1 run in 0.043808 ms, achieved: 479.649 GB/s
# kernel1 run in 0.043008 ms, achieved: 488.571 GB/s
# kernel1 run in 0.044032 ms, achieved: 477.209 GB/s
# kernel1 run in 0.043008 ms, achieved: 488.571 GB/s
# kernel1 run in 0.043008 ms, achieved: 488.571 GB/s
# kernel1 run in 0.043008 ms, achieved: 488.571 GB/s
```
So this PR gives about 35% improvement in performance using nvfuser executor with this specific normalized shape.

Also this PR fixes https://github.com/pytorch/pytorch/issues/83506 (see the change in `torch/csrc/jit/python/pybind_utils.cpp`).

Ref. https://github.com/pytorch/pytorch/issues/80187

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83508
Approved by: https://github.com/ngimel
2022-08-27 09:05:20 +00:00
Ivan Yashchuk
cb488e6d2f Allow None arguments for elementwise type promotion wrapper and fix clamp with None arguments (#83586)
Fixes https://github.com/pytorch/torchdynamo/issues/759
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83586
Approved by: https://github.com/ezyang, https://github.com/ngimel
2022-08-23 17:47:10 +00:00
Edward Z. Yang
817a82704f Delete ProxyTensor wrapper subclass (#83330)
I was working on https://github.com/pytorch/torchdynamo/issues/80 and my
working hypothesis for what was causing the error was that proxy tensor
was not advertising correct dispatch keys, causing AMP to operate
differently when you traced.  I could have fixed this directly by
replicating fake tensor's fix for setting dispatch keys to also apply to
proxy tensor, but I was like, "Why must I repeat myself."

This PR is the result.  It completely deletes the ProxyTensor wrapper
subclass, so that when we are tracing, the tensors flowing through the
program are the *original* real or fake tensors, depending on what the
user requested in the top-level API.  There is no more wrapping.  To
store the Proxy objects necessary for actually doing tracing, I store
the property directly on the tensors.  (Note: I never
clean up old entries from the map at the moment, this is easily fixed
by using a weak map)

Benefits of doing this:

* No more tip-toeing around no_dispatch() creation of new ProxyTensors;
  we never create new tensors (except when we call the underlying func),
  so you don't have to worry about accidentally tracing them.

* No more syncing up metadata from in place operators.  In particular
  https://github.com/pytorch/pytorch/issues/81526 is mooted

* This fixes https://github.com/pytorch/torchdynamo/issues/519 as we no longer need to teach proxy tensor to support sparse tensor.

* No more schlepping symbolic integers from the inner fake tensor to the
  outer proxy tensor.  If you can make a fake tensor with symbolic ints,
  you're done, nothing else to do.

To avoid having to rewrite all of the guts, when I get to the actual
proxy tensor handler, I first "fetch" the stored ProxyTensor data from
the weakmap via a tree_map, and then operate on the consequent data as
before.  A more optimized implementation is possible.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83330
Approved by: https://github.com/Chillee
2022-08-18 01:56:07 +00:00
Ivan Yashchuk
ec67c6abbe Add torch.ops.nvprims namespace for nvFuser-specific prims (#82155)
New namespace `torch.ops.nvprims` is meant for specific to the nvFuser set of primitives. All `impl_nvfuser` attributes are removed from `torch.ops.prims` functions.

`NvfuserPrimsMode()` context manager can be used for automatic rewrite of `torch.ops.prims` calls to `torch.ops.nvprims` when possible.

The previous way to test whether a prim would be executable with nvFuser was to test `impl_nvfuser is not None`, now all functions in the `torch.ops.nvprims` namespace are supposed to have the `impl_nvfuser` attribute and hence all are executable by nvFuser.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/82155
Approved by: https://github.com/jjsjann123, https://github.com/ngimel
2022-08-04 16:51:56 +00:00
soulitzer
16093a1d81 Fix primtorch out_wrapper semantics for factory functions (#82375)
This PR:
- introduces new OpInfo attribute `is_factory_function`
- updates OpInfo test_out to handle case when `is_factory_function=True`:
- correct primtorch out_wrapper
- update sample inputs for arange, linspace, logspace to not explicitly pass in dtype or device (having this sample is necessary for the test to get triggered)

Fixes https://github.com/pytorch/pytorch/issues/82364

Pull Request resolved: https://github.com/pytorch/pytorch/pull/82375
Approved by: https://github.com/ezyang, https://github.com/ngimel
2022-07-29 00:57:57 +00:00
Horace He
a5fb41e3d3 Revert "Revert "Refactored prim utils into _prims_utils folder (#81746)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81746
Approved by: https://github.com/anijain2305, https://github.com/Krovatkin
2022-07-20 23:43:57 +00:00
PyTorch MergeBot
e43a02c314 Revert "Refactored prim utils into _prims_utils folder (#81088)"
This reverts commit 80231d0a72.

Reverted https://github.com/pytorch/pytorch/pull/81088 on behalf of https://github.com/jeanschmidt due to breaking internal tests
2022-07-19 19:56:41 +00:00
Horace He
80231d0a72 Refactored prim utils into _prims_utils folder (#81088)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81088
Approved by: https://github.com/ngimel
2022-07-19 03:55:51 +00:00