mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
New namespace `torch.ops.nvprims` is meant for specific to the nvFuser set of primitives. All `impl_nvfuser` attributes are removed from `torch.ops.prims` functions. `NvfuserPrimsMode()` context manager can be used for automatic rewrite of `torch.ops.prims` calls to `torch.ops.nvprims` when possible. The previous way to test whether a prim would be executable with nvFuser was to test `impl_nvfuser is not None`, now all functions in the `torch.ops.nvprims` namespace are supposed to have the `impl_nvfuser` attribute and hence all are executable by nvFuser. Pull Request resolved: https://github.com/pytorch/pytorch/pull/82155 Approved by: https://github.com/jjsjann123, https://github.com/ngimel
74 lines
2.1 KiB
Python
74 lines
2.1 KiB
Python
from typing import Callable
|
|
|
|
from torch._prims.context import NvfuserPrimsMode, TorchRefsMode
|
|
from torch._prims.nvfuser_executor import nvfuser_execute, nvfuser_execute_partitioned
|
|
|
|
from torch.fx import GraphModule
|
|
from torch.fx.experimental.proxy_tensor import make_fx
|
|
|
|
|
|
def execute(gm: GraphModule, *args, executor: str = "aten"):
|
|
"""
|
|
Prototype ATen executor.
|
|
|
|
Just executes the context's graph.
|
|
"""
|
|
|
|
if executor == "aten":
|
|
return gm.forward(*args)
|
|
elif executor == "nvfuser":
|
|
return nvfuser_execute_partitioned(gm, *args)
|
|
elif executor == "strictly_nvfuser":
|
|
return nvfuser_execute(gm, *args)
|
|
|
|
msg = "Received unexpected value for 'executor': {0}. Allowed values are: aten, nvfuser.".format(
|
|
executor
|
|
)
|
|
raise ValueError(msg)
|
|
|
|
|
|
def make_traced(fn: Callable):
|
|
"""
|
|
Returns a function that, when called, will
|
|
trace its torch operations to prims and then
|
|
execute those prims on the requested trace executor
|
|
(possibly lowering them to that trace executor first).
|
|
|
|
Only supports the torch operations defined in _torch_to_reference_map
|
|
in context.py and operations with positional args. All args must
|
|
be tensors.
|
|
In the near future all these restrictions will be lifted.
|
|
|
|
Example usage:
|
|
|
|
def foo(a, b):
|
|
return torch.add(a, b)
|
|
|
|
traced_foo = make_traced(foo)
|
|
|
|
a = torch.randn((1, 2, 3, 4, 5), device='cuda')
|
|
b = torch.randn((1, 2, 3, 4, 5), device='cuda')
|
|
result = traced_foo(a, b, executor='nvfuser')
|
|
|
|
Executor may be either 'aten' or 'nvfuser'.
|
|
"""
|
|
|
|
def _traced(*args, executor="aten", **kwargs):
|
|
# TODO: caching
|
|
nargs = len(args)
|
|
fn_kwargs = kwargs
|
|
flat_fn_kwargs = list(fn_kwargs.values())
|
|
all_args = list(args) + flat_fn_kwargs
|
|
|
|
def wrapped(args):
|
|
fn_args = args[:nargs]
|
|
kwargs_keys = list(fn_kwargs.keys())
|
|
kwargs = dict(zip(kwargs_keys, args[nargs:]))
|
|
return fn(*fn_args, **kwargs)
|
|
|
|
with NvfuserPrimsMode(), TorchRefsMode():
|
|
gm = make_fx(wrapped)(all_args)
|
|
return execute(gm, all_args, executor=executor)
|
|
|
|
return _traced
|