[BE][PYFMT] migrate PYFMT for test/[i-z]*/ to ruff format (#144556)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144556
Approved by: https://github.com/ezyang
This commit is contained in:
Xuehai Pan 2025-07-24 17:56:13 +08:00 committed by PyTorch MergeBot
parent 19ce1beb05
commit 775788f93b
38 changed files with 234 additions and 321 deletions

View File

@ -802,7 +802,8 @@ class AddedAttributesTest(JitBackendTestCase):
# Attach bundled inputs which adds several attributes and functions to the model # Attach bundled inputs which adds several attributes and functions to the model
self.lowered_module = ( self.lowered_module = (
torch.utils.bundled_inputs.augment_model_with_bundled_inputs( torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
lowered_module, input # noqa: F821 lowered_module, # noqa: F821
input,
) )
) )
post_bundled = self.lowered_module( post_bundled = self.lowered_module(

View File

@ -279,23 +279,23 @@ class TestModuleContainers(JitTestCase):
self.moduledict = CustomModuleDict({"submod": self.submod}) self.moduledict = CustomModuleDict({"submod": self.submod})
def forward(self, inputs): def forward(self, inputs):
assert ( assert self.modulelist[0] is self.submod, (
self.modulelist[0] is self.submod "__getitem__ failing for ModuleList"
), "__getitem__ failing for ModuleList" )
assert len(self.modulelist) == 1, "__len__ failing for ModuleList" assert len(self.modulelist) == 1, "__len__ failing for ModuleList"
for module in self.modulelist: for module in self.modulelist:
assert module is self.submod, "__iter__ failing for ModuleList" assert module is self.submod, "__iter__ failing for ModuleList"
assert ( assert self.sequential[0] is self.submod, (
self.sequential[0] is self.submod "__getitem__ failing for Sequential"
), "__getitem__ failing for Sequential" )
assert len(self.sequential) == 1, "__len__ failing for Sequential" assert len(self.sequential) == 1, "__len__ failing for Sequential"
for module in self.sequential: for module in self.sequential:
assert module is self.submod, "__iter__ failing for Sequential" assert module is self.submod, "__iter__ failing for Sequential"
assert ( assert self.moduledict["submod"] is self.submod, (
self.moduledict["submod"] is self.submod "__getitem__ failing for ModuleDict"
), "__getitem__ failing for ModuleDict" )
assert len(self.moduledict) == 1, "__len__ failing for ModuleDict" assert len(self.moduledict) == 1, "__len__ failing for ModuleDict"
# note: unable to index moduledict with a string variable currently # note: unable to index moduledict with a string variable currently
@ -439,9 +439,9 @@ class TestModuleContainers(JitTestCase):
self.moduledict = CustomModuleDict() self.moduledict = CustomModuleDict()
def forward(self, inputs): def forward(self, inputs):
assert ( assert "submod" not in self.moduledict, (
"submod" not in self.moduledict "__contains__ fails for ModuleDict"
), "__contains__ fails for ModuleDict" )
return inputs return inputs
m = MyModule() m = MyModule()

View File

@ -405,9 +405,7 @@ class TestPythonBuiltinOP(JitTestCase):
def f(): def f():
x = torch.ones(10, 9, 8, 7, 6) x = torch.ones(10, 9, 8, 7, 6)
return x{indices}.shape return x{indices}.shape
""".format( """.format(indices=indices)
indices=indices
)
) )
test_str = test_str.replace(r"'", r"") test_str = test_str.replace(r"'", r"")
scope = {} scope = {}

View File

@ -139,9 +139,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
): ):
with self.assertWarnsRegex( with self.assertWarnsRegex(
UserWarning, UserWarning,
"doesn't support " "doesn't support instance-level annotations on empty non-base types",
"instance-level annotations on "
"empty non-base types",
): ):
torch.jit.script(M()) torch.jit.script(M())
@ -160,9 +158,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
): ):
with self.assertWarnsRegex( with self.assertWarnsRegex(
UserWarning, UserWarning,
"doesn't support " "doesn't support instance-level annotations on empty non-base types",
"instance-level annotations on "
"empty non-base types",
): ):
torch.jit.script(M()) torch.jit.script(M())
@ -181,9 +177,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
): ):
with self.assertWarnsRegex( with self.assertWarnsRegex(
UserWarning, UserWarning,
"doesn't support " "doesn't support instance-level annotations on empty non-base types",
"instance-level annotations on "
"empty non-base types",
): ):
torch.jit.script(M()) torch.jit.script(M())
@ -202,9 +196,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
): ):
with self.assertWarnsRegex( with self.assertWarnsRegex(
UserWarning, UserWarning,
"doesn't support " "doesn't support instance-level annotations on empty non-base types",
"instance-level annotations on "
"empty non-base types",
): ):
torch.jit.script(M()) torch.jit.script(M())
@ -223,9 +215,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
): ):
with self.assertWarnsRegex( with self.assertWarnsRegex(
UserWarning, UserWarning,
"doesn't support " "doesn't support instance-level annotations on empty non-base types",
"instance-level annotations on "
"empty non-base types",
): ):
torch.jit.script(M()) torch.jit.script(M())
@ -244,9 +234,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
): ):
with self.assertWarnsRegex( with self.assertWarnsRegex(
UserWarning, UserWarning,
"doesn't support " "doesn't support instance-level annotations on empty non-base types",
"instance-level annotations on "
"empty non-base types",
): ):
torch.jit.script(M()) torch.jit.script(M())
@ -265,9 +253,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
): ):
with self.assertWarnsRegex( with self.assertWarnsRegex(
UserWarning, UserWarning,
"doesn't support " "doesn't support instance-level annotations on empty non-base types",
"instance-level annotations on "
"empty non-base types",
): ):
torch.jit.script(M()) torch.jit.script(M())
@ -286,9 +272,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
): ):
with self.assertWarnsRegex( with self.assertWarnsRegex(
UserWarning, UserWarning,
"doesn't support " "doesn't support instance-level annotations on empty non-base types",
"instance-level annotations on "
"empty non-base types",
): ):
torch.jit.script(M()) torch.jit.script(M())
@ -307,9 +291,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
): ):
with self.assertWarnsRegex( with self.assertWarnsRegex(
UserWarning, UserWarning,
"doesn't support " "doesn't support instance-level annotations on empty non-base types",
"instance-level annotations on "
"empty non-base types",
): ):
torch.jit.script(M()) torch.jit.script(M())
@ -328,9 +310,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
): ):
with self.assertWarnsRegex( with self.assertWarnsRegex(
UserWarning, UserWarning,
"doesn't support " "doesn't support instance-level annotations on empty non-base types",
"instance-level annotations on "
"empty non-base types",
): ):
torch.jit.script(M()) torch.jit.script(M())
@ -351,9 +331,7 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
): ):
with self.assertWarnsRegex( with self.assertWarnsRegex(
UserWarning, UserWarning,
"doesn't support " "doesn't support instance-level annotations on empty non-base types",
"instance-level annotations on "
"empty non-base types",
): ):
torch.jit.script(M()) torch.jit.script(M())

View File

@ -960,8 +960,9 @@ class TestTracer(JitTestCase):
V = Variable V = Variable
a, b = V(torch.rand(1)), V(torch.rand(1)) a, b = V(torch.rand(1)), V(torch.rand(1))
ge = torch.jit.trace(foo, (a, b)) ge = torch.jit.trace(foo, (a, b))
a, b = V(torch.rand(1), requires_grad=True), V( a, b = (
torch.rand(1), requires_grad=True V(torch.rand(1), requires_grad=True),
V(torch.rand(1), requires_grad=True),
) )
(r,) = ge(a, b) (r,) = ge(a, b)
da, db = torch.autograd.grad(r + 3, [a, b], create_graph=True) da, db = torch.autograd.grad(r + 3, [a, b], create_graph=True)

View File

@ -396,9 +396,7 @@ class TestUnion(JitTestCase):
with self.assertRaisesRegex( with self.assertRaisesRegex(
RuntimeError, RuntimeError,
"only int, float, " "only int, float, complex, Tensor, device and string keys are supported",
"complex, Tensor, device and string keys "
"are supported",
): ):
torch.jit.script(fn) torch.jit.script(fn)
@ -602,9 +600,7 @@ class TestUnion(JitTestCase):
with self.assertRaisesRegex( with self.assertRaisesRegex(
RuntimeError, RuntimeError,
"y is set to type str" "y is set to type str in the true branch and type int in the false branch",
" in the true branch and type int "
"in the false branch",
): ):
torch.jit.script(fn) torch.jit.script(fn)
@ -622,9 +618,7 @@ class TestUnion(JitTestCase):
with self.assertRaisesRegex( with self.assertRaisesRegex(
RuntimeError, RuntimeError,
"previously had type " "previously had type str but is now being assigned to a value of type int",
"str but is now being assigned to a"
" value of type int",
): ):
torch.jit.script(fn) torch.jit.script(fn)
@ -729,8 +723,7 @@ class TestUnion(JitTestCase):
template, template,
"Union[List[str], List[torch.Tensor]]", "Union[List[str], List[torch.Tensor]]",
lhs["list_literal_empty"], lhs["list_literal_empty"],
"there are multiple possible List type " "there are multiple possible List type candidates in the Union annotation",
"candidates in the Union annotation",
) )
self._assert_passes( self._assert_passes(
@ -902,8 +895,7 @@ class TestUnion(JitTestCase):
template, template,
"Union[Dict[str, torch.Tensor], Dict[str, int]]", "Union[Dict[str, torch.Tensor], Dict[str, int]]",
lhs["dict_literal_of_mixed"], lhs["dict_literal_of_mixed"],
"none of those dict types can hold the " "none of those dict types can hold the types of the given keys and values",
"types of the given keys and values",
) )
# TODO: String frontend does not support tuple unpacking # TODO: String frontend does not support tuple unpacking

View File

@ -406,9 +406,7 @@ class TestUnion(JitTestCase):
with self.assertRaisesRegex( with self.assertRaisesRegex(
RuntimeError, RuntimeError,
"only int, float, " "only int, float, complex, Tensor, device and string keys are supported",
"complex, Tensor, device and string keys "
"are supported",
): ):
torch.jit.script(fn) torch.jit.script(fn)
@ -612,9 +610,7 @@ class TestUnion(JitTestCase):
with self.assertRaisesRegex( with self.assertRaisesRegex(
RuntimeError, RuntimeError,
"y is set to type str" "y is set to type str in the true branch and type int in the false branch",
" in the true branch and type int "
"in the false branch",
): ):
torch.jit.script(fn) torch.jit.script(fn)
@ -632,9 +628,7 @@ class TestUnion(JitTestCase):
with self.assertRaisesRegex( with self.assertRaisesRegex(
RuntimeError, RuntimeError,
"previously had type " "previously had type str but is now being assigned to a value of type int",
"str but is now being assigned to a"
" value of type int",
): ):
torch.jit.script(fn) torch.jit.script(fn)
@ -739,8 +733,7 @@ class TestUnion(JitTestCase):
template, template,
"List[str] | List[torch.Tensor]", "List[str] | List[torch.Tensor]",
lhs["list_literal_empty"], lhs["list_literal_empty"],
"there are multiple possible List type " "there are multiple possible List type candidates in the Union annotation",
"candidates in the Union annotation",
) )
self._assert_passes( self._assert_passes(
@ -906,8 +899,7 @@ class TestUnion(JitTestCase):
template, template,
"Dict[str, torch.Tensor] | Dict[str, int]", "Dict[str, torch.Tensor] | Dict[str, int]",
lhs["dict_literal_of_mixed"], lhs["dict_literal_of_mixed"],
"none of those dict types can hold the " "none of those dict types can hold the types of the given keys and values",
"types of the given keys and values",
) )
# TODO: String frontend does not support tuple unpacking # TODO: String frontend does not support tuple unpacking

View File

@ -135,12 +135,14 @@ class TestWarn(JitTestCase):
bar() bar()
FileCheck().check_count( FileCheck().check_count(
str="UserWarning: I am warning you from foo", count=1, exactly=True str="UserWarning: I am warning you from foo",
count=1,
exactly=True,
).check_count( ).check_count(
str="UserWarning: I am warning you from bar", count=1, exactly=True str="UserWarning: I am warning you from bar",
).run( count=1,
f.getvalue() exactly=True,
) ).run(f.getvalue())
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -42,12 +42,12 @@ class LazyGeneratorTest(TestCase):
torch._lazy.mark_step() torch._lazy.mark_step()
assert torch.allclose( assert torch.allclose(cpu_t1, lazy_t1.to("cpu")), (
cpu_t1, lazy_t1.to("cpu") f"Expected {cpu_t1}, got {lazy_t1.to('cpu')}"
), f"Expected {cpu_t1}, got {lazy_t1.to('cpu')}" )
assert torch.allclose( assert torch.allclose(cpu_t2, lazy_t2.to("cpu")), (
cpu_t2, lazy_t2.to("cpu") f"Expected {cpu_t2}, got {lazy_t2.to('cpu')}"
), f"Expected {cpu_t2}, got {lazy_t2.to('cpu')}" )
@skipIfTorchDynamo("Torch Dynamo does not support torch.Generator type") @skipIfTorchDynamo("Torch Dynamo does not support torch.Generator type")
def test_generator_causes_multiple_compiles(self): def test_generator_causes_multiple_compiles(self):
@ -69,29 +69,29 @@ class LazyGeneratorTest(TestCase):
torch._lazy.mark_step() torch._lazy.mark_step()
uncached_compile = metrics.counter_value("UncachedCompile") uncached_compile = metrics.counter_value("UncachedCompile")
assert ( assert uncached_compile == 1, (
uncached_compile == 1 f"Expected 1 uncached compiles, got {uncached_compile}"
), f"Expected 1 uncached compiles, got {uncached_compile}" )
t = generate_tensor(2) t = generate_tensor(2)
torch._lazy.mark_step() torch._lazy.mark_step()
uncached_compile = metrics.counter_value("UncachedCompile") uncached_compile = metrics.counter_value("UncachedCompile")
assert ( assert uncached_compile == 2, (
uncached_compile == 2 f"Expected 2 uncached compiles, got {uncached_compile}"
), f"Expected 2 uncached compiles, got {uncached_compile}" )
t = generate_tensor(1) # noqa: F841 t = generate_tensor(1) # noqa: F841
torch._lazy.mark_step() torch._lazy.mark_step()
uncached_compile = metrics.counter_value("UncachedCompile") uncached_compile = metrics.counter_value("UncachedCompile")
assert ( assert uncached_compile == 2, (
uncached_compile == 2 f"Expected 2 uncached compiles, got {uncached_compile}"
), f"Expected 2 uncached compiles, got {uncached_compile}" )
cached_compile = metrics.counter_value("CachedCompile") cached_compile = metrics.counter_value("CachedCompile")
assert ( assert cached_compile == 1, (
cached_compile == 1 f"Expected 1 cached compile, got {cached_compile}"
), f"Expected 1 cached compile, got {cached_compile}" )
metrics.reset() metrics.reset()

View File

@ -486,17 +486,9 @@ class TestLiteScriptModule(TestCase):
"Traceback of TorchScript" "Traceback of TorchScript"
).check("self.b.forwardError").check_next( ).check("self.b.forwardError").check_next(
"~~~~~~~~~~~~~~~~~~~ <--- HERE" "~~~~~~~~~~~~~~~~~~~ <--- HERE"
).check( ).check("return self.call").check_next("~~~~~~~~~ <--- HERE").check(
"return self.call"
).check_next(
"~~~~~~~~~ <--- HERE"
).check(
"return torch.ones" "return torch.ones"
).check_next( ).check_next("~~~~~~~~~~ <--- HERE").run(str(exp))
"~~~~~~~~~~ <--- HERE"
).run(
str(exp)
)
class TestLiteScriptQuantizedModule(QuantizationLiteTestCase): class TestLiteScriptQuantizedModule(QuantizationLiteTestCase):

View File

@ -25,7 +25,7 @@ class TestNnModule(torch.nn.Module):
torch.nn.ReLU(True), torch.nn.ReLU(True),
# state size. (ngf) x 32 x 32 # state size. (ngf) x 32 x 32
torch.nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False), torch.nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
torch.nn.Tanh() torch.nn.Tanh(),
# state size. (nc) x 64 x 64 # state size. (nc) x 64 x 64
) )

View File

@ -721,9 +721,9 @@ class TestExecutionTrace(TestCase):
with tempfile.TemporaryDirectory() as temp_dir: with tempfile.TemporaryDirectory() as temp_dir:
fp_name = os.path.join(temp_dir, "test.et.json") fp_name = os.path.join(temp_dir, "test.et.json")
os.environ[ os.environ["ENABLE_PYTORCH_EXECUTION_TRACE_SAVE_INTEGRAL_TENSOR_DATA"] = (
"ENABLE_PYTORCH_EXECUTION_TRACE_SAVE_INTEGRAL_TENSOR_DATA" "aten::gather"
] = "aten::gather" )
et = ExecutionTraceObserver() et = ExecutionTraceObserver()
et.register_callback(fp_name) et.register_callback(fp_name)
et.set_extra_resource_collection(True) et.set_extra_resource_collection(True)

View File

@ -1468,7 +1468,7 @@ class TestProfiler(TestCase):
cats = {e.get("cat", None) for e in j["traceEvents"]} cats = {e.get("cat", None) for e in j["traceEvents"]}
self.assertTrue( self.assertTrue(
"cuda_sync" in cats, "cuda_sync" in cats,
"Expected to find cuda_sync event" f" found = {cats}", f"Expected to find cuda_sync event found = {cats}",
) )
print("Testing enable_cuda_sync_events in _ExperimentalConfig") print("Testing enable_cuda_sync_events in _ExperimentalConfig")

View File

@ -47,6 +47,6 @@ class AOMigrationTestCase(TestCase):
new_dict = getattr(new_location, dict_name) new_dict = getattr(new_location, dict_name)
assert old_dict == new_dict, f"Dicts don't match: {dict_name}" assert old_dict == new_dict, f"Dicts don't match: {dict_name}"
for key in new_dict.keys(): for key in new_dict.keys():
assert ( assert old_dict[key] == new_dict[key], (
old_dict[key] == new_dict[key] f"Dicts don't match: {dict_name} for key {key}"
), f"Dicts don't match: {dict_name} for key {key}" )

View File

@ -426,7 +426,6 @@ instantiate_device_type_tests(TestFloat4Dtype, globals())
class TestFloat8DtypeCPUOnly(TestCase): class TestFloat8DtypeCPUOnly(TestCase):
""" """
Test of mul implementation Test of mul implementation

View File

@ -248,9 +248,9 @@ class _ReferenceConvBnNd(torch.nn.Conv2d, torch.nn.modules.conv._ConvNd):
+ cls._FLOAT_MODULE.__name__ + cls._FLOAT_MODULE.__name__
) )
if not qconfig: if not qconfig:
assert hasattr( assert hasattr(mod, "qconfig"), (
mod, "qconfig" "Input float module must have qconfig defined"
), "Input float module must have qconfig defined" )
assert mod.qconfig, "Input float module must have a valid qconfig" assert mod.qconfig, "Input float module must have a valid qconfig"
qconfig = mod.qconfig qconfig = mod.qconfig
conv, bn = mod[0], mod[1] conv, bn = mod[0], mod[1]

View File

@ -99,17 +99,17 @@ class OnDevicePTQUtils:
): ):
raise ValueError("Quantized weight must be produced.") raise ValueError("Quantized weight must be produced.")
fp_weight = weight.inputsAt(0).node() fp_weight = weight.inputsAt(0).node()
assert ( assert fp_weight.kind() == "prim::GetAttr", (
fp_weight.kind() == "prim::GetAttr" "Weight must be an attribute of the module."
), "Weight must be an attribute of the module." )
fp_weight_name = fp_weight.s("name") fp_weight_name = fp_weight.s("name")
return fp_weight_name return fp_weight_name
@staticmethod @staticmethod
def is_per_channel_quantized_packed_param(node): def is_per_channel_quantized_packed_param(node):
assert ( assert node.kind() == "quantized::linear_prepack", (
node.kind() == "quantized::linear_prepack" "Node must corresponds to linear_prepack."
), "Node must corresponds to linear_prepack." )
weight = node.inputsAt(0).node() weight = node.inputsAt(0).node()
assert ( assert (
weight.kind() != "aten::quantize_per_tensor" weight.kind() != "aten::quantize_per_tensor"

View File

@ -124,11 +124,7 @@ class TestQuantizeJitPasses(QuantizationTestCase):
"aten::dequantize" "aten::dequantize"
).check_not("aten::quantize_per_channel").check("aten::dequantize").check_next( ).check_not("aten::quantize_per_channel").check("aten::dequantize").check_next(
"aten::conv2d" "aten::conv2d"
).check_next( ).check_next("aten::quantize_per_tensor").check_next("aten::dequantize").run(
"aten::quantize_per_tensor"
).check_next(
"aten::dequantize"
).run(
freezed.graph freezed.graph
) )
@ -670,9 +666,9 @@ class TestQuantizeJitPasses(QuantizationTestCase):
} }
assert len(activation_dtypes) == 1, "Expected to have 1 activation dtype" assert len(activation_dtypes) == 1, "Expected to have 1 activation dtype"
assert len(weight_dtypes) == 1, "Expected to have 1 weight dtype" assert len(weight_dtypes) == 1, "Expected to have 1 weight dtype"
assert next(iter(activation_dtypes)) != next( assert next(iter(activation_dtypes)) != next(iter(weight_dtypes)), (
iter(weight_dtypes) "Expected activation dtype to "
), "Expected activation dtype to " )
" be different from wegiht dtype" " be different from wegiht dtype"
def test_insert_observers_for_reused_weight(self): def test_insert_observers_for_reused_weight(self):
@ -706,9 +702,9 @@ class TestQuantizeJitPasses(QuantizationTestCase):
conv2_observers = attrs_with_prefix(m.conv2, "_observer_") conv2_observers = attrs_with_prefix(m.conv2, "_observer_")
assert len(conv1_observers) == 1, "Expected to have 1 observer submodules" assert len(conv1_observers) == 1, "Expected to have 1 observer submodules"
assert len(conv2_observers) == 1, "Expected to have 1 observer submodules" assert len(conv2_observers) == 1, "Expected to have 1 observer submodules"
assert ( assert conv1_observers == conv2_observers, (
conv1_observers == conv2_observers "Expect conv1 and conv2 to have same observers since the class type is shared"
), "Expect conv1 and conv2 to have same observers since the class type is shared" )
def test_insert_observers_for_general_ops(self): def test_insert_observers_for_general_ops(self):
"""Make sure we skip observers for ops that doesn't require """Make sure we skip observers for ops that doesn't require
@ -734,13 +730,9 @@ class TestQuantizeJitPasses(QuantizationTestCase):
'prim::GetAttr[name="conv"]' 'prim::GetAttr[name="conv"]'
).check("prim::CallMethod").check( ).check("prim::CallMethod").check(
'Observer = prim::GetAttr[name="_observer_' 'Observer = prim::GetAttr[name="_observer_'
).check( ).check("aten::flatten").check_not(
"aten::flatten"
).check_not(
'Observer = prim::GetAttr[name="_observer_' 'Observer = prim::GetAttr[name="_observer_'
).run( ).run(m.graph)
m.graph
)
# TODO: this is too long, split this to test_insert_observers.py and remove # TODO: this is too long, split this to test_insert_observers.py and remove
# insrt_observers prefix # insrt_observers prefix
@ -770,17 +762,11 @@ class TestQuantizeJitPasses(QuantizationTestCase):
'prim::GetAttr[name="conv1"]' 'prim::GetAttr[name="conv1"]'
).check("prim::CallMethod").check( ).check("prim::CallMethod").check(
'Observer = prim::GetAttr[name="_observer_' 'Observer = prim::GetAttr[name="_observer_'
).check( ).check("aten::flatten").check_not(
"aten::flatten"
).check_not(
'Observer = prim::GetAttr[name="_observer_' 'Observer = prim::GetAttr[name="_observer_'
).check( ).check('prim::GetAttr[name="conv2"]').check(
'prim::GetAttr[name="conv2"]'
).check(
'Observer = prim::GetAttr[name="_observer_' 'Observer = prim::GetAttr[name="_observer_'
).run( ).run(m.graph)
m.graph
)
def test_insert_observers_propagate_observed_in_submodule(self): def test_insert_observers_propagate_observed_in_submodule(self):
"""Make sure we propagate observed property through general ops""" """Make sure we propagate observed property through general ops"""
@ -809,17 +795,11 @@ class TestQuantizeJitPasses(QuantizationTestCase):
'prim::GetAttr[name="conv1"]' 'prim::GetAttr[name="conv1"]'
).check("prim::CallMethod").check( ).check("prim::CallMethod").check(
'Observer = prim::GetAttr[name="_observer_' 'Observer = prim::GetAttr[name="_observer_'
).check( ).check("prim::CallMethod").check_not(
"prim::CallMethod"
).check_not(
'Observer = prim::GetAttr[name="_observer_' 'Observer = prim::GetAttr[name="_observer_'
).check( ).check('prim::GetAttr[name="conv2"]').check(
'prim::GetAttr[name="conv2"]'
).check(
'Observer = prim::GetAttr[name="_observer_' 'Observer = prim::GetAttr[name="_observer_'
).run( ).run(m.graph)
m.graph
)
def test_insert_observers_propagate_observed_for_function(self): def test_insert_observers_propagate_observed_for_function(self):
def channel_shuffle(x: torch.Tensor, groups: int) -> torch.Tensor: def channel_shuffle(x: torch.Tensor, groups: int) -> torch.Tensor:
@ -1055,9 +1035,9 @@ class TestQuantizeJitPasses(QuantizationTestCase):
m(data) m(data)
m = convert_jit(m, debug=True) m = convert_jit(m, debug=True)
assert ( assert len(m._modules._c.items()) == 1, (
len(m._modules._c.items()) == 1 "Expected to have single submodule of conv"
), "Expected to have single submodule of conv" )
# make sure the quantized model is executable # make sure the quantized model is executable
m(data) m(data)
quant_func = ( quant_func = (
@ -1088,17 +1068,17 @@ class TestQuantizeJitPasses(QuantizationTestCase):
qconfig_dict = {"": qconfig} qconfig_dict = {"": qconfig}
m = prepare_jit(m, qconfig_dict) m = prepare_jit(m, qconfig_dict)
# observers for input, output and value between conv1/conv2 # observers for input, output and value between conv1/conv2
assert ( assert len(attrs_with_prefix(m, "_observer_")) == 3, (
len(attrs_with_prefix(m, "_observer_")) == 3 "Expected to have 3 obervers"
), "Expected to have 3 obervers" )
# observer for weight # observer for weight
assert ( assert len(attrs_with_prefix(m.conv1, "_observer_")) == 1, (
len(attrs_with_prefix(m.conv1, "_observer_")) == 1 "Expected to have 1 obervers"
), "Expected to have 1 obervers" )
# observer for weight # observer for weight
assert ( assert len(attrs_with_prefix(m.conv2, "_observer_")) == 1, (
len(attrs_with_prefix(m.conv2, "_observer_")) == 1 "Expected to have 1 obervers"
), "Expected to have 1 obervers" )
data = torch.randn(1, 3, 10, 10, dtype=torch.float) data = torch.randn(1, 3, 10, 10, dtype=torch.float)
m(data) m(data)
@ -1107,15 +1087,15 @@ class TestQuantizeJitPasses(QuantizationTestCase):
assert m.conv1._c._type() == m.conv2._c._type() assert m.conv1._c._type() == m.conv2._c._type()
# check all observers have been removed # check all observers have been removed
assert ( assert len(attrs_with_prefix(m, "_observer_")) == 0, (
len(attrs_with_prefix(m, "_observer_")) == 0 "Expected to have 0 obervers"
), "Expected to have 0 obervers" )
assert ( assert len(attrs_with_prefix(m.conv1, "_observer_")) == 0, (
len(attrs_with_prefix(m.conv1, "_observer_")) == 0 "Expected to have 0 obervers"
), "Expected to have 0 obervers" )
assert ( assert len(attrs_with_prefix(m.conv2, "_observer_")) == 0, (
len(attrs_with_prefix(m.conv2, "_observer_")) == 0 "Expected to have 0 obervers"
), "Expected to have 0 obervers" )
quant_func = ( quant_func = (
"aten::quantize_per_channel" "aten::quantize_per_channel"
@ -1334,11 +1314,7 @@ class TestQuantizeJitPasses(QuantizationTestCase):
"aten::avg_pool2d" "aten::avg_pool2d"
).check("aten::q_scale").check_next("aten::q_zero_point").check_next( ).check("aten::q_scale").check_next("aten::q_zero_point").check_next(
"prim::dtype" "prim::dtype"
).check_next( ).check_next("aten::quantize_per_tensor").check("aten::dequantize").run(
"aten::quantize_per_tensor"
).check(
"aten::dequantize"
).run(
model.graph model.graph
) )
@ -1757,9 +1733,7 @@ class TestQuantizeJitOps(QuantizationTestCase):
"aten::relu" "aten::relu"
).check_not(f"quantized::conv{dim}d(").check_not( ).check_not(f"quantized::conv{dim}d(").check_not(
"quantized::relu(" "quantized::relu("
).run( ).run(m.graph)
m.graph
)
@skipIfNoFBGEMM @skipIfNoFBGEMM
def test_quantized_add_alpha(self): def test_quantized_add_alpha(self):
@ -1910,9 +1884,7 @@ class TestQuantizeJitOps(QuantizationTestCase):
"aten::relu(" "aten::relu("
).check_not("aten::relu_(").check_not("quantized::add(").check_not( ).check_not("aten::relu_(").check_not("quantized::add(").check_not(
"quantized::relu(" "quantized::relu("
).run( ).run(m.graph)
m.graph
)
@skipIfNoFBGEMM @skipIfNoFBGEMM
def test_quantized_add(self): def test_quantized_add(self):
@ -2119,9 +2091,7 @@ class TestQuantizeJitOps(QuantizationTestCase):
"aten::relu(" "aten::relu("
).check_not("aten::relu_(").check_not("quantized::add(").check_not( ).check_not("aten::relu_(").check_not("quantized::add(").check_not(
"quantized::relu(" "quantized::relu("
).run( ).run(m.graph)
m.graph
)
@skipIfNoFBGEMM @skipIfNoFBGEMM
def test_quantized_add_scalar_relu(self): def test_quantized_add_scalar_relu(self):
@ -2205,11 +2175,7 @@ class TestQuantizeJitOps(QuantizationTestCase):
"aten::relu(" "aten::relu("
).check_not("aten::relu_(").check_not( ).check_not("aten::relu_(").check_not(
"quantized::add_scalar(" "quantized::add_scalar("
).check_not( ).check_not("quantized::relu(").run(m.graph)
"quantized::relu("
).run(
m.graph
)
@skipIfNoFBGEMM @skipIfNoFBGEMM
def test_quantized_cat(self): def test_quantized_cat(self):
@ -2544,9 +2510,7 @@ class TestQuantizeJitOps(QuantizationTestCase):
"aten::relu(" "aten::relu("
).check_not("aten::relu_(").check_not("quantized::mul(").check_not( ).check_not("aten::relu_(").check_not("quantized::mul(").check_not(
"quantized::relu(" "quantized::relu("
).run( ).run(m.graph)
m.graph
)
@skipIfNoFBGEMM @skipIfNoFBGEMM
def test_quantized_mul_scalar_relu(self): def test_quantized_mul_scalar_relu(self):
@ -2629,11 +2593,7 @@ class TestQuantizeJitOps(QuantizationTestCase):
"aten::relu(" "aten::relu("
).check_not("aten::relu_(").check_not( ).check_not("aten::relu_(").check_not(
"quantized::mul_scalar(" "quantized::mul_scalar("
).check_not( ).check_not("quantized::relu(").run(m.graph)
"quantized::relu("
).run(
m.graph
)
@override_qengines @override_qengines
def test_hardswish(self): def test_hardswish(self):
@ -3103,9 +3063,7 @@ class TestQuantizeDynamicJitPasses(QuantizationTestCase):
'Observer = prim::GetAttr[name="_observer_' 'Observer = prim::GetAttr[name="_observer_'
).check("prim::CallMethod").check_not( ).check("prim::CallMethod").check_not(
'Observer = prim::GetAttr[name="_observer_' 'Observer = prim::GetAttr[name="_observer_'
).run( ).run(m.graph)
m.graph
)
def test_insert_quant_dequant_linear_dynamic(self): def test_insert_quant_dequant_linear_dynamic(self):
class M(torch.nn.Module): class M(torch.nn.Module):
@ -3126,9 +3084,9 @@ class TestQuantizeDynamicJitPasses(QuantizationTestCase):
else default_dynamic_qconfig else default_dynamic_qconfig
) )
m = quantize_dynamic_jit(m, {"": qconfig}, debug=True) m = quantize_dynamic_jit(m, {"": qconfig}, debug=True)
assert ( assert len(m._modules._c.items()) == 2, (
len(m._modules._c.items()) == 2 "Expected to have two submodule of linear"
), "Expected to have two submodule of linear" )
wt_quant_func = ( wt_quant_func = (
"aten::quantize_per_channel" "aten::quantize_per_channel"
@ -3141,21 +3099,11 @@ class TestQuantizeDynamicJitPasses(QuantizationTestCase):
act_quant_func act_quant_func
).check_next("aten::dequantize").check( ).check_next("aten::dequantize").check(
"aten::_choose_qparams_per_tensor" "aten::_choose_qparams_per_tensor"
).check_next( ).check_next(act_quant_func).check_next("aten::dequantize").check(
act_quant_func
).check_next(
"aten::dequantize"
).check(
wt_quant_func wt_quant_func
).check_next( ).check_next("aten::dequantize").check_not(wt_quant_func).check(
"aten::dequantize"
).check_not(
wt_quant_func
).check(
"return" "return"
).run( ).run(m.graph)
m.graph
)
@override_qengines @override_qengines
def test_dynamic_multi_op(self): def test_dynamic_multi_op(self):

View File

@ -254,16 +254,16 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
maxpool_node = node maxpool_node = node
input_act = maxpool_node.args[0] input_act = maxpool_node.args[0]
assert isinstance(input_act, Node) assert isinstance(input_act, Node)
maxpool_node.meta[ maxpool_node.meta["quantization_annotation"] = (
"quantization_annotation" QuantizationAnnotation(
] = QuantizationAnnotation( input_qspec_map={
input_qspec_map={ input_act: act_qspec,
input_act: act_qspec, },
}, output_qspec=SharedQuantizationSpec(
output_qspec=SharedQuantizationSpec( (input_act, maxpool_node)
(input_act, maxpool_node) ),
), _annotated=True,
_annotated=True, )
) )
def validate(self, model: torch.fx.GraphModule) -> None: def validate(self, model: torch.fx.GraphModule) -> None:
@ -339,9 +339,9 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
def derive_qparams_fn( def derive_qparams_fn(
obs_or_fqs: list[ObserverOrFakeQuantize], obs_or_fqs: list[ObserverOrFakeQuantize],
) -> tuple[Tensor, Tensor]: ) -> tuple[Tensor, Tensor]:
assert ( assert len(obs_or_fqs) == 2, (
len(obs_or_fqs) == 2 f"Expecting two obs/fqs, one for activation and one for weight, got: {len(obs_or_fqs)}"
), f"Expecting two obs/fqs, one for activation and one for weight, got: {len(obs_or_fqs)}" )
act_obs_or_fq = obs_or_fqs[0] act_obs_or_fq = obs_or_fqs[0]
weight_obs_or_fq = obs_or_fqs[1] weight_obs_or_fq = obs_or_fqs[1]
act_scale, act_zp = act_obs_or_fq.calculate_qparams() act_scale, act_zp = act_obs_or_fq.calculate_qparams()
@ -442,9 +442,9 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
def derive_qparams_fn( def derive_qparams_fn(
obs_or_fqs: list[ObserverOrFakeQuantize], obs_or_fqs: list[ObserverOrFakeQuantize],
) -> tuple[Tensor, Tensor]: ) -> tuple[Tensor, Tensor]:
assert ( assert len(obs_or_fqs) == 1, (
len(obs_or_fqs) == 1 f"Expecting one weight obs/fq, got: {len(obs_or_fqs)}"
), f"Expecting one weight obs/fq, got: {len(obs_or_fqs)}" )
weight_obs_or_fq = obs_or_fqs[0] weight_obs_or_fq = obs_or_fqs[0]
( (
weight_scale, weight_scale,
@ -748,16 +748,16 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
(first_input_node, cat_node) (first_input_node, cat_node)
) )
for input_node in input_nodes[1:]: for input_node in input_nodes[1:]:
input_qspec_map[ input_qspec_map[input_node] = (
input_node share_qparams_with_input_act0_qspec
] = share_qparams_with_input_act0_qspec )
cat_node.meta[ cat_node.meta["quantization_annotation"] = (
"quantization_annotation" QuantizationAnnotation(
] = QuantizationAnnotation( input_qspec_map=input_qspec_map,
input_qspec_map=input_qspec_map, output_qspec=share_qparams_with_input_act0_qspec,
output_qspec=share_qparams_with_input_act0_qspec, _annotated=True,
_annotated=True, )
) )
def validate(self, model: torch.fx.GraphModule) -> None: def validate(self, model: torch.fx.GraphModule) -> None:
@ -783,9 +783,9 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
obs_ins0 = getattr(m, input0.target) obs_ins0 = getattr(m, input0.target)
obs_ins1 = getattr(m, input1.target) obs_ins1 = getattr(m, input1.target)
assert obs_ins0 == obs_ins1 assert obs_ins0 == obs_ins1
assert ( assert len(conv_output_obs) == 2, (
len(conv_output_obs) == 2 "expecting two observer that follows conv2d ops"
), "expecting two observer that follows conv2d ops" )
# checking that the output observers for the two convs are shared as well # checking that the output observers for the two convs are shared as well
assert conv_output_obs[0] == conv_output_obs[1] assert conv_output_obs[0] == conv_output_obs[1]
@ -850,9 +850,9 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
obs_ins2 = getattr(m, output_obs.target) obs_ins2 = getattr(m, output_obs.target)
assert obs_ins0 == obs_ins2, "input observer does not match output" assert obs_ins0 == obs_ins2, "input observer does not match output"
assert ( assert len(conv_output_obs) == 2, (
len(conv_output_obs) == 2 "expecting two observer that follows conv2d ops"
), "expecting two observer that follows conv2d ops" )
# checking that the output observers for the two convs are shared as well # checking that the output observers for the two convs are shared as well
assert conv_output_obs[0] == conv_output_obs[1] assert conv_output_obs[0] == conv_output_obs[1]
@ -967,16 +967,16 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
(first_input_node, cat_node) (first_input_node, cat_node)
) )
for input_node in input_nodes[1:]: for input_node in input_nodes[1:]:
input_qspec_map[ input_qspec_map[input_node] = (
input_node share_qparams_with_input_act0_qspec
] = share_qparams_with_input_act0_qspec )
cat_node.meta[ cat_node.meta["quantization_annotation"] = (
"quantization_annotation" QuantizationAnnotation(
] = QuantizationAnnotation( input_qspec_map=input_qspec_map,
input_qspec_map=input_qspec_map, output_qspec=share_qparams_with_input_act0_qspec,
output_qspec=share_qparams_with_input_act0_qspec, _annotated=True,
_annotated=True, )
) )
def validate(self, model: torch.fx.GraphModule) -> None: def validate(self, model: torch.fx.GraphModule) -> None:
@ -1063,16 +1063,16 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
share_qparams_with_input_act1_qspec = SharedQuantizationSpec( share_qparams_with_input_act1_qspec = SharedQuantizationSpec(
(second_input_node, cat_node) (second_input_node, cat_node)
) )
input_qspec_map[ input_qspec_map[first_input_node] = (
first_input_node share_qparams_with_input_act1_qspec
] = share_qparams_with_input_act1_qspec )
cat_node.meta[ cat_node.meta["quantization_annotation"] = (
"quantization_annotation" QuantizationAnnotation(
] = QuantizationAnnotation( input_qspec_map=input_qspec_map,
input_qspec_map=input_qspec_map, output_qspec=share_qparams_with_input_act1_qspec,
output_qspec=share_qparams_with_input_act1_qspec, _annotated=True,
_annotated=True, )
) )
def validate(self, model: torch.fx.GraphModule) -> None: def validate(self, model: torch.fx.GraphModule) -> None:
@ -1121,17 +1121,17 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
share_qparams_with_input_act1_qspec = SharedQuantizationSpec( share_qparams_with_input_act1_qspec = SharedQuantizationSpec(
(second_input_node, add_node) (second_input_node, add_node)
) )
input_qspec_map[ input_qspec_map[first_input_node] = (
first_input_node share_qparams_with_input_act1_qspec
] = share_qparams_with_input_act1_qspec )
add_node.meta[ add_node.meta["quantization_annotation"] = (
"quantization_annotation" QuantizationAnnotation(
] = QuantizationAnnotation( input_qspec_map=input_qspec_map,
input_qspec_map=input_qspec_map, output_qspec=share_qparams_with_input_act1_qspec,
output_qspec=share_qparams_with_input_act1_qspec, allow_implicit_sharing=False,
allow_implicit_sharing=False, _annotated=True,
_annotated=True, )
) )
def validate(self, model: torch.fx.GraphModule) -> None: def validate(self, model: torch.fx.GraphModule) -> None:

View File

@ -277,9 +277,9 @@ class PT2EQATTestCase(QuantizationTestCase):
# Verify: conv literal args # Verify: conv literal args
if expected_conv_literal_args is not None: if expected_conv_literal_args is not None:
assert ( assert len(expected_conv_literal_args) == 6, (
len(expected_conv_literal_args) == 6 "wrong num conv args, bad test setup"
), "wrong num conv args, bad test setup" )
for i in range(6): for i in range(6):
if i + 3 < len(conv_node.args): if i + 3 < len(conv_node.args):
self.assertEqual( self.assertEqual(

View File

@ -157,9 +157,9 @@ async def run1(coroutine_id):
gpuid = coroutine_id % GPUS gpuid = coroutine_id % GPUS
else: else:
gpu_assignments = args.gpus.split(":") gpu_assignments = args.gpus.split(":")
assert args.nproc == len( assert args.nproc == len(gpu_assignments), (
gpu_assignments "Please specify GPU assignment for each process, separated by :"
), "Please specify GPU assignment for each process, separated by :" )
gpuid = gpu_assignments[coroutine_id] gpuid = gpu_assignments[coroutine_id]
while progress < len(ALL_TESTS): while progress < len(ALL_TESTS):

View File

@ -1,8 +1,7 @@
# Owner(s): ["module: dynamo"] # Owner(s): ["module: dynamo"]
""" Test functions for limits module. """Test functions for limits module."""
"""
import functools import functools
import warnings import warnings
from unittest import expectedFailure as xfail, skipIf from unittest import expectedFailure as xfail, skipIf

View File

@ -4104,6 +4104,7 @@ class TestIO(TestCase):
def test_decimal_period_separator(): def test_decimal_period_separator():
pass pass
def test_decimal_comma_separator(): def test_decimal_comma_separator():
with CommaDecimalPointLocale(): with CommaDecimalPointLocale():
pass pass
@ -6786,7 +6787,10 @@ class TestWritebackIfCopy(TestCase):
class TestArange(TestCase): class TestArange(TestCase):
def test_infinite(self): def test_infinite(self):
assert_raises( assert_raises(
(RuntimeError, ValueError), np.arange, 0, np.inf # "unsupported range", (RuntimeError, ValueError),
np.arange,
0,
np.inf, # "unsupported range",
) )
def test_nan_step(self): def test_nan_step(self):

View File

@ -2733,10 +2733,18 @@ class TestMoveaxis(TestCase):
assert_raises(np.AxisError, np.moveaxis, x, 3, 0) # 'source.*out of bounds', assert_raises(np.AxisError, np.moveaxis, x, 3, 0) # 'source.*out of bounds',
assert_raises(np.AxisError, np.moveaxis, x, -4, 0) # 'source.*out of bounds', assert_raises(np.AxisError, np.moveaxis, x, -4, 0) # 'source.*out of bounds',
assert_raises( assert_raises(
np.AxisError, np.moveaxis, x, 0, 5 # 'destination.*out of bounds', np.AxisError,
np.moveaxis,
x,
0,
5, # 'destination.*out of bounds',
) )
assert_raises( assert_raises(
ValueError, np.moveaxis, x, [0, 0], [0, 1] # 'repeated axis in `source`', ValueError,
np.moveaxis,
x,
[0, 0],
[0, 1], # 'repeated axis in `source`',
) )
assert_raises( assert_raises(
ValueError, # 'repeated axis in `destination`', ValueError, # 'repeated axis in `destination`',

View File

@ -3,6 +3,7 @@
""" """
Test the scalar constructors, which also do type-coercion Test the scalar constructors, which also do type-coercion
""" """
import functools import functools
from unittest import skipIf as skipif from unittest import skipIf as skipif

View File

@ -3,6 +3,7 @@
""" """
Test the scalar constructors, which also do type-coercion Test the scalar constructors, which also do type-coercion
""" """
import fractions import fractions
import functools import functools
import types import types

View File

@ -1,8 +1,7 @@
# Owner(s): ["module: dynamo"] # Owner(s): ["module: dynamo"]
""" Test printing of scalar types. """Test printing of scalar types."""
"""
import functools import functools
from unittest import skipIf as skipif from unittest import skipIf as skipif

View File

@ -811,7 +811,10 @@ class TestBlock(TestCase):
assert_raises_regex(ValueError, msg, block, [[1], 2]) assert_raises_regex(ValueError, msg, block, [[1], 2])
assert_raises_regex(ValueError, msg, block, [[], 2]) assert_raises_regex(ValueError, msg, block, [[], 2])
assert_raises_regex( assert_raises_regex(
ValueError, msg, block, [[[1], [2]], [[3, 4]], [5]] # missing brackets ValueError,
msg,
block,
[[[1], [2]], [[3, 4]], [5]], # missing brackets
) )
def test_empty_lists(self, block): def test_empty_lists(self, block):

View File

@ -5,6 +5,7 @@
Copied from fftpack.helper by Pearu Peterson, October 2005 Copied from fftpack.helper by Pearu Peterson, October 2005
""" """
from torch.testing._internal.common_utils import ( from torch.testing._internal.common_utils import (
run_tests, run_tests,
TEST_WITH_TORCHDYNAMO, TEST_WITH_TORCHDYNAMO,

View File

@ -372,7 +372,7 @@ class TestFFTThreadSafe(TestCase):
assert_allclose( assert_allclose(
q.get(timeout=5), q.get(timeout=5),
expected, expected,
atol=2e-14 atol=2e-14,
# msg="Function returned wrong value in multithreaded context", # msg="Function returned wrong value in multithreaded context",
) )

View File

@ -1,8 +1,7 @@
# Owner(s): ["module: dynamo"] # Owner(s): ["module: dynamo"]
"""Test functions for 1D array set operations. """Test functions for 1D array set operations."""
"""
from unittest import expectedFailure as xfail, skipIf from unittest import expectedFailure as xfail, skipIf
import numpy import numpy

View File

@ -2881,7 +2881,8 @@ class TestPercentile(TestCase):
np.testing.assert_equal(res.dtype, arr.dtype) np.testing.assert_equal(res.dtype, arr.dtype)
H_F_TYPE_CODES = [ H_F_TYPE_CODES = [
(int_type, np.float64) for int_type in "Bbhil" # np.typecodes["AllInteger"] (int_type, np.float64)
for int_type in "Bbhil" # np.typecodes["AllInteger"]
] + [ ] + [
(np.float16, np.float16), (np.float16, np.float16),
(np.float32, np.float32), (np.float32, np.float32),

View File

@ -505,8 +505,7 @@ class TestHistogramOptimBinNums(TestCase):
assert_equal( assert_equal(
len(a), len(a),
numbins, numbins,
err_msg=f"For the {estimator} estimator " err_msg=f"For the {estimator} estimator with datasize of {testlen}",
f"with datasize of {testlen}",
) )
def test_small(self): def test_small(self):
@ -552,8 +551,7 @@ class TestHistogramOptimBinNums(TestCase):
assert_equal( assert_equal(
len(a), len(a),
expbins, expbins,
err_msg=f"For the {estimator} estimator " err_msg=f"For the {estimator} estimator with datasize of {testlen}",
f"with datasize of {testlen}",
) )
def test_incorrect_methods(self): def test_incorrect_methods(self):

View File

@ -1,8 +1,7 @@
# Owner(s): ["module: dynamo"] # Owner(s): ["module: dynamo"]
"""Test functions for matrix module """Test functions for matrix module"""
"""
import functools import functools
from unittest import expectedFailure as xfail, skipIf as skipif from unittest import expectedFailure as xfail, skipIf as skipif

View File

@ -1,7 +1,6 @@
# Owner(s): ["module: dynamo"] # Owner(s): ["module: dynamo"]
""" Test functions for linalg module """Test functions for linalg module"""
"""
import functools import functools
import itertools import itertools
import os import os

View File

@ -1,7 +1,7 @@
# Owner(s): ["module: dynamo"] # Owner(s): ["module: dynamo"]
"""Light smoke test switching between numpy to pytorch random streams. """Light smoke test switching between numpy to pytorch random streams."""
"""
from contextlib import contextmanager from contextlib import contextmanager
from functools import partial from functools import partial

View File

@ -9,6 +9,7 @@ The goal is to validate on numpy, and tests should work when replacing
by by
>>> import torch._numpy as np >>> import torch._numpy as np
""" """
import operator import operator
from unittest import skipIf as skip, SkipTest from unittest import skipIf as skip, SkipTest

View File

@ -39,12 +39,9 @@ USE_BLACK_FILELIST = re.compile(
# test/** # test/**
# test/[a-h]*/** # test/[a-h]*/**
# test/[i-j]*/** # test/[i-j]*/**
"test/j*/**",
# test/[k-m]*/** # test/[k-m]*/**
"test/[k-m]*/**",
# test/optim/** # test/optim/**
# "test/[p-z]*/**", # test/[p-z]*/**,
"test/[p-z]*/**",
# torch/** # torch/**
# torch/_[a-c]*/** # torch/_[a-c]*/**
# torch/_[e-h]*/** # torch/_[e-h]*/**