pytorch/torch/_prims/executor.py
Ivan Yashchuk ec67c6abbe Add torch.ops.nvprims namespace for nvFuser-specific prims (#82155)
New namespace `torch.ops.nvprims` is meant for specific to the nvFuser set of primitives. All `impl_nvfuser` attributes are removed from `torch.ops.prims` functions.

`NvfuserPrimsMode()` context manager can be used for automatic rewrite of `torch.ops.prims` calls to `torch.ops.nvprims` when possible.

The previous way to test whether a prim would be executable with nvFuser was to test `impl_nvfuser is not None`, now all functions in the `torch.ops.nvprims` namespace are supposed to have the `impl_nvfuser` attribute and hence all are executable by nvFuser.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/82155
Approved by: https://github.com/jjsjann123, https://github.com/ngimel
2022-08-04 16:51:56 +00:00

74 lines
2.1 KiB
Python

from typing import Callable
from torch._prims.context import NvfuserPrimsMode, TorchRefsMode
from torch._prims.nvfuser_executor import nvfuser_execute, nvfuser_execute_partitioned
from torch.fx import GraphModule
from torch.fx.experimental.proxy_tensor import make_fx
def execute(gm: GraphModule, *args, executor: str = "aten"):
"""
Prototype ATen executor.
Just executes the context's graph.
"""
if executor == "aten":
return gm.forward(*args)
elif executor == "nvfuser":
return nvfuser_execute_partitioned(gm, *args)
elif executor == "strictly_nvfuser":
return nvfuser_execute(gm, *args)
msg = "Received unexpected value for 'executor': {0}. Allowed values are: aten, nvfuser.".format(
executor
)
raise ValueError(msg)
def make_traced(fn: Callable):
"""
Returns a function that, when called, will
trace its torch operations to prims and then
execute those prims on the requested trace executor
(possibly lowering them to that trace executor first).
Only supports the torch operations defined in _torch_to_reference_map
in context.py and operations with positional args. All args must
be tensors.
In the near future all these restrictions will be lifted.
Example usage:
def foo(a, b):
return torch.add(a, b)
traced_foo = make_traced(foo)
a = torch.randn((1, 2, 3, 4, 5), device='cuda')
b = torch.randn((1, 2, 3, 4, 5), device='cuda')
result = traced_foo(a, b, executor='nvfuser')
Executor may be either 'aten' or 'nvfuser'.
"""
def _traced(*args, executor="aten", **kwargs):
# TODO: caching
nargs = len(args)
fn_kwargs = kwargs
flat_fn_kwargs = list(fn_kwargs.values())
all_args = list(args) + flat_fn_kwargs
def wrapped(args):
fn_args = args[:nargs]
kwargs_keys = list(fn_kwargs.keys())
kwargs = dict(zip(kwargs_keys, args[nargs:]))
return fn(*fn_args, **kwargs)
with NvfuserPrimsMode(), TorchRefsMode():
gm = make_fx(wrapped)(all_args)
return execute(gm, all_args, executor=executor)
return _traced