pytorch/torch/nn/modules/_functions.py
Shen Li 1884d7fbe9 Avoid CPU Sync in SyncBatchNorm When Capturing CUDA Graphs
We recently updated `SyncBatchNorm` to support empty input batches.
The new code removes stats from ranks with empty inputs. However,
this change breaks CUDA graph capture as it forces CPU sync. This
commit uses `is_current_stream_capturing()` to guard the new code
path, and only run the new code when not capturing CUA Graphs. To
support empty inputs with CUDA graph capturing, we might need to
update CUDA kernels for `batch_norm_backward_elemt` and
`batch_norm_gather_stats_with_counts`. See #78656.

Fixes #78549

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78666

Approved by: https://github.com/albanD
2022-06-03 04:32:57 +00:00

276 lines
11 KiB
Python

import torch
import torch.distributed as dist
from torch.autograd.function import Function
class SyncBatchNorm(Function):
@staticmethod
def forward(self, input, weight, bias, running_mean, running_var, eps, momentum, process_group, world_size):
if not input.is_contiguous(memory_format=torch.channels_last):
input = input.contiguous()
if weight is not None:
weight = weight.contiguous()
size = int(input.numel() // input.size(1))
if size == 1 and world_size < 2:
raise ValueError('Expected more than 1 value per channel when training, got input size {}'.format(size))
num_channels = input.shape[1]
if input.numel() > 0:
# calculate mean/invstd for input.
mean, invstd = torch.batch_norm_stats(input, eps)
count = torch.full(
(1,),
input.numel() // input.size(1),
dtype=mean.dtype,
device=mean.device
)
# C, C, 1 -> (2C + 1)
combined = torch.cat([mean, invstd, count], dim=0)
else:
# for empty input, set stats and the count to zero. The stats with
# zero count will be filtered out later when computing global mean
# & invstd, but they still needs to participate the all_gather
# collective communication to unblock other peer processes.
combined = torch.zeros(
2 * num_channels + 1,
dtype=input.dtype,
device=input.device
)
# Use allgather instead of allreduce because count could be different across
# ranks, simple all reduce op can not give correct results.
# batch_norm_gather_stats_with_counts calculates global mean & invstd based on
# all gathered mean, invstd and count.
# for nccl backend, use the optimized version of all gather.
if process_group._get_backend_name() == 'nccl':
# world_size * (2C + 1)
combined_size = combined.numel()
combined_flat = torch.empty(1,
combined_size * world_size,
dtype=combined.dtype,
device=combined.device)
dist._all_gather_base(combined_flat, combined, process_group, async_op=False)
combined = torch.reshape(combined_flat, (world_size, combined_size))
# world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1
mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1)
else:
# world_size * (2C + 1)
combined_list = [
torch.empty_like(combined) for _ in range(world_size)
]
dist.all_gather(combined_list, combined, process_group, async_op=False)
combined = torch.stack(combined_list, dim=0)
# world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1
mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1)
if not torch.cuda.is_current_stream_capturing():
# The lines below force a synchronization between CUDA and CPU, because
# the shape of the result count_all depends on the values in mask tensor.
# Such synchronizations break CUDA Graph capturing.
# See https://github.com/pytorch/pytorch/issues/78549
# FIXME: https://github.com/pytorch/pytorch/issues/78656 describes
# a better longer-term solution.
# remove stats from empty inputs
mask = count_all.squeeze(-1) >= 1
count_all = count_all[mask]
mean_all = mean_all[mask]
invstd_all = invstd_all[mask]
# calculate global mean & invstd
mean, invstd = torch.batch_norm_gather_stats_with_counts(
input,
mean_all,
invstd_all,
running_mean,
running_var,
momentum,
eps,
count_all.view(-1)
)
self.save_for_backward(input, weight, mean, invstd, count_all.to(torch.int32))
self.process_group = process_group
# apply element-wise normalization
if input.numel() > 0:
return torch.batch_norm_elemt(input, weight, bias, mean, invstd, eps)
else:
return torch.empty_like(input)
@staticmethod
def backward(self, grad_output):
if not grad_output.is_contiguous(memory_format=torch.channels_last):
grad_output = grad_output.contiguous()
saved_input, weight, mean, invstd, count_tensor = self.saved_tensors
grad_input = grad_weight = grad_bias = None
process_group = self.process_group
if saved_input.numel() > 0:
# calculate local stats as well as grad_weight / grad_bias
sum_dy, sum_dy_xmu, grad_weight, grad_bias = torch.batch_norm_backward_reduce(
grad_output,
saved_input,
mean,
invstd,
weight,
self.needs_input_grad[0],
self.needs_input_grad[1],
self.needs_input_grad[2]
)
if self.needs_input_grad[0]:
# synchronizing stats used to calculate input gradient.
num_channels = sum_dy.shape[0]
combined = torch.cat([sum_dy, sum_dy_xmu], dim=0)
torch.distributed.all_reduce(
combined, torch.distributed.ReduceOp.SUM, process_group, async_op=False)
sum_dy, sum_dy_xmu = torch.split(combined, num_channels)
# backward pass for gradient calculation
grad_input = torch.batch_norm_backward_elemt(
grad_output,
saved_input,
mean,
invstd,
weight,
sum_dy,
sum_dy_xmu,
count_tensor
)
# synchronizing of grad_weight / grad_bias is not needed as distributed
# training would handle all reduce.
if weight is None or not self.needs_input_grad[1]:
grad_weight = None
if weight is None or not self.needs_input_grad[2]:
grad_bias = None
else:
# This process got an empty input tensor in the forward pass.
# Although this process can directly set grad_input as an empty
# tensor of zeros, it still needs to participate in the collective
# communication to unblock its peers, as other peer processes might
# have recieved non-empty inputs.
num_channels = saved_input.shape[1]
if self.needs_input_grad[0]:
# launch all_reduce to unblock other peer processes
combined = torch.zeros(
2 * num_channels,
dtype=saved_input.dtype,
device=saved_input.device
)
torch.distributed.all_reduce(
combined, torch.distributed.ReduceOp.SUM, process_group, async_op=False)
# Leave grad_input, grad_weight and grad_bias as None, which will be
# interpreted by the autograd engine as Tensors full of zeros.
return grad_input, grad_weight, grad_bias, None, None, None, None, None, None
class CrossMapLRN2d(Function):
@staticmethod
def forward(ctx, input, size, alpha=1e-4, beta=0.75, k=1):
ctx.size = size
ctx.alpha = alpha
ctx.beta = beta
ctx.k = k
ctx.scale = None
assert input.dim() == 4
ctx.scale = ctx.scale or input.new()
output = input.new()
batch_size = input.size(0)
channels = input.size(1)
input_height = input.size(2)
input_width = input.size(3)
output.resize_as_(input)
ctx.scale.resize_as_(input)
# use output storage as temporary buffer
input_square = output
torch.pow(input, 2, out=input_square)
pre_pad = int((ctx.size - 1) / 2 + 1)
pre_pad_crop = channels if pre_pad > channels else pre_pad
scale_first = ctx.scale.select(1, 0)
scale_first.zero_()
# compute first feature map normalization
for c in range(pre_pad_crop):
scale_first.add_(input_square.select(1, c))
# reuse computations for next feature maps normalization
# by adding the next feature map and removing the previous
for c in range(1, channels):
scale_previous = ctx.scale.select(1, c - 1)
scale_current = ctx.scale.select(1, c)
scale_current.copy_(scale_previous)
if c < channels - pre_pad + 1:
square_next = input_square.select(1, c + pre_pad - 1)
scale_current.add_(square_next, alpha=1)
if c > pre_pad:
square_previous = input_square.select(1, c - pre_pad)
scale_current.add_(square_previous, alpha=-1)
ctx.scale.mul_(ctx.alpha / ctx.size).add_(ctx.k)
torch.pow(ctx.scale, -ctx.beta, out=output)
output.mul_(input)
ctx.save_for_backward(input, output)
return output
@staticmethod
def backward(ctx, grad_output):
input, output = ctx.saved_tensors
grad_input = grad_output.new()
batch_size = input.size(0)
channels = input.size(1)
input_height = input.size(2)
input_width = input.size(3)
paddded_ratio = input.new(channels + ctx.size - 1, input_height,
input_width)
accum_ratio = input.new(input_height, input_width)
cache_ratio_value = 2 * ctx.alpha * ctx.beta / ctx.size
inversePrePad = int(ctx.size - (ctx.size - 1) / 2)
grad_input.resize_as_(input)
torch.pow(ctx.scale, -ctx.beta, out=grad_input).mul_(grad_output)
paddded_ratio.zero_()
padded_ratio_center = paddded_ratio.narrow(0, inversePrePad,
channels)
for n in range(batch_size):
torch.mul(grad_output[n], output[n], out=padded_ratio_center)
padded_ratio_center.div_(ctx.scale[n])
torch.sum(
paddded_ratio.narrow(0, 0, ctx.size - 1), 0, keepdim=False, out=accum_ratio)
for c in range(channels):
accum_ratio.add_(paddded_ratio[c + ctx.size - 1])
grad_input[n][c].addcmul_(input[n][c], accum_ratio, value=-cache_ratio_value)
accum_ratio.add_(paddded_ratio[c], alpha=-1)
return grad_input, None, None, None, None
class BackwardHookFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, *args):
ctx.mark_non_differentiable(*[arg for arg in args if not arg.requires_grad])
return args
@staticmethod
def backward(ctx, *args):
return args