pytorch/torch/autograd/_functions/tensor.py
Aaron Gokaslan 6de28e92d2 [BE]: Apply FURB118 (prev): replaces unnecessary lambdas with operator. (#116027)
This replaces a bunch of unnecessary lambdas with the operator package. This is semantically equivalent, but the operator package is faster, and arguably more readable. When the FURB rules are taken out of preview, I will enable it as a ruff check.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116027
Approved by: https://github.com/malfet
2023-12-20 19:35:08 +00:00

64 lines
2.1 KiB
Python

import operator
import warnings
from functools import reduce
import torch
import torch._utils
from ..function import Function
class Type(Function):
@staticmethod
def forward(ctx, i, dest_type):
warnings.warn(
"torch.autograd._functions.Type is deprecated as of PyTorch 2.1, please use "
"torch.tensor.to(dtype=dtype) instead."
)
ctx.input_type = type(i)
ctx.input_device = -1 if not i.is_cuda else i.get_device()
return i.type(dest_type)
@staticmethod
def backward(ctx, grad_output):
if ctx.input_device == -1:
return grad_output.type(ctx.input_type), None
else:
with torch.cuda.device(ctx.input_device):
return grad_output.type(ctx.input_type), None
# TODO: deprecate this
class Resize(Function):
@staticmethod
def forward(ctx, tensor, sizes):
ctx.sizes = sizes
ctx.numel = reduce(operator.mul, sizes, 1)
if tensor.numel() != ctx.numel:
raise RuntimeError(
(
"requested resize to {} ({} elements in total), "
"but the given tensor has a size of {} ({} elements). "
"autograd's resize can only change the shape of a given "
"tensor, while preserving the number of elements. "
).format(
"x".join(map(str, sizes)),
ctx.numel,
"x".join(map(str, tensor.size())),
tensor.numel(),
)
)
ctx.input_sizes = tensor.size()
if tensor.is_quantized:
tensor.copy_(tensor)
return tensor.contiguous().view(*sizes)
if tensor.is_contiguous():
result = tensor.new(tensor).contiguous().view(*sizes)
return result
else:
return tensor.contiguous().view(*sizes)
@staticmethod
def backward(ctx, grad_output):
assert grad_output.numel() == ctx.numel
return grad_output.contiguous().view(ctx.input_sizes), None