mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
This reverts commit 59b7e52ad8.
Reverted https://github.com/pytorch/pytorch/pull/146845 on behalf of https://github.com/jeanschmidt due to Seems to break a few code dependencies in multiple places ([comment](https://github.com/pytorch/pytorch/pull/146845#issuecomment-2666656834))
This commit is contained in:
parent
57060bebf3
commit
302f56a1f2
|
|
@ -27,7 +27,7 @@ assert_type(TENSOR >= TENSOR, Tensor)
|
|||
assert_type(TENSOR + TENSOR, Tensor)
|
||||
assert_type(TENSOR - TENSOR, Tensor)
|
||||
assert_type(TENSOR * TENSOR, Tensor)
|
||||
assert_type(TENSOR // TENSOR, Tensor)
|
||||
assert_type(TENSOR // TENSOR, Any)
|
||||
assert_type(TENSOR / TENSOR, Tensor)
|
||||
assert_type(TENSOR % TENSOR, Tensor)
|
||||
assert_type(TENSOR**TENSOR, Tensor)
|
||||
|
|
@ -46,7 +46,7 @@ assert_type(TENSOR >= BOOL, Tensor)
|
|||
assert_type(TENSOR + BOOL, Tensor)
|
||||
assert_type(TENSOR - BOOL, Tensor)
|
||||
assert_type(TENSOR * BOOL, Tensor)
|
||||
assert_type(TENSOR // BOOL, Tensor)
|
||||
assert_type(TENSOR // BOOL, Any)
|
||||
assert_type(TENSOR / BOOL, Tensor)
|
||||
assert_type(TENSOR % BOOL, Tensor)
|
||||
assert_type(TENSOR**BOOL, Tensor)
|
||||
|
|
@ -63,14 +63,14 @@ assert_type(BOOL > TENSOR, Tensor)
|
|||
assert_type(BOOL <= TENSOR, Tensor)
|
||||
assert_type(BOOL >= TENSOR, Tensor)
|
||||
assert_type(BOOL + TENSOR, Tensor)
|
||||
assert_type(BOOL - TENSOR, Tensor)
|
||||
assert_type(BOOL - TENSOR, Any)
|
||||
assert_type(BOOL * TENSOR, Tensor)
|
||||
assert_type(BOOL // TENSOR, Tensor)
|
||||
assert_type(BOOL / TENSOR, Tensor)
|
||||
assert_type(BOOL % TENSOR, Tensor)
|
||||
assert_type(BOOL**TENSOR, Tensor)
|
||||
assert_type(BOOL << TENSOR, Tensor)
|
||||
assert_type(BOOL >> TENSOR, Tensor)
|
||||
assert_type(BOOL // TENSOR, Any)
|
||||
assert_type(BOOL / TENSOR, Any)
|
||||
assert_type(BOOL % TENSOR, Any)
|
||||
assert_type(BOOL**TENSOR, Any)
|
||||
assert_type(BOOL << TENSOR, Any)
|
||||
assert_type(BOOL >> TENSOR, Any)
|
||||
assert_type(BOOL & TENSOR, Tensor)
|
||||
assert_type(BOOL | TENSOR, Tensor)
|
||||
assert_type(BOOL ^ TENSOR, Tensor)
|
||||
|
|
@ -84,7 +84,7 @@ assert_type(TENSOR >= INT, Tensor)
|
|||
assert_type(TENSOR + INT, Tensor)
|
||||
assert_type(TENSOR - INT, Tensor)
|
||||
assert_type(TENSOR * INT, Tensor)
|
||||
assert_type(TENSOR // INT, Tensor)
|
||||
assert_type(TENSOR // INT, Any)
|
||||
assert_type(TENSOR / INT, Tensor)
|
||||
assert_type(TENSOR % INT, Tensor)
|
||||
assert_type(TENSOR**INT, Tensor)
|
||||
|
|
@ -101,14 +101,14 @@ assert_type(INT > TENSOR, Tensor)
|
|||
assert_type(INT <= TENSOR, Tensor)
|
||||
assert_type(INT >= TENSOR, Tensor)
|
||||
assert_type(INT + TENSOR, Tensor)
|
||||
assert_type(INT - TENSOR, Tensor)
|
||||
assert_type(INT - TENSOR, Any)
|
||||
assert_type(INT * TENSOR, Tensor)
|
||||
assert_type(INT // TENSOR, Tensor)
|
||||
assert_type(INT / TENSOR, Tensor)
|
||||
assert_type(INT % TENSOR, Tensor)
|
||||
assert_type(INT**TENSOR, Tensor)
|
||||
assert_type(INT << TENSOR, Tensor)
|
||||
assert_type(INT >> TENSOR, Tensor)
|
||||
assert_type(INT // TENSOR, Any)
|
||||
assert_type(INT / TENSOR, Any)
|
||||
assert_type(INT % TENSOR, Any)
|
||||
assert_type(INT**TENSOR, Any)
|
||||
assert_type(INT << TENSOR, Any)
|
||||
assert_type(INT >> TENSOR, Any)
|
||||
assert_type(INT & TENSOR, Any) # type: ignore[operator]
|
||||
assert_type(INT | TENSOR, Any) # type: ignore[operator]
|
||||
assert_type(INT ^ TENSOR, Any) # type: ignore[operator]
|
||||
|
|
@ -122,7 +122,7 @@ assert_type(TENSOR >= FLOAT, Tensor)
|
|||
assert_type(TENSOR + FLOAT, Tensor)
|
||||
assert_type(TENSOR - FLOAT, Tensor)
|
||||
assert_type(TENSOR * FLOAT, Tensor)
|
||||
assert_type(TENSOR // FLOAT, Tensor)
|
||||
assert_type(TENSOR // FLOAT, Any)
|
||||
assert_type(TENSOR / FLOAT, Tensor)
|
||||
assert_type(TENSOR % FLOAT, Tensor)
|
||||
assert_type(TENSOR**FLOAT, Tensor)
|
||||
|
|
@ -139,14 +139,14 @@ assert_type(FLOAT > TENSOR, Tensor)
|
|||
assert_type(FLOAT <= TENSOR, Tensor)
|
||||
assert_type(FLOAT >= TENSOR, Tensor)
|
||||
assert_type(FLOAT + TENSOR, Tensor)
|
||||
assert_type(FLOAT - TENSOR, Tensor)
|
||||
assert_type(FLOAT - TENSOR, Any)
|
||||
assert_type(FLOAT * TENSOR, Tensor)
|
||||
assert_type(FLOAT // TENSOR, Tensor)
|
||||
assert_type(FLOAT / TENSOR, Tensor)
|
||||
assert_type(FLOAT % TENSOR, Tensor)
|
||||
assert_type(FLOAT**TENSOR, Tensor)
|
||||
assert_type(FLOAT << TENSOR, Tensor)
|
||||
assert_type(FLOAT >> TENSOR, Tensor)
|
||||
assert_type(FLOAT // TENSOR, Any)
|
||||
assert_type(FLOAT / TENSOR, Any)
|
||||
assert_type(FLOAT % TENSOR, Any)
|
||||
assert_type(FLOAT**TENSOR, Any)
|
||||
assert_type(FLOAT << TENSOR, Any)
|
||||
assert_type(FLOAT >> TENSOR, Any)
|
||||
assert_type(FLOAT & TENSOR, Tensor) # type: ignore[operator]
|
||||
assert_type(FLOAT | TENSOR, Tensor) # type: ignore[operator]
|
||||
assert_type(FLOAT ^ TENSOR, Tensor) # type: ignore[operator]
|
||||
|
|
@ -377,6 +377,38 @@ assert_type(BOOL ^ BINARY, Binary)
|
|||
# Tensor operators whose types could be improved
|
||||
# This is the "diff" of the first and second sections.
|
||||
|
||||
assert_type(BOOL // TENSOR, Any)
|
||||
assert_type(FLOAT // TENSOR, Any)
|
||||
assert_type(INT // TENSOR, Any)
|
||||
assert_type(TENSOR // BOOL, Any)
|
||||
assert_type(TENSOR // FLOAT, Any)
|
||||
assert_type(TENSOR // INT, Any)
|
||||
assert_type(TENSOR // TENSOR, Any)
|
||||
|
||||
assert_type(BOOL**TENSOR, Any)
|
||||
assert_type(FLOAT**TENSOR, Any)
|
||||
assert_type(INT**TENSOR, Any)
|
||||
|
||||
assert_type(BOOL - TENSOR, Any)
|
||||
assert_type(FLOAT - TENSOR, Any)
|
||||
assert_type(INT - TENSOR, Any)
|
||||
|
||||
assert_type(BOOL / TENSOR, Any)
|
||||
assert_type(FLOAT / TENSOR, Any)
|
||||
assert_type(INT / TENSOR, Any)
|
||||
|
||||
assert_type(BOOL % TENSOR, Any)
|
||||
assert_type(FLOAT % TENSOR, Any)
|
||||
assert_type(INT % TENSOR, Any)
|
||||
|
||||
assert_type(BOOL << TENSOR, Any)
|
||||
assert_type(FLOAT << TENSOR, Any)
|
||||
assert_type(INT << TENSOR, Any)
|
||||
|
||||
assert_type(BOOL >> TENSOR, Any)
|
||||
assert_type(FLOAT >> TENSOR, Any)
|
||||
assert_type(INT >> TENSOR, Any)
|
||||
|
||||
assert_type(FLOAT & TENSOR, Tensor) # type: ignore[operator]
|
||||
assert_type(INT & TENSOR, Any) # type: ignore[operator]
|
||||
|
||||
|
|
|
|||
|
|
@ -803,7 +803,7 @@ def slice_scatter(
|
|||
if start == 0 and end == dim_size and step == 1:
|
||||
return src.clone()
|
||||
|
||||
indices: list[Optional[Tensor]] = [None] * input.dim()
|
||||
indices = [None] * input.dim()
|
||||
idx = torch.arange(dim_size, device=input.device)
|
||||
indices[dim] = (idx - start) // step
|
||||
|
||||
|
|
@ -1664,7 +1664,6 @@ def native_layer_norm_backward(
|
|||
)
|
||||
mean = _unsqueeze_to_dim(mean, input_cast.dim()) # type: ignore[union-attr]
|
||||
rstd = _unsqueeze_to_dim(rstd, input_cast.dim()) # type: ignore[union-attr]
|
||||
assert input_cast is not None
|
||||
x_hat = (input_cast - mean) * rstd
|
||||
if weight_cast is not None:
|
||||
grad_x_hat = grad_out_cast * weight_cast
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
# mypy: allow-untyped-defs
|
||||
from typing import cast, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
|
|
@ -69,10 +69,12 @@ def philox_rand_offset(
|
|||
curand4_engine_calls = 4
|
||||
device_property = torch.cuda.get_device_properties(torch.cuda.current_device())
|
||||
blocks_per_sm = device_property.max_threads_per_multi_processor // block_size
|
||||
num = cast(int, numel)
|
||||
grid_size = (num + block_size - 1) // block_size
|
||||
grid_size = (numel + block_size - 1) // block_size
|
||||
grid_size = min(grid_size, device_property.multi_processor_count * blocks_per_sm)
|
||||
return ((num - 1) // (block_size * grid_size * unroll) + 1) * curand4_engine_calls
|
||||
offset = (
|
||||
(numel - 1) // (block_size * grid_size * unroll) + 1
|
||||
) * curand4_engine_calls
|
||||
return offset
|
||||
|
||||
|
||||
def register_philox_rand():
|
||||
|
|
|
|||
|
|
@ -6,8 +6,7 @@ import warnings
|
|||
from collections import OrderedDict
|
||||
from copy import deepcopy
|
||||
from numbers import Number
|
||||
from typing import Any, Callable, cast, Optional, TypeVar, Union
|
||||
from typing_extensions import Concatenate, ParamSpec
|
||||
from typing import Any, Callable, cast, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch._C as _C
|
||||
|
|
@ -28,21 +27,16 @@ from torch.overrides import (
|
|||
)
|
||||
|
||||
|
||||
_P = ParamSpec("_P")
|
||||
_TensorLike = TypeVar("_TensorLike", bound=_C.TensorBase)
|
||||
def _handle_torch_function_and_wrap_type_error_to_not_implemented(f):
|
||||
assigned = functools.WRAPPER_ASSIGNMENTS
|
||||
|
||||
|
||||
def _handle_torch_function_and_wrap_type_error_to_not_implemented(
|
||||
f: Callable[Concatenate[_TensorLike, _P], "Tensor"],
|
||||
) -> Callable[Concatenate[_TensorLike, _P], "Tensor"]:
|
||||
@functools.wraps(f)
|
||||
def wrapped(self: _TensorLike, *args: _P.args, **kwargs: _P.kwargs) -> "Tensor":
|
||||
@functools.wraps(f, assigned=assigned)
|
||||
def wrapped(*args, **kwargs):
|
||||
try:
|
||||
# See https://github.com/pytorch/pytorch/issues/75462
|
||||
sargs = self, *args
|
||||
if has_torch_function(sargs):
|
||||
return handle_torch_function(wrapped, sargs, *sargs, **kwargs)
|
||||
return f(self, *args, **kwargs)
|
||||
if has_torch_function(args):
|
||||
return handle_torch_function(wrapped, args, *args, **kwargs)
|
||||
return f(*args, **kwargs)
|
||||
except TypeError:
|
||||
return NotImplemented
|
||||
|
||||
|
|
@ -1100,11 +1094,11 @@ class Tensor(torch._C.TensorBase):
|
|||
)
|
||||
|
||||
@_handle_torch_function_and_wrap_type_error_to_not_implemented
|
||||
def __rsub__(self, other: Union["Tensor", int, float, bool, complex]) -> "Tensor":
|
||||
def __rsub__(self, other):
|
||||
return _C._VariableFunctions.rsub(self, other)
|
||||
|
||||
@_handle_torch_function_and_wrap_type_error_to_not_implemented
|
||||
def __rdiv__(self, other: Union["Tensor", int, float, bool, complex]) -> "Tensor":
|
||||
def __rdiv__(self, other):
|
||||
return self.reciprocal() * other
|
||||
|
||||
__rtruediv__ = __rdiv__
|
||||
|
|
@ -1119,13 +1113,12 @@ class Tensor(torch._C.TensorBase):
|
|||
_C.TensorBase.pow
|
||||
),
|
||||
)
|
||||
|
||||
__ipow__ = _handle_torch_function_and_wrap_type_error_to_not_implemented(
|
||||
_C.TensorBase.pow_
|
||||
)
|
||||
|
||||
@_handle_torch_function_and_wrap_type_error_to_not_implemented
|
||||
def __rmod__(self, other: Union["Tensor", int, float, bool, complex]) -> "Tensor":
|
||||
def __rmod__(self, other):
|
||||
return torch.remainder(other, self)
|
||||
|
||||
def __format__(self, format_spec):
|
||||
|
|
@ -1136,33 +1129,27 @@ class Tensor(torch._C.TensorBase):
|
|||
return object.__format__(self, format_spec)
|
||||
|
||||
@_handle_torch_function_and_wrap_type_error_to_not_implemented
|
||||
def __rpow__(self, other: Union["Tensor", int, float, bool, complex]) -> "Tensor":
|
||||
def __rpow__(self, other):
|
||||
return torch.pow(other, self)
|
||||
|
||||
@_handle_torch_function_and_wrap_type_error_to_not_implemented
|
||||
def __floordiv__(self, other: Union["Tensor", int, float, bool]) -> "Tensor": # type: ignore[override]
|
||||
# TODO(rec): the superclass says it accepts complex here,
|
||||
# but torch.floor_divide says it doesn't.
|
||||
def __floordiv__(self, other):
|
||||
return torch.floor_divide(self, other)
|
||||
|
||||
@_handle_torch_function_and_wrap_type_error_to_not_implemented
|
||||
def __rfloordiv__(self, other: Union["Tensor", int, float, bool]) -> "Tensor": # type: ignore[override]
|
||||
def __rfloordiv__(self, other):
|
||||
return torch.floor_divide(other, self)
|
||||
|
||||
@_handle_torch_function_and_wrap_type_error_to_not_implemented
|
||||
def __rlshift__(
|
||||
self, other: Union["Tensor", int, float, bool, complex]
|
||||
) -> "Tensor":
|
||||
def __rlshift__(self, other):
|
||||
return torch.bitwise_left_shift(other, self)
|
||||
|
||||
@_handle_torch_function_and_wrap_type_error_to_not_implemented
|
||||
def __rrshift__(
|
||||
self, other: Union["Tensor", int, float, bool, complex]
|
||||
) -> "Tensor":
|
||||
def __rrshift__(self, other):
|
||||
return torch.bitwise_right_shift(other, self)
|
||||
|
||||
@_handle_torch_function_and_wrap_type_error_to_not_implemented
|
||||
def __rmatmul__(self, other: "Tensor") -> "Tensor":
|
||||
def __rmatmul__(self, other):
|
||||
return torch.matmul(other, self)
|
||||
|
||||
__pos__ = _C.TensorBase.positive
|
||||
|
|
|
|||
|
|
@ -631,7 +631,6 @@ def powerSGD_hook(
|
|||
|
||||
if state.use_error_feedback:
|
||||
# Memorize the local errors.
|
||||
assert input_tensor_cp is not None
|
||||
state.error_dict[bucket_index] = input_tensor_cp - input_tensor
|
||||
if not state.warm_start:
|
||||
state.p_memory_dict.clear()
|
||||
|
|
@ -844,7 +843,6 @@ def batched_powerSGD_hook(
|
|||
|
||||
if state.use_error_feedback:
|
||||
# Memorize the local errors.
|
||||
assert input_tensor_cp is not None
|
||||
state.error_dict[bucket_index] = input_tensor_cp - input_tensor
|
||||
# Removing this seemingly unnecessary sync somehow may cause failures.
|
||||
# See: https://github.com/pytorch/pytorch/pull/54838
|
||||
|
|
|
|||
|
|
@ -1,6 +1,4 @@
|
|||
# mypy: allow-untyped-defs
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.distributions.distribution import Distribution
|
||||
|
|
@ -57,7 +55,7 @@ class ExponentialFamily(Distribution):
|
|||
"""
|
||||
Method to compute the entropy using Bregman divergence of the log normalizer.
|
||||
"""
|
||||
result: Union[Tensor, float] = -self._mean_carrier_measure
|
||||
result = -self._mean_carrier_measure
|
||||
nparams = [p.detach().requires_grad_() for p in self._natural_params]
|
||||
lg_normal = self._log_normalizer(*nparams)
|
||||
gradients = torch.autograd.grad(lg_normal.sum(), nparams, create_graph=True)
|
||||
|
|
|
|||
|
|
@ -451,11 +451,9 @@ def _single_tensor_adam(
|
|||
# expavg.lerp(grad^2, 1-beta2)
|
||||
exp_avg_sq.lerp_(torch.square(grad), weight=1 - beta2)
|
||||
else:
|
||||
exp_avg_sq.mul_(beta2).addcmul_(
|
||||
grad, grad, value=cast(float, 1 - beta2)
|
||||
)
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
||||
else:
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) # type: ignore[arg-type]
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
||||
|
||||
if capturable or differentiable:
|
||||
step = step_t
|
||||
|
|
@ -526,7 +524,7 @@ def _single_tensor_adam(
|
|||
else:
|
||||
denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps)
|
||||
|
||||
param.addcdiv_(exp_avg, denom, value=-step_size) # type: ignore[arg-type]
|
||||
param.addcdiv_(exp_avg, denom, value=-step_size)
|
||||
|
||||
# Lastly, switch back to complex view
|
||||
if amsgrad and torch.is_complex(params[i]):
|
||||
|
|
@ -672,9 +670,7 @@ def _multi_tensor_adam(
|
|||
# Decay the first and second moment running average coefficient
|
||||
# Use device beta1 if beta1 is a tensor to ensure all
|
||||
# tensors are on the same device
|
||||
torch._foreach_lerp_(
|
||||
device_exp_avgs, device_grads, cast(float, 1 - device_beta1)
|
||||
)
|
||||
torch._foreach_lerp_(device_exp_avgs, device_grads, 1 - device_beta1)
|
||||
|
||||
torch._foreach_mul_(device_exp_avg_sqs, beta2)
|
||||
|
||||
|
|
|
|||
|
|
@ -366,9 +366,7 @@ def _single_tensor_nadam(
|
|||
grad, denom, value=(-lr * (1.0 - mu) / (1.0 - _get_value(mu_product)))
|
||||
)
|
||||
param.addcdiv_(
|
||||
exp_avg,
|
||||
denom,
|
||||
value=cast(float, (-lr * mu_next) / (1.0 - mu_product_next)),
|
||||
exp_avg, denom, value=(-lr * mu_next) / (1.0 - mu_product_next)
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import math
|
|||
import warnings
|
||||
from collections.abc import Iterable
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable, cast, Literal, Optional, Union
|
||||
from typing import Any, Callable, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
|
@ -68,9 +68,7 @@ def get_swa_multi_avg_fn():
|
|||
averaged_param_list[0]
|
||||
):
|
||||
torch._foreach_lerp_(
|
||||
averaged_param_list,
|
||||
current_param_list,
|
||||
cast(float, 1 / (num_averaged + 1)),
|
||||
averaged_param_list, current_param_list, 1 / (num_averaged + 1)
|
||||
)
|
||||
else:
|
||||
diffs = torch._foreach_sub(current_param_list, averaged_param_list)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user