mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Reland][DDP] log bucket sizes (#62625)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/62625 reland of https://github.com/pytorch/pytorch/pull/62232 which ran into a land race. Test Plan: ci Reviewed By: SciPioneer Differential Revision: D30058217 fbshipit-source-id: 1454dd481e630f3de9ec6111b3f2e18cd8976091
This commit is contained in:
parent
1630b86dd6
commit
4d5607bb25
|
|
@ -2071,7 +2071,7 @@ class ReducerTest(TestCase):
|
||||||
model = ReducerModule()
|
model = ReducerModule()
|
||||||
parameters = list(model.parameters())
|
parameters = list(model.parameters())
|
||||||
buckets = [list(range(len(parameters)))]
|
buckets = [list(range(len(parameters)))]
|
||||||
dist.Reducer([parameters], buckets, self.process_group)
|
dist.Reducer([parameters], buckets, [dist._DEFAULT_FIRST_BUCKET_BYTES], self.process_group)
|
||||||
|
|
||||||
def _create_mixed_precision_model(self):
|
def _create_mixed_precision_model(self):
|
||||||
model = ReducerModule()
|
model = ReducerModule()
|
||||||
|
|
@ -2088,7 +2088,12 @@ class ReducerTest(TestCase):
|
||||||
with self.assertRaises(RuntimeError):
|
with self.assertRaises(RuntimeError):
|
||||||
parameters = [list(model.parameters())]
|
parameters = [list(model.parameters())]
|
||||||
buckets = [list(range(len(parameters[0])))]
|
buckets = [list(range(len(parameters[0])))]
|
||||||
dist.Reducer(parameters, buckets, self.process_group)
|
dist.Reducer(
|
||||||
|
parameters,
|
||||||
|
buckets,
|
||||||
|
[dist._DEFAULT_FIRST_BUCKET_BYTES],
|
||||||
|
self.process_group
|
||||||
|
)
|
||||||
|
|
||||||
@requires_gloo()
|
@requires_gloo()
|
||||||
def test_multi_dtype_multi_bucket(self):
|
def test_multi_dtype_multi_bucket(self):
|
||||||
|
|
@ -2098,7 +2103,12 @@ class ReducerTest(TestCase):
|
||||||
range(len(parameters[0])), key=lambda i: parameters[0][i].dtype
|
range(len(parameters[0])), key=lambda i: parameters[0][i].dtype
|
||||||
)
|
)
|
||||||
buckets = [list(indices) for _, indices in group_by_dtype]
|
buckets = [list(indices) for _, indices in group_by_dtype]
|
||||||
dist.Reducer(parameters, buckets, self.process_group)
|
dist.Reducer(
|
||||||
|
parameters,
|
||||||
|
buckets,
|
||||||
|
[dist._DEFAULT_FIRST_BUCKET_BYTES for _ in buckets],
|
||||||
|
self.process_group
|
||||||
|
)
|
||||||
|
|
||||||
def _create_reducer_for_models(self, models, find_unused_parameters=False):
|
def _create_reducer_for_models(self, models, find_unused_parameters=False):
|
||||||
parameters = [list(model.parameters()) for model in models]
|
parameters = [list(model.parameters()) for model in models]
|
||||||
|
|
@ -2109,6 +2119,7 @@ class ReducerTest(TestCase):
|
||||||
return dist.Reducer(
|
return dist.Reducer(
|
||||||
parameters,
|
parameters,
|
||||||
buckets,
|
buckets,
|
||||||
|
[dist._DEFAULT_FIRST_BUCKET_BYTES for _ in range(len(buckets))],
|
||||||
self.process_group,
|
self.process_group,
|
||||||
find_unused_parameters=find_unused_parameters,
|
find_unused_parameters=find_unused_parameters,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -338,6 +338,7 @@ An enum-like class for built-in communication hooks: ``ALLREDUCE`` and ``FP16_CO
|
||||||
py::init<
|
py::init<
|
||||||
std::vector<std::vector<at::Tensor>>,
|
std::vector<std::vector<at::Tensor>>,
|
||||||
std::vector<std::vector<size_t>>,
|
std::vector<std::vector<size_t>>,
|
||||||
|
std::vector<size_t>,
|
||||||
c10::intrusive_ptr<::c10d::ProcessGroup>,
|
c10::intrusive_ptr<::c10d::ProcessGroup>,
|
||||||
std::vector<std::vector<bool>>,
|
std::vector<std::vector<bool>>,
|
||||||
int64_t,
|
int64_t,
|
||||||
|
|
@ -346,6 +347,7 @@ An enum-like class for built-in communication hooks: ``ALLREDUCE`` and ``FP16_CO
|
||||||
std::unordered_map<size_t, std::string>>(),
|
std::unordered_map<size_t, std::string>>(),
|
||||||
py::arg("replicas"),
|
py::arg("replicas"),
|
||||||
py::arg("bucket_indices"),
|
py::arg("bucket_indices"),
|
||||||
|
py::arg("per_bucket_size_limits"),
|
||||||
py::arg("process_group"),
|
py::arg("process_group"),
|
||||||
py::arg("expect_sparse_gradients") = std::vector<std::vector<bool>>(),
|
py::arg("expect_sparse_gradients") = std::vector<std::vector<bool>>(),
|
||||||
py::arg("bucket_bytes_cap") = ::c10d::kDefaultBucketBytesCap,
|
py::arg("bucket_bytes_cap") = ::c10d::kDefaultBucketBytesCap,
|
||||||
|
|
|
||||||
|
|
@ -100,6 +100,14 @@ std::vector<int> Logger::get_bucket_sizes() {
|
||||||
return bucket_sizes;
|
return bucket_sizes;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<int> Logger::get_bucket_size_limits() {
|
||||||
|
std::vector<int> bucket_size_limits;
|
||||||
|
for (const auto& bucket : reducer_->buckets_) {
|
||||||
|
bucket_size_limits.push_back(bucket.bucket_size_limit);
|
||||||
|
}
|
||||||
|
return bucket_size_limits;
|
||||||
|
}
|
||||||
|
|
||||||
// Communication hook. Empty string if not set, in which case it will not be
|
// Communication hook. Empty string if not set, in which case it will not be
|
||||||
// logged.
|
// logged.
|
||||||
void Logger::set_comm_hook(const std::string& hook) {
|
void Logger::set_comm_hook(const std::string& hook) {
|
||||||
|
|
@ -139,6 +147,9 @@ void Logger::set_construction_data_and_log(
|
||||||
// A list of bucket sizes (Bytes) calculated during construction time
|
// A list of bucket sizes (Bytes) calculated during construction time
|
||||||
ddp_logging_data_->strs_map["bucket_sizes"] =
|
ddp_logging_data_->strs_map["bucket_sizes"] =
|
||||||
c10::Join(", ", get_bucket_sizes());
|
c10::Join(", ", get_bucket_sizes());
|
||||||
|
// A list of bucket size limits (bytes) specified during construction time
|
||||||
|
ddp_logging_data_->strs_map["initial_bucket_size_limits"] =
|
||||||
|
c10::Join(", ", get_bucket_size_limits());
|
||||||
set_env_variables();
|
set_env_variables();
|
||||||
|
|
||||||
// DistributedDataParallel constructor input parameters
|
// DistributedDataParallel constructor input parameters
|
||||||
|
|
@ -223,6 +234,8 @@ void Logger::set_runtime_stats_and_log() {
|
||||||
reducer_->has_rebuilt_bucket_;
|
reducer_->has_rebuilt_bucket_;
|
||||||
ddp_logging_data_->strs_map["rebuilt_bucket_sizes"] =
|
ddp_logging_data_->strs_map["rebuilt_bucket_sizes"] =
|
||||||
c10::Join(", ", get_bucket_sizes());
|
c10::Join(", ", get_bucket_sizes());
|
||||||
|
ddp_logging_data_->strs_map["rebuilt_bucket_size_limits"] =
|
||||||
|
c10::Join(", ", get_bucket_size_limits());
|
||||||
}
|
}
|
||||||
|
|
||||||
reset_performance_stats();
|
reset_performance_stats();
|
||||||
|
|
|
||||||
|
|
@ -37,6 +37,8 @@ class TORCH_API Logger {
|
||||||
void set_parameter_stats();
|
void set_parameter_stats();
|
||||||
// Get size of each bucket (Bytes).
|
// Get size of each bucket (Bytes).
|
||||||
std::vector<int> get_bucket_sizes();
|
std::vector<int> get_bucket_sizes();
|
||||||
|
// Get bucket size limits specified during DDP construction.
|
||||||
|
std::vector<int> get_bucket_size_limits();
|
||||||
// Set comm. hook, if used
|
// Set comm. hook, if used
|
||||||
void set_comm_hook(const std::string& hook);
|
void set_comm_hook(const std::string& hook);
|
||||||
// Set running with uneven input detection (model.join() context manager)
|
// Set running with uneven input detection (model.join() context manager)
|
||||||
|
|
|
||||||
|
|
@ -109,6 +109,7 @@ C10_REGISTER_TYPED_CLASS(TimerRegistry, c10::kCPU, CpuTimer);
|
||||||
Reducer::Reducer(
|
Reducer::Reducer(
|
||||||
std::vector<std::vector<at::Tensor>> replicas,
|
std::vector<std::vector<at::Tensor>> replicas,
|
||||||
std::vector<std::vector<size_t>> bucket_indices,
|
std::vector<std::vector<size_t>> bucket_indices,
|
||||||
|
std::vector<size_t> per_bucket_size_limits,
|
||||||
c10::intrusive_ptr<c10d::ProcessGroup> process_group,
|
c10::intrusive_ptr<c10d::ProcessGroup> process_group,
|
||||||
std::vector<std::vector<bool>> expect_sparse_gradients,
|
std::vector<std::vector<bool>> expect_sparse_gradients,
|
||||||
int64_t bucket_bytes_cap,
|
int64_t bucket_bytes_cap,
|
||||||
|
|
@ -174,7 +175,8 @@ Reducer::Reducer(
|
||||||
// This can be reinitialized later after capturing runtime information.
|
// This can be reinitialized later after capturing runtime information.
|
||||||
{
|
{
|
||||||
std::lock_guard<std::mutex> lock(mutex_);
|
std::lock_guard<std::mutex> lock(mutex_);
|
||||||
initialize_buckets(std::move(bucket_indices));
|
initialize_buckets(
|
||||||
|
std::move(bucket_indices), std::move(per_bucket_size_limits));
|
||||||
}
|
}
|
||||||
|
|
||||||
// All variables are expected to have their `grad_fn` set to the gradient
|
// All variables are expected to have their `grad_fn` set to the gradient
|
||||||
|
|
@ -939,7 +941,8 @@ void Reducer::mark_bucket_ready(size_t bucket_index) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void Reducer::initialize_buckets(
|
void Reducer::initialize_buckets(
|
||||||
std::vector<std::vector<size_t>> bucket_indices) {
|
std::vector<std::vector<size_t>> bucket_indices,
|
||||||
|
std::vector<size_t> per_bucket_sizes) {
|
||||||
// If initialize_buckets is called inside DDP constructor, then
|
// If initialize_buckets is called inside DDP constructor, then
|
||||||
// it does not matter rpc context ptr is nullptr or not, as grad
|
// it does not matter rpc context ptr is nullptr or not, as grad
|
||||||
// will not be mutated.
|
// will not be mutated.
|
||||||
|
|
@ -970,8 +973,10 @@ void Reducer::initialize_buckets(
|
||||||
const auto bucket_count = bucket_indices.size();
|
const auto bucket_count = bucket_indices.size();
|
||||||
const auto replica_count = replicas_.size();
|
const auto replica_count = replicas_.size();
|
||||||
buckets_.reserve(bucket_count);
|
buckets_.reserve(bucket_count);
|
||||||
|
TORCH_INTERNAL_ASSERT(bucket_count == per_bucket_sizes.size());
|
||||||
for (const auto bucket_index : c10::irange(bucket_count)) {
|
for (const auto bucket_index : c10::irange(bucket_count)) {
|
||||||
Bucket bucket;
|
Bucket bucket;
|
||||||
|
bucket.bucket_size_limit = per_bucket_sizes[bucket_index];
|
||||||
|
|
||||||
// TODO(@pietern): Validate indices.
|
// TODO(@pietern): Validate indices.
|
||||||
// Must be non-empty, unique, and unique across buckets.
|
// Must be non-empty, unique, and unique across buckets.
|
||||||
|
|
@ -1685,7 +1690,8 @@ bool Reducer::rebuild_buckets() {
|
||||||
rebuilt_params_.clear();
|
rebuilt_params_.clear();
|
||||||
rebuilt_param_indices_.clear();
|
rebuilt_param_indices_.clear();
|
||||||
|
|
||||||
initialize_buckets(std::move(rebuilt_bucket_indices));
|
initialize_buckets(
|
||||||
|
std::move(rebuilt_bucket_indices), std::move(per_bucket_size_limits));
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1936,7 +1942,6 @@ compute_bucket_assignment_by_size(
|
||||||
c10::hash<BucketKey>>
|
c10::hash<BucketKey>>
|
||||||
bucket_size_limit_iterators;
|
bucket_size_limit_iterators;
|
||||||
|
|
||||||
|
|
||||||
// Keep vector of indices and size accumulator by tensor type and device.
|
// Keep vector of indices and size accumulator by tensor type and device.
|
||||||
std::unordered_map<BucketKey, BucketAccumulator, c10::hash<BucketKey>>
|
std::unordered_map<BucketKey, BucketAccumulator, c10::hash<BucketKey>>
|
||||||
buckets;
|
buckets;
|
||||||
|
|
@ -2005,11 +2010,14 @@ compute_bucket_assignment_by_size(
|
||||||
std::sort(
|
std::sort(
|
||||||
result.begin(),
|
result.begin(),
|
||||||
result.end(),
|
result.end(),
|
||||||
[](const std::tuple<std::vector<size_t>, size_t>& a, const std::tuple<std::vector<size_t>, size_t>& b) {
|
[](const std::tuple<std::vector<size_t>, size_t>& a,
|
||||||
|
const std::tuple<std::vector<size_t>, size_t>& b) {
|
||||||
auto indices_a = std::get<0>(a);
|
auto indices_a = std::get<0>(a);
|
||||||
auto indices_b = std::get<0>(b);
|
auto indices_b = std::get<0>(b);
|
||||||
const auto amin = std::min_element(indices_a.begin(), indices_a.end());
|
const auto amin =
|
||||||
const auto bmin = std::min_element(indices_b.begin(), indices_b.end());
|
std::min_element(indices_a.begin(), indices_a.end());
|
||||||
|
const auto bmin =
|
||||||
|
std::min_element(indices_b.begin(), indices_b.end());
|
||||||
return *amin < *bmin;
|
return *amin < *bmin;
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -69,6 +69,7 @@ class TORCH_API Reducer {
|
||||||
explicit Reducer(
|
explicit Reducer(
|
||||||
std::vector<std::vector<at::Tensor>> replicas,
|
std::vector<std::vector<at::Tensor>> replicas,
|
||||||
std::vector<std::vector<size_t>> bucket_indices,
|
std::vector<std::vector<size_t>> bucket_indices,
|
||||||
|
std::vector<size_t> per_bucket_size_limits,
|
||||||
c10::intrusive_ptr<c10d::ProcessGroup> process_group,
|
c10::intrusive_ptr<c10d::ProcessGroup> process_group,
|
||||||
std::vector<std::vector<bool>> expect_sparse_gradients,
|
std::vector<std::vector<bool>> expect_sparse_gradients,
|
||||||
int64_t bucket_bytes_cap,
|
int64_t bucket_bytes_cap,
|
||||||
|
|
@ -82,7 +83,9 @@ class TORCH_API Reducer {
|
||||||
// of which is specified by a list of indices in the variables list.
|
// of which is specified by a list of indices in the variables list.
|
||||||
// This function performs validation that the variables within a bucket
|
// This function performs validation that the variables within a bucket
|
||||||
// all live on the same device and have the same dimensionality.
|
// all live on the same device and have the same dimensionality.
|
||||||
void initialize_buckets(std::vector<std::vector<size_t>> bucket_indices);
|
void initialize_buckets(
|
||||||
|
std::vector<std::vector<size_t>> bucket_indices,
|
||||||
|
std::vector<size_t> per_bucket_sizes);
|
||||||
|
|
||||||
// This function is called when the forward function has produced an output,
|
// This function is called when the forward function has produced an output,
|
||||||
// and the user wishes to reduce gradients in the backwards pass.
|
// and the user wishes to reduce gradients in the backwards pass.
|
||||||
|
|
@ -385,6 +388,10 @@ class TORCH_API Reducer {
|
||||||
// If this bucket should expect a single sparse gradient.
|
// If this bucket should expect a single sparse gradient.
|
||||||
// Implies: replicas[i].variables.size() == 1.
|
// Implies: replicas[i].variables.size() == 1.
|
||||||
bool expect_sparse_gradient = false;
|
bool expect_sparse_gradient = false;
|
||||||
|
// "Limit" of cumulative parameter sizes that this bucket manages. It is
|
||||||
|
// actually a soft limit because we don't shard parameters across buckets
|
||||||
|
// so a single parameter may push it over the cap.
|
||||||
|
size_t bucket_size_limit;
|
||||||
};
|
};
|
||||||
|
|
||||||
std::vector<Bucket> buckets_;
|
std::vector<Bucket> buckets_;
|
||||||
|
|
|
||||||
|
|
@ -611,7 +611,7 @@ class DistributedDataParallel(Module, _Joinable):
|
||||||
# that are defined first, such that their gradients don't spill into
|
# that are defined first, such that their gradients don't spill into
|
||||||
# a much larger bucket, adding unnecessary latency after gradient
|
# a much larger bucket, adding unnecessary latency after gradient
|
||||||
# computation finishes. Experiments showed 1MB is a reasonable value.
|
# computation finishes. Experiments showed 1MB is a reasonable value.
|
||||||
bucket_indices, _ = dist._compute_bucket_assignment_by_size(
|
bucket_indices, per_bucket_size_limits = dist._compute_bucket_assignment_by_size(
|
||||||
parameters[0],
|
parameters[0],
|
||||||
[dist._DEFAULT_FIRST_BUCKET_BYTES, self.bucket_bytes_cap],
|
[dist._DEFAULT_FIRST_BUCKET_BYTES, self.bucket_bytes_cap],
|
||||||
expect_sparse_gradient[0],
|
expect_sparse_gradient[0],
|
||||||
|
|
@ -623,6 +623,7 @@ class DistributedDataParallel(Module, _Joinable):
|
||||||
self.reducer = dist.Reducer(
|
self.reducer = dist.Reducer(
|
||||||
parameters,
|
parameters,
|
||||||
list(reversed(bucket_indices)),
|
list(reversed(bucket_indices)),
|
||||||
|
list(reversed(per_bucket_size_limits)),
|
||||||
self.process_group,
|
self.process_group,
|
||||||
expect_sparse_gradient,
|
expect_sparse_gradient,
|
||||||
self.bucket_bytes_cap,
|
self.bucket_bytes_cap,
|
||||||
|
|
|
||||||
|
|
@ -4979,6 +4979,16 @@ class DistributedTest:
|
||||||
# type if it didn't exist.
|
# type if it didn't exist.
|
||||||
self.assertEqual(ddp_logging_data.get("unused_parameter_size", 0), 0)
|
self.assertEqual(ddp_logging_data.get("unused_parameter_size", 0), 0)
|
||||||
self.assertEqual(ddp_logging_data.get("has_rebuilt_buckets"), 1)
|
self.assertEqual(ddp_logging_data.get("has_rebuilt_buckets"), 1)
|
||||||
|
init_bucket_lims = ddp_logging_data.get("initial_bucket_size_limits")
|
||||||
|
rebuilt_bucket_lims = ddp_logging_data.get("rebuilt_bucket_size_limits")
|
||||||
|
self.assertEqual(
|
||||||
|
int(init_bucket_lims),
|
||||||
|
dist._DEFAULT_FIRST_BUCKET_BYTES,
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
int(rebuilt_bucket_lims),
|
||||||
|
dist._DEFAULT_FIRST_BUCKET_BYTES,
|
||||||
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
ddp_logging_data.get("rebuilt_bucket_sizes"), str(param_size)
|
ddp_logging_data.get("rebuilt_bucket_sizes"), str(param_size)
|
||||||
)
|
)
|
||||||
|
|
@ -7606,6 +7616,59 @@ class DistributedTest:
|
||||||
self.assertEqual(opt[i]["tensor"].grad_fn, None)
|
self.assertEqual(opt[i]["tensor"].grad_fn, None)
|
||||||
out.mean().backward()
|
out.mean().backward()
|
||||||
|
|
||||||
|
@skip_if_lt_x_gpu(2)
|
||||||
|
@sandcastle_skip_if(
|
||||||
|
BACKEND != "nccl" and BACKEND != "gloo",
|
||||||
|
"Only Nccl & Gloo backend support DistributedDataParallel",
|
||||||
|
)
|
||||||
|
def test_ddp_get_bucket_sizes(self):
|
||||||
|
torch.cuda.set_device(self.rank)
|
||||||
|
default_bucket_cap_mb = 25 * (1024 ** 2)
|
||||||
|
first_bucket_bytes_mb = dist._DEFAULT_FIRST_BUCKET_BYTES
|
||||||
|
|
||||||
|
class MyModel(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.model = nn.Sequential(
|
||||||
|
nn.Linear(2, 4000, bias=False),
|
||||||
|
*[nn.Linear(4000, 4000, bias=False) for _ in range(10)]
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.model(x)
|
||||||
|
|
||||||
|
ddp = torch.nn.parallel.DistributedDataParallel(
|
||||||
|
MyModel().cuda(),
|
||||||
|
device_ids=[self.rank]
|
||||||
|
)
|
||||||
|
inp = torch.randn(10, 2)
|
||||||
|
for i in range(6):
|
||||||
|
out = ddp(inp).sum()
|
||||||
|
out.backward()
|
||||||
|
logging_data = ddp._get_ddp_logging_data()
|
||||||
|
if i < 2:
|
||||||
|
bucket_size_limits = [
|
||||||
|
int(b) for b in logging_data["initial_bucket_size_limits"].split(", ")
|
||||||
|
]
|
||||||
|
# first_bucket_bytes is actually the last because we reverse
|
||||||
|
# parameter bucket order.
|
||||||
|
self.assertEqual(bucket_size_limits[-1], first_bucket_bytes_mb)
|
||||||
|
for j, bucket_size in enumerate(bucket_size_limits):
|
||||||
|
if j != len(bucket_size_limits) - 1:
|
||||||
|
self.assertEqual(bucket_size, default_bucket_cap_mb)
|
||||||
|
else:
|
||||||
|
bucket_size_limits = [
|
||||||
|
int(b) for b in logging_data["rebuilt_bucket_size_limits"].split(", ")
|
||||||
|
]
|
||||||
|
# TODO: rebuild buckets places first bucket at beginning, but
|
||||||
|
# might be better to move it to end.
|
||||||
|
self.assertEqual(
|
||||||
|
bucket_size_limits[0], first_bucket_bytes_mb
|
||||||
|
)
|
||||||
|
for j, bucket_size in enumerate(bucket_size_limits):
|
||||||
|
if j != 0:
|
||||||
|
self.assertEqual(bucket_size, default_bucket_cap_mb)
|
||||||
|
|
||||||
@skip_if_lt_x_gpu(2)
|
@skip_if_lt_x_gpu(2)
|
||||||
@sandcastle_skip_if(
|
@sandcastle_skip_if(
|
||||||
BACKEND != "nccl" and BACKEND != "gloo",
|
BACKEND != "nccl" and BACKEND != "gloo",
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user