mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Ref https://github.com/pytorch/pytorch/issues/61492#issuecomment-1413003480 The array API specifies correction to be `Union[int, float]` while we currently only support integers. https://data-apis.org/array-api/latest/API_specification/generated/array_api.std.html As std/var is calculated currently, the final count of elements is already done in floating point so we can make the correction floating point without any loss of precision or generality. Pull Request resolved: https://github.com/pytorch/pytorch/pull/94073 Approved by: https://github.com/ezyang
116 lines
4.0 KiB
Python
116 lines
4.0 KiB
Python
|
|
|
|
import torch
|
|
from torch import Tensor
|
|
aten = torch.ops.aten
|
|
from typing import Optional, List, Dict, Set
|
|
import inspect
|
|
import warnings
|
|
from torch.types import Number
|
|
|
|
decomposition_table: Dict[str, torch.jit.ScriptFunction] = {}
|
|
function_name_set: Set[str] = set()
|
|
|
|
def check_decomposition_has_type_annotations(f):
|
|
|
|
inspect_empty = inspect._empty # type: ignore[attr-defined]
|
|
sig = inspect.signature(f)
|
|
for param in sig.parameters.values():
|
|
assert param.annotation != inspect_empty, \
|
|
"No signature on param {name} for function {func}".format(name=param.name, func=f.name)
|
|
|
|
assert sig.return_annotation != inspect_empty, "No return annotation for function {func}".format(func=f.name)
|
|
|
|
def signatures_match(decomposition_sig, torch_op_sig):
|
|
decomp_params = decomposition_sig.parameters
|
|
op_params = torch_op_sig.parameters
|
|
|
|
if len(decomp_params) != len(op_params):
|
|
return False
|
|
|
|
|
|
for decomp_param, op_param in zip(decomp_params.values(), op_params.values()):
|
|
# can't check full equality yet because not all fields are correcly deduced
|
|
# in the torch_op_sig - like default value
|
|
# can't check 'kind' bc
|
|
# kwarg-only values with defaults not yet supported in TS
|
|
inspect_empty = inspect._empty # type: ignore[attr-defined]
|
|
for field in ['name', 'annotation']:
|
|
if field == 'name' and decomp_param.name == "self":
|
|
warnings.warn("PyTorch uses 'input' instead of 'self' on public api")
|
|
|
|
if getattr(decomp_param, field) != getattr(op_param, field):
|
|
return False
|
|
|
|
decomp_default = decomp_param.default
|
|
op_default = op_param.default
|
|
# default value not always correctly inferred as being present on torch schema,
|
|
# but if specified on both they should be equal
|
|
if decomp_default != inspect_empty and op_default != inspect_empty:
|
|
if decomp_default != op_default:
|
|
return False
|
|
|
|
return decomposition_sig.return_annotation == torch_op_sig.return_annotation
|
|
|
|
def register_decomposition(aten_op, registry=None):
|
|
def decomposition_decorator(f):
|
|
nonlocal registry
|
|
if registry is None:
|
|
registry = decomposition_table
|
|
|
|
assert isinstance(aten_op, torch._ops.OpOverload)
|
|
|
|
# Need unique name for jit function serialization
|
|
assert f.__name__ not in function_name_set, "Duplicated function name {}".format(f.__name__)
|
|
function_name_set.add(f.__name__)
|
|
|
|
scripted_func = torch.jit.script(f)
|
|
torch._C._jit_pass_inline(scripted_func.graph)
|
|
|
|
for _ in range(2):
|
|
torch._C._jit_pass_peephole(scripted_func.graph)
|
|
torch._C._jit_pass_constant_propagation(scripted_func.graph)
|
|
|
|
registry[str(aten_op._schema)] = scripted_func
|
|
return f
|
|
|
|
return decomposition_decorator
|
|
|
|
# TODO: replace torch.sigmoid -> aten.sigmoid
|
|
|
|
@register_decomposition(aten.var.correction)
|
|
def var_decomposition(input: Tensor, dim: Optional[List[int]] = None,
|
|
correction: Optional[Number] = None,
|
|
keepdim: bool = False) -> Tensor:
|
|
if dim is None:
|
|
dim_i: List[int] = []
|
|
dim = dim_i
|
|
|
|
if isinstance(dim, (tuple, list)) and len(dim) == 0:
|
|
n = input.numel()
|
|
else:
|
|
n = 1
|
|
for dim_i in dim: # type: ignore[assignment]
|
|
n *= input.shape[dim_i] # type: ignore[call-overload]
|
|
|
|
mean = aten.mean(input, dim, True)
|
|
sub = input - mean
|
|
sq = sub * sub
|
|
sum = aten.sum(sq, dim, keepdim)
|
|
|
|
if correction is None:
|
|
denom = float(n - 1)
|
|
else:
|
|
if isinstance(correction, int):
|
|
denom = float(n - correction)
|
|
elif isinstance(correction, float):
|
|
denom = float(n) - correction
|
|
else:
|
|
raise RuntimeError("correction must be int or float")
|
|
|
|
return sum / max(0, denom)
|
|
|
|
@register_decomposition(aten.var.default)
|
|
def var(input: Tensor, unbiased: bool = True) -> Tensor:
|
|
return var_decomposition(input, correction=(1 if unbiased else 0))
|