[AOTAutogradCache] Allow torch.Tensor and a non-torch op from einops (#152369)

This addresses part of #150706.

Specifically, it reduces the warm start `torch.compile` overhead by
40~50% for GGUF models on
1. HuggingFace diffusers: [tlparse before, 224s](https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmpqgbdva/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000) v.s. [tlparse after, 126s](https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmp950PFy/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000)
2. ComfyUI: [tlparse before, 93s](https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmp7SeJb4/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000) v.s. [tlparse after, 51s](https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmpRwGNqA/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000)

The improvements should generalize to all other GGUF models on these
platforms, because the cache miss was induced by framework code, which
will be hit by every GGUF model.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152369
Approved by: https://github.com/jamesjwu
This commit is contained in:
Ryan Guo 2025-04-28 14:07:38 -07:00 committed by PyTorch MergeBot
parent ce2cf31623
commit e4994e2f73

View File

@ -130,12 +130,14 @@ def check_node_safe(node: Node):
SAFE_TORCH_MODULES = ("torch.functional", "torch.nn.functional")
SAFE_TORCH_FUNCTIONS = (
"torch.Size",
"torch.Tensor",
"torch.sym_int",
"torch._sym_sqrt",
"torch.sym_float",
"torch.sym_sum",
"einops.einops.rearrange",
)
SAFE_NON_TORCH_FUNCTIONS = ("einops.einops.repeat",)
def is_public_torch_api(target):
# Don't blindly allow private functions in the torch namespace
@ -163,7 +165,7 @@ def check_node_safe(node: Node):
or function_name in torch._inductor.config.unsafe_marked_cacheable_functions
)
def is_torch_function(target):
def is_cacheable_function(target):
if isinstance(target, (torch._ops.OpOverload, torch._ops.OpOverloadPacket)):
return True
if is_public_torch_api(target):
@ -177,6 +179,9 @@ def check_node_safe(node: Node):
return True
if is_safe_torch_function(target):
return True
function_name = f"{target.__module__}.{target.__name__}"
if function_name in SAFE_NON_TORCH_FUNCTIONS:
return True
return False
def is_tensor(target):
@ -185,9 +190,7 @@ def check_node_safe(node: Node):
# I'd love to use a match statement here, but it wasn't introduced until py3.10
if node.op == "call_function":
# We support only torch.* functions for now
# We can probably add an allowlist of safe non-torch implementations as well
if not is_torch_function(node.target):
if not is_cacheable_function(node.target):
module = getattr(node.target, "__module__", None)
name = getattr(node.target, "__name__", None)
raise BypassAOTAutogradCache(