import torch from torch import Tensor from torch._decomp import register_decomposition from enum import Enum from typing import Tuple, Optional, List, Callable import torch.nn.functional as F import functools from torch.utils._pytree import tree_map, tree_flatten import torch._prims.utils as utils from torch._prims.wrappers import out_wrapper_multi # None of these functions are publicly accessible; get at them # from torch._decomps __all__: List[str] = [] aten = torch.ops.aten class Reduction(Enum): NONE = 0 MEAN = 1 SUM = 2 # This wraps a decomposition and performs various type promotion logic within it, depending on the strategy provided # We're currently re-using ELEMENTWISE_TYPE_PROMOTION_KIND, although some of the usages are on non-elementwise ops # Will need to validate the non-elementwise uses def type_casts(f: Callable, type_promotion: utils.ELEMENTWISE_TYPE_PROMOTION_KIND): @functools.wraps(f) def inner(*args, **kwargs): flat_args = [x for x in tree_flatten((args, kwargs))[0] if isinstance(x, Tensor)] computation_dtype, result_dtype = utils.elementwise_dtypes(*flat_args, type_promotion_kind=type_promotion) # TODO: pretty sure this is not quite right def increase_prec(x): if isinstance(x, Tensor): return x.to(computation_dtype) else: return x def decrease_prec(x): if isinstance(x, Tensor): return x.to(result_dtype) else: return x r = f(*tree_map(increase_prec, args), **tree_map(increase_prec, kwargs)) return tree_map(decrease_prec, r) return inner pw_cast_for_opmath = functools.partial(type_casts, type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT) reduction_complex_to_real = functools.partial(type_casts, type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT) pw_cast_for_int_to_real = functools.partial(type_casts, type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) # This expands x until x.dim() == dim. Might be useful as an operator def _unsqueeze_to_dim(x: Tensor, dim: int): for _ in range(dim - x.dim()): x = x.unsqueeze(-1) return x @register_decomposition(aten.tanh_backward) @pw_cast_for_opmath def tanh_backward(out_grad: Tensor, y: Tensor): return out_grad * (1 - y * y).conj_physical() @register_decomposition(aten.sigmoid_backward) @pw_cast_for_opmath def sigmoid_backward(out_grad: Tensor, y: Tensor): return out_grad * (y * (1 - y)).conj_physical() @register_decomposition(aten.softplus_backward) @pw_cast_for_opmath def softplus_backward(out_grad: Tensor, x: Tensor, beta: float, threshold: float): z = (x * beta).exp() return torch.where((x * beta) > threshold, out_grad, out_grad * z / (z + 1.0)) @register_decomposition(aten.elu) @pw_cast_for_opmath def elu( self: Tensor, alpha: float = 1, scale: float = 1, input_scale: float = 1 ) -> Tensor: negcoef = alpha * scale poscoef = scale negiptcoef = input_scale return torch.where( self > 0, self * poscoef, (torch.exp(self * negiptcoef) - 1) * negcoef ) @register_decomposition(aten.elu_backward) @pw_cast_for_opmath def elu_backward( grad_output: Tensor, alpha: float, scale: float, input_scale: float, is_result: bool, self_or_result: Tensor, ): negcoef = alpha * scale poscoef = scale negiptcoef = input_scale if is_result: return torch.where( self_or_result <= 0, grad_output * negiptcoef * (self_or_result + negcoef), self_or_result * poscoef, ) else: return torch.where( self_or_result <= 0, grad_output * negiptcoef * negcoef * torch.exp(self_or_result * negiptcoef), grad_output * poscoef, ) @register_decomposition(aten.hardsigmoid) @pw_cast_for_opmath def hardsigmoid(self: Tensor) -> Tensor: return torch.clamp(torch.clamp(self + 3, min=0), max=6) / 6 @register_decomposition(aten.hardsigmoid_backward) @pw_cast_for_opmath def hardsigmoid_backward(grad_output: Tensor, self: Tensor): return torch.where( (self > -3.0) & (self < 3.0), grad_output * (1.0 / 6.0), grad_output.new_zeros(()), ) @register_decomposition(aten.hardtanh) @pw_cast_for_opmath def hardtanh(self: Tensor, min_val: float = -1, max_val: float = 1) -> Tensor: return torch.clamp(self, min_val, max_val) @register_decomposition(aten.hardtanh_backward) @pw_cast_for_opmath def hardtanh_backward( grad_output: Tensor, self: Tensor, min_val: float, max_val: float ): return torch.where( (self <= min_val) | (self >= max_val), grad_output.new_zeros(()), grad_output ) @register_decomposition(aten.hardshrink_backward) @pw_cast_for_opmath def hardshrink_backward(grad_out: Tensor, self: Tensor, lambd: float): return torch.where( (self >= -lambd) & (self <= lambd), grad_out.new_zeros(()), grad_out ) @register_decomposition(aten.hardswish) @pw_cast_for_opmath def hardswish(self: Tensor) -> Tensor: return self * torch.clamp(torch.clamp(self + 3, min=0), max=6) / 6 @register_decomposition(aten.hardswish_backward) @pw_cast_for_opmath def hardswish_backward(grad_output: Tensor, self: Tensor) -> Tensor: return torch.where( self < -3, grad_output.new_zeros(()), torch.where(self <= 3, grad_output * ((self / 3) + 0.5), grad_output), ) @register_decomposition(aten.threshold_backward) @pw_cast_for_opmath def threshold_backward(grad_output: Tensor, self: Tensor, threshold: float): return torch.where(self <= threshold, grad_output.new_zeros(()), grad_output) @register_decomposition(aten.leaky_relu_backward) @pw_cast_for_opmath def leaky_relu_backward( grad_output: Tensor, self: Tensor, negative_slope: float, self_is_result: bool ): return torch.where(self > 0, grad_output, grad_output * negative_slope) @register_decomposition(aten.gelu) @pw_cast_for_opmath def gelu(self: Tensor, approximate: str = 'none') -> Tensor: M_SQRT2 = 1.41421356237309504880 M_SQRT1_2 = 0.70710678118654752440 M_2_SQRTPI = 1.12837916709551257390 if approximate == 'tanh': kBeta = M_SQRT2 * M_2_SQRTPI * 0.5 kKappa = 0.044715 x_cube = self * self * self inner = kBeta * (self + kKappa * x_cube) return 0.5 * self * (1 + torch.tanh(inner)) else: kAlpha = M_SQRT1_2 return self * 0.5 * (1 + torch.erf(self * kAlpha)) @register_decomposition(aten.gelu_backward) @pw_cast_for_opmath def gelu_backward(grad: Tensor, self: Tensor, approximate: str = "none"): M_SQRT2 = 1.41421356237309504880 M_SQRT1_2 = 0.70710678118654752440 M_2_SQRTPI = 1.12837916709551257390 if approximate == 'tanh': kBeta = M_SQRT2 * M_2_SQRTPI * 0.5 kKappa = 0.044715 x_sq = self * self x_cube = x_sq * self inner = kBeta * (self + kKappa * x_cube) tanh_inner = torch.tanh(inner) left = 0.5 * self right = 1 + tanh_inner left_derivative = 0.5 * right tanh_derivative = 1 - tanh_inner * tanh_inner inner_derivative = kBeta * (1 + 3 * kKappa * x_sq) right_derivative = left * tanh_derivative * inner_derivative return grad * (left_derivative + right_derivative) else: kAlpha = M_SQRT1_2 kBeta = M_2_SQRTPI * M_SQRT1_2 * 0.5 cdf = 0.5 * (1 + torch.erf(self * kAlpha)) pdf = kBeta * torch.exp(self * self * -0.5) return grad * (cdf + self * pdf) @register_decomposition(aten.mish_backward) @pw_cast_for_opmath def mish_backward(grad_output: Tensor, input: Tensor): input_tanh_softplus = torch.tanh(F.softplus(input)) input_sigmoid = torch.sigmoid(input) out = input * input_sigmoid * (1 - input_tanh_softplus * input_tanh_softplus) return grad_output * (input_tanh_softplus + out) @register_decomposition(aten.silu) @pw_cast_for_opmath def silu(self: Tensor) -> Tensor: return self * torch.sigmoid(self) @register_decomposition(aten.silu_backward) @pw_cast_for_opmath def silu_backward(grad_output: Tensor, self: Tensor) -> Tensor: sigmoid = 1 / (1 + torch.exp(-self)) return grad_output * sigmoid * (1 + self * (1 - sigmoid)) @register_decomposition(aten.softshrink_backward) def softshrink_backward(grad_output: Tensor, self: Tensor, lambd: float) -> Tensor: return torch.where( (self >= -lambd) & (self <= lambd), grad_output.new_zeros(()), grad_output ) @register_decomposition(aten.prelu_backward) @pw_cast_for_opmath def prelu_backward( grad_output: Tensor, self: Tensor, weight: Tensor ) -> Tuple[Tensor, Tensor]: # Logic is more complicated than I would like. Basically, weight can either # be a scalar or a vector of size [C], and in the forward pass it's # broadcast against [N, C, ...]. So now, we need to do the corresponding # reduction, which is harder than we'd like... cur_weight = weight for _ in range(2, grad_output.dim()): cur_weight = cur_weight.unsqueeze(-1) input_grad = torch.where(self > 0, grad_output, cur_weight * grad_output) weight_grad_collector = torch.where( self > 0, grad_output.new_zeros(()), self * grad_output ) out = weight_grad_collector.sum_to_size(cur_weight.shape) while out.dim() > weight.dim(): out = out.squeeze(-1) return (input_grad, out) @register_decomposition(aten.rrelu_with_noise_backward) @pw_cast_for_opmath def rrelu_with_noise_backward( grad_output: Tensor, self: Tensor, noise: Tensor, lower: float, upper: float, training: bool, self_is_result: bool, ) -> Tensor: if training and upper - lower > 1e-6: return grad_output.mul(noise) else: negative_slope = (lower + upper) / 2 return aten.leaky_relu_backward(grad_output, self, negative_slope, self_is_result) @register_decomposition(aten.log_sigmoid_backward) @pw_cast_for_opmath def log_sigmoid_backward(grad_output: Tensor, self: Tensor, buffer: Tensor) -> Tensor: in_negative = self < 0 max_deriv = torch.where(in_negative, 1, 0) sign = torch.where(in_negative, 1, -1) z = torch.exp(-torch.abs(self)) return grad_output * (max_deriv - sign * (z / (1 + z))) # CPU has a special formula that uses buffer, but disabled for convenience sake # return (max_deriv - sign * (buffer / (1 + buffer))) * grad_output def apply_loss_reduction(loss: Tensor, reduction: int): if reduction == Reduction.MEAN.value: return torch.mean(loss) elif reduction == Reduction.SUM.value: return torch.sum(loss) else: return loss def to_real_dtype(dtype: torch.dtype): if dtype == torch.complex32: return torch.float16 elif dtype == torch.complex64: return torch.float32 elif dtype == torch.complex128: return torch.float64 # TODO: None of these loss castings are quite correct, see # https://github.com/pytorch/pytorch/issues/76870. Also, the ATen kernels # perform the pointwise portion in opmath, but don't maintain it between the # pointwise portion and the reduction @register_decomposition(aten.l1_loss) def l1_loss( self: Tensor, target: Tensor, reduction: int = Reduction.MEAN.value ) -> Tensor: loss = (self - target).abs() # PyTorch semantics result in the output of l1_loss having the corresponding # real dtype to self. This may not happen without explicit casting if say # self: complex64 and target: float64, which results in loss: float64 float_type = to_real_dtype(self.dtype) return apply_loss_reduction(loss, reduction).to(float_type) @register_decomposition(aten.l1_loss_backward) @pw_cast_for_opmath def l1_loss_backward( grad_output: Tensor, self: Tensor, target: Tensor, reduction: int = Reduction.MEAN.value, ): sign = torch.sign(self - target) norm = sign / self.numel() if reduction == Reduction.MEAN.value else sign return grad_output * norm @register_decomposition(aten.mse_loss) @pw_cast_for_opmath def mse_loss( self: Tensor, target: Tensor, reduction: int = Reduction.MEAN.value ) -> Tensor: loss = (self - target) ** 2 return apply_loss_reduction(loss, reduction) @register_decomposition(aten.mse_loss_backward) @pw_cast_for_opmath def mse_loss_backward( grad_output: Tensor, input: Tensor, target: Tensor, reduction: int ): norm = 2.0 / input.numel() if reduction == Reduction.MEAN.value else 2.0 return norm * (input - target) * grad_output @register_decomposition(aten.huber_loss) @pw_cast_for_opmath def huber_loss( self: Tensor, target: Tensor, reduction: int = Reduction.MEAN.value, delta: float = 1.0, ) -> Tensor: assert delta > 0, "huber_loss does not support non-positive values for delta." z = (self - target).abs() loss = torch.where(z < delta, 0.5 * z * z, delta * (z - 0.5 * delta)) return apply_loss_reduction(loss, reduction) @register_decomposition(aten.huber_loss_backward) @pw_cast_for_opmath def huber_loss_backward( grad_output: Tensor, self: Tensor, target: Tensor, reduction: int, delta: float ): norm = 1.0 / self.numel() if reduction == Reduction.MEAN.value else 1.0 x = self - target return torch.where( x < -delta, -norm * grad_output * delta, torch.where(x > delta, norm * grad_output * delta, norm * x * grad_output), ) def _nll_loss_backward( grad_output: Tensor, self: Tensor, target: Tensor, weight: Optional[Tensor], reduction: int, ignore_index: int, total_weight: Tensor, ) -> Tensor: channel_dim = 0 if self.dim() < 2 else 1 if reduction == Reduction.MEAN.value: grad_output = grad_output / total_weight target = target.unsqueeze(channel_dim) grad_input = torch.zeros_like(self) grad_input = torch.scatter(grad_input, channel_dim, target, -1.0) if grad_input.dim() > grad_output.dim() > 0: grad_output = grad_output.unsqueeze(channel_dim) if weight is not None: new_shape = [1 for _ in range(self.dim())] new_shape[channel_dim] = weight.shape[0] weight = weight.reshape(new_shape) grad_output = grad_output * weight has_ignore_index = ignore_index >= 0 if has_ignore_index: ignore_index_mask = target != ignore_index grad_output = grad_output * ignore_index_mask return grad_input * grad_output @register_decomposition(aten.nll_loss_backward) def nll_loss_backward( grad_output: Tensor, self: Tensor, target: Tensor, weight: Optional[Tensor], reduction: int, ignore_index: int, total_weight: Tensor, ) -> Tensor: assert 0 <= self.dim() <= 2, "input tensor should be 1D or 2D" assert ( target.dim() <= 1 ), "0D or 1D target tensor expected, multi-target not supported" no_batch_dim = self.dim() == 1 and target.dim() == 0 assert no_batch_dim or ( self.shape[0] == target.shape[0] ), f"size mismatch (got input: {self.shape}, target: {target.shape})" assert total_weight.numel() == 1, ( "expected total_weight to be a single element tensor, got: ", f"{total_weight.shape} ({total_weight.numel()} elements)", ) assert ( weight is None or weight.numel() == self.shape[-1] ), "weight tensor should be defined either for all or no classes" if reduction == Reduction.NONE.value and self.dim() == 2: assert grad_output.dim() == 1 and grad_output.shape[0] == self.shape[0], ( f"Expected a tensor of dimension 1 and tensor.size[0] == {self.shape[0]} but " f"got: dimension {grad_output.dim()} and tensor.size[0] == {grad_output.shape[0]}" ) else: assert ( grad_output.dim() <= 1 and grad_output.numel() == 1 ), f"Expected a single element grad_output tensor, but got: {grad_output.shape}" return _nll_loss_backward(grad_output, self, target, weight, reduction, ignore_index, total_weight) @register_decomposition(aten.nll_loss2d_backward) def nll_loss2d_backward( grad_output: Tensor, self: Tensor, target: Tensor, weight: Optional[Tensor], reduction: int, ignore_index: int, total_weight: Tensor, ) -> Tensor: assert ( self.dim() == 4 ), f"only batches of spatial inputs supported (4D tensors), but got input of dimension: {self.dim()}" assert ( target.dim() == 3 ), f"only batches of spatial targets supported (3D tensors) but got targets of dimension: {target.dim()}" assert( self.shape[0] == target.shape[0] and self.shape[2] == target.shape[1] and self.shape[3] == target.shape[2] ), f"size mismatch (got input: {self.shape}, target: {target.shape}" assert ( total_weight.numel() == 1 ), ( "expected total_weight to be a single element tensor, " f"got: {total_weight.shape} ( {total_weight.numel()}, elements)" ) return _nll_loss_backward(grad_output, self, target, weight, reduction, ignore_index, total_weight) @register_decomposition(aten.binary_cross_entropy) @pw_cast_for_opmath def binary_cross_entropy( self: Tensor, target: Tensor, weight: Optional[Tensor] = None, reduction: int = Reduction.MEAN.value, ) -> Tensor: # We cannot currently model this without introducing data-dependent control flow # TORCH_CHECK( # (input_val >= 0) && (input_val <= 1), # "all elements of input should be between 0 and 1" # ) loss = (target - 1) * torch.maximum( torch.log(1 - self), self.new_full((), -100) ) - target * torch.maximum(torch.log(self), self.new_full((), -100)) if weight is not None: loss = loss * weight return apply_loss_reduction(loss, reduction) @register_decomposition(aten.binary_cross_entropy_backward) @pw_cast_for_opmath def binary_cross_entropy_backward( grad_output: Tensor, self: Tensor, target: Tensor, weight: Optional[Tensor] = None, reduction: int = Reduction.MEAN.value, ) -> Tensor: EPSILON = 1e-12 result = grad_output * (self - target) / torch.clamp(self * (1 - self), min=EPSILON) if weight is not None: result = result * weight if reduction == Reduction.MEAN.value: result = result / self.numel() return result @register_decomposition(aten._euclidean_dist) def _euclidean_dist(x1: Tensor, x2: Tensor) -> Tensor: x1_norm = x1.pow(2).sum(-1, True) x1_pad = torch.ones_like(x1_norm, memory_format=torch.contiguous_format) x2_norm = x2.pow(2).sum(-1, True) x2_pad = torch.ones_like(x2_norm, memory_format=torch.contiguous_format) x1_ = torch.cat([x1.mul(-2), x1_norm, x1_pad], -1) x2_ = torch.cat([x2, x2_pad, x2_norm], -1) result = x1_.matmul(x2_.mT) return result.clamp_min(0).sqrt() @register_decomposition(aten.slice_backward) def slice_backward( grad_output: Tensor, input_sizes: List[int], dim: int, start: int, end: int, step: int, ): grad_input = grad_output.new_zeros(input_sizes) return torch.slice_scatter(grad_input, grad_output, dim, start, end, step) @register_decomposition(aten.select_backward) def select_backward(grad_output: Tensor, input_sizes: List[int], dim: int, index: int): grad_input = grad_output.new_zeros(input_sizes) return torch.select_scatter(grad_input, grad_output, dim, index) @register_decomposition(aten.diagonal_backward) def diagonal_backward( grad_output: Tensor, input_sizes: List[int], offset: int, dim1: int, dim2: int ): grad_input = grad_output.new_zeros(input_sizes) return torch.diagonal_scatter(grad_input, grad_output, offset, dim1, dim2) @register_decomposition(aten._softmax_backward_data) @pw_cast_for_opmath def _softmax_backward_data( grad_output: Tensor, output: Tensor, dim: int, input_dtype: int ): new_grad = grad_output * output return new_grad - output * torch.sum(new_grad, dim=dim, keepdim=True) @register_decomposition(aten._log_softmax_backward_data) @pw_cast_for_opmath def _log_softmax_backward_data( grad_output: Tensor, output: Tensor, dim: int, input_dtype: int ): grad_input = grad_output - torch.exp(output) * torch.sum( grad_output, dim=dim, keepdim=True ) return grad_input # TODO: the type annotations on arguments are not quite right @register_decomposition(aten.im2col_backward) def im2col_backward( grad_output: Tensor, input_size: List[int], kernel_size: List[int], dilation: List[int], padding: List[int], stride: List[int], ) -> Tensor: return F.fold(grad_output, input_size, kernel_size, dilation, padding, stride) # type: ignore[arg-type] @register_decomposition(aten.col2im_backward) def col2im_backward( grad_output: Tensor, kernel_size: List[int], dilation: List[int], padding: List[int], stride: List[int], ) -> Tensor: return F.unfold(grad_output, kernel_size, dilation, padding, stride) # type: ignore[arg-type] @register_decomposition(aten.masked_fill.Scalar) def masked_fill_Scalar(self: Tensor, mask: Tensor, value: float) -> Tensor: return torch.where(mask, utils.dtype_to_type(self.dtype)(value), self) @register_decomposition(aten.masked_fill.Tensor) def masked_fill_Tensor(self: Tensor, mask: Tensor, value: Tensor) -> Tensor: return torch.where(mask, value, self) @register_decomposition(aten.native_dropout_backward) @pw_cast_for_opmath def native_dropout_backward(grad_output: Tensor, mask: Tensor, scale: float): return grad_output * (mask.type_as(grad_output) * scale) @register_decomposition(aten.logit_backward.default) @pw_cast_for_opmath def logit_backward( grad_output: Tensor, self: Tensor, eps: Optional[float] = None ) -> Tensor: if eps is not None: lo = eps hi = 1.0 - lo return torch.where( torch.logical_and(self >= lo, self <= hi), grad_output / (self * (1.0 - self)), self.new_zeros(()), ) else: return torch.where( torch.logical_and(self >= 0.0, self <= 1.0), grad_output / (self * (1.0 - self)), self.new_full((), float("nan")), ) @register_decomposition(aten.native_dropout) def native_dropout(input: Tensor, p: float, train: Optional[bool]): if train: bool_mask = torch.rand_like(input) > p res = bool_mask * input * float(1.0 / (1.0 - p)) return (res, bool_mask) else: return (input, torch.ones_like(input, dtype=torch.bool)) # TODO: Correct the type promotion semantics @register_decomposition(aten._softmax) @pw_cast_for_opmath def _softmax(x: Tensor, dim: int, half_to_float: bool): x_max = torch.max(x, dim, keepdim=True)[0] unnormalized = torch.exp(x - x_max) return unnormalized / torch.sum(unnormalized, dim, keepdim=True) # TODO: Correct the type promotion semantics @register_decomposition(aten._log_softmax) @pw_cast_for_opmath def _log_softmax(x: Tensor, dim: int, half_to_float: bool): x_max = torch.max(x, dim, keepdim=True)[0] shifted = x - x_max shifted_logsumexp = torch.log(torch.sum(torch.exp(shifted), dim, keepdim=True)) return shifted - shifted_logsumexp @register_decomposition(aten.addcdiv) @pw_cast_for_opmath def addcdiv(self: Tensor, tensor1: Tensor, tensor2: Tensor, value: float = 1): return self + value * (tensor1 / tensor2) # Remove special case when https://github.com/pytorch/pytorch/pull/72949 is landed. @register_decomposition(aten.addcmul) @pw_cast_for_opmath def addcmul(self: Tensor, tensor1: Tensor, tensor2: Tensor, value: float = 1): if self.is_floating_point() or self.is_complex(): return self + value * tensor1 * tensor2 else: return self + int(value) * tensor1 * tensor2 @register_decomposition(aten.rsub.Tensor) def rsub_Tensor(self: Tensor, other: Tensor, alpha: float = 1) -> Tensor: return torch.sub(other, self, alpha=alpha) @register_decomposition(aten.rsub.Scalar) def rsub_Scalar(self: Tensor, other: float, alpha: float = 1) -> Tensor: return torch.sub(other, self, alpha=alpha) @register_decomposition(aten.embedding) def embedding( weight: Tensor, indices: Tensor, padding_idx: int = -1, scale_grad_by_freq: bool = False, sparse: bool = False, ) -> Tensor: assert weight.dim() == 2, "'weight' must be 2-D" # TODO: Assert not ported over yet # auto indices_arg = TensorArg(indices, "indices", 1); # checkScalarTypes("embedding", indices_arg, {kLong, kInt}); if indices.dim() == 1: return weight.index_select(0, indices) size = list(indices.shape) for d in weight.shape[1:]: size.append(d) return weight.index_select(0, indices.reshape(-1)).view(size) # TODO: Correct the type promotion semantics @register_decomposition(aten.embedding_dense_backward) def embedding_dense_backward( grad_output: Tensor, indices: Tensor, num_weights: int, padding_idx: int, scale_grad_by_freq: bool, ): numel = indices.numel() grad = grad_output.view(numel, grad_output.size(-1)) grad_weight = grad_output.new_zeros((num_weights, grad_output.shape[-1])) indices_rank1 = indices.view(numel) if scale_grad_by_freq: counts = indices.new_zeros((num_weights,)) ones = indices.new_ones((numel,)) counts = counts.index_put([indices_rank1], ones, accumulate=True) grad_weights_scale = counts[indices_rank1] grad = grad / grad_weights_scale.unsqueeze(1) skip_padding = (indices_rank1 != padding_idx).unsqueeze(1) skip_padding = skip_padding.expand_as(grad) zero_grad = torch.full_like(grad, 0) return grad_weight.index_put( [indices_rank1], torch.where(skip_padding, grad, zero_grad), accumulate=True ) def prod(x: List[int]): r = 1 for i in x: r *= i return r @register_decomposition(aten.split_with_sizes) def split_with_sizes( self: Tensor, split_sizes: List[int], dim: int = 0 ) -> List[Tensor]: num_splits = len(split_sizes) splits = [] start_idx = 0 for i in range(num_splits): length = split_sizes[i] splits.append(self.narrow(dim, start_idx, length)) start_idx += length return splits @register_decomposition(aten.split.Tensor) def split(self: Tensor, split_size: int, dim: int = 0) -> List[Tensor]: input_sizes = self.shape dim_size = input_sizes[dim] if split_size == 0: assert dim_size == 0 return [self] chunks = (dim_size + split_size - 1) // split_size split_sizes = [split_size for i in range(chunks)] split_sizes[chunks - 1] = split_size - (split_size * chunks - dim_size) return torch.split(self, split_sizes, dim) # TODO: this doesn't appear to have enough precision in bfloat16 @register_decomposition(aten.addmm) @pw_cast_for_opmath def addmm(self: Tensor, mat1: Tensor, mat2: Tensor, beta: int = 1, alpha: int = 1): if not self.is_floating_point() and not self.is_complex(): beta = int(beta) alpha = int(alpha) out = alpha * torch.mm(mat1, mat2) if beta == 0: return out return beta * self + out # This computes the mean and variance along the specifized normalization dims, # then normalizes along those dims. Finally, it returns the mean and variance of # the normalized dims. Note that it intentionally leaves outputs upcasted. # Example: # input: [2, 3, 4, 5], norm_dims: [1, 3] # mean: [2, 1, 4, 1] def normalize(input, norm_dims, eps): computation_dtype = utils.get_computation_dtype(input.dtype) input_acc = input.to(dtype=computation_dtype) biased_var = torch.var(input_acc, dim=norm_dims, unbiased=False, keepdim=True) mean = torch.mean(input_acc, dim=norm_dims, keepdim=True) rstd = torch.rsqrt(biased_var + eps) out = ((input - mean) * rstd) return out, mean, rstd @register_decomposition(aten.native_layer_norm.default) def native_layer_norm( input: Tensor, normalized_shape: List[int], weight: Optional[Tensor], bias: Optional[Tensor], eps: float, ) -> Tuple[Tensor, Tensor, Tensor]: computation_dtype = utils.get_computation_dtype(input.dtype) axis = input.dim() - len(normalized_shape) if prod(list(input.shape[:axis])) == 0: mean = input.new_zeros((0,), dtype=computation_dtype) rstd = input.new_zeros((0,), dtype=computation_dtype) out = input else: reduction_dims = list(range(axis, input.dim())) out, mean, rstd = normalize(input, reduction_dims, eps) if weight is not None: out = out * weight if bias is not None: out = out + bias out = out.to(dtype=input.dtype) if input.device.type == 'cpu': mean = mean.to(dtype=input.dtype) rstd = rstd.to(dtype=input.dtype) return (out, mean, rstd) @register_decomposition(aten.native_group_norm.default, disable_meta=True) def native_group_norm( input: Tensor, weight: Optional[Tensor], bias: Optional[Tensor], N: int, C: int, HxW: int, group: int, eps: float, ) -> Tuple[Tensor, Tensor, Tensor]: orig_shape = input.shape input = input.view(N, group, C // group, HxW) reduction_dims = [2, 3] out, mean, rstd = normalize(input, reduction_dims, eps) mean = _squeeze_multiple(mean, reduction_dims) rstd = _squeeze_multiple(rstd, reduction_dims) out = out.view(orig_shape) if weight is not None: weight = _unsqueeze_to_dim(weight, out.dim() - 1) out = out * weight if bias is not None: bias = _unsqueeze_to_dim(bias, out.dim() - 1) out = out + bias out = out.to(dtype=input.dtype) mean = mean.to(dtype=input.dtype) rstd = rstd.to(dtype=input.dtype) return (out, mean, rstd) def _maybe_cast(x: Optional[Tensor], dtype) -> Optional[Tensor]: if x is not None: return x.to(dtype) return x # TODO: Take a closer look at the type promotion semantics @register_decomposition(aten.native_layer_norm_backward) def native_layer_norm_backward( grad_out: Tensor, input: Tensor, normalized_shape: List[int], mean: Tensor, rstd: Tensor, weight: Optional[Tensor], bias: Optional[Tensor], output_mask: List[bool], ) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]: input_shape = input.shape input_ndim = input.dim() computation_dtype = utils.get_computation_dtype(input.dtype) grad_out_cast, input_cast, weight_cast, bias_cast = [ x.to(computation_dtype) if x is not None else x for x in (grad_out, input, weight, bias) ] assert grad_out_cast is not None axis = input_ndim - len(normalized_shape) inner_dims = input_shape[axis:] outer_dims = input_shape[:axis] inner_dim_indices: List[int] = [] outer_dim_indices: List[int] = [] for i in range(input_ndim): if i >= axis: inner_dim_indices.append(i) else: outer_dim_indices.append(i) N = prod(inner_dims) # type: ignore[arg-type] M = prod(outer_dims) # type: ignore[arg-type] if M <= 0 or N <= 0: return ( input.new_zeros(input_shape), input.new_zeros(input_shape[axis:]), input.new_zeros(input_shape[axis:]), ) x_hat = (input_cast - mean) * rstd if weight_cast is not None: grad_x_hat = grad_out_cast * weight_cast else: grad_x_hat = grad_out_cast a = grad_x_hat * N b = torch.sum(grad_x_hat, inner_dim_indices, True) c1 = torch.mul(grad_x_hat, x_hat) c2 = torch.sum(c1, inner_dim_indices, True) c3 = torch.mul(x_hat, c2) inner = a - b - c3 d_input: Optional[Tensor] = None d_weight: Optional[Tensor] = None d_bias: Optional[Tensor] = None if output_mask[0]: d_input = (rstd / N) * inner if output_mask[1] and weight_cast is not None: if len(outer_dim_indices) > 0: d_weight = torch.sum( grad_out_cast * x_hat, outer_dim_indices, False ) else: d_weight = grad_out_cast * x_hat if output_mask[2] and bias_cast is not None: if len(outer_dim_indices) > 0: d_bias = torch.sum(grad_out_cast, outer_dim_indices, False) else: d_bias = grad_out_cast return _maybe_cast(d_input, input.dtype), _maybe_cast(d_weight, input.dtype), _maybe_cast(d_bias, input.dtype) @register_decomposition(aten.native_batch_norm) def native_batch_norm( input: Tensor, weight: Optional[Tensor], bias: Optional[Tensor], running_mean: Optional[Tensor], running_var: Optional[Tensor], training: bool, momentum: float, eps: float, ) -> Tuple[Tensor, Tensor, Tensor]: reduction_dims = [0] + list(range(2, input.dim())) computation_dtype = utils.get_computation_dtype(input.dtype) if training: output, mean, rstd = normalize(input, reduction_dims, eps) save_mean = _squeeze_multiple(mean, reduction_dims) save_rstd = _squeeze_multiple(rstd, reduction_dims) if running_mean is not None: running_mean.copy_(momentum * save_mean + (1 - momentum) * running_mean) if running_var is not None: n = input.numel() / input.shape[1] # This doesn't strictly match eager's numerics, which accumulates var sum and then directly applies the correction # But... that would require re-implementing var here, for negligible numerics gain on a tensor whose # numerics probably don't matter. unbiased_var = torch.var(input, reduction_dims, unbiased=False) * (n / (n - 1)) running_var.copy_(momentum * unbiased_var + (1 - momentum) * running_var) else: assert running_mean is not None and running_var is not None running_mean = running_mean.to(dtype=computation_dtype) running_var = running_var.to(dtype=computation_dtype) mean = running_mean invstd = 1 / (torch.sqrt(running_var + eps)) # Very annoying inconsistency where CPU and CUDA give different shapes if input.device.type != "cpu": save_mean = running_mean save_rstd = invstd else: save_mean = input.new_zeros((0,)) save_rstd = input.new_zeros((0,)) mean = _unsqueeze_to_dim(mean, input.dim() - 1) invstd = _unsqueeze_to_dim(invstd, input.dim() - 1) output = ((input - mean) * invstd) if weight is None: weight = input.new_ones(()) if bias is None: bias = input.new_zeros(()) weight = _unsqueeze_to_dim(weight, input.dim() - 1) bias = _unsqueeze_to_dim(bias, input.dim() - 1) output = output * weight + bias if input.device.type == 'cpu': save_mean = save_mean.to(dtype=input.dtype) save_rstd = save_rstd.to(dtype=input.dtype) return output.to(dtype=input.dtype), save_mean, save_rstd @register_decomposition(aten.clamp_min) def clamp_min(self: Tensor, min: float): return torch.clamp(self, min=min) @register_decomposition(aten.clamp_max) def clamp_max(self: Tensor, max: float): return torch.clamp(self, max=max) @register_decomposition(aten._fused_dropout) @pw_cast_for_opmath def _fused_dropout_decomposition(input, p, generator=None): mask = (torch.rand_like(input) < p).to(dtype=torch.uint8) res = mask.type_as(input) * input * (1.0 / p) return (res, mask) @register_decomposition(aten.logical_not) def logical_not(self: Tensor) -> Tensor: return ~self.to(dtype=torch.bool) @register_decomposition(aten.xlogy.Tensor) @pw_cast_for_int_to_real def xlogy(self: Tensor, other: Tensor) -> Tensor: return aten.where(aten.isnan(self), self, aten.where(self == aten.new_zeros(self, ()), aten.new_zeros(self, ()), self * aten.log(other))) @register_decomposition(aten.var.correction) @reduction_complex_to_real def var_correction( x: Tensor, dims: Optional[List[int]], correction: Optional[int] = None, keepdim: bool = False, ): if dims is None: dims = [] if x.is_complex(): # For complex, calculate variance of real and imaginary components # separately then add to get overall variance. real_in = x.real var_real = torch.var(real_in, dims, correction=correction, keepdim=keepdim) imag_in = x.imag var_imag = torch.var(imag_in, dims, correction=correction, keepdim=keepdim) return var_real + var_imag if correction is None: correction = 0 if len(dims) == 0: n = prod(x.shape) # type: ignore[arg-type] else: n = 1 for dim in dims: n *= x.shape[dim] mean = torch.mean(x, dims, True) sub = x - mean sq = sub * sub sum = torch.sum(sq, dims, keepdim) if correction: n = n - correction return sum / n @register_decomposition(aten.std.correction) @reduction_complex_to_real def std_decomposition( x: Tensor, dims: List[int], correction: int = 0, keepdim: bool = False ): return torch.sqrt(torch.var(x, dims, correction=correction, keepdim=keepdim)) # Questionable decompositions # This is only valid if we're running the graph without autograd, such as if the backward pass has been traced. # Note that this decomposition causes issues with in-place ops @register_decomposition(aten.detach, disable_meta=True) def detach_decomposition(x): return x @register_decomposition(aten.cudnn_batch_norm) def cudnn_batch_norm( input: Tensor, weight: Tensor, bias: Optional[Tensor], running_mean: Optional[Tensor], running_var: Optional[Tensor], training: bool, exponential_average_factor: float, epsilon: float, ): a, b, c = aten.native_batch_norm( input, weight, bias, running_mean, running_var, training, exponential_average_factor, epsilon, ) # Cudnn return running mean and variance when training is True if training: return (a, b, c, input.new_zeros((0,), dtype=torch.uint8)) return (a, input.new_zeros((0,)), input.new_zeros((0,)), input.new_zeros((0,), dtype=torch.uint8)) @register_decomposition(aten.cudnn_batch_norm_backward) def cudnn_batch_norm_backward( input: Tensor, grad_output: Tensor, weight: Tensor, running_mean: Optional[Tensor], running_var: Optional[Tensor], save_mean: Optional[Tensor], save_var: Optional[Tensor], epsilon: float, reserveSpace: Tensor, ): return aten.native_batch_norm_backward( grad_output, input, weight, running_mean, running_var, save_mean, save_var, True, epsilon, [True, True, True], ) @register_decomposition(aten.rot90.default) def rot90(self: Tensor, k: int = 1, dims: List[int] = [0, 1]) -> Tensor: # noqa: B006 total_dims = self.dim() total_rot_dims = len(dims) assert total_rot_dims == 2, f"expected total rotation dims == 2, but got dims = {total_rot_dims}" assert total_dims >= 2, f"expected total dims >= 2, but got total dims = {total_dims}" assert dims[0] != dims[1] and abs(dims[0] - dims[1]) != total_dims,\ f"expected rotation dims to be different, but got dim0 = {dims[0]} and dim1 = {dims[1]}" assert dims[0] < total_dims and dims[0] >= -total_dims, f"Rotation dim0 out of range, dim0 = {dims[0]}" assert dims[1] < total_dims and dims[1] >= -total_dims, f"Rotation dim1 out of range, dim1 = {dims[1]}" k = k % 4 if k == 1: return self.flip(dims[1]).transpose(dims[0], dims[1]) elif k == 2: return self.flip(dims) elif k == 3: return self.flip(dims[0]).transpose(dims[0], dims[1]) else: return self.clone(memory_format=torch.contiguous_format) @register_decomposition(aten.transpose.int) def transpose_int(self: Tensor, dim0: int, dim1: int) -> Tensor: dim0, dim1 = utils.canonicalize_dims(self.dim(), (dim0, dim1)) # type: ignore[misc] if self.dim() <= 1: return self if dim0 == dim1: return self perm = list(range(self.dim())) perm[dim0], perm[dim1] = perm[dim1], perm[dim0] return torch.permute(self, perm) @register_decomposition(aten.t.default) def t(self: Tensor) -> Tensor: return self.transpose(0, 0 if self.dim() < 2 else 1) def check_stack_inputs(tensors: List[Tensor]): entry_shape = tensors[0].shape for i in range(1, len(tensors)): assert tensors[i].shape == entry_shape, (f"stack expects each tensor to be equal size, but got {entry_shape} at entry 0" f"and {tensors[i].shape} at entry {i}") def get_stack_inputs(tensors: List[Tensor], dim: int): check_stack_inputs(tensors) return [t.unsqueeze(dim) for t in tensors] @register_decomposition(aten.stack.default) def stack(tensors: List[Tensor], dim: int = 0) -> Tensor: assert len(tensors) > 0, "stack expects a non-empty TensorList" wrapped_dim = utils.canonicalize_dim(tensors[0].dim() + 1, dim) if wrapped_dim < tensors[0].dim() and not tensors[0].is_sparse: check_stack_inputs(tensors) result_sizes = list(tensors[0].shape) result_sizes.insert(wrapped_dim, len(tensors)) out = torch.cat(tensors, wrapped_dim) return out.view(result_sizes) else: return torch.cat(get_stack_inputs(tensors, wrapped_dim), dim) def _squeeze_multiple(self: Tensor, dims: List[int]) -> Tensor: ndim = self.dim() wrapped_dims = utils.canonicalize_dims(ndim, dims) assert isinstance(wrapped_dims, tuple) for idx in range(ndim - 1, -1, -1): if idx in wrapped_dims: self = self.squeeze(idx) return self @register_decomposition(aten.logsumexp.default) @pw_cast_for_int_to_real def logsumexp(self: Tensor, dim: List[int], keepdim: bool = False) -> Tensor: if self.numel() == 0: return torch.sum(torch.exp(self), dim, keepdim).log() maxes = torch.amax(self, dim, keepdim=True) maxes_squeezed = maxes if keepdim else _squeeze_multiple(maxes, dim) maxes_squeezed = torch.masked_fill(maxes_squeezed, maxes_squeezed.abs() == float('inf'), 0) result = torch.sum(torch.exp(self - maxes), dim, keepdim) return result.log().add(maxes_squeezed) @register_decomposition(aten.trace.default) def trace(self: Tensor) -> Tensor: return torch.sum(torch.diag(self)) # nb: Should use acc_t, not op_math @register_decomposition(aten.log_sigmoid_forward) @out_wrapper_multi('output', 'buffer') @pw_cast_for_opmath def log_sigmoid_forward(self: Tensor) -> Tuple[Tensor, Tensor]: min = torch.minimum(self.new_zeros(()), self) z = torch.exp(-torch.abs(self)) if self.is_cuda: buffer = self.new_zeros((0,)) else: buffer = z return min - torch.log1p(z), buffer