[BE][PYFMT] migrate PYFMT for test/[a-h]*/ to ruff format (#144555)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144555
Approved by: https://github.com/ezyang
ghstack dependencies: #144551, #144554
This commit is contained in:
Xuehai Pan 2025-06-24 00:18:25 +00:00 committed by PyTorch MergeBot
parent e600e044a7
commit 6d5c789ad5
84 changed files with 585 additions and 520 deletions

View File

@ -494,9 +494,7 @@ class TestBaseDataSparsifier(_BaseDataSparsiferTestCase):
( (
emb1, emb1,
emb2, emb2,
) = nn.Embedding( ) = nn.Embedding(10, 3), nn.Embedding(20, 3)
10, 3
), nn.Embedding(20, 3)
emb1_bag, emb2_bag = nn.EmbeddingBag(10, 3), nn.EmbeddingBag(20, 3) emb1_bag, emb2_bag = nn.EmbeddingBag(10, 3), nn.EmbeddingBag(20, 3)
emb3, emb3_bag = nn.Embedding(15, 3), nn.EmbeddingBag(20, 3) emb3, emb3_bag = nn.Embedding(15, 3), nn.EmbeddingBag(20, 3)
@ -627,9 +625,7 @@ class TestNormDataSparsifiers(_NormDataSparsifierTestCase):
( (
emb1, emb1,
emb2, emb2,
) = nn.Embedding( ) = nn.Embedding(10, 3), nn.Embedding(20, 3)
10, 3
), nn.Embedding(20, 3)
emb1_bag, emb2_bag = nn.EmbeddingBag(10, 3), nn.EmbeddingBag(20, 3) emb1_bag, emb2_bag = nn.EmbeddingBag(10, 3), nn.EmbeddingBag(20, 3)
emb3, emb3_bag = nn.Embedding(15, 3), nn.EmbeddingBag(20, 3) emb3, emb3_bag = nn.Embedding(15, 3), nn.EmbeddingBag(20, 3)

View File

@ -218,12 +218,12 @@ def _sparse_layer_test_helper(
qmodule_to_check = fqn_to_module(qmodel, fqn_to_check) qmodule_to_check = fqn_to_module(qmodel, fqn_to_check)
# check that the modules were converted as expected # check that the modules were converted as expected
assert isinstance( assert isinstance(sqmodule_to_check, sqmodule_expected_converted_class), (
sqmodule_to_check, sqmodule_expected_converted_class "Convert failed"
), "Convert failed" )
assert isinstance( assert isinstance(qmodule_to_check, qmodule_expected_converted_class), (
qmodule_to_check, qmodule_expected_converted_class "Mapping failed"
), "Mapping failed" )
row_block_size, col_block_size = sqmodel.linear._packed_params._weight_bias()[ row_block_size, col_block_size = sqmodel.linear._packed_params._weight_bias()[
2: 2:

View File

@ -1055,9 +1055,9 @@ class TestFPGMPruner(TestCase):
mask1 = pruner.groups[0]["module"].parametrizations.weight[0].mask[-1] mask1 = pruner.groups[0]["module"].parametrizations.weight[0].mask[-1]
mask2 = pruner.groups[0]["module"].parametrizations.weight[0].mask[-2] mask2 = pruner.groups[0]["module"].parametrizations.weight[0].mask[-2]
# Check if either of the least-norm filters is not pruned # Check if either of the least-norm filters is not pruned
assert ( assert mask1.item() is not False or mask2.item() is not False, (
mask1.item() is not False or mask2.item() is not False "Do not prune all least-norm filters"
), "Do not prune all least-norm filters" )
# fusion step # fusion step
pruned_model = pruner.prune() pruned_model = pruner.prune()

View File

@ -66,9 +66,10 @@ def generate_callgrind_artifacts() -> None:
json.dump(artifacts, f, indent=4) json.dump(artifacts, f, indent=4)
def load_callgrind_artifacts() -> ( def load_callgrind_artifacts() -> tuple[
tuple[benchmark_utils.CallgrindStats, benchmark_utils.CallgrindStats] benchmark_utils.CallgrindStats,
): benchmark_utils.CallgrindStats,
]:
"""Hermetic artifact to unit test Callgrind wrapper. """Hermetic artifact to unit test Callgrind wrapper.
In addition to collecting counts, this wrapper provides some facilities for In addition to collecting counts, this wrapper provides some facilities for

View File

@ -158,7 +158,8 @@ def compute_functional_name(test_params_dict):
return test_params_dict["cpp_function_call"].split("(")[0].replace("F::", "") return test_params_dict["cpp_function_call"].split("(")[0].replace("F::", "")
else: else:
raise RuntimeError( raise RuntimeError(
f"`cpp_options_args` or `cpp_function_call` entry must be present in test params dict:\n{pprint.pformat(test_params_dict)}" # noqa: B950 "`cpp_options_args` or `cpp_function_call` entry must be present in test params dict:\n"
f"{pprint.pformat(test_params_dict)}"
) )
@ -179,7 +180,8 @@ def compute_cpp_function_call(test_params_dict, arg_dict, functional_name):
) )
else: else:
raise RuntimeError( raise RuntimeError(
f"`cpp_options_args` or `cpp_function_call` entry must be present in test params dict:\n{pprint.pformat(test_params_dict)}" # noqa: B950 "`cpp_options_args` or `cpp_function_call` entry must be present in test params dict:\n"
f"{pprint.pformat(test_params_dict)}"
) )
@ -217,7 +219,8 @@ def write_test_to_test_class(
or "cpp_function_call" in test_params_dict or "cpp_function_call" in test_params_dict
), ( ), (
"To enable C++ API parity test, " "To enable C++ API parity test, "
f"`cpp_options_args` or `cpp_function_call` entry must be present in test params dict:\n{pprint.pformat(test_params_dict)}. \n" # noqa: B950 "`cpp_options_args` or `cpp_function_call` entry must be present in test params dict:\n"
f"{pprint.pformat(test_params_dict)}. \n"
"If you are interested in adding the C++ API parity test, please see:\n" "If you are interested in adding the C++ API parity test, please see:\n"
"NOTE [How to check NN module / functional API parity between Python and C++ frontends]. \n" "NOTE [How to check NN module / functional API parity between Python and C++ frontends]. \n"
"If not, please add `test_cpp_api_parity=False` to the test params dict and file an issue about this." "If not, please add `test_cpp_api_parity=False` to the test params dict and file an issue about this."
@ -233,14 +236,16 @@ def write_test_to_test_class(
functional_name = compute_functional_name(test_params_dict) functional_name = compute_functional_name(test_params_dict)
assert hasattr( assert hasattr(torch.nn.functional, functional_name), (
torch.nn.functional, functional_name f"`torch.nn.functional` doesn't have function `{functional_name}`. "
), f"`torch.nn.functional` doesn't have function `{functional_name}`. (Discovered while processing\n{pprint.pformat(test_params_dict)}.)" # noqa: B950 f"(Discovered while processing\n{pprint.pformat(test_params_dict)}.)"
)
functional_full_name = "F::" + functional_name functional_full_name = "F::" + functional_name
assert functional_full_name in parity_table["torch::nn::functional"], ( assert functional_full_name in parity_table["torch::nn::functional"], (
f"Please add `{functional_full_name}` entry to `torch::nn::functional` section of `test/cpp_api_parity/parity-tracker.md`. " f"Please add `{functional_full_name}` entry to `torch::nn::functional` "
"section of `test/cpp_api_parity/parity-tracker.md`. "
f"(Discovered while processing\n{pprint.pformat(test_params_dict)}.)" f"(Discovered while processing\n{pprint.pformat(test_params_dict)}.)"
) )

View File

@ -78,9 +78,9 @@ def _kernel_fallback(op, *args, **kwargs):
elif op._schema.is_mutable or op is torch.ops.aten._copy_from.default: elif op._schema.is_mutable or op is torch.ops.aten._copy_from.default:
# Only handle inplace ops returning their first arg # Only handle inplace ops returning their first arg
assert len(args) >= 1, f"Inplace {op} needs at least one arg" assert len(args) >= 1, f"Inplace {op} needs at least one arg"
assert ( assert len(op._schema.returns) == 1, (
len(op._schema.returns) == 1 f"NYI Inplace {op} with more than one return"
), f"NYI Inplace {op} with more than one return" )
op_name = op.overloadpacket._qualified_op_name op_name = op.overloadpacket._qualified_op_name
real_res = args[0] real_res = args[0]
elif any(r.alias_info is not None for r in op._schema.returns): elif any(r.alias_info is not None for r in op._schema.returns):

View File

@ -67,7 +67,7 @@ def prepare_for_sending(args, kwargs):
def convert(obj): def convert(obj):
if type(obj) not in VALID_QUEUE_TYPES_IN: if type(obj) not in VALID_QUEUE_TYPES_IN:
raise RuntimeError( raise RuntimeError(
f"Cannot send object of type {type(obj)} " "over openreg device pipe." f"Cannot send object of type {type(obj)} over openreg device pipe."
) )
if isinstance(obj, torch.Tensor): if isinstance(obj, torch.Tensor):
@ -82,8 +82,7 @@ def receive_after_sending(allocator, args, kwargs):
def convert(obj): def convert(obj):
if type(obj) not in VALID_QUEUE_TYPES_OUT: if type(obj) not in VALID_QUEUE_TYPES_OUT:
raise RuntimeError( raise RuntimeError(
f"Received invalid object of type {type(obj)} " f"Received invalid object of type {type(obj)} over openreg device pipe."
"over openreg device pipe."
) )
if isinstance(obj, OpenRegTensorMeta): if isinstance(obj, OpenRegTensorMeta):

View File

@ -561,8 +561,9 @@ class TestFullyShardPrefetch(FSDPTest):
FSDPParamGroup.post_backward, events FSDPParamGroup.post_backward, events
) )
# Check the order for normal 1 forward, 1 backward, 1 optimizer step # Check the order for normal 1 forward, 1 backward, 1 optimizer step
with patch_unshard(unshard_with_record), patch_post_backward( with (
post_backward_with_record patch_unshard(unshard_with_record),
patch_post_backward(post_backward_with_record),
): ):
for iter_idx in range(3): for iter_idx in range(3):
loss = model(inp) loss = model(inp)
@ -617,8 +618,9 @@ class TestFullyShardPrefetch(FSDPTest):
FSDPParamGroup.post_backward, events FSDPParamGroup.post_backward, events
) )
# Check the order for multiple forwards before 1 backward # Check the order for multiple forwards before 1 backward
with patch_unshard(unshard_with_record), patch_post_backward( with (
post_backward_with_record patch_unshard(unshard_with_record),
patch_post_backward(post_backward_with_record),
): ):
loss1 = model(inp) loss1 = model(inp)
loss2 = model(inp) loss2 = model(inp)
@ -703,8 +705,9 @@ class TestFullyShardPrefetch(FSDPTest):
post_backward_with_record = self._get_post_backward_with_record( post_backward_with_record = self._get_post_backward_with_record(
FSDPParamGroup.post_backward, events FSDPParamGroup.post_backward, events
) )
with patch_unshard(unshard_with_record), patch_post_backward( with (
post_backward_with_record patch_unshard(unshard_with_record),
patch_post_backward(post_backward_with_record),
): ):
loss1, loss2 = model(inp) loss1, loss2 = model(inp)
expected_events = [ expected_events = [
@ -794,9 +797,11 @@ class TestFullyShardPrefetch(FSDPTest):
("reshard", "", TrainingState.POST_BACKWARD), ("reshard", "", TrainingState.POST_BACKWARD),
("post_backward", "", TrainingState.POST_BACKWARD), ("post_backward", "", TrainingState.POST_BACKWARD),
] ]
with patch_unshard(unshard_with_record), patch_reshard( with (
reshard_with_record patch_unshard(unshard_with_record),
), patch_post_backward(post_backward_with_record): patch_reshard(reshard_with_record),
patch_post_backward(post_backward_with_record),
):
set_forward_prefetch(model, num_to_prefetch=1) set_forward_prefetch(model, num_to_prefetch=1)
loss = model(inp) loss = model(inp)
expected_forward_events = [ expected_forward_events = [
@ -882,9 +887,11 @@ class TestFullyShardPrefetch(FSDPTest):
("reshard", "layers.3", TrainingState.FORWARD), ("reshard", "layers.3", TrainingState.FORWARD),
("reshard", "", TrainingState.FORWARD), ("reshard", "", TrainingState.FORWARD),
] ]
with patch_unshard(unshard_with_record), patch_reshard( with (
reshard_with_record patch_unshard(unshard_with_record),
), patch_post_backward(post_backward_with_record): patch_reshard(reshard_with_record),
patch_post_backward(post_backward_with_record),
):
set_backward_prefetch(model, num_to_prefetch=1) set_backward_prefetch(model, num_to_prefetch=1)
loss = model(inp) loss = model(inp)
self.assertEqual(events, expected_forward_events) self.assertEqual(events, expected_forward_events)
@ -967,8 +974,9 @@ class TestFullyShardPrefetch(FSDPTest):
(2, model_args.max_seq_len), (2, model_args.max_seq_len),
device=device_type.type, device=device_type.type,
) )
with patch_unshard(unshard_with_record), patch_post_backward( with (
post_backward_with_record patch_unshard(unshard_with_record),
patch_post_backward(post_backward_with_record),
): ):
for _ in range(3): for _ in range(3):
loss = model(inp) loss = model(inp)
@ -1046,8 +1054,9 @@ class TestFullyShardPrefetch(FSDPTest):
FSDPParamGroup.post_backward, events FSDPParamGroup.post_backward, events
) )
inp = torch.randn((2, 16), device=device_type.type) inp = torch.randn((2, 16), device=device_type.type)
with patch_unshard(unshard_with_record), patch_post_backward( with (
post_backward_with_record patch_unshard(unshard_with_record),
patch_post_backward(post_backward_with_record),
): ):
for _ in range(3): for _ in range(3):
loss = model(inp) loss = model(inp)

View File

@ -222,9 +222,7 @@ class TestFullyShardCompile(FSDPTest):
): ):
unsharded_param_graph_inputs.add(node.args[0]) unsharded_param_graph_inputs.add(node.args[0])
assert len(unsharded_param_graph_inputs) > 0 assert len(unsharded_param_graph_inputs) > 0
assert len(unsharded_param_graph_inputs) == len( assert len(unsharded_param_graph_inputs) == len(list(model.parameters())), """\
list(model.parameters())
), """\
Expected all model parameters to be wrapped by FSDP2 and Expected all model parameters to be wrapped by FSDP2 and
have their unsharded version as graph input, but it's not true! have their unsharded version as graph input, but it's not true!
""" """
@ -237,7 +235,7 @@ have their unsharded version as graph input, but it's not true!
no_aliased_unsharded_params_in_graph_inputs = False no_aliased_unsharded_params_in_graph_inputs = False
err_msg += f"""\n err_msg += f"""\n
Found aliased unsharded param in graph inputs: {aliased_graph_inputs}, Found aliased unsharded param in graph inputs: {aliased_graph_inputs},
val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]}, val.shape: {[node.meta["val"].shape for node in aliased_graph_inputs]},
""" """
self.assertTrue(no_aliased_unsharded_params_in_graph_inputs, err_msg) self.assertTrue(no_aliased_unsharded_params_in_graph_inputs, err_msg)
@ -466,10 +464,9 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]},
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
def test_compiled_autograd_ctx(self): def test_compiled_autograd_ctx(self):
self.skipTestForOldSm() self.skipTestForOldSm()
with torch._dynamo.config.patch( with (
skip_fsdp_hooks=False, torch._dynamo.config.patch(skip_fsdp_hooks=False),
), torch._functorch.config.patch( torch._functorch.config.patch(recompute_views=True),
recompute_views=True,
): ):
inputs = torch.randn(8, 8) inputs = torch.randn(8, 8)
model = torch.nn.Linear(8, 8) model = torch.nn.Linear(8, 8)
@ -567,24 +564,28 @@ Unsupported Tensor.backward() call
torch._dynamo.reset() torch._dynamo.reset()
torch._dynamo.compiled_autograd.reset() torch._dynamo.compiled_autograd.reset()
with torch._dynamo.config.patch( with (
compiled_autograd=True, torch._dynamo.config.patch(
compiled_autograd_kwargs_override={ compiled_autograd=True,
"fullgraph": True, compiled_autograd_kwargs_override={
}, "fullgraph": True,
inline_inbuilt_nn_modules=True, },
skip_fsdp_hooks=False, inline_inbuilt_nn_modules=True,
), torch._functorch.config.patch( skip_fsdp_hooks=False,
enable_autograd_cache=False, ),
recompute_views=True, torch._functorch.config.patch(
), torch._inductor.config.patch( enable_autograd_cache=False,
force_disable_caches=True, recompute_views=True,
reorder_for_compute_comm_overlap=True, ),
reorder_for_compute_comm_overlap_passes=[ torch._inductor.config.patch(
"sink_waits", force_disable_caches=True,
"raise_comms", reorder_for_compute_comm_overlap=True,
"reorder_compute_for_overlap", reorder_for_compute_comm_overlap_passes=[
], "sink_waits",
"raise_comms",
"reorder_compute_for_overlap",
],
),
): ):
losses_compiled = test_compiled() losses_compiled = test_compiled()
losses_eager = test_eager() losses_eager = test_eager()
@ -741,20 +742,21 @@ Unsupported Tensor.backward() call
def _test_nested_fully_shard_backend_inductor_fullgraph_True(self): def _test_nested_fully_shard_backend_inductor_fullgraph_True(self):
self.skipTestForOldSm() self.skipTestForOldSm()
for fwd_fullgraph in [True]: for fwd_fullgraph in [True]:
with self._reinplace_all_gather_with_optional_checks( with (
fwd_fullgraph self._reinplace_all_gather_with_optional_checks(fwd_fullgraph),
), torch._inductor.config.patch( torch._inductor.config.patch(
post_grad_custom_post_pass=( post_grad_custom_post_pass=(
functools.partial( functools.partial(
self._check_fsdp_copy_and_resize_ops_count_in_graph, self._check_fsdp_copy_and_resize_ops_count_in_graph,
fwd_copy_count=0, fwd_copy_count=0,
fwd_resize_count=0, fwd_resize_count=0,
bwd_copy_count=0, bwd_copy_count=0,
bwd_resize_count=0, bwd_resize_count=0,
)
if fwd_fullgraph
else None
) )
if fwd_fullgraph ),
else None
)
): ):
_, triton_codes = run_and_get_code( _, triton_codes = run_and_get_code(
lambda: self._test_traceable_fsdp( lambda: self._test_traceable_fsdp(
@ -943,9 +945,10 @@ Unsupported Tensor.backward() call
for fwd_fullgraph, all_requires_grad in itertools.product( for fwd_fullgraph, all_requires_grad in itertools.product(
[True], [True, False] [True], [True, False]
): ):
with self._maybe_add_graph_break_to_sdpa( with (
fwd_fullgraph self._maybe_add_graph_break_to_sdpa(fwd_fullgraph),
), self._reinplace_all_gather_with_optional_checks(fwd_fullgraph): self._reinplace_all_gather_with_optional_checks(fwd_fullgraph),
):
self._test_traceable_fsdp( self._test_traceable_fsdp(
*self._create_transformer_factory_fns( *self._create_transformer_factory_fns(
all_requires_grad=all_requires_grad all_requires_grad=all_requires_grad
@ -982,23 +985,24 @@ Unsupported Tensor.backward() call
log.warning( log.warning(
f"fwd_fullgraph={fwd_fullgraph}, all_requires_grad={all_requires_grad}, activation_checkpoint={activation_checkpoint}" # noqa: G004, G001, B950 f"fwd_fullgraph={fwd_fullgraph}, all_requires_grad={all_requires_grad}, activation_checkpoint={activation_checkpoint}" # noqa: G004, G001, B950
) )
with self._reinplace_all_gather_with_optional_checks( with (
fwd_fullgraph self._reinplace_all_gather_with_optional_checks(fwd_fullgraph),
), torch._inductor.config.patch( torch._inductor.config.patch(
post_grad_custom_post_pass=( post_grad_custom_post_pass=(
functools.partial( functools.partial(
self._check_fsdp_copy_and_resize_ops_count_in_graph, self._check_fsdp_copy_and_resize_ops_count_in_graph,
# NOTE: For the root unsharded params, we don't reshard after forward since for training, # NOTE: For the root unsharded params, we don't reshard after forward since for training,
# the parameters would be freed and all-gathered immediately. Hence we still have # the parameters would be freed and all-gathered immediately. Hence we still have
# their resize and copy ops in the graph. # their resize and copy ops in the graph.
fwd_copy_count=4, fwd_copy_count=4,
fwd_resize_count=4, fwd_resize_count=4,
bwd_copy_count=0, bwd_copy_count=0,
bwd_resize_count=4, bwd_resize_count=4,
)
if fwd_fullgraph
else None
) )
if fwd_fullgraph ),
else None
)
): ):
_, triton_codes = run_and_get_code( _, triton_codes = run_and_get_code(
lambda: self._test_traceable_fsdp( lambda: self._test_traceable_fsdp(

View File

@ -385,9 +385,9 @@ class TestFullyShardAllGatherExtensionsMultiThread(
only some ranks may require padding, in which case only those ranks only some ranks may require padding, in which case only those ranks
will error out and the all-gather will timeout. will error out and the all-gather will timeout.
""" """
assert ( assert self.world_size >= 2, (
self.world_size >= 2 f"Assumes world size of at least 2 but got {self.world_size=}"
), f"Assumes world size of at least 2 but got {self.world_size=}" )
model = MLP(dim=3, dim_multiplier=3) model = MLP(dim=3, dim_multiplier=3)
for module in model.modules(): for module in model.modules():
for param_name, param in module.named_parameters(recurse=False): for param_name, param in module.named_parameters(recurse=False):

View File

@ -115,9 +115,10 @@ class TestFullyShardFrozen(FSDPTest):
torch.manual_seed(42 + self.rank + 1) torch.manual_seed(42 + self.rank + 1)
device = device_type device = device_type
with patch_reduce_scatter( with (
reduce_scatter patch_reduce_scatter(reduce_scatter),
), patch_register_post_backward_hook_backward(backward_with_count): patch_register_post_backward_hook_backward(backward_with_count),
):
for iter_idx in range(10): for iter_idx in range(10):
inp = torch.randn((8, lin_dim), device=device) inp = torch.randn((8, lin_dim), device=device)
losses: list[torch.Tensor] = [] losses: list[torch.Tensor] = []

View File

@ -910,9 +910,9 @@ class TestFullyShardProcessGroupInit(FSDPTestMultiThread):
@skip_if_lt_x_gpu(1) @skip_if_lt_x_gpu(1)
def test_2d_process_group_init(self): def test_2d_process_group_init(self):
shard_mesh_dim_size = 2 shard_mesh_dim_size = 2
assert ( assert self.world_size % shard_mesh_dim_size == 0, (
self.world_size % shard_mesh_dim_size == 0 f"Expects {self.world_size} to be divisible by {shard_mesh_dim_size}"
), f"Expects {self.world_size} to be divisible by {shard_mesh_dim_size}" )
replicate_mesh_dim_size = self.world_size // shard_mesh_dim_size replicate_mesh_dim_size = self.world_size // shard_mesh_dim_size
mesh_dim_names = ("replicate", "shard") mesh_dim_names = ("replicate", "shard")
ref_mesh = init_device_mesh( ref_mesh = init_device_mesh(

View File

@ -54,9 +54,9 @@ class TestFullyShardMemory(FSDPTest):
) )
): ):
return # skip since not a common use case return # skip since not a common use case
assert ( assert self.world_size == 2, (
self.world_size == 2 f"Requires world size of 2 since some values are hard coded: {self.world_size}"
), f"Requires world size of 2 since some values are hard coded: {self.world_size}" )
torch.manual_seed(42) torch.manual_seed(42)
# Pre-run a linear forward (gemm and bias) and backward (gemm) to # Pre-run a linear forward (gemm and bias) and backward (gemm) to
# allocate the cuBLAS workspaces before measuring the memory usage # allocate the cuBLAS workspaces before measuring the memory usage

View File

@ -284,9 +284,7 @@ class TestFullyShardMixedPrecisionTraining(FSDPTest):
) # bf16 reduction ) # bf16 reduction
param.grad = funcol.all_gather_tensor( param.grad = funcol.all_gather_tensor(
sharded_grad, gather_dim=0, group=group sharded_grad, gather_dim=0, group=group
).to( ).to(param.dtype) # upcast to fp32
param.dtype
) # upcast to fp32
ref_optim.step() # fp32 optimizer step ref_optim.step() # fp32 optimizer step
self.assertEqual(fsdp_loss, ref_loss) self.assertEqual(fsdp_loss, ref_loss)

View File

@ -139,8 +139,9 @@ class TestFullyShardOverlap(FSDPTest):
dist.reduce_scatter_tensor(dummy_rs_output, dummy_rs_input) dist.reduce_scatter_tensor(dummy_rs_output, dummy_rs_input)
def fwd_bwd(): def fwd_bwd():
with patch_all_gather(delayed_all_gather), patch_reduce_scatter( with (
delayed_reduce_scatter patch_all_gather(delayed_all_gather),
patch_reduce_scatter(delayed_reduce_scatter),
): ):
loss = model(inp).sum() loss = model(inp).sum()
loss.backward() loss.backward()

View File

@ -74,12 +74,12 @@ class TestFullyShardForwardInputs(FSDPTestMultiThread):
# Check that FSDP moved the inputs to GPU, including recursing # Check that FSDP moved the inputs to GPU, including recursing
# into the tuple data structure # into the tuple data structure
assert x.device == device, f"Expects {device} but got {x.device}" assert x.device == device, f"Expects {device} but got {x.device}"
assert ( assert ys[0].device == device, (
ys[0].device == device f"Expects {device} but got {ys[0].device}"
), f"Expects {device} but got {ys[0].device}" )
assert ( assert ys[1].device == device, (
ys[1].device == device f"Expects {device} but got {ys[1].device}"
), f"Expects {device} but got {ys[1].device}" )
y = ys[0] + ys[1] y = ys[0] + ys[1]
return x + y + 1 return x + y + 1

View File

@ -234,8 +234,8 @@ class DistributedDataParallelCommHookTest(MultiProcessTestCase):
if __name__ == "__main__": if __name__ == "__main__":
assert ( assert not torch.cuda._initialized, (
not torch.cuda._initialized "test_distributed must not have initialized CUDA context on main process"
), "test_distributed must not have initialized CUDA context on main process" )
run_tests() run_tests()

View File

@ -250,9 +250,11 @@ class TestJoin(MultiProcessTestCase):
else "Detected at least one rank that exhausted inputs. " else "Detected at least one rank that exhausted inputs. "
"Throwing across all ranks." "Throwing across all ranks."
) )
with self.assertRaisesRegex( with (
RuntimeError, expected_msg self.assertRaisesRegex(RuntimeError, expected_msg)
) if throw_on_early_termination else contextlib.nullcontext(): if throw_on_early_termination
else contextlib.nullcontext()
):
with Join( with Join(
allreducers, allreducers,
enable=enable, enable=enable,

View File

@ -677,9 +677,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
fully_shard(layer) fully_shard(layer)
fully_shard(model) fully_shard(model)
optim = torch.optim.Adam(model.parameters(), lr=1e-2) optim = torch.optim.Adam(model.parameters(), lr=1e-2)
torch.optim.lr_scheduler.LambdaLR( torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=[lambda epoch: 0.95**epoch])
optim, lr_lambda=[lambda epoch: 0.95**epoch]
)
opt_state_dict = ptd_state_dict.get_optimizer_state_dict( opt_state_dict = ptd_state_dict.get_optimizer_state_dict(
model, model,
optim, optim,

View File

@ -228,9 +228,9 @@ class TestStateDictStager(TestCase):
# Validate tensor count and bytes # Validate tensor count and bytes
expected_storage_cnt = 2 expected_storage_cnt = 2
assert ( assert num_storages == expected_storage_cnt, (
num_storages == expected_storage_cnt f"Expected {expected_storage_cnt} storages, got {num_storages}"
), f"Expected {expected_storage_cnt} storages, got {num_storages}" )
# Calculate expected bytes # Calculate expected bytes
# Note: Only unique storages are counted in the byte count # Note: Only unique storages are counted in the byte count
@ -239,9 +239,9 @@ class TestStateDictStager(TestCase):
+ tensor3.numel() # tensor1 and tensor2 share storage + tensor3.numel() # tensor1 and tensor2 share storage
* tensor3.element_size() # tensor3 and its narrow view share storage * tensor3.element_size() # tensor3 and its narrow view share storage
) )
assert ( assert num_bytes == expected_bytes, (
num_bytes == expected_bytes f"Expected {expected_bytes} bytes, got {num_bytes}"
), f"Expected {expected_bytes} bytes, got {num_bytes}" )
# Verify that the CPU state dict is equivalent to the original CUDA state dict # Verify that the CPU state dict is equivalent to the original CUDA state dict
result, error = compare_state_dicts(state_dict, cpu_state_dict) result, error = compare_state_dicts(state_dict, cpu_state_dict)
assert result, f"State dicts are not equivalent: {error}" assert result, f"State dicts are not equivalent: {error}"
@ -301,9 +301,9 @@ class TestStateDictStager(TestCase):
# Verify the first result is correct # Verify the first result is correct
result, error = compare_state_dicts(state_dict, cpu_state_dict1) result, error = compare_state_dicts(state_dict, cpu_state_dict1)
assert ( assert result, (
result f"First state dict is not equivalent to original: {error}"
), f"First state dict is not equivalent to original: {error}" )
# Modify the original tensors # Modify the original tensors
tensor1.fill_(0) tensor1.fill_(0)
@ -317,14 +317,14 @@ class TestStateDictStager(TestCase):
# Verify that the second CPU state dict is equivalent to the modified original state dict # Verify that the second CPU state dict is equivalent to the modified original state dict
result, error = compare_state_dicts(state_dict, cpu_state_dict2) result, error = compare_state_dicts(state_dict, cpu_state_dict2)
assert ( assert result, (
result f"Second state dict is not equivalent to modified original: {error}"
), f"Second state dict is not equivalent to modified original: {error}" )
# Verify that the number of cached storages hasn't changed # Verify that the number of cached storages hasn't changed
assert ( assert num_storages1 == num_storages2, (
num_storages1 == num_storages2 f"Storage count changed: {num_storages1} vs {num_storages2}"
), f"Storage count changed: {num_storages1} vs {num_storages2}" )
# Verify that the tensors in the second state dict have the same storage pointers as the first # Verify that the tensors in the second state dict have the same storage pointers as the first
assert ( assert (
@ -347,12 +347,12 @@ class TestStateDictStager(TestCase):
cpu_state_dict3 = stager.stage(state_dict) cpu_state_dict3 = stager.stage(state_dict)
# Verify that the third CPU state dict reflects the updated values # Verify that the third CPU state dict reflects the updated values
assert torch.all( assert torch.all(cpu_state_dict3["tensor1"] == 42.0), (
cpu_state_dict3["tensor1"] == 42.0 "Updated values should be reflected in the cached state dict"
), "Updated values should be reflected in the cached state dict" )
assert torch.all( assert torch.all(cpu_state_dict3["tensor2"] == 42.0), (
cpu_state_dict3["tensor2"] == 42.0 "Updated values should be reflected in the cached state dict"
), "Updated values should be reflected in the cached state dict" )
@requires_cuda @requires_cuda
def test_tensor_attrs(self): def test_tensor_attrs(self):
@ -381,24 +381,24 @@ class TestStateDictStager(TestCase):
cpu_state_dict = stager.stage(state_dict) cpu_state_dict = stager.stage(state_dict)
# Verify that tensor attributes are preserved # Verify that tensor attributes are preserved
assert hasattr( assert hasattr(cpu_state_dict["tensor1"], "a"), (
cpu_state_dict["tensor1"], "a" "Tensor attribute 'a' was not preserved"
), "Tensor attribute 'a' was not preserved" )
assert ( assert cpu_state_dict["tensor1"].a == 42, (
cpu_state_dict["tensor1"].a == 42 "Tensor attribute 'a' has incorrect value"
), "Tensor attribute 'a' has incorrect value" )
assert hasattr( assert hasattr(cpu_state_dict["tensor1"], "b"), (
cpu_state_dict["tensor1"], "b" "Tensor attribute 'b' was not preserved"
), "Tensor attribute 'b' was not preserved" )
assert ( assert cpu_state_dict["tensor1"].b == 43, (
cpu_state_dict["tensor1"].b == 43 "Tensor attribute 'b' has incorrect value"
), "Tensor attribute 'b' has incorrect value" )
assert hasattr( assert hasattr(cpu_state_dict["recursive"]["tensor3"], "c"), (
cpu_state_dict["recursive"]["tensor3"], "c" "Tensor attribute 'c' was not preserved"
), "Tensor attribute 'c' was not preserved" )
assert ( assert cpu_state_dict["recursive"]["tensor3"].c == 44, (
cpu_state_dict["recursive"]["tensor3"].c == 44 "Tensor attribute 'c' has incorrect value"
), "Tensor attribute 'c' has incorrect value" )
@requires_cuda @requires_cuda
def test_different_dtypes(self): def test_different_dtypes(self):

View File

@ -500,11 +500,13 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
logs_specs=DefaultLogsSpecs(log_dir=self.log_dir()), logs_specs=DefaultLogsSpecs(log_dir=self.log_dir()),
) )
with mock.patch.object( with (
mpc, "_is_done", return_value=True mock.patch.object(mpc, "_is_done", return_value=True),
), mock.patch.object(mpc, "_pc"), mock.patch.object( mock.patch.object(mpc, "_pc"),
mpc._pc, "join", side_effect=[True, False, False, True] mock.patch.object(
) as mock_join: mpc._pc, "join", side_effect=[True, False, False, True]
) as mock_join,
):
mpc._poll() mpc._poll()
self.assertEqual(4, mock_join.call_count) self.assertEqual(4, mock_join.call_count)

View File

@ -56,32 +56,36 @@ class TestDistributedCheckpoint(FSDPTest):
torch.manual_seed(200) torch.manual_seed(200)
new_model = wrap(SkipModel(double_nest=True)) new_model = wrap(SkipModel(double_nest=True))
with FullyShardedDataParallel.summon_full_params( with (
model FullyShardedDataParallel.summon_full_params(model),
), FullyShardedDataParallel.summon_full_params(new_model): FullyShardedDataParallel.summon_full_params(new_model),
):
params = list(model.parameters()) params = list(model.parameters())
new_params = list(new_model.parameters()) new_params = list(new_model.parameters())
self.assertNotEqual(params, new_params) self.assertNotEqual(params, new_params)
writer = FileSystemWriter(self.temp_dir) writer = FileSystemWriter(self.temp_dir)
reader = FileSystemReader(self.temp_dir) reader = FileSystemReader(self.temp_dir)
with FSDP.state_dict_type(model, state_dict_type), FSDP.state_dict_type( with (
new_model, state_dict_type FSDP.state_dict_type(model, state_dict_type),
FSDP.state_dict_type(new_model, state_dict_type),
): ):
state_dict = model.state_dict() state_dict = model.state_dict()
save(state_dict, writer) save(state_dict, writer)
with FSDP.state_dict_type(model, state_dict_type), FSDP.state_dict_type( with (
new_model, state_dict_type FSDP.state_dict_type(model, state_dict_type),
FSDP.state_dict_type(new_model, state_dict_type),
): ):
state_dict = new_model.state_dict() state_dict = new_model.state_dict()
load(state_dict, reader) load(state_dict, reader)
new_model.load_state_dict(state_dict) new_model.load_state_dict(state_dict)
with FullyShardedDataParallel.summon_full_params( with (
model FullyShardedDataParallel.summon_full_params(model),
), FullyShardedDataParallel.summon_full_params(new_model): FullyShardedDataParallel.summon_full_params(new_model),
):
params = list(model.parameters()) params = list(model.parameters())
new_params = list(new_model.parameters()) new_params = list(new_model.parameters())
self.assertEqual(params, new_params) self.assertEqual(params, new_params)

View File

@ -242,11 +242,10 @@ class TestCommunication(FSDPTest):
# and if `use_no_sync=False`, we only run `num_iters` iterations # and if `use_no_sync=False`, we only run `num_iters` iterations
# outside `no_sync()` # outside `no_sync()`
num_iters = 3 num_iters = 3
with patch( with (
"torch.distributed.all_gather_into_tensor" patch("torch.distributed.all_gather_into_tensor") as mock_all_gather,
) as mock_all_gather, patch( patch("torch.distributed.reduce_scatter_tensor") as mock_reduce_scatter,
"torch.distributed.reduce_scatter_tensor" ):
) as mock_reduce_scatter:
def reset_mocks(): def reset_mocks():
mock_all_gather.reset_mock() mock_all_gather.reset_mock()

View File

@ -379,12 +379,15 @@ class TestHooks(FSDPTest):
register_pre_backward_hooks_call_count += 1 register_pre_backward_hooks_call_count += 1
return orig_register_pre_backward_hooks(*args, **kwargs) return orig_register_pre_backward_hooks(*args, **kwargs)
with mock.patch( with (
"torch.distributed.fsdp._runtime_utils._register_pre_backward_hooks", mock.patch(
_register_pre_backward_hooks_with_count, "torch.distributed.fsdp._runtime_utils._register_pre_backward_hooks",
), mock.patch( _register_pre_backward_hooks_with_count,
"torch.distributed.fsdp._runtime_utils._register_post_backward_hook" ),
) as register_post_bwd_mock: mock.patch(
"torch.distributed.fsdp._runtime_utils._register_post_backward_hook"
) as register_post_bwd_mock,
):
self.assertEqual(register_pre_backward_hooks_call_count, 0) self.assertEqual(register_pre_backward_hooks_call_count, 0)
self.assertFalse(register_post_bwd_mock.called) self.assertFalse(register_post_bwd_mock.called)
fsdp_model(*input) fsdp_model(*input)

View File

@ -152,9 +152,9 @@ class TestGradAcc(FSDPTest):
batches.append(tuple(permute_tensor(t) for t in batch)) batches.append(tuple(permute_tensor(t) for t in batch))
for batch1, batch2 in itertools.combinations(batches, r=2): for batch1, batch2 in itertools.combinations(batches, r=2):
for t1, t2 in zip(batch1, batch2): for t1, t2 in zip(batch1, batch2):
assert not torch.all( assert not torch.all(t1 == t2), (
t1 == t2 "Check the test to make sure that batches are distinct"
), "Check the test to make sure that batches are distinct" )
# Concatenate the batches along the given batch dimension # Concatenate the batches along the given batch dimension
concat_batch: tuple[torch.Tensor, ...] = tuple( concat_batch: tuple[torch.Tensor, ...] = tuple(

View File

@ -121,8 +121,9 @@ class TestFSDPHybridShard(FSDPTest):
def test_hsdp_save_load_state_dict(self): def test_hsdp_save_load_state_dict(self):
model = MyModel().cuda() model = MyModel().cuda()
num_node_devices = torch.cuda.device_count() num_node_devices = torch.cuda.device_count()
shard_rank_lists = list(range(0, num_node_devices // 2)), list( shard_rank_lists = (
range(num_node_devices // 2, num_node_devices) list(range(0, num_node_devices // 2)),
list(range(num_node_devices // 2, num_node_devices)),
) )
shard_groups = ( shard_groups = (
dist.new_group(shard_rank_lists[0]), dist.new_group(shard_rank_lists[0]),
@ -171,8 +172,9 @@ class TestFSDPHybridShard(FSDPTest):
def test_hsdp_sync_module_state(self): def test_hsdp_sync_module_state(self):
model = MyModel().cuda() model = MyModel().cuda()
num_node_devices = torch.cuda.device_count() num_node_devices = torch.cuda.device_count()
shard_rank_lists = list(range(0, num_node_devices // 2)), list( shard_rank_lists = (
range(num_node_devices // 2, num_node_devices) list(range(0, num_node_devices // 2)),
list(range(num_node_devices // 2, num_node_devices)),
) )
shard_groups = ( shard_groups = (
dist.new_group(shard_rank_lists[0]), dist.new_group(shard_rank_lists[0]),
@ -310,8 +312,9 @@ class TestFSDPHybridShard(FSDPTest):
cntr = Counter() cntr = Counter()
patched_allreduce = partial(patched_collective, orig_ar, cntr) patched_allreduce = partial(patched_collective, orig_ar, cntr)
patched_reduce_scatter = partial(patched_collective, orig_rs, cntr) patched_reduce_scatter = partial(patched_collective, orig_rs, cntr)
with patch_allreduce(patched_allreduce), patch_reduce_scatter( with (
patched_reduce_scatter patch_allreduce(patched_allreduce),
patch_reduce_scatter(patched_reduce_scatter),
): ):
inp = hsdp_model.get_input(device=torch.cuda.current_device()) inp = hsdp_model.get_input(device=torch.cuda.current_device())
out = hsdp_model(inp[0], inp[1]) out = hsdp_model(inp[0], inp[1])
@ -355,9 +358,9 @@ class TestFSDPHybridShard(FSDPTest):
use_orig_params, use_orig_params,
hsdp_process_groups=hsdp_pgs, hsdp_process_groups=hsdp_pgs,
) )
assert ( assert hsdp_model._inter_node_pg.size() > 1, (
hsdp_model._inter_node_pg.size() > 1 "HSDP model initialized without replication"
), "HSDP model initialized without replication" )
fsdp_optim = torch.optim.Adam(fsdp_model.parameters(), lr=1e-2) fsdp_optim = torch.optim.Adam(fsdp_model.parameters(), lr=1e-2)
hsdp_optim = torch.optim.Adam(hsdp_model.parameters(), lr=1e-2) hsdp_optim = torch.optim.Adam(hsdp_model.parameters(), lr=1e-2)
torch.manual_seed(global_pg.rank() + 1) torch.manual_seed(global_pg.rank() + 1)

View File

@ -766,9 +766,9 @@ class TestFSDPMixedPrecisionSharded(TestFSDPMixedPrecision):
if expect_use_full_prec_in_eval: if expect_use_full_prec_in_eval:
assert x.dtype == torch.float32, f"Expected fp32, got {x.dtype}" assert x.dtype == torch.float32, f"Expected fp32, got {x.dtype}"
else: else:
assert ( assert x.dtype == low_prec_dtype, (
x.dtype == low_prec_dtype f"Expected {low_prec_dtype}, got {x.dtype}"
), f"Expected {low_prec_dtype}, got {x.dtype}" )
return self.a(x) return self.a(x)
mp_config = MixedPrecision( mp_config = MixedPrecision(

View File

@ -91,9 +91,9 @@ class TestTPFSDPIntegration(FSDPTest):
tensor_parallel_size: int, tensor_parallel_size: int,
) -> tuple[dict[str, int], dict[str, tuple[torch.Size, int]]]: ) -> tuple[dict[str, int], dict[str, tuple[torch.Size, int]]]:
""" """ """ """
assert ( assert type(model) is SimpleModel, (
type(model) is SimpleModel "Expects a `SimpleModel` since the sharding cases on the model definition"
), "Expects a `SimpleModel` since the sharding cases on the model definition" )
param_name_to_numel = OrderedDict() param_name_to_numel = OrderedDict()
param_name_to_sharding_info = OrderedDict() param_name_to_sharding_info = OrderedDict()
for param_name, param in model.named_parameters(): for param_name, param in model.named_parameters():

View File

@ -654,9 +654,12 @@ class TestFSDPUseOrigParamsUnshardReshard(FSDPTest):
losses1 = [] losses1 = []
losses2 = [] losses2 = []
losses = [] losses = []
for _model, _optim in (fsdp_model, optim), ( for _model, _optim in (
fsdp_model_orig_params, (fsdp_model, optim),
optim_orig_params, (
fsdp_model_orig_params,
optim_orig_params,
),
): ):
_optim.zero_grad() _optim.zero_grad()
loss1 = _model(*inp1) loss1 = _model(*inp1)
@ -1166,9 +1169,9 @@ class TestFSDPUseOrigParamsFQNs(FSDPTest):
clean_tensor_name(tup[0]) for tup in self.named_parameters() clean_tensor_name(tup[0]) for tup in self.named_parameters()
] ]
params = [tup[1] for tup in self.named_parameters()] params = [tup[1] for tup in self.named_parameters()]
assert ( assert param_shapes[0] is not None and param_shapes[1] is not None, (
param_shapes[0] is not None and param_shapes[1] is not None "`param_sizes` should be set"
), "`param_sizes` should be set" )
assert_equal_fn( assert_equal_fn(
param_names, param_names,
[ [

View File

@ -19,6 +19,7 @@ The script itself is not a test case hence no assertions are made in this script
see: - test/distributed/launcher/run_test.py#test_is_torchelastic_launched() see: - test/distributed/launcher/run_test.py#test_is_torchelastic_launched()
- test/distributed/launcher/run_test.py#test_is_not_torchelastic_launched() - test/distributed/launcher/run_test.py#test_is_not_torchelastic_launched()
""" """
import argparse import argparse
import torch.distributed as dist import torch.distributed as dist

View File

@ -292,9 +292,9 @@ class TestZeroRedundancyOptimizerSingleRank(TestZeroRedundancyOptimizer):
betas=BETAS, betas=BETAS,
eps=EPS, eps=EPS,
) )
assert ( assert len(o.param_groups) == 2, (
len(o.param_groups) == 2 f"Expected 2 ZeRO param groups, but got {len(o.param_groups)}"
), f"Expected 2 ZeRO param groups, but got {len(o.param_groups)}" )
assert len(o.optim.param_groups) == 2, ( assert len(o.optim.param_groups) == 2, (
"Expected 2 local optimizer param groups, but got " "Expected 2 local optimizer param groups, but got "
f"{len(o.optim.param_groups)}" f"{len(o.optim.param_groups)}"
@ -713,9 +713,9 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
LR = 1e-3 LR = 1e-3
MOMENTUM = 0.99 MOMENTUM = 0.99
REFERENCE_RANK = 0 REFERENCE_RANK = 0
assert ( assert REFERENCE_RANK in subgroup_ranks, (
REFERENCE_RANK in subgroup_ranks "Reference rank must be in the new process group"
), "Reference rank must be in the new process group" )
loss_fn = torch.nn.L1Loss().to(device) loss_fn = torch.nn.L1Loss().to(device)
def check(optimizer): def check(optimizer):
@ -1165,22 +1165,28 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
# Increased tolerances are needed to pass when using TF32 # Increased tolerances are needed to pass when using TF32
# See: https://github.com/pytorch/pytorch/issues/67764 # See: https://github.com/pytorch/pytorch/issues/67764
torch.testing.assert_close( (
local_loss.cpu(), torch.testing.assert_close(
ddp_loss.cpu(), local_loss.cpu(),
rtol=1e-03, ddp_loss.cpu(),
atol=1e-08, rtol=1e-03,
), "Losses differ between local optimizer and ZeRO" atol=1e-08,
),
"Losses differ between local optimizer and ZeRO",
)
for local_p, ddp_p in zip( for local_p, ddp_p in zip(
local_model.parameters(), ddp_model.parameters() local_model.parameters(), ddp_model.parameters()
): ):
torch.testing.assert_close( (
local_p.cpu(), torch.testing.assert_close(
ddp_p.cpu(), local_p.cpu(),
rtol=1e-03, ddp_p.cpu(),
atol=1e-04, rtol=1e-03,
), "Models differ after a step" atol=1e-04,
),
"Models differ after a step",
)
@skipIfHpu @skipIfHpu
@skip_if_lt_x_gpu(4) @skip_if_lt_x_gpu(4)

View File

@ -89,9 +89,9 @@ class PipeTests(TestCase):
mb_args=(x, y), mb_args=(x, y),
) )
assert ( assert pipe.num_stages == EXPECTED_N_STAGES[ModelClass], (
pipe.num_stages == EXPECTED_N_STAGES[ModelClass] f"nstages = {pipe.num_stages}, expect {EXPECTED_N_STAGES[ModelClass]}"
), f"nstages = {pipe.num_stages}, expect {EXPECTED_N_STAGES[ModelClass]}" )
ref_out = mod(x, y) ref_out = mod(x, y)
out = pipe(x, y)[0] out = pipe(x, y)[0]
@ -109,9 +109,7 @@ class PipeTests(TestCase):
new_names.update(stage_fqns) new_names.update(stage_fqns)
if CHECK_FQN_SET_EQUALITY: if CHECK_FQN_SET_EQUALITY:
assert ( assert old_names == new_names, f"""
old_names == new_names
), f"""
old names {old_names} old names {old_names}
new names {new_names} new names {new_names}
""" """

View File

@ -60,9 +60,9 @@ class UnflattenTests(TestCase):
for stage_idx in range(pipe.num_stages): for stage_idx in range(pipe.num_stages):
stage_mod = pipe.get_stage_module(stage_idx) stage_mod = pipe.get_stage_module(stage_idx)
for param_name, _ in stage_mod.named_parameters(): for param_name, _ in stage_mod.named_parameters():
assert ( assert param_name in orig_state_dict, (
param_name in orig_state_dict f"{param_name} not in original state dict"
), f"{param_name} not in original state dict" )
print("Param qualname test passed") print("Param qualname test passed")
# Check equivalence # Check equivalence

View File

@ -45,9 +45,9 @@ class ShareMemoryRPCPickler(_InternalRPCPickler):
for t in torch._tensor_classes: for t in torch._tensor_classes:
self._dispatch_table[t] = TorchMpReductions.reduce_tensor self._dispatch_table[t] = TorchMpReductions.reduce_tensor
self._dispatch_table[torch.Tensor] = TorchMpReductions.reduce_tensor self._dispatch_table[torch.Tensor] = TorchMpReductions.reduce_tensor
self._dispatch_table[ self._dispatch_table[torch.nn.parameter.Parameter] = (
torch.nn.parameter.Parameter TorchMpReductions.reduce_tensor
] = TorchMpReductions.reduce_tensor )
def worker_loop(a): def worker_loop(a):

View File

@ -872,9 +872,7 @@ def forward(self, primals_1):
"buf0 = torch.ops._c10d_functional.all_gather_into_tensor.default(primal" "buf0 = torch.ops._c10d_functional.all_gather_into_tensor.default(primal"
).check("torch.ops._c10d_functional.wait_tensor.default(buf0").check( ).check("torch.ops._c10d_functional.wait_tensor.default(buf0").check(
"extern_kernels.mm(buf0," "extern_kernels.mm(buf0,"
).run( ).run(code)
code
)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@skip_if_lt_x_gpu(1) @skip_if_lt_x_gpu(1)

View File

@ -441,9 +441,11 @@ class DistMathOpsTest(DTensorTestBase):
out_req_grad: bool out_req_grad: bool
subtest_fails = {} subtest_fails = {}
valid_filter = lambda cfg: not ( # noqa: E731 valid_filter = ( # noqa: E731
cfg.ln_req_grad and not cfg.elementwise_affine lambda cfg: (
) and any(cfg[2:]) not (cfg.ln_req_grad and not cfg.elementwise_affine) and any(cfg[2:])
)
)
subtest_cfgs = list( subtest_cfgs = list(
filter( filter(
valid_filter, valid_filter,
@ -566,9 +568,9 @@ class DistMathOpsTest(DTensorTestBase):
except Exception as e: except Exception as e:
subtest_fails[subtest_cfg] = e subtest_fails[subtest_cfg] = e
# if any subtest fails, provide the failed subtests and report the overall failure # if any subtest fails, provide the failed subtests and report the overall failure
assert ( assert not subtest_fails, (
not subtest_fails f"{len(subtest_fails)}/{len(subtest_cfgs)} subtests failed: {pformat(subtest_fails)}"
), f"{len(subtest_fails)}/{len(subtest_cfgs)} subtests failed: {pformat(subtest_fails)}" )
@with_comms @with_comms
def test_topk(self): def test_topk(self):

View File

@ -26,7 +26,9 @@ def with_xla(func: Callable) -> Callable:
@wraps(func) # pyre-ignore[6] @wraps(func) # pyre-ignore[6]
def wrapper( def wrapper(
self, *args: tuple[object], **kwargs: dict[str, Any] # type: ignore[misc] self,
*args: tuple[object],
**kwargs: dict[str, Any], # type: ignore[misc]
) -> None: ) -> None:
# TODO(yeounoh) replace this with xr.use_spmd() when we deprecate the flag. # TODO(yeounoh) replace this with xr.use_spmd() when we deprecate the flag.
os.environ["XLA_USE_SPMD"] = "1" os.environ["XLA_USE_SPMD"] = "1"

View File

@ -2234,8 +2234,8 @@ class LocalRankTest(MultiProcessTestCase):
if __name__ == "__main__": if __name__ == "__main__":
assert ( assert not torch.cuda._initialized, (
not torch.cuda._initialized "test_distributed must not have initialized CUDA context on main process"
), "test_distributed must not have initialized CUDA context on main process" )
run_tests() run_tests()

View File

@ -796,13 +796,11 @@ class CompileTest(TestCase):
.check("buf6 = empty") .check("buf6 = empty")
# Expect in-place with inductor allocated buf # Expect in-place with inductor allocated buf
.check( .check(
"torch.ops._c10d_functional.all_reduce_coalesced_" "torch.ops._c10d_functional.all_reduce_coalesced_.default([buf0, buf1]"
".default([buf0, buf1]"
) )
# Expect no in-place with graph input (buf5, buf6 are clones) # Expect no in-place with graph input (buf5, buf6 are clones)
.check( .check(
"torch.ops._c10d_functional.all_reduce_coalesced_" "torch.ops._c10d_functional.all_reduce_coalesced_.default([buf5, buf6]"
".default([buf5, buf6]"
) )
.check("torch.ops._c10d_functional.wait_tensor.default(buf0") .check("torch.ops._c10d_functional.wait_tensor.default(buf0")
.check("torch.ops._c10d_functional.wait_tensor.default(buf1") .check("torch.ops._c10d_functional.wait_tensor.default(buf1")

View File

@ -2705,8 +2705,8 @@ class LargeCommTest(test_c10d_common.AbstractLargeCommTest, MultiProcessTestCase
if __name__ == "__main__": if __name__ == "__main__":
assert ( assert not torch.cuda._initialized, (
not torch.cuda._initialized "test_distributed must not have initialized CUDA context on main process"
), "test_distributed must not have initialized CUDA context on main process" )
run_tests() run_tests()

View File

@ -2869,9 +2869,9 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
self.assertTrue(t.is_alive()) self.assertTrue(t.is_alive())
if prev_nccl_async_error_handling is not None: if prev_nccl_async_error_handling is not None:
os.environ[ os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = (
"TORCH_NCCL_ASYNC_ERROR_HANDLING" prev_nccl_async_error_handling
] = prev_nccl_async_error_handling )
@requires_nccl() @requires_nccl()
@skip_if_lt_x_gpu(3) @skip_if_lt_x_gpu(3)
@ -2931,9 +2931,9 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "0" os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "0"
self._test_barrier_error() self._test_barrier_error()
if prev_nccl_async_error_handling is not None: if prev_nccl_async_error_handling is not None:
os.environ[ os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = (
"TORCH_NCCL_ASYNC_ERROR_HANDLING" prev_nccl_async_error_handling
] = prev_nccl_async_error_handling )
@requires_nccl() @requires_nccl()
@requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking") @requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
@ -2984,9 +2984,9 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
process_group.abort() process_group.abort()
if prev_nccl_async_error_handling is not None: if prev_nccl_async_error_handling is not None:
os.environ[ os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = (
"TORCH_NCCL_ASYNC_ERROR_HANDLING" prev_nccl_async_error_handling
] = prev_nccl_async_error_handling )
@requires_nccl() @requires_nccl()
@requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking") @requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
@ -3065,9 +3065,9 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
os.remove(new_file_name) os.remove(new_file_name)
if prev_nccl_async_error_handling is not None: if prev_nccl_async_error_handling is not None:
os.environ[ os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = (
"TORCH_NCCL_ASYNC_ERROR_HANDLING" prev_nccl_async_error_handling
] = prev_nccl_async_error_handling )
def _run_invalid_nccl_blocking_wait_env(self, val): def _run_invalid_nccl_blocking_wait_env(self, val):
os.environ["TORCH_NCCL_BLOCKING_WAIT"] = val os.environ["TORCH_NCCL_BLOCKING_WAIT"] = val
@ -3360,9 +3360,7 @@ class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase):
self.assertEqual(_get_intra_node_comm_usage_counter(), 3) self.assertEqual(_get_intra_node_comm_usage_counter(), 3)
# Verify that IntraNodeComm is not used beyond 10MB # Verify that IntraNodeComm is not used beyond 10MB
t = torch.full( t = torch.full((10 * 1024**2 // 2 + 1,), self.rank, dtype=torch.bfloat16).cuda()
(10 * 1024**2 // 2 + 1,), self.rank, dtype=torch.bfloat16
).cuda()
c10d.all_reduce(t, c10d.ReduceOp.SUM) c10d.all_reduce(t, c10d.ReduceOp.SUM)
self.assertTrue(t.eq(expect).all()) self.assertTrue(t.eq(expect).all())
self.assertEqual(_get_intra_node_comm_usage_counter(), 3) self.assertEqual(_get_intra_node_comm_usage_counter(), 3)
@ -4249,9 +4247,9 @@ class SparseCollective(MultiProcessTestCase):
class NCCLTraceTestBase(MultiProcessTestCase): class NCCLTraceTestBase(MultiProcessTestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
os.environ[ os.environ["TORCH_NCCL_ENABLE_TIMING"] = (
"TORCH_NCCL_ENABLE_TIMING" "0" # see 'timing_enabled' parametrized tests
] = "0" # see 'timing_enabled' parametrized tests )
os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "1000" os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "1000"
os.environ["TORCH_NCCL_DUMP_ON_TIMEOUT"] = "1" os.environ["TORCH_NCCL_DUMP_ON_TIMEOUT"] = "1"
self.tempdir = tempfile.TemporaryDirectory() self.tempdir = tempfile.TemporaryDirectory()
@ -5331,8 +5329,8 @@ class ProcessGroupNCCLLargerScaleTest(MultiProcessTestCase):
if __name__ == "__main__": if __name__ == "__main__":
assert ( assert not torch.cuda._initialized, (
not torch.cuda._initialized "test_distributed must not have initialized CUDA context on main process"
), "test_distributed must not have initialized CUDA context on main process" )
run_tests() run_tests()

View File

@ -1090,8 +1090,8 @@ class UccProcessGroupWithDispatchedCollectivesTests(
if __name__ == "__main__": if __name__ == "__main__":
assert ( assert not torch.cuda._initialized, (
not torch.cuda._initialized "test_distributed must not have initialized CUDA context on main process"
), "test_distributed must not have initialized CUDA context on main process" )
run_tests() run_tests()

View File

@ -90,9 +90,9 @@ class TestCollectiveUtils(MultiProcessTestCase):
res = all_gather(data_or_fn=func, pg=pg) res = all_gather(data_or_fn=func, pg=pg)
func.assert_called_once() func.assert_called_once()
assert res == list( assert res == list(range(self.world_size)), (
range(self.world_size) f"Expect res to be list of 0 through {self.world_size} (got {res})"
), f"Expect res to be list of 0 through {self.world_size} (got {res})" )
def test_all_gather_result_no_pg(self) -> None: def test_all_gather_result_no_pg(self) -> None:
""" """

View File

@ -207,8 +207,8 @@ class TestCollectives(TestCase):
if __name__ == "__main__": if __name__ == "__main__":
assert ( assert not torch.cuda._initialized, (
not torch.cuda._initialized "test_distributed must not have initialized CUDA context on main process"
), "test_distributed must not have initialized CUDA context on main process" )
run_tests() run_tests()

View File

@ -490,8 +490,9 @@ class DeviceMeshTestNDim(DTensorTestBase):
# Create shard groups (e.g. (0, 1, 2, 3), (4, 5, 6, 7)) # Create shard groups (e.g. (0, 1, 2, 3), (4, 5, 6, 7))
# and assign the correct shard group to each rank # and assign the correct shard group to each rank
shard_rank_lists = list(range(0, self.world_size // 2)), list( shard_rank_lists = (
range(self.world_size // 2, self.world_size) list(range(0, self.world_size // 2)),
list(range(self.world_size // 2, self.world_size)),
) )
shard_groups = ( shard_groups = (
new_group(shard_rank_lists[0]), new_group(shard_rank_lists[0]),

View File

@ -1830,9 +1830,7 @@ class TestSingleProc(DynamoDistributedSingleProcTestCase):
f"""{expected_guard_source} "L['self']._modules['net']" TYPE_MATCH""" f"""{expected_guard_source} "L['self']._modules['net']" TYPE_MATCH"""
).check( ).check(
f"""{expected_guard_source} "L['self']._modules['net']._modules['0']" TYPE_MATCH""" f"""{expected_guard_source} "L['self']._modules['net']._modules['0']" TYPE_MATCH"""
).run( ).run(GUARDS_FILE.getvalue())
GUARDS_FILE.getvalue()
)
self.assertTrue(same(correct_outputs, outputs)) self.assertTrue(same(correct_outputs, outputs))

View File

@ -519,12 +519,13 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
out = a2a / a2a.sum(dim=0) out = a2a / a2a.sum(dim=0)
return out return out
with _dynamo_dist_per_rank_init( with (
self.rank, self.world_size _dynamo_dist_per_rank_init(self.rank, self.world_size),
), torch._dynamo.config.patch( torch._dynamo.config.patch(
dynamic_shapes=True, dynamic_shapes=True,
capture_dynamic_output_shape_ops=True, capture_dynamic_output_shape_ops=True,
capture_scalar_outputs=True, capture_scalar_outputs=True,
),
): ):
row = self.world_size * (self.rank + 1) * (self.world_size + 1) / 2 row = self.world_size * (self.rank + 1) * (self.world_size + 1) / 2
input_split_sizes_tensor = torch.tensor( input_split_sizes_tensor = torch.tensor(
@ -680,15 +681,15 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
return torch.ops.custom_ns.foo(a2a) return torch.ops.custom_ns.foo(a2a)
with _dynamo_dist_per_rank_init( with (
self.rank, self.world_size _dynamo_dist_per_rank_init(self.rank, self.world_size),
), torch._dynamo.config.patch( torch._dynamo.config.patch(
dynamic_shapes=True, dynamic_shapes=True,
capture_dynamic_output_shape_ops=True, capture_dynamic_output_shape_ops=True,
capture_scalar_outputs=True, capture_scalar_outputs=True,
), torch.library._scoped_library( ),
"custom_ns", "FRAGMENT" torch.library._scoped_library("custom_ns", "FRAGMENT") as lib,
) as lib: ):
lib.define( lib.define(
"alltoall_autograd(Tensor input, SymInt[]? output_split_sizes, SymInt[]? input_split_sizes, str tag, int[] ranks, int group_size) -> Tensor" # noqa: B950 "alltoall_autograd(Tensor input, SymInt[]? output_split_sizes, SymInt[]? input_split_sizes, str tag, int[] ranks, int group_size) -> Tensor" # noqa: B950
) )

View File

@ -464,8 +464,8 @@ class ProcessGroupGlooWrapperTest(AbstractProcessGroupWrapperTest):
if __name__ == "__main__": if __name__ == "__main__":
assert ( assert not torch.cuda._initialized, (
not torch.cuda._initialized "test_pg_wrapper must not have initialized CUDA context on main process"
), "test_pg_wrapper must not have initialized CUDA context on main process" )
run_tests() run_tests()

View File

@ -372,12 +372,8 @@ class TCPStoreTest(TestCase, StoreTestBase):
# Use noqa to silence flake8. # Use noqa to silence flake8.
# Need to store in an unused variable here to ensure the first # Need to store in an unused variable here to ensure the first
# object is not destroyed before the second object is created. # object is not destroyed before the second object is created.
store1 = dist.TCPStore( store1 = dist.TCPStore(addr, port, 1, True, use_libuv=self._use_libuv) # noqa: F841
addr, port, 1, True, use_libuv=self._use_libuv store2 = dist.TCPStore(addr, port, 1, True, use_libuv=self._use_libuv) # noqa: F841
) # noqa: F841
store2 = dist.TCPStore(
addr, port, 1, True, use_libuv=self._use_libuv
) # noqa: F841
self.assertEqual(store1.libuvBackend, self._use_libuv) self.assertEqual(store1.libuvBackend, self._use_libuv)
self.assertEqual(store2.libuvBackend, self._use_libuv) self.assertEqual(store2.libuvBackend, self._use_libuv)
@ -767,7 +763,7 @@ class RendezvousFileTest(TestCase):
def test_nominal(self): def test_nominal(self):
with tempfile.NamedTemporaryFile(delete=False) as file: with tempfile.NamedTemporaryFile(delete=False) as file:
url = f'file:///{file.name.replace(os.path.sep, "/")}?world_size=2' url = f"file:///{file.name.replace(os.path.sep, '/')}?world_size=2"
gen0 = dist.rendezvous(url + "&rank=0") gen0 = dist.rendezvous(url + "&rank=0")
store0, rank0, size0 = next(gen0) store0, rank0, size0 = next(gen0)
self.assertEqual(0, rank0) self.assertEqual(0, rank0)
@ -1178,8 +1174,8 @@ class TestClientProtocol(TestCase):
if __name__ == "__main__": if __name__ == "__main__":
assert ( assert not torch.cuda._initialized, (
not torch.cuda._initialized "test_distributed must not have initialized CUDA context on main process"
), "test_distributed must not have initialized CUDA context on main process" )
run_tests() run_tests()

View File

@ -81,9 +81,9 @@ def count_ops(
for node in gm.graph.nodes: for node in gm.graph.nodes:
if match_rng_op(node, op) or node.target == op: if match_rng_op(node, op) or node.target == op:
actual_count += 1 actual_count += 1
assert ( assert actual_count >= freq_ge, (
actual_count >= freq_ge f"In graph {gm}, expected {op} to have occurred at least {freq_ge} times in the graph, but got {actual_count}."
), f"In graph {gm}, expected {op} to have occurred at least {freq_ge} times in the graph, but got {actual_count}." )
return gm return gm

View File

@ -455,9 +455,9 @@ class AutogradFunctionTests(torch._dynamo.test_case.TestCase):
# Modify gradient using .data (Dangerous: Breaks autograd tracking!) # Modify gradient using .data (Dangerous: Breaks autograd tracking!)
modified_grad = grad_output.clone() modified_grad = grad_output.clone()
modified_grad.data[ modified_grad.data[input_tensor.data < 0] = (
input_tensor.data < 0 0 # Zero-out gradients for negative inputs
] = 0 # Zero-out gradients for negative inputs )
return modified_grad * 3 return modified_grad * 3

View File

@ -195,10 +195,13 @@ class GraphModule(torch.nn.Module):
def f(x, y): def f(x, y):
return invoke_quant_test(inner, [x, y], scheme="nf4") return invoke_quant_test(inner, [x, y], scheme="nf4")
with mock.patch( with (
"torch._dynamo.variables.higher_order_ops.BaseHOPVariable.supports_input_mutation", mock.patch(
True, "torch._dynamo.variables.higher_order_ops.BaseHOPVariable.supports_input_mutation",
), torch.no_grad(): True,
),
torch.no_grad(),
):
torch.compile(f, backend=bk, fullgraph=True)(x.clone(), y) torch.compile(f, backend=bk, fullgraph=True)(x.clone(), y)
self.assertEqual(len(bk.graphs), 1) self.assertEqual(len(bk.graphs), 1)
@ -319,10 +322,13 @@ class GraphModule(torch.nn.Module):
x = torch.randn(3, 3, requires_grad=False) x = torch.randn(3, 3, requires_grad=False)
x_clone = x.clone() x_clone = x.clone()
y = torch.randn(3, 3, requires_grad=True) y = torch.randn(3, 3, requires_grad=True)
with mock.patch( with (
"torch._dynamo.variables.higher_order_ops.BaseHOPVariable.supports_input_mutation", mock.patch(
True, "torch._dynamo.variables.higher_order_ops.BaseHOPVariable.supports_input_mutation",
), torch.no_grad(): True,
),
torch.no_grad(),
):
compiled_out = torch.compile(f, backend=backend, fullgraph=True)(x, y) compiled_out = torch.compile(f, backend=backend, fullgraph=True)(x, y)
self.assertEqual(x, x_clone + 1) self.assertEqual(x, x_clone + 1)
self.assertEqual(compiled_out, x_clone + y + 1) self.assertEqual(compiled_out, x_clone + y + 1)

View File

@ -53,8 +53,8 @@ class BytecodeTests(torch._dynamo.test_case.TestCase):
fn_str = f"""\ fn_str = f"""\
def fn(): def fn():
foo.bar(1, 2, 3) foo.bar(1, 2, 3)
{str(chr(10)).join(' ' * 4 + 'x' + str(i) + ' = 1' for i in range(1 << 9))} {str(chr(10)).join(" " * 4 + "x" + str(i) + " = 1" for i in range(1 << 9))}
l = [{' '.join('x' + str(i) + ',' for i in range(1 << 9))}] l = [{" ".join("x" + str(i) + "," for i in range(1 << 9))}]
""" """
locals = {} locals = {}
exec(fn_str, {}, locals) exec(fn_str, {}, locals)

View File

@ -26,9 +26,10 @@ class CallbackTests(TestCase):
def test_callbacks_with_duplicate_prevention(self) -> None: def test_callbacks_with_duplicate_prevention(self) -> None:
trigger = CallbackTrigger.DYNAMO trigger = CallbackTrigger.DYNAMO
compile_id = CompileId(0, 0) compile_id = CompileId(0, 0)
with callback_handler.install_callbacks( with (
trigger, compile_id callback_handler.install_callbacks(trigger, compile_id),
), callback_handler.install_callbacks(trigger, compile_id): callback_handler.install_callbacks(trigger, compile_id),
):
self._on_compile_start.assert_called_once() self._on_compile_start.assert_called_once()
self._on_compile_end.assert_called_once() self._on_compile_end.assert_called_once()

View File

@ -1252,10 +1252,13 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase):
def f(x, y): def f(x, y):
return x + y return x + y
x, y = torch.ones( x, y = (
1, torch.ones(
), torch.zeros( 1,
1, ),
torch.zeros(
1,
),
) )
return f(x, y) return f(x, y)
@ -1289,10 +1292,13 @@ class GraphModule(torch.nn.Module):
def f(x, y): def f(x, y):
return x + y return x + y
x, y = torch.ones( x, y = (
1, torch.ones(
), torch.zeros( 1,
1, ),
torch.zeros(
1,
),
) )
return f(x, y) return f(x, y)
@ -1335,10 +1341,13 @@ class GraphModule(torch.nn.Module):
return inner_fn(x, y) + x return inner_fn(x, y) + x
x, y = torch.ones( x, y = (
1, torch.ones(
), torch.zeros( 1,
1, ),
torch.zeros(
1,
),
) )
return f(x, y) return f(x, y)

View File

@ -19,7 +19,7 @@ from torch.testing._internal.common_utils import (
class CustomException(Exception): class CustomException(Exception):
... pass
class CustomExceptionMeta(type): class CustomExceptionMeta(type):
@ -28,7 +28,7 @@ class CustomExceptionMeta(type):
class CustomExceptionWithInstanceCheck(Exception, metaclass=CustomExceptionMeta): class CustomExceptionWithInstanceCheck(Exception, metaclass=CustomExceptionMeta):
... pass
class CustomExceptionWithArgs(Exception): class CustomExceptionWithArgs(Exception):
@ -358,7 +358,7 @@ class ExceptionTests(torch._dynamo.test_case.TestCase):
def test_raise_custom_exception(self): def test_raise_custom_exception(self):
class Exc(Exception): class Exc(Exception):
... pass
@torch.compile(backend="eager", fullgraph=True) @torch.compile(backend="eager", fullgraph=True)
def fn(t): def fn(t):
@ -375,7 +375,7 @@ class ExceptionTests(torch._dynamo.test_case.TestCase):
def test_raise_custom_exception_with_args(self): def test_raise_custom_exception_with_args(self):
class Exc(Exception): class Exc(Exception):
... pass
@torch.compile(backend="eager", fullgraph=True) @torch.compile(backend="eager", fullgraph=True)
def fn(t): def fn(t):

View File

@ -324,9 +324,13 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
distributed_state=None, distributed_state=None,
package=None, package=None,
) )
with compile_context(CompileContext(CompileId(0, 0))), tracing( with (
tracer.output.tracing_context compile_context(CompileContext(CompileId(0, 0))),
), tracer.set_current_tx(), get_metrics_context(), dynamo_timed(""): tracing(tracer.output.tracing_context),
tracer.set_current_tx(),
get_metrics_context(),
dynamo_timed(""),
):
tracer.run() tracer.run()
check_fn_manager = CheckFunctionManager( check_fn_manager = CheckFunctionManager(

View File

@ -1092,9 +1092,7 @@ not ___dict_contains('bbbbbbbb', G['sys'].modules)
not ___dict_contains('cccccccc', G['sys'].modules) not ___dict_contains('cccccccc', G['sys'].modules)
str(L['x'].device) == 'cpu' str(L['x'].device) == 'cpu'
str(L['x'].dtype) == 'torch.float32' str(L['x'].dtype) == 'torch.float32'
utils_device.CURRENT_DEVICE == None""".split( utils_device.CURRENT_DEVICE == None""".split("\n"):
"\n"
):
self.assertIn( self.assertIn(
line, line,
guard_code_str, guard_code_str,
@ -2806,7 +2804,7 @@ utils_device.CURRENT_DEVICE == None""".split(
"int", "int",
np.intp, np.intp,
np.int32, np.int32,
np.uint8 np.uint8,
# np.dtype('int') # XXX: as above # np.dtype('int') # XXX: as above
] ]
@ -5527,9 +5525,9 @@ utils_device.CURRENT_DEVICE == None""".split(
def forward(self, idx, targets=None): def forward(self, idx, targets=None):
b, t = idx.size() b, t = idx.size()
assert ( assert t <= self.block_size, (
t <= self.block_size "Cannot forward, model block size is exhausted."
), "Cannot forward, model block size is exhausted." )
# forward the GPT model # forward the GPT model
token_embeddings = self.tok_emb( token_embeddings = self.tok_emb(
@ -6075,15 +6073,17 @@ utils_device.CURRENT_DEVICE == None""".split(
def count_graph_break_msgs(msgs): def count_graph_break_msgs(msgs):
return sum("Graph break in user code" in msg for msg in msgs) return sum("Graph break in user code" in msg for msg in msgs)
with self.assertLogs( with (
logger="torch._dynamo", level=logging.DEBUG self.assertLogs(logger="torch._dynamo", level=logging.DEBUG) as log,
) as log, torch._dynamo.config.patch(verbose=True): torch._dynamo.config.patch(verbose=True),
):
f1(torch.randn(10), torch.randn(10)) f1(torch.randn(10), torch.randn(10))
self.assertGreater(count_graph_break_msgs(log.output), 1) self.assertGreater(count_graph_break_msgs(log.output), 1)
with self.assertLogs( with (
logger="torch._dynamo", level=logging.DEBUG self.assertLogs(logger="torch._dynamo", level=logging.DEBUG) as log,
) as log, torch._dynamo.config.patch(verbose=False): torch._dynamo.config.patch(verbose=False),
):
g1(torch.randn(10), torch.randn(10)) g1(torch.randn(10), torch.randn(10))
self.assertEqual(count_graph_break_msgs(log.output), 1) self.assertEqual(count_graph_break_msgs(log.output), 1)
@ -8235,8 +8235,9 @@ utils_device.CURRENT_DEVICE == None""".split(
def f(a): def f(a):
return h(a) return h(a)
with warnings.catch_warnings(record=True) as w, self.assertRaises( with (
torch._dynamo.exc.BackendCompilerFailed warnings.catch_warnings(record=True) as w,
self.assertRaises(torch._dynamo.exc.BackendCompilerFailed),
): ):
f(torch.randn(2, 2, requires_grad=True)) f(torch.randn(2, 2, requires_grad=True))
@ -8429,8 +8430,7 @@ utils_device.CURRENT_DEVICE == None""".split(
def test_torch_compile_ctx_on_forward_and_training_step(self): def test_torch_compile_ctx_on_forward_and_training_step(self):
class MyModel(torch.nn.Module): class MyModel(torch.nn.Module):
def forward(self): def forward(self): ...
...
def training_step(self): def training_step(self):
self() self()

View File

@ -2094,11 +2094,12 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
mod = MockModule() mod = MockModule()
# Each submod is compiled separately and has a different nn module # Each submod is compiled separately and has a different nn module
# guard. Ensure that recompilation logic is handle correctly. # guard. Ensure that recompilation logic is handle correctly.
with unittest.mock.patch( with (
"torch._dynamo.config.error_on_recompile", True unittest.mock.patch("torch._dynamo.config.error_on_recompile", True),
), unittest.mock.patch( unittest.mock.patch(
"torch._dynamo.config.recompile_limit", "torch._dynamo.config.recompile_limit",
recompile_limit, recompile_limit,
),
): ):
x = torch.randn(*size, requires_grad=True) x = torch.randn(*size, requires_grad=True)
mod(x) mod(x)
@ -2160,11 +2161,12 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
mod = MockModule() mod = MockModule()
# Each submod is compiled separately and has a different nn module # Each submod is compiled separately and has a different nn module
# guard. Ensure that recompilation logic is handle correctly. # guard. Ensure that recompilation logic is handle correctly.
with unittest.mock.patch( with (
"torch._dynamo.config.error_on_recompile", True unittest.mock.patch("torch._dynamo.config.error_on_recompile", True),
), unittest.mock.patch( unittest.mock.patch(
"torch._dynamo.config.recompile_limit", "torch._dynamo.config.recompile_limit",
recompile_limit, recompile_limit,
),
): ):
x = torch.randn(*size, requires_grad=True) x = torch.randn(*size, requires_grad=True)
mod(x) mod(x)

View File

@ -3,6 +3,7 @@
PYTEST_DONT_REWRITE (prevents pytest from rewriting assertions, which interferes PYTEST_DONT_REWRITE (prevents pytest from rewriting assertions, which interferes
with test_adam in OptimizerTests) with test_adam in OptimizerTests)
""" """
import functools import functools
import torch import torch

View File

@ -238,9 +238,7 @@ class RecompileUxTests(torch._dynamo.test_case.TestCase):
tensor 'x' size mismatch at index 0. expected 11, actual 12 tensor 'x' size mismatch at index 0. expected 11, actual 12
tensor 'x' size mismatch at index 0. expected 10, actual 12 tensor 'x' size mismatch at index 0. expected 10, actual 12
tensor 'x' size mismatch at index 0. expected 9, actual 12 tensor 'x' size mismatch at index 0. expected 9, actual 12
tensor 'x' size mismatch at index 0. expected 8, actual 12""".split( tensor 'x' size mismatch at index 0. expected 8, actual 12""".split("\n"):
"\n"
):
self.assertIn( self.assertIn(
line, line,
failure_str, failure_str,
@ -276,9 +274,7 @@ tensor 'x' size mismatch at index 0. expected 8, actual 12""".split(
opt_f([7, 8]) opt_f([7, 8])
for line in """\ for line in """\
len(x) == 3""".split( len(x) == 3""".split("\n"):
"\n"
):
self.assertIn(line, filter_reasons()) self.assertIn(line, filter_reasons())
failure_reasons.clear() failure_reasons.clear()
@ -286,9 +282,7 @@ len(x) == 3""".split(
for line in """\ for line in """\
len(x) == 2 len(x) == 2
len(x) == 3""".split( len(x) == 3""".split("\n"):
"\n"
):
self.assertIn(line, filter_reasons()) self.assertIn(line, filter_reasons())
@torch._dynamo.config.patch(recompile_limit=1) @torch._dynamo.config.patch(recompile_limit=1)

View File

@ -179,9 +179,9 @@ def shapes_to_tensor(x, device=None):
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return torch.as_tensor(x, device=device) return torch.as_tensor(x, device=device)
if torch.jit.is_tracing(): if torch.jit.is_tracing():
assert all( assert all(isinstance(t, torch.Tensor) for t in x), (
isinstance(t, torch.Tensor) for t in x "Shape should be tensor during tracing!"
), "Shape should be tensor during tracing!" )
# as_tensor should not be used in tracing because it records a constant # as_tensor should not be used in tracing because it records a constant
ret = torch.stack(x) ret = torch.stack(x)
if ret.device != device: # avoid recording a hard-coded device if not necessary if ret.device != device: # avoid recording a hard-coded device if not necessary
@ -480,9 +480,9 @@ class PartialT5(torch.nn.Module):
real_seq_length = seq_length real_seq_length = seq_length
if past_key_value is not None: if past_key_value is not None:
assert ( assert len(past_key_value) == 2, (
len(past_key_value) == 2 f"past_key_value should have 2 past states: keys and values. Got {len(past_key_value)} past states"
), f"past_key_value should have 2 past states: keys and values. Got {len(past_key_value)} past states" )
real_seq_length += ( real_seq_length += (
past_key_value[0].shape[2] if query_length is None else query_length past_key_value[0].shape[2] if query_length is None else query_length
) )
@ -4877,9 +4877,9 @@ def forward(self, s77 : torch.SymInt, s27 : torch.SymInt, L_x_ : torch.Tensor):
with warnings.catch_warnings(record=True): with warnings.catch_warnings(record=True):
data_len = len(value) data_len = len(value)
if len(self._fields): if len(self._fields):
assert ( assert len(self) == data_len, (
len(self) == data_len f"Adding a field of length {data_len} to a Instances of length {len(self)}"
), f"Adding a field of length {data_len} to a Instances of length {len(self)}" )
self._fields[name] = value self._fields[name] = value
def get(self, name: str) -> Any: def get(self, name: str) -> Any:

View File

@ -108,9 +108,10 @@ def get_view_test_cases():
for requires_grad_1, requires_grad_2 in itertools.product( for requires_grad_1, requires_grad_2 in itertools.product(
[True, False], repeat=2 [True, False], repeat=2
): ):
yield partial( yield (
mk_leaf, base_is_nt, requires_grad_1, requires_grad_2 partial(mk_leaf, base_is_nt, requires_grad_1, requires_grad_2),
), f"{prefix}_leaf_{requires_grad_1}_{requires_grad_2}" f"{prefix}_leaf_{requires_grad_1}_{requires_grad_2}",
)
# (3) obscure case: # (3) obscure case:
# view is not a leaf (implies requires_grad True) # view is not a leaf (implies requires_grad True)
@ -118,9 +119,10 @@ def get_view_test_cases():
yield partial(mk_obscure, base_is_nt), f"{prefix}_obscure" yield partial(mk_obscure, base_is_nt), f"{prefix}_obscure"
# Subclass -> Dense # Subclass -> Dense
yield lambda: get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)[ yield (
0 lambda: get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)[0].clone(),
].clone(), "subclass_dense" "subclass_dense",
)
# Dense -> Subclass -> Dense -> Subclass # Dense -> Subclass -> Dense -> Subclass
def mk_dense_subclass_dense_subclass(): def mk_dense_subclass_dense_subclass():

View File

@ -151,9 +151,9 @@ def gen_allowed_objs_and_ids(record=False, c_binding_only=True) -> AllowedObject
types.WrapperDescriptorType, types.WrapperDescriptorType,
), ),
) or is_special_functions(obj): ) or is_special_functions(obj):
torch_name_rule_map[ torch_name_rule_map[f"{module.__name__}.{name}"] = (
f"{module.__name__}.{name}" TorchInGraphFunctionVariable
] = TorchInGraphFunctionVariable )
if c_binding_only: if c_binding_only:
if not hasattr(obj, "__code__"): if not hasattr(obj, "__code__"):
c_binding_in_graph_functions.add(obj) c_binding_in_graph_functions.add(obj)
@ -398,12 +398,15 @@ class TraceRuleTests(torch._dynamo.test_case.TestCase):
) )
self.assertTrue("torch._dynamo" not in torch._dynamo.trace_rules.MOD_INLINELIST) self.assertTrue("torch._dynamo" not in torch._dynamo.trace_rules.MOD_INLINELIST)
with unittest.mock.patch( with (
"torch._dynamo.trace_rules.torch_name_rule_map", unittest.mock.patch(
_torch_name_rule_map, "torch._dynamo.trace_rules.torch_name_rule_map",
), unittest.mock.patch( _torch_name_rule_map,
"torch._dynamo.trace_rules.get_torch_obj_rule_map", ),
torch._dynamo.trace_rules.get_torch_obj_rule_map.__wrapped__, # bypass functools.lru_cache unittest.mock.patch(
"torch._dynamo.trace_rules.get_torch_obj_rule_map",
torch._dynamo.trace_rules.get_torch_obj_rule_map.__wrapped__, # bypass functools.lru_cache
),
): ):
x = torch.rand(3) x = torch.rand(3)
opt_fn = torch.compile(backend="eager", fullgraph=True)(fn) opt_fn = torch.compile(backend="eager", fullgraph=True)(fn)
@ -419,9 +422,9 @@ class TraceRuleTests(torch._dynamo.test_case.TestCase):
_manual_torch_name_rule_map = manual_torch_name_rule_map.copy() _manual_torch_name_rule_map = manual_torch_name_rule_map.copy()
# Force inline `mod.func` by setting trace rule. # Force inline `mod.func` by setting trace rule.
_manual_torch_name_rule_map[ _manual_torch_name_rule_map[f"{mod.__name__}.{func.__name__}"] = (
f"{mod.__name__}.{func.__name__}" UserFunctionVariable
] = UserFunctionVariable )
_torch_name_rule_map = [ _torch_name_rule_map = [
_manual_torch_name_rule_map, _manual_torch_name_rule_map,
@ -429,12 +432,15 @@ class TraceRuleTests(torch._dynamo.test_case.TestCase):
torch_non_c_binding_in_graph_functions, torch_non_c_binding_in_graph_functions,
] ]
with unittest.mock.patch( with (
"torch._dynamo.trace_rules.torch_name_rule_map", unittest.mock.patch(
_torch_name_rule_map, "torch._dynamo.trace_rules.torch_name_rule_map",
), unittest.mock.patch( _torch_name_rule_map,
"torch._dynamo.trace_rules.get_torch_obj_rule_map", ),
torch._dynamo.trace_rules.get_torch_obj_rule_map.__wrapped__, unittest.mock.patch(
"torch._dynamo.trace_rules.get_torch_obj_rule_map",
torch._dynamo.trace_rules.get_torch_obj_rule_map.__wrapped__,
),
): ):
# First adding the module to SKIP_DIRS so that it will be skipped by default. # First adding the module to SKIP_DIRS so that it will be skipped by default.
torch._dynamo.trace_rules.add(mod.__name__) torch._dynamo.trace_rules.add(mod.__name__)

View File

@ -593,9 +593,10 @@ class TestDynamoTimed(TestCase):
) )
compilation_events = [] compilation_events = []
with dynamo_config.patch({"automatic_dynamic_shapes": False}), mock.patch( with (
"torch._dynamo.utils.log_compilation_event" dynamo_config.patch({"automatic_dynamic_shapes": False}),
) as log_event: mock.patch("torch._dynamo.utils.log_compilation_event") as log_event,
):
@torch.compile() @torch.compile()
def f(x): def f(x):

View File

@ -181,9 +181,12 @@ class TestDraftExport(TestCase):
self.assertEqual(len(report.op_profiles), 1) self.assertEqual(len(report.op_profiles), 1)
self.assertEqual(len(report.op_profiles["mylib.foo8.default"]), 1) self.assertEqual(len(report.op_profiles["mylib.foo8.default"]), 1)
with torch._library.fake_profile.unsafe_generate_fake_kernels( with (
report.op_profiles torch._library.fake_profile.unsafe_generate_fake_kernels(
), FakeTensorMode(allow_non_fake_inputs=True, shape_env=ShapeEnv()): report.op_profiles
),
FakeTensorMode(allow_non_fake_inputs=True, shape_env=ShapeEnv()),
):
torch.ops.mylib.foo8(*new_inp) torch.ops.mylib.foo8(*new_inp)
# Existing registration has been updated to match the new # Existing registration has been updated to match the new

View File

@ -12834,12 +12834,14 @@ def forward(self, x, y):
"y": [Dim("dy")], # y & z incorrect, export is supposed to fail. "y": [Dim("dy")], # y & z incorrect, export is supposed to fail.
"z": [Dim("dz")], # suggested fix should be to match these up. "z": [Dim("dz")], # suggested fix should be to match these up.
} }
with self.assertRaisesRegex( # if disable=True, suggested fixes should not specialize. with (
torch._dynamo.exc.UserError, self.assertRaisesRegex( # if disable=True, suggested fixes should not specialize.
r".*Constraints violated(.*\n)*" torch._dynamo.exc.UserError,
r"Suggested fixes:(.*\n)*" r".*Constraints violated(.*\n)*"
r".*dz = dy(.*\n)*", r"Suggested fixes:(.*\n)*"
) as msg: r".*dz = dy(.*\n)*",
) as msg
):
export( export(
Foo(), Foo(),
inputs, inputs,
@ -13675,8 +13677,7 @@ def forward(self, x):
"""Make sure the metadata is kept after exported program run_decompositions.""" """Make sure the metadata is kept after exported program run_decompositions."""
@torch.library.custom_op("mylib::add", mutates_args=()) @torch.library.custom_op("mylib::add", mutates_args=())
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: ...
...
@torch.library.register_fake("mylib::add") @torch.library.register_fake("mylib::add")
def _(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: def _(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:

View File

@ -947,10 +947,12 @@ class TestDeserialize(TestCase):
ep = torch.export.export(M(), (torch.ones(3, 3), None, torch.ones(3, 3))) ep = torch.export.export(M(), (torch.ones(3, 3), None, torch.ones(3, 3)))
serialized_program = ExportedProgramSerializer(None, 2).serialize(ep) serialized_program = ExportedProgramSerializer(None, 2).serialize(ep)
serialized_program.exported_program.graph_module.signature.input_specs[ serialized_program.exported_program.graph_module.signature.input_specs[1] = (
1 schema.InputSpec.create(
] = schema.InputSpec.create( user_input=schema.UserInputSpec(
user_input=schema.UserInputSpec(arg=schema.Argument.create(as_none=True)) arg=schema.Argument.create(as_none=True)
)
)
) )
ep = ExportedProgramDeserializer(None).deserialize( ep = ExportedProgramDeserializer(None).deserialize(
serialized_program.exported_program, {}, {}, {} serialized_program.exported_program, {}, {}, {}

View File

@ -22,7 +22,9 @@ from torch.testing._internal.common_utils import IS_WINDOWS, run_tests, TestCase
{"strict": False}, {"strict": False},
{"strict": True}, {"strict": True},
], ],
class_name_func=lambda cls, _, params: f"{cls.__name__}_{'strict' if params['strict'] else 'nonstrict'}", class_name_func=lambda cls,
_,
params: f"{cls.__name__}_{'strict' if params['strict'] else 'nonstrict'}",
) )
class TestSwap(TestCase): class TestSwap(TestCase):
def test_unflatten_preserve_signature(self): def test_unflatten_preserve_signature(self):

View File

@ -348,8 +348,7 @@ def check_fc(existing_schemas):
"\n\t".join(str(s) for s in matching_new_schemas), "\n\t".join(str(s) for s in matching_new_schemas),
) )
log.warning( log.warning(
"Refer to following reasons for failure " "Refer to following reasons for failure to find FC schema:\n[\n%s\n]",
"to find FC schema:\n[\n%s\n]",
"\n\t".join(str(r) for r in possible_failure_reasons), "\n\t".join(str(r) for r in possible_failure_reasons),
) )
broken_ops.append(str(existing_schema)) broken_ops.append(str(existing_schema))

View File

@ -523,15 +523,15 @@ def decorateForModules(decorator, module_classes, device_type=None, dtypes=None)
dtypes=dtypes, dtypes=dtypes,
): ):
name_parts = fn.__qualname__.split(".") name_parts = fn.__qualname__.split(".")
assert ( assert len(name_parts) == 2, (
len(name_parts) == 2 "Decorator only applies to a test function of a test class"
), "Decorator only applies to a test function of a test class" )
test_case_name, base_test_name = name_parts test_case_name, base_test_name = name_parts
for module_cls in module_classes: for module_cls in module_classes:
matching_module_infos = [m for m in module_db if m.module_cls == module_cls] matching_module_infos = [m for m in module_db if m.module_cls == module_cls]
assert ( assert len(matching_module_infos) == 1, (
len(matching_module_infos) == 1 f"Couldn't find single ModuleInfo for {module_cls}"
), f"Couldn't find single ModuleInfo for {module_cls}" )
module_info = matching_module_infos[0] module_info = matching_module_infos[0]
decorators = list(module_info.decorators) decorators = list(module_info.decorators)
new_decorator = DecorateInfo( new_decorator = DecorateInfo(

View File

@ -124,9 +124,7 @@ class TestGraphInfoProvider(TestCase):
) )
def test_recomputable_node_only_graph_with_larger_graph_context(self): def test_recomputable_node_only_graph_with_larger_graph_context(self):
recomputable_node_only_graph_with_larger_graph_context = ( recomputable_node_only_graph_with_larger_graph_context = self.graph_info_provider.recomputable_node_only_graph_with_larger_graph_context # noqa: B950
self.graph_info_provider.recomputable_node_only_graph_with_larger_graph_context
)
expected_nodes = self.all_recomputable_banned_nodes expected_nodes = self.all_recomputable_banned_nodes
# node1 does not have an indirect path to node5 because of node2 # node1 does not have an indirect path to node5 because of node2
# node2 has an indirect path to node5 # node2 has an indirect path to node5

View File

@ -2568,8 +2568,9 @@ def forward(self, primals_1, primals_2):
def fn(x: torch.Tensor, x1: torch.Tensor) -> torch.Tensor: def fn(x: torch.Tensor, x1: torch.Tensor) -> torch.Tensor:
return torch.ops._test._clone_create_graph(x, x1) return torch.ops._test._clone_create_graph(x, x1)
inp_x, inp_x1 = torch.randn(3, requires_grad=True), torch.randn( inp_x, inp_x1 = (
3, requires_grad=True torch.randn(3, requires_grad=True),
torch.randn(3, requires_grad=True),
) )
ref_x, ref_x1 = inp_x.clone(), inp_x1.clone() ref_x, ref_x1 = inp_x.clone(), inp_x1.clone()
@ -5283,11 +5284,12 @@ def forward(self, arg0_1):
mod = TestMod(fn) mod = TestMod(fn)
inp = torch.randn(2) inp = torch.randn(2)
with patch( with (
"functorch.compile.config.functionalize_rng_ops", True patch("functorch.compile.config.functionalize_rng_ops", True),
), self.assertRaisesRegex( self.assertRaisesRegex(
RuntimeError, RuntimeError,
"Functionalized RNG is not currently supported in the aot_export", "Functionalized RNG is not currently supported in the aot_export",
),
): ):
aot_export_joint_simple(fn, [mod.p, inp], trace_joint=False) aot_export_joint_simple(fn, [mod.p, inp], trace_joint=False)
aot_export_joint_simple(fn, [mod.p, inp], trace_joint=True) aot_export_joint_simple(fn, [mod.p, inp], trace_joint=True)

View File

@ -4324,9 +4324,7 @@ class TestExamplesCorrectness(TestCase):
def lennard_jones_force(r): def lennard_jones_force(r):
"""Get magnitude of LJ force""" """Get magnitude of LJ force"""
return -epsilon * ( return -epsilon * ((-12 * sigma**12 / r**13) + (6 * sigma**6 / r**7))
(-12 * sigma**12 / r**13) + (6 * sigma**6 / r**7)
)
r = torch.linspace(0.5, 2 * sigma, steps=100, requires_grad=True, device=device) r = torch.linspace(0.5, 2 * sigma, steps=100, requires_grad=True, device=device)
drs = torch.outer(r, torch.tensor([1.0, 0, 0], device=device)) drs = torch.outer(r, torch.tensor([1.0, 0, 0], device=device))
@ -4495,8 +4493,9 @@ class TestExamplesCorrectness(TestCase):
# This example mimics what a user might do when trying to find the optimal learning rate. They would # This example mimics what a user might do when trying to find the optimal learning rate. They would
# want to run a bunch of models with the same behavior (including the same dropout!) and have them # want to run a bunch of models with the same behavior (including the same dropout!) and have them
# each run with different learning rates. Specifically, this is an example of using same randomness with vmap # each run with different learning rates. Specifically, this is an example of using same randomness with vmap
points, labels = torch.randn(100, 2, 2, 2, 2, device=device), torch.randint( points, labels = (
0, 2, (100,), device=device torch.randn(100, 2, 2, 2, 2, device=device),
torch.randint(0, 2, (100,), device=device),
) )
class MLPClassifier(nn.Module): class MLPClassifier(nn.Module):

View File

@ -208,33 +208,33 @@ def check(f, t, delta, check_val=True, graph_input=False):
old_num_nodes = len(fx_g.graph.nodes) old_num_nodes = len(fx_g.graph.nodes)
new_num_nodes = len(new_graph.nodes) new_num_nodes = len(new_graph.nodes)
if delta == -1: if delta == -1:
assert ( assert old_num_nodes >= new_num_nodes, (
old_num_nodes >= new_num_nodes f"number of nodes increased {old_num_nodes}, {new_num_nodes}"
), f"number of nodes increased {old_num_nodes}, {new_num_nodes}" )
else: else:
assert ( assert old_num_nodes == new_num_nodes + delta, (
old_num_nodes == new_num_nodes + delta f"number of nodes not the same {old_num_nodes - delta}, {new_num_nodes}\n {fx_g.graph} \n {new_graph}"
), f"number of nodes not the same {old_num_nodes - delta}, {new_num_nodes}\n {fx_g.graph} \n {new_graph}" )
# a second pass should not reduce more nodes # a second pass should not reduce more nodes
pass_2_graph = fx_graph_cse(new_graph) pass_2_graph = fx_graph_cse(new_graph)
pass_2_num_nodes = len(pass_2_graph.nodes) pass_2_num_nodes = len(pass_2_graph.nodes)
assert ( assert pass_2_num_nodes == new_num_nodes, (
pass_2_num_nodes == new_num_nodes f"second pass graph has less node {pass_2_num_nodes}, {new_num_nodes}\n {new_graph} \n {pass_2_graph}"
), f"second pass graph has less node {pass_2_num_nodes}, {new_num_nodes}\n {new_graph} \n {pass_2_graph}" )
# check correctness # check correctness
if check_val: if check_val:
true_result = fx_g(t) true_result = fx_g(t)
our_result = new_g(t) our_result = new_g(t)
if true_result is None: # both return None if true_result is None: # both return None
assert ( assert our_result is None, (
our_result is None f"true result is None, CSE result is {our_result}"
), f"true result is None, CSE result is {our_result}" )
else: # results returned are the same else: # results returned are the same
assert torch.all( assert torch.all(true_result == our_result), (
true_result == our_result f"results are different {true_result}, {our_result}"
), f"results are different {true_result}, {our_result}" # check results are the same ) # check results are the same
class NoChangeTestCase(TestCase): class NoChangeTestCase(TestCase):

View File

@ -2154,9 +2154,9 @@ class TestOperators(TestCase):
else: else:
weight = torch.randn(weight_shape, device=device) weight = torch.randn(weight_shape, device=device)
target = torch.randint(0, C, target_shape, device=device) target = torch.randint(0, C, target_shape, device=device)
target[ target[0] = (
0 1 # since we're ignoring index 0, at least one element must be non-zero
] = 1 # since we're ignoring index 0, at least one element must be non-zero )
fn = functools.partial( fn = functools.partial(
torch.nn.functional.nll_loss, target=target, weight=weight, **kwargs torch.nn.functional.nll_loss, target=target, weight=weight, **kwargs

View File

@ -24,6 +24,7 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE. SOFTWARE.
""" """
from typing import Any from typing import Any
from unittest import mock from unittest import mock
@ -107,7 +108,7 @@ class TestParsedExpression(TestCase):
ParsedExpression("(a) ((b c) (d ...))") ParsedExpression("(a) ((b c) (d ...))")
# invalid identifiers # invalid identifiers
ParsedExpression("camelCase under_scored cApiTaLs \u00DF ...") ParsedExpression("camelCase under_scored cApiTaLs \u00df ...")
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
ParsedExpression("1a") ParsedExpression("1a")
with self.assertRaises(ValueError): with self.assertRaises(ValueError):

View File

@ -25,7 +25,6 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE. SOFTWARE.
""" """
import numpy as np import numpy as np
import torch import torch

View File

@ -1532,7 +1532,7 @@ class TestVmapOperators(Namespace.TestVmapBase):
self._test_unary(op, getter, "cpu") self._test_unary(op, getter, "cpu")
# test in-place # test in-place
method = getattr(Tensor, f'{op.__name__ + "_"}') method = getattr(Tensor, f"{op.__name__ + '_'}")
self._test_unary(method, getter, "cpu", check_propagates_grad=False) self._test_unary(method, getter, "cpu", check_propagates_grad=False)
def test_clone(self): def test_clone(self):

View File

@ -2,6 +2,7 @@ r"""
**This file is EXPERIMENTAL and is mostly used for testing purposes! Do not **This file is EXPERIMENTAL and is mostly used for testing purposes! Do not
rely on it for anything!** rely on it for anything!**
""" """
import operator import operator
import sys import sys

View File

@ -46,9 +46,9 @@ def check(self, f, t, delta, check_val=True, graph_input=False, P=None):
old_num_nodes = len(fx_g.graph.nodes) old_num_nodes = len(fx_g.graph.nodes)
new_num_nodes = len(new_graph.nodes) new_num_nodes = len(new_graph.nodes)
assert ( assert (new_num_nodes < old_num_nodes) == modified, (
new_num_nodes < old_num_nodes "modified should be True if the number of nodes decrease"
) == modified, "modified should be True if the number of nodes decrease" )
if delta == -1: if delta == -1:
self.assertTrue( self.assertTrue(

View File

@ -1783,8 +1783,9 @@ class TestSingleOperation(unittest.TestCase):
self.assertEqual(s.check(), z3.sat) self.assertEqual(s.check(), z3.sat)
add_result = z3.Const(3, tensor_type) add_result = z3.Const(3, tensor_type)
broadcast_res1, broadcast_res2 = z3.Const(4, tensor_type), z3.Const( broadcast_res1, broadcast_res2 = (
5, tensor_type z3.Const(4, tensor_type),
z3.Const(5, tensor_type),
) )
# print(s.model()) # print(s.model())

View File

@ -251,10 +251,13 @@ class TestInvokeSubgraphCompile(TestCase):
x_clone = x.detach().clone().requires_grad_(True) x_clone = x.detach().clone().requires_grad_(True)
y_clone = y.detach().clone().requires_grad_(True) y_clone = y.detach().clone().requires_grad_(True)
backend = EagerAndRecordGraphs() backend = EagerAndRecordGraphs()
with mock.patch( with (
"torch._dynamo.variables.higher_order_ops.InvokeSubgraphHigherOrderVariable.supports_input_mutation", mock.patch(
True, "torch._dynamo.variables.higher_order_ops.InvokeSubgraphHigherOrderVariable.supports_input_mutation",
), torch.no_grad(): True,
),
torch.no_grad(),
):
res = torch.compile(fn, backend=backend, fullgraph=True)( res = torch.compile(fn, backend=backend, fullgraph=True)(
mod, x_clone, y_clone mod, x_clone, y_clone
) )
@ -2399,7 +2402,9 @@ class GraphModule(torch.nn.Module):
{"strict": False}, {"strict": False},
{"strict": True}, {"strict": True},
], ],
class_name_func=lambda cls, _, params: f"{cls.__name__}{'Strict' if params['strict'] else 'Nonstrict'}", class_name_func=lambda cls,
_,
params: f"{cls.__name__}{'Strict' if params['strict'] else 'Nonstrict'}",
) )
class TestInvokeSubgraphExport(TestCase): class TestInvokeSubgraphExport(TestCase):
def test_simple_func(self): def test_simple_func(self):

View File

@ -38,7 +38,6 @@ USE_BLACK_FILELIST = re.compile(
# torchgen/** # torchgen/**
# test/** # test/**
# test/[a-h]*/** # test/[a-h]*/**
"test/[a-h]*/**",
# test/[i-j]*/** # test/[i-j]*/**
"test/[i-j]*/**", "test/[i-j]*/**",
# test/[k-m]*/** # test/[k-m]*/**