This PR adds nvfuser-specific primitive - `var_mean`.
Interpretation `torch.var_mean` -> `torch.ops.nvprims.var_mean` is handled by `TorchRefsNvfuserCapabilityMode` context manager.
I moved some helper code from `_prims/__init__.py` to `_prims_common`. Correctness is tested with OpInfo tests (see `PythonRefInfo("ops.nvprims.var_mean"`).
Layer norm reference now uses `torch.var_mean` instead of `torch._refs.var_mean` to allow interception. Here's a simple comparison of performance with this PR and master (on 3080ti):
```py
import torch
from torch._prims.context import TorchRefsNvfuserCapabilityMode
from torch.fx.experimental.proxy_tensor import make_fx
from torch._prims.executor import execute
def func(a):
return torch.native_layer_norm(a, (1024,), None, None, 1e-6)
a = torch.randn(10, 512, 1024, dtype=torch.float16, device="cuda")
with TorchRefsNvfuserCapabilityMode():
gm = make_fx(func)(a)
for _ in range(10):
execute(gm, a, executor="strictly_nvfuser");
```
run with `PYTORCH_NVFUSER_DUMP=dump_eff_bandwidth python script.py`
```py
# WITH THIS PR
# kernel1 run in 0.032768 ms, achieved: 641.25 GB/s
# kernel1 run in 0.033792 ms, achieved: 621.818 GB/s
# kernel1 run in 0.032768 ms, achieved: 641.25 GB/s
# kernel1 run in 0.032608 ms, achieved: 644.396 GB/s
# kernel1 run in 0.031744 ms, achieved: 661.935 GB/s
# kernel1 run in 0.031744 ms, achieved: 661.935 GB/s
# kernel1 run in 0.032768 ms, achieved: 641.25 GB/s
# kernel1 run in 0.03072 ms, achieved: 684 GB/s
# kernel1 run in 0.031744 ms, achieved: 661.935 GB/s
# kernel1 run in 0.031744 ms, achieved: 661.935 GB/s
# ON MASTER
# kernel1 run in 0.05632 ms, achieved: 373.091 GB/s
# kernel1 run in 0.044032 ms, achieved: 477.209 GB/s
# kernel1 run in 0.044032 ms, achieved: 477.209 GB/s
# kernel1 run in 0.044032 ms, achieved: 477.209 GB/s
# kernel1 run in 0.043808 ms, achieved: 479.649 GB/s
# kernel1 run in 0.043008 ms, achieved: 488.571 GB/s
# kernel1 run in 0.044032 ms, achieved: 477.209 GB/s
# kernel1 run in 0.043008 ms, achieved: 488.571 GB/s
# kernel1 run in 0.043008 ms, achieved: 488.571 GB/s
# kernel1 run in 0.043008 ms, achieved: 488.571 GB/s
```
So this PR gives about 35% improvement in performance using nvfuser executor with this specific normalized shape.
Also this PR fixes https://github.com/pytorch/pytorch/issues/83506 (see the change in `torch/csrc/jit/python/pybind_utils.cpp`).
Ref. https://github.com/pytorch/pytorch/issues/80187
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83508
Approved by: https://github.com/ngimel
Conditional decomposing aten::_to_copy to nvprim::convert_element_type to allow fusion with type casting, which is introduced during type promotion phase at torch decomposition.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83782
Approved by: https://github.com/ngimel
This PR adds nvfuser-specific primitive - `var_mean`.
Interpretation `torch.var_mean` -> `torch.ops.nvprims.var_mean` is handled by `TorchRefsNvfuserCapabilityMode` context manager.
I moved some helper code from `_prims/__init__.py` to `_prims_common`. Correctness is tested with OpInfo tests (see `PythonRefInfo("ops.nvprims.var_mean"`).
Layer norm reference now uses `torch.var_mean` instead of `torch._refs.var_mean` to allow interception. Here's a simple comparison of performance with this PR and master (on 3080ti):
```py
import torch
from torch._prims.context import TorchRefsNvfuserCapabilityMode
from torch.fx.experimental.proxy_tensor import make_fx
from torch._prims.executor import execute
def func(a):
return torch.native_layer_norm(a, (1024,), None, None, 1e-6)
a = torch.randn(10, 512, 1024, dtype=torch.float16, device="cuda")
with TorchRefsNvfuserCapabilityMode():
gm = make_fx(func)(a)
for _ in range(10):
execute(gm, a, executor="strictly_nvfuser");
```
run with `PYTORCH_NVFUSER_DUMP=dump_eff_bandwidth python script.py`
```py
# WITH THIS PR
# kernel1 run in 0.032768 ms, achieved: 641.25 GB/s
# kernel1 run in 0.033792 ms, achieved: 621.818 GB/s
# kernel1 run in 0.032768 ms, achieved: 641.25 GB/s
# kernel1 run in 0.032608 ms, achieved: 644.396 GB/s
# kernel1 run in 0.031744 ms, achieved: 661.935 GB/s
# kernel1 run in 0.031744 ms, achieved: 661.935 GB/s
# kernel1 run in 0.032768 ms, achieved: 641.25 GB/s
# kernel1 run in 0.03072 ms, achieved: 684 GB/s
# kernel1 run in 0.031744 ms, achieved: 661.935 GB/s
# kernel1 run in 0.031744 ms, achieved: 661.935 GB/s
# ON MASTER
# kernel1 run in 0.05632 ms, achieved: 373.091 GB/s
# kernel1 run in 0.044032 ms, achieved: 477.209 GB/s
# kernel1 run in 0.044032 ms, achieved: 477.209 GB/s
# kernel1 run in 0.044032 ms, achieved: 477.209 GB/s
# kernel1 run in 0.043808 ms, achieved: 479.649 GB/s
# kernel1 run in 0.043008 ms, achieved: 488.571 GB/s
# kernel1 run in 0.044032 ms, achieved: 477.209 GB/s
# kernel1 run in 0.043008 ms, achieved: 488.571 GB/s
# kernel1 run in 0.043008 ms, achieved: 488.571 GB/s
# kernel1 run in 0.043008 ms, achieved: 488.571 GB/s
```
So this PR gives about 35% improvement in performance using nvfuser executor with this specific normalized shape.
Also this PR fixes https://github.com/pytorch/pytorch/issues/83506 (see the change in `torch/csrc/jit/python/pybind_utils.cpp`).
Ref. https://github.com/pytorch/pytorch/issues/80187
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83508
Approved by: https://github.com/ngimel
New namespace `torch.ops.nvprims` is meant for specific to the nvFuser set of primitives. All `impl_nvfuser` attributes are removed from `torch.ops.prims` functions.
`NvfuserPrimsMode()` context manager can be used for automatic rewrite of `torch.ops.prims` calls to `torch.ops.nvprims` when possible.
The previous way to test whether a prim would be executable with nvFuser was to test `impl_nvfuser is not None`, now all functions in the `torch.ops.nvprims` namespace are supposed to have the `impl_nvfuser` attribute and hence all are executable by nvFuser.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82155
Approved by: https://github.com/jjsjann123, https://github.com/ngimel
### Description
In PyTorch, it's possible to pass unpacked shape to Tensor.view/reshape, and when this call is translated to use refs it caused an error.
Now `refs.reshape` and `refs.view` support passing variable arguments for the shape.
### Testing
Added a simple test `test_reshape_view_method` that compares Tensor.reshape/view to torch._refs.reshape/view result.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82651
Approved by: https://github.com/ngimel
Adds a new context manager `TorchRefsNvfuserCapabilityMode` for conditional rewrite of `torch.*` calls to `torch._refs.*` based on whether the decomposition consisting of prims supports nvFuser execution or not.
A new optional argument for `TorchRefsMode` is added - `should_fallback_fn`, a callable that returns whether the original `torch.foo` or the replacement `torch._refs.foo` should be used.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81764
Approved by: https://github.com/ezyang
This PR does not include an NVFuser frontend cache but it decouples the backed Fusion IR exposure and instead builds it as needed, if there was a cache, by recording the requested definition for replay to start the process of building a Fusion if it doesn't already exist. Another PR will be put up to include the actual caching.
The main change in the Python Frontend is that the NVFuser Fusion IR is not directly defined by the interface. Currently, there is direct connection between the Python API and the creation of the Fusion IR and Object. This means the user defines TensorViews, Scalars, and calls Arith Functions (IR Expressions) on those IR Values. The goal is to disconnect the Python API from directly specifying the Fusion IR and enable caching of the IR so a Fusion Object is not necessarily built every time a Fusion Definition is seen.
The FusionDefinition in Python will mostly look the same except the Definition is now being recorded in a light weight representation called a "Recording" of Records. If the Description is not already cached, the Records are executed to build the Fusion IR. Initially, there is no caching because I am trying to bring up the representation first and get it correctly working.
This is what the Records look like. The records are functors that are called if it is necessary to build the Fusion IR
torch/csrc/jit/codegen/cuda/python_frontend/fusion_record.h
**Tensor Definition Record**
_Note: The Tensor Definition will change for runtime contiguity caching, I am just matching what is already there for now._
```
InputTensorRecord(
std::vector<size_t> _outputs,
std::vector<int64_t> _symbolic_sizes,
std::vector<bool> _contiguous_info,
NvfDataType _dtype)
: RecordFunctor({}, std::move(_outputs)),
symbolic_sizes(std::move(_symbolic_sizes)),
contiguous_info(std::move(_contiguous_info)),
dtype(_dtype) {}
void operator()(FusionDefinition& fd) final {
auto tv = TensorViewBuilder()
.ndims(symbolic_sizes.size())
.contiguity(contiguous_info)
.shape(symbolic_sizes)
.dtype(dtype)
.build();
fd.fusion_state.at(outputs.at(0)) = tv;
fd.addInput(tv);
}
std::vector<int64_t> symbolic_sizes;
std::vector<bool> contiguous_info;
NvfDataType dtype;
};
```
**Generic Templatized Op Record Definition**
Op Records are notable because they record Fusion IR arith functions as the `fusion_op_`.
```
template <class OutType, class... ArgTypes>
struct OpRecord : RecordFunctor {
OpRecord(
std::vector<size_t> _args,
std::vector<size_t> _outputs,
std::function<OutType(ArgTypes...)> fusion_op)
: RecordFunctor(std::move(_args), std::move(_outputs)),
fusion_op_(fusion_op) {}
template <class TupleType, std::size_t... Is>
OutType opFunc(
FusionDefinition& fd,
TupleType& tp,
std::index_sequence<Is...>) {
return fusion_op_(
dynamic_cast<typename std::tuple_element<Is, TupleType>::type>(
fd.fusion_state.at(args.at(Is)))...);
}
void operator()(FusionDefinition& fd) final {
using arg_tuple_t = std::tuple<ArgTypes...>;
auto indices =
std::make_index_sequence<std::tuple_size<arg_tuple_t>::value>();
arg_tuple_t inputs;
auto output = opFunc(fd, inputs, indices);
fd.fusion_state.at(outputs.at(0)) = output;
}
private:
std::function<OutType(ArgTypes...)> fusion_op_;
};
```
Perhaps the most confusing aspect of the Python Frontend is the `FusionDefinition`. The C++ Class that is bound to is very light weight, purposely. In an attempt to make sure users don't have to touch more than one file when adding new ops, assuming an appropriate Record has already been defined, the Python bindings effectively create functions that act on the FusionDefinition and appear as part of the class in Python but are not part of the class in C++.
Here is an example of a Unary Op Macro. It is creating the binding to a lambda function that effectively appears as a FusionDefinition operation in Python. The other way to do this would have been to create a class method directly in the `FusionDefinition` C++ and have a separate binding to that method.
```
#define NVFUSER_PYTHON_BINDING_UNARY_OP(op_str, op_name) \
nvf_ops.def( \
op_str, \
[](nvfuser::FusionDefinition::Operators& self, \
nvfuser::Tensor* input) -> nvfuser::Tensor* { \
nvfuser::Tensor* output = new nvfuser::Tensor( \
self.fusion_definition->recording_state.size()); \
self.fusion_definition->recording_state.emplace_back(output); \
self.fusion_definition->recording.emplace_back( \
new nvfuser::OpRecord<NvfTensorView*, NvfTensorView*>( \
{input->index}, \
{output->index}, \
static_cast<NvfTensorView* (*)(NvfTensorView*)>( \
torch::jit::fuser::cuda::op_name))); \
return output; \
}, \
py::return_value_policy::reference); \
```
Here is the `FusionDefinition` class edited for brevity. The playing of the records will be found under the `exit()` method where exit refers to exiting of the Python Context Manager. A `FusionDefinition` is captured through a context manager like the following:
```
fusion = Fusion()
with FusionDefinition(fusion) as fd :
t0 = fd.define_tensor(sizes=[5], strides=[1])
t1 = fd.ops.abs(t0)
fd.add_output(t1)
```
```
class FusionDefinition {
public:
FusionDefinition(FusionOwner* fusion_owner)
: fusion_owner_(fusion_owner),
prev_fusion_(nullptr),
recording(),
recording_state(),
fusion_state(),
ops(this) {}
// Context Manager Methods
FusionDefinition* enter() {
prev_fusion_ = FusionGuard::getCurFusion();
FusionGuard::setCurFusion(fusionPtr());
return this;
}
void exit() {
// Found in the Python Bindings, currently.
//for (auto& record : recording) {
// auto functor = record.get();
// (*functor)(self);
//}
FusionGuard::setCurFusion(prev_fusion_);
prev_fusion_ = nullptr;
}
void addInput(torch::jit::fuser::cuda::Val* input) {
fusionPtr()->addInput(input);
}
void addOutput(torch::jit::fuser::cuda::Val* output) {
fusionPtr()->addOutput(output);
}
Fusion* fusionPtr() {
return fusion_owner_->fusionPtr();
}
private:
FusionOwner* fusion_owner_;
Fusion* prev_fusion_;
public:
std::vector<std::unique_ptr<RecordFunctor>> recording;
std::vector<std::unique_ptr<State>> recording_state;
std::vector<NvfVal*> fusion_state;
struct Operators {
Operators(FusionDefinition* fd) : fusion_definition(fd) {}
// Python operations are effectively bound here.
FusionDefinition* fusion_definition;
};
Operators ops;
};
```
The Fusion IR doesn’t have `define_tensor` or `define_scalar` functions. I made them up and the name for the Python `FusionDefinition` as a more understandable/convenient way to define input tensors and scalars. `TensorView` objects and Fusion IR `Val` objects are not typically defined outside of a Fusion IR `Expr` output (typically arith function outputs) except for inputs to a graph. Mechanically speaking, there are two things you need to do to define the input in the Fusion IR. You need to define the IR `TensorView`/`Val` object and then record that the IR `TensorView`/`Val` object is an input in the `Fusion` Object that encapsulates the Fusion IR. Since the `FusionDefinition` does not correspond one-to-one with the Fusion IR and `define_tensor` and `define_scalar` are made up functions, I decided to combine the `Val` Object creation and recording of the input in the `Fusion` object in one step to reduce the amount of syntax required to define a Fusion in the python interface.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81578
Approved by: https://github.com/jjsjann123, https://github.com/IvanYashchuk, https://github.com/SherlockNoMad
Currently we have 2 ways of doing the same thing for torch dispatch and function modes:
`with push_torch_dispatch_mode(X)` or `with X.push(...)`
is now the equivalent of doing
`with X()`
This removes the first API (which is older and private so we don't need to go through a deprecation cycle)
There is some risk here that this might land race with a PR that uses the old API but in general it seems like most are using the `with X()` API or `enable_torch_dispatch_mode(X())` which isn't getting removed.
EDIT: left the `with X.push(...)` API since there were ~3 land races with that over the past day or so. But made it give a warning and ask users to use the other API
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78215
Approved by: https://github.com/ezyang
This PR introduces a new nvFuser executor for FX graphs containing different kinds of nodes, not just `torch.ops.prims` supported by nvFuser. The FX graph is partitioned based on whether nodes are supported or not by nvFuser and supported nodes are fused into subgraphs, that's all using Sherlock's work on the partitioner.
This new partitions-based executor with fallbacks to ATen is used by default with `executor="nvfuser"`. And the previous executor can be used with `executor="strictly_nvfuser"`, naming suggestions are welcome!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81043
Approved by: https://github.com/jjsjann123, https://github.com/SherlockNoMad
I also filed while creating this PR.
This PR...
**Filed issues**
- https://github.com/pytorch/pytorch/issues/79818
- https://github.com/pytorch/pytorch/issues/80154
**prims**
- Fixes prims.squeeze when called with an unsorted list of dimensions
- Removes the clone prim
**refs**
- adds contiguous
- adds expand
- updates clone to call empty_like and copy_to
- updates empty to accept a memory format
- updates empty_like to accept a memory_format
**utils**
- adds helper functions for working with memory formats and channels last tensors, in particular
**tests**
- removes unused clamp sample input functions (mooted by clamp's new reference inputs)
- extends the reference inputs for clone to include different memory formats
- creates reference inputs for contiguous
- xfails operators that depend on clone (including clone) on `test_python_ref` (see issues)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79820
Approved by: https://github.com/ngimel
In the current setup for each call of the `execute` function, a `Fusion` object was constructed using `GraphModule` and args, that's expensive.
This PR makes use of `functools.lru_cache` to pay the `Fusion` creation cost once per `GraphModule` and set of args. Currently, the shape, strides, and dtype of tensors are static it can be changed later to make better use of the nvFuser's internal caching mechanism (by specifying only ndim, contiguity, dtype).
On master:
```py
In [2]: a = torch.randn(3, 3, device='cuda')
In [3]: with TorchRefsMode.push():
...: gm = make_fx(lambda x: torch.sigmoid(x))(a)
...:
In [4]: %%timeit
...: execute(gm, a, executor="nvfuser")
...: torch.cuda.synchronize()
175 ms ± 1.18 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
```
This PR:
```py
In [2]: a = torch.randn(3, 3, device='cuda')
In [3]: with TorchRefsMode.push():
...: gm = make_fx(lambda x: torch.sigmoid(x))(a)
...:
In [4]: %%timeit
...: execute(gm, a, executor="nvfuser")
...: torch.cuda.synchronize()
62.6 µs ± 9.99 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
```
In addition, this PR adds support for pytree inputs and extends the test for this.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80525
Approved by: https://github.com/kevinstephano, https://github.com/jjsjann123, https://github.com/SherlockNoMad
This PR fixes a bug with `broadcast_in_dim` leading to the situation when reduction ops were not allowed to be used before `broadcast_in_dim`.
With this PR it's possible to run
```py
import torch
import torch._refs
from torch._prims.executor import make_traced
def foo(a):
return torch._refs.mean(a, keepdim=False)
a = torch.randn(3, 3, device='cuda')
make_traced(foo)(a, executor="nvfuser")
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79444
Approved by: https://github.com/mruberry, https://github.com/jjsjann123
This PR lifts the restriction that the output of a function traced with `make_traced` and executed with nvFuser must be a single tensor. Now it's possible to return a "pytree", a tensor's nested data structure (see https://github.com/pytorch/pytorch/blob/master/torch/utils/_pytree.py).
I added a test with a function that returns a tuple of two objects where one of the objects is a dictionary with a tensor value.
```py
def fn(a, b):
d = {}
d["c"] = torch.add(a, b)
return (d, torch.add(a, d["c"]))
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78802
Approved by: https://github.com/mruberry
This PR adds `test_nvfuser_impl_is_used` that checks that the corresponding nvfuser op (if available) is used in the prim definition.
Adds `impl_nvfuser=` for atan2, bitwise_and, bitwise_or, bitwise_xor, eq, ne, pow, sub, sum, where, rsqrt, lgamma.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78493
Approved by: https://github.com/mruberry
This makes prims look as if they were defined in native_functions.yaml
but they're still all written in Python. You now need to give a full
schema string for your prims. The returned prim object is now
torch.ops.prim overload (prims are not allowed to be overloaded,
so we return the overload, not the overload packet, for speed.)
Signed-off-by: Edward Z. Yang <ezyangfb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77117
Approved by: https://github.com/mruberry, https://github.com/albanD
This PR primarily addresses augmenting the frontend to properly support `broadcast_in_dim`. This required make a new version of the `define_tensor()` that takes in the `size` and `strides` of input tensors in order to properly determine broadcasts.
This PR also has a fix for the `python_example.py` that broke when a new argument was added to reductions to allow the user to specify an output Data Type.
`define_tensor()` Interface Example:
```
fusion2 = Fusion()
input1 = torch.ones(1, 1, 4, device='cuda')
input2 = torch.ones(2, 3, 4, device='cuda')
with FusionDefinition(fusion2) as fd :
t0 = fd.define_tensor(sizes=input1.size(), strides=input1.stride())
t1 = fd.define_tensor(sizes=input2.size(), strides=input2.stride())
fd.add_input(t0)
fd.add_input(t1)
t0_b = fd.Ops.broadcast_in_dim(t0, [2, 3, 4], [0, 1, 2])
print("Broadcast TensorView", t0_b)
t2 = fd.Ops.add(t0_b, t1)
fd.add_output(t2)
```
Print statement of defined broadcast tensor:
```
Broadcast TensorView T2_l[ sbS6{1}, sbS7{1}, iS8{i2} ] DataType: float Contiguity: ttt
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76790
Approved by: https://github.com/mruberry, https://github.com/jjsjann123
This adds prototype nvFuser integration for the following prims:
- broadcast_in_dim
- convert_element_type
- add
- div
- ge
- gt
- le
- lt
- mul
Adding it for additional prims supported by nvFuser's prototype Python frontend should be easy.
This also adds a new sugar to run operations using the ATen or nvFuser trace executors. For example:
```
def foo(a, b):
return torch.add(a, b)
traced_foo = make_traced(foo)
a = torch.randn((1, 2, 3, 4, 5), device='cuda')
b = torch.randn((1, 2, 3, 4, 5), device='cuda')
result = traced_foo(a, b, executor='nvfuser')
```
Currently only operations with tensor inputs and one tensor output are supported, and the operation must be composed exclusively of reference or prim operations.
Finally, this adds a new test, test_prims.py, that just tests the broadcast_in_dim prim for now. In the future we'll likely have OpInfos for each prim, but we'll need a reference implementation of broadcast_in_dim to make that interesting.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76560
Approved by: https://github.com/ngimel