From 6bcda3a21a68eb774974472571babcdbcc4fa54f Mon Sep 17 00:00:00 2001 From: Aaron Orenstein Date: Thu, 12 Dec 2024 19:38:17 -0800 Subject: [PATCH] dynamo tracing perf: cache on import_source: 52.9 -> 52.58 (#143058) See #143056 for overall docs. This PR: add cache to `InstructionTranslatorBase.import_source()` Pull Request resolved: https://github.com/pytorch/pytorch/pull/143058 Approved by: https://github.com/jansel ghstack dependencies: #143066, #143056 --- torch/_dynamo/symbolic_convert.py | 4 +++ torch/utils/_functools.py | 44 +++++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+) create mode 100644 torch/utils/_functools.py diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 9b9a2e466dc..76cd73f03d0 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -26,6 +26,7 @@ import torch import torch._logging from torch._dynamo.exc import TensorifyScalarRestartAnalysis from torch._guards import tracing, TracingContext +from torch.utils._functools import cache_method from . import config, exc, logging as torchdynamo_logging, trace_rules, variables from .bytecode_analysis import ( @@ -1198,6 +1199,9 @@ class InstructionTranslatorBase( unimplemented("Storing handles in globals - NYI") self.output.side_effects.store_global(variable, name, value) + # Cache note: This cache only exists for the duration of this + # InstructionTranslator - so it should be safe to do. + @cache_method def import_source(self, module_name): """Create an alias to a module for use in guards""" if "torch_package" in module_name: diff --git a/torch/utils/_functools.py b/torch/utils/_functools.py new file mode 100644 index 00000000000..40ffd8f80a9 --- /dev/null +++ b/torch/utils/_functools.py @@ -0,0 +1,44 @@ +import functools +from typing import Callable, TypeVar +from typing_extensions import Concatenate, ParamSpec + + +_P = ParamSpec("_P") +_T = TypeVar("_T") +_C = TypeVar("_C") + +# Sentinel used to indicate that cache lookup failed. +_cache_sentinel = object() + + +def cache_method( + f: Callable[Concatenate[_C, _P], _T] +) -> Callable[Concatenate[_C, _P], _T]: + """ + Like `@functools.cache` but for methods. + + `@functools.cache` (and similarly `@functools.lru_cache`) shouldn't be used + on methods because it caches `self`, keeping it alive + forever. `@cache_method` ignores `self` so won't keep `self` alive (assuming + no cycles with `self` in the parameters). + + Footgun warning: This decorator completely ignores self's properties so only + use it when you know that self is frozen or won't change in a meaningful + way (such as the wrapped function being pure). + """ + cache_name = "_cache_method_" + f.__name__ + + @functools.wraps(f) + def wrap(self: _C, *args: _P.args, **kwargs: _P.kwargs) -> _T: + assert not kwargs + if not (cache := getattr(self, cache_name, None)): + cache = {} + setattr(self, cache_name, cache) + cached_value = cache.get(args, _cache_sentinel) + if cached_value is not _cache_sentinel: + return cached_value + value = f(self, *args, **kwargs) + cache[args] = value + return value + + return wrap