mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Allows individual DifferentiableGraphOp to display optimized forward graph. This improves user visibility to graph mutation via optimization pass, especially fusion. Pull Request resolved: https://github.com/pytorch/pytorch/pull/55139 Reviewed By: albanD Differential Revision: D31330909 Pulled By: dzhulgakov fbshipit-source-id: c745b482fdc34876dc404cbe3bacd99dcf2ac724
108 lines
4.1 KiB
Python
108 lines
4.1 KiB
Python
import contextlib
|
|
|
|
import torch
|
|
from typing import List
|
|
|
|
@contextlib.contextmanager
|
|
def optimized_execution(should_optimize):
|
|
"""
|
|
A context manager that controls whether the JIT's executor will run
|
|
optimizations before executing a function.
|
|
"""
|
|
stored_flag = torch._C._get_graph_executor_optimize()
|
|
torch._C._set_graph_executor_optimize(should_optimize)
|
|
try:
|
|
yield
|
|
finally:
|
|
torch._C._set_graph_executor_optimize(stored_flag)
|
|
|
|
@contextlib.contextmanager
|
|
def fuser(name):
|
|
"""
|
|
A context manager that facilitates switching between
|
|
backend fusers.
|
|
|
|
Valid names:
|
|
* ``fuser0`` - enables only legacy fuser
|
|
* ``fuser1`` - enables only NNC
|
|
* ``fuser2`` - enables only nvFuser
|
|
"""
|
|
old_cpu_fuse = torch._C._jit_can_fuse_on_cpu()
|
|
old_gpu_fuse = torch._C._jit_can_fuse_on_gpu()
|
|
old_texpr_fuser_state = torch._C._jit_texpr_fuser_enabled()
|
|
old_nvfuser_state = torch._C._jit_nvfuser_enabled()
|
|
if name == 'fuser0': # legacy fuser
|
|
torch._C._jit_override_can_fuse_on_cpu(True)
|
|
torch._C._jit_override_can_fuse_on_gpu(True)
|
|
torch._C._jit_set_texpr_fuser_enabled(False)
|
|
torch._C._jit_set_nvfuser_enabled(False)
|
|
elif name == 'fuser1': # NNC
|
|
old_profiling_executor = torch._C._jit_set_profiling_executor(True)
|
|
old_profiling_mode = torch._C._jit_set_profiling_mode(True)
|
|
torch._C._jit_override_can_fuse_on_cpu(False)
|
|
torch._C._jit_override_can_fuse_on_gpu(True)
|
|
torch._C._jit_set_texpr_fuser_enabled(True)
|
|
torch._C._jit_set_nvfuser_enabled(False)
|
|
elif name == 'fuser2': # nvFuser
|
|
torch._C._jit_override_can_fuse_on_cpu(False)
|
|
torch._C._jit_override_can_fuse_on_gpu(False)
|
|
torch._C._jit_set_texpr_fuser_enabled(False)
|
|
torch._C._jit_set_nvfuser_enabled(True)
|
|
else:
|
|
raise Exception("unrecognized fuser option")
|
|
try:
|
|
yield
|
|
finally:
|
|
if name == 'fuser1': # NNC
|
|
torch._C._jit_set_profiling_executor(old_profiling_executor)
|
|
torch._C._jit_set_profiling_mode(old_profiling_mode)
|
|
# recover the previous values
|
|
torch._C._jit_override_can_fuse_on_cpu(old_cpu_fuse)
|
|
torch._C._jit_override_can_fuse_on_gpu(old_gpu_fuse)
|
|
torch._C._jit_set_texpr_fuser_enabled(old_texpr_fuser_state)
|
|
torch._C._jit_set_nvfuser_enabled(old_nvfuser_state)
|
|
|
|
|
|
last_executed_optimized_graph = torch._C._last_executed_optimized_graph
|
|
|
|
def _get_differentiable_graph_node(node, diff_node):
|
|
if node.kind() == 'prim::DifferentiableGraph':
|
|
diff_node.append(node)
|
|
else:
|
|
for block in node.blocks():
|
|
for n in block.nodes():
|
|
_get_differentiable_graph_node(n, diff_node)
|
|
|
|
def _graph_for(self, *args, **kwargs):
|
|
return _script_method_graph_for(self, self, *args, **kwargs)
|
|
|
|
def _script_method_graph_for(self, parent, *args, **kwargs):
|
|
try:
|
|
dbs = parent.get_debug_state()
|
|
eps = list(dbs.execution_plans.values())
|
|
assert(len(eps) == 1)
|
|
graph = eps[0].graph.copy()
|
|
|
|
# graph_executor_states for differentiable node
|
|
fw_states = eps[0].code.differentiable_op_executor_states()
|
|
diff_nodes: List[torch._C.Node] = []
|
|
for n in graph.nodes():
|
|
_get_differentiable_graph_node(n, diff_nodes)
|
|
|
|
assert(len(fw_states) == len(diff_nodes))
|
|
# swap each differentiable graph with optimized graph in their execution plan
|
|
for n, state in zip(diff_nodes, fw_states):
|
|
fw_execution_plans = list(state.execution_plans.values())
|
|
# we can only update the subgraph when there's a unique execution
|
|
# plan. Avoid assert here so we would skip the ones that can't be
|
|
# updated while try the best effort to update other nodes.
|
|
if len(fw_execution_plans) == 1:
|
|
n.g_('Subgraph', fw_execution_plans[0].graph)
|
|
|
|
return graph
|
|
except Exception:
|
|
# fallback approach, we just ran the graph and return the recorded optimized
|
|
# graph
|
|
self(*args, **kwargs)
|
|
return last_executed_optimized_graph()
|