Revert "Tighten type hints for tensor arithmetic (#135392)"

This reverts commit d378819068.

Reverted https://github.com/pytorch/pytorch/pull/135392 on behalf of https://github.com/ZainRizvi due to Sorry but this is breaking internally. See D65641103 for more details ([comment](https://github.com/pytorch/pytorch/pull/135392#issuecomment-2465906839))
This commit is contained in:
PyTorch MergeBot 2024-11-08 23:44:41 +00:00
parent 2af5172774
commit beae7725be
4 changed files with 28 additions and 46 deletions

View File

@ -177,18 +177,14 @@ blocklist = [
"copy_",
]
shift_ops = (
"lshift",
"rshift",
"ilshift",
"irshift", # inplace ops
)
arithmetic_ops = (
binary_ops = (
"add",
"sub",
"mul",
"div",
"pow",
"lshift",
"rshift",
"mod",
"truediv",
"matmul",
@ -199,26 +195,24 @@ arithmetic_ops = (
"rtruediv",
"rfloordiv",
"rpow", # reverse arithmetic
"iadd",
"idiv",
"imul",
"isub",
"ifloordiv",
"imod", # inplace ops
)
logic_ops = (
"and",
"or",
"xor",
"rand",
"ror",
"rxor", # reverse logic
"rxor", # logic
"iadd",
"iand",
"idiv",
"ilshift",
"imul",
"ior",
"ixor", # inplace ops
"irshift",
"isub",
"ixor",
"ifloordiv",
"imod", # inplace ops
)
binary_ops = shift_ops + arithmetic_ops + logic_ops
symmetric_comparison_ops = ("eq", "ne")
asymmetric_comparison_ops = ("ge", "gt", "lt", "le")
comparison_ops = symmetric_comparison_ops + asymmetric_comparison_ops
@ -238,28 +232,14 @@ def sig_for_ops(opname: str) -> list[str]:
assert opname.endswith("__") and opname.startswith("__"), f"Unexpected op {opname}"
name = opname[2:-2]
if name == "rpow":
return [ # somehow required to make mypy ci happy?
f"def {opname}(self, other: Union[Tensor, Number, _complex]) -> Tensor: ... # type: ignore[has-type]"
]
elif name in arithmetic_ops:
return [
f"def {opname}(self, other: Union[Tensor, Number, _complex]) -> Tensor: ..."
]
elif name in logic_ops:
return [f"def {opname}(self, other: Union[Tensor, _bool]) -> Tensor: ..."]
elif name in shift_ops:
return [f"def {opname}(self, other: Union[Tensor, _int]) -> Tensor: ..."]
elif name in symmetric_comparison_ops:
return [
if name in binary_ops:
return [f"def {opname}(self, other: Any) -> Tensor: ..."]
elif name in comparison_ops:
sig = f"def {opname}(self, other: Any) -> Tensor: ..."
if name in symmetric_comparison_ops:
# unsafe override https://github.com/python/mypy/issues/5704
f"def {opname}(self, other: Union[Tensor, Number, _complex]) -> Tensor: ... # type: ignore[override]",
f"def {opname}(self, other: Any) -> _bool: ...",
]
elif name in asymmetric_comparison_ops:
return [
f"def {opname}(self, other: Union[Tensor, Number, _complex]) -> Tensor: ..."
]
sig += " # type: ignore[override]"
return [sig]
elif name in unary_ops:
return [f"def {opname}(self) -> Tensor: ..."]
elif name in to_py_type_ops:

View File

@ -2291,8 +2291,7 @@ def native_batch_norm_backward(
mean = save_mean_cast
invstd = save_invstd_cast
if train:
assert mean is not None and invstd is not None
assert save_mean_cast is not None and save_invstd_cast is not None
else:
assert running_mean_cast is not None and running_var_cast is not None
mean = running_mean_cast

View File

@ -33,7 +33,6 @@ def efficient_conv_bn_eval(
"""
assert bn.running_var is not None
assert bn.running_mean is not None
# These lines of code are designed to deal with various cases
# like bn without affine transform, and conv without bias

View File

@ -128,7 +128,8 @@ def cross_layer_equalization(module1, module2, output_axis=0, input_axis=1):
"module type not supported:", type(module1), " ", type(module2)
)
bias = get_module_bias(module1) if has_bias(module1) else None
conv1_has_bias = has_bias(module1)
bias = None
weight1 = get_module_weight(module1)
weight2 = get_module_weight(module2)
@ -139,6 +140,9 @@ def cross_layer_equalization(module1, module2, output_axis=0, input_axis=1):
number input channels of second arg"
)
if conv1_has_bias:
bias = get_module_bias(module1)
weight1_range = channel_range(weight1, output_axis)
weight2_range = channel_range(weight2, input_axis)
@ -147,7 +151,7 @@ def cross_layer_equalization(module1, module2, output_axis=0, input_axis=1):
scaling_factors = torch.sqrt(weight1_range / weight2_range)
inverse_scaling_factors = torch.reciprocal(scaling_factors)
if bias is not None:
if conv1_has_bias:
bias = bias * inverse_scaling_factors
# formatting the scaling (1D) tensors to be applied on the given argument tensors
@ -164,7 +168,7 @@ def cross_layer_equalization(module1, module2, output_axis=0, input_axis=1):
weight2 = weight2 * scaling_factors
set_module_weight(module1, weight1)
if bias is not None:
if conv1_has_bias:
set_module_bias(module1, bias)
set_module_weight(module2, weight2)