mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[BE][PYFMT] migrate PYFMT for test/[i-z]*/ to ruff format (#144556)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144556 Approved by: https://github.com/ezyang
This commit is contained in:
parent
19ce1beb05
commit
775788f93b
|
|
@ -802,7 +802,8 @@ class AddedAttributesTest(JitBackendTestCase):
|
||||||
# Attach bundled inputs which adds several attributes and functions to the model
|
# Attach bundled inputs which adds several attributes and functions to the model
|
||||||
self.lowered_module = (
|
self.lowered_module = (
|
||||||
torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
|
torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
|
||||||
lowered_module, input # noqa: F821
|
lowered_module, # noqa: F821
|
||||||
|
input,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
post_bundled = self.lowered_module(
|
post_bundled = self.lowered_module(
|
||||||
|
|
|
||||||
|
|
@ -279,23 +279,23 @@ class TestModuleContainers(JitTestCase):
|
||||||
self.moduledict = CustomModuleDict({"submod": self.submod})
|
self.moduledict = CustomModuleDict({"submod": self.submod})
|
||||||
|
|
||||||
def forward(self, inputs):
|
def forward(self, inputs):
|
||||||
assert (
|
assert self.modulelist[0] is self.submod, (
|
||||||
self.modulelist[0] is self.submod
|
"__getitem__ failing for ModuleList"
|
||||||
), "__getitem__ failing for ModuleList"
|
)
|
||||||
assert len(self.modulelist) == 1, "__len__ failing for ModuleList"
|
assert len(self.modulelist) == 1, "__len__ failing for ModuleList"
|
||||||
for module in self.modulelist:
|
for module in self.modulelist:
|
||||||
assert module is self.submod, "__iter__ failing for ModuleList"
|
assert module is self.submod, "__iter__ failing for ModuleList"
|
||||||
|
|
||||||
assert (
|
assert self.sequential[0] is self.submod, (
|
||||||
self.sequential[0] is self.submod
|
"__getitem__ failing for Sequential"
|
||||||
), "__getitem__ failing for Sequential"
|
)
|
||||||
assert len(self.sequential) == 1, "__len__ failing for Sequential"
|
assert len(self.sequential) == 1, "__len__ failing for Sequential"
|
||||||
for module in self.sequential:
|
for module in self.sequential:
|
||||||
assert module is self.submod, "__iter__ failing for Sequential"
|
assert module is self.submod, "__iter__ failing for Sequential"
|
||||||
|
|
||||||
assert (
|
assert self.moduledict["submod"] is self.submod, (
|
||||||
self.moduledict["submod"] is self.submod
|
"__getitem__ failing for ModuleDict"
|
||||||
), "__getitem__ failing for ModuleDict"
|
)
|
||||||
assert len(self.moduledict) == 1, "__len__ failing for ModuleDict"
|
assert len(self.moduledict) == 1, "__len__ failing for ModuleDict"
|
||||||
|
|
||||||
# note: unable to index moduledict with a string variable currently
|
# note: unable to index moduledict with a string variable currently
|
||||||
|
|
@ -439,9 +439,9 @@ class TestModuleContainers(JitTestCase):
|
||||||
self.moduledict = CustomModuleDict()
|
self.moduledict = CustomModuleDict()
|
||||||
|
|
||||||
def forward(self, inputs):
|
def forward(self, inputs):
|
||||||
assert (
|
assert "submod" not in self.moduledict, (
|
||||||
"submod" not in self.moduledict
|
"__contains__ fails for ModuleDict"
|
||||||
), "__contains__ fails for ModuleDict"
|
)
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
m = MyModule()
|
m = MyModule()
|
||||||
|
|
|
||||||
|
|
@ -405,9 +405,7 @@ class TestPythonBuiltinOP(JitTestCase):
|
||||||
def f():
|
def f():
|
||||||
x = torch.ones(10, 9, 8, 7, 6)
|
x = torch.ones(10, 9, 8, 7, 6)
|
||||||
return x{indices}.shape
|
return x{indices}.shape
|
||||||
""".format(
|
""".format(indices=indices)
|
||||||
indices=indices
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
test_str = test_str.replace(r"'", r"")
|
test_str = test_str.replace(r"'", r"")
|
||||||
scope = {}
|
scope = {}
|
||||||
|
|
|
||||||
|
|
@ -139,9 +139,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
|
||||||
):
|
):
|
||||||
with self.assertWarnsRegex(
|
with self.assertWarnsRegex(
|
||||||
UserWarning,
|
UserWarning,
|
||||||
"doesn't support "
|
"doesn't support instance-level annotations on empty non-base types",
|
||||||
"instance-level annotations on "
|
|
||||||
"empty non-base types",
|
|
||||||
):
|
):
|
||||||
torch.jit.script(M())
|
torch.jit.script(M())
|
||||||
|
|
||||||
|
|
@ -160,9 +158,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
|
||||||
):
|
):
|
||||||
with self.assertWarnsRegex(
|
with self.assertWarnsRegex(
|
||||||
UserWarning,
|
UserWarning,
|
||||||
"doesn't support "
|
"doesn't support instance-level annotations on empty non-base types",
|
||||||
"instance-level annotations on "
|
|
||||||
"empty non-base types",
|
|
||||||
):
|
):
|
||||||
torch.jit.script(M())
|
torch.jit.script(M())
|
||||||
|
|
||||||
|
|
@ -181,9 +177,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
|
||||||
):
|
):
|
||||||
with self.assertWarnsRegex(
|
with self.assertWarnsRegex(
|
||||||
UserWarning,
|
UserWarning,
|
||||||
"doesn't support "
|
"doesn't support instance-level annotations on empty non-base types",
|
||||||
"instance-level annotations on "
|
|
||||||
"empty non-base types",
|
|
||||||
):
|
):
|
||||||
torch.jit.script(M())
|
torch.jit.script(M())
|
||||||
|
|
||||||
|
|
@ -202,9 +196,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
|
||||||
):
|
):
|
||||||
with self.assertWarnsRegex(
|
with self.assertWarnsRegex(
|
||||||
UserWarning,
|
UserWarning,
|
||||||
"doesn't support "
|
"doesn't support instance-level annotations on empty non-base types",
|
||||||
"instance-level annotations on "
|
|
||||||
"empty non-base types",
|
|
||||||
):
|
):
|
||||||
torch.jit.script(M())
|
torch.jit.script(M())
|
||||||
|
|
||||||
|
|
@ -223,9 +215,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
|
||||||
):
|
):
|
||||||
with self.assertWarnsRegex(
|
with self.assertWarnsRegex(
|
||||||
UserWarning,
|
UserWarning,
|
||||||
"doesn't support "
|
"doesn't support instance-level annotations on empty non-base types",
|
||||||
"instance-level annotations on "
|
|
||||||
"empty non-base types",
|
|
||||||
):
|
):
|
||||||
torch.jit.script(M())
|
torch.jit.script(M())
|
||||||
|
|
||||||
|
|
@ -244,9 +234,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
|
||||||
):
|
):
|
||||||
with self.assertWarnsRegex(
|
with self.assertWarnsRegex(
|
||||||
UserWarning,
|
UserWarning,
|
||||||
"doesn't support "
|
"doesn't support instance-level annotations on empty non-base types",
|
||||||
"instance-level annotations on "
|
|
||||||
"empty non-base types",
|
|
||||||
):
|
):
|
||||||
torch.jit.script(M())
|
torch.jit.script(M())
|
||||||
|
|
||||||
|
|
@ -265,9 +253,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
|
||||||
):
|
):
|
||||||
with self.assertWarnsRegex(
|
with self.assertWarnsRegex(
|
||||||
UserWarning,
|
UserWarning,
|
||||||
"doesn't support "
|
"doesn't support instance-level annotations on empty non-base types",
|
||||||
"instance-level annotations on "
|
|
||||||
"empty non-base types",
|
|
||||||
):
|
):
|
||||||
torch.jit.script(M())
|
torch.jit.script(M())
|
||||||
|
|
||||||
|
|
@ -286,9 +272,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
|
||||||
):
|
):
|
||||||
with self.assertWarnsRegex(
|
with self.assertWarnsRegex(
|
||||||
UserWarning,
|
UserWarning,
|
||||||
"doesn't support "
|
"doesn't support instance-level annotations on empty non-base types",
|
||||||
"instance-level annotations on "
|
|
||||||
"empty non-base types",
|
|
||||||
):
|
):
|
||||||
torch.jit.script(M())
|
torch.jit.script(M())
|
||||||
|
|
||||||
|
|
@ -307,9 +291,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
|
||||||
):
|
):
|
||||||
with self.assertWarnsRegex(
|
with self.assertWarnsRegex(
|
||||||
UserWarning,
|
UserWarning,
|
||||||
"doesn't support "
|
"doesn't support instance-level annotations on empty non-base types",
|
||||||
"instance-level annotations on "
|
|
||||||
"empty non-base types",
|
|
||||||
):
|
):
|
||||||
torch.jit.script(M())
|
torch.jit.script(M())
|
||||||
|
|
||||||
|
|
@ -328,9 +310,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
|
||||||
):
|
):
|
||||||
with self.assertWarnsRegex(
|
with self.assertWarnsRegex(
|
||||||
UserWarning,
|
UserWarning,
|
||||||
"doesn't support "
|
"doesn't support instance-level annotations on empty non-base types",
|
||||||
"instance-level annotations on "
|
|
||||||
"empty non-base types",
|
|
||||||
):
|
):
|
||||||
torch.jit.script(M())
|
torch.jit.script(M())
|
||||||
|
|
||||||
|
|
@ -351,9 +331,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
|
||||||
):
|
):
|
||||||
with self.assertWarnsRegex(
|
with self.assertWarnsRegex(
|
||||||
UserWarning,
|
UserWarning,
|
||||||
"doesn't support "
|
"doesn't support instance-level annotations on empty non-base types",
|
||||||
"instance-level annotations on "
|
|
||||||
"empty non-base types",
|
|
||||||
):
|
):
|
||||||
torch.jit.script(M())
|
torch.jit.script(M())
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -960,8 +960,9 @@ class TestTracer(JitTestCase):
|
||||||
V = Variable
|
V = Variable
|
||||||
a, b = V(torch.rand(1)), V(torch.rand(1))
|
a, b = V(torch.rand(1)), V(torch.rand(1))
|
||||||
ge = torch.jit.trace(foo, (a, b))
|
ge = torch.jit.trace(foo, (a, b))
|
||||||
a, b = V(torch.rand(1), requires_grad=True), V(
|
a, b = (
|
||||||
torch.rand(1), requires_grad=True
|
V(torch.rand(1), requires_grad=True),
|
||||||
|
V(torch.rand(1), requires_grad=True),
|
||||||
)
|
)
|
||||||
(r,) = ge(a, b)
|
(r,) = ge(a, b)
|
||||||
da, db = torch.autograd.grad(r + 3, [a, b], create_graph=True)
|
da, db = torch.autograd.grad(r + 3, [a, b], create_graph=True)
|
||||||
|
|
|
||||||
|
|
@ -396,9 +396,7 @@ class TestUnion(JitTestCase):
|
||||||
|
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
RuntimeError,
|
RuntimeError,
|
||||||
"only int, float, "
|
"only int, float, complex, Tensor, device and string keys are supported",
|
||||||
"complex, Tensor, device and string keys "
|
|
||||||
"are supported",
|
|
||||||
):
|
):
|
||||||
torch.jit.script(fn)
|
torch.jit.script(fn)
|
||||||
|
|
||||||
|
|
@ -602,9 +600,7 @@ class TestUnion(JitTestCase):
|
||||||
|
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
RuntimeError,
|
RuntimeError,
|
||||||
"y is set to type str"
|
"y is set to type str in the true branch and type int in the false branch",
|
||||||
" in the true branch and type int "
|
|
||||||
"in the false branch",
|
|
||||||
):
|
):
|
||||||
torch.jit.script(fn)
|
torch.jit.script(fn)
|
||||||
|
|
||||||
|
|
@ -622,9 +618,7 @@ class TestUnion(JitTestCase):
|
||||||
|
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
RuntimeError,
|
RuntimeError,
|
||||||
"previously had type "
|
"previously had type str but is now being assigned to a value of type int",
|
||||||
"str but is now being assigned to a"
|
|
||||||
" value of type int",
|
|
||||||
):
|
):
|
||||||
torch.jit.script(fn)
|
torch.jit.script(fn)
|
||||||
|
|
||||||
|
|
@ -729,8 +723,7 @@ class TestUnion(JitTestCase):
|
||||||
template,
|
template,
|
||||||
"Union[List[str], List[torch.Tensor]]",
|
"Union[List[str], List[torch.Tensor]]",
|
||||||
lhs["list_literal_empty"],
|
lhs["list_literal_empty"],
|
||||||
"there are multiple possible List type "
|
"there are multiple possible List type candidates in the Union annotation",
|
||||||
"candidates in the Union annotation",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self._assert_passes(
|
self._assert_passes(
|
||||||
|
|
@ -902,8 +895,7 @@ class TestUnion(JitTestCase):
|
||||||
template,
|
template,
|
||||||
"Union[Dict[str, torch.Tensor], Dict[str, int]]",
|
"Union[Dict[str, torch.Tensor], Dict[str, int]]",
|
||||||
lhs["dict_literal_of_mixed"],
|
lhs["dict_literal_of_mixed"],
|
||||||
"none of those dict types can hold the "
|
"none of those dict types can hold the types of the given keys and values",
|
||||||
"types of the given keys and values",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: String frontend does not support tuple unpacking
|
# TODO: String frontend does not support tuple unpacking
|
||||||
|
|
|
||||||
|
|
@ -406,9 +406,7 @@ class TestUnion(JitTestCase):
|
||||||
|
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
RuntimeError,
|
RuntimeError,
|
||||||
"only int, float, "
|
"only int, float, complex, Tensor, device and string keys are supported",
|
||||||
"complex, Tensor, device and string keys "
|
|
||||||
"are supported",
|
|
||||||
):
|
):
|
||||||
torch.jit.script(fn)
|
torch.jit.script(fn)
|
||||||
|
|
||||||
|
|
@ -612,9 +610,7 @@ class TestUnion(JitTestCase):
|
||||||
|
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
RuntimeError,
|
RuntimeError,
|
||||||
"y is set to type str"
|
"y is set to type str in the true branch and type int in the false branch",
|
||||||
" in the true branch and type int "
|
|
||||||
"in the false branch",
|
|
||||||
):
|
):
|
||||||
torch.jit.script(fn)
|
torch.jit.script(fn)
|
||||||
|
|
||||||
|
|
@ -632,9 +628,7 @@ class TestUnion(JitTestCase):
|
||||||
|
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
RuntimeError,
|
RuntimeError,
|
||||||
"previously had type "
|
"previously had type str but is now being assigned to a value of type int",
|
||||||
"str but is now being assigned to a"
|
|
||||||
" value of type int",
|
|
||||||
):
|
):
|
||||||
torch.jit.script(fn)
|
torch.jit.script(fn)
|
||||||
|
|
||||||
|
|
@ -739,8 +733,7 @@ class TestUnion(JitTestCase):
|
||||||
template,
|
template,
|
||||||
"List[str] | List[torch.Tensor]",
|
"List[str] | List[torch.Tensor]",
|
||||||
lhs["list_literal_empty"],
|
lhs["list_literal_empty"],
|
||||||
"there are multiple possible List type "
|
"there are multiple possible List type candidates in the Union annotation",
|
||||||
"candidates in the Union annotation",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self._assert_passes(
|
self._assert_passes(
|
||||||
|
|
@ -906,8 +899,7 @@ class TestUnion(JitTestCase):
|
||||||
template,
|
template,
|
||||||
"Dict[str, torch.Tensor] | Dict[str, int]",
|
"Dict[str, torch.Tensor] | Dict[str, int]",
|
||||||
lhs["dict_literal_of_mixed"],
|
lhs["dict_literal_of_mixed"],
|
||||||
"none of those dict types can hold the "
|
"none of those dict types can hold the types of the given keys and values",
|
||||||
"types of the given keys and values",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: String frontend does not support tuple unpacking
|
# TODO: String frontend does not support tuple unpacking
|
||||||
|
|
|
||||||
|
|
@ -135,12 +135,14 @@ class TestWarn(JitTestCase):
|
||||||
bar()
|
bar()
|
||||||
|
|
||||||
FileCheck().check_count(
|
FileCheck().check_count(
|
||||||
str="UserWarning: I am warning you from foo", count=1, exactly=True
|
str="UserWarning: I am warning you from foo",
|
||||||
|
count=1,
|
||||||
|
exactly=True,
|
||||||
).check_count(
|
).check_count(
|
||||||
str="UserWarning: I am warning you from bar", count=1, exactly=True
|
str="UserWarning: I am warning you from bar",
|
||||||
).run(
|
count=1,
|
||||||
f.getvalue()
|
exactly=True,
|
||||||
)
|
).run(f.getvalue())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -42,12 +42,12 @@ class LazyGeneratorTest(TestCase):
|
||||||
|
|
||||||
torch._lazy.mark_step()
|
torch._lazy.mark_step()
|
||||||
|
|
||||||
assert torch.allclose(
|
assert torch.allclose(cpu_t1, lazy_t1.to("cpu")), (
|
||||||
cpu_t1, lazy_t1.to("cpu")
|
f"Expected {cpu_t1}, got {lazy_t1.to('cpu')}"
|
||||||
), f"Expected {cpu_t1}, got {lazy_t1.to('cpu')}"
|
)
|
||||||
assert torch.allclose(
|
assert torch.allclose(cpu_t2, lazy_t2.to("cpu")), (
|
||||||
cpu_t2, lazy_t2.to("cpu")
|
f"Expected {cpu_t2}, got {lazy_t2.to('cpu')}"
|
||||||
), f"Expected {cpu_t2}, got {lazy_t2.to('cpu')}"
|
)
|
||||||
|
|
||||||
@skipIfTorchDynamo("Torch Dynamo does not support torch.Generator type")
|
@skipIfTorchDynamo("Torch Dynamo does not support torch.Generator type")
|
||||||
def test_generator_causes_multiple_compiles(self):
|
def test_generator_causes_multiple_compiles(self):
|
||||||
|
|
@ -69,29 +69,29 @@ class LazyGeneratorTest(TestCase):
|
||||||
torch._lazy.mark_step()
|
torch._lazy.mark_step()
|
||||||
|
|
||||||
uncached_compile = metrics.counter_value("UncachedCompile")
|
uncached_compile = metrics.counter_value("UncachedCompile")
|
||||||
assert (
|
assert uncached_compile == 1, (
|
||||||
uncached_compile == 1
|
f"Expected 1 uncached compiles, got {uncached_compile}"
|
||||||
), f"Expected 1 uncached compiles, got {uncached_compile}"
|
)
|
||||||
|
|
||||||
t = generate_tensor(2)
|
t = generate_tensor(2)
|
||||||
torch._lazy.mark_step()
|
torch._lazy.mark_step()
|
||||||
|
|
||||||
uncached_compile = metrics.counter_value("UncachedCompile")
|
uncached_compile = metrics.counter_value("UncachedCompile")
|
||||||
assert (
|
assert uncached_compile == 2, (
|
||||||
uncached_compile == 2
|
f"Expected 2 uncached compiles, got {uncached_compile}"
|
||||||
), f"Expected 2 uncached compiles, got {uncached_compile}"
|
)
|
||||||
|
|
||||||
t = generate_tensor(1) # noqa: F841
|
t = generate_tensor(1) # noqa: F841
|
||||||
torch._lazy.mark_step()
|
torch._lazy.mark_step()
|
||||||
|
|
||||||
uncached_compile = metrics.counter_value("UncachedCompile")
|
uncached_compile = metrics.counter_value("UncachedCompile")
|
||||||
assert (
|
assert uncached_compile == 2, (
|
||||||
uncached_compile == 2
|
f"Expected 2 uncached compiles, got {uncached_compile}"
|
||||||
), f"Expected 2 uncached compiles, got {uncached_compile}"
|
)
|
||||||
cached_compile = metrics.counter_value("CachedCompile")
|
cached_compile = metrics.counter_value("CachedCompile")
|
||||||
assert (
|
assert cached_compile == 1, (
|
||||||
cached_compile == 1
|
f"Expected 1 cached compile, got {cached_compile}"
|
||||||
), f"Expected 1 cached compile, got {cached_compile}"
|
)
|
||||||
|
|
||||||
metrics.reset()
|
metrics.reset()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -486,17 +486,9 @@ class TestLiteScriptModule(TestCase):
|
||||||
"Traceback of TorchScript"
|
"Traceback of TorchScript"
|
||||||
).check("self.b.forwardError").check_next(
|
).check("self.b.forwardError").check_next(
|
||||||
"~~~~~~~~~~~~~~~~~~~ <--- HERE"
|
"~~~~~~~~~~~~~~~~~~~ <--- HERE"
|
||||||
).check(
|
).check("return self.call").check_next("~~~~~~~~~ <--- HERE").check(
|
||||||
"return self.call"
|
|
||||||
).check_next(
|
|
||||||
"~~~~~~~~~ <--- HERE"
|
|
||||||
).check(
|
|
||||||
"return torch.ones"
|
"return torch.ones"
|
||||||
).check_next(
|
).check_next("~~~~~~~~~~ <--- HERE").run(str(exp))
|
||||||
"~~~~~~~~~~ <--- HERE"
|
|
||||||
).run(
|
|
||||||
str(exp)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestLiteScriptQuantizedModule(QuantizationLiteTestCase):
|
class TestLiteScriptQuantizedModule(QuantizationLiteTestCase):
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,7 @@ class TestNnModule(torch.nn.Module):
|
||||||
torch.nn.ReLU(True),
|
torch.nn.ReLU(True),
|
||||||
# state size. (ngf) x 32 x 32
|
# state size. (ngf) x 32 x 32
|
||||||
torch.nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
|
torch.nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
|
||||||
torch.nn.Tanh()
|
torch.nn.Tanh(),
|
||||||
# state size. (nc) x 64 x 64
|
# state size. (nc) x 64 x 64
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -721,9 +721,9 @@ class TestExecutionTrace(TestCase):
|
||||||
with tempfile.TemporaryDirectory() as temp_dir:
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
fp_name = os.path.join(temp_dir, "test.et.json")
|
fp_name = os.path.join(temp_dir, "test.et.json")
|
||||||
|
|
||||||
os.environ[
|
os.environ["ENABLE_PYTORCH_EXECUTION_TRACE_SAVE_INTEGRAL_TENSOR_DATA"] = (
|
||||||
"ENABLE_PYTORCH_EXECUTION_TRACE_SAVE_INTEGRAL_TENSOR_DATA"
|
"aten::gather"
|
||||||
] = "aten::gather"
|
)
|
||||||
et = ExecutionTraceObserver()
|
et = ExecutionTraceObserver()
|
||||||
et.register_callback(fp_name)
|
et.register_callback(fp_name)
|
||||||
et.set_extra_resource_collection(True)
|
et.set_extra_resource_collection(True)
|
||||||
|
|
|
||||||
|
|
@ -1468,7 +1468,7 @@ class TestProfiler(TestCase):
|
||||||
cats = {e.get("cat", None) for e in j["traceEvents"]}
|
cats = {e.get("cat", None) for e in j["traceEvents"]}
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
"cuda_sync" in cats,
|
"cuda_sync" in cats,
|
||||||
"Expected to find cuda_sync event" f" found = {cats}",
|
f"Expected to find cuda_sync event found = {cats}",
|
||||||
)
|
)
|
||||||
|
|
||||||
print("Testing enable_cuda_sync_events in _ExperimentalConfig")
|
print("Testing enable_cuda_sync_events in _ExperimentalConfig")
|
||||||
|
|
|
||||||
|
|
@ -47,6 +47,6 @@ class AOMigrationTestCase(TestCase):
|
||||||
new_dict = getattr(new_location, dict_name)
|
new_dict = getattr(new_location, dict_name)
|
||||||
assert old_dict == new_dict, f"Dicts don't match: {dict_name}"
|
assert old_dict == new_dict, f"Dicts don't match: {dict_name}"
|
||||||
for key in new_dict.keys():
|
for key in new_dict.keys():
|
||||||
assert (
|
assert old_dict[key] == new_dict[key], (
|
||||||
old_dict[key] == new_dict[key]
|
f"Dicts don't match: {dict_name} for key {key}"
|
||||||
), f"Dicts don't match: {dict_name} for key {key}"
|
)
|
||||||
|
|
|
||||||
|
|
@ -426,7 +426,6 @@ instantiate_device_type_tests(TestFloat4Dtype, globals())
|
||||||
|
|
||||||
|
|
||||||
class TestFloat8DtypeCPUOnly(TestCase):
|
class TestFloat8DtypeCPUOnly(TestCase):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Test of mul implementation
|
Test of mul implementation
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -248,9 +248,9 @@ class _ReferenceConvBnNd(torch.nn.Conv2d, torch.nn.modules.conv._ConvNd):
|
||||||
+ cls._FLOAT_MODULE.__name__
|
+ cls._FLOAT_MODULE.__name__
|
||||||
)
|
)
|
||||||
if not qconfig:
|
if not qconfig:
|
||||||
assert hasattr(
|
assert hasattr(mod, "qconfig"), (
|
||||||
mod, "qconfig"
|
"Input float module must have qconfig defined"
|
||||||
), "Input float module must have qconfig defined"
|
)
|
||||||
assert mod.qconfig, "Input float module must have a valid qconfig"
|
assert mod.qconfig, "Input float module must have a valid qconfig"
|
||||||
qconfig = mod.qconfig
|
qconfig = mod.qconfig
|
||||||
conv, bn = mod[0], mod[1]
|
conv, bn = mod[0], mod[1]
|
||||||
|
|
|
||||||
|
|
@ -99,17 +99,17 @@ class OnDevicePTQUtils:
|
||||||
):
|
):
|
||||||
raise ValueError("Quantized weight must be produced.")
|
raise ValueError("Quantized weight must be produced.")
|
||||||
fp_weight = weight.inputsAt(0).node()
|
fp_weight = weight.inputsAt(0).node()
|
||||||
assert (
|
assert fp_weight.kind() == "prim::GetAttr", (
|
||||||
fp_weight.kind() == "prim::GetAttr"
|
"Weight must be an attribute of the module."
|
||||||
), "Weight must be an attribute of the module."
|
)
|
||||||
fp_weight_name = fp_weight.s("name")
|
fp_weight_name = fp_weight.s("name")
|
||||||
return fp_weight_name
|
return fp_weight_name
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def is_per_channel_quantized_packed_param(node):
|
def is_per_channel_quantized_packed_param(node):
|
||||||
assert (
|
assert node.kind() == "quantized::linear_prepack", (
|
||||||
node.kind() == "quantized::linear_prepack"
|
"Node must corresponds to linear_prepack."
|
||||||
), "Node must corresponds to linear_prepack."
|
)
|
||||||
weight = node.inputsAt(0).node()
|
weight = node.inputsAt(0).node()
|
||||||
assert (
|
assert (
|
||||||
weight.kind() != "aten::quantize_per_tensor"
|
weight.kind() != "aten::quantize_per_tensor"
|
||||||
|
|
|
||||||
|
|
@ -124,11 +124,7 @@ class TestQuantizeJitPasses(QuantizationTestCase):
|
||||||
"aten::dequantize"
|
"aten::dequantize"
|
||||||
).check_not("aten::quantize_per_channel").check("aten::dequantize").check_next(
|
).check_not("aten::quantize_per_channel").check("aten::dequantize").check_next(
|
||||||
"aten::conv2d"
|
"aten::conv2d"
|
||||||
).check_next(
|
).check_next("aten::quantize_per_tensor").check_next("aten::dequantize").run(
|
||||||
"aten::quantize_per_tensor"
|
|
||||||
).check_next(
|
|
||||||
"aten::dequantize"
|
|
||||||
).run(
|
|
||||||
freezed.graph
|
freezed.graph
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -670,9 +666,9 @@ class TestQuantizeJitPasses(QuantizationTestCase):
|
||||||
}
|
}
|
||||||
assert len(activation_dtypes) == 1, "Expected to have 1 activation dtype"
|
assert len(activation_dtypes) == 1, "Expected to have 1 activation dtype"
|
||||||
assert len(weight_dtypes) == 1, "Expected to have 1 weight dtype"
|
assert len(weight_dtypes) == 1, "Expected to have 1 weight dtype"
|
||||||
assert next(iter(activation_dtypes)) != next(
|
assert next(iter(activation_dtypes)) != next(iter(weight_dtypes)), (
|
||||||
iter(weight_dtypes)
|
"Expected activation dtype to "
|
||||||
), "Expected activation dtype to "
|
)
|
||||||
" be different from wegiht dtype"
|
" be different from wegiht dtype"
|
||||||
|
|
||||||
def test_insert_observers_for_reused_weight(self):
|
def test_insert_observers_for_reused_weight(self):
|
||||||
|
|
@ -706,9 +702,9 @@ class TestQuantizeJitPasses(QuantizationTestCase):
|
||||||
conv2_observers = attrs_with_prefix(m.conv2, "_observer_")
|
conv2_observers = attrs_with_prefix(m.conv2, "_observer_")
|
||||||
assert len(conv1_observers) == 1, "Expected to have 1 observer submodules"
|
assert len(conv1_observers) == 1, "Expected to have 1 observer submodules"
|
||||||
assert len(conv2_observers) == 1, "Expected to have 1 observer submodules"
|
assert len(conv2_observers) == 1, "Expected to have 1 observer submodules"
|
||||||
assert (
|
assert conv1_observers == conv2_observers, (
|
||||||
conv1_observers == conv2_observers
|
"Expect conv1 and conv2 to have same observers since the class type is shared"
|
||||||
), "Expect conv1 and conv2 to have same observers since the class type is shared"
|
)
|
||||||
|
|
||||||
def test_insert_observers_for_general_ops(self):
|
def test_insert_observers_for_general_ops(self):
|
||||||
"""Make sure we skip observers for ops that doesn't require
|
"""Make sure we skip observers for ops that doesn't require
|
||||||
|
|
@ -734,13 +730,9 @@ class TestQuantizeJitPasses(QuantizationTestCase):
|
||||||
'prim::GetAttr[name="conv"]'
|
'prim::GetAttr[name="conv"]'
|
||||||
).check("prim::CallMethod").check(
|
).check("prim::CallMethod").check(
|
||||||
'Observer = prim::GetAttr[name="_observer_'
|
'Observer = prim::GetAttr[name="_observer_'
|
||||||
).check(
|
).check("aten::flatten").check_not(
|
||||||
"aten::flatten"
|
|
||||||
).check_not(
|
|
||||||
'Observer = prim::GetAttr[name="_observer_'
|
'Observer = prim::GetAttr[name="_observer_'
|
||||||
).run(
|
).run(m.graph)
|
||||||
m.graph
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: this is too long, split this to test_insert_observers.py and remove
|
# TODO: this is too long, split this to test_insert_observers.py and remove
|
||||||
# insrt_observers prefix
|
# insrt_observers prefix
|
||||||
|
|
@ -770,17 +762,11 @@ class TestQuantizeJitPasses(QuantizationTestCase):
|
||||||
'prim::GetAttr[name="conv1"]'
|
'prim::GetAttr[name="conv1"]'
|
||||||
).check("prim::CallMethod").check(
|
).check("prim::CallMethod").check(
|
||||||
'Observer = prim::GetAttr[name="_observer_'
|
'Observer = prim::GetAttr[name="_observer_'
|
||||||
).check(
|
).check("aten::flatten").check_not(
|
||||||
"aten::flatten"
|
|
||||||
).check_not(
|
|
||||||
'Observer = prim::GetAttr[name="_observer_'
|
'Observer = prim::GetAttr[name="_observer_'
|
||||||
).check(
|
).check('prim::GetAttr[name="conv2"]').check(
|
||||||
'prim::GetAttr[name="conv2"]'
|
|
||||||
).check(
|
|
||||||
'Observer = prim::GetAttr[name="_observer_'
|
'Observer = prim::GetAttr[name="_observer_'
|
||||||
).run(
|
).run(m.graph)
|
||||||
m.graph
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_insert_observers_propagate_observed_in_submodule(self):
|
def test_insert_observers_propagate_observed_in_submodule(self):
|
||||||
"""Make sure we propagate observed property through general ops"""
|
"""Make sure we propagate observed property through general ops"""
|
||||||
|
|
@ -809,17 +795,11 @@ class TestQuantizeJitPasses(QuantizationTestCase):
|
||||||
'prim::GetAttr[name="conv1"]'
|
'prim::GetAttr[name="conv1"]'
|
||||||
).check("prim::CallMethod").check(
|
).check("prim::CallMethod").check(
|
||||||
'Observer = prim::GetAttr[name="_observer_'
|
'Observer = prim::GetAttr[name="_observer_'
|
||||||
).check(
|
).check("prim::CallMethod").check_not(
|
||||||
"prim::CallMethod"
|
|
||||||
).check_not(
|
|
||||||
'Observer = prim::GetAttr[name="_observer_'
|
'Observer = prim::GetAttr[name="_observer_'
|
||||||
).check(
|
).check('prim::GetAttr[name="conv2"]').check(
|
||||||
'prim::GetAttr[name="conv2"]'
|
|
||||||
).check(
|
|
||||||
'Observer = prim::GetAttr[name="_observer_'
|
'Observer = prim::GetAttr[name="_observer_'
|
||||||
).run(
|
).run(m.graph)
|
||||||
m.graph
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_insert_observers_propagate_observed_for_function(self):
|
def test_insert_observers_propagate_observed_for_function(self):
|
||||||
def channel_shuffle(x: torch.Tensor, groups: int) -> torch.Tensor:
|
def channel_shuffle(x: torch.Tensor, groups: int) -> torch.Tensor:
|
||||||
|
|
@ -1055,9 +1035,9 @@ class TestQuantizeJitPasses(QuantizationTestCase):
|
||||||
|
|
||||||
m(data)
|
m(data)
|
||||||
m = convert_jit(m, debug=True)
|
m = convert_jit(m, debug=True)
|
||||||
assert (
|
assert len(m._modules._c.items()) == 1, (
|
||||||
len(m._modules._c.items()) == 1
|
"Expected to have single submodule of conv"
|
||||||
), "Expected to have single submodule of conv"
|
)
|
||||||
# make sure the quantized model is executable
|
# make sure the quantized model is executable
|
||||||
m(data)
|
m(data)
|
||||||
quant_func = (
|
quant_func = (
|
||||||
|
|
@ -1088,17 +1068,17 @@ class TestQuantizeJitPasses(QuantizationTestCase):
|
||||||
qconfig_dict = {"": qconfig}
|
qconfig_dict = {"": qconfig}
|
||||||
m = prepare_jit(m, qconfig_dict)
|
m = prepare_jit(m, qconfig_dict)
|
||||||
# observers for input, output and value between conv1/conv2
|
# observers for input, output and value between conv1/conv2
|
||||||
assert (
|
assert len(attrs_with_prefix(m, "_observer_")) == 3, (
|
||||||
len(attrs_with_prefix(m, "_observer_")) == 3
|
"Expected to have 3 obervers"
|
||||||
), "Expected to have 3 obervers"
|
)
|
||||||
# observer for weight
|
# observer for weight
|
||||||
assert (
|
assert len(attrs_with_prefix(m.conv1, "_observer_")) == 1, (
|
||||||
len(attrs_with_prefix(m.conv1, "_observer_")) == 1
|
"Expected to have 1 obervers"
|
||||||
), "Expected to have 1 obervers"
|
)
|
||||||
# observer for weight
|
# observer for weight
|
||||||
assert (
|
assert len(attrs_with_prefix(m.conv2, "_observer_")) == 1, (
|
||||||
len(attrs_with_prefix(m.conv2, "_observer_")) == 1
|
"Expected to have 1 obervers"
|
||||||
), "Expected to have 1 obervers"
|
)
|
||||||
|
|
||||||
data = torch.randn(1, 3, 10, 10, dtype=torch.float)
|
data = torch.randn(1, 3, 10, 10, dtype=torch.float)
|
||||||
m(data)
|
m(data)
|
||||||
|
|
@ -1107,15 +1087,15 @@ class TestQuantizeJitPasses(QuantizationTestCase):
|
||||||
assert m.conv1._c._type() == m.conv2._c._type()
|
assert m.conv1._c._type() == m.conv2._c._type()
|
||||||
|
|
||||||
# check all observers have been removed
|
# check all observers have been removed
|
||||||
assert (
|
assert len(attrs_with_prefix(m, "_observer_")) == 0, (
|
||||||
len(attrs_with_prefix(m, "_observer_")) == 0
|
"Expected to have 0 obervers"
|
||||||
), "Expected to have 0 obervers"
|
)
|
||||||
assert (
|
assert len(attrs_with_prefix(m.conv1, "_observer_")) == 0, (
|
||||||
len(attrs_with_prefix(m.conv1, "_observer_")) == 0
|
"Expected to have 0 obervers"
|
||||||
), "Expected to have 0 obervers"
|
)
|
||||||
assert (
|
assert len(attrs_with_prefix(m.conv2, "_observer_")) == 0, (
|
||||||
len(attrs_with_prefix(m.conv2, "_observer_")) == 0
|
"Expected to have 0 obervers"
|
||||||
), "Expected to have 0 obervers"
|
)
|
||||||
|
|
||||||
quant_func = (
|
quant_func = (
|
||||||
"aten::quantize_per_channel"
|
"aten::quantize_per_channel"
|
||||||
|
|
@ -1334,11 +1314,7 @@ class TestQuantizeJitPasses(QuantizationTestCase):
|
||||||
"aten::avg_pool2d"
|
"aten::avg_pool2d"
|
||||||
).check("aten::q_scale").check_next("aten::q_zero_point").check_next(
|
).check("aten::q_scale").check_next("aten::q_zero_point").check_next(
|
||||||
"prim::dtype"
|
"prim::dtype"
|
||||||
).check_next(
|
).check_next("aten::quantize_per_tensor").check("aten::dequantize").run(
|
||||||
"aten::quantize_per_tensor"
|
|
||||||
).check(
|
|
||||||
"aten::dequantize"
|
|
||||||
).run(
|
|
||||||
model.graph
|
model.graph
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -1757,9 +1733,7 @@ class TestQuantizeJitOps(QuantizationTestCase):
|
||||||
"aten::relu"
|
"aten::relu"
|
||||||
).check_not(f"quantized::conv{dim}d(").check_not(
|
).check_not(f"quantized::conv{dim}d(").check_not(
|
||||||
"quantized::relu("
|
"quantized::relu("
|
||||||
).run(
|
).run(m.graph)
|
||||||
m.graph
|
|
||||||
)
|
|
||||||
|
|
||||||
@skipIfNoFBGEMM
|
@skipIfNoFBGEMM
|
||||||
def test_quantized_add_alpha(self):
|
def test_quantized_add_alpha(self):
|
||||||
|
|
@ -1910,9 +1884,7 @@ class TestQuantizeJitOps(QuantizationTestCase):
|
||||||
"aten::relu("
|
"aten::relu("
|
||||||
).check_not("aten::relu_(").check_not("quantized::add(").check_not(
|
).check_not("aten::relu_(").check_not("quantized::add(").check_not(
|
||||||
"quantized::relu("
|
"quantized::relu("
|
||||||
).run(
|
).run(m.graph)
|
||||||
m.graph
|
|
||||||
)
|
|
||||||
|
|
||||||
@skipIfNoFBGEMM
|
@skipIfNoFBGEMM
|
||||||
def test_quantized_add(self):
|
def test_quantized_add(self):
|
||||||
|
|
@ -2119,9 +2091,7 @@ class TestQuantizeJitOps(QuantizationTestCase):
|
||||||
"aten::relu("
|
"aten::relu("
|
||||||
).check_not("aten::relu_(").check_not("quantized::add(").check_not(
|
).check_not("aten::relu_(").check_not("quantized::add(").check_not(
|
||||||
"quantized::relu("
|
"quantized::relu("
|
||||||
).run(
|
).run(m.graph)
|
||||||
m.graph
|
|
||||||
)
|
|
||||||
|
|
||||||
@skipIfNoFBGEMM
|
@skipIfNoFBGEMM
|
||||||
def test_quantized_add_scalar_relu(self):
|
def test_quantized_add_scalar_relu(self):
|
||||||
|
|
@ -2205,11 +2175,7 @@ class TestQuantizeJitOps(QuantizationTestCase):
|
||||||
"aten::relu("
|
"aten::relu("
|
||||||
).check_not("aten::relu_(").check_not(
|
).check_not("aten::relu_(").check_not(
|
||||||
"quantized::add_scalar("
|
"quantized::add_scalar("
|
||||||
).check_not(
|
).check_not("quantized::relu(").run(m.graph)
|
||||||
"quantized::relu("
|
|
||||||
).run(
|
|
||||||
m.graph
|
|
||||||
)
|
|
||||||
|
|
||||||
@skipIfNoFBGEMM
|
@skipIfNoFBGEMM
|
||||||
def test_quantized_cat(self):
|
def test_quantized_cat(self):
|
||||||
|
|
@ -2544,9 +2510,7 @@ class TestQuantizeJitOps(QuantizationTestCase):
|
||||||
"aten::relu("
|
"aten::relu("
|
||||||
).check_not("aten::relu_(").check_not("quantized::mul(").check_not(
|
).check_not("aten::relu_(").check_not("quantized::mul(").check_not(
|
||||||
"quantized::relu("
|
"quantized::relu("
|
||||||
).run(
|
).run(m.graph)
|
||||||
m.graph
|
|
||||||
)
|
|
||||||
|
|
||||||
@skipIfNoFBGEMM
|
@skipIfNoFBGEMM
|
||||||
def test_quantized_mul_scalar_relu(self):
|
def test_quantized_mul_scalar_relu(self):
|
||||||
|
|
@ -2629,11 +2593,7 @@ class TestQuantizeJitOps(QuantizationTestCase):
|
||||||
"aten::relu("
|
"aten::relu("
|
||||||
).check_not("aten::relu_(").check_not(
|
).check_not("aten::relu_(").check_not(
|
||||||
"quantized::mul_scalar("
|
"quantized::mul_scalar("
|
||||||
).check_not(
|
).check_not("quantized::relu(").run(m.graph)
|
||||||
"quantized::relu("
|
|
||||||
).run(
|
|
||||||
m.graph
|
|
||||||
)
|
|
||||||
|
|
||||||
@override_qengines
|
@override_qengines
|
||||||
def test_hardswish(self):
|
def test_hardswish(self):
|
||||||
|
|
@ -3103,9 +3063,7 @@ class TestQuantizeDynamicJitPasses(QuantizationTestCase):
|
||||||
'Observer = prim::GetAttr[name="_observer_'
|
'Observer = prim::GetAttr[name="_observer_'
|
||||||
).check("prim::CallMethod").check_not(
|
).check("prim::CallMethod").check_not(
|
||||||
'Observer = prim::GetAttr[name="_observer_'
|
'Observer = prim::GetAttr[name="_observer_'
|
||||||
).run(
|
).run(m.graph)
|
||||||
m.graph
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_insert_quant_dequant_linear_dynamic(self):
|
def test_insert_quant_dequant_linear_dynamic(self):
|
||||||
class M(torch.nn.Module):
|
class M(torch.nn.Module):
|
||||||
|
|
@ -3126,9 +3084,9 @@ class TestQuantizeDynamicJitPasses(QuantizationTestCase):
|
||||||
else default_dynamic_qconfig
|
else default_dynamic_qconfig
|
||||||
)
|
)
|
||||||
m = quantize_dynamic_jit(m, {"": qconfig}, debug=True)
|
m = quantize_dynamic_jit(m, {"": qconfig}, debug=True)
|
||||||
assert (
|
assert len(m._modules._c.items()) == 2, (
|
||||||
len(m._modules._c.items()) == 2
|
"Expected to have two submodule of linear"
|
||||||
), "Expected to have two submodule of linear"
|
)
|
||||||
|
|
||||||
wt_quant_func = (
|
wt_quant_func = (
|
||||||
"aten::quantize_per_channel"
|
"aten::quantize_per_channel"
|
||||||
|
|
@ -3141,21 +3099,11 @@ class TestQuantizeDynamicJitPasses(QuantizationTestCase):
|
||||||
act_quant_func
|
act_quant_func
|
||||||
).check_next("aten::dequantize").check(
|
).check_next("aten::dequantize").check(
|
||||||
"aten::_choose_qparams_per_tensor"
|
"aten::_choose_qparams_per_tensor"
|
||||||
).check_next(
|
).check_next(act_quant_func).check_next("aten::dequantize").check(
|
||||||
act_quant_func
|
|
||||||
).check_next(
|
|
||||||
"aten::dequantize"
|
|
||||||
).check(
|
|
||||||
wt_quant_func
|
wt_quant_func
|
||||||
).check_next(
|
).check_next("aten::dequantize").check_not(wt_quant_func).check(
|
||||||
"aten::dequantize"
|
|
||||||
).check_not(
|
|
||||||
wt_quant_func
|
|
||||||
).check(
|
|
||||||
"return"
|
"return"
|
||||||
).run(
|
).run(m.graph)
|
||||||
m.graph
|
|
||||||
)
|
|
||||||
|
|
||||||
@override_qengines
|
@override_qengines
|
||||||
def test_dynamic_multi_op(self):
|
def test_dynamic_multi_op(self):
|
||||||
|
|
|
||||||
|
|
@ -254,16 +254,16 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||||
maxpool_node = node
|
maxpool_node = node
|
||||||
input_act = maxpool_node.args[0]
|
input_act = maxpool_node.args[0]
|
||||||
assert isinstance(input_act, Node)
|
assert isinstance(input_act, Node)
|
||||||
maxpool_node.meta[
|
maxpool_node.meta["quantization_annotation"] = (
|
||||||
"quantization_annotation"
|
QuantizationAnnotation(
|
||||||
] = QuantizationAnnotation(
|
input_qspec_map={
|
||||||
input_qspec_map={
|
input_act: act_qspec,
|
||||||
input_act: act_qspec,
|
},
|
||||||
},
|
output_qspec=SharedQuantizationSpec(
|
||||||
output_qspec=SharedQuantizationSpec(
|
(input_act, maxpool_node)
|
||||||
(input_act, maxpool_node)
|
),
|
||||||
),
|
_annotated=True,
|
||||||
_annotated=True,
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def validate(self, model: torch.fx.GraphModule) -> None:
|
def validate(self, model: torch.fx.GraphModule) -> None:
|
||||||
|
|
@ -339,9 +339,9 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||||
def derive_qparams_fn(
|
def derive_qparams_fn(
|
||||||
obs_or_fqs: list[ObserverOrFakeQuantize],
|
obs_or_fqs: list[ObserverOrFakeQuantize],
|
||||||
) -> tuple[Tensor, Tensor]:
|
) -> tuple[Tensor, Tensor]:
|
||||||
assert (
|
assert len(obs_or_fqs) == 2, (
|
||||||
len(obs_or_fqs) == 2
|
f"Expecting two obs/fqs, one for activation and one for weight, got: {len(obs_or_fqs)}"
|
||||||
), f"Expecting two obs/fqs, one for activation and one for weight, got: {len(obs_or_fqs)}"
|
)
|
||||||
act_obs_or_fq = obs_or_fqs[0]
|
act_obs_or_fq = obs_or_fqs[0]
|
||||||
weight_obs_or_fq = obs_or_fqs[1]
|
weight_obs_or_fq = obs_or_fqs[1]
|
||||||
act_scale, act_zp = act_obs_or_fq.calculate_qparams()
|
act_scale, act_zp = act_obs_or_fq.calculate_qparams()
|
||||||
|
|
@ -442,9 +442,9 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||||
def derive_qparams_fn(
|
def derive_qparams_fn(
|
||||||
obs_or_fqs: list[ObserverOrFakeQuantize],
|
obs_or_fqs: list[ObserverOrFakeQuantize],
|
||||||
) -> tuple[Tensor, Tensor]:
|
) -> tuple[Tensor, Tensor]:
|
||||||
assert (
|
assert len(obs_or_fqs) == 1, (
|
||||||
len(obs_or_fqs) == 1
|
f"Expecting one weight obs/fq, got: {len(obs_or_fqs)}"
|
||||||
), f"Expecting one weight obs/fq, got: {len(obs_or_fqs)}"
|
)
|
||||||
weight_obs_or_fq = obs_or_fqs[0]
|
weight_obs_or_fq = obs_or_fqs[0]
|
||||||
(
|
(
|
||||||
weight_scale,
|
weight_scale,
|
||||||
|
|
@ -748,16 +748,16 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||||
(first_input_node, cat_node)
|
(first_input_node, cat_node)
|
||||||
)
|
)
|
||||||
for input_node in input_nodes[1:]:
|
for input_node in input_nodes[1:]:
|
||||||
input_qspec_map[
|
input_qspec_map[input_node] = (
|
||||||
input_node
|
share_qparams_with_input_act0_qspec
|
||||||
] = share_qparams_with_input_act0_qspec
|
)
|
||||||
|
|
||||||
cat_node.meta[
|
cat_node.meta["quantization_annotation"] = (
|
||||||
"quantization_annotation"
|
QuantizationAnnotation(
|
||||||
] = QuantizationAnnotation(
|
input_qspec_map=input_qspec_map,
|
||||||
input_qspec_map=input_qspec_map,
|
output_qspec=share_qparams_with_input_act0_qspec,
|
||||||
output_qspec=share_qparams_with_input_act0_qspec,
|
_annotated=True,
|
||||||
_annotated=True,
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def validate(self, model: torch.fx.GraphModule) -> None:
|
def validate(self, model: torch.fx.GraphModule) -> None:
|
||||||
|
|
@ -783,9 +783,9 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||||
obs_ins0 = getattr(m, input0.target)
|
obs_ins0 = getattr(m, input0.target)
|
||||||
obs_ins1 = getattr(m, input1.target)
|
obs_ins1 = getattr(m, input1.target)
|
||||||
assert obs_ins0 == obs_ins1
|
assert obs_ins0 == obs_ins1
|
||||||
assert (
|
assert len(conv_output_obs) == 2, (
|
||||||
len(conv_output_obs) == 2
|
"expecting two observer that follows conv2d ops"
|
||||||
), "expecting two observer that follows conv2d ops"
|
)
|
||||||
# checking that the output observers for the two convs are shared as well
|
# checking that the output observers for the two convs are shared as well
|
||||||
assert conv_output_obs[0] == conv_output_obs[1]
|
assert conv_output_obs[0] == conv_output_obs[1]
|
||||||
|
|
||||||
|
|
@ -850,9 +850,9 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||||
obs_ins2 = getattr(m, output_obs.target)
|
obs_ins2 = getattr(m, output_obs.target)
|
||||||
assert obs_ins0 == obs_ins2, "input observer does not match output"
|
assert obs_ins0 == obs_ins2, "input observer does not match output"
|
||||||
|
|
||||||
assert (
|
assert len(conv_output_obs) == 2, (
|
||||||
len(conv_output_obs) == 2
|
"expecting two observer that follows conv2d ops"
|
||||||
), "expecting two observer that follows conv2d ops"
|
)
|
||||||
# checking that the output observers for the two convs are shared as well
|
# checking that the output observers for the two convs are shared as well
|
||||||
assert conv_output_obs[0] == conv_output_obs[1]
|
assert conv_output_obs[0] == conv_output_obs[1]
|
||||||
|
|
||||||
|
|
@ -967,16 +967,16 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||||
(first_input_node, cat_node)
|
(first_input_node, cat_node)
|
||||||
)
|
)
|
||||||
for input_node in input_nodes[1:]:
|
for input_node in input_nodes[1:]:
|
||||||
input_qspec_map[
|
input_qspec_map[input_node] = (
|
||||||
input_node
|
share_qparams_with_input_act0_qspec
|
||||||
] = share_qparams_with_input_act0_qspec
|
)
|
||||||
|
|
||||||
cat_node.meta[
|
cat_node.meta["quantization_annotation"] = (
|
||||||
"quantization_annotation"
|
QuantizationAnnotation(
|
||||||
] = QuantizationAnnotation(
|
input_qspec_map=input_qspec_map,
|
||||||
input_qspec_map=input_qspec_map,
|
output_qspec=share_qparams_with_input_act0_qspec,
|
||||||
output_qspec=share_qparams_with_input_act0_qspec,
|
_annotated=True,
|
||||||
_annotated=True,
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def validate(self, model: torch.fx.GraphModule) -> None:
|
def validate(self, model: torch.fx.GraphModule) -> None:
|
||||||
|
|
@ -1063,16 +1063,16 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||||
share_qparams_with_input_act1_qspec = SharedQuantizationSpec(
|
share_qparams_with_input_act1_qspec = SharedQuantizationSpec(
|
||||||
(second_input_node, cat_node)
|
(second_input_node, cat_node)
|
||||||
)
|
)
|
||||||
input_qspec_map[
|
input_qspec_map[first_input_node] = (
|
||||||
first_input_node
|
share_qparams_with_input_act1_qspec
|
||||||
] = share_qparams_with_input_act1_qspec
|
)
|
||||||
|
|
||||||
cat_node.meta[
|
cat_node.meta["quantization_annotation"] = (
|
||||||
"quantization_annotation"
|
QuantizationAnnotation(
|
||||||
] = QuantizationAnnotation(
|
input_qspec_map=input_qspec_map,
|
||||||
input_qspec_map=input_qspec_map,
|
output_qspec=share_qparams_with_input_act1_qspec,
|
||||||
output_qspec=share_qparams_with_input_act1_qspec,
|
_annotated=True,
|
||||||
_annotated=True,
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def validate(self, model: torch.fx.GraphModule) -> None:
|
def validate(self, model: torch.fx.GraphModule) -> None:
|
||||||
|
|
@ -1121,17 +1121,17 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||||
share_qparams_with_input_act1_qspec = SharedQuantizationSpec(
|
share_qparams_with_input_act1_qspec = SharedQuantizationSpec(
|
||||||
(second_input_node, add_node)
|
(second_input_node, add_node)
|
||||||
)
|
)
|
||||||
input_qspec_map[
|
input_qspec_map[first_input_node] = (
|
||||||
first_input_node
|
share_qparams_with_input_act1_qspec
|
||||||
] = share_qparams_with_input_act1_qspec
|
)
|
||||||
|
|
||||||
add_node.meta[
|
add_node.meta["quantization_annotation"] = (
|
||||||
"quantization_annotation"
|
QuantizationAnnotation(
|
||||||
] = QuantizationAnnotation(
|
input_qspec_map=input_qspec_map,
|
||||||
input_qspec_map=input_qspec_map,
|
output_qspec=share_qparams_with_input_act1_qspec,
|
||||||
output_qspec=share_qparams_with_input_act1_qspec,
|
allow_implicit_sharing=False,
|
||||||
allow_implicit_sharing=False,
|
_annotated=True,
|
||||||
_annotated=True,
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def validate(self, model: torch.fx.GraphModule) -> None:
|
def validate(self, model: torch.fx.GraphModule) -> None:
|
||||||
|
|
|
||||||
|
|
@ -277,9 +277,9 @@ class PT2EQATTestCase(QuantizationTestCase):
|
||||||
|
|
||||||
# Verify: conv literal args
|
# Verify: conv literal args
|
||||||
if expected_conv_literal_args is not None:
|
if expected_conv_literal_args is not None:
|
||||||
assert (
|
assert len(expected_conv_literal_args) == 6, (
|
||||||
len(expected_conv_literal_args) == 6
|
"wrong num conv args, bad test setup"
|
||||||
), "wrong num conv args, bad test setup"
|
)
|
||||||
for i in range(6):
|
for i in range(6):
|
||||||
if i + 3 < len(conv_node.args):
|
if i + 3 < len(conv_node.args):
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
|
|
|
||||||
|
|
@ -157,9 +157,9 @@ async def run1(coroutine_id):
|
||||||
gpuid = coroutine_id % GPUS
|
gpuid = coroutine_id % GPUS
|
||||||
else:
|
else:
|
||||||
gpu_assignments = args.gpus.split(":")
|
gpu_assignments = args.gpus.split(":")
|
||||||
assert args.nproc == len(
|
assert args.nproc == len(gpu_assignments), (
|
||||||
gpu_assignments
|
"Please specify GPU assignment for each process, separated by :"
|
||||||
), "Please specify GPU assignment for each process, separated by :"
|
)
|
||||||
gpuid = gpu_assignments[coroutine_id]
|
gpuid = gpu_assignments[coroutine_id]
|
||||||
|
|
||||||
while progress < len(ALL_TESTS):
|
while progress < len(ALL_TESTS):
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,7 @@
|
||||||
# Owner(s): ["module: dynamo"]
|
# Owner(s): ["module: dynamo"]
|
||||||
|
|
||||||
""" Test functions for limits module.
|
"""Test functions for limits module."""
|
||||||
|
|
||||||
"""
|
|
||||||
import functools
|
import functools
|
||||||
import warnings
|
import warnings
|
||||||
from unittest import expectedFailure as xfail, skipIf
|
from unittest import expectedFailure as xfail, skipIf
|
||||||
|
|
|
||||||
|
|
@ -4104,6 +4104,7 @@ class TestIO(TestCase):
|
||||||
def test_decimal_period_separator():
|
def test_decimal_period_separator():
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def test_decimal_comma_separator():
|
def test_decimal_comma_separator():
|
||||||
with CommaDecimalPointLocale():
|
with CommaDecimalPointLocale():
|
||||||
pass
|
pass
|
||||||
|
|
@ -6786,7 +6787,10 @@ class TestWritebackIfCopy(TestCase):
|
||||||
class TestArange(TestCase):
|
class TestArange(TestCase):
|
||||||
def test_infinite(self):
|
def test_infinite(self):
|
||||||
assert_raises(
|
assert_raises(
|
||||||
(RuntimeError, ValueError), np.arange, 0, np.inf # "unsupported range",
|
(RuntimeError, ValueError),
|
||||||
|
np.arange,
|
||||||
|
0,
|
||||||
|
np.inf, # "unsupported range",
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_nan_step(self):
|
def test_nan_step(self):
|
||||||
|
|
|
||||||
|
|
@ -2733,10 +2733,18 @@ class TestMoveaxis(TestCase):
|
||||||
assert_raises(np.AxisError, np.moveaxis, x, 3, 0) # 'source.*out of bounds',
|
assert_raises(np.AxisError, np.moveaxis, x, 3, 0) # 'source.*out of bounds',
|
||||||
assert_raises(np.AxisError, np.moveaxis, x, -4, 0) # 'source.*out of bounds',
|
assert_raises(np.AxisError, np.moveaxis, x, -4, 0) # 'source.*out of bounds',
|
||||||
assert_raises(
|
assert_raises(
|
||||||
np.AxisError, np.moveaxis, x, 0, 5 # 'destination.*out of bounds',
|
np.AxisError,
|
||||||
|
np.moveaxis,
|
||||||
|
x,
|
||||||
|
0,
|
||||||
|
5, # 'destination.*out of bounds',
|
||||||
)
|
)
|
||||||
assert_raises(
|
assert_raises(
|
||||||
ValueError, np.moveaxis, x, [0, 0], [0, 1] # 'repeated axis in `source`',
|
ValueError,
|
||||||
|
np.moveaxis,
|
||||||
|
x,
|
||||||
|
[0, 0],
|
||||||
|
[0, 1], # 'repeated axis in `source`',
|
||||||
)
|
)
|
||||||
assert_raises(
|
assert_raises(
|
||||||
ValueError, # 'repeated axis in `destination`',
|
ValueError, # 'repeated axis in `destination`',
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@
|
||||||
"""
|
"""
|
||||||
Test the scalar constructors, which also do type-coercion
|
Test the scalar constructors, which also do type-coercion
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
from unittest import skipIf as skipif
|
from unittest import skipIf as skipif
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@
|
||||||
"""
|
"""
|
||||||
Test the scalar constructors, which also do type-coercion
|
Test the scalar constructors, which also do type-coercion
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import fractions
|
import fractions
|
||||||
import functools
|
import functools
|
||||||
import types
|
import types
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,7 @@
|
||||||
# Owner(s): ["module: dynamo"]
|
# Owner(s): ["module: dynamo"]
|
||||||
|
|
||||||
""" Test printing of scalar types.
|
"""Test printing of scalar types."""
|
||||||
|
|
||||||
"""
|
|
||||||
import functools
|
import functools
|
||||||
from unittest import skipIf as skipif
|
from unittest import skipIf as skipif
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -811,7 +811,10 @@ class TestBlock(TestCase):
|
||||||
assert_raises_regex(ValueError, msg, block, [[1], 2])
|
assert_raises_regex(ValueError, msg, block, [[1], 2])
|
||||||
assert_raises_regex(ValueError, msg, block, [[], 2])
|
assert_raises_regex(ValueError, msg, block, [[], 2])
|
||||||
assert_raises_regex(
|
assert_raises_regex(
|
||||||
ValueError, msg, block, [[[1], [2]], [[3, 4]], [5]] # missing brackets
|
ValueError,
|
||||||
|
msg,
|
||||||
|
block,
|
||||||
|
[[[1], [2]], [[3, 4]], [5]], # missing brackets
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_empty_lists(self, block):
|
def test_empty_lists(self, block):
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@
|
||||||
Copied from fftpack.helper by Pearu Peterson, October 2005
|
Copied from fftpack.helper by Pearu Peterson, October 2005
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from torch.testing._internal.common_utils import (
|
from torch.testing._internal.common_utils import (
|
||||||
run_tests,
|
run_tests,
|
||||||
TEST_WITH_TORCHDYNAMO,
|
TEST_WITH_TORCHDYNAMO,
|
||||||
|
|
|
||||||
|
|
@ -372,7 +372,7 @@ class TestFFTThreadSafe(TestCase):
|
||||||
assert_allclose(
|
assert_allclose(
|
||||||
q.get(timeout=5),
|
q.get(timeout=5),
|
||||||
expected,
|
expected,
|
||||||
atol=2e-14
|
atol=2e-14,
|
||||||
# msg="Function returned wrong value in multithreaded context",
|
# msg="Function returned wrong value in multithreaded context",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,7 @@
|
||||||
# Owner(s): ["module: dynamo"]
|
# Owner(s): ["module: dynamo"]
|
||||||
|
|
||||||
"""Test functions for 1D array set operations.
|
"""Test functions for 1D array set operations."""
|
||||||
|
|
||||||
"""
|
|
||||||
from unittest import expectedFailure as xfail, skipIf
|
from unittest import expectedFailure as xfail, skipIf
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
|
|
|
||||||
|
|
@ -2881,7 +2881,8 @@ class TestPercentile(TestCase):
|
||||||
np.testing.assert_equal(res.dtype, arr.dtype)
|
np.testing.assert_equal(res.dtype, arr.dtype)
|
||||||
|
|
||||||
H_F_TYPE_CODES = [
|
H_F_TYPE_CODES = [
|
||||||
(int_type, np.float64) for int_type in "Bbhil" # np.typecodes["AllInteger"]
|
(int_type, np.float64)
|
||||||
|
for int_type in "Bbhil" # np.typecodes["AllInteger"]
|
||||||
] + [
|
] + [
|
||||||
(np.float16, np.float16),
|
(np.float16, np.float16),
|
||||||
(np.float32, np.float32),
|
(np.float32, np.float32),
|
||||||
|
|
|
||||||
|
|
@ -505,8 +505,7 @@ class TestHistogramOptimBinNums(TestCase):
|
||||||
assert_equal(
|
assert_equal(
|
||||||
len(a),
|
len(a),
|
||||||
numbins,
|
numbins,
|
||||||
err_msg=f"For the {estimator} estimator "
|
err_msg=f"For the {estimator} estimator with datasize of {testlen}",
|
||||||
f"with datasize of {testlen}",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_small(self):
|
def test_small(self):
|
||||||
|
|
@ -552,8 +551,7 @@ class TestHistogramOptimBinNums(TestCase):
|
||||||
assert_equal(
|
assert_equal(
|
||||||
len(a),
|
len(a),
|
||||||
expbins,
|
expbins,
|
||||||
err_msg=f"For the {estimator} estimator "
|
err_msg=f"For the {estimator} estimator with datasize of {testlen}",
|
||||||
f"with datasize of {testlen}",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_incorrect_methods(self):
|
def test_incorrect_methods(self):
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,7 @@
|
||||||
# Owner(s): ["module: dynamo"]
|
# Owner(s): ["module: dynamo"]
|
||||||
|
|
||||||
"""Test functions for matrix module
|
"""Test functions for matrix module"""
|
||||||
|
|
||||||
"""
|
|
||||||
import functools
|
import functools
|
||||||
from unittest import expectedFailure as xfail, skipIf as skipif
|
from unittest import expectedFailure as xfail, skipIf as skipif
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
# Owner(s): ["module: dynamo"]
|
# Owner(s): ["module: dynamo"]
|
||||||
""" Test functions for linalg module
|
"""Test functions for linalg module"""
|
||||||
|
|
||||||
"""
|
|
||||||
import functools
|
import functools
|
||||||
import itertools
|
import itertools
|
||||||
import os
|
import os
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
# Owner(s): ["module: dynamo"]
|
# Owner(s): ["module: dynamo"]
|
||||||
|
|
||||||
"""Light smoke test switching between numpy to pytorch random streams.
|
"""Light smoke test switching between numpy to pytorch random streams."""
|
||||||
"""
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ The goal is to validate on numpy, and tests should work when replacing
|
||||||
by
|
by
|
||||||
>>> import torch._numpy as np
|
>>> import torch._numpy as np
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import operator
|
import operator
|
||||||
from unittest import skipIf as skip, SkipTest
|
from unittest import skipIf as skip, SkipTest
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -39,12 +39,9 @@ USE_BLACK_FILELIST = re.compile(
|
||||||
# test/**
|
# test/**
|
||||||
# test/[a-h]*/**
|
# test/[a-h]*/**
|
||||||
# test/[i-j]*/**
|
# test/[i-j]*/**
|
||||||
"test/j*/**",
|
|
||||||
# test/[k-m]*/**
|
# test/[k-m]*/**
|
||||||
"test/[k-m]*/**",
|
|
||||||
# test/optim/**
|
# test/optim/**
|
||||||
# "test/[p-z]*/**",
|
# test/[p-z]*/**,
|
||||||
"test/[p-z]*/**",
|
|
||||||
# torch/**
|
# torch/**
|
||||||
# torch/_[a-c]*/**
|
# torch/_[a-c]*/**
|
||||||
# torch/_[e-h]*/**
|
# torch/_[e-h]*/**
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user