mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[AOTAutogradCache] Allow torch.Tensor and a non-torch op from einops (#152369)
This addresses part of #150706. Specifically, it reduces the warm start `torch.compile` overhead by 40~50% for GGUF models on 1. HuggingFace diffusers: [tlparse before, 224s](https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmpqgbdva/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000) v.s. [tlparse after, 126s](https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmp950PFy/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000) 2. ComfyUI: [tlparse before, 93s](https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmp7SeJb4/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000) v.s. [tlparse after, 51s](https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmpRwGNqA/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000) The improvements should generalize to all other GGUF models on these platforms, because the cache miss was induced by framework code, which will be hit by every GGUF model. Pull Request resolved: https://github.com/pytorch/pytorch/pull/152369 Approved by: https://github.com/jamesjwu
This commit is contained in:
parent
ce2cf31623
commit
e4994e2f73
|
|
@ -130,12 +130,14 @@ def check_node_safe(node: Node):
|
|||
SAFE_TORCH_MODULES = ("torch.functional", "torch.nn.functional")
|
||||
SAFE_TORCH_FUNCTIONS = (
|
||||
"torch.Size",
|
||||
"torch.Tensor",
|
||||
"torch.sym_int",
|
||||
"torch._sym_sqrt",
|
||||
"torch.sym_float",
|
||||
"torch.sym_sum",
|
||||
"einops.einops.rearrange",
|
||||
)
|
||||
SAFE_NON_TORCH_FUNCTIONS = ("einops.einops.repeat",)
|
||||
|
||||
def is_public_torch_api(target):
|
||||
# Don't blindly allow private functions in the torch namespace
|
||||
|
|
@ -163,7 +165,7 @@ def check_node_safe(node: Node):
|
|||
or function_name in torch._inductor.config.unsafe_marked_cacheable_functions
|
||||
)
|
||||
|
||||
def is_torch_function(target):
|
||||
def is_cacheable_function(target):
|
||||
if isinstance(target, (torch._ops.OpOverload, torch._ops.OpOverloadPacket)):
|
||||
return True
|
||||
if is_public_torch_api(target):
|
||||
|
|
@ -177,6 +179,9 @@ def check_node_safe(node: Node):
|
|||
return True
|
||||
if is_safe_torch_function(target):
|
||||
return True
|
||||
function_name = f"{target.__module__}.{target.__name__}"
|
||||
if function_name in SAFE_NON_TORCH_FUNCTIONS:
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_tensor(target):
|
||||
|
|
@ -185,9 +190,7 @@ def check_node_safe(node: Node):
|
|||
|
||||
# I'd love to use a match statement here, but it wasn't introduced until py3.10
|
||||
if node.op == "call_function":
|
||||
# We support only torch.* functions for now
|
||||
# We can probably add an allowlist of safe non-torch implementations as well
|
||||
if not is_torch_function(node.target):
|
||||
if not is_cacheable_function(node.target):
|
||||
module = getattr(node.target, "__module__", None)
|
||||
name = getattr(node.target, "__name__", None)
|
||||
raise BypassAOTAutogradCache(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user