Fix NJT frexp() to handle both outputs (#144585)

Part of my BE project addressing NJT bugs surfaced via OpInfo tests.

Before this PR, `frexp()` for NJT was handled via the unary pointwise fallback. The op returns a tuple, however, and the fallback doesn't handle that. This PR defines an explicit impl for `frexp()` that wraps both returned `(mantissa, exponent)` as NJTs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144585
Approved by: https://github.com/soulitzer
ghstack dependencies: #144582, #144583, #144584
This commit is contained in:
Joel Schlosser 2025-01-17 16:53:46 -05:00 committed by PyTorch MergeBot
parent 3ee531f8b9
commit b63b81410c
2 changed files with 15 additions and 8 deletions

View File

@ -8150,14 +8150,6 @@ FORWARD_SKIPS_AND_XFAILS = [
),
name="binary_noncontig_holes_broadcasting_1_over_ragged",
),
# Bug: this op returns a tuple of Tensors so it doesn't work with NJT's unary
# pointwise logic
XFailRule(
error_type=AttributeError,
error_msg="'tuple' object has no attribute 'device'",
op_match_fn=lambda device, op: op.full_name == "frexp",
name="frexp_tuple_return",
),
# Bug: fill doesn't work with NJTs at all for some reason
XFailRule(
error_type=TypeError,

View File

@ -2467,6 +2467,21 @@ def fill__Scalar(func, *args, **kwargs):
return inp
@register_jagged_func(torch.ops.aten.frexp.Tensor, "self: jt_all")
def frexp_Tensor(func, *args, **kwargs):
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
inp = new_kwargs.pop("input")
output_kwargs = extract_kwargs(inp)
mantissa, exponent = func(inp._values)
return NestedTensor(mantissa, **output_kwargs), NestedTensor(
exponent, **output_kwargs
)
from torch._higher_order_ops.flex_attention import (
flex_attention as flex_attention_hop,
flex_attention_backward as flex_attention_backward_hop,