Get tensor subclasses and torch.library.triton_op to dispatch correctly (#160341)

Short-term fix for https://github.com/pytorch/pytorch/issues/160333

The problem is:
1) `triton_op` adds a decomposition for FunctionalTensorMode for this operation
2) Tensor Subclasses rely on FunctionalTensorMode's `__torch_dispatch__` returning NotImplemented.
3) `triton_op`'s FunctionalTensorMode decomposition takes precedence over FunctionalTensorMode's decomposition.

The easy fix is to copy-paste the FunctionalTensorMode's NotImplemented
return logic into the decomposition.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160341
Approved by: https://github.com/drisspg
This commit is contained in:
rzou 2025-08-11 17:57:31 -07:00 committed by PyTorch MergeBot
parent 32e5e2f596
commit 10bc36fe84
2 changed files with 51 additions and 0 deletions

View File

@ -3583,6 +3583,40 @@ class CustomOpTests(torch._inductor.test_case.TestCase):
self.assertNotIn(libname, code) self.assertNotIn(libname, code)
self.assertNotIn(opname, code) self.assertNotIn(opname, code)
@requires_gpu
def test_subclass(self):
libname = "my_cool_namespace"
opname = "my_triton_operator"
@torch.library.triton_op(f"{libname}::{opname}", mutates_args={})
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
output = torch.empty_like(x)
n_elements = output.numel()
def grid(meta):
return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
capture_triton(add_kernel)[grid](x, y, output, n_elements, 16)
return output
def f(x, y):
return add(x, y)
x0 = torch.randn(3, device=GPU_TYPE)
y0 = torch.randn(3, device=GPU_TYPE)
x1 = torch.randn(3, device=GPU_TYPE)
y1 = torch.randn(3, device=GPU_TYPE)
from torch.testing._internal.two_tensor import TwoTensor
x = TwoTensor(x0, x1)
y = TwoTensor(y0, y1)
out = torch.compile(f, fullgraph=True)(x, y)
expected = f(x, y)
self.assertEqual(out.a, expected.a)
self.assertEqual(out.b, expected.b)
@requires_gpu @requires_gpu
@dynamo_config.patch("recompile_limit", 1) @dynamo_config.patch("recompile_limit", 1)
def test_triton_dynamic_grid_no_recompile(self): def test_triton_dynamic_grid_no_recompile(self):

View File

@ -155,6 +155,23 @@ def triton_op(
if custom_triton_ops_decomposition_disabled(): if custom_triton_ops_decomposition_disabled():
return mode.__torch_dispatch__(op, types, args, kwargs) return mode.__torch_dispatch__(op, types, args, kwargs)
else: else:
# TODO: https://github.com/pytorch/pytorch/issues/160333
# We should deduplicate the unrecognized_types logic.
import torch._subclasses
unrecognized_types = [
t
for t in types
if not issubclass(t, torch._subclasses.FakeTensor)
and t
not in [
torch.Tensor,
torch._subclasses.functional_tensor.FunctionalTensor,
]
]
if unrecognized_types:
return NotImplemented
with mode: with mode:
return fn(*args, **kwargs) return fn(*args, **kwargs)