Commit Graph

294 Commits

Author SHA1 Message Date
Nikita Karetnikov
76af71444a [primTorch] Add ref for complex (#88562)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88562
Approved by: https://github.com/ezyang
2022-11-13 20:31:16 +00:00
Ivan Yashchuk
69b2352236 Add min cut partitioner for AOT+nvFuser (#88204)
Here we mark most of `torch.ops.nvprims` as something that can be recomputed in the backward passes (and hopefully fused).

TODO:
- [x] Add a test after https://github.com/pytorch/pytorch/pull/88186 is merged

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88204
Approved by: https://github.com/jjsjann123, https://github.com/jansel
2022-11-09 12:56:55 +00:00
Kurt Mohler
ee28b865ee Deprecate TypedStorage, its derived classes, and all of their public methods (#85303)
Part of #85302

Pull Request resolved: https://github.com/pytorch/pytorch/pull/85303
Approved by: https://github.com/ezyang
2022-11-08 18:11:01 +00:00
jjsjann123
af09270e10 nvprims bookend non compute (#88457)
Cherry-pickeding: https://github.com/csarofeen/pytorch/pull/2099

1. enabling bookend non-compute-ops pass on nvfuser
2. fixing bookend op check on intermediate tensor as partition inputs
3. python tests added for: `getitem` special handling bookend_non_compute removal
4. patching dfs by excluding dfs within partition to avoid going over recursion limitation
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88457
Approved by: https://github.com/SherlockNoMad
2022-11-08 12:06:35 +00:00
jjsjann123
52375a0fd2 nvprims native batch norm patch (#88455)
Cherry-picking: https://github.com/csarofeen/pytorch/pull/2104

- [x] Added explicit cast on inputs to nvprims.native_batch_norm. This avoids the explicit cast, which gives us issue on fusion definition.
- [x] add python repro with dynamo

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88455
Approved by: https://github.com/mruberry, https://github.com/IvanYashchuk
2022-11-05 02:22:27 +00:00
jjsjann123
79abea5683 nvprim python runtime dtype correctness patch (#88452)
Cherry-picking: https://github.com/csarofeen/pytorch/pull/2133

- [x] casts FusionDefinition output to original dtype recorded in the GraphModule
- [x] add a python repro with dynamo

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88452
Approved by: https://github.com/IvanYashchuk, https://github.com/mruberry
2022-11-04 19:17:07 +00:00
Ivan Yashchuk
4a8382b58e Update caching of tensor arguments for nvFuser's fusion creation (#87860)
Previously nvFuser's fusion definition was cached based on concrete shape and strides of tensor inputs for simplicity and correctness. This PR changes Python's cache to check the number of dimensions, size-1 dimensions, and contiguity information based on given strides and shapes.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/87860
Approved by: https://github.com/kevinstephano, https://github.com/jjsjann123, https://github.com/ngimel
2022-11-02 09:29:20 +00:00
Kevin Stephano
8ef9bda1bf Fix nvFuser Fusion Definition printing of Squeeze and Permute (#88041)
NM

cc @jjsjann123
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88041
Approved by: https://github.com/IvanYashchuk, https://github.com/jjsjann123, https://github.com/mruberry
2022-11-01 19:02:40 +00:00
Kevin Stephano
323c646ca9 Cleaned up the nvFuser Python Frontend Batch Norm printing (#88057)
* Removed `define_null_tensor` usage in favor of using optional arguments for binding.
* Re-ordered the non-State arguments for easier printing.
* Added a printing function to include booleans `training` and `channels_last`
* Fixed `define_tensor` to print `is_cpu`

cc @jjsjann123
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88057
Approved by: https://github.com/IvanYashchuk, https://github.com/jjsjann123, https://github.com/mruberry
2022-11-01 05:05:15 +00:00
Sherlock Huang
5723fd503c Fix meta function for aten.flip and aten.rot90 (#88065)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88065
Approved by: https://github.com/mruberry
2022-10-31 16:52:05 +00:00
Edward Z. Yang
1ff52225f1 Unify SymIntNode and SymFloatNode into SymNode (#87817)
This refactor was prompted by challenges handling mixed int/float
operations in C++.  A previous version of this patch
added overloads for each permutation of int/float and was unwieldy
https://github.com/pytorch/pytorch/pull/87722/  This PR takes a different
approach.

The general outline of the patch is to combine the C++ types SymIntNode
and SymFloatNode into a single type, SymNode.  This is type erased; we
no longer know statically at C++ if we have an int/float and have to test
it with the is_int()/is_float() virtual methods.  This has a number of
knock on effects.

- We no longer have C++ classes to bind to Python.  Instead, we take an
  entirely new approach to our Python API, where we have a SymInt/SymFloat
  class defined entirely in Python, which hold a SymNode (which corresponds
  to the C++ SymNode).  However, SymNode is not pybind11-bound; instead,
  it lives as-is in Python, and is wrapped into C++ SymNode using PythonSymNode
  when it goes into C++.  This implies a userland rename.

  In principle, it is also possible for the canonical implementation of SymNode
  to be written in C++, and then bound to Python with pybind11 (we have
  this code, although it is commented out.)  However, I did not implement
  this as we currently have no C++ implementations of SymNode.

  Because we do return SymInt/SymFloat from C++ bindings, the C++ binding
  code needs to know how to find these classes.  Currently, this is done
  just by manually importing torch and getting the attributes.

- Because SymInt/SymFloat are easy Python wrappers, __sym_dispatch__ now
  takes SymInt/SymFloat, rather than SymNode, bringing it in line with how
  __torch_dispatch__ works.

Some miscellaneous improvements:

- SymInt now has a constructor that takes SymNode.  Note that this
  constructor is ambiguous if you pass in a subclass of SymNode,
  so an explicit downcast is necessary.  This means toSymFloat/toSymInt
  are no more.  This is a mild optimization as it means rvalue reference
  works automatically.

- We uniformly use the caster for c10::SymInt/SymFloat, rather than
  going the long way via the SymIntNode/SymFloatNode.

- Removed some unnecessary toSymInt/toSymFloat calls in normalize_*
  functions, pretty sure this doesn't do anything.

- guard_int is now a free function, since to guard on an int you cannot
  assume the method exists.  A function can handle both int and SymInt
  inputs.

- We clean up the magic method definition code for SymInt/SymFloat/SymNode.
  ONLY the user classes (SymInt/SymFloat) get magic methods; SymNode gets
  plain methods; this is to help avoid confusion between the two types.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

cc @jansel @mlazos @soumith @voznesenskym @yanboliang @penguinwu @anijain2305
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87817
Approved by: https://github.com/albanD, https://github.com/anjali411
2022-10-27 20:56:02 +00:00
Ivan Yashchuk
ae4fbac819 Enable nvprims.transpose fusions for nvFuser (#86967)
This PR allows transposes to be fused with other operations. If a fusion group is formed only from operations that just manipulate metadata in PyTorch (transpose, view, etc.) then this group is not sent to nvFuser.
On top of that if we have converted to `nvprims` but then decided to not form a fusion group we modify the graph use `prim.impl_aten` attribute instead of calling `prim(*args, **kwargs)` that has a higher overhead.

cc @kevinstephano @jjsjann123
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86967
Approved by: https://github.com/jjsjann123, https://github.com/SherlockNoMad
2022-10-26 17:00:07 +00:00
Ivan Yashchuk
72f446b9bc Remove getitem special handling in the partitioner (#87073)
This special handling of getitem unnecessary splits fusions at functions with tuple outputs.

Example script:
```py
import torch
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
from torch._prims.nvfuser_executor import NvfuserPrimOperatorSupport
from torch.fx.experimental.proxy_tensor import make_fx

def func(x):
    xx = torch.ops.nvprims.add(x, 1)
    var, mean = torch.ops.nvprims.var_mean(x, correction=0)
    var_cos = torch.ops.nvprims.cos(var)
    mean_sin = torch.ops.nvprims.sin(mean)
    return torch.ops.nvprims.add(var_cos, mean_sin)

a = torch.randn(5, 3, 3, device="cuda")
gm = make_fx(func)(a)
gm.graph.print_tabular()

supported_ops = NvfuserPrimOperatorSupport()
partitioner = CapabilityBasedPartitioner(
    gm, supported_ops, allows_single_node_partition=False
)
partitions = partitioner.propose_partitions()
print(partitions)
partitioned_graph = partitioner.fuse_partitions(partitions)
partitioned_graph.graph.print_tabular()
```
Output on master:
```py
opcode         name       target                       args              kwargs
-------------  ---------  ---------------------------  ----------------  -----------------
placeholder    x_1        x_1                          ()                {}
call_function  add        nvprims.add.default          (x_1, 1)          {}
call_function  var_mean   nvprims.var_mean.main        (x_1, [0, 1, 2])  {'correction': 0}
call_function  getitem    <built-in function getitem>  (var_mean, 0)     {}
call_function  getitem_1  <built-in function getitem>  (var_mean, 1)     {}
call_function  cos        nvprims.cos.default          (getitem,)        {}
call_function  sin        nvprims.sin.default          (getitem_1,)      {}
call_function  add_1      nvprims.add.default          (cos, sin)        {}
output         output     output                       (add_1,)          {}
[{cos, sin, add_1}, {var_mean, add, getitem, getitem_1}]
opcode         name       target                       args                    kwargs
-------------  ---------  ---------------------------  ----------------------  --------
placeholder    x_1        x_1                          ()                      {}
call_module    fused_1    fused_1                      (x_1,)                  {}
call_function  getitem_2  <built-in function getitem>  (fused_1, 0)            {}
call_function  getitem_3  <built-in function getitem>  (fused_1, 1)            {}
call_module    fused_0    fused_0                      (getitem_2, getitem_3)  {}
output         output     output                       (fused_0,)              {}
```
Output with this PR:
```
[{var_mean, add_1, cos, sin, add, getitem_1, getitem}]
opcode       name     target    args        kwargs
-----------  -------  --------  ----------  --------
placeholder  x_1      x_1       ()          {}
call_module  fused_0  fused_0   (x_1,)      {}
output       output   output    (fused_0,)  {}
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/87073
Approved by: https://github.com/jjsjann123, https://github.com/SherlockNoMad
2022-10-26 14:18:46 +00:00
Ivan Yashchuk
ff2569bc8c Intercept aten._reshape_alias for nvFuser (#87072)
This would help forming larger fusion groups. If this won't end up executed by nvFuser then eager mode implementation would call into `.reshape`: 37e9e89afb/torch/_prims/nvfuser_prims.py (L552-L553)

cc @kevinstephano @jjsjann123
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87072
Approved by: https://github.com/ngimel
2022-10-25 21:53:12 +00:00
Sherlock Huang
0b162f5b49 Fix stride for prims.where (#87563)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87563
Approved by: https://github.com/ngimel, https://github.com/mruberry
2022-10-25 21:22:50 +00:00
PyTorch MergeBot
5308886ec3 Revert "Intercept aten._reshape_alias for nvFuser (#87072)"
This reverts commit 163a829caa.

Reverted https://github.com/pytorch/pytorch/pull/87072 on behalf of https://github.com/malfet due to Looks like it broke test_indexing in dynamo shard, see https://github.com/pytorch/pytorch/actions/runs/3318778609/jobs/5483248042
2022-10-25 14:45:14 +00:00
Ivan Yashchuk
163a829caa Intercept aten._reshape_alias for nvFuser (#87072)
This would help forming larger fusion groups. If this won't end up executed by nvFuser then eager mode implementation would call into `.reshape`: 37e9e89afb/torch/_prims/nvfuser_prims.py (L552-L553)

cc @kevinstephano @jjsjann123
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87072
Approved by: https://github.com/ngimel
2022-10-25 06:56:02 +00:00
Soof Golan
874a94ce94 Fix tensor.stride() type hint (#84177)
`tensor.stride()` now hints at tuple of variable length instead of tuple with constant length of 1

Fixes #84176

Pull Request resolved: https://github.com/pytorch/pytorch/pull/84177
Approved by: https://github.com/Chillee
2022-10-25 04:43:10 +00:00
Edward Z. Yang
d73d4aa7de Audit for error prone isinstance int/float and add lint (#87345)
We recently fixed a bug on symbolic-shapes branch where
an isinstance(x, int) test failed when passed a SymIntNode.
To prevent this, I've added a lint for all the codepaths
where we may pass SymInt/SymFloat directly to reject
direct isinstance int/float tests, and instead use one of
the aliases.  The lint rule explains the options.  I then
go and fix all of them.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87345
Approved by: https://github.com/bdhirsh, https://github.com/albanD
2022-10-21 15:55:24 +00:00
Nikita Karetnikov
841995d53b [primTorch] Add refs for data conversion ops (#86561)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86561
Approved by: https://github.com/lezcano, https://github.com/mruberry, https://github.com/zou3519
2022-10-18 08:38:51 +00:00
Ivan Yashchuk
31931515bc Workarounds for cudnn_batch_norm with TorchRefsNvfuserCapabilityMode (#86796)
This PR adds workarounds to support AOT Autograd's graphs containing `aten.cudnn_batch_norm` and `aten.cudnn_batch_norm_backward` with `TorchRefsNvfuserCapabilityMode`.

The problem with the decomposition of `aten.cudnn_batch_norm` is that it uses a `new_empty` call that is not supported by nvFuser and we are conservative with lowering functions to nvprims by default.

The problem with the decomposition of `aten.cudnn_batch_norm_backward` is described here https://github.com/pytorch/pytorch/pull/86115#issue-1394883782, but changing the decomposition directly in that PR makes many tests fail.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86796
Approved by: https://github.com/mruberry
2022-10-17 18:46:28 +00:00
Ivan Yashchuk
fc3afc8407 Remove empty_like+fill from AOT Autograd graphs for nvFuser (#86908)
AOT Autograd records C++ code `1 - tensor` as a sequence of empty_like, fill, and sub (see https://github.com/pytorch/pytorch/issues/86612).

Both empty_like and fill are not supported yet. This PR is a workaround for enabling fusions of `silu_backward`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86908
Approved by: https://github.com/ngimel
2022-10-14 19:49:39 +00:00
Ivan Yashchuk
fd80684784 Add nvFuser support for torch.Tensor.view (#84634)
This is an alternative to https://github.com/pytorch/pytorch/pull/83739. While PrimTorch has `view` as a reference, we would like to use nvFuser's implementation for `view` for now. Later we might transition to PrimTorch's `torch._refs.view`.

See `test_nvprims_view` for examples of things that are now sent to nvFuser. Note that nvFuser's `view` is a copy-like operation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84634
Approved by: https://github.com/kevinstephano, https://github.com/mruberry
2022-10-14 12:08:02 +00:00
Brian Hirsh
6907db3f95 fix aliasing for primtorch view meta kernels (#86285)
Fixes https://github.com/pytorch/pytorch/issues/86284

Pull Request resolved: https://github.com/pytorch/pytorch/pull/86285
Approved by: https://github.com/albanD, https://github.com/mruberry
2022-10-13 14:14:20 +00:00
Khushi Agrawal
77d29bcee2 [primTorch] special: ndtr, ndtri, log_ndtr, erfcx (#86077)
- Adds prims and _refs for `erfcx` and `ndtri`.
- Adds _refs for `ndtr`, and `log_ndtr`.

cc @kshitij12345 @lezcano @mruberry
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86077
Approved by: https://github.com/mruberry
2022-10-13 01:18:30 +00:00
Ivan Yashchuk
cd7c86eaa4 Add prims.clone (#86705)
This simple PR adds `clone` as a primitive.
Current implementation of `clone` is not supported with nvFuser executor because of `empty_like` + `copy_to`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86705
Approved by: https://github.com/mruberry
2022-10-12 18:22:00 +00:00
Kevin Stephano
b14f1d7bb8 Add Skip List for Aten Ops that are fused in nvFuser. (#86101)
This Skip List (tuple) is added under the nvprims context manager.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/86101
Approved by: https://github.com/jjsjann123, https://github.com/mruberry
2022-10-07 03:55:13 +00:00
Khushi
d6b030856b [primTorch] special: j0, j1, spherical_j0 (#86049)
Adds prims and refs for special functions (bessel_j0, bessel_j1, spherical_bessel_j0). Thanks!

Pull Request resolved: https://github.com/pytorch/pytorch/pull/86049
Approved by: https://github.com/mruberry
2022-10-04 18:21:46 +00:00
Ivan Yashchuk
68a6113248 Add nvFuser support for torch.native_batch_norm (#85562)
This PR adds nvFuser's implementation for batch_norm as there's no reference yet (https://github.com/pytorch/pytorch/pull/81191) and no in-place copy support (https://github.com/pytorch/pytorch/pull/84545).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/85562
Approved by: https://github.com/kevinstephano, https://github.com/ngimel
2022-10-03 15:03:08 +00:00
jjsjann123
fd553c46f4 nvprim op support runtime checks on dtype compatibility on prims.convert_element_type (#85566)
I'm seeing issue that we lower `_to_copy` into `nvprims.convert_element_type`. In cases where we are casting to a dtype that's not supported by nvfuser, this raise runtime error.

I added a quick check in the lowering part where each op can peek at fx.node and make a runtime decision on whether the given op should be lowered to nvprim.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85566
Approved by: https://github.com/IvanYashchuk, https://github.com/ngimel
2022-09-30 23:19:25 +00:00
Ivan Yashchuk
b00a5359f7 Add a way to skip lowering to nvprims (#85811)
This PR adds `skip_ops` argument to `TorchRefsNvfuserCapabilityMode` and `NvfuserPrimsMode` which is an iterable of function names to be skipped in the translation to nvprims process.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/85811
Approved by: https://github.com/mruberry, https://github.com/jjsjann123
2022-09-30 12:01:45 +00:00
Kevin Stephano
6004c65af8 Fix rand_like nvprim meta function. (#85882)
Really minor fix necessary to work with TorchDynamo.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85882
Approved by: https://github.com/mruberry, https://github.com/jjsjann123
2022-09-29 23:06:15 +00:00
jjsjann123
cab6ffa0f7 catches failure on nvprim speculative lowering (#85580)
Fixes #85517

Added a try/catch exception during tracing `get_isolated_graphmodule` inside `_is_func_unsupported_nvfuser`. Stops speculative lowering to nvprim when query errors out.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85580
Approved by: https://github.com/mruberry, https://github.com/IvanYashchuk
2022-09-29 15:22:45 +00:00
samdow
18d8c548f4 [Modes] remove enable and rewrite mode stack (squashed) (#84774)
Based on @ezyang's suggestion, mode stack now has "one true mode" which is the _only_ mode that can ever be active at the C++ level. That mode's torch dispatch is just to take the top mode in the stack, reenable itself (if we aren't at the end of the mode stack), and run the top mode's torch_{dispatch|function}

This maintains that in the middle of a mode's torch dispatch, the mode itself will not be active. It changes the function the user has to call to see what the current mode is (no longer queries the C++, it's python only) but allows the user to also see the entire mode stack easily

Removes `enable_torch_dispatch_mode` and `.restore()` since neither makes sense in this new setup

### Background
Why do we want this? Well, a pretty common pattern that was coming up was that users had to do something like

```python
## PRE-PR UX
def f(mode):
  with mode.restore():  # user needs to understand this restore thing?
    ...

with Mode() as m:
  pass
f(m)
```

Many users were getting error from forgetting to call `.restore` or from forgetting to add the (tbh weird) "mode instantiation"  step where they use the mode as a context manager with an empty body. Really, they wanted to treat modes like context managers and just write
```python
## FROM FEEDBACK, USER DESIRED CODE. POSSIBLE POST-PR
def f(mode):
  with mode:
    ...
f(Mode())
```

** Technical Details **
With the old mode stack, we basically had a linked list so the mode itself could only be used once and had a fixed parent. In this new design, the mode stack is just a python list that we're pushing to and popping from. There's only one mode that's ever active at the C++ level and it runs the next mode in the Python list. The modes don't have state on them anymore
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84774
Approved by: https://github.com/ezyang, https://github.com/zou3519
2022-09-27 01:04:35 +00:00
Kevin Stephano
c7b17d7eb1 Add nvprims rand_like support for Dropout (#85077)
NM
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85077
Approved by: https://github.com/IvanYashchuk, https://github.com/mruberry
2022-09-23 18:03:35 +00:00
Ivan Yashchuk
35943f30cb Reference implementation for torch.Tensor.sum_to_size (#85338)
New ref: `torch._refs.sum_to_size`.

View consistency validation is disabled because the ref returns a view instead of returning the input.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85338
Approved by: https://github.com/mruberry
2022-09-21 18:12:52 +00:00
Ivan Yashchuk
308b26fe4d Add nvFuser support for transpose (#84629)
`torch._refs.t`, `torch._refs.transpose`, `torch._refs.permute` are all should be working now with nvFuser executor. It would also work with graphs processed by AOT Autograd as these functions are registered to the aten->ref mapping via the "register_decomposition" decorator:
07d398fb26/torch/_refs/__init__.py (L3125-L3126)
07d398fb26/torch/_refs/__init__.py (L3143-L3144)
07d398fb26/torch/_refs/__init__.py (L2548-L2549)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84629
Approved by: https://github.com/ngimel
2022-09-21 12:45:15 +00:00
Mike Ruberry
d561aa944b Adds normal prim, randn reference, and randn OpInfo (#85128)
This PR extends prims support for random operations by adding `prims.normal` and `refs.randn`. Note that in the future we may not want to model draws from distributions as their own prims.

`prims.normal` accepts a shape and the mean and standard deviation of a normal distribution as numbers. This is distinct from `torch.normal` which takes two tensors so every generated datapoint can be drawn from a normal distribution with its own mean and standard deviation. To address this @ngimel and I expect to add `prims.normal_with_tensors`. The current `prims.normal` could be implemented using `prims.normal_with_tensors`, but we expect the case of two numbers is much more common, and that executors will likely want to specialize for it, anyway.

In a follow-up PR I plan to add `refs.randn_like`, `prims.normal_with_tensors` (as mentioned above), and `refs.normal`.

While writing this PR I noticed the following issues:

- https://github.com/pytorch/pytorch/issues/85123
- https://github.com/pytorch/pytorch/issues/85121

The latter of which is prohibiting some testing.

In future PRs I plan to add a prim for changing layout, add support for pinned memory, and improve support for testing tensor creation operators, likely with a TensorCreationOpInfo class.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85128
Approved by: https://github.com/ngimel
2022-09-19 10:32:41 +00:00
Horace He
4bdc0af53d Added support for symbolic is_contiguous (#84829)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84829
Approved by: https://github.com/ezyang
2022-09-16 04:54:01 +00:00
Ivan Yashchuk
d802fcfcd8 Add config to PrimTorch's nvFuser executor (#84482)
This PR adds `executor_parameters` keyword argument to `torch._prims.executor.execute`.

For now there are two knobs:
* `use_python_fusion_cache: bool = True` whether to use lru_cache when constructing fusion object or not.
* `allow_single_op_fusion: bool = True` whether to allow fusions with single callable

Behavior can be controlled by passing dict with custom specified values as `executor_parameters` argument.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84482
Approved by: https://github.com/jjsjann123, https://github.com/ngimel
2022-09-09 07:58:21 +00:00
jjsjann123
1a33e944b5 nvfuser torchbench patch (#84411)
1. Patching nvfuser_execute to take aten nvprim fallback when no cuda tensors are provided as inputs
2. Extending support of nvfuser python API on cpu scalar tensor.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84411
Approved by: https://github.com/ngimel, https://github.com/kevinstephano, https://github.com/IvanYashchuk
2022-09-07 05:22:37 +00:00
Ivan Yashchuk
edab44f6dd Support a few corner cases for nvFuser executor (#84416)
This PR adds asserts to the `nvfuser_execute` function for the cases that do not work. Fallback to eager is used in those cases.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84416
Approved by: https://github.com/jjsjann123, https://github.com/ngimel
2022-09-05 08:49:01 +00:00
Edward Z. Yang
2a332afbf4 Add SymFloat, support SymInt to SymFloat conversion (#84284)
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84284
Approved by: https://github.com/albanD
2022-09-03 01:30:32 +00:00
PyTorch MergeBot
2d969dc2ca Revert "Support a few corner cases for nvFuser executor (#84416)"
This reverts commit 3db3845f5f.

Reverted https://github.com/pytorch/pytorch/pull/84416 on behalf of https://github.com/malfet due to Broke both trunk and pull, see 3db3845f5f
2022-09-02 17:40:17 +00:00
Ivan Yashchuk
3db3845f5f Support a few corner cases for nvFuser executor (#84416)
This PR adds asserts to the `nvfuser_execute` function for the cases that do not work. Fallback to eager is used in those cases.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84416
Approved by: https://github.com/jjsjann123, https://github.com/ngimel
2022-09-02 14:57:05 +00:00
PyTorch MergeBot
0fd173b097 Revert "Support a few corner cases for nvFuser executor (#84416)"
This reverts commit 3ac9f6683d.

Reverted https://github.com/pytorch/pytorch/pull/84416 on behalf of https://github.com/IvanYashchuk due to trunk CI is failing due to sneaked in print_tabular() call
2022-09-02 10:45:41 +00:00
Ivan Yashchuk
3ac9f6683d Support a few corner cases for nvFuser executor (#84416)
This PR adds asserts to the `nvfuser_execute` function for the cases that do not work. Fallback to eager is used in those cases.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84416
Approved by: https://github.com/jjsjann123, https://github.com/ngimel
2022-09-02 06:42:39 +00:00
Elias Ellison
9c452abcf1 Use reentrant mode when invoking prims, delete global prim_fake_mode (#84090)
Maybe I should be using the meta_impl instead of the prim_impl, but it's not terribly clear why, since the prim impl will be better tested and should work under the re-entrant FakeTensorMode.

Fixes https://github.com/pytorch/pytorch/issues/78613 in the process
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84090
Approved by: https://github.com/ezyang, https://github.com/samdow
2022-08-31 01:58:44 +00:00
Ivan Yashchuk
90161c23cf Add nvfuser support for squeeze (#84117)
"_refs.squeeze" and "refs.unsqueeze" now work with nvfuser executor tests.

Similarly to `_refs.reshape` we need to explicitly save the concrete shape on the trace to pass that info to nvfuser, as it gets lost in translation (https://github.com/pytorch/pytorch/pull/83739#discussion_r950352124).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84117
Approved by: https://github.com/ngimel
2022-08-30 20:36:11 +00:00
Ivan Yashchuk
3aae6ff1e1 Add nvprims.var_mean (#83508)
This PR adds nvfuser-specific primitive - `var_mean`.
Interpretation `torch.var_mean` -> `torch.ops.nvprims.var_mean` is handled by `TorchRefsNvfuserCapabilityMode` context manager.

I moved some helper code from `_prims/__init__.py` to `_prims_common`. Correctness is tested with OpInfo tests (see `PythonRefInfo("ops.nvprims.var_mean"`).

Layer norm reference now uses `torch.var_mean` instead of `torch._refs.var_mean` to allow interception. Here's a simple comparison of performance with this PR and master (on 3080ti):
```py
import torch
from torch._prims.context import TorchRefsNvfuserCapabilityMode
from torch.fx.experimental.proxy_tensor import make_fx
from torch._prims.executor import execute

def func(a):
    return torch.native_layer_norm(a, (1024,), None, None, 1e-6)

a = torch.randn(10, 512, 1024, dtype=torch.float16, device="cuda")

with TorchRefsNvfuserCapabilityMode():
    gm = make_fx(func)(a)

for _ in range(10):
    execute(gm, a, executor="strictly_nvfuser");
```
run with `PYTORCH_NVFUSER_DUMP=dump_eff_bandwidth python script.py`
```py
# WITH THIS PR
# kernel1 run in 0.032768 ms, achieved: 641.25 GB/s
# kernel1 run in 0.033792 ms, achieved: 621.818 GB/s
# kernel1 run in 0.032768 ms, achieved: 641.25 GB/s
# kernel1 run in 0.032608 ms, achieved: 644.396 GB/s
# kernel1 run in 0.031744 ms, achieved: 661.935 GB/s
# kernel1 run in 0.031744 ms, achieved: 661.935 GB/s
# kernel1 run in 0.032768 ms, achieved: 641.25 GB/s
# kernel1 run in 0.03072 ms, achieved: 684 GB/s
# kernel1 run in 0.031744 ms, achieved: 661.935 GB/s
# kernel1 run in 0.031744 ms, achieved: 661.935 GB/s

# ON MASTER
# kernel1 run in 0.05632 ms, achieved: 373.091 GB/s
# kernel1 run in 0.044032 ms, achieved: 477.209 GB/s
# kernel1 run in 0.044032 ms, achieved: 477.209 GB/s
# kernel1 run in 0.044032 ms, achieved: 477.209 GB/s
# kernel1 run in 0.043808 ms, achieved: 479.649 GB/s
# kernel1 run in 0.043008 ms, achieved: 488.571 GB/s
# kernel1 run in 0.044032 ms, achieved: 477.209 GB/s
# kernel1 run in 0.043008 ms, achieved: 488.571 GB/s
# kernel1 run in 0.043008 ms, achieved: 488.571 GB/s
# kernel1 run in 0.043008 ms, achieved: 488.571 GB/s
```
So this PR gives about 35% improvement in performance using nvfuser executor with this specific normalized shape.

Also this PR fixes https://github.com/pytorch/pytorch/issues/83506 (see the change in `torch/csrc/jit/python/pybind_utils.cpp`).

Ref. https://github.com/pytorch/pytorch/issues/80187

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83508
Approved by: https://github.com/ngimel
2022-08-28 18:45:25 +00:00
PyTorch MergeBot
b159a5230f Revert "Add nvprims.var_mean (#83508)"
This reverts commit 7e7694b661.

Reverted https://github.com/pytorch/pytorch/pull/83508 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally
2022-08-28 11:30:27 +00:00
jjsjann123
b078d242c4 Nvfuser to copy decomp to prim (#83782)
Conditional decomposing aten::_to_copy to nvprim::convert_element_type to allow fusion with type casting, which is introduced during type promotion phase at torch decomposition.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83782
Approved by: https://github.com/ngimel
2022-08-28 04:26:36 +00:00
Ivan Yashchuk
7e7694b661 Add nvprims.var_mean (#83508)
This PR adds nvfuser-specific primitive - `var_mean`.
Interpretation `torch.var_mean` -> `torch.ops.nvprims.var_mean` is handled by `TorchRefsNvfuserCapabilityMode` context manager.

I moved some helper code from `_prims/__init__.py` to `_prims_common`. Correctness is tested with OpInfo tests (see `PythonRefInfo("ops.nvprims.var_mean"`).

Layer norm reference now uses `torch.var_mean` instead of `torch._refs.var_mean` to allow interception. Here's a simple comparison of performance with this PR and master (on 3080ti):
```py
import torch
from torch._prims.context import TorchRefsNvfuserCapabilityMode
from torch.fx.experimental.proxy_tensor import make_fx
from torch._prims.executor import execute

def func(a):
    return torch.native_layer_norm(a, (1024,), None, None, 1e-6)

a = torch.randn(10, 512, 1024, dtype=torch.float16, device="cuda")

with TorchRefsNvfuserCapabilityMode():
    gm = make_fx(func)(a)

for _ in range(10):
    execute(gm, a, executor="strictly_nvfuser");
```
run with `PYTORCH_NVFUSER_DUMP=dump_eff_bandwidth python script.py`
```py
# WITH THIS PR
# kernel1 run in 0.032768 ms, achieved: 641.25 GB/s
# kernel1 run in 0.033792 ms, achieved: 621.818 GB/s
# kernel1 run in 0.032768 ms, achieved: 641.25 GB/s
# kernel1 run in 0.032608 ms, achieved: 644.396 GB/s
# kernel1 run in 0.031744 ms, achieved: 661.935 GB/s
# kernel1 run in 0.031744 ms, achieved: 661.935 GB/s
# kernel1 run in 0.032768 ms, achieved: 641.25 GB/s
# kernel1 run in 0.03072 ms, achieved: 684 GB/s
# kernel1 run in 0.031744 ms, achieved: 661.935 GB/s
# kernel1 run in 0.031744 ms, achieved: 661.935 GB/s

# ON MASTER
# kernel1 run in 0.05632 ms, achieved: 373.091 GB/s
# kernel1 run in 0.044032 ms, achieved: 477.209 GB/s
# kernel1 run in 0.044032 ms, achieved: 477.209 GB/s
# kernel1 run in 0.044032 ms, achieved: 477.209 GB/s
# kernel1 run in 0.043808 ms, achieved: 479.649 GB/s
# kernel1 run in 0.043008 ms, achieved: 488.571 GB/s
# kernel1 run in 0.044032 ms, achieved: 477.209 GB/s
# kernel1 run in 0.043008 ms, achieved: 488.571 GB/s
# kernel1 run in 0.043008 ms, achieved: 488.571 GB/s
# kernel1 run in 0.043008 ms, achieved: 488.571 GB/s
```
So this PR gives about 35% improvement in performance using nvfuser executor with this specific normalized shape.

Also this PR fixes https://github.com/pytorch/pytorch/issues/83506 (see the change in `torch/csrc/jit/python/pybind_utils.cpp`).

Ref. https://github.com/pytorch/pytorch/issues/80187

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83508
Approved by: https://github.com/ngimel
2022-08-27 09:05:20 +00:00
joncrall
b136f3f310 More doctest refinements. (#83317)
Follow up to #82797

Now that the doctests themselves are in a better state, we should be able to enable xdoctest on the CI so they stay that way.

@ezyang @vadimkantorov
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83317
Approved by: https://github.com/ezyang
2022-08-22 20:07:26 +00:00
Ivan Yashchuk
108a1fb173 Avoid using fx.Interpreter in nvfuser executor function (#83607)
Using fx.Interpreter is a nice way of modifying the calls inside of FX graphs, but it introduces unnecessary overhead in this case.

Example:
```py
import torch
from torch.fx.experimental.proxy_tensor import make_fx
from torch._prims.context import TorchRefsNvfuserCapabilityMode
from torch._prims.executor import execute
a = torch.randn(3, 2, dtype=torch.float16, device="cuda")
s = torch.sigmoid
d = torch.digamma  # digamma is not supported in nvfuser and aten eager execution is used
def func(a):
    return s(d(s(d(s(d(s(a)))))))
with TorchRefsNvfuserCapabilityMode():
    gm = make_fx(func)(a)

%%timeit
execute(gm, a, executor="nvfuser"); torch.cuda.synchronize();
# On master: 350 µs
# With this PR: 130 µs
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83607
Approved by: https://github.com/ezyang
2022-08-19 16:20:35 +00:00
Edward Z. Yang
6679d238fd SymInt'ify schemas for prims (#83528)
I audited these looking for places where ints were accepted for sizes
and turned them into SymInts.  Dimensions and miscellaneous ints were
not modified.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83528
Approved by: https://github.com/ngimel
2022-08-18 02:00:50 +00:00
Ivan Yashchuk
9f03444f70 Add torch.ops.aten -> torch._refs mapping to TorchRefsMode using decomposition_table (#82657)
### Description
This PR adds the possibility to convert `torch.ops.aten` calls to `torch._refs` and consequently prims under TorchRefsMode.

### Testing
New test, `test_aten_overload_to_prims`, in `test/test_prims.py`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/82657
Approved by: https://github.com/jjsjann123, https://github.com/ezyang
2022-08-17 14:46:06 +00:00
Fabio Rocha
2a096e940d [primTorch] support for a few magic methods (#83524)
Added support for mapping __rsub__, __rtruediv__,
__rfloordiv__, __floordiv__, __pow__,
and __rpow__ in TorchRefsMode.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83524
Approved by: https://github.com/ngimel
2022-08-17 09:48:15 +00:00
Edward Z. Yang
e09821f784 Avoid using true division in split_dim (#83527)
This makes it more amenable to tracing with dynamic shapes,
where we don't support SymFloats yet.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83527
Approved by: https://github.com/ngimel
2022-08-17 04:19:29 +00:00
Nikita Karetnikov
4010f96121 [primTorch] Fix off by 1 in canonicalize_dim (#83198)
Also fix an issue in the `unsqueeze` ref due to this change.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83198
Approved by: https://github.com/ngimel
2022-08-16 17:57:01 +00:00
joncrall
4618371da5 Integrate xdoctest - Rebased (#82797)
This is a new version of #15648 based on the latest master branch.

Unlike the previous PR where I fixed a lot of the doctests in addition to integrating xdoctest, I'm going to reduce the scope here. I'm simply going to integrate xdoctest, and then I'm going to mark all of the failing tests as "SKIP". This will let xdoctest run on the dashboards, provide some value, and still let the dashboards pass. I'll leave fixing the doctests themselves to another PR.

In my initial commit, I do the bare minimum to get something running with failing dashboards. The few tests that I marked as skip are causing segfaults. Running xdoctest results in 293 failed, 201 passed tests. The next commits will be to disable those tests. (unfortunately I don't have a tool that will insert the `#xdoctest: +SKIP` directive over every failing test, so I'm going to do this mostly manually.)

Fixes https://github.com/pytorch/pytorch/issues/71105

@ezyang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82797
Approved by: https://github.com/ezyang
2022-08-12 02:08:01 +00:00
Ivan Yashchuk
7191ae58a7 Add nvfuser support for prims.sign and refs.sign (#83167)
This short PR adds nvFuser support for `prims.sign` and consequently `refs.sign`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83167
Approved by: https://github.com/ngimel
2022-08-11 10:58:32 +00:00
Peter Bell
a17211f79b Fix prims.div to return the correct dtype (#82949)
The documentation for prims.div says that it truncates integer inputs,
but it actually did `true_divide` for all Tensor inputs. This fixes it
to actually use truncation division for integers.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82949
Approved by: https://github.com/ngimel
2022-08-09 04:38:17 +00:00
Ivan Yashchuk
ea39146507 Add a common wrapper for make_fx to handle args and kwargs (#82965)
Added a helper function `wrapper_and_args_for_make_fx` since a number of places we use `make_fx` internally might grow.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82965
Approved by: https://github.com/ngimel
2022-08-08 17:03:21 +00:00
Peter Bell
4f255dbfb3 Remove manual bindings for arange (#81380)
The functional variant of one of the `arange` overloads has a schema mismatch with the out variant. The functional one has `Scalar step`, but the corresponding out variant has `Scalar step=1`. This isn't allowed, so it had to be special-cased in the python codegen and manually bound. This adds the default `step` value to the functional overload and removes the special-casing.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/81380
Approved by: https://github.com/ngimel
2022-08-07 00:10:27 +00:00
Ivan Yashchuk
ec67c6abbe Add torch.ops.nvprims namespace for nvFuser-specific prims (#82155)
New namespace `torch.ops.nvprims` is meant for specific to the nvFuser set of primitives. All `impl_nvfuser` attributes are removed from `torch.ops.prims` functions.

`NvfuserPrimsMode()` context manager can be used for automatic rewrite of `torch.ops.prims` calls to `torch.ops.nvprims` when possible.

The previous way to test whether a prim would be executable with nvFuser was to test `impl_nvfuser is not None`, now all functions in the `torch.ops.nvprims` namespace are supposed to have the `impl_nvfuser` attribute and hence all are executable by nvFuser.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/82155
Approved by: https://github.com/jjsjann123, https://github.com/ngimel
2022-08-04 16:51:56 +00:00
Ivan Yashchuk
900e93d351 Add context manager for conditional rewrites of torch.* to torch._refs.* calls (#81764)
Adds a new context manager `TorchRefsNvfuserCapabilityMode` for conditional rewrite of `torch.*` calls to `torch._refs.*` based on whether the decomposition consisting of prims supports nvFuser execution or not.

A new optional argument for `TorchRefsMode` is added - `should_fallback_fn`, a callable that returns whether the original `torch.foo` or the replacement `torch._refs.foo` should be used.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/81764
Approved by: https://github.com/ezyang
2022-08-02 11:02:10 +00:00
Kurt Mohler
14d0296e5c Rename _Typed/_UntypedStorage to Typed/UntypedStorage and update docs (#82438)
### Description

Since the major changes for `_TypedStorage` and `_UntypedStorage` are now complete, they can be renamed to be public.

`TypedStorage._untyped()` is renamed to `TypedStorage.untyped()`.

Documentation for storages is improved as well.

### Issue
Fixes #82436

### Testing
N/A

Pull Request resolved: https://github.com/pytorch/pytorch/pull/82438
Approved by: https://github.com/ezyang
2022-07-30 19:37:08 +00:00
Edward Z. Yang
fd5ac1e6b5 Rename SymbolicIntNode to SymIntNodeImpl (#82350)
Done via

```
git grep -l 'SymbolicIntNode' | xargs sed -i 's/SymbolicIntNode/SymIntNodeImpl/g'
```

Reasoning for the change:

* Sym is shorter than Symbolic, and consistent with SymInt
* You usually will deal in shared_ptr<...>, so we're going to
  reserve the shorter name (SymIntNode) for the shared pointer.

But I don't want to update the Python name, so afterwards I ran

```
 git grep -l _C.SymIntNodeImpl | xargs sed -i 's/_C.SymIntNodeImpl/_C.SymIntNode/'
```

and manually fixed up the binding code

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82350
Approved by: https://github.com/Krovatkin
2022-07-28 18:27:45 +00:00
Edward Z. Yang
98b9dfa129 Add decompositions for zero_, fill_, new_full, new_zeros, new_ones (#82332)
This makes symbolic tracing tests for logsigmoid and xlogy start working again.

While I'm at it, add pin_memory and layout kwargs to empty; but they
don't actually do anything and raise an error if they are non standard.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82332
Approved by: https://github.com/eellison
2022-07-28 04:02:02 +00:00
Elias Ellison
1c0f7bd6d2 Enable complex for meta tensors (#79975)
There weren't really any fundamental blockers
- add support for `aten::complex`
- update `angle` for complex
- remove the error in the fallback kernel
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79975
Approved by: https://github.com/ezyang
2022-07-27 22:19:14 +00:00
Edward Z. Yang
3b6b27e9d7 Add a miniature backend select implementation for prims (#82311)
It turns out that for factory function prims (prims with no Tensor
arguments), we were always going to the ATen implementation of
the operator.

Prior to the next PR in this stack, the change is a bit hard to
test, but you can indirectly observe the impact by running arange
with trace dispatching on (well, you need
https://github.com/pytorch/pytorch/pull/82277 patched in too.)

```
$ TORCH_SHOW_DISPATCH_TRACE=1 python -c "import torch._refs; torch._refs.arange(4, device='meta')"
[callBoxed] op=[prims::arange], key=[BackendSelect]
[call] op=[aten::empty_strided], key=[BackendSelect]
[redispatch] op=[aten::empty_strided], key=[Meta]
```

Previously, the prims::arange call was dispatching to Undefined.

For maximum fidelity, technically we're supposed to redispatch to a
specific dispatch key, but the Python bindings to do this don't exist
and it was easy to route to the implementations which we already
intended to go to.  We would have to fix this if we wanted external
backends to register custom implementations to OTHER dispatch keys
via Python op registration.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82311
Approved by: https://github.com/ngimel, https://github.com/bdhirsh
2022-07-27 20:21:20 +00:00
Kevin Stephano
7a4a8f327a Add new NVFuser Python Frontend Record Keeping for Cache enablement. (#81578)
This PR does not include an NVFuser frontend cache but it decouples the backed Fusion IR exposure and instead builds it as needed, if there was a cache, by recording the requested definition for replay to start the process of building a Fusion if it doesn't already exist.   Another PR will be put up to include the actual caching.

The main change in the Python Frontend is that the NVFuser Fusion IR is not directly defined by the interface. Currently, there is direct connection between the Python API and the creation of the Fusion IR and Object.  This means the user defines TensorViews, Scalars, and calls Arith Functions (IR Expressions) on those IR Values.  The goal is to disconnect the Python API from directly specifying the Fusion IR and enable caching of the IR so a Fusion Object is not necessarily built every time a Fusion Definition is seen.

The FusionDefinition in Python will mostly look the same except the Definition is now being recorded in a light weight representation called a "Recording" of Records.  If the Description is not already cached, the Records are executed to build the Fusion IR.  Initially, there is no caching because I am trying to bring up the representation first and get it correctly working.

This is what the Records look like.  The records are functors that are called if it is necessary to build the Fusion IR
torch/csrc/jit/codegen/cuda/python_frontend/fusion_record.h

**Tensor Definition Record**
_Note: The Tensor Definition will change for runtime contiguity caching, I am just matching what is already there for now._
```
InputTensorRecord(
      std::vector<size_t> _outputs,
      std::vector<int64_t> _symbolic_sizes,
      std::vector<bool> _contiguous_info,
      NvfDataType _dtype)
      : RecordFunctor({}, std::move(_outputs)),
        symbolic_sizes(std::move(_symbolic_sizes)),
        contiguous_info(std::move(_contiguous_info)),
        dtype(_dtype) {}
  void operator()(FusionDefinition& fd) final {
    auto tv = TensorViewBuilder()
                  .ndims(symbolic_sizes.size())
                  .contiguity(contiguous_info)
                  .shape(symbolic_sizes)
                  .dtype(dtype)
                  .build();

    fd.fusion_state.at(outputs.at(0)) = tv;
    fd.addInput(tv);
  }

  std::vector<int64_t> symbolic_sizes;
  std::vector<bool> contiguous_info;
  NvfDataType dtype;
};

```
**Generic Templatized Op Record Definition**
Op Records are notable because they record Fusion IR arith functions as the `fusion_op_`.
```
template <class OutType, class... ArgTypes>
struct OpRecord : RecordFunctor {
  OpRecord(
      std::vector<size_t> _args,
      std::vector<size_t> _outputs,
      std::function<OutType(ArgTypes...)> fusion_op)
      : RecordFunctor(std::move(_args), std::move(_outputs)),
        fusion_op_(fusion_op) {}

  template <class TupleType, std::size_t... Is>
  OutType opFunc(
      FusionDefinition& fd,
      TupleType& tp,
      std::index_sequence<Is...>) {
    return fusion_op_(
        dynamic_cast<typename std::tuple_element<Is, TupleType>::type>(
            fd.fusion_state.at(args.at(Is)))...);
  }

  void operator()(FusionDefinition& fd) final {
    using arg_tuple_t = std::tuple<ArgTypes...>;
    auto indices =
        std::make_index_sequence<std::tuple_size<arg_tuple_t>::value>();
    arg_tuple_t inputs;
    auto output = opFunc(fd, inputs, indices);
    fd.fusion_state.at(outputs.at(0)) = output;
  }

 private:
  std::function<OutType(ArgTypes...)> fusion_op_;
};
```

Perhaps the most confusing aspect of the Python Frontend is the `FusionDefinition`.  The C++ Class that is bound to is very light weight, purposely.  In an attempt to make sure users don't have to touch more than one file when adding new ops, assuming an appropriate Record has already been defined, the Python bindings effectively create functions that act on the FusionDefinition and appear as part of the class in Python but are not part of the class in C++.

Here is an example of a Unary Op Macro.  It is creating the binding to a lambda function that effectively appears as a FusionDefinition operation in Python.  The other way to do this would have been to create a class method directly in the `FusionDefinition` C++ and have a separate binding to that method.

```
#define NVFUSER_PYTHON_BINDING_UNARY_OP(op_str, op_name)              \
  nvf_ops.def(                                                        \
      op_str,                                                         \
      [](nvfuser::FusionDefinition::Operators& self,                  \
         nvfuser::Tensor* input) -> nvfuser::Tensor* {                \
        nvfuser::Tensor* output = new nvfuser::Tensor(                \
            self.fusion_definition->recording_state.size());          \
        self.fusion_definition->recording_state.emplace_back(output); \
        self.fusion_definition->recording.emplace_back(               \
            new nvfuser::OpRecord<NvfTensorView*, NvfTensorView*>(    \
                {input->index},                                       \
                {output->index},                                      \
                static_cast<NvfTensorView* (*)(NvfTensorView*)>(      \
                    torch::jit::fuser::cuda::op_name)));              \
        return output;                                                \
      },                                                              \
      py::return_value_policy::reference);                            \
```

Here is the `FusionDefinition` class edited for brevity.  The playing of the records will be found under the `exit()` method where exit refers to exiting of the Python Context Manager.  A `FusionDefinition` is captured through a context manager like the following:

```
fusion = Fusion()
with FusionDefinition(fusion) as fd :
    t0 = fd.define_tensor(sizes=[5], strides=[1])
    t1 = fd.ops.abs(t0)
    fd.add_output(t1)
```

```
class FusionDefinition {
 public:
  FusionDefinition(FusionOwner* fusion_owner)
    : fusion_owner_(fusion_owner),
      prev_fusion_(nullptr),
      recording(),
      recording_state(),
      fusion_state(),
      ops(this) {}

  // Context Manager Methods
  FusionDefinition* enter() {
    prev_fusion_ = FusionGuard::getCurFusion();
    FusionGuard::setCurFusion(fusionPtr());
    return this;
  }

  void exit() {
    // Found in the Python Bindings, currently.
    //for (auto& record : recording) {
    //  auto functor = record.get();
    //  (*functor)(self);
    //}

    FusionGuard::setCurFusion(prev_fusion_);
    prev_fusion_ = nullptr;
  }

  void addInput(torch::jit::fuser::cuda::Val* input) {
    fusionPtr()->addInput(input);
  }
  void addOutput(torch::jit::fuser::cuda::Val* output) {
    fusionPtr()->addOutput(output);
  }

  Fusion* fusionPtr() {
    return fusion_owner_->fusionPtr();
  }

 private:
  FusionOwner* fusion_owner_;
  Fusion* prev_fusion_;

 public:
  std::vector<std::unique_ptr<RecordFunctor>> recording;
  std::vector<std::unique_ptr<State>> recording_state;
  std::vector<NvfVal*> fusion_state;

  struct Operators {
    Operators(FusionDefinition* fd) : fusion_definition(fd) {}

    // Python operations are effectively bound here.

    FusionDefinition* fusion_definition;
  };

  Operators ops;
};
```

The Fusion IR doesn’t have `define_tensor` or `define_scalar` functions.  I made them up and the name for the Python `FusionDefinition` as a more understandable/convenient way to define input tensors and scalars.  `TensorView` objects and Fusion IR `Val` objects are not typically defined outside of a Fusion IR `Expr` output (typically arith function outputs) except for inputs to a graph.  Mechanically speaking, there are two things you need to do to define the input in the Fusion IR.  You need to define the IR `TensorView`/`Val` object and then record that the IR `TensorView`/`Val` object is an input in the `Fusion` Object that encapsulates the Fusion IR.  Since the `FusionDefinition` does not correspond one-to-one with the Fusion IR and `define_tensor` and `define_scalar` are made up functions, I decided to combine the `Val` Object creation and recording of the input in the `Fusion` object in one step to reduce the amount of syntax required to define a Fusion in the python interface.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81578
Approved by: https://github.com/jjsjann123, https://github.com/IvanYashchuk, https://github.com/SherlockNoMad
2022-07-26 21:34:20 +00:00
soulitzer
6b0ca72b61 Add prim, ref, and OpInfo for arange (#81734)
Per title.
- See https://github.com/pytorch/pytorch/issues/81959 for discussion on overloading

TODO:
- ~Handle remaining TensorOptions: layout, pin_memory (won't do in this PR)~
- ~Add sample inputs for floating point and complex numbers (done for floating point, won't do for complex in this PR)~
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81734
Approved by: https://github.com/ngimel
2022-07-26 19:35:02 +00:00
PyTorch MergeBot
dc57624622 Revert "Add prim, ref, and OpInfo for arange (#81734)"
This reverts commit 67dc9abbd8.

Reverted https://github.com/pytorch/pytorch/pull/81734 on behalf of https://github.com/kit1980 due to Broke trunk slow tests 67dc9abbd8
2022-07-26 05:56:09 +00:00
soulitzer
67dc9abbd8 Add prim, ref, and OpInfo for arange (#81734)
Per title.
- See https://github.com/pytorch/pytorch/issues/81959 for discussion on overloading

TODO:
- ~Handle remaining TensorOptions: layout, pin_memory (won't do in this PR)~
- ~Add sample inputs for floating point and complex numbers (done for floating point, won't do for complex in this PR)~
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81734
Approved by: https://github.com/ngimel
2022-07-25 23:05:01 +00:00
lezcano
11fe277b62 [PrimTorch] Add reference for torch.norm (#81765)
This ref does more things than `torch.norm`, and it fixes a few bugs
that `torch.norm` has. This implementation and the `torch.norm`
implementation come to terms in the next PR of this stack

We put this PR before, as otherwise `test_decomp.py` was failing.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81765
Approved by: https://github.com/ngimel
2022-07-25 19:57:21 +00:00
Horace He
1a18ff3247 Revert "Revert "Added dynamic shape POC (#81093)"" (#82063)
This reverts commit 0888a4844c.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82063
Approved by: https://github.com/ezyang
2022-07-23 22:35:50 +00:00
PyTorch MergeBot
0888a4844c Revert "Added dynamic shape POC (#81093)"
This reverts commit 8169a85dc6.

Reverted https://github.com/pytorch/pytorch/pull/81093 on behalf of https://github.com/janeyx99 due to Broke slow tests on trunk 8169a85dc6.
2022-07-23 11:30:37 +00:00
Horace He
8169a85dc6 Added dynamic shape POC (#81093)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81093
Approved by: https://github.com/ezyang, https://github.com/eellison
2022-07-23 04:46:32 +00:00
samdow
2ac24675cc get rid of push_torch_{dispatch, function}_mode (#78215)
Currently we have 2 ways of doing the same thing for torch dispatch and function modes:
`with push_torch_dispatch_mode(X)` or `with X.push(...)`
is now the equivalent of doing
`with X()`

This removes the first API (which is older and private so we don't need to go through a deprecation cycle)

There is some risk here that this might land race with a PR that uses the old API but in general it seems like most are using the `with X()` API or `enable_torch_dispatch_mode(X())` which isn't getting removed.

EDIT: left the `with X.push(...)` API since there were ~3 land races with that over the past day or so. But made it give a warning and ask users to use the other API
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78215
Approved by: https://github.com/ezyang
2022-07-22 18:56:37 +00:00
Huy Do
12cb26509a Apply ufmt to torch internal (#81643)
This is a big bang PR, merge conflicts are probably expected and will be addressed at merge.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81643
Approved by: https://github.com/ezyang
2022-07-22 02:19:50 +00:00
Horace He
a5fb41e3d3 Revert "Revert "Refactored prim utils into _prims_utils folder (#81746)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81746
Approved by: https://github.com/anijain2305, https://github.com/Krovatkin
2022-07-20 23:43:57 +00:00
Ivan Yashchuk
a3d5d2ddf1 Add partitioned nvFuser executor with ATen fallbacks (#81043)
This PR introduces a new nvFuser executor for FX graphs containing different kinds of nodes, not just `torch.ops.prims` supported by nvFuser. The FX graph is partitioned based on whether nodes are supported or not by nvFuser and supported nodes are fused into subgraphs, that's all using Sherlock's work on the partitioner.

This new partitions-based executor with fallbacks to ATen is used by default with `executor="nvfuser"`. And the previous executor can be used with `executor="strictly_nvfuser"`, naming suggestions are welcome!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81043
Approved by: https://github.com/jjsjann123, https://github.com/SherlockNoMad
2022-07-20 19:51:20 +00:00
PyTorch MergeBot
e43a02c314 Revert "Refactored prim utils into _prims_utils folder (#81088)"
This reverts commit 80231d0a72.

Reverted https://github.com/pytorch/pytorch/pull/81088 on behalf of https://github.com/jeanschmidt due to breaking internal tests
2022-07-19 19:56:41 +00:00
Horace He
80231d0a72 Refactored prim utils into _prims_utils folder (#81088)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81088
Approved by: https://github.com/ngimel
2022-07-19 03:55:51 +00:00
Sergii Dymchenko
6884865009 Remove unnecessary assigment (#81498)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81498
Approved by: https://github.com/seemethere
2022-07-18 16:26:40 +00:00
Edward Z. Yang
471397d0ee Remove spurious assert (#81604)
This assert was triggering after https://github.com/pytorch/functorch/pull/935
in some of the exhaustive python key tests, but the test is fine
even without it.  I added an explanation in code.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81604
Approved by: https://github.com/Chillee
2022-07-18 04:11:14 +00:00
Peter Bell
bf36d8b987 [primTorch] Implement one-dimensional fft transforms (#80570)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80570
Approved by: https://github.com/mruberry
2022-07-15 15:13:43 +00:00
Peter Bell
00459c2c87 [primTorch] Implement constant_pad_nd (#80182)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80182
Approved by: https://github.com/mruberry, https://github.com/ngimel
2022-07-15 15:13:42 +00:00
jjsjann123
2f55050fb5 Bump nvfuser executor lru cache max size (#81461)
default 128 cache size has been causing no cache hit on some benchmark results with more than 128 partition. Bumping up to a more reasonable cache size.
Note that the simple LRU_CACHE doesn't give us any reuse of repetitive pattern, but that shouldn't be of much issue in our next iteration of nvfuser python API.

script for running benchmarks vvv
https://github.com/SherlockNoMad/NvFuserSample

Pull Request resolved: https://github.com/pytorch/pytorch/pull/81461
Approved by: https://github.com/SherlockNoMad
2022-07-14 22:15:29 +00:00
PyTorch MergeBot
6997ac79d6 Revert "[primTorch] Implement constant_pad_nd (#80182)"
This reverts commit 77cfa9f7a1.

Reverted https://github.com/pytorch/pytorch/pull/80182 on behalf of https://github.com/clee2000 due to causes failures on trunk / linux-bionic-py3.7-clang9-slow / test (slow https://github.com/pytorch/pytorch/runs/7343337014?check_suite_focus=true
2022-07-14 17:30:51 +00:00
Peter Bell
77cfa9f7a1 [primTorch] Implement constant_pad_nd (#80182)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80182
Approved by: https://github.com/mruberry, https://github.com/ngimel
2022-07-14 15:29:42 +00:00
Peter Bell
924b7951aa [primTorch] Implement conj and conj_physical (#80358)
This adds `prims.conj` and `prims.conj_physical` which only accept
complex tensors, as well as `refs.conj` and `refs.conj_physical` which
pass-through non-complex values and call the appropriate `prims` for
complex types.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80358
Approved by: https://github.com/mruberry
2022-07-14 15:29:41 +00:00
Sherlock Huang
6b280e880a Update NvFuserOperatorSupport (#81311)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81311
Approved by: https://github.com/davidberard98
2022-07-12 21:19:37 +00:00
Mike Ruberry
8740c68c41 [primTorch] Adds contiguous and expand references (#79820)
I also filed  while creating this PR.

This PR...

**Filed issues**

- https://github.com/pytorch/pytorch/issues/79818
- https://github.com/pytorch/pytorch/issues/80154

**prims**

- Fixes prims.squeeze when called with an unsorted list of dimensions
- Removes the clone prim

**refs**
- adds contiguous
- adds expand
- updates clone to call empty_like and copy_to
- updates empty to accept a memory format
- updates empty_like to accept a memory_format

**utils**
- adds helper functions for working with memory formats and channels last tensors, in particular

**tests**

- removes unused clamp sample input functions (mooted by clamp's new reference inputs)
- extends the reference inputs for clone to include different memory formats
- creates reference inputs for contiguous
- xfails operators that depend on clone (including clone) on `test_python_ref` (see issues)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79820
Approved by: https://github.com/ngimel
2022-07-11 17:42:58 +00:00
lezcano
24af7948ca Add prim.svd, refs.linalg.svd, and refs.linalg.svdvals (#78616)
This is the first prim / ref added that has multiple returns.
There is an issue with `out_wrapper_multi` as currently implemented
(left a note). It assumes that the API is `svd(X, U=U, S=S, Vh=Vh)`,
when it's actually `svd(X, out=(U, S, Vh))`.

Even more, if we want to model PyTorch exactly, it should return a
`torch.return_types.svd`, rather than a `Tuple`.

There is an issue with
As per title
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78616
Approved by: https://github.com/mruberry
2022-07-09 19:42:01 +00:00
lezcano
45b67b65de Fix handling of device (#78615)
Removes an unnecessary auxiliary function (we had already implemented
it), uses DeviceLikeType to denote str or dtype, and adds `is_cpu` and
`is_cuda` helper functions
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78615
Approved by: https://github.com/mruberry
2022-07-09 19:42:01 +00:00
lezcano
e9a9b50f48 Reference for linalg.vector_norm (#78350)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78350
Approved by: https://github.com/mruberry
2022-07-09 19:42:01 +00:00
Aidyn-A
04ef236c0d [primTorch] Elementwise unary ops vi (#79526)
This PR add primitives and references for `heaviside` and `hypot`.
Depends on #80146
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79526
Approved by: https://github.com/mruberry
2022-07-08 15:17:45 +00:00
Xiang Gao
dad071d8fe [nvFuser] Add real and imag to nvfuser and its python frontend (#79824)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79824
Approved by: https://github.com/mruberry, https://github.com/jjsjann123, https://github.com/kevinstephano
2022-07-07 17:25:42 +00:00
lezcano
beb98676ba Correct cbrt implementation (#80443)
Following
https://github.com/pytorch/pytorch/pull/80219#discussion_r907680368
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80443
Approved by: https://github.com/IvanYashchuk, https://github.com/mruberry
2022-07-07 16:15:42 +00:00
Xiang Gao
9ca561cbd4 Add prims::{real, imag}, refs::{real, imag}, remove prims::is_infinite (#80148)
This is a subset of https://github.com/pytorch/pytorch/pull/78655. I want to land this separately because the world is changing so fast, and https://github.com/pytorch/pytorch/pull/78655 is having lots of conflicts with other parts of the world.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80148
Approved by: https://github.com/mruberry
2022-07-05 23:06:39 +00:00
Nikita Shulga
b62209f047 [Prims] Unbreak CUDA lazy init (#80899)
CUDA calls should not be made in the default codepath

Fixes https://github.com/pytorch/pytorch/issues/80876
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80899
Approved by: https://github.com/ngimel
2022-07-05 21:16:58 +00:00
Ivan Yashchuk
9a12aa6cad Add cached nvFuser's fusion creation for torch._prims.executor (#80525)
In the current setup for each call of the `execute` function, a `Fusion` object was constructed using `GraphModule` and args, that's expensive.

This PR makes use of `functools.lru_cache` to pay the `Fusion` creation cost once per `GraphModule` and set of args. Currently, the shape, strides, and dtype of tensors are static it can be changed later to make better use of the nvFuser's internal caching mechanism (by specifying only ndim, contiguity, dtype).

On master:
```py
In [2]: a = torch.randn(3, 3, device='cuda')

In [3]: with TorchRefsMode.push():
   ...:     gm = make_fx(lambda x: torch.sigmoid(x))(a)
   ...:

In [4]: %%timeit
   ...: execute(gm, a, executor="nvfuser")
   ...: torch.cuda.synchronize()
175 ms ± 1.18 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
```
This PR:
```py
In [2]: a = torch.randn(3, 3, device='cuda')

In [3]: with TorchRefsMode.push():
   ...:     gm = make_fx(lambda x: torch.sigmoid(x))(a)
   ...:

In [4]: %%timeit
   ...: execute(gm, a, executor="nvfuser")
   ...: torch.cuda.synchronize()
62.6 µs ± 9.99 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
```

In addition, this PR adds support for pytree inputs and extends the test for this.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80525
Approved by: https://github.com/kevinstephano, https://github.com/jjsjann123, https://github.com/SherlockNoMad
2022-07-05 17:00:45 +00:00
lezcano
eb0889cf7d Add support for multiple inputs to out_wrapper and strict dtype checking (#80601)
Reland of https://github.com/pytorch/pytorch/pull/79941
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80601
Approved by: https://github.com/albanD
2022-07-05 12:31:21 +00:00
PyTorch MergeBot
184a065ba7 Revert "Add support for multiple inputs to out_wrapper and strict dtype checking (#79941)"
This reverts commit dc7066a8f0.

Reverted https://github.com/pytorch/pytorch/pull/79941 on behalf of https://github.com/suo due to broke master dc7066a8f0
2022-06-30 03:29:30 +00:00
lezcano
dc7066a8f0 Add support for multiple inputs to out_wrapper and strict dtype checking (#79941)
When a function returns multiple parameters in PyTorch, the `out`
parameter takes a tuple of tensors (see `linalg.svd` for example).
The current implementation in `out_wrapper_multi` modelled this wrong,
as it assumed that it would take a number of different named
parameters.

This PR implements the correct behaviour in `out_wrapper`. As a small
side-effect, we now need to call `@out_wrapper()` when the output is
just one tensor.

This PR also implements an additional optional parameter that checks
whether the dtype of the given `out` is exactly the dtype that the meta
function requires. This is the behaviour that we currently have in
PyTorch, and this check is necessary in eager when we call with these
tensors into external libraries.

We also make the functions with several outputs return a namedtuple,
similar to what we do in PyTorch.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79941
Approved by: https://github.com/mruberry, https://github.com/ezyang
2022-06-30 02:47:16 +00:00
lezcano
2d100eaa40 Correct meta behaviour of prims.resize (#80516)
The previous behaviour did not modify the tensor in-place when it should
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80516
Approved by: https://github.com/ezyang
2022-06-30 02:47:16 +00:00
Aidyn-A
74fb6ee4c5 [primTorch] support one tensor and two scalars in _prims.where (#80146)
Fixes an issue of supporting two scalar arguments for `where` and other functions with similar set of arguments:

```
refs.where(a, 1, 0)
```

I had to skip `test_python_ref_executor` because the test causes a `Segmentation fault` when running with two scalars.
The issue https://github.com/csarofeen/pytorch/issues/1770 has been fixed https://github.com/csarofeen/pytorch/pull/1774, so we can lift the skip when its merged to the upstream.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80146
Approved by: https://github.com/ngimel
2022-06-29 19:58:31 +00:00
Ryan Spring
1d0d506e97 Add Div reference (#77936)
Add Prims:
-  trunc
-  Replace _wrap_scalar with scalar_tensor

Add Reference:
-  copysign
- div
- floor_divide
- trunc_divide

Other:
* Add support for `variant_test_name` in _find_referenced_opinfo
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77936
Approved by: https://github.com/mruberry
2022-06-27 14:46:17 +00:00
Ivan Yashchuk
0d5bc54114 Fix interpretation torch -> torch._refs in case of nested torch calls under TorchRefsMode (#80135)
torch calls inside `TorchRefsMode.__torch_function__` dispatch should be interpreted as refs calls under `TorchRefsMode`. Fixes https://github.com/pytorch/pytorch/issues/80079.

In addition, this PR enables two more tests for the nvFuser executor.

For example here's the FX trace of `torch._refs.nn.functional.layer_norm` before the proposed change (note the mix of `aten` and `prims`):
```py
opcode         name                    target                      args                              kwargs
-------------  ----------------------  --------------------------  --------------------------------  -----------------
placeholder    a_1                     a_1                         ()                                {}
call_function  convert_element_type    prims.convert_element_type  (a_1, torch.float32)              {}
call_function  var                     prims.var                   (convert_element_type, [0, 1])    {'correction': 0}
call_function  broadcast_in_dim        prims.broadcast_in_dim      (var, [1, 1], [])                 {}
call_function  convert_element_type_1  prims.convert_element_type  (a_1, torch.float32)              {}
call_function  sum_1                   prims.sum                   (convert_element_type_1, [0, 1])  {}
call_function  broadcast_in_dim_1      prims.broadcast_in_dim      (sum_1, [1, 1], [])               {}
call_function  div                     prims.div                   (broadcast_in_dim_1, 9.0)         {}
call_function  add                     aten.add                    (broadcast_in_dim, 1e-05)         {}
call_function  rsqrt                   aten.rsqrt                  (add,)                            {}
call_function  sub                     aten.sub                    (a_1, div)                        {}
call_function  mul                     aten.mul                    (sub, rsqrt)                      {}
call_function  convert_element_type_2  prims.convert_element_type  (mul, torch.float32)              {}
output         output                  output                      (convert_element_type_2,)         {}
```
And with this PR:
```py
opcode         name                    target                      args                              kwargs
-------------  ----------------------  --------------------------  --------------------------------  -----------------
placeholder    a_1                     a_1                         ()                                {}
call_function  convert_element_type    prims.convert_element_type  (a_1, torch.float32)              {}
call_function  var                     prims.var                   (convert_element_type, [0, 1])    {'correction': 0}
call_function  broadcast_in_dim        prims.broadcast_in_dim      (var, [1, 1], [])                 {}
call_function  convert_element_type_1  prims.convert_element_type  (a_1, torch.float32)              {}
call_function  sum_1                   prims.sum                   (convert_element_type_1, [0, 1])  {}
call_function  broadcast_in_dim_1      prims.broadcast_in_dim      (sum_1, [1, 1], [])               {}
call_function  div                     prims.div                   (broadcast_in_dim_1, 9.0)         {}
call_function  add                     prims.add                   (broadcast_in_dim, 1e-05)         {}
call_function  rsqrt                   prims.rsqrt                 (add,)                            {}
call_function  broadcast_in_dim_2      prims.broadcast_in_dim      (div, [3, 3], [0, 1])             {}
call_function  sub                     prims.sub                   (a_1, broadcast_in_dim_2)         {}
call_function  broadcast_in_dim_3      prims.broadcast_in_dim      (rsqrt, [3, 3], [0, 1])           {}
call_function  mul                     prims.mul                   (sub, broadcast_in_dim_3)         {}
call_function  convert_element_type_2  prims.convert_element_type  (mul, torch.float32)              {}
output         output                  output                      (convert_element_type_2,)         {}
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80135
Approved by: https://github.com/ngimel
2022-06-25 03:55:04 +00:00
Khushi Agrawal
b00448df6b [primTorch] asinh, atanh (#80210)
Adds asinh and atanh refs and prims. These are prims because both are C++ standard library calls.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/80210
Approved by: https://github.com/mruberry
2022-06-24 16:43:35 +00:00
Ivan Yashchuk
072311bb28 Enable torch._prims.amax/amin for nvFuser executor (#80070)
This PR adds nvFuser implementations for `torch._prims.amax` and `torch._prims.amin` reduction functions. Currently, nvFuser refuses to reduce the 0d tensor, so these inputs are skipped in tests for now.

An accompanying fix replaces `collections.Sequence` -> `collections.abc.Sequence` in refs because `collections.Sequence` is deprecated and removed in Python 3.10

Many ops that were skipped for the nvFuser executor test are now enabled.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80070
Approved by: https://github.com/ngimel
2022-06-23 10:19:57 +00:00
Natalia Gimelshein
9244547a1b small cleanup of executor (#79973)
per title
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79973
Approved by: https://github.com/mruberry
2022-06-22 00:35:51 +00:00
Mike Ruberry
ca845462a0 Fixes maybe_broadcast to actually broadcast only when needed (#79298)
Adds a `same_shape` util and updates maybe_broadcast to use it; previously maybe_broadcast was always broadcasting because its equality check was always failing.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79298
Approved by: https://github.com/ezyang
2022-06-21 19:26:40 +00:00
Natalia Gimelshein
c0ce4b0de9 make refs executor handle kwargs (#79858)
Mostly fixes #78923
I had to disable function patching in fx for functions with kwonly args, see https://github.com/pytorch/pytorch/compare/ngimel/make_fx_fix?expand=1#diff-090b22122be0779cd14afd2ebaf20d1e7c0bfe837e9eefa1d84e7521bb1defc6R446, cc @jamesr66a
But it looks like it was doing weird things anyway - it was patching signature of wrapped function with arbitrary local vars from wrapper, that can't be right, but I don't know what the intent there is.
A lot of functions now fail with nvfuser executor, and some still fail with aten, although with the different errors than before.
Edit: undid the change to _symbolic_script.py, turns out inspect.unwrapping function is not needed, and fx never sees kwargs.
cc @IvanYashchuk, @Chillee

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79858
Approved by: https://github.com/IvanYashchuk, https://github.com/mruberry
2022-06-21 18:53:15 +00:00
Horace He
e89676f76c fix logical_not reland issues
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79900

Approved by: https://github.com/ngimel
2022-06-21 03:41:18 +00:00
Nikita Shulga
f5eb05f107 Revert "Reland #2 of "Added {logical_not, trace} refs, moved logical ops to use method overloads""
This reverts commit f3665dd237.

Reverted https://github.com/pytorch/pytorch/pull/79819 on behalf of https://github.com/malfet due to land raced with softshrink refs
2022-06-20 14:22:15 -07:00
Horace He
f3665dd237 Reland #2 of "Added {logical_not, trace} refs, moved logical ops to use method overloads"
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79819

Approved by: https://github.com/mruberry
2022-06-20 19:50:43 +00:00
Syed Tousif Ahmed
802efc90c6 [primTorch] Implements refs for gcd, lcm and remainder (#78747)
This PR implements the references for gcd, lcm and remainder. Additionally, `gcd` is added as a prim, since we currently don't have a while loop construct.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78747
Approved by: https://github.com/mruberry
2022-06-20 01:51:32 +00:00
Ivan Yashchuk
e10b762537 Enable torch._refs.var for nvFuser executor (#79517)
This PR adds variance function with correction argument to nvFuser.

Now it's possible to run
```py
import torch
import torch._refs
from torch._prims.executor import make_traced

def foo1(a):
    return torch._refs.var(a, keepdim=False, unbiased=False)

def foo2(a):
    return torch._refs.var(a, keepdim=False, correction=2)

a = torch.randn(3, 3, device='cuda')
make_traced(foo1)(a, executor="nvfuser")
make_traced(foo2)(a, executor="nvfuser")
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79517
Approved by: https://github.com/mruberry, https://github.com/jjsjann123
2022-06-14 23:08:53 +00:00
Ivan Yashchuk
8895862744 Enable torch._refs.mean for nvFuser executor (#79444)
This PR fixes a bug with `broadcast_in_dim` leading to the situation when reduction ops were not allowed to be used before `broadcast_in_dim`.

With this PR it's possible to run
```py
import torch
import torch._refs
from torch._prims.executor import make_traced

def foo(a):
    return torch._refs.mean(a, keepdim=False)

a = torch.randn(3, 3, device='cuda')
make_traced(foo)(a, executor="nvfuser")
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79444
Approved by: https://github.com/mruberry, https://github.com/jjsjann123
2022-06-14 19:42:07 +00:00
PyTorch MergeBot
66460c4a6a Revert "Fixes maybe_broadcast to actually broadcast only when needed (#79298)"
This reverts commit 1cb1c2c08c.

Reverted https://github.com/pytorch/pytorch/pull/79298 on behalf of https://github.com/suo due to Broke FakeTensor tests on master, see: 1cb1c2c08c
2022-06-11 23:36:18 +00:00
Mike Ruberry
1cb1c2c08c Fixes maybe_broadcast to actually broadcast only when needed (#79298)
Adds a `same_shape` util and updates maybe_broadcast to use it; previously maybe_broadcast was always broadcasting because its equality check was always failing.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79298
Approved by: https://github.com/ezyang
2022-06-11 22:04:47 +00:00
PyTorch MergeBot
fefff54cad Revert "Revert "Revert "Added {logical_not, trace} refs, moved logical ops to use method overloads"""
This reverts commit a2d2981e8e.

Reverted https://github.com/pytorch/pytorch/pull/79224 on behalf of https://github.com/suo due to broke lots of things a2d2981e8e
2022-06-10 04:40:43 +00:00
Horace He
a2d2981e8e Revert "Revert "Added {logical_not, trace} refs, moved logical ops to use method overloads""
This reverts commit d67309aefb.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79224

Approved by: https://github.com/mruberry
2022-06-10 03:07:14 +00:00
PyTorch MergeBot
d67309aefb Revert "Added {logical_not, trace} refs, moved logical ops to use method overloads"
This reverts commit 64b6bd8c1e.

Reverted https://github.com/pytorch/pytorch/pull/79000 on behalf of https://github.com/malfet due to Introduces test failure, see https://hud.pytorch.org/pr/79000
2022-06-09 13:11:23 +00:00
Horace He
64b6bd8c1e Added {logical_not, trace} refs, moved logical ops to use method overloads
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79000

Approved by: https://github.com/ezyang
2022-06-09 07:16:36 +00:00
Elias Ellison
3c5a3ca9e8 Make FakeTensors return meta within kerenl invocation, add FakeTensor op tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78972

Approved by: https://github.com/ezyang
2022-06-09 01:39:27 +00:00
Elias Ellison
290d0979f1 Migrate FakeTensors to always call into FakeTensorMode and have them hold a reference
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78677

Approved by: https://github.com/ezyang
2022-06-08 22:30:34 +00:00
Horace He
e675dbadc4 Ported gelu decomp to ref (#78697)
Ugh... these are actually so painful to write without operator overloading lol.

Decided to just utilize operator overloading, and xfail the ref tests for now.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78697
Approved by: https://github.com/mruberry
2022-06-06 22:30:20 +00:00
Edward Z. Yang
80f2c175be Follow up on CR for "Replace TensorMeta with FakeTensor"
See https://github.com/pytorch/pytorch/pull/78836

Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78895

Approved by: https://github.com/albanD
2022-06-06 22:20:40 +00:00
Kshiteej K
c461d8a977 [primTorch] refs: hsplit, vsplit (#78418)
As per title

TODO:
* [x] Add error inputs (already exist)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78418
Approved by: https://github.com/mruberry
2022-06-06 19:54:05 +00:00
Edward Z. Yang
99882fc492 Make check() strongly typed, fix erroneous call sites
Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78896

Approved by: https://github.com/Lezcano, https://github.com/anjali411
2022-06-05 23:10:55 +00:00
Edward Z. Yang
587efdb5fa Replace TensorMeta with FakeTensor
Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78836

Approved by: https://github.com/albanD, https://github.com/mruberry
2022-06-05 11:51:27 +00:00
Ivan Yashchuk
df748b60f7 Allow pytrees as output for make_traced and nvfuser executor (#78802)
This PR lifts the restriction that the output of a function traced with `make_traced` and executed with nvFuser must be a single tensor. Now it's possible to return a "pytree", a tensor's nested data structure (see https://github.com/pytorch/pytorch/blob/master/torch/utils/_pytree.py).

I added a test with a function that returns a tuple of two objects where one of the objects is a dictionary with a tensor value.

```py
def fn(a, b):
    d = {}
    d["c"] = torch.add(a, b)
    return (d, torch.add(a, d["c"]))
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78802
Approved by: https://github.com/mruberry
2022-06-04 08:41:18 +00:00
Edward Z. Yang
83d40a4dba linalg_cholesky_ex meta function
Taken from https://github.com/albanD/subclass_zoo/blob/main/python_meta_tensor.py

Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78604

Approved by: https://github.com/bdhirsh, https://github.com/ngimel, https://github.com/Lezcano
2022-06-03 23:11:02 +00:00
Edward Z. Yang
6120a8e05d Implement meta function for aten::index.Tensor
Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78527

Approved by: https://github.com/bdhirsh, https://github.com/ngimel, https://github.com/Lezcano
2022-06-03 23:11:02 +00:00
Gao, Xiang
eb88ea01b5 Cleanup impl_nvfuser for unary ops (#78670)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78670
Approved by: https://github.com/mruberry, https://github.com/IvanYashchuk
2022-06-02 16:17:47 +00:00
Ivan Yashchuk
0be4672a9d [primTorch] Use the same error message as in ATen for canonicalize_dim (#78541)
Fixes https://github.com/pytorch/pytorch/issues/78252.

Locally nothing seems to break when changing the error type and the error message meaning there were no tests.
At least one xfailed test from https://github.com/pytorch/pytorch/pull/78080 wouldn't pass with this PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78541
Approved by: https://github.com/ngimel, https://github.com/mruberry
2022-06-02 12:10:41 +00:00
jjsjann123
fea909b43e [primTorch] Adds broadcast_shapes reference (#78612)
1. Added references `_refs.broadcast_shapes`
2. Added OpInfo test for `torch.broadcast_shapes`

A few minor changes:
- `test_python_ref_meta` and `_ref_test_helper` update to avoid non-tensor outputs
- type annotation update for `_resize_meta`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78612
Approved by: https://github.com/mruberry
2022-06-02 08:56:37 +00:00
Xiang Gao
b651148fc3 remove prims::square (#78627)
because it is just `x * x`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78627
Approved by: https://github.com/mruberry
2022-06-02 02:18:17 +00:00
PyTorch MergeBot
6dafefe3d4 Revert "[primTorch] Use the same error message as in ATen for canonicalize_dim (#78541)"
This reverts commit c054993b53.

Reverted https://github.com/pytorch/pytorch/pull/78541 on behalf of https://github.com/malfet due to as it depends on https://github.com/pytorch/pytorch/pull/78080 that caused XLA failures and is getting reverted
2022-06-01 16:48:00 +00:00
Ivan Yashchuk
c054993b53 [primTorch] Use the same error message as in ATen for canonicalize_dim (#78541)
Fixes https://github.com/pytorch/pytorch/issues/78252.

Locally nothing seems to break when changing the error type and the error message meaning there were no tests.
At least one xfailed test from https://github.com/pytorch/pytorch/pull/78080 wouldn't pass with this PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78541
Approved by: https://github.com/ngimel, https://github.com/mruberry
2022-06-01 16:09:35 +00:00
Ivan Yashchuk
20c32503eb Add more impl_nvfuser for prims (#78493)
This PR adds `test_nvfuser_impl_is_used` that checks that the corresponding nvfuser op (if available) is used in the prim definition.

Adds `impl_nvfuser=` for atan2, bitwise_and, bitwise_or, bitwise_xor, eq, ne, pow, sub, sum, where, rsqrt, lgamma.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78493
Approved by: https://github.com/mruberry
2022-06-01 14:10:17 +00:00
Edward Z. Yang
b7215de32f prod ref
It turns out the prim is implemented incorrectly as torch.prod does not accept
a dim list, so I added a little stub for this.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78461

Approved by: https://github.com/ngimel
2022-05-31 14:18:49 +00:00
Ryan Spring
2df1da09e1 Add Elementwise unary ops 4 references (#78216)
Add reference implementations for `nan_to_num, positive, sigmoid, signbit, tanhshink`
Add prims for `minimum_value(dtype)` and `maximum_value(dtype)`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78216
Approved by: https://github.com/mruberry
2022-05-27 21:55:34 +00:00
jjsjann123
1a9a1b8b5e fixing typo (#78417)
primtorch prod is mistakenly using `_sum_doc`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78417
Approved by: https://github.com/malfet
2022-05-27 17:10:15 +00:00
Aidyn-A
31016eb81e [primTorch] Elementwise Binary Ops I (#78023)
This PR is a result of collaboration with @rdspring1 and @mruberry on primTorch.

It adds the following prims:
- `fmax`
- `fmin`
- `fmod`

And adds the following refs:
- `fmax`
- `fmin`
- `fmod`
- `logical_xor`

The work is in progress as there are some tests that fail.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78023
Approved by: https://github.com/mruberry
2022-05-26 20:22:27 +00:00