mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/74353 Repatched `d00de0d43598522b8f6ab2de553b6aaf6768faa5` by Nora Belrose (norabelrose). With following changes: * Register fake source of generated methods in linecache so that inspect.get_source will succeed. * this patching is only triggered if the given dataclass passed to torch.jit.script previously. Effectively we make this feature opt-in. ## Original Summary: Fixes #72901. Since we can't get access to the source code for synthesized magic methods on dataclasses, we have to synthesize our own versions. torch/jit/_dataclass_impls.py has the code that does this. What's supported Synthesized __init__, __eq__, and the comparison magic methods when order=True is set on the dataclass decorator Default values for fields __post_init__, including using InitVar fields inside of __post_init__, on Python 3.8+ Overriding __eq__ or any of the comparison magic methods to provide your own implementation What's not supported Default factory initializers for fields Frozen dataclasses InitVar on Python 3.7 __repr__ and __hash__ (these are actually implemented, but the TorchScript interpreter won't call them) Using the != operator on dataclasses inside TorchScript; this is because TorchScript requires that you implement __ne__ to use this operator, whereas in regular Python the != operator will resolve to the negation of whatever is returned by __eq__ if there's no __ne__. Dataclasses don't actually synthesize an __ne__ method for this reason. I've been toying with different ways to fix this but != is not working in this PR at the moment. Test Plan: unittest Also run previously failed test: ``` buck test mode/dev-nosan //fblearner/flow/projects/fluent2/definition/transformers/contrib/faim/test:tests -- --exact 'fblearner/flow/projects/fluent2/definition/transformers/contrib/faim/test:tests - test_mixmatch_multiclass (fblearner.flow.projects.fluent2.definition.transformers.contrib.faim.test.faim_mixmatch_test.TestFaimTransformerMixMatch)' ``` passes Differential Revision: D35206262 Pull Request resolved: https://github.com/pytorch/pytorch/pull/74889 Approved by: https://github.com/zhxchen17
166 lines
4.6 KiB
Python
166 lines
4.6 KiB
Python
# Owner(s): ["oncall: jit"]
|
|
# flake8: noqa
|
|
|
|
from dataclasses import dataclass, field, InitVar
|
|
from hypothesis import given, settings, strategies as st
|
|
from torch.testing._internal.jit_utils import JitTestCase
|
|
from typing import List, Optional
|
|
import sys
|
|
import torch
|
|
import unittest
|
|
from enum import Enum
|
|
|
|
# Example jittable dataclass
|
|
@torch.jit.script
|
|
@dataclass(order=True)
|
|
class Point:
|
|
x: float
|
|
y: float
|
|
norm: Optional[torch.Tensor] = None
|
|
|
|
def __post_init__(self):
|
|
self.norm = (torch.tensor(self.x) ** 2 + torch.tensor(self.y) ** 2) ** 0.5
|
|
|
|
class MixupScheme(Enum):
|
|
|
|
INPUT = ["input"]
|
|
|
|
MANIFOLD = [
|
|
"input",
|
|
"before_fusion_projection",
|
|
"after_fusion_projection",
|
|
"after_classifier_projection",
|
|
]
|
|
|
|
|
|
@dataclass
|
|
class MixupParams:
|
|
def __init__(self, alpha: float = 0.125, scheme: MixupScheme = MixupScheme.INPUT):
|
|
self.alpha = alpha
|
|
self.scheme = scheme
|
|
|
|
class MixupScheme2(Enum):
|
|
A = 1
|
|
B = 2
|
|
|
|
|
|
@dataclass
|
|
class MixupParams2:
|
|
def __init__(self, alpha: float = 0.125, scheme: MixupScheme2 = MixupScheme2.A):
|
|
self.alpha = alpha
|
|
self.scheme = scheme
|
|
|
|
@dataclass
|
|
class MixupParams3:
|
|
def __init__(self, alpha: float = 0.125, scheme: MixupScheme2 = MixupScheme2.A):
|
|
self.alpha = alpha
|
|
self.scheme = scheme
|
|
|
|
|
|
# Make sure the Meta internal tooling doesn't raise an overflow error
|
|
NonHugeFloats = st.floats(min_value=-1e4, max_value=1e4, allow_nan=False)
|
|
|
|
class TestDataclasses(JitTestCase):
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
torch._C._jit_clear_class_registry()
|
|
# We only support InitVar in JIT dataclasses for Python 3.8+ because it would be very hard
|
|
# to support without the `type` attribute on InitVar (see comment in _dataclass_impls.py).
|
|
@unittest.skipIf(sys.version_info < (3, 8), "InitVar not supported in Python < 3.8")
|
|
def test_init_vars(self):
|
|
@torch.jit.script
|
|
@dataclass(order=True)
|
|
class Point2:
|
|
x: float
|
|
y: float
|
|
norm_p: InitVar[int] = 2
|
|
norm: Optional[torch.Tensor] = None
|
|
|
|
def __post_init__(self, norm_p: int):
|
|
self.norm = (torch.tensor(self.x) ** norm_p + torch.tensor(self.y) ** norm_p) ** (1 / norm_p)
|
|
|
|
def fn(x: float, y: float, p: int):
|
|
pt = Point2(x, y, p)
|
|
return pt.norm
|
|
|
|
self.checkScript(fn, (1.0, 2.0, 3))
|
|
|
|
# Sort of tests both __post_init__ and optional fields
|
|
@settings(deadline=None)
|
|
@given(NonHugeFloats, NonHugeFloats)
|
|
def test__post_init__(self, x, y):
|
|
def fn(x: float, y: float):
|
|
pt = Point(x, y)
|
|
return pt.norm
|
|
|
|
self.checkScript(fn, [x, y])
|
|
|
|
@settings(deadline=None)
|
|
@given(st.tuples(NonHugeFloats, NonHugeFloats), st.tuples(NonHugeFloats, NonHugeFloats))
|
|
def test_comparators(self, pt1, pt2):
|
|
x1, y1 = pt1
|
|
x2, y2 = pt2
|
|
|
|
def compare(x1: float, y1: float, x2: float, y2: float):
|
|
pt1 = Point(x1, y1)
|
|
pt2 = Point(x2, y2)
|
|
return (
|
|
pt1 == pt2,
|
|
# pt1 != pt2, # TODO: Modify interpreter to auto-resolve (a != b) to not (a == b) when there's no __ne__
|
|
pt1 < pt2,
|
|
pt1 <= pt2,
|
|
pt1 > pt2,
|
|
pt1 >= pt2,
|
|
)
|
|
|
|
self.checkScript(compare, [x1, y1, x2, y2])
|
|
|
|
def test_default_factories(self):
|
|
@dataclass
|
|
class Foo(object):
|
|
x: List[int] = field(default_factory=list)
|
|
|
|
with self.assertRaises(NotImplementedError):
|
|
torch.jit.script(Foo)
|
|
def fn():
|
|
foo = Foo()
|
|
return foo.x
|
|
|
|
torch.jit.script(fn)()
|
|
|
|
# The user should be able to write their own __eq__ implementation
|
|
# without us overriding it.
|
|
def test_custom__eq__(self):
|
|
@torch.jit.script
|
|
@dataclass
|
|
class CustomEq:
|
|
a: int
|
|
b: int
|
|
|
|
def __eq__(self, other: 'CustomEq') -> bool:
|
|
return self.a == other.a # ignore the b field
|
|
|
|
def fn(a: int, b1: int, b2: int):
|
|
pt1 = CustomEq(a, b1)
|
|
pt2 = CustomEq(a, b2)
|
|
return pt1 == pt2
|
|
|
|
self.checkScript(fn, [1, 2, 3])
|
|
|
|
def test_no_source(self):
|
|
with self.assertRaises(RuntimeError):
|
|
# uses list in Enum is not supported
|
|
torch.jit.script(MixupParams)
|
|
|
|
torch.jit.script(MixupParams2) # don't throw
|
|
|
|
|
|
def test_use_unregistered_dataclass_raises(self):
|
|
|
|
def f(a: MixupParams3):
|
|
return 0
|
|
|
|
with self.assertRaises(OSError):
|
|
torch.jit.script(f)
|