[Inductor] avoid CUDA__equal when constant tensors are from different device (#163529)

Summary:
otherwise, may hit
```
Exception: Expected all tensors to be on the same device, but got other is on cuda:0, different from other tensors on cpu (when checking argument in method wrapper_CUDA__equal)
```

Test Plan: UTs

Reviewed By: yushangdi

Differential Revision: D82974062

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163529
Approved by: https://github.com/yushangdi, https://github.com/Skylion007
This commit is contained in:
Chang Pan 2025-09-22 22:04:11 +00:00 committed by PyTorch MergeBot
parent 4fc271e559
commit e0cbab46ad
2 changed files with 6 additions and 4 deletions

View File

@ -1829,7 +1829,7 @@ class AOTInductorTestsTemplate:
self.user_float_feature_idx = user_float_feature_idx
self.register_buffer(
"_tensor_constant0",
torch.ones(1, device=device, dtype=torch.float32),
torch.ones(5, device=device, dtype=torch.float32),
persistent=True,
)
self.register_buffer(
@ -1840,6 +1840,7 @@ class AOTInductorTestsTemplate:
self.sub_mod = SubModule(device)
def forward(self, x):
self._tensor_constant0[1:2] = 1
return (
torch.index_select(
x, 1, torch.tensor(self.user_float_feature_idx, device=x.device)
@ -7291,8 +7292,6 @@ MPS_TEST_FAILURES = {
"test_index_put_with_none_index": fail_mps(),
# Error device may not be nil
"test_zero_size_weight": fail_mps(is_skip=True),
# RuntimeError: Cannot compare two tensors on different devices. Got: cpu and mps:0
"test_aoti_constant_tensor_name_collision": fail_mps(is_skip=True),
# MPSGraph does not support tensor dims > INT_MAX
"test_upper_bound_i64": fail_mps(is_skip=True),
# MPS doesn't support triton

View File

@ -370,9 +370,12 @@ def _resolve_name_collision(mod: GraphModule, gm: GraphModule) -> None:
):
continue
elif (
torch.equal(gm_target, model_target)
gm_target.device == model_target.device
and gm_target.dtype == model_target.dtype
and torch.equal(gm_target, model_target)
):
# If tensors with same name from gm and model are indeed the same, we don't need to rename
# Check device first, to avoid torch.equal(wrapper_CUDA__equal) raise when different device
continue
prefix = (