pytorch/torch/csrc/distributed/c10d/Functional.cpp
Will Feng 4ee514144b [c10d][Partial-Graph Overlap] Support calling .wait_tensor() on output tensor of eager async_op=True collective if under allow_inflight_collective_as_graph_input_ctx() context manager (#137763)
This PR aims to support the following use case:
```python
def all_reduce_eager(x):
    y = x * x
    req = dist.all_reduce(y, op=dist.ReduceOp.SUM, async_op=True)
    assert isinstance(req, torch.distributed.Work)
    return y

@torch.compile(fullgraph=True)
def all_reduce_wait_compiled(y):
    torch.ops.c10d_functional.wait_tensor(y)
    return y * y

x = torch.ones(1280, 1280, device="cuda") + self.rank
with allow_inflight_collective_as_graph_input_ctx():
    y = all_reduce_eager(x)
    z = all_reduce_wait_compiled(y)
```
where the collective is issued in eager (with `async_op=True`) but waited in compiled region.

This is important for internal use cases such as TorchRec, where we issue collectives in eager for SparseArch all_to_all but want to wait for them in compiled region at beginning of OverArch, so that the all_to_all can be overlapped with the DenseArch compute that runs in parallel.

----

**Update**: Did two items to prevent regression to existing use cases:

1. Added memory-stressed test case to test_c10d_nccl.py `test_unwaited` to cover existing user's "not calling work.wait() for non-functional collective" use case
2. Gated all new `register_work()` / `unregister_work()` calls with `c10d::allow_inflight_collective_as_graph_input()` check, which is a new context manager that requires explicit user enablement (i.e. not on by default, so should not affect existing users).

The risk of this new version of PR causing regression should be very low.

------

Test commands:
- `pytest -rA test/distributed/test_inductor_collectives.py::TestCollectivesMultiProc::test_eager_async_allreduce_inductor_wait`
- `pytest -rA test/test_fx.py::TestDCE::test_keep_collectives`
- `pytest -rA test/test_fx.py::TestDCE::test_keep_collectives_no_overload`
- `pytest -rA test/distributed/test_c10d_functional_native.py::TestWithNCCL::test_wait_tensor`
- `pytest -rA test/distributed/test_c10d_functional_native.py::TestWithNCCL::test_unwaited`
- `pytest -rA test/distributed/test_c10d_nccl.py::CommTest::test_wait_tensor`
- `pytest -rA test/distributed/test_c10d_nccl.py::CommTest::test_unwaited`
- `pytest -rA test/distributed/_tensor/test_tensor_ops.py::DistTensorOpsTest::test_equal`
- `pytest -rA test/distributed/_tensor/test_random_ops.py::DistTensorRandomOpTest::test_manual_seed`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_ddp_baseline_aot_eager_multiprocess`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_aot_eager`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_setattr`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_unspecialized_forced_getattr_inline`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_unspecialized_forced_getattr_no_inline`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_asymmetric_compilation`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_automatic_dynamic_scalar`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_automatic_dynamic_speculation_divergence`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_automatic_dynamic_tensor`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_dim_mismatch`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_graph_break_empty_graph_still_collective`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_missing_source`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_scalar_missing_source`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_type_mismatch`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_ddp_activation_checkpointing`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_ddp_baseline_aot_eager_multiprocess`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_activation_checkpointing`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_aot_eager`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_inductor`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_setattr`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_unspecialized_forced_getattr_inline`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_unspecialized_forced_getattr_no_inline`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_ddp_aot_eager`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_ddp_aot_eager_static_graph`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_ddp_inductor`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_ddp_inductor_static_graph`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_fsdp_activation_checkpointing`
- `pytest -rA test/distributed/_tensor/test_experimental_ops.py::DistOtherOpsTest::test_bernoulli`
- `pytest -rA test/distributed/_tensor/test_dtensor_compile.py::TestDTensorCompileE2E::test_tp_compile_fullgraph_is_seq_parallel_True`
- `pytest -rA test/distributed/test_inductor_collectives.py::TestCollectivesMultiProc::test_allreduce_inductor_cudagraph_trees`
- `python benchmarks/dynamo/torchbench.py --ci --accuracy --timing --explain --inductor --device cuda --inference --bfloat16 --total-partitions 2 --partition-id 1 --output inference_torchbench.csv --only moco`

------

Differential Revision: [D65023311](https://our.internmc.facebook.com/intern/diff/D65023311)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137763
Approved by: https://github.com/yifuwang
2024-10-29 03:31:19 +00:00

570 lines
19 KiB
C++

#include <ATen/ATen.h>
#include <ATen/core/op_registration/op_registration.h>
#include <c10/core/DispatchKey.h>
#include <torch/csrc/autograd/custom_function.h>
#include <torch/csrc/autograd/function.h>
#include <torch/csrc/distributed/c10d/Functional.hpp>
#include <torch/csrc/distributed/c10d/GroupRegistry.hpp>
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
#include <utility>
namespace {
const std::unordered_map<std::string, c10d::ReduceOp> str_to_reduce_op = {
{"sum", c10d::ReduceOp(c10d::ReduceOp::RedOpType::SUM)},
{"avg", c10d::ReduceOp(c10d::ReduceOp::RedOpType::AVG)},
{"product", c10d::ReduceOp(c10d::ReduceOp::RedOpType::PRODUCT)},
{"min", c10d::ReduceOp(c10d::ReduceOp::RedOpType::MIN)},
{"max", c10d::ReduceOp(c10d::ReduceOp::RedOpType::MAX)},
{"band", c10d::ReduceOp(c10d::ReduceOp::RedOpType::BAND)},
{"bor", c10d::ReduceOp(c10d::ReduceOp::RedOpType::BOR)},
{"bxor", c10d::ReduceOp(c10d::ReduceOp::RedOpType::BXOR)},
// TODO: support premul_sum
// {"premul_sum", c10d::ReduceOp(c10d::ReduceOp::RedOpType::PREMUL_SUM)},
{"unused", c10d::ReduceOp(c10d::ReduceOp::RedOpType::UNUSED)}};
c10d::ReduceOp to_reduce_op(const std::string& reduce_op) {
auto it = str_to_reduce_op.find(reduce_op);
TORCH_CHECK(
it != str_to_reduce_op.end(), "Unrecognized reduce_op: ", reduce_op);
return it->second;
}
at::Tensor& all_reduce_(
at::Tensor& input,
// NOLINTNEXTLINE(performance-unnecessary-value-param)
std::string reduce_op,
// NOLINTNEXTLINE(performance-unnecessary-value-param)
std::string group_name) {
c10d::AllreduceOptions opts;
opts.reduceOp = to_reduce_op(reduce_op);
std::vector<at::Tensor> inputs{input};
auto group = c10d::resolve_process_group(group_name);
auto work = group->allreduce(inputs, opts);
c10d::register_work(input, work);
return input;
}
at::Tensor all_reduce(
const at::Tensor& input,
std::string reduce_op,
std::string group_name) {
auto output = input.clone(at::MemoryFormat::Contiguous);
return all_reduce_(output, std::move(reduce_op), std::move(group_name));
}
std::vector<at::Tensor> all_reduce_coalesced_(
std::vector<at::Tensor> inputs,
// NOLINTNEXTLINE(performance-unnecessary-value-param)
std::string reduce_op,
// NOLINTNEXTLINE(performance-unnecessary-value-param)
std::string group_name) {
c10d::AllreduceCoalescedOptions opts;
opts.reduceOp = to_reduce_op(reduce_op);
auto group = c10d::resolve_process_group(group_name);
auto work = group->allreduce_coalesced(inputs, opts);
for (const auto& tensor : inputs) {
c10d::register_work(tensor, work);
}
return inputs;
}
std::vector<at::Tensor> all_reduce_coalesced(
// NOLINTNEXTLINE(performance-unnecessary-value-param)
std::vector<at::Tensor> inputs,
std::string reduce_op,
std::string group_name) {
std::vector<at::Tensor> outputs;
outputs.reserve(inputs.size());
for (const auto& tensor : inputs) {
outputs.push_back(tensor.clone(at::MemoryFormat::Contiguous));
}
return all_reduce_coalesced_(
outputs, std::move(reduce_op), std::move(group_name));
}
at::Tensor allocate_all_gather_output(
const at::Tensor& input,
int64_t group_size) {
auto output_size = input.sizes().vec();
output_size[0] *= group_size;
return at::empty(
output_size,
at::TensorOptions().dtype(input.dtype()).device(input.device()));
}
std::vector<at::Tensor> all_gather_into_tensor_coalesced(
std::vector<at::Tensor> inputs,
int64_t group_size,
// NOLINTNEXTLINE(performance-unnecessary-value-param)
std::string group_name) {
std::vector<at::Tensor> outputs;
outputs.reserve(inputs.size());
for (const auto& tensor : inputs) {
outputs.push_back(allocate_all_gather_output(tensor, group_size));
}
auto group = c10d::resolve_process_group(group_name);
auto work = group->allgather_into_tensor_coalesced(outputs, inputs);
for (const auto& tensor : outputs) {
c10d::register_work(tensor, work);
}
return outputs;
}
at::Tensor all_gather_into_tensor(
const at::Tensor& input,
int64_t group_size,
std::string group_name) {
std::vector<at::Tensor> inputs{input};
return all_gather_into_tensor_coalesced(
inputs, group_size, std::move(group_name))[0];
}
at::Tensor& all_gather_into_tensor_out(
at::Tensor& input,
int64_t group_size,
const std::string& group_name,
at::Tensor& output) {
c10d::AllgatherOptions opts;
auto group = c10d::resolve_process_group(group_name);
auto work = group->_allgather_base(output, input, opts);
c10d::register_work(output, work);
return output;
}
at::Tensor allocate_reduce_scatter_output(
const at::Tensor& input,
const int64_t group_size) {
auto output_size = input.sizes().vec();
if (output_size[0] % group_size != 0) {
LOG(WARNING) << "The first dimension of the reduce_scatter input ("
<< output_size[0] << ") is not divisible by the group size ("
<< group_size << ").";
}
output_size[0] /= group_size;
return at::empty(
output_size,
at::TensorOptions().dtype(input.dtype()).device(input.device()));
}
std::vector<at::Tensor> reduce_scatter_tensor_coalesced(
std::vector<at::Tensor> inputs,
// NOLINTNEXTLINE(performance-unnecessary-value-param)
std::string reduce_op,
int64_t group_size,
// NOLINTNEXTLINE(performance-unnecessary-value-param)
std::string group_name) {
c10d::ReduceScatterOptions opts;
opts.reduceOp = to_reduce_op(reduce_op);
std::vector<at::Tensor> outputs;
outputs.reserve(inputs.size());
for (const auto& tensor : inputs) {
outputs.push_back(allocate_reduce_scatter_output(tensor, group_size));
}
auto group = c10d::resolve_process_group(group_name);
auto work = group->reduce_scatter_tensor_coalesced(outputs, inputs, opts);
for (const auto& tensor : outputs) {
c10d::register_work(tensor, work);
}
return outputs;
}
at::Tensor reduce_scatter_tensor(
const at::Tensor& input,
std::string reduce_op,
int64_t group_size,
std::string group_name) {
std::vector<at::Tensor> inputs{input};
return reduce_scatter_tensor_coalesced(
inputs, std::move(reduce_op), group_size, std::move(group_name))[0];
}
at::Tensor all_to_all_single(
const at::Tensor& input,
std::vector<int64_t> output_split_sizes,
std::vector<int64_t> input_split_sizes,
// NOLINTNEXTLINE(performance-unnecessary-value-param)
std::string group_name) {
std::vector<int64_t> output_sizes = input.sizes().vec();
output_sizes[0] = std::accumulate(
output_split_sizes.begin(), output_split_sizes.end(), int64_t(0));
auto output = input.new_empty(output_sizes);
auto group = c10d::resolve_process_group(group_name);
auto work = group->alltoall_base(
output,
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
const_cast<at::Tensor&>(input),
output_split_sizes,
input_split_sizes);
c10d::register_work(output, work);
return output;
}
// NOLINTNEXTLINE(performance-unnecessary-value-param)
at::Tensor& broadcast_(at::Tensor& input, int64_t src, std::string group_name) {
c10d::BroadcastOptions opts;
opts.rootRank = src;
std::vector<at::Tensor> inputs{input};
auto group = c10d::resolve_process_group(group_name);
auto work = group->broadcast(inputs, opts);
c10d::register_work(input, work);
return input;
}
at::Tensor broadcast(
const at::Tensor& input,
int64_t src,
std::string group_name) {
auto output = input.clone(at::MemoryFormat::Contiguous);
return broadcast_(output, src, std::move(group_name));
}
} // namespace
TORCH_LIBRARY(_c10d_functional, m) {
m.def(
"all_reduce(Tensor input, str reduce_op, str group_name) -> Tensor",
torch::dispatch(
c10::DispatchKey::CompositeExplicitAutograd, ::all_reduce),
{at::Tag::pt2_compliant_tag});
m.def(
"all_reduce_(Tensor(a!) input, str reduce_op, str group_name) -> Tensor(a!)",
torch::dispatch(
c10::DispatchKey::CompositeExplicitAutograd, ::all_reduce_),
{at::Tag::pt2_compliant_tag});
m.def(
"all_reduce_coalesced(Tensor[] inputs, str reduce_op, str group_name) -> Tensor[]",
torch::dispatch(
c10::DispatchKey::CompositeExplicitAutograd, ::all_reduce_coalesced),
{at::Tag::pt2_compliant_tag});
m.def(
"all_reduce_coalesced_(Tensor[](a!) inputs, str reduce_op, str group_name) -> Tensor[](a!)",
torch::dispatch(
c10::DispatchKey::CompositeExplicitAutograd, ::all_reduce_coalesced_),
{at::Tag::pt2_compliant_tag});
m.def(
"all_gather_into_tensor_out(Tensor input, int group_size, str group_name, *, Tensor(a!) out) -> Tensor(a!)",
torch::dispatch(
c10::DispatchKey::CompositeExplicitAutograd,
::all_gather_into_tensor_out),
{at::Tag::pt2_compliant_tag});
m.def(
"all_gather_into_tensor(Tensor input, int group_size, str group_name) -> Tensor",
torch::dispatch(
c10::DispatchKey::CompositeExplicitAutograd,
::all_gather_into_tensor),
{at::Tag::pt2_compliant_tag});
m.def(
"all_gather_into_tensor_coalesced(Tensor[] inputs, int group_size, str group_name) -> Tensor[]",
torch::dispatch(
c10::DispatchKey::CompositeExplicitAutograd,
::all_gather_into_tensor_coalesced),
{at::Tag::pt2_compliant_tag});
m.def(
"reduce_scatter_tensor(Tensor input, str reduce_op, int group_size, str group_name) -> Tensor",
torch::dispatch(
c10::DispatchKey::CompositeExplicitAutograd, ::reduce_scatter_tensor),
{at::Tag::pt2_compliant_tag});
m.def(
"reduce_scatter_tensor_coalesced(Tensor[] inputs, str reduce_op, int group_size, str group_name) -> Tensor[]",
torch::dispatch(
c10::DispatchKey::CompositeExplicitAutograd,
::reduce_scatter_tensor_coalesced),
{at::Tag::pt2_compliant_tag});
m.def(
"all_to_all_single("
"Tensor input, "
"SymInt[] output_split_sizes, "
"SymInt[] input_split_sizes, "
"str group_name) -> Tensor",
torch::dispatch(
c10::DispatchKey::CompositeExplicitAutograd, ::all_to_all_single),
{at::Tag::pt2_compliant_tag});
m.def(
"broadcast(Tensor input, int src, str group_name) -> Tensor",
torch::dispatch(c10::DispatchKey::CompositeExplicitAutograd, ::broadcast),
{at::Tag::pt2_compliant_tag});
m.def(
"broadcast_(Tensor(a!) input, int src, str group_name) -> Tensor(a!)",
torch::dispatch(
c10::DispatchKey::CompositeExplicitAutograd, ::broadcast_),
{at::Tag::pt2_compliant_tag});
m.def(
"wait_tensor(Tensor tensor) -> Tensor",
torch::dispatch(
c10::DispatchKey::CompositeExplicitAutograd, c10d::wait_tensor),
{at::Tag::pt2_compliant_tag});
}
namespace {
class AllToAllSingle : public torch::autograd::Function<AllToAllSingle> {
public:
static torch::autograd::Variable forward(
torch::autograd::AutogradContext* ctx,
const at::Tensor& input,
// NOLINTNEXTLINE(performance-unnecessary-value-param)
std::vector<int64_t> output_split_sizes,
// NOLINTNEXTLINE(performance-unnecessary-value-param)
std::vector<int64_t> input_split_sizes,
// NOLINTNEXTLINE(performance-unnecessary-value-param)
std::string group_name) {
// swap sizes for backwards pass
ctx->saved_data["output_split_sizes"] = input_split_sizes;
ctx->saved_data["input_split_sizes"] = output_split_sizes;
ctx->saved_data["group_name"] = group_name;
return c10::Dispatcher::singleton()
.findSchemaOrThrow("_c10d_functional::all_to_all_single", "")
.typed<decltype(all_to_all_single)>()
.call(input, output_split_sizes, input_split_sizes, group_name);
}
static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
const torch::autograd::variable_list& grad_out_list) {
const std::vector<int64_t>& output_split_sizes =
ctx->saved_data["output_split_sizes"].toIntVector();
const std::vector<int64_t>& input_split_sizes =
ctx->saved_data["input_split_sizes"].toIntVector();
const std::string& group_name = ctx->saved_data["group_name"].toStringRef();
DCHECK(grad_out_list.size() == 1);
auto grad_out = grad_out_list[0].contiguous();
auto out =
c10::Dispatcher::singleton()
.findSchemaOrThrow("_c10d_functional::all_to_all_single", "")
.typed<decltype(all_to_all_single)>()
.call(grad_out, output_split_sizes, input_split_sizes, group_name);
// do an explicit wait to avoid cuda stream issues
// TODO: track active cuda stream in wait
out = c10::Dispatcher::singleton()
.findSchemaOrThrow("_c10d_functional::wait_tensor", "")
.typed<decltype(c10d::wait_tensor)>()
.call(out);
return {out, at::Tensor(), at::Tensor(), at::Tensor()};
}
};
at::Tensor all_to_all_single_autograd(
const at::Tensor& input,
const std::vector<int64_t>& output_split_sizes,
const std::vector<int64_t>& input_split_sizes,
const std::string& group_name) {
return AllToAllSingle::apply(
input, output_split_sizes, input_split_sizes, group_name);
}
class ReduceScatterTensor
: public torch::autograd::Function<ReduceScatterTensor> {
public:
static torch::autograd::Variable forward(
torch::autograd::AutogradContext* ctx,
const at::Tensor& input,
const std::string& reduce_op,
int64_t group_size,
const std::string& group_name) {
TORCH_CHECK(reduce_op == "sum", "Only sum reduce op is supported");
ctx->saved_data["group_size"] = group_size;
ctx->saved_data["group_name"] = group_name;
return c10::Dispatcher::singleton()
.findSchemaOrThrow("_c10d_functional::reduce_scatter_tensor", "")
.typed<decltype(reduce_scatter_tensor)>()
.call(input, reduce_op, group_size, group_name);
}
static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
const torch::autograd::variable_list& grad_out_list) {
const int64_t group_size = ctx->saved_data["group_size"].toInt();
const std::string& group_name = ctx->saved_data["group_name"].toStringRef();
DCHECK(grad_out_list.size() == 1);
const auto& grad_out = grad_out_list[0];
auto out =
c10::Dispatcher::singleton()
.findSchemaOrThrow("_c10d_functional::all_gather_into_tensor", "")
.typed<decltype(all_gather_into_tensor)>()
.call(grad_out, group_size, group_name);
// do an explicit wait to avoid cuda stream issues
// TODO: track active cuda stream in wait
out = c10::Dispatcher::singleton()
.findSchemaOrThrow("_c10d_functional::wait_tensor", "")
.typed<decltype(c10d::wait_tensor)>()
.call(out);
return {
out,
at::Tensor(),
at::Tensor(),
at::Tensor(),
};
}
};
at::Tensor reduce_scatter_tensor_autograd(
const at::Tensor& input,
const std::string& reduce_op,
int64_t group_size,
const std::string& group_name) {
return ReduceScatterTensor::apply(input, reduce_op, group_size, group_name);
}
class AllGatherIntoTensor
: public torch::autograd::Function<AllGatherIntoTensor> {
public:
static torch::autograd::Variable forward(
torch::autograd::AutogradContext* ctx,
const at::Tensor& input,
int64_t group_size,
const std::string& group_name) {
ctx->saved_data["group_size"] = group_size;
ctx->saved_data["group_name"] = group_name;
return c10::Dispatcher::singleton()
.findSchemaOrThrow("_c10d_functional::all_gather_into_tensor", "")
.typed<decltype(all_gather_into_tensor)>()
.call(input, group_size, group_name);
}
static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
const torch::autograd::variable_list& grad_out_list) {
const int64_t group_size = ctx->saved_data["group_size"].toInt();
const std::string& group_name = ctx->saved_data["group_name"].toStringRef();
DCHECK(grad_out_list.size() == 1);
const auto& grad_out = grad_out_list[0];
auto out =
c10::Dispatcher::singleton()
.findSchemaOrThrow("_c10d_functional::reduce_scatter_tensor", "")
.typed<decltype(reduce_scatter_tensor)>()
.call(grad_out, "sum", group_size, group_name);
// do an explicit wait to avoid cuda stream issues
// TODO: track active cuda stream in wait
out = c10::Dispatcher::singleton()
.findSchemaOrThrow("_c10d_functional::wait_tensor", "")
.typed<decltype(c10d::wait_tensor)>()
.call(out);
return {
out,
at::Tensor(),
at::Tensor(),
};
}
};
at::Tensor all_gather_into_tensor_autograd(
const at::Tensor& input,
int64_t group_size,
const std::string& group_name) {
return AllGatherIntoTensor::apply(input, group_size, group_name);
}
} // namespace
TORCH_LIBRARY(_c10d_functional_autograd, m) {
m.def(
"all_to_all_single("
"Tensor input, "
"SymInt[] output_split_sizes, "
"SymInt[] input_split_sizes, "
"str group_name) -> Tensor",
torch::dispatch(c10::DispatchKey::Autograd, ::all_to_all_single_autograd),
{at::Tag::pt2_compliant_tag});
m.def(
"reduce_scatter_tensor("
"Tensor input, "
"str reduce_op, "
"int group_size, "
"str group_name) -> Tensor",
torch::dispatch(
c10::DispatchKey::Autograd, ::reduce_scatter_tensor_autograd),
{at::Tag::pt2_compliant_tag});
m.def(
"all_gather_into_tensor("
"Tensor input, "
"int group_size, "
"str group_name) -> Tensor",
torch::dispatch(
c10::DispatchKey::Autograd, ::all_gather_into_tensor_autograd),
{at::Tag::pt2_compliant_tag});
}
namespace {
// DTensor related comm operations, sharing code with functional collective for
// now
at::Tensor shard_dim_alltoall(
const at::Tensor& input,
int64_t gather_dim,
int64_t shard_dim,
const std::string& group_name) {
auto group = c10d::resolve_process_group(group_name);
auto group_size = group->getSize();
std::vector<int64_t> output_sizes = input.sizes().vec();
if (output_sizes[shard_dim] % group_size != 0) {
LOG(WARNING) << "The first dimension of the shard_dim_alltoall input ("
<< output_sizes[shard_dim]
<< ") is not divisible by the group size (" << group_size
<< ").";
}
output_sizes[shard_dim] = output_sizes[shard_dim] / group_size;
std::vector<at::Tensor> inputs;
inputs.reserve(group_size);
auto length = output_sizes[shard_dim];
for (int i = 0; i < group_size; i++) {
inputs.push_back(input.narrow(shard_dim, i * length, length).contiguous());
}
// allocate outputs
std::vector<at::Tensor> outputs;
outputs.reserve(group_size);
for (int i = 0; i < group_size; i++) {
outputs.push_back(input.new_empty(output_sizes).contiguous());
}
auto work = group->alltoall(outputs, inputs);
work->wait();
// TODO: it's very tricky to get the current async behavior work for shard dim
// alltoall so for now we just keep this comm op to be synchronous. We can
// revisit later how to support the async case with the Work registry.
return at::cat(outputs, gather_dim);
}
} // namespace
// DTensor comm op registry
TORCH_LIBRARY(_dtensor, m) {
m.def(
"shard_dim_alltoall(Tensor input, int gather_dim, int shard_dim, str group_name) -> Tensor",
torch::dispatch(
c10::DispatchKey::CompositeExplicitAutograd, ::shard_dim_alltoall),
{at::Tag::pt2_compliant_tag});
}