pytorch/test/dynamo/test_utils.py
Shunting Zhang c0735a3dd3 [pt2-bench] fix accuracy failure for a few models (#129941)
This PR batch the fix for a few accuracy failures issues during training by raising tolerance. I do that only for models that I think it fails not due to real issue.

## sebotnet33ts_256

The accuracy test for this model start to fail around June 05 [link](https://hud.pytorch.org/benchmark/timm_models/inductor_with_cudagraphs?dashboard=torchinductor&startTime=Sun%2C%2002%20Jun%202024%2007%3A19%3A38%20GMT&stopTime=Tue%2C%2002%20Jul%202024%2007%3A19%3A38%20GMT&granularity=day&mode=training&dtype=amp&lBranch=main&lCommit=04a0d856207d83c2031e4b9cb6825ba3e0092850&rBranch=main&rCommit=e62925930f6a62f6aeeb1fe1a661a9bd3352b53d&model=sebotnet33ts_256).

I can not repro locally, but from the log from the dashboard:
```
RMSE (res-fp64): 0.09441, (ref-fp64): 0.02971 and shape=torch.Size([1536]). res.dtype: torch.float32, multiplier: 3.000000, tol: 0.040000
```
raising the tolerance should fix it.

## DebertaForQuestionAnswering

This model fails accuracy test on the dashboard only in max-autotune mode. I can not repro locally by command:
```
TORCHINDUCTOR_MAX_AUTOTUNE=1 time python benchmarks/dynamo/huggingface.py --accuracy --no-translation-validation --training --amp --backend inductor --device cuda --only DebertaForQuestionAnswering
```

From error message on the dashboard:
```
RMSE (res-fp64): 0.01803, (ref-fp64): 0.00537 and shape=torch.Size([2]). res.dtype: torch.float32, multiplier: 3.000000, tol: 0.010000
```

0.02 tolerance should suppress this error.

## gluon_inception_v3

This model fail on the dashboard in max-autotune mode. I can not repro locally by command
```
TORCHINDUCTOR_MAX_AUTOTUNE=1 time python benchmarks/dynamo/timm_models.py --accuracy --training --amp --backend inductor --disable-cudagraphs --device cuda --only gluon_inception_v3
```

From error message on the dashboard
```
RMSE (res-fp64): 0.02798, (ref-fp64): 0.00730 and shape=torch.Size([384]). res.dtype: torch.float32, multiplier: 3.000000, tol: 0.010000
Accuracy failed for key name Mixed_7c.branch3x3dbl_3a.bn.running_var
```
raising tolerance should suppress this error.

# mobilenetv3_large_100
Fail in MA model. I can not repro locally by command
```
TORCHINDUCTOR_MAX_AUTOTUNE=1 time python benchmarks/dynamo/timm_models.py --accuracy --training --amp --backend inductor --disable-cudagraphs --device cuda --only
```
The error message on the dashboard is
```
RMSE (res-fp64): 0.29754, (ref-fp64): 0.05205 and shape=torch.Size([]). res.dtype: torch.float32, multiplier: 3.000000, tol: 0.040000
```

The tensor is so small that the noise can be high. I use larger multiplier for smaller tensor in torch._dynamo.utils.same.

# yolov3

Fail on dashboard with error
```
Error on the dashboard: RMSE (res-fp64): 0.01278, (ref-fp64): 0.00246 and shape=torch.Size([256]). res.dtype: torch.float32, multiplier: 3.000000, tol: 0.001000
```

Fix it by using a larger multiplier for smaller tensors and raising the tolereance.

# timm_efficientdet

Fail on the dashboard with error
```
E0623 18:37:43.638000 139924418725056 torch/_dynamo/utils.py:1468] RMSE (res-fp64): 0.00096, (ref-fp64): 0.00009 and shape=torch.Size([2]). res.dtype: torch.float32, multiplier: 3.000000, tol: 0.001000
```
But I can not repro locally with command
```
time python benchmarks/dynamo/torchbench.py --backend inductor --amp --performance --only timm_efficientdet  --training
```

Raise the tolerance should fix.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129941
Approved by: https://github.com/jansel
ghstack dependencies: #129996
2024-07-05 10:26:39 +00:00

77 lines
2.2 KiB
Python

# Owner(s): ["module: dynamo"]
import torch
from torch._dynamo import utils
from torch._inductor.test_case import TestCase
class TestUtils(TestCase):
def test_nan(self):
a = torch.Tensor([float("nan")])
b = torch.Tensor([float("nan")])
fp64_ref = torch.DoubleTensor([5.0])
res = utils.same(a, b, fp64_ref=fp64_ref, equal_nan=True)
self.assertTrue(res)
def test_larger_multiplier_for_smaller_tensor(self):
"""
Tensor numel between (10, 500]
"""
N = 100
fp64_ref = torch.full([N], 0.0, dtype=torch.double)
a = torch.full([N], 1.0)
tol = 4 * 1e-2
self.assertTrue(utils.same(a, a * 2, fp64_ref=fp64_ref, tol=tol))
self.assertFalse(utils.same(a, a * 4, fp64_ref=fp64_ref, tol=tol))
self.assertTrue(
utils.same(
a,
a * 4,
fp64_ref=fp64_ref,
use_larger_multiplier_for_smaller_tensor=True,
tol=tol,
)
)
self.assertFalse(
utils.same(
a,
a * 6,
fp64_ref=fp64_ref,
use_larger_multiplier_for_smaller_tensor=True,
tol=tol,
)
)
def test_larger_multiplier_for_even_smaller_tensor(self):
"""
Tesnor numel <=10
"""
fp64_ref = torch.DoubleTensor([0.0])
a = torch.Tensor([1.0])
tol = 4 * 1e-2
self.assertTrue(utils.same(a, a * 2, fp64_ref=fp64_ref, tol=tol))
self.assertFalse(utils.same(a, a * 7, fp64_ref=fp64_ref, tol=tol))
self.assertTrue(
utils.same(
a,
a * 7,
fp64_ref=fp64_ref,
use_larger_multiplier_for_smaller_tensor=True,
tol=tol,
)
)
self.assertFalse(
utils.same(
a,
a * 20,
fp64_ref=fp64_ref,
use_larger_multiplier_for_smaller_tensor=True,
tol=tol,
)
)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()