mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Get tensor subclasses and torch.library.triton_op to dispatch correctly (#160341)
Short-term fix for https://github.com/pytorch/pytorch/issues/160333 The problem is: 1) `triton_op` adds a decomposition for FunctionalTensorMode for this operation 2) Tensor Subclasses rely on FunctionalTensorMode's `__torch_dispatch__` returning NotImplemented. 3) `triton_op`'s FunctionalTensorMode decomposition takes precedence over FunctionalTensorMode's decomposition. The easy fix is to copy-paste the FunctionalTensorMode's NotImplemented return logic into the decomposition. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160341 Approved by: https://github.com/drisspg
This commit is contained in:
parent
32e5e2f596
commit
10bc36fe84
|
|
@ -3583,6 +3583,40 @@ class CustomOpTests(torch._inductor.test_case.TestCase):
|
||||||
self.assertNotIn(libname, code)
|
self.assertNotIn(libname, code)
|
||||||
self.assertNotIn(opname, code)
|
self.assertNotIn(opname, code)
|
||||||
|
|
||||||
|
@requires_gpu
|
||||||
|
def test_subclass(self):
|
||||||
|
libname = "my_cool_namespace"
|
||||||
|
opname = "my_triton_operator"
|
||||||
|
|
||||||
|
@torch.library.triton_op(f"{libname}::{opname}", mutates_args={})
|
||||||
|
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||||
|
output = torch.empty_like(x)
|
||||||
|
n_elements = output.numel()
|
||||||
|
|
||||||
|
def grid(meta):
|
||||||
|
return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
||||||
|
capture_triton(add_kernel)[grid](x, y, output, n_elements, 16)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
def f(x, y):
|
||||||
|
return add(x, y)
|
||||||
|
|
||||||
|
x0 = torch.randn(3, device=GPU_TYPE)
|
||||||
|
y0 = torch.randn(3, device=GPU_TYPE)
|
||||||
|
x1 = torch.randn(3, device=GPU_TYPE)
|
||||||
|
y1 = torch.randn(3, device=GPU_TYPE)
|
||||||
|
|
||||||
|
from torch.testing._internal.two_tensor import TwoTensor
|
||||||
|
|
||||||
|
x = TwoTensor(x0, x1)
|
||||||
|
y = TwoTensor(y0, y1)
|
||||||
|
|
||||||
|
out = torch.compile(f, fullgraph=True)(x, y)
|
||||||
|
expected = f(x, y)
|
||||||
|
self.assertEqual(out.a, expected.a)
|
||||||
|
self.assertEqual(out.b, expected.b)
|
||||||
|
|
||||||
@requires_gpu
|
@requires_gpu
|
||||||
@dynamo_config.patch("recompile_limit", 1)
|
@dynamo_config.patch("recompile_limit", 1)
|
||||||
def test_triton_dynamic_grid_no_recompile(self):
|
def test_triton_dynamic_grid_no_recompile(self):
|
||||||
|
|
|
||||||
|
|
@ -155,6 +155,23 @@ def triton_op(
|
||||||
if custom_triton_ops_decomposition_disabled():
|
if custom_triton_ops_decomposition_disabled():
|
||||||
return mode.__torch_dispatch__(op, types, args, kwargs)
|
return mode.__torch_dispatch__(op, types, args, kwargs)
|
||||||
else:
|
else:
|
||||||
|
# TODO: https://github.com/pytorch/pytorch/issues/160333
|
||||||
|
# We should deduplicate the unrecognized_types logic.
|
||||||
|
import torch._subclasses
|
||||||
|
|
||||||
|
unrecognized_types = [
|
||||||
|
t
|
||||||
|
for t in types
|
||||||
|
if not issubclass(t, torch._subclasses.FakeTensor)
|
||||||
|
and t
|
||||||
|
not in [
|
||||||
|
torch.Tensor,
|
||||||
|
torch._subclasses.functional_tensor.FunctionalTensor,
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
|
if unrecognized_types:
|
||||||
|
return NotImplemented
|
||||||
with mode:
|
with mode:
|
||||||
return fn(*args, **kwargs)
|
return fn(*args, **kwargs)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user