diff --git a/torch/_jit_internal.py b/torch/_jit_internal.py index b718ab87c86..d41f101d4ed 100644 --- a/torch/_jit_internal.py +++ b/torch/_jit_internal.py @@ -31,10 +31,8 @@ from typing import ( # noqa: UP035, F401 # (Dict, List, Tuple) imported by tor List, Optional, Tuple, - TypeVar, Union, ) -from typing_extensions import ParamSpec import torch @@ -49,9 +47,6 @@ from torch._sources import fake_range, get_source_lines_and_file, parse_def from torch.futures import Future -_P = ParamSpec("_P") -_R = TypeVar("_R") - IS_PY310_PLUS: Final[bool] = sys.version_info >= (3, 10) BuiltinUnionType: Union[type, tuple[type, ...]] @@ -670,7 +665,7 @@ class FunctionModifiers: _DROP = "_drop (function is fully ignored, declaration can be unscriptable)" -def export(fn: Callable[_P, _R]) -> Callable[_P, _R]: +def export(fn): """ This decorator indicates that a method on an ``nn.Module`` is used as an entry point into a :class:`ScriptModule` and should be compiled. @@ -712,11 +707,11 @@ def export(fn: Callable[_P, _R]) -> Callable[_P, _R]: # any compiled methods and wasn't decorated with `@torch.jit.export` m = torch.jit.script(MyModule()) """ - fn._torchscript_modifier = FunctionModifiers.EXPORT # type:ignore[attr-defined] + fn._torchscript_modifier = FunctionModifiers.EXPORT return fn -def unused(fn: Callable[_P, _R]) -> Callable[_P, _R]: +def unused(fn): """ This decorator indicates to the compiler that a function or method should be ignored and replaced with the raising of an exception. This allows you @@ -769,7 +764,7 @@ def unused(fn: Callable[_P, _R]) -> Callable[_P, _R]: return prop - fn._torchscript_modifier = FunctionModifiers.UNUSED # type: ignore[attr-defined] + fn._torchscript_modifier = FunctionModifiers.UNUSED return fn @@ -887,13 +882,13 @@ def ignore(drop=False, **kwargs): return decorator -def _drop(fn: Callable[_P, _R]) -> Callable[_P, _R]: - fn._torchscript_modifier = FunctionModifiers._DROP # type: ignore[attr-defined] +def _drop(fn): + fn._torchscript_modifier = FunctionModifiers._DROP return fn -def _copy_to_script_wrapper(fn: Callable[_P, _R]) -> Callable[_P, _R]: - fn._torchscript_modifier = FunctionModifiers.COPY_TO_SCRIPT_WRAPPER # type: ignore[attr-defined] +def _copy_to_script_wrapper(fn): + fn._torchscript_modifier = FunctionModifiers.COPY_TO_SCRIPT_WRAPPER return fn diff --git a/torch/ao/nn/intrinsic/qat/modules/linear_fused.py b/torch/ao/nn/intrinsic/qat/modules/linear_fused.py index aada0ab2ab7..75d46cf3f7d 100644 --- a/torch/ao/nn/intrinsic/qat/modules/linear_fused.py +++ b/torch/ao/nn/intrinsic/qat/modules/linear_fused.py @@ -169,13 +169,13 @@ class LinearBn1d(nn.modules.linear.Linear, nni._FusedModule): False, qconfig, ) - qat_linearbn.weight = linear.weight # type: ignore[assignment] - qat_linearbn.bias = linear.bias # type: ignore[assignment] - qat_linearbn.bn.weight = bn.weight # type: ignore[assignment] - qat_linearbn.bn.bias = bn.bias # type: ignore[assignment] - qat_linearbn.bn.running_mean = bn.running_mean # type: ignore[assignment] - qat_linearbn.bn.running_var = bn.running_var # type: ignore[assignment] - qat_linearbn.bn.num_batches_tracked = bn.num_batches_tracked # type: ignore[assignment] + qat_linearbn.weight = linear.weight + qat_linearbn.bias = linear.bias + qat_linearbn.bn.weight = bn.weight + qat_linearbn.bn.bias = bn.bias + qat_linearbn.bn.running_mean = bn.running_mean + qat_linearbn.bn.running_var = bn.running_var + qat_linearbn.bn.num_batches_tracked = bn.num_batches_tracked return qat_linearbn def to_float(self): diff --git a/torch/ao/nn/intrinsic/quantized/modules/linear_relu.py b/torch/ao/nn/intrinsic/quantized/modules/linear_relu.py index a4b4fd1e7f3..5173ed813bf 100644 --- a/torch/ao/nn/intrinsic/quantized/modules/linear_relu.py +++ b/torch/ao/nn/intrinsic/quantized/modules/linear_relu.py @@ -99,7 +99,7 @@ class LinearLeakyReLU(nnq.Linear): activation_post_process = mod.activation_post_process leaky_relu = mod[1] mod = mod[0] - weight_post_process = mod.qconfig.weight() # type: ignore[union-attr, operator] + weight_post_process = mod.qconfig.weight() weight_post_process(mod.weight) dtype = weight_post_process.dtype act_scale, act_zp = activation_post_process.calculate_qparams() # type: ignore[union-attr,operator] @@ -108,7 +108,7 @@ class LinearLeakyReLU(nnq.Linear): qlinear_leaky_relu = cls( mod.in_features, mod.out_features, leaky_relu.negative_slope, dtype=dtype ) - qlinear_leaky_relu.set_weight_bias(qweight, mod.bias) # type: ignore[arg-type] + qlinear_leaky_relu.set_weight_bias(qweight, mod.bias) qlinear_leaky_relu.scale = float(act_scale) qlinear_leaky_relu.zero_point = int(act_zp) return qlinear_leaky_relu @@ -164,14 +164,14 @@ class LinearTanh(nnq.Linear): assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined" activation_post_process = mod.activation_post_process mod = mod[0] - weight_post_process = mod.qconfig.weight() # type: ignore[union-attr,operator] + weight_post_process = mod.qconfig.weight() weight_post_process(mod.weight) dtype = weight_post_process.dtype act_scale, act_zp = activation_post_process.calculate_qparams() # type: ignore[union-attr,operator] assert dtype == torch.qint8, "Weight observer must have dtype torch.qint8" qweight = _quantize_weight(mod.weight.float(), weight_post_process) qlinear_tanh = cls(mod.in_features, mod.out_features, dtype=dtype) - qlinear_tanh.set_weight_bias(qweight, mod.bias) # type: ignore[arg-type] + qlinear_tanh.set_weight_bias(qweight, mod.bias) qlinear_tanh.scale = float(act_scale) qlinear_tanh.zero_point = int(act_zp) return qlinear_tanh diff --git a/torch/ao/ns/fx/weight_utils.py b/torch/ao/ns/fx/weight_utils.py index 52cb13c1286..fdd87963c2d 100644 --- a/torch/ao/ns/fx/weight_utils.py +++ b/torch/ao/ns/fx/weight_utils.py @@ -52,7 +52,7 @@ def get_conv_mod_weight(mod: nn.Module) -> torch.Tensor: if isinstance(mod, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): return mod.weight.detach() elif isinstance(mod, (nni.ConvReLU1d, nni.ConvReLU2d, nni.ConvReLU3d)): - return mod[0].weight.detach() # type: ignore[operator] + return mod[0].weight.detach() else: return mod._weight_bias()[0] # type: ignore[operator] @@ -61,7 +61,7 @@ def get_linear_mod_weight(mod: nn.Module) -> torch.Tensor: if isinstance(mod, nn.Linear): return mod.weight.detach() elif isinstance(mod, nni.LinearReLU): - return mod[0].weight.detach() # type: ignore[operator] + return mod[0].weight.detach() else: return mod._weight_bias()[0] # type: ignore[operator] @@ -79,12 +79,8 @@ def get_lstm_mod_weights(mod: nn.Module) -> list[torch.Tensor]: assert isinstance(mod, nnqd.LSTM), f"type {type(mod)} not handled yet" res = [] for weight_value in mod._all_weight_values: - res.append( - weight_value.param.__getstate__()[0][4][0].__getstate__()[0][0] # type: ignore[index] - ) - res.append( - weight_value.param.__getstate__()[0][4][1].__getstate__()[0][0] # type: ignore[index] - ) + res.append(weight_value.param.__getstate__()[0][4][0].__getstate__()[0][0]) + res.append(weight_value.param.__getstate__()[0][4][1].__getstate__()[0][0]) return res diff --git a/torch/ao/quantization/experimental/adaround_fake_quantize.py b/torch/ao/quantization/experimental/adaround_fake_quantize.py index f9b0067acf2..77c8781ed27 100644 --- a/torch/ao/quantization/experimental/adaround_fake_quantize.py +++ b/torch/ao/quantization/experimental/adaround_fake_quantize.py @@ -60,7 +60,7 @@ class AdaroundFakeQuantizer(FakeQuantize): self.use_soft_rounding = True @torch.jit.export - def calculate_qparams(self) -> tuple[torch.Tensor, torch.Tensor]: # type: ignore[override] + def calculate_qparams(self) -> tuple[torch.Tensor, torch.Tensor]: return self.scale, self.zero_point @torch.jit.export diff --git a/torch/ao/quantization/fake_quantize.py b/torch/ao/quantization/fake_quantize.py index 74d91e6733a..e957f04a7ef 100644 --- a/torch/ao/quantization/fake_quantize.py +++ b/torch/ao/quantization/fake_quantize.py @@ -392,7 +392,7 @@ class FusedMovingAvgObsFakeQuantize(FakeQuantize): ) @torch.jit.export - def calculate_qparams(self) -> tuple[torch.Tensor, torch.Tensor]: # type: ignore[override] + def calculate_qparams(self) -> tuple[torch.Tensor, torch.Tensor]: return self.activation_post_process.calculate_qparams() @torch.jit.export diff --git a/torch/ao/quantization/fx/lstm_utils.py b/torch/ao/quantization/fx/lstm_utils.py index 83a234fd8e1..f4fcb868944 100644 --- a/torch/ao/quantization/fx/lstm_utils.py +++ b/torch/ao/quantization/fx/lstm_utils.py @@ -104,8 +104,7 @@ def _get_lstm_with_individually_observed_parts( # Insert observers into each LSTM cell # TODO: maybe make this work for layer_bw as well for layer in quantizable_lstm.layers: - cell = layer.layer_fw.cell # type: ignore[union-attr] - assert isinstance(cell, torch.nn.Module), "cell should be a nn.Module" + cell = layer.layer_fw.cell cell = prepare_fx(cell, cell_qm, example_inputs, backend_config=backend_config) # HACK: Manually replace the activation_post_process following these ops. # This is needed for FloatFunctional ops because there is currently no way @@ -155,7 +154,7 @@ def _get_lstm_with_individually_observed_parts( setattr( cell, activation_post_process_name, activation_post_process_ctr() ) - layer.layer_fw.cell = cell # type: ignore[union-attr] + layer.layer_fw.cell = cell return quantizable_lstm @@ -217,5 +216,5 @@ def _get_reference_quantized_lstm_module( node.replace_input_with(arg, arg.args[0]) cell.graph.eliminate_dead_code() cell.recompile() - layer.layer_fw.cell = cell # type: ignore[union-attr] + layer.layer_fw.cell = cell return quantized_lstm diff --git a/torch/nn/modules/adaptive.py b/torch/nn/modules/adaptive.py index 93e11c63082..cde1ad0005d 100644 --- a/torch/nn/modules/adaptive.py +++ b/torch/nn/modules/adaptive.py @@ -172,9 +172,9 @@ class AdaptiveLogSoftmaxWithLoss(Module): def reset_parameters(self) -> None: self.head.reset_parameters() - for i2h, h2o in self.tail: # type: ignore[misc] - i2h.reset_parameters() # type: ignore[has-type] - h2o.reset_parameters() # type: ignore[has-type] + for i2h, h2o in self.tail: + i2h.reset_parameters() + h2o.reset_parameters() def forward(self, input_: Tensor, target_: Tensor) -> _ASMoutput: targ_dim = target_.dim() diff --git a/torch/nn/modules/container.py b/torch/nn/modules/container.py index 298ab149639..9078097aca4 100644 --- a/torch/nn/modules/container.py +++ b/torch/nn/modules/container.py @@ -1,3 +1,4 @@ +# mypy: allow-untyped-decorators # mypy: allow-untyped-defs from __future__ import annotations @@ -28,7 +29,6 @@ __all__ = [ ] T = TypeVar("T", bound=Module) -_V = TypeVar("_V") # Copied from torch.nn.modules.module, required for a custom __repr__ for ModuleList @@ -121,7 +121,7 @@ class Sequential(Module): for idx, module in enumerate(args): self.add_module(str(idx), module) - def _get_item_by_idx(self, iterator: Iterable[_V], idx: int) -> _V: + def _get_item_by_idx(self, iterator, idx) -> T: # type: ignore[misc, type-var] """Get the idx-th item of the iterator.""" size = len(self) idx = operator.index(idx) @@ -131,7 +131,7 @@ class Sequential(Module): return next(islice(iterator, idx, None)) @_copy_to_script_wrapper - def __getitem__(self, idx: Union[slice, int]) -> Union[Sequential, Module]: + def __getitem__(self, idx: Union[slice, int]) -> Union[Sequential, T]: if isinstance(idx, slice): return self.__class__(OrderedDict(list(self._modules.items())[idx])) else: @@ -227,7 +227,7 @@ class Sequential(Module): return self @_copy_to_script_wrapper - def __dir__(self) -> list[str]: + def __dir__(self): keys = super().__dir__() keys = [key for key in keys if not key.isdigit()] return keys @@ -410,7 +410,7 @@ class ModuleList(Module): combined.add_module(str(i), module) return combined - def __repr__(self) -> str: + def __repr__(self): """Return a custom repr for ModuleList that compresses repeated module representations.""" list_of_reprs = [repr(item) for item in self] if len(list_of_reprs) == 0: @@ -443,7 +443,7 @@ class ModuleList(Module): return main_str @_copy_to_script_wrapper - def __dir__(self) -> list[str]: + def __dir__(self): keys = super().__dir__() keys = [key for key in keys if not key.isdigit()] return keys @@ -580,17 +580,17 @@ class ModuleDict(Module): return v @_copy_to_script_wrapper - def keys(self) -> container_abcs.KeysView[str]: + def keys(self) -> Iterable[str]: r"""Return an iterable of the ModuleDict keys.""" return self._modules.keys() @_copy_to_script_wrapper - def items(self) -> container_abcs.ItemsView[str, Module]: + def items(self) -> Iterable[tuple[str, Module]]: r"""Return an iterable of the ModuleDict key/value pairs.""" return self._modules.items() @_copy_to_script_wrapper - def values(self) -> container_abcs.ValuesView[Module]: + def values(self) -> Iterable[Module]: r"""Return an iterable of the ModuleDict values.""" return self._modules.values() @@ -716,7 +716,7 @@ class ParameterList(Module): def __iadd__(self, parameters: Iterable[Any]) -> Self: return self.extend(parameters) - def __dir__(self) -> list[str]: + def __dir__(self): keys = super().__dir__() keys = [key for key in keys if not key.isdigit()] return keys @@ -930,7 +930,7 @@ class ParameterDict(Module): """ return ParameterDict((k, default) for k in keys) - def keys(self) -> container_abcs.KeysView[str]: + def keys(self) -> Iterable[str]: r"""Return an iterable of the ParameterDict keys.""" return self._keys.keys() diff --git a/torch/nn/utils/parametrize.py b/torch/nn/utils/parametrize.py index b4c88c89819..544397e5378 100644 --- a/torch/nn/utils/parametrize.py +++ b/torch/nn/utils/parametrize.py @@ -594,9 +594,9 @@ def register_parametrization( # add the new parametrization to the parametrization list assert isinstance(module.parametrizations, ModuleDict) # Make mypy happy - module.parametrizations[tensor_name].append(parametrization) # type: ignore[operator] + module.parametrizations[tensor_name].append(parametrization) # If unsafe was True in previous parametrization, keep it enabled - module.parametrizations[tensor_name].unsafe |= unsafe # type: ignore[index, union-attr, operator] + module.parametrizations[tensor_name].unsafe |= unsafe # type: ignore[index, union-attr] elif tensor_name in module._buffers or tensor_name in module._parameters: # Set the parametrization mechanism # Fetch the original buffer or parameter @@ -686,7 +686,6 @@ def remove_parametrizations( parametrizations = module.parametrizations[tensor_name] if parametrizations.is_tensor: original = parametrizations.original - assert isinstance(original, torch.Tensor), "is_tensor promised us a Tensor" if leave_parametrized: with torch.no_grad(): t = getattr(module, tensor_name) @@ -793,9 +792,7 @@ def transfer_parametrizations_and_params( ) # apply the params's parametrizations to to_module - for param_func in from_module.parametrizations[ # type: ignore[attr-defined] - parameter_name - ]: + for param_func in from_module.parametrizations[parameter_name]: register_parametrization(to_module, parameter_name, param_func) assert isinstance(to_module.parametrizations, ModuleDict) # for mypy