pytorch/torch/_prims/executor.py
Mike Ruberry fe1968dea0 [primTorch] Prototype nvFuser integration and test_prims.py
This adds prototype nvFuser integration for the following prims:

- broadcast_in_dim
- convert_element_type
- add
- div
- ge
- gt
- le
- lt
- mul

Adding it for additional prims supported by nvFuser's prototype Python frontend should be easy.

This also adds a new sugar to run operations using the ATen or nvFuser trace executors. For example:

```
def foo(a, b):
  return torch.add(a, b)

traced_foo = make_traced(foo)

a = torch.randn((1, 2, 3, 4, 5), device='cuda')
b = torch.randn((1, 2, 3, 4, 5), device='cuda')
result = traced_foo(a, b, executor='nvfuser')
```

Currently only operations with tensor inputs and one tensor output are supported, and the operation must be composed exclusively of reference or prim operations.

Finally, this adds a new test, test_prims.py, that just tests the broadcast_in_dim prim for now. In the future we'll likely have OpInfos for each prim, but we'll need a reference implementation of broadcast_in_dim to make that interesting.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76560
Approved by: https://github.com/ngimel
2022-04-29 02:02:25 +00:00

124 lines
3.9 KiB
Python

from typing import Callable
import torch
from torch.fx import GraphModule
from torch._prims.utils import TensorMeta
from torch._prims.context import PrimContext
import torch._C._nvfuser as nvfuser # type: ignore[import]
DataType = nvfuser.DataType # type: ignore[attr-defined]
Fusion = nvfuser.Fusion # type: ignore[attr-defined]
FusionDefinition = nvfuser.FusionDefinition # type: ignore[attr-defined]
# TODO: refactor me into a common place
_torch_dtype_to_nvfuser_dtype_map = {
torch.cdouble: DataType.ComplexDouble,
torch.cfloat: DataType.ComplexFloat,
torch.double: DataType.Double,
torch.float: DataType.Float,
torch.half: DataType.Half,
torch.bfloat16: DataType.BFloat16,
torch.long: DataType.Int,
torch.int: DataType.Int32,
torch.bool: DataType.Bool,
}
def execute(ctx: PrimContext, *args, executor: str = "aten", **kwargs):
"""
Prototype ATen executor.
Just executes the context's graph.
"""
if executor == "aten":
gm = GraphModule({}, ctx.graph)
return gm.forward(*args, **kwargs)
elif executor == "nvfuser":
# PROTOTYPE nvfuser executor
# Only accepts tensor inputs and single tensor outputs
# Does not handle kwargs
# Does not support reusing the same ctx to execute!
assert len(kwargs) == 0
# TODO: make this a proper trace -> trace transform that
# doesn't mutate the context
graph_fd = ctx.graph.placeholder("fd")
ctx.graph._root.append(graph_fd)
fusion = Fusion()
with FusionDefinition(fusion) as fd:
# Transforms graph to call nvfuser lowerings
nv_args = [fd]
for arg in args:
if isinstance(arg, torch.Tensor):
x = fd.define_tensor(
arg.ndim, _torch_dtype_to_nvfuser_dtype_map[arg.dtype]
)
fd.add_input(x)
nv_args.append(x)
else:
nv_args.append(x)
for x in ctx.graph.nodes:
if x.op == "call_function":
x.target = x.target.impl_nvfuser
x.args = (graph_fd,) + x.args
gm = GraphModule({}, ctx.graph)
out = gm.forward(*nv_args)
fd.add_output(out)
return fusion.execute(
tuple(arg for arg in args if isinstance(arg, torch.Tensor))
)[0]
msg = "Received unexpected value for 'executor': {0}. Allowed values are: aten, nvfuser.".format(
executor
)
raise ValueError(msg)
def make_traced(fn: Callable):
"""
Returns a function that, when called, will
trace its torch operations to prims and then
execute those prims on the requested trace executor
(possibly lowering them to that trace executor first).
Only supports the torch operations defined in _torch_to_reference_map
in context.py and operations with positional args. All args must
be tensors and the function must return a single tensor. In the
near future all these restrictions will be lifted.
Example usage:
def foo(a, b):
return torch.add(a, b)
traced_foo = make_traced(foo)
a = torch.randn((1, 2, 3, 4, 5), device='cuda')
b = torch.randn((1, 2, 3, 4, 5), device='cuda')
result = traced_foo(a, b, executor='nvfuser')
Executor may be either 'aten' or 'nvfuser'.
"""
def _traced(*args, executor="aten"):
ctx = PrimContext()
with ctx:
placeholders = []
for arg in args:
if isinstance(arg, torch.Tensor):
placeholders.append(ctx.placeholder(TensorMeta(arg)))
else:
placeholders.append(ctx.placeholder(arg))
result = fn(*placeholders)
ctx.output(result)
return execute(ctx, *args, executor=executor)
return _traced