pytorch/torch/_dynamo/funcname_cache.py
William Wen 71beca4899 [dynamo, logging] Report name of defining class along side function name in Dynamo logs (#110190)
Implement https://github.com/pytorch/pytorch/issues/109236

Sample code:
```python
import torch

class AAA:
    class DUMMY:
        class DUMMY2:
            pass
    def dummy(self):
        def dummy2():
            pass
    class BBB:
        @staticmethod
        def CCC():
            class DDD:
                if True:
                    @staticmethod
                    def EEE():
                        x = [torch.ones(3, 3) for _ in range(5)]
                        return x
            return DDD

def fn():
    return AAA.BBB.CCC().EEE()

opt_fn = torch.compile(fn, backend="eager")

opt_fn()
```

Logs:
```bash
$TORCH_LOGS="trace_source" python playground2.py
[2023-09-27 17:38:35,641] [0/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG] TRACE starts_line /data/users/williamwen/pytorch/playground2.py:21 in fn (fn)
[2023-09-27 17:38:35,641] [0/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG]     def fn():
[2023-09-27 17:38:35,642] [0/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG] TRACE starts_line /data/users/williamwen/pytorch/playground2.py:22 in fn (fn)
[2023-09-27 17:38:35,642] [0/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG]         return AAA.BBB.CCC().EEE()
[2023-09-27 17:38:35,661] [0/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG] TRACE starts_line /data/users/williamwen/pytorch/playground2.py:11 in CCC (AAA.BBB) (inline depth: 1)
[2023-09-27 17:38:35,661] [0/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG]             @staticmethod
[2023-09-27 17:38:35,661] [0/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG] TRACE starts_line /data/users/williamwen/pytorch/playground2.py:13 in CCC (AAA.BBB.CCC.DDD) (inline depth: 1)
[2023-09-27 17:38:35,661] [0/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG]                 class DDD:
[2023-09-27 17:38:35,723] [1/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG] TRACE starts_line /data/users/williamwen/pytorch/playground2.py:17 in <listcomp> (AAA.BBB.CCC.DDD.EEE)
[2023-09-27 17:38:35,723] [1/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG]                             x = [torch.ones(3, 3) for _ in range(5)]
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110190
Approved by: https://github.com/ezyang, https://github.com/mlazos
2023-10-05 20:41:38 +00:00

56 lines
1.6 KiB
Python

import tokenize
cache = {}
def clearcache():
cache.clear()
def _add_file(filename):
try:
with open(filename) as f:
tokens = list(tokenize.generate_tokens(f.readline))
except OSError:
cache[filename] = {}
return
# NOTE: undefined behavior if file is not valid Python source,
# since tokenize will have undefined behavior.
result = {}
# current full funcname, e.g. xxx.yyy.zzz
cur_name = ""
cur_indent = 0
significant_indents = []
for i, token in enumerate(tokens):
if token.type == tokenize.INDENT:
cur_indent += 1
elif token.type == tokenize.DEDENT:
cur_indent -= 1
# possible end of function or class
if significant_indents and cur_indent == significant_indents[-1]:
significant_indents.pop()
# pop the last name
cur_name = cur_name.rpartition(".")[0]
elif (
token.type == tokenize.NAME
and i + 1 < len(tokens)
and tokens[i + 1].type == tokenize.NAME
and (token.string == "class" or token.string == "def")
):
# name of class/function always follows class/def token
significant_indents.append(cur_indent)
if cur_name:
cur_name += "."
cur_name += tokens[i + 1].string
result[token.start[0]] = cur_name
cache[filename] = result
def get_funcname(filename, lineno):
if filename not in cache:
_add_file(filename)
return cache[filename].get(lineno, None)