mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Inductor] avoid CUDA__equal when constant tensors are from different device (#163529)
Summary: otherwise, may hit ``` Exception: Expected all tensors to be on the same device, but got other is on cuda:0, different from other tensors on cpu (when checking argument in method wrapper_CUDA__equal) ``` Test Plan: UTs Reviewed By: yushangdi Differential Revision: D82974062 Pull Request resolved: https://github.com/pytorch/pytorch/pull/163529 Approved by: https://github.com/yushangdi, https://github.com/Skylion007
This commit is contained in:
parent
4fc271e559
commit
e0cbab46ad
|
|
@ -1829,7 +1829,7 @@ class AOTInductorTestsTemplate:
|
||||||
self.user_float_feature_idx = user_float_feature_idx
|
self.user_float_feature_idx = user_float_feature_idx
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
"_tensor_constant0",
|
"_tensor_constant0",
|
||||||
torch.ones(1, device=device, dtype=torch.float32),
|
torch.ones(5, device=device, dtype=torch.float32),
|
||||||
persistent=True,
|
persistent=True,
|
||||||
)
|
)
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
|
|
@ -1840,6 +1840,7 @@ class AOTInductorTestsTemplate:
|
||||||
self.sub_mod = SubModule(device)
|
self.sub_mod = SubModule(device)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
self._tensor_constant0[1:2] = 1
|
||||||
return (
|
return (
|
||||||
torch.index_select(
|
torch.index_select(
|
||||||
x, 1, torch.tensor(self.user_float_feature_idx, device=x.device)
|
x, 1, torch.tensor(self.user_float_feature_idx, device=x.device)
|
||||||
|
|
@ -7291,8 +7292,6 @@ MPS_TEST_FAILURES = {
|
||||||
"test_index_put_with_none_index": fail_mps(),
|
"test_index_put_with_none_index": fail_mps(),
|
||||||
# Error device may not be nil
|
# Error device may not be nil
|
||||||
"test_zero_size_weight": fail_mps(is_skip=True),
|
"test_zero_size_weight": fail_mps(is_skip=True),
|
||||||
# RuntimeError: Cannot compare two tensors on different devices. Got: cpu and mps:0
|
|
||||||
"test_aoti_constant_tensor_name_collision": fail_mps(is_skip=True),
|
|
||||||
# MPSGraph does not support tensor dims > INT_MAX
|
# MPSGraph does not support tensor dims > INT_MAX
|
||||||
"test_upper_bound_i64": fail_mps(is_skip=True),
|
"test_upper_bound_i64": fail_mps(is_skip=True),
|
||||||
# MPS doesn't support triton
|
# MPS doesn't support triton
|
||||||
|
|
|
||||||
|
|
@ -370,9 +370,12 @@ def _resolve_name_collision(mod: GraphModule, gm: GraphModule) -> None:
|
||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
elif (
|
elif (
|
||||||
torch.equal(gm_target, model_target)
|
gm_target.device == model_target.device
|
||||||
and gm_target.dtype == model_target.dtype
|
and gm_target.dtype == model_target.dtype
|
||||||
|
and torch.equal(gm_target, model_target)
|
||||||
):
|
):
|
||||||
|
# If tensors with same name from gm and model are indeed the same, we don't need to rename
|
||||||
|
# Check device first, to avoid torch.equal(wrapper_CUDA__equal) raise when different device
|
||||||
continue
|
continue
|
||||||
|
|
||||||
prefix = (
|
prefix = (
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user