mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
This reverts commit 1a55fb0ee8.
Reverted https://github.com/pytorch/pytorch/pull/154725 on behalf of https://github.com/malfet due to This added 2nd copy of raise_on_run to common_utils.py which caused lint failures, see https://github.com/pytorch/pytorch/actions/runs/15445374980/job/43473457466 ([comment](https://github.com/pytorch/pytorch/pull/154725#issuecomment-2940503905))
93 lines
2.9 KiB
Python
93 lines
2.9 KiB
Python
# Owner(s): ["oncall: jit"]
|
|
|
|
import os
|
|
import sys
|
|
import unittest
|
|
|
|
import torch
|
|
|
|
|
|
# Make the helper files in test/ importable
|
|
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
|
sys.path.append(pytorch_test_dir)
|
|
from torch.testing._internal.jit_utils import JitTestCase
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise RuntimeError(
|
|
"This test file is not meant to be run directly, use:\n\n"
|
|
"\tpython test/test_jit.py TESTNAME\n\n"
|
|
"instead."
|
|
)
|
|
|
|
# NOTE: FIXING FAILING TESTS
|
|
# If you are seeing a test failure from this file, congrats, you improved
|
|
# parity between JIT and Python API. Before you fix the test, you must also update
|
|
# the corresponding section in documentation that states the unsupported behavior.
|
|
# see: `jit_unsupported.rst`
|
|
|
|
|
|
class TestUnsupportedOps(JitTestCase):
|
|
def test_factory_ops_requires_grad_fail(self):
|
|
# Keyword argument {name} unknown is a JIT-only error message,
|
|
# so these functions are succeeding in eager and failing in JIT
|
|
|
|
# Complete issue and set of ops is https://github.com/pytorch/pytorch/issues/30761
|
|
# only testing some because they should be fixed all at once
|
|
def ones():
|
|
return torch.ones([2], requires_grad=True)
|
|
|
|
with self.assertRaisesRegexWithHighlight(
|
|
Exception, "Keyword argument requires_grad unknown", "torch.ones"
|
|
):
|
|
torch.jit.script(ones)
|
|
|
|
def randn():
|
|
return torch.randn([2], requires_grad=True)
|
|
|
|
with self.assertRaisesRegexWithHighlight(
|
|
Exception, "Keyword argument requires_grad unknown", "torch.randn"
|
|
):
|
|
torch.jit.script(randn)
|
|
|
|
def zeros():
|
|
return torch.zeros([2], requires_grad=True)
|
|
|
|
with self.assertRaisesRegexWithHighlight(
|
|
Exception, "Keyword argument requires_grad unknown", "torch.zeros"
|
|
):
|
|
torch.jit.script(zeros)
|
|
|
|
@unittest.skipIf(not torch._C.has_lapack, "PyTorch compiled without Lapack")
|
|
def test_init_ops(self):
|
|
def calculate_gain():
|
|
return torch.nn.init.calculate_gain("leaky_relu", 0.2)
|
|
|
|
def eye_():
|
|
return torch.nn.init.eye_(torch.zeros([2, 2]))
|
|
|
|
def dirac_():
|
|
return torch.nn.init.dirac_(torch.empty(3, 16, 5, 5))
|
|
|
|
def kaiming_uniform_():
|
|
return torch.nn.init.kaiming_normal_(torch.empty(3, 5))
|
|
|
|
def orthogonal_():
|
|
return torch.nn.init.orthogonal_(torch.empty(3, 5))
|
|
|
|
def sparse():
|
|
return torch.nn.init.sparse_(torch.empty(3, 5), sparsity=0.1)
|
|
|
|
for func in [
|
|
calculate_gain,
|
|
eye_,
|
|
dirac_,
|
|
kaiming_uniform_,
|
|
orthogonal_,
|
|
sparse,
|
|
]:
|
|
# doesn't error in eager
|
|
func()
|
|
with self.assertRaisesRegex(Exception, ""):
|
|
torch.jit.script(func)
|