mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
This replaces a bunch of unnecessary lambdas with the operator package. This is semantically equivalent, but the operator package is faster, and arguably more readable. When the FURB rules are taken out of preview, I will enable it as a ruff check. Pull Request resolved: https://github.com/pytorch/pytorch/pull/116027 Approved by: https://github.com/malfet
64 lines
2.1 KiB
Python
64 lines
2.1 KiB
Python
import operator
|
|
import warnings
|
|
from functools import reduce
|
|
|
|
import torch
|
|
import torch._utils
|
|
from ..function import Function
|
|
|
|
|
|
class Type(Function):
|
|
@staticmethod
|
|
def forward(ctx, i, dest_type):
|
|
warnings.warn(
|
|
"torch.autograd._functions.Type is deprecated as of PyTorch 2.1, please use "
|
|
"torch.tensor.to(dtype=dtype) instead."
|
|
)
|
|
ctx.input_type = type(i)
|
|
ctx.input_device = -1 if not i.is_cuda else i.get_device()
|
|
return i.type(dest_type)
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
if ctx.input_device == -1:
|
|
return grad_output.type(ctx.input_type), None
|
|
else:
|
|
with torch.cuda.device(ctx.input_device):
|
|
return grad_output.type(ctx.input_type), None
|
|
|
|
|
|
# TODO: deprecate this
|
|
class Resize(Function):
|
|
@staticmethod
|
|
def forward(ctx, tensor, sizes):
|
|
ctx.sizes = sizes
|
|
ctx.numel = reduce(operator.mul, sizes, 1)
|
|
if tensor.numel() != ctx.numel:
|
|
raise RuntimeError(
|
|
(
|
|
"requested resize to {} ({} elements in total), "
|
|
"but the given tensor has a size of {} ({} elements). "
|
|
"autograd's resize can only change the shape of a given "
|
|
"tensor, while preserving the number of elements. "
|
|
).format(
|
|
"x".join(map(str, sizes)),
|
|
ctx.numel,
|
|
"x".join(map(str, tensor.size())),
|
|
tensor.numel(),
|
|
)
|
|
)
|
|
ctx.input_sizes = tensor.size()
|
|
if tensor.is_quantized:
|
|
tensor.copy_(tensor)
|
|
return tensor.contiguous().view(*sizes)
|
|
if tensor.is_contiguous():
|
|
result = tensor.new(tensor).contiguous().view(*sizes)
|
|
return result
|
|
else:
|
|
return tensor.contiguous().view(*sizes)
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
assert grad_output.numel() == ctx.numel
|
|
return grad_output.contiguous().view(ctx.input_sizes), None
|