From e0cbab46adee336f357d6e7a11c210c26b6745ef Mon Sep 17 00:00:00 2001 From: Chang Pan Date: Mon, 22 Sep 2025 22:04:11 +0000 Subject: [PATCH] [Inductor] avoid CUDA__equal when constant tensors are from different device (#163529) Summary: otherwise, may hit ``` Exception: Expected all tensors to be on the same device, but got other is on cuda:0, different from other tensors on cpu (when checking argument in method wrapper_CUDA__equal) ``` Test Plan: UTs Reviewed By: yushangdi Differential Revision: D82974062 Pull Request resolved: https://github.com/pytorch/pytorch/pull/163529 Approved by: https://github.com/yushangdi, https://github.com/Skylion007 --- test/inductor/test_aot_inductor.py | 5 ++--- torch/_inductor/compile_fx.py | 5 ++++- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index 7bf1dc7d36c..6cc31b4ab23 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -1829,7 +1829,7 @@ class AOTInductorTestsTemplate: self.user_float_feature_idx = user_float_feature_idx self.register_buffer( "_tensor_constant0", - torch.ones(1, device=device, dtype=torch.float32), + torch.ones(5, device=device, dtype=torch.float32), persistent=True, ) self.register_buffer( @@ -1840,6 +1840,7 @@ class AOTInductorTestsTemplate: self.sub_mod = SubModule(device) def forward(self, x): + self._tensor_constant0[1:2] = 1 return ( torch.index_select( x, 1, torch.tensor(self.user_float_feature_idx, device=x.device) @@ -7291,8 +7292,6 @@ MPS_TEST_FAILURES = { "test_index_put_with_none_index": fail_mps(), # Error device may not be nil "test_zero_size_weight": fail_mps(is_skip=True), - # RuntimeError: Cannot compare two tensors on different devices. Got: cpu and mps:0 - "test_aoti_constant_tensor_name_collision": fail_mps(is_skip=True), # MPSGraph does not support tensor dims > INT_MAX "test_upper_bound_i64": fail_mps(is_skip=True), # MPS doesn't support triton diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index b99b37ff010..5306919ecf6 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -370,9 +370,12 @@ def _resolve_name_collision(mod: GraphModule, gm: GraphModule) -> None: ): continue elif ( - torch.equal(gm_target, model_target) + gm_target.device == model_target.device and gm_target.dtype == model_target.dtype + and torch.equal(gm_target, model_target) ): + # If tensors with same name from gm and model are indeed the same, we don't need to rename + # Check device first, to avoid torch.equal(wrapper_CUDA__equal) raise when different device continue prefix = (