mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Back out "Make grad point to bucket buffer in DDP to save memory usage" (#43557)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/43557 backout the diff that caused some errors in pytext distributed training Test Plan: Tested by rayhou who verified reverting the diff works Differential Revision: D23320238 fbshipit-source-id: caa0fe74404059e336cd95fdb41373f58ecf486e
This commit is contained in:
parent
58666982fb
commit
f35e069622
|
|
@ -2119,7 +2119,7 @@ class _DistTestBase(object):
|
|||
# Clear gradients manually
|
||||
grad = net.module.weight.grad
|
||||
if grad is not None:
|
||||
grad.requires_grad_(False)
|
||||
grad.detach_()
|
||||
grad.zero_()
|
||||
# Forward + BW
|
||||
batch = torch.tensor([rank]).float().cuda(rank)
|
||||
|
|
|
|||
|
|
@ -269,11 +269,7 @@ Tensor & detach_(Tensor & self) {
|
|||
"of detach_(). Alternatively, create this view with an "
|
||||
"`unsafe_` version of the function that produced it.");
|
||||
} else {
|
||||
AT_ERROR("If you are using DistributedDataParallel (DDP) for training, "
|
||||
"gradients are views of DDP buckets, and hence detach_() cannot "
|
||||
"be called on these gradients. To fix this error, please refer "
|
||||
"to the Optimizer.zero_grad() function "
|
||||
"in torch/optim/optimizer.py as the solution.");
|
||||
AT_ERROR("Can't detach views in-place. Use detach() instead");
|
||||
}
|
||||
}
|
||||
// I think the choice here is conservative. In principle, doing
|
||||
|
|
|
|||
|
|
@ -161,11 +161,6 @@ struct TORCH_API AccumulateGrad : public Node {
|
|||
// valid operation which adds `new_grad` to `variable_grad` in
|
||||
// place. `variable_grad` is thus still referring to the same tensor
|
||||
// after the operation.
|
||||
// Also DistributedDataParallel(DDP) package relies on grad being
|
||||
// mutated in place for saving peak memory usage. DDP will still
|
||||
// work correctly if it is mutated out of place here, but DDP will
|
||||
// maintain one extra copy of grad tensors in buffer and thus
|
||||
// increase peak memory usage.
|
||||
variable_grad += new_grad;
|
||||
CHECK_RESULT(variable_grad, variable);
|
||||
// ^ We could enforce the contract more aggressively here by writing:
|
||||
|
|
|
|||
|
|
@ -168,8 +168,8 @@ PyObject* c10d_init(PyObject* _unused) {
|
|||
py::arg("find_unused_parameters") = false,
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def(
|
||||
"prepare_forward",
|
||||
&::c10d::Reducer::prepare_forward,
|
||||
"initialize_buckets",
|
||||
&::c10d::Reducer::initialize_buckets,
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def(
|
||||
"prepare_for_backward",
|
||||
|
|
|
|||
|
|
@ -64,9 +64,7 @@ Reducer::Reducer(
|
|||
|
||||
// Initialize variable bucketing.
|
||||
// This can be reinitialized later after capturing runtime information.
|
||||
std::unique_lock<std::mutex> lock(this->mutex_);
|
||||
initialize_buckets(std::move(bucket_indices));
|
||||
lock.unlock();
|
||||
|
||||
// All variables are expected to have their `grad_fn` set to the gradient
|
||||
// accumulation function (since they are leafs in the autograd graph).
|
||||
|
|
@ -317,66 +315,56 @@ void Reducer::mark_variable_ready_dense(VariableIndex index) {
|
|||
const auto length = replica.lengths[bucket_index.intra_bucket_index];
|
||||
auto& bucket_view = replica.bucket_views[bucket_index.intra_bucket_index];
|
||||
|
||||
// Copy contents of gradient tensor to bucket tensor.
|
||||
// If the gradient is not set, we assume it wasn't computed
|
||||
// as part of the current backwards pass, and zero the part
|
||||
// of the bucket it would otherwise hold.
|
||||
runGradCallbackForVariable(variable, [&](auto& grad) {
|
||||
if (grad.defined()) {
|
||||
// Copy grad to bucket view buffer if grad and bucket_view are pointing
|
||||
// to different storages, and then let grad point to bucket_view
|
||||
// for saving memory and avoiding copies in subsquent iterations.
|
||||
// In most cases, the copy is needed only at first
|
||||
// iteration, there will be no copies in subsquent iterations.
|
||||
// In rare cases, if users explicitly set grad to be None after every
|
||||
// iteration, then it needs to copy grad to bucket_view in every
|
||||
// iteration.
|
||||
if (!grad.is_alias_of(bucket_view)) {
|
||||
// Ensure that the gradient type matches the bucket type.
|
||||
TORCH_CHECK(
|
||||
grad.options().type_equal(bucket_view.options()),
|
||||
"Expected ",
|
||||
bucket_view.toString(),
|
||||
", got ",
|
||||
grad.toString());
|
||||
TORCH_INTERNAL_ASSERT(grad.device() == bucket_view.device());
|
||||
TORCH_INTERNAL_ASSERT(grad.numel() == bucket_view.numel());
|
||||
// AccumulateGrad doesn't HAVE to obey the grad layout contract.
|
||||
// The penalty for disobedience is reduced performance, not numerical
|
||||
// death. Warnings here help diagnose poor DDP performance.
|
||||
if (grad.strides() != bucket_view.strides()) {
|
||||
TORCH_WARN_ONCE(
|
||||
"Grad strides do not match bucket view strides. "
|
||||
"This may indicate grad was not created according to the "
|
||||
"gradient layout contract, or that the param's strides "
|
||||
"changed since DDP was constructed. This is not an error, "
|
||||
"but may impair performance.\n"
|
||||
"grad.sizes() = ",
|
||||
grad.sizes(),
|
||||
", strides() = ",
|
||||
grad.strides(),
|
||||
"\n",
|
||||
"bucket_view.sizes() = ",
|
||||
bucket_view.sizes(),
|
||||
", strides() = ",
|
||||
bucket_view.strides());
|
||||
}
|
||||
// See Note [DDP Communication Hook]
|
||||
if (comm_hook_ == nullptr) {
|
||||
// imitates wrapped_scalar_tensor in ATen/native/BinaryOps.cpp
|
||||
auto wrapped =
|
||||
c10::scalar_to_tensor(double(1.) / process_group_->getSize());
|
||||
wrapped.unsafeGetTensorImpl()->set_wrapped_number(true);
|
||||
// Divides while copying into the bucket view.
|
||||
at::native::mul_out(bucket_view, grad, wrapped);
|
||||
} else {
|
||||
bucket_view.copy_(grad);
|
||||
}
|
||||
// Let grad point to bucket_view buffer.
|
||||
grad = bucket_view;
|
||||
// The grad is modified and need to be written back.
|
||||
return true;
|
||||
// Ensure that the gradient type matches the bucket type.
|
||||
TORCH_CHECK(
|
||||
grad.options().type_equal(bucket_view.options()),
|
||||
"Expected ",
|
||||
bucket_view.toString(),
|
||||
", got ",
|
||||
grad.toString());
|
||||
// Assert that the grad tensor and the bucket don't share storage.
|
||||
// If they did, we could avoid the copy altogether.
|
||||
// The reason for not doing this is that existing code calls
|
||||
// `detach_` from `zero_grad`, which is incompatible with views.
|
||||
TORCH_INTERNAL_ASSERT(!grad.is_alias_of(bucket_view));
|
||||
TORCH_INTERNAL_ASSERT(grad.device() == bucket_view.device());
|
||||
TORCH_INTERNAL_ASSERT(grad.numel() == bucket_view.numel());
|
||||
// AccumulateGrad doesn't HAVE to obey the grad layout contract.
|
||||
// The penalty for disobedience is reduced performance, not numerical
|
||||
// death. Warnings here help diagnose poor DDP performance.
|
||||
if (grad.strides() != bucket_view.strides()) {
|
||||
TORCH_WARN_ONCE(
|
||||
"Grad strides do not match bucket view strides. "
|
||||
"This may indicate grad was not created according to the "
|
||||
"gradient layout contract, or that the param's strides "
|
||||
"changed since DDP was constructed. This is not an error, "
|
||||
"but may impair performance.\n"
|
||||
"grad.sizes() = ",
|
||||
grad.sizes(),
|
||||
", strides() = ",
|
||||
grad.strides(),
|
||||
"\n",
|
||||
"bucket_view.sizes() = ",
|
||||
bucket_view.sizes(),
|
||||
", strides() = ",
|
||||
bucket_view.strides());
|
||||
}
|
||||
// See Note [DDP Communication Hook]
|
||||
if (comm_hook_ == nullptr) {
|
||||
// imitates wrapped_scalar_tensor in ATen/native/BinaryOps.cpp
|
||||
auto wrapped =
|
||||
c10::scalar_to_tensor(double(1.) / process_group_->getSize());
|
||||
wrapped.unsafeGetTensorImpl()->set_wrapped_number(true);
|
||||
// Divides while copying into the bucket view.
|
||||
at::native::mul_out(bucket_view, grad, wrapped);
|
||||
} else {
|
||||
// If grad and bucket view point to the same storage, no need to copy
|
||||
if (comm_hook_ == nullptr) {
|
||||
bucket_view.div_(process_group_->getSize());
|
||||
}
|
||||
bucket_view.copy_(grad);
|
||||
}
|
||||
} else {
|
||||
bucket_view.zero_();
|
||||
|
|
@ -564,10 +552,20 @@ void Reducer::mark_variable_ready(VariableIndex index) {
|
|||
const c10::Stream currentStream =
|
||||
guard.getStream(replica.contents.device());
|
||||
torch::autograd::Engine::get_default_engine().queue_callback([=] {
|
||||
std::lock_guard<std::mutex> lock(this->mutex_);
|
||||
std::unique_lock<std::mutex> lock(this->mutex_);
|
||||
// Run callback with the current stream
|
||||
c10::OptionalStreamGuard currentStreamGuard{currentStream};
|
||||
this->finalize_backward();
|
||||
// Rebuild bucket if this is the first time to rebuild
|
||||
if (!rebuilt_params_.empty()) {
|
||||
auto rebuilt_bucket_indices = rebuildBuckets();
|
||||
// Unlock before initialize_buckets() as initialize_buckets() requires a
|
||||
// lock, it could result in self deadlock without unlocking here.
|
||||
lock.unlock();
|
||||
initialize_buckets(std::move(rebuilt_bucket_indices));
|
||||
} else {
|
||||
lock.unlock();
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
@ -615,16 +613,7 @@ void Reducer::mark_bucket_ready(size_t bucket_index) {
|
|||
|
||||
void Reducer::initialize_buckets(
|
||||
std::vector<std::vector<size_t>> bucket_indices) {
|
||||
// If initialize_buckets is called inside DDP constructor, then
|
||||
// it does not matter rpc context ptr is nullptr or not, as grad
|
||||
// will not be mutated.
|
||||
// If initialize_buckets is called during training loop, e.g, inside
|
||||
// rebuild_buckets(), since grad could be mutated and be pointed to
|
||||
// bucket_view, then it needs to check rpc context ptr is nullptr or not,
|
||||
// If rpc context ptr is nullptr, mutate variable.grad(); otherwise,
|
||||
// mutate grad in rpc context.
|
||||
using torch::distributed::autograd::ThreadLocalDistAutogradContext;
|
||||
this->rpc_context_.set(ThreadLocalDistAutogradContext::getContextPtr());
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
|
||||
// This shouldn't be called if we're expecting autograd hooks to fire.
|
||||
TORCH_CHECK(
|
||||
|
|
@ -708,6 +697,7 @@ void Reducer::initialize_buckets(
|
|||
|
||||
// Allocate bucket contents tensor.
|
||||
replica.contents = at::empty({static_cast<long>(offset)}, options);
|
||||
|
||||
// Note: "Gradient Layout Contract"
|
||||
//
|
||||
// Here, create views into the contents tensor for each variable's grad.
|
||||
|
|
@ -745,7 +735,7 @@ void Reducer::initialize_buckets(
|
|||
// metadata. Checking just once won't catch if someone messes with
|
||||
// param layouts over time, but not messing with params after DDP
|
||||
// construction is already a documented constraint.
|
||||
initialize_bucket_views(replica, replica.contents, true);
|
||||
initialize_bucketviews(replica, replica.contents);
|
||||
}
|
||||
|
||||
// Add bucket replica to enclosing bucket.
|
||||
|
|
@ -771,61 +761,29 @@ void Reducer::initialize_buckets(
|
|||
}
|
||||
|
||||
// (see Note: "Gradient Layout Contract" in initialize_buckets).
|
||||
void Reducer::initialize_bucket_views(
|
||||
void Reducer::initialize_bucketviews(
|
||||
Reducer::BucketReplica& replica,
|
||||
at::Tensor& contents,
|
||||
bool copy_to_bucket_view) {
|
||||
at::Tensor& contents) {
|
||||
for (size_t i = 0; i < replica.variables.size(); i++) {
|
||||
auto& v = replica.variables[i];
|
||||
const auto& v = replica.variables[i];
|
||||
const auto offset = replica.offsets[i];
|
||||
const auto length = replica.lengths[i];
|
||||
at::Tensor bucket_view;
|
||||
if (v.is_non_overlapping_and_dense()) {
|
||||
// If the param's memory is dense, match its layout, anticipating
|
||||
// the autograd engine (AccumulateGrad) will also create gradients
|
||||
// matching its layout.
|
||||
bucket_view = contents.as_strided(v.sizes(), v.strides(), offset);
|
||||
replica.bucket_views.push_back(
|
||||
contents.as_strided(v.sizes(), v.strides(), offset));
|
||||
} else {
|
||||
// Fall back to a C-style contiguous view, again anticipating
|
||||
// AccumulateGrad will do the same when stashing grads for non-dense
|
||||
// params.
|
||||
bucket_view = contents.narrow(0, offset, length).view(v.sizes());
|
||||
replica.bucket_views.push_back(
|
||||
contents.narrow(0, offset, length).view(v.sizes()));
|
||||
}
|
||||
replica.bucket_views.push_back(bucket_view);
|
||||
// There are three cases to handle:
|
||||
// 1. initialize_bucket_views could be called inside communication hook,
|
||||
// bucket_view has the updated results in new tensor, just let grad point to
|
||||
// bucket_view, copy_to_bucket_view is false in this case.
|
||||
// 2. initialize_bucket_views could be called inside initialize_buckets when
|
||||
// rebuild_buckets, if grad has already been defined/calculated in previous
|
||||
// iteration, old grad needs to be copied into new bucket_view
|
||||
// and let grad point to the new bucket_view,
|
||||
// copy_to_bucket_view is true in this case.
|
||||
// 3. initialize_bucket_views could be called inside initialize_buckets
|
||||
// during construction. copy_to_bucket_view is true in this case. But mostly
|
||||
// grads are not defined during construction time, when grad is not defined,
|
||||
// do not let grad point to bucket_view, because grads should be kept as
|
||||
// being undefined for globally unused parameters.
|
||||
runGradCallbackForVariable(v, [&](auto& grad) {
|
||||
if (grad.defined() && !grad.is_alias_of(bucket_view)) {
|
||||
if (copy_to_bucket_view) {
|
||||
bucket_view.copy_(grad);
|
||||
}
|
||||
grad = bucket_view;
|
||||
// The grad is modefied and needs to be written back.
|
||||
return true;
|
||||
}
|
||||
// The grad is not modified and does not need to be written back.
|
||||
return false;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
void Reducer::prepare_forward() {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
rebuild_buckets();
|
||||
}
|
||||
|
||||
// Traverse the autograd graph starting at the specified output.
|
||||
// All parameters for which we have a pointer to their gradient accumulation
|
||||
// functions, but don't show up in the autograd graph will be marked ready for
|
||||
|
|
@ -973,14 +931,13 @@ void Reducer::finalize_bucket_dense(Bucket& bucket) {
|
|||
runGradCallbackForVariable(variable, [&](auto& grad) {
|
||||
// If a parameter is globally unused, we keep its grad untouched.
|
||||
if (!global_unused) {
|
||||
// If grad is globally used but locally unused, let grad point to
|
||||
// bucket_view
|
||||
if (!grad.defined()) {
|
||||
grad = bucket_view;
|
||||
// Creates grad according to the "Gradient Layout Contract"
|
||||
// (see torch/csrc/grad/AccumulateGrad.h)
|
||||
grad = torch::autograd::utils::clone_obey_contract(
|
||||
bucket_view, variable);
|
||||
} else {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
grad.is_alias_of(bucket_view),
|
||||
"Grad should have been pointed to bucket_view if grad is defined");
|
||||
grad.copy_(bucket_view);
|
||||
}
|
||||
// The grad is modified and needs to be written back.
|
||||
return true;
|
||||
|
|
@ -1030,7 +987,7 @@ void Reducer::finalize_backward() {
|
|||
// Reinitialize bucket_views with the future_result by following
|
||||
// the same logic in `inititalize_buckets`.
|
||||
bucket.replicas[i].bucket_views.clear();
|
||||
initialize_bucket_views(bucket.replicas[i], future_result[i], false);
|
||||
initialize_bucketviews(bucket.replicas[i], future_result[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1158,11 +1115,7 @@ void Reducer::sync_bucket_indices(
|
|||
}
|
||||
}
|
||||
|
||||
void Reducer::rebuild_buckets() {
|
||||
if (rebuilt_params_.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<std::vector<size_t>> Reducer::rebuildBuckets() {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
rebuilt_params_.size() == rebuilt_param_indices_.size(),
|
||||
"rebuilt parameter tensors size is not same as rebuilt parameter indices size.");
|
||||
|
|
@ -1188,7 +1141,7 @@ void Reducer::rebuild_buckets() {
|
|||
rebuilt_params_.clear();
|
||||
rebuilt_param_indices_.clear();
|
||||
|
||||
initialize_buckets(std::move(rebuilt_bucket_indices));
|
||||
return rebuilt_bucket_indices;
|
||||
}
|
||||
|
||||
// See Note [DDP Communication Hook]
|
||||
|
|
|
|||
|
|
@ -34,13 +34,11 @@ class Reducer {
|
|||
|
||||
~Reducer() noexcept(false);
|
||||
|
||||
// This funcation is called before forward compuation, e.g.
|
||||
// rebuild_buckets.
|
||||
// It may allocate new buckets before deallocating old buckets
|
||||
// inside rebuild_buckets. To save peak memory usage,
|
||||
// call rebuild_buckets before the peak memory usage increases
|
||||
// during forward computation.
|
||||
void prepare_forward();
|
||||
// To (re-)initialize bucket assignment, pass a list of buckets, each
|
||||
// of which is specified by a list of indices in the variables list.
|
||||
// This function performs validation that the variables within a bucket
|
||||
// all live on the same device and have the same dimensionality.
|
||||
void initialize_buckets(std::vector<std::vector<size_t>> bucket_indices);
|
||||
|
||||
// This function is called when the forward function has produced an output,
|
||||
// and the user wishes to reduce gradients in the backwards pass.
|
||||
|
|
@ -125,16 +123,9 @@ class Reducer {
|
|||
|
||||
void finalize_backward();
|
||||
|
||||
// To (re-)initialize bucket assignment, pass a list of buckets, each
|
||||
// of which is specified by a list of indices in the variables list.
|
||||
// This function performs validation that the variables within a bucket
|
||||
// all live on the same device and have the same dimensionality.
|
||||
void initialize_buckets(std::vector<std::vector<size_t>> bucket_indices);
|
||||
|
||||
// Broadcast rebuilt buckets from rank 0 to other ranks before initializing
|
||||
// the buckets
|
||||
void sync_bucket_indices(std::vector<std::vector<size_t>>& bucket_indices);
|
||||
|
||||
// Rebuild buckets based on rebuilt_params_ and rebuilt_param_indices_
|
||||
// TODO this function makes broadcast communication call and
|
||||
// could be overlapped with next forward() call, thus
|
||||
|
|
@ -144,7 +135,7 @@ class Reducer {
|
|||
// and parameter indices order may change more frequently.
|
||||
// For find_unused_parameters = false case, buckets are only rebuilt once,
|
||||
// the performance cost is negligible.
|
||||
void rebuild_buckets();
|
||||
std::vector<std::vector<size_t>> rebuildBuckets();
|
||||
|
||||
using GradCallback =
|
||||
torch::distributed::autograd::DistAutogradContext::GradCallback;
|
||||
|
|
@ -198,19 +189,11 @@ class Reducer {
|
|||
// This function is called inside `initialize_buckets` and
|
||||
// `finalize_backward`. The function call in `initialize_bucket` creates views
|
||||
// into the contents tensor for each variable's grad. Views serve as entry
|
||||
// points to refer to each grad's data of the flat contents tensor. When it is
|
||||
// called inside 'initialize_buckets', copy_to_bucket_view is true, meaning grad
|
||||
// needs to be copied into bucket_view.
|
||||
// The function call in `finalize_backward` happens only if DDP communication
|
||||
// hook was registered to recrate views with the result of `future_work`.
|
||||
// Before `finalize_backward` call, views must be cleared. In this case,
|
||||
// copy_to_bucket_view is false, meaning grad does not need to be copied into
|
||||
// bucket_view, as grad has already been mutated in bucket_view, just let grad
|
||||
// point to bucket_view here.
|
||||
void initialize_bucket_views(
|
||||
BucketReplica& replica,
|
||||
at::Tensor& contents,
|
||||
bool copy_to_bucket_view);
|
||||
// points to copy_ each grad's data in/out of the flat contents tensor. The
|
||||
// function call in `finalize_backward` happens only if DDP communication hook
|
||||
// was registered to recrate views with the result of `future_work`. Before
|
||||
// `finalize_backward` call, views must be cleared.
|
||||
void initialize_bucketviews(BucketReplica& replica, at::Tensor& contents);
|
||||
|
||||
// A bucket holds N bucket replicas (1 per model replica).
|
||||
//
|
||||
|
|
|
|||
|
|
@ -571,8 +571,6 @@ class DistributedDataParallel(Module):
|
|||
if self.require_forward_param_sync:
|
||||
self._sync_params()
|
||||
|
||||
self.reducer.prepare_forward()
|
||||
|
||||
if self.device_ids:
|
||||
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
|
||||
if len(self.device_ids) == 1:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user