mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[reland][aotinductor] Add example_value metadata to nodes (#113986)
Test Plan: `TORCH_LOGS=dynamo,inductor,aot CUDA_VISIBLE_DEVICES=7 TORCH_COMPILE_DEBUG=0 TORCHINDUCTOR_MAX_AUTOTUNE=1 buck2 run mode/opt-split-dwarf mode/inplace -c fbcode.enable_gpu_sections=true -c fbcode.platform=platform010 caffe2/torch/fb/model_transform/experimental/benchmark:mts_gpu_benchmark -- --local-model /tmp/409501788/66/gpu_lowering/input.predictor.disagg.gpu.merge --lower-backend="AOT_INDUCTOR"` Without passes: `BS: 2048, MFLOPS/BS: 40.51, TFLOP/s: 37.32, Time per iter: 2.22ms, Threads: 1, QPS: 921146.83, Accuracy: True (rtol=0.01), AOT_INDUCTOR lowering duration: 66.15s` With passes: `BS: 2048, MFLOPS/BS: 40.51, TFLOP/s: 37.49, Time per iter: 2.21ms, Threads: 1, QPS: 925450.82, Accuracy: True (rtol=0.01), AOT_INDUCTOR lowering duration: 261.11s` Differential Revision: D51436878 Pull Request resolved: https://github.com/pytorch/pytorch/pull/113986 Approved by: https://github.com/zhxchen17
This commit is contained in:
parent
33c6cae13b
commit
72a8329ec9
|
|
@ -11,6 +11,7 @@ import torch._export
|
|||
import torch._inductor
|
||||
import torch.fx._pytree as fx_pytree
|
||||
from torch._dynamo.testing import same
|
||||
from torch._dynamo.utils import counters
|
||||
from torch._inductor import config
|
||||
from torch._inductor.exc import CppWrapperCodeGenError
|
||||
from torch._inductor.utils import aot_inductor_launcher, cache_dir
|
||||
|
|
@ -318,6 +319,21 @@ class AOTInductorTestsTemplate:
|
|||
with config.patch({"freezing": True}):
|
||||
self.check_model(Model(self.device), example_inputs)
|
||||
|
||||
def test_simple_split(self):
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return torch.cat(tensors=torch.split(x, 4, dim=1), dim=-2)
|
||||
|
||||
example_inputs = (torch.randn(2, 8, device=self.device),)
|
||||
counters.clear()
|
||||
self.check_model(Model(), example_inputs)
|
||||
self.assertEqual(counters["inductor"]["scmerge_split_removed"], 1)
|
||||
self.assertEqual(counters["inductor"]["scmerge_cat_removed"], 1)
|
||||
self.assertEqual(counters["inductor"]["scmerge_split_sections_removed"], 1)
|
||||
|
||||
def test_missing_output(self):
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
|
|
|||
|
|
@ -912,6 +912,8 @@ class FlattenInputOutputSignature(torch.fx.interpreter.Transformer):
|
|||
arg.node.meta["val"] = self.current_node.meta["val"]
|
||||
if "tensor_dict" in self.current_node.meta:
|
||||
arg.node.meta["tensor_dict"] = self.current_node.meta["tensor_dict"]
|
||||
if "example_value" in self.current_node.meta:
|
||||
arg.node.meta["example_value"] = self.current_node.meta["example_value"]
|
||||
return arg
|
||||
|
||||
def output(self, target, args, kwargs):
|
||||
|
|
@ -925,6 +927,10 @@ class FlattenInputOutputSignature(torch.fx.interpreter.Transformer):
|
|||
result_proxy = super().run_node(n)
|
||||
if "val" in self.current_node.meta:
|
||||
result_proxy.node.meta["val"] = self.current_node.meta["val"]
|
||||
if "example_value" in self.current_node.meta:
|
||||
result_proxy.node.meta["example_value"] = self.current_node.meta[
|
||||
"example_value"
|
||||
]
|
||||
if self.current_node.op != "output":
|
||||
result_proxy.node._rename(
|
||||
getattr(self.current_node, "name", result_proxy.node.name)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user