Strided masked normalize. (#68694)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/68694

Test Plan: Imported from OSS

Reviewed By: samdow

Differential Revision: D32724552

Pulled By: cpuhrsch

fbshipit-source-id: 82f579a86b0b265e0b9b3715a8a327b775dd55e1
This commit is contained in:
Pearu Peterson 2021-12-01 10:43:55 -08:00 committed by Facebook GitHub Bot
parent 23633bdb5c
commit 1842364b30
3 changed files with 87 additions and 19 deletions

View File

@ -126,21 +126,26 @@ def apply_masked_reduction_along_dim(op, input, *args, **kwargs):
return output
def apply_masked_normalization_along_dim(op, x, dim, dtype=None, mask=None):
def apply_masked_normalization_along_dim(op, input, *args, **kwargs):
"""Applies normalization op along given dimension to strided x
elements that are valid according to mask tensor.
"""
if x.ndim == 0: # scalar input
return op(x, dim, dtype=dtype)
y = torch.zeros_like(x, dtype=dtype)
inpmask = torch._masked._input_mask(x, mask=mask)
dim_ = dim % x.ndim
left_ranges = tuple(map(range, x.shape[:dim_]))
right_ranges = tuple(map(range, x.shape[dim_ + 1:]))
mask = kwargs.pop('mask', None)
dim_pos = kwargs.pop('dim_position', 0)
if input.ndim == 0: # scalar input
return op(input, *args, **kwargs)
dtype = kwargs.get('dtype', input.dtype)
dim = args[dim_pos]
args0 = args[:dim_pos] + (0,) + args[dim_pos + 1:]
output = torch.zeros_like(input, dtype=dtype)
inpmask = torch._masked._input_mask(input, mask=mask)
dim_ = dim % input.ndim
left_ranges = tuple(map(range, input.shape[:dim_]))
right_ranges = tuple(map(range, input.shape[dim_ + 1:]))
for s in itertools.product(*(left_ranges + ((slice(None),),) + right_ranges)):
indices = inpmask[s].argwhere()
y[s][indices] = op(x[s][indices], 0, dtype=dtype)
return y
output[s][indices] = op(input[s][indices], *args0, **kwargs)
return output
reference_functions = dict(
@ -148,6 +153,8 @@ reference_functions = dict(
softmax=lambda *args, **kwargs: apply_masked_normalization_along_dim(torch.softmax, *args, **kwargs),
log_softmax=lambda *args, **kwargs: apply_masked_normalization_along_dim(torch.log_softmax, *args, **kwargs),
softmin=lambda *args, **kwargs: apply_masked_normalization_along_dim(torch.nn.functional.softmin, *args, **kwargs),
normalize=lambda *args, **kwargs: apply_masked_normalization_along_dim(
torch.nn.functional.normalize, *args, **dict(kwargs, dim_position=1)),
)
masked_ops = [op for op in op_db if op.name.startswith('_masked.')]

View File

@ -144,6 +144,7 @@ Example::
softmax=(('dim__as_int',), ('dtype=None', 'mask=None')),
log_softmax=(('dim__as_int',), ('dtype=None', 'mask=None')),
softmin=(('dim__as_int',), ('dtype=None', 'mask=None')),
normalize=(('ord__required', 'dim__as_int',), ('eps=1e-12', 'dtype=None', 'mask=None')),
)
argument_declarations = dict(
@ -154,12 +155,15 @@ dim (int or tuple of ints, optional): the dimension or dimensions to reduce.
dim (int): the dimension along which {operation name} is computed.''',
ord='''\
ord (int, float, optional): the order of vector norm. Default: 2.
See :func:`torch.linalg.vector_norm` for a list of supported norms.
''',
See :func:`torch.linalg.vector_norm` for a list of supported norms.''',
ord__required='''\
ord (int, float): the order of vector norm. Default: 2.
See :func:`torch.linalg.vector_norm` for a list of supported norms.''',
eps='''\
eps (float, optional): small value to avoid division by zero. Default: {default}.''',
keepdim='''\
keepdim (bool, optional): whether the output tensor has
:attr:`dim` retained or not. Default: {default}.
''',
:attr:`dim` retained or not. Default: {default}.''',
dtype='''\
dtype (:class:`torch.dtype`, optional): the desired data type
of returned tensor. If specified, the input tensor is
@ -183,7 +187,11 @@ defined as ``log(exp(x[i])/sum(exp(x)))``.''',
softmin='''\
Let ``x`` be a sequence of unmasked elements of one-dimensional slice
of the :attr:`input` tensor. Softmin of i-th element in ``x`` is
defined as ``exp(-x[i])/sum(exp(-x))``.''')
defined as ``exp(-x[i])/sum(exp(-x))``.''',
normalize='''\
Let ``x`` be a sequence of unmasked elements of one-dimensional slice
of the :attr:`input` tensor. Normalize of i-th element in ``x`` is
defined as ``x[i]/max(norm(x, p), eps)``.''')
reduction_names = dict(
sum='sum',
@ -196,7 +204,8 @@ defined as ``exp(-x[i])/sum(exp(-x))``.''')
normalization_names = dict(
softmax='softmax',
log_softmax='log_softmax',
softmin='softmin')
softmin='softmin',
normalize='normalize')
operation_names = dict()
operation_names.update(reduction_names)
@ -207,7 +216,7 @@ defined as ``exp(-x[i])/sum(exp(-x))``.''')
example_input = torch.tensor([[-3, -2, -1], [0, 1, 2]])
example_mask = torch.tensor([[True, False, True], [False, False, False]])
example_args: Tuple[Any, ...]
if func.__name__ == 'norm':
if func.__name__ in {'norm', 'normalize'}:
example_args = (2.0, example_dim)
example_input = example_input.to(dtype=torch.float32)
else:
@ -375,11 +384,11 @@ def _output_mask(op, input: Tensor, *args, **kwargs) -> Tensor:
"""
if callable(op):
is_reduction = op.__name__ in {'sum', 'prod', 'amax', 'amin', 'mean', 'norm'}
is_normalization = op.__name__ in {'softmax', 'log_softmax', 'softmin'}
is_normalization = op.__name__ in {'softmax', 'log_softmax', 'softmin', 'normalize'}
if is_reduction:
if op.__name__ == 'norm':
if args:
args = args[1:] # lstrip p argument
args = args[1:] # lstrip ord argument
dim = args[0] if args else kwargs.get('dim')
outmask = _input_mask(input, *args, **kwargs)
keepdim = kwargs.get('keepdim', False)
@ -618,3 +627,27 @@ def softmin(input: Tensor,
return torch.nn.functional.softmin(mask_input, dim_, dtype=dtype)
else:
raise ValueError(f'masked softmin expects strided tensor (got {input.layout} tensor)')
@_apply_docstring_templates
def normalize(input: Tensor,
ord: float,
dim: int,
*,
eps: float = 1e-12,
dtype: Optional[DType] = None,
mask: Optional[Tensor] = None) -> Tensor:
if dtype is None:
dtype = input.dtype
dim_ = _canonical_dim(dim, input.ndim)[0]
if input.layout == torch.strided:
nrm_ = norm(input, ord, dim, keepdim=True, dtype=dtype, mask=mask)
# TODO: replace torch.maximum with masked maximum when available.
denom = torch.maximum(nrm_, nrm_.new_full([], eps))
# TODO: eliminate mask_input as unnecessary when using masked divide.
inmask = _input_mask(input, mask=mask)
mask_input = input if mask is None else torch.where(inmask, input, input.new_zeros([]))
# TODO: replace torch.divide with masked divide when available.
return torch.divide(mask_input, denom)
else:
raise ValueError(f'masked normalize expects strided tensor (got {input.layout} tensor)')

View File

@ -1006,6 +1006,7 @@ def sample_inputs_masked_reduction(op_info, device, dtype, requires_grad, **kwar
return inputs
def sample_inputs_masked_norm(op_info, device, dtype, requires_grad, **kwargs):
"""Sample inputs for masked norm.
"""
@ -5822,6 +5823,17 @@ def sample_inputs_masked_softmax(op_info, device, dtype, requires_grad, with_dty
return inputs
def sample_inputs_masked_normalize(op_info, device, dtype, requires_grad, **kwargs):
"""Sample inputs for masked normalize.
"""
inputs: List[SampleInput] = []
for ord in [2.0, 1, float('inf'), float('-inf'), 0]:
for sample_input in sample_inputs_softmax_variant(op_info, device, dtype, requires_grad, **kwargs):
sample_input_args, sample_input_kwargs = (ord,) + sample_input.args, sample_input.kwargs.copy()
inputs.append(SampleInput(sample_input.input.detach().clone().requires_grad_(requires_grad),
args=sample_input_args, kwargs=sample_input_kwargs))
return inputs
def sample_inputs_logit(op_info, device, dtype, requires_grad, **kwargs):
low, high = op_info.domain
@ -13785,6 +13797,22 @@ op_db: List[OpInfo] = [
),
gradcheck_wrapper=gradcheck_wrapper_masked_operation,
supports_out=False),
OpInfo(
'_masked.normalize',
method_variant=None,
dtypes=floating_types_and(torch.half, torch.bfloat16),
sample_inputs_func=sample_inputs_masked_normalize,
skips=(
# torch.jit.frontend.NotSupportedError: Compiled
# functions can't take variable number of arguments or
# use keyword-only arguments with defaults
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
# RuntimeError: "clamp_min_cpu" not implemented for 'Half'
DecorateInfo(unittest.skip("Skipped!"), 'TestMasked', 'test_reference_masked',
device_type='cpu', dtypes=[torch.half]),
),
gradcheck_wrapper=gradcheck_wrapper_masked_operation,
supports_out=False),
OpInfo(
"nn.functional.ctc_loss",
ref=_NOTHING,