[inductor] update numbytes_hint for NoneLayout to allow more fusions (#141766)

We found that [this commit](6eca0aee76) caused a ~6% performance drop in ViT INT8. This was due to changes to the `numbytes_hint` for `NoneLayout`. In this PR, we reverted the changes in `numbytes_hint` to allow more fusions.

```
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.dense = torch.nn.Linear(768, 768)
        self.layernorm = torch.nn.LayerNorm(768, eps=1e-12)
    def forward(self, context_layer, hidden_states):
        attention_output = self.dense(context_layer)
        hidden_states = attention_output + hidden_states
        layer_output = self.layernorm(hidden_states)
        return layer_output
```
The generated code before (left) and after (right) this PR is as follows:
![image](https://github.com/user-attachments/assets/0ec65ae5-103e-4e2c-bf7c-e8bed24fc179)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141766
Approved by: https://github.com/leslie-fang-intel, https://github.com/jgong5, https://github.com/jansel
This commit is contained in:
blzheng 2024-12-08 19:37:46 -08:00 committed by PyTorch MergeBot
parent daa27fe59d
commit b9e253cb72
2 changed files with 37 additions and 0 deletions

View File

@ -4782,6 +4782,33 @@ class CPUReproTests(TestCase):
code
)
@config.patch(freezing=True)
def test_add_layernorm(self):
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.dense = torch.nn.Linear(768, 768)
self.layernorm = torch.nn.LayerNorm(768, eps=1e-12)
def forward(self, context_layer, hidden_states):
attention_output = self.dense(context_layer)
hidden_states = attention_output + hidden_states
layer_output = self.layernorm(hidden_states)
return layer_output
model = Model()
example_batch = (torch.rand(1, 197, 768), torch.rand(1, 197, 768))
from torch.testing._internal.common_quantization import (
_generate_qdq_quantized_model,
)
with torch.no_grad():
converted_model = _generate_qdq_quantized_model(model, example_batch)
torch.ao.quantization.move_exported_model_to_eval(converted_model)
metrics.reset()
torch.compile(converted_model)(*example_batch)
check_metrics_vec_kernel_count(3)
if __name__ == "__main__":
from torch._inductor.test_case import run_tests

View File

@ -812,6 +812,16 @@ class GraphLowering(torch.fx.Interpreter):
def get_dtype(self, buffer_name: str) -> torch.dtype:
if buffer_name in self.constants:
return self.constants[buffer_name].dtype
# For a mutation op we should return the dtype of the buffer being mutated
if (
hasattr(self.scheduler, "mutation_real_name")
and buffer_name in self.scheduler.mutation_real_name
):
mutated_buf = self.scheduler.mutation_real_name[buffer_name]
if mutated_buf in self.name_to_buffer:
return self.name_to_buffer[mutated_buf].get_dtype()
if mutated_buf in self.graph_inputs:
return self.graph_inputs[mutated_buf].get_dtype()
if buffer_name in self.name_to_buffer:
return self.name_to_buffer[buffer_name].get_dtype()
if buffer_name in self.graph_inputs: