mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fix NJT frexp() to handle both outputs (#144585)
Part of my BE project addressing NJT bugs surfaced via OpInfo tests. Before this PR, `frexp()` for NJT was handled via the unary pointwise fallback. The op returns a tuple, however, and the fallback doesn't handle that. This PR defines an explicit impl for `frexp()` that wraps both returned `(mantissa, exponent)` as NJTs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/144585 Approved by: https://github.com/soulitzer ghstack dependencies: #144582, #144583, #144584
This commit is contained in:
parent
3ee531f8b9
commit
b63b81410c
|
|
@ -8150,14 +8150,6 @@ FORWARD_SKIPS_AND_XFAILS = [
|
|||
),
|
||||
name="binary_noncontig_holes_broadcasting_1_over_ragged",
|
||||
),
|
||||
# Bug: this op returns a tuple of Tensors so it doesn't work with NJT's unary
|
||||
# pointwise logic
|
||||
XFailRule(
|
||||
error_type=AttributeError,
|
||||
error_msg="'tuple' object has no attribute 'device'",
|
||||
op_match_fn=lambda device, op: op.full_name == "frexp",
|
||||
name="frexp_tuple_return",
|
||||
),
|
||||
# Bug: fill doesn't work with NJTs at all for some reason
|
||||
XFailRule(
|
||||
error_type=TypeError,
|
||||
|
|
|
|||
|
|
@ -2467,6 +2467,21 @@ def fill__Scalar(func, *args, **kwargs):
|
|||
return inp
|
||||
|
||||
|
||||
@register_jagged_func(torch.ops.aten.frexp.Tensor, "self: jt_all")
|
||||
def frexp_Tensor(func, *args, **kwargs):
|
||||
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||
)
|
||||
|
||||
inp = new_kwargs.pop("input")
|
||||
output_kwargs = extract_kwargs(inp)
|
||||
|
||||
mantissa, exponent = func(inp._values)
|
||||
return NestedTensor(mantissa, **output_kwargs), NestedTensor(
|
||||
exponent, **output_kwargs
|
||||
)
|
||||
|
||||
|
||||
from torch._higher_order_ops.flex_attention import (
|
||||
flex_attention as flex_attention_hop,
|
||||
flex_attention_backward as flex_attention_backward_hop,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user