mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
- Don't copy inputs in cudagraphs wrapping, since the copies will distorts timing and triton do_bench will clear cache anyway - Don't skip op if there is a fallback, since we have both fallbacks and lowerings for some ops - Add option for channels last Pull Request resolved: https://github.com/pytorch/pytorch/pull/103110 Approved by: https://github.com/desertfire
185 lines
6.0 KiB
Python
185 lines
6.0 KiB
Python
import logging
|
|
import operator
|
|
from collections import defaultdict
|
|
from typing import Set
|
|
|
|
import torch
|
|
|
|
from torch.fx import GraphModule
|
|
from torch.fx.passes.backends.cudagraphs import partition_cudagraphs
|
|
from torch.multiprocessing.reductions import StorageWeakRef
|
|
from torch.nn import Module
|
|
from torch.utils._pytree import tree_map
|
|
from .common import aot_autograd
|
|
from .registry import register_backend
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
def cloner(t):
|
|
if isinstance(t, torch.Tensor):
|
|
return t.clone()
|
|
else:
|
|
return t
|
|
|
|
|
|
class CudaGraphModule(Module):
|
|
gm: GraphModule
|
|
mutated_inputs: Set[int]
|
|
|
|
def __init__(self, gm, mutated_inputs):
|
|
super().__init__()
|
|
self.gm = gm
|
|
self.mutated_inputs = mutated_inputs
|
|
|
|
warmed_up = False
|
|
|
|
# these are all None or all filled
|
|
graph = None
|
|
static_inputs = None
|
|
static_outputs = None
|
|
|
|
# NB: we override __call__ as we don't need any nn.Module machinery
|
|
# and to reduce overhead
|
|
def __call__(self, *args):
|
|
# TODO: once we've recorded here, we'd like to replace the __call__
|
|
# implementation with compiled bytecode that copies into static, replays
|
|
# the cuda graph, then copies out. First condition is the hotpath,
|
|
# needs optimizing
|
|
if self.graph is not None:
|
|
assert len(args) == len(self.static_inputs)
|
|
for dst, src in zip(self.static_inputs, args):
|
|
dst.copy_(src)
|
|
self.graph.replay()
|
|
for i in self.mutated_inputs:
|
|
args[i].copy_(self.static_inputs[i])
|
|
return tree_map(cloner, self.static_outputs)
|
|
|
|
elif self.warmed_up:
|
|
# record
|
|
self.static_inputs = [x.clone() for x in args]
|
|
self.graph = torch.cuda.CUDAGraph()
|
|
with torch.cuda.graph(self.graph):
|
|
self.static_outputs = self.gm(*self.static_inputs)
|
|
# NB: recording doesn't actually run the operations, so
|
|
# now we immediately replay the graph to serve up the result
|
|
self.graph.replay()
|
|
for i in self.mutated_inputs:
|
|
args[i].copy_(self.static_inputs[i])
|
|
return tree_map(cloner, self.static_outputs)
|
|
|
|
else:
|
|
# warmup
|
|
stream = torch.cuda.Stream()
|
|
stream.wait_stream(torch.cuda.current_stream())
|
|
with torch.cuda.stream(stream):
|
|
r = self.gm(*args)
|
|
torch.cuda.current_stream().wait_stream(stream)
|
|
self.warmed_up = True
|
|
return r
|
|
|
|
|
|
# Interpreter versions of these passes can be found at
|
|
# https://gist.github.com/ezyang/df2d746cac3b2c7d55c181e37c57ef23
|
|
|
|
|
|
def find_input_mutations(g):
|
|
def meta_fk(meta):
|
|
return meta["val"] if "val" in meta else meta["fake_result"]
|
|
|
|
inputs = defaultdict(set)
|
|
input_idx = 0
|
|
mutated_inputs = set()
|
|
for n in g.nodes:
|
|
if n.op == "placeholder":
|
|
inputs[StorageWeakRef(meta_fk(n.meta)._typed_storage())].add(input_idx)
|
|
input_idx += 1
|
|
elif n.op == "call_function":
|
|
if n.target is operator.getitem:
|
|
continue
|
|
schema = n.target._schema
|
|
for i, arg in enumerate(schema.arguments):
|
|
if i < len(n.args):
|
|
argument = n.args[i]
|
|
else:
|
|
if arg.name not in n.kwargs:
|
|
continue
|
|
argument = n.kwargs[arg.name]
|
|
mut_arg = False
|
|
if arg.alias_info:
|
|
if arg.alias_info.is_write:
|
|
mut_arg = True
|
|
if mut_arg:
|
|
# TODO: not correct for args that contain tensors in a struct
|
|
# like list
|
|
mutated_inputs |= inputs[
|
|
StorageWeakRef(meta_fk(argument.meta)._typed_storage())
|
|
]
|
|
# TODO: error on unrecognized nodes
|
|
return mutated_inputs
|
|
|
|
|
|
# Mutates input graph
|
|
def apply_cuda_graphs(gm):
|
|
for n in gm.graph.nodes:
|
|
if n.op == "call_module":
|
|
assert not n.kwargs
|
|
submod = gm.get_submodule(n.target)
|
|
gm.delete_submodule(n.target)
|
|
mutated_inputs = find_input_mutations(submod.graph)
|
|
gm.add_submodule(n.target, CudaGraphModule(submod, mutated_inputs))
|
|
# NB: we didn't actually change the graph, no need for recompile
|
|
|
|
|
|
def cudagraphs(model, inputs):
|
|
model = partition_cudagraphs(model, inputs)
|
|
apply_cuda_graphs(model)
|
|
return model
|
|
|
|
|
|
aot_cudagraphs = aot_autograd(fw_compiler=cudagraphs, bw_compiler=cudagraphs)
|
|
|
|
# aot_cudagraphs only applies CUDA graphs to the graph. It is also helpful
|
|
# for debugging and can serve as a perf baseline.
|
|
# TODO(jansel): rename to just "cudagraphs"?
|
|
register_backend(name="cudagraphs", compiler_fn=aot_cudagraphs)
|
|
|
|
|
|
def cudagraphs_inner(model, inputs, copy_outputs=True, copy_inputs=True):
|
|
"""This isn't registered as a backend, but is used in some benchmarks"""
|
|
assert isinstance(inputs, (list, tuple))
|
|
if copy_inputs:
|
|
static_inputs = [torch.zeros_like(x) for x in inputs]
|
|
else:
|
|
static_inputs = list(inputs)
|
|
|
|
# warmup
|
|
torch.cuda.synchronize()
|
|
stream = torch.cuda.Stream()
|
|
stream.wait_stream(torch.cuda.current_stream())
|
|
with torch.cuda.stream(stream):
|
|
model(*inputs)
|
|
stream.synchronize()
|
|
torch.cuda.current_stream().wait_stream(stream)
|
|
torch.cuda.synchronize()
|
|
|
|
# record
|
|
graph = torch.cuda.CUDAGraph()
|
|
with torch.cuda.graph(graph, stream=stream):
|
|
static_outputs = model(*static_inputs)
|
|
if not isinstance(static_outputs, (list, tuple)):
|
|
static_outputs = (static_outputs,)
|
|
|
|
def run(*new_inputs):
|
|
assert len(static_inputs) == len(new_inputs)
|
|
if copy_inputs:
|
|
for dst, src in zip(static_inputs, new_inputs):
|
|
dst.copy_(src)
|
|
graph.replay()
|
|
if copy_outputs:
|
|
return [x.clone() for x in static_outputs]
|
|
else:
|
|
return static_outputs
|
|
|
|
return run
|