This change adds "/ctrl/" to the name of the dummy constant added by graph partitioning. As a result, it becomes easier to determine in a profile trace whether a Send/Recv edge was added due to a control or data dependency.
PiperOrigin-RevId: 566393419
This tool lets you convert a HloModule from stdin or a file to another format,
run a set of expander passes, and dump the output to stdout or a file.
This expander tool is divided into the following steps:
1. Load HloModule from stdin or file.
2. Add a set of passes to the HloPassPipeline.
3. Run a set of passes on the module.
4. Optionally print the output to stdout.
5. Optionally write the output to file in the specified format.
Usage:
hlo-expand \
[--input_format=[hlo|pb|pbtxt]] \
[--optional_flags] \
[path/to/hlo_module]
PiperOrigin-RevId: 566374517
Imported from GitHub PR https://github.com/openxla/xla/pull/5184
This is a follow up PR for cleaning up the redundant almost duplicated runners and MLIR op.
There are currently 4 runners and 3 MLIR op defined for fmha fwd. Different runner is chosen based on if there is bias/mask.
There are currently 2 runners and 2 MLIR op defined for fmha bwd. Different runner is chosen based on if there is mask.
We can merge all fwd runner/MLIR op into one and all bwd runner/MLIR op into one with `AttrSizedOperandSegments` since with flash attention there will be more optional buffers. So there is no point of keeping multiple runners/MLIR which is hard for code maintenance and introducing code bloat.
Copybara import of the project:
--
41063ec170108374701721038f4feb602dfee436 by cjkkkk <ske@nvidia.com>:
clean up fmha runner
--
fa04b2e5396e2cce1ecaad8e28a96a0b6f257cd2 by cjkkkk <ske@nvidia.com>:
fix compilation error
--
f3a25aaebd3110529ab383ceb17652eeb22da5e2 by cjkkkk <ske@nvidia.com>:
rebased on fmha runtime changes
--
0007719d2292d5750b8666ba5e1e5a223a420e38 by cjkkkk <ske@nvidia.com>:
add std::optional to guard optional buffers
--
60437ffbe17bbd13173c2ec773f5bf945ecf6432 by cjkkkk <ske@nvidia.com>:
use more informative uid
--
25aa1baa87abb259e3e0350737c4d39a715183a0 by cjkkkk <ske@nvidia.com>:
fix data_ptrs_vec.size() == data_uids_vec.size() check
Merging this change closes#5184
PiperOrigin-RevId: 566359023
1) Move `_copy_trackable_to_cpu` to `ShardedVariable` (from `ShardedVariableMixin`, which is inherited by other objects)
2) Fixed bug that excluded the handling of the TPUEmbedding object.
PiperOrigin-RevId: 566355521
I stumbled upon this just by chance and found a solution.
That's how I understand it:
- Multiple instances of SnapshotManager share one instance of SnapshotAssignmentManager.
- SnapshotManager recently became thread-safe and multiple instances are being used from multiple threads.
- SnapshotAssignmentManager is NOT thread-safe, but it is being used from multiple instances of SnapshotManager
from multiple threads.
My fix just adds a mutex to SnapshotAssignmentManager.
PiperOrigin-RevId: 566347773
Add a test for startup fault tolerance.
Make the connection step more explicit; previously connection only happened within context initialization during a list_logical_devices call, which was unintuitive.
PiperOrigin-RevId: 566344263
PR #5300: A new pass to optimize the AllGather->Binary_Op order sequence
Imported from GitHub PR https://github.com/openxla/xla/pull/5300
This is a new GPU SPMD optimization pass for the following pattern:
binary-op(all-gather(a), all-gather(b))
to
all-gather(binary-op(a, b))
PiperOrigin-RevId: 566340142
In the error case, `snapshot_manager_` will not be deleted.
To solve that we can create the `unique_ptr` before calling `Start` instead of afterwards.
I just saw this by chance when looking into cl/566268125.
PiperOrigin-RevId: 566334445