mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: In `batch_norm_gather_stats_with_counts_cuda` use `input.scalar_type()` if `running_mean` is not defined In `SyncBatchNorm` forward function create count tensor with `torch.float32` type if `running_mean` is None Fix a few typos Pull Request resolved: https://github.com/pytorch/pytorch/pull/50126 Test Plan: ``` python -c "import torch;print(torch.batch_norm_gather_stats_with_counts( torch.randn(1, 3, 3, 3, device='cuda'), mean = torch.ones(2, 3, device='cuda'), invstd = torch.ones(2, 3, device='cuda'), running_mean = None, running_var = None , momentum = .1, eps = 1e-5, counts = torch.ones(2, device='cuda')))" ``` Fixes https://github.com/pytorch/pytorch/issues/49730 Reviewed By: ngimel Differential Revision: D25797930 Pulled By: malfet fbshipit-source-id: 22a91e3969b5e9bbb7969d9cc70b45013a42fe83
210 lines
7.3 KiB
Python
210 lines
7.3 KiB
Python
import torch
|
|
import torch.distributed as dist
|
|
|
|
from torch.autograd.function import Function
|
|
|
|
class SyncBatchNorm(Function):
|
|
|
|
@staticmethod
|
|
def forward(self, input, weight, bias, running_mean, running_var, eps, momentum, process_group, world_size):
|
|
input = input.contiguous()
|
|
|
|
# calculate mean/invstd for input.
|
|
mean, invstd = torch.batch_norm_stats(input, eps)
|
|
|
|
count = torch.full((1,), input.numel() // input.size(1),
|
|
dtype=mean.dtype,
|
|
device=mean.device)
|
|
|
|
|
|
num_channels = input.shape[1]
|
|
# C, C, 1 -> (2C + 1)
|
|
combined = torch.cat([mean, invstd, count], dim=0)
|
|
# world_size * (2C + 1)
|
|
combined_list = [
|
|
torch.empty_like(combined) for k in range(world_size)
|
|
]
|
|
# Use allgather instead of allreduce since I don't trust in-place operations ..
|
|
dist.all_gather(combined_list, combined, process_group, async_op=False)
|
|
combined = torch.stack(combined_list, dim=0)
|
|
# world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1
|
|
mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1)
|
|
|
|
size = count_all.view(-1).long().sum()
|
|
if size == 1:
|
|
raise ValueError('Expected more than 1 value per channel when training, got input size {}'.format(size))
|
|
|
|
# calculate global mean & invstd
|
|
mean, invstd = torch.batch_norm_gather_stats_with_counts(
|
|
input,
|
|
mean_all,
|
|
invstd_all,
|
|
running_mean,
|
|
running_var,
|
|
momentum,
|
|
eps,
|
|
count_all.view(-1)
|
|
)
|
|
|
|
self.save_for_backward(input, weight, mean, invstd, count_all)
|
|
self.process_group = process_group
|
|
|
|
# apply element-wise normalization
|
|
out = torch.batch_norm_elemt(input, weight, bias, mean, invstd, eps)
|
|
return out
|
|
|
|
@staticmethod
|
|
def backward(self, grad_output):
|
|
grad_output = grad_output.contiguous()
|
|
saved_input, weight, mean, invstd, count_tensor = self.saved_tensors
|
|
grad_input = grad_weight = grad_bias = None
|
|
process_group = self.process_group
|
|
|
|
# calculate local stats as well as grad_weight / grad_bias
|
|
sum_dy, sum_dy_xmu, grad_weight, grad_bias = torch.batch_norm_backward_reduce(
|
|
grad_output,
|
|
saved_input,
|
|
mean,
|
|
invstd,
|
|
weight,
|
|
self.needs_input_grad[0],
|
|
self.needs_input_grad[1],
|
|
self.needs_input_grad[2]
|
|
)
|
|
|
|
if self.needs_input_grad[0]:
|
|
# synchronizing stats used to calculate input gradient.
|
|
# TODO: move div_ into batch_norm_backward_elemt kernel
|
|
num_channels = sum_dy.shape[0]
|
|
combined = torch.cat([sum_dy, sum_dy_xmu], dim=0)
|
|
torch.distributed.all_reduce(
|
|
combined, torch.distributed.ReduceOp.SUM, process_group, async_op=False)
|
|
sum_dy, sum_dy_xmu = torch.split(combined, num_channels)
|
|
|
|
divisor = count_tensor.sum()
|
|
mean_dy = sum_dy / divisor
|
|
mean_dy_xmu = sum_dy_xmu / divisor
|
|
# backward pass for gradient calculation
|
|
grad_input = torch.batch_norm_backward_elemt(
|
|
grad_output,
|
|
saved_input,
|
|
mean,
|
|
invstd,
|
|
weight,
|
|
mean_dy,
|
|
mean_dy_xmu
|
|
)
|
|
|
|
# synchronizing of grad_weight / grad_bias is not needed as distributed
|
|
# training would handle all reduce.
|
|
if weight is None or not self.needs_input_grad[1]:
|
|
grad_weight = None
|
|
|
|
if weight is None or not self.needs_input_grad[2]:
|
|
grad_bias = None
|
|
|
|
return grad_input, grad_weight, grad_bias, None, None, None, None, None, None
|
|
|
|
class CrossMapLRN2d(Function):
|
|
|
|
@staticmethod
|
|
def forward(ctx, input, size, alpha=1e-4, beta=0.75, k=1):
|
|
ctx.size = size
|
|
ctx.alpha = alpha
|
|
ctx.beta = beta
|
|
ctx.k = k
|
|
ctx.scale = None
|
|
|
|
assert input.dim() == 4
|
|
|
|
ctx.scale = ctx.scale or input.new()
|
|
output = input.new()
|
|
|
|
batch_size = input.size(0)
|
|
channels = input.size(1)
|
|
input_height = input.size(2)
|
|
input_width = input.size(3)
|
|
|
|
output.resize_as_(input)
|
|
ctx.scale.resize_as_(input)
|
|
|
|
# use output storage as temporary buffer
|
|
input_square = output
|
|
torch.pow(input, 2, out=input_square)
|
|
|
|
pre_pad = int((ctx.size - 1) / 2 + 1)
|
|
pre_pad_crop = channels if pre_pad > channels else pre_pad
|
|
|
|
scale_first = ctx.scale.select(1, 0)
|
|
scale_first.zero_()
|
|
# compute first feature map normalization
|
|
for c in range(pre_pad_crop):
|
|
scale_first.add_(input_square.select(1, c))
|
|
|
|
# reuse computations for next feature maps normalization
|
|
# by adding the next feature map and removing the previous
|
|
for c in range(1, channels):
|
|
scale_previous = ctx.scale.select(1, c - 1)
|
|
scale_current = ctx.scale.select(1, c)
|
|
scale_current.copy_(scale_previous)
|
|
if c < channels - pre_pad + 1:
|
|
square_next = input_square.select(1, c + pre_pad - 1)
|
|
scale_current.add_(square_next, alpha=1)
|
|
|
|
if c > pre_pad:
|
|
square_previous = input_square.select(1, c - pre_pad)
|
|
scale_current.add_(square_previous, alpha=-1)
|
|
|
|
ctx.scale.mul_(ctx.alpha / ctx.size).add_(ctx.k)
|
|
|
|
torch.pow(ctx.scale, -ctx.beta, out=output)
|
|
output.mul_(input)
|
|
|
|
ctx.save_for_backward(input, output)
|
|
return output
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
input, output = ctx.saved_tensors
|
|
grad_input = grad_output.new()
|
|
|
|
batch_size = input.size(0)
|
|
channels = input.size(1)
|
|
input_height = input.size(2)
|
|
input_width = input.size(3)
|
|
|
|
paddded_ratio = input.new(channels + ctx.size - 1, input_height,
|
|
input_width)
|
|
accum_ratio = input.new(input_height, input_width)
|
|
|
|
cache_ratio_value = 2 * ctx.alpha * ctx.beta / ctx.size
|
|
inversePrePad = int(ctx.size - (ctx.size - 1) / 2)
|
|
|
|
grad_input.resize_as_(input)
|
|
torch.pow(ctx.scale, -ctx.beta, out=grad_input).mul_(grad_output)
|
|
|
|
paddded_ratio.zero_()
|
|
padded_ratio_center = paddded_ratio.narrow(0, inversePrePad,
|
|
channels)
|
|
for n in range(batch_size):
|
|
torch.mul(grad_output[n], output[n], out=padded_ratio_center)
|
|
padded_ratio_center.div_(ctx.scale[n])
|
|
torch.sum(
|
|
paddded_ratio.narrow(0, 0, ctx.size - 1), 0, keepdim=False, out=accum_ratio)
|
|
for c in range(channels):
|
|
accum_ratio.add_(paddded_ratio[c + ctx.size - 1])
|
|
grad_input[n][c].addcmul_(input[n][c], accum_ratio, value=-cache_ratio_value)
|
|
accum_ratio.add_(paddded_ratio[c], alpha=-1)
|
|
|
|
return grad_input, None, None, None, None
|
|
|
|
class BackwardHookFunction(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, *args):
|
|
ctx.mark_non_differentiable(*[arg for arg in args if not arg.requires_grad])
|
|
return args
|
|
|
|
@staticmethod
|
|
def backward(ctx, *args):
|
|
return args
|