[Dynamo] Allow inlining into AO quantization modules (#152934)

This adds dynamo inlining into `torch.ao.quantization.fake_quantize`.

This is needed for QAT compatbility w/ an RL training model.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152934
Approved by: https://github.com/williamwen42
This commit is contained in:
Michael Lazos 2025-05-07 23:58:11 +00:00 committed by PyTorch MergeBot
parent 5bf0c3518c
commit 20e2ca3e29
2 changed files with 16 additions and 0 deletions

View File

@ -6193,6 +6193,21 @@ def forward(self, s77 : torch.SymInt, s27 : torch.SymInt, L_x_ : torch.Tensor):
with torch.no_grad(): with torch.no_grad():
model(x) model(x)
def test_ao_fake_quantize_tracing(self):
import torch.ao.quantization.fake_quantize
q = torch.ao.quantization.FusedMovingAvgObsFakeQuantize()
def fn(x):
return q(x)
x = torch.ones(2, 2)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
res = opt_fn(x)
eager_res = fn(x)
self.assertEqual(res, eager_res)
def test_typed_dict(self): def test_typed_dict(self):
class LlavaImagePixelInputs(TypedDict): class LlavaImagePixelInputs(TypedDict):
type: Literal["pixel_values"] type: Literal["pixel_values"]

View File

@ -3298,6 +3298,7 @@ MOD_INLINELIST = [
"torch._tensor", "torch._tensor",
"torch.amp.autocast_mode", "torch.amp.autocast_mode",
"torch.ao.nn", "torch.ao.nn",
"torch.ao.quantization.fake_quantize",
"torch.autograd.function", "torch.autograd.function",
"torch.backends.cuda", "torch.backends.cuda",
"torch.cuda.amp.autocast_mode", "torch.cuda.amp.autocast_mode",