remove allow-untyped-defs from ./torch/nn/utils/_expanded_weights/embedding_expanded_weights.py (#163475)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163475
Approved by: https://github.com/ezyang, https://github.com/Skylion007
ghstack dependencies: #163478
This commit is contained in:
Bob Ren 2025-09-24 15:18:42 -07:00 committed by PyTorch MergeBot
parent a6974195da
commit e7d6ea65ca

View File

@ -1,5 +1,4 @@
# mypy: allow-untyped-defs from typing import Any, Optional
from typing import Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@ -15,7 +14,9 @@ from .expanded_weights_utils import (
@implements_per_sample_grads(F.embedding) @implements_per_sample_grads(F.embedding)
class EmbeddingPerSampleGrad(torch.autograd.Function): class EmbeddingPerSampleGrad(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, kwarg_names, _, *expanded_args_and_kwargs): def forward(
ctx: Any, kwarg_names: list[str], _: Any, *expanded_args_and_kwargs: Any
) -> torch.Tensor:
expanded_args, expanded_kwargs = standard_kwargs( expanded_args, expanded_kwargs = standard_kwargs(
kwarg_names, expanded_args_and_kwargs kwarg_names, expanded_args_and_kwargs
) )
@ -33,7 +34,9 @@ class EmbeddingPerSampleGrad(torch.autograd.Function):
return output return output
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(
ctx: Any, grad_output: torch.Tensor
) -> tuple[Optional[torch.Tensor], ...]:
input, weight = ctx.input, ctx.weight input, weight = ctx.input, ctx.weight
padding_idx, scale_grad_by_freq, sparse = ( padding_idx, scale_grad_by_freq, sparse = (
ctx.padding_idx, ctx.padding_idx,
@ -41,7 +44,7 @@ class EmbeddingPerSampleGrad(torch.autograd.Function):
ctx.sparse, ctx.sparse,
) )
def weight_per_sample_grad(weight): def weight_per_sample_grad(weight: torch.Tensor) -> torch.Tensor:
batch_size = input.shape[0] batch_size = input.shape[0]
embedding_dim = weight.shape[1] embedding_dim = weight.shape[1]
index = ( index = (
@ -49,7 +52,7 @@ class EmbeddingPerSampleGrad(torch.autograd.Function):
.expand(*input.shape, embedding_dim) .expand(*input.shape, embedding_dim)
.reshape(batch_size, -1, embedding_dim) .reshape(batch_size, -1, embedding_dim)
) )
grad_sample = torch.zeros( grad_sample = torch.zeros( # type: ignore[attr-defined]
batch_size, *weight.shape, device=weight.device, dtype=grad_output.dtype batch_size, *weight.shape, device=weight.device, dtype=grad_output.dtype
) )
return grad_sample.scatter_add_( return grad_sample.scatter_add_(