mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[AOTAutogradCache] Allow torch.Tensor and a non-torch op from einops (#152369)
This addresses part of #150706. Specifically, it reduces the warm start `torch.compile` overhead by 40~50% for GGUF models on 1. HuggingFace diffusers: [tlparse before, 224s](https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmpqgbdva/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000) v.s. [tlparse after, 126s](https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmp950PFy/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000) 2. ComfyUI: [tlparse before, 93s](https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmp7SeJb4/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000) v.s. [tlparse after, 51s](https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmpRwGNqA/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000) The improvements should generalize to all other GGUF models on these platforms, because the cache miss was induced by framework code, which will be hit by every GGUF model. Pull Request resolved: https://github.com/pytorch/pytorch/pull/152369 Approved by: https://github.com/jamesjwu
This commit is contained in:
parent
ce2cf31623
commit
e4994e2f73
|
|
@ -130,12 +130,14 @@ def check_node_safe(node: Node):
|
||||||
SAFE_TORCH_MODULES = ("torch.functional", "torch.nn.functional")
|
SAFE_TORCH_MODULES = ("torch.functional", "torch.nn.functional")
|
||||||
SAFE_TORCH_FUNCTIONS = (
|
SAFE_TORCH_FUNCTIONS = (
|
||||||
"torch.Size",
|
"torch.Size",
|
||||||
|
"torch.Tensor",
|
||||||
"torch.sym_int",
|
"torch.sym_int",
|
||||||
"torch._sym_sqrt",
|
"torch._sym_sqrt",
|
||||||
"torch.sym_float",
|
"torch.sym_float",
|
||||||
"torch.sym_sum",
|
"torch.sym_sum",
|
||||||
"einops.einops.rearrange",
|
"einops.einops.rearrange",
|
||||||
)
|
)
|
||||||
|
SAFE_NON_TORCH_FUNCTIONS = ("einops.einops.repeat",)
|
||||||
|
|
||||||
def is_public_torch_api(target):
|
def is_public_torch_api(target):
|
||||||
# Don't blindly allow private functions in the torch namespace
|
# Don't blindly allow private functions in the torch namespace
|
||||||
|
|
@ -163,7 +165,7 @@ def check_node_safe(node: Node):
|
||||||
or function_name in torch._inductor.config.unsafe_marked_cacheable_functions
|
or function_name in torch._inductor.config.unsafe_marked_cacheable_functions
|
||||||
)
|
)
|
||||||
|
|
||||||
def is_torch_function(target):
|
def is_cacheable_function(target):
|
||||||
if isinstance(target, (torch._ops.OpOverload, torch._ops.OpOverloadPacket)):
|
if isinstance(target, (torch._ops.OpOverload, torch._ops.OpOverloadPacket)):
|
||||||
return True
|
return True
|
||||||
if is_public_torch_api(target):
|
if is_public_torch_api(target):
|
||||||
|
|
@ -177,6 +179,9 @@ def check_node_safe(node: Node):
|
||||||
return True
|
return True
|
||||||
if is_safe_torch_function(target):
|
if is_safe_torch_function(target):
|
||||||
return True
|
return True
|
||||||
|
function_name = f"{target.__module__}.{target.__name__}"
|
||||||
|
if function_name in SAFE_NON_TORCH_FUNCTIONS:
|
||||||
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def is_tensor(target):
|
def is_tensor(target):
|
||||||
|
|
@ -185,9 +190,7 @@ def check_node_safe(node: Node):
|
||||||
|
|
||||||
# I'd love to use a match statement here, but it wasn't introduced until py3.10
|
# I'd love to use a match statement here, but it wasn't introduced until py3.10
|
||||||
if node.op == "call_function":
|
if node.op == "call_function":
|
||||||
# We support only torch.* functions for now
|
if not is_cacheable_function(node.target):
|
||||||
# We can probably add an allowlist of safe non-torch implementations as well
|
|
||||||
if not is_torch_function(node.target):
|
|
||||||
module = getattr(node.target, "__module__", None)
|
module = getattr(node.target, "__module__", None)
|
||||||
name = getattr(node.target, "__name__", None)
|
name = getattr(node.target, "__name__", None)
|
||||||
raise BypassAOTAutogradCache(
|
raise BypassAOTAutogradCache(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user