mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[Inductor] avoid CUDA__equal when constant tensors are from different device (#163529)
Summary: otherwise, may hit ``` Exception: Expected all tensors to be on the same device, but got other is on cuda:0, different from other tensors on cpu (when checking argument in method wrapper_CUDA__equal) ``` Test Plan: UTs Reviewed By: yushangdi Differential Revision: D82974062 Pull Request resolved: https://github.com/pytorch/pytorch/pull/163529 Approved by: https://github.com/yushangdi, https://github.com/Skylion007
This commit is contained in:
parent
4fc271e559
commit
e0cbab46ad
|
|
@ -1829,7 +1829,7 @@ class AOTInductorTestsTemplate:
|
|||
self.user_float_feature_idx = user_float_feature_idx
|
||||
self.register_buffer(
|
||||
"_tensor_constant0",
|
||||
torch.ones(1, device=device, dtype=torch.float32),
|
||||
torch.ones(5, device=device, dtype=torch.float32),
|
||||
persistent=True,
|
||||
)
|
||||
self.register_buffer(
|
||||
|
|
@ -1840,6 +1840,7 @@ class AOTInductorTestsTemplate:
|
|||
self.sub_mod = SubModule(device)
|
||||
|
||||
def forward(self, x):
|
||||
self._tensor_constant0[1:2] = 1
|
||||
return (
|
||||
torch.index_select(
|
||||
x, 1, torch.tensor(self.user_float_feature_idx, device=x.device)
|
||||
|
|
@ -7291,8 +7292,6 @@ MPS_TEST_FAILURES = {
|
|||
"test_index_put_with_none_index": fail_mps(),
|
||||
# Error device may not be nil
|
||||
"test_zero_size_weight": fail_mps(is_skip=True),
|
||||
# RuntimeError: Cannot compare two tensors on different devices. Got: cpu and mps:0
|
||||
"test_aoti_constant_tensor_name_collision": fail_mps(is_skip=True),
|
||||
# MPSGraph does not support tensor dims > INT_MAX
|
||||
"test_upper_bound_i64": fail_mps(is_skip=True),
|
||||
# MPS doesn't support triton
|
||||
|
|
|
|||
|
|
@ -370,9 +370,12 @@ def _resolve_name_collision(mod: GraphModule, gm: GraphModule) -> None:
|
|||
):
|
||||
continue
|
||||
elif (
|
||||
torch.equal(gm_target, model_target)
|
||||
gm_target.device == model_target.device
|
||||
and gm_target.dtype == model_target.dtype
|
||||
and torch.equal(gm_target, model_target)
|
||||
):
|
||||
# If tensors with same name from gm and model are indeed the same, we don't need to rename
|
||||
# Check device first, to avoid torch.equal(wrapper_CUDA__equal) raise when different device
|
||||
continue
|
||||
|
||||
prefix = (
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user