pytorch/torch/fx/passes/backends/cudagraphs.py
Elias Ellison 8bd9fe3f49 Changes to prepare for fake tensors on in functorch by default (#84432)
Fixes some errors you run into in dynamo when turning on fake tensors. I'm waiting on flipping the switch because I need to also get some fixes into dynamo + do benchmarking.

I could manually turn off fake tensors in functorch in dynamo, and then turn it on here if requested, although the changes here are pretty minimal.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84432
Approved by: https://github.com/Chillee
2022-09-08 04:29:30 +00:00

57 lines
2.0 KiB
Python

import torch
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
from torch.fx.passes.operator_support import OperatorSupport
from torch.fx.passes.tools_common import CALLABLE_NODE_OPS
from torch.fx.passes.fake_tensor_prop import FakeTensorProp
from torch.utils._pytree import tree_map
import operator
class CudaGraphsSupport(OperatorSupport):
# TODO: why is submodules passed here
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
if node.op not in CALLABLE_NODE_OPS:
return False
if node.target in [torch.ops.aten.embedding_dense_backward.default]:
return False
if node.target in [operator.getitem]:
return True
found_not_cuda = False
def meta_fk(meta):
return meta["val"] if "val" in meta else meta["fake_result"]
def find_not_cuda(t):
nonlocal found_not_cuda
if isinstance(t, torch.Tensor) and t.device.type != 'cuda':
found_not_cuda = True
for n in node.all_input_nodes:
tree_map(find_not_cuda, meta_fk(n.meta))
tree_map(find_not_cuda, meta_fk(node.meta))
# NB: factory function is accounted for because the result would be
# cpu or cuda
return not found_not_cuda
def partition_cudagraphs(gm, inputs):
"""
Partition an FX graph into sub-GraphModules that can be validly run under
CUDA graphs. For a subgraph to be runnable under CUDA, all of the operations
must involve CUDA tensors only/
"""
FakeTensorProp(gm).propagate(*inputs)
supported_ops = CudaGraphsSupport()
# TODO: single node partition may be wrong due to the pessimization
# from copying in and out the data. Check in benchmarks, perhaps
partitioner = CapabilityBasedPartitioner(gm, supported_ops, allows_single_node_partition=True)
partitions = partitioner.propose_partitions()
fused_graph = partitioner.fuse_partitions(partitions)
return fused_graph