pytorch/torch/_prims/context.py
Ivan Yashchuk 3aae6ff1e1 Add nvprims.var_mean (#83508)
This PR adds nvfuser-specific primitive - `var_mean`.
Interpretation `torch.var_mean` -> `torch.ops.nvprims.var_mean` is handled by `TorchRefsNvfuserCapabilityMode` context manager.

I moved some helper code from `_prims/__init__.py` to `_prims_common`. Correctness is tested with OpInfo tests (see `PythonRefInfo("ops.nvprims.var_mean"`).

Layer norm reference now uses `torch.var_mean` instead of `torch._refs.var_mean` to allow interception. Here's a simple comparison of performance with this PR and master (on 3080ti):
```py
import torch
from torch._prims.context import TorchRefsNvfuserCapabilityMode
from torch.fx.experimental.proxy_tensor import make_fx
from torch._prims.executor import execute

def func(a):
    return torch.native_layer_norm(a, (1024,), None, None, 1e-6)

a = torch.randn(10, 512, 1024, dtype=torch.float16, device="cuda")

with TorchRefsNvfuserCapabilityMode():
    gm = make_fx(func)(a)

for _ in range(10):
    execute(gm, a, executor="strictly_nvfuser");
```
run with `PYTORCH_NVFUSER_DUMP=dump_eff_bandwidth python script.py`
```py
# WITH THIS PR
# kernel1 run in 0.032768 ms, achieved: 641.25 GB/s
# kernel1 run in 0.033792 ms, achieved: 621.818 GB/s
# kernel1 run in 0.032768 ms, achieved: 641.25 GB/s
# kernel1 run in 0.032608 ms, achieved: 644.396 GB/s
# kernel1 run in 0.031744 ms, achieved: 661.935 GB/s
# kernel1 run in 0.031744 ms, achieved: 661.935 GB/s
# kernel1 run in 0.032768 ms, achieved: 641.25 GB/s
# kernel1 run in 0.03072 ms, achieved: 684 GB/s
# kernel1 run in 0.031744 ms, achieved: 661.935 GB/s
# kernel1 run in 0.031744 ms, achieved: 661.935 GB/s

# ON MASTER
# kernel1 run in 0.05632 ms, achieved: 373.091 GB/s
# kernel1 run in 0.044032 ms, achieved: 477.209 GB/s
# kernel1 run in 0.044032 ms, achieved: 477.209 GB/s
# kernel1 run in 0.044032 ms, achieved: 477.209 GB/s
# kernel1 run in 0.043808 ms, achieved: 479.649 GB/s
# kernel1 run in 0.043008 ms, achieved: 488.571 GB/s
# kernel1 run in 0.044032 ms, achieved: 477.209 GB/s
# kernel1 run in 0.043008 ms, achieved: 488.571 GB/s
# kernel1 run in 0.043008 ms, achieved: 488.571 GB/s
# kernel1 run in 0.043008 ms, achieved: 488.571 GB/s
```
So this PR gives about 35% improvement in performance using nvfuser executor with this specific normalized shape.

Also this PR fixes https://github.com/pytorch/pytorch/issues/83506 (see the change in `torch/csrc/jit/python/pybind_utils.cpp`).

Ref. https://github.com/pytorch/pytorch/issues/80187

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83508
Approved by: https://github.com/ngimel
2022-08-28 18:45:25 +00:00

247 lines
8.6 KiB
Python

import functools
from contextlib import nullcontext
from typing import Any, Callable, Dict, Sequence, Union
import torch
import torch._decomp
import torch._prims
import torch._refs
import torch._refs.nn
import torch._refs.nn.functional
import torch._refs.special
import torch.overrides
from torch._prims.nvfuser_executor import NvfuserPrimOperatorSupport
from torch._prims_common import torch_function_passthrough
from torch.fx.experimental.proxy_tensor import get_isolated_graphmodule
@functools.lru_cache(None)
def torch_to_refs_map():
"""
Mapping of torch API functions to torch._refs functions.
E.g. torch_to_refs_map()[torch.add] == torch._refs.add
"""
modules = [
(torch, torch._refs),
(torch.nn, torch._refs.nn),
(torch.nn.functional, torch._refs.nn.functional),
(torch.special, torch._refs.special),
(torch.fft, torch._refs.fft),
(torch.linalg, torch._refs.linalg),
]
r: Dict[Any, Any] = {
torch.Tensor.__invert__: torch._refs.bitwise_not,
torch.Tensor.__xor__: torch._refs.bitwise_xor,
torch.Tensor.__and__: torch._refs.bitwise_and,
torch.Tensor.__or__: torch._refs.bitwise_or,
torch.Tensor.__eq__: torch._refs.eq,
torch.Tensor.__rsub__: torch._refs.rsub,
torch.Tensor.__rtruediv__: torch._refs.rtruediv,
torch.Tensor.__floordiv__: torch._refs.floor_divide,
torch.Tensor.__rfloordiv__: torch._refs.rfloordiv,
torch.Tensor.__pow__: torch._refs.pow,
torch.Tensor.__rpow__: torch._refs.rpow,
torch.Tensor.new_empty: torch._refs.new_empty,
torch.Tensor.new_full: torch._refs.new_full,
torch.Tensor.new_zeros: torch._refs.new_zeros,
torch.Tensor.new_ones: torch._refs.new_ones,
torch.Tensor.fill_: torch._refs.fill_,
torch.Tensor.zero_: torch._refs.zero_,
torch.Tensor.to: torch._refs.to,
# TODO: Should these methods be mapped some other way?
torch.Tensor.copy_: torch._prims.copy_to,
torch.Tensor.resize: torch._prims.resize,
}
for mod_torch, mod_refs in modules:
for s in mod_refs.__all__: # type: ignore[attr-defined]
r[mod_torch.__dict__.get(s)] = mod_refs.__dict__.get(s)
# Support remapping torch.Tensor.foo to _refs.foo
for s in dir(torch.Tensor):
if s in torch._refs.__all__:
r[getattr(torch.Tensor, s)] = torch._refs.__dict__.get(s)
return r
@functools.lru_cache(None)
def nvfuser_decomp_table():
"""
decomposition table needed for nvfuser
"""
aten = torch.ops.aten
nvfuser_decompositions: Sequence[
Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket]
] = { # type: ignore[assignment]
# AMP calls `to` in C++, which is not handled by torch mapping
aten._to_copy,
}
from torch._decomp import get_decompositions
decomp_table = get_decompositions(nvfuser_decompositions)
return decomp_table
@functools.lru_cache(None)
def all_prims():
"""
Set of all prim functions, e.g., torch._prims.add in all_prims()
"""
return {torch._prims.__dict__.get(s) for s in torch._prims.__all__}
class NvfuserPrimsMode(torch.overrides.TorchFunctionMode):
"""
Switches the interpretation of torch.ops.prims.* functions to
use nvFuser's prims in torch.ops.nvprims.*
>>> # xdoctest: +SKIP("undefined vars")
>>> with NvfuserPrimsMode():
... torch.ops.prims.add(x, y) # calls torch.ops.nvprims.add(x, y)
By default, this context manager will fall back on the torch.ops.prims* if the
nvprim does not exist.
"""
def __torch_function__(
self,
orig_func: Callable,
types: Sequence,
args: Sequence[Any] = (),
kwargs: Dict = None,
):
if kwargs is None:
kwargs = {}
if isinstance(orig_func, torch._ops.OpOverload) or isinstance(
orig_func, torch._ops.OpOverloadPacket
):
namespace = str(orig_func).split(".")[0]
name = str(orig_func).split(".")[1]
if namespace == "prims":
nvfunc = getattr(torch.ops.nvprims, name, None)
if nvfunc is not None:
return nvfunc(*args, **kwargs)
return orig_func(*args, **kwargs)
class TorchRefsMode(torch.overrides.TorchFunctionMode):
"""
Switches the interpretation of torch.* functions and Tensor methods to
use PrimTorch refs in torch._refs. (Direct calls to _refs are unaffected.)
>>> # xdoctest: +SKIP
>>> with TorchRefsMode():
... torch.add(x, y) # calls torch._refs.add(x, y)
By default, this context manager will fall back on the torch.* if the
ref does not exist; set strict=True to error if this occurs.
If the ref exists we still would like to fall back on the torch.* sometimes,
this behavior can be customized by passing a function to should_fallback_fn.
"""
def __init__(
self,
strict=False,
should_fallback_fn=lambda *_: False,
prims_mode_cls=nullcontext,
):
self.strict = strict
self.should_fallback_fn = should_fallback_fn
self.prims_mode_cls = prims_mode_cls
def __torch_function__(
self,
orig_func: Callable,
types: Sequence,
args: Sequence[Any] = (),
kwargs: Dict = None,
):
if kwargs is None:
kwargs = {}
# For primitive operations, run them as is without interception
# Unless we are in prims_mode, in which case we want to use nvprims
if orig_func in torch_function_passthrough or orig_func in all_prims():
with self.prims_mode_cls():
return orig_func(*args, **kwargs)
mapping = torch_to_refs_map()
func = mapping.get(orig_func, None)
# For torch.ops.aten.*, use registered decompositions from torch._decomp
# torch._decomp.decomposition_table provides a mapping from
# torch.ops.aten.* to torch._refs or torch._decomp.decompositions
# implementations.
# There're other ways to implement this functionality,
# see https://github.com/pytorch/pytorch/pull/82657#discussion_r939776417
if func is None and isinstance(orig_func, torch._ops.OpOverload):
func = torch._decomp.decomposition_table.get(orig_func, None)
if func is not None:
# If the ref exists query whether we should use it or not
if self.should_fallback_fn(self, func, args, kwargs):
return orig_func(*args, **kwargs)
# torch calls inside func should be interpreted as refs calls
with torch.overrides.enable_torch_function_mode(self, replace=self.inner):
return func(*args, **kwargs)
if self.strict:
raise RuntimeError(
f"no _refs support for {torch.overrides.resolve_name(orig_func)}"
)
return orig_func(*args, **kwargs)
def _is_node_supported_nvfuser(node):
return (
node.op == "call_function"
and getattr(node.target, "impl_nvfuser", None) is not None
)
def _is_func_unsupported_nvfuser(torch_function_mode, func, args, kwargs):
with torch.overrides.enable_torch_function_mode(
torch_function_mode, replace=torch_function_mode.inner
):
gm = get_isolated_graphmodule(func, args, kwargs)
supported_ops = NvfuserPrimOperatorSupport()
call_function_nodes = filter(lambda n: n.op == "call_function", gm.graph.nodes)
any_unsupported = any(
not supported_ops.is_node_supported(None, node) for node in call_function_nodes
)
return any_unsupported
class TorchRefsNvfuserCapabilityMode(TorchRefsMode):
def __init__(self):
super().__init__(
strict=False,
should_fallback_fn=_is_func_unsupported_nvfuser,
prims_mode_cls=NvfuserPrimsMode,
)
def _is_var_mean(self, func):
return "torch.var_mean" == torch.overrides.resolve_name(func) or (
(
isinstance(func, torch._ops.OpOverload)
or isinstance(func, torch._ops.OpOverloadPacket)
)
and "aten.var_mean" in str(func)
)
def __torch_function__(
self,
orig_func: Callable,
types: Sequence,
args: Sequence[Any] = (),
kwargs: Dict = None,
):
if kwargs is None:
kwargs = {}
# First we intercept calls for nvfuser-specific prims bypassing generic torch._refs
if self._is_var_mean(orig_func):
return torch.ops.nvprims.var_mean(*args, **kwargs)
# Then we use TorchRefsMode to interpret the rest
return super().__torch_function__(orig_func, types, args, kwargs)