import contextlib import torch from typing import List @contextlib.contextmanager def optimized_execution(should_optimize): """ A context manager that controls whether the JIT's executor will run optimizations before executing a function. """ stored_flag = torch._C._get_graph_executor_optimize() torch._C._set_graph_executor_optimize(should_optimize) try: yield finally: torch._C._set_graph_executor_optimize(stored_flag) @contextlib.contextmanager def fuser(name): """ A context manager that facilitates switching between backend fusers. Valid names: * ``fuser0`` - enables only legacy fuser * ``fuser1`` - enables only NNC * ``fuser2`` - enables only nvFuser """ old_cpu_fuse = torch._C._jit_can_fuse_on_cpu() old_gpu_fuse = torch._C._jit_can_fuse_on_gpu() old_texpr_fuser_state = torch._C._jit_texpr_fuser_enabled() old_nvfuser_state = torch._C._jit_nvfuser_enabled() if name == 'fuser0': # legacy fuser torch._C._jit_override_can_fuse_on_cpu(True) torch._C._jit_override_can_fuse_on_gpu(True) torch._C._jit_set_texpr_fuser_enabled(False) torch._C._jit_set_nvfuser_enabled(False) elif name == 'fuser1': # NNC old_profiling_executor = torch._C._jit_set_profiling_executor(True) old_profiling_mode = torch._C._jit_set_profiling_mode(True) torch._C._jit_override_can_fuse_on_cpu(False) torch._C._jit_override_can_fuse_on_gpu(True) torch._C._jit_set_texpr_fuser_enabled(True) torch._C._jit_set_nvfuser_enabled(False) elif name == 'fuser2': # nvFuser torch._C._jit_override_can_fuse_on_cpu(False) torch._C._jit_override_can_fuse_on_gpu(False) torch._C._jit_set_texpr_fuser_enabled(False) torch._C._jit_set_nvfuser_enabled(True) else: raise Exception("unrecognized fuser option") try: yield finally: if name == 'fuser1': # NNC torch._C._jit_set_profiling_executor(old_profiling_executor) torch._C._jit_set_profiling_mode(old_profiling_mode) # recover the previous values torch._C._jit_override_can_fuse_on_cpu(old_cpu_fuse) torch._C._jit_override_can_fuse_on_gpu(old_gpu_fuse) torch._C._jit_set_texpr_fuser_enabled(old_texpr_fuser_state) torch._C._jit_set_nvfuser_enabled(old_nvfuser_state) last_executed_optimized_graph = torch._C._last_executed_optimized_graph def _get_differentiable_graph_node(node, diff_node): if node.kind() == 'prim::DifferentiableGraph': diff_node.append(node) else: for block in node.blocks(): for n in block.nodes(): _get_differentiable_graph_node(n, diff_node) def _graph_for(self, *args, **kwargs): try: dbs = self.get_debug_state() eps = list(dbs.execution_plans.values()) assert(len(eps) == 1) graph = eps[0].graph.copy() # graph_executor_states for differentiable node fw_states = eps[0].code.differentiable_op_executor_states() diff_nodes: List[torch._C.Node] = [] for n in graph.nodes(): _get_differentiable_graph_node(n, diff_nodes) assert(len(fw_states) == len(diff_nodes)) # swap each differentiable graph with optimized graph in their execution plan for n, state in zip(diff_nodes, fw_states): fw_execution_plans = list(state.execution_plans.values()) assert(len(fw_execution_plans) == 1) n.g_('Subgraph', fw_execution_plans[0].graph) return graph except Exception: # fallback approach, we just ran the graph and return the recorded optimized # graph self(*args, **kwargs) return last_executed_optimized_graph()