mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Enable ruff rule E721 (#165162)
`E721` checks for object type comparisons using == and other comparison operators. This is useful because it is recommended to use `is` for type comparisons. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165162 Approved by: https://github.com/Skylion007
This commit is contained in:
parent
220a34118f
commit
9e7c19f72b
|
|
@ -367,7 +367,7 @@ class DeepSpeech(nn.Module):
|
|||
"""
|
||||
seq_len = input_length
|
||||
for m in self.conv.modules():
|
||||
if type(m) == nn.modules.conv.Conv2d:
|
||||
if type(m) is nn.modules.conv.Conv2d:
|
||||
seq_len = (
|
||||
seq_len
|
||||
+ 2 * m.padding[1]
|
||||
|
|
|
|||
|
|
@ -66,7 +66,7 @@ class GroupedSetup:
|
|||
|
||||
def __post_init__(self) -> None:
|
||||
for field in dataclasses.fields(self):
|
||||
assert field.type == str
|
||||
assert field.type is str
|
||||
value: str = getattr(self, field.name)
|
||||
object.__setattr__(self, field.name, textwrap.dedent(value))
|
||||
|
||||
|
|
|
|||
|
|
@ -113,7 +113,7 @@ class TorchBenchmarkBase(torch.nn.Module):
|
|||
value = kargs[key]
|
||||
test_name_str.append(
|
||||
("" if key in skip_key_list else key)
|
||||
+ str(value if type(value) != bool else int(value))
|
||||
+ str(value if type(value) is not bool else int(value))
|
||||
)
|
||||
name = (self.module_name() + "_" + "_".join(test_name_str)).replace(" ", "")
|
||||
return name
|
||||
|
|
|
|||
|
|
@ -125,7 +125,7 @@ class CatBenchmark(op_bench.TorchBenchmarkBase):
|
|||
random.seed(42)
|
||||
inputs = []
|
||||
gen_sizes = []
|
||||
if type(sizes) == list and N == -1:
|
||||
if type(sizes) is list and N == -1:
|
||||
gen_sizes = sizes
|
||||
else:
|
||||
for i in range(N):
|
||||
|
|
|
|||
|
|
@ -61,7 +61,7 @@ class StackBenchmark(op_bench.TorchBenchmarkBase):
|
|||
random.seed(42)
|
||||
inputs = []
|
||||
gen_sizes = []
|
||||
if type(sizes) == list and N == -1:
|
||||
if type(sizes) is list and N == -1:
|
||||
gen_sizes = sizes
|
||||
else:
|
||||
for i in range(N):
|
||||
|
|
|
|||
|
|
@ -155,7 +155,6 @@ ignore = [
|
|||
"E402",
|
||||
"C408", # C408 ignored because we like the dict keyword argument syntax
|
||||
"E501", # E501 is not flexible enough, we're using B950 instead
|
||||
"E721",
|
||||
"E741",
|
||||
"EXE001",
|
||||
"F405",
|
||||
|
|
|
|||
|
|
@ -243,7 +243,7 @@ class TestActivationSparsifier(TestCase):
|
|||
if mask1 is None:
|
||||
assert mask2 is None
|
||||
else:
|
||||
assert type(mask1) == type(mask2)
|
||||
assert type(mask1) is type(mask2)
|
||||
if isinstance(mask1, list):
|
||||
assert len(mask1) == len(mask2)
|
||||
for idx in range(len(mask1)):
|
||||
|
|
|
|||
|
|
@ -710,15 +710,15 @@ class TestQuantizationUtils(TestCase):
|
|||
**sparse_config,
|
||||
)
|
||||
|
||||
assert type(model.emb1) == torch.ao.nn.quantized.modules.embedding_ops.Embedding
|
||||
assert type(model.emb1) is torch.ao.nn.quantized.modules.embedding_ops.Embedding
|
||||
assert (
|
||||
type(model.embbag1)
|
||||
== torch.ao.nn.quantized.modules.embedding_ops.EmbeddingBag
|
||||
is torch.ao.nn.quantized.modules.embedding_ops.EmbeddingBag
|
||||
)
|
||||
assert type(model.emb_seq[0] == nn.Embedding)
|
||||
assert type(model.emb_seq[1] == nn.EmbeddingBag)
|
||||
assert type(model.linear1) == nn.Linear
|
||||
assert type(model.linear2) == nn.Linear
|
||||
assert type(model.emb_seq[0] is nn.Embedding)
|
||||
assert type(model.emb_seq[1] is nn.EmbeddingBag)
|
||||
assert type(model.linear1) is nn.Linear
|
||||
assert type(model.linear2) is nn.Linear
|
||||
|
||||
dequant_emb1 = torch.dequantize(model.emb1.weight())
|
||||
dequant_embbag1 = torch.dequantize(model.embbag1.weight())
|
||||
|
|
@ -749,19 +749,21 @@ class TestQuantizationUtils(TestCase):
|
|||
model, DataNormSparsifier, sparsify_first=False, **sparse_config
|
||||
)
|
||||
|
||||
assert type(model.emb1) == torch.ao.nn.quantized.modules.embedding_ops.Embedding
|
||||
assert type(model.emb1) is torch.ao.nn.quantized.modules.embedding_ops.Embedding
|
||||
assert (
|
||||
type(model.embbag1)
|
||||
== torch.ao.nn.quantized.modules.embedding_ops.EmbeddingBag
|
||||
is torch.ao.nn.quantized.modules.embedding_ops.EmbeddingBag
|
||||
)
|
||||
assert type(
|
||||
model.emb_seq[0] == torch.ao.nn.quantized.modules.embedding_ops.Embedding
|
||||
assert (
|
||||
type(model.emb_seq[0])
|
||||
is torch.ao.nn.quantized.modules.embedding_ops.Embedding
|
||||
)
|
||||
assert type(
|
||||
model.emb_seq[1] == torch.ao.nn.quantized.modules.embedding_ops.EmbeddingBag
|
||||
assert (
|
||||
type(model.emb_seq[1])
|
||||
is torch.ao.nn.quantized.modules.embedding_ops.EmbeddingBag
|
||||
)
|
||||
assert type(model.linear1) == nn.Linear # not quantized
|
||||
assert type(model.linear2) == nn.Linear # not quantized
|
||||
assert type(model.linear1) is nn.Linear # not quantized
|
||||
assert type(model.linear2) is nn.Linear # not quantized
|
||||
|
||||
dequant_emb1 = torch.dequantize(model.emb1.weight())
|
||||
dequant_embbag1 = torch.dequantize(model.embbag1.weight())
|
||||
|
|
|
|||
|
|
@ -291,7 +291,7 @@ class TestWeightNormSparsifier(TestCase):
|
|||
assert hasattr(module.parametrizations["weight"][0], "mask")
|
||||
# Check parametrization exists and is correct
|
||||
assert is_parametrized(module, "weight")
|
||||
assert type(module.parametrizations.weight[0]) == FakeSparsity
|
||||
assert type(module.parametrizations.weight[0]) is FakeSparsity
|
||||
|
||||
def test_mask_squash(self):
|
||||
model = SimpleLinear()
|
||||
|
|
@ -415,7 +415,7 @@ class TestNearlyDiagonalSparsifier(TestCase):
|
|||
assert hasattr(module.parametrizations["weight"][0], "mask")
|
||||
# Check parametrization exists and is correct
|
||||
assert is_parametrized(module, "weight")
|
||||
assert type(module.parametrizations.weight[0]) == FakeSparsity
|
||||
assert type(module.parametrizations.weight[0]) is FakeSparsity
|
||||
|
||||
def test_mask_squash(self):
|
||||
model = SimpleLinear()
|
||||
|
|
|
|||
|
|
@ -158,7 +158,7 @@ class TestBaseStructuredSparsifier(TestCase):
|
|||
assert parametrize.is_parametrized(module)
|
||||
assert hasattr(module, "parametrizations")
|
||||
# Assume that this is the 1st/only parametrization
|
||||
assert type(module.parametrizations.weight[0]) == FakeStructuredSparsity
|
||||
assert type(module.parametrizations.weight[0]) is FakeStructuredSparsity
|
||||
|
||||
def _check_pruner_valid_before_step(self, model, pruner, device):
|
||||
for config in pruner.groups:
|
||||
|
|
|
|||
|
|
@ -116,7 +116,7 @@ class TestTensorType(TestCase):
|
|||
|
||||
for dtype, str in dtypes_map.items():
|
||||
x = torch.empty(4, 4, dtype=dtype, device="openreg")
|
||||
self.assertTrue(x.type() == str)
|
||||
self.assertTrue(x.type() is str)
|
||||
|
||||
# Note that all dtype-d Tensor objects here are only for legacy reasons
|
||||
# and should NOT be used.
|
||||
|
|
|
|||
|
|
@ -134,7 +134,7 @@ def compare_state_dicts(gpu_state_dict, cpu_state_dict, rtol=1e-5, atol=1e-8):
|
|||
False,
|
||||
f"Collection length mismatch at {path}: {len(gpu_obj)} vs {len(cpu_obj)}",
|
||||
)
|
||||
if type(gpu_obj) != type(cpu_obj):
|
||||
if type(gpu_obj) is not type(cpu_obj):
|
||||
return (
|
||||
False,
|
||||
f"Collection type mismatch at {path}: {type(gpu_obj)} vs {type(cpu_obj)}",
|
||||
|
|
@ -149,7 +149,7 @@ def compare_state_dicts(gpu_state_dict, cpu_state_dict, rtol=1e-5, atol=1e-8):
|
|||
|
||||
# If objects are custom classes, compare their attributes
|
||||
elif hasattr(gpu_obj, "__dict__") and hasattr(cpu_obj, "__dict__"):
|
||||
if type(gpu_obj) != type(cpu_obj):
|
||||
if type(gpu_obj) is not type(cpu_obj):
|
||||
return (
|
||||
False,
|
||||
f"Object type mismatch at {path}: {type(gpu_obj)} vs {type(cpu_obj)}",
|
||||
|
|
@ -165,7 +165,7 @@ def compare_state_dicts(gpu_state_dict, cpu_state_dict, rtol=1e-5, atol=1e-8):
|
|||
|
||||
# For other types, use direct equality comparison
|
||||
else:
|
||||
if type(gpu_obj) != type(cpu_obj):
|
||||
if type(gpu_obj) is not type(cpu_obj):
|
||||
return (
|
||||
False,
|
||||
f"Type mismatch at {path}: {type(gpu_obj)} vs {type(cpu_obj)}",
|
||||
|
|
|
|||
|
|
@ -44,14 +44,14 @@ class TestApply(FSDPTest):
|
|||
|
||||
@torch.no_grad()
|
||||
def _init_linear_weights(self, m):
|
||||
if type(m) == nn.Linear:
|
||||
if type(m) is nn.Linear:
|
||||
m.weight.fill_(1.0)
|
||||
m.bias.fill_(1.0)
|
||||
|
||||
def check_weights(self, fsdp, expected_tensor_fn, check):
|
||||
with FSDP.summon_full_params(fsdp, recurse=True):
|
||||
linear_modules = [
|
||||
module for module in fsdp.modules() if type(module) == nn.Linear
|
||||
module for module in fsdp.modules() if type(module) is nn.Linear
|
||||
]
|
||||
for module in linear_modules:
|
||||
for param in module.parameters():
|
||||
|
|
|
|||
|
|
@ -1021,7 +1021,7 @@ class TestFSDPMiscWorldSize1(FSDPTestMultiThread):
|
|||
)
|
||||
for warning in w:
|
||||
self.assertTrue(
|
||||
warning.category != UserWarning
|
||||
warning.category is not UserWarning
|
||||
or not str(warning.message).startswith(warning_prefix)
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -421,7 +421,7 @@ class TestFSDPOptimState(FSDPTest):
|
|||
return False
|
||||
for state_name, value1 in state1.items():
|
||||
value2 = state2[state_name]
|
||||
if type(value1) != type(value2):
|
||||
if type(value1) is not type(value2):
|
||||
return False
|
||||
if torch.is_tensor(value1): # tensor state
|
||||
assert torch.is_tensor(value2)
|
||||
|
|
|
|||
|
|
@ -5887,7 +5887,7 @@ class TestKL(DistributionsTestCase):
|
|||
|
||||
def test_kl_exponential_family(self):
|
||||
for (p, _), (_, q) in self.finite_examples:
|
||||
if type(p) == type(q) and issubclass(type(p), ExponentialFamily):
|
||||
if type(p) is type(q) and issubclass(type(p), ExponentialFamily):
|
||||
actual = kl_divergence(p, q)
|
||||
expected = _kl_expfamily_expfamily(p, q)
|
||||
self.assertEqual(
|
||||
|
|
|
|||
|
|
@ -3370,9 +3370,9 @@ utils_device.CURRENT_DEVICE == None""".split("\n"):
|
|||
# Test on non autocast state and autocast cache states.
|
||||
self.assertIn("autocast_state", json_guards)
|
||||
for key, value in json_guards.items():
|
||||
if type(value) == int:
|
||||
if type(value) is int:
|
||||
variant = value + 1
|
||||
elif type(value) == bool:
|
||||
elif type(value) is bool:
|
||||
variant = not value
|
||||
elif isinstance(value, dict) and key == "autocast_state":
|
||||
variant = value.copy()
|
||||
|
|
|
|||
|
|
@ -59,7 +59,7 @@ class SourceTests(torch._dynamo.test_case.TestCase):
|
|||
def forward(self):
|
||||
if (
|
||||
torch.utils._pytree.SUPPORTED_NODES[CausalLMOutputWithPast].type
|
||||
== int
|
||||
is int
|
||||
):
|
||||
x = torch.sin(self.x)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -662,7 +662,7 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
|
|||
"comparison",
|
||||
[
|
||||
subtest(isinstance, "isinstance"),
|
||||
subtest(lambda instance, type_: type(instance) == type_, "equality"),
|
||||
subtest(lambda instance, type_: type(instance) is type_, "equality"),
|
||||
subtest(lambda instance, type_: type(instance) is type_, "identity"),
|
||||
],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ class PreDispatchSchemaCheckMode(SchemaCheckMode):
|
|||
|
||||
def _may_alias_or_mutate(self, func, types, args, kwargs):
|
||||
def unwrap(e):
|
||||
if isinstance(e, torch.Tensor) and not type(e) == torch.Tensor:
|
||||
if isinstance(e, torch.Tensor) and type(e) is not torch.Tensor:
|
||||
try:
|
||||
return e.elem
|
||||
except AttributeError:
|
||||
|
|
|
|||
|
|
@ -128,7 +128,7 @@ def run_with_nativert(ep):
|
|||
flat_results = pytree.tree_leaves(results)
|
||||
assert len(flat_results) == len(flat_expected)
|
||||
for result, expected in zip(flat_results, flat_expected):
|
||||
assert type(result) == type(expected)
|
||||
assert type(result) is type(expected)
|
||||
if isinstance(result, torch.Tensor) and isinstance(expected, torch.Tensor):
|
||||
assert result.shape == expected.shape
|
||||
assert result.dtype == expected.dtype
|
||||
|
|
@ -323,7 +323,7 @@ class TestNativeRT(TestCase):
|
|||
flat_results = pytree.tree_leaves(results)
|
||||
assert len(flat_results) == len(flat_expected)
|
||||
for result, expected in zip(flat_results, flat_expected):
|
||||
assert type(result) == type(expected)
|
||||
assert type(result) is type(expected)
|
||||
if isinstance(result, torch.Tensor) and isinstance(
|
||||
expected, torch.Tensor
|
||||
):
|
||||
|
|
|
|||
|
|
@ -82,7 +82,7 @@ class TestSerialize(TestCase):
|
|||
return 0
|
||||
|
||||
def __eq__(self, other):
|
||||
return type(other) == type(self)
|
||||
return type(other) is type(self)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return torch.ops.aten.add.Tensor(*args, **kwargs)
|
||||
|
|
|
|||
|
|
@ -6332,7 +6332,7 @@ def forward(self, tangents_1, tangents_2):
|
|||
self.assertEqual(out_ref[0].b, out_test[0].b)
|
||||
self.assertEqual(out_ref[1], out_test[1])
|
||||
|
||||
# We compiled our graph assuming type(grad_out[1]) == torch.Tensor,
|
||||
# We compiled our graph assuming type(grad_out[1]) is torch.Tensor,
|
||||
# but we were wrong: in the below tests, it is a subclass.
|
||||
# This will eventually require a repartition + recompile
|
||||
with self.assertRaisesRegex(
|
||||
|
|
|
|||
|
|
@ -3671,7 +3671,7 @@ class AssociativeScanModels:
|
|||
# Check if val is a list and if it has the same length as combine_fn
|
||||
# If so, then use the individual elements.
|
||||
# If not, duplicate the first element.
|
||||
if type(val) == list and len(val) == chain_len:
|
||||
if type(val) is list and len(val) == chain_len:
|
||||
kwargs_el[key] = val[ind]
|
||||
else:
|
||||
kwargs_el[key] = val
|
||||
|
|
|
|||
|
|
@ -296,7 +296,7 @@ class TestSplitOutputType(TestCase):
|
|||
gm_output = module(inputs)
|
||||
split_gm_output = split_gm(inputs)
|
||||
|
||||
self.assertTrue(type(gm_output) == type(split_gm_output))
|
||||
self.assertTrue(type(gm_output) is type(split_gm_output))
|
||||
self.assertTrue(torch.equal(gm_output, split_gm_output))
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -514,8 +514,8 @@ class TestSubgraphRewriter(JitTestCase):
|
|||
symbolic_traced: torch.fx.GraphModule = symbolic_trace(module)
|
||||
for n, m in zip(symbolic_traced.graph.nodes, graph.nodes):
|
||||
if n.op == "placeholder":
|
||||
assert n.type == int
|
||||
assert m.type == int
|
||||
assert n.type is int
|
||||
assert m.type is int
|
||||
|
||||
def test_subgraph_rewriter_replace_consecutive_submodules(self):
|
||||
def f(x):
|
||||
|
|
|
|||
|
|
@ -81,9 +81,9 @@ class BinaryFoldingTemplate(TestCase):
|
|||
out_optimized = torch.compile(mod_eager)
|
||||
|
||||
inps = [4, 3, 4]
|
||||
if module == nn.Conv2d:
|
||||
if module is nn.Conv2d:
|
||||
inps.append(inps[-1])
|
||||
if module == nn.Conv3d:
|
||||
if module is nn.Conv3d:
|
||||
inps.append(inps[-1])
|
||||
inps.append(inps[-1])
|
||||
|
||||
|
|
@ -195,9 +195,9 @@ class BinaryFoldingTemplate(TestCase):
|
|||
)
|
||||
|
||||
inps = [4, 3, 4]
|
||||
if module[0] == nn.Conv2d:
|
||||
if module[0] is nn.Conv2d:
|
||||
inps.append(inps[-1])
|
||||
if module[0] == nn.Conv3d:
|
||||
if module[0] is nn.Conv3d:
|
||||
inps.append(inps[-1])
|
||||
inps.append(inps[-1])
|
||||
|
||||
|
|
|
|||
|
|
@ -106,9 +106,9 @@ class TestMixin:
|
|||
return keys
|
||||
|
||||
def key(self: Self, key_type: type[icache.Key]) -> icache.Key:
|
||||
if key_type == str:
|
||||
if key_type is str:
|
||||
return f"s{randint(0, 2**32)}"
|
||||
elif key_type == int:
|
||||
elif key_type is int:
|
||||
return randint(0, 2**32)
|
||||
elif key_type == tuple[Any, ...]:
|
||||
return (self.key(str), self.key(int))
|
||||
|
|
@ -125,13 +125,13 @@ class TestMixin:
|
|||
return values
|
||||
|
||||
def value(self: Self, value_type: type[icache.Value]) -> icache.Value:
|
||||
if value_type == str:
|
||||
if value_type is str:
|
||||
return f"s{randint(0, 2**32)}"
|
||||
elif value_type == int:
|
||||
elif value_type is int:
|
||||
return randint(0, 2**32)
|
||||
elif value_type == tuple[Any, ...]:
|
||||
return (self.value(str), self.value(int))
|
||||
elif value_type == bytes:
|
||||
elif value_type is bytes:
|
||||
return self.value(str).encode()
|
||||
elif value_type == dict[Any, Any]:
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -88,7 +88,7 @@ def _check_if_instances_equal(op1, op2) -> bool:
|
|||
if isinstance(op1, (list | tuple)):
|
||||
return tuple(op1) == tuple(op2)
|
||||
|
||||
if type(op1) != type(op2):
|
||||
if type(op1) is not type(op2):
|
||||
return False
|
||||
|
||||
# some classes have __eq__ defined but they may be insufficient
|
||||
|
|
|
|||
|
|
@ -127,11 +127,11 @@ class EfficientConvBNEvalTemplate(TestCase):
|
|||
spatial_d = (
|
||||
4 if issubclass(module[0], nn.modules.conv._ConvTransposeNd) else 96
|
||||
)
|
||||
if module[0] == nn.Conv1d or module[0] == nn.ConvTranspose1d:
|
||||
if module[0] is nn.Conv1d or module[0] is nn.ConvTranspose1d:
|
||||
inps += [spatial_d] * 1
|
||||
if module[0] == nn.Conv2d or module[0] == nn.ConvTranspose2d:
|
||||
if module[0] is nn.Conv2d or module[0] is nn.ConvTranspose2d:
|
||||
inps += [spatial_d] * 2
|
||||
if module[0] == nn.Conv3d or module[0] == nn.ConvTranspose3d:
|
||||
if module[0] is nn.Conv3d or module[0] is nn.ConvTranspose3d:
|
||||
inps += [spatial_d] * 3
|
||||
inp = torch.rand(inps).to(self.device)
|
||||
|
||||
|
|
|
|||
|
|
@ -514,11 +514,11 @@ def check_model(
|
|||
# print("Graph", graph)
|
||||
if check_has_compiled:
|
||||
assert called, "Ran graph without calling compile_fx"
|
||||
assert type(actual) == type(correct)
|
||||
assert type(actual) is type(correct)
|
||||
if isinstance(actual, (tuple, list)):
|
||||
assert len(actual) == len(correct)
|
||||
assert all(
|
||||
type(actual_item) == type(correct_item)
|
||||
type(actual_item) is type(correct_item)
|
||||
for actual_item, correct_item in zip(actual, correct)
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -198,7 +198,7 @@ class TestUtils(TestCase):
|
|||
@dtypes(torch.float16, torch.bfloat16, torch.float32)
|
||||
def test_get_device_tflops(self, dtype):
|
||||
ret = get_device_tflops(dtype)
|
||||
self.assertTrue(type(ret) == float)
|
||||
self.assertTrue(type(ret) is float)
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestUtils, globals())
|
||||
|
|
|
|||
|
|
@ -2083,9 +2083,9 @@ class TestFrozenOptimizations(JitTestCase):
|
|||
|
||||
mod_eager = ConvBN(3, 32, kernel_size=3, stride=2).eval()
|
||||
inps = [4, 3, 4]
|
||||
if modules[0] == nn.Conv2d:
|
||||
if modules[0] is nn.Conv2d:
|
||||
inps.append(inps[-1])
|
||||
if modules[0] == nn.Conv3d:
|
||||
if modules[0] is nn.Conv3d:
|
||||
inps.append(inps[-1])
|
||||
inps.append(inps[-1])
|
||||
|
||||
|
|
@ -2224,9 +2224,9 @@ class TestFrozenOptimizations(JitTestCase):
|
|||
mod_eager = ConvOp(3, 32, kernel_size=3, stride=2).eval()
|
||||
|
||||
inps = [4, 3, 4]
|
||||
if module == nn.Conv2d:
|
||||
if module is nn.Conv2d:
|
||||
inps.append(inps[-1])
|
||||
if module == nn.Conv3d:
|
||||
if module is nn.Conv3d:
|
||||
inps.append(inps[-1])
|
||||
inps.append(inps[-1])
|
||||
|
||||
|
|
@ -2366,10 +2366,10 @@ class TestFrozenOptimizations(JitTestCase):
|
|||
mod_eager = LinearBN(32, 32).eval()
|
||||
|
||||
inps = [3, 32]
|
||||
if modules[1] == nn.BatchNorm2d:
|
||||
if modules[1] is nn.BatchNorm2d:
|
||||
inps.append(inps[-1])
|
||||
inps.append(inps[-1])
|
||||
if modules[1] == nn.BatchNorm3d:
|
||||
if modules[1] is nn.BatchNorm3d:
|
||||
inps.append(inps[-1])
|
||||
inps.append(inps[-1])
|
||||
inps.append(inps[-1])
|
||||
|
|
@ -2429,14 +2429,14 @@ class TestFrozenOptimizations(JitTestCase):
|
|||
|
||||
N, C = 3, bn_in
|
||||
input_shape = [N, C]
|
||||
if modules[1] == nn.BatchNorm1d:
|
||||
if modules[1] is nn.BatchNorm1d:
|
||||
H = linear_in
|
||||
input_shape.append(H)
|
||||
elif modules[1] == nn.BatchNorm2d:
|
||||
elif modules[1] is nn.BatchNorm2d:
|
||||
H, W = 4, linear_in
|
||||
input_shape.append(H)
|
||||
input_shape.append(W)
|
||||
elif modules[1] == nn.BatchNorm3d:
|
||||
elif modules[1] is nn.BatchNorm3d:
|
||||
D, H, W = 4, 4, linear_in
|
||||
input_shape.append(D)
|
||||
input_shape.append(H)
|
||||
|
|
@ -2504,10 +2504,10 @@ class TestFrozenOptimizations(JitTestCase):
|
|||
mod_eager = LinearBN(32, 32).cuda().eval()
|
||||
|
||||
inps = [3, 32]
|
||||
if modules[1] == nn.BatchNorm2d:
|
||||
if modules[1] is nn.BatchNorm2d:
|
||||
inps.append(inps[-1])
|
||||
inps.append(inps[-1])
|
||||
if modules[1] == nn.BatchNorm3d:
|
||||
if modules[1] is nn.BatchNorm3d:
|
||||
inps.append(inps[-1])
|
||||
inps.append(inps[-1])
|
||||
inps.append(inps[-1])
|
||||
|
|
@ -2757,9 +2757,9 @@ class TestFrozenOptimizations(JitTestCase):
|
|||
for module, trace in product([nn.Conv2d, nn.Conv3d], [False, True]):
|
||||
mod = module(3, 32, kernel_size=3, stride=2).eval()
|
||||
inps = [4, 3, 4]
|
||||
if module == nn.Conv2d:
|
||||
if module is nn.Conv2d:
|
||||
inps.append(inps[-1])
|
||||
if module == nn.Conv3d:
|
||||
if module is nn.Conv3d:
|
||||
inps.append(inps[-1])
|
||||
inps.append(inps[-1])
|
||||
|
||||
|
|
@ -2997,7 +2997,7 @@ class TestFrozenOptimizations(JitTestCase):
|
|||
mod_eager = Net(3, 6, kernel_size=3, stride=2).eval().cuda()
|
||||
|
||||
inps = [5, 3, 4, 4]
|
||||
if conv == nn.Conv3d:
|
||||
if conv is nn.Conv3d:
|
||||
inps.append(inps[-1])
|
||||
inp = torch.rand(inps).cuda()
|
||||
|
||||
|
|
|
|||
|
|
@ -210,7 +210,7 @@ class TestTyping(JitTestCase):
|
|||
li_1, li_2, li_3 = stuff4([True])
|
||||
li_3 = li_3[0]
|
||||
for li in [li_1, li_2, li_3]:
|
||||
self.assertTrue(type(li[0]) == bool)
|
||||
self.assertTrue(type(li[0]) is bool)
|
||||
|
||||
def test_nested_list(self):
|
||||
def foo(z):
|
||||
|
|
|
|||
|
|
@ -3839,9 +3839,9 @@ class TestConvolutionNNDeviceType(NNTestCase):
|
|||
# This is because we have N111 weight that cannot handle
|
||||
# the ambiguous memory_format
|
||||
if w_f == torch.channels_last:
|
||||
if layer == nn.Conv2d and filter_size * c != 1:
|
||||
if layer is nn.Conv2d and filter_size * c != 1:
|
||||
output_format = torch.channels_last
|
||||
if layer == nn.ConvTranspose2d and filter_size * k != 1:
|
||||
if layer is nn.ConvTranspose2d and filter_size * k != 1:
|
||||
output_format = torch.channels_last
|
||||
self._run_conv(
|
||||
layer,
|
||||
|
|
|
|||
|
|
@ -474,8 +474,8 @@ def load_torch_function_handler(cls, func, types, args=(), kwargs=None):
|
|||
f"Expected isinstance(src, {cls}) but got {type(src)}"
|
||||
)
|
||||
assert (
|
||||
type(dest) == torch.Tensor
|
||||
or type(dest) == torch.nn.Parameter
|
||||
type(dest) is torch.Tensor
|
||||
or type(dest) is torch.nn.Parameter
|
||||
or issubclass(cls, type(dest))
|
||||
)
|
||||
if assign:
|
||||
|
|
|
|||
|
|
@ -3053,7 +3053,7 @@ class TestQuantizedOps(TestCase):
|
|||
lstm_quantized = torch.ao.quantization.convert(
|
||||
lstm_prepared, convert_custom_config_dict=custom_config_dict
|
||||
)
|
||||
assert type(lstm_quantized[0]) == torch.ao.nn.quantized.LSTM
|
||||
assert type(lstm_quantized[0]) is torch.ao.nn.quantized.LSTM
|
||||
qy = lstm_quantized(qx)
|
||||
|
||||
snr = _snr(y, qy)
|
||||
|
|
|
|||
|
|
@ -138,7 +138,7 @@ class TestObserver(QuantizationTestCase):
|
|||
# Calculate Qparams should return with a warning for observers with no data
|
||||
qparams = myobs.calculate_qparams()
|
||||
input_scale = 2**16 if qdtype is torch.qint32 else 1
|
||||
if type(myobs) == MinMaxObserver:
|
||||
if type(myobs) is MinMaxObserver:
|
||||
x = torch.tensor([1.0, 2.0, 2.0, 3.0, 4.0, 5.0, 6.0]) * input_scale
|
||||
y = torch.tensor([4.0, 5.0, 5.0, 6.0, 7.0, 8.0]) * input_scale
|
||||
else:
|
||||
|
|
@ -201,7 +201,7 @@ class TestObserver(QuantizationTestCase):
|
|||
[[[-4.0, -3.0], [5.0, 5.0]], [[6.0, 3.0], [7.0, 8.0]]],
|
||||
]
|
||||
)
|
||||
if type(myobs) == MovingAveragePerChannelMinMaxObserver:
|
||||
if type(myobs) is MovingAveragePerChannelMinMaxObserver:
|
||||
# Scaling the input tensor to model change in min/max values
|
||||
# across batches
|
||||
result = myobs(0.5 * x)
|
||||
|
|
|
|||
|
|
@ -599,7 +599,7 @@ class TestFakeQuantizeOps(TestCase):
|
|||
# Output of fake quant is not identical to input
|
||||
Y = fq_module(X)
|
||||
self.assertNotEqual(Y, X)
|
||||
if type(fq_module) == _LearnableFakeQuantize:
|
||||
if type(fq_module) is _LearnableFakeQuantize:
|
||||
fq_module.toggle_fake_quant(False)
|
||||
else:
|
||||
torch.ao.quantization.disable_fake_quant(fq_module)
|
||||
|
|
@ -613,7 +613,7 @@ class TestFakeQuantizeOps(TestCase):
|
|||
scale = fq_module.scale.detach().clone()
|
||||
zero_point = fq_module.zero_point.detach().clone()
|
||||
|
||||
if type(fq_module) == _LearnableFakeQuantize:
|
||||
if type(fq_module) is _LearnableFakeQuantize:
|
||||
fq_module.toggle_observer_update(False)
|
||||
fq_module.toggle_fake_quant(True)
|
||||
else:
|
||||
|
|
@ -625,7 +625,7 @@ class TestFakeQuantizeOps(TestCase):
|
|||
# Observer is disabled, scale and zero-point do not change
|
||||
self.assertEqual(fq_module.scale, scale)
|
||||
self.assertEqual(fq_module.zero_point, zero_point)
|
||||
if type(fq_module) == _LearnableFakeQuantize:
|
||||
if type(fq_module) is _LearnableFakeQuantize:
|
||||
fq_module.toggle_observer_update(True)
|
||||
else:
|
||||
torch.ao.quantization.enable_observer(fq_module)
|
||||
|
|
|
|||
|
|
@ -241,7 +241,7 @@ class _ReferenceConvBnNd(torch.nn.Conv2d, torch.nn.modules.conv._ConvNd):
|
|||
Args: `mod` a float module, either produced by torch.ao.quantization utilities
|
||||
or directly from user
|
||||
"""
|
||||
assert type(mod) == cls._FLOAT_MODULE, (
|
||||
assert type(mod) is cls._FLOAT_MODULE, (
|
||||
"qat."
|
||||
+ cls.__name__
|
||||
+ ".from_float only works for "
|
||||
|
|
@ -1264,8 +1264,8 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
|
|||
mp = prepare_qat(m)
|
||||
mp(data)
|
||||
mq = convert(mp)
|
||||
self.assertTrue(type(mq[1]) == nnq.Linear)
|
||||
self.assertTrue(type(mq[2]) == nn.Identity)
|
||||
self.assertTrue(type(mq[1]) is nnq.Linear)
|
||||
self.assertTrue(type(mq[2]) is nn.Identity)
|
||||
|
||||
@skipIfNoXNNPACK
|
||||
@override_qengines
|
||||
|
|
|
|||
|
|
@ -1823,7 +1823,7 @@ class TestFxModelReportVisualizer(QuantizationTestCase):
|
|||
plottable_set = set()
|
||||
|
||||
for feature_name in b_1_linear_features:
|
||||
if type(b_1_linear_features[feature_name]) == torch.Tensor:
|
||||
if type(b_1_linear_features[feature_name]) is torch.Tensor:
|
||||
plottable_set.add(feature_name)
|
||||
|
||||
returned_plottable_feats = mod_rep_visualizer.get_all_unique_feature_names()
|
||||
|
|
|
|||
|
|
@ -826,7 +826,7 @@ class TestFuseFx(QuantizationTestCase):
|
|||
# check conv module has two inputs
|
||||
named_modules = dict(m.named_modules())
|
||||
for node in m.graph.nodes:
|
||||
if node.op == "call_module" and type(named_modules[node.target]) == torch.nn.Conv2d:
|
||||
if node.op == "call_module" and type(named_modules[node.target]) is torch.nn.Conv2d:
|
||||
self.assertTrue(len(node.args) == 2, msg="Expecting the fused op to have two arguments")
|
||||
|
||||
def test_fusion_pattern_with_matchallnode(self):
|
||||
|
|
@ -917,7 +917,7 @@ class TestQuantizeFx(QuantizationTestCase):
|
|||
m = torch.fx.symbolic_trace(M())
|
||||
modules = dict(m.named_modules())
|
||||
for n in m.graph.nodes:
|
||||
if n.op == 'call_module' and type(modules[n.target]) == nn.ReLU:
|
||||
if n.op == 'call_module' and type(modules[n.target]) is nn.ReLU:
|
||||
self.assertTrue(_is_match(modules, n, pattern))
|
||||
|
||||
def test_pattern_match_constant(self):
|
||||
|
|
|
|||
|
|
@ -454,8 +454,8 @@ class TestSubgraphRewriter(JitTestCase):
|
|||
symbolic_traced: torch.fx.GraphModule = symbolic_trace(module)
|
||||
for n, m in zip(symbolic_traced.graph.nodes, graph.nodes):
|
||||
if n.op == 'placeholder':
|
||||
assert n.type == int
|
||||
assert m.type == int
|
||||
assert n.type is int
|
||||
assert m.type is int
|
||||
|
||||
def test_subgraph_writer_replace_consecutive_submodules(self):
|
||||
|
||||
|
|
|
|||
|
|
@ -332,7 +332,7 @@ class TestHelperModules:
|
|||
) -> None:
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(4, 4, bias=use_bias)
|
||||
if postop == nn.GELU:
|
||||
if postop is nn.GELU:
|
||||
self.postop = postop(approximate=post_op_algo)
|
||||
else:
|
||||
self.postop = postop(inplace=inplace_postop)
|
||||
|
|
|
|||
|
|
@ -4162,7 +4162,7 @@ class TestBinaryUfuncs(TestCase):
|
|||
for i in complex_exponents if exp_dtype.is_complex else exponents:
|
||||
out_dtype_scalar_exp = (
|
||||
torch.complex128
|
||||
if base_dtype.is_complex or type(i) == complex
|
||||
if base_dtype.is_complex or type(i) is complex
|
||||
else torch.float64
|
||||
)
|
||||
expected_scalar_exp = torch.from_numpy(np.float_power(to_np(base), i))
|
||||
|
|
@ -4190,7 +4190,7 @@ class TestBinaryUfuncs(TestCase):
|
|||
for i in complex_exponents if base_dtype.is_complex else exponents:
|
||||
out_dtype_scalar_base = (
|
||||
torch.complex128
|
||||
if exp_dtype.is_complex or type(i) == complex
|
||||
if exp_dtype.is_complex or type(i) is complex
|
||||
else torch.float64
|
||||
)
|
||||
expected_scalar_base = torch.from_numpy(np.float_power(i, to_np(exp)))
|
||||
|
|
@ -4205,9 +4205,9 @@ class TestBinaryUfuncs(TestCase):
|
|||
def test_float_power_exceptions(self, device):
|
||||
def _promo_helper(x, y):
|
||||
for i in (x, y):
|
||||
if type(i) == complex:
|
||||
if type(i) is complex:
|
||||
return torch.complex128
|
||||
elif type(i) == torch.Tensor and i.is_complex():
|
||||
elif type(i) is torch.Tensor and i.is_complex():
|
||||
return torch.complex128
|
||||
return torch.double
|
||||
|
||||
|
|
|
|||
|
|
@ -2478,7 +2478,7 @@ class TestTyping(TestCase):
|
|||
else:
|
||||
self.assertFalse(issubinstance(d, S))
|
||||
for t in basic_type:
|
||||
if type(d) == t:
|
||||
if type(d) is t:
|
||||
self.assertTrue(issubinstance(d, t))
|
||||
else:
|
||||
self.assertFalse(issubinstance(d, t))
|
||||
|
|
@ -2577,7 +2577,7 @@ class TestTyping(TestCase):
|
|||
|
||||
self.assertTrue(issubclass(DP4, IterDataPipe))
|
||||
dp4 = DP4()
|
||||
self.assertTrue(dp4.type.param == tuple)
|
||||
self.assertTrue(dp4.type.param is tuple)
|
||||
|
||||
class DP5(IterDataPipe):
|
||||
r"""DataPipe without type annotation"""
|
||||
|
|
@ -2601,7 +2601,7 @@ class TestTyping(TestCase):
|
|||
|
||||
self.assertTrue(issubclass(DP6, IterDataPipe))
|
||||
dp6 = DP6()
|
||||
self.assertTrue(dp6.type.param == int)
|
||||
self.assertTrue(dp6.type.param is int)
|
||||
|
||||
class DP7(IterDataPipe[Awaitable[T_co]]):
|
||||
r"""DataPipe with abstract base class"""
|
||||
|
|
|
|||
|
|
@ -878,7 +878,7 @@ def forward(self, scores_1, mask_1, value_1):
|
|||
zip(real_out, decomp_out, real_out_double)
|
||||
):
|
||||
if not isinstance(orig, torch.Tensor):
|
||||
assert type(orig) == type(decomp)
|
||||
assert type(orig) is type(decomp)
|
||||
assert orig == decomp
|
||||
continue
|
||||
op_assert_ref(
|
||||
|
|
@ -895,7 +895,7 @@ def forward(self, scores_1, mask_1, value_1):
|
|||
else:
|
||||
for orig, decomp in zip(real_out, decomp_out):
|
||||
if not isinstance(orig, torch.Tensor):
|
||||
assert type(orig) == type(decomp)
|
||||
assert type(orig) is type(decomp)
|
||||
assert orig == decomp
|
||||
continue
|
||||
op_assert_equal(
|
||||
|
|
|
|||
|
|
@ -2887,9 +2887,9 @@ graph(%Ra, %Rb):
|
|||
self.assertTrue(hasattr(input, 'type'))
|
||||
self.assertTrue(input.type() is not None)
|
||||
self.assertTrue(hasattr(block, 'returnNode'))
|
||||
self.assertTrue(type(block.returnNode()) == torch._C.Node)
|
||||
self.assertTrue(type(block.returnNode()) is torch._C.Node)
|
||||
self.assertTrue(hasattr(block, 'paramNode'))
|
||||
self.assertTrue(type(block.paramNode()) == torch._C.Node)
|
||||
self.assertTrue(type(block.paramNode()) is torch._C.Node)
|
||||
self.assertTrue(tested_blocks)
|
||||
|
||||
def test_export_opnames(self):
|
||||
|
|
@ -6510,7 +6510,7 @@ a")
|
|||
if isinstance(res_python, Exception):
|
||||
continue
|
||||
|
||||
if type(res_python) == type(res_script):
|
||||
if type(res_python) is type(res_script):
|
||||
if isinstance(res_python, tuple) and (math.isnan(res_python[0]) == math.isnan(res_script[0])):
|
||||
continue
|
||||
if isinstance(res_python, float) and math.isnan(res_python) and math.isnan(res_script):
|
||||
|
|
@ -8646,7 +8646,7 @@ dedent """
|
|||
args = args + [1, 1.5]
|
||||
|
||||
def isBool(arg):
|
||||
return type(arg) == bool or (type(arg) == str and "torch.bool" in arg)
|
||||
return type(arg) is bool or (type(arg) is str and "torch.bool" in arg)
|
||||
|
||||
for op in ops:
|
||||
for first_arg in args:
|
||||
|
|
@ -8655,7 +8655,7 @@ dedent """
|
|||
if (op == 'sub' or op == 'div') and (isBool(first_arg) or isBool(second_arg)):
|
||||
continue
|
||||
# div is not implemented correctly for mixed-type or int params
|
||||
if (op == 'div' and (type(first_arg) != type(second_arg) or
|
||||
if (op == 'div' and (type(first_arg) is not type(second_arg) or
|
||||
isinstance(first_arg, int) or
|
||||
(isinstance(first_arg, str) and 'int' in first_arg))):
|
||||
continue
|
||||
|
|
@ -8671,7 +8671,7 @@ dedent """
|
|||
graph = cu.func.graph
|
||||
torch._C._jit_pass_complete_shape_analysis(graph, (), False)
|
||||
# use dim=-1 to represent a python/jit scalar.
|
||||
dim = -1 if type(first_arg) != str and type(second_arg) != str else non_jit_result.dim()
|
||||
dim = -1 if type(first_arg) is not str and type(second_arg) is not str else non_jit_result.dim()
|
||||
dtype = non_jit_result.dtype
|
||||
# jit only supports int/float scalars.
|
||||
if dim < 0:
|
||||
|
|
|
|||
|
|
@ -211,9 +211,9 @@ def autograd_sharing(queue, ready, master_modified, device, is_parameter):
|
|||
is_ok &= var.grad is None
|
||||
is_ok &= not var._backward_hooks
|
||||
if is_parameter:
|
||||
is_ok &= type(var) == Parameter
|
||||
is_ok &= type(var) is Parameter
|
||||
else:
|
||||
is_ok &= type(var) == torch.Tensor
|
||||
is_ok &= type(var) is torch.Tensor
|
||||
var._grad = torch.ones(5, 5, device=device)
|
||||
|
||||
queue.put(is_ok)
|
||||
|
|
|
|||
|
|
@ -596,7 +596,7 @@ class TestNumPyInterop(TestCase):
|
|||
if (
|
||||
dtype == torch.complex64
|
||||
and torch.is_tensor(t)
|
||||
and type(a) == np.complex64
|
||||
and type(a) is np.complex64
|
||||
):
|
||||
# TODO: Imaginary part is dropped in this case. Need fix.
|
||||
# https://github.com/pytorch/pytorch/issues/43579
|
||||
|
|
|
|||
|
|
@ -3327,7 +3327,7 @@ class TestReductions(TestCase):
|
|||
"""
|
||||
def _test_histogramdd_numpy(self, t, bins, bin_range, weights, density):
|
||||
def to_np(t):
|
||||
if type(t) == list:
|
||||
if type(t) is list:
|
||||
return list(map(to_np, t))
|
||||
if not torch.is_tensor(t):
|
||||
return t
|
||||
|
|
|
|||
|
|
@ -968,7 +968,7 @@ class TestTypePromotion(TestCase):
|
|||
except Exception as e:
|
||||
expected = e
|
||||
|
||||
same_result = (type(expected) == type(actual)) and expected == actual
|
||||
same_result = (type(expected) is type(actual)) and expected == actual
|
||||
|
||||
# Note: An "undesired failure," as opposed to an "expected failure"
|
||||
# is both expected (we know the test will fail) and
|
||||
|
|
@ -1128,7 +1128,7 @@ class TestTypePromotion(TestCase):
|
|||
maxs = (max_t, max_t[0], max_t[0].item())
|
||||
inp = make_tensor((S,), dtype0)
|
||||
for min_v, max_v in itertools.product(mins, maxs):
|
||||
if type(max_v) != type(min_v):
|
||||
if type(max_v) is not type(min_v):
|
||||
continue
|
||||
if isinstance(min_v, torch.Tensor) and min_v.ndim == 0 and max_v.ndim == 0:
|
||||
continue # 0d tensors go to scalar overload, and it's tested separately
|
||||
|
|
|
|||
|
|
@ -2384,7 +2384,7 @@ class TestLikeFuncs(TestCase):
|
|||
b = a[:, ::2] # Ensure b is not contiguous.
|
||||
kwargs = {"fill_value": ""} if likefunc == np.full_like else {}
|
||||
result = likefunc(b, dtype=dtype, **kwargs)
|
||||
if dtype == str:
|
||||
if dtype is str:
|
||||
assert result.strides == (16, 4)
|
||||
else:
|
||||
# dtype is bytes
|
||||
|
|
|
|||
|
|
@ -925,7 +925,7 @@ class TestScalarSubclassingMisc(TestCase):
|
|||
|
||||
# inheritance has to override, or this is correctly lost:
|
||||
res = op(myf_simple1(1), myf_simple2(2))
|
||||
assert type(res) == sctype or type(res) == np.bool_
|
||||
assert type(res) is sctype or type(res) is np.bool_
|
||||
assert op(myf_simple1(1), myf_simple2(2)) == op(1, 2) # inherited
|
||||
|
||||
# Two independent subclasses do not really define an order. This could
|
||||
|
|
@ -955,7 +955,7 @@ class TestScalarSubclassingMisc(TestCase):
|
|||
assert op(myt(1), np.float64(2)) == __op__
|
||||
assert op(np.float64(1), myt(2)) == __rop__
|
||||
|
||||
if op in {operator.mod, operator.floordiv} and subtype == complex:
|
||||
if op in {operator.mod, operator.floordiv} and subtype is complex:
|
||||
return # module is not support for complex. Do not test.
|
||||
|
||||
if __rop__ == __op__:
|
||||
|
|
@ -968,11 +968,11 @@ class TestScalarSubclassingMisc(TestCase):
|
|||
res = op(myt(1), np.float16(2))
|
||||
expected = op(subtype(1), np.float16(2))
|
||||
assert res == expected
|
||||
assert type(res) == type(expected)
|
||||
assert type(res) is type(expected)
|
||||
res = op(np.float32(2), myt(1))
|
||||
expected = op(np.float32(2), subtype(1))
|
||||
assert res == expected
|
||||
assert type(res) == type(expected)
|
||||
assert type(res) is type(expected)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -937,7 +937,7 @@ class DetCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):
|
|||
@instantiate_parametrized_tests
|
||||
class TestDet(DetCases, TestCase):
|
||||
def test_zero(self):
|
||||
# NB: comment out tests of type(det) == double : we return zero-dim arrays
|
||||
# NB: comment out tests of type(det) is double : we return zero-dim arrays
|
||||
assert_equal(linalg.det([[0.0]]), 0.0)
|
||||
# assert_equal(type(linalg.det([[0.0]])), double)
|
||||
assert_equal(linalg.det([[0.0j]]), 0.0)
|
||||
|
|
@ -1103,7 +1103,7 @@ class TestMatrixPower(TestCase):
|
|||
|
||||
for mat in self.rshft_all:
|
||||
tz(mat.astype(dt))
|
||||
if dt != object:
|
||||
if dt is not object:
|
||||
tz(self.stacked.astype(dt))
|
||||
|
||||
@parametrize("dt", [np.dtype(c) for c in "?bBhilefdFD"])
|
||||
|
|
@ -1115,7 +1115,7 @@ class TestMatrixPower(TestCase):
|
|||
|
||||
for mat in self.rshft_all:
|
||||
tz(mat.astype(dt))
|
||||
if dt != object:
|
||||
if dt is not object:
|
||||
tz(self.stacked.astype(dt))
|
||||
|
||||
@parametrize("dt", [np.dtype(c) for c in "?bBhilefdFD"])
|
||||
|
|
@ -1128,7 +1128,7 @@ class TestMatrixPower(TestCase):
|
|||
|
||||
for mat in self.rshft_all:
|
||||
tz(mat.astype(dt))
|
||||
if dt != object:
|
||||
if dt is not object:
|
||||
tz(self.stacked.astype(dt))
|
||||
|
||||
@parametrize("dt", [np.dtype(c) for c in "?bBhilefdFD"])
|
||||
|
|
|
|||
|
|
@ -661,7 +661,7 @@ class TestIter(TestCase):
|
|||
# numpy generates array scalars, we do 0D arrays
|
||||
a = np.arange(5)
|
||||
lst = list(a)
|
||||
assert all(type(x) == np.ndarray for x in lst), f"{[type(x) for x in lst]}"
|
||||
assert all(type(x) is np.ndarray for x in lst), f"{[type(x) for x in lst]}"
|
||||
assert all(x.ndim == 0 for x in lst)
|
||||
|
||||
def test_iter_2d(self):
|
||||
|
|
@ -669,7 +669,8 @@ class TestIter(TestCase):
|
|||
a = np.arange(5)[None, :]
|
||||
lst = list(a)
|
||||
assert len(lst) == 1
|
||||
assert type(lst[0]) == np.ndarray
|
||||
# FIXME: "is" cannot be used here because dynamo fails
|
||||
assert type(lst[0]) == np.ndarray # noqa: E721
|
||||
assert_equal(lst[0], np.arange(5))
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -94,7 +94,7 @@ class TestNEP50Table(TestCase):
|
|||
def test_nep50_exceptions(self, example):
|
||||
old, new = examples[example]
|
||||
|
||||
if new == Exception:
|
||||
if new is Exception:
|
||||
with assert_raises(OverflowError):
|
||||
eval(example)
|
||||
|
||||
|
|
|
|||
|
|
@ -554,7 +554,7 @@ def fuzz_scalar(spec, seed: Optional[int] = None) -> Union[float, int, bool, com
|
|||
|
||||
def specs_compatible(spec1: Spec, spec2: Spec) -> bool:
|
||||
"""Check if two specifications are compatible (one can be used where the other is expected)."""
|
||||
if type(spec1) != type(spec2):
|
||||
if type(spec1) is not type(spec2):
|
||||
return False
|
||||
|
||||
if isinstance(spec1, ScalarSpec):
|
||||
|
|
|
|||
|
|
@ -2842,7 +2842,7 @@ def _index_add(
|
|||
if alpha != 1:
|
||||
python_type = utils.dtype_to_type(x.dtype)
|
||||
torch._check(
|
||||
python_type == bool
|
||||
python_type is bool
|
||||
or utils.is_weakly_lesser_type(type(alpha), python_type),
|
||||
lambda: f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -295,7 +295,7 @@ class PyCodegen:
|
|||
output.extend(create_call_function(2, False))
|
||||
elif (
|
||||
isinstance(value, SymNodeVariable)
|
||||
and value.python_type() == float
|
||||
and value.python_type() is float
|
||||
and not self.tx.export
|
||||
):
|
||||
# This is a little unusual; force the output convention to be a
|
||||
|
|
|
|||
|
|
@ -4182,7 +4182,7 @@ def make_torch_function_mode_stack_guard(
|
|||
return False
|
||||
|
||||
for ty, mode in zip(types, cur_stack):
|
||||
if ty != type(mode):
|
||||
if ty is not type(mode):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
|
|
|||
|
|
@ -1361,7 +1361,7 @@ class TensorVariable(VariableTracker):
|
|||
if (len(args) == 1 and isinstance(args[0], SizeVariable)) or (
|
||||
len(args) >= 1
|
||||
and all(
|
||||
isinstance(a, ConstantVariable) and a.python_type() == int for a in args
|
||||
isinstance(a, ConstantVariable) and a.python_type() is int for a in args
|
||||
)
|
||||
):
|
||||
from ..symbolic_convert import InstructionTranslator
|
||||
|
|
|
|||
|
|
@ -64,14 +64,14 @@ def _staged_schema():
|
|||
)
|
||||
elif o := typing.get_origin(t):
|
||||
# Lemme know if there's a better way to do this.
|
||||
if o == list:
|
||||
if o is list:
|
||||
yaml_head, cpp_head, thrift_head, thrift_tail = (
|
||||
"List",
|
||||
"std::vector",
|
||||
"list<",
|
||||
">",
|
||||
)
|
||||
elif o == dict:
|
||||
elif o is dict:
|
||||
yaml_head, cpp_head, thrift_head, thrift_tail = (
|
||||
"Dict",
|
||||
"std::unordered_map",
|
||||
|
|
@ -81,7 +81,7 @@ def _staged_schema():
|
|||
elif o == Union:
|
||||
assert level == 0, "Optional is only supported at the top level."
|
||||
args = typing.get_args(t)
|
||||
assert len(args) == 2 and args[1] == type(None)
|
||||
assert len(args) == 2 and args[1] is type(None)
|
||||
yaml_type, cpp_type, thrift_type = dump_type(args[0], level + 1)
|
||||
return (
|
||||
f"Optional[{yaml_type}]",
|
||||
|
|
|
|||
|
|
@ -83,7 +83,7 @@ class HopPartitionedGraph:
|
|||
val1: Union[torch.SymInt, torch.Tensor],
|
||||
val2: Union[torch.SymInt, torch.Tensor],
|
||||
) -> bool:
|
||||
if type(val1) != type(val2):
|
||||
if type(val1) is not type(val2):
|
||||
return False
|
||||
|
||||
if isinstance(val1, torch.SymInt) and isinstance(val2, torch.SymInt):
|
||||
|
|
|
|||
|
|
@ -1211,7 +1211,7 @@ class CppVecOverrides(CppOverrides):
|
|||
return wrapper
|
||||
|
||||
for name, method in vars(CppVecOverrides).items():
|
||||
if getattr(method, "__class__", None) == staticmethod and name not in [
|
||||
if getattr(method, "__class__", None) is staticmethod and name not in [
|
||||
"masked",
|
||||
"index_expr",
|
||||
]:
|
||||
|
|
|
|||
|
|
@ -220,15 +220,15 @@ class SamplingMethod(Enum):
|
|||
if field_name in TYPE_OVERRIDES:
|
||||
return random.choice(TYPE_OVERRIDES[field_name])
|
||||
|
||||
if type_hint == bool:
|
||||
if type_hint is bool:
|
||||
return random.choice([True, False]) if random_sample else not default
|
||||
elif type_hint == int:
|
||||
elif type_hint is int:
|
||||
# NOTE initially tried to use negation of the value, but it doesn't work because most types are ints
|
||||
# when they should be natural numbers + zero. Python types to cover these values aren't super convenient.
|
||||
return random.randint(0, 1000)
|
||||
elif type_hint == float:
|
||||
elif type_hint is float:
|
||||
return random.uniform(0, 1000)
|
||||
elif type_hint == str:
|
||||
elif type_hint is str:
|
||||
characters = string.ascii_letters + string.digits + string.punctuation
|
||||
return "".join(
|
||||
random.choice(characters) for _ in range(random.randint(1, 20))
|
||||
|
|
@ -306,7 +306,7 @@ class SamplingMethod(Enum):
|
|||
new_type = random.choice(type_hint.__args__)
|
||||
else:
|
||||
new_type = random.choice(
|
||||
[t for t in type_hint.__args__ if t != type(default)]
|
||||
[t for t in type_hint.__args__ if t is not type(default)]
|
||||
)
|
||||
try:
|
||||
new_default = new_type()
|
||||
|
|
|
|||
|
|
@ -1208,7 +1208,7 @@ def safe_grad_filter(message, category, filename, lineno, file=None, line=None)
|
|||
def user_warning_filter(
|
||||
message, category, filename, lineno, file=None, line=None
|
||||
) -> bool:
|
||||
return category != UserWarning
|
||||
return category is not UserWarning
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
|
|
|
|||
|
|
@ -428,7 +428,7 @@ def percentile(
|
|||
interpolation: NotImplementedType = None,
|
||||
):
|
||||
# np.percentile(float_tensor, 30) : q.dtype is int64 => q / 100.0 is float32
|
||||
if _dtypes_impl.python_type_for_torch(q.dtype) == int:
|
||||
if _dtypes_impl.python_type_for_torch(q.dtype) is int:
|
||||
q = q.to(_dtypes_impl.default_dtypes().float_dtype)
|
||||
qq = q / 100.0
|
||||
|
||||
|
|
|
|||
|
|
@ -1179,7 +1179,7 @@ def add(
|
|||
if alpha is not None:
|
||||
dtype = a.dtype if isinstance(a, TensorLike) else b.dtype # type: ignore[union-attr]
|
||||
python_type = utils.dtype_to_type(dtype)
|
||||
if python_type != bool and not utils.is_weakly_lesser_type(
|
||||
if python_type is not bool and not utils.is_weakly_lesser_type(
|
||||
type(alpha), python_type
|
||||
):
|
||||
msg = f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!"
|
||||
|
|
|
|||
|
|
@ -755,7 +755,7 @@ class ExceptionWrapper:
|
|||
# Format a message such as: "Caught ValueError in DataLoader worker
|
||||
# process 2. Original Traceback:", followed by the traceback.
|
||||
msg = f"Caught {self.exc_type.__name__} {self.where}.\nOriginal {self.exc_msg}" # pyrefly: ignore # missing-attribute
|
||||
if self.exc_type == KeyError:
|
||||
if self.exc_type is KeyError:
|
||||
# KeyError calls repr() on its argument (usually a dict key). This
|
||||
# makes stack traces unreadable. It will not be changed in Python
|
||||
# (https://bugs.python.org/issue2651), so we work around it.
|
||||
|
|
|
|||
|
|
@ -317,7 +317,7 @@ def get_arg_indices_of_inputs_to_log(node: Node) -> list[int]:
|
|||
node.target in (torch.add, torch.ops.quantized.add, operator.add)
|
||||
or node.target in (torch.mul, torch.ops.quantized.mul, operator.mul)
|
||||
):
|
||||
result = [i for i in range(2) if type(node.args[i]) == Node]
|
||||
result = [i for i in range(2) if type(node.args[i]) is Node]
|
||||
return result
|
||||
return [0]
|
||||
|
||||
|
|
|
|||
|
|
@ -589,7 +589,7 @@ def _match_static_pattern(
|
|||
|
||||
# Handle cases where the node is wrapped in a ReLU
|
||||
if (ref_node.op == "call_function" and ref_node.target in (F.relu, torch.relu)) or (
|
||||
ref_node.op == "call_module" and type(_get_module(ref_node, modules)) == nn.ReLU
|
||||
ref_node.op == "call_module" and type(_get_module(ref_node, modules)) is nn.ReLU
|
||||
):
|
||||
relu_node = ref_node
|
||||
ref_node = relu_node.args[0]
|
||||
|
|
@ -724,7 +724,7 @@ def _lower_static_weighted_ref_module(
|
|||
# If so, we replace the entire fused module with the corresponding quantized module
|
||||
if ref_class in STATIC_LOWER_FUSED_MODULE_MAP:
|
||||
inner_ref_class, q_class = STATIC_LOWER_FUSED_MODULE_MAP[ref_class]
|
||||
if type(ref_module[0]) != inner_ref_class: # type: ignore[index]
|
||||
if type(ref_module[0]) is not inner_ref_class: # type: ignore[index]
|
||||
continue
|
||||
else:
|
||||
q_class = STATIC_LOWER_MODULE_MAP[ref_class]
|
||||
|
|
@ -786,7 +786,7 @@ def _lower_static_weighted_ref_module_with_two_inputs(
|
|||
inner_ref_class, q_class = STATIC_LOWER_FUSED_MODULE_TWO_INPUTS_MAP[
|
||||
ref_class
|
||||
]
|
||||
if type(ref_module[0]) != inner_ref_class: # type: ignore[index]
|
||||
if type(ref_module[0]) is not inner_ref_class: # type: ignore[index]
|
||||
continue
|
||||
else:
|
||||
continue
|
||||
|
|
@ -846,7 +846,7 @@ def _lower_dynamic_weighted_ref_module(model: GraphModule):
|
|||
ref_class = type(ref_module)
|
||||
if ref_class in DYNAMIC_LOWER_FUSED_MODULE_MAP:
|
||||
inner_ref_class, q_class = DYNAMIC_LOWER_FUSED_MODULE_MAP[ref_class]
|
||||
if type(ref_module[0]) != inner_ref_class:
|
||||
if type(ref_module[0]) is not inner_ref_class:
|
||||
continue
|
||||
else:
|
||||
q_class = DYNAMIC_LOWER_MODULE_MAP.get(ref_class) # type: ignore[assignment]
|
||||
|
|
@ -1008,7 +1008,7 @@ def _lower_dynamic_weighted_ref_functional(
|
|||
func_node.op == "call_function"
|
||||
and func_node.target == F.relu
|
||||
or func_node.op == "call_module"
|
||||
and type(modules[str(func_node.target)]) == torch.nn.ReLU
|
||||
and type(modules[str(func_node.target)]) is torch.nn.ReLU
|
||||
):
|
||||
relu_node = func_node
|
||||
func_node = relu_node.args[0]
|
||||
|
|
|
|||
|
|
@ -132,7 +132,7 @@ class ModelReportVisualizer:
|
|||
# if we need plottable, ensure type of val is tensor
|
||||
if (
|
||||
not plottable_features_only
|
||||
or type(feature_dict[feature_name]) == torch.Tensor
|
||||
or type(feature_dict[feature_name]) is torch.Tensor
|
||||
):
|
||||
unique_feature_names.add(feature_name)
|
||||
|
||||
|
|
|
|||
|
|
@ -704,7 +704,7 @@ def _maybe_get_custom_module_lstm_from_node_arg(
|
|||
return a.op == "call_function" and a.target == operator.getitem
|
||||
|
||||
def match_tuple(a):
|
||||
return a.op == "call_function" and a.target == tuple
|
||||
return a.op == "call_function" and a.target is tuple
|
||||
|
||||
def _match_pattern(match_pattern: list[Callable]) -> Optional[Node]:
|
||||
"""
|
||||
|
|
@ -797,7 +797,7 @@ def _reroute_tuple_getitem_pattern(graph: Graph):
|
|||
|
||||
# Iterate through users of this node to find tuple/getitem nodes to match
|
||||
for user in node.users:
|
||||
if user.op == "call_function" and user.target == tuple:
|
||||
if user.op == "call_function" and user.target is tuple:
|
||||
for i, user_arg in enumerate(user.args[0]): # type: ignore[arg-type]
|
||||
if user_arg == node:
|
||||
index_stack.append(i)
|
||||
|
|
@ -826,7 +826,7 @@ def _reroute_tuple_getitem_pattern(graph: Graph):
|
|||
for pattern in matched_patterns:
|
||||
first_tuple = pattern[0]
|
||||
last_getitem = pattern[-1]
|
||||
assert first_tuple.op == "call_function" and first_tuple.target == tuple
|
||||
assert first_tuple.op == "call_function" and first_tuple.target is tuple
|
||||
assert (
|
||||
last_getitem.op == "call_function"
|
||||
and last_getitem.target == operator.getitem
|
||||
|
|
|
|||
|
|
@ -699,12 +699,12 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
|||
state_dict_config = state_dict_config_type()
|
||||
if optim_state_dict_config is None:
|
||||
optim_state_dict_config = optim_state_dict_config_type()
|
||||
if state_dict_config_type != type(state_dict_config):
|
||||
if state_dict_config_type is not type(state_dict_config):
|
||||
raise RuntimeError(
|
||||
f"Expected state_dict_config of type {state_dict_config_type} "
|
||||
f"but got {type(state_dict_config)}"
|
||||
)
|
||||
if optim_state_dict_config_type != type(optim_state_dict_config):
|
||||
if optim_state_dict_config_type is not type(optim_state_dict_config):
|
||||
raise RuntimeError(
|
||||
f"Expected optim_state_dict_config of type {optim_state_dict_config_type} "
|
||||
f"but got {type(optim_state_dict_config)}"
|
||||
|
|
|
|||
|
|
@ -180,12 +180,12 @@ def add_inference_rule(n: Node):
|
|||
t2 = n.args[1].type
|
||||
|
||||
# handle scalar addition
|
||||
if t1 == int and isinstance(t2, TensorType):
|
||||
if t1 is int and isinstance(t2, TensorType):
|
||||
n.type = t2
|
||||
return n.type
|
||||
|
||||
# handle scalar addition
|
||||
elif t2 == int and isinstance(t1, TensorType):
|
||||
elif t2 is int and isinstance(t1, TensorType):
|
||||
n.type = t1
|
||||
return n.type
|
||||
|
||||
|
|
|
|||
|
|
@ -542,7 +542,7 @@ def reinplace(gm, *sample_args):
|
|||
continue
|
||||
if len(node.target._schema.arguments) < 1:
|
||||
continue
|
||||
if type(node.target._schema.arguments[0].type) != torch.TensorType:
|
||||
if type(node.target._schema.arguments[0].type) is not torch.TensorType:
|
||||
continue
|
||||
|
||||
# Step 1a: Check that the self argument we're attempting to reinplace
|
||||
|
|
|
|||
|
|
@ -78,7 +78,7 @@ def issubtype(left, right, recursive=True):
|
|||
if getattr(right, "__origin__", None) is Generic:
|
||||
return True
|
||||
|
||||
if right == type(None):
|
||||
if right is type(None):
|
||||
return False
|
||||
|
||||
# Right-side type
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user