diff --git a/test/inductor/test_triton_kernels.py b/test/inductor/test_triton_kernels.py index 6804a500fbd..fc9f92477c7 100644 --- a/test/inductor/test_triton_kernels.py +++ b/test/inductor/test_triton_kernels.py @@ -3583,6 +3583,40 @@ class CustomOpTests(torch._inductor.test_case.TestCase): self.assertNotIn(libname, code) self.assertNotIn(opname, code) + @requires_gpu + def test_subclass(self): + libname = "my_cool_namespace" + opname = "my_triton_operator" + + @torch.library.triton_op(f"{libname}::{opname}", mutates_args={}) + def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + output = torch.empty_like(x) + n_elements = output.numel() + + def grid(meta): + return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + capture_triton(add_kernel)[grid](x, y, output, n_elements, 16) + + return output + + def f(x, y): + return add(x, y) + + x0 = torch.randn(3, device=GPU_TYPE) + y0 = torch.randn(3, device=GPU_TYPE) + x1 = torch.randn(3, device=GPU_TYPE) + y1 = torch.randn(3, device=GPU_TYPE) + + from torch.testing._internal.two_tensor import TwoTensor + + x = TwoTensor(x0, x1) + y = TwoTensor(y0, y1) + + out = torch.compile(f, fullgraph=True)(x, y) + expected = f(x, y) + self.assertEqual(out.a, expected.a) + self.assertEqual(out.b, expected.b) + @requires_gpu @dynamo_config.patch("recompile_limit", 1) def test_triton_dynamic_grid_no_recompile(self): diff --git a/torch/_library/triton.py b/torch/_library/triton.py index 72805c765d8..17d02a99456 100644 --- a/torch/_library/triton.py +++ b/torch/_library/triton.py @@ -155,6 +155,23 @@ def triton_op( if custom_triton_ops_decomposition_disabled(): return mode.__torch_dispatch__(op, types, args, kwargs) else: + # TODO: https://github.com/pytorch/pytorch/issues/160333 + # We should deduplicate the unrecognized_types logic. + import torch._subclasses + + unrecognized_types = [ + t + for t in types + if not issubclass(t, torch._subclasses.FakeTensor) + and t + not in [ + torch.Tensor, + torch._subclasses.functional_tensor.FunctionalTensor, + ] + ] + + if unrecognized_types: + return NotImplemented with mode: return fn(*args, **kwargs)