mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Graph Partition] move custom rules to inductor config (#166458)
This PR adds `custom_should_partition_ops: list[str]` to specify the name of custom ops upon which graph partition happens. It works with cache since it is a `list[str]` in the config file. The op name should be of format "mylib::baz". Close: #165341 Pull Request resolved: https://github.com/pytorch/pytorch/pull/166458 Approved by: https://github.com/ProExpertProg, https://github.com/eellison, https://github.com/zou3519
This commit is contained in:
parent
56a809aa07
commit
bebabd7fce
|
|
@ -945,35 +945,46 @@ if HAS_CUDA_AND_TRITON:
|
|||
self.assertEqual(num_partitions, 1)
|
||||
|
||||
@torch.library.custom_op("mylib::baz", mutates_args=())
|
||||
def baz(x: torch.Tensor, flag: int) -> torch.Tensor:
|
||||
def baz(x: torch.Tensor) -> torch.Tensor:
|
||||
return x.clone()
|
||||
|
||||
@baz.register_fake
|
||||
def _(x, flag):
|
||||
def _(x):
|
||||
return x.clone()
|
||||
|
||||
def should_partition(x, flag):
|
||||
return flag
|
||||
# custom_should_partition_ops takes effect which lead to 2 partitions
|
||||
torch._inductor.config.custom_should_partition_ops = ["mylib::baz"]
|
||||
|
||||
torch._inductor.scheduler.register_should_partition_rule(
|
||||
torch.ops.mylib.baz.default, should_partition
|
||||
)
|
||||
|
||||
def f(x, flag):
|
||||
def f(x):
|
||||
x = x + 1
|
||||
x = baz(x, flag)
|
||||
x = baz(x)
|
||||
x = x + 1
|
||||
return x
|
||||
|
||||
f_compiled = torch.compile(f, mode="reduce-overhead", fullgraph=True)
|
||||
_, code = run_and_get_code(f_compiled, x, True)
|
||||
_, code = run_and_get_code(f_compiled, x)
|
||||
num_partitions = get_num_partitions(code)
|
||||
self.assertEqual(num_partitions, 2)
|
||||
|
||||
_, code = run_and_get_code(f_compiled, x, False)
|
||||
# update the config should NOT force recompile
|
||||
torch._inductor.config.custom_should_partition_ops = []
|
||||
with torch.compiler.set_stance("fail_on_recompile"):
|
||||
f_compiled(x)
|
||||
|
||||
# run_and_get_code forces recompile. Now we should cache miss, recompile, and
|
||||
# only have 1 partition.
|
||||
_, code = run_and_get_code(f_compiled, x)
|
||||
num_partitions = get_num_partitions(code)
|
||||
self.assertEqual(num_partitions, 1)
|
||||
|
||||
# test that op_overload name takes effect which lead to 2 partitions
|
||||
torch._inductor.config.custom_should_partition_ops = ["mylib::baz.default"]
|
||||
|
||||
f_compiled = torch.compile(f, mode="reduce-overhead", fullgraph=True)
|
||||
_, code = run_and_get_code(f_compiled, x)
|
||||
num_partitions = get_num_partitions(code)
|
||||
self.assertEqual(num_partitions, 2)
|
||||
|
||||
@torch._inductor.config.patch("graph_partition", True)
|
||||
@torch._inductor.config.patch("implicit_fallbacks", True)
|
||||
def test_graph_partition_with_memory_plan_reuse(self):
|
||||
|
|
|
|||
|
|
@ -483,6 +483,11 @@ graph_partition: bool = (
|
|||
== "1"
|
||||
)
|
||||
|
||||
# register ops upon which inductor should partition the graph. name format should be
|
||||
# "namespace::kernel_name" (e.g., aten::mm) for op overload packet, or
|
||||
# "namespace::kernel_name.overload" (e.g., aten::mm.default).
|
||||
custom_should_partition_ops: list[str] = []
|
||||
|
||||
# whether template autotuning should allow flexible layouts if possible (e.g. only extern choices)
|
||||
max_autotune_allow_flexible_layouts: bool = False
|
||||
|
||||
|
|
|
|||
|
|
@ -25,8 +25,6 @@ if TYPE_CHECKING:
|
|||
from collections.abc import Iterator, Sequence
|
||||
from types import ModuleType
|
||||
|
||||
import weakref
|
||||
|
||||
import sympy
|
||||
|
||||
import torch
|
||||
|
|
@ -97,28 +95,6 @@ _T = TypeVar("_T")
|
|||
_P = ParamSpec("_P")
|
||||
|
||||
|
||||
_custom_should_partition_fns: weakref.WeakKeyDictionary[
|
||||
torch._ops.OpOverload, Callable[..., bool]
|
||||
] = weakref.WeakKeyDictionary()
|
||||
|
||||
|
||||
def register_should_partition_rule(
|
||||
op: torch._ops.OpOverload,
|
||||
func: Callable[..., bool],
|
||||
) -> None:
|
||||
"""Register a function that says if Inductor should partition the graph on this op.
|
||||
|
||||
The function should be have the same signature as the operator.
|
||||
Inductor will invoke the function with FakeTensors when it needs to decide
|
||||
if the graph should be partitioned.
|
||||
|
||||
`register_should_partition_rule` is currently private and experimental.
|
||||
Use at your own risk.
|
||||
"""
|
||||
assert isinstance(op, torch._ops.OpOverload)
|
||||
_custom_should_partition_fns[op] = func
|
||||
|
||||
|
||||
class MixOrderReduction:
|
||||
"""
|
||||
This class contains utility functions to decide if we should fuse reductions
|
||||
|
|
@ -4996,21 +4972,21 @@ class Scheduler:
|
|||
# Allow users to manually specify if a node should be partitioned
|
||||
# Can only do this for FallbackKernels
|
||||
ir_node = node.node
|
||||
if isinstance(ir_node, torch._inductor.ir.FallbackKernel):
|
||||
operator = ir_node.op_overload
|
||||
if operator is not None and operator in _custom_should_partition_fns:
|
||||
assert isinstance(operator, torch._ops.OpOverload)
|
||||
should_partition_fn = _custom_should_partition_fns[operator]
|
||||
fx_node = ir_node.get_origin_node()
|
||||
assert fx_node is not None
|
||||
success, fake_args, fake_kwargs = (
|
||||
torch._inductor.fx_utils.get_fake_args_kwargs(fx_node)
|
||||
)
|
||||
assert success, (
|
||||
"If this op came from a custom inductor pass, make sure to run FakeTensorUpdator"
|
||||
)
|
||||
should_partition = should_partition_fn(*fake_args, **fake_kwargs)
|
||||
return should_partition
|
||||
if isinstance(ir_node, torch._inductor.ir.FallbackKernel) and (
|
||||
op := ir_node.op_overload
|
||||
):
|
||||
op_overload_packet_name = op.name()
|
||||
op_overload_name = (
|
||||
f"{op_overload_packet_name}.{op._overloadname}"
|
||||
if isinstance(op, torch._ops.OpOverload)
|
||||
else op_overload_packet_name
|
||||
)
|
||||
if (
|
||||
op_overload_packet_name in config.custom_should_partition_ops
|
||||
or op_overload_name in config.custom_should_partition_ops
|
||||
):
|
||||
assert isinstance(op, torch._ops.OpOverload)
|
||||
return True
|
||||
|
||||
# When not using cudagraphs, keep all kernels in the `call` function
|
||||
# instead of graph partition functions, since graph partition only brings
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user