mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Strided masked reduction: mean (2nd try) (#67088)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/67088 Test Plan: Imported from OSS Reviewed By: pbelevich Differential Revision: D31914868 Pulled By: cpuhrsch fbshipit-source-id: beda9d32ea65bcae31c2c0181f95ad23c6631075
This commit is contained in:
parent
6c22b96082
commit
a33d3d84df
|
|
@ -25,7 +25,6 @@ def _apply_docstring_templates(func):
|
||||||
"""Decorator that applies docstring templates to function docstring
|
"""Decorator that applies docstring templates to function docstring
|
||||||
and returns the function instance.
|
and returns the function instance.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
docstring_templates = dict(
|
docstring_templates = dict(
|
||||||
reduction_signature='''\
|
reduction_signature='''\
|
||||||
{function_name}(input, dim, *, keepdim=False, dtype=None, mask=None) -> Tensor''',
|
{function_name}(input, dim, *, keepdim=False, dtype=None, mask=None) -> Tensor''',
|
||||||
|
|
@ -106,7 +105,8 @@ and int32 dtypes, the identity values are ``{identity_float32}``, ``{identity_ui
|
||||||
sum='sum',
|
sum='sum',
|
||||||
prod='product',
|
prod='product',
|
||||||
amax='maximum',
|
amax='maximum',
|
||||||
amin='minimum')[func.__name__],
|
amin='minimum',
|
||||||
|
mean='mean')[func.__name__],
|
||||||
'identity_uint8': _reduction_identity(func.__name__, torch.tensor(0, dtype=torch.uint8)),
|
'identity_uint8': _reduction_identity(func.__name__, torch.tensor(0, dtype=torch.uint8)),
|
||||||
'identity_int32': _reduction_identity(func.__name__, torch.tensor(0, dtype=torch.int32)),
|
'identity_int32': _reduction_identity(func.__name__, torch.tensor(0, dtype=torch.int32)),
|
||||||
'identity_float32': _reduction_identity(func.__name__, torch.tensor(0, dtype=torch.float32)),
|
'identity_float32': _reduction_identity(func.__name__, torch.tensor(0, dtype=torch.float32)),
|
||||||
|
|
@ -172,6 +172,13 @@ def _reduction_identity(op_name: str, input: Tensor):
|
||||||
return torch.tensor(torch.inf, dtype=dtype, device=device)
|
return torch.tensor(torch.inf, dtype=dtype, device=device)
|
||||||
elif torch.is_signed(input) or dtype == torch.uint8:
|
elif torch.is_signed(input) or dtype == torch.uint8:
|
||||||
return torch.tensor(torch.iinfo(dtype).max, dtype=dtype, device=device)
|
return torch.tensor(torch.iinfo(dtype).max, dtype=dtype, device=device)
|
||||||
|
elif op_name == 'mean':
|
||||||
|
# Strictly speaking, the identity value of the mean operation
|
||||||
|
# is the mean of the input. Since the mean value depends on
|
||||||
|
# the dim argument and it may be a non-scalar tensor, we
|
||||||
|
# consider the identity value of the mean operation ambiguous.
|
||||||
|
# Moreover, the mean value of empty input is undefined.
|
||||||
|
return None
|
||||||
raise NotImplementedError(f'identity of {op_name} on {dtype} input')
|
raise NotImplementedError(f'identity of {op_name} on {dtype} input')
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -332,3 +339,36 @@ def amin(input: Tensor,
|
||||||
return torch.amin(mask_input, dim_, bool(keepdim)).to(dtype=dtype)
|
return torch.amin(mask_input, dim_, bool(keepdim)).to(dtype=dtype)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'masked amin expects strided tensor (got {input.layout} tensor)')
|
raise ValueError(f'masked amin expects strided tensor (got {input.layout} tensor)')
|
||||||
|
|
||||||
|
|
||||||
|
@_apply_docstring_templates
|
||||||
|
def mean(input: Tensor,
|
||||||
|
dim: DimOrDims = None,
|
||||||
|
*,
|
||||||
|
keepdim: Optional[bool] = False,
|
||||||
|
dtype: Optional[DType] = None,
|
||||||
|
mask: Optional[Tensor] = None) -> Tensor:
|
||||||
|
"""\
|
||||||
|
{reduction_signature}
|
||||||
|
|
||||||
|
{reduction_descr}
|
||||||
|
|
||||||
|
By definition, the identity value of a mean operation is the mean
|
||||||
|
value of the tensor. If all elements of the input tensor along given
|
||||||
|
dimension(s) :attr:`dim` are masked-out, the identity value of the
|
||||||
|
mean is undefined. Due to this ambiguity, the elements of output
|
||||||
|
tensor with strided layout, that correspond to fully masked-out
|
||||||
|
elements, have ``nan`` values.
|
||||||
|
|
||||||
|
{reduction_args}
|
||||||
|
|
||||||
|
{reduction_example}"""
|
||||||
|
if dtype is None:
|
||||||
|
dtype = input.dtype
|
||||||
|
if input.layout == torch.strided:
|
||||||
|
inmask = _input_mask(input, mask=mask)
|
||||||
|
count = sum(inmask.new_ones(input.shape, dtype=torch.int64), dim, keepdim=keepdim, mask=inmask)
|
||||||
|
total = sum(input, dim, keepdim=keepdim, dtype=dtype, mask=inmask)
|
||||||
|
return total / count
|
||||||
|
else:
|
||||||
|
raise ValueError(f'masked sum expects strided tensor (got {input.layout} tensor)')
|
||||||
|
|
|
||||||
|
|
@ -11177,6 +11177,24 @@ op_db: List[OpInfo] = [
|
||||||
sample_inputs_func=sample_inputs_masked_reduction,
|
sample_inputs_func=sample_inputs_masked_reduction,
|
||||||
gradcheck_wrapper=gradcheck_wrapper_masked_operation
|
gradcheck_wrapper=gradcheck_wrapper_masked_operation
|
||||||
),
|
),
|
||||||
|
ReductionOpInfo(
|
||||||
|
'_masked.mean',
|
||||||
|
ref=reference_reduction_numpy(np.mean) if np.lib.NumpyVersion(np.__version__) >= '1.20.0' else None,
|
||||||
|
method_variant=None,
|
||||||
|
nan_policy='propagate',
|
||||||
|
supports_out=False,
|
||||||
|
promotes_int_to_float=True,
|
||||||
|
dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16, torch.bool),
|
||||||
|
skips=(
|
||||||
|
# FIXME: sum reduces all dimensions when dim=[]
|
||||||
|
DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'),
|
||||||
|
DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'),
|
||||||
|
# RuntimeError: undefined value tensor
|
||||||
|
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
|
||||||
|
),
|
||||||
|
sample_inputs_func=sample_inputs_masked_reduction,
|
||||||
|
gradcheck_wrapper=gradcheck_wrapper_masked_operation
|
||||||
|
),
|
||||||
OpInfo(
|
OpInfo(
|
||||||
"nn.functional.nll_loss",
|
"nn.functional.nll_loss",
|
||||||
ref=_NOTHING,
|
ref=_NOTHING,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user