mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Bugfix][Dynamo] Fix Sparse tensors by graph break in Dynamo (#164873)
Fixes #164823 by making lack of support for sparse tensors very explicit (in fake tensor, inductor, and lowering code) Pull Request resolved: https://github.com/pytorch/pytorch/pull/164873 Approved by: https://github.com/williamwen42, https://github.com/eellison, https://github.com/mlazos
This commit is contained in:
parent
1a5b7eca7b
commit
e6d9d68598
|
|
@ -7702,6 +7702,19 @@ utils_device.CURRENT_DEVICE == None""".split("\n"):
|
|||
opt_fn = torch.compile(fn, backend="eager")
|
||||
self.assertEqual(opt_fn(torch.ones(1)), torch.tensor([3.0]))
|
||||
|
||||
def test_sparse_output_inductor_should_break(self) -> None:
|
||||
# See https://github.com/pytorch/pytorch/issues/164823
|
||||
# We want consistent semantics here
|
||||
def forward(x: torch.Tensor) -> torch.Tensor:
|
||||
x_sparse = x.to_sparse()
|
||||
return x_sparse * 2
|
||||
|
||||
test_tensor = torch.randn(10, 10)
|
||||
pt = forward(test_tensor)
|
||||
aot_eager = torch.compile(forward, backend="aot_eager")(test_tensor)
|
||||
self.assertEqual(pt, aot_eager)
|
||||
inductor = torch.compile(forward, backend="inductor")(test_tensor)
|
||||
|
||||
def test_nested_sequential_try_with(self):
|
||||
def fn(x):
|
||||
with torch.set_grad_enabled(True):
|
||||
|
|
|
|||
|
|
@ -266,9 +266,12 @@ inductor_expected_failures_single_sample["cuda"] = {
|
|||
"torch.ops.aten._flash_attention_forward": {f16},
|
||||
"torch.ops.aten._efficient_attention_forward": {f16, f32},
|
||||
"to_sparse": {
|
||||
b8,
|
||||
f16,
|
||||
f32,
|
||||
f64,
|
||||
i32,
|
||||
i64,
|
||||
}, # NYI: could not find kernel for aten.view.default at dispatch key DispatchKey.SparseCUDA
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -2790,5 +2790,15 @@
|
|||
"Explanation": "Object does not allow us to make a weakref to it",
|
||||
"Hints": []
|
||||
}
|
||||
],
|
||||
"GB0277": [
|
||||
{
|
||||
"Gb_type": "Attempted to wrap sparse Tensor with VariableTracker",
|
||||
"Context": "str(example_value)",
|
||||
"Explanation": "torch.compile does not support sparse Tensors with VariableTracker",
|
||||
"Hints": [
|
||||
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues."
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2881,6 +2881,17 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe
|
|||
import torch._utils
|
||||
|
||||
if isinstance(example_value, torch.Tensor):
|
||||
# Check if the result is a sparse tensor -
|
||||
# We generally don't support sparse tensor so better to graph break here
|
||||
if is_sparse_any(example_value) and (
|
||||
not tx.export or not config.capture_sparse_compute
|
||||
):
|
||||
unimplemented_v2(
|
||||
gb_type="Attempted to wrap sparse Tensor with VariableTracker",
|
||||
context=str(example_value),
|
||||
explanation="torch.compile does not support sparse Tensors with VariableTracker",
|
||||
hints=[*graph_break_hints.SUPPORTABLE],
|
||||
)
|
||||
var = construct_tensor_variable(
|
||||
target_cls, tx, proxy, example_value, subclass_type, options
|
||||
)
|
||||
|
|
|
|||
|
|
@ -2141,6 +2141,9 @@ def unsupported_input_tensor(t: torch.Tensor, node=None):
|
|||
if t.is_meta:
|
||||
return True
|
||||
|
||||
if t.is_sparse:
|
||||
return True
|
||||
|
||||
if t.dtype == torch.float8_e8m0fnu:
|
||||
if not node:
|
||||
return True
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user