mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Get tensor subclasses and torch.library.triton_op to dispatch correctly (#160341)
Short-term fix for https://github.com/pytorch/pytorch/issues/160333 The problem is: 1) `triton_op` adds a decomposition for FunctionalTensorMode for this operation 2) Tensor Subclasses rely on FunctionalTensorMode's `__torch_dispatch__` returning NotImplemented. 3) `triton_op`'s FunctionalTensorMode decomposition takes precedence over FunctionalTensorMode's decomposition. The easy fix is to copy-paste the FunctionalTensorMode's NotImplemented return logic into the decomposition. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160341 Approved by: https://github.com/drisspg
This commit is contained in:
parent
32e5e2f596
commit
10bc36fe84
|
|
@ -3583,6 +3583,40 @@ class CustomOpTests(torch._inductor.test_case.TestCase):
|
|||
self.assertNotIn(libname, code)
|
||||
self.assertNotIn(opname, code)
|
||||
|
||||
@requires_gpu
|
||||
def test_subclass(self):
|
||||
libname = "my_cool_namespace"
|
||||
opname = "my_triton_operator"
|
||||
|
||||
@torch.library.triton_op(f"{libname}::{opname}", mutates_args={})
|
||||
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
output = torch.empty_like(x)
|
||||
n_elements = output.numel()
|
||||
|
||||
def grid(meta):
|
||||
return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
||||
capture_triton(add_kernel)[grid](x, y, output, n_elements, 16)
|
||||
|
||||
return output
|
||||
|
||||
def f(x, y):
|
||||
return add(x, y)
|
||||
|
||||
x0 = torch.randn(3, device=GPU_TYPE)
|
||||
y0 = torch.randn(3, device=GPU_TYPE)
|
||||
x1 = torch.randn(3, device=GPU_TYPE)
|
||||
y1 = torch.randn(3, device=GPU_TYPE)
|
||||
|
||||
from torch.testing._internal.two_tensor import TwoTensor
|
||||
|
||||
x = TwoTensor(x0, x1)
|
||||
y = TwoTensor(y0, y1)
|
||||
|
||||
out = torch.compile(f, fullgraph=True)(x, y)
|
||||
expected = f(x, y)
|
||||
self.assertEqual(out.a, expected.a)
|
||||
self.assertEqual(out.b, expected.b)
|
||||
|
||||
@requires_gpu
|
||||
@dynamo_config.patch("recompile_limit", 1)
|
||||
def test_triton_dynamic_grid_no_recompile(self):
|
||||
|
|
|
|||
|
|
@ -155,6 +155,23 @@ def triton_op(
|
|||
if custom_triton_ops_decomposition_disabled():
|
||||
return mode.__torch_dispatch__(op, types, args, kwargs)
|
||||
else:
|
||||
# TODO: https://github.com/pytorch/pytorch/issues/160333
|
||||
# We should deduplicate the unrecognized_types logic.
|
||||
import torch._subclasses
|
||||
|
||||
unrecognized_types = [
|
||||
t
|
||||
for t in types
|
||||
if not issubclass(t, torch._subclasses.FakeTensor)
|
||||
and t
|
||||
not in [
|
||||
torch.Tensor,
|
||||
torch._subclasses.functional_tensor.FunctionalTensor,
|
||||
]
|
||||
]
|
||||
|
||||
if unrecognized_types:
|
||||
return NotImplemented
|
||||
with mode:
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user