mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[BE][Easy]: Dedupe a TypeAlias in PrimsCommon (#151565)
Replaces a duplicate TypeAlias with a reference to the global constant for them Pull Request resolved: https://github.com/pytorch/pytorch/pull/151565 Approved by: https://github.com/albanD
This commit is contained in:
parent
c4688af254
commit
da580123a0
|
|
@ -1,7 +1,7 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import warnings
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union
|
||||
from typing_extensions import ParamSpec
|
||||
from typing_extensions import ParamSpec, TypeAlias
|
||||
|
||||
import torch
|
||||
from torch import sym_float, Tensor
|
||||
|
|
@ -12,9 +12,10 @@ from torch.masked.maskedtensor.creation import as_masked_tensor
|
|||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._prims_common import DimsType
|
||||
from torch.types import _dtype as DType
|
||||
|
||||
DimOrDims = Optional[Union[int, tuple[int, ...], list[int]]]
|
||||
DimOrDims: TypeAlias = Optional[DimsType]
|
||||
else:
|
||||
# The JIT doesn't understand Union, nor torch.dtype here
|
||||
DType = int
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user