Enable ruff rule E721 (#165162)

`E721` checks for object type comparisons using == and other comparison operators. This is useful because it is recommended to use `is` for type comparisons.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165162
Approved by: https://github.com/Skylion007
This commit is contained in:
Yuanyuan Chen 2025-10-11 06:43:53 +00:00 committed by PyTorch MergeBot
parent 220a34118f
commit 9e7c19f72b
78 changed files with 166 additions and 164 deletions

View File

@ -367,7 +367,7 @@ class DeepSpeech(nn.Module):
"""
seq_len = input_length
for m in self.conv.modules():
if type(m) == nn.modules.conv.Conv2d:
if type(m) is nn.modules.conv.Conv2d:
seq_len = (
seq_len
+ 2 * m.padding[1]

View File

@ -66,7 +66,7 @@ class GroupedSetup:
def __post_init__(self) -> None:
for field in dataclasses.fields(self):
assert field.type == str
assert field.type is str
value: str = getattr(self, field.name)
object.__setattr__(self, field.name, textwrap.dedent(value))

View File

@ -113,7 +113,7 @@ class TorchBenchmarkBase(torch.nn.Module):
value = kargs[key]
test_name_str.append(
("" if key in skip_key_list else key)
+ str(value if type(value) != bool else int(value))
+ str(value if type(value) is not bool else int(value))
)
name = (self.module_name() + "_" + "_".join(test_name_str)).replace(" ", "")
return name

View File

@ -125,7 +125,7 @@ class CatBenchmark(op_bench.TorchBenchmarkBase):
random.seed(42)
inputs = []
gen_sizes = []
if type(sizes) == list and N == -1:
if type(sizes) is list and N == -1:
gen_sizes = sizes
else:
for i in range(N):

View File

@ -61,7 +61,7 @@ class StackBenchmark(op_bench.TorchBenchmarkBase):
random.seed(42)
inputs = []
gen_sizes = []
if type(sizes) == list and N == -1:
if type(sizes) is list and N == -1:
gen_sizes = sizes
else:
for i in range(N):

View File

@ -155,7 +155,6 @@ ignore = [
"E402",
"C408", # C408 ignored because we like the dict keyword argument syntax
"E501", # E501 is not flexible enough, we're using B950 instead
"E721",
"E741",
"EXE001",
"F405",

View File

@ -243,7 +243,7 @@ class TestActivationSparsifier(TestCase):
if mask1 is None:
assert mask2 is None
else:
assert type(mask1) == type(mask2)
assert type(mask1) is type(mask2)
if isinstance(mask1, list):
assert len(mask1) == len(mask2)
for idx in range(len(mask1)):

View File

@ -710,15 +710,15 @@ class TestQuantizationUtils(TestCase):
**sparse_config,
)
assert type(model.emb1) == torch.ao.nn.quantized.modules.embedding_ops.Embedding
assert type(model.emb1) is torch.ao.nn.quantized.modules.embedding_ops.Embedding
assert (
type(model.embbag1)
== torch.ao.nn.quantized.modules.embedding_ops.EmbeddingBag
is torch.ao.nn.quantized.modules.embedding_ops.EmbeddingBag
)
assert type(model.emb_seq[0] == nn.Embedding)
assert type(model.emb_seq[1] == nn.EmbeddingBag)
assert type(model.linear1) == nn.Linear
assert type(model.linear2) == nn.Linear
assert type(model.emb_seq[0] is nn.Embedding)
assert type(model.emb_seq[1] is nn.EmbeddingBag)
assert type(model.linear1) is nn.Linear
assert type(model.linear2) is nn.Linear
dequant_emb1 = torch.dequantize(model.emb1.weight())
dequant_embbag1 = torch.dequantize(model.embbag1.weight())
@ -749,19 +749,21 @@ class TestQuantizationUtils(TestCase):
model, DataNormSparsifier, sparsify_first=False, **sparse_config
)
assert type(model.emb1) == torch.ao.nn.quantized.modules.embedding_ops.Embedding
assert type(model.emb1) is torch.ao.nn.quantized.modules.embedding_ops.Embedding
assert (
type(model.embbag1)
== torch.ao.nn.quantized.modules.embedding_ops.EmbeddingBag
is torch.ao.nn.quantized.modules.embedding_ops.EmbeddingBag
)
assert type(
model.emb_seq[0] == torch.ao.nn.quantized.modules.embedding_ops.Embedding
assert (
type(model.emb_seq[0])
is torch.ao.nn.quantized.modules.embedding_ops.Embedding
)
assert type(
model.emb_seq[1] == torch.ao.nn.quantized.modules.embedding_ops.EmbeddingBag
assert (
type(model.emb_seq[1])
is torch.ao.nn.quantized.modules.embedding_ops.EmbeddingBag
)
assert type(model.linear1) == nn.Linear # not quantized
assert type(model.linear2) == nn.Linear # not quantized
assert type(model.linear1) is nn.Linear # not quantized
assert type(model.linear2) is nn.Linear # not quantized
dequant_emb1 = torch.dequantize(model.emb1.weight())
dequant_embbag1 = torch.dequantize(model.embbag1.weight())

View File

@ -291,7 +291,7 @@ class TestWeightNormSparsifier(TestCase):
assert hasattr(module.parametrizations["weight"][0], "mask")
# Check parametrization exists and is correct
assert is_parametrized(module, "weight")
assert type(module.parametrizations.weight[0]) == FakeSparsity
assert type(module.parametrizations.weight[0]) is FakeSparsity
def test_mask_squash(self):
model = SimpleLinear()
@ -415,7 +415,7 @@ class TestNearlyDiagonalSparsifier(TestCase):
assert hasattr(module.parametrizations["weight"][0], "mask")
# Check parametrization exists and is correct
assert is_parametrized(module, "weight")
assert type(module.parametrizations.weight[0]) == FakeSparsity
assert type(module.parametrizations.weight[0]) is FakeSparsity
def test_mask_squash(self):
model = SimpleLinear()

View File

@ -158,7 +158,7 @@ class TestBaseStructuredSparsifier(TestCase):
assert parametrize.is_parametrized(module)
assert hasattr(module, "parametrizations")
# Assume that this is the 1st/only parametrization
assert type(module.parametrizations.weight[0]) == FakeStructuredSparsity
assert type(module.parametrizations.weight[0]) is FakeStructuredSparsity
def _check_pruner_valid_before_step(self, model, pruner, device):
for config in pruner.groups:

View File

@ -116,7 +116,7 @@ class TestTensorType(TestCase):
for dtype, str in dtypes_map.items():
x = torch.empty(4, 4, dtype=dtype, device="openreg")
self.assertTrue(x.type() == str)
self.assertTrue(x.type() is str)
# Note that all dtype-d Tensor objects here are only for legacy reasons
# and should NOT be used.

View File

@ -134,7 +134,7 @@ def compare_state_dicts(gpu_state_dict, cpu_state_dict, rtol=1e-5, atol=1e-8):
False,
f"Collection length mismatch at {path}: {len(gpu_obj)} vs {len(cpu_obj)}",
)
if type(gpu_obj) != type(cpu_obj):
if type(gpu_obj) is not type(cpu_obj):
return (
False,
f"Collection type mismatch at {path}: {type(gpu_obj)} vs {type(cpu_obj)}",
@ -149,7 +149,7 @@ def compare_state_dicts(gpu_state_dict, cpu_state_dict, rtol=1e-5, atol=1e-8):
# If objects are custom classes, compare their attributes
elif hasattr(gpu_obj, "__dict__") and hasattr(cpu_obj, "__dict__"):
if type(gpu_obj) != type(cpu_obj):
if type(gpu_obj) is not type(cpu_obj):
return (
False,
f"Object type mismatch at {path}: {type(gpu_obj)} vs {type(cpu_obj)}",
@ -165,7 +165,7 @@ def compare_state_dicts(gpu_state_dict, cpu_state_dict, rtol=1e-5, atol=1e-8):
# For other types, use direct equality comparison
else:
if type(gpu_obj) != type(cpu_obj):
if type(gpu_obj) is not type(cpu_obj):
return (
False,
f"Type mismatch at {path}: {type(gpu_obj)} vs {type(cpu_obj)}",

View File

@ -44,14 +44,14 @@ class TestApply(FSDPTest):
@torch.no_grad()
def _init_linear_weights(self, m):
if type(m) == nn.Linear:
if type(m) is nn.Linear:
m.weight.fill_(1.0)
m.bias.fill_(1.0)
def check_weights(self, fsdp, expected_tensor_fn, check):
with FSDP.summon_full_params(fsdp, recurse=True):
linear_modules = [
module for module in fsdp.modules() if type(module) == nn.Linear
module for module in fsdp.modules() if type(module) is nn.Linear
]
for module in linear_modules:
for param in module.parameters():

View File

@ -1021,7 +1021,7 @@ class TestFSDPMiscWorldSize1(FSDPTestMultiThread):
)
for warning in w:
self.assertTrue(
warning.category != UserWarning
warning.category is not UserWarning
or not str(warning.message).startswith(warning_prefix)
)

View File

@ -421,7 +421,7 @@ class TestFSDPOptimState(FSDPTest):
return False
for state_name, value1 in state1.items():
value2 = state2[state_name]
if type(value1) != type(value2):
if type(value1) is not type(value2):
return False
if torch.is_tensor(value1): # tensor state
assert torch.is_tensor(value2)

View File

@ -5887,7 +5887,7 @@ class TestKL(DistributionsTestCase):
def test_kl_exponential_family(self):
for (p, _), (_, q) in self.finite_examples:
if type(p) == type(q) and issubclass(type(p), ExponentialFamily):
if type(p) is type(q) and issubclass(type(p), ExponentialFamily):
actual = kl_divergence(p, q)
expected = _kl_expfamily_expfamily(p, q)
self.assertEqual(

View File

@ -3370,9 +3370,9 @@ utils_device.CURRENT_DEVICE == None""".split("\n"):
# Test on non autocast state and autocast cache states.
self.assertIn("autocast_state", json_guards)
for key, value in json_guards.items():
if type(value) == int:
if type(value) is int:
variant = value + 1
elif type(value) == bool:
elif type(value) is bool:
variant = not value
elif isinstance(value, dict) and key == "autocast_state":
variant = value.copy()

View File

@ -59,7 +59,7 @@ class SourceTests(torch._dynamo.test_case.TestCase):
def forward(self):
if (
torch.utils._pytree.SUPPORTED_NODES[CausalLMOutputWithPast].type
== int
is int
):
x = torch.sin(self.x)
else:

View File

@ -662,7 +662,7 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
"comparison",
[
subtest(isinstance, "isinstance"),
subtest(lambda instance, type_: type(instance) == type_, "equality"),
subtest(lambda instance, type_: type(instance) is type_, "equality"),
subtest(lambda instance, type_: type(instance) is type_, "identity"),
],
)

View File

@ -38,7 +38,7 @@ class PreDispatchSchemaCheckMode(SchemaCheckMode):
def _may_alias_or_mutate(self, func, types, args, kwargs):
def unwrap(e):
if isinstance(e, torch.Tensor) and not type(e) == torch.Tensor:
if isinstance(e, torch.Tensor) and type(e) is not torch.Tensor:
try:
return e.elem
except AttributeError:

View File

@ -128,7 +128,7 @@ def run_with_nativert(ep):
flat_results = pytree.tree_leaves(results)
assert len(flat_results) == len(flat_expected)
for result, expected in zip(flat_results, flat_expected):
assert type(result) == type(expected)
assert type(result) is type(expected)
if isinstance(result, torch.Tensor) and isinstance(expected, torch.Tensor):
assert result.shape == expected.shape
assert result.dtype == expected.dtype
@ -323,7 +323,7 @@ class TestNativeRT(TestCase):
flat_results = pytree.tree_leaves(results)
assert len(flat_results) == len(flat_expected)
for result, expected in zip(flat_results, flat_expected):
assert type(result) == type(expected)
assert type(result) is type(expected)
if isinstance(result, torch.Tensor) and isinstance(
expected, torch.Tensor
):

View File

@ -82,7 +82,7 @@ class TestSerialize(TestCase):
return 0
def __eq__(self, other):
return type(other) == type(self)
return type(other) is type(self)
def __call__(self, *args, **kwargs):
return torch.ops.aten.add.Tensor(*args, **kwargs)

View File

@ -6332,7 +6332,7 @@ def forward(self, tangents_1, tangents_2):
self.assertEqual(out_ref[0].b, out_test[0].b)
self.assertEqual(out_ref[1], out_test[1])
# We compiled our graph assuming type(grad_out[1]) == torch.Tensor,
# We compiled our graph assuming type(grad_out[1]) is torch.Tensor,
# but we were wrong: in the below tests, it is a subclass.
# This will eventually require a repartition + recompile
with self.assertRaisesRegex(

View File

@ -3671,7 +3671,7 @@ class AssociativeScanModels:
# Check if val is a list and if it has the same length as combine_fn
# If so, then use the individual elements.
# If not, duplicate the first element.
if type(val) == list and len(val) == chain_len:
if type(val) is list and len(val) == chain_len:
kwargs_el[key] = val[ind]
else:
kwargs_el[key] = val

View File

@ -296,7 +296,7 @@ class TestSplitOutputType(TestCase):
gm_output = module(inputs)
split_gm_output = split_gm(inputs)
self.assertTrue(type(gm_output) == type(split_gm_output))
self.assertTrue(type(gm_output) is type(split_gm_output))
self.assertTrue(torch.equal(gm_output, split_gm_output))

View File

@ -514,8 +514,8 @@ class TestSubgraphRewriter(JitTestCase):
symbolic_traced: torch.fx.GraphModule = symbolic_trace(module)
for n, m in zip(symbolic_traced.graph.nodes, graph.nodes):
if n.op == "placeholder":
assert n.type == int
assert m.type == int
assert n.type is int
assert m.type is int
def test_subgraph_rewriter_replace_consecutive_submodules(self):
def f(x):

View File

@ -81,9 +81,9 @@ class BinaryFoldingTemplate(TestCase):
out_optimized = torch.compile(mod_eager)
inps = [4, 3, 4]
if module == nn.Conv2d:
if module is nn.Conv2d:
inps.append(inps[-1])
if module == nn.Conv3d:
if module is nn.Conv3d:
inps.append(inps[-1])
inps.append(inps[-1])
@ -195,9 +195,9 @@ class BinaryFoldingTemplate(TestCase):
)
inps = [4, 3, 4]
if module[0] == nn.Conv2d:
if module[0] is nn.Conv2d:
inps.append(inps[-1])
if module[0] == nn.Conv3d:
if module[0] is nn.Conv3d:
inps.append(inps[-1])
inps.append(inps[-1])

View File

@ -106,9 +106,9 @@ class TestMixin:
return keys
def key(self: Self, key_type: type[icache.Key]) -> icache.Key:
if key_type == str:
if key_type is str:
return f"s{randint(0, 2**32)}"
elif key_type == int:
elif key_type is int:
return randint(0, 2**32)
elif key_type == tuple[Any, ...]:
return (self.key(str), self.key(int))
@ -125,13 +125,13 @@ class TestMixin:
return values
def value(self: Self, value_type: type[icache.Value]) -> icache.Value:
if value_type == str:
if value_type is str:
return f"s{randint(0, 2**32)}"
elif value_type == int:
elif value_type is int:
return randint(0, 2**32)
elif value_type == tuple[Any, ...]:
return (self.value(str), self.value(int))
elif value_type == bytes:
elif value_type is bytes:
return self.value(str).encode()
elif value_type == dict[Any, Any]:
return {

View File

@ -88,7 +88,7 @@ def _check_if_instances_equal(op1, op2) -> bool:
if isinstance(op1, (list | tuple)):
return tuple(op1) == tuple(op2)
if type(op1) != type(op2):
if type(op1) is not type(op2):
return False
# some classes have __eq__ defined but they may be insufficient

View File

@ -127,11 +127,11 @@ class EfficientConvBNEvalTemplate(TestCase):
spatial_d = (
4 if issubclass(module[0], nn.modules.conv._ConvTransposeNd) else 96
)
if module[0] == nn.Conv1d or module[0] == nn.ConvTranspose1d:
if module[0] is nn.Conv1d or module[0] is nn.ConvTranspose1d:
inps += [spatial_d] * 1
if module[0] == nn.Conv2d or module[0] == nn.ConvTranspose2d:
if module[0] is nn.Conv2d or module[0] is nn.ConvTranspose2d:
inps += [spatial_d] * 2
if module[0] == nn.Conv3d or module[0] == nn.ConvTranspose3d:
if module[0] is nn.Conv3d or module[0] is nn.ConvTranspose3d:
inps += [spatial_d] * 3
inp = torch.rand(inps).to(self.device)

View File

@ -514,11 +514,11 @@ def check_model(
# print("Graph", graph)
if check_has_compiled:
assert called, "Ran graph without calling compile_fx"
assert type(actual) == type(correct)
assert type(actual) is type(correct)
if isinstance(actual, (tuple, list)):
assert len(actual) == len(correct)
assert all(
type(actual_item) == type(correct_item)
type(actual_item) is type(correct_item)
for actual_item, correct_item in zip(actual, correct)
)

View File

@ -198,7 +198,7 @@ class TestUtils(TestCase):
@dtypes(torch.float16, torch.bfloat16, torch.float32)
def test_get_device_tflops(self, dtype):
ret = get_device_tflops(dtype)
self.assertTrue(type(ret) == float)
self.assertTrue(type(ret) is float)
instantiate_device_type_tests(TestUtils, globals())

View File

@ -2083,9 +2083,9 @@ class TestFrozenOptimizations(JitTestCase):
mod_eager = ConvBN(3, 32, kernel_size=3, stride=2).eval()
inps = [4, 3, 4]
if modules[0] == nn.Conv2d:
if modules[0] is nn.Conv2d:
inps.append(inps[-1])
if modules[0] == nn.Conv3d:
if modules[0] is nn.Conv3d:
inps.append(inps[-1])
inps.append(inps[-1])
@ -2224,9 +2224,9 @@ class TestFrozenOptimizations(JitTestCase):
mod_eager = ConvOp(3, 32, kernel_size=3, stride=2).eval()
inps = [4, 3, 4]
if module == nn.Conv2d:
if module is nn.Conv2d:
inps.append(inps[-1])
if module == nn.Conv3d:
if module is nn.Conv3d:
inps.append(inps[-1])
inps.append(inps[-1])
@ -2366,10 +2366,10 @@ class TestFrozenOptimizations(JitTestCase):
mod_eager = LinearBN(32, 32).eval()
inps = [3, 32]
if modules[1] == nn.BatchNorm2d:
if modules[1] is nn.BatchNorm2d:
inps.append(inps[-1])
inps.append(inps[-1])
if modules[1] == nn.BatchNorm3d:
if modules[1] is nn.BatchNorm3d:
inps.append(inps[-1])
inps.append(inps[-1])
inps.append(inps[-1])
@ -2429,14 +2429,14 @@ class TestFrozenOptimizations(JitTestCase):
N, C = 3, bn_in
input_shape = [N, C]
if modules[1] == nn.BatchNorm1d:
if modules[1] is nn.BatchNorm1d:
H = linear_in
input_shape.append(H)
elif modules[1] == nn.BatchNorm2d:
elif modules[1] is nn.BatchNorm2d:
H, W = 4, linear_in
input_shape.append(H)
input_shape.append(W)
elif modules[1] == nn.BatchNorm3d:
elif modules[1] is nn.BatchNorm3d:
D, H, W = 4, 4, linear_in
input_shape.append(D)
input_shape.append(H)
@ -2504,10 +2504,10 @@ class TestFrozenOptimizations(JitTestCase):
mod_eager = LinearBN(32, 32).cuda().eval()
inps = [3, 32]
if modules[1] == nn.BatchNorm2d:
if modules[1] is nn.BatchNorm2d:
inps.append(inps[-1])
inps.append(inps[-1])
if modules[1] == nn.BatchNorm3d:
if modules[1] is nn.BatchNorm3d:
inps.append(inps[-1])
inps.append(inps[-1])
inps.append(inps[-1])
@ -2757,9 +2757,9 @@ class TestFrozenOptimizations(JitTestCase):
for module, trace in product([nn.Conv2d, nn.Conv3d], [False, True]):
mod = module(3, 32, kernel_size=3, stride=2).eval()
inps = [4, 3, 4]
if module == nn.Conv2d:
if module is nn.Conv2d:
inps.append(inps[-1])
if module == nn.Conv3d:
if module is nn.Conv3d:
inps.append(inps[-1])
inps.append(inps[-1])
@ -2997,7 +2997,7 @@ class TestFrozenOptimizations(JitTestCase):
mod_eager = Net(3, 6, kernel_size=3, stride=2).eval().cuda()
inps = [5, 3, 4, 4]
if conv == nn.Conv3d:
if conv is nn.Conv3d:
inps.append(inps[-1])
inp = torch.rand(inps).cuda()

View File

@ -210,7 +210,7 @@ class TestTyping(JitTestCase):
li_1, li_2, li_3 = stuff4([True])
li_3 = li_3[0]
for li in [li_1, li_2, li_3]:
self.assertTrue(type(li[0]) == bool)
self.assertTrue(type(li[0]) is bool)
def test_nested_list(self):
def foo(z):

View File

@ -3839,9 +3839,9 @@ class TestConvolutionNNDeviceType(NNTestCase):
# This is because we have N111 weight that cannot handle
# the ambiguous memory_format
if w_f == torch.channels_last:
if layer == nn.Conv2d and filter_size * c != 1:
if layer is nn.Conv2d and filter_size * c != 1:
output_format = torch.channels_last
if layer == nn.ConvTranspose2d and filter_size * k != 1:
if layer is nn.ConvTranspose2d and filter_size * k != 1:
output_format = torch.channels_last
self._run_conv(
layer,

View File

@ -474,8 +474,8 @@ def load_torch_function_handler(cls, func, types, args=(), kwargs=None):
f"Expected isinstance(src, {cls}) but got {type(src)}"
)
assert (
type(dest) == torch.Tensor
or type(dest) == torch.nn.Parameter
type(dest) is torch.Tensor
or type(dest) is torch.nn.Parameter
or issubclass(cls, type(dest))
)
if assign:

View File

@ -3053,7 +3053,7 @@ class TestQuantizedOps(TestCase):
lstm_quantized = torch.ao.quantization.convert(
lstm_prepared, convert_custom_config_dict=custom_config_dict
)
assert type(lstm_quantized[0]) == torch.ao.nn.quantized.LSTM
assert type(lstm_quantized[0]) is torch.ao.nn.quantized.LSTM
qy = lstm_quantized(qx)
snr = _snr(y, qy)

View File

@ -138,7 +138,7 @@ class TestObserver(QuantizationTestCase):
# Calculate Qparams should return with a warning for observers with no data
qparams = myobs.calculate_qparams()
input_scale = 2**16 if qdtype is torch.qint32 else 1
if type(myobs) == MinMaxObserver:
if type(myobs) is MinMaxObserver:
x = torch.tensor([1.0, 2.0, 2.0, 3.0, 4.0, 5.0, 6.0]) * input_scale
y = torch.tensor([4.0, 5.0, 5.0, 6.0, 7.0, 8.0]) * input_scale
else:
@ -201,7 +201,7 @@ class TestObserver(QuantizationTestCase):
[[[-4.0, -3.0], [5.0, 5.0]], [[6.0, 3.0], [7.0, 8.0]]],
]
)
if type(myobs) == MovingAveragePerChannelMinMaxObserver:
if type(myobs) is MovingAveragePerChannelMinMaxObserver:
# Scaling the input tensor to model change in min/max values
# across batches
result = myobs(0.5 * x)

View File

@ -599,7 +599,7 @@ class TestFakeQuantizeOps(TestCase):
# Output of fake quant is not identical to input
Y = fq_module(X)
self.assertNotEqual(Y, X)
if type(fq_module) == _LearnableFakeQuantize:
if type(fq_module) is _LearnableFakeQuantize:
fq_module.toggle_fake_quant(False)
else:
torch.ao.quantization.disable_fake_quant(fq_module)
@ -613,7 +613,7 @@ class TestFakeQuantizeOps(TestCase):
scale = fq_module.scale.detach().clone()
zero_point = fq_module.zero_point.detach().clone()
if type(fq_module) == _LearnableFakeQuantize:
if type(fq_module) is _LearnableFakeQuantize:
fq_module.toggle_observer_update(False)
fq_module.toggle_fake_quant(True)
else:
@ -625,7 +625,7 @@ class TestFakeQuantizeOps(TestCase):
# Observer is disabled, scale and zero-point do not change
self.assertEqual(fq_module.scale, scale)
self.assertEqual(fq_module.zero_point, zero_point)
if type(fq_module) == _LearnableFakeQuantize:
if type(fq_module) is _LearnableFakeQuantize:
fq_module.toggle_observer_update(True)
else:
torch.ao.quantization.enable_observer(fq_module)

View File

@ -241,7 +241,7 @@ class _ReferenceConvBnNd(torch.nn.Conv2d, torch.nn.modules.conv._ConvNd):
Args: `mod` a float module, either produced by torch.ao.quantization utilities
or directly from user
"""
assert type(mod) == cls._FLOAT_MODULE, (
assert type(mod) is cls._FLOAT_MODULE, (
"qat."
+ cls.__name__
+ ".from_float only works for "
@ -1264,8 +1264,8 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
mp = prepare_qat(m)
mp(data)
mq = convert(mp)
self.assertTrue(type(mq[1]) == nnq.Linear)
self.assertTrue(type(mq[2]) == nn.Identity)
self.assertTrue(type(mq[1]) is nnq.Linear)
self.assertTrue(type(mq[2]) is nn.Identity)
@skipIfNoXNNPACK
@override_qengines

View File

@ -1823,7 +1823,7 @@ class TestFxModelReportVisualizer(QuantizationTestCase):
plottable_set = set()
for feature_name in b_1_linear_features:
if type(b_1_linear_features[feature_name]) == torch.Tensor:
if type(b_1_linear_features[feature_name]) is torch.Tensor:
plottable_set.add(feature_name)
returned_plottable_feats = mod_rep_visualizer.get_all_unique_feature_names()

View File

@ -826,7 +826,7 @@ class TestFuseFx(QuantizationTestCase):
# check conv module has two inputs
named_modules = dict(m.named_modules())
for node in m.graph.nodes:
if node.op == "call_module" and type(named_modules[node.target]) == torch.nn.Conv2d:
if node.op == "call_module" and type(named_modules[node.target]) is torch.nn.Conv2d:
self.assertTrue(len(node.args) == 2, msg="Expecting the fused op to have two arguments")
def test_fusion_pattern_with_matchallnode(self):
@ -917,7 +917,7 @@ class TestQuantizeFx(QuantizationTestCase):
m = torch.fx.symbolic_trace(M())
modules = dict(m.named_modules())
for n in m.graph.nodes:
if n.op == 'call_module' and type(modules[n.target]) == nn.ReLU:
if n.op == 'call_module' and type(modules[n.target]) is nn.ReLU:
self.assertTrue(_is_match(modules, n, pattern))
def test_pattern_match_constant(self):

View File

@ -454,8 +454,8 @@ class TestSubgraphRewriter(JitTestCase):
symbolic_traced: torch.fx.GraphModule = symbolic_trace(module)
for n, m in zip(symbolic_traced.graph.nodes, graph.nodes):
if n.op == 'placeholder':
assert n.type == int
assert m.type == int
assert n.type is int
assert m.type is int
def test_subgraph_writer_replace_consecutive_submodules(self):

View File

@ -332,7 +332,7 @@ class TestHelperModules:
) -> None:
super().__init__()
self.linear = nn.Linear(4, 4, bias=use_bias)
if postop == nn.GELU:
if postop is nn.GELU:
self.postop = postop(approximate=post_op_algo)
else:
self.postop = postop(inplace=inplace_postop)

View File

@ -4162,7 +4162,7 @@ class TestBinaryUfuncs(TestCase):
for i in complex_exponents if exp_dtype.is_complex else exponents:
out_dtype_scalar_exp = (
torch.complex128
if base_dtype.is_complex or type(i) == complex
if base_dtype.is_complex or type(i) is complex
else torch.float64
)
expected_scalar_exp = torch.from_numpy(np.float_power(to_np(base), i))
@ -4190,7 +4190,7 @@ class TestBinaryUfuncs(TestCase):
for i in complex_exponents if base_dtype.is_complex else exponents:
out_dtype_scalar_base = (
torch.complex128
if exp_dtype.is_complex or type(i) == complex
if exp_dtype.is_complex or type(i) is complex
else torch.float64
)
expected_scalar_base = torch.from_numpy(np.float_power(i, to_np(exp)))
@ -4205,9 +4205,9 @@ class TestBinaryUfuncs(TestCase):
def test_float_power_exceptions(self, device):
def _promo_helper(x, y):
for i in (x, y):
if type(i) == complex:
if type(i) is complex:
return torch.complex128
elif type(i) == torch.Tensor and i.is_complex():
elif type(i) is torch.Tensor and i.is_complex():
return torch.complex128
return torch.double

View File

@ -2478,7 +2478,7 @@ class TestTyping(TestCase):
else:
self.assertFalse(issubinstance(d, S))
for t in basic_type:
if type(d) == t:
if type(d) is t:
self.assertTrue(issubinstance(d, t))
else:
self.assertFalse(issubinstance(d, t))
@ -2577,7 +2577,7 @@ class TestTyping(TestCase):
self.assertTrue(issubclass(DP4, IterDataPipe))
dp4 = DP4()
self.assertTrue(dp4.type.param == tuple)
self.assertTrue(dp4.type.param is tuple)
class DP5(IterDataPipe):
r"""DataPipe without type annotation"""
@ -2601,7 +2601,7 @@ class TestTyping(TestCase):
self.assertTrue(issubclass(DP6, IterDataPipe))
dp6 = DP6()
self.assertTrue(dp6.type.param == int)
self.assertTrue(dp6.type.param is int)
class DP7(IterDataPipe[Awaitable[T_co]]):
r"""DataPipe with abstract base class"""

View File

@ -878,7 +878,7 @@ def forward(self, scores_1, mask_1, value_1):
zip(real_out, decomp_out, real_out_double)
):
if not isinstance(orig, torch.Tensor):
assert type(orig) == type(decomp)
assert type(orig) is type(decomp)
assert orig == decomp
continue
op_assert_ref(
@ -895,7 +895,7 @@ def forward(self, scores_1, mask_1, value_1):
else:
for orig, decomp in zip(real_out, decomp_out):
if not isinstance(orig, torch.Tensor):
assert type(orig) == type(decomp)
assert type(orig) is type(decomp)
assert orig == decomp
continue
op_assert_equal(

View File

@ -2887,9 +2887,9 @@ graph(%Ra, %Rb):
self.assertTrue(hasattr(input, 'type'))
self.assertTrue(input.type() is not None)
self.assertTrue(hasattr(block, 'returnNode'))
self.assertTrue(type(block.returnNode()) == torch._C.Node)
self.assertTrue(type(block.returnNode()) is torch._C.Node)
self.assertTrue(hasattr(block, 'paramNode'))
self.assertTrue(type(block.paramNode()) == torch._C.Node)
self.assertTrue(type(block.paramNode()) is torch._C.Node)
self.assertTrue(tested_blocks)
def test_export_opnames(self):
@ -6510,7 +6510,7 @@ a")
if isinstance(res_python, Exception):
continue
if type(res_python) == type(res_script):
if type(res_python) is type(res_script):
if isinstance(res_python, tuple) and (math.isnan(res_python[0]) == math.isnan(res_script[0])):
continue
if isinstance(res_python, float) and math.isnan(res_python) and math.isnan(res_script):
@ -8646,7 +8646,7 @@ dedent """
args = args + [1, 1.5]
def isBool(arg):
return type(arg) == bool or (type(arg) == str and "torch.bool" in arg)
return type(arg) is bool or (type(arg) is str and "torch.bool" in arg)
for op in ops:
for first_arg in args:
@ -8655,7 +8655,7 @@ dedent """
if (op == 'sub' or op == 'div') and (isBool(first_arg) or isBool(second_arg)):
continue
# div is not implemented correctly for mixed-type or int params
if (op == 'div' and (type(first_arg) != type(second_arg) or
if (op == 'div' and (type(first_arg) is not type(second_arg) or
isinstance(first_arg, int) or
(isinstance(first_arg, str) and 'int' in first_arg))):
continue
@ -8671,7 +8671,7 @@ dedent """
graph = cu.func.graph
torch._C._jit_pass_complete_shape_analysis(graph, (), False)
# use dim=-1 to represent a python/jit scalar.
dim = -1 if type(first_arg) != str and type(second_arg) != str else non_jit_result.dim()
dim = -1 if type(first_arg) is not str and type(second_arg) is not str else non_jit_result.dim()
dtype = non_jit_result.dtype
# jit only supports int/float scalars.
if dim < 0:

View File

@ -211,9 +211,9 @@ def autograd_sharing(queue, ready, master_modified, device, is_parameter):
is_ok &= var.grad is None
is_ok &= not var._backward_hooks
if is_parameter:
is_ok &= type(var) == Parameter
is_ok &= type(var) is Parameter
else:
is_ok &= type(var) == torch.Tensor
is_ok &= type(var) is torch.Tensor
var._grad = torch.ones(5, 5, device=device)
queue.put(is_ok)

View File

@ -596,7 +596,7 @@ class TestNumPyInterop(TestCase):
if (
dtype == torch.complex64
and torch.is_tensor(t)
and type(a) == np.complex64
and type(a) is np.complex64
):
# TODO: Imaginary part is dropped in this case. Need fix.
# https://github.com/pytorch/pytorch/issues/43579

View File

@ -3327,7 +3327,7 @@ class TestReductions(TestCase):
"""
def _test_histogramdd_numpy(self, t, bins, bin_range, weights, density):
def to_np(t):
if type(t) == list:
if type(t) is list:
return list(map(to_np, t))
if not torch.is_tensor(t):
return t

View File

@ -968,7 +968,7 @@ class TestTypePromotion(TestCase):
except Exception as e:
expected = e
same_result = (type(expected) == type(actual)) and expected == actual
same_result = (type(expected) is type(actual)) and expected == actual
# Note: An "undesired failure," as opposed to an "expected failure"
# is both expected (we know the test will fail) and
@ -1128,7 +1128,7 @@ class TestTypePromotion(TestCase):
maxs = (max_t, max_t[0], max_t[0].item())
inp = make_tensor((S,), dtype0)
for min_v, max_v in itertools.product(mins, maxs):
if type(max_v) != type(min_v):
if type(max_v) is not type(min_v):
continue
if isinstance(min_v, torch.Tensor) and min_v.ndim == 0 and max_v.ndim == 0:
continue # 0d tensors go to scalar overload, and it's tested separately

View File

@ -2384,7 +2384,7 @@ class TestLikeFuncs(TestCase):
b = a[:, ::2] # Ensure b is not contiguous.
kwargs = {"fill_value": ""} if likefunc == np.full_like else {}
result = likefunc(b, dtype=dtype, **kwargs)
if dtype == str:
if dtype is str:
assert result.strides == (16, 4)
else:
# dtype is bytes

View File

@ -925,7 +925,7 @@ class TestScalarSubclassingMisc(TestCase):
# inheritance has to override, or this is correctly lost:
res = op(myf_simple1(1), myf_simple2(2))
assert type(res) == sctype or type(res) == np.bool_
assert type(res) is sctype or type(res) is np.bool_
assert op(myf_simple1(1), myf_simple2(2)) == op(1, 2) # inherited
# Two independent subclasses do not really define an order. This could
@ -955,7 +955,7 @@ class TestScalarSubclassingMisc(TestCase):
assert op(myt(1), np.float64(2)) == __op__
assert op(np.float64(1), myt(2)) == __rop__
if op in {operator.mod, operator.floordiv} and subtype == complex:
if op in {operator.mod, operator.floordiv} and subtype is complex:
return # module is not support for complex. Do not test.
if __rop__ == __op__:
@ -968,11 +968,11 @@ class TestScalarSubclassingMisc(TestCase):
res = op(myt(1), np.float16(2))
expected = op(subtype(1), np.float16(2))
assert res == expected
assert type(res) == type(expected)
assert type(res) is type(expected)
res = op(np.float32(2), myt(1))
expected = op(np.float32(2), subtype(1))
assert res == expected
assert type(res) == type(expected)
assert type(res) is type(expected)
if __name__ == "__main__":

View File

@ -937,7 +937,7 @@ class DetCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):
@instantiate_parametrized_tests
class TestDet(DetCases, TestCase):
def test_zero(self):
# NB: comment out tests of type(det) == double : we return zero-dim arrays
# NB: comment out tests of type(det) is double : we return zero-dim arrays
assert_equal(linalg.det([[0.0]]), 0.0)
# assert_equal(type(linalg.det([[0.0]])), double)
assert_equal(linalg.det([[0.0j]]), 0.0)
@ -1103,7 +1103,7 @@ class TestMatrixPower(TestCase):
for mat in self.rshft_all:
tz(mat.astype(dt))
if dt != object:
if dt is not object:
tz(self.stacked.astype(dt))
@parametrize("dt", [np.dtype(c) for c in "?bBhilefdFD"])
@ -1115,7 +1115,7 @@ class TestMatrixPower(TestCase):
for mat in self.rshft_all:
tz(mat.astype(dt))
if dt != object:
if dt is not object:
tz(self.stacked.astype(dt))
@parametrize("dt", [np.dtype(c) for c in "?bBhilefdFD"])
@ -1128,7 +1128,7 @@ class TestMatrixPower(TestCase):
for mat in self.rshft_all:
tz(mat.astype(dt))
if dt != object:
if dt is not object:
tz(self.stacked.astype(dt))
@parametrize("dt", [np.dtype(c) for c in "?bBhilefdFD"])

View File

@ -661,7 +661,7 @@ class TestIter(TestCase):
# numpy generates array scalars, we do 0D arrays
a = np.arange(5)
lst = list(a)
assert all(type(x) == np.ndarray for x in lst), f"{[type(x) for x in lst]}"
assert all(type(x) is np.ndarray for x in lst), f"{[type(x) for x in lst]}"
assert all(x.ndim == 0 for x in lst)
def test_iter_2d(self):
@ -669,7 +669,8 @@ class TestIter(TestCase):
a = np.arange(5)[None, :]
lst = list(a)
assert len(lst) == 1
assert type(lst[0]) == np.ndarray
# FIXME: "is" cannot be used here because dynamo fails
assert type(lst[0]) == np.ndarray # noqa: E721
assert_equal(lst[0], np.arange(5))

View File

@ -94,7 +94,7 @@ class TestNEP50Table(TestCase):
def test_nep50_exceptions(self, example):
old, new = examples[example]
if new == Exception:
if new is Exception:
with assert_raises(OverflowError):
eval(example)

View File

@ -554,7 +554,7 @@ def fuzz_scalar(spec, seed: Optional[int] = None) -> Union[float, int, bool, com
def specs_compatible(spec1: Spec, spec2: Spec) -> bool:
"""Check if two specifications are compatible (one can be used where the other is expected)."""
if type(spec1) != type(spec2):
if type(spec1) is not type(spec2):
return False
if isinstance(spec1, ScalarSpec):

View File

@ -2842,7 +2842,7 @@ def _index_add(
if alpha != 1:
python_type = utils.dtype_to_type(x.dtype)
torch._check(
python_type == bool
python_type is bool
or utils.is_weakly_lesser_type(type(alpha), python_type),
lambda: f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!",
)

View File

@ -295,7 +295,7 @@ class PyCodegen:
output.extend(create_call_function(2, False))
elif (
isinstance(value, SymNodeVariable)
and value.python_type() == float
and value.python_type() is float
and not self.tx.export
):
# This is a little unusual; force the output convention to be a

View File

@ -4182,7 +4182,7 @@ def make_torch_function_mode_stack_guard(
return False
for ty, mode in zip(types, cur_stack):
if ty != type(mode):
if ty is not type(mode):
return False
return True

View File

@ -1361,7 +1361,7 @@ class TensorVariable(VariableTracker):
if (len(args) == 1 and isinstance(args[0], SizeVariable)) or (
len(args) >= 1
and all(
isinstance(a, ConstantVariable) and a.python_type() == int for a in args
isinstance(a, ConstantVariable) and a.python_type() is int for a in args
)
):
from ..symbolic_convert import InstructionTranslator

View File

@ -64,14 +64,14 @@ def _staged_schema():
)
elif o := typing.get_origin(t):
# Lemme know if there's a better way to do this.
if o == list:
if o is list:
yaml_head, cpp_head, thrift_head, thrift_tail = (
"List",
"std::vector",
"list<",
">",
)
elif o == dict:
elif o is dict:
yaml_head, cpp_head, thrift_head, thrift_tail = (
"Dict",
"std::unordered_map",
@ -81,7 +81,7 @@ def _staged_schema():
elif o == Union:
assert level == 0, "Optional is only supported at the top level."
args = typing.get_args(t)
assert len(args) == 2 and args[1] == type(None)
assert len(args) == 2 and args[1] is type(None)
yaml_type, cpp_type, thrift_type = dump_type(args[0], level + 1)
return (
f"Optional[{yaml_type}]",

View File

@ -83,7 +83,7 @@ class HopPartitionedGraph:
val1: Union[torch.SymInt, torch.Tensor],
val2: Union[torch.SymInt, torch.Tensor],
) -> bool:
if type(val1) != type(val2):
if type(val1) is not type(val2):
return False
if isinstance(val1, torch.SymInt) and isinstance(val2, torch.SymInt):

View File

@ -1211,7 +1211,7 @@ class CppVecOverrides(CppOverrides):
return wrapper
for name, method in vars(CppVecOverrides).items():
if getattr(method, "__class__", None) == staticmethod and name not in [
if getattr(method, "__class__", None) is staticmethod and name not in [
"masked",
"index_expr",
]:

View File

@ -220,15 +220,15 @@ class SamplingMethod(Enum):
if field_name in TYPE_OVERRIDES:
return random.choice(TYPE_OVERRIDES[field_name])
if type_hint == bool:
if type_hint is bool:
return random.choice([True, False]) if random_sample else not default
elif type_hint == int:
elif type_hint is int:
# NOTE initially tried to use negation of the value, but it doesn't work because most types are ints
# when they should be natural numbers + zero. Python types to cover these values aren't super convenient.
return random.randint(0, 1000)
elif type_hint == float:
elif type_hint is float:
return random.uniform(0, 1000)
elif type_hint == str:
elif type_hint is str:
characters = string.ascii_letters + string.digits + string.punctuation
return "".join(
random.choice(characters) for _ in range(random.randint(1, 20))
@ -306,7 +306,7 @@ class SamplingMethod(Enum):
new_type = random.choice(type_hint.__args__)
else:
new_type = random.choice(
[t for t in type_hint.__args__ if t != type(default)]
[t for t in type_hint.__args__ if t is not type(default)]
)
try:
new_default = new_type()

View File

@ -1208,7 +1208,7 @@ def safe_grad_filter(message, category, filename, lineno, file=None, line=None)
def user_warning_filter(
message, category, filename, lineno, file=None, line=None
) -> bool:
return category != UserWarning
return category is not UserWarning
@contextlib.contextmanager

View File

@ -428,7 +428,7 @@ def percentile(
interpolation: NotImplementedType = None,
):
# np.percentile(float_tensor, 30) : q.dtype is int64 => q / 100.0 is float32
if _dtypes_impl.python_type_for_torch(q.dtype) == int:
if _dtypes_impl.python_type_for_torch(q.dtype) is int:
q = q.to(_dtypes_impl.default_dtypes().float_dtype)
qq = q / 100.0

View File

@ -1179,7 +1179,7 @@ def add(
if alpha is not None:
dtype = a.dtype if isinstance(a, TensorLike) else b.dtype # type: ignore[union-attr]
python_type = utils.dtype_to_type(dtype)
if python_type != bool and not utils.is_weakly_lesser_type(
if python_type is not bool and not utils.is_weakly_lesser_type(
type(alpha), python_type
):
msg = f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!"

View File

@ -755,7 +755,7 @@ class ExceptionWrapper:
# Format a message such as: "Caught ValueError in DataLoader worker
# process 2. Original Traceback:", followed by the traceback.
msg = f"Caught {self.exc_type.__name__} {self.where}.\nOriginal {self.exc_msg}" # pyrefly: ignore # missing-attribute
if self.exc_type == KeyError:
if self.exc_type is KeyError:
# KeyError calls repr() on its argument (usually a dict key). This
# makes stack traces unreadable. It will not be changed in Python
# (https://bugs.python.org/issue2651), so we work around it.

View File

@ -317,7 +317,7 @@ def get_arg_indices_of_inputs_to_log(node: Node) -> list[int]:
node.target in (torch.add, torch.ops.quantized.add, operator.add)
or node.target in (torch.mul, torch.ops.quantized.mul, operator.mul)
):
result = [i for i in range(2) if type(node.args[i]) == Node]
result = [i for i in range(2) if type(node.args[i]) is Node]
return result
return [0]

View File

@ -589,7 +589,7 @@ def _match_static_pattern(
# Handle cases where the node is wrapped in a ReLU
if (ref_node.op == "call_function" and ref_node.target in (F.relu, torch.relu)) or (
ref_node.op == "call_module" and type(_get_module(ref_node, modules)) == nn.ReLU
ref_node.op == "call_module" and type(_get_module(ref_node, modules)) is nn.ReLU
):
relu_node = ref_node
ref_node = relu_node.args[0]
@ -724,7 +724,7 @@ def _lower_static_weighted_ref_module(
# If so, we replace the entire fused module with the corresponding quantized module
if ref_class in STATIC_LOWER_FUSED_MODULE_MAP:
inner_ref_class, q_class = STATIC_LOWER_FUSED_MODULE_MAP[ref_class]
if type(ref_module[0]) != inner_ref_class: # type: ignore[index]
if type(ref_module[0]) is not inner_ref_class: # type: ignore[index]
continue
else:
q_class = STATIC_LOWER_MODULE_MAP[ref_class]
@ -786,7 +786,7 @@ def _lower_static_weighted_ref_module_with_two_inputs(
inner_ref_class, q_class = STATIC_LOWER_FUSED_MODULE_TWO_INPUTS_MAP[
ref_class
]
if type(ref_module[0]) != inner_ref_class: # type: ignore[index]
if type(ref_module[0]) is not inner_ref_class: # type: ignore[index]
continue
else:
continue
@ -846,7 +846,7 @@ def _lower_dynamic_weighted_ref_module(model: GraphModule):
ref_class = type(ref_module)
if ref_class in DYNAMIC_LOWER_FUSED_MODULE_MAP:
inner_ref_class, q_class = DYNAMIC_LOWER_FUSED_MODULE_MAP[ref_class]
if type(ref_module[0]) != inner_ref_class:
if type(ref_module[0]) is not inner_ref_class:
continue
else:
q_class = DYNAMIC_LOWER_MODULE_MAP.get(ref_class) # type: ignore[assignment]
@ -1008,7 +1008,7 @@ def _lower_dynamic_weighted_ref_functional(
func_node.op == "call_function"
and func_node.target == F.relu
or func_node.op == "call_module"
and type(modules[str(func_node.target)]) == torch.nn.ReLU
and type(modules[str(func_node.target)]) is torch.nn.ReLU
):
relu_node = func_node
func_node = relu_node.args[0]

View File

@ -132,7 +132,7 @@ class ModelReportVisualizer:
# if we need plottable, ensure type of val is tensor
if (
not plottable_features_only
or type(feature_dict[feature_name]) == torch.Tensor
or type(feature_dict[feature_name]) is torch.Tensor
):
unique_feature_names.add(feature_name)

View File

@ -704,7 +704,7 @@ def _maybe_get_custom_module_lstm_from_node_arg(
return a.op == "call_function" and a.target == operator.getitem
def match_tuple(a):
return a.op == "call_function" and a.target == tuple
return a.op == "call_function" and a.target is tuple
def _match_pattern(match_pattern: list[Callable]) -> Optional[Node]:
"""
@ -797,7 +797,7 @@ def _reroute_tuple_getitem_pattern(graph: Graph):
# Iterate through users of this node to find tuple/getitem nodes to match
for user in node.users:
if user.op == "call_function" and user.target == tuple:
if user.op == "call_function" and user.target is tuple:
for i, user_arg in enumerate(user.args[0]): # type: ignore[arg-type]
if user_arg == node:
index_stack.append(i)
@ -826,7 +826,7 @@ def _reroute_tuple_getitem_pattern(graph: Graph):
for pattern in matched_patterns:
first_tuple = pattern[0]
last_getitem = pattern[-1]
assert first_tuple.op == "call_function" and first_tuple.target == tuple
assert first_tuple.op == "call_function" and first_tuple.target is tuple
assert (
last_getitem.op == "call_function"
and last_getitem.target == operator.getitem

View File

@ -699,12 +699,12 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
state_dict_config = state_dict_config_type()
if optim_state_dict_config is None:
optim_state_dict_config = optim_state_dict_config_type()
if state_dict_config_type != type(state_dict_config):
if state_dict_config_type is not type(state_dict_config):
raise RuntimeError(
f"Expected state_dict_config of type {state_dict_config_type} "
f"but got {type(state_dict_config)}"
)
if optim_state_dict_config_type != type(optim_state_dict_config):
if optim_state_dict_config_type is not type(optim_state_dict_config):
raise RuntimeError(
f"Expected optim_state_dict_config of type {optim_state_dict_config_type} "
f"but got {type(optim_state_dict_config)}"

View File

@ -180,12 +180,12 @@ def add_inference_rule(n: Node):
t2 = n.args[1].type
# handle scalar addition
if t1 == int and isinstance(t2, TensorType):
if t1 is int and isinstance(t2, TensorType):
n.type = t2
return n.type
# handle scalar addition
elif t2 == int and isinstance(t1, TensorType):
elif t2 is int and isinstance(t1, TensorType):
n.type = t1
return n.type

View File

@ -542,7 +542,7 @@ def reinplace(gm, *sample_args):
continue
if len(node.target._schema.arguments) < 1:
continue
if type(node.target._schema.arguments[0].type) != torch.TensorType:
if type(node.target._schema.arguments[0].type) is not torch.TensorType:
continue
# Step 1a: Check that the self argument we're attempting to reinplace

View File

@ -78,7 +78,7 @@ def issubtype(left, right, recursive=True):
if getattr(right, "__origin__", None) is Generic:
return True
if right == type(None):
if right is type(None):
return False
# Right-side type