Revert "Fix non-bitwise type annotations for Tensor operators (see #145838) (#146845)"

This reverts commit 59b7e52ad8.

Reverted https://github.com/pytorch/pytorch/pull/146845 on behalf of https://github.com/jeanschmidt due to Seems to break a few code dependencies in multiple places ([comment](https://github.com/pytorch/pytorch/pull/146845#issuecomment-2666656834))
This commit is contained in:
PyTorch MergeBot 2025-02-18 19:01:27 +00:00
parent 57060bebf3
commit 302f56a1f2
9 changed files with 89 additions and 81 deletions

View File

@ -27,7 +27,7 @@ assert_type(TENSOR >= TENSOR, Tensor)
assert_type(TENSOR + TENSOR, Tensor)
assert_type(TENSOR - TENSOR, Tensor)
assert_type(TENSOR * TENSOR, Tensor)
assert_type(TENSOR // TENSOR, Tensor)
assert_type(TENSOR // TENSOR, Any)
assert_type(TENSOR / TENSOR, Tensor)
assert_type(TENSOR % TENSOR, Tensor)
assert_type(TENSOR**TENSOR, Tensor)
@ -46,7 +46,7 @@ assert_type(TENSOR >= BOOL, Tensor)
assert_type(TENSOR + BOOL, Tensor)
assert_type(TENSOR - BOOL, Tensor)
assert_type(TENSOR * BOOL, Tensor)
assert_type(TENSOR // BOOL, Tensor)
assert_type(TENSOR // BOOL, Any)
assert_type(TENSOR / BOOL, Tensor)
assert_type(TENSOR % BOOL, Tensor)
assert_type(TENSOR**BOOL, Tensor)
@ -63,14 +63,14 @@ assert_type(BOOL > TENSOR, Tensor)
assert_type(BOOL <= TENSOR, Tensor)
assert_type(BOOL >= TENSOR, Tensor)
assert_type(BOOL + TENSOR, Tensor)
assert_type(BOOL - TENSOR, Tensor)
assert_type(BOOL - TENSOR, Any)
assert_type(BOOL * TENSOR, Tensor)
assert_type(BOOL // TENSOR, Tensor)
assert_type(BOOL / TENSOR, Tensor)
assert_type(BOOL % TENSOR, Tensor)
assert_type(BOOL**TENSOR, Tensor)
assert_type(BOOL << TENSOR, Tensor)
assert_type(BOOL >> TENSOR, Tensor)
assert_type(BOOL // TENSOR, Any)
assert_type(BOOL / TENSOR, Any)
assert_type(BOOL % TENSOR, Any)
assert_type(BOOL**TENSOR, Any)
assert_type(BOOL << TENSOR, Any)
assert_type(BOOL >> TENSOR, Any)
assert_type(BOOL & TENSOR, Tensor)
assert_type(BOOL | TENSOR, Tensor)
assert_type(BOOL ^ TENSOR, Tensor)
@ -84,7 +84,7 @@ assert_type(TENSOR >= INT, Tensor)
assert_type(TENSOR + INT, Tensor)
assert_type(TENSOR - INT, Tensor)
assert_type(TENSOR * INT, Tensor)
assert_type(TENSOR // INT, Tensor)
assert_type(TENSOR // INT, Any)
assert_type(TENSOR / INT, Tensor)
assert_type(TENSOR % INT, Tensor)
assert_type(TENSOR**INT, Tensor)
@ -101,14 +101,14 @@ assert_type(INT > TENSOR, Tensor)
assert_type(INT <= TENSOR, Tensor)
assert_type(INT >= TENSOR, Tensor)
assert_type(INT + TENSOR, Tensor)
assert_type(INT - TENSOR, Tensor)
assert_type(INT - TENSOR, Any)
assert_type(INT * TENSOR, Tensor)
assert_type(INT // TENSOR, Tensor)
assert_type(INT / TENSOR, Tensor)
assert_type(INT % TENSOR, Tensor)
assert_type(INT**TENSOR, Tensor)
assert_type(INT << TENSOR, Tensor)
assert_type(INT >> TENSOR, Tensor)
assert_type(INT // TENSOR, Any)
assert_type(INT / TENSOR, Any)
assert_type(INT % TENSOR, Any)
assert_type(INT**TENSOR, Any)
assert_type(INT << TENSOR, Any)
assert_type(INT >> TENSOR, Any)
assert_type(INT & TENSOR, Any) # type: ignore[operator]
assert_type(INT | TENSOR, Any) # type: ignore[operator]
assert_type(INT ^ TENSOR, Any) # type: ignore[operator]
@ -122,7 +122,7 @@ assert_type(TENSOR >= FLOAT, Tensor)
assert_type(TENSOR + FLOAT, Tensor)
assert_type(TENSOR - FLOAT, Tensor)
assert_type(TENSOR * FLOAT, Tensor)
assert_type(TENSOR // FLOAT, Tensor)
assert_type(TENSOR // FLOAT, Any)
assert_type(TENSOR / FLOAT, Tensor)
assert_type(TENSOR % FLOAT, Tensor)
assert_type(TENSOR**FLOAT, Tensor)
@ -139,14 +139,14 @@ assert_type(FLOAT > TENSOR, Tensor)
assert_type(FLOAT <= TENSOR, Tensor)
assert_type(FLOAT >= TENSOR, Tensor)
assert_type(FLOAT + TENSOR, Tensor)
assert_type(FLOAT - TENSOR, Tensor)
assert_type(FLOAT - TENSOR, Any)
assert_type(FLOAT * TENSOR, Tensor)
assert_type(FLOAT // TENSOR, Tensor)
assert_type(FLOAT / TENSOR, Tensor)
assert_type(FLOAT % TENSOR, Tensor)
assert_type(FLOAT**TENSOR, Tensor)
assert_type(FLOAT << TENSOR, Tensor)
assert_type(FLOAT >> TENSOR, Tensor)
assert_type(FLOAT // TENSOR, Any)
assert_type(FLOAT / TENSOR, Any)
assert_type(FLOAT % TENSOR, Any)
assert_type(FLOAT**TENSOR, Any)
assert_type(FLOAT << TENSOR, Any)
assert_type(FLOAT >> TENSOR, Any)
assert_type(FLOAT & TENSOR, Tensor) # type: ignore[operator]
assert_type(FLOAT | TENSOR, Tensor) # type: ignore[operator]
assert_type(FLOAT ^ TENSOR, Tensor) # type: ignore[operator]
@ -377,6 +377,38 @@ assert_type(BOOL ^ BINARY, Binary)
# Tensor operators whose types could be improved
# This is the "diff" of the first and second sections.
assert_type(BOOL // TENSOR, Any)
assert_type(FLOAT // TENSOR, Any)
assert_type(INT // TENSOR, Any)
assert_type(TENSOR // BOOL, Any)
assert_type(TENSOR // FLOAT, Any)
assert_type(TENSOR // INT, Any)
assert_type(TENSOR // TENSOR, Any)
assert_type(BOOL**TENSOR, Any)
assert_type(FLOAT**TENSOR, Any)
assert_type(INT**TENSOR, Any)
assert_type(BOOL - TENSOR, Any)
assert_type(FLOAT - TENSOR, Any)
assert_type(INT - TENSOR, Any)
assert_type(BOOL / TENSOR, Any)
assert_type(FLOAT / TENSOR, Any)
assert_type(INT / TENSOR, Any)
assert_type(BOOL % TENSOR, Any)
assert_type(FLOAT % TENSOR, Any)
assert_type(INT % TENSOR, Any)
assert_type(BOOL << TENSOR, Any)
assert_type(FLOAT << TENSOR, Any)
assert_type(INT << TENSOR, Any)
assert_type(BOOL >> TENSOR, Any)
assert_type(FLOAT >> TENSOR, Any)
assert_type(INT >> TENSOR, Any)
assert_type(FLOAT & TENSOR, Tensor) # type: ignore[operator]
assert_type(INT & TENSOR, Any) # type: ignore[operator]

View File

@ -803,7 +803,7 @@ def slice_scatter(
if start == 0 and end == dim_size and step == 1:
return src.clone()
indices: list[Optional[Tensor]] = [None] * input.dim()
indices = [None] * input.dim()
idx = torch.arange(dim_size, device=input.device)
indices[dim] = (idx - start) // step
@ -1664,7 +1664,6 @@ def native_layer_norm_backward(
)
mean = _unsqueeze_to_dim(mean, input_cast.dim()) # type: ignore[union-attr]
rstd = _unsqueeze_to_dim(rstd, input_cast.dim()) # type: ignore[union-attr]
assert input_cast is not None
x_hat = (input_cast - mean) * rstd
if weight_cast is not None:
grad_x_hat = grad_out_cast * weight_cast

View File

@ -1,5 +1,5 @@
# mypy: allow-untyped-defs
from typing import cast, Optional
from typing import Optional
import torch
import torch.utils._pytree as pytree
@ -69,10 +69,12 @@ def philox_rand_offset(
curand4_engine_calls = 4
device_property = torch.cuda.get_device_properties(torch.cuda.current_device())
blocks_per_sm = device_property.max_threads_per_multi_processor // block_size
num = cast(int, numel)
grid_size = (num + block_size - 1) // block_size
grid_size = (numel + block_size - 1) // block_size
grid_size = min(grid_size, device_property.multi_processor_count * blocks_per_sm)
return ((num - 1) // (block_size * grid_size * unroll) + 1) * curand4_engine_calls
offset = (
(numel - 1) // (block_size * grid_size * unroll) + 1
) * curand4_engine_calls
return offset
def register_philox_rand():

View File

@ -6,8 +6,7 @@ import warnings
from collections import OrderedDict
from copy import deepcopy
from numbers import Number
from typing import Any, Callable, cast, Optional, TypeVar, Union
from typing_extensions import Concatenate, ParamSpec
from typing import Any, Callable, cast, Optional, Union
import torch
import torch._C as _C
@ -28,21 +27,16 @@ from torch.overrides import (
)
_P = ParamSpec("_P")
_TensorLike = TypeVar("_TensorLike", bound=_C.TensorBase)
def _handle_torch_function_and_wrap_type_error_to_not_implemented(f):
assigned = functools.WRAPPER_ASSIGNMENTS
def _handle_torch_function_and_wrap_type_error_to_not_implemented(
f: Callable[Concatenate[_TensorLike, _P], "Tensor"],
) -> Callable[Concatenate[_TensorLike, _P], "Tensor"]:
@functools.wraps(f)
def wrapped(self: _TensorLike, *args: _P.args, **kwargs: _P.kwargs) -> "Tensor":
@functools.wraps(f, assigned=assigned)
def wrapped(*args, **kwargs):
try:
# See https://github.com/pytorch/pytorch/issues/75462
sargs = self, *args
if has_torch_function(sargs):
return handle_torch_function(wrapped, sargs, *sargs, **kwargs)
return f(self, *args, **kwargs)
if has_torch_function(args):
return handle_torch_function(wrapped, args, *args, **kwargs)
return f(*args, **kwargs)
except TypeError:
return NotImplemented
@ -1100,11 +1094,11 @@ class Tensor(torch._C.TensorBase):
)
@_handle_torch_function_and_wrap_type_error_to_not_implemented
def __rsub__(self, other: Union["Tensor", int, float, bool, complex]) -> "Tensor":
def __rsub__(self, other):
return _C._VariableFunctions.rsub(self, other)
@_handle_torch_function_and_wrap_type_error_to_not_implemented
def __rdiv__(self, other: Union["Tensor", int, float, bool, complex]) -> "Tensor":
def __rdiv__(self, other):
return self.reciprocal() * other
__rtruediv__ = __rdiv__
@ -1119,13 +1113,12 @@ class Tensor(torch._C.TensorBase):
_C.TensorBase.pow
),
)
__ipow__ = _handle_torch_function_and_wrap_type_error_to_not_implemented(
_C.TensorBase.pow_
)
@_handle_torch_function_and_wrap_type_error_to_not_implemented
def __rmod__(self, other: Union["Tensor", int, float, bool, complex]) -> "Tensor":
def __rmod__(self, other):
return torch.remainder(other, self)
def __format__(self, format_spec):
@ -1136,33 +1129,27 @@ class Tensor(torch._C.TensorBase):
return object.__format__(self, format_spec)
@_handle_torch_function_and_wrap_type_error_to_not_implemented
def __rpow__(self, other: Union["Tensor", int, float, bool, complex]) -> "Tensor":
def __rpow__(self, other):
return torch.pow(other, self)
@_handle_torch_function_and_wrap_type_error_to_not_implemented
def __floordiv__(self, other: Union["Tensor", int, float, bool]) -> "Tensor": # type: ignore[override]
# TODO(rec): the superclass says it accepts complex here,
# but torch.floor_divide says it doesn't.
def __floordiv__(self, other):
return torch.floor_divide(self, other)
@_handle_torch_function_and_wrap_type_error_to_not_implemented
def __rfloordiv__(self, other: Union["Tensor", int, float, bool]) -> "Tensor": # type: ignore[override]
def __rfloordiv__(self, other):
return torch.floor_divide(other, self)
@_handle_torch_function_and_wrap_type_error_to_not_implemented
def __rlshift__(
self, other: Union["Tensor", int, float, bool, complex]
) -> "Tensor":
def __rlshift__(self, other):
return torch.bitwise_left_shift(other, self)
@_handle_torch_function_and_wrap_type_error_to_not_implemented
def __rrshift__(
self, other: Union["Tensor", int, float, bool, complex]
) -> "Tensor":
def __rrshift__(self, other):
return torch.bitwise_right_shift(other, self)
@_handle_torch_function_and_wrap_type_error_to_not_implemented
def __rmatmul__(self, other: "Tensor") -> "Tensor":
def __rmatmul__(self, other):
return torch.matmul(other, self)
__pos__ = _C.TensorBase.positive

View File

@ -631,7 +631,6 @@ def powerSGD_hook(
if state.use_error_feedback:
# Memorize the local errors.
assert input_tensor_cp is not None
state.error_dict[bucket_index] = input_tensor_cp - input_tensor
if not state.warm_start:
state.p_memory_dict.clear()
@ -844,7 +843,6 @@ def batched_powerSGD_hook(
if state.use_error_feedback:
# Memorize the local errors.
assert input_tensor_cp is not None
state.error_dict[bucket_index] = input_tensor_cp - input_tensor
# Removing this seemingly unnecessary sync somehow may cause failures.
# See: https://github.com/pytorch/pytorch/pull/54838

View File

@ -1,6 +1,4 @@
# mypy: allow-untyped-defs
from typing import Union
import torch
from torch import Tensor
from torch.distributions.distribution import Distribution
@ -57,7 +55,7 @@ class ExponentialFamily(Distribution):
"""
Method to compute the entropy using Bregman divergence of the log normalizer.
"""
result: Union[Tensor, float] = -self._mean_carrier_measure
result = -self._mean_carrier_measure
nparams = [p.detach().requires_grad_() for p in self._natural_params]
lg_normal = self._log_normalizer(*nparams)
gradients = torch.autograd.grad(lg_normal.sum(), nparams, create_graph=True)

View File

@ -451,11 +451,9 @@ def _single_tensor_adam(
# expavg.lerp(grad^2, 1-beta2)
exp_avg_sq.lerp_(torch.square(grad), weight=1 - beta2)
else:
exp_avg_sq.mul_(beta2).addcmul_(
grad, grad, value=cast(float, 1 - beta2)
)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
else:
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) # type: ignore[arg-type]
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
if capturable or differentiable:
step = step_t
@ -526,7 +524,7 @@ def _single_tensor_adam(
else:
denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps)
param.addcdiv_(exp_avg, denom, value=-step_size) # type: ignore[arg-type]
param.addcdiv_(exp_avg, denom, value=-step_size)
# Lastly, switch back to complex view
if amsgrad and torch.is_complex(params[i]):
@ -672,9 +670,7 @@ def _multi_tensor_adam(
# Decay the first and second moment running average coefficient
# Use device beta1 if beta1 is a tensor to ensure all
# tensors are on the same device
torch._foreach_lerp_(
device_exp_avgs, device_grads, cast(float, 1 - device_beta1)
)
torch._foreach_lerp_(device_exp_avgs, device_grads, 1 - device_beta1)
torch._foreach_mul_(device_exp_avg_sqs, beta2)

View File

@ -366,9 +366,7 @@ def _single_tensor_nadam(
grad, denom, value=(-lr * (1.0 - mu) / (1.0 - _get_value(mu_product)))
)
param.addcdiv_(
exp_avg,
denom,
value=cast(float, (-lr * mu_next) / (1.0 - mu_product_next)),
exp_avg, denom, value=(-lr * mu_next) / (1.0 - mu_product_next)
)

View File

@ -5,7 +5,7 @@ import math
import warnings
from collections.abc import Iterable
from copy import deepcopy
from typing import Any, Callable, cast, Literal, Optional, Union
from typing import Any, Callable, Literal, Optional, Union
import torch
from torch import Tensor
@ -68,9 +68,7 @@ def get_swa_multi_avg_fn():
averaged_param_list[0]
):
torch._foreach_lerp_(
averaged_param_list,
current_param_list,
cast(float, 1 / (num_averaged + 1)),
averaged_param_list, current_param_list, 1 / (num_averaged + 1)
)
else:
diffs = torch._foreach_sub(current_param_list, averaged_param_list)