mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283 Test plan: dmypy restart && python3 scripts/lintrunner.py -a pyrefly check step 1: uncomment lines in the pyrefly.toml file step 2: run pyrefly check step 3: add suppressions, clean up unused suppressions before: https://gist.github.com/maggiemoss/bb31574ac8a59893c9cf52189e67bb2d after: 0 errors (1,970 ignored) Pull Request resolved: https://github.com/pytorch/pytorch/pull/164588 Approved by: https://github.com/oulgen
70 lines
2.3 KiB
Python
70 lines
2.3 KiB
Python
# mypy: allow-untyped-defs
|
|
import operator
|
|
from functools import reduce
|
|
from typing_extensions import deprecated
|
|
|
|
import torch
|
|
import torch._utils
|
|
from torch.autograd.function import Function
|
|
|
|
|
|
class Type(Function):
|
|
@staticmethod
|
|
@deprecated(
|
|
"`torch.autograd._functions.Type` is deprecated as of PyTorch 2.1, "
|
|
"please use `torch.tensor.to(dtype=dtype)` instead.",
|
|
category=FutureWarning,
|
|
)
|
|
# pyrefly: ignore # bad-override
|
|
def forward(ctx, i, dest_type):
|
|
ctx.input_type = type(i)
|
|
ctx.input_device = -1 if not i.is_cuda else i.get_device()
|
|
return i.type(dest_type)
|
|
|
|
@staticmethod
|
|
# pyrefly: ignore # bad-override
|
|
def backward(ctx, grad_output):
|
|
if ctx.input_device == -1:
|
|
return grad_output.type(ctx.input_type), None
|
|
else:
|
|
with torch.accelerator.device_index(ctx.input_device):
|
|
return grad_output.type(ctx.input_type), None
|
|
|
|
|
|
# TODO: deprecate this
|
|
class Resize(Function):
|
|
@staticmethod
|
|
# pyrefly: ignore # bad-override
|
|
def forward(ctx, tensor, sizes):
|
|
ctx.sizes = sizes
|
|
ctx.numel = reduce(operator.mul, sizes, 1)
|
|
if tensor.numel() != ctx.numel:
|
|
raise RuntimeError(
|
|
(
|
|
"requested resize to {} ({} elements in total), "
|
|
"but the given tensor has a size of {} ({} elements). "
|
|
"autograd's resize can only change the shape of a given "
|
|
"tensor, while preserving the number of elements. "
|
|
).format(
|
|
"x".join(map(str, sizes)),
|
|
ctx.numel,
|
|
"x".join(map(str, tensor.size())),
|
|
tensor.numel(),
|
|
)
|
|
)
|
|
ctx.input_sizes = tensor.size()
|
|
if tensor.is_quantized:
|
|
tensor.copy_(tensor)
|
|
return tensor.contiguous().view(*sizes)
|
|
if tensor.is_contiguous():
|
|
result = tensor.new(tensor).contiguous().view(*sizes)
|
|
return result
|
|
else:
|
|
return tensor.contiguous().view(*sizes)
|
|
|
|
@staticmethod
|
|
# pyrefly: ignore # bad-override
|
|
def backward(ctx, grad_output):
|
|
assert grad_output.numel() == ctx.numel
|
|
return grad_output.contiguous().view(ctx.input_sizes), None
|