remove allow-untyped-defs from ./torch/nn/utils/_expanded_weights/embedding_expanded_weights.py (#163475)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163475
Approved by: https://github.com/ezyang, https://github.com/Skylion007
ghstack dependencies: #163478
This commit is contained in:
Bob Ren 2025-09-24 15:18:42 -07:00 committed by PyTorch MergeBot
parent a6974195da
commit e7d6ea65ca

View File

@ -1,5 +1,4 @@
# mypy: allow-untyped-defs
from typing import Optional
from typing import Any, Optional
import torch
import torch.nn.functional as F
@ -15,7 +14,9 @@ from .expanded_weights_utils import (
@implements_per_sample_grads(F.embedding)
class EmbeddingPerSampleGrad(torch.autograd.Function):
@staticmethod
def forward(ctx, kwarg_names, _, *expanded_args_and_kwargs):
def forward(
ctx: Any, kwarg_names: list[str], _: Any, *expanded_args_and_kwargs: Any
) -> torch.Tensor:
expanded_args, expanded_kwargs = standard_kwargs(
kwarg_names, expanded_args_and_kwargs
)
@ -33,7 +34,9 @@ class EmbeddingPerSampleGrad(torch.autograd.Function):
return output
@staticmethod
def backward(ctx, grad_output):
def backward(
ctx: Any, grad_output: torch.Tensor
) -> tuple[Optional[torch.Tensor], ...]:
input, weight = ctx.input, ctx.weight
padding_idx, scale_grad_by_freq, sparse = (
ctx.padding_idx,
@ -41,7 +44,7 @@ class EmbeddingPerSampleGrad(torch.autograd.Function):
ctx.sparse,
)
def weight_per_sample_grad(weight):
def weight_per_sample_grad(weight: torch.Tensor) -> torch.Tensor:
batch_size = input.shape[0]
embedding_dim = weight.shape[1]
index = (
@ -49,7 +52,7 @@ class EmbeddingPerSampleGrad(torch.autograd.Function):
.expand(*input.shape, embedding_dim)
.reshape(batch_size, -1, embedding_dim)
)
grad_sample = torch.zeros(
grad_sample = torch.zeros( # type: ignore[attr-defined]
batch_size, *weight.shape, device=weight.device, dtype=grad_output.dtype
)
return grad_sample.scatter_add_(