diff --git a/torch/_masked/__init__.py b/torch/_masked/__init__.py index 16f8300ab7c..0ac1259f4bd 100644 --- a/torch/_masked/__init__.py +++ b/torch/_masked/__init__.py @@ -25,7 +25,6 @@ def _apply_docstring_templates(func): """Decorator that applies docstring templates to function docstring and returns the function instance. """ - docstring_templates = dict( reduction_signature='''\ {function_name}(input, dim, *, keepdim=False, dtype=None, mask=None) -> Tensor''', @@ -106,7 +105,8 @@ and int32 dtypes, the identity values are ``{identity_float32}``, ``{identity_ui sum='sum', prod='product', amax='maximum', - amin='minimum')[func.__name__], + amin='minimum', + mean='mean')[func.__name__], 'identity_uint8': _reduction_identity(func.__name__, torch.tensor(0, dtype=torch.uint8)), 'identity_int32': _reduction_identity(func.__name__, torch.tensor(0, dtype=torch.int32)), 'identity_float32': _reduction_identity(func.__name__, torch.tensor(0, dtype=torch.float32)), @@ -172,6 +172,13 @@ def _reduction_identity(op_name: str, input: Tensor): return torch.tensor(torch.inf, dtype=dtype, device=device) elif torch.is_signed(input) or dtype == torch.uint8: return torch.tensor(torch.iinfo(dtype).max, dtype=dtype, device=device) + elif op_name == 'mean': + # Strictly speaking, the identity value of the mean operation + # is the mean of the input. Since the mean value depends on + # the dim argument and it may be a non-scalar tensor, we + # consider the identity value of the mean operation ambiguous. + # Moreover, the mean value of empty input is undefined. + return None raise NotImplementedError(f'identity of {op_name} on {dtype} input') @@ -332,3 +339,36 @@ def amin(input: Tensor, return torch.amin(mask_input, dim_, bool(keepdim)).to(dtype=dtype) else: raise ValueError(f'masked amin expects strided tensor (got {input.layout} tensor)') + + +@_apply_docstring_templates +def mean(input: Tensor, + dim: DimOrDims = None, + *, + keepdim: Optional[bool] = False, + dtype: Optional[DType] = None, + mask: Optional[Tensor] = None) -> Tensor: + """\ +{reduction_signature} + +{reduction_descr} + +By definition, the identity value of a mean operation is the mean +value of the tensor. If all elements of the input tensor along given +dimension(s) :attr:`dim` are masked-out, the identity value of the +mean is undefined. Due to this ambiguity, the elements of output +tensor with strided layout, that correspond to fully masked-out +elements, have ``nan`` values. + +{reduction_args} + +{reduction_example}""" + if dtype is None: + dtype = input.dtype + if input.layout == torch.strided: + inmask = _input_mask(input, mask=mask) + count = sum(inmask.new_ones(input.shape, dtype=torch.int64), dim, keepdim=keepdim, mask=inmask) + total = sum(input, dim, keepdim=keepdim, dtype=dtype, mask=inmask) + return total / count + else: + raise ValueError(f'masked sum expects strided tensor (got {input.layout} tensor)') diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 2105db0c441..317d3396bea 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -11177,6 +11177,24 @@ op_db: List[OpInfo] = [ sample_inputs_func=sample_inputs_masked_reduction, gradcheck_wrapper=gradcheck_wrapper_masked_operation ), + ReductionOpInfo( + '_masked.mean', + ref=reference_reduction_numpy(np.mean) if np.lib.NumpyVersion(np.__version__) >= '1.20.0' else None, + method_variant=None, + nan_policy='propagate', + supports_out=False, + promotes_int_to_float=True, + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16, torch.bool), + skips=( + # FIXME: sum reduces all dimensions when dim=[] + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'), + # RuntimeError: undefined value tensor + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + ), + sample_inputs_func=sample_inputs_masked_reduction, + gradcheck_wrapper=gradcheck_wrapper_masked_operation + ), OpInfo( "nn.functional.nll_loss", ref=_NOTHING,