pytorch/torch/autograd/_functions/utils.py
Aaron Gokaslan 6de28e92d2 [BE]: Apply FURB118 (prev): replaces unnecessary lambdas with operator. (#116027)
This replaces a bunch of unnecessary lambdas with the operator package. This is semantically equivalent, but the operator package is faster, and arguably more readable. When the FURB rules are taken out of preview, I will enable it as a ruff check.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116027
Approved by: https://github.com/malfet
2023-12-20 19:35:08 +00:00

63 lines
2.0 KiB
Python

import operator
from functools import reduce
def maybe_view(tensor, size, check_same_size=True):
if check_same_size and tensor.size() == size:
return tensor
return tensor.contiguous().view(size)
def maybe_unexpand(tensor, old_size, check_same_size=True):
if check_same_size and tensor.size() == old_size:
return tensor
num_unsqueezed = tensor.dim() - len(old_size)
expanded_dims = [
dim
for dim, (expanded, original) in enumerate(
zip(tensor.size()[num_unsqueezed:], old_size)
)
if expanded != original
]
for _ in range(num_unsqueezed):
tensor = tensor.sum(0, keepdim=False)
for dim in expanded_dims:
tensor = tensor.sum(dim, keepdim=True)
return tensor
# Check whether the op enable broadcasting, and whether it is supported by ONNX.
# If dims1 and dims2 are different, then broadcast is True.
# We always assume the combination of dims1 and dims2 is broadcastable.
# The following types of broadcasting are supported in ONNX:
# 1) Only one element in dims2, such as dims2 = [1, 1]
# 2) dims2 is suffix of dims1, such as dims1 = [2, 3, 4], and dims2 = [3, 4]
# Details can be found here: https://github.com/onnx/onnx/blob/master/docs/Operators.md#Gemm
def check_onnx_broadcast(dims1, dims2):
broadcast = False
supported = True
len1 = len(dims1)
len2 = len(dims2)
numel1 = reduce(operator.mul, dims1)
numel2 = reduce(operator.mul, dims2)
if len1 < len2:
broadcast = True
if numel2 != 1:
supported = False
elif len1 > len2:
broadcast = True
if numel2 != 1 and dims1[len1 - len2 :] != dims2:
supported = False
else:
if dims1 != dims2:
broadcast = True
if numel2 != 1:
supported = False
if not supported:
raise ValueError(
f"Numpy style broadcasting is not supported in ONNX. Input dims are: {dims1}, {dims2}"
)
return broadcast