Strided masked reduction: mean (2nd try) (#67088)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/67088

Test Plan: Imported from OSS

Reviewed By: pbelevich

Differential Revision: D31914868

Pulled By: cpuhrsch

fbshipit-source-id: beda9d32ea65bcae31c2c0181f95ad23c6631075
This commit is contained in:
Pearu Peterson 2021-10-26 11:53:15 -07:00 committed by Facebook GitHub Bot
parent 6c22b96082
commit a33d3d84df
2 changed files with 60 additions and 2 deletions

View File

@ -25,7 +25,6 @@ def _apply_docstring_templates(func):
"""Decorator that applies docstring templates to function docstring """Decorator that applies docstring templates to function docstring
and returns the function instance. and returns the function instance.
""" """
docstring_templates = dict( docstring_templates = dict(
reduction_signature='''\ reduction_signature='''\
{function_name}(input, dim, *, keepdim=False, dtype=None, mask=None) -> Tensor''', {function_name}(input, dim, *, keepdim=False, dtype=None, mask=None) -> Tensor''',
@ -106,7 +105,8 @@ and int32 dtypes, the identity values are ``{identity_float32}``, ``{identity_ui
sum='sum', sum='sum',
prod='product', prod='product',
amax='maximum', amax='maximum',
amin='minimum')[func.__name__], amin='minimum',
mean='mean')[func.__name__],
'identity_uint8': _reduction_identity(func.__name__, torch.tensor(0, dtype=torch.uint8)), 'identity_uint8': _reduction_identity(func.__name__, torch.tensor(0, dtype=torch.uint8)),
'identity_int32': _reduction_identity(func.__name__, torch.tensor(0, dtype=torch.int32)), 'identity_int32': _reduction_identity(func.__name__, torch.tensor(0, dtype=torch.int32)),
'identity_float32': _reduction_identity(func.__name__, torch.tensor(0, dtype=torch.float32)), 'identity_float32': _reduction_identity(func.__name__, torch.tensor(0, dtype=torch.float32)),
@ -172,6 +172,13 @@ def _reduction_identity(op_name: str, input: Tensor):
return torch.tensor(torch.inf, dtype=dtype, device=device) return torch.tensor(torch.inf, dtype=dtype, device=device)
elif torch.is_signed(input) or dtype == torch.uint8: elif torch.is_signed(input) or dtype == torch.uint8:
return torch.tensor(torch.iinfo(dtype).max, dtype=dtype, device=device) return torch.tensor(torch.iinfo(dtype).max, dtype=dtype, device=device)
elif op_name == 'mean':
# Strictly speaking, the identity value of the mean operation
# is the mean of the input. Since the mean value depends on
# the dim argument and it may be a non-scalar tensor, we
# consider the identity value of the mean operation ambiguous.
# Moreover, the mean value of empty input is undefined.
return None
raise NotImplementedError(f'identity of {op_name} on {dtype} input') raise NotImplementedError(f'identity of {op_name} on {dtype} input')
@ -332,3 +339,36 @@ def amin(input: Tensor,
return torch.amin(mask_input, dim_, bool(keepdim)).to(dtype=dtype) return torch.amin(mask_input, dim_, bool(keepdim)).to(dtype=dtype)
else: else:
raise ValueError(f'masked amin expects strided tensor (got {input.layout} tensor)') raise ValueError(f'masked amin expects strided tensor (got {input.layout} tensor)')
@_apply_docstring_templates
def mean(input: Tensor,
dim: DimOrDims = None,
*,
keepdim: Optional[bool] = False,
dtype: Optional[DType] = None,
mask: Optional[Tensor] = None) -> Tensor:
"""\
{reduction_signature}
{reduction_descr}
By definition, the identity value of a mean operation is the mean
value of the tensor. If all elements of the input tensor along given
dimension(s) :attr:`dim` are masked-out, the identity value of the
mean is undefined. Due to this ambiguity, the elements of output
tensor with strided layout, that correspond to fully masked-out
elements, have ``nan`` values.
{reduction_args}
{reduction_example}"""
if dtype is None:
dtype = input.dtype
if input.layout == torch.strided:
inmask = _input_mask(input, mask=mask)
count = sum(inmask.new_ones(input.shape, dtype=torch.int64), dim, keepdim=keepdim, mask=inmask)
total = sum(input, dim, keepdim=keepdim, dtype=dtype, mask=inmask)
return total / count
else:
raise ValueError(f'masked sum expects strided tensor (got {input.layout} tensor)')

View File

@ -11177,6 +11177,24 @@ op_db: List[OpInfo] = [
sample_inputs_func=sample_inputs_masked_reduction, sample_inputs_func=sample_inputs_masked_reduction,
gradcheck_wrapper=gradcheck_wrapper_masked_operation gradcheck_wrapper=gradcheck_wrapper_masked_operation
), ),
ReductionOpInfo(
'_masked.mean',
ref=reference_reduction_numpy(np.mean) if np.lib.NumpyVersion(np.__version__) >= '1.20.0' else None,
method_variant=None,
nan_policy='propagate',
supports_out=False,
promotes_int_to_float=True,
dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16, torch.bool),
skips=(
# FIXME: sum reduces all dimensions when dim=[]
DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'),
DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'),
# RuntimeError: undefined value tensor
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
),
sample_inputs_func=sample_inputs_masked_reduction,
gradcheck_wrapper=gradcheck_wrapper_masked_operation
),
OpInfo( OpInfo(
"nn.functional.nll_loss", "nn.functional.nll_loss",
ref=_NOTHING, ref=_NOTHING,