mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
remove allow-untyped-defs from ./torch/nn/utils/_expanded_weights/embedding_expanded_weights.py (#163475)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163475 Approved by: https://github.com/ezyang, https://github.com/Skylion007 ghstack dependencies: #163478
This commit is contained in:
parent
a6974195da
commit
e7d6ea65ca
|
|
@ -1,5 +1,4 @@
|
|||
# mypy: allow-untyped-defs
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
|
@ -15,7 +14,9 @@ from .expanded_weights_utils import (
|
|||
@implements_per_sample_grads(F.embedding)
|
||||
class EmbeddingPerSampleGrad(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, kwarg_names, _, *expanded_args_and_kwargs):
|
||||
def forward(
|
||||
ctx: Any, kwarg_names: list[str], _: Any, *expanded_args_and_kwargs: Any
|
||||
) -> torch.Tensor:
|
||||
expanded_args, expanded_kwargs = standard_kwargs(
|
||||
kwarg_names, expanded_args_and_kwargs
|
||||
)
|
||||
|
|
@ -33,7 +34,9 @@ class EmbeddingPerSampleGrad(torch.autograd.Function):
|
|||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
def backward(
|
||||
ctx: Any, grad_output: torch.Tensor
|
||||
) -> tuple[Optional[torch.Tensor], ...]:
|
||||
input, weight = ctx.input, ctx.weight
|
||||
padding_idx, scale_grad_by_freq, sparse = (
|
||||
ctx.padding_idx,
|
||||
|
|
@ -41,7 +44,7 @@ class EmbeddingPerSampleGrad(torch.autograd.Function):
|
|||
ctx.sparse,
|
||||
)
|
||||
|
||||
def weight_per_sample_grad(weight):
|
||||
def weight_per_sample_grad(weight: torch.Tensor) -> torch.Tensor:
|
||||
batch_size = input.shape[0]
|
||||
embedding_dim = weight.shape[1]
|
||||
index = (
|
||||
|
|
@ -49,7 +52,7 @@ class EmbeddingPerSampleGrad(torch.autograd.Function):
|
|||
.expand(*input.shape, embedding_dim)
|
||||
.reshape(batch_size, -1, embedding_dim)
|
||||
)
|
||||
grad_sample = torch.zeros(
|
||||
grad_sample = torch.zeros( # type: ignore[attr-defined]
|
||||
batch_size, *weight.shape, device=weight.device, dtype=grad_output.dtype
|
||||
)
|
||||
return grad_sample.scatter_add_(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user