mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[SymmMem] Tiled reduce (#162243)
Added op: `tile_reduce(Tensor input, Tensor(a!) out, int root, str group_name)`
For now supports only:
- NVSHMEM backed symmetric tensor;
- 2D tensor and tile;
- torch.float.
Testing on right-bottom quandrant:
```
rank 0:
tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 1., 1., 1., 1.],
[0., 0., 0., 0., 1., 1., 1., 1.],
[0., 0., 0., 0., 1., 1., 1., 1.],
[0., 0., 0., 0., 1., 1., 1., 1.]], device='cuda:0')
PASSED
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162243
Approved by: https://github.com/ngimel
This commit is contained in:
parent
3040a5d294
commit
d444384003
191
benchmarks/distributed/bench_nvshmem_tile_reduce.py
Normal file
191
benchmarks/distributed/bench_nvshmem_tile_reduce.py
Normal file
|
|
@ -0,0 +1,191 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Benchmark for NVSHMEM tile reduce operations.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python benchmarks/distributed/bench_nvshmem_tile_reduce.py
|
||||||
|
|
||||||
|
This benchmark measures the performance of tile reduce operations across different
|
||||||
|
matrix sizes and tile configurations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.distributed._symmetric_memory as symm_mem
|
||||||
|
from torch.testing._internal.common_distributed import MultiProcContinuousTest
|
||||||
|
from torch.testing._internal.common_utils import (
|
||||||
|
requires_cuda_p2p_access,
|
||||||
|
skip_but_pass_in_sandcastle_if,
|
||||||
|
skipIfRocm,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Decorator
|
||||||
|
def requires_nvshmem():
|
||||||
|
return skip_but_pass_in_sandcastle_if(
|
||||||
|
not symm_mem.is_nvshmem_available(),
|
||||||
|
"bench_nvshmem_tile_reduce requires NVSHMEM, skipping benchmark",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# So that benchmarks are written in device-agnostic way
|
||||||
|
device_type = "cuda"
|
||||||
|
device_module = torch.get_device_module(device_type)
|
||||||
|
|
||||||
|
|
||||||
|
@requires_nvshmem()
|
||||||
|
@requires_cuda_p2p_access()
|
||||||
|
class NVSHMEMTileReduceBenchmark(MultiProcContinuousTest):
|
||||||
|
def _init_device(self) -> None:
|
||||||
|
# TODO: relieve this (seems to hang if without)
|
||||||
|
device_module.set_device(self.device)
|
||||||
|
# Set NVSHMEM as SymmMem backend
|
||||||
|
symm_mem.set_backend("NVSHMEM")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self) -> torch.device:
|
||||||
|
return torch.device(device_type, self.rank)
|
||||||
|
|
||||||
|
def _benchmark_tile_reduce_single(
|
||||||
|
self,
|
||||||
|
full_size: int,
|
||||||
|
tile_size: int,
|
||||||
|
warmup_iters: int = 5,
|
||||||
|
bench_iters: int = 10,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Benchmark a single configuration of tile reduce.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
full_size: Size of the full matrix (full_size x full_size)
|
||||||
|
warmup_iters: Number of warmup iterations
|
||||||
|
bench_iters: Number of benchmark iterations
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with benchmark results
|
||||||
|
"""
|
||||||
|
self._init_device()
|
||||||
|
group_name = dist.group.WORLD.group_name
|
||||||
|
symm_mem.enable_symm_mem_for_group(group_name)
|
||||||
|
|
||||||
|
dtype = torch.float
|
||||||
|
|
||||||
|
# Allocate full matrices
|
||||||
|
full_inp = symm_mem.empty(
|
||||||
|
full_size, full_size, dtype=dtype, device=self.device
|
||||||
|
).fill_(self.rank)
|
||||||
|
full_out = symm_mem.empty(
|
||||||
|
full_size, full_size, dtype=dtype, device=self.device
|
||||||
|
).fill_(0)
|
||||||
|
|
||||||
|
slice_ut = slice(0, tile_size)
|
||||||
|
inp_tile = full_inp[slice_ut, slice_ut]
|
||||||
|
out_tile = full_out[slice_ut, slice_ut]
|
||||||
|
|
||||||
|
root = 0
|
||||||
|
|
||||||
|
# Warmup iterations
|
||||||
|
for _ in range(warmup_iters):
|
||||||
|
torch.ops.symm_mem.tile_reduce(inp_tile, out_tile, root, group_name)
|
||||||
|
torch.cuda.synchronize(self.device)
|
||||||
|
|
||||||
|
# Benchmark iterations
|
||||||
|
times = []
|
||||||
|
|
||||||
|
dist.barrier()
|
||||||
|
torch.cuda.synchronize(self.device)
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
|
for _ in range(bench_iters):
|
||||||
|
torch.ops.symm_mem.tile_reduce(inp_tile, out_tile, root, group_name)
|
||||||
|
|
||||||
|
torch.cuda.synchronize(self.device)
|
||||||
|
end_time = time.perf_counter()
|
||||||
|
times.append((end_time - start_time) / bench_iters)
|
||||||
|
|
||||||
|
# Calculate statistics
|
||||||
|
times = torch.tensor(times, dtype=torch.float64)
|
||||||
|
tile_elements = tile_size * tile_size
|
||||||
|
tile_bytes = (
|
||||||
|
tile_elements * dtype.itemsize
|
||||||
|
if hasattr(dtype, "itemsize")
|
||||||
|
else tile_elements * 4
|
||||||
|
)
|
||||||
|
|
||||||
|
results = {
|
||||||
|
"full_size": full_size,
|
||||||
|
"tile_size": tile_size,
|
||||||
|
"tile_elements": tile_elements,
|
||||||
|
"tile_bytes": tile_bytes,
|
||||||
|
"world_size": self.world_size,
|
||||||
|
"mean_time_ms": times.mean().item() * 1000,
|
||||||
|
"std_time_ms": times.std().item() * 1000,
|
||||||
|
"min_time_ms": times.min().item() * 1000,
|
||||||
|
"max_time_ms": times.max().item() * 1000,
|
||||||
|
"throughput_gb_s": tile_bytes / (times.mean().item() * 1e9),
|
||||||
|
"elements_per_sec": tile_elements / times.mean().item(),
|
||||||
|
}
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
@skipIfRocm
|
||||||
|
def test_benchmark_tile_reduce_various_sizes(self) -> None:
|
||||||
|
"""
|
||||||
|
Benchmark tile reduce across various matrix sizes.
|
||||||
|
"""
|
||||||
|
# Test various matrix sizes
|
||||||
|
tile_sizes = [512, 1024, 2048, 4096, 8192, 16384]
|
||||||
|
full_size = tile_sizes[-1]
|
||||||
|
warmup_iters = 5
|
||||||
|
bench_iters = 20
|
||||||
|
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for tile_size in tile_sizes:
|
||||||
|
try:
|
||||||
|
result = self._benchmark_tile_reduce_single(
|
||||||
|
full_size, tile_size, warmup_iters, bench_iters
|
||||||
|
)
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
if self.rank == 0:
|
||||||
|
print(
|
||||||
|
f"Matrix Size: {full_size}x{full_size}, Tile Size: {tile_size}x{tile_size}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" Mean Time: {result['mean_time_ms']:.3f} ± {result['std_time_ms']:.3f} ms"
|
||||||
|
)
|
||||||
|
print(f" Throughput: {result['throughput_gb_s']:.2f} GB/s")
|
||||||
|
print(f" Bytes: {result['tile_bytes']:.0f}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if self.rank == 0:
|
||||||
|
print(f"Failed to benchmark matrix size {full_size}: {e}")
|
||||||
|
|
||||||
|
# Print summary
|
||||||
|
if self.rank == 0 and results:
|
||||||
|
print("=== BENCHMARK SUMMARY ===")
|
||||||
|
print(
|
||||||
|
f"{'Matrix Size':<12} {'Tile Size':<10} {'Time (ms)':<12} {'Throughput (GB/s)':<18} {'Bytes':<15}"
|
||||||
|
)
|
||||||
|
print("-" * 70)
|
||||||
|
|
||||||
|
for result in results:
|
||||||
|
print(
|
||||||
|
f"{result['full_size']}x{result['full_size']:<7} "
|
||||||
|
f"{result['tile_size']}x{result['tile_size']:<5} "
|
||||||
|
f"{result['mean_time_ms']:<12.3f} "
|
||||||
|
f"{result['throughput_gb_s']:<18.2f} "
|
||||||
|
f"{result['tile_bytes']:<15.0f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# For standalone usage, you'd need to set up distributed environment
|
||||||
|
# For now, this is meant to be run via the PyTorch test framework
|
||||||
|
from torch.testing._internal.common_utils import run_tests
|
||||||
|
|
||||||
|
run_tests()
|
||||||
|
|
@ -701,5 +701,54 @@ class DispatchCombineInSubgroups(MultiProcContinuousTest):
|
||||||
dispatch_then_combine(self.device, align=8, group=subgroup)
|
dispatch_then_combine(self.device, align=8, group=subgroup)
|
||||||
|
|
||||||
|
|
||||||
|
@instantiate_parametrized_tests
|
||||||
|
@requires_nvshmem()
|
||||||
|
@requires_cuda_p2p_access()
|
||||||
|
class NVSHMEMTileCommTest(MultiProcContinuousTest):
|
||||||
|
def _init_device(self) -> None:
|
||||||
|
# TODO: relieve this (seems to hang if without)
|
||||||
|
device_module.set_device(self.device)
|
||||||
|
# Set NVSHMEM as SymmMem backend
|
||||||
|
symm_mem.set_backend("NVSHMEM")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self) -> torch.device:
|
||||||
|
return torch.device(device_type, self.rank)
|
||||||
|
|
||||||
|
@skipIfRocm
|
||||||
|
@parametrize("tile_size", [32, 128, 512])
|
||||||
|
@parametrize("dtype", [torch.float, torch.half, torch.bfloat16])
|
||||||
|
def test_tile_reduce(self, tile_size: int, dtype: torch.dtype) -> None:
|
||||||
|
full_size = 1024
|
||||||
|
assert tile_size <= full_size
|
||||||
|
|
||||||
|
self._init_device()
|
||||||
|
group_name = dist.group.WORLD.group_name
|
||||||
|
symm_mem.enable_symm_mem_for_group(group_name)
|
||||||
|
|
||||||
|
full_inp = symm_mem.empty(
|
||||||
|
full_size, full_size, dtype=dtype, device=self.device
|
||||||
|
).fill_(self.rank)
|
||||||
|
full_out = symm_mem.empty(
|
||||||
|
full_size, full_size, dtype=dtype, device=self.device
|
||||||
|
).fill_(0)
|
||||||
|
|
||||||
|
slice_ut = slice(tile_size, 2 * tile_size)
|
||||||
|
inp_tile = full_inp[slice_ut, slice_ut]
|
||||||
|
out_tile = full_out[slice_ut, slice_ut]
|
||||||
|
|
||||||
|
# Reduce the tile
|
||||||
|
root = 0
|
||||||
|
torch.ops.symm_mem.tile_reduce(inp_tile, out_tile, root, group_name)
|
||||||
|
|
||||||
|
# Check data
|
||||||
|
expected = torch.zeros_like(full_out)
|
||||||
|
expected_tile = expected[slice_ut, slice_ut]
|
||||||
|
if self.rank == root:
|
||||||
|
expected_tile.fill_(self.world_size * (self.world_size - 1) / 2)
|
||||||
|
|
||||||
|
torch.testing.assert_close(full_out, expected)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
run_tests()
|
run_tests()
|
||||||
|
|
|
||||||
|
|
@ -510,6 +510,8 @@ TORCH_LIBRARY_FRAGMENT(symm_mem, m) {
|
||||||
"all_to_all_vdev_2d(Tensor input, Tensor(a!) out, Tensor in_splits, Tensor(a!) out_splits_offsets, str group_name, int? major_align=None) -> ()");
|
"all_to_all_vdev_2d(Tensor input, Tensor(a!) out, Tensor in_splits, Tensor(a!) out_splits_offsets, str group_name, int? major_align=None) -> ()");
|
||||||
m.def(
|
m.def(
|
||||||
"all_to_all_vdev_2d_offset(Tensor input, Tensor(a!) out, Tensor in_splits_offsets, Tensor(a!) out_splits_offsets, str group_name) -> ()");
|
"all_to_all_vdev_2d_offset(Tensor input, Tensor(a!) out, Tensor in_splits_offsets, Tensor(a!) out_splits_offsets, str group_name) -> ()");
|
||||||
|
m.def(
|
||||||
|
"tile_reduce(Tensor in_tile, Tensor(a!) out_tile, int root, str group_name, str reduce_op='sum') -> ()");
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_LIBRARY_IMPL(symm_mem, Meta, m) {
|
TORCH_LIBRARY_IMPL(symm_mem, Meta, m) {
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@
|
||||||
#include <torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryUtils.hpp>
|
#include <torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryUtils.hpp>
|
||||||
#include <torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.hpp>
|
#include <torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.hpp>
|
||||||
|
|
||||||
|
#include <ATen/ceil_div.h>
|
||||||
// Use torch's cub wrapper instead of CUDA's <cub/cub.cuh>, see #55292
|
// Use torch's cub wrapper instead of CUDA's <cub/cub.cuh>, see #55292
|
||||||
#include <ATen/cuda/cub.cuh>
|
#include <ATen/cuda/cub.cuh>
|
||||||
|
|
||||||
|
|
@ -863,6 +864,128 @@ void all_to_all_vdev_2d_offset(
|
||||||
0,
|
0,
|
||||||
stream);
|
stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* Tiled Communication */
|
||||||
|
|
||||||
|
using Shape2D = nvshmemx::shape<int64_t, int64_t>;
|
||||||
|
using Stride2D = nvshmemx::stride<int64_t, int64_t>;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void tile_reduce_kernel(
|
||||||
|
T* src_ptr, T* dst_ptr, Shape2D shape, Stride2D strides, int64_t root, nvshmem_team_t* teams) {
|
||||||
|
#ifndef _NVSHMEM_DEVICELIB_SUPPORTED
|
||||||
|
CUDA_KERNEL_ASSERT_MSG(false, "SM arch unsupported for NVSHMEM");
|
||||||
|
#else
|
||||||
|
int bid = blockIdx.x;
|
||||||
|
auto team = teams[bid];
|
||||||
|
CUDA_KERNEL_ASSERT(team != NVSHMEM_TEAM_INVALID && " invalid team\n");
|
||||||
|
|
||||||
|
// Global tile shape
|
||||||
|
auto [rows, cols] = shape;
|
||||||
|
auto [stride0, stride1] = strides;
|
||||||
|
|
||||||
|
// Divide rows among CUDA blocks
|
||||||
|
auto rows_per_block = at::ceil_div(rows, (int64_t)gridDim.x);
|
||||||
|
auto block_start_row = rows_per_block * bid;
|
||||||
|
auto block_shape = nvshmemx::make_shape(std::min(rows_per_block, rows - block_start_row), cols);
|
||||||
|
auto block_layout = nvshmemx::make_layout(block_shape, strides);
|
||||||
|
|
||||||
|
// Start pointer of each block's sub-tile
|
||||||
|
auto block_src_ptr = src_ptr + stride0 * block_start_row;
|
||||||
|
auto block_dst_ptr = dst_ptr + stride0 * block_start_row;
|
||||||
|
auto block_src_tensor = nvshmemx::Tensor(block_src_ptr, block_layout);
|
||||||
|
auto block_dst_tensor = nvshmemx::Tensor(block_dst_ptr, block_layout);
|
||||||
|
|
||||||
|
// Making these empty to avoid nvshmemx::tile_sum_reduce_block() from doing
|
||||||
|
// additional range checks
|
||||||
|
auto start_coord = nvshmemx::make_shape();
|
||||||
|
auto boundary = nvshmemx::make_shape();
|
||||||
|
|
||||||
|
// Use one-shot pull to reduce the tile
|
||||||
|
uint64_t flag = 0;
|
||||||
|
constexpr auto algo = nvshmemx::tile_coll_algo_t::NVLS_ONE_SHOT_PULL_NBI;
|
||||||
|
nvshmemx::tile_sum_reduce_block<decltype(block_src_tensor), decltype(block_dst_tensor), decltype(boundary), algo>(
|
||||||
|
team, block_src_tensor, block_dst_tensor, start_coord, boundary, root, flag /* unused */);
|
||||||
|
|
||||||
|
// Wait for the operation to complete
|
||||||
|
nvshmemx::tile_collective_wait<algo>(team, flag /* unused */);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
#define AT_DISPATCH_CASE_CONVERT(enum_type, scalar_type, ...) \
|
||||||
|
case enum_type: { \
|
||||||
|
AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type); \
|
||||||
|
using scalar_t = scalar_type; \
|
||||||
|
return __VA_ARGS__(); \
|
||||||
|
}
|
||||||
|
|
||||||
|
#define AT_DISPATCH_NVSHMEM_FLOATS(scalar_type, name, ...) \
|
||||||
|
AT_DISPATCH_SWITCH( \
|
||||||
|
scalar_type, name, \
|
||||||
|
AT_DISPATCH_CASE_CONVERT(at::kBFloat16, __nv_bfloat16, __VA_ARGS__); \
|
||||||
|
AT_DISPATCH_CASE_CONVERT(at::kHalf, __half, __VA_ARGS__); \
|
||||||
|
AT_DISPATCH_CASE(at::kFloat, __VA_ARGS__));
|
||||||
|
|
||||||
|
void tile_reduce(
|
||||||
|
at::Tensor& in_tile,
|
||||||
|
at::Tensor& out_tile,
|
||||||
|
int64_t root,
|
||||||
|
std::string group_name,
|
||||||
|
std::string reduce_op) {
|
||||||
|
/* Perform a tile reduce operation on the input tensor, with the root rank
|
||||||
|
* receiving the reduced tensor. */
|
||||||
|
TORCH_CHECK(reduce_op == "sum", "tile_reduce: only sum is supported for now");
|
||||||
|
TORCH_CHECK(in_tile.dim() == 2 && out_tile.dim() == 2, "Only 2D tensors are supported");
|
||||||
|
TORCH_CHECK_EQ(in_tile.dtype(), out_tile.dtype());
|
||||||
|
TORCH_CHECK_EQ(in_tile.sizes(), out_tile.sizes());
|
||||||
|
TORCH_CHECK_EQ(in_tile.strides(), out_tile.strides());
|
||||||
|
TORCH_CHECK_EQ(in_tile.device(), out_tile.device());
|
||||||
|
|
||||||
|
auto device = in_tile.device();
|
||||||
|
c10::cuda::CUDAGuard guard(device);
|
||||||
|
auto hdl = c10d::symmetric_memory::rendezvous(in_tile, group_name);
|
||||||
|
c10d::symmetric_memory::rendezvous(out_tile, group_name);
|
||||||
|
|
||||||
|
// Ideally 16 bytes per thread
|
||||||
|
int nblocks = at::ceil_div(
|
||||||
|
in_tile.numel() * in_tile.element_size(),
|
||||||
|
(int64_t)THREADS_PER_BLOCK * 16);
|
||||||
|
nblocks = std::min(nblocks, 24);
|
||||||
|
|
||||||
|
// Need one team per block
|
||||||
|
auto& team_manager = TeamManager::get(device);
|
||||||
|
auto [teams, teams_dev] = team_manager.get_n_teams(
|
||||||
|
group_name, hdl->get_rank_to_global_rank(), nblocks);
|
||||||
|
TORCH_CHECK(
|
||||||
|
root < nvshmem_team_n_pes(teams[0]),
|
||||||
|
"root must be smaller than group size");
|
||||||
|
auto stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
|
||||||
|
// Prepare launch parameters
|
||||||
|
auto shape = nvshmemx::make_shape(in_tile.sizes()[0], in_tile.sizes()[1]);
|
||||||
|
auto stride = nvshmemx::make_stride(in_tile.strides()[0], in_tile.strides()[1]);
|
||||||
|
auto src_ptr = in_tile.const_data_ptr();
|
||||||
|
auto dst_ptr = out_tile.mutable_data_ptr();
|
||||||
|
void* args[] = {
|
||||||
|
&src_ptr,
|
||||||
|
&dst_ptr,
|
||||||
|
&shape,
|
||||||
|
&stride,
|
||||||
|
&root,
|
||||||
|
&teams_dev};
|
||||||
|
|
||||||
|
AT_DISPATCH_NVSHMEM_FLOATS(in_tile.scalar_type(), "tile_reduce", [&]() {
|
||||||
|
nvshmemx_collective_launch(
|
||||||
|
(const void*)tile_reduce_kernel<scalar_t>,
|
||||||
|
dim3(nblocks),
|
||||||
|
dim3(THREADS_PER_BLOCK),
|
||||||
|
args,
|
||||||
|
0,
|
||||||
|
stream);
|
||||||
|
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace c10d::nvshmem_extension
|
} // namespace c10d::nvshmem_extension
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -876,4 +999,5 @@ TORCH_LIBRARY_IMPL(symm_mem, CUDA, m) {
|
||||||
m.impl("all_to_all_vdev", c10d::nvshmem_extension::all_to_all_vdev);
|
m.impl("all_to_all_vdev", c10d::nvshmem_extension::all_to_all_vdev);
|
||||||
m.impl("all_to_all_vdev_2d", c10d::nvshmem_extension::all_to_all_vdev_2d);
|
m.impl("all_to_all_vdev_2d", c10d::nvshmem_extension::all_to_all_vdev_2d);
|
||||||
m.impl("all_to_all_vdev_2d_offset", c10d::nvshmem_extension::all_to_all_vdev_2d_offset);
|
m.impl("all_to_all_vdev_2d_offset", c10d::nvshmem_extension::all_to_all_vdev_2d_offset);
|
||||||
|
m.impl("tile_reduce", c10d::nvshmem_extension::tile_reduce);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -58,4 +58,11 @@ void all_to_all_vdev_2d_offset(
|
||||||
at::Tensor& out_splits_offsets,
|
at::Tensor& out_splits_offsets,
|
||||||
std::string group_name);
|
std::string group_name);
|
||||||
|
|
||||||
|
void tile_reduce(
|
||||||
|
at::Tensor& in_tile,
|
||||||
|
at::Tensor& out_tile,
|
||||||
|
int64_t root,
|
||||||
|
std::string group_name,
|
||||||
|
std::string reduce_op = "sum");
|
||||||
|
|
||||||
} // namespace c10d::nvshmem_extension
|
} // namespace c10d::nvshmem_extension
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user