pytorch/torch/csrc/distributed/c10d/Functional.cpp
Tristan Rice ddd0ed1b43 distributed: templated ring attention (#124215)
This adds a templated version of the ring attention forwards function as well as tests it with memory efficient attention. This doesn't add support for memory efficient attention in DTensor. That will be added in a follow up PR.

This templating is also a POC of how to support other attention ops such as Jagged/nested tensor and as well how to implement striped attention in a scalable way.

Misc changes:

* Fixes all_to_all_single autograd implementation with CUDA + adds NCCL test
* Adds compile support to the ring attention implementations (required some tweaks to process groups)

Test plan:

```
pytest test/distributed/_tensor/test_attention.py
pytest test/distributed/test_functional_api.py
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124215
Approved by: https://github.com/wanchaol
2024-04-19 00:57:08 +00:00

582 lines
20 KiB
C++

#include <ATen/ATen.h>
#include <ATen/core/op_registration/op_registration.h>
#include <c10/core/DispatchKey.h>
#include <torch/csrc/autograd/custom_function.h>
#include <torch/csrc/autograd/function.h>
#include <torch/csrc/distributed/c10d/GroupRegistry.hpp>
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
#include <torch/csrc/distributed/c10d/RankLocal.hpp>
#include <utility>
namespace {
class WorkRegistry {
public:
void register_work(
const at::Tensor& tensor,
const c10::intrusive_ptr<c10d::Work>& work) {
const auto storage = tensor.storage().getWeakStorageImpl();
std::unique_lock lock(lock_);
auto [it, inserted] = registry_.emplace(storage, work);
TORCH_CHECK(
inserted || it->second != work,
"The tensor storage is already associated with another work.");
}
c10::intrusive_ptr<c10d::Work> pop_work(const at::Tensor& tensor) {
const auto storage = tensor.storage().getWeakStorageImpl();
std::unique_lock lock(lock_);
auto it = registry_.find(storage);
if (it == registry_.end()) {
return nullptr;
}
auto work = it->second;
registry_.erase(it);
return work;
}
~WorkRegistry() {
// If there are still unwaited work objects, their corresponding process
// groups should have already been destroyed at this stage. Any attempts to
// wait for these work objects or to destroy them will only result in
// confusing errors. Therefore, we simply issue a warning and intentionally
// allow the unwaited work objects to leak.
if (!registry_.empty()) {
TORCH_WARN(
"At the time of process termination, there are still ",
registry_.size(),
" unwaited c10d_functional collective calls. "
"Please review your program to ensure c10d_functional.wait_tensor() "
"is invoked on all tensors returned from c10d_functional collective "
"ops before they are used.");
}
for (auto& it : registry_) {
it.second.release();
}
}
private:
std::unordered_map<
c10::weak_intrusive_ptr<c10::StorageImpl>,
c10::intrusive_ptr<c10d::Work>>
registry_;
std::mutex lock_;
};
static WorkRegistry process_registry;
void register_work(
const at::Tensor& tensor,
const c10::intrusive_ptr<c10d::Work>& work) {
if (c10d::get_thread_isolation_mode()) {
c10d::RankLocal<WorkRegistry>::get().register_work(tensor, work);
} else {
process_registry.register_work(tensor, work);
}
}
c10::intrusive_ptr<c10d::Work> pop_work(const at::Tensor& tensor) {
if (c10d::get_thread_isolation_mode()) {
return c10d::RankLocal<WorkRegistry>::get().pop_work(tensor);
} else {
return process_registry.pop_work(tensor);
}
}
const std::unordered_map<std::string, c10d::ReduceOp> str_to_reduce_op = {
{"sum", c10d::ReduceOp(c10d::ReduceOp::RedOpType::SUM)},
{"avg", c10d::ReduceOp(c10d::ReduceOp::RedOpType::AVG)},
{"product", c10d::ReduceOp(c10d::ReduceOp::RedOpType::PRODUCT)},
{"min", c10d::ReduceOp(c10d::ReduceOp::RedOpType::MIN)},
{"max", c10d::ReduceOp(c10d::ReduceOp::RedOpType::MAX)},
{"band", c10d::ReduceOp(c10d::ReduceOp::RedOpType::BAND)},
{"bor", c10d::ReduceOp(c10d::ReduceOp::RedOpType::BOR)},
{"bxor", c10d::ReduceOp(c10d::ReduceOp::RedOpType::BXOR)},
// TODO: support premul_sum
// {"premul_sum", c10d::ReduceOp(c10d::ReduceOp::RedOpType::PREMUL_SUM)},
{"unused", c10d::ReduceOp(c10d::ReduceOp::RedOpType::UNUSED)}};
c10d::ReduceOp to_reduce_op(const std::string& reduce_op) {
auto it = str_to_reduce_op.find(reduce_op);
TORCH_CHECK(
it != str_to_reduce_op.end(), "Unrecognized reduce_op: ", reduce_op);
return it->second;
}
at::Tensor& all_reduce_(
at::Tensor& input,
// NOLINTNEXTLINE(performance-unnecessary-value-param)
std::string reduce_op,
// NOLINTNEXTLINE(performance-unnecessary-value-param)
std::string group_name) {
c10d::AllreduceOptions opts;
opts.reduceOp = to_reduce_op(reduce_op);
std::vector<at::Tensor> inputs{input};
auto group = c10d::resolve_process_group(group_name);
auto work = group->allreduce(inputs, opts);
c10d::RankLocal<WorkRegistry>::get().register_work(input, work);
return input;
}
at::Tensor all_reduce(
const at::Tensor& input,
std::string reduce_op,
std::string group_name) {
auto output = input.clone(at::MemoryFormat::Contiguous);
return all_reduce_(output, std::move(reduce_op), std::move(group_name));
}
std::vector<at::Tensor> all_reduce_coalesced_(
std::vector<at::Tensor> inputs,
// NOLINTNEXTLINE(performance-unnecessary-value-param)
std::string reduce_op,
// NOLINTNEXTLINE(performance-unnecessary-value-param)
std::string group_name) {
c10d::AllreduceCoalescedOptions opts;
opts.reduceOp = to_reduce_op(reduce_op);
auto group = c10d::resolve_process_group(group_name);
auto work = group->allreduce_coalesced(inputs, opts);
for (const auto& tensor : inputs) {
c10d::RankLocal<WorkRegistry>::get().register_work(tensor, work);
}
return inputs;
}
std::vector<at::Tensor> all_reduce_coalesced(
// NOLINTNEXTLINE(performance-unnecessary-value-param)
std::vector<at::Tensor> inputs,
std::string reduce_op,
std::string group_name) {
std::vector<at::Tensor> outputs;
outputs.reserve(inputs.size());
for (const auto& tensor : inputs) {
outputs.push_back(tensor.clone(at::MemoryFormat::Contiguous));
}
return all_reduce_coalesced_(
outputs, std::move(reduce_op), std::move(group_name));
}
at::Tensor allocate_all_gather_output(
const at::Tensor& input,
int64_t group_size) {
auto output_size = input.sizes().vec();
output_size[0] *= group_size;
return at::empty(
output_size,
at::TensorOptions().dtype(input.dtype()).device(input.device()));
}
std::vector<at::Tensor> all_gather_into_tensor_coalesced(
std::vector<at::Tensor> inputs,
int64_t group_size,
// NOLINTNEXTLINE(performance-unnecessary-value-param)
std::string group_name) {
std::vector<at::Tensor> outputs;
outputs.reserve(inputs.size());
for (const auto& tensor : inputs) {
outputs.push_back(allocate_all_gather_output(tensor, group_size));
}
auto group = c10d::resolve_process_group(group_name);
auto work = group->allgather_into_tensor_coalesced(outputs, inputs);
for (const auto& tensor : outputs) {
c10d::RankLocal<WorkRegistry>::get().register_work(tensor, work);
}
return outputs;
}
at::Tensor all_gather_into_tensor(
const at::Tensor& input,
int64_t group_size,
std::string group_name) {
std::vector<at::Tensor> inputs{input};
return all_gather_into_tensor_coalesced(
inputs, group_size, std::move(group_name))[0];
}
at::Tensor allocate_reduce_scatter_output(
const at::Tensor& input,
const int64_t group_size) {
auto output_size = input.sizes().vec();
if (output_size[0] % group_size != 0) {
LOG(WARNING) << "The first dimension of the reduce_scatter input ("
<< output_size[0] << ") is not divisible by the group size ("
<< group_size << ").";
}
output_size[0] /= group_size;
return at::empty(
output_size,
at::TensorOptions().dtype(input.dtype()).device(input.device()));
}
std::vector<at::Tensor> reduce_scatter_tensor_coalesced(
std::vector<at::Tensor> inputs,
// NOLINTNEXTLINE(performance-unnecessary-value-param)
std::string reduce_op,
int64_t group_size,
// NOLINTNEXTLINE(performance-unnecessary-value-param)
std::string group_name) {
c10d::ReduceScatterOptions opts;
opts.reduceOp = to_reduce_op(reduce_op);
std::vector<at::Tensor> outputs;
outputs.reserve(inputs.size());
for (const auto& tensor : inputs) {
outputs.push_back(allocate_reduce_scatter_output(tensor, group_size));
}
auto group = c10d::resolve_process_group(group_name);
auto work = group->reduce_scatter_tensor_coalesced(outputs, inputs, opts);
for (const auto& tensor : outputs) {
c10d::RankLocal<WorkRegistry>::get().register_work(tensor, work);
}
return outputs;
}
at::Tensor reduce_scatter_tensor(
const at::Tensor& input,
std::string reduce_op,
int64_t group_size,
std::string group_name) {
std::vector<at::Tensor> inputs{input};
return reduce_scatter_tensor_coalesced(
inputs, std::move(reduce_op), group_size, std::move(group_name))[0];
}
at::Tensor all_to_all_single(
const at::Tensor& input,
std::vector<int64_t> output_split_sizes,
std::vector<int64_t> input_split_sizes,
// NOLINTNEXTLINE(performance-unnecessary-value-param)
std::string group_name) {
std::vector<int64_t> output_sizes = input.sizes().vec();
output_sizes[0] = std::accumulate(
output_split_sizes.begin(), output_split_sizes.end(), int64_t(0));
auto output = input.new_empty(output_sizes);
auto group = c10d::resolve_process_group(group_name);
auto work = group->alltoall_base(
output,
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
const_cast<at::Tensor&>(input),
output_split_sizes,
input_split_sizes);
c10d::RankLocal<WorkRegistry>::get().register_work(output, work);
return output;
}
// NOLINTNEXTLINE(performance-unnecessary-value-param)
at::Tensor& broadcast_(at::Tensor& input, int64_t src, std::string group_name) {
c10d::BroadcastOptions opts;
opts.rootRank = src;
std::vector<at::Tensor> inputs{input};
auto group = c10d::resolve_process_group(group_name);
auto work = group->broadcast(inputs, opts);
c10d::RankLocal<WorkRegistry>::get().register_work(input, work);
return input;
}
at::Tensor broadcast(
const at::Tensor& input,
int64_t src,
std::string group_name) {
auto output = input.clone(at::MemoryFormat::Contiguous);
return broadcast_(output, src, std::move(group_name));
}
at::Tensor wait_tensor(const at::Tensor& tensor) {
auto work = c10d::RankLocal<WorkRegistry>::get().pop_work(tensor);
if (work != nullptr) {
work->wait();
}
return tensor;
}
} // namespace
TORCH_LIBRARY(_c10d_functional, m) {
m.def(
"all_reduce(Tensor input, str reduce_op, str group_name) -> Tensor",
torch::dispatch(
c10::DispatchKey::CompositeExplicitAutograd, ::all_reduce),
{at::Tag::pt2_compliant_tag});
m.def(
"all_reduce_(Tensor(a!) input, str reduce_op, str group_name) -> Tensor(a!)",
torch::dispatch(
c10::DispatchKey::CompositeExplicitAutograd, ::all_reduce_),
{at::Tag::pt2_compliant_tag});
m.def(
"all_reduce_coalesced(Tensor[] inputs, str reduce_op, str group_name) -> Tensor[]",
torch::dispatch(
c10::DispatchKey::CompositeExplicitAutograd, ::all_reduce_coalesced),
{at::Tag::pt2_compliant_tag});
m.def(
"all_reduce_coalesced_(Tensor[](a!) inputs, str reduce_op, str group_name) -> Tensor[](a!)",
torch::dispatch(
c10::DispatchKey::CompositeExplicitAutograd, ::all_reduce_coalesced_),
{at::Tag::pt2_compliant_tag});
m.def(
"all_gather_into_tensor(Tensor input, int group_size, str group_name) -> Tensor",
torch::dispatch(
c10::DispatchKey::CompositeExplicitAutograd,
::all_gather_into_tensor),
{at::Tag::pt2_compliant_tag});
m.def(
"all_gather_into_tensor_coalesced(Tensor[] inputs, int group_size, str group_name) -> Tensor[]",
torch::dispatch(
c10::DispatchKey::CompositeExplicitAutograd,
::all_gather_into_tensor_coalesced),
{at::Tag::pt2_compliant_tag});
m.def(
"reduce_scatter_tensor(Tensor input, str reduce_op, int group_size, str group_name) -> Tensor",
torch::dispatch(
c10::DispatchKey::CompositeExplicitAutograd, ::reduce_scatter_tensor),
{at::Tag::pt2_compliant_tag});
m.def(
"reduce_scatter_tensor_coalesced(Tensor[] inputs, str reduce_op, int group_size, str group_name) -> Tensor[]",
torch::dispatch(
c10::DispatchKey::CompositeExplicitAutograd,
::reduce_scatter_tensor_coalesced),
{at::Tag::pt2_compliant_tag});
m.def(
"all_to_all_single("
"Tensor input, "
"SymInt[] output_split_sizes, "
"SymInt[] input_split_sizes, "
"str group_name) -> Tensor",
torch::dispatch(
c10::DispatchKey::CompositeExplicitAutograd, ::all_to_all_single),
{at::Tag::pt2_compliant_tag});
m.def(
"broadcast(Tensor input, int src, str group_name) -> Tensor",
torch::dispatch(c10::DispatchKey::CompositeExplicitAutograd, ::broadcast),
{at::Tag::pt2_compliant_tag});
m.def(
"broadcast_(Tensor(a!) input, int src, str group_name) -> Tensor(a!)",
torch::dispatch(
c10::DispatchKey::CompositeExplicitAutograd, ::broadcast_),
{at::Tag::pt2_compliant_tag});
m.def(
"wait_tensor(Tensor tensor) -> Tensor",
torch::dispatch(
c10::DispatchKey::CompositeExplicitAutograd, ::wait_tensor),
{at::Tag::pt2_compliant_tag});
}
namespace {
class AllToAllSingle : public torch::autograd::Function<AllToAllSingle> {
public:
static torch::autograd::Variable forward(
torch::autograd::AutogradContext* ctx,
const at::Tensor& input,
// NOLINTNEXTLINE(performance-unnecessary-value-param)
std::vector<int64_t> output_split_sizes,
// NOLINTNEXTLINE(performance-unnecessary-value-param)
std::vector<int64_t> input_split_sizes,
// NOLINTNEXTLINE(performance-unnecessary-value-param)
std::string group_name) {
// swap sizes for backwards pass
ctx->saved_data["output_split_sizes"] = input_split_sizes;
ctx->saved_data["input_split_sizes"] = output_split_sizes;
ctx->saved_data["group_name"] = group_name;
return c10::Dispatcher::singleton()
.findSchemaOrThrow("_c10d_functional::all_to_all_single", "")
.typed<decltype(all_to_all_single)>()
.call(input, output_split_sizes, input_split_sizes, group_name);
}
static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::variable_list grad_out_list) {
const std::vector<int64_t>& output_split_sizes =
ctx->saved_data["output_split_sizes"].toIntVector();
const std::vector<int64_t>& input_split_sizes =
ctx->saved_data["input_split_sizes"].toIntVector();
const std::string& group_name = ctx->saved_data["group_name"].toStringRef();
DCHECK(grad_out_list.size() == 1);
auto grad_out = grad_out_list[0].contiguous();
auto out =
c10::Dispatcher::singleton()
.findSchemaOrThrow("_c10d_functional::all_to_all_single", "")
.typed<decltype(all_to_all_single)>()
.call(grad_out, output_split_sizes, input_split_sizes, group_name);
// do an explicit wait to avoid cuda stream issues
// TODO: track active cuda stream in wait
out = c10::Dispatcher::singleton()
.findSchemaOrThrow("_c10d_functional::wait_tensor", "")
.typed<decltype(wait_tensor)>()
.call(out);
return {out, at::Tensor(), at::Tensor(), at::Tensor()};
}
};
at::Tensor all_to_all_single_autograd(
const at::Tensor& input,
const std::vector<int64_t>& output_split_sizes,
const std::vector<int64_t>& input_split_sizes,
const std::string& group_name) {
return AllToAllSingle::apply(
input, output_split_sizes, input_split_sizes, group_name);
}
class ReduceScatterTensor
: public torch::autograd::Function<ReduceScatterTensor> {
public:
static torch::autograd::Variable forward(
torch::autograd::AutogradContext* ctx,
const at::Tensor& input,
std::string reduce_op,
int64_t group_size,
std::string group_name) {
TORCH_CHECK(reduce_op == "sum", "Only sum reduce op is supported");
ctx->saved_data["group_size"] = group_size;
ctx->saved_data["group_name"] = group_name;
return c10::Dispatcher::singleton()
.findSchemaOrThrow("_c10d_functional::reduce_scatter_tensor", "")
.typed<decltype(reduce_scatter_tensor)>()
.call(input, reduce_op, group_size, group_name);
}
static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::variable_list grad_out_list) {
const int64_t group_size = ctx->saved_data["group_size"].toInt();
const std::string& group_name = ctx->saved_data["group_name"].toStringRef();
DCHECK(grad_out_list.size() == 1);
auto grad_out = grad_out_list[0];
auto out =
c10::Dispatcher::singleton()
.findSchemaOrThrow("_c10d_functional::all_gather_into_tensor", "")
.typed<decltype(all_gather_into_tensor)>()
.call(grad_out, group_size, group_name);
// do an explicit wait to avoid cuda stream issues
// TODO: track active cuda stream in wait
out = c10::Dispatcher::singleton()
.findSchemaOrThrow("_c10d_functional::wait_tensor", "")
.typed<decltype(wait_tensor)>()
.call(out);
return {
out,
at::Tensor(),
at::Tensor(),
at::Tensor(),
};
}
};
at::Tensor reduce_scatter_tensor_autograd(
const at::Tensor& input,
std::string reduce_op,
int64_t group_size,
std::string group_name) {
return ReduceScatterTensor::apply(input, reduce_op, group_size, group_name);
}
class AllGatherIntoTensor
: public torch::autograd::Function<AllGatherIntoTensor> {
public:
static torch::autograd::Variable forward(
torch::autograd::AutogradContext* ctx,
const at::Tensor& input,
int64_t group_size,
std::string group_name) {
ctx->saved_data["group_size"] = group_size;
ctx->saved_data["group_name"] = group_name;
return c10::Dispatcher::singleton()
.findSchemaOrThrow("_c10d_functional::all_gather_into_tensor", "")
.typed<decltype(all_gather_into_tensor)>()
.call(input, group_size, group_name);
}
static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::variable_list grad_out_list) {
const int64_t group_size = ctx->saved_data["group_size"].toInt();
const std::string& group_name = ctx->saved_data["group_name"].toStringRef();
DCHECK(grad_out_list.size() == 1);
auto grad_out = grad_out_list[0];
auto out =
c10::Dispatcher::singleton()
.findSchemaOrThrow("_c10d_functional::reduce_scatter_tensor", "")
.typed<decltype(reduce_scatter_tensor)>()
.call(grad_out, "sum", group_size, group_name);
// do an explicit wait to avoid cuda stream issues
// TODO: track active cuda stream in wait
out = c10::Dispatcher::singleton()
.findSchemaOrThrow("_c10d_functional::wait_tensor", "")
.typed<decltype(wait_tensor)>()
.call(out);
return {
out,
at::Tensor(),
at::Tensor(),
};
}
};
at::Tensor all_gather_into_tensor_autograd(
const at::Tensor& input,
int64_t group_size,
std::string group_name) {
return AllGatherIntoTensor::apply(input, group_size, group_name);
}
} // namespace
TORCH_LIBRARY(_c10d_functional_autograd, m) {
m.def(
"all_to_all_single("
"Tensor input, "
"SymInt[] output_split_sizes, "
"SymInt[] input_split_sizes, "
"str group_name) -> Tensor",
torch::dispatch(c10::DispatchKey::Autograd, ::all_to_all_single_autograd),
{at::Tag::pt2_compliant_tag});
m.def(
"reduce_scatter_tensor("
"Tensor input, "
"str reduce_op, "
"int group_size, "
"str group_name) -> Tensor",
torch::dispatch(
c10::DispatchKey::Autograd, ::reduce_scatter_tensor_autograd),
{at::Tag::pt2_compliant_tag});
m.def(
"all_gather_into_tensor("
"Tensor input, "
"int group_size, "
"str group_name) -> Tensor",
torch::dispatch(
c10::DispatchKey::Autograd, ::all_gather_into_tensor_autograd),
{at::Tag::pt2_compliant_tag});
}