mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[BE][PYFMT] migrate PYFMT for test/[a-h]*/ to ruff format (#144555)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144555 Approved by: https://github.com/ezyang ghstack dependencies: #144551, #144554
This commit is contained in:
parent
e600e044a7
commit
6d5c789ad5
|
|
@ -494,9 +494,7 @@ class TestBaseDataSparsifier(_BaseDataSparsiferTestCase):
|
|||
(
|
||||
emb1,
|
||||
emb2,
|
||||
) = nn.Embedding(
|
||||
10, 3
|
||||
), nn.Embedding(20, 3)
|
||||
) = nn.Embedding(10, 3), nn.Embedding(20, 3)
|
||||
emb1_bag, emb2_bag = nn.EmbeddingBag(10, 3), nn.EmbeddingBag(20, 3)
|
||||
|
||||
emb3, emb3_bag = nn.Embedding(15, 3), nn.EmbeddingBag(20, 3)
|
||||
|
|
@ -627,9 +625,7 @@ class TestNormDataSparsifiers(_NormDataSparsifierTestCase):
|
|||
(
|
||||
emb1,
|
||||
emb2,
|
||||
) = nn.Embedding(
|
||||
10, 3
|
||||
), nn.Embedding(20, 3)
|
||||
) = nn.Embedding(10, 3), nn.Embedding(20, 3)
|
||||
emb1_bag, emb2_bag = nn.EmbeddingBag(10, 3), nn.EmbeddingBag(20, 3)
|
||||
|
||||
emb3, emb3_bag = nn.Embedding(15, 3), nn.EmbeddingBag(20, 3)
|
||||
|
|
|
|||
|
|
@ -218,12 +218,12 @@ def _sparse_layer_test_helper(
|
|||
qmodule_to_check = fqn_to_module(qmodel, fqn_to_check)
|
||||
|
||||
# check that the modules were converted as expected
|
||||
assert isinstance(
|
||||
sqmodule_to_check, sqmodule_expected_converted_class
|
||||
), "Convert failed"
|
||||
assert isinstance(
|
||||
qmodule_to_check, qmodule_expected_converted_class
|
||||
), "Mapping failed"
|
||||
assert isinstance(sqmodule_to_check, sqmodule_expected_converted_class), (
|
||||
"Convert failed"
|
||||
)
|
||||
assert isinstance(qmodule_to_check, qmodule_expected_converted_class), (
|
||||
"Mapping failed"
|
||||
)
|
||||
|
||||
row_block_size, col_block_size = sqmodel.linear._packed_params._weight_bias()[
|
||||
2:
|
||||
|
|
|
|||
|
|
@ -1055,9 +1055,9 @@ class TestFPGMPruner(TestCase):
|
|||
mask1 = pruner.groups[0]["module"].parametrizations.weight[0].mask[-1]
|
||||
mask2 = pruner.groups[0]["module"].parametrizations.weight[0].mask[-2]
|
||||
# Check if either of the least-norm filters is not pruned
|
||||
assert (
|
||||
mask1.item() is not False or mask2.item() is not False
|
||||
), "Do not prune all least-norm filters"
|
||||
assert mask1.item() is not False or mask2.item() is not False, (
|
||||
"Do not prune all least-norm filters"
|
||||
)
|
||||
|
||||
# fusion step
|
||||
pruned_model = pruner.prune()
|
||||
|
|
|
|||
|
|
@ -66,9 +66,10 @@ def generate_callgrind_artifacts() -> None:
|
|||
json.dump(artifacts, f, indent=4)
|
||||
|
||||
|
||||
def load_callgrind_artifacts() -> (
|
||||
tuple[benchmark_utils.CallgrindStats, benchmark_utils.CallgrindStats]
|
||||
):
|
||||
def load_callgrind_artifacts() -> tuple[
|
||||
benchmark_utils.CallgrindStats,
|
||||
benchmark_utils.CallgrindStats,
|
||||
]:
|
||||
"""Hermetic artifact to unit test Callgrind wrapper.
|
||||
|
||||
In addition to collecting counts, this wrapper provides some facilities for
|
||||
|
|
|
|||
|
|
@ -158,7 +158,8 @@ def compute_functional_name(test_params_dict):
|
|||
return test_params_dict["cpp_function_call"].split("(")[0].replace("F::", "")
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"`cpp_options_args` or `cpp_function_call` entry must be present in test params dict:\n{pprint.pformat(test_params_dict)}" # noqa: B950
|
||||
"`cpp_options_args` or `cpp_function_call` entry must be present in test params dict:\n"
|
||||
f"{pprint.pformat(test_params_dict)}"
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -179,7 +180,8 @@ def compute_cpp_function_call(test_params_dict, arg_dict, functional_name):
|
|||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"`cpp_options_args` or `cpp_function_call` entry must be present in test params dict:\n{pprint.pformat(test_params_dict)}" # noqa: B950
|
||||
"`cpp_options_args` or `cpp_function_call` entry must be present in test params dict:\n"
|
||||
f"{pprint.pformat(test_params_dict)}"
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -217,7 +219,8 @@ def write_test_to_test_class(
|
|||
or "cpp_function_call" in test_params_dict
|
||||
), (
|
||||
"To enable C++ API parity test, "
|
||||
f"`cpp_options_args` or `cpp_function_call` entry must be present in test params dict:\n{pprint.pformat(test_params_dict)}. \n" # noqa: B950
|
||||
"`cpp_options_args` or `cpp_function_call` entry must be present in test params dict:\n"
|
||||
f"{pprint.pformat(test_params_dict)}. \n"
|
||||
"If you are interested in adding the C++ API parity test, please see:\n"
|
||||
"NOTE [How to check NN module / functional API parity between Python and C++ frontends]. \n"
|
||||
"If not, please add `test_cpp_api_parity=False` to the test params dict and file an issue about this."
|
||||
|
|
@ -233,14 +236,16 @@ def write_test_to_test_class(
|
|||
|
||||
functional_name = compute_functional_name(test_params_dict)
|
||||
|
||||
assert hasattr(
|
||||
torch.nn.functional, functional_name
|
||||
), f"`torch.nn.functional` doesn't have function `{functional_name}`. (Discovered while processing\n{pprint.pformat(test_params_dict)}.)" # noqa: B950
|
||||
assert hasattr(torch.nn.functional, functional_name), (
|
||||
f"`torch.nn.functional` doesn't have function `{functional_name}`. "
|
||||
f"(Discovered while processing\n{pprint.pformat(test_params_dict)}.)"
|
||||
)
|
||||
|
||||
functional_full_name = "F::" + functional_name
|
||||
|
||||
assert functional_full_name in parity_table["torch::nn::functional"], (
|
||||
f"Please add `{functional_full_name}` entry to `torch::nn::functional` section of `test/cpp_api_parity/parity-tracker.md`. "
|
||||
f"Please add `{functional_full_name}` entry to `torch::nn::functional` "
|
||||
"section of `test/cpp_api_parity/parity-tracker.md`. "
|
||||
f"(Discovered while processing\n{pprint.pformat(test_params_dict)}.)"
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -78,9 +78,9 @@ def _kernel_fallback(op, *args, **kwargs):
|
|||
elif op._schema.is_mutable or op is torch.ops.aten._copy_from.default:
|
||||
# Only handle inplace ops returning their first arg
|
||||
assert len(args) >= 1, f"Inplace {op} needs at least one arg"
|
||||
assert (
|
||||
len(op._schema.returns) == 1
|
||||
), f"NYI Inplace {op} with more than one return"
|
||||
assert len(op._schema.returns) == 1, (
|
||||
f"NYI Inplace {op} with more than one return"
|
||||
)
|
||||
op_name = op.overloadpacket._qualified_op_name
|
||||
real_res = args[0]
|
||||
elif any(r.alias_info is not None for r in op._schema.returns):
|
||||
|
|
|
|||
|
|
@ -67,7 +67,7 @@ def prepare_for_sending(args, kwargs):
|
|||
def convert(obj):
|
||||
if type(obj) not in VALID_QUEUE_TYPES_IN:
|
||||
raise RuntimeError(
|
||||
f"Cannot send object of type {type(obj)} " "over openreg device pipe."
|
||||
f"Cannot send object of type {type(obj)} over openreg device pipe."
|
||||
)
|
||||
|
||||
if isinstance(obj, torch.Tensor):
|
||||
|
|
@ -82,8 +82,7 @@ def receive_after_sending(allocator, args, kwargs):
|
|||
def convert(obj):
|
||||
if type(obj) not in VALID_QUEUE_TYPES_OUT:
|
||||
raise RuntimeError(
|
||||
f"Received invalid object of type {type(obj)} "
|
||||
"over openreg device pipe."
|
||||
f"Received invalid object of type {type(obj)} over openreg device pipe."
|
||||
)
|
||||
|
||||
if isinstance(obj, OpenRegTensorMeta):
|
||||
|
|
|
|||
|
|
@ -561,8 +561,9 @@ class TestFullyShardPrefetch(FSDPTest):
|
|||
FSDPParamGroup.post_backward, events
|
||||
)
|
||||
# Check the order for normal 1 forward, 1 backward, 1 optimizer step
|
||||
with patch_unshard(unshard_with_record), patch_post_backward(
|
||||
post_backward_with_record
|
||||
with (
|
||||
patch_unshard(unshard_with_record),
|
||||
patch_post_backward(post_backward_with_record),
|
||||
):
|
||||
for iter_idx in range(3):
|
||||
loss = model(inp)
|
||||
|
|
@ -617,8 +618,9 @@ class TestFullyShardPrefetch(FSDPTest):
|
|||
FSDPParamGroup.post_backward, events
|
||||
)
|
||||
# Check the order for multiple forwards before 1 backward
|
||||
with patch_unshard(unshard_with_record), patch_post_backward(
|
||||
post_backward_with_record
|
||||
with (
|
||||
patch_unshard(unshard_with_record),
|
||||
patch_post_backward(post_backward_with_record),
|
||||
):
|
||||
loss1 = model(inp)
|
||||
loss2 = model(inp)
|
||||
|
|
@ -703,8 +705,9 @@ class TestFullyShardPrefetch(FSDPTest):
|
|||
post_backward_with_record = self._get_post_backward_with_record(
|
||||
FSDPParamGroup.post_backward, events
|
||||
)
|
||||
with patch_unshard(unshard_with_record), patch_post_backward(
|
||||
post_backward_with_record
|
||||
with (
|
||||
patch_unshard(unshard_with_record),
|
||||
patch_post_backward(post_backward_with_record),
|
||||
):
|
||||
loss1, loss2 = model(inp)
|
||||
expected_events = [
|
||||
|
|
@ -794,9 +797,11 @@ class TestFullyShardPrefetch(FSDPTest):
|
|||
("reshard", "", TrainingState.POST_BACKWARD),
|
||||
("post_backward", "", TrainingState.POST_BACKWARD),
|
||||
]
|
||||
with patch_unshard(unshard_with_record), patch_reshard(
|
||||
reshard_with_record
|
||||
), patch_post_backward(post_backward_with_record):
|
||||
with (
|
||||
patch_unshard(unshard_with_record),
|
||||
patch_reshard(reshard_with_record),
|
||||
patch_post_backward(post_backward_with_record),
|
||||
):
|
||||
set_forward_prefetch(model, num_to_prefetch=1)
|
||||
loss = model(inp)
|
||||
expected_forward_events = [
|
||||
|
|
@ -882,9 +887,11 @@ class TestFullyShardPrefetch(FSDPTest):
|
|||
("reshard", "layers.3", TrainingState.FORWARD),
|
||||
("reshard", "", TrainingState.FORWARD),
|
||||
]
|
||||
with patch_unshard(unshard_with_record), patch_reshard(
|
||||
reshard_with_record
|
||||
), patch_post_backward(post_backward_with_record):
|
||||
with (
|
||||
patch_unshard(unshard_with_record),
|
||||
patch_reshard(reshard_with_record),
|
||||
patch_post_backward(post_backward_with_record),
|
||||
):
|
||||
set_backward_prefetch(model, num_to_prefetch=1)
|
||||
loss = model(inp)
|
||||
self.assertEqual(events, expected_forward_events)
|
||||
|
|
@ -967,8 +974,9 @@ class TestFullyShardPrefetch(FSDPTest):
|
|||
(2, model_args.max_seq_len),
|
||||
device=device_type.type,
|
||||
)
|
||||
with patch_unshard(unshard_with_record), patch_post_backward(
|
||||
post_backward_with_record
|
||||
with (
|
||||
patch_unshard(unshard_with_record),
|
||||
patch_post_backward(post_backward_with_record),
|
||||
):
|
||||
for _ in range(3):
|
||||
loss = model(inp)
|
||||
|
|
@ -1046,8 +1054,9 @@ class TestFullyShardPrefetch(FSDPTest):
|
|||
FSDPParamGroup.post_backward, events
|
||||
)
|
||||
inp = torch.randn((2, 16), device=device_type.type)
|
||||
with patch_unshard(unshard_with_record), patch_post_backward(
|
||||
post_backward_with_record
|
||||
with (
|
||||
patch_unshard(unshard_with_record),
|
||||
patch_post_backward(post_backward_with_record),
|
||||
):
|
||||
for _ in range(3):
|
||||
loss = model(inp)
|
||||
|
|
|
|||
|
|
@ -222,9 +222,7 @@ class TestFullyShardCompile(FSDPTest):
|
|||
):
|
||||
unsharded_param_graph_inputs.add(node.args[0])
|
||||
assert len(unsharded_param_graph_inputs) > 0
|
||||
assert len(unsharded_param_graph_inputs) == len(
|
||||
list(model.parameters())
|
||||
), """\
|
||||
assert len(unsharded_param_graph_inputs) == len(list(model.parameters())), """\
|
||||
Expected all model parameters to be wrapped by FSDP2 and
|
||||
have their unsharded version as graph input, but it's not true!
|
||||
"""
|
||||
|
|
@ -237,7 +235,7 @@ have their unsharded version as graph input, but it's not true!
|
|||
no_aliased_unsharded_params_in_graph_inputs = False
|
||||
err_msg += f"""\n
|
||||
Found aliased unsharded param in graph inputs: {aliased_graph_inputs},
|
||||
val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]},
|
||||
val.shape: {[node.meta["val"].shape for node in aliased_graph_inputs]},
|
||||
"""
|
||||
self.assertTrue(no_aliased_unsharded_params_in_graph_inputs, err_msg)
|
||||
|
||||
|
|
@ -466,10 +464,9 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]},
|
|||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
def test_compiled_autograd_ctx(self):
|
||||
self.skipTestForOldSm()
|
||||
with torch._dynamo.config.patch(
|
||||
skip_fsdp_hooks=False,
|
||||
), torch._functorch.config.patch(
|
||||
recompute_views=True,
|
||||
with (
|
||||
torch._dynamo.config.patch(skip_fsdp_hooks=False),
|
||||
torch._functorch.config.patch(recompute_views=True),
|
||||
):
|
||||
inputs = torch.randn(8, 8)
|
||||
model = torch.nn.Linear(8, 8)
|
||||
|
|
@ -567,24 +564,28 @@ Unsupported Tensor.backward() call
|
|||
|
||||
torch._dynamo.reset()
|
||||
torch._dynamo.compiled_autograd.reset()
|
||||
with torch._dynamo.config.patch(
|
||||
compiled_autograd=True,
|
||||
compiled_autograd_kwargs_override={
|
||||
"fullgraph": True,
|
||||
},
|
||||
inline_inbuilt_nn_modules=True,
|
||||
skip_fsdp_hooks=False,
|
||||
), torch._functorch.config.patch(
|
||||
enable_autograd_cache=False,
|
||||
recompute_views=True,
|
||||
), torch._inductor.config.patch(
|
||||
force_disable_caches=True,
|
||||
reorder_for_compute_comm_overlap=True,
|
||||
reorder_for_compute_comm_overlap_passes=[
|
||||
"sink_waits",
|
||||
"raise_comms",
|
||||
"reorder_compute_for_overlap",
|
||||
],
|
||||
with (
|
||||
torch._dynamo.config.patch(
|
||||
compiled_autograd=True,
|
||||
compiled_autograd_kwargs_override={
|
||||
"fullgraph": True,
|
||||
},
|
||||
inline_inbuilt_nn_modules=True,
|
||||
skip_fsdp_hooks=False,
|
||||
),
|
||||
torch._functorch.config.patch(
|
||||
enable_autograd_cache=False,
|
||||
recompute_views=True,
|
||||
),
|
||||
torch._inductor.config.patch(
|
||||
force_disable_caches=True,
|
||||
reorder_for_compute_comm_overlap=True,
|
||||
reorder_for_compute_comm_overlap_passes=[
|
||||
"sink_waits",
|
||||
"raise_comms",
|
||||
"reorder_compute_for_overlap",
|
||||
],
|
||||
),
|
||||
):
|
||||
losses_compiled = test_compiled()
|
||||
losses_eager = test_eager()
|
||||
|
|
@ -741,20 +742,21 @@ Unsupported Tensor.backward() call
|
|||
def _test_nested_fully_shard_backend_inductor_fullgraph_True(self):
|
||||
self.skipTestForOldSm()
|
||||
for fwd_fullgraph in [True]:
|
||||
with self._reinplace_all_gather_with_optional_checks(
|
||||
fwd_fullgraph
|
||||
), torch._inductor.config.patch(
|
||||
post_grad_custom_post_pass=(
|
||||
functools.partial(
|
||||
self._check_fsdp_copy_and_resize_ops_count_in_graph,
|
||||
fwd_copy_count=0,
|
||||
fwd_resize_count=0,
|
||||
bwd_copy_count=0,
|
||||
bwd_resize_count=0,
|
||||
with (
|
||||
self._reinplace_all_gather_with_optional_checks(fwd_fullgraph),
|
||||
torch._inductor.config.patch(
|
||||
post_grad_custom_post_pass=(
|
||||
functools.partial(
|
||||
self._check_fsdp_copy_and_resize_ops_count_in_graph,
|
||||
fwd_copy_count=0,
|
||||
fwd_resize_count=0,
|
||||
bwd_copy_count=0,
|
||||
bwd_resize_count=0,
|
||||
)
|
||||
if fwd_fullgraph
|
||||
else None
|
||||
)
|
||||
if fwd_fullgraph
|
||||
else None
|
||||
)
|
||||
),
|
||||
):
|
||||
_, triton_codes = run_and_get_code(
|
||||
lambda: self._test_traceable_fsdp(
|
||||
|
|
@ -943,9 +945,10 @@ Unsupported Tensor.backward() call
|
|||
for fwd_fullgraph, all_requires_grad in itertools.product(
|
||||
[True], [True, False]
|
||||
):
|
||||
with self._maybe_add_graph_break_to_sdpa(
|
||||
fwd_fullgraph
|
||||
), self._reinplace_all_gather_with_optional_checks(fwd_fullgraph):
|
||||
with (
|
||||
self._maybe_add_graph_break_to_sdpa(fwd_fullgraph),
|
||||
self._reinplace_all_gather_with_optional_checks(fwd_fullgraph),
|
||||
):
|
||||
self._test_traceable_fsdp(
|
||||
*self._create_transformer_factory_fns(
|
||||
all_requires_grad=all_requires_grad
|
||||
|
|
@ -982,23 +985,24 @@ Unsupported Tensor.backward() call
|
|||
log.warning(
|
||||
f"fwd_fullgraph={fwd_fullgraph}, all_requires_grad={all_requires_grad}, activation_checkpoint={activation_checkpoint}" # noqa: G004, G001, B950
|
||||
)
|
||||
with self._reinplace_all_gather_with_optional_checks(
|
||||
fwd_fullgraph
|
||||
), torch._inductor.config.patch(
|
||||
post_grad_custom_post_pass=(
|
||||
functools.partial(
|
||||
self._check_fsdp_copy_and_resize_ops_count_in_graph,
|
||||
# NOTE: For the root unsharded params, we don't reshard after forward since for training,
|
||||
# the parameters would be freed and all-gathered immediately. Hence we still have
|
||||
# their resize and copy ops in the graph.
|
||||
fwd_copy_count=4,
|
||||
fwd_resize_count=4,
|
||||
bwd_copy_count=0,
|
||||
bwd_resize_count=4,
|
||||
with (
|
||||
self._reinplace_all_gather_with_optional_checks(fwd_fullgraph),
|
||||
torch._inductor.config.patch(
|
||||
post_grad_custom_post_pass=(
|
||||
functools.partial(
|
||||
self._check_fsdp_copy_and_resize_ops_count_in_graph,
|
||||
# NOTE: For the root unsharded params, we don't reshard after forward since for training,
|
||||
# the parameters would be freed and all-gathered immediately. Hence we still have
|
||||
# their resize and copy ops in the graph.
|
||||
fwd_copy_count=4,
|
||||
fwd_resize_count=4,
|
||||
bwd_copy_count=0,
|
||||
bwd_resize_count=4,
|
||||
)
|
||||
if fwd_fullgraph
|
||||
else None
|
||||
)
|
||||
if fwd_fullgraph
|
||||
else None
|
||||
)
|
||||
),
|
||||
):
|
||||
_, triton_codes = run_and_get_code(
|
||||
lambda: self._test_traceable_fsdp(
|
||||
|
|
|
|||
|
|
@ -385,9 +385,9 @@ class TestFullyShardAllGatherExtensionsMultiThread(
|
|||
only some ranks may require padding, in which case only those ranks
|
||||
will error out and the all-gather will timeout.
|
||||
"""
|
||||
assert (
|
||||
self.world_size >= 2
|
||||
), f"Assumes world size of at least 2 but got {self.world_size=}"
|
||||
assert self.world_size >= 2, (
|
||||
f"Assumes world size of at least 2 but got {self.world_size=}"
|
||||
)
|
||||
model = MLP(dim=3, dim_multiplier=3)
|
||||
for module in model.modules():
|
||||
for param_name, param in module.named_parameters(recurse=False):
|
||||
|
|
|
|||
|
|
@ -115,9 +115,10 @@ class TestFullyShardFrozen(FSDPTest):
|
|||
|
||||
torch.manual_seed(42 + self.rank + 1)
|
||||
device = device_type
|
||||
with patch_reduce_scatter(
|
||||
reduce_scatter
|
||||
), patch_register_post_backward_hook_backward(backward_with_count):
|
||||
with (
|
||||
patch_reduce_scatter(reduce_scatter),
|
||||
patch_register_post_backward_hook_backward(backward_with_count),
|
||||
):
|
||||
for iter_idx in range(10):
|
||||
inp = torch.randn((8, lin_dim), device=device)
|
||||
losses: list[torch.Tensor] = []
|
||||
|
|
|
|||
|
|
@ -910,9 +910,9 @@ class TestFullyShardProcessGroupInit(FSDPTestMultiThread):
|
|||
@skip_if_lt_x_gpu(1)
|
||||
def test_2d_process_group_init(self):
|
||||
shard_mesh_dim_size = 2
|
||||
assert (
|
||||
self.world_size % shard_mesh_dim_size == 0
|
||||
), f"Expects {self.world_size} to be divisible by {shard_mesh_dim_size}"
|
||||
assert self.world_size % shard_mesh_dim_size == 0, (
|
||||
f"Expects {self.world_size} to be divisible by {shard_mesh_dim_size}"
|
||||
)
|
||||
replicate_mesh_dim_size = self.world_size // shard_mesh_dim_size
|
||||
mesh_dim_names = ("replicate", "shard")
|
||||
ref_mesh = init_device_mesh(
|
||||
|
|
|
|||
|
|
@ -54,9 +54,9 @@ class TestFullyShardMemory(FSDPTest):
|
|||
)
|
||||
):
|
||||
return # skip since not a common use case
|
||||
assert (
|
||||
self.world_size == 2
|
||||
), f"Requires world size of 2 since some values are hard coded: {self.world_size}"
|
||||
assert self.world_size == 2, (
|
||||
f"Requires world size of 2 since some values are hard coded: {self.world_size}"
|
||||
)
|
||||
torch.manual_seed(42)
|
||||
# Pre-run a linear forward (gemm and bias) and backward (gemm) to
|
||||
# allocate the cuBLAS workspaces before measuring the memory usage
|
||||
|
|
|
|||
|
|
@ -284,9 +284,7 @@ class TestFullyShardMixedPrecisionTraining(FSDPTest):
|
|||
) # bf16 reduction
|
||||
param.grad = funcol.all_gather_tensor(
|
||||
sharded_grad, gather_dim=0, group=group
|
||||
).to(
|
||||
param.dtype
|
||||
) # upcast to fp32
|
||||
).to(param.dtype) # upcast to fp32
|
||||
ref_optim.step() # fp32 optimizer step
|
||||
|
||||
self.assertEqual(fsdp_loss, ref_loss)
|
||||
|
|
|
|||
|
|
@ -139,8 +139,9 @@ class TestFullyShardOverlap(FSDPTest):
|
|||
dist.reduce_scatter_tensor(dummy_rs_output, dummy_rs_input)
|
||||
|
||||
def fwd_bwd():
|
||||
with patch_all_gather(delayed_all_gather), patch_reduce_scatter(
|
||||
delayed_reduce_scatter
|
||||
with (
|
||||
patch_all_gather(delayed_all_gather),
|
||||
patch_reduce_scatter(delayed_reduce_scatter),
|
||||
):
|
||||
loss = model(inp).sum()
|
||||
loss.backward()
|
||||
|
|
|
|||
|
|
@ -74,12 +74,12 @@ class TestFullyShardForwardInputs(FSDPTestMultiThread):
|
|||
# Check that FSDP moved the inputs to GPU, including recursing
|
||||
# into the tuple data structure
|
||||
assert x.device == device, f"Expects {device} but got {x.device}"
|
||||
assert (
|
||||
ys[0].device == device
|
||||
), f"Expects {device} but got {ys[0].device}"
|
||||
assert (
|
||||
ys[1].device == device
|
||||
), f"Expects {device} but got {ys[1].device}"
|
||||
assert ys[0].device == device, (
|
||||
f"Expects {device} but got {ys[0].device}"
|
||||
)
|
||||
assert ys[1].device == device, (
|
||||
f"Expects {device} but got {ys[1].device}"
|
||||
)
|
||||
y = ys[0] + ys[1]
|
||||
return x + y + 1
|
||||
|
||||
|
|
|
|||
|
|
@ -234,8 +234,8 @@ class DistributedDataParallelCommHookTest(MultiProcessTestCase):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
assert (
|
||||
not torch.cuda._initialized
|
||||
), "test_distributed must not have initialized CUDA context on main process"
|
||||
assert not torch.cuda._initialized, (
|
||||
"test_distributed must not have initialized CUDA context on main process"
|
||||
)
|
||||
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -250,9 +250,11 @@ class TestJoin(MultiProcessTestCase):
|
|||
else "Detected at least one rank that exhausted inputs. "
|
||||
"Throwing across all ranks."
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, expected_msg
|
||||
) if throw_on_early_termination else contextlib.nullcontext():
|
||||
with (
|
||||
self.assertRaisesRegex(RuntimeError, expected_msg)
|
||||
if throw_on_early_termination
|
||||
else contextlib.nullcontext()
|
||||
):
|
||||
with Join(
|
||||
allreducers,
|
||||
enable=enable,
|
||||
|
|
|
|||
|
|
@ -677,9 +677,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
|
|||
fully_shard(layer)
|
||||
fully_shard(model)
|
||||
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
|
||||
torch.optim.lr_scheduler.LambdaLR(
|
||||
optim, lr_lambda=[lambda epoch: 0.95**epoch]
|
||||
)
|
||||
torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=[lambda epoch: 0.95**epoch])
|
||||
opt_state_dict = ptd_state_dict.get_optimizer_state_dict(
|
||||
model,
|
||||
optim,
|
||||
|
|
|
|||
|
|
@ -228,9 +228,9 @@ class TestStateDictStager(TestCase):
|
|||
|
||||
# Validate tensor count and bytes
|
||||
expected_storage_cnt = 2
|
||||
assert (
|
||||
num_storages == expected_storage_cnt
|
||||
), f"Expected {expected_storage_cnt} storages, got {num_storages}"
|
||||
assert num_storages == expected_storage_cnt, (
|
||||
f"Expected {expected_storage_cnt} storages, got {num_storages}"
|
||||
)
|
||||
|
||||
# Calculate expected bytes
|
||||
# Note: Only unique storages are counted in the byte count
|
||||
|
|
@ -239,9 +239,9 @@ class TestStateDictStager(TestCase):
|
|||
+ tensor3.numel() # tensor1 and tensor2 share storage
|
||||
* tensor3.element_size() # tensor3 and its narrow view share storage
|
||||
)
|
||||
assert (
|
||||
num_bytes == expected_bytes
|
||||
), f"Expected {expected_bytes} bytes, got {num_bytes}"
|
||||
assert num_bytes == expected_bytes, (
|
||||
f"Expected {expected_bytes} bytes, got {num_bytes}"
|
||||
)
|
||||
# Verify that the CPU state dict is equivalent to the original CUDA state dict
|
||||
result, error = compare_state_dicts(state_dict, cpu_state_dict)
|
||||
assert result, f"State dicts are not equivalent: {error}"
|
||||
|
|
@ -301,9 +301,9 @@ class TestStateDictStager(TestCase):
|
|||
|
||||
# Verify the first result is correct
|
||||
result, error = compare_state_dicts(state_dict, cpu_state_dict1)
|
||||
assert (
|
||||
result
|
||||
), f"First state dict is not equivalent to original: {error}"
|
||||
assert result, (
|
||||
f"First state dict is not equivalent to original: {error}"
|
||||
)
|
||||
|
||||
# Modify the original tensors
|
||||
tensor1.fill_(0)
|
||||
|
|
@ -317,14 +317,14 @@ class TestStateDictStager(TestCase):
|
|||
|
||||
# Verify that the second CPU state dict is equivalent to the modified original state dict
|
||||
result, error = compare_state_dicts(state_dict, cpu_state_dict2)
|
||||
assert (
|
||||
result
|
||||
), f"Second state dict is not equivalent to modified original: {error}"
|
||||
assert result, (
|
||||
f"Second state dict is not equivalent to modified original: {error}"
|
||||
)
|
||||
|
||||
# Verify that the number of cached storages hasn't changed
|
||||
assert (
|
||||
num_storages1 == num_storages2
|
||||
), f"Storage count changed: {num_storages1} vs {num_storages2}"
|
||||
assert num_storages1 == num_storages2, (
|
||||
f"Storage count changed: {num_storages1} vs {num_storages2}"
|
||||
)
|
||||
|
||||
# Verify that the tensors in the second state dict have the same storage pointers as the first
|
||||
assert (
|
||||
|
|
@ -347,12 +347,12 @@ class TestStateDictStager(TestCase):
|
|||
cpu_state_dict3 = stager.stage(state_dict)
|
||||
|
||||
# Verify that the third CPU state dict reflects the updated values
|
||||
assert torch.all(
|
||||
cpu_state_dict3["tensor1"] == 42.0
|
||||
), "Updated values should be reflected in the cached state dict"
|
||||
assert torch.all(
|
||||
cpu_state_dict3["tensor2"] == 42.0
|
||||
), "Updated values should be reflected in the cached state dict"
|
||||
assert torch.all(cpu_state_dict3["tensor1"] == 42.0), (
|
||||
"Updated values should be reflected in the cached state dict"
|
||||
)
|
||||
assert torch.all(cpu_state_dict3["tensor2"] == 42.0), (
|
||||
"Updated values should be reflected in the cached state dict"
|
||||
)
|
||||
|
||||
@requires_cuda
|
||||
def test_tensor_attrs(self):
|
||||
|
|
@ -381,24 +381,24 @@ class TestStateDictStager(TestCase):
|
|||
cpu_state_dict = stager.stage(state_dict)
|
||||
|
||||
# Verify that tensor attributes are preserved
|
||||
assert hasattr(
|
||||
cpu_state_dict["tensor1"], "a"
|
||||
), "Tensor attribute 'a' was not preserved"
|
||||
assert (
|
||||
cpu_state_dict["tensor1"].a == 42
|
||||
), "Tensor attribute 'a' has incorrect value"
|
||||
assert hasattr(
|
||||
cpu_state_dict["tensor1"], "b"
|
||||
), "Tensor attribute 'b' was not preserved"
|
||||
assert (
|
||||
cpu_state_dict["tensor1"].b == 43
|
||||
), "Tensor attribute 'b' has incorrect value"
|
||||
assert hasattr(
|
||||
cpu_state_dict["recursive"]["tensor3"], "c"
|
||||
), "Tensor attribute 'c' was not preserved"
|
||||
assert (
|
||||
cpu_state_dict["recursive"]["tensor3"].c == 44
|
||||
), "Tensor attribute 'c' has incorrect value"
|
||||
assert hasattr(cpu_state_dict["tensor1"], "a"), (
|
||||
"Tensor attribute 'a' was not preserved"
|
||||
)
|
||||
assert cpu_state_dict["tensor1"].a == 42, (
|
||||
"Tensor attribute 'a' has incorrect value"
|
||||
)
|
||||
assert hasattr(cpu_state_dict["tensor1"], "b"), (
|
||||
"Tensor attribute 'b' was not preserved"
|
||||
)
|
||||
assert cpu_state_dict["tensor1"].b == 43, (
|
||||
"Tensor attribute 'b' has incorrect value"
|
||||
)
|
||||
assert hasattr(cpu_state_dict["recursive"]["tensor3"], "c"), (
|
||||
"Tensor attribute 'c' was not preserved"
|
||||
)
|
||||
assert cpu_state_dict["recursive"]["tensor3"].c == 44, (
|
||||
"Tensor attribute 'c' has incorrect value"
|
||||
)
|
||||
|
||||
@requires_cuda
|
||||
def test_different_dtypes(self):
|
||||
|
|
|
|||
|
|
@ -500,11 +500,13 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
|
|||
logs_specs=DefaultLogsSpecs(log_dir=self.log_dir()),
|
||||
)
|
||||
|
||||
with mock.patch.object(
|
||||
mpc, "_is_done", return_value=True
|
||||
), mock.patch.object(mpc, "_pc"), mock.patch.object(
|
||||
mpc._pc, "join", side_effect=[True, False, False, True]
|
||||
) as mock_join:
|
||||
with (
|
||||
mock.patch.object(mpc, "_is_done", return_value=True),
|
||||
mock.patch.object(mpc, "_pc"),
|
||||
mock.patch.object(
|
||||
mpc._pc, "join", side_effect=[True, False, False, True]
|
||||
) as mock_join,
|
||||
):
|
||||
mpc._poll()
|
||||
self.assertEqual(4, mock_join.call_count)
|
||||
|
||||
|
|
|
|||
|
|
@ -56,32 +56,36 @@ class TestDistributedCheckpoint(FSDPTest):
|
|||
torch.manual_seed(200)
|
||||
new_model = wrap(SkipModel(double_nest=True))
|
||||
|
||||
with FullyShardedDataParallel.summon_full_params(
|
||||
model
|
||||
), FullyShardedDataParallel.summon_full_params(new_model):
|
||||
with (
|
||||
FullyShardedDataParallel.summon_full_params(model),
|
||||
FullyShardedDataParallel.summon_full_params(new_model),
|
||||
):
|
||||
params = list(model.parameters())
|
||||
new_params = list(new_model.parameters())
|
||||
self.assertNotEqual(params, new_params)
|
||||
|
||||
writer = FileSystemWriter(self.temp_dir)
|
||||
reader = FileSystemReader(self.temp_dir)
|
||||
with FSDP.state_dict_type(model, state_dict_type), FSDP.state_dict_type(
|
||||
new_model, state_dict_type
|
||||
with (
|
||||
FSDP.state_dict_type(model, state_dict_type),
|
||||
FSDP.state_dict_type(new_model, state_dict_type),
|
||||
):
|
||||
state_dict = model.state_dict()
|
||||
|
||||
save(state_dict, writer)
|
||||
|
||||
with FSDP.state_dict_type(model, state_dict_type), FSDP.state_dict_type(
|
||||
new_model, state_dict_type
|
||||
with (
|
||||
FSDP.state_dict_type(model, state_dict_type),
|
||||
FSDP.state_dict_type(new_model, state_dict_type),
|
||||
):
|
||||
state_dict = new_model.state_dict()
|
||||
load(state_dict, reader)
|
||||
new_model.load_state_dict(state_dict)
|
||||
|
||||
with FullyShardedDataParallel.summon_full_params(
|
||||
model
|
||||
), FullyShardedDataParallel.summon_full_params(new_model):
|
||||
with (
|
||||
FullyShardedDataParallel.summon_full_params(model),
|
||||
FullyShardedDataParallel.summon_full_params(new_model),
|
||||
):
|
||||
params = list(model.parameters())
|
||||
new_params = list(new_model.parameters())
|
||||
self.assertEqual(params, new_params)
|
||||
|
|
|
|||
|
|
@ -242,11 +242,10 @@ class TestCommunication(FSDPTest):
|
|||
# and if `use_no_sync=False`, we only run `num_iters` iterations
|
||||
# outside `no_sync()`
|
||||
num_iters = 3
|
||||
with patch(
|
||||
"torch.distributed.all_gather_into_tensor"
|
||||
) as mock_all_gather, patch(
|
||||
"torch.distributed.reduce_scatter_tensor"
|
||||
) as mock_reduce_scatter:
|
||||
with (
|
||||
patch("torch.distributed.all_gather_into_tensor") as mock_all_gather,
|
||||
patch("torch.distributed.reduce_scatter_tensor") as mock_reduce_scatter,
|
||||
):
|
||||
|
||||
def reset_mocks():
|
||||
mock_all_gather.reset_mock()
|
||||
|
|
|
|||
|
|
@ -379,12 +379,15 @@ class TestHooks(FSDPTest):
|
|||
register_pre_backward_hooks_call_count += 1
|
||||
return orig_register_pre_backward_hooks(*args, **kwargs)
|
||||
|
||||
with mock.patch(
|
||||
"torch.distributed.fsdp._runtime_utils._register_pre_backward_hooks",
|
||||
_register_pre_backward_hooks_with_count,
|
||||
), mock.patch(
|
||||
"torch.distributed.fsdp._runtime_utils._register_post_backward_hook"
|
||||
) as register_post_bwd_mock:
|
||||
with (
|
||||
mock.patch(
|
||||
"torch.distributed.fsdp._runtime_utils._register_pre_backward_hooks",
|
||||
_register_pre_backward_hooks_with_count,
|
||||
),
|
||||
mock.patch(
|
||||
"torch.distributed.fsdp._runtime_utils._register_post_backward_hook"
|
||||
) as register_post_bwd_mock,
|
||||
):
|
||||
self.assertEqual(register_pre_backward_hooks_call_count, 0)
|
||||
self.assertFalse(register_post_bwd_mock.called)
|
||||
fsdp_model(*input)
|
||||
|
|
|
|||
|
|
@ -152,9 +152,9 @@ class TestGradAcc(FSDPTest):
|
|||
batches.append(tuple(permute_tensor(t) for t in batch))
|
||||
for batch1, batch2 in itertools.combinations(batches, r=2):
|
||||
for t1, t2 in zip(batch1, batch2):
|
||||
assert not torch.all(
|
||||
t1 == t2
|
||||
), "Check the test to make sure that batches are distinct"
|
||||
assert not torch.all(t1 == t2), (
|
||||
"Check the test to make sure that batches are distinct"
|
||||
)
|
||||
|
||||
# Concatenate the batches along the given batch dimension
|
||||
concat_batch: tuple[torch.Tensor, ...] = tuple(
|
||||
|
|
|
|||
|
|
@ -121,8 +121,9 @@ class TestFSDPHybridShard(FSDPTest):
|
|||
def test_hsdp_save_load_state_dict(self):
|
||||
model = MyModel().cuda()
|
||||
num_node_devices = torch.cuda.device_count()
|
||||
shard_rank_lists = list(range(0, num_node_devices // 2)), list(
|
||||
range(num_node_devices // 2, num_node_devices)
|
||||
shard_rank_lists = (
|
||||
list(range(0, num_node_devices // 2)),
|
||||
list(range(num_node_devices // 2, num_node_devices)),
|
||||
)
|
||||
shard_groups = (
|
||||
dist.new_group(shard_rank_lists[0]),
|
||||
|
|
@ -171,8 +172,9 @@ class TestFSDPHybridShard(FSDPTest):
|
|||
def test_hsdp_sync_module_state(self):
|
||||
model = MyModel().cuda()
|
||||
num_node_devices = torch.cuda.device_count()
|
||||
shard_rank_lists = list(range(0, num_node_devices // 2)), list(
|
||||
range(num_node_devices // 2, num_node_devices)
|
||||
shard_rank_lists = (
|
||||
list(range(0, num_node_devices // 2)),
|
||||
list(range(num_node_devices // 2, num_node_devices)),
|
||||
)
|
||||
shard_groups = (
|
||||
dist.new_group(shard_rank_lists[0]),
|
||||
|
|
@ -310,8 +312,9 @@ class TestFSDPHybridShard(FSDPTest):
|
|||
cntr = Counter()
|
||||
patched_allreduce = partial(patched_collective, orig_ar, cntr)
|
||||
patched_reduce_scatter = partial(patched_collective, orig_rs, cntr)
|
||||
with patch_allreduce(patched_allreduce), patch_reduce_scatter(
|
||||
patched_reduce_scatter
|
||||
with (
|
||||
patch_allreduce(patched_allreduce),
|
||||
patch_reduce_scatter(patched_reduce_scatter),
|
||||
):
|
||||
inp = hsdp_model.get_input(device=torch.cuda.current_device())
|
||||
out = hsdp_model(inp[0], inp[1])
|
||||
|
|
@ -355,9 +358,9 @@ class TestFSDPHybridShard(FSDPTest):
|
|||
use_orig_params,
|
||||
hsdp_process_groups=hsdp_pgs,
|
||||
)
|
||||
assert (
|
||||
hsdp_model._inter_node_pg.size() > 1
|
||||
), "HSDP model initialized without replication"
|
||||
assert hsdp_model._inter_node_pg.size() > 1, (
|
||||
"HSDP model initialized without replication"
|
||||
)
|
||||
fsdp_optim = torch.optim.Adam(fsdp_model.parameters(), lr=1e-2)
|
||||
hsdp_optim = torch.optim.Adam(hsdp_model.parameters(), lr=1e-2)
|
||||
torch.manual_seed(global_pg.rank() + 1)
|
||||
|
|
|
|||
|
|
@ -766,9 +766,9 @@ class TestFSDPMixedPrecisionSharded(TestFSDPMixedPrecision):
|
|||
if expect_use_full_prec_in_eval:
|
||||
assert x.dtype == torch.float32, f"Expected fp32, got {x.dtype}"
|
||||
else:
|
||||
assert (
|
||||
x.dtype == low_prec_dtype
|
||||
), f"Expected {low_prec_dtype}, got {x.dtype}"
|
||||
assert x.dtype == low_prec_dtype, (
|
||||
f"Expected {low_prec_dtype}, got {x.dtype}"
|
||||
)
|
||||
return self.a(x)
|
||||
|
||||
mp_config = MixedPrecision(
|
||||
|
|
|
|||
|
|
@ -91,9 +91,9 @@ class TestTPFSDPIntegration(FSDPTest):
|
|||
tensor_parallel_size: int,
|
||||
) -> tuple[dict[str, int], dict[str, tuple[torch.Size, int]]]:
|
||||
""" """
|
||||
assert (
|
||||
type(model) is SimpleModel
|
||||
), "Expects a `SimpleModel` since the sharding cases on the model definition"
|
||||
assert type(model) is SimpleModel, (
|
||||
"Expects a `SimpleModel` since the sharding cases on the model definition"
|
||||
)
|
||||
param_name_to_numel = OrderedDict()
|
||||
param_name_to_sharding_info = OrderedDict()
|
||||
for param_name, param in model.named_parameters():
|
||||
|
|
|
|||
|
|
@ -654,9 +654,12 @@ class TestFSDPUseOrigParamsUnshardReshard(FSDPTest):
|
|||
losses1 = []
|
||||
losses2 = []
|
||||
losses = []
|
||||
for _model, _optim in (fsdp_model, optim), (
|
||||
fsdp_model_orig_params,
|
||||
optim_orig_params,
|
||||
for _model, _optim in (
|
||||
(fsdp_model, optim),
|
||||
(
|
||||
fsdp_model_orig_params,
|
||||
optim_orig_params,
|
||||
),
|
||||
):
|
||||
_optim.zero_grad()
|
||||
loss1 = _model(*inp1)
|
||||
|
|
@ -1166,9 +1169,9 @@ class TestFSDPUseOrigParamsFQNs(FSDPTest):
|
|||
clean_tensor_name(tup[0]) for tup in self.named_parameters()
|
||||
]
|
||||
params = [tup[1] for tup in self.named_parameters()]
|
||||
assert (
|
||||
param_shapes[0] is not None and param_shapes[1] is not None
|
||||
), "`param_sizes` should be set"
|
||||
assert param_shapes[0] is not None and param_shapes[1] is not None, (
|
||||
"`param_sizes` should be set"
|
||||
)
|
||||
assert_equal_fn(
|
||||
param_names,
|
||||
[
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ The script itself is not a test case hence no assertions are made in this script
|
|||
see: - test/distributed/launcher/run_test.py#test_is_torchelastic_launched()
|
||||
- test/distributed/launcher/run_test.py#test_is_not_torchelastic_launched()
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
import torch.distributed as dist
|
||||
|
|
|
|||
|
|
@ -292,9 +292,9 @@ class TestZeroRedundancyOptimizerSingleRank(TestZeroRedundancyOptimizer):
|
|||
betas=BETAS,
|
||||
eps=EPS,
|
||||
)
|
||||
assert (
|
||||
len(o.param_groups) == 2
|
||||
), f"Expected 2 ZeRO param groups, but got {len(o.param_groups)}"
|
||||
assert len(o.param_groups) == 2, (
|
||||
f"Expected 2 ZeRO param groups, but got {len(o.param_groups)}"
|
||||
)
|
||||
assert len(o.optim.param_groups) == 2, (
|
||||
"Expected 2 local optimizer param groups, but got "
|
||||
f"{len(o.optim.param_groups)}"
|
||||
|
|
@ -713,9 +713,9 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
|
|||
LR = 1e-3
|
||||
MOMENTUM = 0.99
|
||||
REFERENCE_RANK = 0
|
||||
assert (
|
||||
REFERENCE_RANK in subgroup_ranks
|
||||
), "Reference rank must be in the new process group"
|
||||
assert REFERENCE_RANK in subgroup_ranks, (
|
||||
"Reference rank must be in the new process group"
|
||||
)
|
||||
loss_fn = torch.nn.L1Loss().to(device)
|
||||
|
||||
def check(optimizer):
|
||||
|
|
@ -1165,22 +1165,28 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
|
|||
|
||||
# Increased tolerances are needed to pass when using TF32
|
||||
# See: https://github.com/pytorch/pytorch/issues/67764
|
||||
torch.testing.assert_close(
|
||||
local_loss.cpu(),
|
||||
ddp_loss.cpu(),
|
||||
rtol=1e-03,
|
||||
atol=1e-08,
|
||||
), "Losses differ between local optimizer and ZeRO"
|
||||
(
|
||||
torch.testing.assert_close(
|
||||
local_loss.cpu(),
|
||||
ddp_loss.cpu(),
|
||||
rtol=1e-03,
|
||||
atol=1e-08,
|
||||
),
|
||||
"Losses differ between local optimizer and ZeRO",
|
||||
)
|
||||
|
||||
for local_p, ddp_p in zip(
|
||||
local_model.parameters(), ddp_model.parameters()
|
||||
):
|
||||
torch.testing.assert_close(
|
||||
local_p.cpu(),
|
||||
ddp_p.cpu(),
|
||||
rtol=1e-03,
|
||||
atol=1e-04,
|
||||
), "Models differ after a step"
|
||||
(
|
||||
torch.testing.assert_close(
|
||||
local_p.cpu(),
|
||||
ddp_p.cpu(),
|
||||
rtol=1e-03,
|
||||
atol=1e-04,
|
||||
),
|
||||
"Models differ after a step",
|
||||
)
|
||||
|
||||
@skipIfHpu
|
||||
@skip_if_lt_x_gpu(4)
|
||||
|
|
|
|||
|
|
@ -89,9 +89,9 @@ class PipeTests(TestCase):
|
|||
mb_args=(x, y),
|
||||
)
|
||||
|
||||
assert (
|
||||
pipe.num_stages == EXPECTED_N_STAGES[ModelClass]
|
||||
), f"nstages = {pipe.num_stages}, expect {EXPECTED_N_STAGES[ModelClass]}"
|
||||
assert pipe.num_stages == EXPECTED_N_STAGES[ModelClass], (
|
||||
f"nstages = {pipe.num_stages}, expect {EXPECTED_N_STAGES[ModelClass]}"
|
||||
)
|
||||
|
||||
ref_out = mod(x, y)
|
||||
out = pipe(x, y)[0]
|
||||
|
|
@ -109,9 +109,7 @@ class PipeTests(TestCase):
|
|||
new_names.update(stage_fqns)
|
||||
|
||||
if CHECK_FQN_SET_EQUALITY:
|
||||
assert (
|
||||
old_names == new_names
|
||||
), f"""
|
||||
assert old_names == new_names, f"""
|
||||
old names {old_names}
|
||||
new names {new_names}
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -60,9 +60,9 @@ class UnflattenTests(TestCase):
|
|||
for stage_idx in range(pipe.num_stages):
|
||||
stage_mod = pipe.get_stage_module(stage_idx)
|
||||
for param_name, _ in stage_mod.named_parameters():
|
||||
assert (
|
||||
param_name in orig_state_dict
|
||||
), f"{param_name} not in original state dict"
|
||||
assert param_name in orig_state_dict, (
|
||||
f"{param_name} not in original state dict"
|
||||
)
|
||||
print("Param qualname test passed")
|
||||
|
||||
# Check equivalence
|
||||
|
|
|
|||
|
|
@ -45,9 +45,9 @@ class ShareMemoryRPCPickler(_InternalRPCPickler):
|
|||
for t in torch._tensor_classes:
|
||||
self._dispatch_table[t] = TorchMpReductions.reduce_tensor
|
||||
self._dispatch_table[torch.Tensor] = TorchMpReductions.reduce_tensor
|
||||
self._dispatch_table[
|
||||
torch.nn.parameter.Parameter
|
||||
] = TorchMpReductions.reduce_tensor
|
||||
self._dispatch_table[torch.nn.parameter.Parameter] = (
|
||||
TorchMpReductions.reduce_tensor
|
||||
)
|
||||
|
||||
|
||||
def worker_loop(a):
|
||||
|
|
|
|||
|
|
@ -872,9 +872,7 @@ def forward(self, primals_1):
|
|||
"buf0 = torch.ops._c10d_functional.all_gather_into_tensor.default(primal"
|
||||
).check("torch.ops._c10d_functional.wait_tensor.default(buf0").check(
|
||||
"extern_kernels.mm(buf0,"
|
||||
).run(
|
||||
code
|
||||
)
|
||||
).run(code)
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@skip_if_lt_x_gpu(1)
|
||||
|
|
|
|||
|
|
@ -441,9 +441,11 @@ class DistMathOpsTest(DTensorTestBase):
|
|||
out_req_grad: bool
|
||||
|
||||
subtest_fails = {}
|
||||
valid_filter = lambda cfg: not ( # noqa: E731
|
||||
cfg.ln_req_grad and not cfg.elementwise_affine
|
||||
) and any(cfg[2:])
|
||||
valid_filter = ( # noqa: E731
|
||||
lambda cfg: (
|
||||
not (cfg.ln_req_grad and not cfg.elementwise_affine) and any(cfg[2:])
|
||||
)
|
||||
)
|
||||
subtest_cfgs = list(
|
||||
filter(
|
||||
valid_filter,
|
||||
|
|
@ -566,9 +568,9 @@ class DistMathOpsTest(DTensorTestBase):
|
|||
except Exception as e:
|
||||
subtest_fails[subtest_cfg] = e
|
||||
# if any subtest fails, provide the failed subtests and report the overall failure
|
||||
assert (
|
||||
not subtest_fails
|
||||
), f"{len(subtest_fails)}/{len(subtest_cfgs)} subtests failed: {pformat(subtest_fails)}"
|
||||
assert not subtest_fails, (
|
||||
f"{len(subtest_fails)}/{len(subtest_cfgs)} subtests failed: {pformat(subtest_fails)}"
|
||||
)
|
||||
|
||||
@with_comms
|
||||
def test_topk(self):
|
||||
|
|
|
|||
|
|
@ -26,7 +26,9 @@ def with_xla(func: Callable) -> Callable:
|
|||
|
||||
@wraps(func) # pyre-ignore[6]
|
||||
def wrapper(
|
||||
self, *args: tuple[object], **kwargs: dict[str, Any] # type: ignore[misc]
|
||||
self,
|
||||
*args: tuple[object],
|
||||
**kwargs: dict[str, Any], # type: ignore[misc]
|
||||
) -> None:
|
||||
# TODO(yeounoh) replace this with xr.use_spmd() when we deprecate the flag.
|
||||
os.environ["XLA_USE_SPMD"] = "1"
|
||||
|
|
|
|||
|
|
@ -2234,8 +2234,8 @@ class LocalRankTest(MultiProcessTestCase):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
assert (
|
||||
not torch.cuda._initialized
|
||||
), "test_distributed must not have initialized CUDA context on main process"
|
||||
assert not torch.cuda._initialized, (
|
||||
"test_distributed must not have initialized CUDA context on main process"
|
||||
)
|
||||
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -796,13 +796,11 @@ class CompileTest(TestCase):
|
|||
.check("buf6 = empty")
|
||||
# Expect in-place with inductor allocated buf
|
||||
.check(
|
||||
"torch.ops._c10d_functional.all_reduce_coalesced_"
|
||||
".default([buf0, buf1]"
|
||||
"torch.ops._c10d_functional.all_reduce_coalesced_.default([buf0, buf1]"
|
||||
)
|
||||
# Expect no in-place with graph input (buf5, buf6 are clones)
|
||||
.check(
|
||||
"torch.ops._c10d_functional.all_reduce_coalesced_"
|
||||
".default([buf5, buf6]"
|
||||
"torch.ops._c10d_functional.all_reduce_coalesced_.default([buf5, buf6]"
|
||||
)
|
||||
.check("torch.ops._c10d_functional.wait_tensor.default(buf0")
|
||||
.check("torch.ops._c10d_functional.wait_tensor.default(buf1")
|
||||
|
|
|
|||
|
|
@ -2705,8 +2705,8 @@ class LargeCommTest(test_c10d_common.AbstractLargeCommTest, MultiProcessTestCase
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
assert (
|
||||
not torch.cuda._initialized
|
||||
), "test_distributed must not have initialized CUDA context on main process"
|
||||
assert not torch.cuda._initialized, (
|
||||
"test_distributed must not have initialized CUDA context on main process"
|
||||
)
|
||||
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -2869,9 +2869,9 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
|
|||
self.assertTrue(t.is_alive())
|
||||
|
||||
if prev_nccl_async_error_handling is not None:
|
||||
os.environ[
|
||||
"TORCH_NCCL_ASYNC_ERROR_HANDLING"
|
||||
] = prev_nccl_async_error_handling
|
||||
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = (
|
||||
prev_nccl_async_error_handling
|
||||
)
|
||||
|
||||
@requires_nccl()
|
||||
@skip_if_lt_x_gpu(3)
|
||||
|
|
@ -2931,9 +2931,9 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
|
|||
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "0"
|
||||
self._test_barrier_error()
|
||||
if prev_nccl_async_error_handling is not None:
|
||||
os.environ[
|
||||
"TORCH_NCCL_ASYNC_ERROR_HANDLING"
|
||||
] = prev_nccl_async_error_handling
|
||||
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = (
|
||||
prev_nccl_async_error_handling
|
||||
)
|
||||
|
||||
@requires_nccl()
|
||||
@requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
|
||||
|
|
@ -2984,9 +2984,9 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
|
|||
process_group.abort()
|
||||
|
||||
if prev_nccl_async_error_handling is not None:
|
||||
os.environ[
|
||||
"TORCH_NCCL_ASYNC_ERROR_HANDLING"
|
||||
] = prev_nccl_async_error_handling
|
||||
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = (
|
||||
prev_nccl_async_error_handling
|
||||
)
|
||||
|
||||
@requires_nccl()
|
||||
@requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
|
||||
|
|
@ -3065,9 +3065,9 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
|
|||
os.remove(new_file_name)
|
||||
|
||||
if prev_nccl_async_error_handling is not None:
|
||||
os.environ[
|
||||
"TORCH_NCCL_ASYNC_ERROR_HANDLING"
|
||||
] = prev_nccl_async_error_handling
|
||||
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = (
|
||||
prev_nccl_async_error_handling
|
||||
)
|
||||
|
||||
def _run_invalid_nccl_blocking_wait_env(self, val):
|
||||
os.environ["TORCH_NCCL_BLOCKING_WAIT"] = val
|
||||
|
|
@ -3360,9 +3360,7 @@ class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase):
|
|||
self.assertEqual(_get_intra_node_comm_usage_counter(), 3)
|
||||
|
||||
# Verify that IntraNodeComm is not used beyond 10MB
|
||||
t = torch.full(
|
||||
(10 * 1024**2 // 2 + 1,), self.rank, dtype=torch.bfloat16
|
||||
).cuda()
|
||||
t = torch.full((10 * 1024**2 // 2 + 1,), self.rank, dtype=torch.bfloat16).cuda()
|
||||
c10d.all_reduce(t, c10d.ReduceOp.SUM)
|
||||
self.assertTrue(t.eq(expect).all())
|
||||
self.assertEqual(_get_intra_node_comm_usage_counter(), 3)
|
||||
|
|
@ -4249,9 +4247,9 @@ class SparseCollective(MultiProcessTestCase):
|
|||
class NCCLTraceTestBase(MultiProcessTestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
os.environ[
|
||||
"TORCH_NCCL_ENABLE_TIMING"
|
||||
] = "0" # see 'timing_enabled' parametrized tests
|
||||
os.environ["TORCH_NCCL_ENABLE_TIMING"] = (
|
||||
"0" # see 'timing_enabled' parametrized tests
|
||||
)
|
||||
os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "1000"
|
||||
os.environ["TORCH_NCCL_DUMP_ON_TIMEOUT"] = "1"
|
||||
self.tempdir = tempfile.TemporaryDirectory()
|
||||
|
|
@ -5331,8 +5329,8 @@ class ProcessGroupNCCLLargerScaleTest(MultiProcessTestCase):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
assert (
|
||||
not torch.cuda._initialized
|
||||
), "test_distributed must not have initialized CUDA context on main process"
|
||||
assert not torch.cuda._initialized, (
|
||||
"test_distributed must not have initialized CUDA context on main process"
|
||||
)
|
||||
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -1090,8 +1090,8 @@ class UccProcessGroupWithDispatchedCollectivesTests(
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
assert (
|
||||
not torch.cuda._initialized
|
||||
), "test_distributed must not have initialized CUDA context on main process"
|
||||
assert not torch.cuda._initialized, (
|
||||
"test_distributed must not have initialized CUDA context on main process"
|
||||
)
|
||||
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -90,9 +90,9 @@ class TestCollectiveUtils(MultiProcessTestCase):
|
|||
|
||||
res = all_gather(data_or_fn=func, pg=pg)
|
||||
func.assert_called_once()
|
||||
assert res == list(
|
||||
range(self.world_size)
|
||||
), f"Expect res to be list of 0 through {self.world_size} (got {res})"
|
||||
assert res == list(range(self.world_size)), (
|
||||
f"Expect res to be list of 0 through {self.world_size} (got {res})"
|
||||
)
|
||||
|
||||
def test_all_gather_result_no_pg(self) -> None:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -207,8 +207,8 @@ class TestCollectives(TestCase):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
assert (
|
||||
not torch.cuda._initialized
|
||||
), "test_distributed must not have initialized CUDA context on main process"
|
||||
assert not torch.cuda._initialized, (
|
||||
"test_distributed must not have initialized CUDA context on main process"
|
||||
)
|
||||
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -490,8 +490,9 @@ class DeviceMeshTestNDim(DTensorTestBase):
|
|||
|
||||
# Create shard groups (e.g. (0, 1, 2, 3), (4, 5, 6, 7))
|
||||
# and assign the correct shard group to each rank
|
||||
shard_rank_lists = list(range(0, self.world_size // 2)), list(
|
||||
range(self.world_size // 2, self.world_size)
|
||||
shard_rank_lists = (
|
||||
list(range(0, self.world_size // 2)),
|
||||
list(range(self.world_size // 2, self.world_size)),
|
||||
)
|
||||
shard_groups = (
|
||||
new_group(shard_rank_lists[0]),
|
||||
|
|
|
|||
|
|
@ -1830,9 +1830,7 @@ class TestSingleProc(DynamoDistributedSingleProcTestCase):
|
|||
f"""{expected_guard_source} "L['self']._modules['net']" TYPE_MATCH"""
|
||||
).check(
|
||||
f"""{expected_guard_source} "L['self']._modules['net']._modules['0']" TYPE_MATCH"""
|
||||
).run(
|
||||
GUARDS_FILE.getvalue()
|
||||
)
|
||||
).run(GUARDS_FILE.getvalue())
|
||||
|
||||
self.assertTrue(same(correct_outputs, outputs))
|
||||
|
||||
|
|
|
|||
|
|
@ -519,12 +519,13 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
|
|||
out = a2a / a2a.sum(dim=0)
|
||||
return out
|
||||
|
||||
with _dynamo_dist_per_rank_init(
|
||||
self.rank, self.world_size
|
||||
), torch._dynamo.config.patch(
|
||||
dynamic_shapes=True,
|
||||
capture_dynamic_output_shape_ops=True,
|
||||
capture_scalar_outputs=True,
|
||||
with (
|
||||
_dynamo_dist_per_rank_init(self.rank, self.world_size),
|
||||
torch._dynamo.config.patch(
|
||||
dynamic_shapes=True,
|
||||
capture_dynamic_output_shape_ops=True,
|
||||
capture_scalar_outputs=True,
|
||||
),
|
||||
):
|
||||
row = self.world_size * (self.rank + 1) * (self.world_size + 1) / 2
|
||||
input_split_sizes_tensor = torch.tensor(
|
||||
|
|
@ -680,15 +681,15 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
|
|||
|
||||
return torch.ops.custom_ns.foo(a2a)
|
||||
|
||||
with _dynamo_dist_per_rank_init(
|
||||
self.rank, self.world_size
|
||||
), torch._dynamo.config.patch(
|
||||
dynamic_shapes=True,
|
||||
capture_dynamic_output_shape_ops=True,
|
||||
capture_scalar_outputs=True,
|
||||
), torch.library._scoped_library(
|
||||
"custom_ns", "FRAGMENT"
|
||||
) as lib:
|
||||
with (
|
||||
_dynamo_dist_per_rank_init(self.rank, self.world_size),
|
||||
torch._dynamo.config.patch(
|
||||
dynamic_shapes=True,
|
||||
capture_dynamic_output_shape_ops=True,
|
||||
capture_scalar_outputs=True,
|
||||
),
|
||||
torch.library._scoped_library("custom_ns", "FRAGMENT") as lib,
|
||||
):
|
||||
lib.define(
|
||||
"alltoall_autograd(Tensor input, SymInt[]? output_split_sizes, SymInt[]? input_split_sizes, str tag, int[] ranks, int group_size) -> Tensor" # noqa: B950
|
||||
)
|
||||
|
|
|
|||
|
|
@ -464,8 +464,8 @@ class ProcessGroupGlooWrapperTest(AbstractProcessGroupWrapperTest):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
assert (
|
||||
not torch.cuda._initialized
|
||||
), "test_pg_wrapper must not have initialized CUDA context on main process"
|
||||
assert not torch.cuda._initialized, (
|
||||
"test_pg_wrapper must not have initialized CUDA context on main process"
|
||||
)
|
||||
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -372,12 +372,8 @@ class TCPStoreTest(TestCase, StoreTestBase):
|
|||
# Use noqa to silence flake8.
|
||||
# Need to store in an unused variable here to ensure the first
|
||||
# object is not destroyed before the second object is created.
|
||||
store1 = dist.TCPStore(
|
||||
addr, port, 1, True, use_libuv=self._use_libuv
|
||||
) # noqa: F841
|
||||
store2 = dist.TCPStore(
|
||||
addr, port, 1, True, use_libuv=self._use_libuv
|
||||
) # noqa: F841
|
||||
store1 = dist.TCPStore(addr, port, 1, True, use_libuv=self._use_libuv) # noqa: F841
|
||||
store2 = dist.TCPStore(addr, port, 1, True, use_libuv=self._use_libuv) # noqa: F841
|
||||
self.assertEqual(store1.libuvBackend, self._use_libuv)
|
||||
self.assertEqual(store2.libuvBackend, self._use_libuv)
|
||||
|
||||
|
|
@ -767,7 +763,7 @@ class RendezvousFileTest(TestCase):
|
|||
|
||||
def test_nominal(self):
|
||||
with tempfile.NamedTemporaryFile(delete=False) as file:
|
||||
url = f'file:///{file.name.replace(os.path.sep, "/")}?world_size=2'
|
||||
url = f"file:///{file.name.replace(os.path.sep, '/')}?world_size=2"
|
||||
gen0 = dist.rendezvous(url + "&rank=0")
|
||||
store0, rank0, size0 = next(gen0)
|
||||
self.assertEqual(0, rank0)
|
||||
|
|
@ -1178,8 +1174,8 @@ class TestClientProtocol(TestCase):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
assert (
|
||||
not torch.cuda._initialized
|
||||
), "test_distributed must not have initialized CUDA context on main process"
|
||||
assert not torch.cuda._initialized, (
|
||||
"test_distributed must not have initialized CUDA context on main process"
|
||||
)
|
||||
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -81,9 +81,9 @@ def count_ops(
|
|||
for node in gm.graph.nodes:
|
||||
if match_rng_op(node, op) or node.target == op:
|
||||
actual_count += 1
|
||||
assert (
|
||||
actual_count >= freq_ge
|
||||
), f"In graph {gm}, expected {op} to have occurred at least {freq_ge} times in the graph, but got {actual_count}."
|
||||
assert actual_count >= freq_ge, (
|
||||
f"In graph {gm}, expected {op} to have occurred at least {freq_ge} times in the graph, but got {actual_count}."
|
||||
)
|
||||
return gm
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -455,9 +455,9 @@ class AutogradFunctionTests(torch._dynamo.test_case.TestCase):
|
|||
|
||||
# Modify gradient using .data (Dangerous: Breaks autograd tracking!)
|
||||
modified_grad = grad_output.clone()
|
||||
modified_grad.data[
|
||||
input_tensor.data < 0
|
||||
] = 0 # Zero-out gradients for negative inputs
|
||||
modified_grad.data[input_tensor.data < 0] = (
|
||||
0 # Zero-out gradients for negative inputs
|
||||
)
|
||||
|
||||
return modified_grad * 3
|
||||
|
||||
|
|
|
|||
|
|
@ -195,10 +195,13 @@ class GraphModule(torch.nn.Module):
|
|||
def f(x, y):
|
||||
return invoke_quant_test(inner, [x, y], scheme="nf4")
|
||||
|
||||
with mock.patch(
|
||||
"torch._dynamo.variables.higher_order_ops.BaseHOPVariable.supports_input_mutation",
|
||||
True,
|
||||
), torch.no_grad():
|
||||
with (
|
||||
mock.patch(
|
||||
"torch._dynamo.variables.higher_order_ops.BaseHOPVariable.supports_input_mutation",
|
||||
True,
|
||||
),
|
||||
torch.no_grad(),
|
||||
):
|
||||
torch.compile(f, backend=bk, fullgraph=True)(x.clone(), y)
|
||||
|
||||
self.assertEqual(len(bk.graphs), 1)
|
||||
|
|
@ -319,10 +322,13 @@ class GraphModule(torch.nn.Module):
|
|||
x = torch.randn(3, 3, requires_grad=False)
|
||||
x_clone = x.clone()
|
||||
y = torch.randn(3, 3, requires_grad=True)
|
||||
with mock.patch(
|
||||
"torch._dynamo.variables.higher_order_ops.BaseHOPVariable.supports_input_mutation",
|
||||
True,
|
||||
), torch.no_grad():
|
||||
with (
|
||||
mock.patch(
|
||||
"torch._dynamo.variables.higher_order_ops.BaseHOPVariable.supports_input_mutation",
|
||||
True,
|
||||
),
|
||||
torch.no_grad(),
|
||||
):
|
||||
compiled_out = torch.compile(f, backend=backend, fullgraph=True)(x, y)
|
||||
self.assertEqual(x, x_clone + 1)
|
||||
self.assertEqual(compiled_out, x_clone + y + 1)
|
||||
|
|
|
|||
|
|
@ -53,8 +53,8 @@ class BytecodeTests(torch._dynamo.test_case.TestCase):
|
|||
fn_str = f"""\
|
||||
def fn():
|
||||
foo.bar(1, 2, 3)
|
||||
{str(chr(10)).join(' ' * 4 + 'x' + str(i) + ' = 1' for i in range(1 << 9))}
|
||||
l = [{' '.join('x' + str(i) + ',' for i in range(1 << 9))}]
|
||||
{str(chr(10)).join(" " * 4 + "x" + str(i) + " = 1" for i in range(1 << 9))}
|
||||
l = [{" ".join("x" + str(i) + "," for i in range(1 << 9))}]
|
||||
"""
|
||||
locals = {}
|
||||
exec(fn_str, {}, locals)
|
||||
|
|
|
|||
|
|
@ -26,9 +26,10 @@ class CallbackTests(TestCase):
|
|||
def test_callbacks_with_duplicate_prevention(self) -> None:
|
||||
trigger = CallbackTrigger.DYNAMO
|
||||
compile_id = CompileId(0, 0)
|
||||
with callback_handler.install_callbacks(
|
||||
trigger, compile_id
|
||||
), callback_handler.install_callbacks(trigger, compile_id):
|
||||
with (
|
||||
callback_handler.install_callbacks(trigger, compile_id),
|
||||
callback_handler.install_callbacks(trigger, compile_id),
|
||||
):
|
||||
self._on_compile_start.assert_called_once()
|
||||
self._on_compile_end.assert_called_once()
|
||||
|
||||
|
|
|
|||
|
|
@ -1252,10 +1252,13 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase):
|
|||
def f(x, y):
|
||||
return x + y
|
||||
|
||||
x, y = torch.ones(
|
||||
1,
|
||||
), torch.zeros(
|
||||
1,
|
||||
x, y = (
|
||||
torch.ones(
|
||||
1,
|
||||
),
|
||||
torch.zeros(
|
||||
1,
|
||||
),
|
||||
)
|
||||
return f(x, y)
|
||||
|
||||
|
|
@ -1289,10 +1292,13 @@ class GraphModule(torch.nn.Module):
|
|||
def f(x, y):
|
||||
return x + y
|
||||
|
||||
x, y = torch.ones(
|
||||
1,
|
||||
), torch.zeros(
|
||||
1,
|
||||
x, y = (
|
||||
torch.ones(
|
||||
1,
|
||||
),
|
||||
torch.zeros(
|
||||
1,
|
||||
),
|
||||
)
|
||||
return f(x, y)
|
||||
|
||||
|
|
@ -1335,10 +1341,13 @@ class GraphModule(torch.nn.Module):
|
|||
|
||||
return inner_fn(x, y) + x
|
||||
|
||||
x, y = torch.ones(
|
||||
1,
|
||||
), torch.zeros(
|
||||
1,
|
||||
x, y = (
|
||||
torch.ones(
|
||||
1,
|
||||
),
|
||||
torch.zeros(
|
||||
1,
|
||||
),
|
||||
)
|
||||
return f(x, y)
|
||||
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ from torch.testing._internal.common_utils import (
|
|||
|
||||
|
||||
class CustomException(Exception):
|
||||
...
|
||||
pass
|
||||
|
||||
|
||||
class CustomExceptionMeta(type):
|
||||
|
|
@ -28,7 +28,7 @@ class CustomExceptionMeta(type):
|
|||
|
||||
|
||||
class CustomExceptionWithInstanceCheck(Exception, metaclass=CustomExceptionMeta):
|
||||
...
|
||||
pass
|
||||
|
||||
|
||||
class CustomExceptionWithArgs(Exception):
|
||||
|
|
@ -358,7 +358,7 @@ class ExceptionTests(torch._dynamo.test_case.TestCase):
|
|||
|
||||
def test_raise_custom_exception(self):
|
||||
class Exc(Exception):
|
||||
...
|
||||
pass
|
||||
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(t):
|
||||
|
|
@ -375,7 +375,7 @@ class ExceptionTests(torch._dynamo.test_case.TestCase):
|
|||
|
||||
def test_raise_custom_exception_with_args(self):
|
||||
class Exc(Exception):
|
||||
...
|
||||
pass
|
||||
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(t):
|
||||
|
|
|
|||
|
|
@ -324,9 +324,13 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
|
|||
distributed_state=None,
|
||||
package=None,
|
||||
)
|
||||
with compile_context(CompileContext(CompileId(0, 0))), tracing(
|
||||
tracer.output.tracing_context
|
||||
), tracer.set_current_tx(), get_metrics_context(), dynamo_timed(""):
|
||||
with (
|
||||
compile_context(CompileContext(CompileId(0, 0))),
|
||||
tracing(tracer.output.tracing_context),
|
||||
tracer.set_current_tx(),
|
||||
get_metrics_context(),
|
||||
dynamo_timed(""),
|
||||
):
|
||||
tracer.run()
|
||||
|
||||
check_fn_manager = CheckFunctionManager(
|
||||
|
|
|
|||
|
|
@ -1092,9 +1092,7 @@ not ___dict_contains('bbbbbbbb', G['sys'].modules)
|
|||
not ___dict_contains('cccccccc', G['sys'].modules)
|
||||
str(L['x'].device) == 'cpu'
|
||||
str(L['x'].dtype) == 'torch.float32'
|
||||
utils_device.CURRENT_DEVICE == None""".split(
|
||||
"\n"
|
||||
):
|
||||
utils_device.CURRENT_DEVICE == None""".split("\n"):
|
||||
self.assertIn(
|
||||
line,
|
||||
guard_code_str,
|
||||
|
|
@ -2806,7 +2804,7 @@ utils_device.CURRENT_DEVICE == None""".split(
|
|||
"int",
|
||||
np.intp,
|
||||
np.int32,
|
||||
np.uint8
|
||||
np.uint8,
|
||||
# np.dtype('int') # XXX: as above
|
||||
]
|
||||
|
||||
|
|
@ -5527,9 +5525,9 @@ utils_device.CURRENT_DEVICE == None""".split(
|
|||
|
||||
def forward(self, idx, targets=None):
|
||||
b, t = idx.size()
|
||||
assert (
|
||||
t <= self.block_size
|
||||
), "Cannot forward, model block size is exhausted."
|
||||
assert t <= self.block_size, (
|
||||
"Cannot forward, model block size is exhausted."
|
||||
)
|
||||
|
||||
# forward the GPT model
|
||||
token_embeddings = self.tok_emb(
|
||||
|
|
@ -6075,15 +6073,17 @@ utils_device.CURRENT_DEVICE == None""".split(
|
|||
def count_graph_break_msgs(msgs):
|
||||
return sum("Graph break in user code" in msg for msg in msgs)
|
||||
|
||||
with self.assertLogs(
|
||||
logger="torch._dynamo", level=logging.DEBUG
|
||||
) as log, torch._dynamo.config.patch(verbose=True):
|
||||
with (
|
||||
self.assertLogs(logger="torch._dynamo", level=logging.DEBUG) as log,
|
||||
torch._dynamo.config.patch(verbose=True),
|
||||
):
|
||||
f1(torch.randn(10), torch.randn(10))
|
||||
self.assertGreater(count_graph_break_msgs(log.output), 1)
|
||||
|
||||
with self.assertLogs(
|
||||
logger="torch._dynamo", level=logging.DEBUG
|
||||
) as log, torch._dynamo.config.patch(verbose=False):
|
||||
with (
|
||||
self.assertLogs(logger="torch._dynamo", level=logging.DEBUG) as log,
|
||||
torch._dynamo.config.patch(verbose=False),
|
||||
):
|
||||
g1(torch.randn(10), torch.randn(10))
|
||||
self.assertEqual(count_graph_break_msgs(log.output), 1)
|
||||
|
||||
|
|
@ -8235,8 +8235,9 @@ utils_device.CURRENT_DEVICE == None""".split(
|
|||
def f(a):
|
||||
return h(a)
|
||||
|
||||
with warnings.catch_warnings(record=True) as w, self.assertRaises(
|
||||
torch._dynamo.exc.BackendCompilerFailed
|
||||
with (
|
||||
warnings.catch_warnings(record=True) as w,
|
||||
self.assertRaises(torch._dynamo.exc.BackendCompilerFailed),
|
||||
):
|
||||
f(torch.randn(2, 2, requires_grad=True))
|
||||
|
||||
|
|
@ -8429,8 +8430,7 @@ utils_device.CURRENT_DEVICE == None""".split(
|
|||
|
||||
def test_torch_compile_ctx_on_forward_and_training_step(self):
|
||||
class MyModel(torch.nn.Module):
|
||||
def forward(self):
|
||||
...
|
||||
def forward(self): ...
|
||||
|
||||
def training_step(self):
|
||||
self()
|
||||
|
|
|
|||
|
|
@ -2094,11 +2094,12 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
|
|||
mod = MockModule()
|
||||
# Each submod is compiled separately and has a different nn module
|
||||
# guard. Ensure that recompilation logic is handle correctly.
|
||||
with unittest.mock.patch(
|
||||
"torch._dynamo.config.error_on_recompile", True
|
||||
), unittest.mock.patch(
|
||||
"torch._dynamo.config.recompile_limit",
|
||||
recompile_limit,
|
||||
with (
|
||||
unittest.mock.patch("torch._dynamo.config.error_on_recompile", True),
|
||||
unittest.mock.patch(
|
||||
"torch._dynamo.config.recompile_limit",
|
||||
recompile_limit,
|
||||
),
|
||||
):
|
||||
x = torch.randn(*size, requires_grad=True)
|
||||
mod(x)
|
||||
|
|
@ -2160,11 +2161,12 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
|
|||
mod = MockModule()
|
||||
# Each submod is compiled separately and has a different nn module
|
||||
# guard. Ensure that recompilation logic is handle correctly.
|
||||
with unittest.mock.patch(
|
||||
"torch._dynamo.config.error_on_recompile", True
|
||||
), unittest.mock.patch(
|
||||
"torch._dynamo.config.recompile_limit",
|
||||
recompile_limit,
|
||||
with (
|
||||
unittest.mock.patch("torch._dynamo.config.error_on_recompile", True),
|
||||
unittest.mock.patch(
|
||||
"torch._dynamo.config.recompile_limit",
|
||||
recompile_limit,
|
||||
),
|
||||
):
|
||||
x = torch.randn(*size, requires_grad=True)
|
||||
mod(x)
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
PYTEST_DONT_REWRITE (prevents pytest from rewriting assertions, which interferes
|
||||
with test_adam in OptimizerTests)
|
||||
"""
|
||||
|
||||
import functools
|
||||
|
||||
import torch
|
||||
|
|
|
|||
|
|
@ -238,9 +238,7 @@ class RecompileUxTests(torch._dynamo.test_case.TestCase):
|
|||
tensor 'x' size mismatch at index 0. expected 11, actual 12
|
||||
tensor 'x' size mismatch at index 0. expected 10, actual 12
|
||||
tensor 'x' size mismatch at index 0. expected 9, actual 12
|
||||
tensor 'x' size mismatch at index 0. expected 8, actual 12""".split(
|
||||
"\n"
|
||||
):
|
||||
tensor 'x' size mismatch at index 0. expected 8, actual 12""".split("\n"):
|
||||
self.assertIn(
|
||||
line,
|
||||
failure_str,
|
||||
|
|
@ -276,9 +274,7 @@ tensor 'x' size mismatch at index 0. expected 8, actual 12""".split(
|
|||
opt_f([7, 8])
|
||||
|
||||
for line in """\
|
||||
len(x) == 3""".split(
|
||||
"\n"
|
||||
):
|
||||
len(x) == 3""".split("\n"):
|
||||
self.assertIn(line, filter_reasons())
|
||||
|
||||
failure_reasons.clear()
|
||||
|
|
@ -286,9 +282,7 @@ len(x) == 3""".split(
|
|||
|
||||
for line in """\
|
||||
len(x) == 2
|
||||
len(x) == 3""".split(
|
||||
"\n"
|
||||
):
|
||||
len(x) == 3""".split("\n"):
|
||||
self.assertIn(line, filter_reasons())
|
||||
|
||||
@torch._dynamo.config.patch(recompile_limit=1)
|
||||
|
|
|
|||
|
|
@ -179,9 +179,9 @@ def shapes_to_tensor(x, device=None):
|
|||
if torch.jit.is_scripting():
|
||||
return torch.as_tensor(x, device=device)
|
||||
if torch.jit.is_tracing():
|
||||
assert all(
|
||||
isinstance(t, torch.Tensor) for t in x
|
||||
), "Shape should be tensor during tracing!"
|
||||
assert all(isinstance(t, torch.Tensor) for t in x), (
|
||||
"Shape should be tensor during tracing!"
|
||||
)
|
||||
# as_tensor should not be used in tracing because it records a constant
|
||||
ret = torch.stack(x)
|
||||
if ret.device != device: # avoid recording a hard-coded device if not necessary
|
||||
|
|
@ -480,9 +480,9 @@ class PartialT5(torch.nn.Module):
|
|||
real_seq_length = seq_length
|
||||
|
||||
if past_key_value is not None:
|
||||
assert (
|
||||
len(past_key_value) == 2
|
||||
), f"past_key_value should have 2 past states: keys and values. Got {len(past_key_value)} past states"
|
||||
assert len(past_key_value) == 2, (
|
||||
f"past_key_value should have 2 past states: keys and values. Got {len(past_key_value)} past states"
|
||||
)
|
||||
real_seq_length += (
|
||||
past_key_value[0].shape[2] if query_length is None else query_length
|
||||
)
|
||||
|
|
@ -4877,9 +4877,9 @@ def forward(self, s77 : torch.SymInt, s27 : torch.SymInt, L_x_ : torch.Tensor):
|
|||
with warnings.catch_warnings(record=True):
|
||||
data_len = len(value)
|
||||
if len(self._fields):
|
||||
assert (
|
||||
len(self) == data_len
|
||||
), f"Adding a field of length {data_len} to a Instances of length {len(self)}"
|
||||
assert len(self) == data_len, (
|
||||
f"Adding a field of length {data_len} to a Instances of length {len(self)}"
|
||||
)
|
||||
self._fields[name] = value
|
||||
|
||||
def get(self, name: str) -> Any:
|
||||
|
|
|
|||
|
|
@ -108,9 +108,10 @@ def get_view_test_cases():
|
|||
for requires_grad_1, requires_grad_2 in itertools.product(
|
||||
[True, False], repeat=2
|
||||
):
|
||||
yield partial(
|
||||
mk_leaf, base_is_nt, requires_grad_1, requires_grad_2
|
||||
), f"{prefix}_leaf_{requires_grad_1}_{requires_grad_2}"
|
||||
yield (
|
||||
partial(mk_leaf, base_is_nt, requires_grad_1, requires_grad_2),
|
||||
f"{prefix}_leaf_{requires_grad_1}_{requires_grad_2}",
|
||||
)
|
||||
|
||||
# (3) obscure case:
|
||||
# view is not a leaf (implies requires_grad True)
|
||||
|
|
@ -118,9 +119,10 @@ def get_view_test_cases():
|
|||
yield partial(mk_obscure, base_is_nt), f"{prefix}_obscure"
|
||||
|
||||
# Subclass -> Dense
|
||||
yield lambda: get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)[
|
||||
0
|
||||
].clone(), "subclass_dense"
|
||||
yield (
|
||||
lambda: get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)[0].clone(),
|
||||
"subclass_dense",
|
||||
)
|
||||
|
||||
# Dense -> Subclass -> Dense -> Subclass
|
||||
def mk_dense_subclass_dense_subclass():
|
||||
|
|
|
|||
|
|
@ -151,9 +151,9 @@ def gen_allowed_objs_and_ids(record=False, c_binding_only=True) -> AllowedObject
|
|||
types.WrapperDescriptorType,
|
||||
),
|
||||
) or is_special_functions(obj):
|
||||
torch_name_rule_map[
|
||||
f"{module.__name__}.{name}"
|
||||
] = TorchInGraphFunctionVariable
|
||||
torch_name_rule_map[f"{module.__name__}.{name}"] = (
|
||||
TorchInGraphFunctionVariable
|
||||
)
|
||||
if c_binding_only:
|
||||
if not hasattr(obj, "__code__"):
|
||||
c_binding_in_graph_functions.add(obj)
|
||||
|
|
@ -398,12 +398,15 @@ class TraceRuleTests(torch._dynamo.test_case.TestCase):
|
|||
)
|
||||
self.assertTrue("torch._dynamo" not in torch._dynamo.trace_rules.MOD_INLINELIST)
|
||||
|
||||
with unittest.mock.patch(
|
||||
"torch._dynamo.trace_rules.torch_name_rule_map",
|
||||
_torch_name_rule_map,
|
||||
), unittest.mock.patch(
|
||||
"torch._dynamo.trace_rules.get_torch_obj_rule_map",
|
||||
torch._dynamo.trace_rules.get_torch_obj_rule_map.__wrapped__, # bypass functools.lru_cache
|
||||
with (
|
||||
unittest.mock.patch(
|
||||
"torch._dynamo.trace_rules.torch_name_rule_map",
|
||||
_torch_name_rule_map,
|
||||
),
|
||||
unittest.mock.patch(
|
||||
"torch._dynamo.trace_rules.get_torch_obj_rule_map",
|
||||
torch._dynamo.trace_rules.get_torch_obj_rule_map.__wrapped__, # bypass functools.lru_cache
|
||||
),
|
||||
):
|
||||
x = torch.rand(3)
|
||||
opt_fn = torch.compile(backend="eager", fullgraph=True)(fn)
|
||||
|
|
@ -419,9 +422,9 @@ class TraceRuleTests(torch._dynamo.test_case.TestCase):
|
|||
|
||||
_manual_torch_name_rule_map = manual_torch_name_rule_map.copy()
|
||||
# Force inline `mod.func` by setting trace rule.
|
||||
_manual_torch_name_rule_map[
|
||||
f"{mod.__name__}.{func.__name__}"
|
||||
] = UserFunctionVariable
|
||||
_manual_torch_name_rule_map[f"{mod.__name__}.{func.__name__}"] = (
|
||||
UserFunctionVariable
|
||||
)
|
||||
|
||||
_torch_name_rule_map = [
|
||||
_manual_torch_name_rule_map,
|
||||
|
|
@ -429,12 +432,15 @@ class TraceRuleTests(torch._dynamo.test_case.TestCase):
|
|||
torch_non_c_binding_in_graph_functions,
|
||||
]
|
||||
|
||||
with unittest.mock.patch(
|
||||
"torch._dynamo.trace_rules.torch_name_rule_map",
|
||||
_torch_name_rule_map,
|
||||
), unittest.mock.patch(
|
||||
"torch._dynamo.trace_rules.get_torch_obj_rule_map",
|
||||
torch._dynamo.trace_rules.get_torch_obj_rule_map.__wrapped__,
|
||||
with (
|
||||
unittest.mock.patch(
|
||||
"torch._dynamo.trace_rules.torch_name_rule_map",
|
||||
_torch_name_rule_map,
|
||||
),
|
||||
unittest.mock.patch(
|
||||
"torch._dynamo.trace_rules.get_torch_obj_rule_map",
|
||||
torch._dynamo.trace_rules.get_torch_obj_rule_map.__wrapped__,
|
||||
),
|
||||
):
|
||||
# First adding the module to SKIP_DIRS so that it will be skipped by default.
|
||||
torch._dynamo.trace_rules.add(mod.__name__)
|
||||
|
|
|
|||
|
|
@ -593,9 +593,10 @@ class TestDynamoTimed(TestCase):
|
|||
)
|
||||
|
||||
compilation_events = []
|
||||
with dynamo_config.patch({"automatic_dynamic_shapes": False}), mock.patch(
|
||||
"torch._dynamo.utils.log_compilation_event"
|
||||
) as log_event:
|
||||
with (
|
||||
dynamo_config.patch({"automatic_dynamic_shapes": False}),
|
||||
mock.patch("torch._dynamo.utils.log_compilation_event") as log_event,
|
||||
):
|
||||
|
||||
@torch.compile()
|
||||
def f(x):
|
||||
|
|
|
|||
|
|
@ -181,9 +181,12 @@ class TestDraftExport(TestCase):
|
|||
self.assertEqual(len(report.op_profiles), 1)
|
||||
self.assertEqual(len(report.op_profiles["mylib.foo8.default"]), 1)
|
||||
|
||||
with torch._library.fake_profile.unsafe_generate_fake_kernels(
|
||||
report.op_profiles
|
||||
), FakeTensorMode(allow_non_fake_inputs=True, shape_env=ShapeEnv()):
|
||||
with (
|
||||
torch._library.fake_profile.unsafe_generate_fake_kernels(
|
||||
report.op_profiles
|
||||
),
|
||||
FakeTensorMode(allow_non_fake_inputs=True, shape_env=ShapeEnv()),
|
||||
):
|
||||
torch.ops.mylib.foo8(*new_inp)
|
||||
|
||||
# Existing registration has been updated to match the new
|
||||
|
|
|
|||
|
|
@ -12834,12 +12834,14 @@ def forward(self, x, y):
|
|||
"y": [Dim("dy")], # y & z incorrect, export is supposed to fail.
|
||||
"z": [Dim("dz")], # suggested fix should be to match these up.
|
||||
}
|
||||
with self.assertRaisesRegex( # if disable=True, suggested fixes should not specialize.
|
||||
torch._dynamo.exc.UserError,
|
||||
r".*Constraints violated(.*\n)*"
|
||||
r"Suggested fixes:(.*\n)*"
|
||||
r".*dz = dy(.*\n)*",
|
||||
) as msg:
|
||||
with (
|
||||
self.assertRaisesRegex( # if disable=True, suggested fixes should not specialize.
|
||||
torch._dynamo.exc.UserError,
|
||||
r".*Constraints violated(.*\n)*"
|
||||
r"Suggested fixes:(.*\n)*"
|
||||
r".*dz = dy(.*\n)*",
|
||||
) as msg
|
||||
):
|
||||
export(
|
||||
Foo(),
|
||||
inputs,
|
||||
|
|
@ -13675,8 +13677,7 @@ def forward(self, x):
|
|||
"""Make sure the metadata is kept after exported program run_decompositions."""
|
||||
|
||||
@torch.library.custom_op("mylib::add", mutates_args=())
|
||||
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
...
|
||||
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: ...
|
||||
|
||||
@torch.library.register_fake("mylib::add")
|
||||
def _(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
|
|
|
|||
|
|
@ -947,10 +947,12 @@ class TestDeserialize(TestCase):
|
|||
ep = torch.export.export(M(), (torch.ones(3, 3), None, torch.ones(3, 3)))
|
||||
|
||||
serialized_program = ExportedProgramSerializer(None, 2).serialize(ep)
|
||||
serialized_program.exported_program.graph_module.signature.input_specs[
|
||||
1
|
||||
] = schema.InputSpec.create(
|
||||
user_input=schema.UserInputSpec(arg=schema.Argument.create(as_none=True))
|
||||
serialized_program.exported_program.graph_module.signature.input_specs[1] = (
|
||||
schema.InputSpec.create(
|
||||
user_input=schema.UserInputSpec(
|
||||
arg=schema.Argument.create(as_none=True)
|
||||
)
|
||||
)
|
||||
)
|
||||
ep = ExportedProgramDeserializer(None).deserialize(
|
||||
serialized_program.exported_program, {}, {}, {}
|
||||
|
|
|
|||
|
|
@ -22,7 +22,9 @@ from torch.testing._internal.common_utils import IS_WINDOWS, run_tests, TestCase
|
|||
{"strict": False},
|
||||
{"strict": True},
|
||||
],
|
||||
class_name_func=lambda cls, _, params: f"{cls.__name__}_{'strict' if params['strict'] else 'nonstrict'}",
|
||||
class_name_func=lambda cls,
|
||||
_,
|
||||
params: f"{cls.__name__}_{'strict' if params['strict'] else 'nonstrict'}",
|
||||
)
|
||||
class TestSwap(TestCase):
|
||||
def test_unflatten_preserve_signature(self):
|
||||
|
|
|
|||
|
|
@ -348,8 +348,7 @@ def check_fc(existing_schemas):
|
|||
"\n\t".join(str(s) for s in matching_new_schemas),
|
||||
)
|
||||
log.warning(
|
||||
"Refer to following reasons for failure "
|
||||
"to find FC schema:\n[\n%s\n]",
|
||||
"Refer to following reasons for failure to find FC schema:\n[\n%s\n]",
|
||||
"\n\t".join(str(r) for r in possible_failure_reasons),
|
||||
)
|
||||
broken_ops.append(str(existing_schema))
|
||||
|
|
|
|||
|
|
@ -523,15 +523,15 @@ def decorateForModules(decorator, module_classes, device_type=None, dtypes=None)
|
|||
dtypes=dtypes,
|
||||
):
|
||||
name_parts = fn.__qualname__.split(".")
|
||||
assert (
|
||||
len(name_parts) == 2
|
||||
), "Decorator only applies to a test function of a test class"
|
||||
assert len(name_parts) == 2, (
|
||||
"Decorator only applies to a test function of a test class"
|
||||
)
|
||||
test_case_name, base_test_name = name_parts
|
||||
for module_cls in module_classes:
|
||||
matching_module_infos = [m for m in module_db if m.module_cls == module_cls]
|
||||
assert (
|
||||
len(matching_module_infos) == 1
|
||||
), f"Couldn't find single ModuleInfo for {module_cls}"
|
||||
assert len(matching_module_infos) == 1, (
|
||||
f"Couldn't find single ModuleInfo for {module_cls}"
|
||||
)
|
||||
module_info = matching_module_infos[0]
|
||||
decorators = list(module_info.decorators)
|
||||
new_decorator = DecorateInfo(
|
||||
|
|
|
|||
|
|
@ -124,9 +124,7 @@ class TestGraphInfoProvider(TestCase):
|
|||
)
|
||||
|
||||
def test_recomputable_node_only_graph_with_larger_graph_context(self):
|
||||
recomputable_node_only_graph_with_larger_graph_context = (
|
||||
self.graph_info_provider.recomputable_node_only_graph_with_larger_graph_context
|
||||
)
|
||||
recomputable_node_only_graph_with_larger_graph_context = self.graph_info_provider.recomputable_node_only_graph_with_larger_graph_context # noqa: B950
|
||||
expected_nodes = self.all_recomputable_banned_nodes
|
||||
# node1 does not have an indirect path to node5 because of node2
|
||||
# node2 has an indirect path to node5
|
||||
|
|
|
|||
|
|
@ -2568,8 +2568,9 @@ def forward(self, primals_1, primals_2):
|
|||
def fn(x: torch.Tensor, x1: torch.Tensor) -> torch.Tensor:
|
||||
return torch.ops._test._clone_create_graph(x, x1)
|
||||
|
||||
inp_x, inp_x1 = torch.randn(3, requires_grad=True), torch.randn(
|
||||
3, requires_grad=True
|
||||
inp_x, inp_x1 = (
|
||||
torch.randn(3, requires_grad=True),
|
||||
torch.randn(3, requires_grad=True),
|
||||
)
|
||||
|
||||
ref_x, ref_x1 = inp_x.clone(), inp_x1.clone()
|
||||
|
|
@ -5283,11 +5284,12 @@ def forward(self, arg0_1):
|
|||
|
||||
mod = TestMod(fn)
|
||||
inp = torch.randn(2)
|
||||
with patch(
|
||||
"functorch.compile.config.functionalize_rng_ops", True
|
||||
), self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Functionalized RNG is not currently supported in the aot_export",
|
||||
with (
|
||||
patch("functorch.compile.config.functionalize_rng_ops", True),
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Functionalized RNG is not currently supported in the aot_export",
|
||||
),
|
||||
):
|
||||
aot_export_joint_simple(fn, [mod.p, inp], trace_joint=False)
|
||||
aot_export_joint_simple(fn, [mod.p, inp], trace_joint=True)
|
||||
|
|
|
|||
|
|
@ -4324,9 +4324,7 @@ class TestExamplesCorrectness(TestCase):
|
|||
|
||||
def lennard_jones_force(r):
|
||||
"""Get magnitude of LJ force"""
|
||||
return -epsilon * (
|
||||
(-12 * sigma**12 / r**13) + (6 * sigma**6 / r**7)
|
||||
)
|
||||
return -epsilon * ((-12 * sigma**12 / r**13) + (6 * sigma**6 / r**7))
|
||||
|
||||
r = torch.linspace(0.5, 2 * sigma, steps=100, requires_grad=True, device=device)
|
||||
drs = torch.outer(r, torch.tensor([1.0, 0, 0], device=device))
|
||||
|
|
@ -4495,8 +4493,9 @@ class TestExamplesCorrectness(TestCase):
|
|||
# This example mimics what a user might do when trying to find the optimal learning rate. They would
|
||||
# want to run a bunch of models with the same behavior (including the same dropout!) and have them
|
||||
# each run with different learning rates. Specifically, this is an example of using same randomness with vmap
|
||||
points, labels = torch.randn(100, 2, 2, 2, 2, device=device), torch.randint(
|
||||
0, 2, (100,), device=device
|
||||
points, labels = (
|
||||
torch.randn(100, 2, 2, 2, 2, device=device),
|
||||
torch.randint(0, 2, (100,), device=device),
|
||||
)
|
||||
|
||||
class MLPClassifier(nn.Module):
|
||||
|
|
|
|||
|
|
@ -208,33 +208,33 @@ def check(f, t, delta, check_val=True, graph_input=False):
|
|||
old_num_nodes = len(fx_g.graph.nodes)
|
||||
new_num_nodes = len(new_graph.nodes)
|
||||
if delta == -1:
|
||||
assert (
|
||||
old_num_nodes >= new_num_nodes
|
||||
), f"number of nodes increased {old_num_nodes}, {new_num_nodes}"
|
||||
assert old_num_nodes >= new_num_nodes, (
|
||||
f"number of nodes increased {old_num_nodes}, {new_num_nodes}"
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
old_num_nodes == new_num_nodes + delta
|
||||
), f"number of nodes not the same {old_num_nodes - delta}, {new_num_nodes}\n {fx_g.graph} \n {new_graph}"
|
||||
assert old_num_nodes == new_num_nodes + delta, (
|
||||
f"number of nodes not the same {old_num_nodes - delta}, {new_num_nodes}\n {fx_g.graph} \n {new_graph}"
|
||||
)
|
||||
|
||||
# a second pass should not reduce more nodes
|
||||
pass_2_graph = fx_graph_cse(new_graph)
|
||||
pass_2_num_nodes = len(pass_2_graph.nodes)
|
||||
assert (
|
||||
pass_2_num_nodes == new_num_nodes
|
||||
), f"second pass graph has less node {pass_2_num_nodes}, {new_num_nodes}\n {new_graph} \n {pass_2_graph}"
|
||||
assert pass_2_num_nodes == new_num_nodes, (
|
||||
f"second pass graph has less node {pass_2_num_nodes}, {new_num_nodes}\n {new_graph} \n {pass_2_graph}"
|
||||
)
|
||||
|
||||
# check correctness
|
||||
if check_val:
|
||||
true_result = fx_g(t)
|
||||
our_result = new_g(t)
|
||||
if true_result is None: # both return None
|
||||
assert (
|
||||
our_result is None
|
||||
), f"true result is None, CSE result is {our_result}"
|
||||
assert our_result is None, (
|
||||
f"true result is None, CSE result is {our_result}"
|
||||
)
|
||||
else: # results returned are the same
|
||||
assert torch.all(
|
||||
true_result == our_result
|
||||
), f"results are different {true_result}, {our_result}" # check results are the same
|
||||
assert torch.all(true_result == our_result), (
|
||||
f"results are different {true_result}, {our_result}"
|
||||
) # check results are the same
|
||||
|
||||
|
||||
class NoChangeTestCase(TestCase):
|
||||
|
|
|
|||
|
|
@ -2154,9 +2154,9 @@ class TestOperators(TestCase):
|
|||
else:
|
||||
weight = torch.randn(weight_shape, device=device)
|
||||
target = torch.randint(0, C, target_shape, device=device)
|
||||
target[
|
||||
0
|
||||
] = 1 # since we're ignoring index 0, at least one element must be non-zero
|
||||
target[0] = (
|
||||
1 # since we're ignoring index 0, at least one element must be non-zero
|
||||
)
|
||||
|
||||
fn = functools.partial(
|
||||
torch.nn.functional.nll_loss, target=target, weight=weight, **kwargs
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from unittest import mock
|
||||
|
||||
|
|
@ -107,7 +108,7 @@ class TestParsedExpression(TestCase):
|
|||
ParsedExpression("(a) ((b c) (d ...))")
|
||||
|
||||
# invalid identifiers
|
||||
ParsedExpression("camelCase under_scored cApiTaLs \u00DF ...")
|
||||
ParsedExpression("camelCase under_scored cApiTaLs \u00df ...")
|
||||
with self.assertRaises(ValueError):
|
||||
ParsedExpression("1a")
|
||||
with self.assertRaises(ValueError):
|
||||
|
|
|
|||
|
|
@ -25,7 +25,6 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|||
SOFTWARE.
|
||||
"""
|
||||
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
|
|
|
|||
|
|
@ -1532,7 +1532,7 @@ class TestVmapOperators(Namespace.TestVmapBase):
|
|||
self._test_unary(op, getter, "cpu")
|
||||
|
||||
# test in-place
|
||||
method = getattr(Tensor, f'{op.__name__ + "_"}')
|
||||
method = getattr(Tensor, f"{op.__name__ + '_'}")
|
||||
self._test_unary(method, getter, "cpu", check_propagates_grad=False)
|
||||
|
||||
def test_clone(self):
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ r"""
|
|||
**This file is EXPERIMENTAL and is mostly used for testing purposes! Do not
|
||||
rely on it for anything!**
|
||||
"""
|
||||
|
||||
import operator
|
||||
import sys
|
||||
|
||||
|
|
|
|||
|
|
@ -46,9 +46,9 @@ def check(self, f, t, delta, check_val=True, graph_input=False, P=None):
|
|||
old_num_nodes = len(fx_g.graph.nodes)
|
||||
new_num_nodes = len(new_graph.nodes)
|
||||
|
||||
assert (
|
||||
new_num_nodes < old_num_nodes
|
||||
) == modified, "modified should be True if the number of nodes decrease"
|
||||
assert (new_num_nodes < old_num_nodes) == modified, (
|
||||
"modified should be True if the number of nodes decrease"
|
||||
)
|
||||
|
||||
if delta == -1:
|
||||
self.assertTrue(
|
||||
|
|
|
|||
|
|
@ -1783,8 +1783,9 @@ class TestSingleOperation(unittest.TestCase):
|
|||
self.assertEqual(s.check(), z3.sat)
|
||||
|
||||
add_result = z3.Const(3, tensor_type)
|
||||
broadcast_res1, broadcast_res2 = z3.Const(4, tensor_type), z3.Const(
|
||||
5, tensor_type
|
||||
broadcast_res1, broadcast_res2 = (
|
||||
z3.Const(4, tensor_type),
|
||||
z3.Const(5, tensor_type),
|
||||
)
|
||||
|
||||
# print(s.model())
|
||||
|
|
|
|||
|
|
@ -251,10 +251,13 @@ class TestInvokeSubgraphCompile(TestCase):
|
|||
x_clone = x.detach().clone().requires_grad_(True)
|
||||
y_clone = y.detach().clone().requires_grad_(True)
|
||||
backend = EagerAndRecordGraphs()
|
||||
with mock.patch(
|
||||
"torch._dynamo.variables.higher_order_ops.InvokeSubgraphHigherOrderVariable.supports_input_mutation",
|
||||
True,
|
||||
), torch.no_grad():
|
||||
with (
|
||||
mock.patch(
|
||||
"torch._dynamo.variables.higher_order_ops.InvokeSubgraphHigherOrderVariable.supports_input_mutation",
|
||||
True,
|
||||
),
|
||||
torch.no_grad(),
|
||||
):
|
||||
res = torch.compile(fn, backend=backend, fullgraph=True)(
|
||||
mod, x_clone, y_clone
|
||||
)
|
||||
|
|
@ -2399,7 +2402,9 @@ class GraphModule(torch.nn.Module):
|
|||
{"strict": False},
|
||||
{"strict": True},
|
||||
],
|
||||
class_name_func=lambda cls, _, params: f"{cls.__name__}{'Strict' if params['strict'] else 'Nonstrict'}",
|
||||
class_name_func=lambda cls,
|
||||
_,
|
||||
params: f"{cls.__name__}{'Strict' if params['strict'] else 'Nonstrict'}",
|
||||
)
|
||||
class TestInvokeSubgraphExport(TestCase):
|
||||
def test_simple_func(self):
|
||||
|
|
|
|||
|
|
@ -38,7 +38,6 @@ USE_BLACK_FILELIST = re.compile(
|
|||
# torchgen/**
|
||||
# test/**
|
||||
# test/[a-h]*/**
|
||||
"test/[a-h]*/**",
|
||||
# test/[i-j]*/**
|
||||
"test/[i-j]*/**",
|
||||
# test/[k-m]*/**
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user