mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[be] fix flaky test aot_export_ cond caused by free symbol lifting and automatic dynamic shape (#145330)
Fixes https://github.com/pytorch/pytorch/issues/139998#issuecomment-2605908426. It seems to be an issue caused by the interaction between dynamoed hop X automatic dynamic shape X auto_lift_free symbols. The immediate error is that the asserteExpectedInline of the graph can sometimes be different e.g. see https://hud.pytorch.org/flakytest?name=test_aot_export_with_torch_cond&suite=TestAOTExport&limit=100, where sometimes the shapes are lifted as input to the cond and sometimes they're not. The root cause of the flakyness is that the two invocations of torch.cond triggers two torch.compile on the same code object ([code](https://github.com/pytorch/pytorch/blob/main/torch/_higher_order_ops/cond.py#L192)), and triggers automatic dynamic shape because in test_aot_export_with_torch_cond, x has shape (3, 4) while the pre_dispatch one has shape (2, 2). Because of we auto lift free symbols for dynamic shaped input, this causes cond sometimes have the shape as arguments and sometimes not. This PR adds a simple fix by adding a _dynamo.reset before each torch.cond tests. This fixes the error by not triggering automatic dynamic shape. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145330 Approved by: https://github.com/zou3519
This commit is contained in:
parent
3c247ee8c4
commit
bdc2c2a237
|
|
@ -3960,6 +3960,10 @@ class TestMod(torch.nn.Module):
|
|||
|
||||
|
||||
class TestAOTExport(AOTTestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
torch._dynamo.reset()
|
||||
|
||||
def test_aot_export_ban_dropout_mut_pre_dispatch(self):
|
||||
def fn(p, x):
|
||||
y = torch.ops.aten.dropout.default(x, 0.1, train=False)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user