Revert "Forward / backward NJT support for several activation functions (#140736)"

This reverts commit daaecb96d6.

Reverted https://github.com/pytorch/pytorch/pull/140736 on behalf of https://github.com/malfet due to Take 2, of stack revert your change but its tests are failing in trunk ([comment](https://github.com/pytorch/pytorch/pull/140736#issuecomment-2498479702))
This commit is contained in:
PyTorch MergeBot 2024-11-25 16:27:00 +00:00
parent e0f9ec4a25
commit cffeb83f15
3 changed files with 13 additions and 63 deletions

View File

@ -5018,7 +5018,7 @@
- func: rrelu(Tensor self, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor
device_check: NoCheck # TensorIterator
tags: [pointwise, nondeterministic_seeded]
tags: nondeterministic_seeded
- func: rrelu_(Tensor(a!) self, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor(a!)
tags: nondeterministic_seeded
@ -5055,7 +5055,6 @@
- func: relu6(Tensor self) -> Tensor
python_module: nn
tags: pointwise
- func: relu6_(Tensor(a!) self) -> Tensor(a!)
python_module: nn
@ -5140,7 +5139,6 @@
structured_delegate: hardshrink.out
device_check: NoCheck # TensorIterator
variants: function, method
tags: pointwise
- func: hardshrink_backward.grad_input(Tensor grad_out, Tensor self, Scalar lambd, *, Tensor(a!) grad_input) -> Tensor(a!)
structured: True
@ -5205,7 +5203,6 @@
- func: selu(Tensor self) -> Tensor
device_check: NoCheck # TensorIterator
tags: pointwise
- func: selu_(Tensor(a!) self) -> Tensor(a!)
device_check: NoCheck # TensorIterator
@ -5214,7 +5211,6 @@
device_check: NoCheck # TensorIterator
dispatch:
CompositeExplicitAutograd: celu
tags: pointwise
- func: celu_(Tensor(a!) self, Scalar alpha=1.0) -> Tensor(a!)
device_check: NoCheck # TensorIterator
@ -5265,7 +5261,6 @@
- func: mish(Tensor self) -> Tensor
structured_delegate: mish.out
python_module: nn
tags: pointwise
- func: mish_(Tensor(a!) self) -> Tensor(a!)
structured_delegate: mish.out
@ -6066,7 +6061,6 @@
structured_delegate: threshold.out
dispatch:
QuantizedCPU: threshold_quantized_cpu
tags: pointwise
- func: threshold_(Tensor(a!) self, Scalar threshold, Scalar value) -> Tensor(a!)
device_check: NoCheck # TensorIterator
@ -11787,7 +11781,6 @@
structured_delegate: elu.out
device_check: NoCheck # TensorIterator
python_module: nn
tags: pointwise
- func: elu_backward.grad_input(Tensor grad_output, Scalar alpha, Scalar scale, Scalar input_scale, bool is_result, Tensor self_or_result, *, Tensor(a!) grad_input) -> Tensor(a!)
structured: True
@ -11861,7 +11854,6 @@
python_module: nn
dispatch:
QuantizedCPU: hardsigmoid_quantized_cpu
tags: pointwise
- func: hardsigmoid_(Tensor(a!) self) -> Tensor(a!)
structured_delegate: hardsigmoid.out
@ -11893,7 +11885,7 @@
dispatch:
CPU, CUDA, MPS: hardtanh
QuantizedCPU: hardtanh_quantized_cpu
tags: [pointwise, core]
tags: core
- func: hardtanh_backward.grad_input(Tensor grad_output, Tensor self, Scalar min_val, Scalar max_val, *, Tensor(a!) grad_input) -> Tensor(a!)
python_module: nn
@ -12057,7 +12049,6 @@
structured_delegate: softplus.out
device_check: NoCheck # TensorIterator
python_module: nn
tags: pointwise
- func: softplus_backward.grad_input(Tensor grad_output, Tensor self, Scalar beta, Scalar threshold, *, Tensor(a!) grad_input) -> Tensor(a!)
structured: True
@ -12084,7 +12075,6 @@
structured_delegate: softshrink.out
device_check: NoCheck # TensorIterator
python_module: nn
tags: pointwise
- func: softshrink_backward.grad_input(Tensor grad_output, Tensor self, Scalar lambd, *, Tensor(a!) grad_input) -> Tensor(a!)
structured: True

View File

@ -7875,10 +7875,19 @@ FORWARD_SKIPS_AND_XFAILS = [
op_match_fn=lambda device, op: op.full_name
in {
# unary
# needs log_sigmoid_forward, which returns a tuple
"nn.functional.celu",
"nn.functional.elu",
"nn.functional.hardshrink",
"nn.functional.hardsigmoid",
"nn.functional.hardtanh",
"nn.functional.logsigmoid",
# needs rrelu_with_noise
"nn.functional.mish",
"nn.functional.relu6",
"nn.functional.rrelu",
"nn.functional.selu",
"nn.functional.softplus",
"nn.functional.softshrink",
"nn.functional.threshold",
# binary
"__rsub__",
"complex",
@ -8269,20 +8278,6 @@ COMPILE_FORWARD_SKIPS_AND_XFAILS = [
sample_match_fn=lambda device, sample: ("noncontig_transposed" in sample.name),
name="crazy_aot_autograd_bug2",
),
# Bug: Something is wrongly creating an empty tensor with the jagged layout on the C++ side
# for these activation ops
XFailRule(
error_type=torch._dynamo.exc.Unsupported,
error_msg="non-strided meta tensors not supported yet",
op_match_fn=lambda device, op: (
op.full_name
in {
"nn.functional.hardshrink",
"nn.functional.softshrink",
}
),
name="empty_with_jagged_layout_activation",
),
]
COMPILE_BACKWARD_SKIPS_AND_XFAILS = [

View File

@ -2224,41 +2224,6 @@ def new_empty_default(func, *args, **kwargs):
raise RuntimeError("new_empty() not supported for NJT with shape != ()")
@register_jagged_func(
[
torch.ops.aten.elu_backward.default,
torch.ops.aten.hardshrink_backward.default,
torch.ops.aten.hardsigmoid_backward.default,
torch.ops.aten.hardtanh_backward.default,
torch.ops.aten.softplus_backward.default,
torch.ops.aten.softshrink_backward.default,
],
"self: jt_all, ...",
)
def activation_backward(func, *args, **kwargs):
# first NJT arg is expected to be grad_output
grad_output = next(arg for arg in args if isinstance(arg, NestedTensor))
return NestedTensor(
func(
*(arg._values if isinstance(arg, NestedTensor) else arg for arg in args),
**kwargs,
),
**extract_kwargs(grad_output),
)
@register_jagged_func(torch.ops.aten.fill_.Scalar, "self: jt_all, value: any")
def fill__Scalar(func, *args, **kwargs):
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
inp = new_kwargs.pop("input")
func(inp._values, **new_kwargs)
return inp
from torch._higher_order_ops.flex_attention import (
flex_attention as flex_attention_hop,
flex_attention_backward as flex_attention_backward_hop,