mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[BE][PYFMT] migrate PYFMT for test/[a-h]*/ to ruff format (#144555)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144555 Approved by: https://github.com/ezyang ghstack dependencies: #144551, #144554
This commit is contained in:
parent
e600e044a7
commit
6d5c789ad5
|
|
@ -494,9 +494,7 @@ class TestBaseDataSparsifier(_BaseDataSparsiferTestCase):
|
||||||
(
|
(
|
||||||
emb1,
|
emb1,
|
||||||
emb2,
|
emb2,
|
||||||
) = nn.Embedding(
|
) = nn.Embedding(10, 3), nn.Embedding(20, 3)
|
||||||
10, 3
|
|
||||||
), nn.Embedding(20, 3)
|
|
||||||
emb1_bag, emb2_bag = nn.EmbeddingBag(10, 3), nn.EmbeddingBag(20, 3)
|
emb1_bag, emb2_bag = nn.EmbeddingBag(10, 3), nn.EmbeddingBag(20, 3)
|
||||||
|
|
||||||
emb3, emb3_bag = nn.Embedding(15, 3), nn.EmbeddingBag(20, 3)
|
emb3, emb3_bag = nn.Embedding(15, 3), nn.EmbeddingBag(20, 3)
|
||||||
|
|
@ -627,9 +625,7 @@ class TestNormDataSparsifiers(_NormDataSparsifierTestCase):
|
||||||
(
|
(
|
||||||
emb1,
|
emb1,
|
||||||
emb2,
|
emb2,
|
||||||
) = nn.Embedding(
|
) = nn.Embedding(10, 3), nn.Embedding(20, 3)
|
||||||
10, 3
|
|
||||||
), nn.Embedding(20, 3)
|
|
||||||
emb1_bag, emb2_bag = nn.EmbeddingBag(10, 3), nn.EmbeddingBag(20, 3)
|
emb1_bag, emb2_bag = nn.EmbeddingBag(10, 3), nn.EmbeddingBag(20, 3)
|
||||||
|
|
||||||
emb3, emb3_bag = nn.Embedding(15, 3), nn.EmbeddingBag(20, 3)
|
emb3, emb3_bag = nn.Embedding(15, 3), nn.EmbeddingBag(20, 3)
|
||||||
|
|
|
||||||
|
|
@ -218,12 +218,12 @@ def _sparse_layer_test_helper(
|
||||||
qmodule_to_check = fqn_to_module(qmodel, fqn_to_check)
|
qmodule_to_check = fqn_to_module(qmodel, fqn_to_check)
|
||||||
|
|
||||||
# check that the modules were converted as expected
|
# check that the modules were converted as expected
|
||||||
assert isinstance(
|
assert isinstance(sqmodule_to_check, sqmodule_expected_converted_class), (
|
||||||
sqmodule_to_check, sqmodule_expected_converted_class
|
"Convert failed"
|
||||||
), "Convert failed"
|
)
|
||||||
assert isinstance(
|
assert isinstance(qmodule_to_check, qmodule_expected_converted_class), (
|
||||||
qmodule_to_check, qmodule_expected_converted_class
|
"Mapping failed"
|
||||||
), "Mapping failed"
|
)
|
||||||
|
|
||||||
row_block_size, col_block_size = sqmodel.linear._packed_params._weight_bias()[
|
row_block_size, col_block_size = sqmodel.linear._packed_params._weight_bias()[
|
||||||
2:
|
2:
|
||||||
|
|
|
||||||
|
|
@ -1055,9 +1055,9 @@ class TestFPGMPruner(TestCase):
|
||||||
mask1 = pruner.groups[0]["module"].parametrizations.weight[0].mask[-1]
|
mask1 = pruner.groups[0]["module"].parametrizations.weight[0].mask[-1]
|
||||||
mask2 = pruner.groups[0]["module"].parametrizations.weight[0].mask[-2]
|
mask2 = pruner.groups[0]["module"].parametrizations.weight[0].mask[-2]
|
||||||
# Check if either of the least-norm filters is not pruned
|
# Check if either of the least-norm filters is not pruned
|
||||||
assert (
|
assert mask1.item() is not False or mask2.item() is not False, (
|
||||||
mask1.item() is not False or mask2.item() is not False
|
"Do not prune all least-norm filters"
|
||||||
), "Do not prune all least-norm filters"
|
)
|
||||||
|
|
||||||
# fusion step
|
# fusion step
|
||||||
pruned_model = pruner.prune()
|
pruned_model = pruner.prune()
|
||||||
|
|
|
||||||
|
|
@ -66,9 +66,10 @@ def generate_callgrind_artifacts() -> None:
|
||||||
json.dump(artifacts, f, indent=4)
|
json.dump(artifacts, f, indent=4)
|
||||||
|
|
||||||
|
|
||||||
def load_callgrind_artifacts() -> (
|
def load_callgrind_artifacts() -> tuple[
|
||||||
tuple[benchmark_utils.CallgrindStats, benchmark_utils.CallgrindStats]
|
benchmark_utils.CallgrindStats,
|
||||||
):
|
benchmark_utils.CallgrindStats,
|
||||||
|
]:
|
||||||
"""Hermetic artifact to unit test Callgrind wrapper.
|
"""Hermetic artifact to unit test Callgrind wrapper.
|
||||||
|
|
||||||
In addition to collecting counts, this wrapper provides some facilities for
|
In addition to collecting counts, this wrapper provides some facilities for
|
||||||
|
|
|
||||||
|
|
@ -158,7 +158,8 @@ def compute_functional_name(test_params_dict):
|
||||||
return test_params_dict["cpp_function_call"].split("(")[0].replace("F::", "")
|
return test_params_dict["cpp_function_call"].split("(")[0].replace("F::", "")
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"`cpp_options_args` or `cpp_function_call` entry must be present in test params dict:\n{pprint.pformat(test_params_dict)}" # noqa: B950
|
"`cpp_options_args` or `cpp_function_call` entry must be present in test params dict:\n"
|
||||||
|
f"{pprint.pformat(test_params_dict)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -179,7 +180,8 @@ def compute_cpp_function_call(test_params_dict, arg_dict, functional_name):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"`cpp_options_args` or `cpp_function_call` entry must be present in test params dict:\n{pprint.pformat(test_params_dict)}" # noqa: B950
|
"`cpp_options_args` or `cpp_function_call` entry must be present in test params dict:\n"
|
||||||
|
f"{pprint.pformat(test_params_dict)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -217,7 +219,8 @@ def write_test_to_test_class(
|
||||||
or "cpp_function_call" in test_params_dict
|
or "cpp_function_call" in test_params_dict
|
||||||
), (
|
), (
|
||||||
"To enable C++ API parity test, "
|
"To enable C++ API parity test, "
|
||||||
f"`cpp_options_args` or `cpp_function_call` entry must be present in test params dict:\n{pprint.pformat(test_params_dict)}. \n" # noqa: B950
|
"`cpp_options_args` or `cpp_function_call` entry must be present in test params dict:\n"
|
||||||
|
f"{pprint.pformat(test_params_dict)}. \n"
|
||||||
"If you are interested in adding the C++ API parity test, please see:\n"
|
"If you are interested in adding the C++ API parity test, please see:\n"
|
||||||
"NOTE [How to check NN module / functional API parity between Python and C++ frontends]. \n"
|
"NOTE [How to check NN module / functional API parity between Python and C++ frontends]. \n"
|
||||||
"If not, please add `test_cpp_api_parity=False` to the test params dict and file an issue about this."
|
"If not, please add `test_cpp_api_parity=False` to the test params dict and file an issue about this."
|
||||||
|
|
@ -233,14 +236,16 @@ def write_test_to_test_class(
|
||||||
|
|
||||||
functional_name = compute_functional_name(test_params_dict)
|
functional_name = compute_functional_name(test_params_dict)
|
||||||
|
|
||||||
assert hasattr(
|
assert hasattr(torch.nn.functional, functional_name), (
|
||||||
torch.nn.functional, functional_name
|
f"`torch.nn.functional` doesn't have function `{functional_name}`. "
|
||||||
), f"`torch.nn.functional` doesn't have function `{functional_name}`. (Discovered while processing\n{pprint.pformat(test_params_dict)}.)" # noqa: B950
|
f"(Discovered while processing\n{pprint.pformat(test_params_dict)}.)"
|
||||||
|
)
|
||||||
|
|
||||||
functional_full_name = "F::" + functional_name
|
functional_full_name = "F::" + functional_name
|
||||||
|
|
||||||
assert functional_full_name in parity_table["torch::nn::functional"], (
|
assert functional_full_name in parity_table["torch::nn::functional"], (
|
||||||
f"Please add `{functional_full_name}` entry to `torch::nn::functional` section of `test/cpp_api_parity/parity-tracker.md`. "
|
f"Please add `{functional_full_name}` entry to `torch::nn::functional` "
|
||||||
|
"section of `test/cpp_api_parity/parity-tracker.md`. "
|
||||||
f"(Discovered while processing\n{pprint.pformat(test_params_dict)}.)"
|
f"(Discovered while processing\n{pprint.pformat(test_params_dict)}.)"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -78,9 +78,9 @@ def _kernel_fallback(op, *args, **kwargs):
|
||||||
elif op._schema.is_mutable or op is torch.ops.aten._copy_from.default:
|
elif op._schema.is_mutable or op is torch.ops.aten._copy_from.default:
|
||||||
# Only handle inplace ops returning their first arg
|
# Only handle inplace ops returning their first arg
|
||||||
assert len(args) >= 1, f"Inplace {op} needs at least one arg"
|
assert len(args) >= 1, f"Inplace {op} needs at least one arg"
|
||||||
assert (
|
assert len(op._schema.returns) == 1, (
|
||||||
len(op._schema.returns) == 1
|
f"NYI Inplace {op} with more than one return"
|
||||||
), f"NYI Inplace {op} with more than one return"
|
)
|
||||||
op_name = op.overloadpacket._qualified_op_name
|
op_name = op.overloadpacket._qualified_op_name
|
||||||
real_res = args[0]
|
real_res = args[0]
|
||||||
elif any(r.alias_info is not None for r in op._schema.returns):
|
elif any(r.alias_info is not None for r in op._schema.returns):
|
||||||
|
|
|
||||||
|
|
@ -67,7 +67,7 @@ def prepare_for_sending(args, kwargs):
|
||||||
def convert(obj):
|
def convert(obj):
|
||||||
if type(obj) not in VALID_QUEUE_TYPES_IN:
|
if type(obj) not in VALID_QUEUE_TYPES_IN:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Cannot send object of type {type(obj)} " "over openreg device pipe."
|
f"Cannot send object of type {type(obj)} over openreg device pipe."
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(obj, torch.Tensor):
|
if isinstance(obj, torch.Tensor):
|
||||||
|
|
@ -82,8 +82,7 @@ def receive_after_sending(allocator, args, kwargs):
|
||||||
def convert(obj):
|
def convert(obj):
|
||||||
if type(obj) not in VALID_QUEUE_TYPES_OUT:
|
if type(obj) not in VALID_QUEUE_TYPES_OUT:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Received invalid object of type {type(obj)} "
|
f"Received invalid object of type {type(obj)} over openreg device pipe."
|
||||||
"over openreg device pipe."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(obj, OpenRegTensorMeta):
|
if isinstance(obj, OpenRegTensorMeta):
|
||||||
|
|
|
||||||
|
|
@ -561,8 +561,9 @@ class TestFullyShardPrefetch(FSDPTest):
|
||||||
FSDPParamGroup.post_backward, events
|
FSDPParamGroup.post_backward, events
|
||||||
)
|
)
|
||||||
# Check the order for normal 1 forward, 1 backward, 1 optimizer step
|
# Check the order for normal 1 forward, 1 backward, 1 optimizer step
|
||||||
with patch_unshard(unshard_with_record), patch_post_backward(
|
with (
|
||||||
post_backward_with_record
|
patch_unshard(unshard_with_record),
|
||||||
|
patch_post_backward(post_backward_with_record),
|
||||||
):
|
):
|
||||||
for iter_idx in range(3):
|
for iter_idx in range(3):
|
||||||
loss = model(inp)
|
loss = model(inp)
|
||||||
|
|
@ -617,8 +618,9 @@ class TestFullyShardPrefetch(FSDPTest):
|
||||||
FSDPParamGroup.post_backward, events
|
FSDPParamGroup.post_backward, events
|
||||||
)
|
)
|
||||||
# Check the order for multiple forwards before 1 backward
|
# Check the order for multiple forwards before 1 backward
|
||||||
with patch_unshard(unshard_with_record), patch_post_backward(
|
with (
|
||||||
post_backward_with_record
|
patch_unshard(unshard_with_record),
|
||||||
|
patch_post_backward(post_backward_with_record),
|
||||||
):
|
):
|
||||||
loss1 = model(inp)
|
loss1 = model(inp)
|
||||||
loss2 = model(inp)
|
loss2 = model(inp)
|
||||||
|
|
@ -703,8 +705,9 @@ class TestFullyShardPrefetch(FSDPTest):
|
||||||
post_backward_with_record = self._get_post_backward_with_record(
|
post_backward_with_record = self._get_post_backward_with_record(
|
||||||
FSDPParamGroup.post_backward, events
|
FSDPParamGroup.post_backward, events
|
||||||
)
|
)
|
||||||
with patch_unshard(unshard_with_record), patch_post_backward(
|
with (
|
||||||
post_backward_with_record
|
patch_unshard(unshard_with_record),
|
||||||
|
patch_post_backward(post_backward_with_record),
|
||||||
):
|
):
|
||||||
loss1, loss2 = model(inp)
|
loss1, loss2 = model(inp)
|
||||||
expected_events = [
|
expected_events = [
|
||||||
|
|
@ -794,9 +797,11 @@ class TestFullyShardPrefetch(FSDPTest):
|
||||||
("reshard", "", TrainingState.POST_BACKWARD),
|
("reshard", "", TrainingState.POST_BACKWARD),
|
||||||
("post_backward", "", TrainingState.POST_BACKWARD),
|
("post_backward", "", TrainingState.POST_BACKWARD),
|
||||||
]
|
]
|
||||||
with patch_unshard(unshard_with_record), patch_reshard(
|
with (
|
||||||
reshard_with_record
|
patch_unshard(unshard_with_record),
|
||||||
), patch_post_backward(post_backward_with_record):
|
patch_reshard(reshard_with_record),
|
||||||
|
patch_post_backward(post_backward_with_record),
|
||||||
|
):
|
||||||
set_forward_prefetch(model, num_to_prefetch=1)
|
set_forward_prefetch(model, num_to_prefetch=1)
|
||||||
loss = model(inp)
|
loss = model(inp)
|
||||||
expected_forward_events = [
|
expected_forward_events = [
|
||||||
|
|
@ -882,9 +887,11 @@ class TestFullyShardPrefetch(FSDPTest):
|
||||||
("reshard", "layers.3", TrainingState.FORWARD),
|
("reshard", "layers.3", TrainingState.FORWARD),
|
||||||
("reshard", "", TrainingState.FORWARD),
|
("reshard", "", TrainingState.FORWARD),
|
||||||
]
|
]
|
||||||
with patch_unshard(unshard_with_record), patch_reshard(
|
with (
|
||||||
reshard_with_record
|
patch_unshard(unshard_with_record),
|
||||||
), patch_post_backward(post_backward_with_record):
|
patch_reshard(reshard_with_record),
|
||||||
|
patch_post_backward(post_backward_with_record),
|
||||||
|
):
|
||||||
set_backward_prefetch(model, num_to_prefetch=1)
|
set_backward_prefetch(model, num_to_prefetch=1)
|
||||||
loss = model(inp)
|
loss = model(inp)
|
||||||
self.assertEqual(events, expected_forward_events)
|
self.assertEqual(events, expected_forward_events)
|
||||||
|
|
@ -967,8 +974,9 @@ class TestFullyShardPrefetch(FSDPTest):
|
||||||
(2, model_args.max_seq_len),
|
(2, model_args.max_seq_len),
|
||||||
device=device_type.type,
|
device=device_type.type,
|
||||||
)
|
)
|
||||||
with patch_unshard(unshard_with_record), patch_post_backward(
|
with (
|
||||||
post_backward_with_record
|
patch_unshard(unshard_with_record),
|
||||||
|
patch_post_backward(post_backward_with_record),
|
||||||
):
|
):
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
loss = model(inp)
|
loss = model(inp)
|
||||||
|
|
@ -1046,8 +1054,9 @@ class TestFullyShardPrefetch(FSDPTest):
|
||||||
FSDPParamGroup.post_backward, events
|
FSDPParamGroup.post_backward, events
|
||||||
)
|
)
|
||||||
inp = torch.randn((2, 16), device=device_type.type)
|
inp = torch.randn((2, 16), device=device_type.type)
|
||||||
with patch_unshard(unshard_with_record), patch_post_backward(
|
with (
|
||||||
post_backward_with_record
|
patch_unshard(unshard_with_record),
|
||||||
|
patch_post_backward(post_backward_with_record),
|
||||||
):
|
):
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
loss = model(inp)
|
loss = model(inp)
|
||||||
|
|
|
||||||
|
|
@ -222,9 +222,7 @@ class TestFullyShardCompile(FSDPTest):
|
||||||
):
|
):
|
||||||
unsharded_param_graph_inputs.add(node.args[0])
|
unsharded_param_graph_inputs.add(node.args[0])
|
||||||
assert len(unsharded_param_graph_inputs) > 0
|
assert len(unsharded_param_graph_inputs) > 0
|
||||||
assert len(unsharded_param_graph_inputs) == len(
|
assert len(unsharded_param_graph_inputs) == len(list(model.parameters())), """\
|
||||||
list(model.parameters())
|
|
||||||
), """\
|
|
||||||
Expected all model parameters to be wrapped by FSDP2 and
|
Expected all model parameters to be wrapped by FSDP2 and
|
||||||
have their unsharded version as graph input, but it's not true!
|
have their unsharded version as graph input, but it's not true!
|
||||||
"""
|
"""
|
||||||
|
|
@ -237,7 +235,7 @@ have their unsharded version as graph input, but it's not true!
|
||||||
no_aliased_unsharded_params_in_graph_inputs = False
|
no_aliased_unsharded_params_in_graph_inputs = False
|
||||||
err_msg += f"""\n
|
err_msg += f"""\n
|
||||||
Found aliased unsharded param in graph inputs: {aliased_graph_inputs},
|
Found aliased unsharded param in graph inputs: {aliased_graph_inputs},
|
||||||
val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]},
|
val.shape: {[node.meta["val"].shape for node in aliased_graph_inputs]},
|
||||||
"""
|
"""
|
||||||
self.assertTrue(no_aliased_unsharded_params_in_graph_inputs, err_msg)
|
self.assertTrue(no_aliased_unsharded_params_in_graph_inputs, err_msg)
|
||||||
|
|
||||||
|
|
@ -466,10 +464,9 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]},
|
||||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||||
def test_compiled_autograd_ctx(self):
|
def test_compiled_autograd_ctx(self):
|
||||||
self.skipTestForOldSm()
|
self.skipTestForOldSm()
|
||||||
with torch._dynamo.config.patch(
|
with (
|
||||||
skip_fsdp_hooks=False,
|
torch._dynamo.config.patch(skip_fsdp_hooks=False),
|
||||||
), torch._functorch.config.patch(
|
torch._functorch.config.patch(recompute_views=True),
|
||||||
recompute_views=True,
|
|
||||||
):
|
):
|
||||||
inputs = torch.randn(8, 8)
|
inputs = torch.randn(8, 8)
|
||||||
model = torch.nn.Linear(8, 8)
|
model = torch.nn.Linear(8, 8)
|
||||||
|
|
@ -567,24 +564,28 @@ Unsupported Tensor.backward() call
|
||||||
|
|
||||||
torch._dynamo.reset()
|
torch._dynamo.reset()
|
||||||
torch._dynamo.compiled_autograd.reset()
|
torch._dynamo.compiled_autograd.reset()
|
||||||
with torch._dynamo.config.patch(
|
with (
|
||||||
compiled_autograd=True,
|
torch._dynamo.config.patch(
|
||||||
compiled_autograd_kwargs_override={
|
compiled_autograd=True,
|
||||||
"fullgraph": True,
|
compiled_autograd_kwargs_override={
|
||||||
},
|
"fullgraph": True,
|
||||||
inline_inbuilt_nn_modules=True,
|
},
|
||||||
skip_fsdp_hooks=False,
|
inline_inbuilt_nn_modules=True,
|
||||||
), torch._functorch.config.patch(
|
skip_fsdp_hooks=False,
|
||||||
enable_autograd_cache=False,
|
),
|
||||||
recompute_views=True,
|
torch._functorch.config.patch(
|
||||||
), torch._inductor.config.patch(
|
enable_autograd_cache=False,
|
||||||
force_disable_caches=True,
|
recompute_views=True,
|
||||||
reorder_for_compute_comm_overlap=True,
|
),
|
||||||
reorder_for_compute_comm_overlap_passes=[
|
torch._inductor.config.patch(
|
||||||
"sink_waits",
|
force_disable_caches=True,
|
||||||
"raise_comms",
|
reorder_for_compute_comm_overlap=True,
|
||||||
"reorder_compute_for_overlap",
|
reorder_for_compute_comm_overlap_passes=[
|
||||||
],
|
"sink_waits",
|
||||||
|
"raise_comms",
|
||||||
|
"reorder_compute_for_overlap",
|
||||||
|
],
|
||||||
|
),
|
||||||
):
|
):
|
||||||
losses_compiled = test_compiled()
|
losses_compiled = test_compiled()
|
||||||
losses_eager = test_eager()
|
losses_eager = test_eager()
|
||||||
|
|
@ -741,20 +742,21 @@ Unsupported Tensor.backward() call
|
||||||
def _test_nested_fully_shard_backend_inductor_fullgraph_True(self):
|
def _test_nested_fully_shard_backend_inductor_fullgraph_True(self):
|
||||||
self.skipTestForOldSm()
|
self.skipTestForOldSm()
|
||||||
for fwd_fullgraph in [True]:
|
for fwd_fullgraph in [True]:
|
||||||
with self._reinplace_all_gather_with_optional_checks(
|
with (
|
||||||
fwd_fullgraph
|
self._reinplace_all_gather_with_optional_checks(fwd_fullgraph),
|
||||||
), torch._inductor.config.patch(
|
torch._inductor.config.patch(
|
||||||
post_grad_custom_post_pass=(
|
post_grad_custom_post_pass=(
|
||||||
functools.partial(
|
functools.partial(
|
||||||
self._check_fsdp_copy_and_resize_ops_count_in_graph,
|
self._check_fsdp_copy_and_resize_ops_count_in_graph,
|
||||||
fwd_copy_count=0,
|
fwd_copy_count=0,
|
||||||
fwd_resize_count=0,
|
fwd_resize_count=0,
|
||||||
bwd_copy_count=0,
|
bwd_copy_count=0,
|
||||||
bwd_resize_count=0,
|
bwd_resize_count=0,
|
||||||
|
)
|
||||||
|
if fwd_fullgraph
|
||||||
|
else None
|
||||||
)
|
)
|
||||||
if fwd_fullgraph
|
),
|
||||||
else None
|
|
||||||
)
|
|
||||||
):
|
):
|
||||||
_, triton_codes = run_and_get_code(
|
_, triton_codes = run_and_get_code(
|
||||||
lambda: self._test_traceable_fsdp(
|
lambda: self._test_traceable_fsdp(
|
||||||
|
|
@ -943,9 +945,10 @@ Unsupported Tensor.backward() call
|
||||||
for fwd_fullgraph, all_requires_grad in itertools.product(
|
for fwd_fullgraph, all_requires_grad in itertools.product(
|
||||||
[True], [True, False]
|
[True], [True, False]
|
||||||
):
|
):
|
||||||
with self._maybe_add_graph_break_to_sdpa(
|
with (
|
||||||
fwd_fullgraph
|
self._maybe_add_graph_break_to_sdpa(fwd_fullgraph),
|
||||||
), self._reinplace_all_gather_with_optional_checks(fwd_fullgraph):
|
self._reinplace_all_gather_with_optional_checks(fwd_fullgraph),
|
||||||
|
):
|
||||||
self._test_traceable_fsdp(
|
self._test_traceable_fsdp(
|
||||||
*self._create_transformer_factory_fns(
|
*self._create_transformer_factory_fns(
|
||||||
all_requires_grad=all_requires_grad
|
all_requires_grad=all_requires_grad
|
||||||
|
|
@ -982,23 +985,24 @@ Unsupported Tensor.backward() call
|
||||||
log.warning(
|
log.warning(
|
||||||
f"fwd_fullgraph={fwd_fullgraph}, all_requires_grad={all_requires_grad}, activation_checkpoint={activation_checkpoint}" # noqa: G004, G001, B950
|
f"fwd_fullgraph={fwd_fullgraph}, all_requires_grad={all_requires_grad}, activation_checkpoint={activation_checkpoint}" # noqa: G004, G001, B950
|
||||||
)
|
)
|
||||||
with self._reinplace_all_gather_with_optional_checks(
|
with (
|
||||||
fwd_fullgraph
|
self._reinplace_all_gather_with_optional_checks(fwd_fullgraph),
|
||||||
), torch._inductor.config.patch(
|
torch._inductor.config.patch(
|
||||||
post_grad_custom_post_pass=(
|
post_grad_custom_post_pass=(
|
||||||
functools.partial(
|
functools.partial(
|
||||||
self._check_fsdp_copy_and_resize_ops_count_in_graph,
|
self._check_fsdp_copy_and_resize_ops_count_in_graph,
|
||||||
# NOTE: For the root unsharded params, we don't reshard after forward since for training,
|
# NOTE: For the root unsharded params, we don't reshard after forward since for training,
|
||||||
# the parameters would be freed and all-gathered immediately. Hence we still have
|
# the parameters would be freed and all-gathered immediately. Hence we still have
|
||||||
# their resize and copy ops in the graph.
|
# their resize and copy ops in the graph.
|
||||||
fwd_copy_count=4,
|
fwd_copy_count=4,
|
||||||
fwd_resize_count=4,
|
fwd_resize_count=4,
|
||||||
bwd_copy_count=0,
|
bwd_copy_count=0,
|
||||||
bwd_resize_count=4,
|
bwd_resize_count=4,
|
||||||
|
)
|
||||||
|
if fwd_fullgraph
|
||||||
|
else None
|
||||||
)
|
)
|
||||||
if fwd_fullgraph
|
),
|
||||||
else None
|
|
||||||
)
|
|
||||||
):
|
):
|
||||||
_, triton_codes = run_and_get_code(
|
_, triton_codes = run_and_get_code(
|
||||||
lambda: self._test_traceable_fsdp(
|
lambda: self._test_traceable_fsdp(
|
||||||
|
|
|
||||||
|
|
@ -385,9 +385,9 @@ class TestFullyShardAllGatherExtensionsMultiThread(
|
||||||
only some ranks may require padding, in which case only those ranks
|
only some ranks may require padding, in which case only those ranks
|
||||||
will error out and the all-gather will timeout.
|
will error out and the all-gather will timeout.
|
||||||
"""
|
"""
|
||||||
assert (
|
assert self.world_size >= 2, (
|
||||||
self.world_size >= 2
|
f"Assumes world size of at least 2 but got {self.world_size=}"
|
||||||
), f"Assumes world size of at least 2 but got {self.world_size=}"
|
)
|
||||||
model = MLP(dim=3, dim_multiplier=3)
|
model = MLP(dim=3, dim_multiplier=3)
|
||||||
for module in model.modules():
|
for module in model.modules():
|
||||||
for param_name, param in module.named_parameters(recurse=False):
|
for param_name, param in module.named_parameters(recurse=False):
|
||||||
|
|
|
||||||
|
|
@ -115,9 +115,10 @@ class TestFullyShardFrozen(FSDPTest):
|
||||||
|
|
||||||
torch.manual_seed(42 + self.rank + 1)
|
torch.manual_seed(42 + self.rank + 1)
|
||||||
device = device_type
|
device = device_type
|
||||||
with patch_reduce_scatter(
|
with (
|
||||||
reduce_scatter
|
patch_reduce_scatter(reduce_scatter),
|
||||||
), patch_register_post_backward_hook_backward(backward_with_count):
|
patch_register_post_backward_hook_backward(backward_with_count),
|
||||||
|
):
|
||||||
for iter_idx in range(10):
|
for iter_idx in range(10):
|
||||||
inp = torch.randn((8, lin_dim), device=device)
|
inp = torch.randn((8, lin_dim), device=device)
|
||||||
losses: list[torch.Tensor] = []
|
losses: list[torch.Tensor] = []
|
||||||
|
|
|
||||||
|
|
@ -910,9 +910,9 @@ class TestFullyShardProcessGroupInit(FSDPTestMultiThread):
|
||||||
@skip_if_lt_x_gpu(1)
|
@skip_if_lt_x_gpu(1)
|
||||||
def test_2d_process_group_init(self):
|
def test_2d_process_group_init(self):
|
||||||
shard_mesh_dim_size = 2
|
shard_mesh_dim_size = 2
|
||||||
assert (
|
assert self.world_size % shard_mesh_dim_size == 0, (
|
||||||
self.world_size % shard_mesh_dim_size == 0
|
f"Expects {self.world_size} to be divisible by {shard_mesh_dim_size}"
|
||||||
), f"Expects {self.world_size} to be divisible by {shard_mesh_dim_size}"
|
)
|
||||||
replicate_mesh_dim_size = self.world_size // shard_mesh_dim_size
|
replicate_mesh_dim_size = self.world_size // shard_mesh_dim_size
|
||||||
mesh_dim_names = ("replicate", "shard")
|
mesh_dim_names = ("replicate", "shard")
|
||||||
ref_mesh = init_device_mesh(
|
ref_mesh = init_device_mesh(
|
||||||
|
|
|
||||||
|
|
@ -54,9 +54,9 @@ class TestFullyShardMemory(FSDPTest):
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
return # skip since not a common use case
|
return # skip since not a common use case
|
||||||
assert (
|
assert self.world_size == 2, (
|
||||||
self.world_size == 2
|
f"Requires world size of 2 since some values are hard coded: {self.world_size}"
|
||||||
), f"Requires world size of 2 since some values are hard coded: {self.world_size}"
|
)
|
||||||
torch.manual_seed(42)
|
torch.manual_seed(42)
|
||||||
# Pre-run a linear forward (gemm and bias) and backward (gemm) to
|
# Pre-run a linear forward (gemm and bias) and backward (gemm) to
|
||||||
# allocate the cuBLAS workspaces before measuring the memory usage
|
# allocate the cuBLAS workspaces before measuring the memory usage
|
||||||
|
|
|
||||||
|
|
@ -284,9 +284,7 @@ class TestFullyShardMixedPrecisionTraining(FSDPTest):
|
||||||
) # bf16 reduction
|
) # bf16 reduction
|
||||||
param.grad = funcol.all_gather_tensor(
|
param.grad = funcol.all_gather_tensor(
|
||||||
sharded_grad, gather_dim=0, group=group
|
sharded_grad, gather_dim=0, group=group
|
||||||
).to(
|
).to(param.dtype) # upcast to fp32
|
||||||
param.dtype
|
|
||||||
) # upcast to fp32
|
|
||||||
ref_optim.step() # fp32 optimizer step
|
ref_optim.step() # fp32 optimizer step
|
||||||
|
|
||||||
self.assertEqual(fsdp_loss, ref_loss)
|
self.assertEqual(fsdp_loss, ref_loss)
|
||||||
|
|
|
||||||
|
|
@ -139,8 +139,9 @@ class TestFullyShardOverlap(FSDPTest):
|
||||||
dist.reduce_scatter_tensor(dummy_rs_output, dummy_rs_input)
|
dist.reduce_scatter_tensor(dummy_rs_output, dummy_rs_input)
|
||||||
|
|
||||||
def fwd_bwd():
|
def fwd_bwd():
|
||||||
with patch_all_gather(delayed_all_gather), patch_reduce_scatter(
|
with (
|
||||||
delayed_reduce_scatter
|
patch_all_gather(delayed_all_gather),
|
||||||
|
patch_reduce_scatter(delayed_reduce_scatter),
|
||||||
):
|
):
|
||||||
loss = model(inp).sum()
|
loss = model(inp).sum()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
|
||||||
|
|
@ -74,12 +74,12 @@ class TestFullyShardForwardInputs(FSDPTestMultiThread):
|
||||||
# Check that FSDP moved the inputs to GPU, including recursing
|
# Check that FSDP moved the inputs to GPU, including recursing
|
||||||
# into the tuple data structure
|
# into the tuple data structure
|
||||||
assert x.device == device, f"Expects {device} but got {x.device}"
|
assert x.device == device, f"Expects {device} but got {x.device}"
|
||||||
assert (
|
assert ys[0].device == device, (
|
||||||
ys[0].device == device
|
f"Expects {device} but got {ys[0].device}"
|
||||||
), f"Expects {device} but got {ys[0].device}"
|
)
|
||||||
assert (
|
assert ys[1].device == device, (
|
||||||
ys[1].device == device
|
f"Expects {device} but got {ys[1].device}"
|
||||||
), f"Expects {device} but got {ys[1].device}"
|
)
|
||||||
y = ys[0] + ys[1]
|
y = ys[0] + ys[1]
|
||||||
return x + y + 1
|
return x + y + 1
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -234,8 +234,8 @@ class DistributedDataParallelCommHookTest(MultiProcessTestCase):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
assert (
|
assert not torch.cuda._initialized, (
|
||||||
not torch.cuda._initialized
|
"test_distributed must not have initialized CUDA context on main process"
|
||||||
), "test_distributed must not have initialized CUDA context on main process"
|
)
|
||||||
|
|
||||||
run_tests()
|
run_tests()
|
||||||
|
|
|
||||||
|
|
@ -250,9 +250,11 @@ class TestJoin(MultiProcessTestCase):
|
||||||
else "Detected at least one rank that exhausted inputs. "
|
else "Detected at least one rank that exhausted inputs. "
|
||||||
"Throwing across all ranks."
|
"Throwing across all ranks."
|
||||||
)
|
)
|
||||||
with self.assertRaisesRegex(
|
with (
|
||||||
RuntimeError, expected_msg
|
self.assertRaisesRegex(RuntimeError, expected_msg)
|
||||||
) if throw_on_early_termination else contextlib.nullcontext():
|
if throw_on_early_termination
|
||||||
|
else contextlib.nullcontext()
|
||||||
|
):
|
||||||
with Join(
|
with Join(
|
||||||
allreducers,
|
allreducers,
|
||||||
enable=enable,
|
enable=enable,
|
||||||
|
|
|
||||||
|
|
@ -677,9 +677,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
|
||||||
fully_shard(layer)
|
fully_shard(layer)
|
||||||
fully_shard(model)
|
fully_shard(model)
|
||||||
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
|
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
|
||||||
torch.optim.lr_scheduler.LambdaLR(
|
torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=[lambda epoch: 0.95**epoch])
|
||||||
optim, lr_lambda=[lambda epoch: 0.95**epoch]
|
|
||||||
)
|
|
||||||
opt_state_dict = ptd_state_dict.get_optimizer_state_dict(
|
opt_state_dict = ptd_state_dict.get_optimizer_state_dict(
|
||||||
model,
|
model,
|
||||||
optim,
|
optim,
|
||||||
|
|
|
||||||
|
|
@ -228,9 +228,9 @@ class TestStateDictStager(TestCase):
|
||||||
|
|
||||||
# Validate tensor count and bytes
|
# Validate tensor count and bytes
|
||||||
expected_storage_cnt = 2
|
expected_storage_cnt = 2
|
||||||
assert (
|
assert num_storages == expected_storage_cnt, (
|
||||||
num_storages == expected_storage_cnt
|
f"Expected {expected_storage_cnt} storages, got {num_storages}"
|
||||||
), f"Expected {expected_storage_cnt} storages, got {num_storages}"
|
)
|
||||||
|
|
||||||
# Calculate expected bytes
|
# Calculate expected bytes
|
||||||
# Note: Only unique storages are counted in the byte count
|
# Note: Only unique storages are counted in the byte count
|
||||||
|
|
@ -239,9 +239,9 @@ class TestStateDictStager(TestCase):
|
||||||
+ tensor3.numel() # tensor1 and tensor2 share storage
|
+ tensor3.numel() # tensor1 and tensor2 share storage
|
||||||
* tensor3.element_size() # tensor3 and its narrow view share storage
|
* tensor3.element_size() # tensor3 and its narrow view share storage
|
||||||
)
|
)
|
||||||
assert (
|
assert num_bytes == expected_bytes, (
|
||||||
num_bytes == expected_bytes
|
f"Expected {expected_bytes} bytes, got {num_bytes}"
|
||||||
), f"Expected {expected_bytes} bytes, got {num_bytes}"
|
)
|
||||||
# Verify that the CPU state dict is equivalent to the original CUDA state dict
|
# Verify that the CPU state dict is equivalent to the original CUDA state dict
|
||||||
result, error = compare_state_dicts(state_dict, cpu_state_dict)
|
result, error = compare_state_dicts(state_dict, cpu_state_dict)
|
||||||
assert result, f"State dicts are not equivalent: {error}"
|
assert result, f"State dicts are not equivalent: {error}"
|
||||||
|
|
@ -301,9 +301,9 @@ class TestStateDictStager(TestCase):
|
||||||
|
|
||||||
# Verify the first result is correct
|
# Verify the first result is correct
|
||||||
result, error = compare_state_dicts(state_dict, cpu_state_dict1)
|
result, error = compare_state_dicts(state_dict, cpu_state_dict1)
|
||||||
assert (
|
assert result, (
|
||||||
result
|
f"First state dict is not equivalent to original: {error}"
|
||||||
), f"First state dict is not equivalent to original: {error}"
|
)
|
||||||
|
|
||||||
# Modify the original tensors
|
# Modify the original tensors
|
||||||
tensor1.fill_(0)
|
tensor1.fill_(0)
|
||||||
|
|
@ -317,14 +317,14 @@ class TestStateDictStager(TestCase):
|
||||||
|
|
||||||
# Verify that the second CPU state dict is equivalent to the modified original state dict
|
# Verify that the second CPU state dict is equivalent to the modified original state dict
|
||||||
result, error = compare_state_dicts(state_dict, cpu_state_dict2)
|
result, error = compare_state_dicts(state_dict, cpu_state_dict2)
|
||||||
assert (
|
assert result, (
|
||||||
result
|
f"Second state dict is not equivalent to modified original: {error}"
|
||||||
), f"Second state dict is not equivalent to modified original: {error}"
|
)
|
||||||
|
|
||||||
# Verify that the number of cached storages hasn't changed
|
# Verify that the number of cached storages hasn't changed
|
||||||
assert (
|
assert num_storages1 == num_storages2, (
|
||||||
num_storages1 == num_storages2
|
f"Storage count changed: {num_storages1} vs {num_storages2}"
|
||||||
), f"Storage count changed: {num_storages1} vs {num_storages2}"
|
)
|
||||||
|
|
||||||
# Verify that the tensors in the second state dict have the same storage pointers as the first
|
# Verify that the tensors in the second state dict have the same storage pointers as the first
|
||||||
assert (
|
assert (
|
||||||
|
|
@ -347,12 +347,12 @@ class TestStateDictStager(TestCase):
|
||||||
cpu_state_dict3 = stager.stage(state_dict)
|
cpu_state_dict3 = stager.stage(state_dict)
|
||||||
|
|
||||||
# Verify that the third CPU state dict reflects the updated values
|
# Verify that the third CPU state dict reflects the updated values
|
||||||
assert torch.all(
|
assert torch.all(cpu_state_dict3["tensor1"] == 42.0), (
|
||||||
cpu_state_dict3["tensor1"] == 42.0
|
"Updated values should be reflected in the cached state dict"
|
||||||
), "Updated values should be reflected in the cached state dict"
|
)
|
||||||
assert torch.all(
|
assert torch.all(cpu_state_dict3["tensor2"] == 42.0), (
|
||||||
cpu_state_dict3["tensor2"] == 42.0
|
"Updated values should be reflected in the cached state dict"
|
||||||
), "Updated values should be reflected in the cached state dict"
|
)
|
||||||
|
|
||||||
@requires_cuda
|
@requires_cuda
|
||||||
def test_tensor_attrs(self):
|
def test_tensor_attrs(self):
|
||||||
|
|
@ -381,24 +381,24 @@ class TestStateDictStager(TestCase):
|
||||||
cpu_state_dict = stager.stage(state_dict)
|
cpu_state_dict = stager.stage(state_dict)
|
||||||
|
|
||||||
# Verify that tensor attributes are preserved
|
# Verify that tensor attributes are preserved
|
||||||
assert hasattr(
|
assert hasattr(cpu_state_dict["tensor1"], "a"), (
|
||||||
cpu_state_dict["tensor1"], "a"
|
"Tensor attribute 'a' was not preserved"
|
||||||
), "Tensor attribute 'a' was not preserved"
|
)
|
||||||
assert (
|
assert cpu_state_dict["tensor1"].a == 42, (
|
||||||
cpu_state_dict["tensor1"].a == 42
|
"Tensor attribute 'a' has incorrect value"
|
||||||
), "Tensor attribute 'a' has incorrect value"
|
)
|
||||||
assert hasattr(
|
assert hasattr(cpu_state_dict["tensor1"], "b"), (
|
||||||
cpu_state_dict["tensor1"], "b"
|
"Tensor attribute 'b' was not preserved"
|
||||||
), "Tensor attribute 'b' was not preserved"
|
)
|
||||||
assert (
|
assert cpu_state_dict["tensor1"].b == 43, (
|
||||||
cpu_state_dict["tensor1"].b == 43
|
"Tensor attribute 'b' has incorrect value"
|
||||||
), "Tensor attribute 'b' has incorrect value"
|
)
|
||||||
assert hasattr(
|
assert hasattr(cpu_state_dict["recursive"]["tensor3"], "c"), (
|
||||||
cpu_state_dict["recursive"]["tensor3"], "c"
|
"Tensor attribute 'c' was not preserved"
|
||||||
), "Tensor attribute 'c' was not preserved"
|
)
|
||||||
assert (
|
assert cpu_state_dict["recursive"]["tensor3"].c == 44, (
|
||||||
cpu_state_dict["recursive"]["tensor3"].c == 44
|
"Tensor attribute 'c' has incorrect value"
|
||||||
), "Tensor attribute 'c' has incorrect value"
|
)
|
||||||
|
|
||||||
@requires_cuda
|
@requires_cuda
|
||||||
def test_different_dtypes(self):
|
def test_different_dtypes(self):
|
||||||
|
|
|
||||||
|
|
@ -500,11 +500,13 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
|
||||||
logs_specs=DefaultLogsSpecs(log_dir=self.log_dir()),
|
logs_specs=DefaultLogsSpecs(log_dir=self.log_dir()),
|
||||||
)
|
)
|
||||||
|
|
||||||
with mock.patch.object(
|
with (
|
||||||
mpc, "_is_done", return_value=True
|
mock.patch.object(mpc, "_is_done", return_value=True),
|
||||||
), mock.patch.object(mpc, "_pc"), mock.patch.object(
|
mock.patch.object(mpc, "_pc"),
|
||||||
mpc._pc, "join", side_effect=[True, False, False, True]
|
mock.patch.object(
|
||||||
) as mock_join:
|
mpc._pc, "join", side_effect=[True, False, False, True]
|
||||||
|
) as mock_join,
|
||||||
|
):
|
||||||
mpc._poll()
|
mpc._poll()
|
||||||
self.assertEqual(4, mock_join.call_count)
|
self.assertEqual(4, mock_join.call_count)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -56,32 +56,36 @@ class TestDistributedCheckpoint(FSDPTest):
|
||||||
torch.manual_seed(200)
|
torch.manual_seed(200)
|
||||||
new_model = wrap(SkipModel(double_nest=True))
|
new_model = wrap(SkipModel(double_nest=True))
|
||||||
|
|
||||||
with FullyShardedDataParallel.summon_full_params(
|
with (
|
||||||
model
|
FullyShardedDataParallel.summon_full_params(model),
|
||||||
), FullyShardedDataParallel.summon_full_params(new_model):
|
FullyShardedDataParallel.summon_full_params(new_model),
|
||||||
|
):
|
||||||
params = list(model.parameters())
|
params = list(model.parameters())
|
||||||
new_params = list(new_model.parameters())
|
new_params = list(new_model.parameters())
|
||||||
self.assertNotEqual(params, new_params)
|
self.assertNotEqual(params, new_params)
|
||||||
|
|
||||||
writer = FileSystemWriter(self.temp_dir)
|
writer = FileSystemWriter(self.temp_dir)
|
||||||
reader = FileSystemReader(self.temp_dir)
|
reader = FileSystemReader(self.temp_dir)
|
||||||
with FSDP.state_dict_type(model, state_dict_type), FSDP.state_dict_type(
|
with (
|
||||||
new_model, state_dict_type
|
FSDP.state_dict_type(model, state_dict_type),
|
||||||
|
FSDP.state_dict_type(new_model, state_dict_type),
|
||||||
):
|
):
|
||||||
state_dict = model.state_dict()
|
state_dict = model.state_dict()
|
||||||
|
|
||||||
save(state_dict, writer)
|
save(state_dict, writer)
|
||||||
|
|
||||||
with FSDP.state_dict_type(model, state_dict_type), FSDP.state_dict_type(
|
with (
|
||||||
new_model, state_dict_type
|
FSDP.state_dict_type(model, state_dict_type),
|
||||||
|
FSDP.state_dict_type(new_model, state_dict_type),
|
||||||
):
|
):
|
||||||
state_dict = new_model.state_dict()
|
state_dict = new_model.state_dict()
|
||||||
load(state_dict, reader)
|
load(state_dict, reader)
|
||||||
new_model.load_state_dict(state_dict)
|
new_model.load_state_dict(state_dict)
|
||||||
|
|
||||||
with FullyShardedDataParallel.summon_full_params(
|
with (
|
||||||
model
|
FullyShardedDataParallel.summon_full_params(model),
|
||||||
), FullyShardedDataParallel.summon_full_params(new_model):
|
FullyShardedDataParallel.summon_full_params(new_model),
|
||||||
|
):
|
||||||
params = list(model.parameters())
|
params = list(model.parameters())
|
||||||
new_params = list(new_model.parameters())
|
new_params = list(new_model.parameters())
|
||||||
self.assertEqual(params, new_params)
|
self.assertEqual(params, new_params)
|
||||||
|
|
|
||||||
|
|
@ -242,11 +242,10 @@ class TestCommunication(FSDPTest):
|
||||||
# and if `use_no_sync=False`, we only run `num_iters` iterations
|
# and if `use_no_sync=False`, we only run `num_iters` iterations
|
||||||
# outside `no_sync()`
|
# outside `no_sync()`
|
||||||
num_iters = 3
|
num_iters = 3
|
||||||
with patch(
|
with (
|
||||||
"torch.distributed.all_gather_into_tensor"
|
patch("torch.distributed.all_gather_into_tensor") as mock_all_gather,
|
||||||
) as mock_all_gather, patch(
|
patch("torch.distributed.reduce_scatter_tensor") as mock_reduce_scatter,
|
||||||
"torch.distributed.reduce_scatter_tensor"
|
):
|
||||||
) as mock_reduce_scatter:
|
|
||||||
|
|
||||||
def reset_mocks():
|
def reset_mocks():
|
||||||
mock_all_gather.reset_mock()
|
mock_all_gather.reset_mock()
|
||||||
|
|
|
||||||
|
|
@ -379,12 +379,15 @@ class TestHooks(FSDPTest):
|
||||||
register_pre_backward_hooks_call_count += 1
|
register_pre_backward_hooks_call_count += 1
|
||||||
return orig_register_pre_backward_hooks(*args, **kwargs)
|
return orig_register_pre_backward_hooks(*args, **kwargs)
|
||||||
|
|
||||||
with mock.patch(
|
with (
|
||||||
"torch.distributed.fsdp._runtime_utils._register_pre_backward_hooks",
|
mock.patch(
|
||||||
_register_pre_backward_hooks_with_count,
|
"torch.distributed.fsdp._runtime_utils._register_pre_backward_hooks",
|
||||||
), mock.patch(
|
_register_pre_backward_hooks_with_count,
|
||||||
"torch.distributed.fsdp._runtime_utils._register_post_backward_hook"
|
),
|
||||||
) as register_post_bwd_mock:
|
mock.patch(
|
||||||
|
"torch.distributed.fsdp._runtime_utils._register_post_backward_hook"
|
||||||
|
) as register_post_bwd_mock,
|
||||||
|
):
|
||||||
self.assertEqual(register_pre_backward_hooks_call_count, 0)
|
self.assertEqual(register_pre_backward_hooks_call_count, 0)
|
||||||
self.assertFalse(register_post_bwd_mock.called)
|
self.assertFalse(register_post_bwd_mock.called)
|
||||||
fsdp_model(*input)
|
fsdp_model(*input)
|
||||||
|
|
|
||||||
|
|
@ -152,9 +152,9 @@ class TestGradAcc(FSDPTest):
|
||||||
batches.append(tuple(permute_tensor(t) for t in batch))
|
batches.append(tuple(permute_tensor(t) for t in batch))
|
||||||
for batch1, batch2 in itertools.combinations(batches, r=2):
|
for batch1, batch2 in itertools.combinations(batches, r=2):
|
||||||
for t1, t2 in zip(batch1, batch2):
|
for t1, t2 in zip(batch1, batch2):
|
||||||
assert not torch.all(
|
assert not torch.all(t1 == t2), (
|
||||||
t1 == t2
|
"Check the test to make sure that batches are distinct"
|
||||||
), "Check the test to make sure that batches are distinct"
|
)
|
||||||
|
|
||||||
# Concatenate the batches along the given batch dimension
|
# Concatenate the batches along the given batch dimension
|
||||||
concat_batch: tuple[torch.Tensor, ...] = tuple(
|
concat_batch: tuple[torch.Tensor, ...] = tuple(
|
||||||
|
|
|
||||||
|
|
@ -121,8 +121,9 @@ class TestFSDPHybridShard(FSDPTest):
|
||||||
def test_hsdp_save_load_state_dict(self):
|
def test_hsdp_save_load_state_dict(self):
|
||||||
model = MyModel().cuda()
|
model = MyModel().cuda()
|
||||||
num_node_devices = torch.cuda.device_count()
|
num_node_devices = torch.cuda.device_count()
|
||||||
shard_rank_lists = list(range(0, num_node_devices // 2)), list(
|
shard_rank_lists = (
|
||||||
range(num_node_devices // 2, num_node_devices)
|
list(range(0, num_node_devices // 2)),
|
||||||
|
list(range(num_node_devices // 2, num_node_devices)),
|
||||||
)
|
)
|
||||||
shard_groups = (
|
shard_groups = (
|
||||||
dist.new_group(shard_rank_lists[0]),
|
dist.new_group(shard_rank_lists[0]),
|
||||||
|
|
@ -171,8 +172,9 @@ class TestFSDPHybridShard(FSDPTest):
|
||||||
def test_hsdp_sync_module_state(self):
|
def test_hsdp_sync_module_state(self):
|
||||||
model = MyModel().cuda()
|
model = MyModel().cuda()
|
||||||
num_node_devices = torch.cuda.device_count()
|
num_node_devices = torch.cuda.device_count()
|
||||||
shard_rank_lists = list(range(0, num_node_devices // 2)), list(
|
shard_rank_lists = (
|
||||||
range(num_node_devices // 2, num_node_devices)
|
list(range(0, num_node_devices // 2)),
|
||||||
|
list(range(num_node_devices // 2, num_node_devices)),
|
||||||
)
|
)
|
||||||
shard_groups = (
|
shard_groups = (
|
||||||
dist.new_group(shard_rank_lists[0]),
|
dist.new_group(shard_rank_lists[0]),
|
||||||
|
|
@ -310,8 +312,9 @@ class TestFSDPHybridShard(FSDPTest):
|
||||||
cntr = Counter()
|
cntr = Counter()
|
||||||
patched_allreduce = partial(patched_collective, orig_ar, cntr)
|
patched_allreduce = partial(patched_collective, orig_ar, cntr)
|
||||||
patched_reduce_scatter = partial(patched_collective, orig_rs, cntr)
|
patched_reduce_scatter = partial(patched_collective, orig_rs, cntr)
|
||||||
with patch_allreduce(patched_allreduce), patch_reduce_scatter(
|
with (
|
||||||
patched_reduce_scatter
|
patch_allreduce(patched_allreduce),
|
||||||
|
patch_reduce_scatter(patched_reduce_scatter),
|
||||||
):
|
):
|
||||||
inp = hsdp_model.get_input(device=torch.cuda.current_device())
|
inp = hsdp_model.get_input(device=torch.cuda.current_device())
|
||||||
out = hsdp_model(inp[0], inp[1])
|
out = hsdp_model(inp[0], inp[1])
|
||||||
|
|
@ -355,9 +358,9 @@ class TestFSDPHybridShard(FSDPTest):
|
||||||
use_orig_params,
|
use_orig_params,
|
||||||
hsdp_process_groups=hsdp_pgs,
|
hsdp_process_groups=hsdp_pgs,
|
||||||
)
|
)
|
||||||
assert (
|
assert hsdp_model._inter_node_pg.size() > 1, (
|
||||||
hsdp_model._inter_node_pg.size() > 1
|
"HSDP model initialized without replication"
|
||||||
), "HSDP model initialized without replication"
|
)
|
||||||
fsdp_optim = torch.optim.Adam(fsdp_model.parameters(), lr=1e-2)
|
fsdp_optim = torch.optim.Adam(fsdp_model.parameters(), lr=1e-2)
|
||||||
hsdp_optim = torch.optim.Adam(hsdp_model.parameters(), lr=1e-2)
|
hsdp_optim = torch.optim.Adam(hsdp_model.parameters(), lr=1e-2)
|
||||||
torch.manual_seed(global_pg.rank() + 1)
|
torch.manual_seed(global_pg.rank() + 1)
|
||||||
|
|
|
||||||
|
|
@ -766,9 +766,9 @@ class TestFSDPMixedPrecisionSharded(TestFSDPMixedPrecision):
|
||||||
if expect_use_full_prec_in_eval:
|
if expect_use_full_prec_in_eval:
|
||||||
assert x.dtype == torch.float32, f"Expected fp32, got {x.dtype}"
|
assert x.dtype == torch.float32, f"Expected fp32, got {x.dtype}"
|
||||||
else:
|
else:
|
||||||
assert (
|
assert x.dtype == low_prec_dtype, (
|
||||||
x.dtype == low_prec_dtype
|
f"Expected {low_prec_dtype}, got {x.dtype}"
|
||||||
), f"Expected {low_prec_dtype}, got {x.dtype}"
|
)
|
||||||
return self.a(x)
|
return self.a(x)
|
||||||
|
|
||||||
mp_config = MixedPrecision(
|
mp_config = MixedPrecision(
|
||||||
|
|
|
||||||
|
|
@ -91,9 +91,9 @@ class TestTPFSDPIntegration(FSDPTest):
|
||||||
tensor_parallel_size: int,
|
tensor_parallel_size: int,
|
||||||
) -> tuple[dict[str, int], dict[str, tuple[torch.Size, int]]]:
|
) -> tuple[dict[str, int], dict[str, tuple[torch.Size, int]]]:
|
||||||
""" """
|
""" """
|
||||||
assert (
|
assert type(model) is SimpleModel, (
|
||||||
type(model) is SimpleModel
|
"Expects a `SimpleModel` since the sharding cases on the model definition"
|
||||||
), "Expects a `SimpleModel` since the sharding cases on the model definition"
|
)
|
||||||
param_name_to_numel = OrderedDict()
|
param_name_to_numel = OrderedDict()
|
||||||
param_name_to_sharding_info = OrderedDict()
|
param_name_to_sharding_info = OrderedDict()
|
||||||
for param_name, param in model.named_parameters():
|
for param_name, param in model.named_parameters():
|
||||||
|
|
|
||||||
|
|
@ -654,9 +654,12 @@ class TestFSDPUseOrigParamsUnshardReshard(FSDPTest):
|
||||||
losses1 = []
|
losses1 = []
|
||||||
losses2 = []
|
losses2 = []
|
||||||
losses = []
|
losses = []
|
||||||
for _model, _optim in (fsdp_model, optim), (
|
for _model, _optim in (
|
||||||
fsdp_model_orig_params,
|
(fsdp_model, optim),
|
||||||
optim_orig_params,
|
(
|
||||||
|
fsdp_model_orig_params,
|
||||||
|
optim_orig_params,
|
||||||
|
),
|
||||||
):
|
):
|
||||||
_optim.zero_grad()
|
_optim.zero_grad()
|
||||||
loss1 = _model(*inp1)
|
loss1 = _model(*inp1)
|
||||||
|
|
@ -1166,9 +1169,9 @@ class TestFSDPUseOrigParamsFQNs(FSDPTest):
|
||||||
clean_tensor_name(tup[0]) for tup in self.named_parameters()
|
clean_tensor_name(tup[0]) for tup in self.named_parameters()
|
||||||
]
|
]
|
||||||
params = [tup[1] for tup in self.named_parameters()]
|
params = [tup[1] for tup in self.named_parameters()]
|
||||||
assert (
|
assert param_shapes[0] is not None and param_shapes[1] is not None, (
|
||||||
param_shapes[0] is not None and param_shapes[1] is not None
|
"`param_sizes` should be set"
|
||||||
), "`param_sizes` should be set"
|
)
|
||||||
assert_equal_fn(
|
assert_equal_fn(
|
||||||
param_names,
|
param_names,
|
||||||
[
|
[
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,7 @@ The script itself is not a test case hence no assertions are made in this script
|
||||||
see: - test/distributed/launcher/run_test.py#test_is_torchelastic_launched()
|
see: - test/distributed/launcher/run_test.py#test_is_torchelastic_launched()
|
||||||
- test/distributed/launcher/run_test.py#test_is_not_torchelastic_launched()
|
- test/distributed/launcher/run_test.py#test_is_not_torchelastic_launched()
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
|
||||||
|
|
@ -292,9 +292,9 @@ class TestZeroRedundancyOptimizerSingleRank(TestZeroRedundancyOptimizer):
|
||||||
betas=BETAS,
|
betas=BETAS,
|
||||||
eps=EPS,
|
eps=EPS,
|
||||||
)
|
)
|
||||||
assert (
|
assert len(o.param_groups) == 2, (
|
||||||
len(o.param_groups) == 2
|
f"Expected 2 ZeRO param groups, but got {len(o.param_groups)}"
|
||||||
), f"Expected 2 ZeRO param groups, but got {len(o.param_groups)}"
|
)
|
||||||
assert len(o.optim.param_groups) == 2, (
|
assert len(o.optim.param_groups) == 2, (
|
||||||
"Expected 2 local optimizer param groups, but got "
|
"Expected 2 local optimizer param groups, but got "
|
||||||
f"{len(o.optim.param_groups)}"
|
f"{len(o.optim.param_groups)}"
|
||||||
|
|
@ -713,9 +713,9 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
|
||||||
LR = 1e-3
|
LR = 1e-3
|
||||||
MOMENTUM = 0.99
|
MOMENTUM = 0.99
|
||||||
REFERENCE_RANK = 0
|
REFERENCE_RANK = 0
|
||||||
assert (
|
assert REFERENCE_RANK in subgroup_ranks, (
|
||||||
REFERENCE_RANK in subgroup_ranks
|
"Reference rank must be in the new process group"
|
||||||
), "Reference rank must be in the new process group"
|
)
|
||||||
loss_fn = torch.nn.L1Loss().to(device)
|
loss_fn = torch.nn.L1Loss().to(device)
|
||||||
|
|
||||||
def check(optimizer):
|
def check(optimizer):
|
||||||
|
|
@ -1165,22 +1165,28 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
|
||||||
|
|
||||||
# Increased tolerances are needed to pass when using TF32
|
# Increased tolerances are needed to pass when using TF32
|
||||||
# See: https://github.com/pytorch/pytorch/issues/67764
|
# See: https://github.com/pytorch/pytorch/issues/67764
|
||||||
torch.testing.assert_close(
|
(
|
||||||
local_loss.cpu(),
|
torch.testing.assert_close(
|
||||||
ddp_loss.cpu(),
|
local_loss.cpu(),
|
||||||
rtol=1e-03,
|
ddp_loss.cpu(),
|
||||||
atol=1e-08,
|
rtol=1e-03,
|
||||||
), "Losses differ between local optimizer and ZeRO"
|
atol=1e-08,
|
||||||
|
),
|
||||||
|
"Losses differ between local optimizer and ZeRO",
|
||||||
|
)
|
||||||
|
|
||||||
for local_p, ddp_p in zip(
|
for local_p, ddp_p in zip(
|
||||||
local_model.parameters(), ddp_model.parameters()
|
local_model.parameters(), ddp_model.parameters()
|
||||||
):
|
):
|
||||||
torch.testing.assert_close(
|
(
|
||||||
local_p.cpu(),
|
torch.testing.assert_close(
|
||||||
ddp_p.cpu(),
|
local_p.cpu(),
|
||||||
rtol=1e-03,
|
ddp_p.cpu(),
|
||||||
atol=1e-04,
|
rtol=1e-03,
|
||||||
), "Models differ after a step"
|
atol=1e-04,
|
||||||
|
),
|
||||||
|
"Models differ after a step",
|
||||||
|
)
|
||||||
|
|
||||||
@skipIfHpu
|
@skipIfHpu
|
||||||
@skip_if_lt_x_gpu(4)
|
@skip_if_lt_x_gpu(4)
|
||||||
|
|
|
||||||
|
|
@ -89,9 +89,9 @@ class PipeTests(TestCase):
|
||||||
mb_args=(x, y),
|
mb_args=(x, y),
|
||||||
)
|
)
|
||||||
|
|
||||||
assert (
|
assert pipe.num_stages == EXPECTED_N_STAGES[ModelClass], (
|
||||||
pipe.num_stages == EXPECTED_N_STAGES[ModelClass]
|
f"nstages = {pipe.num_stages}, expect {EXPECTED_N_STAGES[ModelClass]}"
|
||||||
), f"nstages = {pipe.num_stages}, expect {EXPECTED_N_STAGES[ModelClass]}"
|
)
|
||||||
|
|
||||||
ref_out = mod(x, y)
|
ref_out = mod(x, y)
|
||||||
out = pipe(x, y)[0]
|
out = pipe(x, y)[0]
|
||||||
|
|
@ -109,9 +109,7 @@ class PipeTests(TestCase):
|
||||||
new_names.update(stage_fqns)
|
new_names.update(stage_fqns)
|
||||||
|
|
||||||
if CHECK_FQN_SET_EQUALITY:
|
if CHECK_FQN_SET_EQUALITY:
|
||||||
assert (
|
assert old_names == new_names, f"""
|
||||||
old_names == new_names
|
|
||||||
), f"""
|
|
||||||
old names {old_names}
|
old names {old_names}
|
||||||
new names {new_names}
|
new names {new_names}
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -60,9 +60,9 @@ class UnflattenTests(TestCase):
|
||||||
for stage_idx in range(pipe.num_stages):
|
for stage_idx in range(pipe.num_stages):
|
||||||
stage_mod = pipe.get_stage_module(stage_idx)
|
stage_mod = pipe.get_stage_module(stage_idx)
|
||||||
for param_name, _ in stage_mod.named_parameters():
|
for param_name, _ in stage_mod.named_parameters():
|
||||||
assert (
|
assert param_name in orig_state_dict, (
|
||||||
param_name in orig_state_dict
|
f"{param_name} not in original state dict"
|
||||||
), f"{param_name} not in original state dict"
|
)
|
||||||
print("Param qualname test passed")
|
print("Param qualname test passed")
|
||||||
|
|
||||||
# Check equivalence
|
# Check equivalence
|
||||||
|
|
|
||||||
|
|
@ -45,9 +45,9 @@ class ShareMemoryRPCPickler(_InternalRPCPickler):
|
||||||
for t in torch._tensor_classes:
|
for t in torch._tensor_classes:
|
||||||
self._dispatch_table[t] = TorchMpReductions.reduce_tensor
|
self._dispatch_table[t] = TorchMpReductions.reduce_tensor
|
||||||
self._dispatch_table[torch.Tensor] = TorchMpReductions.reduce_tensor
|
self._dispatch_table[torch.Tensor] = TorchMpReductions.reduce_tensor
|
||||||
self._dispatch_table[
|
self._dispatch_table[torch.nn.parameter.Parameter] = (
|
||||||
torch.nn.parameter.Parameter
|
TorchMpReductions.reduce_tensor
|
||||||
] = TorchMpReductions.reduce_tensor
|
)
|
||||||
|
|
||||||
|
|
||||||
def worker_loop(a):
|
def worker_loop(a):
|
||||||
|
|
|
||||||
|
|
@ -872,9 +872,7 @@ def forward(self, primals_1):
|
||||||
"buf0 = torch.ops._c10d_functional.all_gather_into_tensor.default(primal"
|
"buf0 = torch.ops._c10d_functional.all_gather_into_tensor.default(primal"
|
||||||
).check("torch.ops._c10d_functional.wait_tensor.default(buf0").check(
|
).check("torch.ops._c10d_functional.wait_tensor.default(buf0").check(
|
||||||
"extern_kernels.mm(buf0,"
|
"extern_kernels.mm(buf0,"
|
||||||
).run(
|
).run(code)
|
||||||
code
|
|
||||||
)
|
|
||||||
|
|
||||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||||
@skip_if_lt_x_gpu(1)
|
@skip_if_lt_x_gpu(1)
|
||||||
|
|
|
||||||
|
|
@ -441,9 +441,11 @@ class DistMathOpsTest(DTensorTestBase):
|
||||||
out_req_grad: bool
|
out_req_grad: bool
|
||||||
|
|
||||||
subtest_fails = {}
|
subtest_fails = {}
|
||||||
valid_filter = lambda cfg: not ( # noqa: E731
|
valid_filter = ( # noqa: E731
|
||||||
cfg.ln_req_grad and not cfg.elementwise_affine
|
lambda cfg: (
|
||||||
) and any(cfg[2:])
|
not (cfg.ln_req_grad and not cfg.elementwise_affine) and any(cfg[2:])
|
||||||
|
)
|
||||||
|
)
|
||||||
subtest_cfgs = list(
|
subtest_cfgs = list(
|
||||||
filter(
|
filter(
|
||||||
valid_filter,
|
valid_filter,
|
||||||
|
|
@ -566,9 +568,9 @@ class DistMathOpsTest(DTensorTestBase):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
subtest_fails[subtest_cfg] = e
|
subtest_fails[subtest_cfg] = e
|
||||||
# if any subtest fails, provide the failed subtests and report the overall failure
|
# if any subtest fails, provide the failed subtests and report the overall failure
|
||||||
assert (
|
assert not subtest_fails, (
|
||||||
not subtest_fails
|
f"{len(subtest_fails)}/{len(subtest_cfgs)} subtests failed: {pformat(subtest_fails)}"
|
||||||
), f"{len(subtest_fails)}/{len(subtest_cfgs)} subtests failed: {pformat(subtest_fails)}"
|
)
|
||||||
|
|
||||||
@with_comms
|
@with_comms
|
||||||
def test_topk(self):
|
def test_topk(self):
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,9 @@ def with_xla(func: Callable) -> Callable:
|
||||||
|
|
||||||
@wraps(func) # pyre-ignore[6]
|
@wraps(func) # pyre-ignore[6]
|
||||||
def wrapper(
|
def wrapper(
|
||||||
self, *args: tuple[object], **kwargs: dict[str, Any] # type: ignore[misc]
|
self,
|
||||||
|
*args: tuple[object],
|
||||||
|
**kwargs: dict[str, Any], # type: ignore[misc]
|
||||||
) -> None:
|
) -> None:
|
||||||
# TODO(yeounoh) replace this with xr.use_spmd() when we deprecate the flag.
|
# TODO(yeounoh) replace this with xr.use_spmd() when we deprecate the flag.
|
||||||
os.environ["XLA_USE_SPMD"] = "1"
|
os.environ["XLA_USE_SPMD"] = "1"
|
||||||
|
|
|
||||||
|
|
@ -2234,8 +2234,8 @@ class LocalRankTest(MultiProcessTestCase):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
assert (
|
assert not torch.cuda._initialized, (
|
||||||
not torch.cuda._initialized
|
"test_distributed must not have initialized CUDA context on main process"
|
||||||
), "test_distributed must not have initialized CUDA context on main process"
|
)
|
||||||
|
|
||||||
run_tests()
|
run_tests()
|
||||||
|
|
|
||||||
|
|
@ -796,13 +796,11 @@ class CompileTest(TestCase):
|
||||||
.check("buf6 = empty")
|
.check("buf6 = empty")
|
||||||
# Expect in-place with inductor allocated buf
|
# Expect in-place with inductor allocated buf
|
||||||
.check(
|
.check(
|
||||||
"torch.ops._c10d_functional.all_reduce_coalesced_"
|
"torch.ops._c10d_functional.all_reduce_coalesced_.default([buf0, buf1]"
|
||||||
".default([buf0, buf1]"
|
|
||||||
)
|
)
|
||||||
# Expect no in-place with graph input (buf5, buf6 are clones)
|
# Expect no in-place with graph input (buf5, buf6 are clones)
|
||||||
.check(
|
.check(
|
||||||
"torch.ops._c10d_functional.all_reduce_coalesced_"
|
"torch.ops._c10d_functional.all_reduce_coalesced_.default([buf5, buf6]"
|
||||||
".default([buf5, buf6]"
|
|
||||||
)
|
)
|
||||||
.check("torch.ops._c10d_functional.wait_tensor.default(buf0")
|
.check("torch.ops._c10d_functional.wait_tensor.default(buf0")
|
||||||
.check("torch.ops._c10d_functional.wait_tensor.default(buf1")
|
.check("torch.ops._c10d_functional.wait_tensor.default(buf1")
|
||||||
|
|
|
||||||
|
|
@ -2705,8 +2705,8 @@ class LargeCommTest(test_c10d_common.AbstractLargeCommTest, MultiProcessTestCase
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
assert (
|
assert not torch.cuda._initialized, (
|
||||||
not torch.cuda._initialized
|
"test_distributed must not have initialized CUDA context on main process"
|
||||||
), "test_distributed must not have initialized CUDA context on main process"
|
)
|
||||||
|
|
||||||
run_tests()
|
run_tests()
|
||||||
|
|
|
||||||
|
|
@ -2869,9 +2869,9 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
|
||||||
self.assertTrue(t.is_alive())
|
self.assertTrue(t.is_alive())
|
||||||
|
|
||||||
if prev_nccl_async_error_handling is not None:
|
if prev_nccl_async_error_handling is not None:
|
||||||
os.environ[
|
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = (
|
||||||
"TORCH_NCCL_ASYNC_ERROR_HANDLING"
|
prev_nccl_async_error_handling
|
||||||
] = prev_nccl_async_error_handling
|
)
|
||||||
|
|
||||||
@requires_nccl()
|
@requires_nccl()
|
||||||
@skip_if_lt_x_gpu(3)
|
@skip_if_lt_x_gpu(3)
|
||||||
|
|
@ -2931,9 +2931,9 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
|
||||||
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "0"
|
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "0"
|
||||||
self._test_barrier_error()
|
self._test_barrier_error()
|
||||||
if prev_nccl_async_error_handling is not None:
|
if prev_nccl_async_error_handling is not None:
|
||||||
os.environ[
|
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = (
|
||||||
"TORCH_NCCL_ASYNC_ERROR_HANDLING"
|
prev_nccl_async_error_handling
|
||||||
] = prev_nccl_async_error_handling
|
)
|
||||||
|
|
||||||
@requires_nccl()
|
@requires_nccl()
|
||||||
@requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
|
@requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
|
||||||
|
|
@ -2984,9 +2984,9 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
|
||||||
process_group.abort()
|
process_group.abort()
|
||||||
|
|
||||||
if prev_nccl_async_error_handling is not None:
|
if prev_nccl_async_error_handling is not None:
|
||||||
os.environ[
|
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = (
|
||||||
"TORCH_NCCL_ASYNC_ERROR_HANDLING"
|
prev_nccl_async_error_handling
|
||||||
] = prev_nccl_async_error_handling
|
)
|
||||||
|
|
||||||
@requires_nccl()
|
@requires_nccl()
|
||||||
@requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
|
@requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
|
||||||
|
|
@ -3065,9 +3065,9 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
|
||||||
os.remove(new_file_name)
|
os.remove(new_file_name)
|
||||||
|
|
||||||
if prev_nccl_async_error_handling is not None:
|
if prev_nccl_async_error_handling is not None:
|
||||||
os.environ[
|
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = (
|
||||||
"TORCH_NCCL_ASYNC_ERROR_HANDLING"
|
prev_nccl_async_error_handling
|
||||||
] = prev_nccl_async_error_handling
|
)
|
||||||
|
|
||||||
def _run_invalid_nccl_blocking_wait_env(self, val):
|
def _run_invalid_nccl_blocking_wait_env(self, val):
|
||||||
os.environ["TORCH_NCCL_BLOCKING_WAIT"] = val
|
os.environ["TORCH_NCCL_BLOCKING_WAIT"] = val
|
||||||
|
|
@ -3360,9 +3360,7 @@ class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase):
|
||||||
self.assertEqual(_get_intra_node_comm_usage_counter(), 3)
|
self.assertEqual(_get_intra_node_comm_usage_counter(), 3)
|
||||||
|
|
||||||
# Verify that IntraNodeComm is not used beyond 10MB
|
# Verify that IntraNodeComm is not used beyond 10MB
|
||||||
t = torch.full(
|
t = torch.full((10 * 1024**2 // 2 + 1,), self.rank, dtype=torch.bfloat16).cuda()
|
||||||
(10 * 1024**2 // 2 + 1,), self.rank, dtype=torch.bfloat16
|
|
||||||
).cuda()
|
|
||||||
c10d.all_reduce(t, c10d.ReduceOp.SUM)
|
c10d.all_reduce(t, c10d.ReduceOp.SUM)
|
||||||
self.assertTrue(t.eq(expect).all())
|
self.assertTrue(t.eq(expect).all())
|
||||||
self.assertEqual(_get_intra_node_comm_usage_counter(), 3)
|
self.assertEqual(_get_intra_node_comm_usage_counter(), 3)
|
||||||
|
|
@ -4249,9 +4247,9 @@ class SparseCollective(MultiProcessTestCase):
|
||||||
class NCCLTraceTestBase(MultiProcessTestCase):
|
class NCCLTraceTestBase(MultiProcessTestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super().setUp()
|
super().setUp()
|
||||||
os.environ[
|
os.environ["TORCH_NCCL_ENABLE_TIMING"] = (
|
||||||
"TORCH_NCCL_ENABLE_TIMING"
|
"0" # see 'timing_enabled' parametrized tests
|
||||||
] = "0" # see 'timing_enabled' parametrized tests
|
)
|
||||||
os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "1000"
|
os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "1000"
|
||||||
os.environ["TORCH_NCCL_DUMP_ON_TIMEOUT"] = "1"
|
os.environ["TORCH_NCCL_DUMP_ON_TIMEOUT"] = "1"
|
||||||
self.tempdir = tempfile.TemporaryDirectory()
|
self.tempdir = tempfile.TemporaryDirectory()
|
||||||
|
|
@ -5331,8 +5329,8 @@ class ProcessGroupNCCLLargerScaleTest(MultiProcessTestCase):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
assert (
|
assert not torch.cuda._initialized, (
|
||||||
not torch.cuda._initialized
|
"test_distributed must not have initialized CUDA context on main process"
|
||||||
), "test_distributed must not have initialized CUDA context on main process"
|
)
|
||||||
|
|
||||||
run_tests()
|
run_tests()
|
||||||
|
|
|
||||||
|
|
@ -1090,8 +1090,8 @@ class UccProcessGroupWithDispatchedCollectivesTests(
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
assert (
|
assert not torch.cuda._initialized, (
|
||||||
not torch.cuda._initialized
|
"test_distributed must not have initialized CUDA context on main process"
|
||||||
), "test_distributed must not have initialized CUDA context on main process"
|
)
|
||||||
|
|
||||||
run_tests()
|
run_tests()
|
||||||
|
|
|
||||||
|
|
@ -90,9 +90,9 @@ class TestCollectiveUtils(MultiProcessTestCase):
|
||||||
|
|
||||||
res = all_gather(data_or_fn=func, pg=pg)
|
res = all_gather(data_or_fn=func, pg=pg)
|
||||||
func.assert_called_once()
|
func.assert_called_once()
|
||||||
assert res == list(
|
assert res == list(range(self.world_size)), (
|
||||||
range(self.world_size)
|
f"Expect res to be list of 0 through {self.world_size} (got {res})"
|
||||||
), f"Expect res to be list of 0 through {self.world_size} (got {res})"
|
)
|
||||||
|
|
||||||
def test_all_gather_result_no_pg(self) -> None:
|
def test_all_gather_result_no_pg(self) -> None:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -207,8 +207,8 @@ class TestCollectives(TestCase):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
assert (
|
assert not torch.cuda._initialized, (
|
||||||
not torch.cuda._initialized
|
"test_distributed must not have initialized CUDA context on main process"
|
||||||
), "test_distributed must not have initialized CUDA context on main process"
|
)
|
||||||
|
|
||||||
run_tests()
|
run_tests()
|
||||||
|
|
|
||||||
|
|
@ -490,8 +490,9 @@ class DeviceMeshTestNDim(DTensorTestBase):
|
||||||
|
|
||||||
# Create shard groups (e.g. (0, 1, 2, 3), (4, 5, 6, 7))
|
# Create shard groups (e.g. (0, 1, 2, 3), (4, 5, 6, 7))
|
||||||
# and assign the correct shard group to each rank
|
# and assign the correct shard group to each rank
|
||||||
shard_rank_lists = list(range(0, self.world_size // 2)), list(
|
shard_rank_lists = (
|
||||||
range(self.world_size // 2, self.world_size)
|
list(range(0, self.world_size // 2)),
|
||||||
|
list(range(self.world_size // 2, self.world_size)),
|
||||||
)
|
)
|
||||||
shard_groups = (
|
shard_groups = (
|
||||||
new_group(shard_rank_lists[0]),
|
new_group(shard_rank_lists[0]),
|
||||||
|
|
|
||||||
|
|
@ -1830,9 +1830,7 @@ class TestSingleProc(DynamoDistributedSingleProcTestCase):
|
||||||
f"""{expected_guard_source} "L['self']._modules['net']" TYPE_MATCH"""
|
f"""{expected_guard_source} "L['self']._modules['net']" TYPE_MATCH"""
|
||||||
).check(
|
).check(
|
||||||
f"""{expected_guard_source} "L['self']._modules['net']._modules['0']" TYPE_MATCH"""
|
f"""{expected_guard_source} "L['self']._modules['net']._modules['0']" TYPE_MATCH"""
|
||||||
).run(
|
).run(GUARDS_FILE.getvalue())
|
||||||
GUARDS_FILE.getvalue()
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertTrue(same(correct_outputs, outputs))
|
self.assertTrue(same(correct_outputs, outputs))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -519,12 +519,13 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
|
||||||
out = a2a / a2a.sum(dim=0)
|
out = a2a / a2a.sum(dim=0)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
with _dynamo_dist_per_rank_init(
|
with (
|
||||||
self.rank, self.world_size
|
_dynamo_dist_per_rank_init(self.rank, self.world_size),
|
||||||
), torch._dynamo.config.patch(
|
torch._dynamo.config.patch(
|
||||||
dynamic_shapes=True,
|
dynamic_shapes=True,
|
||||||
capture_dynamic_output_shape_ops=True,
|
capture_dynamic_output_shape_ops=True,
|
||||||
capture_scalar_outputs=True,
|
capture_scalar_outputs=True,
|
||||||
|
),
|
||||||
):
|
):
|
||||||
row = self.world_size * (self.rank + 1) * (self.world_size + 1) / 2
|
row = self.world_size * (self.rank + 1) * (self.world_size + 1) / 2
|
||||||
input_split_sizes_tensor = torch.tensor(
|
input_split_sizes_tensor = torch.tensor(
|
||||||
|
|
@ -680,15 +681,15 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
|
||||||
|
|
||||||
return torch.ops.custom_ns.foo(a2a)
|
return torch.ops.custom_ns.foo(a2a)
|
||||||
|
|
||||||
with _dynamo_dist_per_rank_init(
|
with (
|
||||||
self.rank, self.world_size
|
_dynamo_dist_per_rank_init(self.rank, self.world_size),
|
||||||
), torch._dynamo.config.patch(
|
torch._dynamo.config.patch(
|
||||||
dynamic_shapes=True,
|
dynamic_shapes=True,
|
||||||
capture_dynamic_output_shape_ops=True,
|
capture_dynamic_output_shape_ops=True,
|
||||||
capture_scalar_outputs=True,
|
capture_scalar_outputs=True,
|
||||||
), torch.library._scoped_library(
|
),
|
||||||
"custom_ns", "FRAGMENT"
|
torch.library._scoped_library("custom_ns", "FRAGMENT") as lib,
|
||||||
) as lib:
|
):
|
||||||
lib.define(
|
lib.define(
|
||||||
"alltoall_autograd(Tensor input, SymInt[]? output_split_sizes, SymInt[]? input_split_sizes, str tag, int[] ranks, int group_size) -> Tensor" # noqa: B950
|
"alltoall_autograd(Tensor input, SymInt[]? output_split_sizes, SymInt[]? input_split_sizes, str tag, int[] ranks, int group_size) -> Tensor" # noqa: B950
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -464,8 +464,8 @@ class ProcessGroupGlooWrapperTest(AbstractProcessGroupWrapperTest):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
assert (
|
assert not torch.cuda._initialized, (
|
||||||
not torch.cuda._initialized
|
"test_pg_wrapper must not have initialized CUDA context on main process"
|
||||||
), "test_pg_wrapper must not have initialized CUDA context on main process"
|
)
|
||||||
|
|
||||||
run_tests()
|
run_tests()
|
||||||
|
|
|
||||||
|
|
@ -372,12 +372,8 @@ class TCPStoreTest(TestCase, StoreTestBase):
|
||||||
# Use noqa to silence flake8.
|
# Use noqa to silence flake8.
|
||||||
# Need to store in an unused variable here to ensure the first
|
# Need to store in an unused variable here to ensure the first
|
||||||
# object is not destroyed before the second object is created.
|
# object is not destroyed before the second object is created.
|
||||||
store1 = dist.TCPStore(
|
store1 = dist.TCPStore(addr, port, 1, True, use_libuv=self._use_libuv) # noqa: F841
|
||||||
addr, port, 1, True, use_libuv=self._use_libuv
|
store2 = dist.TCPStore(addr, port, 1, True, use_libuv=self._use_libuv) # noqa: F841
|
||||||
) # noqa: F841
|
|
||||||
store2 = dist.TCPStore(
|
|
||||||
addr, port, 1, True, use_libuv=self._use_libuv
|
|
||||||
) # noqa: F841
|
|
||||||
self.assertEqual(store1.libuvBackend, self._use_libuv)
|
self.assertEqual(store1.libuvBackend, self._use_libuv)
|
||||||
self.assertEqual(store2.libuvBackend, self._use_libuv)
|
self.assertEqual(store2.libuvBackend, self._use_libuv)
|
||||||
|
|
||||||
|
|
@ -767,7 +763,7 @@ class RendezvousFileTest(TestCase):
|
||||||
|
|
||||||
def test_nominal(self):
|
def test_nominal(self):
|
||||||
with tempfile.NamedTemporaryFile(delete=False) as file:
|
with tempfile.NamedTemporaryFile(delete=False) as file:
|
||||||
url = f'file:///{file.name.replace(os.path.sep, "/")}?world_size=2'
|
url = f"file:///{file.name.replace(os.path.sep, '/')}?world_size=2"
|
||||||
gen0 = dist.rendezvous(url + "&rank=0")
|
gen0 = dist.rendezvous(url + "&rank=0")
|
||||||
store0, rank0, size0 = next(gen0)
|
store0, rank0, size0 = next(gen0)
|
||||||
self.assertEqual(0, rank0)
|
self.assertEqual(0, rank0)
|
||||||
|
|
@ -1178,8 +1174,8 @@ class TestClientProtocol(TestCase):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
assert (
|
assert not torch.cuda._initialized, (
|
||||||
not torch.cuda._initialized
|
"test_distributed must not have initialized CUDA context on main process"
|
||||||
), "test_distributed must not have initialized CUDA context on main process"
|
)
|
||||||
|
|
||||||
run_tests()
|
run_tests()
|
||||||
|
|
|
||||||
|
|
@ -81,9 +81,9 @@ def count_ops(
|
||||||
for node in gm.graph.nodes:
|
for node in gm.graph.nodes:
|
||||||
if match_rng_op(node, op) or node.target == op:
|
if match_rng_op(node, op) or node.target == op:
|
||||||
actual_count += 1
|
actual_count += 1
|
||||||
assert (
|
assert actual_count >= freq_ge, (
|
||||||
actual_count >= freq_ge
|
f"In graph {gm}, expected {op} to have occurred at least {freq_ge} times in the graph, but got {actual_count}."
|
||||||
), f"In graph {gm}, expected {op} to have occurred at least {freq_ge} times in the graph, but got {actual_count}."
|
)
|
||||||
return gm
|
return gm
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -455,9 +455,9 @@ class AutogradFunctionTests(torch._dynamo.test_case.TestCase):
|
||||||
|
|
||||||
# Modify gradient using .data (Dangerous: Breaks autograd tracking!)
|
# Modify gradient using .data (Dangerous: Breaks autograd tracking!)
|
||||||
modified_grad = grad_output.clone()
|
modified_grad = grad_output.clone()
|
||||||
modified_grad.data[
|
modified_grad.data[input_tensor.data < 0] = (
|
||||||
input_tensor.data < 0
|
0 # Zero-out gradients for negative inputs
|
||||||
] = 0 # Zero-out gradients for negative inputs
|
)
|
||||||
|
|
||||||
return modified_grad * 3
|
return modified_grad * 3
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -195,10 +195,13 @@ class GraphModule(torch.nn.Module):
|
||||||
def f(x, y):
|
def f(x, y):
|
||||||
return invoke_quant_test(inner, [x, y], scheme="nf4")
|
return invoke_quant_test(inner, [x, y], scheme="nf4")
|
||||||
|
|
||||||
with mock.patch(
|
with (
|
||||||
"torch._dynamo.variables.higher_order_ops.BaseHOPVariable.supports_input_mutation",
|
mock.patch(
|
||||||
True,
|
"torch._dynamo.variables.higher_order_ops.BaseHOPVariable.supports_input_mutation",
|
||||||
), torch.no_grad():
|
True,
|
||||||
|
),
|
||||||
|
torch.no_grad(),
|
||||||
|
):
|
||||||
torch.compile(f, backend=bk, fullgraph=True)(x.clone(), y)
|
torch.compile(f, backend=bk, fullgraph=True)(x.clone(), y)
|
||||||
|
|
||||||
self.assertEqual(len(bk.graphs), 1)
|
self.assertEqual(len(bk.graphs), 1)
|
||||||
|
|
@ -319,10 +322,13 @@ class GraphModule(torch.nn.Module):
|
||||||
x = torch.randn(3, 3, requires_grad=False)
|
x = torch.randn(3, 3, requires_grad=False)
|
||||||
x_clone = x.clone()
|
x_clone = x.clone()
|
||||||
y = torch.randn(3, 3, requires_grad=True)
|
y = torch.randn(3, 3, requires_grad=True)
|
||||||
with mock.patch(
|
with (
|
||||||
"torch._dynamo.variables.higher_order_ops.BaseHOPVariable.supports_input_mutation",
|
mock.patch(
|
||||||
True,
|
"torch._dynamo.variables.higher_order_ops.BaseHOPVariable.supports_input_mutation",
|
||||||
), torch.no_grad():
|
True,
|
||||||
|
),
|
||||||
|
torch.no_grad(),
|
||||||
|
):
|
||||||
compiled_out = torch.compile(f, backend=backend, fullgraph=True)(x, y)
|
compiled_out = torch.compile(f, backend=backend, fullgraph=True)(x, y)
|
||||||
self.assertEqual(x, x_clone + 1)
|
self.assertEqual(x, x_clone + 1)
|
||||||
self.assertEqual(compiled_out, x_clone + y + 1)
|
self.assertEqual(compiled_out, x_clone + y + 1)
|
||||||
|
|
|
||||||
|
|
@ -53,8 +53,8 @@ class BytecodeTests(torch._dynamo.test_case.TestCase):
|
||||||
fn_str = f"""\
|
fn_str = f"""\
|
||||||
def fn():
|
def fn():
|
||||||
foo.bar(1, 2, 3)
|
foo.bar(1, 2, 3)
|
||||||
{str(chr(10)).join(' ' * 4 + 'x' + str(i) + ' = 1' for i in range(1 << 9))}
|
{str(chr(10)).join(" " * 4 + "x" + str(i) + " = 1" for i in range(1 << 9))}
|
||||||
l = [{' '.join('x' + str(i) + ',' for i in range(1 << 9))}]
|
l = [{" ".join("x" + str(i) + "," for i in range(1 << 9))}]
|
||||||
"""
|
"""
|
||||||
locals = {}
|
locals = {}
|
||||||
exec(fn_str, {}, locals)
|
exec(fn_str, {}, locals)
|
||||||
|
|
|
||||||
|
|
@ -26,9 +26,10 @@ class CallbackTests(TestCase):
|
||||||
def test_callbacks_with_duplicate_prevention(self) -> None:
|
def test_callbacks_with_duplicate_prevention(self) -> None:
|
||||||
trigger = CallbackTrigger.DYNAMO
|
trigger = CallbackTrigger.DYNAMO
|
||||||
compile_id = CompileId(0, 0)
|
compile_id = CompileId(0, 0)
|
||||||
with callback_handler.install_callbacks(
|
with (
|
||||||
trigger, compile_id
|
callback_handler.install_callbacks(trigger, compile_id),
|
||||||
), callback_handler.install_callbacks(trigger, compile_id):
|
callback_handler.install_callbacks(trigger, compile_id),
|
||||||
|
):
|
||||||
self._on_compile_start.assert_called_once()
|
self._on_compile_start.assert_called_once()
|
||||||
self._on_compile_end.assert_called_once()
|
self._on_compile_end.assert_called_once()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1252,10 +1252,13 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase):
|
||||||
def f(x, y):
|
def f(x, y):
|
||||||
return x + y
|
return x + y
|
||||||
|
|
||||||
x, y = torch.ones(
|
x, y = (
|
||||||
1,
|
torch.ones(
|
||||||
), torch.zeros(
|
1,
|
||||||
1,
|
),
|
||||||
|
torch.zeros(
|
||||||
|
1,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
return f(x, y)
|
return f(x, y)
|
||||||
|
|
||||||
|
|
@ -1289,10 +1292,13 @@ class GraphModule(torch.nn.Module):
|
||||||
def f(x, y):
|
def f(x, y):
|
||||||
return x + y
|
return x + y
|
||||||
|
|
||||||
x, y = torch.ones(
|
x, y = (
|
||||||
1,
|
torch.ones(
|
||||||
), torch.zeros(
|
1,
|
||||||
1,
|
),
|
||||||
|
torch.zeros(
|
||||||
|
1,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
return f(x, y)
|
return f(x, y)
|
||||||
|
|
||||||
|
|
@ -1335,10 +1341,13 @@ class GraphModule(torch.nn.Module):
|
||||||
|
|
||||||
return inner_fn(x, y) + x
|
return inner_fn(x, y) + x
|
||||||
|
|
||||||
x, y = torch.ones(
|
x, y = (
|
||||||
1,
|
torch.ones(
|
||||||
), torch.zeros(
|
1,
|
||||||
1,
|
),
|
||||||
|
torch.zeros(
|
||||||
|
1,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
return f(x, y)
|
return f(x, y)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,7 @@ from torch.testing._internal.common_utils import (
|
||||||
|
|
||||||
|
|
||||||
class CustomException(Exception):
|
class CustomException(Exception):
|
||||||
...
|
pass
|
||||||
|
|
||||||
|
|
||||||
class CustomExceptionMeta(type):
|
class CustomExceptionMeta(type):
|
||||||
|
|
@ -28,7 +28,7 @@ class CustomExceptionMeta(type):
|
||||||
|
|
||||||
|
|
||||||
class CustomExceptionWithInstanceCheck(Exception, metaclass=CustomExceptionMeta):
|
class CustomExceptionWithInstanceCheck(Exception, metaclass=CustomExceptionMeta):
|
||||||
...
|
pass
|
||||||
|
|
||||||
|
|
||||||
class CustomExceptionWithArgs(Exception):
|
class CustomExceptionWithArgs(Exception):
|
||||||
|
|
@ -358,7 +358,7 @@ class ExceptionTests(torch._dynamo.test_case.TestCase):
|
||||||
|
|
||||||
def test_raise_custom_exception(self):
|
def test_raise_custom_exception(self):
|
||||||
class Exc(Exception):
|
class Exc(Exception):
|
||||||
...
|
pass
|
||||||
|
|
||||||
@torch.compile(backend="eager", fullgraph=True)
|
@torch.compile(backend="eager", fullgraph=True)
|
||||||
def fn(t):
|
def fn(t):
|
||||||
|
|
@ -375,7 +375,7 @@ class ExceptionTests(torch._dynamo.test_case.TestCase):
|
||||||
|
|
||||||
def test_raise_custom_exception_with_args(self):
|
def test_raise_custom_exception_with_args(self):
|
||||||
class Exc(Exception):
|
class Exc(Exception):
|
||||||
...
|
pass
|
||||||
|
|
||||||
@torch.compile(backend="eager", fullgraph=True)
|
@torch.compile(backend="eager", fullgraph=True)
|
||||||
def fn(t):
|
def fn(t):
|
||||||
|
|
|
||||||
|
|
@ -324,9 +324,13 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
|
||||||
distributed_state=None,
|
distributed_state=None,
|
||||||
package=None,
|
package=None,
|
||||||
)
|
)
|
||||||
with compile_context(CompileContext(CompileId(0, 0))), tracing(
|
with (
|
||||||
tracer.output.tracing_context
|
compile_context(CompileContext(CompileId(0, 0))),
|
||||||
), tracer.set_current_tx(), get_metrics_context(), dynamo_timed(""):
|
tracing(tracer.output.tracing_context),
|
||||||
|
tracer.set_current_tx(),
|
||||||
|
get_metrics_context(),
|
||||||
|
dynamo_timed(""),
|
||||||
|
):
|
||||||
tracer.run()
|
tracer.run()
|
||||||
|
|
||||||
check_fn_manager = CheckFunctionManager(
|
check_fn_manager = CheckFunctionManager(
|
||||||
|
|
|
||||||
|
|
@ -1092,9 +1092,7 @@ not ___dict_contains('bbbbbbbb', G['sys'].modules)
|
||||||
not ___dict_contains('cccccccc', G['sys'].modules)
|
not ___dict_contains('cccccccc', G['sys'].modules)
|
||||||
str(L['x'].device) == 'cpu'
|
str(L['x'].device) == 'cpu'
|
||||||
str(L['x'].dtype) == 'torch.float32'
|
str(L['x'].dtype) == 'torch.float32'
|
||||||
utils_device.CURRENT_DEVICE == None""".split(
|
utils_device.CURRENT_DEVICE == None""".split("\n"):
|
||||||
"\n"
|
|
||||||
):
|
|
||||||
self.assertIn(
|
self.assertIn(
|
||||||
line,
|
line,
|
||||||
guard_code_str,
|
guard_code_str,
|
||||||
|
|
@ -2806,7 +2804,7 @@ utils_device.CURRENT_DEVICE == None""".split(
|
||||||
"int",
|
"int",
|
||||||
np.intp,
|
np.intp,
|
||||||
np.int32,
|
np.int32,
|
||||||
np.uint8
|
np.uint8,
|
||||||
# np.dtype('int') # XXX: as above
|
# np.dtype('int') # XXX: as above
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
@ -5527,9 +5525,9 @@ utils_device.CURRENT_DEVICE == None""".split(
|
||||||
|
|
||||||
def forward(self, idx, targets=None):
|
def forward(self, idx, targets=None):
|
||||||
b, t = idx.size()
|
b, t = idx.size()
|
||||||
assert (
|
assert t <= self.block_size, (
|
||||||
t <= self.block_size
|
"Cannot forward, model block size is exhausted."
|
||||||
), "Cannot forward, model block size is exhausted."
|
)
|
||||||
|
|
||||||
# forward the GPT model
|
# forward the GPT model
|
||||||
token_embeddings = self.tok_emb(
|
token_embeddings = self.tok_emb(
|
||||||
|
|
@ -6075,15 +6073,17 @@ utils_device.CURRENT_DEVICE == None""".split(
|
||||||
def count_graph_break_msgs(msgs):
|
def count_graph_break_msgs(msgs):
|
||||||
return sum("Graph break in user code" in msg for msg in msgs)
|
return sum("Graph break in user code" in msg for msg in msgs)
|
||||||
|
|
||||||
with self.assertLogs(
|
with (
|
||||||
logger="torch._dynamo", level=logging.DEBUG
|
self.assertLogs(logger="torch._dynamo", level=logging.DEBUG) as log,
|
||||||
) as log, torch._dynamo.config.patch(verbose=True):
|
torch._dynamo.config.patch(verbose=True),
|
||||||
|
):
|
||||||
f1(torch.randn(10), torch.randn(10))
|
f1(torch.randn(10), torch.randn(10))
|
||||||
self.assertGreater(count_graph_break_msgs(log.output), 1)
|
self.assertGreater(count_graph_break_msgs(log.output), 1)
|
||||||
|
|
||||||
with self.assertLogs(
|
with (
|
||||||
logger="torch._dynamo", level=logging.DEBUG
|
self.assertLogs(logger="torch._dynamo", level=logging.DEBUG) as log,
|
||||||
) as log, torch._dynamo.config.patch(verbose=False):
|
torch._dynamo.config.patch(verbose=False),
|
||||||
|
):
|
||||||
g1(torch.randn(10), torch.randn(10))
|
g1(torch.randn(10), torch.randn(10))
|
||||||
self.assertEqual(count_graph_break_msgs(log.output), 1)
|
self.assertEqual(count_graph_break_msgs(log.output), 1)
|
||||||
|
|
||||||
|
|
@ -8235,8 +8235,9 @@ utils_device.CURRENT_DEVICE == None""".split(
|
||||||
def f(a):
|
def f(a):
|
||||||
return h(a)
|
return h(a)
|
||||||
|
|
||||||
with warnings.catch_warnings(record=True) as w, self.assertRaises(
|
with (
|
||||||
torch._dynamo.exc.BackendCompilerFailed
|
warnings.catch_warnings(record=True) as w,
|
||||||
|
self.assertRaises(torch._dynamo.exc.BackendCompilerFailed),
|
||||||
):
|
):
|
||||||
f(torch.randn(2, 2, requires_grad=True))
|
f(torch.randn(2, 2, requires_grad=True))
|
||||||
|
|
||||||
|
|
@ -8429,8 +8430,7 @@ utils_device.CURRENT_DEVICE == None""".split(
|
||||||
|
|
||||||
def test_torch_compile_ctx_on_forward_and_training_step(self):
|
def test_torch_compile_ctx_on_forward_and_training_step(self):
|
||||||
class MyModel(torch.nn.Module):
|
class MyModel(torch.nn.Module):
|
||||||
def forward(self):
|
def forward(self): ...
|
||||||
...
|
|
||||||
|
|
||||||
def training_step(self):
|
def training_step(self):
|
||||||
self()
|
self()
|
||||||
|
|
|
||||||
|
|
@ -2094,11 +2094,12 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
|
||||||
mod = MockModule()
|
mod = MockModule()
|
||||||
# Each submod is compiled separately and has a different nn module
|
# Each submod is compiled separately and has a different nn module
|
||||||
# guard. Ensure that recompilation logic is handle correctly.
|
# guard. Ensure that recompilation logic is handle correctly.
|
||||||
with unittest.mock.patch(
|
with (
|
||||||
"torch._dynamo.config.error_on_recompile", True
|
unittest.mock.patch("torch._dynamo.config.error_on_recompile", True),
|
||||||
), unittest.mock.patch(
|
unittest.mock.patch(
|
||||||
"torch._dynamo.config.recompile_limit",
|
"torch._dynamo.config.recompile_limit",
|
||||||
recompile_limit,
|
recompile_limit,
|
||||||
|
),
|
||||||
):
|
):
|
||||||
x = torch.randn(*size, requires_grad=True)
|
x = torch.randn(*size, requires_grad=True)
|
||||||
mod(x)
|
mod(x)
|
||||||
|
|
@ -2160,11 +2161,12 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
|
||||||
mod = MockModule()
|
mod = MockModule()
|
||||||
# Each submod is compiled separately and has a different nn module
|
# Each submod is compiled separately and has a different nn module
|
||||||
# guard. Ensure that recompilation logic is handle correctly.
|
# guard. Ensure that recompilation logic is handle correctly.
|
||||||
with unittest.mock.patch(
|
with (
|
||||||
"torch._dynamo.config.error_on_recompile", True
|
unittest.mock.patch("torch._dynamo.config.error_on_recompile", True),
|
||||||
), unittest.mock.patch(
|
unittest.mock.patch(
|
||||||
"torch._dynamo.config.recompile_limit",
|
"torch._dynamo.config.recompile_limit",
|
||||||
recompile_limit,
|
recompile_limit,
|
||||||
|
),
|
||||||
):
|
):
|
||||||
x = torch.randn(*size, requires_grad=True)
|
x = torch.randn(*size, requires_grad=True)
|
||||||
mod(x)
|
mod(x)
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@
|
||||||
PYTEST_DONT_REWRITE (prevents pytest from rewriting assertions, which interferes
|
PYTEST_DONT_REWRITE (prevents pytest from rewriting assertions, which interferes
|
||||||
with test_adam in OptimizerTests)
|
with test_adam in OptimizerTests)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
|
||||||
|
|
@ -238,9 +238,7 @@ class RecompileUxTests(torch._dynamo.test_case.TestCase):
|
||||||
tensor 'x' size mismatch at index 0. expected 11, actual 12
|
tensor 'x' size mismatch at index 0. expected 11, actual 12
|
||||||
tensor 'x' size mismatch at index 0. expected 10, actual 12
|
tensor 'x' size mismatch at index 0. expected 10, actual 12
|
||||||
tensor 'x' size mismatch at index 0. expected 9, actual 12
|
tensor 'x' size mismatch at index 0. expected 9, actual 12
|
||||||
tensor 'x' size mismatch at index 0. expected 8, actual 12""".split(
|
tensor 'x' size mismatch at index 0. expected 8, actual 12""".split("\n"):
|
||||||
"\n"
|
|
||||||
):
|
|
||||||
self.assertIn(
|
self.assertIn(
|
||||||
line,
|
line,
|
||||||
failure_str,
|
failure_str,
|
||||||
|
|
@ -276,9 +274,7 @@ tensor 'x' size mismatch at index 0. expected 8, actual 12""".split(
|
||||||
opt_f([7, 8])
|
opt_f([7, 8])
|
||||||
|
|
||||||
for line in """\
|
for line in """\
|
||||||
len(x) == 3""".split(
|
len(x) == 3""".split("\n"):
|
||||||
"\n"
|
|
||||||
):
|
|
||||||
self.assertIn(line, filter_reasons())
|
self.assertIn(line, filter_reasons())
|
||||||
|
|
||||||
failure_reasons.clear()
|
failure_reasons.clear()
|
||||||
|
|
@ -286,9 +282,7 @@ len(x) == 3""".split(
|
||||||
|
|
||||||
for line in """\
|
for line in """\
|
||||||
len(x) == 2
|
len(x) == 2
|
||||||
len(x) == 3""".split(
|
len(x) == 3""".split("\n"):
|
||||||
"\n"
|
|
||||||
):
|
|
||||||
self.assertIn(line, filter_reasons())
|
self.assertIn(line, filter_reasons())
|
||||||
|
|
||||||
@torch._dynamo.config.patch(recompile_limit=1)
|
@torch._dynamo.config.patch(recompile_limit=1)
|
||||||
|
|
|
||||||
|
|
@ -179,9 +179,9 @@ def shapes_to_tensor(x, device=None):
|
||||||
if torch.jit.is_scripting():
|
if torch.jit.is_scripting():
|
||||||
return torch.as_tensor(x, device=device)
|
return torch.as_tensor(x, device=device)
|
||||||
if torch.jit.is_tracing():
|
if torch.jit.is_tracing():
|
||||||
assert all(
|
assert all(isinstance(t, torch.Tensor) for t in x), (
|
||||||
isinstance(t, torch.Tensor) for t in x
|
"Shape should be tensor during tracing!"
|
||||||
), "Shape should be tensor during tracing!"
|
)
|
||||||
# as_tensor should not be used in tracing because it records a constant
|
# as_tensor should not be used in tracing because it records a constant
|
||||||
ret = torch.stack(x)
|
ret = torch.stack(x)
|
||||||
if ret.device != device: # avoid recording a hard-coded device if not necessary
|
if ret.device != device: # avoid recording a hard-coded device if not necessary
|
||||||
|
|
@ -480,9 +480,9 @@ class PartialT5(torch.nn.Module):
|
||||||
real_seq_length = seq_length
|
real_seq_length = seq_length
|
||||||
|
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
assert (
|
assert len(past_key_value) == 2, (
|
||||||
len(past_key_value) == 2
|
f"past_key_value should have 2 past states: keys and values. Got {len(past_key_value)} past states"
|
||||||
), f"past_key_value should have 2 past states: keys and values. Got {len(past_key_value)} past states"
|
)
|
||||||
real_seq_length += (
|
real_seq_length += (
|
||||||
past_key_value[0].shape[2] if query_length is None else query_length
|
past_key_value[0].shape[2] if query_length is None else query_length
|
||||||
)
|
)
|
||||||
|
|
@ -4877,9 +4877,9 @@ def forward(self, s77 : torch.SymInt, s27 : torch.SymInt, L_x_ : torch.Tensor):
|
||||||
with warnings.catch_warnings(record=True):
|
with warnings.catch_warnings(record=True):
|
||||||
data_len = len(value)
|
data_len = len(value)
|
||||||
if len(self._fields):
|
if len(self._fields):
|
||||||
assert (
|
assert len(self) == data_len, (
|
||||||
len(self) == data_len
|
f"Adding a field of length {data_len} to a Instances of length {len(self)}"
|
||||||
), f"Adding a field of length {data_len} to a Instances of length {len(self)}"
|
)
|
||||||
self._fields[name] = value
|
self._fields[name] = value
|
||||||
|
|
||||||
def get(self, name: str) -> Any:
|
def get(self, name: str) -> Any:
|
||||||
|
|
|
||||||
|
|
@ -108,9 +108,10 @@ def get_view_test_cases():
|
||||||
for requires_grad_1, requires_grad_2 in itertools.product(
|
for requires_grad_1, requires_grad_2 in itertools.product(
|
||||||
[True, False], repeat=2
|
[True, False], repeat=2
|
||||||
):
|
):
|
||||||
yield partial(
|
yield (
|
||||||
mk_leaf, base_is_nt, requires_grad_1, requires_grad_2
|
partial(mk_leaf, base_is_nt, requires_grad_1, requires_grad_2),
|
||||||
), f"{prefix}_leaf_{requires_grad_1}_{requires_grad_2}"
|
f"{prefix}_leaf_{requires_grad_1}_{requires_grad_2}",
|
||||||
|
)
|
||||||
|
|
||||||
# (3) obscure case:
|
# (3) obscure case:
|
||||||
# view is not a leaf (implies requires_grad True)
|
# view is not a leaf (implies requires_grad True)
|
||||||
|
|
@ -118,9 +119,10 @@ def get_view_test_cases():
|
||||||
yield partial(mk_obscure, base_is_nt), f"{prefix}_obscure"
|
yield partial(mk_obscure, base_is_nt), f"{prefix}_obscure"
|
||||||
|
|
||||||
# Subclass -> Dense
|
# Subclass -> Dense
|
||||||
yield lambda: get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)[
|
yield (
|
||||||
0
|
lambda: get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)[0].clone(),
|
||||||
].clone(), "subclass_dense"
|
"subclass_dense",
|
||||||
|
)
|
||||||
|
|
||||||
# Dense -> Subclass -> Dense -> Subclass
|
# Dense -> Subclass -> Dense -> Subclass
|
||||||
def mk_dense_subclass_dense_subclass():
|
def mk_dense_subclass_dense_subclass():
|
||||||
|
|
|
||||||
|
|
@ -151,9 +151,9 @@ def gen_allowed_objs_and_ids(record=False, c_binding_only=True) -> AllowedObject
|
||||||
types.WrapperDescriptorType,
|
types.WrapperDescriptorType,
|
||||||
),
|
),
|
||||||
) or is_special_functions(obj):
|
) or is_special_functions(obj):
|
||||||
torch_name_rule_map[
|
torch_name_rule_map[f"{module.__name__}.{name}"] = (
|
||||||
f"{module.__name__}.{name}"
|
TorchInGraphFunctionVariable
|
||||||
] = TorchInGraphFunctionVariable
|
)
|
||||||
if c_binding_only:
|
if c_binding_only:
|
||||||
if not hasattr(obj, "__code__"):
|
if not hasattr(obj, "__code__"):
|
||||||
c_binding_in_graph_functions.add(obj)
|
c_binding_in_graph_functions.add(obj)
|
||||||
|
|
@ -398,12 +398,15 @@ class TraceRuleTests(torch._dynamo.test_case.TestCase):
|
||||||
)
|
)
|
||||||
self.assertTrue("torch._dynamo" not in torch._dynamo.trace_rules.MOD_INLINELIST)
|
self.assertTrue("torch._dynamo" not in torch._dynamo.trace_rules.MOD_INLINELIST)
|
||||||
|
|
||||||
with unittest.mock.patch(
|
with (
|
||||||
"torch._dynamo.trace_rules.torch_name_rule_map",
|
unittest.mock.patch(
|
||||||
_torch_name_rule_map,
|
"torch._dynamo.trace_rules.torch_name_rule_map",
|
||||||
), unittest.mock.patch(
|
_torch_name_rule_map,
|
||||||
"torch._dynamo.trace_rules.get_torch_obj_rule_map",
|
),
|
||||||
torch._dynamo.trace_rules.get_torch_obj_rule_map.__wrapped__, # bypass functools.lru_cache
|
unittest.mock.patch(
|
||||||
|
"torch._dynamo.trace_rules.get_torch_obj_rule_map",
|
||||||
|
torch._dynamo.trace_rules.get_torch_obj_rule_map.__wrapped__, # bypass functools.lru_cache
|
||||||
|
),
|
||||||
):
|
):
|
||||||
x = torch.rand(3)
|
x = torch.rand(3)
|
||||||
opt_fn = torch.compile(backend="eager", fullgraph=True)(fn)
|
opt_fn = torch.compile(backend="eager", fullgraph=True)(fn)
|
||||||
|
|
@ -419,9 +422,9 @@ class TraceRuleTests(torch._dynamo.test_case.TestCase):
|
||||||
|
|
||||||
_manual_torch_name_rule_map = manual_torch_name_rule_map.copy()
|
_manual_torch_name_rule_map = manual_torch_name_rule_map.copy()
|
||||||
# Force inline `mod.func` by setting trace rule.
|
# Force inline `mod.func` by setting trace rule.
|
||||||
_manual_torch_name_rule_map[
|
_manual_torch_name_rule_map[f"{mod.__name__}.{func.__name__}"] = (
|
||||||
f"{mod.__name__}.{func.__name__}"
|
UserFunctionVariable
|
||||||
] = UserFunctionVariable
|
)
|
||||||
|
|
||||||
_torch_name_rule_map = [
|
_torch_name_rule_map = [
|
||||||
_manual_torch_name_rule_map,
|
_manual_torch_name_rule_map,
|
||||||
|
|
@ -429,12 +432,15 @@ class TraceRuleTests(torch._dynamo.test_case.TestCase):
|
||||||
torch_non_c_binding_in_graph_functions,
|
torch_non_c_binding_in_graph_functions,
|
||||||
]
|
]
|
||||||
|
|
||||||
with unittest.mock.patch(
|
with (
|
||||||
"torch._dynamo.trace_rules.torch_name_rule_map",
|
unittest.mock.patch(
|
||||||
_torch_name_rule_map,
|
"torch._dynamo.trace_rules.torch_name_rule_map",
|
||||||
), unittest.mock.patch(
|
_torch_name_rule_map,
|
||||||
"torch._dynamo.trace_rules.get_torch_obj_rule_map",
|
),
|
||||||
torch._dynamo.trace_rules.get_torch_obj_rule_map.__wrapped__,
|
unittest.mock.patch(
|
||||||
|
"torch._dynamo.trace_rules.get_torch_obj_rule_map",
|
||||||
|
torch._dynamo.trace_rules.get_torch_obj_rule_map.__wrapped__,
|
||||||
|
),
|
||||||
):
|
):
|
||||||
# First adding the module to SKIP_DIRS so that it will be skipped by default.
|
# First adding the module to SKIP_DIRS so that it will be skipped by default.
|
||||||
torch._dynamo.trace_rules.add(mod.__name__)
|
torch._dynamo.trace_rules.add(mod.__name__)
|
||||||
|
|
|
||||||
|
|
@ -593,9 +593,10 @@ class TestDynamoTimed(TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
compilation_events = []
|
compilation_events = []
|
||||||
with dynamo_config.patch({"automatic_dynamic_shapes": False}), mock.patch(
|
with (
|
||||||
"torch._dynamo.utils.log_compilation_event"
|
dynamo_config.patch({"automatic_dynamic_shapes": False}),
|
||||||
) as log_event:
|
mock.patch("torch._dynamo.utils.log_compilation_event") as log_event,
|
||||||
|
):
|
||||||
|
|
||||||
@torch.compile()
|
@torch.compile()
|
||||||
def f(x):
|
def f(x):
|
||||||
|
|
|
||||||
|
|
@ -181,9 +181,12 @@ class TestDraftExport(TestCase):
|
||||||
self.assertEqual(len(report.op_profiles), 1)
|
self.assertEqual(len(report.op_profiles), 1)
|
||||||
self.assertEqual(len(report.op_profiles["mylib.foo8.default"]), 1)
|
self.assertEqual(len(report.op_profiles["mylib.foo8.default"]), 1)
|
||||||
|
|
||||||
with torch._library.fake_profile.unsafe_generate_fake_kernels(
|
with (
|
||||||
report.op_profiles
|
torch._library.fake_profile.unsafe_generate_fake_kernels(
|
||||||
), FakeTensorMode(allow_non_fake_inputs=True, shape_env=ShapeEnv()):
|
report.op_profiles
|
||||||
|
),
|
||||||
|
FakeTensorMode(allow_non_fake_inputs=True, shape_env=ShapeEnv()),
|
||||||
|
):
|
||||||
torch.ops.mylib.foo8(*new_inp)
|
torch.ops.mylib.foo8(*new_inp)
|
||||||
|
|
||||||
# Existing registration has been updated to match the new
|
# Existing registration has been updated to match the new
|
||||||
|
|
|
||||||
|
|
@ -12834,12 +12834,14 @@ def forward(self, x, y):
|
||||||
"y": [Dim("dy")], # y & z incorrect, export is supposed to fail.
|
"y": [Dim("dy")], # y & z incorrect, export is supposed to fail.
|
||||||
"z": [Dim("dz")], # suggested fix should be to match these up.
|
"z": [Dim("dz")], # suggested fix should be to match these up.
|
||||||
}
|
}
|
||||||
with self.assertRaisesRegex( # if disable=True, suggested fixes should not specialize.
|
with (
|
||||||
torch._dynamo.exc.UserError,
|
self.assertRaisesRegex( # if disable=True, suggested fixes should not specialize.
|
||||||
r".*Constraints violated(.*\n)*"
|
torch._dynamo.exc.UserError,
|
||||||
r"Suggested fixes:(.*\n)*"
|
r".*Constraints violated(.*\n)*"
|
||||||
r".*dz = dy(.*\n)*",
|
r"Suggested fixes:(.*\n)*"
|
||||||
) as msg:
|
r".*dz = dy(.*\n)*",
|
||||||
|
) as msg
|
||||||
|
):
|
||||||
export(
|
export(
|
||||||
Foo(),
|
Foo(),
|
||||||
inputs,
|
inputs,
|
||||||
|
|
@ -13675,8 +13677,7 @@ def forward(self, x):
|
||||||
"""Make sure the metadata is kept after exported program run_decompositions."""
|
"""Make sure the metadata is kept after exported program run_decompositions."""
|
||||||
|
|
||||||
@torch.library.custom_op("mylib::add", mutates_args=())
|
@torch.library.custom_op("mylib::add", mutates_args=())
|
||||||
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: ...
|
||||||
...
|
|
||||||
|
|
||||||
@torch.library.register_fake("mylib::add")
|
@torch.library.register_fake("mylib::add")
|
||||||
def _(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
def _(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||||
|
|
|
||||||
|
|
@ -947,10 +947,12 @@ class TestDeserialize(TestCase):
|
||||||
ep = torch.export.export(M(), (torch.ones(3, 3), None, torch.ones(3, 3)))
|
ep = torch.export.export(M(), (torch.ones(3, 3), None, torch.ones(3, 3)))
|
||||||
|
|
||||||
serialized_program = ExportedProgramSerializer(None, 2).serialize(ep)
|
serialized_program = ExportedProgramSerializer(None, 2).serialize(ep)
|
||||||
serialized_program.exported_program.graph_module.signature.input_specs[
|
serialized_program.exported_program.graph_module.signature.input_specs[1] = (
|
||||||
1
|
schema.InputSpec.create(
|
||||||
] = schema.InputSpec.create(
|
user_input=schema.UserInputSpec(
|
||||||
user_input=schema.UserInputSpec(arg=schema.Argument.create(as_none=True))
|
arg=schema.Argument.create(as_none=True)
|
||||||
|
)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
ep = ExportedProgramDeserializer(None).deserialize(
|
ep = ExportedProgramDeserializer(None).deserialize(
|
||||||
serialized_program.exported_program, {}, {}, {}
|
serialized_program.exported_program, {}, {}, {}
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,9 @@ from torch.testing._internal.common_utils import IS_WINDOWS, run_tests, TestCase
|
||||||
{"strict": False},
|
{"strict": False},
|
||||||
{"strict": True},
|
{"strict": True},
|
||||||
],
|
],
|
||||||
class_name_func=lambda cls, _, params: f"{cls.__name__}_{'strict' if params['strict'] else 'nonstrict'}",
|
class_name_func=lambda cls,
|
||||||
|
_,
|
||||||
|
params: f"{cls.__name__}_{'strict' if params['strict'] else 'nonstrict'}",
|
||||||
)
|
)
|
||||||
class TestSwap(TestCase):
|
class TestSwap(TestCase):
|
||||||
def test_unflatten_preserve_signature(self):
|
def test_unflatten_preserve_signature(self):
|
||||||
|
|
|
||||||
|
|
@ -348,8 +348,7 @@ def check_fc(existing_schemas):
|
||||||
"\n\t".join(str(s) for s in matching_new_schemas),
|
"\n\t".join(str(s) for s in matching_new_schemas),
|
||||||
)
|
)
|
||||||
log.warning(
|
log.warning(
|
||||||
"Refer to following reasons for failure "
|
"Refer to following reasons for failure to find FC schema:\n[\n%s\n]",
|
||||||
"to find FC schema:\n[\n%s\n]",
|
|
||||||
"\n\t".join(str(r) for r in possible_failure_reasons),
|
"\n\t".join(str(r) for r in possible_failure_reasons),
|
||||||
)
|
)
|
||||||
broken_ops.append(str(existing_schema))
|
broken_ops.append(str(existing_schema))
|
||||||
|
|
|
||||||
|
|
@ -523,15 +523,15 @@ def decorateForModules(decorator, module_classes, device_type=None, dtypes=None)
|
||||||
dtypes=dtypes,
|
dtypes=dtypes,
|
||||||
):
|
):
|
||||||
name_parts = fn.__qualname__.split(".")
|
name_parts = fn.__qualname__.split(".")
|
||||||
assert (
|
assert len(name_parts) == 2, (
|
||||||
len(name_parts) == 2
|
"Decorator only applies to a test function of a test class"
|
||||||
), "Decorator only applies to a test function of a test class"
|
)
|
||||||
test_case_name, base_test_name = name_parts
|
test_case_name, base_test_name = name_parts
|
||||||
for module_cls in module_classes:
|
for module_cls in module_classes:
|
||||||
matching_module_infos = [m for m in module_db if m.module_cls == module_cls]
|
matching_module_infos = [m for m in module_db if m.module_cls == module_cls]
|
||||||
assert (
|
assert len(matching_module_infos) == 1, (
|
||||||
len(matching_module_infos) == 1
|
f"Couldn't find single ModuleInfo for {module_cls}"
|
||||||
), f"Couldn't find single ModuleInfo for {module_cls}"
|
)
|
||||||
module_info = matching_module_infos[0]
|
module_info = matching_module_infos[0]
|
||||||
decorators = list(module_info.decorators)
|
decorators = list(module_info.decorators)
|
||||||
new_decorator = DecorateInfo(
|
new_decorator = DecorateInfo(
|
||||||
|
|
|
||||||
|
|
@ -124,9 +124,7 @@ class TestGraphInfoProvider(TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_recomputable_node_only_graph_with_larger_graph_context(self):
|
def test_recomputable_node_only_graph_with_larger_graph_context(self):
|
||||||
recomputable_node_only_graph_with_larger_graph_context = (
|
recomputable_node_only_graph_with_larger_graph_context = self.graph_info_provider.recomputable_node_only_graph_with_larger_graph_context # noqa: B950
|
||||||
self.graph_info_provider.recomputable_node_only_graph_with_larger_graph_context
|
|
||||||
)
|
|
||||||
expected_nodes = self.all_recomputable_banned_nodes
|
expected_nodes = self.all_recomputable_banned_nodes
|
||||||
# node1 does not have an indirect path to node5 because of node2
|
# node1 does not have an indirect path to node5 because of node2
|
||||||
# node2 has an indirect path to node5
|
# node2 has an indirect path to node5
|
||||||
|
|
|
||||||
|
|
@ -2568,8 +2568,9 @@ def forward(self, primals_1, primals_2):
|
||||||
def fn(x: torch.Tensor, x1: torch.Tensor) -> torch.Tensor:
|
def fn(x: torch.Tensor, x1: torch.Tensor) -> torch.Tensor:
|
||||||
return torch.ops._test._clone_create_graph(x, x1)
|
return torch.ops._test._clone_create_graph(x, x1)
|
||||||
|
|
||||||
inp_x, inp_x1 = torch.randn(3, requires_grad=True), torch.randn(
|
inp_x, inp_x1 = (
|
||||||
3, requires_grad=True
|
torch.randn(3, requires_grad=True),
|
||||||
|
torch.randn(3, requires_grad=True),
|
||||||
)
|
)
|
||||||
|
|
||||||
ref_x, ref_x1 = inp_x.clone(), inp_x1.clone()
|
ref_x, ref_x1 = inp_x.clone(), inp_x1.clone()
|
||||||
|
|
@ -5283,11 +5284,12 @@ def forward(self, arg0_1):
|
||||||
|
|
||||||
mod = TestMod(fn)
|
mod = TestMod(fn)
|
||||||
inp = torch.randn(2)
|
inp = torch.randn(2)
|
||||||
with patch(
|
with (
|
||||||
"functorch.compile.config.functionalize_rng_ops", True
|
patch("functorch.compile.config.functionalize_rng_ops", True),
|
||||||
), self.assertRaisesRegex(
|
self.assertRaisesRegex(
|
||||||
RuntimeError,
|
RuntimeError,
|
||||||
"Functionalized RNG is not currently supported in the aot_export",
|
"Functionalized RNG is not currently supported in the aot_export",
|
||||||
|
),
|
||||||
):
|
):
|
||||||
aot_export_joint_simple(fn, [mod.p, inp], trace_joint=False)
|
aot_export_joint_simple(fn, [mod.p, inp], trace_joint=False)
|
||||||
aot_export_joint_simple(fn, [mod.p, inp], trace_joint=True)
|
aot_export_joint_simple(fn, [mod.p, inp], trace_joint=True)
|
||||||
|
|
|
||||||
|
|
@ -4324,9 +4324,7 @@ class TestExamplesCorrectness(TestCase):
|
||||||
|
|
||||||
def lennard_jones_force(r):
|
def lennard_jones_force(r):
|
||||||
"""Get magnitude of LJ force"""
|
"""Get magnitude of LJ force"""
|
||||||
return -epsilon * (
|
return -epsilon * ((-12 * sigma**12 / r**13) + (6 * sigma**6 / r**7))
|
||||||
(-12 * sigma**12 / r**13) + (6 * sigma**6 / r**7)
|
|
||||||
)
|
|
||||||
|
|
||||||
r = torch.linspace(0.5, 2 * sigma, steps=100, requires_grad=True, device=device)
|
r = torch.linspace(0.5, 2 * sigma, steps=100, requires_grad=True, device=device)
|
||||||
drs = torch.outer(r, torch.tensor([1.0, 0, 0], device=device))
|
drs = torch.outer(r, torch.tensor([1.0, 0, 0], device=device))
|
||||||
|
|
@ -4495,8 +4493,9 @@ class TestExamplesCorrectness(TestCase):
|
||||||
# This example mimics what a user might do when trying to find the optimal learning rate. They would
|
# This example mimics what a user might do when trying to find the optimal learning rate. They would
|
||||||
# want to run a bunch of models with the same behavior (including the same dropout!) and have them
|
# want to run a bunch of models with the same behavior (including the same dropout!) and have them
|
||||||
# each run with different learning rates. Specifically, this is an example of using same randomness with vmap
|
# each run with different learning rates. Specifically, this is an example of using same randomness with vmap
|
||||||
points, labels = torch.randn(100, 2, 2, 2, 2, device=device), torch.randint(
|
points, labels = (
|
||||||
0, 2, (100,), device=device
|
torch.randn(100, 2, 2, 2, 2, device=device),
|
||||||
|
torch.randint(0, 2, (100,), device=device),
|
||||||
)
|
)
|
||||||
|
|
||||||
class MLPClassifier(nn.Module):
|
class MLPClassifier(nn.Module):
|
||||||
|
|
|
||||||
|
|
@ -208,33 +208,33 @@ def check(f, t, delta, check_val=True, graph_input=False):
|
||||||
old_num_nodes = len(fx_g.graph.nodes)
|
old_num_nodes = len(fx_g.graph.nodes)
|
||||||
new_num_nodes = len(new_graph.nodes)
|
new_num_nodes = len(new_graph.nodes)
|
||||||
if delta == -1:
|
if delta == -1:
|
||||||
assert (
|
assert old_num_nodes >= new_num_nodes, (
|
||||||
old_num_nodes >= new_num_nodes
|
f"number of nodes increased {old_num_nodes}, {new_num_nodes}"
|
||||||
), f"number of nodes increased {old_num_nodes}, {new_num_nodes}"
|
)
|
||||||
else:
|
else:
|
||||||
assert (
|
assert old_num_nodes == new_num_nodes + delta, (
|
||||||
old_num_nodes == new_num_nodes + delta
|
f"number of nodes not the same {old_num_nodes - delta}, {new_num_nodes}\n {fx_g.graph} \n {new_graph}"
|
||||||
), f"number of nodes not the same {old_num_nodes - delta}, {new_num_nodes}\n {fx_g.graph} \n {new_graph}"
|
)
|
||||||
|
|
||||||
# a second pass should not reduce more nodes
|
# a second pass should not reduce more nodes
|
||||||
pass_2_graph = fx_graph_cse(new_graph)
|
pass_2_graph = fx_graph_cse(new_graph)
|
||||||
pass_2_num_nodes = len(pass_2_graph.nodes)
|
pass_2_num_nodes = len(pass_2_graph.nodes)
|
||||||
assert (
|
assert pass_2_num_nodes == new_num_nodes, (
|
||||||
pass_2_num_nodes == new_num_nodes
|
f"second pass graph has less node {pass_2_num_nodes}, {new_num_nodes}\n {new_graph} \n {pass_2_graph}"
|
||||||
), f"second pass graph has less node {pass_2_num_nodes}, {new_num_nodes}\n {new_graph} \n {pass_2_graph}"
|
)
|
||||||
|
|
||||||
# check correctness
|
# check correctness
|
||||||
if check_val:
|
if check_val:
|
||||||
true_result = fx_g(t)
|
true_result = fx_g(t)
|
||||||
our_result = new_g(t)
|
our_result = new_g(t)
|
||||||
if true_result is None: # both return None
|
if true_result is None: # both return None
|
||||||
assert (
|
assert our_result is None, (
|
||||||
our_result is None
|
f"true result is None, CSE result is {our_result}"
|
||||||
), f"true result is None, CSE result is {our_result}"
|
)
|
||||||
else: # results returned are the same
|
else: # results returned are the same
|
||||||
assert torch.all(
|
assert torch.all(true_result == our_result), (
|
||||||
true_result == our_result
|
f"results are different {true_result}, {our_result}"
|
||||||
), f"results are different {true_result}, {our_result}" # check results are the same
|
) # check results are the same
|
||||||
|
|
||||||
|
|
||||||
class NoChangeTestCase(TestCase):
|
class NoChangeTestCase(TestCase):
|
||||||
|
|
|
||||||
|
|
@ -2154,9 +2154,9 @@ class TestOperators(TestCase):
|
||||||
else:
|
else:
|
||||||
weight = torch.randn(weight_shape, device=device)
|
weight = torch.randn(weight_shape, device=device)
|
||||||
target = torch.randint(0, C, target_shape, device=device)
|
target = torch.randint(0, C, target_shape, device=device)
|
||||||
target[
|
target[0] = (
|
||||||
0
|
1 # since we're ignoring index 0, at least one element must be non-zero
|
||||||
] = 1 # since we're ignoring index 0, at least one element must be non-zero
|
)
|
||||||
|
|
||||||
fn = functools.partial(
|
fn = functools.partial(
|
||||||
torch.nn.functional.nll_loss, target=target, weight=weight, **kwargs
|
torch.nn.functional.nll_loss, target=target, weight=weight, **kwargs
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,7 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
SOFTWARE.
|
SOFTWARE.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
|
|
@ -107,7 +108,7 @@ class TestParsedExpression(TestCase):
|
||||||
ParsedExpression("(a) ((b c) (d ...))")
|
ParsedExpression("(a) ((b c) (d ...))")
|
||||||
|
|
||||||
# invalid identifiers
|
# invalid identifiers
|
||||||
ParsedExpression("camelCase under_scored cApiTaLs \u00DF ...")
|
ParsedExpression("camelCase under_scored cApiTaLs \u00df ...")
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
ParsedExpression("1a")
|
ParsedExpression("1a")
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,6 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
SOFTWARE.
|
SOFTWARE.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
|
||||||
|
|
@ -1532,7 +1532,7 @@ class TestVmapOperators(Namespace.TestVmapBase):
|
||||||
self._test_unary(op, getter, "cpu")
|
self._test_unary(op, getter, "cpu")
|
||||||
|
|
||||||
# test in-place
|
# test in-place
|
||||||
method = getattr(Tensor, f'{op.__name__ + "_"}')
|
method = getattr(Tensor, f"{op.__name__ + '_'}")
|
||||||
self._test_unary(method, getter, "cpu", check_propagates_grad=False)
|
self._test_unary(method, getter, "cpu", check_propagates_grad=False)
|
||||||
|
|
||||||
def test_clone(self):
|
def test_clone(self):
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@ r"""
|
||||||
**This file is EXPERIMENTAL and is mostly used for testing purposes! Do not
|
**This file is EXPERIMENTAL and is mostly used for testing purposes! Do not
|
||||||
rely on it for anything!**
|
rely on it for anything!**
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import operator
|
import operator
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -46,9 +46,9 @@ def check(self, f, t, delta, check_val=True, graph_input=False, P=None):
|
||||||
old_num_nodes = len(fx_g.graph.nodes)
|
old_num_nodes = len(fx_g.graph.nodes)
|
||||||
new_num_nodes = len(new_graph.nodes)
|
new_num_nodes = len(new_graph.nodes)
|
||||||
|
|
||||||
assert (
|
assert (new_num_nodes < old_num_nodes) == modified, (
|
||||||
new_num_nodes < old_num_nodes
|
"modified should be True if the number of nodes decrease"
|
||||||
) == modified, "modified should be True if the number of nodes decrease"
|
)
|
||||||
|
|
||||||
if delta == -1:
|
if delta == -1:
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
|
|
|
||||||
|
|
@ -1783,8 +1783,9 @@ class TestSingleOperation(unittest.TestCase):
|
||||||
self.assertEqual(s.check(), z3.sat)
|
self.assertEqual(s.check(), z3.sat)
|
||||||
|
|
||||||
add_result = z3.Const(3, tensor_type)
|
add_result = z3.Const(3, tensor_type)
|
||||||
broadcast_res1, broadcast_res2 = z3.Const(4, tensor_type), z3.Const(
|
broadcast_res1, broadcast_res2 = (
|
||||||
5, tensor_type
|
z3.Const(4, tensor_type),
|
||||||
|
z3.Const(5, tensor_type),
|
||||||
)
|
)
|
||||||
|
|
||||||
# print(s.model())
|
# print(s.model())
|
||||||
|
|
|
||||||
|
|
@ -251,10 +251,13 @@ class TestInvokeSubgraphCompile(TestCase):
|
||||||
x_clone = x.detach().clone().requires_grad_(True)
|
x_clone = x.detach().clone().requires_grad_(True)
|
||||||
y_clone = y.detach().clone().requires_grad_(True)
|
y_clone = y.detach().clone().requires_grad_(True)
|
||||||
backend = EagerAndRecordGraphs()
|
backend = EagerAndRecordGraphs()
|
||||||
with mock.patch(
|
with (
|
||||||
"torch._dynamo.variables.higher_order_ops.InvokeSubgraphHigherOrderVariable.supports_input_mutation",
|
mock.patch(
|
||||||
True,
|
"torch._dynamo.variables.higher_order_ops.InvokeSubgraphHigherOrderVariable.supports_input_mutation",
|
||||||
), torch.no_grad():
|
True,
|
||||||
|
),
|
||||||
|
torch.no_grad(),
|
||||||
|
):
|
||||||
res = torch.compile(fn, backend=backend, fullgraph=True)(
|
res = torch.compile(fn, backend=backend, fullgraph=True)(
|
||||||
mod, x_clone, y_clone
|
mod, x_clone, y_clone
|
||||||
)
|
)
|
||||||
|
|
@ -2399,7 +2402,9 @@ class GraphModule(torch.nn.Module):
|
||||||
{"strict": False},
|
{"strict": False},
|
||||||
{"strict": True},
|
{"strict": True},
|
||||||
],
|
],
|
||||||
class_name_func=lambda cls, _, params: f"{cls.__name__}{'Strict' if params['strict'] else 'Nonstrict'}",
|
class_name_func=lambda cls,
|
||||||
|
_,
|
||||||
|
params: f"{cls.__name__}{'Strict' if params['strict'] else 'Nonstrict'}",
|
||||||
)
|
)
|
||||||
class TestInvokeSubgraphExport(TestCase):
|
class TestInvokeSubgraphExport(TestCase):
|
||||||
def test_simple_func(self):
|
def test_simple_func(self):
|
||||||
|
|
|
||||||
|
|
@ -38,7 +38,6 @@ USE_BLACK_FILELIST = re.compile(
|
||||||
# torchgen/**
|
# torchgen/**
|
||||||
# test/**
|
# test/**
|
||||||
# test/[a-h]*/**
|
# test/[a-h]*/**
|
||||||
"test/[a-h]*/**",
|
|
||||||
# test/[i-j]*/**
|
# test/[i-j]*/**
|
||||||
"test/[i-j]*/**",
|
"test/[i-j]*/**",
|
||||||
# test/[k-m]*/**
|
# test/[k-m]*/**
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user