From 72a8329ec945aeda366353e171b7c110e30f7736 Mon Sep 17 00:00:00 2001 From: Angela Yi Date: Sun, 19 Nov 2023 07:12:24 +0000 Subject: [PATCH] [reland][aotinductor] Add example_value metadata to nodes (#113986) Test Plan: `TORCH_LOGS=dynamo,inductor,aot CUDA_VISIBLE_DEVICES=7 TORCH_COMPILE_DEBUG=0 TORCHINDUCTOR_MAX_AUTOTUNE=1 buck2 run mode/opt-split-dwarf mode/inplace -c fbcode.enable_gpu_sections=true -c fbcode.platform=platform010 caffe2/torch/fb/model_transform/experimental/benchmark:mts_gpu_benchmark -- --local-model /tmp/409501788/66/gpu_lowering/input.predictor.disagg.gpu.merge --lower-backend="AOT_INDUCTOR"` Without passes: `BS: 2048, MFLOPS/BS: 40.51, TFLOP/s: 37.32, Time per iter: 2.22ms, Threads: 1, QPS: 921146.83, Accuracy: True (rtol=0.01), AOT_INDUCTOR lowering duration: 66.15s` With passes: `BS: 2048, MFLOPS/BS: 40.51, TFLOP/s: 37.49, Time per iter: 2.21ms, Threads: 1, QPS: 925450.82, Accuracy: True (rtol=0.01), AOT_INDUCTOR lowering duration: 261.11s` Differential Revision: D51436878 Pull Request resolved: https://github.com/pytorch/pytorch/pull/113986 Approved by: https://github.com/zhxchen17 --- test/inductor/test_aot_inductor.py | 16 ++++++++++++++++ torch/_dynamo/eval_frame.py | 6 ++++++ 2 files changed, 22 insertions(+) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index 0fc73c9cf61..b1dc7431816 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -11,6 +11,7 @@ import torch._export import torch._inductor import torch.fx._pytree as fx_pytree from torch._dynamo.testing import same +from torch._dynamo.utils import counters from torch._inductor import config from torch._inductor.exc import CppWrapperCodeGenError from torch._inductor.utils import aot_inductor_launcher, cache_dir @@ -318,6 +319,21 @@ class AOTInductorTestsTemplate: with config.patch({"freezing": True}): self.check_model(Model(self.device), example_inputs) + def test_simple_split(self): + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.cat(tensors=torch.split(x, 4, dim=1), dim=-2) + + example_inputs = (torch.randn(2, 8, device=self.device),) + counters.clear() + self.check_model(Model(), example_inputs) + self.assertEqual(counters["inductor"]["scmerge_split_removed"], 1) + self.assertEqual(counters["inductor"]["scmerge_cat_removed"], 1) + self.assertEqual(counters["inductor"]["scmerge_split_sections_removed"], 1) + def test_missing_output(self): class Model(torch.nn.Module): def __init__(self): diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index fe115e438bb..3636fc35816 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -912,6 +912,8 @@ class FlattenInputOutputSignature(torch.fx.interpreter.Transformer): arg.node.meta["val"] = self.current_node.meta["val"] if "tensor_dict" in self.current_node.meta: arg.node.meta["tensor_dict"] = self.current_node.meta["tensor_dict"] + if "example_value" in self.current_node.meta: + arg.node.meta["example_value"] = self.current_node.meta["example_value"] return arg def output(self, target, args, kwargs): @@ -925,6 +927,10 @@ class FlattenInputOutputSignature(torch.fx.interpreter.Transformer): result_proxy = super().run_node(n) if "val" in self.current_node.meta: result_proxy.node.meta["val"] = self.current_node.meta["val"] + if "example_value" in self.current_node.meta: + result_proxy.node.meta["example_value"] = self.current_node.meta[ + "example_value" + ] if self.current_node.op != "output": result_proxy.node._rename( getattr(self.current_node, "name", result_proxy.node.name)