mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Tighten type hints for tensor arithmetic (#135392)
Fixes #124015 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135392 Approved by: https://github.com/ezyang
This commit is contained in:
parent
080e0ca584
commit
bf5cd8d011
|
|
@ -177,14 +177,18 @@ blocklist = [
|
||||||
"copy_",
|
"copy_",
|
||||||
]
|
]
|
||||||
|
|
||||||
binary_ops = (
|
shift_ops = (
|
||||||
|
"lshift",
|
||||||
|
"rshift",
|
||||||
|
"ilshift",
|
||||||
|
"irshift", # inplace ops
|
||||||
|
)
|
||||||
|
arithmetic_ops = (
|
||||||
"add",
|
"add",
|
||||||
"sub",
|
"sub",
|
||||||
"mul",
|
"mul",
|
||||||
"div",
|
"div",
|
||||||
"pow",
|
"pow",
|
||||||
"lshift",
|
|
||||||
"rshift",
|
|
||||||
"mod",
|
"mod",
|
||||||
"truediv",
|
"truediv",
|
||||||
"matmul",
|
"matmul",
|
||||||
|
|
@ -195,24 +199,26 @@ binary_ops = (
|
||||||
"rtruediv",
|
"rtruediv",
|
||||||
"rfloordiv",
|
"rfloordiv",
|
||||||
"rpow", # reverse arithmetic
|
"rpow", # reverse arithmetic
|
||||||
|
"iadd",
|
||||||
|
"idiv",
|
||||||
|
"imul",
|
||||||
|
"isub",
|
||||||
|
"ifloordiv",
|
||||||
|
"imod", # inplace ops
|
||||||
|
)
|
||||||
|
logic_ops = (
|
||||||
"and",
|
"and",
|
||||||
"or",
|
"or",
|
||||||
"xor",
|
"xor",
|
||||||
"rand",
|
"rand",
|
||||||
"ror",
|
"ror",
|
||||||
"rxor", # logic
|
"rxor", # reverse logic
|
||||||
"iadd",
|
|
||||||
"iand",
|
"iand",
|
||||||
"idiv",
|
|
||||||
"ilshift",
|
|
||||||
"imul",
|
|
||||||
"ior",
|
"ior",
|
||||||
"irshift",
|
"ixor", # inplace ops
|
||||||
"isub",
|
|
||||||
"ixor",
|
|
||||||
"ifloordiv",
|
|
||||||
"imod", # inplace ops
|
|
||||||
)
|
)
|
||||||
|
binary_ops = shift_ops + arithmetic_ops + logic_ops
|
||||||
|
|
||||||
symmetric_comparison_ops = ("eq", "ne")
|
symmetric_comparison_ops = ("eq", "ne")
|
||||||
asymmetric_comparison_ops = ("ge", "gt", "lt", "le")
|
asymmetric_comparison_ops = ("ge", "gt", "lt", "le")
|
||||||
comparison_ops = symmetric_comparison_ops + asymmetric_comparison_ops
|
comparison_ops = symmetric_comparison_ops + asymmetric_comparison_ops
|
||||||
|
|
@ -232,14 +238,24 @@ def sig_for_ops(opname: str) -> list[str]:
|
||||||
assert opname.endswith("__") and opname.startswith("__"), f"Unexpected op {opname}"
|
assert opname.endswith("__") and opname.startswith("__"), f"Unexpected op {opname}"
|
||||||
|
|
||||||
name = opname[2:-2]
|
name = opname[2:-2]
|
||||||
if name in binary_ops:
|
if name in arithmetic_ops:
|
||||||
return [f"def {opname}(self, other: Any) -> Tensor: ..."]
|
return [
|
||||||
elif name in comparison_ops:
|
f"def {opname}(self, other: Union[Tensor, Number, _complex]) -> Tensor: ..."
|
||||||
sig = f"def {opname}(self, other: Any) -> Tensor: ..."
|
]
|
||||||
if name in symmetric_comparison_ops:
|
elif name in logic_ops:
|
||||||
|
return [f"def {opname}(self, other: Union[Tensor, _bool]) -> Tensor: ..."]
|
||||||
|
elif name in shift_ops:
|
||||||
|
return [f"def {opname}(self, other: Union[Tensor, _int]) -> Tensor: ..."]
|
||||||
|
elif name in symmetric_comparison_ops:
|
||||||
|
return [
|
||||||
# unsafe override https://github.com/python/mypy/issues/5704
|
# unsafe override https://github.com/python/mypy/issues/5704
|
||||||
sig += " # type: ignore[override]"
|
f"def {opname}(self, other: Union[Tensor, Number, _complex]) -> Tensor: ... # type: ignore[override]",
|
||||||
return [sig]
|
f"def {opname}(self, other: Any) -> _bool: ...",
|
||||||
|
]
|
||||||
|
elif name in asymmetric_comparison_ops:
|
||||||
|
return [
|
||||||
|
f"def {opname}(self, other: Union[Tensor, Number, _complex]) -> Tensor: ..."
|
||||||
|
]
|
||||||
elif name in unary_ops:
|
elif name in unary_ops:
|
||||||
return [f"def {opname}(self) -> Tensor: ..."]
|
return [f"def {opname}(self) -> Tensor: ..."]
|
||||||
elif name in to_py_type_ops:
|
elif name in to_py_type_ops:
|
||||||
|
|
|
||||||
|
|
@ -2291,7 +2291,8 @@ def native_batch_norm_backward(
|
||||||
mean = save_mean_cast
|
mean = save_mean_cast
|
||||||
invstd = save_invstd_cast
|
invstd = save_invstd_cast
|
||||||
if train:
|
if train:
|
||||||
assert save_mean_cast is not None and save_invstd_cast is not None
|
assert mean is not None and invstd is not None
|
||||||
|
|
||||||
else:
|
else:
|
||||||
assert running_mean_cast is not None and running_var_cast is not None
|
assert running_mean_cast is not None and running_var_cast is not None
|
||||||
mean = running_mean_cast
|
mean = running_mean_cast
|
||||||
|
|
|
||||||
|
|
@ -33,6 +33,7 @@ def efficient_conv_bn_eval(
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert bn.running_var is not None
|
assert bn.running_var is not None
|
||||||
|
assert bn.running_mean is not None
|
||||||
|
|
||||||
# These lines of code are designed to deal with various cases
|
# These lines of code are designed to deal with various cases
|
||||||
# like bn without affine transform, and conv without bias
|
# like bn without affine transform, and conv without bias
|
||||||
|
|
|
||||||
|
|
@ -128,8 +128,7 @@ def cross_layer_equalization(module1, module2, output_axis=0, input_axis=1):
|
||||||
"module type not supported:", type(module1), " ", type(module2)
|
"module type not supported:", type(module1), " ", type(module2)
|
||||||
)
|
)
|
||||||
|
|
||||||
conv1_has_bias = has_bias(module1)
|
bias = get_module_bias(module1) if has_bias(module1) else None
|
||||||
bias = None
|
|
||||||
|
|
||||||
weight1 = get_module_weight(module1)
|
weight1 = get_module_weight(module1)
|
||||||
weight2 = get_module_weight(module2)
|
weight2 = get_module_weight(module2)
|
||||||
|
|
@ -140,9 +139,6 @@ def cross_layer_equalization(module1, module2, output_axis=0, input_axis=1):
|
||||||
number input channels of second arg"
|
number input channels of second arg"
|
||||||
)
|
)
|
||||||
|
|
||||||
if conv1_has_bias:
|
|
||||||
bias = get_module_bias(module1)
|
|
||||||
|
|
||||||
weight1_range = channel_range(weight1, output_axis)
|
weight1_range = channel_range(weight1, output_axis)
|
||||||
weight2_range = channel_range(weight2, input_axis)
|
weight2_range = channel_range(weight2, input_axis)
|
||||||
|
|
||||||
|
|
@ -151,7 +147,7 @@ def cross_layer_equalization(module1, module2, output_axis=0, input_axis=1):
|
||||||
scaling_factors = torch.sqrt(weight1_range / weight2_range)
|
scaling_factors = torch.sqrt(weight1_range / weight2_range)
|
||||||
inverse_scaling_factors = torch.reciprocal(scaling_factors)
|
inverse_scaling_factors = torch.reciprocal(scaling_factors)
|
||||||
|
|
||||||
if conv1_has_bias:
|
if bias is not None:
|
||||||
bias = bias * inverse_scaling_factors
|
bias = bias * inverse_scaling_factors
|
||||||
|
|
||||||
# formatting the scaling (1D) tensors to be applied on the given argument tensors
|
# formatting the scaling (1D) tensors to be applied on the given argument tensors
|
||||||
|
|
@ -168,7 +164,7 @@ def cross_layer_equalization(module1, module2, output_axis=0, input_axis=1):
|
||||||
weight2 = weight2 * scaling_factors
|
weight2 = weight2 * scaling_factors
|
||||||
|
|
||||||
set_module_weight(module1, weight1)
|
set_module_weight(module1, weight1)
|
||||||
if conv1_has_bias:
|
if bias is not None:
|
||||||
set_module_bias(module1, bias)
|
set_module_bias(module1, bias)
|
||||||
set_module_weight(module2, weight2)
|
set_module_weight(module2, weight2)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user