mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
We recently updated `SyncBatchNorm` to support empty input batches. The new code removes stats from ranks with empty inputs. However, this change breaks CUDA graph capture as it forces CPU sync. This commit uses `is_current_stream_capturing()` to guard the new code path, and only run the new code when not capturing CUA Graphs. To support empty inputs with CUDA graph capturing, we might need to update CUDA kernels for `batch_norm_backward_elemt` and `batch_norm_gather_stats_with_counts`. See #78656. Fixes #78549 Pull Request resolved: https://github.com/pytorch/pytorch/pull/78666 Approved by: https://github.com/albanD
276 lines
11 KiB
Python
276 lines
11 KiB
Python
import torch
|
|
import torch.distributed as dist
|
|
|
|
from torch.autograd.function import Function
|
|
|
|
class SyncBatchNorm(Function):
|
|
|
|
@staticmethod
|
|
def forward(self, input, weight, bias, running_mean, running_var, eps, momentum, process_group, world_size):
|
|
if not input.is_contiguous(memory_format=torch.channels_last):
|
|
input = input.contiguous()
|
|
if weight is not None:
|
|
weight = weight.contiguous()
|
|
|
|
size = int(input.numel() // input.size(1))
|
|
if size == 1 and world_size < 2:
|
|
raise ValueError('Expected more than 1 value per channel when training, got input size {}'.format(size))
|
|
|
|
num_channels = input.shape[1]
|
|
if input.numel() > 0:
|
|
# calculate mean/invstd for input.
|
|
mean, invstd = torch.batch_norm_stats(input, eps)
|
|
|
|
count = torch.full(
|
|
(1,),
|
|
input.numel() // input.size(1),
|
|
dtype=mean.dtype,
|
|
device=mean.device
|
|
)
|
|
|
|
# C, C, 1 -> (2C + 1)
|
|
combined = torch.cat([mean, invstd, count], dim=0)
|
|
else:
|
|
# for empty input, set stats and the count to zero. The stats with
|
|
# zero count will be filtered out later when computing global mean
|
|
# & invstd, but they still needs to participate the all_gather
|
|
# collective communication to unblock other peer processes.
|
|
combined = torch.zeros(
|
|
2 * num_channels + 1,
|
|
dtype=input.dtype,
|
|
device=input.device
|
|
)
|
|
|
|
# Use allgather instead of allreduce because count could be different across
|
|
# ranks, simple all reduce op can not give correct results.
|
|
# batch_norm_gather_stats_with_counts calculates global mean & invstd based on
|
|
# all gathered mean, invstd and count.
|
|
# for nccl backend, use the optimized version of all gather.
|
|
if process_group._get_backend_name() == 'nccl':
|
|
# world_size * (2C + 1)
|
|
combined_size = combined.numel()
|
|
combined_flat = torch.empty(1,
|
|
combined_size * world_size,
|
|
dtype=combined.dtype,
|
|
device=combined.device)
|
|
dist._all_gather_base(combined_flat, combined, process_group, async_op=False)
|
|
combined = torch.reshape(combined_flat, (world_size, combined_size))
|
|
# world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1
|
|
mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1)
|
|
else:
|
|
# world_size * (2C + 1)
|
|
combined_list = [
|
|
torch.empty_like(combined) for _ in range(world_size)
|
|
]
|
|
dist.all_gather(combined_list, combined, process_group, async_op=False)
|
|
combined = torch.stack(combined_list, dim=0)
|
|
# world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1
|
|
mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1)
|
|
|
|
if not torch.cuda.is_current_stream_capturing():
|
|
# The lines below force a synchronization between CUDA and CPU, because
|
|
# the shape of the result count_all depends on the values in mask tensor.
|
|
# Such synchronizations break CUDA Graph capturing.
|
|
# See https://github.com/pytorch/pytorch/issues/78549
|
|
# FIXME: https://github.com/pytorch/pytorch/issues/78656 describes
|
|
# a better longer-term solution.
|
|
|
|
# remove stats from empty inputs
|
|
mask = count_all.squeeze(-1) >= 1
|
|
count_all = count_all[mask]
|
|
mean_all = mean_all[mask]
|
|
invstd_all = invstd_all[mask]
|
|
|
|
# calculate global mean & invstd
|
|
mean, invstd = torch.batch_norm_gather_stats_with_counts(
|
|
input,
|
|
mean_all,
|
|
invstd_all,
|
|
running_mean,
|
|
running_var,
|
|
momentum,
|
|
eps,
|
|
count_all.view(-1)
|
|
)
|
|
|
|
self.save_for_backward(input, weight, mean, invstd, count_all.to(torch.int32))
|
|
self.process_group = process_group
|
|
|
|
# apply element-wise normalization
|
|
if input.numel() > 0:
|
|
return torch.batch_norm_elemt(input, weight, bias, mean, invstd, eps)
|
|
else:
|
|
return torch.empty_like(input)
|
|
|
|
@staticmethod
|
|
def backward(self, grad_output):
|
|
if not grad_output.is_contiguous(memory_format=torch.channels_last):
|
|
grad_output = grad_output.contiguous()
|
|
saved_input, weight, mean, invstd, count_tensor = self.saved_tensors
|
|
grad_input = grad_weight = grad_bias = None
|
|
process_group = self.process_group
|
|
|
|
if saved_input.numel() > 0:
|
|
# calculate local stats as well as grad_weight / grad_bias
|
|
sum_dy, sum_dy_xmu, grad_weight, grad_bias = torch.batch_norm_backward_reduce(
|
|
grad_output,
|
|
saved_input,
|
|
mean,
|
|
invstd,
|
|
weight,
|
|
self.needs_input_grad[0],
|
|
self.needs_input_grad[1],
|
|
self.needs_input_grad[2]
|
|
)
|
|
|
|
if self.needs_input_grad[0]:
|
|
# synchronizing stats used to calculate input gradient.
|
|
num_channels = sum_dy.shape[0]
|
|
combined = torch.cat([sum_dy, sum_dy_xmu], dim=0)
|
|
torch.distributed.all_reduce(
|
|
combined, torch.distributed.ReduceOp.SUM, process_group, async_op=False)
|
|
sum_dy, sum_dy_xmu = torch.split(combined, num_channels)
|
|
|
|
# backward pass for gradient calculation
|
|
grad_input = torch.batch_norm_backward_elemt(
|
|
grad_output,
|
|
saved_input,
|
|
mean,
|
|
invstd,
|
|
weight,
|
|
sum_dy,
|
|
sum_dy_xmu,
|
|
count_tensor
|
|
)
|
|
# synchronizing of grad_weight / grad_bias is not needed as distributed
|
|
# training would handle all reduce.
|
|
if weight is None or not self.needs_input_grad[1]:
|
|
grad_weight = None
|
|
|
|
if weight is None or not self.needs_input_grad[2]:
|
|
grad_bias = None
|
|
else:
|
|
# This process got an empty input tensor in the forward pass.
|
|
# Although this process can directly set grad_input as an empty
|
|
# tensor of zeros, it still needs to participate in the collective
|
|
# communication to unblock its peers, as other peer processes might
|
|
# have recieved non-empty inputs.
|
|
num_channels = saved_input.shape[1]
|
|
if self.needs_input_grad[0]:
|
|
# launch all_reduce to unblock other peer processes
|
|
combined = torch.zeros(
|
|
2 * num_channels,
|
|
dtype=saved_input.dtype,
|
|
device=saved_input.device
|
|
)
|
|
torch.distributed.all_reduce(
|
|
combined, torch.distributed.ReduceOp.SUM, process_group, async_op=False)
|
|
|
|
# Leave grad_input, grad_weight and grad_bias as None, which will be
|
|
# interpreted by the autograd engine as Tensors full of zeros.
|
|
|
|
return grad_input, grad_weight, grad_bias, None, None, None, None, None, None
|
|
|
|
class CrossMapLRN2d(Function):
|
|
|
|
@staticmethod
|
|
def forward(ctx, input, size, alpha=1e-4, beta=0.75, k=1):
|
|
ctx.size = size
|
|
ctx.alpha = alpha
|
|
ctx.beta = beta
|
|
ctx.k = k
|
|
ctx.scale = None
|
|
|
|
assert input.dim() == 4
|
|
|
|
ctx.scale = ctx.scale or input.new()
|
|
output = input.new()
|
|
|
|
batch_size = input.size(0)
|
|
channels = input.size(1)
|
|
input_height = input.size(2)
|
|
input_width = input.size(3)
|
|
|
|
output.resize_as_(input)
|
|
ctx.scale.resize_as_(input)
|
|
|
|
# use output storage as temporary buffer
|
|
input_square = output
|
|
torch.pow(input, 2, out=input_square)
|
|
|
|
pre_pad = int((ctx.size - 1) / 2 + 1)
|
|
pre_pad_crop = channels if pre_pad > channels else pre_pad
|
|
|
|
scale_first = ctx.scale.select(1, 0)
|
|
scale_first.zero_()
|
|
# compute first feature map normalization
|
|
for c in range(pre_pad_crop):
|
|
scale_first.add_(input_square.select(1, c))
|
|
|
|
# reuse computations for next feature maps normalization
|
|
# by adding the next feature map and removing the previous
|
|
for c in range(1, channels):
|
|
scale_previous = ctx.scale.select(1, c - 1)
|
|
scale_current = ctx.scale.select(1, c)
|
|
scale_current.copy_(scale_previous)
|
|
if c < channels - pre_pad + 1:
|
|
square_next = input_square.select(1, c + pre_pad - 1)
|
|
scale_current.add_(square_next, alpha=1)
|
|
|
|
if c > pre_pad:
|
|
square_previous = input_square.select(1, c - pre_pad)
|
|
scale_current.add_(square_previous, alpha=-1)
|
|
|
|
ctx.scale.mul_(ctx.alpha / ctx.size).add_(ctx.k)
|
|
|
|
torch.pow(ctx.scale, -ctx.beta, out=output)
|
|
output.mul_(input)
|
|
|
|
ctx.save_for_backward(input, output)
|
|
return output
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
input, output = ctx.saved_tensors
|
|
grad_input = grad_output.new()
|
|
|
|
batch_size = input.size(0)
|
|
channels = input.size(1)
|
|
input_height = input.size(2)
|
|
input_width = input.size(3)
|
|
|
|
paddded_ratio = input.new(channels + ctx.size - 1, input_height,
|
|
input_width)
|
|
accum_ratio = input.new(input_height, input_width)
|
|
|
|
cache_ratio_value = 2 * ctx.alpha * ctx.beta / ctx.size
|
|
inversePrePad = int(ctx.size - (ctx.size - 1) / 2)
|
|
|
|
grad_input.resize_as_(input)
|
|
torch.pow(ctx.scale, -ctx.beta, out=grad_input).mul_(grad_output)
|
|
|
|
paddded_ratio.zero_()
|
|
padded_ratio_center = paddded_ratio.narrow(0, inversePrePad,
|
|
channels)
|
|
for n in range(batch_size):
|
|
torch.mul(grad_output[n], output[n], out=padded_ratio_center)
|
|
padded_ratio_center.div_(ctx.scale[n])
|
|
torch.sum(
|
|
paddded_ratio.narrow(0, 0, ctx.size - 1), 0, keepdim=False, out=accum_ratio)
|
|
for c in range(channels):
|
|
accum_ratio.add_(paddded_ratio[c + ctx.size - 1])
|
|
grad_input[n][c].addcmul_(input[n][c], accum_ratio, value=-cache_ratio_value)
|
|
accum_ratio.add_(paddded_ratio[c], alpha=-1)
|
|
|
|
return grad_input, None, None, None, None
|
|
|
|
class BackwardHookFunction(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, *args):
|
|
ctx.mark_non_differentiable(*[arg for arg in args if not arg.requires_grad])
|
|
return args
|
|
|
|
@staticmethod
|
|
def backward(ctx, *args):
|
|
return args
|