mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[export] Fix custom ops in subgraphs (#160004)
Fixes https://github.com/pytorch/pytorch/issues/159995 Currently there are two problems with extern kernels in subgraphs: 1. They don't get serialized to the extern kernel json file because we only look at the toplevel graph. 2. Since the scope of each extern_kernel list is within its own subgraph, the indices referencing the operator is messed up because each subgraph will start counting from 0. So, this PR moves the extern_kernels list to a global view (under virtualized) so that we can count the extern kernels across subgraphs and the toplevel graph. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160004 Approved by: https://github.com/ydwu4
This commit is contained in:
parent
1091165826
commit
3c8c509a9c
|
|
@ -6772,6 +6772,49 @@ class AOTInductorTestsTemplate:
|
||||||
# compare against eager
|
# compare against eager
|
||||||
self.assertEqual(optimized(**model_kwargs), model(**model_kwargs))
|
self.assertEqual(optimized(**model_kwargs), model(**model_kwargs))
|
||||||
|
|
||||||
|
def test_custom_op_in_subgraph(self):
|
||||||
|
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
|
||||||
|
torch.library.define(
|
||||||
|
"mylib::foo_add1",
|
||||||
|
"(Tensor a) -> Tensor",
|
||||||
|
tags=torch.Tag.pt2_compliant_tag,
|
||||||
|
lib=lib,
|
||||||
|
)
|
||||||
|
|
||||||
|
@torch.library.impl("mylib::foo_add1", "CompositeExplicitAutograd", lib=lib)
|
||||||
|
@torch.library.register_fake("mylib::foo_add1", lib=lib)
|
||||||
|
def foo_add1_impl(a: torch.Tensor) -> torch.Tensor:
|
||||||
|
return a + 1
|
||||||
|
|
||||||
|
torch.library.define(
|
||||||
|
"mylib::foo_add2",
|
||||||
|
"(Tensor a) -> Tensor",
|
||||||
|
tags=torch.Tag.pt2_compliant_tag,
|
||||||
|
lib=lib,
|
||||||
|
)
|
||||||
|
|
||||||
|
@torch.library.impl("mylib::foo_add2", "CompositeExplicitAutograd", lib=lib)
|
||||||
|
@torch.library.register_fake("mylib::foo_add2", lib=lib)
|
||||||
|
def foo_add2_impl(a: torch.Tensor) -> torch.Tensor:
|
||||||
|
return a + 2
|
||||||
|
|
||||||
|
class M(torch.nn.Module):
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.cond(
|
||||||
|
x.shape[0] < 5,
|
||||||
|
torch.ops.mylib.foo_add1,
|
||||||
|
torch.ops.mylib.foo_add2,
|
||||||
|
(x,),
|
||||||
|
)
|
||||||
|
|
||||||
|
list_example_inputs = [
|
||||||
|
(torch.ones(6, device=self.device),),
|
||||||
|
(torch.ones(3, device=self.device),),
|
||||||
|
]
|
||||||
|
self.check_model_with_multiple_inputs(
|
||||||
|
M(), list_example_inputs, dynamic_shapes=({0: Dim.DYNAMIC},)
|
||||||
|
)
|
||||||
|
|
||||||
def test_clamp_decomposition(self):
|
def test_clamp_decomposition(self):
|
||||||
class Model1(torch.nn.Module):
|
class Model1(torch.nn.Module):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
|
|
||||||
|
|
@ -70,6 +70,7 @@ CPU_TEST_FAILURES = {
|
||||||
"test_cond_with_multiple_outputs": fail_minimal_arrayref_interface(),
|
"test_cond_with_multiple_outputs": fail_minimal_arrayref_interface(),
|
||||||
"test_cond_with_parameters": fail_minimal_arrayref_interface(),
|
"test_cond_with_parameters": fail_minimal_arrayref_interface(),
|
||||||
"test_cond_with_reinterpret_view_inputs_outputs": fail_minimal_arrayref_interface(),
|
"test_cond_with_reinterpret_view_inputs_outputs": fail_minimal_arrayref_interface(),
|
||||||
|
"test_custom_op_in_subgraph": fail_minimal_arrayref_interface(),
|
||||||
"test_cond_share_predicte": fail_stack_allocation(is_skip=True),
|
"test_cond_share_predicte": fail_stack_allocation(is_skip=True),
|
||||||
"test_cond_unbacked_symint_closure_dynamic_True": fail_minimal_arrayref_interface(),
|
"test_cond_unbacked_symint_closure_dynamic_True": fail_minimal_arrayref_interface(),
|
||||||
"test_while_loop_with_unbacked_symint_closure_dynamic_True": fail_minimal_arrayref_interface(),
|
"test_while_loop_with_unbacked_symint_closure_dynamic_True": fail_minimal_arrayref_interface(),
|
||||||
|
|
|
||||||
|
|
@ -2604,7 +2604,7 @@ if (!custom_op_wrapper) {
|
||||||
"AtenTensorHandle", tensor_call_args, force_mutable=True
|
"AtenTensorHandle", tensor_call_args, force_mutable=True
|
||||||
)
|
)
|
||||||
|
|
||||||
extern_kernel_node_index = len(V.graph.extern_kernel_nodes) - 1
|
extern_kernel_node_index = len(V.extern_kernel_nodes) - 1
|
||||||
self.writeline(
|
self.writeline(
|
||||||
f"aoti_torch_proxy_executor_call_function(proxy_executor, "
|
f"aoti_torch_proxy_executor_call_function(proxy_executor, "
|
||||||
f"{extern_kernel_node_index}, "
|
f"{extern_kernel_node_index}, "
|
||||||
|
|
|
||||||
|
|
@ -1390,7 +1390,10 @@ class _InProcessFxCompile(FxCompile):
|
||||||
is_backward=is_backward,
|
is_backward=is_backward,
|
||||||
is_const_graph=True,
|
is_const_graph=True,
|
||||||
)
|
)
|
||||||
with V.set_graph_handler(const_graph):
|
with (
|
||||||
|
V.set_graph_handler(const_graph),
|
||||||
|
V.set_extern_kernel_nodes([]),
|
||||||
|
):
|
||||||
assert cpp_wrapper, "AOT mode only supports C++ wrapper"
|
assert cpp_wrapper, "AOT mode only supports C++ wrapper"
|
||||||
const_graph.run()
|
const_graph.run()
|
||||||
const_wrapper_code, const_kernel_code = (
|
const_wrapper_code, const_kernel_code = (
|
||||||
|
|
@ -1425,7 +1428,7 @@ class _InProcessFxCompile(FxCompile):
|
||||||
# We are going to start code generating runtime asserts, so make sure
|
# We are going to start code generating runtime asserts, so make sure
|
||||||
# you don't start adding new ones in the lowering process
|
# you don't start adding new ones in the lowering process
|
||||||
graph.freeze_runtime_asserts()
|
graph.freeze_runtime_asserts()
|
||||||
with V.set_graph_handler(graph):
|
with V.set_graph_handler(graph), V.set_extern_kernel_nodes([]):
|
||||||
graph.run(*example_inputs)
|
graph.run(*example_inputs)
|
||||||
output_strides: list[Optional[tuple[_StrideExprStr, ...]]] = []
|
output_strides: list[Optional[tuple[_StrideExprStr, ...]]] = []
|
||||||
if graph.graph_outputs is not None:
|
if graph.graph_outputs is not None:
|
||||||
|
|
@ -1472,11 +1475,9 @@ class _InProcessFxCompile(FxCompile):
|
||||||
)
|
)
|
||||||
|
|
||||||
serialized_extern_kernel_nodes = None
|
serialized_extern_kernel_nodes = None
|
||||||
if graph.extern_kernel_nodes:
|
if V.extern_kernel_nodes:
|
||||||
serialized_extern_kernel_nodes = (
|
serialized_extern_kernel_nodes = (
|
||||||
graph.extern_node_serializer(
|
graph.extern_node_serializer(V.extern_kernel_nodes)
|
||||||
graph.extern_kernel_nodes
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
output_code_log.debug(
|
output_code_log.debug(
|
||||||
"Serialized Extern Kernel Nodes: \n%s",
|
"Serialized Extern Kernel Nodes: \n%s",
|
||||||
|
|
|
||||||
|
|
@ -392,8 +392,6 @@ class GraphLowering(torch.fx.Interpreter):
|
||||||
self.inplaced_to_remove: OrderedSet[str] = OrderedSet()
|
self.inplaced_to_remove: OrderedSet[str] = OrderedSet()
|
||||||
self.device_ops: DeviceOpOverrides = None # type: ignore[assignment]
|
self.device_ops: DeviceOpOverrides = None # type: ignore[assignment]
|
||||||
self.wrapper_code: PythonWrapperCodegen = None # type: ignore[assignment]
|
self.wrapper_code: PythonWrapperCodegen = None # type: ignore[assignment]
|
||||||
# See `ProxyExecutor Design Note` in ir.py for more details
|
|
||||||
self.extern_kernel_nodes: list[ir.ExternKernelNode] = []
|
|
||||||
|
|
||||||
from torch._inductor.extern_node_serializer import extern_node_json_serializer
|
from torch._inductor.extern_node_serializer import extern_node_json_serializer
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7656,7 +7656,7 @@ class FallbackKernel(ExternKernelAlloc):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
V.graph.extern_kernel_nodes.append(node)
|
V.extern_kernel_nodes.append(node)
|
||||||
|
|
||||||
return [*args, *ordered_kwargs]
|
return [*args, *ordered_kwargs]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -80,6 +80,7 @@ if TYPE_CHECKING:
|
||||||
from torch._inductor.codegen.cpp_utils import LocalBufferContext
|
from torch._inductor.codegen.cpp_utils import LocalBufferContext
|
||||||
from torch._inductor.debug import DebugContext
|
from torch._inductor.debug import DebugContext
|
||||||
from torch._inductor.graph import GraphLowering
|
from torch._inductor.graph import GraphLowering
|
||||||
|
from torch._inductor.ir import ExternKernelNode
|
||||||
from torch._inductor.loop_body import InterpreterShim
|
from torch._inductor.loop_body import InterpreterShim
|
||||||
from torch._subclasses import FakeTensorMode
|
from torch._subclasses import FakeTensorMode
|
||||||
|
|
||||||
|
|
@ -183,6 +184,9 @@ _ops: Virtualized[OpsHandler[Any]] = Virtualized(
|
||||||
"ops", cast(type[OpsHandler[Any]], MockHandler)
|
"ops", cast(type[OpsHandler[Any]], MockHandler)
|
||||||
)
|
)
|
||||||
_graph: Virtualized[GraphLowering] = Virtualized("graph", NullHandler)
|
_graph: Virtualized[GraphLowering] = Virtualized("graph", NullHandler)
|
||||||
|
_extern_kernel_nodes: Virtualized[list[ExternKernelNode]] = Virtualized(
|
||||||
|
"extern_kernel_nodes", NullHandler
|
||||||
|
)
|
||||||
_real_inputs: Virtualized[list[torch.Tensor]] = Virtualized("real_inputs", NullHandler)
|
_real_inputs: Virtualized[list[torch.Tensor]] = Virtualized("real_inputs", NullHandler)
|
||||||
_fake_mode: Virtualized[FakeTensorMode] = Virtualized("fake_mode", NullHandler)
|
_fake_mode: Virtualized[FakeTensorMode] = Virtualized("fake_mode", NullHandler)
|
||||||
_kernel: Virtualized[NullKernelHandler] = Virtualized(
|
_kernel: Virtualized[NullKernelHandler] = Virtualized(
|
||||||
|
|
@ -343,6 +347,9 @@ class _V:
|
||||||
)
|
)
|
||||||
get_ops_handler: Callable[[], OpsHandler[Any]] = _ops._get_handler
|
get_ops_handler: Callable[[], OpsHandler[Any]] = _ops._get_handler
|
||||||
set_graph_handler: Callable[[GraphLowering], Any] = _graph._set_handler
|
set_graph_handler: Callable[[GraphLowering], Any] = _graph._set_handler
|
||||||
|
set_extern_kernel_nodes: Callable[[list[ExternKernelNode]], Any] = (
|
||||||
|
_extern_kernel_nodes._set_handler
|
||||||
|
)
|
||||||
set_real_inputs: Callable[[Any], Any] = _real_inputs._set_handler
|
set_real_inputs: Callable[[Any], Any] = _real_inputs._set_handler
|
||||||
get_real_inputs: Callable[[], Any] = _real_inputs._get_handler
|
get_real_inputs: Callable[[], Any] = _real_inputs._get_handler
|
||||||
set_fake_mode: Callable[[Any], Any] = _fake_mode._set_handler
|
set_fake_mode: Callable[[Any], Any] = _fake_mode._set_handler
|
||||||
|
|
@ -368,6 +375,15 @@ class _V:
|
||||||
"""The graph currently being generated"""
|
"""The graph currently being generated"""
|
||||||
return _graph._get_handler()
|
return _graph._get_handler()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def extern_kernel_nodes(self) -> list[ExternKernelNode]:
|
||||||
|
"""
|
||||||
|
The extern_kernel_nodes needed for the entire graph, including the
|
||||||
|
subgraphs.
|
||||||
|
See `ProxyExecutor Design Note` in ir.py for more details
|
||||||
|
"""
|
||||||
|
return _extern_kernel_nodes._get_handler()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def real_inputs(self):
|
def real_inputs(self):
|
||||||
"""non-fake example inputs"""
|
"""non-fake example inputs"""
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user