mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Strided masked normalize. (#68694)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/68694 Test Plan: Imported from OSS Reviewed By: samdow Differential Revision: D32724552 Pulled By: cpuhrsch fbshipit-source-id: 82f579a86b0b265e0b9b3715a8a327b775dd55e1
This commit is contained in:
parent
23633bdb5c
commit
1842364b30
|
|
@ -126,21 +126,26 @@ def apply_masked_reduction_along_dim(op, input, *args, **kwargs):
|
|||
return output
|
||||
|
||||
|
||||
def apply_masked_normalization_along_dim(op, x, dim, dtype=None, mask=None):
|
||||
def apply_masked_normalization_along_dim(op, input, *args, **kwargs):
|
||||
"""Applies normalization op along given dimension to strided x
|
||||
elements that are valid according to mask tensor.
|
||||
"""
|
||||
if x.ndim == 0: # scalar input
|
||||
return op(x, dim, dtype=dtype)
|
||||
y = torch.zeros_like(x, dtype=dtype)
|
||||
inpmask = torch._masked._input_mask(x, mask=mask)
|
||||
dim_ = dim % x.ndim
|
||||
left_ranges = tuple(map(range, x.shape[:dim_]))
|
||||
right_ranges = tuple(map(range, x.shape[dim_ + 1:]))
|
||||
mask = kwargs.pop('mask', None)
|
||||
dim_pos = kwargs.pop('dim_position', 0)
|
||||
if input.ndim == 0: # scalar input
|
||||
return op(input, *args, **kwargs)
|
||||
dtype = kwargs.get('dtype', input.dtype)
|
||||
dim = args[dim_pos]
|
||||
args0 = args[:dim_pos] + (0,) + args[dim_pos + 1:]
|
||||
output = torch.zeros_like(input, dtype=dtype)
|
||||
inpmask = torch._masked._input_mask(input, mask=mask)
|
||||
dim_ = dim % input.ndim
|
||||
left_ranges = tuple(map(range, input.shape[:dim_]))
|
||||
right_ranges = tuple(map(range, input.shape[dim_ + 1:]))
|
||||
for s in itertools.product(*(left_ranges + ((slice(None),),) + right_ranges)):
|
||||
indices = inpmask[s].argwhere()
|
||||
y[s][indices] = op(x[s][indices], 0, dtype=dtype)
|
||||
return y
|
||||
output[s][indices] = op(input[s][indices], *args0, **kwargs)
|
||||
return output
|
||||
|
||||
|
||||
reference_functions = dict(
|
||||
|
|
@ -148,6 +153,8 @@ reference_functions = dict(
|
|||
softmax=lambda *args, **kwargs: apply_masked_normalization_along_dim(torch.softmax, *args, **kwargs),
|
||||
log_softmax=lambda *args, **kwargs: apply_masked_normalization_along_dim(torch.log_softmax, *args, **kwargs),
|
||||
softmin=lambda *args, **kwargs: apply_masked_normalization_along_dim(torch.nn.functional.softmin, *args, **kwargs),
|
||||
normalize=lambda *args, **kwargs: apply_masked_normalization_along_dim(
|
||||
torch.nn.functional.normalize, *args, **dict(kwargs, dim_position=1)),
|
||||
)
|
||||
|
||||
masked_ops = [op for op in op_db if op.name.startswith('_masked.')]
|
||||
|
|
|
|||
|
|
@ -144,6 +144,7 @@ Example::
|
|||
softmax=(('dim__as_int',), ('dtype=None', 'mask=None')),
|
||||
log_softmax=(('dim__as_int',), ('dtype=None', 'mask=None')),
|
||||
softmin=(('dim__as_int',), ('dtype=None', 'mask=None')),
|
||||
normalize=(('ord__required', 'dim__as_int',), ('eps=1e-12', 'dtype=None', 'mask=None')),
|
||||
)
|
||||
|
||||
argument_declarations = dict(
|
||||
|
|
@ -154,12 +155,15 @@ dim (int or tuple of ints, optional): the dimension or dimensions to reduce.
|
|||
dim (int): the dimension along which {operation name} is computed.''',
|
||||
ord='''\
|
||||
ord (int, float, optional): the order of vector norm. Default: 2.
|
||||
See :func:`torch.linalg.vector_norm` for a list of supported norms.
|
||||
''',
|
||||
See :func:`torch.linalg.vector_norm` for a list of supported norms.''',
|
||||
ord__required='''\
|
||||
ord (int, float): the order of vector norm. Default: 2.
|
||||
See :func:`torch.linalg.vector_norm` for a list of supported norms.''',
|
||||
eps='''\
|
||||
eps (float, optional): small value to avoid division by zero. Default: {default}.''',
|
||||
keepdim='''\
|
||||
keepdim (bool, optional): whether the output tensor has
|
||||
:attr:`dim` retained or not. Default: {default}.
|
||||
''',
|
||||
:attr:`dim` retained or not. Default: {default}.''',
|
||||
dtype='''\
|
||||
dtype (:class:`torch.dtype`, optional): the desired data type
|
||||
of returned tensor. If specified, the input tensor is
|
||||
|
|
@ -183,7 +187,11 @@ defined as ``log(exp(x[i])/sum(exp(x)))``.''',
|
|||
softmin='''\
|
||||
Let ``x`` be a sequence of unmasked elements of one-dimensional slice
|
||||
of the :attr:`input` tensor. Softmin of i-th element in ``x`` is
|
||||
defined as ``exp(-x[i])/sum(exp(-x))``.''')
|
||||
defined as ``exp(-x[i])/sum(exp(-x))``.''',
|
||||
normalize='''\
|
||||
Let ``x`` be a sequence of unmasked elements of one-dimensional slice
|
||||
of the :attr:`input` tensor. Normalize of i-th element in ``x`` is
|
||||
defined as ``x[i]/max(norm(x, p), eps)``.''')
|
||||
|
||||
reduction_names = dict(
|
||||
sum='sum',
|
||||
|
|
@ -196,7 +204,8 @@ defined as ``exp(-x[i])/sum(exp(-x))``.''')
|
|||
normalization_names = dict(
|
||||
softmax='softmax',
|
||||
log_softmax='log_softmax',
|
||||
softmin='softmin')
|
||||
softmin='softmin',
|
||||
normalize='normalize')
|
||||
|
||||
operation_names = dict()
|
||||
operation_names.update(reduction_names)
|
||||
|
|
@ -207,7 +216,7 @@ defined as ``exp(-x[i])/sum(exp(-x))``.''')
|
|||
example_input = torch.tensor([[-3, -2, -1], [0, 1, 2]])
|
||||
example_mask = torch.tensor([[True, False, True], [False, False, False]])
|
||||
example_args: Tuple[Any, ...]
|
||||
if func.__name__ == 'norm':
|
||||
if func.__name__ in {'norm', 'normalize'}:
|
||||
example_args = (2.0, example_dim)
|
||||
example_input = example_input.to(dtype=torch.float32)
|
||||
else:
|
||||
|
|
@ -375,11 +384,11 @@ def _output_mask(op, input: Tensor, *args, **kwargs) -> Tensor:
|
|||
"""
|
||||
if callable(op):
|
||||
is_reduction = op.__name__ in {'sum', 'prod', 'amax', 'amin', 'mean', 'norm'}
|
||||
is_normalization = op.__name__ in {'softmax', 'log_softmax', 'softmin'}
|
||||
is_normalization = op.__name__ in {'softmax', 'log_softmax', 'softmin', 'normalize'}
|
||||
if is_reduction:
|
||||
if op.__name__ == 'norm':
|
||||
if args:
|
||||
args = args[1:] # lstrip p argument
|
||||
args = args[1:] # lstrip ord argument
|
||||
dim = args[0] if args else kwargs.get('dim')
|
||||
outmask = _input_mask(input, *args, **kwargs)
|
||||
keepdim = kwargs.get('keepdim', False)
|
||||
|
|
@ -618,3 +627,27 @@ def softmin(input: Tensor,
|
|||
return torch.nn.functional.softmin(mask_input, dim_, dtype=dtype)
|
||||
else:
|
||||
raise ValueError(f'masked softmin expects strided tensor (got {input.layout} tensor)')
|
||||
|
||||
|
||||
@_apply_docstring_templates
|
||||
def normalize(input: Tensor,
|
||||
ord: float,
|
||||
dim: int,
|
||||
*,
|
||||
eps: float = 1e-12,
|
||||
dtype: Optional[DType] = None,
|
||||
mask: Optional[Tensor] = None) -> Tensor:
|
||||
if dtype is None:
|
||||
dtype = input.dtype
|
||||
dim_ = _canonical_dim(dim, input.ndim)[0]
|
||||
if input.layout == torch.strided:
|
||||
nrm_ = norm(input, ord, dim, keepdim=True, dtype=dtype, mask=mask)
|
||||
# TODO: replace torch.maximum with masked maximum when available.
|
||||
denom = torch.maximum(nrm_, nrm_.new_full([], eps))
|
||||
# TODO: eliminate mask_input as unnecessary when using masked divide.
|
||||
inmask = _input_mask(input, mask=mask)
|
||||
mask_input = input if mask is None else torch.where(inmask, input, input.new_zeros([]))
|
||||
# TODO: replace torch.divide with masked divide when available.
|
||||
return torch.divide(mask_input, denom)
|
||||
else:
|
||||
raise ValueError(f'masked normalize expects strided tensor (got {input.layout} tensor)')
|
||||
|
|
|
|||
|
|
@ -1006,6 +1006,7 @@ def sample_inputs_masked_reduction(op_info, device, dtype, requires_grad, **kwar
|
|||
|
||||
return inputs
|
||||
|
||||
|
||||
def sample_inputs_masked_norm(op_info, device, dtype, requires_grad, **kwargs):
|
||||
"""Sample inputs for masked norm.
|
||||
"""
|
||||
|
|
@ -5822,6 +5823,17 @@ def sample_inputs_masked_softmax(op_info, device, dtype, requires_grad, with_dty
|
|||
return inputs
|
||||
|
||||
|
||||
def sample_inputs_masked_normalize(op_info, device, dtype, requires_grad, **kwargs):
|
||||
"""Sample inputs for masked normalize.
|
||||
"""
|
||||
inputs: List[SampleInput] = []
|
||||
for ord in [2.0, 1, float('inf'), float('-inf'), 0]:
|
||||
for sample_input in sample_inputs_softmax_variant(op_info, device, dtype, requires_grad, **kwargs):
|
||||
sample_input_args, sample_input_kwargs = (ord,) + sample_input.args, sample_input.kwargs.copy()
|
||||
inputs.append(SampleInput(sample_input.input.detach().clone().requires_grad_(requires_grad),
|
||||
args=sample_input_args, kwargs=sample_input_kwargs))
|
||||
return inputs
|
||||
|
||||
def sample_inputs_logit(op_info, device, dtype, requires_grad, **kwargs):
|
||||
low, high = op_info.domain
|
||||
|
||||
|
|
@ -13785,6 +13797,22 @@ op_db: List[OpInfo] = [
|
|||
),
|
||||
gradcheck_wrapper=gradcheck_wrapper_masked_operation,
|
||||
supports_out=False),
|
||||
OpInfo(
|
||||
'_masked.normalize',
|
||||
method_variant=None,
|
||||
dtypes=floating_types_and(torch.half, torch.bfloat16),
|
||||
sample_inputs_func=sample_inputs_masked_normalize,
|
||||
skips=(
|
||||
# torch.jit.frontend.NotSupportedError: Compiled
|
||||
# functions can't take variable number of arguments or
|
||||
# use keyword-only arguments with defaults
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
|
||||
# RuntimeError: "clamp_min_cpu" not implemented for 'Half'
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestMasked', 'test_reference_masked',
|
||||
device_type='cpu', dtypes=[torch.half]),
|
||||
),
|
||||
gradcheck_wrapper=gradcheck_wrapper_masked_operation,
|
||||
supports_out=False),
|
||||
OpInfo(
|
||||
"nn.functional.ctc_loss",
|
||||
ref=_NOTHING,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user