diff --git a/torch/nn/utils/_expanded_weights/embedding_expanded_weights.py b/torch/nn/utils/_expanded_weights/embedding_expanded_weights.py index febeecab8b3..24acd7549e6 100644 --- a/torch/nn/utils/_expanded_weights/embedding_expanded_weights.py +++ b/torch/nn/utils/_expanded_weights/embedding_expanded_weights.py @@ -1,5 +1,4 @@ -# mypy: allow-untyped-defs -from typing import Optional +from typing import Any, Optional import torch import torch.nn.functional as F @@ -15,7 +14,9 @@ from .expanded_weights_utils import ( @implements_per_sample_grads(F.embedding) class EmbeddingPerSampleGrad(torch.autograd.Function): @staticmethod - def forward(ctx, kwarg_names, _, *expanded_args_and_kwargs): + def forward( + ctx: Any, kwarg_names: list[str], _: Any, *expanded_args_and_kwargs: Any + ) -> torch.Tensor: expanded_args, expanded_kwargs = standard_kwargs( kwarg_names, expanded_args_and_kwargs ) @@ -33,7 +34,9 @@ class EmbeddingPerSampleGrad(torch.autograd.Function): return output @staticmethod - def backward(ctx, grad_output): + def backward( + ctx: Any, grad_output: torch.Tensor + ) -> tuple[Optional[torch.Tensor], ...]: input, weight = ctx.input, ctx.weight padding_idx, scale_grad_by_freq, sparse = ( ctx.padding_idx, @@ -41,7 +44,7 @@ class EmbeddingPerSampleGrad(torch.autograd.Function): ctx.sparse, ) - def weight_per_sample_grad(weight): + def weight_per_sample_grad(weight: torch.Tensor) -> torch.Tensor: batch_size = input.shape[0] embedding_dim = weight.shape[1] index = ( @@ -49,7 +52,7 @@ class EmbeddingPerSampleGrad(torch.autograd.Function): .expand(*input.shape, embedding_dim) .reshape(batch_size, -1, embedding_dim) ) - grad_sample = torch.zeros( + grad_sample = torch.zeros( # type: ignore[attr-defined] batch_size, *weight.shape, device=weight.device, dtype=grad_output.dtype ) return grad_sample.scatter_add_(