mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
remove allow-untyped-defs from ./torch/nn/utils/_expanded_weights/embedding_expanded_weights.py (#163475)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163475 Approved by: https://github.com/ezyang, https://github.com/Skylion007 ghstack dependencies: #163478
This commit is contained in:
parent
a6974195da
commit
e7d6ea65ca
|
|
@ -1,5 +1,4 @@
|
||||||
# mypy: allow-untyped-defs
|
from typing import Any, Optional
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
@ -15,7 +14,9 @@ from .expanded_weights_utils import (
|
||||||
@implements_per_sample_grads(F.embedding)
|
@implements_per_sample_grads(F.embedding)
|
||||||
class EmbeddingPerSampleGrad(torch.autograd.Function):
|
class EmbeddingPerSampleGrad(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, kwarg_names, _, *expanded_args_and_kwargs):
|
def forward(
|
||||||
|
ctx: Any, kwarg_names: list[str], _: Any, *expanded_args_and_kwargs: Any
|
||||||
|
) -> torch.Tensor:
|
||||||
expanded_args, expanded_kwargs = standard_kwargs(
|
expanded_args, expanded_kwargs = standard_kwargs(
|
||||||
kwarg_names, expanded_args_and_kwargs
|
kwarg_names, expanded_args_and_kwargs
|
||||||
)
|
)
|
||||||
|
|
@ -33,7 +34,9 @@ class EmbeddingPerSampleGrad(torch.autograd.Function):
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, grad_output):
|
def backward(
|
||||||
|
ctx: Any, grad_output: torch.Tensor
|
||||||
|
) -> tuple[Optional[torch.Tensor], ...]:
|
||||||
input, weight = ctx.input, ctx.weight
|
input, weight = ctx.input, ctx.weight
|
||||||
padding_idx, scale_grad_by_freq, sparse = (
|
padding_idx, scale_grad_by_freq, sparse = (
|
||||||
ctx.padding_idx,
|
ctx.padding_idx,
|
||||||
|
|
@ -41,7 +44,7 @@ class EmbeddingPerSampleGrad(torch.autograd.Function):
|
||||||
ctx.sparse,
|
ctx.sparse,
|
||||||
)
|
)
|
||||||
|
|
||||||
def weight_per_sample_grad(weight):
|
def weight_per_sample_grad(weight: torch.Tensor) -> torch.Tensor:
|
||||||
batch_size = input.shape[0]
|
batch_size = input.shape[0]
|
||||||
embedding_dim = weight.shape[1]
|
embedding_dim = weight.shape[1]
|
||||||
index = (
|
index = (
|
||||||
|
|
@ -49,7 +52,7 @@ class EmbeddingPerSampleGrad(torch.autograd.Function):
|
||||||
.expand(*input.shape, embedding_dim)
|
.expand(*input.shape, embedding_dim)
|
||||||
.reshape(batch_size, -1, embedding_dim)
|
.reshape(batch_size, -1, embedding_dim)
|
||||||
)
|
)
|
||||||
grad_sample = torch.zeros(
|
grad_sample = torch.zeros( # type: ignore[attr-defined]
|
||||||
batch_size, *weight.shape, device=weight.device, dtype=grad_output.dtype
|
batch_size, *weight.shape, device=weight.device, dtype=grad_output.dtype
|
||||||
)
|
)
|
||||||
return grad_sample.scatter_add_(
|
return grad_sample.scatter_add_(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user