mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Dynamo] Allow inlining into AO quantization modules (#152934)
This adds dynamo inlining into `torch.ao.quantization.fake_quantize`. This is needed for QAT compatbility w/ an RL training model. Pull Request resolved: https://github.com/pytorch/pytorch/pull/152934 Approved by: https://github.com/williamwen42
This commit is contained in:
parent
5bf0c3518c
commit
20e2ca3e29
|
|
@ -6193,6 +6193,21 @@ def forward(self, s77 : torch.SymInt, s27 : torch.SymInt, L_x_ : torch.Tensor):
|
|||
with torch.no_grad():
|
||||
model(x)
|
||||
|
||||
def test_ao_fake_quantize_tracing(self):
|
||||
import torch.ao.quantization.fake_quantize
|
||||
|
||||
q = torch.ao.quantization.FusedMovingAvgObsFakeQuantize()
|
||||
|
||||
def fn(x):
|
||||
return q(x)
|
||||
|
||||
x = torch.ones(2, 2)
|
||||
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
||||
res = opt_fn(x)
|
||||
eager_res = fn(x)
|
||||
|
||||
self.assertEqual(res, eager_res)
|
||||
|
||||
def test_typed_dict(self):
|
||||
class LlavaImagePixelInputs(TypedDict):
|
||||
type: Literal["pixel_values"]
|
||||
|
|
|
|||
|
|
@ -3298,6 +3298,7 @@ MOD_INLINELIST = [
|
|||
"torch._tensor",
|
||||
"torch.amp.autocast_mode",
|
||||
"torch.ao.nn",
|
||||
"torch.ao.quantization.fake_quantize",
|
||||
"torch.autograd.function",
|
||||
"torch.backends.cuda",
|
||||
"torch.cuda.amp.autocast_mode",
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user