[reland][aotinductor] Add example_value metadata to nodes (#113986)

Test Plan:
`TORCH_LOGS=dynamo,inductor,aot  CUDA_VISIBLE_DEVICES=7 TORCH_COMPILE_DEBUG=0 TORCHINDUCTOR_MAX_AUTOTUNE=1 buck2 run mode/opt-split-dwarf mode/inplace -c fbcode.enable_gpu_sections=true -c fbcode.platform=platform010  caffe2/torch/fb/model_transform/experimental/benchmark:mts_gpu_benchmark -- --local-model /tmp/409501788/66/gpu_lowering/input.predictor.disagg.gpu.merge --lower-backend="AOT_INDUCTOR"`

Without passes:
`BS: 2048, MFLOPS/BS: 40.51, TFLOP/s: 37.32, Time per iter: 2.22ms, Threads: 1, QPS: 921146.83, Accuracy: True (rtol=0.01), AOT_INDUCTOR lowering duration: 66.15s`

With passes:
`BS: 2048, MFLOPS/BS: 40.51, TFLOP/s: 37.49, Time per iter: 2.21ms, Threads: 1, QPS: 925450.82, Accuracy: True (rtol=0.01), AOT_INDUCTOR lowering duration: 261.11s`

Differential Revision: D51436878

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113986
Approved by: https://github.com/zhxchen17
This commit is contained in:
Angela Yi 2023-11-19 07:12:24 +00:00 committed by PyTorch MergeBot
parent 33c6cae13b
commit 72a8329ec9
2 changed files with 22 additions and 0 deletions

View File

@ -11,6 +11,7 @@ import torch._export
import torch._inductor
import torch.fx._pytree as fx_pytree
from torch._dynamo.testing import same
from torch._dynamo.utils import counters
from torch._inductor import config
from torch._inductor.exc import CppWrapperCodeGenError
from torch._inductor.utils import aot_inductor_launcher, cache_dir
@ -318,6 +319,21 @@ class AOTInductorTestsTemplate:
with config.patch({"freezing": True}):
self.check_model(Model(self.device), example_inputs)
def test_simple_split(self):
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return torch.cat(tensors=torch.split(x, 4, dim=1), dim=-2)
example_inputs = (torch.randn(2, 8, device=self.device),)
counters.clear()
self.check_model(Model(), example_inputs)
self.assertEqual(counters["inductor"]["scmerge_split_removed"], 1)
self.assertEqual(counters["inductor"]["scmerge_cat_removed"], 1)
self.assertEqual(counters["inductor"]["scmerge_split_sections_removed"], 1)
def test_missing_output(self):
class Model(torch.nn.Module):
def __init__(self):

View File

@ -912,6 +912,8 @@ class FlattenInputOutputSignature(torch.fx.interpreter.Transformer):
arg.node.meta["val"] = self.current_node.meta["val"]
if "tensor_dict" in self.current_node.meta:
arg.node.meta["tensor_dict"] = self.current_node.meta["tensor_dict"]
if "example_value" in self.current_node.meta:
arg.node.meta["example_value"] = self.current_node.meta["example_value"]
return arg
def output(self, target, args, kwargs):
@ -925,6 +927,10 @@ class FlattenInputOutputSignature(torch.fx.interpreter.Transformer):
result_proxy = super().run_node(n)
if "val" in self.current_node.meta:
result_proxy.node.meta["val"] = self.current_node.meta["val"]
if "example_value" in self.current_node.meta:
result_proxy.node.meta["example_value"] = self.current_node.meta[
"example_value"
]
if self.current_node.op != "output":
result_proxy.node._rename(
getattr(self.current_node, "name", result_proxy.node.name)