replace usages of upload_graph in inductor with tlparse (v2) (#148720)

Reland of https://github.com/pytorch/pytorch/pull/148703

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148720
Approved by: https://github.com/mengluy0125
This commit is contained in:
Brian Hirsh 2025-03-10 10:35:52 -07:00 committed by PyTorch MergeBot
parent 5bbca7d328
commit 492f3fd5cf
6 changed files with 135 additions and 22 deletions

View File

@ -246,10 +246,14 @@ class StructuredTraceTest(TestCase):
{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['a']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0}
{"dynamo_output_graph": {"sizes": {"l_a_": [1000, 1000], "ones": [1000, 1000], "output": [1000, 1000]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"inductor_pre_grad_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "before_recompile_pre_grad", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "after_recompile_pre_grad", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"aot_inference_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "torch._functorch.config", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "before_recompile_post_grad", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "after_recompile_post_grad", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"inductor_post_grad_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
@ -274,10 +278,14 @@ class StructuredTraceTest(TestCase):
{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['a']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0}
{"dynamo_output_graph": {"sizes": {"l_a_": [1000, 1000], "ones": [1000, 1000], "output": [1000, 1000]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"inductor_pre_grad_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "before_recompile_pre_grad", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "after_recompile_pre_grad", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"aot_inference_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "torch._functorch.config", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "before_recompile_post_grad", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "after_recompile_post_grad", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"inductor_post_grad_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
@ -311,10 +319,14 @@ class StructuredTraceTest(TestCase):
{"describe_source": {"describer_id": "ID", "id": 1, "source": "L['x']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0}
{"dynamo_output_graph": {"sizes": {"l_y_": [1000, 1000], "l_x_": [1000, 1000], "add": [1000, 1000]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"inductor_pre_grad_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "before_recompile_pre_grad", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "after_recompile_pre_grad", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"aot_inference_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "torch._functorch.config", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "before_recompile_post_grad", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "after_recompile_post_grad", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"inductor_post_grad_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
@ -329,10 +341,14 @@ class StructuredTraceTest(TestCase):
{"create_symbol": {"symbol": "s0", "val": "1", "vr": "[-int_oo, int_oo]", "source": "L['y']", "user_stack": "STACK", "stack": "STACK"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0}
{"dynamo_output_graph": {"sizes": {"l_x_": [1000, 1000], "add": [1000, 1000]}}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
{"inductor_pre_grad_graph": {}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "before_recompile_pre_grad", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "after_recompile_pre_grad", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
{"aot_inference_graph": {}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "torch._functorch.config", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "before_recompile_post_grad", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "after_recompile_post_grad", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
{"inductor_post_grad_graph": {}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
{"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"}
@ -357,10 +373,14 @@ class StructuredTraceTest(TestCase):
{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['a']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0}
{"dynamo_output_graph": {"sizes": {"l_a_": [1000, 1000], "ones": [1000, 1000], "output": [1000, 1000], "ones_1": [1000, 1000], "output_1": [1000, 1000]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"inductor_pre_grad_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "before_recompile_pre_grad", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "after_recompile_pre_grad", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"aot_inference_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "torch._functorch.config", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "before_recompile_post_grad", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "after_recompile_post_grad", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"inductor_post_grad_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
@ -410,12 +430,16 @@ class StructuredTraceTest(TestCase):
{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['___stack0']"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1}
{"dynamo_output_graph": {"sizes": {"l_stack0_": [1000, 1000], "ones": [1000, 1000], "output": [1000, 1000], "sum_1": []}}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}
{"inductor_pre_grad_graph": {}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}
{"artifact": {"name": "before_recompile_pre_grad", "encoding": "string"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}
{"artifact": {"name": "after_recompile_pre_grad", "encoding": "string"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}
{"aot_joint_graph": {}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}
{"artifact": {"name": "torch._functorch.config", "encoding": "string"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}
{"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}
{"aot_forward_graph": {}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}
{"aot_backward_graph": {}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}
{"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}
{"artifact": {"name": "before_recompile_post_grad", "encoding": "string"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}
{"artifact": {"name": "after_recompile_post_grad", "encoding": "string"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}
{"inductor_post_grad_graph": {}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}
{"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}
{"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}
@ -425,6 +449,8 @@ class StructuredTraceTest(TestCase):
{"dynamo_start": {"stack": "STACK"}, "frame_id": 3, "frame_compile_id": 0, "attempt": 0}
{"compilation_metrics": "METRICS", "frame_id": 3, "frame_compile_id": 0, "attempt": 0}
{"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}
{"artifact": {"name": "before_recompile_post_grad", "encoding": "string"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}
{"artifact": {"name": "after_recompile_post_grad", "encoding": "string"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}
{"inductor_post_grad_graph": {}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}
{"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}
{"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"}
@ -489,12 +515,16 @@ class StructuredTraceTest(TestCase):
{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['a']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0}
{"dynamo_output_graph": {"sizes": {"l_a_": [1000, 1000], "output": [1000, 1000]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"inductor_pre_grad_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "before_recompile_pre_grad", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "after_recompile_pre_grad", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"aot_joint_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "torch._functorch.config", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"aot_forward_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"aot_backward_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "before_recompile_post_grad", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "after_recompile_post_grad", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"inductor_post_grad_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "dynamo_error", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 0}
@ -620,12 +650,16 @@ class StructuredTraceTest(TestCase):
{"describe_tensor": {"id": 2, "ndim": 1, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1024], "is_leaf": true, "requires_grad": true, "is_parameter": true, "stride": [1], "storage": 2, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}
{"describe_source": {"describer_id": "ID", "id": 2, "source": "L['self']._modules['layers']._modules['0']._parameters['bias']"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}
{"inductor_pre_grad_graph": {}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "before_recompile_pre_grad", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "after_recompile_pre_grad", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"aot_joint_graph": {}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "torch._functorch.config", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"aot_forward_graph": {}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"aot_backward_graph": {}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "before_recompile_post_grad", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "after_recompile_post_grad", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"inductor_post_grad_graph": {}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"inductor_output_code": {"filename": "FILENAME"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
@ -637,12 +671,16 @@ class StructuredTraceTest(TestCase):
{"describe_tensor": {"id": 30, "ndim": 1, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1024], "is_leaf": true, "requires_grad": true, "is_parameter": true, "stride": [1], "storage": 17, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}
{"describe_source": {"describer_id": "ID", "id": 30, "source": "L['self']._modules['layers']._modules['1']._parameters['bias']"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}
{"inductor_pre_grad_graph": {}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "before_recompile_pre_grad", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "after_recompile_pre_grad", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"aot_joint_graph": {}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "torch._functorch.config", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"aot_forward_graph": {}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"aot_backward_graph": {}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "before_recompile_post_grad", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "after_recompile_post_grad", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"inductor_post_grad_graph": {}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"inductor_output_code": {"filename": "FILENAME"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
@ -676,10 +714,14 @@ class StructuredTraceTest(TestCase):
{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['x']"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0}
{"dynamo_output_graph": {"sizes": {"l_x_": [1], "add": [1]}}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"inductor_pre_grad_graph": {}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "before_recompile_pre_grad", "encoding": "string"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "after_recompile_pre_grad", "encoding": "string"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"aot_inference_graph": {}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "torch._functorch.config", "encoding": "string"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "before_recompile_post_grad", "encoding": "string"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "after_recompile_post_grad", "encoding": "string"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"inductor_post_grad_graph": {}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
@ -833,10 +875,14 @@ def forward(self, x, y):
{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['a']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0}
{"dynamo_output_graph": {"sizes": {"l_a_": [1], "sin": [1]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"inductor_pre_grad_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "before_recompile_pre_grad", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "after_recompile_pre_grad", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"aot_inference_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "torch._functorch.config", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "before_recompile_post_grad", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "after_recompile_post_grad", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"inductor_post_grad_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
@ -849,6 +895,8 @@ def forward(self, x, y):
{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['a']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0}
{"dynamo_output_graph": {"sizes": {"l_a_": [1], "sin": [1]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"inductor_pre_grad_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "before_recompile_pre_grad", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "after_recompile_pre_grad", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"aot_inference_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}
{"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"}

View File

@ -6,7 +6,7 @@ import unittest
import torch
import torch._inductor
import torch._inductor.fx_passes.group_batch_fusion
from torch._dynamo.utils import counters, optimus_scuba_log
from torch._dynamo.utils import counters
from torch._inductor.test_case import run_tests, TestCase
from torch.testing._internal.inductor_utils import GPU_TYPE, requires_gpu
@ -347,7 +347,6 @@ class TestGroupBatchFusion(TestCase):
counters["inductor"]["group_linear"],
2,
)
self.assertNotIn("group_batch_fusion_pre_grad", optimus_scuba_log)
ref.sum().backward()
res.sum().backward()
self.compare_parameters(module, traced)
@ -360,7 +359,6 @@ class TestGroupBatchFusion(TestCase):
counters["inductor"]["batch_aten_add"],
0,
)
self.assertIn("GroupLinearFusion", optimus_scuba_log)
counters.clear()
@unittest.skipIf(not has_fbgemm, "requires fbgemm")
@ -603,7 +601,6 @@ class TestPostGradBatchLinearFusion(TestCase):
counters["inductor"]["batch_linear_post_grad"],
2,
)
self.assertIn("PostGradBatchLinearFusion", optimus_scuba_log)
class TestFindIndependentSubsetGreedy(TestCase):

View File

@ -2,7 +2,7 @@
import torch
from torch._dynamo.utils import counters, optimus_scuba_log
from torch._dynamo.utils import counters
from torch._inductor.fx_passes.misc_patterns import numpy_compat_normalization
from torch._inductor.test_case import run_tests, TestCase
from torch.testing._internal.common_utils import IS_LINUX
@ -300,8 +300,6 @@ class TestSplitCatFxPasses(TestCase):
counters["inductor"]["merge_splits_pass"],
expected_split_merged,
)
if expected_split_merged > 0:
self.assertIn("merge_splits_pass_pre_grad", optimus_scuba_log)
counters.clear()
@patch

View File

@ -7,8 +7,8 @@ from collections.abc import Iterable, Iterator
from typing import Any, Optional
import torch
from torch._dynamo.utils import counters, optimus_scuba_log
from torch._utils_internal import upload_graph
from torch._dynamo.utils import counters
from torch._logging import trace_structured
from torch.fx.passes.graph_transform_observer import GraphTransformObserver
from torch.utils._ordered_set import OrderedSet
@ -1355,7 +1355,27 @@ def apply_group_batch_fusion(graph: torch.fx.GraphModule, rule: GroupBatchFusion
)
log_to_scuba = True
if log_to_scuba:
optimus_scuba_log[rule.__class__.__name__] = upload_graph(graph)
from torch.fx._lazy_graph_module import _LazyGraphModule
# Force graph to re-compile otherwise the output python code may be broken
gm = graph._owning_module
if isinstance(gm, _LazyGraphModule):
_LazyGraphModule.recompile()
else:
assert isinstance(gm, torch.fx.GraphModule)
gm.recompile()
graph_str = gm.print_readable(
print_output=False, include_stride=True, include_device=True
)
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": f"optimus_{str(rule.__class__.__name__)}",
"encoding": "string",
},
payload_fn=lambda: graph_str,
)
def generate_fusion_from_config(config_options: dict[str, Any], pre_grad=True):

View File

@ -13,11 +13,11 @@ import torch._inductor as inductor
import torch.utils._pytree as pytree
from torch import fx
from torch._decomp import register_decomposition
from torch._dynamo.utils import counters, optimus_scuba_log
from torch._dynamo.utils import counters
from torch._inductor import comms
from torch._inductor.virtualized import ops
from torch._logging import trace_structured
from torch._prims_common import is_boolean_dtype, is_expandable_to, is_integer_dtype
from torch._utils_internal import upload_graph
from torch.fx.experimental.symbolic_shapes import statically_known_true, sym_eq
from torch.utils._ordered_set import OrderedSet
@ -115,7 +115,16 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
if config.pattern_matcher:
lazy_init()
optimus_scuba_log["before_recompile_post_grad"] = upload_graph(gm.graph)
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "before_recompile_post_grad",
"encoding": "string",
},
payload_fn=lambda: gm.print_readable(
print_output=False, include_stride=True, include_device=True
),
)
GraphTransformObserver(gm, "post_grad_custom_pre_pass").apply_graph_pass(
functools.partial(group_batch_fusion_passes, pre_grad=False)
)
@ -139,8 +148,15 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
pattern_matcher_pass.apply
)
if not is_same_dict(counters["inductor"], inductor_before_change):
optimus_scuba_log[f"{pattern_matcher_pass.pass_name}_post_grad"] = (
upload_graph(gm.graph)
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": f"{pattern_matcher_pass.pass_name}_post_grad",
"encoding": "string",
},
payload_fn=lambda: gm.print_readable(
print_output=False, include_stride=True, include_device=True
),
)
if config.b2b_gemm_pass:
B2B_GEMM_PASS.apply(gm.graph) # type: ignore[arg-type]
@ -186,7 +202,16 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
)
gm.recompile()
optimus_scuba_log["after_recompile_post_grad"] = upload_graph(gm.graph)
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "after_recompile_post_grad",
"encoding": "string",
},
payload_fn=lambda: gm.print_readable(
print_output=False, include_stride=True, include_device=True
),
)
gm.graph.lint()

View File

@ -8,8 +8,8 @@ from typing import Optional
import torch
import torch.nn as nn
from torch._dynamo.utils import counters, detect_fake_mode, optimus_scuba_log
from torch._utils_internal import upload_graph
from torch._dynamo.utils import counters, detect_fake_mode
from torch._logging import trace_structured
from torch.fx.experimental.optimization import (
matches_module_pattern,
replace_node_module,
@ -258,7 +258,16 @@ def pre_grad_passes(
if example_inputs is not None:
gm = fuse_fx(gm, example_inputs)
numpy_compat_normalization(gm.graph)
optimus_scuba_log["before_recompile_pre_grad"] = upload_graph(gm.graph)
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "before_recompile_pre_grad",
"encoding": "string",
},
payload_fn=lambda: gm.print_readable(
print_output=False, include_stride=True, include_device=True
),
)
# We should always do the normalization_pass first
if "normalization_pass" in config.pre_grad_fusion_options:
pattern_matcher_pass = PRE_GRAD_PATTERNS["normalization_pass"]
@ -277,8 +286,15 @@ def pre_grad_passes(
for _ in range(counter):
pattern_matcher_pass.apply(gm.graph) # type: ignore[arg-type]
if not is_same_dict(counters["inductor"], inductor_before_change):
optimus_scuba_log[f"{pattern_matcher_pass.pass_name}_pre_grad"] = (
upload_graph(gm.graph)
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": f"{pattern_matcher_pass.pass_name}_pre_grad",
"encoding": "string",
},
payload_fn=lambda: gm.print_readable(
print_output=False, include_stride=True, include_device=True
),
)
# TODO: move efficient_conv_bn_eval_pass to the fusions dict too.
efficient_conv_bn_eval_pass.apply(gm.graph) # type: ignore[arg-type]
@ -294,7 +310,16 @@ def pre_grad_passes(
gm.graph.lint()
gm.recompile()
optimus_scuba_log["after_recompile_pre_grad"] = upload_graph(gm.graph)
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "after_recompile_pre_grad",
"encoding": "string",
},
payload_fn=lambda: gm.print_readable(
print_output=False, include_stride=True, include_device=True
),
)
if (
config.pattern_matcher