mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "Forward / backward NJT support for several activation functions (#140736)"
This reverts commit daaecb96d6.
Reverted https://github.com/pytorch/pytorch/pull/140736 on behalf of https://github.com/malfet due to Take 2, of stack revert your change but its tests are failing in trunk ([comment](https://github.com/pytorch/pytorch/pull/140736#issuecomment-2498479702))
This commit is contained in:
parent
e0f9ec4a25
commit
cffeb83f15
|
|
@ -5018,7 +5018,7 @@
|
|||
|
||||
- func: rrelu(Tensor self, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor
|
||||
device_check: NoCheck # TensorIterator
|
||||
tags: [pointwise, nondeterministic_seeded]
|
||||
tags: nondeterministic_seeded
|
||||
|
||||
- func: rrelu_(Tensor(a!) self, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor(a!)
|
||||
tags: nondeterministic_seeded
|
||||
|
|
@ -5055,7 +5055,6 @@
|
|||
|
||||
- func: relu6(Tensor self) -> Tensor
|
||||
python_module: nn
|
||||
tags: pointwise
|
||||
|
||||
- func: relu6_(Tensor(a!) self) -> Tensor(a!)
|
||||
python_module: nn
|
||||
|
|
@ -5140,7 +5139,6 @@
|
|||
structured_delegate: hardshrink.out
|
||||
device_check: NoCheck # TensorIterator
|
||||
variants: function, method
|
||||
tags: pointwise
|
||||
|
||||
- func: hardshrink_backward.grad_input(Tensor grad_out, Tensor self, Scalar lambd, *, Tensor(a!) grad_input) -> Tensor(a!)
|
||||
structured: True
|
||||
|
|
@ -5205,7 +5203,6 @@
|
|||
|
||||
- func: selu(Tensor self) -> Tensor
|
||||
device_check: NoCheck # TensorIterator
|
||||
tags: pointwise
|
||||
|
||||
- func: selu_(Tensor(a!) self) -> Tensor(a!)
|
||||
device_check: NoCheck # TensorIterator
|
||||
|
|
@ -5214,7 +5211,6 @@
|
|||
device_check: NoCheck # TensorIterator
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: celu
|
||||
tags: pointwise
|
||||
|
||||
- func: celu_(Tensor(a!) self, Scalar alpha=1.0) -> Tensor(a!)
|
||||
device_check: NoCheck # TensorIterator
|
||||
|
|
@ -5265,7 +5261,6 @@
|
|||
- func: mish(Tensor self) -> Tensor
|
||||
structured_delegate: mish.out
|
||||
python_module: nn
|
||||
tags: pointwise
|
||||
|
||||
- func: mish_(Tensor(a!) self) -> Tensor(a!)
|
||||
structured_delegate: mish.out
|
||||
|
|
@ -6066,7 +6061,6 @@
|
|||
structured_delegate: threshold.out
|
||||
dispatch:
|
||||
QuantizedCPU: threshold_quantized_cpu
|
||||
tags: pointwise
|
||||
|
||||
- func: threshold_(Tensor(a!) self, Scalar threshold, Scalar value) -> Tensor(a!)
|
||||
device_check: NoCheck # TensorIterator
|
||||
|
|
@ -11787,7 +11781,6 @@
|
|||
structured_delegate: elu.out
|
||||
device_check: NoCheck # TensorIterator
|
||||
python_module: nn
|
||||
tags: pointwise
|
||||
|
||||
- func: elu_backward.grad_input(Tensor grad_output, Scalar alpha, Scalar scale, Scalar input_scale, bool is_result, Tensor self_or_result, *, Tensor(a!) grad_input) -> Tensor(a!)
|
||||
structured: True
|
||||
|
|
@ -11861,7 +11854,6 @@
|
|||
python_module: nn
|
||||
dispatch:
|
||||
QuantizedCPU: hardsigmoid_quantized_cpu
|
||||
tags: pointwise
|
||||
|
||||
- func: hardsigmoid_(Tensor(a!) self) -> Tensor(a!)
|
||||
structured_delegate: hardsigmoid.out
|
||||
|
|
@ -11893,7 +11885,7 @@
|
|||
dispatch:
|
||||
CPU, CUDA, MPS: hardtanh
|
||||
QuantizedCPU: hardtanh_quantized_cpu
|
||||
tags: [pointwise, core]
|
||||
tags: core
|
||||
|
||||
- func: hardtanh_backward.grad_input(Tensor grad_output, Tensor self, Scalar min_val, Scalar max_val, *, Tensor(a!) grad_input) -> Tensor(a!)
|
||||
python_module: nn
|
||||
|
|
@ -12057,7 +12049,6 @@
|
|||
structured_delegate: softplus.out
|
||||
device_check: NoCheck # TensorIterator
|
||||
python_module: nn
|
||||
tags: pointwise
|
||||
|
||||
- func: softplus_backward.grad_input(Tensor grad_output, Tensor self, Scalar beta, Scalar threshold, *, Tensor(a!) grad_input) -> Tensor(a!)
|
||||
structured: True
|
||||
|
|
@ -12084,7 +12075,6 @@
|
|||
structured_delegate: softshrink.out
|
||||
device_check: NoCheck # TensorIterator
|
||||
python_module: nn
|
||||
tags: pointwise
|
||||
|
||||
- func: softshrink_backward.grad_input(Tensor grad_output, Tensor self, Scalar lambd, *, Tensor(a!) grad_input) -> Tensor(a!)
|
||||
structured: True
|
||||
|
|
|
|||
|
|
@ -7875,10 +7875,19 @@ FORWARD_SKIPS_AND_XFAILS = [
|
|||
op_match_fn=lambda device, op: op.full_name
|
||||
in {
|
||||
# unary
|
||||
# needs log_sigmoid_forward, which returns a tuple
|
||||
"nn.functional.celu",
|
||||
"nn.functional.elu",
|
||||
"nn.functional.hardshrink",
|
||||
"nn.functional.hardsigmoid",
|
||||
"nn.functional.hardtanh",
|
||||
"nn.functional.logsigmoid",
|
||||
# needs rrelu_with_noise
|
||||
"nn.functional.mish",
|
||||
"nn.functional.relu6",
|
||||
"nn.functional.rrelu",
|
||||
"nn.functional.selu",
|
||||
"nn.functional.softplus",
|
||||
"nn.functional.softshrink",
|
||||
"nn.functional.threshold",
|
||||
# binary
|
||||
"__rsub__",
|
||||
"complex",
|
||||
|
|
@ -8269,20 +8278,6 @@ COMPILE_FORWARD_SKIPS_AND_XFAILS = [
|
|||
sample_match_fn=lambda device, sample: ("noncontig_transposed" in sample.name),
|
||||
name="crazy_aot_autograd_bug2",
|
||||
),
|
||||
# Bug: Something is wrongly creating an empty tensor with the jagged layout on the C++ side
|
||||
# for these activation ops
|
||||
XFailRule(
|
||||
error_type=torch._dynamo.exc.Unsupported,
|
||||
error_msg="non-strided meta tensors not supported yet",
|
||||
op_match_fn=lambda device, op: (
|
||||
op.full_name
|
||||
in {
|
||||
"nn.functional.hardshrink",
|
||||
"nn.functional.softshrink",
|
||||
}
|
||||
),
|
||||
name="empty_with_jagged_layout_activation",
|
||||
),
|
||||
]
|
||||
|
||||
COMPILE_BACKWARD_SKIPS_AND_XFAILS = [
|
||||
|
|
|
|||
|
|
@ -2224,41 +2224,6 @@ def new_empty_default(func, *args, **kwargs):
|
|||
raise RuntimeError("new_empty() not supported for NJT with shape != ()")
|
||||
|
||||
|
||||
@register_jagged_func(
|
||||
[
|
||||
torch.ops.aten.elu_backward.default,
|
||||
torch.ops.aten.hardshrink_backward.default,
|
||||
torch.ops.aten.hardsigmoid_backward.default,
|
||||
torch.ops.aten.hardtanh_backward.default,
|
||||
torch.ops.aten.softplus_backward.default,
|
||||
torch.ops.aten.softshrink_backward.default,
|
||||
],
|
||||
"self: jt_all, ...",
|
||||
)
|
||||
def activation_backward(func, *args, **kwargs):
|
||||
# first NJT arg is expected to be grad_output
|
||||
grad_output = next(arg for arg in args if isinstance(arg, NestedTensor))
|
||||
return NestedTensor(
|
||||
func(
|
||||
*(arg._values if isinstance(arg, NestedTensor) else arg for arg in args),
|
||||
**kwargs,
|
||||
),
|
||||
**extract_kwargs(grad_output),
|
||||
)
|
||||
|
||||
|
||||
@register_jagged_func(torch.ops.aten.fill_.Scalar, "self: jt_all, value: any")
|
||||
def fill__Scalar(func, *args, **kwargs):
|
||||
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||
)
|
||||
|
||||
inp = new_kwargs.pop("input")
|
||||
|
||||
func(inp._values, **new_kwargs)
|
||||
return inp
|
||||
|
||||
|
||||
from torch._higher_order_ops.flex_attention import (
|
||||
flex_attention as flex_attention_hop,
|
||||
flex_attention_backward as flex_attention_backward_hop,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user