import torch import torch.distributed as dist from torch.autograd.function import Function class SyncBatchNorm(Function): @staticmethod def forward(self, input, weight, bias, running_mean, running_var, eps, momentum, process_group, world_size): input = input.contiguous() # calculate mean/invstd for input. mean, invstd = torch.batch_norm_stats(input, eps) count = torch.full((1,), input.numel() // input.size(1), dtype=mean.dtype, device=mean.device) num_channels = input.shape[1] # C, C, 1 -> (2C + 1) combined = torch.cat([mean, invstd, count], dim=0) # world_size * (2C + 1) combined_list = [ torch.empty_like(combined) for k in range(world_size) ] # Use allgather instead of allreduce since I don't trust in-place operations .. dist.all_gather(combined_list, combined, process_group, async_op=False) combined = torch.stack(combined_list, dim=0) # world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1 mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1) size = count_all.view(-1).long().sum() if size == 1: raise ValueError('Expected more than 1 value per channel when training, got input size {}'.format(size)) # calculate global mean & invstd mean, invstd = torch.batch_norm_gather_stats_with_counts( input, mean_all, invstd_all, running_mean, running_var, momentum, eps, count_all.view(-1) ) self.save_for_backward(input, weight, mean, invstd, count_all) self.process_group = process_group # apply element-wise normalization out = torch.batch_norm_elemt(input, weight, bias, mean, invstd, eps) return out @staticmethod def backward(self, grad_output): grad_output = grad_output.contiguous() saved_input, weight, mean, invstd, count_tensor = self.saved_tensors grad_input = grad_weight = grad_bias = None process_group = self.process_group # calculate local stats as well as grad_weight / grad_bias sum_dy, sum_dy_xmu, grad_weight, grad_bias = torch.batch_norm_backward_reduce( grad_output, saved_input, mean, invstd, weight, self.needs_input_grad[0], self.needs_input_grad[1], self.needs_input_grad[2] ) if self.needs_input_grad[0]: # synchronizing stats used to calculate input gradient. # TODO: move div_ into batch_norm_backward_elemt kernel num_channels = sum_dy.shape[0] combined = torch.cat([sum_dy, sum_dy_xmu], dim=0) torch.distributed.all_reduce( combined, torch.distributed.ReduceOp.SUM, process_group, async_op=False) sum_dy, sum_dy_xmu = torch.split(combined, num_channels) divisor = count_tensor.sum() mean_dy = sum_dy / divisor mean_dy_xmu = sum_dy_xmu / divisor # backward pass for gradient calculation grad_input = torch.batch_norm_backward_elemt( grad_output, saved_input, mean, invstd, weight, mean_dy, mean_dy_xmu ) # synchronizing of grad_weight / grad_bias is not needed as distributed # training would handle all reduce. if weight is None or not self.needs_input_grad[1]: grad_weight = None if weight is None or not self.needs_input_grad[2]: grad_bias = None return grad_input, grad_weight, grad_bias, None, None, None, None, None, None class CrossMapLRN2d(Function): @staticmethod def forward(ctx, input, size, alpha=1e-4, beta=0.75, k=1): ctx.size = size ctx.alpha = alpha ctx.beta = beta ctx.k = k ctx.scale = None assert input.dim() == 4 ctx.scale = ctx.scale or input.new() output = input.new() batch_size = input.size(0) channels = input.size(1) input_height = input.size(2) input_width = input.size(3) output.resize_as_(input) ctx.scale.resize_as_(input) # use output storage as temporary buffer input_square = output torch.pow(input, 2, out=input_square) pre_pad = int((ctx.size - 1) / 2 + 1) pre_pad_crop = channels if pre_pad > channels else pre_pad scale_first = ctx.scale.select(1, 0) scale_first.zero_() # compute first feature map normalization for c in range(pre_pad_crop): scale_first.add_(input_square.select(1, c)) # reuse computations for next feature maps normalization # by adding the next feature map and removing the previous for c in range(1, channels): scale_previous = ctx.scale.select(1, c - 1) scale_current = ctx.scale.select(1, c) scale_current.copy_(scale_previous) if c < channels - pre_pad + 1: square_next = input_square.select(1, c + pre_pad - 1) scale_current.add_(square_next, alpha=1) if c > pre_pad: square_previous = input_square.select(1, c - pre_pad) scale_current.add_(square_previous, alpha=-1) ctx.scale.mul_(ctx.alpha / ctx.size).add_(ctx.k) torch.pow(ctx.scale, -ctx.beta, out=output) output.mul_(input) ctx.save_for_backward(input, output) return output @staticmethod def backward(ctx, grad_output): input, output = ctx.saved_tensors grad_input = grad_output.new() batch_size = input.size(0) channels = input.size(1) input_height = input.size(2) input_width = input.size(3) paddded_ratio = input.new(channels + ctx.size - 1, input_height, input_width) accum_ratio = input.new(input_height, input_width) cache_ratio_value = 2 * ctx.alpha * ctx.beta / ctx.size inversePrePad = int(ctx.size - (ctx.size - 1) / 2) grad_input.resize_as_(input) torch.pow(ctx.scale, -ctx.beta, out=grad_input).mul_(grad_output) paddded_ratio.zero_() padded_ratio_center = paddded_ratio.narrow(0, inversePrePad, channels) for n in range(batch_size): torch.mul(grad_output[n], output[n], out=padded_ratio_center) padded_ratio_center.div_(ctx.scale[n]) torch.sum( paddded_ratio.narrow(0, 0, ctx.size - 1), 0, keepdim=False, out=accum_ratio) for c in range(channels): accum_ratio.add_(paddded_ratio[c + ctx.size - 1]) grad_input[n][c].addcmul_(input[n][c], accum_ratio, value=-cache_ratio_value) accum_ratio.add_(paddded_ratio[c], alpha=-1) return grad_input, None, None, None, None class BackwardHookFunction(torch.autograd.Function): @staticmethod def forward(ctx, *args): ctx.mark_non_differentiable(*[arg for arg in args if not arg.requires_grad]) return args @staticmethod def backward(ctx, *args): return args