mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
unskipped mobilenet_v3 quantization and mobilenet_v2 quantization plus tests from https://github.com/pytorch/pytorch/issues/125438 (#157786)
These tests now pass on AArch64 in our downstream CI. `test_quantization.py::TestNumericSuiteEager::test_mobilenet_v2 <- test/quantization/eager/test_numeric_suite_eager.py PASSED [2.4434s] [ 35%]` Pull Request resolved: https://github.com/pytorch/pytorch/pull/157786 Approved by: https://github.com/jerryzh168, https://github.com/malfet
This commit is contained in:
parent
9fd5b5f735
commit
3a2c3c8ed3
|
|
@ -1,7 +1,6 @@
|
|||
# Owner(s): ["oncall: quantization"]
|
||||
# ruff: noqa: F841
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
import torch.ao.nn.quantized as nnq
|
||||
|
|
@ -38,7 +37,7 @@ from torch.testing._internal.common_quantization import (
|
|||
test_only_eval_fn,
|
||||
)
|
||||
from torch.testing._internal.common_quantized import override_qengines
|
||||
from torch.testing._internal.common_utils import IS_ARM64, raise_on_run_directly
|
||||
from torch.testing._internal.common_utils import raise_on_run_directly
|
||||
|
||||
|
||||
class SubModule(torch.nn.Module):
|
||||
|
|
@ -600,14 +599,12 @@ class TestNumericSuiteEager(QuantizationTestCase):
|
|||
act_compare_dict = get_matching_activations(float_model, qmodel)
|
||||
|
||||
@skip_if_no_torchvision
|
||||
@unittest.skipIf(IS_ARM64, "Not working on arm right now")
|
||||
def test_mobilenet_v2(self):
|
||||
from torchvision.models.quantization import mobilenet_v2
|
||||
|
||||
self._test_vision_model(mobilenet_v2(pretrained=True, quantize=False))
|
||||
|
||||
@skip_if_no_torchvision
|
||||
@unittest.skipIf(IS_ARM64, "Not working on arm right now")
|
||||
def test_mobilenet_v3(self):
|
||||
from torchvision.models.quantization import mobilenet_v3_large
|
||||
|
||||
|
|
|
|||
|
|
@ -1401,8 +1401,6 @@ class TestLinalg(TestCase):
|
|||
|
||||
@dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble, torch.bfloat16, torch.float16)
|
||||
def test_vector_norm(self, device, dtype):
|
||||
if IS_ARM64 and device == 'cpu' and dtype in [torch.float16, torch.bfloat16, torch.float32]:
|
||||
raise unittest.SkipTest("Fails on ARM, see https://github.com/pytorch/pytorch/issues/125438")
|
||||
# have to use torch.randn(...).to(bfloat16) instead of
|
||||
# This test compares torch.linalg.vector_norm's output with
|
||||
# torch.linalg.norm given a flattened tensor
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user