mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Revert "Enable all SIM rules except disabled ones (#164645)"
This reverts commit 321e602692.
Reverted https://github.com/pytorch/pytorch/pull/164645 on behalf of https://github.com/izaitsevfb due to causes lint failures ([comment](https://github.com/pytorch/pytorch/pull/164645#issuecomment-3369274351))
This commit is contained in:
parent
321e602692
commit
5d7360bb03
|
|
@ -4035,7 +4035,7 @@ def run(runner, args, original_dir=None):
|
|||
else:
|
||||
optimize_ctx = torch._dynamo.optimize(args.backend, nopython=args.nopython)
|
||||
experiment = (
|
||||
speedup_experiment if args.backend != "torchao" else latency_experiment
|
||||
speedup_experiment if not args.backend == "torchao" else latency_experiment
|
||||
)
|
||||
if args.accuracy:
|
||||
output_filename = f"accuracy_{args.backend}.csv"
|
||||
|
|
|
|||
|
|
@ -270,7 +270,7 @@ def run_single_backend_sdpa(
|
|||
|
||||
if config.calculate_bwd_time:
|
||||
# TODO: debug backward pass for njt
|
||||
if eager_sdpa and config.attn_type != "document_mask":
|
||||
if eager_sdpa and not config.attn_type == "document_mask":
|
||||
d_out = torch.randn_like(out_eager.transpose(1, 2)).transpose(1, 2)
|
||||
backward_eager_time = benchmark_torch_function_in_microseconds(
|
||||
out_eager.backward, d_out, retain_graph=True
|
||||
|
|
|
|||
|
|
@ -181,7 +181,6 @@ ignore = [
|
|||
"SIM116", # Disable Use a dictionary instead of consecutive `if` statements
|
||||
"SIM117",
|
||||
"SIM118",
|
||||
"SIM300", # Yoda condition detected
|
||||
"UP007", # keep-runtime-typing
|
||||
"UP045", # keep-runtime-typing
|
||||
"TC006",
|
||||
|
|
@ -197,7 +196,8 @@ select = [
|
|||
"E",
|
||||
"EXE",
|
||||
"F",
|
||||
"SIM",
|
||||
"SIM1",
|
||||
"SIM911",
|
||||
"W",
|
||||
# Not included in flake8
|
||||
"FURB",
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ class TestActivationSparsifier(TestCase):
|
|||
|
||||
for key, config in sparsifier_defaults.items():
|
||||
# all the keys in combined_defaults should be present in sparsifier defaults
|
||||
assert config == combined_defaults.get(key)
|
||||
assert config == combined_defaults.get(key, None)
|
||||
|
||||
def _check_register_layer(
|
||||
self, activation_sparsifier, defaults, sparse_config, layer_args_list
|
||||
|
|
|
|||
|
|
@ -3074,7 +3074,7 @@ class TestShardedTensorFromLocalShards(ShardedTensorTestBase):
|
|||
wrong_dtype_shards, [10, 10], init_rrefs=True
|
||||
)
|
||||
|
||||
tensor_requires_grad = self.rank == 0
|
||||
tensor_requires_grad = True if self.rank == 0 else False
|
||||
wrong_requires_grad_shards = [
|
||||
sharded_tensor.Shard(
|
||||
torch.randn(
|
||||
|
|
@ -3121,7 +3121,7 @@ class TestShardedTensorFromLocalShards(ShardedTensorTestBase):
|
|||
wrong_pin_memory_local_shards, [10, 10], init_rrefs=True
|
||||
)
|
||||
|
||||
tensor_pin_memory = self.rank == 0
|
||||
tensor_pin_memory = True if self.rank == 0 else False
|
||||
wrong_pin_memory_shards_cross_ranks = [
|
||||
sharded_tensor.Shard(
|
||||
torch.randn(5, 5, pin_memory=tensor_pin_memory), local_shard_metadata
|
||||
|
|
|
|||
|
|
@ -146,7 +146,7 @@ class TestStorageBase:
|
|||
self.rank = 0 if not dist.is_initialized() else dist.get_rank()
|
||||
|
||||
def _get_ranks(self, name):
|
||||
return self.fail_conf.get(name, None)
|
||||
return self.fail_conf[name] if name in self.fail_conf else None
|
||||
|
||||
def _fail_rank(self, name):
|
||||
ranks = self._get_ranks(name)
|
||||
|
|
|
|||
|
|
@ -155,7 +155,7 @@ class TestFreezingWeights(FSDPTest):
|
|||
|
||||
ddp_kwargs = {
|
||||
"device_ids": [self.rank],
|
||||
"find_unused_parameters": bool(disable_autograd),
|
||||
"find_unused_parameters": True if disable_autograd else False,
|
||||
}
|
||||
|
||||
model = self._create_model(
|
||||
|
|
|
|||
|
|
@ -65,7 +65,7 @@ class MockPipelineStage(_PipelineStageBase):
|
|||
self.num_stages = kwargs.get("num_stages", 1)
|
||||
self.group_size = kwargs.get("group_size", 1)
|
||||
self.group_rank = kwargs.get("group_rank", 0)
|
||||
self.group = kwargs.get("group")
|
||||
self.group = kwargs.get("group", None)
|
||||
|
||||
def _create_grad_recv_info(self, *args, **kwargs):
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -1022,7 +1022,7 @@ class TestDTensorPlacementTypes(DTensorTestBase):
|
|||
assert_array_equal(expected_pad_sizes, pad_sizes)
|
||||
|
||||
is_tensor_empty = [
|
||||
not splitted_tensor.numel() > 0
|
||||
False if splitted_tensor.numel() > 0 else True
|
||||
for splitted_tensor in splitted_tensor_list
|
||||
]
|
||||
expected_is_tensor_empty = [True] * self.world_size
|
||||
|
|
@ -1045,10 +1045,12 @@ class TestDTensorPlacementTypes(DTensorTestBase):
|
|||
for i, tensor in enumerate(splitted_tensor_list)
|
||||
]
|
||||
expected_is_tensor_empty = [
|
||||
not idx < size for idx, _ in enumerate(range(self.world_size))
|
||||
False if idx < size else True
|
||||
for idx, _ in enumerate(range(self.world_size))
|
||||
]
|
||||
is_tensor_empty = [
|
||||
not unpadded_tensor.numel() > 0 for unpadded_tensor in unpadded_list
|
||||
False if unpadded_tensor.numel() > 0 else True
|
||||
for unpadded_tensor in unpadded_list
|
||||
]
|
||||
assert_array_equal(expected_is_tensor_empty, is_tensor_empty)
|
||||
|
||||
|
|
|
|||
|
|
@ -2770,7 +2770,11 @@ class WorkHookTest(MultiProcessTestCase):
|
|||
# from rank0 to other ranks. However, this is DDP's internal implementation,
|
||||
# which is subject to change in future versions.
|
||||
self.assertTrue(num_hook_fired[OpType.BROADCAST] > 0)
|
||||
ctor_allreduce = num_hook_fired.get(OpType.ALLREDUCE, 0)
|
||||
ctor_allreduce = (
|
||||
num_hook_fired[OpType.ALLREDUCE]
|
||||
if OpType.ALLREDUCE in num_hook_fired
|
||||
else 0
|
||||
)
|
||||
|
||||
x = torch.zeros(2, 1000).cuda(self.rank)
|
||||
ddp(x).sum().backward()
|
||||
|
|
|
|||
|
|
@ -82,7 +82,7 @@ def grad(L, desired_results: list[Variable]) -> list[Variable]:
|
|||
# look up dL_dentries. If a variable is never used to compute the loss,
|
||||
# we consider its gradient None, see the note below about zeros for more information.
|
||||
def gather_grad(entries: list[str]):
|
||||
return [dL_d.get(entry) for entry in entries]
|
||||
return [dL_d[entry] if entry in dL_d else None for entry in entries]
|
||||
|
||||
# propagate the gradient information backward
|
||||
for entry in reversed(gradient_tape):
|
||||
|
|
|
|||
|
|
@ -286,7 +286,7 @@ class OptionalScaledTensor(torch.Tensor):
|
|||
def __tensor_unflatten__(inner_tensors, metadata, outer_size, outer_stride):
|
||||
return OptionalScaledTensor(
|
||||
inner_tensors["_data"],
|
||||
inner_tensors.get("_scale", None),
|
||||
inner_tensors["_scale"] if "_scale" in inner_tensors else None,
|
||||
constant=metadata["_constant"],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ class PreDispatchSchemaCheckMode(SchemaCheckMode):
|
|||
|
||||
def _may_alias_or_mutate(self, func, types, args, kwargs):
|
||||
def unwrap(e):
|
||||
if isinstance(e, torch.Tensor) and type(e) != torch.Tensor:
|
||||
if isinstance(e, torch.Tensor) and not type(e) == torch.Tensor:
|
||||
try:
|
||||
return e.elem
|
||||
except AttributeError:
|
||||
|
|
|
|||
|
|
@ -358,7 +358,9 @@ def _sequential_split_inline_tests():
|
|||
|
||||
for i, node in enumerate(insert_locs):
|
||||
with gm.graph.inserting_before(node):
|
||||
gm.graph.call_function(torch._C._set_grad_enabled, (i % 2 == 0,), {})
|
||||
gm.graph.call_function(
|
||||
torch._C._set_grad_enabled, (True if i % 2 == 0 else False,), {}
|
||||
)
|
||||
return gm
|
||||
|
||||
x = torch.randn(2, 2)
|
||||
|
|
|
|||
|
|
@ -2932,7 +2932,9 @@ class GraphModule(torch.nn.Module):
|
|||
if autograd:
|
||||
result_flat = pytree.tree_leaves(result)
|
||||
result_exp_flat = pytree.tree_leaves(result_exp)
|
||||
exp_grad_mask = [bool(r.requires_grad) for r in result_exp_flat]
|
||||
exp_grad_mask = [
|
||||
True if r.requires_grad else False for r in result_exp_flat
|
||||
]
|
||||
self.check_autograd(
|
||||
[r for r, m in zip(result_flat, exp_grad_mask) if m],
|
||||
[r for r, m in zip(result_exp_flat, exp_grad_mask) if m],
|
||||
|
|
@ -3739,7 +3741,9 @@ class AssociativeScanTests(TestCase):
|
|||
):
|
||||
result_flat = pytree.tree_leaves(result)
|
||||
result_exp_flat = pytree.tree_leaves(result_exp)
|
||||
exp_grad_mask = [bool(r.requires_grad) for r in result_exp_flat]
|
||||
exp_grad_mask = [
|
||||
True if r.requires_grad else False for r in result_exp_flat
|
||||
]
|
||||
|
||||
self._check_autograd(
|
||||
[r for r, m in zip(result_flat, exp_grad_mask) if m],
|
||||
|
|
@ -5706,9 +5710,10 @@ def forward(self, arg0_1):
|
|||
)
|
||||
def test_while_loop_tracing(self, while_loop_test):
|
||||
fn, inp = WHILE_LOOP_TESTS[while_loop_test]
|
||||
allow_non_fake_inputs = while_loop_test in (
|
||||
"simple_with_linear",
|
||||
"nested_with_linear",
|
||||
allow_non_fake_inputs = (
|
||||
False
|
||||
if while_loop_test not in ("simple_with_linear", "nested_with_linear")
|
||||
else True
|
||||
)
|
||||
self._check_tracing(fn, inp, allow_non_fake_inputs)
|
||||
|
||||
|
|
|
|||
|
|
@ -177,7 +177,9 @@ class TestFXNodeSource(TestCase):
|
|||
for node_name_2 in node_name_to_from_node:
|
||||
if node_name_2 in {
|
||||
node_name_1,
|
||||
same_ancestor_nodes.get(node_name_1),
|
||||
same_ancestor_nodes[node_name_1]
|
||||
if node_name_1 in same_ancestor_nodes
|
||||
else None,
|
||||
}:
|
||||
self.assertEqual(
|
||||
node_name_to_from_node[node_name_1],
|
||||
|
|
|
|||
|
|
@ -164,7 +164,9 @@ class B2BGEMMTest(TestCase):
|
|||
self.assertTrue("B2B_GEMM_LEFT_TRITON_ENTRANCE" not in code)
|
||||
self.assertTrue("B2B_GEMM_RIGHT_TRITON_ENTRANCE" not in code)
|
||||
|
||||
@unittest.skipIf(os.environ.get("DO_PERF_TEST") != "1", "Perf test not enabled")
|
||||
@unittest.skipIf(
|
||||
not (os.environ.get("DO_PERF_TEST") == "1"), "Perf test not enabled"
|
||||
)
|
||||
@torch._dynamo.config.patch(recompile_limit=32)
|
||||
def test_plain_b2b_gemm_performance(self):
|
||||
"""compare torch.compile(f, b2b_gemm = off) with torch.compile(f, b2b_gemm = on)"""
|
||||
|
|
@ -217,7 +219,9 @@ class B2BGEMMTest(TestCase):
|
|||
# flaky test assertion: disabled
|
||||
# self.assertTrue(average_speedup > 1)
|
||||
|
||||
@unittest.skipIf(os.environ.get("DO_PERF_TEST") != "1", "Perf test not enabled")
|
||||
@unittest.skipIf(
|
||||
not (os.environ.get("DO_PERF_TEST") == "1"), "Perf test not enabled"
|
||||
)
|
||||
@torch._dynamo.config.patch(recompile_limit=32)
|
||||
def test_gelu_b2b_gemm_performance(self):
|
||||
"""compare torch.compile(f, b2b_gemm = off) with torch.compile(f, b2b_gemm = on)"""
|
||||
|
|
@ -272,7 +276,9 @@ class B2BGEMMTest(TestCase):
|
|||
# flaky test assertion: disabled
|
||||
# self.assertTrue(average_speedup > 1)
|
||||
|
||||
@unittest.skipIf(os.environ.get("DO_PERF_TEST") != "1", "Perf test not enabled")
|
||||
@unittest.skipIf(
|
||||
not (os.environ.get("DO_PERF_TEST") == "1"), "Perf test not enabled"
|
||||
)
|
||||
@torch._dynamo.config.patch(recompile_limit=32)
|
||||
def test_gelu_mlp_b2b_gemm_performance(self):
|
||||
"""compare torch.compile(f, b2b_gemm = off) with torch.compile(f, b2b_gemm = on)"""
|
||||
|
|
|
|||
|
|
@ -165,7 +165,7 @@ class BenchmarkFusionTestTemplate:
|
|||
_, out_code = run_and_get_code(foo_c, m, inp)
|
||||
|
||||
# occasionally, CI will make this one kernel. just skip in this case
|
||||
if out_code[0].count("def triton_") != 2:
|
||||
if not out_code[0].count("def triton_") == 2:
|
||||
return
|
||||
|
||||
# should be multiple triton invocations
|
||||
|
|
|
|||
|
|
@ -289,7 +289,7 @@ def build_opt_kwarg_db():
|
|||
|
||||
has_tensor_lr = False
|
||||
for key, val in kwargs.items():
|
||||
if (key != "lr" and key != "betas") and (
|
||||
if (not key == "lr" and not key == "betas") and (
|
||||
not isinstance(val, bool) or (isinstance(val, bool) and val)
|
||||
):
|
||||
name += "_" + key
|
||||
|
|
@ -450,7 +450,7 @@ def make_test(
|
|||
stack.enter_context(config.patch({"triton.cudagraphs": True}))
|
||||
|
||||
kwargs_compiled = deepcopy(kwargs)
|
||||
if isinstance(kwargs.get("lr"), torch.Tensor):
|
||||
if isinstance(kwargs.get("lr", None), torch.Tensor):
|
||||
kwargs["lr"] = kwargs["lr"].to(device)
|
||||
kwargs_compiled["lr"] = kwargs_compiled["lr"].to(device)
|
||||
|
||||
|
|
|
|||
|
|
@ -583,7 +583,9 @@ class TestFlexAttention(InductorTestCase):
|
|||
)
|
||||
q_ref, k_ref, v_ref = query_key_value_clones(q, k, v)
|
||||
q_gold, k_gold, v_gold = query_key_value_clones(q, k, v, torch.float64)
|
||||
sdpa_partial = create_attention(score_mod, block_mask, enable_gqa=(Q_H != KV_H))
|
||||
sdpa_partial = create_attention(
|
||||
score_mod, block_mask, enable_gqa=(not Q_H == KV_H)
|
||||
)
|
||||
|
||||
compiled_sdpa = torch.compile(sdpa_partial)
|
||||
golden_out = sdpa_partial(q_gold, k_gold, v_gold)
|
||||
|
|
@ -757,7 +759,7 @@ class TestFlexAttention(InductorTestCase):
|
|||
return_lse=return_lse,
|
||||
block_mask=converted_block_mask,
|
||||
score_mod=converted_score_mod,
|
||||
enable_gqa=(Q_H != KV_H),
|
||||
enable_gqa=(not Q_H == KV_H),
|
||||
kernel_options=kernel_options,
|
||||
)
|
||||
else:
|
||||
|
|
@ -770,7 +772,7 @@ class TestFlexAttention(InductorTestCase):
|
|||
return_lse=return_lse,
|
||||
block_mask=converted_block_mask,
|
||||
score_mod=converted_score_mod,
|
||||
enable_gqa=(Q_H != KV_H),
|
||||
enable_gqa=(not Q_H == KV_H),
|
||||
kernel_options=kernel_options,
|
||||
)
|
||||
return compiled_out, compiled_lse
|
||||
|
|
@ -815,7 +817,9 @@ class TestFlexAttention(InductorTestCase):
|
|||
if block_mask is None:
|
||||
block_mask = create_block_mask(noop_mask, Q_B, 1, Q_S, KV_S, device=device)
|
||||
|
||||
sdpa_partial = create_attention(score_mod, block_mask, enable_gqa=(Q_H != KV_H))
|
||||
sdpa_partial = create_attention(
|
||||
score_mod, block_mask, enable_gqa=(not Q_H == KV_H)
|
||||
)
|
||||
golden_out, golden_lse = sdpa_partial(q_gold, k_gold, v_gold, return_lse=True)
|
||||
ref_out, ref_lse = sdpa_partial(q_ref, k_ref, v_ref, return_lse=True)
|
||||
|
||||
|
|
@ -1460,7 +1464,7 @@ class TestFlexAttention(InductorTestCase):
|
|||
|
||||
block_mask = create_block_mask(mask_mod, Bq, 1, S, S, device=device)
|
||||
attention = functools.partial(
|
||||
flex_attention, block_mask=block_mask, enable_gqa=(Hq != Hkv)
|
||||
flex_attention, block_mask=block_mask, enable_gqa=(not Hq == Hkv)
|
||||
)
|
||||
|
||||
self.run_test_with_call(attention, dtype, device, Bq, Hq, S, D, Bkv, Hkv, S, D)
|
||||
|
|
|
|||
|
|
@ -412,7 +412,7 @@ class TestFlexDecoding(InductorTestCase):
|
|||
sdpa_partial = create_attention(
|
||||
score_mod,
|
||||
block_mask,
|
||||
enable_gqa=(Q_H != KV_H),
|
||||
enable_gqa=(not Q_H == KV_H),
|
||||
kernel_options=kernel_options,
|
||||
)
|
||||
compiled_sdpa = torch.compile(sdpa_partial)
|
||||
|
|
@ -607,7 +607,7 @@ class TestFlexDecoding(InductorTestCase):
|
|||
return_lse=True,
|
||||
block_mask=converted_block_mask,
|
||||
score_mod=converted_score_mod,
|
||||
enable_gqa=(Q_H != KV_H),
|
||||
enable_gqa=(not Q_H == KV_H),
|
||||
)
|
||||
else:
|
||||
compiled_lse = None
|
||||
|
|
@ -618,7 +618,7 @@ class TestFlexDecoding(InductorTestCase):
|
|||
return_lse=False,
|
||||
block_mask=converted_block_mask,
|
||||
score_mod=converted_score_mod,
|
||||
enable_gqa=(Q_H != KV_H),
|
||||
enable_gqa=(not Q_H == KV_H),
|
||||
)
|
||||
return compiled_out, compiled_lse
|
||||
|
||||
|
|
@ -664,7 +664,9 @@ class TestFlexDecoding(InductorTestCase):
|
|||
if block_mask is None:
|
||||
block_mask = create_block_mask(noop_mask, Q_B, 1, 1, KV_S, device=device)
|
||||
|
||||
sdpa_partial = create_attention(score_mod, block_mask, enable_gqa=(Q_H != KV_H))
|
||||
sdpa_partial = create_attention(
|
||||
score_mod, block_mask, enable_gqa=(not Q_H == KV_H)
|
||||
)
|
||||
golden_out, gold_lse = sdpa_partial(q_gold, k_gold, v_gold, return_lse=True)
|
||||
ref_out, ref_lse = sdpa_partial(q_ref, k_ref, v_ref, return_lse=True)
|
||||
|
||||
|
|
@ -904,7 +906,7 @@ class TestFlexDecoding(InductorTestCase):
|
|||
sdpa_partial = create_attention(
|
||||
score_mod=score_mod,
|
||||
block_mask=None,
|
||||
enable_gqa=(Hq != Hkv),
|
||||
enable_gqa=(not Hq == Hkv),
|
||||
)
|
||||
compiled_sdpa = torch.compile(sdpa_partial)
|
||||
ref_out = sdpa_partial(q, k, v)
|
||||
|
|
@ -1142,7 +1144,7 @@ class TestFlexDecoding(InductorTestCase):
|
|||
|
||||
def head_attention_mod(kv_head_num):
|
||||
head_type = torch.tensor(
|
||||
[i % kv_head_num != 0 for i in range(kv_head_num)],
|
||||
[False if i % kv_head_num == 0 else True for i in range(kv_head_num)],
|
||||
dtype=torch.bool,
|
||||
device=device,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ except ImportError:
|
|||
HAS_PYDOT = False
|
||||
|
||||
|
||||
HAS_DOT = shutil.which("dot") is not None
|
||||
HAS_DOT = True if shutil.which("dot") is not None else False
|
||||
|
||||
|
||||
class TestGraphTransformObserver(TestCase):
|
||||
|
|
|
|||
|
|
@ -834,7 +834,9 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
|||
|
||||
for dtype in dtypes:
|
||||
torch._dynamo.reset()
|
||||
autocast_enabled = dtype in [torch.bfloat16, torch.float16]
|
||||
autocast_enabled = (
|
||||
True if dtype in [torch.bfloat16, torch.float16] else False
|
||||
)
|
||||
with (
|
||||
torch.no_grad(),
|
||||
torch.autocast(
|
||||
|
|
@ -4418,12 +4420,14 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
|||
out_feature = 64
|
||||
q_min, q_max = -32, 31
|
||||
# we only test for qlinear_binary in this case
|
||||
test_for_pointwise_binary = bool(
|
||||
M == 1
|
||||
test_for_pointwise_binary = (
|
||||
True
|
||||
if M == 1
|
||||
and inplace_add
|
||||
and not expand_a_scale
|
||||
and not dynamic
|
||||
and not has_bias
|
||||
else False
|
||||
)
|
||||
if test_for_pointwise_binary and not IS_X86:
|
||||
self.skipTest("Some UTs are only supported on x86_64 CPUs")
|
||||
|
|
|
|||
|
|
@ -706,7 +706,7 @@ def check_model_gpu(
|
|||
if check_lowp:
|
||||
|
||||
def downcast_fn(x):
|
||||
if not isinstance(x, torch.Tensor) or x.dtype != torch.float:
|
||||
if not isinstance(x, torch.Tensor) or not x.dtype == torch.float:
|
||||
return x
|
||||
return torch.empty_strided(
|
||||
x.size(), x.stride(), device=GPU_TYPE, dtype=torch.half
|
||||
|
|
@ -4693,7 +4693,7 @@ class CommonTemplate:
|
|||
# Make sure we compute also with fp16 in the reference. Otherwise,
|
||||
# the reference will compute with fp32 and cast back to fp16, which
|
||||
# causes numeric differences beyond tolerance.
|
||||
reference_in_float=not torch.version.hip,
|
||||
reference_in_float=False if torch.version.hip else True,
|
||||
)
|
||||
|
||||
def test_convolution2(self):
|
||||
|
|
@ -4727,7 +4727,7 @@ class CommonTemplate:
|
|||
# Make sure we compute also with fp16 in the reference. Otherwise,
|
||||
# the reference will compute with fp32 and cast back to fp16, which
|
||||
# causes numeric differences beyond tolerance.
|
||||
reference_in_float=not torch.version.hip,
|
||||
reference_in_float=False if torch.version.hip else True,
|
||||
)
|
||||
|
||||
@skip_if_gpu_halide
|
||||
|
|
@ -4778,7 +4778,7 @@ class CommonTemplate:
|
|||
# Make sure we compute also with fp16 in the reference. Otherwise,
|
||||
# the reference will compute with fp32 and cast back to fp16, which
|
||||
# causes numeric differences beyond tolerance.
|
||||
reference_in_float=not torch.version.hip,
|
||||
reference_in_float=False if torch.version.hip else True,
|
||||
)
|
||||
|
||||
def test_conv2d_channels_last(self):
|
||||
|
|
@ -12937,7 +12937,7 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
|
|||
)
|
||||
|
||||
res = torch.compile(fn)(20)
|
||||
self.assertTrue(torch.all((res >= 0) & (res < 10)).item())
|
||||
self.assertTrue(torch.all((0 <= res) & (res < 10)).item())
|
||||
|
||||
@torch._inductor.config.patch(force_shape_pad=True)
|
||||
@skip_if_gpu_halide # correctness issue
|
||||
|
|
|
|||
|
|
@ -1217,7 +1217,7 @@ class TestInductorOpInfo(TestCase):
|
|||
# not exercised in test_ops_gradients atm. The problem is not
|
||||
# complex32 per-se (which is supported by data movement only ops)
|
||||
# but that when we do backwards we expect other ops like add to work
|
||||
and dtype != torch.complex32
|
||||
and not dtype == torch.complex32
|
||||
)
|
||||
samples = op.sample_inputs(device, dtype, requires_grad=requires_grad)
|
||||
|
||||
|
|
|
|||
|
|
@ -17,13 +17,17 @@ with open(sys.argv[1]) as input_yaml_file:
|
|||
for info in model_infos:
|
||||
for op in info["root_operators"]:
|
||||
# aggregate occurance per op
|
||||
root_operators[op] = 1 + (root_operators.get(op, 0))
|
||||
root_operators[op] = 1 + (root_operators[op] if op in root_operators else 0)
|
||||
for op in info["traced_operators"]:
|
||||
# aggregate occurance per op
|
||||
traced_operators[op] = 1 + (traced_operators.get(op, 0))
|
||||
traced_operators[op] = 1 + (
|
||||
traced_operators[op] if op in traced_operators else 0
|
||||
)
|
||||
# merge dtypes for each kernel
|
||||
for kernal, dtypes in info["kernel_metadata"].items():
|
||||
new_dtypes = dtypes + (kernel_metadata.get(kernal, []))
|
||||
new_dtypes = dtypes + (
|
||||
kernel_metadata[kernal] if kernal in kernel_metadata else []
|
||||
)
|
||||
kernel_metadata[kernal] = list(set(new_dtypes))
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -4879,7 +4879,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
@skipScriptTest()
|
||||
def test_rnn_no_bias(self):
|
||||
def make_model(layers, packed_sequence):
|
||||
batch_first = packed_sequence == 2
|
||||
batch_first = True if packed_sequence == 2 else False
|
||||
model = torch.nn.RNN(
|
||||
RNN_INPUT_SIZE,
|
||||
RNN_HIDDEN_SIZE,
|
||||
|
|
@ -4900,7 +4900,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
|||
return model
|
||||
|
||||
def make_input(batch_size, layers, packed_sequence):
|
||||
batch_first = packed_sequence == 2
|
||||
batch_first = True if packed_sequence == 2 else False
|
||||
seq_lengths = np.random.randint(1, RNN_SEQUENCE_LENGTH + 1, size=batch_size)
|
||||
seq_lengths = sorted(map(int, seq_lengths), reverse=True)
|
||||
inputs = [torch.randn(l, RNN_INPUT_SIZE) for l in seq_lengths]
|
||||
|
|
|
|||
|
|
@ -7045,8 +7045,8 @@ class TestQuantizedConv(TestCase):
|
|||
# ONEDNN only supports symmetric quantization of weight
|
||||
if W_zero_point is not None:
|
||||
W_zero_point = len(W_zero_point) * [0]
|
||||
fp32_output = qconv_output_dtype is torch.float32
|
||||
bfloat16_output = qconv_output_dtype is torch.bfloat16
|
||||
fp32_output = True if qconv_output_dtype is torch.float32 else False
|
||||
bfloat16_output = True if qconv_output_dtype is torch.bfloat16 else False
|
||||
if fp32_output or bfloat16_output:
|
||||
Y_scale = 1.0
|
||||
Y_zero_point = 0
|
||||
|
|
@ -7905,8 +7905,8 @@ class TestQuantizedConv(TestCase):
|
|||
weight_in_channel_last_format=False,
|
||||
):
|
||||
# We assume FP8 quantization is always symmetric
|
||||
fp32_output = qconv_output_dtype is torch.float32
|
||||
bfloat16_output = qconv_output_dtype is torch.bfloat16
|
||||
fp32_output = True if qconv_output_dtype is torch.float32 else False
|
||||
bfloat16_output = True if qconv_output_dtype is torch.bfloat16 else False
|
||||
if fp32_output or bfloat16_output:
|
||||
Y_scale = 1.0
|
||||
X2_scale = 1.0
|
||||
|
|
|
|||
|
|
@ -11712,7 +11712,7 @@ class TestAutogradDeviceType(TestCase):
|
|||
def test_nonzero(tensor, value, expected):
|
||||
tensor[0] = value
|
||||
self.assertEqual(expected, bool(tensor))
|
||||
self.assertEqual(expected, bool(tensor))
|
||||
self.assertEqual(expected, True if tensor else False)
|
||||
|
||||
test_nonzero(l, 0, False)
|
||||
test_nonzero(l, -2, True)
|
||||
|
|
|
|||
|
|
@ -577,7 +577,7 @@ print(t.is_pinned())
|
|||
src = torch.randn(
|
||||
1000000,
|
||||
device="cuda" if dst == "cpu" else "cpu",
|
||||
pin_memory=dst == "cuda",
|
||||
pin_memory=True if dst == "cuda" else False,
|
||||
)
|
||||
_test_to_non_blocking(src, try_non_blocking, dst)
|
||||
|
||||
|
|
|
|||
|
|
@ -942,7 +942,7 @@ def forward(self, scores_1, mask_1, value_1):
|
|||
# not exercised in test_ops_gradients atm. The problem is not
|
||||
# complex32 per-se (which is supported by data movement only ops)
|
||||
# but that when we do backwards we expect other ops like add to work
|
||||
and dtype != torch.complex32
|
||||
and not dtype == torch.complex32
|
||||
)
|
||||
samples = op.sample_inputs(device, dtype, requires_grad=requires_grad)
|
||||
|
||||
|
|
|
|||
|
|
@ -3583,7 +3583,7 @@ class TestFX(JitTestCase):
|
|||
|
||||
class LeafTracerNotB(Tracer):
|
||||
def is_leaf_module(self, module, name):
|
||||
return "b" not in name
|
||||
return False if "b" in name else True
|
||||
|
||||
# Recompile calls added "for fun", since they
|
||||
# chain __call__ wrappers.
|
||||
|
|
|
|||
|
|
@ -2036,7 +2036,7 @@ class TestIndexing(TestCase):
|
|||
index = torch.tensor([0], device=device)
|
||||
x.index_fill_(1, index, 0)
|
||||
self.assertEqual(x, torch.tensor([[0, 2], [0, 5]], dtype=dtype, device=device))
|
||||
if not x.is_complex() and device != "meta":
|
||||
if not x.is_complex() and not device == "meta":
|
||||
with self.assertRaisesRegex(RuntimeError, r"Scalar"):
|
||||
x.index_fill_(1, index, 1 + 1j)
|
||||
# Make sure that the result stays 0-dim while applied to
|
||||
|
|
|
|||
|
|
@ -6723,7 +6723,7 @@ a")
|
|||
@torch.jit.script
|
||||
def testNoThrows(t):
|
||||
c1 = 1
|
||||
if (False and bool(t[1])) or (True or bool(t[1])): # noqa: SIM222,SIM223
|
||||
if (False and bool(t[1])) or (True or bool(t[1])):
|
||||
c1 = 0
|
||||
return c1
|
||||
|
||||
|
|
@ -15758,7 +15758,7 @@ dedent """
|
|||
def fn(d):
|
||||
# type: (Dict[str, int]) -> List[int]
|
||||
out = [1]
|
||||
for i in range(d.get("hi", 6)):
|
||||
for i in range(d["hi"] if "hi" in d else 6):
|
||||
out.append(i) # noqa: PERF402
|
||||
return out
|
||||
|
||||
|
|
@ -16104,7 +16104,7 @@ M = 10
|
|||
S = 5
|
||||
|
||||
def add_nn_module_test(*args, **kwargs):
|
||||
no_grad = kwargs.get('no_grad', False)
|
||||
no_grad = False if 'no_grad' not in kwargs else kwargs['no_grad']
|
||||
|
||||
if 'desc' in kwargs and 'eval' in kwargs['desc']:
|
||||
# eval() is not supported, so skip these tests
|
||||
|
|
|
|||
|
|
@ -111,7 +111,7 @@ class TestAutocast(JitTestCase):
|
|||
def test_runtime_autocast_state_expr(self):
|
||||
@torch.jit.script
|
||||
def fn(a, b):
|
||||
with autocast(enabled=bool((a[0][0] > 0.5).item())):
|
||||
with autocast(enabled=True if a[0][0] > 0.5 else False):
|
||||
return torch.mm(a, b)
|
||||
# runtime values for autocast enable argument are not supported
|
||||
with self.assertRaises(RuntimeError):
|
||||
|
|
|
|||
|
|
@ -3522,7 +3522,7 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
|
|||
nn.RNN(10, 20, batch_first=True)
|
||||
]
|
||||
# ROCm RNN does not issue warning about single contig chunk of memory, so don't assert it
|
||||
first_warn = not torch.version.hip
|
||||
first_warn = False if torch.version.hip else True
|
||||
for rnn in rnns:
|
||||
rnn.cuda()
|
||||
input = torch.randn(5, 4, 10, requires_grad=True, device="cuda")
|
||||
|
|
|
|||
|
|
@ -205,7 +205,7 @@ class TestNumPyInterop(TestCase):
|
|||
x = x.conj()
|
||||
y = x.resolve_conj()
|
||||
expect_error = (
|
||||
requires_grad or sparse or conj or device != "cpu"
|
||||
requires_grad or sparse or conj or not device == "cpu"
|
||||
)
|
||||
error_msg = r"Use (t|T)ensor\..*(\.numpy\(\))?"
|
||||
if not force and expect_error:
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ class PruningOpTest(TestCase):
|
|||
def _generate_rowwise_mask(self, embedding_rows):
|
||||
indicator = torch.from_numpy((np.random.random_sample(embedding_rows)).astype(np.float32))
|
||||
threshold = float(np.random.random_sample())
|
||||
mask = torch.BoolTensor([val >= threshold for val in indicator])
|
||||
mask = torch.BoolTensor([True if val >= threshold else False for val in indicator])
|
||||
return mask
|
||||
|
||||
def _test_rowwise_prune_op(self, embedding_rows, embedding_dims, indices_type, weights_dtype):
|
||||
|
|
|
|||
|
|
@ -1899,7 +1899,7 @@ class TestReductions(TestCase):
|
|||
# Note [all, any uint8 compatibility]: However for compatibility reason,
|
||||
# for `uint8`, they return Tensor of same dtype `uint8`.
|
||||
# Reference: https://github.com/pytorch/pytorch/pull/47878#issuecomment-747108561
|
||||
exact_dtype = dtype != torch.uint8
|
||||
exact_dtype = True if dtype != torch.uint8 else False
|
||||
|
||||
def _test_all_any(x):
|
||||
self.compare_with_numpy(torch.all, np.all, x)
|
||||
|
|
|
|||
|
|
@ -129,7 +129,7 @@ class TestSegmentReductions(TestCase):
|
|||
|
||||
for reduction in reductions:
|
||||
for initial in [0, None]:
|
||||
check_backward = initial is not None
|
||||
check_backward = True if initial is not None else False
|
||||
initial_value = initial
|
||||
default_value = get_default_value(initial_value, reduction)
|
||||
if reduction == "max":
|
||||
|
|
@ -186,7 +186,7 @@ class TestSegmentReductions(TestCase):
|
|||
|
||||
for reduction in reductions:
|
||||
for initial in [0, None]:
|
||||
check_backward = initial is not None
|
||||
check_backward = True if initial is not None else False
|
||||
initial_value = initial
|
||||
default_value = get_default_value(initial_value, reduction)
|
||||
if reduction == "max":
|
||||
|
|
@ -244,7 +244,7 @@ class TestSegmentReductions(TestCase):
|
|||
|
||||
for reduction in reductions:
|
||||
for initial in [0, None]:
|
||||
check_backward = initial is not None
|
||||
check_backward = True if initial is not None else False
|
||||
initial_value = initial
|
||||
default_value = get_default_value(initial_value, reduction)
|
||||
if reduction == "max":
|
||||
|
|
|
|||
|
|
@ -4553,7 +4553,7 @@ class TestSerialization(TestCase, SerializationMixin):
|
|||
with TemporaryFileName() as f:
|
||||
torch.save(m, f)
|
||||
try:
|
||||
old_value = os.environ.get(env_var, None)
|
||||
old_value = os.environ[env_var] if env_var in os.environ else None
|
||||
os.environ[env_var] = "1"
|
||||
# if weights_only is explicitly set, TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD cannot override it
|
||||
with self.assertRaisesRegex(pickle.UnpicklingError, "Weights only load failed"):
|
||||
|
|
|
|||
|
|
@ -3483,7 +3483,7 @@ class TestRandomTensorCreation(TestCase):
|
|||
else:
|
||||
t.uniform_(from_, to_)
|
||||
range_ = to_ - from_
|
||||
if dtype != torch.bfloat16 and not (
|
||||
if not (dtype == torch.bfloat16) and not (
|
||||
dtype == torch.half and device == 'cpu') and not torch.isnan(t).all():
|
||||
delta = alpha * range_
|
||||
double_t = t.to(torch.double)
|
||||
|
|
|
|||
|
|
@ -100,7 +100,7 @@ class TestBuiltin(TestCase):
|
|||
# dtypes results in False/True when compared to valid dtypes.
|
||||
# Here 7 cannot be converted to dtype. No exceptions should be raised
|
||||
|
||||
assert np.dtype(np.int32) != 7, "dtype richcompare failed for =="
|
||||
assert not np.dtype(np.int32) == 7, "dtype richcompare failed for =="
|
||||
assert np.dtype(np.int32) != 7, "dtype richcompare failed for !="
|
||||
|
||||
@parametrize("operation", [operator.le, operator.lt, operator.ge, operator.gt])
|
||||
|
|
|
|||
|
|
@ -401,7 +401,7 @@ allow_rnn = False
|
|||
# exported FX graph. This flag should become the default eventually
|
||||
# and be removed, but currently provides a way to fall back to old
|
||||
# graph breaking behavior.
|
||||
capture_sparse_compute = not is_fbcode()
|
||||
capture_sparse_compute = False if is_fbcode() else True
|
||||
|
||||
# If true, error if we try to compile a function that has
|
||||
# been seen before.
|
||||
|
|
|
|||
|
|
@ -718,7 +718,11 @@ def validate_args_and_maybe_create_graph_inputs(
|
|||
new_proxy = tracer.create_graph_input(
|
||||
arg_name, a.python_type(), example_value
|
||||
)
|
||||
example_value = node.meta.get("example_value", None)
|
||||
example_value = (
|
||||
node.meta["example_value"]
|
||||
if "example_value" in node.meta
|
||||
else None
|
||||
)
|
||||
a = wrap_fx_proxy_cls(
|
||||
target_cls=type(a),
|
||||
tx=tx,
|
||||
|
|
@ -756,7 +760,9 @@ def validate_args_and_maybe_create_graph_inputs(
|
|||
# If `a` can be put into a graph
|
||||
elif a.maybe_fx_node() is not None:
|
||||
node = a.maybe_fx_node()
|
||||
example_value = node.meta.get("example_value", None)
|
||||
example_value = (
|
||||
node.meta["example_value"] if "example_value" in node.meta else None
|
||||
)
|
||||
arg_name = node.name if sub_args_names is None else sub_args_names[idx]
|
||||
new_proxy = tracer.create_graph_input(
|
||||
arg_name, a.python_type(), example_value
|
||||
|
|
|
|||
|
|
@ -270,7 +270,7 @@ backward_pass_autocast = "same_as_forward"
|
|||
|
||||
# This controls whether we collect donated buffer. This flag must be set
|
||||
# False if a user wants to retain_graph=True for backward.
|
||||
donated_buffer = not is_fbcode()
|
||||
donated_buffer = False if is_fbcode() else True
|
||||
|
||||
# Controls the default graph output format used by draw_graph
|
||||
# Supported formats are defined here https://graphviz.org/docs/outputs/
|
||||
|
|
|
|||
|
|
@ -608,7 +608,8 @@ def quantize_activation_fw(graph: torch.fx.Graph) -> None:
|
|||
# Use position-based lookup for building output
|
||||
# only update the return node args, and remain all other users unchanged
|
||||
output_updated_args = [
|
||||
position_to_quant.get(i, node) for i, node in enumerate(fwd_outputs)
|
||||
position_to_quant[i] if i in position_to_quant else node
|
||||
for i, node in enumerate(fwd_outputs)
|
||||
]
|
||||
# add the scale nodes to the output find the first sym_node in the output
|
||||
idx = find_first_sym_node(output_updated_args)
|
||||
|
|
|
|||
|
|
@ -414,7 +414,10 @@ class JsonProfile:
|
|||
elif isinstance(dtype, torch.dtype):
|
||||
self.dtype = dtype
|
||||
else:
|
||||
self.dtype = _dtype_map.get(dtype)
|
||||
if dtype in _dtype_map:
|
||||
self.dtype = _dtype_map[dtype]
|
||||
else:
|
||||
self.dtype = None
|
||||
self._create_devices()
|
||||
|
||||
def convert_dtype(self, event: dict[str, Any]) -> Optional[torch.dtype]:
|
||||
|
|
|
|||
|
|
@ -482,11 +482,15 @@ def get_wrapper_codegen_for_device(
|
|||
|
||||
|
||||
def get_custom_backend_pass_for_device(device: str) -> Optional[CustomGraphModulePass]:
|
||||
return custom_backend_passes.get(device)
|
||||
return custom_backend_passes[device] if device in custom_backend_passes else None
|
||||
|
||||
|
||||
def get_custom_backend_config_for_device(device: str) -> Optional[ConfigModule]:
|
||||
return custom_backend_codegen_configs.get(device)
|
||||
return (
|
||||
custom_backend_codegen_configs[device]
|
||||
if device in custom_backend_codegen_configs
|
||||
else None
|
||||
)
|
||||
|
||||
|
||||
@functools.cache
|
||||
|
|
|
|||
|
|
@ -857,7 +857,7 @@ def parallel_compile_enabled_internally() -> bool:
|
|||
|
||||
jk_name = "pytorch/inductor:enable_parallel_compile_version"
|
||||
version = torch._utils_internal.justknobs_getval_int(jk_name)
|
||||
return version <= ENABLE_PARALLEL_COMPILE_VERSION
|
||||
return ENABLE_PARALLEL_COMPILE_VERSION >= version
|
||||
|
||||
|
||||
def decide_compile_threads() -> int:
|
||||
|
|
@ -1259,7 +1259,7 @@ class triton:
|
|||
cudagraph_trees_history_recording = False
|
||||
|
||||
# Enable cudagraph support for mutated inputs from prior cudagraph pool
|
||||
cudagraph_support_input_mutation = not is_fbcode()
|
||||
cudagraph_support_input_mutation = False if is_fbcode() else True
|
||||
|
||||
# Maximal number of allowed cudagraph re-record for a function and
|
||||
# a cudagraph node due to static input tensor address changes or
|
||||
|
|
|
|||
|
|
@ -476,7 +476,9 @@ def build_subgraph_buffer(
|
|||
elif node.op == "call_function":
|
||||
# For call_function we use the default lowerings and pass in the
|
||||
# already created TensorBoxes as args
|
||||
args, kwargs = tree_map(lambda x: env.get(x, x), (node.args, node.kwargs))
|
||||
args, kwargs = tree_map(
|
||||
lambda x: env[x] if x in env else x, (node.args, node.kwargs)
|
||||
)
|
||||
env[node] = lowerings[node.target](*args, **kwargs)
|
||||
elif node.op == "output":
|
||||
|
||||
|
|
@ -689,7 +691,9 @@ def b2b_gemm_handler(match: Match, mat1: torch.fx.Node, mat2: torch.fx.Node) ->
|
|||
for node in graph.nodes: # preserve the order of nodes
|
||||
if node in subgraph_node_set:
|
||||
subgraph_node_list.append(node)
|
||||
new_node = new_graph.node_copy(node, lambda x: node_remapping.get(x, x))
|
||||
new_node = new_graph.node_copy(
|
||||
node, lambda x: node_remapping[x] if x in node_remapping else x
|
||||
)
|
||||
node_remapping[node] = new_node
|
||||
if node is inner_mm:
|
||||
new_input_anchor = new_node
|
||||
|
|
|
|||
|
|
@ -531,7 +531,7 @@ def _register_quantized_linear_unary_lowering(
|
|||
)
|
||||
|
||||
# bias
|
||||
b = kwargs.get("b")
|
||||
b = kwargs["b"] if "b" in kwargs else None
|
||||
|
||||
# Output QParams
|
||||
o_inv_scale = kwargs["output_scale"]
|
||||
|
|
@ -593,7 +593,7 @@ def _register_quantized_linear_binary_lowering(
|
|||
kwargs["w_zp"],
|
||||
)
|
||||
# bias
|
||||
b = kwargs.get("b")
|
||||
b = kwargs["b"] if "b" in kwargs else None
|
||||
# Output QParams
|
||||
o_inv_scale = kwargs["output_scale"]
|
||||
o_zero_point = kwargs["output_zero_point"]
|
||||
|
|
@ -885,10 +885,10 @@ def _register_quantized_maxpool2d_lowering(
|
|||
def qmaxpool2d(match: Match, *args, **kwargs):
|
||||
x = kwargs["x"]
|
||||
kernel_size = kwargs["kernel_size"]
|
||||
stride = kwargs.get("stride")
|
||||
padding = kwargs.get("padding", 0)
|
||||
dilation = kwargs.get("dilation", 1)
|
||||
ceil_mode = kwargs.get("ceil_mode", False)
|
||||
stride = kwargs["stride"] if ("stride" in kwargs) else None
|
||||
padding = kwargs["padding"] if ("padding" in kwargs) else 0
|
||||
dilation = kwargs["dilation"] if ("dilation" in kwargs) else 1
|
||||
ceil_mode = kwargs["ceil_mode"] if ("ceil_mode" in kwargs) else False
|
||||
|
||||
if padding == 0:
|
||||
padding = [0, 0]
|
||||
|
|
@ -1976,7 +1976,7 @@ def _register_qlinear_weight_prepack_pass(
|
|||
)
|
||||
|
||||
# Params
|
||||
bias = kwargs.get("b")
|
||||
bias = kwargs["b"] if "b" in kwargs else None
|
||||
|
||||
x_shape = qx.meta.get("tensor_meta").shape
|
||||
if has_free_symbols(x_shape):
|
||||
|
|
@ -2451,7 +2451,7 @@ def _register_linear_dynamic_fp16_weight_prepack_pass(
|
|||
# find params
|
||||
x = kwargs["x"]
|
||||
w = kwargs["w"]
|
||||
bias = kwargs.get("b")
|
||||
bias = kwargs["b"] if "b" in kwargs else None
|
||||
|
||||
# find linear node
|
||||
nodes_to_find = [aten.addmm.default, aten.mm.default, aten.bmm.default]
|
||||
|
|
@ -2727,7 +2727,7 @@ def _register_smooth_quant_int_mm_pattern():
|
|||
pass_number=pass_number,
|
||||
)
|
||||
def _int_mm_weight_prepack(match: Match, *args, **kwargs):
|
||||
bias = kwargs.get("bias")
|
||||
bias = kwargs.get("bias", None)
|
||||
x = kwargs["a"]
|
||||
weight = kwargs["b"]
|
||||
dtype = kwargs["dtype"]
|
||||
|
|
@ -2794,7 +2794,7 @@ def _register_smooth_quant_int_mm_pattern():
|
|||
else:
|
||||
# onednn.qlinear does not support per-channel quantization of x
|
||||
# so in this case, we have to apply x scale and add bias ourselves after qlinear
|
||||
in_shape = kwargs.get("in_shape")
|
||||
in_shape = kwargs.get("in_shape", None)
|
||||
if in_shape is None:
|
||||
x_reshaped = x
|
||||
else:
|
||||
|
|
@ -2826,8 +2826,8 @@ def _register_smooth_quant_int_mm_pattern():
|
|||
|
||||
# Add bias and reshape
|
||||
has_outer_reshape = (
|
||||
kwargs.get("out_shape_with_bias") is not None
|
||||
or kwargs.get("out_shape_no_bias") is not None
|
||||
kwargs.get("out_shape_with_bias", None) is not None
|
||||
or kwargs.get("out_shape_no_bias", None) is not None
|
||||
)
|
||||
|
||||
if has_outer_reshape:
|
||||
|
|
@ -3276,7 +3276,7 @@ def _register_qlinear_post_op_fusion_pass(
|
|||
)
|
||||
|
||||
# bias
|
||||
b = kwargs.get("b")
|
||||
b = kwargs["b"] if "b" in kwargs else None
|
||||
|
||||
# Output QParams
|
||||
o_inv_scale = (
|
||||
|
|
|
|||
|
|
@ -1074,13 +1074,13 @@ def _overload_method(func):
|
|||
_check_overload_body(func)
|
||||
qual_name = _qualified_name(func)
|
||||
global _overloaded_methods
|
||||
class_name_map = _overloaded_methods.get(qual_name)
|
||||
class_name_map = _overloaded_methods.get(qual_name, None)
|
||||
if class_name_map is None:
|
||||
class_name_map = {}
|
||||
_overloaded_methods[qual_name] = class_name_map
|
||||
|
||||
class_name, line_no = get_class_name_lineno(func)
|
||||
method_overloads = class_name_map.get(class_name)
|
||||
method_overloads = class_name_map.get(class_name, None)
|
||||
if method_overloads is None:
|
||||
method_overloads = []
|
||||
class_name_map[class_name] = method_overloads
|
||||
|
|
@ -1102,7 +1102,7 @@ def _get_overloaded_methods(method, mod_class):
|
|||
if not hasattr(method, "__name__"):
|
||||
return None
|
||||
qual_name = _qualified_name(method)
|
||||
class_name_map = _overloaded_methods.get(qual_name)
|
||||
class_name_map = _overloaded_methods.get(qual_name, None)
|
||||
if class_name_map is None:
|
||||
return None
|
||||
overloads = class_name_map.get(mod_class.__name__, None)
|
||||
|
|
|
|||
|
|
@ -5303,7 +5303,7 @@ def grid_sampler_3d_backward(
|
|||
|
||||
@register_meta([aten.full.default])
|
||||
def full(size, fill_value, *args, **kwargs):
|
||||
dtype = kwargs.get("dtype")
|
||||
dtype = kwargs.get("dtype", None)
|
||||
if not dtype:
|
||||
dtype = utils.get_dtype(fill_value)
|
||||
kwargs["dtype"] = dtype
|
||||
|
|
|
|||
|
|
@ -1414,7 +1414,7 @@ class _HigherOrderNamespace(types.ModuleType):
|
|||
|
||||
def __getattr__(self, name: str) -> HigherOrderOperator:
|
||||
# Following _OpNamespace.__getattr__, we cache the op on this object.
|
||||
op = _higher_order_ops.get(name)
|
||||
op = _higher_order_ops.get(name, None)
|
||||
if op is None:
|
||||
raise AttributeError(
|
||||
f"'_HigherOrderNamespace' 'torch.ops.higher_order' object has no attribute '{name}'"
|
||||
|
|
|
|||
|
|
@ -84,8 +84,12 @@ class ReferenceQuantizedModule(torch.nn.Module):
|
|||
# store weight_axis as weight_axis_int due to some constraints of torchdynamo.export
|
||||
# for capturing `.item` operations
|
||||
self.weight_axis_int: int = self.weight_axis.item() # type: ignore[operator, assignment]
|
||||
self.weight_quant_min: typing.Optional[int] = weight_qparams.get("quant_min")
|
||||
self.weight_quant_max: typing.Optional[int] = weight_qparams.get("quant_max")
|
||||
self.weight_quant_min: typing.Optional[int] = weight_qparams.get(
|
||||
"quant_min", None
|
||||
)
|
||||
self.weight_quant_max: typing.Optional[int] = weight_qparams.get(
|
||||
"quant_max", None
|
||||
)
|
||||
|
||||
def get_weight(self):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -240,29 +240,29 @@ scale_min_lower_bound=None, scale_max_upper_bound=None)
|
|||
"bias_type": torch.dtype
|
||||
"is_dynamic": bool
|
||||
"""
|
||||
input_dtype = dtype_config_dict.get(INPUT_DTYPE_DICT_KEY)
|
||||
input_dtype = dtype_config_dict.get(INPUT_DTYPE_DICT_KEY, None)
|
||||
if input_dtype is not None and not isinstance(
|
||||
input_dtype, (torch.dtype, DTypeWithConstraints)
|
||||
):
|
||||
raise ValueError(
|
||||
"Expected input_dtype to be a torch.dtype or DTypeWithConstraints"
|
||||
)
|
||||
output_dtype = dtype_config_dict.get(OUTPUT_DTYPE_DICT_KEY)
|
||||
output_dtype = dtype_config_dict.get(OUTPUT_DTYPE_DICT_KEY, None)
|
||||
if output_dtype is not None and not isinstance(
|
||||
output_dtype, (torch.dtype, DTypeWithConstraints)
|
||||
):
|
||||
raise ValueError(
|
||||
"Expected output_dtype to be a torch.dtype or DTypeWithConstraints"
|
||||
)
|
||||
weight_dtype = dtype_config_dict.get(WEIGHT_DTYPE_DICT_KEY)
|
||||
weight_dtype = dtype_config_dict.get(WEIGHT_DTYPE_DICT_KEY, None)
|
||||
if weight_dtype is not None and not isinstance(
|
||||
weight_dtype, (torch.dtype, DTypeWithConstraints)
|
||||
):
|
||||
raise ValueError(
|
||||
"Expected weight_dtype to be a torch.dtype or DTypeWithConstraints"
|
||||
)
|
||||
bias_dtype = dtype_config_dict.get(BIAS_DTYPE_DICT_KEY)
|
||||
is_dynamic = dtype_config_dict.get(IS_DYNAMIC_DICT_KEY)
|
||||
bias_dtype = dtype_config_dict.get(BIAS_DTYPE_DICT_KEY, None)
|
||||
is_dynamic = dtype_config_dict.get(IS_DYNAMIC_DICT_KEY, None)
|
||||
return cls(input_dtype, output_dtype, weight_dtype, bias_dtype, is_dynamic)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
|
|
@ -673,23 +673,23 @@ class BackendPatternConfig:
|
|||
for d in backend_pattern_config_dict.get(DTYPE_CONFIGS_DICT_KEY, []):
|
||||
conf.add_dtype_config(_get_dtype_config(d))
|
||||
conf.set_root_module(
|
||||
backend_pattern_config_dict.get(ROOT_MODULE_DICT_KEY) # type: ignore[arg-type]
|
||||
backend_pattern_config_dict.get(ROOT_MODULE_DICT_KEY, None) # type: ignore[arg-type]
|
||||
)
|
||||
conf.set_qat_module(backend_pattern_config_dict.get(QAT_MODULE_DICT_KEY)) # type: ignore[arg-type]
|
||||
conf.set_qat_module(backend_pattern_config_dict.get(QAT_MODULE_DICT_KEY, None)) # type: ignore[arg-type]
|
||||
conf.set_reference_quantized_module(
|
||||
backend_pattern_config_dict.get(REFERENCE_QUANTIZED_MODULE_DICT_KEY) # type: ignore[arg-type]
|
||||
backend_pattern_config_dict.get(REFERENCE_QUANTIZED_MODULE_DICT_KEY, None) # type: ignore[arg-type]
|
||||
)
|
||||
conf.set_fused_module(
|
||||
backend_pattern_config_dict.get(FUSED_MODULE_DICT_KEY) # type: ignore[arg-type]
|
||||
backend_pattern_config_dict.get(FUSED_MODULE_DICT_KEY, None) # type: ignore[arg-type]
|
||||
)
|
||||
conf.set_fuser_method(
|
||||
backend_pattern_config_dict.get(FUSER_METHOD_DICT_KEY) # type: ignore[arg-type]
|
||||
backend_pattern_config_dict.get(FUSER_METHOD_DICT_KEY, None) # type: ignore[arg-type]
|
||||
)
|
||||
conf._set_root_node_getter(
|
||||
backend_pattern_config_dict.get(ROOT_NODE_GETTER_DICT_KEY) # type: ignore[arg-type]
|
||||
backend_pattern_config_dict.get(ROOT_NODE_GETTER_DICT_KEY, None) # type: ignore[arg-type]
|
||||
)
|
||||
conf._set_extra_inputs_getter(
|
||||
backend_pattern_config_dict.get(EXTRA_INPUTS_GETTER_DICT_KEY) # type: ignore[arg-type]
|
||||
backend_pattern_config_dict.get(EXTRA_INPUTS_GETTER_DICT_KEY, None) # type: ignore[arg-type]
|
||||
)
|
||||
conf._set_num_tensor_args_to_observation_type(
|
||||
backend_pattern_config_dict.get(
|
||||
|
|
|
|||
|
|
@ -286,7 +286,7 @@ def get_fuser_method_new(
|
|||
op_patterns = _get_valid_patterns(op_pattern)
|
||||
fuser_method = None
|
||||
for op_pattern in op_patterns:
|
||||
fuser_method = fuser_method_mapping.get(op_pattern)
|
||||
fuser_method = fuser_method_mapping.get(op_pattern, None)
|
||||
if fuser_method is not None:
|
||||
break
|
||||
assert fuser_method is not None, f"did not find fuser method for: {op_pattern} "
|
||||
|
|
|
|||
|
|
@ -168,7 +168,7 @@ def _find_matches(
|
|||
for node in reversed(graph.nodes):
|
||||
if node.name not in match_map and node.name not in all_matched:
|
||||
for pattern, quantize_handler_cls in patterns.items():
|
||||
root_node_getter = root_node_getter_mapping.get(pattern)
|
||||
root_node_getter = root_node_getter_mapping.get(pattern, None)
|
||||
if _is_match(modules, node, pattern) and node.name not in match_map:
|
||||
matched_node_pattern: list[Node] = []
|
||||
record_match(pattern, node, node, matched_node_pattern, match_map)
|
||||
|
|
|
|||
|
|
@ -130,7 +130,7 @@ def _get_qspec_for_arg(
|
|||
) -> Optional[QuantizationSpecBase]:
|
||||
while _is_activation_post_process_node(arg, named_modules):
|
||||
arg = arg.args[0] # type: ignore[assignment]
|
||||
return input_qspec_map.get(arg)
|
||||
return input_qspec_map.get(arg, None)
|
||||
|
||||
|
||||
def _create_obs_or_fq_from_qspec(
|
||||
|
|
|
|||
|
|
@ -163,7 +163,7 @@ def get_qconv_prepack_op(conv_op: Callable) -> Callable:
|
|||
torch.nn.functional.conv_transpose2d: torch.ops.quantized.conv_transpose2d_prepack,
|
||||
torch.nn.functional.conv_transpose3d: torch.ops.quantized.conv_transpose3d_prepack,
|
||||
}
|
||||
prepack_op = prepack_ops.get(conv_op)
|
||||
prepack_op = prepack_ops.get(conv_op, None)
|
||||
assert prepack_op, f"Didn't find prepack op for {conv_op}"
|
||||
return prepack_op
|
||||
|
||||
|
|
|
|||
|
|
@ -803,7 +803,7 @@ class PerChannelMinMaxObserver(UniformQuantizationObserverBase):
|
|||
unexpected_keys: list[str],
|
||||
error_msgs: list[str],
|
||||
):
|
||||
version = local_metadata.get("version")
|
||||
version = local_metadata.get("version", None)
|
||||
if version is not None and version < 3:
|
||||
local_state = ["min_vals", "max_vals"]
|
||||
expected_min_name = "min_vals"
|
||||
|
|
|
|||
|
|
@ -366,7 +366,7 @@ def _maybe_insert_input_observer_for_arg_or_kwarg(
|
|||
if input_edge_obs_or_fq is None:
|
||||
return new_arg
|
||||
|
||||
arg_as_output_obs_or_fq = obs_or_fq_map.get(original_arg)
|
||||
arg_as_output_obs_or_fq = obs_or_fq_map.get(original_arg, None)
|
||||
# the arg is observed as the output and is using the same instance as the input_edge
|
||||
# we'll reuse the inserted observer/fake_quant
|
||||
if arg_as_output_obs_or_fq is not None and id(arg_as_output_obs_or_fq) == id(
|
||||
|
|
@ -497,7 +497,11 @@ def _maybe_insert_input_and_output_observers_for_node(
|
|||
is_qat: bool,
|
||||
model_device: Optional[torch.device] = None,
|
||||
):
|
||||
this_node_quantization_annotation = node.meta.get("quantization_annotation", None)
|
||||
this_node_quantization_annotation = (
|
||||
node.meta["quantization_annotation"]
|
||||
if "quantization_annotation" in node.meta
|
||||
else None
|
||||
)
|
||||
if this_node_quantization_annotation is None:
|
||||
return
|
||||
|
||||
|
|
|
|||
|
|
@ -343,7 +343,7 @@ def get_default_float_to_quantized_operator_mappings() -> dict[
|
|||
# TODO: merge with get_static_quant_module_class
|
||||
def get_quantized_operator(float_op: Union[Callable, str]) -> Callable:
|
||||
"""Get the quantized operator corresponding to the float operator"""
|
||||
quantized_op = DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS.get(float_op)
|
||||
quantized_op = DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS.get(float_op, None)
|
||||
assert quantized_op is not None, (
|
||||
f"Operator {str(float_op)} does not have corresponding quantized op"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1352,8 +1352,10 @@ class X86InductorQuantizer(Quantizer):
|
|||
def _annotate_output_share_observer_as_input(
|
||||
self, input_node: Node, source_node: Node
|
||||
):
|
||||
source_node_quantization_annotation = source_node.meta.get(
|
||||
QUANT_ANNOTATION_KEY, None
|
||||
source_node_quantization_annotation = (
|
||||
source_node.meta[QUANT_ANNOTATION_KEY]
|
||||
if QUANT_ANNOTATION_KEY in source_node.meta
|
||||
else None
|
||||
)
|
||||
if (
|
||||
source_node_quantization_annotation
|
||||
|
|
@ -1393,8 +1395,10 @@ class X86InductorQuantizer(Quantizer):
|
|||
return
|
||||
|
||||
# Get the quantization_annotation from getitem_node
|
||||
maxpool_node_quantization_annotation = maxpool_node.meta.get(
|
||||
QUANT_ANNOTATION_KEY, None
|
||||
maxpool_node_quantization_annotation = (
|
||||
maxpool_node.meta[QUANT_ANNOTATION_KEY]
|
||||
if QUANT_ANNOTATION_KEY in maxpool_node.meta
|
||||
else None
|
||||
)
|
||||
if (
|
||||
maxpool_node_quantization_annotation
|
||||
|
|
|
|||
|
|
@ -162,7 +162,10 @@ class EventList(list):
|
|||
if p is not None:
|
||||
assert p.fwd_thread is not None
|
||||
t = (p.sequence_nr, p.fwd_thread)
|
||||
evt.stack = fwd_stacks.get(t, [])
|
||||
if t in fwd_stacks:
|
||||
evt.stack = fwd_stacks[t]
|
||||
else:
|
||||
evt.stack = []
|
||||
|
||||
@property
|
||||
def self_cpu_time_total(self):
|
||||
|
|
|
|||
|
|
@ -214,7 +214,7 @@ def replicate(
|
|||
|
||||
state = replicate.state(module)
|
||||
module.register_forward_pre_hook(state.forward_pre_hook, with_kwargs=True)
|
||||
device_mesh = kwargs.get("device_mesh")
|
||||
device_mesh = kwargs.get("device_mesh", None)
|
||||
if device_mesh is not None:
|
||||
from torch.distributed.device_mesh import _mesh_resources
|
||||
|
||||
|
|
|
|||
|
|
@ -226,7 +226,7 @@ def replicate_impl(
|
|||
# Place Replicate leftmost for highest priority in the method resolution order
|
||||
for module in modules:
|
||||
cls = module.__class__
|
||||
new_cls = cls_to_replicate_cls.get(cls)
|
||||
new_cls = cls_to_replicate_cls.get(cls, None)
|
||||
if not new_cls:
|
||||
dct = {"__deepcopy__": _unimplemented_deepcopy}
|
||||
new_cls = type(f"Replicate{cls.__name__}", (ReplicateModule, cls), dct)
|
||||
|
|
|
|||
|
|
@ -131,7 +131,7 @@ def register_tensor_creation_op(op):
|
|||
takes a ShardedTensor as argument, such as ``torch.zeros_like`` or
|
||||
``torch.full_like``.
|
||||
"""
|
||||
creation_op = tensor_like_creation_op_map.get(op)
|
||||
creation_op = tensor_like_creation_op_map.get(op, None)
|
||||
if creation_op is None:
|
||||
raise RuntimeError(f"Tensor creation {op} not supported!")
|
||||
if kwargs is None:
|
||||
|
|
|
|||
|
|
@ -676,7 +676,7 @@ class ShardedTensor(ShardedTensorBase):
|
|||
copy_tensor = kwargs.get("copy", False)
|
||||
non_blocking = kwargs.get("non_blocking", False)
|
||||
memory_format = kwargs.get("memory_format", torch.preserve_format)
|
||||
process_group = kwargs.get("process_group")
|
||||
process_group = kwargs.get("process_group", None)
|
||||
|
||||
if (
|
||||
not copy_tensor
|
||||
|
|
|
|||
|
|
@ -596,7 +596,7 @@ def _distribute_tensors(
|
|||
if pg is None:
|
||||
pg = dist.distributed_c10d._get_default_group()
|
||||
for key in keys:
|
||||
_local_state = local_state_dict.get(key)
|
||||
_local_state = local_state_dict.get(key, None)
|
||||
if _local_state is None or torch.is_tensor(_local_state):
|
||||
continue
|
||||
|
||||
|
|
@ -706,7 +706,7 @@ def _distribute_state_dict(
|
|||
local_state_dict[key] = value.cpu()
|
||||
else:
|
||||
assert isinstance(value, torch.Tensor)
|
||||
local_state = local_state_dict.get(key)
|
||||
local_state = local_state_dict.get(key, None)
|
||||
if local_state is None:
|
||||
continue
|
||||
elif isinstance(local_state, DTensor):
|
||||
|
|
|
|||
|
|
@ -127,7 +127,7 @@ def aggregate_stats(
|
|||
}
|
||||
|
||||
for mod in model.modules():
|
||||
if mod_mem_stat := mod_mem_stats.get(mod):
|
||||
if mod_mem_stat := mod_mem_stats.get(mod, None):
|
||||
if tradeoff_stats := mod_sac_tradeoff_stats.get(mod_mem_stat.mod_fqn, None):
|
||||
sac_runtime = tradeoff_stats.sac_runtime
|
||||
sac_memory = tradeoff_stats.sac_memory
|
||||
|
|
|
|||
|
|
@ -710,7 +710,7 @@ class SACEstimator(TorchDispatchMode):
|
|||
str(i in sac_stats.view_like_ops),
|
||||
str(i in sac_stats.rand_ops),
|
||||
str(i in sac_stats.saved_autograd_ops),
|
||||
str(op_parent.get(i)),
|
||||
str(op_parent.get(i, None)),
|
||||
]
|
||||
table_data.append(row)
|
||||
# Define headers
|
||||
|
|
|
|||
|
|
@ -106,7 +106,7 @@ def auto_quantize(func, qtype, quant_loss=None):
|
|||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
group = kwargs.get("group")
|
||||
group = kwargs.get("group", None)
|
||||
async_op = kwargs.get("async_op", False)
|
||||
if async_op is True:
|
||||
raise RuntimeError("The async_op=True mode is not supported yet.")
|
||||
|
|
@ -132,8 +132,8 @@ def auto_quantize(func, qtype, quant_loss=None):
|
|||
|
||||
elif func == dist.all_to_all_single:
|
||||
tensors = args[0]
|
||||
out_splits = kwargs.get("out_splits")
|
||||
in_splits = kwargs.get("in_splits")
|
||||
out_splits = kwargs.get("out_splits", None)
|
||||
in_splits = kwargs.get("in_splits", None)
|
||||
# Quantizing the input/output tensor
|
||||
input_tensors = _quantize_tensor(args[1], qtype)
|
||||
out_tensors = _quantize_tensor(tensors, qtype)
|
||||
|
|
|
|||
|
|
@ -631,7 +631,7 @@ class _FileSystemWriter(StorageWriter):
|
|||
def set_up_storage_writer(
|
||||
self, is_coordinator: bool, *args: Any, **kwargs: Any
|
||||
) -> None:
|
||||
self.rank = kwargs.get("rank")
|
||||
self.rank = kwargs.get("rank", None)
|
||||
self.use_collectives = kwargs.get("use_collectives", True)
|
||||
|
||||
def _metadata_exists(self) -> bool:
|
||||
|
|
@ -919,7 +919,7 @@ class FileSystemReader(StorageReader):
|
|||
|
||||
# Implementing the abstract function in StorageReader
|
||||
def read_metadata(self, *args: Any, **kwargs: Any) -> Metadata:
|
||||
rank = kwargs.get("rank")
|
||||
rank = kwargs.get("rank", None)
|
||||
path = self._get_metadata_path(rank)
|
||||
with self.fs.create_stream(path, "rb") as metadata_file:
|
||||
metadata = pickle.load(metadata_file)
|
||||
|
|
@ -934,7 +934,7 @@ class FileSystemReader(StorageReader):
|
|||
self, metadata: Metadata, is_coordinator: bool, *args: Any, **kwargs: Any
|
||||
) -> None:
|
||||
self.storage_data = metadata.storage_data
|
||||
self.rank = kwargs.get("rank")
|
||||
self.rank = kwargs.get("rank", None)
|
||||
self.use_collectives = kwargs.get("use_collectives", True)
|
||||
assert self.storage_data is not None
|
||||
|
||||
|
|
|
|||
|
|
@ -30,11 +30,11 @@ def _msg_dict_from_dcp_method_args(*args, **kwargs) -> dict[str, Any]:
|
|||
msg_dict = {}
|
||||
|
||||
# checkpoint ID can be passed in through the serializer or through the checkpoint id directly
|
||||
storage_writer = kwargs.get("storage_writer")
|
||||
storage_reader = kwargs.get("storage_reader")
|
||||
planner = kwargs.get("planner")
|
||||
storage_writer = kwargs.get("storage_writer", None)
|
||||
storage_reader = kwargs.get("storage_reader", None)
|
||||
planner = kwargs.get("planner", None)
|
||||
|
||||
checkpoint_id = kwargs.get("checkpoint_id")
|
||||
checkpoint_id = kwargs.get("checkpoint_id", None)
|
||||
if not checkpoint_id and (serializer := storage_writer or storage_reader):
|
||||
checkpoint_id = getattr(serializer, "checkpoint_id", None)
|
||||
|
||||
|
|
|
|||
|
|
@ -305,7 +305,7 @@ def _verify_options(
|
|||
continue
|
||||
|
||||
fqns = _get_fqns(model, name)
|
||||
fqn = fqn_param_mapping.get(param)
|
||||
fqn = fqn_param_mapping.get(param, None)
|
||||
if fqn is not None:
|
||||
cast(set[str], fqn_param_mapping[param]).update(fqns)
|
||||
shared_params_mapping[param] = fqn_param_mapping[param]
|
||||
|
|
|
|||
|
|
@ -195,7 +195,7 @@ else:
|
|||
# A root mesh is not created through slicing.
|
||||
# We considers the root mesh of a root mesh is itself.
|
||||
root_mesh = self.child_to_root_mapping.get(device_mesh, None)
|
||||
return root_mesh if root_mesh else device_mesh
|
||||
return device_mesh if not root_mesh else root_mesh
|
||||
|
||||
def get_root_mesh_dim(self, device_mesh: "DeviceMesh") -> Optional[int]:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -5042,7 +5042,7 @@ def _is_safe_to_split() -> bool:
|
|||
users must be aware that a pg is only splittable after the first collective is
|
||||
issued.
|
||||
"""
|
||||
return _get_default_group().bound_device_id is not None
|
||||
return False if _get_default_group().bound_device_id is None else True
|
||||
|
||||
|
||||
@_time_logger
|
||||
|
|
|
|||
|
|
@ -88,7 +88,10 @@ def configure(handler: MetricHandler, group: Optional[str] = None):
|
|||
|
||||
|
||||
def getStream(group: str):
|
||||
handler = _metrics_map.get(group, _default_metrics_handler)
|
||||
if group in _metrics_map:
|
||||
handler = _metrics_map[group]
|
||||
else:
|
||||
handler = _default_metrics_handler
|
||||
return MetricStream(group, handler)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -239,7 +239,7 @@ def fully_shard(
|
|||
# Place FSDP leftmost for highest priority in the method resolution order
|
||||
for module in modules:
|
||||
cls = module.__class__
|
||||
new_cls = cls_to_fsdp_cls.get(cls)
|
||||
new_cls = cls_to_fsdp_cls.get(cls, None)
|
||||
if not new_cls:
|
||||
dct = {"__deepcopy__": _unimplemented_deepcopy}
|
||||
new_cls = type(f"FSDP{cls.__name__}", (FSDPModule, cls), dct)
|
||||
|
|
|
|||
|
|
@ -1267,7 +1267,7 @@ def _is_named_optimizer(optim_state_dict: dict[str, Any]) -> bool:
|
|||
(which usually are FQNs) versus integers (which usually refer to param_ids
|
||||
from a vanilla torch.optim.Optimizer).
|
||||
"""
|
||||
state = optim_state_dict.get("state")
|
||||
state = optim_state_dict.get("state", None)
|
||||
if not state:
|
||||
# If we cannot find a state, assume it is not NamedOptimizer as
|
||||
# NamedOptimizer has eager initialization.
|
||||
|
|
@ -1715,7 +1715,7 @@ def _convert_state_with_orig_params(
|
|||
# across ranks
|
||||
for optim_state_key in all_optim_state_keys:
|
||||
param_key: Union[str, int, None] = optim_state_key_to_param_key.get(
|
||||
optim_state_key
|
||||
optim_state_key, None
|
||||
)
|
||||
|
||||
if param_key is None and not optim_state_key.is_fsdp_managed:
|
||||
|
|
@ -1723,7 +1723,7 @@ def _convert_state_with_orig_params(
|
|||
|
||||
if optim_state_key.is_fsdp_managed:
|
||||
fqn = optim_state_key.unflat_param_names[0]
|
||||
fsdp_param_info = fqn_to_fsdp_param_info.get(fqn)
|
||||
fsdp_param_info = fqn_to_fsdp_param_info.get(fqn, None)
|
||||
if fsdp_param_info is None:
|
||||
# This can happen if the not all FSDP instances have all the
|
||||
# parameters. This can happen with FSDP + some MPMD style
|
||||
|
|
@ -1801,7 +1801,7 @@ def _convert_state_with_flat_params(
|
|||
# across ranks
|
||||
for optim_state_key in all_optim_state_keys:
|
||||
param_key: Union[str, int, None] = optim_state_key_to_param_key.get(
|
||||
optim_state_key
|
||||
optim_state_key, None
|
||||
)
|
||||
|
||||
assert param_key is not None, (
|
||||
|
|
|
|||
|
|
@ -52,7 +52,8 @@ class _ScriptLocalOptimizer(nn.Module):
|
|||
all_local_grads = dist_autograd.get_gradients(autograd_ctx_id)
|
||||
# apply functional optimizer step with a list of gradients
|
||||
grads: list[Optional[Tensor]] = [
|
||||
all_local_grads.get(p, None) for p in self._local_params
|
||||
all_local_grads[p] if p in all_local_grads else None
|
||||
for p in self._local_params
|
||||
]
|
||||
|
||||
self.optim.step(grads)
|
||||
|
|
|
|||
|
|
@ -189,7 +189,7 @@ def _insert_stage_symbolic_backward(
|
|||
output_grads: Union[tuple[Optional[fx.Node], ...], Optional[fx.Node]]
|
||||
if node in tuples:
|
||||
stage_output = tuples[node]
|
||||
output_grads = tuple(val_to_grad.get(n) for n in tuples[node])
|
||||
output_grads = tuple(val_to_grad.get(n, None) for n in tuples[node])
|
||||
outputs_with_grads_idxs = [
|
||||
i for i, n in enumerate(tuples[node]) if n in live_nodes
|
||||
]
|
||||
|
|
|
|||
|
|
@ -114,7 +114,7 @@ def get_param_groups(
|
|||
"intermediates": intersected,
|
||||
}
|
||||
for input_node in intersected:
|
||||
existing = param_groups.get(input_node)
|
||||
existing = param_groups.get(input_node, None)
|
||||
if existing is not None:
|
||||
existing["params"] = existing["params"].union(param_group["params"])
|
||||
existing["intermediates"] = existing["intermediates"].union(
|
||||
|
|
|
|||
|
|
@ -326,7 +326,8 @@ def _insert_copy_for_mutations(
|
|||
return_nodes_to_copy[return_node] = copy_node
|
||||
|
||||
output_args = tuple(
|
||||
return_nodes_to_copy.get(node, node) for node in user_output_nodes
|
||||
return_nodes_to_copy[node] if node in return_nodes_to_copy else node
|
||||
for node in user_output_nodes
|
||||
)
|
||||
with gm.graph.inserting_before(output_node):
|
||||
# Only return user outputs
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ class NormalizeArgs(Transformer):
|
|||
|
||||
def get_type(arg):
|
||||
if isinstance(arg, fx.Node):
|
||||
return n.meta.get("type", None)
|
||||
return n.meta["type"] if "type" in n.meta else None
|
||||
return type(arg)
|
||||
|
||||
arg_types = map_aggregate(n.args, get_type)
|
||||
|
|
|
|||
|
|
@ -4378,7 +4378,7 @@ class ShapeEnv:
|
|||
size = []
|
||||
for i, val in enumerate(tensor_size):
|
||||
sym = self.create_symbol(
|
||||
hint_overrides.get(i, val),
|
||||
val if i not in hint_overrides else hint_overrides[i],
|
||||
TensorPropertySource(source, TensorProperty.SIZE, i),
|
||||
dynamic_dims[i],
|
||||
constraint_dims[i],
|
||||
|
|
@ -4579,7 +4579,7 @@ class ShapeEnv:
|
|||
sym_sizes = [
|
||||
self.create_symintnode(
|
||||
sym,
|
||||
hint=hint_overrides.get(i, hint),
|
||||
hint=hint if i not in hint_overrides else hint_overrides[i],
|
||||
source=TensorPropertySource(source, TensorProperty.SIZE, i),
|
||||
)
|
||||
for i, (sym, hint) in enumerate(zip(size, ex_size))
|
||||
|
|
|
|||
|
|
@ -923,7 +923,7 @@ class ParameterDict(Module):
|
|||
key (str): key to get from the ParameterDict
|
||||
default (Parameter, optional): value to return if key not present
|
||||
"""
|
||||
return self[key] if key in self else default # noqa: SIM401
|
||||
return self[key] if key in self else default
|
||||
|
||||
def fromkeys(
|
||||
self, keys: Iterable[str], default: Optional[Any] = None
|
||||
|
|
|
|||
|
|
@ -218,7 +218,7 @@ def _dump_DDP_relevant_env_vars():
|
|||
]
|
||||
formatted_output = ""
|
||||
for var in relevant_env_vars:
|
||||
value = os.environ.get(var, "N/A")
|
||||
value = os.environ[var] if var in os.environ else "N/A"
|
||||
formatted_output += f"env:{var}={value}\n"
|
||||
print(formatted_output)
|
||||
|
||||
|
|
|
|||
|
|
@ -774,8 +774,8 @@ class Optimizer:
|
|||
assert param_groups is not None
|
||||
for pg in param_groups:
|
||||
if param_id in pg["params"]:
|
||||
fused = pg.get("fused", False)
|
||||
capturable = pg.get("capturable", False)
|
||||
fused = pg["fused"] if "fused" in pg else False
|
||||
capturable = pg["capturable"] if "capturable" in pg else False
|
||||
break
|
||||
if key == "step":
|
||||
if capturable or fused:
|
||||
|
|
|
|||
|
|
@ -390,8 +390,8 @@ class DeviceTypeTestBase(TestCase):
|
|||
return test.tolerance_overrides.get(dtype, tol(self.precision, self.rel_tol))
|
||||
|
||||
def _apply_precision_override_for_test(self, test, param_kwargs):
|
||||
dtype = param_kwargs.get("dtype", None)
|
||||
dtype = param_kwargs.get("dtypes", dtype)
|
||||
dtype = param_kwargs["dtype"] if "dtype" in param_kwargs else None
|
||||
dtype = param_kwargs["dtypes"] if "dtypes" in param_kwargs else dtype
|
||||
if dtype:
|
||||
self.precision = self._get_precision_override(test, dtype)
|
||||
self.precision, self.rel_tol = self._get_tolerance_override(test, dtype)
|
||||
|
|
|
|||
|
|
@ -1915,7 +1915,7 @@ def sample_inputs_new_full(self, device, dtype, requires_grad, **kwargs):
|
|||
for sample in sample_inputs_new_fns(self, device, dtype, requires_grad, **kwargs):
|
||||
# The scalar we are passing to new_full must be the same dtype
|
||||
# as the one of the resulting tensor
|
||||
use_dtype = sample.kwargs.get('dtype', dtype)
|
||||
use_dtype = sample.kwargs['dtype'] if 'dtype' in sample.kwargs else dtype
|
||||
yield SampleInput(
|
||||
sample.input, *sample.args, get_val(use_dtype), **sample.kwargs)
|
||||
|
||||
|
|
|
|||
|
|
@ -725,7 +725,7 @@ class DistributedTest:
|
|||
lines = out.getvalue().splitlines()
|
||||
|
||||
def format_line(var):
|
||||
return f"env:{var}={os.environ.get(var, 'N/A')}"
|
||||
return f"env:{var}={os.environ[var] if var in os.environ else 'N/A'}"
|
||||
|
||||
# Check relevant env vars
|
||||
vars = [
|
||||
|
|
@ -6212,7 +6212,7 @@ class DistributedTest:
|
|||
)
|
||||
def test_ddp_logging_data_cpu(self):
|
||||
def parse_env(var):
|
||||
return os.environ.get(var, "N/A")
|
||||
return os.environ[var] if var in os.environ else "N/A"
|
||||
|
||||
dist.set_debug_level(dist.DebugLevel.INFO)
|
||||
_, group_id, _ = self._init_global_test()
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ INEQUALITY_TYPES = (sympy.Gt, sympy.Ge, sympy.Lt, sympy.Le)
|
|||
|
||||
|
||||
def mirror_rel_op(type: type) -> Optional[type[sympy.Rel]]:
|
||||
return _MIRROR_REL_OP.get(type)
|
||||
return _MIRROR_REL_OP.get(type, None)
|
||||
|
||||
|
||||
# Tries to simplify 'expr', so as to leave only 'thing' in the left-hand side.
|
||||
|
|
|
|||
|
|
@ -277,7 +277,7 @@ def create_graph(objects, *, context=None, filter=None):
|
|||
references = annotated_references(obj)
|
||||
for referrent in gc.get_referents(obj):
|
||||
rid = id(referrent)
|
||||
tidx = id_to_node.get(rid)
|
||||
tidx = id_to_node.get(rid, None)
|
||||
if tidx is None:
|
||||
continue
|
||||
labels = references.get(rid, ["?"])
|
||||
|
|
|
|||
|
|
@ -150,7 +150,7 @@ def gen_vmap_inplace_plumbing(native_function: NativeFunction) -> str | None:
|
|||
assert schema.kind() == SchemaKind.inplace
|
||||
if not is_mutated_arg(schema.arguments.flat_all[0]):
|
||||
return None
|
||||
if len([arg for arg in schema.arguments.flat_all if is_mutated_arg(arg)]) != 1:
|
||||
if not len([arg for arg in schema.arguments.flat_all if is_mutated_arg(arg)]) == 1:
|
||||
return None
|
||||
|
||||
# Only support cases where all returns are Tensors or vector<Tensor>
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user