mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fixes a ton of false negatives throughout the codebase. RUFF also properly validates NOQA comments now and most of the changes are fixing typos there or removing filewide flake8 suppressions that were also silencing ruff issues. Pull Request resolved: https://github.com/pytorch/pytorch/pull/153249 Approved by: https://github.com/cyyever, https://github.com/albanD, https://github.com/seemethere
171 lines
4.4 KiB
Python
171 lines
4.4 KiB
Python
# Owner(s): ["oncall: jit"]
|
|
|
|
from dataclasses import dataclass, field, InitVar
|
|
from enum import Enum
|
|
from typing import List, Optional
|
|
|
|
from hypothesis import given, settings, strategies as st
|
|
|
|
import torch
|
|
from torch.testing._internal.jit_utils import JitTestCase
|
|
|
|
|
|
# Example jittable dataclass
|
|
@dataclass(order=True)
|
|
class Point:
|
|
x: float
|
|
y: float
|
|
norm: Optional[torch.Tensor] = None
|
|
|
|
def __post_init__(self):
|
|
self.norm = (torch.tensor(self.x) ** 2 + torch.tensor(self.y) ** 2) ** 0.5
|
|
|
|
|
|
class MixupScheme(Enum):
|
|
INPUT = ["input"]
|
|
|
|
MANIFOLD = [
|
|
"input",
|
|
"before_fusion_projection",
|
|
"after_fusion_projection",
|
|
"after_classifier_projection",
|
|
]
|
|
|
|
|
|
@dataclass
|
|
class MixupParams:
|
|
def __init__(self, alpha: float = 0.125, scheme: MixupScheme = MixupScheme.INPUT):
|
|
self.alpha = alpha
|
|
self.scheme = scheme
|
|
|
|
|
|
class MixupScheme2(Enum):
|
|
A = 1
|
|
B = 2
|
|
|
|
|
|
@dataclass
|
|
class MixupParams2:
|
|
def __init__(self, alpha: float = 0.125, scheme: MixupScheme2 = MixupScheme2.A):
|
|
self.alpha = alpha
|
|
self.scheme = scheme
|
|
|
|
|
|
@dataclass
|
|
class MixupParams3:
|
|
def __init__(self, alpha: float = 0.125, scheme: MixupScheme2 = MixupScheme2.A):
|
|
self.alpha = alpha
|
|
self.scheme = scheme
|
|
|
|
|
|
# Make sure the Meta internal tooling doesn't raise an overflow error
|
|
NonHugeFloats = st.floats(min_value=-1e4, max_value=1e4, allow_nan=False)
|
|
|
|
|
|
class TestDataclasses(JitTestCase):
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
torch._C._jit_clear_class_registry()
|
|
|
|
def test_init_vars(self):
|
|
@torch.jit.script
|
|
@dataclass(order=True)
|
|
class Point2:
|
|
x: float
|
|
y: float
|
|
norm_p: InitVar[int] = 2
|
|
norm: Optional[torch.Tensor] = None
|
|
|
|
def __post_init__(self, norm_p: int):
|
|
self.norm = (
|
|
torch.tensor(self.x) ** norm_p + torch.tensor(self.y) ** norm_p
|
|
) ** (1 / norm_p)
|
|
|
|
def fn(x: float, y: float, p: int):
|
|
pt = Point2(x, y, p)
|
|
return pt.norm
|
|
|
|
self.checkScript(fn, (1.0, 2.0, 3))
|
|
|
|
# Sort of tests both __post_init__ and optional fields
|
|
@settings(deadline=None)
|
|
@given(NonHugeFloats, NonHugeFloats)
|
|
def test__post_init__(self, x, y):
|
|
P = torch.jit.script(Point)
|
|
|
|
def fn(x: float, y: float):
|
|
pt = P(x, y)
|
|
return pt.norm
|
|
|
|
self.checkScript(fn, [x, y])
|
|
|
|
@settings(deadline=None)
|
|
@given(
|
|
st.tuples(NonHugeFloats, NonHugeFloats), st.tuples(NonHugeFloats, NonHugeFloats)
|
|
)
|
|
def test_comparators(self, pt1, pt2):
|
|
x1, y1 = pt1
|
|
x2, y2 = pt2
|
|
P = torch.jit.script(Point)
|
|
|
|
def compare(x1: float, y1: float, x2: float, y2: float):
|
|
pt1 = P(x1, y1)
|
|
pt2 = P(x2, y2)
|
|
return (
|
|
pt1 == pt2,
|
|
# pt1 != pt2, # TODO: Modify interpreter to auto-resolve (a != b) to not (a == b) when there's no __ne__
|
|
pt1 < pt2,
|
|
pt1 <= pt2,
|
|
pt1 > pt2,
|
|
pt1 >= pt2,
|
|
)
|
|
|
|
self.checkScript(compare, [x1, y1, x2, y2])
|
|
|
|
def test_default_factories(self):
|
|
@dataclass
|
|
class Foo(object):
|
|
x: List[int] = field(default_factory=list)
|
|
|
|
with self.assertRaises(NotImplementedError):
|
|
torch.jit.script(Foo)
|
|
|
|
def fn():
|
|
foo = Foo()
|
|
return foo.x
|
|
|
|
torch.jit.script(fn)()
|
|
|
|
# The user should be able to write their own __eq__ implementation
|
|
# without us overriding it.
|
|
def test_custom__eq__(self):
|
|
@torch.jit.script
|
|
@dataclass
|
|
class CustomEq:
|
|
a: int
|
|
b: int
|
|
|
|
def __eq__(self, other: "CustomEq") -> bool:
|
|
return self.a == other.a # ignore the b field
|
|
|
|
def fn(a: int, b1: int, b2: int):
|
|
pt1 = CustomEq(a, b1)
|
|
pt2 = CustomEq(a, b2)
|
|
return pt1 == pt2
|
|
|
|
self.checkScript(fn, [1, 2, 3])
|
|
|
|
def test_no_source(self):
|
|
with self.assertRaises(RuntimeError):
|
|
# uses list in Enum is not supported
|
|
torch.jit.script(MixupParams)
|
|
|
|
torch.jit.script(MixupParams2) # don't throw
|
|
|
|
def test_use_unregistered_dataclass_raises(self):
|
|
def f(a: MixupParams3):
|
|
return 0
|
|
|
|
with self.assertRaises(OSError):
|
|
torch.jit.script(f)
|