mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[inductor] update numbytes_hint for NoneLayout to allow more fusions (#141766)
We found that [this commit](6eca0aee76) caused a ~6% performance drop in ViT INT8. This was due to changes to the `numbytes_hint` for `NoneLayout`. In this PR, we reverted the changes in `numbytes_hint` to allow more fusions.
```
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.dense = torch.nn.Linear(768, 768)
self.layernorm = torch.nn.LayerNorm(768, eps=1e-12)
def forward(self, context_layer, hidden_states):
attention_output = self.dense(context_layer)
hidden_states = attention_output + hidden_states
layer_output = self.layernorm(hidden_states)
return layer_output
```
The generated code before (left) and after (right) this PR is as follows:

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141766
Approved by: https://github.com/leslie-fang-intel, https://github.com/jgong5, https://github.com/jansel
This commit is contained in:
parent
daa27fe59d
commit
b9e253cb72
|
|
@ -4782,6 +4782,33 @@ class CPUReproTests(TestCase):
|
|||
code
|
||||
)
|
||||
|
||||
@config.patch(freezing=True)
|
||||
def test_add_layernorm(self):
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.dense = torch.nn.Linear(768, 768)
|
||||
self.layernorm = torch.nn.LayerNorm(768, eps=1e-12)
|
||||
|
||||
def forward(self, context_layer, hidden_states):
|
||||
attention_output = self.dense(context_layer)
|
||||
hidden_states = attention_output + hidden_states
|
||||
layer_output = self.layernorm(hidden_states)
|
||||
return layer_output
|
||||
|
||||
model = Model()
|
||||
example_batch = (torch.rand(1, 197, 768), torch.rand(1, 197, 768))
|
||||
from torch.testing._internal.common_quantization import (
|
||||
_generate_qdq_quantized_model,
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
converted_model = _generate_qdq_quantized_model(model, example_batch)
|
||||
torch.ao.quantization.move_exported_model_to_eval(converted_model)
|
||||
metrics.reset()
|
||||
torch.compile(converted_model)(*example_batch)
|
||||
check_metrics_vec_kernel_count(3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._inductor.test_case import run_tests
|
||||
|
|
|
|||
|
|
@ -812,6 +812,16 @@ class GraphLowering(torch.fx.Interpreter):
|
|||
def get_dtype(self, buffer_name: str) -> torch.dtype:
|
||||
if buffer_name in self.constants:
|
||||
return self.constants[buffer_name].dtype
|
||||
# For a mutation op we should return the dtype of the buffer being mutated
|
||||
if (
|
||||
hasattr(self.scheduler, "mutation_real_name")
|
||||
and buffer_name in self.scheduler.mutation_real_name
|
||||
):
|
||||
mutated_buf = self.scheduler.mutation_real_name[buffer_name]
|
||||
if mutated_buf in self.name_to_buffer:
|
||||
return self.name_to_buffer[mutated_buf].get_dtype()
|
||||
if mutated_buf in self.graph_inputs:
|
||||
return self.graph_inputs[mutated_buf].get_dtype()
|
||||
if buffer_name in self.name_to_buffer:
|
||||
return self.name_to_buffer[buffer_name].get_dtype()
|
||||
if buffer_name in self.graph_inputs:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user