mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Add None return type to init -- tests (#132352)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132352 Approved by: https://github.com/ezyang ghstack dependencies: #132335, #132351
This commit is contained in:
parent
a6985c09cb
commit
221350e3a4
|
|
@ -20,7 +20,7 @@ logging.basicConfig(
|
|||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
|
||||
self.conv2 = nn.Conv2d(32, 32, kernel_size=3)
|
||||
|
|
|
|||
|
|
@ -686,7 +686,7 @@ class TestNormDataSparsifiers(_NormDataSparsifierTestCase):
|
|||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.emb1 = nn.Embedding(100, 3)
|
||||
self.embbag1 = nn.EmbeddingBag(200, 32)
|
||||
|
|
|
|||
|
|
@ -912,7 +912,7 @@ class TestFPGMPruner(TestCase):
|
|||
"""
|
||||
|
||||
class SimpleConvFPGM(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv2d1 = nn.Conv2d(
|
||||
in_channels=1, out_channels=3, kernel_size=3, padding=1, bias=False
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import torch.nn as nn
|
|||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(20, 20)
|
||||
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ class SimpleModule(torch.nn.Module):
|
|||
a simple module to be compiled
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.fc = torch.nn.Linear(4, 6)
|
||||
self.relu = torch.nn.ReLU()
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ class Net(torch.nn.Module):
|
|||
|
||||
|
||||
class NetWithTensorConstants(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.w = torch.randn(30, 1, device="cuda")
|
||||
|
||||
|
|
|
|||
|
|
@ -1098,7 +1098,7 @@ TEST(RunTimeTest, ParseBytecode) {
|
|||
|
||||
// class Module(torch.nn.Module):
|
||||
//
|
||||
// def __init__(self):
|
||||
// def __init__(self) -> None:
|
||||
// super().__init__()
|
||||
//
|
||||
// def forward(self, x: int, h: int, xfirst: bool):
|
||||
|
|
@ -1169,7 +1169,7 @@ TEST(RunTimeTest, ParseOperator) {
|
|||
// PyTorch program:
|
||||
|
||||
// class Add(torch.nn.Module):
|
||||
// def __init__(self):
|
||||
// def __init__(self) -> None:
|
||||
// super().__init__()
|
||||
|
||||
// def forward(self, a, b):
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ class EvalModeForLoadedModule(FileSetup):
|
|||
|
||||
def setup(self):
|
||||
class Model(torch.jit.ScriptModule):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.dropout = torch.nn.Dropout(0.1)
|
||||
|
||||
|
|
|
|||
|
|
@ -22,21 +22,21 @@ TEST(RunTimeTest, LoadAndForward) {
|
|||
|
||||
// sequence.ptl source code:
|
||||
// class A(torch.nn.Module):
|
||||
// def __init__(self):
|
||||
// def __init__(self) -> None:
|
||||
// super().__init__()
|
||||
//
|
||||
// def forward(self, x):
|
||||
// return x + 1
|
||||
//
|
||||
// class B(torch.nn.Module):
|
||||
// def __init__(self):
|
||||
// def __init__(self) -> None:
|
||||
// super().__init__()
|
||||
//
|
||||
// def forward(self, x):
|
||||
// return x + 2
|
||||
//
|
||||
// class C(torch.nn.Module):
|
||||
// def __init__(self):
|
||||
// def __init__(self) -> None:
|
||||
// super().__init__()
|
||||
// self.A0 = A()
|
||||
// self.B0 = B()
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from torch import nn
|
|||
|
||||
|
||||
class NeuralNetwork(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.flatten = nn.Flatten()
|
||||
self.linear_relu_stack = nn.Sequential(
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ def get_custom_op_library_path():
|
|||
|
||||
|
||||
class Model(torch.jit.ScriptModule):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.p = torch.nn.Parameter(torch.eye(5))
|
||||
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ class PreDispatchSchemaCheckMode(SchemaCheckMode):
|
|||
later decompose and become functional.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self._dispatch_key = torch._C.DispatchKey.PreDispatch
|
||||
super().__init__()
|
||||
|
||||
|
|
|
|||
|
|
@ -1010,7 +1010,7 @@ class TestConverter(TestCase):
|
|||
# Since self.data is only read but not written, it is lifted as
|
||||
# constant tensors.
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.data = torch.randn(3, 2)
|
||||
|
||||
|
|
@ -1018,7 +1018,7 @@ class TestConverter(TestCase):
|
|||
return x + self.data
|
||||
|
||||
class Goo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.data = torch.randn(3, 2)
|
||||
self.foo = Foo()
|
||||
|
|
@ -1032,7 +1032,7 @@ class TestConverter(TestCase):
|
|||
|
||||
def test_prim_SetAttr(self):
|
||||
class Module(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.data = torch.nn.Buffer(torch.ones(3, 2))
|
||||
|
||||
|
|
@ -1046,7 +1046,7 @@ class TestConverter(TestCase):
|
|||
)
|
||||
|
||||
class Module(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.data = torch.nn.Buffer(torch.ones(3, 2))
|
||||
|
||||
|
|
@ -1064,7 +1064,7 @@ class TestConverter(TestCase):
|
|||
# In converter, we change tensor constants that are assigned as a buffer automatically,
|
||||
# since it might be hard to manually register them as buffers.
|
||||
class Module(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.data = torch.ones(3, 2)
|
||||
|
||||
|
|
@ -1082,7 +1082,7 @@ class TestConverter(TestCase):
|
|||
)
|
||||
|
||||
class Module(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.count = 0
|
||||
|
||||
|
|
@ -1165,7 +1165,7 @@ class TestConverter(TestCase):
|
|||
|
||||
def test_context_manager(self):
|
||||
class ContextManager:
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.count = 0
|
||||
return
|
||||
|
||||
|
|
@ -1211,7 +1211,7 @@ class TestConverter(TestCase):
|
|||
|
||||
def test_ts2ep_multi_outputs_on_call_ops(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.pool = torch.nn.AdaptiveMaxPool2d((2, 2), return_indices=True)
|
||||
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ class TestExperiment(TestCase):
|
|||
def test_with_buffer_as_submodule(self):
|
||||
@_mark_strict_experimental
|
||||
class B(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.buffer1 = torch.nn.Buffer(torch.ones(3))
|
||||
|
||||
|
|
@ -31,7 +31,7 @@ class TestExperiment(TestCase):
|
|||
return x.sum() + y.sum() + buffer_updated.sum()
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.submodule = B()
|
||||
|
||||
|
|
@ -86,7 +86,7 @@ def forward(self, arg0_1, arg1_1):
|
|||
def test_mark_strict_with_container_type(self):
|
||||
@_mark_strict_experimental
|
||||
class B(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
|
|
@ -94,7 +94,7 @@ def forward(self, arg0_1, arg1_1):
|
|||
return x0.sum()
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.submodule = B()
|
||||
|
||||
|
|
@ -194,7 +194,7 @@ def forward(self, arg0_1, arg1_1):
|
|||
|
||||
def test_joint_basic(self) -> None:
|
||||
class Module(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(3, 3)
|
||||
self.loss = torch.nn.CrossEntropyLoss()
|
||||
|
|
@ -266,7 +266,7 @@ def forward(self, arg0_1, arg1_1):
|
|||
from torch.export import Dim
|
||||
|
||||
class Module(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.y = torch.nn.Parameter(torch.randn(3))
|
||||
|
||||
|
|
|
|||
|
|
@ -359,7 +359,7 @@ graph():
|
|||
return x + x
|
||||
|
||||
class Basic(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.external_add = ExternalMethod().add
|
||||
|
||||
|
|
@ -373,7 +373,7 @@ graph():
|
|||
|
||||
def test_colon_parameter(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.register_parameter("foo:bar", torch.nn.Parameter(torch.ones(3, 3)))
|
||||
|
||||
|
|
@ -445,7 +445,7 @@ graph():
|
|||
|
||||
def test_basic_non_strict_real_tensor(self):
|
||||
class Basic(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.param = torch.nn.Parameter(torch.randn(1, 3))
|
||||
|
||||
|
|
@ -459,7 +459,7 @@ graph():
|
|||
|
||||
def test_basic_non_strict_fake_tensor(self):
|
||||
class Basic(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.param = torch.nn.Parameter(torch.randn(3, 2))
|
||||
|
||||
|
|
@ -476,7 +476,7 @@ graph():
|
|||
|
||||
def test_non_strict_dynamic_shapes(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.u = torch.nn.Buffer(torch.ones(1))
|
||||
self.v = torch.nn.Buffer(torch.ones(1))
|
||||
|
|
@ -591,7 +591,7 @@ graph():
|
|||
|
||||
def test_state_tensors(self):
|
||||
class M(torch.nn.Module): # simple with register buffer
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.buf = torch.nn.Buffer(torch.ones(2, 3), persistent=False)
|
||||
|
||||
|
|
@ -615,7 +615,7 @@ graph():
|
|||
)
|
||||
|
||||
class M(torch.nn.Module): # simple without register buffer
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.buf = torch.ones(2, 3)
|
||||
|
||||
|
|
@ -635,7 +635,7 @@ graph():
|
|||
torch.export.export(M(), (torch.randn(2, 3),), strict=False)
|
||||
|
||||
class M(torch.nn.Module): # complex with register buffer
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
tensors = [torch.ones(2, 3), torch.ones(2, 3)]
|
||||
for i, tensor in enumerate(tensors):
|
||||
|
|
@ -666,7 +666,7 @@ graph():
|
|||
)
|
||||
|
||||
class M(torch.nn.Module): # complex without register buffer
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.tensors = [torch.ones(2, 3), torch.ones(2, 3)]
|
||||
|
||||
|
|
@ -694,7 +694,7 @@ graph():
|
|||
|
||||
def test_state_primitives(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.x = 1
|
||||
self.y = {"k": 2}
|
||||
|
|
@ -713,7 +713,7 @@ graph():
|
|||
|
||||
def test_torch_fn(self):
|
||||
class M1(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(3, 3)
|
||||
self.relu = torch.nn.ReLU()
|
||||
|
|
@ -741,7 +741,7 @@ graph():
|
|||
self.assertEqual(actual_result, expected_result)
|
||||
|
||||
class M2(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x, weight, bias):
|
||||
|
|
@ -803,7 +803,7 @@ graph():
|
|||
|
||||
def test_export_preserve_linear_at_aot_level(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(3, 3)
|
||||
|
||||
|
|
@ -852,7 +852,7 @@ def forward(self, p_linear_weight, p_linear_bias, x):
|
|||
return self.foo(x)
|
||||
|
||||
class CondBranchClassMethod(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.subm = MySubModule()
|
||||
|
||||
|
|
@ -1191,7 +1191,7 @@ def forward(self, p_linear_weight, p_linear_bias, x):
|
|||
|
||||
def test_keep_composite_ops_invalid(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(3, 3)
|
||||
|
||||
|
|
@ -1228,7 +1228,7 @@ def forward(self, p_linear_weight, p_linear_bias, x):
|
|||
|
||||
def test_keep_composite_ops_linear_convd(self):
|
||||
class MyLinear(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.weight = torch.randn(20, 98)
|
||||
self.bias = torch.randn(20)
|
||||
|
|
@ -1237,7 +1237,7 @@ def forward(self, p_linear_weight, p_linear_bias, x):
|
|||
return torch.nn.functional.linear(x, self.weight, self.bias)
|
||||
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(16, 33, 3)
|
||||
self.conv1d = torch.nn.Conv1d(16, 33, 3)
|
||||
|
|
@ -1313,7 +1313,7 @@ def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, c_
|
|||
|
||||
def test_keep_composite_ops_linear_convd_for_training_ir(self):
|
||||
class MyLinear(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.weight = torch.nn.Buffer(torch.randn(20, 98))
|
||||
self.bias = torch.nn.Buffer(torch.randn(20))
|
||||
|
|
@ -1322,7 +1322,7 @@ def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, c_
|
|||
return torch.nn.functional.linear(x, self.weight, self.bias)
|
||||
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(16, 33, 3)
|
||||
self.conv1d = torch.nn.Conv1d(16, 33, 3)
|
||||
|
|
@ -1460,7 +1460,7 @@ def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, b_
|
|||
|
||||
def test_simple_export_for_training(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(2, 2)
|
||||
|
||||
|
|
@ -1496,7 +1496,7 @@ def forward(self, x):
|
|||
|
||||
def test_export_for_training_with_mutation(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.buffer = torch.nn.Buffer(torch.ones(4, 4))
|
||||
|
||||
|
|
@ -1540,7 +1540,7 @@ def forward(self, x):
|
|||
|
||||
def test_export_for_training_with_dynamic_shapes(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.buffer = torch.nn.Buffer(torch.ones(4, 4))
|
||||
|
||||
|
|
@ -1577,7 +1577,7 @@ def forward(self, x):
|
|||
|
||||
def test_export_for_training_with_container_type(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.buffer = torch.nn.Buffer(torch.ones(4, 4))
|
||||
|
||||
|
|
@ -1605,7 +1605,7 @@ def forward(self, x):
|
|||
|
||||
def test_export_for_training_run_decomp(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.buffer = torch.nn.Buffer(torch.ones(2, 2))
|
||||
self.linear = torch.nn.Linear(2, 2)
|
||||
|
|
@ -1672,7 +1672,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
|||
|
||||
def test_static_dim_constraints(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.l = torch.nn.Linear(6, 4)
|
||||
|
||||
|
|
@ -1896,7 +1896,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
|||
# Just to introduce some indirection: N is a top-level module N that calls
|
||||
# module M, defined next.
|
||||
class N(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.m = M()
|
||||
|
||||
|
|
@ -2408,7 +2408,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
|||
return y[:b]
|
||||
|
||||
class M2(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.m1 = M1()
|
||||
self.m3 = M3()
|
||||
|
|
@ -2462,7 +2462,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
|||
@testing.expectedFailureTrainingIRToRunDecompNonStrict
|
||||
def test_linear_conv(self):
|
||||
class MyLinear(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.weight = torch.randn(20, 98)
|
||||
self.bias = torch.randn(20)
|
||||
|
|
@ -2471,7 +2471,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
|||
return torch.nn.functional.linear(x, self.weight, self.bias)
|
||||
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(16, 33, 3)
|
||||
self.linear = MyLinear()
|
||||
|
|
@ -3100,7 +3100,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
|||
|
||||
def test_param_util(self):
|
||||
class Basic(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.lin = torch.nn.Linear(10, 1)
|
||||
|
||||
|
|
@ -3137,7 +3137,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
|||
|
||||
def test_export_dynamo_config(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.lstm = torch.nn.LSTM(input_size=4, hidden_size=5, num_layers=1)
|
||||
|
||||
|
|
@ -3251,7 +3251,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
|||
|
||||
def test_module(self):
|
||||
class MyLinear(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.weight = torch.randn(20, 98)
|
||||
self.bias = torch.randn(20)
|
||||
|
|
@ -3260,7 +3260,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
|||
return torch.nn.functional.linear(x, self.weight, self.bias)
|
||||
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(16, 33, 3)
|
||||
self.linear = MyLinear()
|
||||
|
|
@ -3296,7 +3296,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
|||
|
||||
def test_module_with_dict_container_inp_out(self):
|
||||
class MyLinear(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.weight = torch.randn(20, 98)
|
||||
self.bias = torch.randn(20)
|
||||
|
|
@ -3305,7 +3305,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
|||
return torch.nn.functional.linear(x, self.weight, self.bias)
|
||||
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(16, 33, 3)
|
||||
self.linear = MyLinear()
|
||||
|
|
@ -3364,7 +3364,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
|||
|
||||
def test_decomp_batch_norm_functional_predispatch(self):
|
||||
class ConvBatchnorm(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(1, 3, 1, 1)
|
||||
self.bn = torch.nn.BatchNorm2d(3)
|
||||
|
|
@ -3605,7 +3605,7 @@ def forward(self, x):
|
|||
|
||||
def test_constrain_decomp(self) -> None:
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.freq = torch.ones(5, 5)
|
||||
|
||||
|
|
@ -3701,7 +3701,7 @@ def forward(self, x):
|
|||
|
||||
def test_to_module_with_mutated_buffer(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.buf = torch.nn.Buffer(torch.zeros(1))
|
||||
|
||||
|
|
@ -3730,7 +3730,7 @@ def forward(self, x):
|
|||
|
||||
def test_to_module_with_mutated_buffer_multiple(self):
|
||||
class Bar(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.buf = torch.nn.Buffer(torch.ones(1))
|
||||
|
||||
|
|
@ -3739,7 +3739,7 @@ def forward(self, x):
|
|||
return x.sum() + self.buf.sum()
|
||||
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.buf = torch.nn.Buffer(torch.zeros(1))
|
||||
self.bar = Bar()
|
||||
|
|
@ -3829,7 +3829,7 @@ def forward(self, x):
|
|||
|
||||
def test_to_module_with_mutated_buffer_multiple_update_sub_later(self):
|
||||
class Bar(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.buf = torch.nn.Buffer(torch.ones(1))
|
||||
|
||||
|
|
@ -3838,7 +3838,7 @@ def forward(self, x):
|
|||
return x.sum() + self.buf.sum()
|
||||
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.buf = torch.nn.Buffer(torch.zeros(1))
|
||||
self.bar = Bar()
|
||||
|
|
@ -3882,7 +3882,7 @@ def forward(self, x):
|
|||
|
||||
def test_retracable_ep(self):
|
||||
class Bar(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.buf = torch.nn.Buffer(torch.ones(1))
|
||||
|
||||
|
|
@ -3891,7 +3891,7 @@ def forward(self, x):
|
|||
return x.sum() + self.buf.sum()
|
||||
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.buf = torch.nn.Buffer(torch.zeros(1))
|
||||
self.bar = Bar()
|
||||
|
|
@ -3938,7 +3938,7 @@ def forward(self, x):
|
|||
|
||||
def test_export_cond_symbool_pred(self):
|
||||
class A(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.buffer = torch.nn.Buffer(torch.ones(6, 4))
|
||||
|
||||
|
|
@ -3946,7 +3946,7 @@ def forward(self, x):
|
|||
return self.buffer.cos()
|
||||
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.a = A()
|
||||
|
||||
|
|
@ -3991,7 +3991,7 @@ def forward(self, b_a_buffer, x):
|
|||
|
||||
def test_cond_buffers(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.register_parameter(
|
||||
"param", torch.nn.Parameter(torch.ones(2, 3), requires_grad=False)
|
||||
|
|
@ -4024,7 +4024,7 @@ def forward(self, b_a_buffer, x):
|
|||
@unittest.expectedFailure
|
||||
def test_map_buffers(self):
|
||||
class M1(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.register_parameter(
|
||||
"param", torch.nn.Parameter(torch.tensor(5), requires_grad=False)
|
||||
|
|
@ -4060,7 +4060,7 @@ def forward(self, b_a_buffer, x):
|
|||
@testing.expectedFailureTrainingIRToRunDecompNonStrict
|
||||
def test_retrace_graph_level_meta_preservation(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
|
|
@ -4118,7 +4118,7 @@ def forward(self, b_a_buffer, x):
|
|||
|
||||
def test_train_eval_on_exported_preautograd_module(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
|
|
@ -4150,7 +4150,7 @@ def forward(self, b_a_buffer, x):
|
|||
self.assertEqual(len(ep.constants), 1)
|
||||
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.a = torch.tensor(3)
|
||||
|
||||
|
|
@ -4252,7 +4252,7 @@ def forward(self, b_a_buffer, x):
|
|||
|
||||
def test_export_decomps_simple(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.lin = torch.nn.Linear(10, 1)
|
||||
|
||||
|
|
@ -4278,7 +4278,7 @@ def forward(self, b_a_buffer, x):
|
|||
|
||||
def test_export_decomps_dynamic(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.lin = torch.nn.Linear(10, 1)
|
||||
|
||||
|
|
@ -4369,7 +4369,7 @@ def forward(self, b_a_buffer, x):
|
|||
|
||||
def test_constant_output(self):
|
||||
class ModuleConstant(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.b = torch.randn(3, 2)
|
||||
|
||||
|
|
@ -4377,7 +4377,7 @@ def forward(self, b_a_buffer, x):
|
|||
return self.b
|
||||
|
||||
class ModuleNestedConstant(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.bff = torch.randn(3, 2)
|
||||
|
||||
|
|
@ -4486,7 +4486,7 @@ graph():
|
|||
|
||||
def test_nested_module_with_init_buffer(self):
|
||||
class M1(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.b = torch.ones(3, 3)
|
||||
|
||||
|
|
@ -4522,7 +4522,7 @@ graph():
|
|||
@testing.expectedFailureRetraceability # Retracing tensor constants results in buffers
|
||||
def test_nested_module_with_constant_buffer(self):
|
||||
class M1(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.b = torch.tensor(5)
|
||||
|
||||
|
|
@ -4572,7 +4572,7 @@ graph():
|
|||
|
||||
def test_nested_module_with_parameter(self):
|
||||
class M1(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.a = torch.nn.Parameter(torch.ones(3, 3))
|
||||
self.b = torch.nn.Parameter(torch.tensor(5.0))
|
||||
|
|
@ -4630,7 +4630,7 @@ graph():
|
|||
|
||||
def test_retrace_pre_autograd(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.buffer = torch.nn.Buffer(torch.ones(4, 4))
|
||||
|
||||
|
|
@ -4684,7 +4684,7 @@ graph():
|
|||
@unittest.skip("Test is only supposed to work with non-strict mode")
|
||||
def test_issue_113041(self):
|
||||
class TestModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.a = torch.tensor(1.0)
|
||||
|
||||
|
|
@ -4699,7 +4699,7 @@ graph():
|
|||
handle = seq.register_forward_hook(forward_hook)
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.seq = seq
|
||||
|
||||
|
|
@ -4791,7 +4791,7 @@ graph():
|
|||
|
||||
def test_run_decomposition_supports_user_input_mutation(self):
|
||||
class SingleOp(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.op = torch.ops.aten.native_batch_norm
|
||||
|
||||
|
|
@ -5052,7 +5052,7 @@ graph():
|
|||
|
||||
def test_check_specialized_int(self):
|
||||
class SingleOp(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.op = torch.ops.aten.scatter_add
|
||||
|
||||
|
|
@ -5085,7 +5085,7 @@ graph():
|
|||
return x / x
|
||||
|
||||
class Child1(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.nested = NestedChild()
|
||||
self.register_parameter(
|
||||
|
|
@ -5097,7 +5097,7 @@ graph():
|
|||
return x + self.child1param
|
||||
|
||||
class Child2(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.child2buffer = torch.nn.Buffer(torch.ones(2, 3))
|
||||
|
||||
|
|
@ -5105,7 +5105,7 @@ graph():
|
|||
return x - self.child2buffer
|
||||
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.foo = Child1()
|
||||
self.bar = Child2()
|
||||
|
|
@ -5148,7 +5148,7 @@ graph():
|
|||
|
||||
def test_nn_module_stack(self):
|
||||
class Leaf(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(4, 4)
|
||||
|
||||
|
|
@ -5156,7 +5156,7 @@ graph():
|
|||
return self.linear(x)
|
||||
|
||||
class Bar(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.leaf = Leaf()
|
||||
self.buffer = torch.nn.Buffer(torch.randn(4, 4))
|
||||
|
|
@ -5165,7 +5165,7 @@ graph():
|
|||
return self.buffer.sum() + self.leaf(x).sum()
|
||||
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.bar = Bar()
|
||||
|
||||
|
|
@ -5205,7 +5205,7 @@ graph():
|
|||
|
||||
def test_nn_module_stack_shared_submodule(self):
|
||||
class Leaf(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(4, 4)
|
||||
|
||||
|
|
@ -5213,7 +5213,7 @@ graph():
|
|||
return self.linear(x)
|
||||
|
||||
class Bar(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.leaf = Leaf()
|
||||
self.buffer = torch.nn.Buffer(torch.randn(4, 4))
|
||||
|
|
@ -5222,7 +5222,7 @@ graph():
|
|||
return self.buffer.sum() + self.leaf(x).sum()
|
||||
|
||||
class BarDifferent(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.leaf = Leaf()
|
||||
|
||||
|
|
@ -5232,7 +5232,7 @@ graph():
|
|||
return a + b
|
||||
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.bar = Bar()
|
||||
self.bar_different = BarDifferent()
|
||||
|
|
@ -5286,7 +5286,7 @@ graph():
|
|||
|
||||
def test_stack_trace(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(4, 4)
|
||||
|
||||
|
|
@ -5318,7 +5318,7 @@ graph():
|
|||
# Guard validation upsets the guard
|
||||
def test_cond_with_module_stack_export_with(self):
|
||||
class Bar(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(4, 4)
|
||||
|
||||
|
|
@ -5332,7 +5332,7 @@ graph():
|
|||
return torch.cond(x.sum() > 4, true_fn, false_fn, [x])
|
||||
|
||||
class CondExport(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.bar = Bar()
|
||||
|
||||
|
|
@ -5371,7 +5371,7 @@ def forward(self, p_bar_linear_weight, p_bar_linear_bias, x):
|
|||
@unittest.expectedFailure
|
||||
def test_cond_with_module_stack_export_with_unflatten(self):
|
||||
class Bar(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(4, 4)
|
||||
|
||||
|
|
@ -5385,7 +5385,7 @@ def forward(self, p_bar_linear_weight, p_bar_linear_bias, x):
|
|||
return torch.cond(x.shape[0] > 4, true_fn, false_fn, [x])
|
||||
|
||||
class CondExport(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.bar = Bar()
|
||||
|
||||
|
|
@ -5412,7 +5412,7 @@ def forward(self, p_bar_linear_weight, p_bar_linear_bias, x):
|
|||
|
||||
def test_predispatch_cond(self):
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.pred = torch.nn.Buffer(torch.tensor(False))
|
||||
self.t = torch.nn.Buffer(torch.tensor(10))
|
||||
|
|
@ -5525,7 +5525,7 @@ def forward(self, x, b_t, y):
|
|||
N, C, H, W = 1, 2, 2, 3
|
||||
|
||||
class Module(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
layer = torch.nn.LayerNorm([C, H, W])
|
||||
self.norms = torch.nn.ModuleList(
|
||||
|
|
@ -5548,7 +5548,7 @@ def forward(self, x, b_t, y):
|
|||
|
||||
def test_non_persistent_buffer(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.foo = torch.nn.Buffer(torch.rand(2, 3), persistent=False)
|
||||
|
||||
|
|
@ -5556,7 +5556,7 @@ def forward(self, x, b_t, y):
|
|||
return self.foo + x
|
||||
|
||||
class MyOuterModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.inner = MyModule()
|
||||
|
||||
|
|
@ -5607,7 +5607,7 @@ def forward(self, x, b_t, y):
|
|||
|
||||
def test_nonstrict_retrace_preserves_metadata(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(4, 4)
|
||||
|
||||
|
|
@ -5625,7 +5625,7 @@ def forward(self, x, b_t, y):
|
|||
|
||||
def test_fake_weights(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.foo = torch.nn.Parameter(torch.randn(4, 4))
|
||||
self.bar = torch.nn.Buffer(torch.randn(4, 4), persistent=False)
|
||||
|
|
@ -5645,7 +5645,7 @@ def forward(self, x, b_t, y):
|
|||
|
||||
def test_fake_inputs(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.foo = torch.nn.Parameter(torch.randn(4, 4))
|
||||
|
||||
|
|
@ -5664,7 +5664,7 @@ def forward(self, x, b_t, y):
|
|||
|
||||
def test_trace_under_fake(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.foo = torch.nn.Parameter(torch.randn(4, 4))
|
||||
|
||||
|
|
@ -5714,7 +5714,7 @@ def forward(self, x, b_t, y):
|
|||
|
||||
def test_user_input_and_buffer_mutation(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.foo = torch.nn.Buffer(torch.randn(4, 4))
|
||||
|
||||
|
|
@ -5739,7 +5739,7 @@ def forward(self, x, b_t, y):
|
|||
|
||||
def test_custom_op_auto_functionalize(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x, z):
|
||||
|
|
@ -5766,7 +5766,7 @@ def forward(self, x, b_t, y):
|
|||
|
||||
def test_custom_op_auto_functionalize_pre_dispatch(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
|
|
@ -5800,7 +5800,7 @@ def forward(self, x):
|
|||
|
||||
def test_custom_op_auto_warn_pre_dispatch(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
|
|
@ -5851,7 +5851,7 @@ def forward(self, x):
|
|||
|
||||
# test collisions between user inputs and params, buffers, constants
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.param = torch.nn.Parameter(torch.randn(4))
|
||||
self.alpha = torch.nn.Buffer(torch.randn(4), persistent=True)
|
||||
|
|
@ -6197,7 +6197,7 @@ def forward(self, x, y):
|
|||
@testing.expectedFailureSerDer
|
||||
def test_preserve_requires_grad_placeholders(self):
|
||||
class Module(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.p = torch.nn.Parameter(torch.randn(3, 3))
|
||||
|
||||
|
|
@ -6216,7 +6216,7 @@ def forward(self, x, y):
|
|||
def test_reshape_view_helper(self):
|
||||
# see: https://github.com/pytorch/pytorch/issues/126607
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
|
|
@ -6326,7 +6326,7 @@ def forward(self, x, y):
|
|||
return x + self.foo + self.m2(x)
|
||||
|
||||
class M2(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.foo = torch.ones(3, 3)
|
||||
|
||||
|
|
@ -6352,7 +6352,7 @@ def forward(self, x, y):
|
|||
@testing.expectedFailureRetraceability
|
||||
def test_unused_aliases(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
# param
|
||||
self.alpha = torch.nn.Parameter(torch.randn(4))
|
||||
|
|
@ -6512,7 +6512,7 @@ def forward(self, x, y):
|
|||
return y[d:]
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.m1 = M1()
|
||||
|
||||
|
|
@ -6539,7 +6539,7 @@ def forward(self, x, y):
|
|||
|
||||
def test_split_const_gm_with_lifted_constants(self):
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.w_pre = torch.randn(4, 4)
|
||||
self.b = torch.randn(4)
|
||||
|
|
@ -6599,7 +6599,7 @@ class TestOneOffModelExportResult(TestCase):
|
|||
"""
|
||||
|
||||
class ScaledDotProductAttention(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, q, k, v):
|
||||
|
|
@ -6640,7 +6640,7 @@ class TestOneOffModelExportResult(TestCase):
|
|||
"""
|
||||
|
||||
class ScaledDotProductAttention(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, q, k, v):
|
||||
|
|
@ -6866,7 +6866,7 @@ def forward(self, x):
|
|||
|
||||
def test_constant_fqn(self):
|
||||
class Nested(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.constant = torch.rand(2, 3)
|
||||
self.parameter = torch.nn.Parameter(torch.rand(2, 3))
|
||||
|
|
@ -6875,7 +6875,7 @@ def forward(self, x):
|
|||
return x + self.constant
|
||||
|
||||
class Mod(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.nested = Nested()
|
||||
|
||||
|
|
@ -6889,7 +6889,7 @@ def forward(self, x):
|
|||
|
||||
def test_constant_name(self):
|
||||
class Nested(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.constant = torch.rand(2, 3)
|
||||
self.parameter = torch.nn.Parameter(torch.rand(2, 3))
|
||||
|
|
@ -6898,7 +6898,7 @@ def forward(self, x):
|
|||
return x + self.constant
|
||||
|
||||
class Mod(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.nested_1 = Nested()
|
||||
self.nested_2 = Nested()
|
||||
|
|
@ -6933,7 +6933,7 @@ def forward(self, x):
|
|||
|
||||
def test_nested_retrace(self):
|
||||
class Nested(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.param = torch.nn.Parameter(torch.randn(3))
|
||||
|
||||
|
|
@ -6941,7 +6941,7 @@ def forward(self, x):
|
|||
return x + self.param
|
||||
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.nested = Nested()
|
||||
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ from torch.testing._internal.common_utils import (
|
|||
|
||||
|
||||
class GraphBuilder:
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.graph = torch.fx.Graph()
|
||||
self.nodes = {}
|
||||
self.values = {}
|
||||
|
|
@ -354,7 +354,7 @@ class TestLift(TestCase):
|
|||
|
||||
def test_unlift_nonpersistent_buffer(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.register_buffer(
|
||||
"non_persistent_buf", torch.zeros(1), persistent=False
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ class TestPassInfra(TestCase):
|
|||
@unittest.skipIf(IS_WINDOWS, "Windows not supported")
|
||||
def test_cond(self) -> None:
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, pred, x, y):
|
||||
|
|
@ -74,7 +74,7 @@ class TestPassInfra(TestCase):
|
|||
# Tests that graph nodes stay the same for nodes that are not touched
|
||||
# during transformation
|
||||
class CustomModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
# Define a parameter
|
||||
|
|
@ -110,7 +110,7 @@ class TestPassInfra(TestCase):
|
|||
# Checks that pass infra correctly updates graph signature
|
||||
# after transformations.
|
||||
class CustomModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.my_parameter = torch.nn.Parameter(torch.tensor(2.0))
|
||||
|
|
@ -152,7 +152,7 @@ class TestPassInfra(TestCase):
|
|||
|
||||
def test_replace_hook_basic(self) -> None:
|
||||
class CustomModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.my_parameter = torch.nn.Parameter(torch.tensor(2.0))
|
||||
|
|
|
|||
|
|
@ -87,12 +87,12 @@ def _get_output_names(gm: torch.fx.GraphModule) -> List[str]:
|
|||
|
||||
class ModelsWithScriptObjectAttr:
|
||||
class Simple(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
|
||||
|
||||
class SimpleWithAttrInContainer(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
|
||||
self.pytree_attr2 = [
|
||||
|
|
@ -104,7 +104,7 @@ class ModelsWithScriptObjectAttr:
|
|||
]
|
||||
|
||||
class NestedWithAttrInContainer(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
|
||||
self.pytree_attr2 = [
|
||||
|
|
@ -118,7 +118,7 @@ class ModelsWithScriptObjectAttr:
|
|||
self.sub_mod2 = ModelsWithScriptObjectAttr.SimpleWithAttrInContainer()
|
||||
|
||||
class MoreNestedWithAttrInContainer(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
|
||||
self.pytree_attr2 = [
|
||||
|
|
@ -267,7 +267,7 @@ class TestPasses(TestCase):
|
|||
|
||||
def test_runtime_assert_one_dim(self) -> None:
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
|
|
@ -290,7 +290,7 @@ class TestPasses(TestCase):
|
|||
|
||||
def test_runtime_assert_multiple_dims(self) -> None:
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x, y):
|
||||
|
|
@ -320,7 +320,7 @@ class TestPasses(TestCase):
|
|||
|
||||
def test_runtime_assert_some_dims_not_specified(self) -> None:
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x, y):
|
||||
|
|
@ -357,7 +357,7 @@ class TestPasses(TestCase):
|
|||
|
||||
def test_runtime_assert_some_inps_not_used(self) -> None:
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x, y):
|
||||
|
|
@ -389,7 +389,7 @@ class TestPasses(TestCase):
|
|||
|
||||
def test_view_to_view_copy(self) -> None:
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
|
|
@ -444,7 +444,7 @@ class TestPasses(TestCase):
|
|||
|
||||
def test_custom_obj_tuple_out(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
|
||||
|
||||
|
|
@ -471,7 +471,7 @@ class TestPasses(TestCase):
|
|||
|
||||
def test_remove_effect_token_kwargs(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
|
||||
|
||||
|
|
@ -546,7 +546,7 @@ def forward(self, token, obj_attr, x):
|
|||
|
||||
def test_runtime_assert_inline_constraints_for_item(self) -> None:
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
|
|
@ -569,7 +569,7 @@ def forward(self, token, obj_attr, x):
|
|||
|
||||
def test_runtime_assert_inline_constraints_for_nonzero(self) -> None:
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
|
|
@ -613,7 +613,7 @@ def forward(self, token, obj_attr, x):
|
|||
# TODO(pianpwk): add back runtime asserts to subgraphs
|
||||
def test_runtime_assert_inline_constraints_for_cond(self) -> None:
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, pred, x, y):
|
||||
|
|
@ -870,7 +870,7 @@ def forward(self, sin, cos):
|
|||
return x + y.add_(1)
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.state = torch.nn.Buffer(torch.zeros(1))
|
||||
|
||||
|
|
@ -911,7 +911,7 @@ def forward(self, sin, cos):
|
|||
return (x, x + y.add_(1))
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.state = torch.nn.Buffer(torch.zeros(1))
|
||||
|
||||
|
|
|
|||
|
|
@ -121,7 +121,7 @@ class TestSerialize(TestCase):
|
|||
|
||||
def test_predispatch_export_with_autograd_op(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
|
|
@ -148,7 +148,7 @@ class TestSerialize(TestCase):
|
|||
class MyModule(torch.nn.Module):
|
||||
"""A test module with that has multiple args and uses kwargs"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.p = torch.nn.Parameter(torch.ones(2, 3))
|
||||
|
||||
|
|
@ -178,7 +178,7 @@ class TestSerialize(TestCase):
|
|||
# Tests that modules with more complicated layer patterns can be serialized
|
||||
# and deserialized correctly.
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.layers = torch.nn.Sequential(
|
||||
torch.nn.SiLU(),
|
||||
|
|
@ -209,7 +209,7 @@ class TestSerialize(TestCase):
|
|||
|
||||
def test_serialize_constant_outputs(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
|
|
@ -231,7 +231,7 @@ class TestSerialize(TestCase):
|
|||
|
||||
def test_serialize_multiple_returns_from_node(self) -> None:
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x, w, b):
|
||||
|
|
@ -267,7 +267,7 @@ class TestSerialize(TestCase):
|
|||
|
||||
def test_serialize_list_returns(self) -> None:
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
|
|
@ -309,7 +309,7 @@ class TestSerialize(TestCase):
|
|||
"""
|
||||
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
|
|
@ -636,7 +636,7 @@ class TestDeserialize(TestCase):
|
|||
"""
|
||||
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x, w, b):
|
||||
|
|
@ -657,7 +657,7 @@ class TestDeserialize(TestCase):
|
|||
|
||||
def test_basic(self) -> None:
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
|
|
@ -671,7 +671,7 @@ class TestDeserialize(TestCase):
|
|||
|
||||
def test_dynamic(self) -> None:
|
||||
class DynamicShapeSimpleModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, a, b, c) -> torch.Tensor:
|
||||
|
|
@ -709,7 +709,7 @@ class TestDeserialize(TestCase):
|
|||
|
||||
def test_module(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear1 = torch.nn.Linear(3, 3)
|
||||
self.relu = torch.nn.ReLU()
|
||||
|
|
@ -727,7 +727,7 @@ class TestDeserialize(TestCase):
|
|||
|
||||
def test_module_meta(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.p = torch.nn.Parameter(torch.ones(3, 3))
|
||||
|
||||
|
|
@ -803,7 +803,7 @@ class TestDeserialize(TestCase):
|
|||
|
||||
def test_list_of_optional_tensors(self) -> None:
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x, y, z):
|
||||
|
|
@ -906,7 +906,7 @@ def forward(self, x):
|
|||
@unittest.skipIf(not torch.cuda.is_available(), "Requires cuda")
|
||||
def test_device(self) -> None:
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True)
|
||||
self.relu = torch.nn.ReLU()
|
||||
|
|
@ -923,7 +923,7 @@ def forward(self, x):
|
|||
|
||||
def test_custom_obj_tuple_out(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
|
||||
|
||||
|
|
@ -939,7 +939,7 @@ def forward(self, x):
|
|||
|
||||
def test_custom_obj(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
|
||||
|
||||
|
|
@ -954,7 +954,7 @@ def forward(self, x):
|
|||
|
||||
def test_custom_obj_list_out(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
|
||||
|
||||
|
|
@ -970,7 +970,7 @@ def forward(self, x):
|
|||
|
||||
def test_export_no_inputs(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.p = torch.ones(3, 3)
|
||||
|
||||
|
|
@ -1019,7 +1019,7 @@ class TestSaveLoad(TestCase):
|
|||
inp = (torch.tensor([0.1, 0.1]),)
|
||||
|
||||
class Module(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(2, 2)
|
||||
|
||||
|
|
@ -1118,7 +1118,7 @@ class TestSaveLoad(TestCase):
|
|||
|
||||
def test_save_constants(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.a = torch.tensor(3)
|
||||
|
||||
|
|
@ -1192,7 +1192,7 @@ class TestSerializeCustomClass(TestCase):
|
|||
|
||||
def test_custom_class_containing_fake_tensor(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.custom_obj = torch.classes._TorchScriptTesting._ContainsTensor(
|
||||
torch.rand(2, 3)
|
||||
|
|
|
|||
|
|
@ -60,7 +60,7 @@ class SumNet(torch.nn.Module):
|
|||
|
||||
|
||||
class EltwiseNet(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.relu = torch.nn.ReLU()
|
||||
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ class TestExportTools(TestCase):
|
|||
return x.sin()
|
||||
|
||||
class Module(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.unsupported = Unsupported()
|
||||
self.supported = Supported()
|
||||
|
|
|
|||
|
|
@ -162,7 +162,7 @@ class TestExportTorchbind(TestCase):
|
|||
@parametrize("pre_dispatch", [True, False])
|
||||
def test_none(self, pre_dispatch):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
|
||||
|
||||
|
|
@ -212,7 +212,7 @@ def forward(self, token, obj_attr, x, n):
|
|||
@parametrize("pre_dispatch", [True, False])
|
||||
def test_attribute(self, pre_dispatch):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
|
||||
|
||||
|
|
@ -246,7 +246,7 @@ def forward(self, token, obj_attr, x):
|
|||
@parametrize("pre_dispatch", [True, False])
|
||||
def test_attribute_as_custom_op_argument(self, pre_dispatch):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
|
||||
|
||||
|
|
@ -282,7 +282,7 @@ def forward(self, token, obj_attr, x):
|
|||
cc = torch.classes._TorchScriptTesting._Foo(10, 20)
|
||||
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x, cc):
|
||||
|
|
@ -320,7 +320,7 @@ def forward(self, token, x, cc):
|
|||
cc = torch.classes._TorchScriptTesting._Foo(10, 20)
|
||||
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x, cc):
|
||||
|
|
@ -381,7 +381,7 @@ def forward(self, token, x, cc):
|
|||
return x + torch.ops._TorchScriptTesting.takes_foo(self.foo, x)
|
||||
|
||||
class F1(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.alpha = torch.classes._TorchScriptTesting._Foo(10, 20)
|
||||
self.beta = self.alpha
|
||||
|
|
@ -417,7 +417,7 @@ def forward(self, token, x, cc):
|
|||
@parametrize("pre_dispatch", [True, False])
|
||||
def test_unlift_custom_obj(self, pre_dispatch):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
|
||||
|
||||
|
|
@ -458,7 +458,7 @@ def forward(self, token, obj_attr, x):
|
|||
@parametrize("pre_dispatch", [True, False])
|
||||
def test_custom_obj_list_out(self, pre_dispatch):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
|
||||
|
||||
|
|
@ -510,7 +510,7 @@ def forward(self, token, obj_attr, x):
|
|||
@parametrize("pre_dispatch", [True, False])
|
||||
def test_custom_obj_tuple_out(self, pre_dispatch):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
|
||||
|
||||
|
|
@ -559,7 +559,7 @@ def forward(self, token, obj_attr, x):
|
|||
test = self
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(3, 2)
|
||||
self.check_tq_is_fake = True
|
||||
|
|
@ -617,7 +617,7 @@ def forward(self, arg0_1, arg1_1):
|
|||
test = self
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(3, 2)
|
||||
self.check_tq_is_fake = True
|
||||
|
|
@ -674,7 +674,7 @@ def forward(self, arg0_1, arg1_1):
|
|||
|
||||
def test_non_strict_export_methods(self):
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(2, 2)
|
||||
|
||||
|
|
@ -857,7 +857,7 @@ def forward(self, token, safe_obj):
|
|||
@parametrize("fallthrough_via", ["lib_impl", "py_impl"])
|
||||
def test_make_fx_tensor_queue_operators(self, fallthrough_via):
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, tq, x):
|
||||
|
|
@ -932,7 +932,7 @@ def forward(self, arg0_1, arg1_1):
|
|||
|
||||
def test_aot_export_tensor_queue_operators(self):
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, tq, x):
|
||||
|
|
@ -1072,7 +1072,7 @@ class TestCompileTorchbind(TestCase):
|
|||
backend = EagerAndRecordGraphs()
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.check_tq_is_fake = True
|
||||
|
||||
|
|
@ -1133,7 +1133,7 @@ class TestCompileTorchbind(TestCase):
|
|||
@parametrize("backend", ["eager", "aot_eager"])
|
||||
def test_compile_script_object_input_guards(self, backend):
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.check_tq_is_fake = True
|
||||
|
||||
|
|
@ -1184,7 +1184,7 @@ class TestCompileTorchbind(TestCase):
|
|||
|
||||
def test_compile_script_object_input_automatic_dynamic_shape(self):
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.check_tq_is_fake = True
|
||||
|
||||
|
|
@ -1221,7 +1221,7 @@ class TestCompileTorchbind(TestCase):
|
|||
backend = EagerAndRecordGraphs()
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.check_tq_is_fake = True
|
||||
|
||||
|
|
@ -1418,7 +1418,7 @@ def forward(self, token, obj, x):
|
|||
backend = EagerAndRecordGraphs()
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.tq = _empty_tensor_queue()
|
||||
|
||||
|
|
@ -1482,7 +1482,7 @@ class TestRegisterFakeClass(TestCase):
|
|||
|
||||
@torch._library.register_fake_class("_TorchScriptTesting::_Foo")
|
||||
class InvalidFakeFoo:
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def test_register_fake_class_from_real_not_classmethod(self):
|
||||
|
|
|
|||
|
|
@ -65,7 +65,7 @@ class TestUnflatten(TestCase):
|
|||
return x / x
|
||||
|
||||
class Child1(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.nested = NestedChild()
|
||||
self.register_parameter(
|
||||
|
|
@ -77,7 +77,7 @@ class TestUnflatten(TestCase):
|
|||
return x + self.child1param
|
||||
|
||||
class Child2(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.child2buffer = torch.nn.Buffer(torch.ones(2, 3))
|
||||
|
||||
|
|
@ -85,7 +85,7 @@ class TestUnflatten(TestCase):
|
|||
return x - self.child2buffer
|
||||
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.foo = Child1()
|
||||
self.bar = Child2()
|
||||
|
|
@ -119,7 +119,7 @@ class TestUnflatten(TestCase):
|
|||
|
||||
def test_unflatten_buffer_mutation(self):
|
||||
class Child(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.child2buffer = torch.nn.Buffer(torch.ones(2, 3))
|
||||
|
||||
|
|
@ -128,7 +128,7 @@ class TestUnflatten(TestCase):
|
|||
return x - self.child2buffer
|
||||
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.foo = Child()
|
||||
self.register_parameter(
|
||||
|
|
@ -155,7 +155,7 @@ class TestUnflatten(TestCase):
|
|||
|
||||
def test_unflatten_nested_access(self):
|
||||
class Child(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.child2buffer = torch.nn.Buffer(torch.ones(2, 3))
|
||||
|
||||
|
|
@ -163,7 +163,7 @@ class TestUnflatten(TestCase):
|
|||
return x - self.child2buffer
|
||||
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.foo = Child()
|
||||
self.register_parameter(
|
||||
|
|
@ -184,7 +184,7 @@ class TestUnflatten(TestCase):
|
|||
|
||||
def test_unflatten_shared_submodule(self):
|
||||
class Shared(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
layernorm = torch.nn.LayerNorm(10)
|
||||
self.sub_net = torch.nn.Sequential(
|
||||
|
|
@ -218,7 +218,7 @@ class TestUnflatten(TestCase):
|
|||
return {"x": y["key"] + zx[1], "w": y["key"] * zx[1]}
|
||||
|
||||
class Child1(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.nested = NestedChild()
|
||||
|
||||
|
|
@ -228,14 +228,14 @@ class TestUnflatten(TestCase):
|
|||
return xw["w"] + z - xw["x"]
|
||||
|
||||
class Child2(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x - 1
|
||||
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.foo = Child1()
|
||||
self.bar = Child2()
|
||||
|
|
@ -287,7 +287,7 @@ class TestUnflatten(TestCase):
|
|||
|
||||
def test_unflatten_param_list_dict(self):
|
||||
class Mod(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.param_list = torch.nn.ParameterList()
|
||||
self.param_dict = torch.nn.ParameterDict()
|
||||
|
|
@ -317,7 +317,7 @@ class TestUnflatten(TestCase):
|
|||
return x + a, b
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.m1 = M1()
|
||||
|
||||
|
|
@ -337,7 +337,7 @@ class TestUnflatten(TestCase):
|
|||
|
||||
def test_unflatten_wrong_input(self):
|
||||
class Mod(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.param_list = torch.nn.ParameterList()
|
||||
self.param_dict = torch.nn.ParameterDict()
|
||||
|
|
@ -374,7 +374,7 @@ class TestUnflatten(TestCase):
|
|||
return x / x
|
||||
|
||||
class Child1(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.nested = NestedChild()
|
||||
self.register_parameter(
|
||||
|
|
@ -386,7 +386,7 @@ class TestUnflatten(TestCase):
|
|||
return x + self.child1param
|
||||
|
||||
class Child2(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.child2buffer = torch.nn.Buffer(torch.ones(2, 3))
|
||||
|
||||
|
|
@ -394,7 +394,7 @@ class TestUnflatten(TestCase):
|
|||
return x - self.child2buffer
|
||||
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.foo = Child1()
|
||||
self.bar = Child2()
|
||||
|
|
@ -420,7 +420,7 @@ class TestUnflatten(TestCase):
|
|||
|
||||
def test_fx_trace(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x, y):
|
||||
|
|
@ -439,14 +439,14 @@ class TestUnflatten(TestCase):
|
|||
|
||||
def test_double_nested_submodule(self):
|
||||
class SubSubMod(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x * x
|
||||
|
||||
class SubMod(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.subsubmod = SubSubMod()
|
||||
|
||||
|
|
@ -454,7 +454,7 @@ class TestUnflatten(TestCase):
|
|||
return x - x
|
||||
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.submod = SubMod()
|
||||
|
||||
|
|
@ -470,7 +470,7 @@ class TestUnflatten(TestCase):
|
|||
|
||||
def test_unflatten_container_type(self):
|
||||
class Leaf(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(4, 4)
|
||||
|
||||
|
|
@ -478,7 +478,7 @@ class TestUnflatten(TestCase):
|
|||
return self.linear(x)
|
||||
|
||||
class Bar(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.leaf = Leaf()
|
||||
self.buffer = torch.nn.Buffer(torch.randn(4, 4))
|
||||
|
|
@ -487,7 +487,7 @@ class TestUnflatten(TestCase):
|
|||
return self.buffer.sum() + self.leaf(x).sum() + z[0].sum() + z[1].sum()
|
||||
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.bar = Bar()
|
||||
|
||||
|
|
@ -506,14 +506,14 @@ class TestUnflatten(TestCase):
|
|||
|
||||
def test_unflattened_module_nodes_has_meta_val(self):
|
||||
class SubMod(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x + x, x * x
|
||||
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.submod = SubMod()
|
||||
|
||||
|
|
@ -538,7 +538,7 @@ class TestUnflatten(TestCase):
|
|||
|
||||
def test_placeholder_and_get_attr_ordering_after_unflattened(self):
|
||||
class TransposeModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(3, 1, 3, stride=2)
|
||||
|
||||
|
|
@ -564,7 +564,7 @@ class TestUnflatten(TestCase):
|
|||
|
||||
def test_unflatten_constant_tensor(self):
|
||||
class SubMod(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.initializer = 0.1
|
||||
|
||||
|
|
@ -572,7 +572,7 @@ class TestUnflatten(TestCase):
|
|||
return x + torch.tensor(self.initializer)
|
||||
|
||||
class Mod(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.submod = SubMod()
|
||||
|
||||
|
|
@ -604,7 +604,7 @@ class TestUnflatten(TestCase):
|
|||
return (self.x + self.y) * z
|
||||
|
||||
class SubMod(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
|
||||
|
||||
|
|
@ -612,7 +612,7 @@ class TestUnflatten(TestCase):
|
|||
return x + self.attr.add_tensor(x)
|
||||
|
||||
class Mod(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.submod = SubMod()
|
||||
|
||||
|
|
@ -635,7 +635,7 @@ class TestUnflatten(TestCase):
|
|||
return x + 1
|
||||
|
||||
class Nested(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.leaf = Leaf()
|
||||
|
||||
|
|
@ -643,7 +643,7 @@ class TestUnflatten(TestCase):
|
|||
return self.leaf(x) + 2
|
||||
|
||||
class TopLevel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.nested = Nested()
|
||||
|
||||
|
|
@ -661,7 +661,7 @@ class TestUnflatten(TestCase):
|
|||
|
||||
def test_unflatten_submodule_ordering(self):
|
||||
class Module2(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.buffer = torch.nn.Buffer(torch.rand(3, 4))
|
||||
self.register_parameter("param", torch.nn.Parameter(torch.rand(3, 4)))
|
||||
|
|
@ -670,7 +670,7 @@ class TestUnflatten(TestCase):
|
|||
return x + self.buffer + self.param
|
||||
|
||||
class Module1(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.buffer = torch.nn.Buffer(torch.rand(3, 4))
|
||||
self.register_parameter("param", torch.nn.Parameter(torch.rand(3, 4)))
|
||||
|
|
@ -679,7 +679,7 @@ class TestUnflatten(TestCase):
|
|||
return x + self.buffer + self.param
|
||||
|
||||
class Module(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.mod2 = Module2()
|
||||
self.mod3 = self.mod2
|
||||
|
|
@ -704,7 +704,7 @@ class TestUnflatten(TestCase):
|
|||
N, C, H, W = 1, 2, 2, 3
|
||||
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
layer = torch.nn.LayerNorm([C, H, W])
|
||||
self.norms = torch.nn.ModuleList(
|
||||
|
|
@ -735,7 +735,7 @@ class TestUnflatten(TestCase):
|
|||
def test_simple_alias(self):
|
||||
# handle weight sharing, check tensor ids after unflattening
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
# alias param
|
||||
self.bias = torch.nn.Parameter(torch.randn(4))
|
||||
|
|
@ -753,7 +753,7 @@ class TestUnflatten(TestCase):
|
|||
|
||||
# handle aliasing where one alias is unused
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.bias = torch.nn.Parameter(torch.randn(4))
|
||||
self.m = torch.nn.Linear(4, 4)
|
||||
|
|
@ -809,7 +809,7 @@ class TestUnflatten(TestCase):
|
|||
return y[:d]
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.m1 = M1()
|
||||
self.m2 = M2()
|
||||
|
|
|
|||
|
|
@ -139,7 +139,7 @@ class TestVerifier(TestCase):
|
|||
|
||||
def test_ep_verifier_invalid_buffer(self) -> None:
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.a = torch.tensor(3.0)
|
||||
|
||||
|
|
@ -160,7 +160,7 @@ class TestVerifier(TestCase):
|
|||
|
||||
def test_ep_verifier_buffer_mutate(self) -> None:
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.my_parameter = torch.nn.Parameter(torch.tensor(2.0))
|
||||
|
|
@ -183,7 +183,7 @@ class TestVerifier(TestCase):
|
|||
|
||||
def test_ep_verifier_invalid_output(self) -> None:
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.my_parameter = torch.nn.Parameter(torch.tensor(2.0))
|
||||
|
|
|
|||
|
|
@ -526,7 +526,7 @@ class TestAOTAutograd(AOTTestCase):
|
|||
inp = [torch.randn(1, 10, 10, dtype=torch.complex64)]
|
||||
|
||||
class F(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(10, 10, dtype=torch.complex64)
|
||||
|
||||
|
|
@ -540,7 +540,7 @@ class TestAOTAutograd(AOTTestCase):
|
|||
# test that this works even though the sparse tensor has no storage.
|
||||
|
||||
class F(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.emb = torch.nn.EmbeddingBag(100, 8, sparse=True)
|
||||
|
||||
|
|
@ -1004,7 +1004,7 @@ def forward(self, primals_1):
|
|||
@skipIfTorchDynamo("This test suite already uses dynamo")
|
||||
def test_composite_impl_compile(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(3, 3)
|
||||
|
||||
|
|
@ -3584,7 +3584,7 @@ def forward(self, tangents_1):
|
|||
|
||||
def test_buffer_copied_in_graph(self):
|
||||
class MyModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.buf = torch.nn.Buffer(torch.zeros(1))
|
||||
self.w1 = torch.nn.Parameter(torch.zeros(1))
|
||||
|
|
@ -3639,7 +3639,7 @@ def forward(self, primals_1, primals_2, primals_3, primals_4):
|
|||
|
||||
def test_buffer_copied_in_graph_with_different_shapes(self):
|
||||
class MyModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.buf = torch.nn.Buffer(torch.ones(4, 4))
|
||||
self.w = torch.nn.Parameter(
|
||||
|
|
@ -3694,7 +3694,7 @@ def forward(self, primals_1, primals_2, primals_3):
|
|||
|
||||
def test_buffer_batch_norm(self):
|
||||
class MyModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.m = torch.nn.BatchNorm1d(100)
|
||||
|
||||
|
|
@ -3816,7 +3816,7 @@ def forward(self, tangents_1):
|
|||
from functorch.experimental import functionalize
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(5, 5)
|
||||
|
||||
|
|
@ -3865,7 +3865,7 @@ def forward(self, tangents_1):
|
|||
|
||||
def test_real_weights_in_symbolic_mode_with_inplace_ops(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.buffer = torch.nn.Buffer(torch.ones(4, 5))
|
||||
|
||||
|
|
@ -4142,7 +4142,7 @@ def forward(self, arg0_1, arg1_1):
|
|||
|
||||
def test_aot_export_predispatch_composite_implicit_linear(self):
|
||||
class MM(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(2, 2)
|
||||
|
||||
|
|
@ -4219,7 +4219,7 @@ def forward(self, arg0_1, arg1_1):
|
|||
|
||||
def test_aot_export_predispatch_buffer_mutation_metadata(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.foo = torch.nn.Buffer(torch.zeros(2, 2))
|
||||
|
||||
|
|
@ -4287,7 +4287,7 @@ def forward(self, arg0_1, arg1_1):
|
|||
)
|
||||
def test_aot_export_predispatch_with_cond_nested(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
|
|
@ -4365,7 +4365,7 @@ def forward(self, arg0_1):
|
|||
)
|
||||
def test_aot_export_predispatch_map_1(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x, y):
|
||||
|
|
@ -4449,7 +4449,7 @@ def forward(self, arg0_1, arg1_1):
|
|||
|
||||
def test_aot_export_predispatch_map_2(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x, y):
|
||||
|
|
@ -4492,7 +4492,7 @@ def forward(self, arg0_1, arg1_1):
|
|||
)
|
||||
def test_aot_export_predispatch_with_cond(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
|
|
@ -4540,7 +4540,7 @@ def forward(self, arg0_1):
|
|||
|
||||
def test_aot_export_predispatch_conv_and_bn(self):
|
||||
class ConvBatchnorm(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(1, 3, 1, 1)
|
||||
self.bn = torch.nn.BatchNorm2d(3)
|
||||
|
|
@ -4607,7 +4607,7 @@ def forward(self, arg0_1):
|
|||
|
||||
def test_aot_export_module_joint(self):
|
||||
class ConvBatchnormRelu(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(1, 3, 1, 1)
|
||||
self.bn = torch.nn.BatchNorm2d(3)
|
||||
|
|
@ -4793,7 +4793,7 @@ class <lambda>(torch.nn.Module):
|
|||
|
||||
def test_aot_export_forward_mutation_no_buffer_mut(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.buffer1 = torch.nn.Buffer(torch.ones(6, 4))
|
||||
|
||||
|
|
@ -4819,7 +4819,7 @@ def forward(self, arg0_1, arg1_1):
|
|||
|
||||
def test_aot_export_forward_mutation_multiple_mut(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.buffer1 = torch.nn.Buffer(torch.ones(6, 4))
|
||||
|
||||
|
|
@ -4928,7 +4928,7 @@ def forward(self, arg0_1, arg1_1, arg2_1):
|
|||
)
|
||||
def test_aot_export_with_torch_cond(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
|
|
@ -5065,7 +5065,7 @@ class TestPartitioning(AOTTestCase):
|
|||
# Following module results in inplace ops while tracing. The test checks
|
||||
# that the meta tensor information is stored for inplace ops.
|
||||
class MockModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.weight = torch.nn.Parameter(
|
||||
torch.randn(3072, 768, requires_grad=True)
|
||||
|
|
@ -5800,7 +5800,7 @@ def forward(self, tangents_1, tangents_2):
|
|||
class TestAOTModuleSimplified(AOTTestCase):
|
||||
def test_aot_module_simplified(self):
|
||||
class MockModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(20, 30)
|
||||
|
||||
|
|
@ -5829,7 +5829,7 @@ class TestAOTModuleSimplified(AOTTestCase):
|
|||
|
||||
def test_aot_module_simplified_dynamic(self):
|
||||
class MockModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(20, 30)
|
||||
|
||||
|
|
@ -5892,7 +5892,7 @@ class TestAOTModuleSimplified(AOTTestCase):
|
|||
def test_inference_python_dispatcher(self):
|
||||
# Extracted from unet
|
||||
class MockModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.upsample = torch.nn.Upsample(
|
||||
scale_factor=2, mode="bilinear", align_corners=True
|
||||
|
|
@ -5911,7 +5911,7 @@ class TestAOTModuleSimplified(AOTTestCase):
|
|||
|
||||
def test_aot_module_simplified_preserves_stack_trace(self):
|
||||
class MockModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(20, 30)
|
||||
|
||||
|
|
@ -5952,7 +5952,7 @@ class TestAOTModuleSimplified(AOTTestCase):
|
|||
|
||||
def test_aot_module_simplified_preserves_stack_trace_from_mutation(self):
|
||||
class MockModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
|
|
@ -6390,7 +6390,7 @@ class MockFXGraphCache:
|
|||
In memory version of FXGraphCache so we can isolate testing for FXGraphCache
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.cache = {}
|
||||
|
||||
def save(self, key, gm):
|
||||
|
|
|
|||
|
|
@ -135,7 +135,7 @@ def _while_loop_tests():
|
|||
return while_loop(cond_fn, body_fn, (ci, cj, a, b))
|
||||
|
||||
class SimpleWithLinear(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(2, 2)
|
||||
self.dec = torch.nn.Buffer(torch.tensor(1))
|
||||
|
|
@ -150,7 +150,7 @@ def _while_loop_tests():
|
|||
return while_loop(cond_fn, body_fn, (iter, x))
|
||||
|
||||
class NestedWithLinear(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.mod = SimpleWithLinear()
|
||||
self.outer_linear = torch.nn.Linear(2, 2)
|
||||
|
|
@ -834,7 +834,7 @@ def forward(self, pred_1, x_1):
|
|||
|
||||
def test_cond_autograd_user_nn_module(self):
|
||||
class User_nn_module(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, input):
|
||||
|
|
@ -2964,7 +2964,7 @@ def forward(self, arg0_1, arg1_1, arg2_1):
|
|||
|
||||
def test_cond_with_module_param_closure(self):
|
||||
class Mod(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.register_parameter(
|
||||
"param", torch.nn.Parameter(torch.ones(2, 3), requires_grad=False)
|
||||
|
|
|
|||
|
|
@ -3661,7 +3661,7 @@ class TestMakeFunctional(TestCase):
|
|||
@parametrize("disable_autograd_tracking", [True, False])
|
||||
def test_disable_autograd_tracking(self, disable_autograd_tracking):
|
||||
class Foo(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(3, 3)
|
||||
|
||||
|
|
@ -3679,7 +3679,7 @@ class TestMakeFunctional(TestCase):
|
|||
|
||||
def test_parameter_tying(self):
|
||||
class Foo(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.bias = nn.Parameter(torch.randn(3))
|
||||
self.linear = nn.Linear(3, 3)
|
||||
|
|
@ -3708,7 +3708,7 @@ class TestMakeFunctional(TestCase):
|
|||
|
||||
def test_buffer_tying(self):
|
||||
class Foo(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.bias = nn.Parameter(torch.randn(3))
|
||||
self.linear = nn.Linear(3, 3)
|
||||
|
|
@ -3740,7 +3740,7 @@ class TestMakeFunctional(TestCase):
|
|||
@parametrize("disable_autograd_tracking", [True, False])
|
||||
def test_with_buffers_disable_autograd_tracking(self, disable_autograd_tracking):
|
||||
class Foo(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(3, 3)
|
||||
self.buffer = nn.Buffer(torch.randn(3))
|
||||
|
|
@ -3762,7 +3762,7 @@ class TestMakeFunctional(TestCase):
|
|||
@parametrize("detach_params", [True, False])
|
||||
def test_using_detach_functional_call(self, detach_params):
|
||||
class Foo(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(3, 3)
|
||||
self.buffer = nn.Buffer(torch.randn(3))
|
||||
|
|
@ -3788,7 +3788,7 @@ class TestMakeFunctional(TestCase):
|
|||
|
||||
def test_parameter_tying_grad(self):
|
||||
class Foo(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(3, 3)
|
||||
self.weight = self.linear.weight
|
||||
|
|
@ -3820,7 +3820,7 @@ class TestMakeFunctional(TestCase):
|
|||
|
||||
def test_parameter_tying_ensemble(self):
|
||||
class Foo(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(3, 3)
|
||||
self.weight = self.linear.weight
|
||||
|
|
@ -3854,7 +3854,7 @@ class TestMakeFunctional(TestCase):
|
|||
@parametrize("mechanism", ["make_functional", "functional_call"])
|
||||
def test_correctness_mnist(self, mechanism):
|
||||
class Net(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
|
||||
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
|
||||
|
|
@ -3965,7 +3965,7 @@ class TestMakeFunctional(TestCase):
|
|||
@parametrize("mechanism", ["make_functional", "functional_call"])
|
||||
def test_make_functional_state_correctly_returned_after_forward(self, mechanism):
|
||||
class Net(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(3, 3)
|
||||
|
||||
|
|
@ -4021,7 +4021,7 @@ class TestExamplesCorrectness(TestCase):
|
|||
@parametrize("mechanism", ["make_functional", "functional_call"])
|
||||
def test_maml_regression(self, device, mechanism):
|
||||
class ThreeLayerNet(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(1, 40)
|
||||
self.relu1 = nn.ReLU()
|
||||
|
|
|
|||
|
|
@ -86,7 +86,7 @@ class TestMinifier(TestCase):
|
|||
|
||||
def test_module(self):
|
||||
class MockModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.relu = torch.nn.ReLU()
|
||||
|
||||
|
|
|
|||
|
|
@ -89,7 +89,7 @@ class TestDCE(TestCase):
|
|||
"""
|
||||
|
||||
class TestModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.attr_1 = torch.nn.Parameter(torch.tensor([-0.9]))
|
||||
|
||||
|
|
@ -105,7 +105,7 @@ class TestDCE(TestCase):
|
|||
"""
|
||||
|
||||
class TestModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.attr_1 = torch.nn.Parameter(torch.tensor([-0.9]))
|
||||
|
||||
|
|
@ -122,7 +122,7 @@ class TestDCE(TestCase):
|
|||
"""
|
||||
|
||||
class TestModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.attr_1 = torch.nn.Parameter(torch.tensor([-0.9]))
|
||||
|
||||
|
|
@ -169,7 +169,7 @@ class TestDCE(TestCase):
|
|||
_is_impure = True
|
||||
|
||||
class TestModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.relu = ReLUImpure()
|
||||
|
||||
|
|
|
|||
|
|
@ -61,7 +61,7 @@ class TestConstFold(TestCase):
|
|||
"""
|
||||
|
||||
class ConstFoldTestModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.attr_1 = torch.nn.Parameter(torch.tensor([[-0.9]]))
|
||||
self.attr_2 = torch.nn.Parameter(torch.tensor([[17.1]]))
|
||||
|
|
@ -106,7 +106,7 @@ class TestConstFold(TestCase):
|
|||
"""
|
||||
|
||||
class ConstFoldTestModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
# Note: Named as such to result in name collision.
|
||||
self.add_1__CF = torch.nn.Parameter(torch.tensor([[1.0]]))
|
||||
|
|
@ -168,7 +168,7 @@ class TestConstFold(TestCase):
|
|||
"""
|
||||
|
||||
class ConstFoldTestModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.attr1 = torch.nn.Parameter(torch.tensor([[-0.9]]))
|
||||
|
||||
|
|
@ -211,7 +211,7 @@ class TestConstFold(TestCase):
|
|||
"""
|
||||
|
||||
class ConstFoldTestModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.attr1 = torch.nn.Parameter(torch.tensor([[-0.9]]))
|
||||
self.attr1 = torch.nn.Parameter(torch.tensor([[1.32]]))
|
||||
|
|
@ -254,7 +254,7 @@ class TestConstFold(TestCase):
|
|||
"""
|
||||
|
||||
class ConstFoldTestModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.attr1 = torch.nn.Parameter(torch.randn(2, 3))
|
||||
self.attr2 = torch.nn.Parameter(torch.randn(2, 3))
|
||||
|
|
@ -301,7 +301,7 @@ class TestConstFold(TestCase):
|
|||
"""
|
||||
|
||||
class ConstFoldTestModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.attr1 = torch.nn.Parameter(torch.randn(4, 4))
|
||||
self.attr2 = torch.nn.Parameter(torch.randn(4, 4))
|
||||
|
|
@ -332,7 +332,7 @@ class TestConstFold(TestCase):
|
|||
"""
|
||||
|
||||
class TracedThroughModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.internal_attr = torch.nn.Parameter(torch.randn(2, 3))
|
||||
|
||||
|
|
@ -340,7 +340,7 @@ class TestConstFold(TestCase):
|
|||
return self.internal_attr
|
||||
|
||||
class ConstFoldTestModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.my_mod = TracedThroughModule()
|
||||
self.attr = torch.nn.Parameter(torch.randn(2, 3))
|
||||
|
|
@ -364,7 +364,7 @@ class TestConstFold(TestCase):
|
|||
"""
|
||||
|
||||
class ConstFoldTestModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.attr = torch.nn.Parameter(torch.randn(2, 3))
|
||||
|
||||
|
|
@ -413,7 +413,7 @@ class TestConstFold(TestCase):
|
|||
|
||||
def test_const_fold_has_inlined_call_module_node(self):
|
||||
class ConstFoldTestModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.attr = torch.nn.Parameter(torch.randn(2, 3))
|
||||
self.mod = torch.nn.Identity()
|
||||
|
|
@ -434,7 +434,7 @@ class TestConstFold(TestCase):
|
|||
|
||||
def test_const_fold_module_attr(self):
|
||||
class ConstFoldTestModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.const = torch.nn.Parameter(torch.randn(2, 3))
|
||||
self.mod = torch.nn.Identity()
|
||||
|
|
@ -456,7 +456,7 @@ class TestConstFold(TestCase):
|
|||
|
||||
def test_const_fold_unused_placeholder(self):
|
||||
class ConstFoldTestModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.const = torch.nn.Parameter(torch.randn(2, 3))
|
||||
|
||||
|
|
@ -475,7 +475,7 @@ class TestConstFold(TestCase):
|
|||
|
||||
def test_dict_output(self):
|
||||
class ConstFoldTestModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.const = torch.nn.Parameter(torch.randn(2, 3))
|
||||
|
||||
|
|
@ -494,7 +494,7 @@ class TestConstFold(TestCase):
|
|||
|
||||
def test_two_outputs(self):
|
||||
class ConstFoldTestModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.const = torch.nn.Parameter(torch.randn(2, 3))
|
||||
|
||||
|
|
@ -514,7 +514,7 @@ class TestConstFold(TestCase):
|
|||
|
||||
def test_three_outputs(self):
|
||||
class ConstFoldTestModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.const = torch.nn.Parameter(torch.randn(2, 3))
|
||||
|
||||
|
|
@ -540,7 +540,7 @@ class TestConstFold(TestCase):
|
|||
"""
|
||||
|
||||
class ConstFoldTestModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.attr = torch.nn.Parameter(torch.randn(2, 3))
|
||||
|
||||
|
|
@ -572,7 +572,7 @@ class TestConstFold(TestCase):
|
|||
"""
|
||||
|
||||
class ConstFoldTestModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.attr = torch.nn.Parameter(torch.randn(2, 3))
|
||||
|
||||
|
|
@ -605,7 +605,7 @@ class TestConstFold(TestCase):
|
|||
"""
|
||||
|
||||
class ConstFoldTestModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.weight = torch.nn.Parameter(torch.randn(4, 4))
|
||||
self.bias = torch.nn.Parameter(torch.randn(4))
|
||||
|
|
@ -650,7 +650,7 @@ class TestConstFold(TestCase):
|
|||
"""
|
||||
|
||||
class ConstFoldTestModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.lin_input = torch.nn.Parameter(torch.randn(4, 4))
|
||||
self.lin = torch.nn.Linear(4, 4)
|
||||
|
|
@ -676,7 +676,7 @@ class TestConstFold(TestCase):
|
|||
"""
|
||||
|
||||
class ConstFoldTestModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.attr_1 = torch.nn.Parameter(torch.tensor([[-0.9]]), requires_grad)
|
||||
self.attr_2 = torch.nn.Parameter(torch.tensor([[17.1]]), requires_grad)
|
||||
|
|
|
|||
|
|
@ -154,7 +154,7 @@ class TestSplitByTags(TestCase):
|
|||
|
||||
class TestSplitOutputType(TestCase):
|
||||
class TestModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True)
|
||||
self.relu = torch.nn.ReLU()
|
||||
|
|
|
|||
|
|
@ -672,7 +672,7 @@ class TypeCheckerTest(TestCase):
|
|||
|
||||
def test_type_check_conv2D_maxpool2d_flatten(self):
|
||||
class BasicBlock(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.conv1 = torch.nn.Conv2d(3, 6, 5)
|
||||
|
|
@ -761,7 +761,7 @@ class TypeCheckerTest(TestCase):
|
|||
|
||||
def test_type_typechecl_maxpool2d_3dinput(self):
|
||||
class BasicBlock(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.pool = torch.nn.MaxPool2d(5, 8)
|
||||
|
||||
|
|
@ -1119,7 +1119,7 @@ class TypeCheckerTest(TestCase):
|
|||
|
||||
def test_type_check_symbolic_inferenceconv2D_maxpool2d_flatten(self):
|
||||
class BasicBlock(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.conv1 = torch.nn.Conv2d(3, 6, 5)
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ class WrapperModule(torch.nn.Module):
|
|||
class TestMatcher(JitTestCase):
|
||||
def test_subgraph_matcher_with_attributes(self):
|
||||
class LargeModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._weight = torch.nn.Parameter(torch.ones(3, 3))
|
||||
self._bias = torch.nn.Parameter(torch.ones(3, 3))
|
||||
|
|
@ -53,7 +53,7 @@ class TestMatcher(JitTestCase):
|
|||
large_model_graph = symbolic_trace(LargeModel()).graph
|
||||
|
||||
class PatternModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._weight_1 = torch.nn.Parameter(torch.ones(5, 5))
|
||||
self._bias_1 = torch.nn.Parameter(torch.ones(5, 5))
|
||||
|
|
@ -228,7 +228,7 @@ class TestMatcher(JitTestCase):
|
|||
"""Testing SubgraphMatcherWithNameNodeMap with module pattern"""
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(5, 5)
|
||||
|
||||
|
|
@ -236,7 +236,7 @@ class TestMatcher(JitTestCase):
|
|||
return self.linear(x)
|
||||
|
||||
class Pattern(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(5, 5)
|
||||
|
||||
|
|
|
|||
|
|
@ -86,7 +86,7 @@ class TestShapeInference(unittest.TestCase):
|
|||
|
||||
def test_infer_shape(self):
|
||||
class TestModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.w_1 = torch.empty([256, 328])
|
||||
self.b_1 = torch.empty([256])
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ class TestSourceMatcher(JitTestCase):
|
|||
@unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
|
||||
def test_module_partitioner_linear_relu_linear(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear1 = torch.nn.Linear(3, 3)
|
||||
self.relu = torch.nn.ReLU()
|
||||
|
|
@ -140,7 +140,7 @@ class TestSourceMatcher(JitTestCase):
|
|||
@unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
|
||||
def test_module_partitioner_functional_conv_relu_conv(self):
|
||||
class FunctionalConv2d(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.stride = (1, 1)
|
||||
self.padding = (0, 0)
|
||||
|
|
@ -159,7 +159,7 @@ class TestSourceMatcher(JitTestCase):
|
|||
)
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv1 = FunctionalConv2d()
|
||||
self.conv2 = FunctionalConv2d()
|
||||
|
|
@ -184,7 +184,7 @@ class TestSourceMatcher(JitTestCase):
|
|||
@unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
|
||||
def test_module_partitioner_functional_linear_relu_linear(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x, weight, bias):
|
||||
|
|
|
|||
|
|
@ -400,7 +400,7 @@ class TestSubgraphRewriter(JitTestCase):
|
|||
"""
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.dtype = torch.float16
|
||||
|
||||
|
|
@ -439,7 +439,7 @@ class TestSubgraphRewriter(JitTestCase):
|
|||
|
||||
def test_subgraph_rewriter_replaces_referenced_submodules(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.sigmoid = torch.nn.Sigmoid()
|
||||
self.submod = torch.nn.ReLU()
|
||||
|
|
@ -449,7 +449,7 @@ class TestSubgraphRewriter(JitTestCase):
|
|||
return self.submod(self.sigmoid(x))
|
||||
|
||||
class Pattern(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.sigmoid = torch.nn.Sigmoid()
|
||||
self.submod = torch.nn.ReLU()
|
||||
|
|
@ -458,7 +458,7 @@ class TestSubgraphRewriter(JitTestCase):
|
|||
return self.submod(self.sigmoid(x))
|
||||
|
||||
class Replacement(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.tanh = torch.nn.Tanh()
|
||||
self.submod = torch.nn.ReLU()
|
||||
|
|
@ -467,7 +467,7 @@ class TestSubgraphRewriter(JitTestCase):
|
|||
return self.submod(self.tanh(x))
|
||||
|
||||
class Comparison(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.tanh = torch.nn.Tanh()
|
||||
self.submod = torch.nn.ReLU()
|
||||
|
|
@ -904,7 +904,7 @@ def forward(self, x):
|
|||
|
||||
def test_replacement_with_attrs(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.a = torch.tensor([1])
|
||||
self.b = torch.tensor([2])
|
||||
|
|
@ -913,7 +913,7 @@ def forward(self, x):
|
|||
return x + self.a - self.b
|
||||
|
||||
class Pattern(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.a = torch.tensor([1])
|
||||
|
||||
|
|
@ -921,7 +921,7 @@ def forward(self, x):
|
|||
return x + self.a
|
||||
|
||||
class Replacement(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.c = torch.tensor([3])
|
||||
|
||||
|
|
|
|||
|
|
@ -391,7 +391,7 @@ class HFOperations(unittest.TestCase):
|
|||
|
||||
def test_layer_norm(self):
|
||||
class BasicBlock(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.l = torch.nn.LayerNorm((1024,))
|
||||
|
||||
|
|
@ -723,7 +723,7 @@ class HFOperations(unittest.TestCase):
|
|||
|
||||
def test_embedding(self):
|
||||
class BasicBlock(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.embedding = torch.nn.Embedding(256008, 1024, padding_idx=1)
|
||||
|
||||
|
|
@ -881,7 +881,7 @@ class HFOperations(unittest.TestCase):
|
|||
|
||||
def test_view_mul(self):
|
||||
class BasicBlock(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.embed_tokens = torch.nn.Embedding(256008, 1024, padding_idx=1)
|
||||
|
||||
|
|
@ -1003,7 +1003,7 @@ class HFOperations(unittest.TestCase):
|
|||
"""
|
||||
|
||||
class BasicBlock(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.embed_tokens = torch.nn.Embedding(256008, 1024, padding_idx=1)
|
||||
|
||||
|
|
@ -1068,7 +1068,7 @@ class HFOperations(unittest.TestCase):
|
|||
"""
|
||||
|
||||
class BasicBlock(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.embed_tokens = torch.nn.Embedding(256008, 1024, padding_idx=1)
|
||||
|
||||
|
|
@ -1531,7 +1531,7 @@ class GradualTypes(unittest.TestCase):
|
|||
class TestSingleOperation(unittest.TestCase):
|
||||
def test_conv_wrong_example(self):
|
||||
class BasicBlock(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv1 = torch.nn.Conv2d(
|
||||
in_channels=2,
|
||||
|
|
@ -2188,7 +2188,7 @@ class TestSingleOperation(unittest.TestCase):
|
|||
|
||||
def test_conv2D_maxpool2d_flatten(self):
|
||||
class BasicBlock(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.conv1 = torch.nn.Conv2d(3, 6, 5)
|
||||
|
|
@ -2225,7 +2225,7 @@ class TestSingleOperation(unittest.TestCase):
|
|||
|
||||
def test_conv2D_maxpool2d_flatten_unsat(self):
|
||||
class BasicBlock(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.conv1 = torch.nn.Conv2d(3, 6, 5)
|
||||
|
|
@ -2258,7 +2258,7 @@ class TestSingleOperation(unittest.TestCase):
|
|||
|
||||
def test_conv2D_maxpool2d_flatten_dyn(self):
|
||||
class BasicBlock(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.conv1 = torch.nn.Conv2d(3, 6, 5)
|
||||
|
|
|
|||
|
|
@ -82,7 +82,7 @@ def forward(self, arg1_1):
|
|||
|
||||
def test_torchbind_custom_op(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
|
||||
|
||||
|
|
@ -108,7 +108,7 @@ def forward(self, arg0_1, arg1_1):
|
|||
|
||||
def test_print_with_buffer_mutations(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.buf = torch.nn.Buffer(torch.ones(3))
|
||||
|
||||
|
|
@ -143,7 +143,7 @@ def forward(self, arg0_1, arg1_1, arg2_1):
|
|||
|
||||
def test_print_with_input_mutations(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
|
|
@ -304,7 +304,7 @@ def forward(self, arg0_1, arg1_1, arg2_1):
|
|||
return torch.nn.functional.linear(x, self.weight, self.bias)
|
||||
|
||||
class MockModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear = MyLinear(10, 10)
|
||||
self.register_buffer(
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ class ModuleReturnMulti(nn.Module):
|
|||
# The default fx tracer will convert torch.randn to a constant.. We may need
|
||||
# a custom tracer.
|
||||
# class ModuleEagerTensor(nn.Module):
|
||||
# def __init__(self):
|
||||
# def __init__(self) -> None:
|
||||
# super().__init__()
|
||||
#
|
||||
# def forward(self, a):
|
||||
|
|
@ -60,7 +60,7 @@ class ModuleReturnMulti(nn.Module):
|
|||
# Unfortunately, the default fx tracer convert the return value of the forward
|
||||
# method to a constant.. Comment out for now
|
||||
# class ModuleReturnEagerTensorOnDefaultDevice(nn.Module):
|
||||
# def __init__(self):
|
||||
# def __init__(self) -> None:
|
||||
# super().__init__()
|
||||
#
|
||||
# def forward(self):
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ class LazyFuncionalizationTest(TestCase):
|
|||
metrics.reset()
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.fc1 = torch.nn.Linear(4, 2, bias=False)
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import torch
|
|||
|
||||
|
||||
class Module(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(1, 10, 5, 1)
|
||||
|
||||
|
|
|
|||
|
|
@ -197,7 +197,7 @@ TEST(LiteInterpreterTest, MultipleOps) {
|
|||
auto testModelFile = "ModelWithMultipleOps.ptl";
|
||||
|
||||
// class ModelWithMultipleOps(torch.nn.Module):
|
||||
// def __init__(self):
|
||||
// def __init__(self) -> None:
|
||||
// super().__init__()
|
||||
// self.ops = torch.nn.Sequential(
|
||||
// torch.nn.ReLU(),
|
||||
|
|
|
|||
|
|
@ -74,7 +74,7 @@ class ModelWithScalarList(torch.nn.Module):
|
|||
# upsample_linear1d.vec(Tensor input, int[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor
|
||||
@save_model
|
||||
class ModelWithFloatList(torch.nn.Upsample):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(
|
||||
scale_factor=(2.0,),
|
||||
mode="linear",
|
||||
|
|
@ -95,7 +95,7 @@ class ModelWithListOfOptionalTensors(torch.nn.Module):
|
|||
# int groups=1) -> Tensor
|
||||
@save_model
|
||||
class ModelWithArrayOfInt(torch.nn.Conv2d):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(1, 2, (2, 2), stride=(1, 1), padding=(1, 1))
|
||||
|
||||
|
||||
|
|
@ -120,7 +120,7 @@ class ModelWithStringOptional(torch.nn.Module):
|
|||
|
||||
@save_model
|
||||
class ModelWithMultipleOps(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.ops = torch.nn.Sequential(
|
||||
torch.nn.ReLU(),
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import torch.nn.functional as F
|
|||
|
||||
# https://pytorch.org/docs/stable/nn.html
|
||||
class NNConvolutionModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.input1d = torch.randn(1, 4, 36)
|
||||
self.input2d = torch.randn(1, 4, 30, 10)
|
||||
|
|
@ -42,7 +42,7 @@ class NNConvolutionModule(torch.nn.Module):
|
|||
|
||||
|
||||
class NNPoolingModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.input1d = torch.randn(1, 16, 50)
|
||||
self.module1d = nn.ModuleList(
|
||||
|
|
@ -90,7 +90,7 @@ class NNPoolingModule(torch.nn.Module):
|
|||
|
||||
|
||||
class NNPaddingModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.input1d = torch.randn(1, 4, 50)
|
||||
self.module1d = nn.ModuleList(
|
||||
|
|
@ -131,7 +131,7 @@ class NNPaddingModule(torch.nn.Module):
|
|||
|
||||
|
||||
class NNNormalizationModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.input1d = torch.randn(1, 4, 50)
|
||||
self.module1d = nn.ModuleList(
|
||||
|
|
@ -172,7 +172,7 @@ class NNNormalizationModule(torch.nn.Module):
|
|||
|
||||
|
||||
class NNActivationModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.activations = nn.ModuleList(
|
||||
[
|
||||
|
|
@ -215,7 +215,7 @@ class NNActivationModule(torch.nn.Module):
|
|||
|
||||
|
||||
class NNRecurrentModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.rnn = nn.ModuleList(
|
||||
[
|
||||
|
|
@ -245,7 +245,7 @@ class NNRecurrentModule(torch.nn.Module):
|
|||
|
||||
|
||||
class NNTransformerModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.transformers = nn.ModuleList(
|
||||
[
|
||||
|
|
@ -271,7 +271,7 @@ class NNTransformerModule(torch.nn.Module):
|
|||
|
||||
|
||||
class NNLinearModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linears = nn.ModuleList(
|
||||
[
|
||||
|
|
@ -329,7 +329,7 @@ class NNDistanceModule(torch.nn.Module):
|
|||
|
||||
|
||||
class NNLossFunctionModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.x = torch.FloatTensor([[0.1, 0.2, 0.4, 0.8]])
|
||||
self.y = torch.LongTensor([[3, 0, -1, 1]])
|
||||
|
|
@ -368,7 +368,7 @@ class NNLossFunctionModule(torch.nn.Module):
|
|||
|
||||
|
||||
class NNVisionModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.input = torch.randn(1, 4, 9, 9)
|
||||
self.vision_modules = nn.ModuleList(
|
||||
|
|
@ -398,7 +398,7 @@ class NNVisionModule(torch.nn.Module):
|
|||
|
||||
|
||||
class NNShuffleModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.shuffle = nn.ChannelShuffle(2)
|
||||
|
||||
|
|
@ -409,7 +409,7 @@ class NNShuffleModule(torch.nn.Module):
|
|||
|
||||
|
||||
class NNUtilsModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.flatten = nn.Sequential(nn.Linear(50, 50), nn.Unflatten(1, (2, 5, 5)))
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import torch.nn as nn
|
|||
|
||||
|
||||
class GeneralQuantModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.embedding = torch.ao.nn.quantized.Embedding(
|
||||
num_embeddings=10, embedding_dim=12
|
||||
|
|
@ -47,7 +47,7 @@ class GeneralQuantModule(torch.nn.Module):
|
|||
|
||||
|
||||
class DynamicQuantModule:
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.module = self.M()
|
||||
|
||||
|
|
@ -55,7 +55,7 @@ class DynamicQuantModule:
|
|||
return torch.ao.quantization.quantize_dynamic(self.module, dtype=torch.qint8)
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super(DynamicQuantModule.M, self).__init__()
|
||||
self.rnn = nn.RNN(4, 8, 2)
|
||||
self.rnncell = nn.RNNCell(4, 8)
|
||||
|
|
@ -122,7 +122,7 @@ class StaticQuantModule:
|
|||
return model_int8
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super(StaticQuantModule.M, self).__init__()
|
||||
self.quant = torch.ao.quantization.QuantStub()
|
||||
self.input1d = torch.randn(4, 2, 2)
|
||||
|
|
@ -182,7 +182,7 @@ class FusedQuantModule:
|
|||
return model_int8
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super(FusedQuantModule.M, self).__init__()
|
||||
self.quant = torch.ao.quantization.QuantStub()
|
||||
self.input1d = torch.randn(4, 2, 2)
|
||||
|
|
|
|||
|
|
@ -72,7 +72,7 @@ class TestLiteScriptModule(TestCase):
|
|||
return x * y
|
||||
|
||||
class B(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.A0 = A()
|
||||
self.A1 = A()
|
||||
|
|
@ -177,7 +177,7 @@ class TestLiteScriptModule(TestCase):
|
|||
|
||||
def test_method_calls_with_optional_arg(self):
|
||||
class A(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
# opt arg in script-to-script invocation
|
||||
|
|
@ -185,7 +185,7 @@ class TestLiteScriptModule(TestCase):
|
|||
return x + two
|
||||
|
||||
class B(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.A0 = A()
|
||||
|
||||
|
|
@ -218,7 +218,7 @@ class TestLiteScriptModule(TestCase):
|
|||
|
||||
def test_unsupported_classtype(self):
|
||||
class Foo:
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
return
|
||||
|
||||
def func(self, x: int, y: int):
|
||||
|
|
@ -243,7 +243,7 @@ class TestLiteScriptModule(TestCase):
|
|||
pass
|
||||
|
||||
class MyTestModuleForListWithModuleClass(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.foo = Foo()
|
||||
|
||||
|
|
@ -267,7 +267,7 @@ class TestLiteScriptModule(TestCase):
|
|||
pass
|
||||
|
||||
class MyTestModuleForDictWithModuleClass(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.foo = Foo()
|
||||
|
||||
|
|
@ -288,7 +288,7 @@ class TestLiteScriptModule(TestCase):
|
|||
|
||||
def test_module_export_operator_list(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.weight = torch.ones((20, 1, 5, 5))
|
||||
self.bias = torch.ones(20)
|
||||
|
|
@ -464,7 +464,7 @@ class TestLiteScriptModule(TestCase):
|
|||
class A(torch.nn.Module):
|
||||
b: Forward
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.b = B()
|
||||
|
||||
|
|
@ -523,7 +523,7 @@ class TestLiteScriptQuantizedModule(QuantizationLiteTestCase):
|
|||
def test_quantization_example(self):
|
||||
# From the example in Static Quantization section of https://pytorch.org/docs/stable/quantization.html
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.quant = torch.ao.quantization.QuantStub()
|
||||
self.conv = torch.nn.Conv2d(1, 1, 1)
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ class TestLiteScriptModule(TestCase):
|
|||
id: torch.Tensor
|
||||
|
||||
class Bar(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.foo = Foo(torch.tensor(1))
|
||||
|
||||
|
|
@ -101,7 +101,7 @@ class TestLiteScriptModule(TestCase):
|
|||
id: torch.Tensor
|
||||
|
||||
class Bar(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.foo = Foo(torch.tensor(1))
|
||||
|
||||
|
|
@ -144,7 +144,7 @@ class TestLiteScriptModule(TestCase):
|
|||
baz: Baz
|
||||
|
||||
class Bar(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.foo = Foo(torch.tensor(1), Baz(torch.tensor(1)))
|
||||
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ class TestLiteFuseFx(QuantizationLiteTestCase):
|
|||
|
||||
def test_embedding(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=12)
|
||||
|
||||
|
|
@ -50,7 +50,7 @@ class TestLiteFuseFx(QuantizationLiteTestCase):
|
|||
|
||||
def test_conv2d(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(1, 1, 1)
|
||||
self.conv2 = nn.Conv2d(1, 1, 1)
|
||||
|
|
|
|||
|
|
@ -753,7 +753,7 @@ class TestLazyModules(TestCase):
|
|||
@suppress_warnings
|
||||
def test_chained_initialization(self):
|
||||
class MyNetwork(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear_1 = torch.nn.LazyLinear(15)
|
||||
self.linear_2 = torch.nn.LazyLinear(10)
|
||||
|
|
|
|||
|
|
@ -223,7 +223,7 @@ class TestLoadStateDict(NNTestCase):
|
|||
@swap([True, False])
|
||||
def test_load_state_dict_custom(self):
|
||||
class CustomState(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.param = torch.nn.Parameter(torch.ones(1))
|
||||
self.sub = torch.nn.Linear(5, 5)
|
||||
|
|
@ -264,7 +264,7 @@ class TestLoadStateDict(NNTestCase):
|
|||
@parametrize("keep_vars", [True, False])
|
||||
def test_load_state_dict_assign_meta(self, keep_vars):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(3, 5)
|
||||
self.bn = nn.BatchNorm1d(5)
|
||||
|
|
@ -340,7 +340,7 @@ class TestLoadStateDict(NNTestCase):
|
|||
@swap([True, False])
|
||||
def test_load_state_dict_assign_with_optimizer(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(3, 5)
|
||||
self.bn = nn.BatchNorm1d(5)
|
||||
|
|
@ -390,7 +390,7 @@ class TestLoadStateDict(NNTestCase):
|
|||
# Assigned tensor is allowed to have different properties than initial
|
||||
# tensor except for shape
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(3, 5)
|
||||
self.bn = nn.BatchNorm1d(5)
|
||||
|
|
@ -426,7 +426,7 @@ class TestLoadStateDict(NNTestCase):
|
|||
@swap([True, False])
|
||||
def test_load_state_dict_with_unexpected_key(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.fc1 = torch.nn.Linear(5, 10)
|
||||
|
||||
|
|
|
|||
|
|
@ -627,7 +627,7 @@ class TestStateDictHooks(TestCase):
|
|||
|
||||
# Test with module instance method as hook
|
||||
class MyModule(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.foo = torch.nn.Parameter(torch.rand(10))
|
||||
|
||||
|
|
@ -699,7 +699,7 @@ class TestStateDictHooks(TestCase):
|
|||
hook_called = 0
|
||||
|
||||
class MyModule(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.foo = torch.nn.Parameter(torch.rand(10))
|
||||
|
||||
|
|
@ -813,7 +813,7 @@ class TestStateDictHooks(TestCase):
|
|||
|
||||
def test_register_state_dict_pre_hook(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.a = nn.Sequential(
|
||||
nn.Linear(3, 3), nn.Linear(3, 3), nn.Linear(3, 3)
|
||||
|
|
@ -827,7 +827,7 @@ class TestStateDictHooks(TestCase):
|
|||
|
||||
def test_register_state_dict_pre_hook_lazy_module(self):
|
||||
class MyLazyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.layer1 = nn.LazyLinear(8)
|
||||
self.layer2 = nn.LazyLinear(5)
|
||||
|
|
|
|||
|
|
@ -941,7 +941,7 @@ class TestNNParametrization(NNTestCase):
|
|||
return x + 1.0
|
||||
|
||||
class ModelWithoutDeepcopy(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(
|
||||
torch.tensor([1.0, 1.0, 1.0, 1.0]), requires_grad=True
|
||||
|
|
|
|||
|
|
@ -353,7 +353,7 @@ class TestDynamoWithONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
)
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(2, 4, bias=True)
|
||||
self.fc2 = nn.Linear(4, 2, bias=True)
|
||||
|
|
@ -690,7 +690,7 @@ class TestDynamoWithONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
)
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(2, 4, bias=True)
|
||||
self.fc2 = nn.Linear(4, 2, bias=True)
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ class SampleModelForDynamicShapes(torch.nn.Module):
|
|||
|
||||
|
||||
class _LargeModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.param = torch.nn.Parameter(torch.randn(2**28)) # 1GB
|
||||
self.param2 = torch.nn.Parameter(torch.randn(2**28)) # 1GB
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import torch.nn.functional as F
|
|||
|
||||
|
||||
class MNIST(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
|
||||
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ class PermuteNet(nn.Module):
|
|||
|
||||
|
||||
class PReluNet(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.features = nn.Sequential(
|
||||
nn.PReLU(3),
|
||||
|
|
@ -41,7 +41,7 @@ class PReluNet(nn.Module):
|
|||
|
||||
|
||||
class FakeQuantNet(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.fake_quant = torch.ao.quantization.FakeQuantize()
|
||||
self.fake_quant.disable_observer()
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ class TestCustomAutogradFunction(pytorch_test_common.ExportTestCase):
|
|||
return g.op("Clip", input, min_f=scalar)
|
||||
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.clip = MyClip.apply
|
||||
|
||||
|
|
@ -52,7 +52,7 @@ class TestCustomAutogradFunction(pytorch_test_common.ExportTestCase):
|
|||
return input.clamp(min=0)
|
||||
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.clip = MyClip.apply
|
||||
self.relu = MyRelu.apply
|
||||
|
|
@ -89,7 +89,7 @@ class TestExportAsContribOps(pytorch_test_common.ExportTestCase):
|
|||
|
||||
def test_contrib_op_with_loop(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.gelu = torch.nn.GELU(approximate="none")
|
||||
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ from torch.testing._internal import common_utils
|
|||
# Smoke tests for export methods
|
||||
class TestExportModes(pytorch_test_common.ExportTestCase):
|
||||
class MyModel(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super(TestExportModes.MyModel, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
|
|
|
|||
|
|
@ -126,7 +126,7 @@ class TestModularizePass(common_utils.TestCase):
|
|||
#
|
||||
# Minified repro from Background_Matting. https://github.com/pytorch/benchmark/issues/1768
|
||||
class TestModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.unused_relu = torch.nn.ReLU()
|
||||
self.used_gelu = torch.nn.GELU()
|
||||
|
|
@ -172,7 +172,7 @@ class TestModularizePass(common_utils.TestCase):
|
|||
self, is_exported_program
|
||||
):
|
||||
class TestModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.relu = torch.nn.ReLU()
|
||||
|
||||
|
|
@ -218,7 +218,7 @@ class TestModularizePass(common_utils.TestCase):
|
|||
):
|
||||
# Minified repro from basic_gnn_edgecnn.
|
||||
class InnerModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.relu = torch.nn.ReLU()
|
||||
|
||||
|
|
@ -226,7 +226,7 @@ class TestModularizePass(common_utils.TestCase):
|
|||
return self.relu(x)
|
||||
|
||||
class TestModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.inner_module = InnerModule()
|
||||
|
||||
|
|
|
|||
|
|
@ -104,7 +104,7 @@ class TestFxToOnnx(pytorch_test_common.ExportTestCase):
|
|||
)
|
||||
def test_mnist_exported_with_no_warnings(self, diagnostic_rule):
|
||||
class MNISTModel(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(1, 32, 3, 1, bias=False)
|
||||
self.conv2 = nn.Conv2d(32, 64, 3, 1, bias=False)
|
||||
|
|
@ -227,7 +227,7 @@ class TestFxToOnnx(pytorch_test_common.ExportTestCase):
|
|||
self,
|
||||
):
|
||||
class TraceModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv2 = torch.nn.Conv2d(
|
||||
16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1)
|
||||
|
|
@ -340,7 +340,7 @@ class TestFxToOnnx(pytorch_test_common.ExportTestCase):
|
|||
return output + bias
|
||||
|
||||
class Module(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.submodule = SubModule()
|
||||
|
||||
|
|
@ -402,7 +402,7 @@ class TestFxToOnnx(pytorch_test_common.ExportTestCase):
|
|||
|
||||
def test_dynamo_export_retains_readable_parameter_and_buffer_names(self):
|
||||
class SubModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv2 = nn.Conv2d(32, 64, 3, 1, bias=False)
|
||||
self.fc1 = nn.Linear(9216, 128, bias=False)
|
||||
|
|
@ -419,7 +419,7 @@ class TestFxToOnnx(pytorch_test_common.ExportTestCase):
|
|||
return tensor_x
|
||||
|
||||
class MNISTModel(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(1, 32, 3, 1, bias=False)
|
||||
self.submodule = SubModule()
|
||||
|
|
@ -649,7 +649,7 @@ class TestFxToOnnx(pytorch_test_common.ExportTestCase):
|
|||
|
||||
def test_exported_program_torch_distributions_normal_Normal(self):
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.normal = torch.distributions.normal.Normal(0, 1)
|
||||
super().__init__()
|
||||
|
||||
|
|
@ -825,7 +825,7 @@ class TestFxToOnnx(pytorch_test_common.ExportTestCase):
|
|||
self, include_initializer, use_fake_mode, use_exported_program
|
||||
):
|
||||
class MNISTModel(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(1, 32, 3, 1, bias=False)
|
||||
self.conv2 = nn.Conv2d(32, 64, 3, 1, bias=False)
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ class TestDynamoExportDecompSkip(pytorch_test_common.ExportTestCase):
|
|||
|
||||
def test_upsample_bilinear2d(self):
|
||||
class TestModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.upsample = torch.nn.Upsample(scale_factor=2, mode="bilinear")
|
||||
|
||||
|
|
@ -51,7 +51,7 @@ class TestDynamoExportDecompSkip(pytorch_test_common.ExportTestCase):
|
|||
|
||||
def test_upsample_trilinear3d(self):
|
||||
class TestModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.upsample = torch.nn.Upsample(scale_factor=2, mode="trilinear")
|
||||
|
||||
|
|
|
|||
|
|
@ -273,7 +273,7 @@ class TestFxToOnnxWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
|
|||
|
||||
def test_mnist(self):
|
||||
class MNISTModel(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(1, 32, 3, 1, bias=True)
|
||||
self.conv2 = nn.Conv2d(32, 64, 3, 1, bias=True)
|
||||
|
|
@ -302,7 +302,7 @@ class TestFxToOnnxWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
|
|||
# This produces op as `torch.ops.aten.log_sigmoid_forward`, instead of the more
|
||||
# conventional `torch.ops.aten.log_sigmoid`.
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.m = torch.nn.LogSigmoid()
|
||||
|
||||
|
|
@ -419,7 +419,7 @@ class TestFxToOnnxWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
|
|||
|
||||
def test_transpose_infer_shape(self):
|
||||
class TransposeModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(3, 1, 3, stride=2)
|
||||
|
||||
|
|
@ -845,7 +845,7 @@ class TestFxToOnnxWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
|
|||
)
|
||||
def test_fx_symbolic_tracer_large_scale_exporter_with_toy_mlp(self):
|
||||
class MLPModel(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.fc0 = nn.Linear(8, 8, bias=True)
|
||||
self.fc1 = nn.Linear(8, 4, bias=True)
|
||||
|
|
@ -1178,7 +1178,7 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
|
|||
)
|
||||
def test_large_scale_exporter_with_toy_mlp(self):
|
||||
class MLPModel(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.fc0 = nn.Linear(8, 8, bias=True)
|
||||
self.fc1 = nn.Linear(8, 4, bias=True)
|
||||
|
|
|
|||
|
|
@ -347,7 +347,7 @@ class TestModelsONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
@skipScriptTest()
|
||||
def test_roi_heads(self):
|
||||
class RoIHeadsModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.transform = _init_test_generalized_rcnn_transform()
|
||||
self.rpn = _init_test_rpn()
|
||||
|
|
|
|||
|
|
@ -300,7 +300,7 @@ class TestONNXOpset(pytorch_test_common.ExportTestCase):
|
|||
|
||||
def test_dropout(self):
|
||||
class MyModule(Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.dropout = torch.nn.Dropout(0.5)
|
||||
|
||||
|
|
|
|||
|
|
@ -109,7 +109,7 @@ class TestONNXScriptExport(common_utils.TestCase):
|
|||
# Control flow is tested for _find_onnxscript_op function in torch/onnx/utils.py,
|
||||
# which has recursive logic to go through every nodes with subgraph in model proto
|
||||
class NestedLoopsModel(torch.jit.ScriptModule):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.selu = torch.nn.SELU()
|
||||
|
||||
|
|
|
|||
|
|
@ -1061,7 +1061,7 @@ class TestOperators(common_utils.TestCase):
|
|||
c0 = torch.randn(1, BATCH_SIZE, RNN_HIDDEN_SIZE)
|
||||
|
||||
class LSTMModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.rnn = torch.nn.LSTM(
|
||||
RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, 1, bidirectional=False
|
||||
|
|
@ -1157,7 +1157,7 @@ class TestOperators(common_utils.TestCase):
|
|||
)
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.emb = torch.nn.Embedding(4, 8)
|
||||
|
||||
|
|
@ -1207,7 +1207,7 @@ class TestOperators(common_utils.TestCase):
|
|||
)
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.emb = torch.nn.Embedding(4, 8)
|
||||
|
||||
|
|
|
|||
|
|
@ -97,7 +97,7 @@ class TestONNXExport(pytorch_test_common.ExportTestCase):
|
|||
return x.contiguous().transpose(0, 1).sum()
|
||||
|
||||
class TraceMe(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.foo = Foo()
|
||||
|
||||
|
|
@ -149,7 +149,7 @@ class TestONNXExport(pytorch_test_common.ExportTestCase):
|
|||
return torch.neg(x)
|
||||
|
||||
class ModuleToExport(torch.jit.ScriptModule):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.mod = PythonModule()
|
||||
|
||||
|
|
@ -169,7 +169,7 @@ class TestONNXExport(pytorch_test_common.ExportTestCase):
|
|||
return torch.neg(x)
|
||||
|
||||
class ModuleToExport(torch.jit.ScriptModule):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.mod = torch.jit.trace(ModuleToInline(), torch.zeros(1, 2, 3))
|
||||
|
||||
|
|
@ -188,7 +188,7 @@ class TestONNXExport(pytorch_test_common.ExportTestCase):
|
|||
return torch.neg(x)
|
||||
|
||||
class ModuleToExport(torch.jit.ScriptModule):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.mod = ModuleToInline()
|
||||
|
||||
|
|
@ -251,7 +251,7 @@ class TestONNXExport(pytorch_test_common.ExportTestCase):
|
|||
|
||||
def test_onnx_export_script_inline_params(self):
|
||||
class ModuleToInline(torch.jit.ScriptModule):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.m = torch.nn.Parameter(torch.ones(3, 3))
|
||||
self.unused = torch.nn.Parameter(torch.ones(1, 2, 3))
|
||||
|
|
@ -261,7 +261,7 @@ class TestONNXExport(pytorch_test_common.ExportTestCase):
|
|||
return torch.mm(x, self.m)
|
||||
|
||||
class ModuleToExport(torch.jit.ScriptModule):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.mod = ModuleToInline()
|
||||
self.param = torch.nn.Parameter(torch.ones(3, 4))
|
||||
|
|
@ -375,7 +375,7 @@ class TestONNXExport(pytorch_test_common.ExportTestCase):
|
|||
|
||||
def test_source_range_propagation(self):
|
||||
class ExpandingModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
# Will be expanded during ONNX export
|
||||
self.ln = torch.nn.LayerNorm([1])
|
||||
|
|
@ -485,7 +485,7 @@ class TestONNXExport(pytorch_test_common.ExportTestCase):
|
|||
"box_coder": BoxCoder,
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.box_coder = BoxCoder(1.4)
|
||||
|
||||
|
|
@ -888,7 +888,7 @@ class TestONNXExport(pytorch_test_common.ExportTestCase):
|
|||
mask_start_point = 0
|
||||
|
||||
class LSTMTraceWrapper(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.rnn = torch.nn.LSTM(
|
||||
|
|
@ -1003,7 +1003,7 @@ class TestONNXExport(pytorch_test_common.ExportTestCase):
|
|||
def test_onnx_aten_fallback_must_not_fallback(self):
|
||||
# For BUILD_CAFFE2=0, aten fallback only when not exportable
|
||||
class ONNXExportable(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.quant = torch.ao.quantization.QuantStub()
|
||||
self.fc1 = torch.nn.Linear(12, 8)
|
||||
|
|
|
|||
|
|
@ -166,7 +166,7 @@ def _parametrize_rnn_args(arg_name):
|
|||
class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
||||
def test_fuse_conv_bn1d(self):
|
||||
class Fuse(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv1d(16, 33, 3, stride=2)
|
||||
self.bn = torch.nn.BatchNorm1d(33)
|
||||
|
|
@ -181,7 +181,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
|
||||
def test_fuse_conv_bn2d(self):
|
||||
class Fuse(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(
|
||||
3, 2, kernel_size=1, stride=2, padding=3, bias=False
|
||||
|
|
@ -198,7 +198,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
|
||||
def test_fuse_conv_bn3d(self):
|
||||
class Fuse(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv3d(
|
||||
3, 2, (3, 5, 2), stride=(2, 1, 1), padding=(3, 2, 0), bias=False
|
||||
|
|
@ -215,7 +215,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
|
||||
def test_fuse_conv_in_block(self):
|
||||
class Fuse(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv1d(
|
||||
in_channels=5,
|
||||
|
|
@ -1201,7 +1201,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
|
||||
def test_conv(self):
|
||||
class TraceModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv1 = torch.nn.Conv1d(16, 33, 3, stride=2)
|
||||
self.conv2 = torch.nn.Conv2d(
|
||||
|
|
@ -1222,7 +1222,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
|
||||
def test_conv_str_padding(self):
|
||||
class TraceModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv1 = torch.nn.Conv1d(16, 33, 3, padding="valid")
|
||||
self.conv2 = torch.nn.Conv2d(
|
||||
|
|
@ -1243,7 +1243,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
|
||||
def test_conv_shape_inference(self):
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv2 = torch.nn.Conv2d(
|
||||
16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1)
|
||||
|
|
@ -1259,7 +1259,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
|
||||
def test_conv_transpose(self):
|
||||
class TraceModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv1 = torch.nn.ConvTranspose1d(16, 33, 3, stride=2)
|
||||
self.conv2 = torch.nn.ConvTranspose2d(
|
||||
|
|
@ -1289,7 +1289,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
# The following test only works when onnx shape inference is enabled.
|
||||
def test_transpose_infer_shape(self):
|
||||
class TransposeModule(torch.jit.ScriptModule):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(3, 1, 3, stride=2)
|
||||
|
||||
|
|
@ -2610,7 +2610,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
return bias
|
||||
|
||||
class ScriptModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.ngram = 2
|
||||
self.max_target_positions = 512
|
||||
|
|
@ -3064,7 +3064,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
)
|
||||
|
||||
class ScriptModule(torch.jit.ScriptModule):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.submodule = ScriptModel()
|
||||
|
||||
|
|
@ -4262,7 +4262,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
@skipIfUnsupportedMinOpsetVersion(16)
|
||||
def test_scatter_reduce(self):
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x, index, input):
|
||||
|
|
@ -4284,7 +4284,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
@skipIfUnsupportedMinOpsetVersion(16)
|
||||
def test_scatter_reduce_self_rank_zero(self):
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x, index, input):
|
||||
|
|
@ -4349,7 +4349,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
@skipScriptTest() # Scripting error: Cannot instantiate nn module
|
||||
def test_gather_constant_fold(self):
|
||||
class GatherModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.weight = torch.nn.Buffer(torch.ones(5))
|
||||
# torch.nn.Embedding is converted to ONNX::Gather.
|
||||
|
|
@ -4368,7 +4368,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
self.run_test(GatherModule(), (x,))
|
||||
|
||||
class GatherModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.weight = torch.nn.Buffer(torch.ones(2))
|
||||
|
||||
|
|
@ -4383,7 +4383,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
self.run_test(GatherModule(), (x,))
|
||||
|
||||
class GatherModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.rb = torch.nn.Buffer(torch.randn(1, 1, 3, 1, 1))
|
||||
|
||||
|
|
@ -4652,7 +4652,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
@skipIfUnsupportedMinOpsetVersion(9)
|
||||
def test_lstm_no_hidden(self):
|
||||
class LSTMModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.rnn = torch.nn.LSTM(input_size=16, hidden_size=16)
|
||||
|
||||
|
|
@ -4665,7 +4665,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
@skipIfUnsupportedMinOpsetVersion(9)
|
||||
def test_lstm_proj_no_hidden(self):
|
||||
class LSTMModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.rnn = torch.nn.LSTM(input_size=16, hidden_size=16, proj_size=8)
|
||||
|
||||
|
|
@ -4679,7 +4679,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
@skipIfUnsupportedMinOpsetVersion(9)
|
||||
def test_lstm(self):
|
||||
class LSTMModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.rnn = torch.nn.LSTM(
|
||||
RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, 1, bidirectional=False
|
||||
|
|
@ -4714,7 +4714,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
@skipIfUnsupportedMinOpsetVersion(9)
|
||||
def test_lstm_default_init_state(self):
|
||||
class LSTMModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.rnn = torch.nn.LSTM(
|
||||
RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, 1, bidirectional=False
|
||||
|
|
@ -4729,7 +4729,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
@skipIfUnsupportedMinOpsetVersion(9)
|
||||
def test_lstm_fixed_batch_size(self):
|
||||
class LSTMModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.lstm = torch.nn.LSTM(
|
||||
RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, 1, bidirectional=False
|
||||
|
|
@ -4752,7 +4752,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
@skipIfUnsupportedMinOpsetVersion(9)
|
||||
def test_lstm_post_fix_init_state(self):
|
||||
class LSTMModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.lstm = torch.nn.LSTM(
|
||||
RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, 1, bidirectional=False
|
||||
|
|
@ -4842,7 +4842,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
@skipIfUnsupportedMinOpsetVersion(9)
|
||||
def test_lstm_sequence(self):
|
||||
class LstmNet(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.rnn1 = torch.nn.LSTM(8, 8, bidirectional=True, batch_first=True)
|
||||
self.linear1 = torch.nn.Linear(8 * 2, 8)
|
||||
|
|
@ -5288,7 +5288,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
|
||||
def test_gt_primitive(self):
|
||||
class GreaterModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.y: int = 2
|
||||
|
||||
|
|
@ -5629,7 +5629,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
|
||||
def test_linear(self):
|
||||
class LinearModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.fc = torch.nn.Linear(16, 16)
|
||||
|
||||
|
|
@ -6580,7 +6580,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
@skipIfUnsupportedMinOpsetVersion(9)
|
||||
def test_new_zeros_with_dtype(self):
|
||||
class MyModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.emb = torch.nn.Embedding(50, 64)
|
||||
|
||||
|
|
@ -6834,7 +6834,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
@skipIfUnsupportedMinOpsetVersion(11)
|
||||
def test_inplace_attr_with_loop(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._bias = torch.arange(
|
||||
12,
|
||||
|
|
@ -6861,7 +6861,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
@skipIfUnsupportedMinOpsetVersion(11)
|
||||
def test_inplace_attr_copy_with_loop(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._bias = torch.arange(
|
||||
12,
|
||||
|
|
@ -7282,7 +7282,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
|
||||
def test_unfold_infer_shape(self):
|
||||
class UnfoldModule(torch.jit.ScriptModule):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv1d(3, 1, 3, stride=2)
|
||||
|
||||
|
|
@ -7347,7 +7347,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
|
||||
def test_prelu(self):
|
||||
class PReluModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.prelu = torch.nn.PReLU()
|
||||
|
||||
|
|
@ -7370,7 +7370,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
|
||||
def test_relu6(self):
|
||||
class Relu6Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.relu6 = torch.nn.ReLU6()
|
||||
|
||||
|
|
@ -7389,7 +7389,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
|
||||
def test_silu(self):
|
||||
class SiLUModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.silu = torch.nn.SiLU()
|
||||
|
||||
|
|
@ -7461,7 +7461,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
|
||||
def test_mish(self):
|
||||
class MishModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.mish = torch.nn.Mish()
|
||||
|
||||
|
|
@ -9047,7 +9047,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
@skipIfUnsupportedMinOpsetVersion(9)
|
||||
def test_MSELoss(self):
|
||||
class MSELoss(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.loss1 = torch.nn.MSELoss(reduction="none")
|
||||
self.loss2 = torch.nn.MSELoss(reduction="sum")
|
||||
|
|
@ -9080,7 +9080,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
|
||||
def _kldiv_loss(self, x, y):
|
||||
class KLDivLossNone(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.loss = torch.nn.KLDivLoss(reduction="none", log_target=True)
|
||||
|
||||
|
|
@ -9090,7 +9090,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
self.run_test(KLDivLossNone(), input_args=(x, y))
|
||||
|
||||
class KLDivLossMean(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.loss = torch.nn.KLDivLoss(reduction="mean", log_target=False)
|
||||
|
||||
|
|
@ -9100,7 +9100,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
self.run_test(KLDivLossMean(), input_args=(x, y))
|
||||
|
||||
class KLDivLossSum(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.loss = torch.nn.KLDivLoss(reduction="sum", log_target=True)
|
||||
|
||||
|
|
@ -9110,7 +9110,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
self.run_test(KLDivLossSum(), input_args=(x, y))
|
||||
|
||||
class KLDivLossBatchMean(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.loss = torch.nn.KLDivLoss(reduction="batchmean", log_target=False)
|
||||
|
||||
|
|
@ -9120,7 +9120,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
self.run_test(KLDivLossBatchMean(), input_args=(x, y))
|
||||
|
||||
class KLDivLossMiniBatchMean(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.loss = torch.nn.KLDivLoss(
|
||||
reduction="batchmean", size_average=False, log_target=True
|
||||
|
|
@ -9134,7 +9134,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
@skipIfUnsupportedMinOpsetVersion(12)
|
||||
def test_nllloss(self):
|
||||
class NLLModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.loss = torch.nn.NLLLoss(reduction="none")
|
||||
self.m = torch.nn.LogSoftmax(dim=1)
|
||||
|
|
@ -9154,7 +9154,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
@skipIfUnsupportedMinOpsetVersion(12)
|
||||
def test_nllloss_2d_none(self):
|
||||
class NLLModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.loss = torch.nn.NLLLoss(reduction="none")
|
||||
self.conv = torch.nn.Conv2d(16, C, (3, 3))
|
||||
|
|
@ -9175,7 +9175,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
@skipIfUnsupportedMinOpsetVersion(12)
|
||||
def test_nllloss_2d_mean(self):
|
||||
class NLLModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.loss = torch.nn.NLLLoss(reduction="mean")
|
||||
self.conv = torch.nn.Conv2d(16, C, (3, 3))
|
||||
|
|
@ -9196,7 +9196,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
@skipIfUnsupportedMinOpsetVersion(12)
|
||||
def test_nllloss_2d_sum(self):
|
||||
class NLLModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.loss = torch.nn.NLLLoss(reduction="sum")
|
||||
self.conv = torch.nn.Conv2d(16, C, (3, 3))
|
||||
|
|
@ -9217,7 +9217,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
@skipIfUnsupportedMinOpsetVersion(12)
|
||||
def test_nllloss_2d_mean_weights(self):
|
||||
class NLLModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.loss = torch.nn.NLLLoss(reduction="mean", weight=torch.randn(C))
|
||||
self.conv = torch.nn.Conv2d(16, C, (3, 3))
|
||||
|
|
@ -9238,7 +9238,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
@skipIfUnsupportedMinOpsetVersion(12)
|
||||
def test_nllloss_2d_mean_ignore_index(self):
|
||||
class NLLModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.loss = torch.nn.NLLLoss(reduction="mean", ignore_index=1)
|
||||
self.conv = torch.nn.Conv2d(16, C, (3, 3))
|
||||
|
|
@ -9296,7 +9296,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
@skipIfUnsupportedMinOpsetVersion(12)
|
||||
def test_nllloss_2d_mean_ignore_index_weights(self):
|
||||
class NLLModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.loss = torch.nn.NLLLoss(
|
||||
reduction="mean", weight=torch.randn(C), ignore_index=1
|
||||
|
|
@ -9640,7 +9640,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
|
||||
def test_dropout(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.dropout = torch.nn.Dropout(0.3)
|
||||
|
||||
|
|
@ -9657,7 +9657,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
|
||||
def test_shape_constant_fold(self):
|
||||
class ShapeModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.weight = torch.nn.Buffer(torch.ones(5))
|
||||
|
||||
|
|
@ -9671,7 +9671,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
@skipIfUnsupportedMinOpsetVersion(12)
|
||||
def test_celu(self):
|
||||
class Celu(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.celu = torch.nn.CELU(alpha=1.0)
|
||||
|
||||
|
|
@ -9684,7 +9684,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
@skipIfUnsupportedMinOpsetVersion(12)
|
||||
def test_celu_default(self):
|
||||
class Celu(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.celu = torch.nn.CELU()
|
||||
|
||||
|
|
@ -9697,7 +9697,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
@skipIfUnsupportedMinOpsetVersion(12)
|
||||
def test_celu_alpha(self):
|
||||
class Celu(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.celu = torch.nn.CELU(alpha=2.0)
|
||||
|
||||
|
|
@ -9710,7 +9710,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
@skipIfUnsupportedMinOpsetVersion(12)
|
||||
def test_celu_cast(self):
|
||||
class Celu(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.celu = torch.nn.CELU()
|
||||
|
||||
|
|
@ -10046,7 +10046,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
@skipIfUnsupportedMinOpsetVersion(9)
|
||||
def test_embedding_module(self):
|
||||
class EmbedModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.emb = torch.nn.Embedding(4, 3, padding_idx=1)
|
||||
self.emb2 = torch.nn.Embedding(4, 3, padding_idx=1)
|
||||
|
|
@ -10067,7 +10067,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
self.run_test(model, (x,))
|
||||
|
||||
class EmbedModelWithoutPaddingIdx(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.emb = torch.nn.Embedding(4, 3)
|
||||
|
||||
|
|
@ -10512,7 +10512,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
|
||||
def test_batchnorm_training(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.bn1 = torch.nn.BatchNorm2d(3, affine=False)
|
||||
self.cv1 = torch.nn.Conv2d(3, 3, 10)
|
||||
|
|
@ -10548,7 +10548,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
|
||||
def test_batchnorm_training_mode_fix_layer(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.bn1 = torch.nn.BatchNorm2d(3, affine=True)
|
||||
self.cv1 = torch.nn.Conv2d(3, 3, 10)
|
||||
|
|
@ -10585,7 +10585,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
|
||||
def test_batchnorm_eval_mode_train_layer(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.bn1 = torch.nn.BatchNorm2d(3, affine=True)
|
||||
self.cv1 = torch.nn.Conv2d(3, 3, 10)
|
||||
|
|
@ -10622,7 +10622,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
|
||||
def test_instancenorm_training(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.in1 = torch.nn.InstanceNorm2d(3, affine=True)
|
||||
self.cv1 = torch.nn.Conv2d(3, 3, 10)
|
||||
|
|
@ -10658,7 +10658,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
|
||||
def test_instancenorm_training_mode_fix_layer(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.in1 = torch.nn.InstanceNorm2d(3, affine=True)
|
||||
self.cv1 = torch.nn.Conv2d(3, 3, 10)
|
||||
|
|
@ -10695,7 +10695,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
|
||||
def test_instancenorm_eval_mode_train_layer(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.in1 = torch.nn.InstanceNorm2d(8, affine=True)
|
||||
self.cv1 = torch.nn.Conv2d(8, 8, 10)
|
||||
|
|
@ -10733,7 +10733,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
@skipIfUnsupportedMinOpsetVersion(12)
|
||||
def test_dropout_training(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.dropout = torch.nn.Dropout(0.4)
|
||||
|
||||
|
|
@ -10775,7 +10775,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
@skipIfUnsupportedMinOpsetVersion(12)
|
||||
def test_dropout_training_zero(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.dropout = torch.nn.Dropout(0.5)
|
||||
|
||||
|
|
@ -10839,7 +10839,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
|
||||
def test_conv_bn(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(
|
||||
3, 16, kernel_size=1, stride=2, padding=3, bias=True
|
||||
|
|
@ -10864,7 +10864,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
|
||||
def test_multiple_conv_bn(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv1 = torch.nn.Conv2d(
|
||||
3, 64, kernel_size=7, stride=2, padding=3, bias=False
|
||||
|
|
@ -11003,7 +11003,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
@skipIfUnsupportedMinOpsetVersion(11)
|
||||
def test_resize_images(self):
|
||||
class TransformModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.transform = _init_test_generalized_rcnn_transform()
|
||||
|
||||
|
|
@ -11024,7 +11024,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
@skipScriptTest()
|
||||
def test_transform_images(self):
|
||||
class TransformModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.transform = _init_test_generalized_rcnn_transform()
|
||||
|
||||
|
|
@ -11055,7 +11055,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
@skipScriptTest()
|
||||
def test_rpn(self):
|
||||
class RPNModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.rpn = _init_test_rpn()
|
||||
|
||||
|
|
@ -11094,7 +11094,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
@skipScriptTest()
|
||||
def test_multi_scale_roi_align(self):
|
||||
class TransformModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.model = torchvision.ops.MultiScaleRoIAlign(
|
||||
["feat1", "feat2"], 3, 2
|
||||
|
|
@ -11203,7 +11203,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
return self.module(x) + self.weights
|
||||
|
||||
class Module(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.module = InnerModule(embedding_dim=8)
|
||||
|
||||
|
|
@ -11243,7 +11243,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
) * self.const
|
||||
|
||||
class Module(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.module = InnerModule(embedding_dim=8)
|
||||
|
||||
|
|
@ -11256,7 +11256,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
|
||||
def test_set_attr(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv1d(3, 10, 2)
|
||||
self.b = False
|
||||
|
|
@ -11279,7 +11279,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
@skipIfUnsupportedMinOpsetVersion(11)
|
||||
def test_set_attr_2(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv1d(10, 3, 3)
|
||||
self.conv.bias = torch.nn.Parameter(torch.zeros(3, 10, 3))
|
||||
|
|
@ -11304,7 +11304,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
@skipIfUnsupportedMinOpsetVersion(11)
|
||||
def test_set_attr_3(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv1d(10, 3, 3)
|
||||
self.conv.weight = torch.nn.Parameter(torch.zeros(3, 10))
|
||||
|
|
@ -11331,7 +11331,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
@skipIfUnsupportedMinOpsetVersion(11)
|
||||
def test_set_attr_4(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv1d(10, 3, 3)
|
||||
self.conv.bias = torch.nn.Parameter(torch.zeros(3, 10, 3))
|
||||
|
|
@ -11363,7 +11363,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
@skipIfUnsupportedMinOpsetVersion(11)
|
||||
def test_set_attr_5(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv1d(10, 3, 3)
|
||||
self.conv.bias = torch.nn.Parameter(torch.zeros(3, 10, 3))
|
||||
|
|
@ -11394,7 +11394,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
@skipIfUnsupportedMinOpsetVersion(11)
|
||||
def test_set_attr_in_loop(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv1d(10, 3, 3)
|
||||
self.conv.weight = torch.nn.Parameter(torch.zeros(3, 10))
|
||||
|
|
@ -11422,7 +11422,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
@skipIfUnsupportedMinOpsetVersion(13)
|
||||
def test_set_attr_in_loop_with_list(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv1d(10, 3, 3)
|
||||
self.conv.weight = torch.nn.Parameter(torch.zeros(3, 10))
|
||||
|
|
@ -12196,7 +12196,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
@skipIfUnsupportedMinOpsetVersion(9)
|
||||
def test_hann_window_periodic(self):
|
||||
class HannWindowModule_Periodic(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.window_length = 0
|
||||
|
||||
|
|
@ -12218,7 +12218,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
@skipIfUnsupportedMinOpsetVersion(9)
|
||||
def test_hann_window_not_periodic(self):
|
||||
class HannWindowModule_NotPeriodic(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.window_length = 0
|
||||
|
||||
|
|
@ -12241,7 +12241,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
@skipScriptTest()
|
||||
def test_hann_window_default_values(self):
|
||||
class HannWindowModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.window_length = 0
|
||||
|
||||
|
|
@ -13098,7 +13098,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
@skipIfUnsupportedMinOpsetVersion(13)
|
||||
def test_qat_linear_per_channel(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.quant = torch.ao.quantization.QuantStub()
|
||||
self.linear = torch.nn.Linear(4, 3)
|
||||
|
|
@ -13130,7 +13130,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
@skipIfUnsupportedMinOpsetVersion(13)
|
||||
def test_quantized_list_of_inputs_with_cat(self):
|
||||
class TestModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.quant = torch.ao.quantization.QuantStub()
|
||||
self.dequant = torch.ao.quantization.DeQuantStub()
|
||||
|
|
@ -13151,7 +13151,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
@skipIfUnsupportedMinOpsetVersion(13)
|
||||
def test_qat_relu(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.quant = torch.ao.quantization.QuantStub()
|
||||
self.relu = torch.nn.ReLU()
|
||||
|
|
@ -13173,7 +13173,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
@skipIfUnsupportedMinOpsetVersion(13)
|
||||
def test_qat_conv2d(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.quant = torch.ao.quantization.QuantStub()
|
||||
self.conv = torch.nn.Conv2d(4, 2, 3, stride=2)
|
||||
|
|
@ -13204,7 +13204,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
@skipIfUnsupportedMinOpsetVersion(13)
|
||||
def test_qat_conv2d_relu(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.quant = torch.ao.quantization.QuantStub()
|
||||
self.conv = torch.nn.Conv2d(4, 2, 3, stride=2)
|
||||
|
|
@ -13237,7 +13237,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
@skipIfUnsupportedMinOpsetVersion(13)
|
||||
def test_qat_conv2d_relu_fused(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.quant = torch.ao.quantization.QuantStub()
|
||||
self.conv = torch.nn.Conv2d(4, 2, 3, stride=2)
|
||||
|
|
@ -13271,7 +13271,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
@skipIfUnsupportedMinOpsetVersion(13)
|
||||
def test_qat_linear_relu_fused(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.quant = torch.ao.quantization.QuantStub()
|
||||
self.linear = torch.nn.Linear(4, 2)
|
||||
|
|
@ -13303,7 +13303,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
@skipIfUnsupportedMinOpsetVersion(10)
|
||||
def test_qat_maxpool2d(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.quant = torch.ao.quantization.QuantStub()
|
||||
self.pool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
|
|
@ -13776,7 +13776,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
inputs = (coords0, coords1, edge_from, edge_to)
|
||||
|
||||
class MySAGEConv(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.SAGEConvBlock1 = torch_geometric_nn.SAGEConv(
|
||||
2, 512, normalize=True
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ class TestONNXRuntime_cuda(onnx_test_common._TestONNXRuntime):
|
|||
@skipScriptTest()
|
||||
def test_layer_norm_fp16(self):
|
||||
class LayerNormModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.layer_norm = torch.nn.LayerNorm([10, 10])
|
||||
|
||||
|
|
@ -73,7 +73,7 @@ class TestONNXRuntime_cuda(onnx_test_common._TestONNXRuntime):
|
|||
@skipScriptTest()
|
||||
def test_softmaxCrossEntropy_fusion_fp16(self):
|
||||
class FusionModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.loss = torch.nn.NLLLoss(reduction="none")
|
||||
self.m = torch.nn.LogSoftmax(dim=1)
|
||||
|
|
@ -97,7 +97,7 @@ class TestONNXRuntime_cuda(onnx_test_common._TestONNXRuntime):
|
|||
@skipScriptTest()
|
||||
def test_apex_o2(self):
|
||||
class LinearModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(3, 5)
|
||||
|
||||
|
|
@ -133,7 +133,7 @@ class TestONNXRuntime_cuda(onnx_test_common._TestONNXRuntime):
|
|||
@skipIfNoCuda
|
||||
def test_deduplicate_initializers_diff_devices(self):
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.w = torch.nn.Parameter(
|
||||
torch.ones(2, 3, device=torch.device("cpu"))
|
||||
|
|
|
|||
|
|
@ -397,7 +397,7 @@ class TestUtilityFuns(_BaseTestCase):
|
|||
|
||||
def test_constant_fold_unsqueeze_multi_axies(self):
|
||||
class PReluModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.prelu = torch.nn.PReLU()
|
||||
|
||||
|
|
@ -490,7 +490,7 @@ class TestUtilityFuns(_BaseTestCase):
|
|||
|
||||
def test_constant_fold_lstm(self):
|
||||
class GruNet(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.mygru = torch.nn.GRU(7, 3, 1, bidirectional=False)
|
||||
|
||||
|
|
@ -521,7 +521,7 @@ class TestUtilityFuns(_BaseTestCase):
|
|||
|
||||
def test_constant_fold_transpose_matmul(self):
|
||||
class MatMulNet(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.B = torch.nn.Parameter(torch.ones(5, 3))
|
||||
|
||||
|
|
@ -694,7 +694,7 @@ class TestUtilityFuns(_BaseTestCase):
|
|||
|
||||
def test_constant_fold_shape(self):
|
||||
class ShapeModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.weight = torch.nn.Buffer(torch.ones(5))
|
||||
|
||||
|
|
@ -845,7 +845,7 @@ class TestUtilityFuns(_BaseTestCase):
|
|||
return x * x
|
||||
|
||||
class Outer(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.inner = torch.jit.script(Inner())
|
||||
|
||||
|
|
@ -1137,7 +1137,7 @@ class TestUtilityFuns(_BaseTestCase):
|
|||
|
||||
def test_node_scope(self):
|
||||
class N(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.relu = torch.nn.ReLU()
|
||||
|
||||
|
|
@ -1566,7 +1566,7 @@ class TestUtilityFuns(_BaseTestCase):
|
|||
|
||||
def test_unused_initializers(self):
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv2 = torch.nn.ConvTranspose2d(
|
||||
16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(1, 1)
|
||||
|
|
@ -1593,7 +1593,7 @@ class TestUtilityFuns(_BaseTestCase):
|
|||
|
||||
def test_scripting_param(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(
|
||||
3, 16, kernel_size=1, stride=2, padding=3, bias=True
|
||||
|
|
@ -1629,7 +1629,7 @@ class TestUtilityFuns(_BaseTestCase):
|
|||
|
||||
def test_fuse_conv_bn(self):
|
||||
class Fuse(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(
|
||||
3, 2, kernel_size=1, stride=2, padding=3, bias=True
|
||||
|
|
@ -1701,7 +1701,7 @@ class TestUtilityFuns(_BaseTestCase):
|
|||
|
||||
def test_onnx_value_name(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.in_weight = torch.nn.Parameter(torch.Tensor(3, 3))
|
||||
self.in_bias = torch.nn.Parameter(torch.Tensor(3))
|
||||
|
|
@ -1734,7 +1734,7 @@ class TestUtilityFuns(_BaseTestCase):
|
|||
|
||||
def test_onnx_node_naming(self):
|
||||
class MainModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._module_1 = torch.nn.Linear(10, 10)
|
||||
self._module_2 = torch.nn.Linear(10, 10)
|
||||
|
|
@ -1773,7 +1773,7 @@ class TestUtilityFuns(_BaseTestCase):
|
|||
|
||||
def _test_deduplicate_initializers(self, torchscript=False):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.layer1 = torch.nn.Linear(3, 3)
|
||||
self.layer2 = torch.nn.Linear(3, 3)
|
||||
|
|
@ -1841,7 +1841,7 @@ class TestUtilityFuns(_BaseTestCase):
|
|||
@skipIfNoCuda
|
||||
def test_deduplicate_initializers_diff_devices(self):
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.w_cpu = torch.nn.Parameter(
|
||||
torch.ones(3, device=torch.device("cpu"))
|
||||
|
|
@ -1914,7 +1914,7 @@ class TestUtilityFuns(_BaseTestCase):
|
|||
# upsample scale is a constant, not a model parameter,
|
||||
# therefore should be ignored by shared weight deduplication.
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.upsample_1 = torch.nn.Upsample(scale_factor=2)
|
||||
self.upsample_2 = torch.nn.Upsample(scale_factor=2)
|
||||
|
|
|
|||
|
|
@ -206,7 +206,7 @@ class TestFindMismatch(pytorch_test_common.ExportTestCase):
|
|||
)
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.layers = torch.nn.Sequential(
|
||||
torch.nn.Linear(3, 4),
|
||||
|
|
|
|||
|
|
@ -110,7 +110,7 @@ class TestFxToOnnxWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
|
|||
|
||||
def test_onnx_program_supports_retraced_graph(self):
|
||||
class Bar(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.buf = torch.nn.Buffer(torch.ones(1))
|
||||
|
||||
|
|
@ -119,7 +119,7 @@ class TestFxToOnnxWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
|
|||
return x.sum() + self.buf.sum()
|
||||
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.buf = torch.nn.Buffer(torch.zeros(1))
|
||||
self.bar = Bar()
|
||||
|
|
@ -209,7 +209,7 @@ class TestFxToOnnxWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
|
|||
for persistent in (True, False):
|
||||
|
||||
class CustomModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.register_buffer(
|
||||
"my_buffer", torch.tensor(4.0), persistent=persistent
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ load_tests = load_tests
|
|||
|
||||
class TestLRScheduler(TestCase):
|
||||
class SchedulerTestNet(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv1 = torch.nn.Conv2d(1, 1, 1)
|
||||
self.conv2 = torch.nn.Conv2d(1, 1, 1)
|
||||
|
|
@ -1572,7 +1572,7 @@ class TestLRScheduler(TestCase):
|
|||
|
||||
# Case 3: Custom `scale_fn`, a callable class
|
||||
class ScaleFn:
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.x = 0.5
|
||||
|
||||
def __call__(self, _):
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ class NewModule(torch.nn.Module):
|
|||
class UsesInterface(torch.nn.Module):
|
||||
proxy_mod: ModuleInterface
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.proxy_mod = OrigModule()
|
||||
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ def uses_script_class(x):
|
|||
|
||||
|
||||
class IdListFeature:
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.id_list = torch.ones(1, 1)
|
||||
|
||||
def returns_self(self) -> "IdListFeature":
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ try:
|
|||
from torchvision.models import resnet18
|
||||
|
||||
class TorchVisionTest(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.tvmod = resnet18()
|
||||
|
||||
|
|
|
|||
|
|
@ -273,7 +273,7 @@ class TestDependencyAPI(PackageTestCase):
|
|||
return module
|
||||
|
||||
class BrokenImporter(Importer):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.modules = {
|
||||
"foo": create_module("foo"),
|
||||
"bar": create_module("bar"),
|
||||
|
|
|
|||
|
|
@ -169,7 +169,7 @@ class TestPackageFX(PackageTestCase):
|
|||
|
||||
def test_package_fx_wrap(self):
|
||||
class TestModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, a):
|
||||
|
|
|
|||
|
|
@ -83,7 +83,7 @@ class TestPackageScript(PackageTestCase):
|
|||
class UsesInterface(torch.nn.Module):
|
||||
proxy_mod: ModuleInterface
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.proxy_mod = ImplementsInterface()
|
||||
|
||||
|
|
@ -246,7 +246,7 @@ class TestPackageScript(PackageTestCase):
|
|||
return input
|
||||
|
||||
class TopMod(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.modB = Submod()
|
||||
|
||||
|
|
@ -710,7 +710,7 @@ class TestPackageScript(PackageTestCase):
|
|||
"""
|
||||
|
||||
class TorchVisionTestInline(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.tvmod = resnet18()
|
||||
|
||||
|
|
@ -749,7 +749,7 @@ class TestPackageScript(PackageTestCase):
|
|||
"""
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.foo = torch.ones(2, 3)
|
||||
|
||||
|
|
|
|||
|
|
@ -66,7 +66,7 @@ class LazyLinear(torch.nn.Module):
|
|||
|
||||
|
||||
class RecordInputOutputDispatchMode(torch.utils._python_dispatch.TorchDispatchMode):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.results = []
|
||||
|
||||
def mark_region(self, name: str):
|
||||
|
|
|
|||
|
|
@ -244,7 +244,7 @@ class TestProfiler(TestCase):
|
|||
return w.sum()
|
||||
|
||||
class DummyModule(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(
|
||||
3, 2, kernel_size=1, stride=2, padding=3, bias=False
|
||||
|
|
@ -351,7 +351,7 @@ class TestProfiler(TestCase):
|
|||
end_barrier = threading.Barrier(num_threads, timeout=timeout)
|
||||
|
||||
class Task(threading.Thread):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self._end_gate = threading.Event()
|
||||
super().__init__(daemon=True)
|
||||
self.start()
|
||||
|
|
@ -763,7 +763,7 @@ class TestProfiler(TestCase):
|
|||
return x + 2
|
||||
|
||||
class C(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.A0 = A()
|
||||
self.B0 = B()
|
||||
|
|
@ -1423,7 +1423,7 @@ from torch.profiler import supported_activities, profile
|
|||
from torch.autograd.profiler import KinetoStepTracker
|
||||
|
||||
class SimpleNet(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(10, 5)
|
||||
self.fc2 = nn.Linear(5, 2)
|
||||
|
|
@ -1914,7 +1914,7 @@ assert KinetoStepTracker.current_step() == initial_step + 2 * niters
|
|||
|
||||
|
||||
class SimpleNet(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(10, 5)
|
||||
self.fc2 = nn.Linear(5, 2)
|
||||
|
|
|
|||
|
|
@ -556,7 +556,7 @@ class TestProfilerTree(TestCase):
|
|||
@ProfilerTree.test
|
||||
def test_profiler_experimental_tree_with_stack_and_modules(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.layers = [
|
||||
torch.nn.ReLU(),
|
||||
|
|
|
|||
|
|
@ -96,7 +96,7 @@ class TestRecordFunction(TestCase):
|
|||
|
||||
def test_datapipe_delegation_with_profiler(self):
|
||||
class IDPIterator(torch.utils.data.IterDataPipe):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.data = list(range(10))
|
||||
self._idx = 0
|
||||
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ def find_node_with_regex(nodes, pattern):
|
|||
|
||||
|
||||
class SimpleNet(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(10, 5)
|
||||
self.fc2 = nn.Linear(5, 2)
|
||||
|
|
|
|||
|
|
@ -514,7 +514,7 @@ class TestSerialization(TestCase):
|
|||
)
|
||||
def test_lstm(self):
|
||||
class LSTMModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.lstm = nnqd.LSTM(input_size=3, hidden_size=7, num_layers=1).to(
|
||||
dtype=torch.float
|
||||
|
|
@ -544,7 +544,7 @@ class TestSerialization(TestCase):
|
|||
|
||||
def test_default_qat_qconfig(self):
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(5, 5)
|
||||
self.relu = nn.ReLU()
|
||||
|
|
|
|||
|
|
@ -70,7 +70,7 @@ class TestAdaround(QuantizationTestCase):
|
|||
|
||||
def get_feed_forward_wrapper(self):
|
||||
class FeedForwardWrapper(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, model, sample):
|
||||
|
|
@ -81,7 +81,7 @@ class TestAdaround(QuantizationTestCase):
|
|||
|
||||
def test_linear_chain(self):
|
||||
class LinearChain(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear1 = nn.Linear(3, 4)
|
||||
self.linear2 = nn.Linear(4, 5)
|
||||
|
|
@ -110,7 +110,7 @@ class TestAdaround(QuantizationTestCase):
|
|||
|
||||
def test_conv_chain(self):
|
||||
class ConvChain(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv2d1 = nn.Conv2d(3, 4, 5, 5)
|
||||
self.conv2d2 = nn.Conv2d(4, 5, 5, 5)
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ import itertools
|
|||
import tempfile
|
||||
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.qscheme = torch.per_tensor_symmetric
|
||||
|
||||
|
|
@ -1414,7 +1414,7 @@ class TestQuantizedTensor(TestCase):
|
|||
class M(torch.jit.ScriptModule):
|
||||
__constants__ = ['fname']
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.fname = fname
|
||||
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ class TestUtils(TestCase):
|
|||
|
||||
def test_get_fqn_to_example_inputs_simple(self):
|
||||
class Sub(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear1 = torch.nn.Linear(5, 5)
|
||||
self.linear2 = torch.nn.Linear(5, 5)
|
||||
|
|
@ -30,7 +30,7 @@ class TestUtils(TestCase):
|
|||
return x
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear1 = torch.nn.Linear(5, 5)
|
||||
self.linear2 = torch.nn.Linear(5, 5)
|
||||
|
|
@ -57,7 +57,7 @@ class TestUtils(TestCase):
|
|||
""" Test that we can get example inputs for functions with default keyword arguments
|
||||
"""
|
||||
class Sub(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear1 = torch.nn.Linear(5, 5)
|
||||
self.linear2 = torch.nn.Linear(5, 5)
|
||||
|
|
@ -68,7 +68,7 @@ class TestUtils(TestCase):
|
|||
return x
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear1 = torch.nn.Linear(5, 5)
|
||||
self.linear2 = torch.nn.Linear(5, 5)
|
||||
|
|
@ -98,7 +98,7 @@ class TestUtils(TestCase):
|
|||
""" Test that we can record complex example inputs such as lists and dicts
|
||||
"""
|
||||
class Sub(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear1 = torch.nn.Linear(5, 5)
|
||||
self.linear2 = torch.nn.Linear(5, 5)
|
||||
|
|
@ -109,7 +109,7 @@ class TestUtils(TestCase):
|
|||
return x
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear1 = torch.nn.Linear(5, 5)
|
||||
self.linear2 = torch.nn.Linear(5, 5)
|
||||
|
|
|
|||
|
|
@ -996,7 +996,7 @@ class TestDistributed(QuantizationTestCase):
|
|||
with override_quantized_engine('fbgemm'):
|
||||
# create conv-bn
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(4, 1, 3, padding=1)
|
||||
self.bn = nn.BatchNorm2d(1)
|
||||
|
|
@ -1045,7 +1045,7 @@ class TestDistributed(QuantizationTestCase):
|
|||
"""
|
||||
class Model(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(1, 1, 1)
|
||||
self.bn = nn.BatchNorm2d(1)
|
||||
|
|
@ -1276,7 +1276,7 @@ class TestFusedObsFakeQuantModule(TestCase):
|
|||
|
||||
def test_embedding_bag_qat_config(self):
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.emb1 = torch.nn.EmbeddingBag(num_embeddings=10, embedding_dim=12,
|
||||
include_last_offset=True, scale_grad_by_freq=False, mode='sum')
|
||||
|
|
@ -1356,7 +1356,7 @@ class TestFusedObsFakeQuantModule(TestCase):
|
|||
|
||||
def test_default_fused_qat_config(self):
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(2, 2)
|
||||
self.relu = nn.ReLU()
|
||||
|
|
|
|||
|
|
@ -629,7 +629,7 @@ class TestFakeQuantizeOps(TestCase):
|
|||
|
||||
def test_fake_quant_preserves_qparam_shapes_for_activations(self):
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(4, 4)
|
||||
|
||||
|
|
|
|||
|
|
@ -67,7 +67,7 @@ class TestBiasCorrectionEager(QuantizationTestCase):
|
|||
@skipIfNoFBGEMM
|
||||
def test_linear_chain(self):
|
||||
class LinearChain(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear1 = nn.Linear(3, 4)
|
||||
self.linear2 = nn.Linear(4, 5)
|
||||
|
|
@ -86,7 +86,7 @@ class TestBiasCorrectionEager(QuantizationTestCase):
|
|||
@skipIfNoFBGEMM
|
||||
def test_conv_chain(self):
|
||||
class ConvChain(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv2d1 = nn.Conv2d(3, 4, 5, 5)
|
||||
self.conv2d2 = nn.Conv2d(4, 5, 5, 5)
|
||||
|
|
|
|||
|
|
@ -72,7 +72,7 @@ class TestEqualizeEager(QuantizationTestCase):
|
|||
given the same input
|
||||
'''
|
||||
class ChainModule(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear1 = nn.Linear(3, 4)
|
||||
self.linear2 = nn.Linear(4, 5)
|
||||
|
|
@ -108,7 +108,7 @@ class TestEqualizeEager(QuantizationTestCase):
|
|||
yield the same output given the same input
|
||||
'''
|
||||
class M(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(3, 3, 1).to(dtype=torch.float)
|
||||
self.relu1 = nn.ReLU(inplace=False).to(dtype=torch.float)
|
||||
|
|
@ -154,7 +154,7 @@ class TestEqualizeEager(QuantizationTestCase):
|
|||
yield the same output given the same input
|
||||
'''
|
||||
class M(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear1 = nn.Linear(3, 4)
|
||||
self.relu1 = nn.ReLU(inplace=False).to(dtype=torch.float)
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ from torch.testing._internal.common_quantized import override_qengines
|
|||
from torch.testing._internal.common_utils import IS_ARM64
|
||||
|
||||
class SubModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.qconfig = default_qconfig
|
||||
self.mod1 = torch.nn.Conv2d(3, 3, 3, bias=False).to(dtype=torch.float)
|
||||
|
|
@ -56,7 +56,7 @@ class SubModule(torch.nn.Module):
|
|||
|
||||
|
||||
class ModelWithSubModules(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.mod1 = SubModule()
|
||||
self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float)
|
||||
|
|
@ -68,7 +68,7 @@ class ModelWithSubModules(torch.nn.Module):
|
|||
|
||||
|
||||
class ModelWithFunctionals(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.mycat = nnq.FloatFunctional()
|
||||
self.myadd = nnq.FloatFunctional()
|
||||
|
|
|
|||
|
|
@ -81,7 +81,7 @@ class TestQuantizeEagerOps(QuantizationTestCase):
|
|||
extra_module_kwargs,
|
||||
input_size):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = float_module_class(**extra_module_kwargs)
|
||||
self.quant = QuantStub()
|
||||
|
|
@ -94,7 +94,7 @@ class TestQuantizeEagerOps(QuantizationTestCase):
|
|||
return x
|
||||
|
||||
class RefM(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = float_module_class(**extra_module_kwargs)
|
||||
self.quant1 = QuantStub()
|
||||
|
|
@ -203,7 +203,7 @@ class TestQuantizeEagerOps(QuantizationTestCase):
|
|||
def test_int16_reference_module(self):
|
||||
|
||||
class RefM(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = nn.ConvTranspose2d(1, 1, 1)
|
||||
self.quant1 = QuantStub()
|
||||
|
|
@ -277,7 +277,7 @@ class TestQuantizeEagerOps(QuantizationTestCase):
|
|||
extra_module_kwargs: keyword args to instantiate the float module
|
||||
"""
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.activation_op = float_module_class(**extra_module_kwargs)
|
||||
self.quant = QuantStub()
|
||||
|
|
@ -839,7 +839,7 @@ class TestQuantizeEagerPTQStatic(QuantizationTestCase):
|
|||
self.checkScriptable(quantized_model, [[indices, offsets, per_sample_weights]], check_save_load=True)
|
||||
|
||||
class EmbeddingBagWithLinear(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.emb = torch.nn.EmbeddingBag(num_embeddings=10, embedding_dim=12,
|
||||
include_last_offset=True, scale_grad_by_freq=False, mode='sum')
|
||||
|
|
@ -861,7 +861,7 @@ class TestQuantizeEagerPTQStatic(QuantizationTestCase):
|
|||
@skipIfNoFBGEMM
|
||||
def test_custom_module_class(self):
|
||||
class CustomModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(1, 1, 1)
|
||||
|
||||
|
|
@ -901,7 +901,7 @@ class TestQuantizeEagerPTQStatic(QuantizationTestCase):
|
|||
return quantized
|
||||
|
||||
class Sub(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.custom = CustomModule()
|
||||
|
||||
|
|
@ -909,7 +909,7 @@ class TestQuantizeEagerPTQStatic(QuantizationTestCase):
|
|||
return self.custom(x)
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.quant = QuantStub()
|
||||
self.conv = torch.nn.Conv2d(1, 1, 1)
|
||||
|
|
@ -924,7 +924,7 @@ class TestQuantizeEagerPTQStatic(QuantizationTestCase):
|
|||
return x
|
||||
|
||||
class RefM(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.quant = QuantStub()
|
||||
self.conv1 = torch.nn.Conv2d(1, 1, 1)
|
||||
|
|
@ -1031,7 +1031,7 @@ class TestQuantizeEagerPTQStatic(QuantizationTestCase):
|
|||
`non_leaf_module_list`.
|
||||
"""
|
||||
class MyModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.quant = QuantStub()
|
||||
self.sigmoid = torch.nn.Sigmoid()
|
||||
|
|
@ -1477,7 +1477,7 @@ class TestQuantizeEagerPTQDynamic(QuantizationTestCase):
|
|||
@skipIfNoFBGEMM
|
||||
def test_embedding_bag_dynamic(self):
|
||||
class EmbeddingBagWithLinear(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.emb = torch.nn.EmbeddingBag(num_embeddings=10, embedding_dim=12,
|
||||
include_last_offset=True, scale_grad_by_freq=False, mode='sum')
|
||||
|
|
@ -1502,7 +1502,7 @@ class TestQuantizeEagerPTQDynamic(QuantizationTestCase):
|
|||
@skipIfNoFBGEMM
|
||||
def test_embedding_ops_dynamic(self):
|
||||
class EmbeddingWithLinear(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.emb = torch.nn.Embedding(
|
||||
num_embeddings=10, embedding_dim=12, scale_grad_by_freq=False)
|
||||
|
|
|
|||
|
|
@ -555,7 +555,7 @@ class TestQuantizeEagerQAT(QuantizationTestCase):
|
|||
|
||||
def test_add_scalar_uses_input_qparams(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.quant = torch.ao.quantization.QuantStub()
|
||||
self.ff = torch.ao.nn.quantized.FloatFunctional()
|
||||
|
|
@ -576,7 +576,7 @@ class TestQuantizeEagerQAT(QuantizationTestCase):
|
|||
|
||||
def test_mul_scalar_uses_input_qparams(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.quant = torch.ao.quantization.QuantStub()
|
||||
self.ff = torch.ao.nn.quantized.FloatFunctional()
|
||||
|
|
@ -642,7 +642,7 @@ class TestQuantizeEagerQAT(QuantizationTestCase):
|
|||
class TestQuantizeEagerQATNumerics(QuantizationTestCase):
|
||||
def _test_activation_convert_numerics_impl(self, Act, data):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.act = Act()
|
||||
self.quant = QuantStub()
|
||||
|
|
@ -664,7 +664,7 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
|
|||
|
||||
def test_fixed_qparam_ops(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.sigmoid = torch.nn.Sigmoid()
|
||||
self.hardsigmoid = torch.nn.Hardsigmoid()
|
||||
|
|
@ -717,7 +717,7 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
|
|||
|
||||
def test_relu(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
|
|
|
|||
|
|
@ -835,7 +835,7 @@ class TestEqualizeFx(QuantizationTestCase):
|
|||
torch.manual_seed(1)
|
||||
|
||||
class M(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.bot = torch.nn.Sequential(torch.nn.Linear(5, 5))
|
||||
self.top = torch.nn.Sequential(torch.nn.Linear(5, 5))
|
||||
|
|
|
|||
|
|
@ -84,7 +84,7 @@ FUSION_CONV_LINEAR_EXAMPLE = torch.nn.Sequential(
|
|||
# Test class
|
||||
# example model to use for tests
|
||||
class ThreeOps(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(3, 3)
|
||||
self.bn = nn.BatchNorm2d(3)
|
||||
|
|
@ -100,7 +100,7 @@ class ThreeOps(nn.Module):
|
|||
return (torch.randn(1, 3, 3, 3),)
|
||||
|
||||
class TwoThreeOps(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.block1 = ThreeOps()
|
||||
self.block2 = ThreeOps()
|
||||
|
|
@ -233,7 +233,7 @@ class TestFxModelReportDetector(QuantizationTestCase):
|
|||
|
||||
# we need to design the model
|
||||
class ConvLinearModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv1 = torch.nn.Conv2d(3, 3, 2, 1)
|
||||
self.fc1 = torch.nn.Linear(9, 27)
|
||||
|
|
@ -433,7 +433,7 @@ class TestFxModelReportDetector(QuantizationTestCase):
|
|||
|
||||
# first we want a QAT model
|
||||
class QATConvLinearReluModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
# QuantStub converts tensors from floating point to quantized
|
||||
self.quant = torch.ao.quantization.QuantStub()
|
||||
|
|
@ -505,7 +505,7 @@ Partition on Output
|
|||
|
||||
class TestFxModelReportObserver(QuantizationTestCase):
|
||||
class NestedModifiedSingleLayerLinear(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.obs1 = ModelReportObserver()
|
||||
self.mod1 = SingleLayerLinearModel()
|
||||
|
|
@ -636,7 +636,7 @@ class TestFxModelReportObserver(QuantizationTestCase):
|
|||
|
||||
# model specific to this test
|
||||
class NestedModifiedObserverAfterRelu(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.obs1 = ModelReportObserver()
|
||||
self.mod1 = SingleLayerLinearModel()
|
||||
|
|
@ -673,7 +673,7 @@ class TestFxModelReportObserver(QuantizationTestCase):
|
|||
|
||||
# set up a basic model
|
||||
class TinyNestModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.obs1 = ModelReportObserver()
|
||||
self.fc1 = torch.nn.Linear(5, 5).to(dtype=torch.float)
|
||||
|
|
@ -688,7 +688,7 @@ class TestFxModelReportObserver(QuantizationTestCase):
|
|||
return x
|
||||
|
||||
class LargerIncludeNestModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.obs1 = ModelReportObserver()
|
||||
self.nested = TinyNestModule()
|
||||
|
|
@ -727,7 +727,7 @@ class TestFxModelReportObserver(QuantizationTestCase):
|
|||
return x
|
||||
|
||||
class HighDimensionNet(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.obs1 = ModelReportObserver()
|
||||
self.fc1 = torch.nn.Linear(3, 7)
|
||||
|
|
@ -786,7 +786,7 @@ class TestFxModelReportDetectDynamicStatic(QuantizationTestCase):
|
|||
@skipIfNoFBGEMM
|
||||
def test_nested_detection_case(self):
|
||||
class SingleLinear(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(3, 3)
|
||||
|
||||
|
|
@ -795,7 +795,7 @@ class TestFxModelReportDetectDynamicStatic(QuantizationTestCase):
|
|||
return x
|
||||
|
||||
class TwoBlockNet(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.block1 = SingleLinear()
|
||||
self.block2 = SingleLinear()
|
||||
|
|
@ -1266,7 +1266,7 @@ class TestFxDetectInputWeightEqualization(QuantizationTestCase):
|
|||
return x
|
||||
|
||||
class TwoBlockComplexNet(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.block1 = TestFxDetectInputWeightEqualization.SimpleConv((3, 32))
|
||||
self.block2 = TestFxDetectInputWeightEqualization.SimpleConv((3, 3))
|
||||
|
|
@ -1292,7 +1292,7 @@ class TestFxDetectInputWeightEqualization(QuantizationTestCase):
|
|||
return (torch.randn((1, 3, 28, 28)),)
|
||||
|
||||
class ReluOnly(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.relu = torch.nn.ReLU()
|
||||
|
||||
|
|
|
|||
|
|
@ -100,7 +100,7 @@ from torch.ao.quantization.fx.quantize_handler import _get_pattern_to_quantize_h
|
|||
# across various different files, speed of debugging on individual test cases
|
||||
# decreases.
|
||||
class LinearReluFunctional(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.w1 = nn.Parameter(torch.empty(4, 4))
|
||||
self.b1 = nn.Parameter(torch.zeros(4))
|
||||
|
|
@ -113,7 +113,7 @@ class LinearReluFunctional(nn.Module):
|
|||
|
||||
|
||||
class LinearFunctional(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.w1 = nn.Parameter(torch.empty(4, 4))
|
||||
self.b1 = nn.Parameter(torch.zeros(4))
|
||||
|
|
@ -125,7 +125,7 @@ class LinearFunctional(nn.Module):
|
|||
|
||||
|
||||
class LinearReluLinearFunctional(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.w = nn.Parameter(torch.Tensor(4, 4))
|
||||
self.b = nn.Parameter(torch.zeros(4))
|
||||
|
|
@ -150,7 +150,7 @@ class AddMulFunctional(nn.Module):
|
|||
|
||||
|
||||
class AllConvAndLinearFusionModules(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
# conv1d
|
||||
self.conv1d_0 = nn.Conv1d(1, 1, 1)
|
||||
|
|
@ -331,7 +331,7 @@ class TestFXGraphMatcher(QuantizationTestCase):
|
|||
@skipIfNoFBGEMM
|
||||
def test_simple_fun(self):
|
||||
class M(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.w = nn.Parameter(torch.empty(1, 4))
|
||||
self.b = nn.Parameter(torch.zeros(1))
|
||||
|
|
@ -495,7 +495,7 @@ class TestFXGraphMatcher(QuantizationTestCase):
|
|||
@skipIfNoFBGEMM
|
||||
def test_nodes_with_equal_types_get_matched(self):
|
||||
class M(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(1, 1, 1)
|
||||
self.conv2 = nn.Conv2d(1, 1, 1)
|
||||
|
|
@ -1241,7 +1241,7 @@ class TestFXNumericSuiteCoreAPIs(FXNumericSuiteQuantizationTestCase):
|
|||
Verifies that logging inputs works correctly
|
||||
"""
|
||||
class M(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(1, 1, 1)
|
||||
|
||||
|
|
@ -1263,7 +1263,7 @@ class TestFXNumericSuiteCoreAPIs(FXNumericSuiteQuantizationTestCase):
|
|||
signature for fp32 and int8 tensors.
|
||||
"""
|
||||
class M(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.max_pool_2d = nn.MaxPool2d(2)
|
||||
|
||||
|
|
@ -1347,7 +1347,7 @@ class TestFXNumericSuiteCoreAPIs(FXNumericSuiteQuantizationTestCase):
|
|||
int8 inputs.
|
||||
"""
|
||||
class M(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
|
|
@ -1401,7 +1401,7 @@ class TestFXNumericSuiteCoreAPIs(FXNumericSuiteQuantizationTestCase):
|
|||
return (x1, x2)
|
||||
|
||||
class M2(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.m1 = M1()
|
||||
|
||||
|
|
@ -1446,7 +1446,7 @@ class TestFXNumericSuiteCoreAPIs(FXNumericSuiteQuantizationTestCase):
|
|||
return x
|
||||
|
||||
class M(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(1, 1)
|
||||
self.user_module = UserModule()
|
||||
|
|
@ -1682,7 +1682,7 @@ class TestFXNumericSuiteCoreAPIs(FXNumericSuiteQuantizationTestCase):
|
|||
Verify that NS APIs work on user defined functions
|
||||
"""
|
||||
class M1(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.w1 = nn.Parameter(torch.empty(1, 1))
|
||||
self.b1 = nn.Parameter(torch.zeros(1))
|
||||
|
|
@ -1695,7 +1695,7 @@ class TestFXNumericSuiteCoreAPIs(FXNumericSuiteQuantizationTestCase):
|
|||
return x
|
||||
|
||||
class M2(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.w1 = nn.Parameter(torch.empty(1, 1))
|
||||
self.b1 = nn.Parameter(torch.zeros(1))
|
||||
|
|
@ -1881,7 +1881,7 @@ class TestFXNumericSuiteCoreAPIs(FXNumericSuiteQuantizationTestCase):
|
|||
@skipIfNoFBGEMM
|
||||
def test_int8_shadows_fp32_coverage(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.adaptive_avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.conv = nn.Conv2d(1, 1, 1)
|
||||
|
|
@ -2048,7 +2048,7 @@ class TestFXNumericSuiteCoreAPIs(FXNumericSuiteQuantizationTestCase):
|
|||
def test_linear_kwargs_shadow(self):
|
||||
|
||||
class M(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.w1 = nn.Parameter(torch.empty(4, 4))
|
||||
self.b1 = nn.Parameter(torch.zeros(4))
|
||||
|
|
@ -2104,7 +2104,7 @@ class TestFXNumericSuiteNShadows(FXNumericSuiteQuantizationTestCase):
|
|||
@withQNNPACKBackend
|
||||
def test_linear_mod(self):
|
||||
class M(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(2, 2)
|
||||
|
||||
|
|
@ -2122,7 +2122,7 @@ class TestFXNumericSuiteNShadows(FXNumericSuiteQuantizationTestCase):
|
|||
@withQNNPACKBackend
|
||||
def test_linear_relu_mod(self):
|
||||
class M(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(2, 2)
|
||||
self.fc2 = nn.Linear(2, 2)
|
||||
|
|
@ -2148,7 +2148,7 @@ class TestFXNumericSuiteNShadows(FXNumericSuiteQuantizationTestCase):
|
|||
@withQNNPACKBackend
|
||||
def test_conv_bn_relu_mod(self):
|
||||
class M(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(1, 1, 1)
|
||||
self.bn = nn.BatchNorm2d(1)
|
||||
|
|
@ -2173,7 +2173,7 @@ class TestFXNumericSuiteNShadows(FXNumericSuiteQuantizationTestCase):
|
|||
@withQNNPACKBackend
|
||||
def test_functions(self):
|
||||
class M(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.w1 = nn.Parameter(torch.randn(2, 2))
|
||||
self.b1 = nn.Parameter(torch.zeros(2))
|
||||
|
|
@ -2212,7 +2212,7 @@ class TestFXNumericSuiteNShadows(FXNumericSuiteQuantizationTestCase):
|
|||
@withQNNPACKBackend
|
||||
def test_partial_qconfig_mapping(self):
|
||||
class M(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.fc = nn.Linear(2, 2)
|
||||
self.w1 = nn.Parameter(torch.randn(2, 2))
|
||||
|
|
@ -2504,7 +2504,7 @@ class TestFXNumericSuiteNShadows(FXNumericSuiteQuantizationTestCase):
|
|||
@withQNNPACKBackend
|
||||
def test_custom_functions_and_tracer(self):
|
||||
class M(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(2, 2)
|
||||
self.fc2 = nn.Linear(2, 2)
|
||||
|
|
@ -2571,7 +2571,7 @@ class TestFXNumericSuiteNShadows(FXNumericSuiteQuantizationTestCase):
|
|||
@withQNNPACKBackend
|
||||
def test_extract_weights_linear(self):
|
||||
class M(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.w1 = nn.Parameter(torch.randn(2, 2))
|
||||
self.b1 = nn.Parameter(torch.randn(2))
|
||||
|
|
@ -2710,7 +2710,7 @@ class TestFXNumericSuiteNShadows(FXNumericSuiteQuantizationTestCase):
|
|||
@withQNNPACKBackend
|
||||
def test_add_loggers_functions(self):
|
||||
class M(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.w1 = nn.Parameter(torch.randn(2, 2))
|
||||
self.b1 = nn.Parameter(torch.randn(2))
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -339,7 +339,7 @@ class TestSubgraphRewriter(JitTestCase):
|
|||
Credit to Jerry Zhang (GitHub: jerryzh168) for this test case
|
||||
"""
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.dtype = torch.float16
|
||||
|
||||
|
|
@ -378,7 +378,7 @@ class TestSubgraphRewriter(JitTestCase):
|
|||
|
||||
def test_subgraph_rewriter_replaces_referenced_submodules(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.sigmoid = torch.nn.Sigmoid()
|
||||
self.submod = torch.nn.ReLU()
|
||||
|
|
@ -388,7 +388,7 @@ class TestSubgraphRewriter(JitTestCase):
|
|||
return self.submod(self.sigmoid(x))
|
||||
|
||||
class Pattern(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.sigmoid = torch.nn.Sigmoid()
|
||||
self.submod = torch.nn.ReLU()
|
||||
|
|
@ -397,7 +397,7 @@ class TestSubgraphRewriter(JitTestCase):
|
|||
return self.submod(self.sigmoid(x))
|
||||
|
||||
class Replacement(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.id = torch.nn.Identity()
|
||||
self.submod = torch.nn.ReLU()
|
||||
|
|
@ -406,7 +406,7 @@ class TestSubgraphRewriter(JitTestCase):
|
|||
return self.submod(self.id(x))
|
||||
|
||||
class Comparison(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.id = torch.nn.Identity()
|
||||
self.submod = torch.nn.ReLU()
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user