[Bugfix][Dynamo] Fix Sparse tensors by graph break in Dynamo (#164873)

Fixes #164823 by making lack of support for sparse tensors very explicit (in fake tensor, inductor, and lowering code)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164873
Approved by: https://github.com/williamwen42, https://github.com/eellison, https://github.com/mlazos
This commit is contained in:
Lucas Kabela 2025-10-16 15:06:16 +00:00 committed by PyTorch MergeBot
parent 1a5b7eca7b
commit e6d9d68598
5 changed files with 40 additions and 0 deletions

View File

@ -7702,6 +7702,19 @@ utils_device.CURRENT_DEVICE == None""".split("\n"):
opt_fn = torch.compile(fn, backend="eager")
self.assertEqual(opt_fn(torch.ones(1)), torch.tensor([3.0]))
def test_sparse_output_inductor_should_break(self) -> None:
# See https://github.com/pytorch/pytorch/issues/164823
# We want consistent semantics here
def forward(x: torch.Tensor) -> torch.Tensor:
x_sparse = x.to_sparse()
return x_sparse * 2
test_tensor = torch.randn(10, 10)
pt = forward(test_tensor)
aot_eager = torch.compile(forward, backend="aot_eager")(test_tensor)
self.assertEqual(pt, aot_eager)
inductor = torch.compile(forward, backend="inductor")(test_tensor)
def test_nested_sequential_try_with(self):
def fn(x):
with torch.set_grad_enabled(True):

View File

@ -266,9 +266,12 @@ inductor_expected_failures_single_sample["cuda"] = {
"torch.ops.aten._flash_attention_forward": {f16},
"torch.ops.aten._efficient_attention_forward": {f16, f32},
"to_sparse": {
b8,
f16,
f32,
f64,
i32,
i64,
}, # NYI: could not find kernel for aten.view.default at dispatch key DispatchKey.SparseCUDA
}

View File

@ -2790,5 +2790,15 @@
"Explanation": "Object does not allow us to make a weakref to it",
"Hints": []
}
],
"GB0277": [
{
"Gb_type": "Attempted to wrap sparse Tensor with VariableTracker",
"Context": "str(example_value)",
"Explanation": "torch.compile does not support sparse Tensors with VariableTracker",
"Hints": [
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues."
]
}
]
}

View File

@ -2881,6 +2881,17 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe
import torch._utils
if isinstance(example_value, torch.Tensor):
# Check if the result is a sparse tensor -
# We generally don't support sparse tensor so better to graph break here
if is_sparse_any(example_value) and (
not tx.export or not config.capture_sparse_compute
):
unimplemented_v2(
gb_type="Attempted to wrap sparse Tensor with VariableTracker",
context=str(example_value),
explanation="torch.compile does not support sparse Tensors with VariableTracker",
hints=[*graph_break_hints.SUPPORTABLE],
)
var = construct_tensor_variable(
target_cls, tx, proxy, example_value, subclass_type, options
)

View File

@ -2141,6 +2141,9 @@ def unsupported_input_tensor(t: torch.Tensor, node=None):
if t.is_meta:
return True
if t.is_sparse:
return True
if t.dtype == torch.float8_e8m0fnu:
if not node:
return True