Summary:
Minor logging cleanup in distributed library
1. Don't use "f" formatted strings - address linter issues.
2. Nits: Make use of unused `e` (error) in a few logs.
3. Change info->debug as asked in issue #113545
4. Nit: rename log -> logger in a few files for consistency
5. Fix a linter error.
Test Plan:
1. Local build passes.
2. Linter is happy.
Reviewers: wanchaol
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122921
Approved by: https://github.com/wanchaol
**Expected behavior**: when rank 0 have inf grad, rank 1...k should get `found_inf=1` after `dist.reduce_all`
**Bug addressed in this PR**: for cpu offloaded param.grad, when rank 0 have inf, rank 1...k would not have found_inf=1. This is because `found_inf` was copied before `future.wait` on async `dist.reduce_all`
repro the bug using the newly added unit test: `pytest test/distributed/fsdp/test_fsdp_sharded_grad_scaler.py -k test_sharded_grad_scaler_found_inf`
```
File "/data/users/weif/pytorch/test/distributed/fsdp/test_fsdp_sharded_grad_scaler.py", line 320, in _test_sharded_grad_scaler_found_inf
self.assertEqual(
File "/data/users/weif/pytorch/torch/testing/_internal/common_utils.py", line 3576, in assertEqual
raise error_metas.pop()[0].to_error(
AssertionError: Scalars are not close!
Expected 1.0 but got 2.0.
Absolute difference: 1.0 (up to 1e-05 allowed)
Relative difference: 1.0 (up to 1.3e-06 allowed)
rank: 0 iter: 0 expect origin scale 2.0 to be backed off by 0.5 but got 2.0
```
verify the bug is fixed: `pytest test/distributed/fsdp/test_fsdp_sharded_grad_scaler.py -k test_sharded_grad_scaler_found_inf`
```
test/distributed/fsdp/test_fsdp_sharded_grad_scaler.py dist init r=1, world=8
dist init r=3, world=8
dist init r=7, world=8
dist init r=4, world=8
dist init r=6, world=8
dist init r=2, world=8
dist init r=0, world=8
dist init r=5, world=8
NCCL version 2.19.3+cuda12.0
. [100%]
====================================================================== 1 passed, 19 deselected in 27.43s =========================
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115710
Approved by: https://github.com/awgu
Fixes [#82206](https://github.com/pytorch/pytorch/issues/82206)
When executing a `ShardedGradScaler` step in the context of `cpu_offload`, [the function](ecd2c71871/torch/distributed/fsdp/sharded_grad_scaler.py (L151-L152)) `_foreach_non_finite_check_and_unscale_cpu_` is grindingly slow. This issue is due to the elementwise op dispatching/redispatching/execution that is engendered by the current approach to gradient tensor validation:
ecd2c71871/torch/distributed/fsdp/sharded_grad_scaler.py (L159-L163)
The subsequent `isinf` and `isnan` checks with associated `any` checks result in unscalable elementwise op dispatches:
ecd2c71871/torch/distributed/fsdp/sharded_grad_scaler.py (L173-L181)
This inefficency is of course hidden in the current FSDP tests given their (appropriately) trivial parameter dimensionality. In the perf analysis below, the example test configures only the final `Linear(4, 8)` module parameters to require grad, so there are 40 elements to iterate through. However, if one increases the dimensionality to a still-modest 320008 elements (changing the final module to `Linear(40000,8)`), the execution time/cpu cost of the test is dominated by the elementwise op dispatching/redispatching/execution of the `any` validation ops in this function.
To characterize the current behavior, I use a slightly modified version of an existing `ShardedGradScaler` test [^1]. The following modifications to the test are made to allow the analysis:
1. Run just `CUDAInitMode.CUDA_BEFORE` for clarity instead of additional scenarios
2. Increase the final module to `Linear(40000, 8)` (along with modifying the preceding module to make the dimensions work) ,
3. For the cProfile run (but not valgrind or perf) the test runs just a single [`_train_for_several_steps`](ecd2c71871/torch/testing/_internal/common_fsdp.py (L926-L934)) step per rank (instead of 2 steps)
4. I temporarily reduce `init_scale` further to ensure we don't hit any `infs`, short-circuiting our analysis
### Current behavior
The most relevant call subgraph:

Note that:
1. Instead of dispatching to the relevant autograd op and then redispatching to the relevant CPU op implementation 8 times per test, (2 train steps x 2 any calls per parameter per step x 2 orig parameters) we (I believe unnecessarily) call the relevant dispatch flow elementwise, so 640016 times! (only 1 node in this trace so 320008 elements/2 X 2 train steps x 2 calls per element per step).
2. Nearly 50% of the relative (inclusive) instruction reads for the entire test in `callgrind` are executed by the `isnan` (320008 execs), `isinf` (320008 execs) and `any` (640016 execs) calls.
3. The `any` pre-dispatch entry point IRs (`torch::autograd::THPVariable_any`) vs actual op implementation IRs (`at::native::structured_any_all_out::impl`) are below to give one a sense of the relative dispatch and op execution cost in an elementwise context[^3].


Using cprofile stats:
```bash
python -c "import pstats; stats=pstats.Stats('/tmp/fsdp_cprofile_8wa9uw39.stats'); stats.print_stats()"
...
ncalls tottime percall cumtime percall filename:lineno(function)
1 20.159 20.159 66.805 66.805 torch/distributed/fsdp/sharded_grad_scaler.py:151(_foreach_non_finite_check_and_unscale_cpu_)
160004 18.427 0.000 18.427 0.000 {built-in method torch.isinf}
160004 6.026 0.000 6.026 0.000 {built-in method torch.isnan}
```
We see that a single step of the scaler runs for more than a minute. Though there is non-trivial cprofile overhead, we can infer from this that per-element op dispatches/executions are on the order of a 100ns.
On the order of 100 nanoseconds per dispatch is acceptable if we're using typical tensor access patterns, but if we're dispatching each element for each op, obviously everything is going to come to a grinding halt for many practical use cases.
(Given the cost of this function is currently O(n) in the number of gradient elements, feel free to set `TORCH_SHOW_DISPATCH_TRACE=1` if you want to make this function cry 🤣)
I've attached a flamegraph at the bottom of the PR[^2] that more intuitively demonstrates the manner and extent of resource consumption attributable to this function with just a modest number of gradient elements.
### After the loop refactor in this PR:
The most relevant call subgraph:

Note that:
1. Less than 0.4% of the relative (inclusive) instruction reads for the entire test in `callgrind` are executed by the `isnan` (4 execs), `isinf` (4 execs) and `any` (8 execs) calls (versus ~50% and 320008, 320008, 640016 respectively above)
2. The `any` pre-dispatch entry point IRs (`torch::autograd::THPVariable_any`) vs actual op implementation IRs (`at::native::structured_any_all_out::impl`) reflect far less overhead (of secondary importance to item number 1)


Using cprofile stats:
```bash
python -c "import pstats; stats=pstats.Stats('/tmp/fsdp_cprofile_pfap7nwk.stats'); stats.print_stats()"
...
ncalls tottime percall cumtime percall filename:lineno(function)
1 0.013 0.013 0.109 0.109 torch/distributed/fsdp/sharded_grad_scaler.py:151(_foreach_non_finite_check_and_unscale_cpu_)
2 0.022 0.011 0.022 0.011 {built-in method torch.isinf}
2 0.018 0.009 0.018 0.009 {built-in method torch.isnan}
```
We can see our function runtime has dropped from more than a minute to ~100ms.
### Assumptions associated with this loop refactor:
The key assumptions here are:
1. The grads are always on CPU in this function so any MTA-safe constraints ([`can_use_fast_route`](efc3887ea5/aten/src/ATen/native/cuda/AmpKernels.cu (L110-L111)) relating to the relevant CUDA kernel path selection, i.e. slower `TensorIterator` gpu kernel vs `multi_tensor_apply_kernel`) do not apply in this context
2. We've already filtered by dtype and device and can assume the presence of a single CPU device. Unless manually creating separate CPU devices with manually set non-default indexes (which I don't think FSDP supports and should be validated prior to this function), device equality should always be `True` for `cpu` type devices so we should just need to check that the current device is of `cpu` type. [^4].

[^1]: `TestShardedGradScalerParityWithDDP.test_fsdp_ddp_parity_with_grad_scaler_offload_true_none_mixed_precision_use_orig_params` test in `test/distributed/fsdp/test_fsdp_sharded_grad_scaler.py`
[^2]: Note the native frame stacks for `torch::autograd::THPVariable_isinf`, `torch::autograd::THPVariable_isnan`, `torch::autograd::THPVariable_any` in particular.
[^3]: There's more `TensorIterator` etc. setup overhead further up the stack beyond `structured_any_all_out`, but roughly speaking
[^4]: Device equality is based on [type and index combination](efc3887ea5/c10/core/Device.h (L47-L51)), CPU device type is -1 by default (`None` on the python side) and is intended to [always be 0](cf21240f67/c10/core/Device.h (L29)) if set explicitly. Though technically, unless in debug mode, this constraint isn't [actually validated](bb4e9e9124/c10/core/Device.h (L171-L184)), so one can actually manually create separate `cpu` devices with invalid indices. I suspect it's safe to ignore that potential incorrect/unusual configuration in this context but let me know if you'd like to add another `cpu` device equality check.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100108
Approved by: https://github.com/awgu
Fixes#99174
## Enable FSDP ``use_orig_params=True`` mixed precision training when some ranks have no (non-zero sized) parameter shards
### The issue
Now that ``use_orig_params=True`` allows non-uniform ``requires_grad`` (🎉🚀 thanks @awgu!!!) with [#98221](https://github.com/pytorch/pytorch/pull/98221), there will be circumstances wherein some ranks have no (non-zero sized) local shards of the original parameters (and hence no associated gradients).
### Use Cases
For a simple Transformer case, imagine a user wraps all encoder layers in separate FSDP instances but allows the classifier head to be wrapped in the same FSDP instance as the relatively large embeddings layers. While this is a sub-optimal wrapping strategy for most use-cases, I believe it is expected to be supported (full precision training works in that context).
I originally encountered this issue while extending a package I maintain, leveraging the relaxed ``requires_grad`` contstraint to simplify multi-phase scheduled fine-tuning FSDP configuration, so a [concrete example is there](https://finetuning-scheduler.readthedocs.io/en/latest/advanced/fsdp_scheduled_fine_tuning.html#basic-scheduled-fine-tuning-with-fsdp).
### Reproduction and Remediation
Currently, ``ShardedGradScaler`` does not accommodate these situations, failing to initialize ``optimizer_state["found_inf_per_device"]`` when ``unscale_`` is called.
In this PR, I extend the existing ``ShardedGradScaler`` tests with an ``use_orig_params=True`` dimension added to the parameterization and test scenarios wherein one rank possesses no (non-zero sized) parameter shards.
The relevant issue can be reproduced with the tests I'm adding in this PR. The current (pre-PR) execution of these tests fail in ``use_orig_params=True`` mode with this error:
```python
./test_fsdp_sharded_grad_scaler.py::TestShardedGradScalerParityWithDDP::test_fsdp_ddp_parity_with_grad_scaler_offload_false_none_mixed_precision_use_orig_params Failed with Error: Process 0 exited with error code 10 and exception:
Traceback (most recent call last):
File "/home/speediedan/repos/pytorch/torch/testing/_internal/common_distributed.py", line 657, in run_test
getattr(self, test_name)()
File "/home/speediedan/repos/pytorch/torch/testing/_internal/common_distributed.py", line 543, in wrapper
fn()
File "/home/speediedan/repos/pytorch/torch/testing/_internal/common_utils.py", line 259, in instantiated_test
test(self, **param_kwargs)
File "/home/speediedan/repos/pytorch/torch/testing/_internal/common_distributed.py", line 174, in wrapper
return func(*args, **kwargs)
File "/home/speediedan/repos/pytorch/test/distributed/fsdp/test_fsdp_sharded_grad_scaler.py", line 187, in test_fsdp_ddp_parity_with_grad_scaler
self._test_fsdp_parity(
File "/home/speediedan/repos/pytorch/torch/testing/_internal/common_fsdp.py", line 1152, in _test_fsdp_parity
fsdp_loss = self._train_for_several_steps(
File "/home/speediedan/repos/pytorch/torch/testing/_internal/common_fsdp.py", line 1016, in _train_for_several_steps
sharded_grad_scaler.step(optim)
File "/home/speediedan/repos/pytorch/torch/distributed/fsdp/sharded_grad_scaler.py", line 291, in step
return super().step(optimizer, *args, **kwargs)
File "/home/speediedan/repos/pytorch/torch/cuda/amp/grad_scaler.py", line 368, in step
assert len(optimizer_state["found_inf_per_device"]) > 0, "No inf checks were recorded for this optimizer."
AssertionError: No inf checks were recorded for this optimizer.
```
A few implementation notes/considerations and questions:
1. Rather than just initialize ``per_device_found_inf``, one could disable the grad scalar altogether for relevant ranks, altering ``unscale_`` to reduce with a subgroup or some rank mask construct to avoid the ``all_reduce`` s in ``distributed/fsdp/sharded_grad_scaler.py:unscale_()`` from hanging. Given that users may subsequently add parameter groups to an optimizer that would require re-enabling the scaler and the complexity associated with maintaining a separate mask construct or process subgroup, I thought this implementation was cleaner.
2. I extended ``_train_for_several_steps`` and ``_test_fsdp_parity`` in ``/torch/testing/_internal/common_fsdp.py`` with the ability to configure ``sharded_grad_scaler_kwargs`` for future testing flexibility.
3. Should the user be warned that no parameter shards were associated with a given rank? My initial thought is that this should be considered an implementation detail, part of supporting ``use_orig_params`` with heterogeneous ``requires_grad``, and therefore should be transparently handled by PyTorch. Should a DEBUG level message be added? If so, likely further upstream rather than at the scaler step level.
4. Rather than extend the existing ``ShardedGradScaler`` tests with an ``use_orig_params=True`` dimension added to the parameterization, let me know if you prefer that I instead narrow the scope of the new testing to a single additional test, e.g.:
```python
# from typing import Optional
from typing import Optional, List
# ...
# use_orig_params = ["enable_use_orig_params", None]
use_orig_params: List[Optional[str]] = [None]
# ...
configs = list(itertools.product(cpu_offload_config, sharding_strategy_config, mixed_precision, use_orig_params))
configs.append((CPUOffload(offload_params=False), None, "enable_mixed_precision", "enable_use_orig_params"))
```
Thanks as always to the PyTorch distributed team for your astonishingly impressive and valuable contributions to the open-source ML engineering community!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99175
Approved by: https://github.com/awgu
This PR adds FSDP and composable API files to `.lintrunner.toml` so that (1) lintrunner enforces that those files are formatted and (2) `lintrunner f` formats those files for you.
There are two requirements here (see https://github.com/pytorch/pytorch/wiki/lintrunner for details):
1. Install lintrunner:
```
pip install lintrunner
lintrunner init
```
2. `lintrunner f` before you finalize your PR, which would now be enforced by CI after this PR.
The code changes in this PR outside of `.lintrunner.toml` are the result of `lintrunner f`.
---
I only plan to land this PR if all of the composable API developers agree that this is something that makes sense and is not too intrusive to the workflow.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90873
Approved by: https://github.com/yhcharles, https://github.com/mrshenli, https://github.com/rohan-varma