diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index 7bf1dc7d36c..6cc31b4ab23 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -1829,7 +1829,7 @@ class AOTInductorTestsTemplate: self.user_float_feature_idx = user_float_feature_idx self.register_buffer( "_tensor_constant0", - torch.ones(1, device=device, dtype=torch.float32), + torch.ones(5, device=device, dtype=torch.float32), persistent=True, ) self.register_buffer( @@ -1840,6 +1840,7 @@ class AOTInductorTestsTemplate: self.sub_mod = SubModule(device) def forward(self, x): + self._tensor_constant0[1:2] = 1 return ( torch.index_select( x, 1, torch.tensor(self.user_float_feature_idx, device=x.device) @@ -7291,8 +7292,6 @@ MPS_TEST_FAILURES = { "test_index_put_with_none_index": fail_mps(), # Error device may not be nil "test_zero_size_weight": fail_mps(is_skip=True), - # RuntimeError: Cannot compare two tensors on different devices. Got: cpu and mps:0 - "test_aoti_constant_tensor_name_collision": fail_mps(is_skip=True), # MPSGraph does not support tensor dims > INT_MAX "test_upper_bound_i64": fail_mps(is_skip=True), # MPS doesn't support triton diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index b99b37ff010..5306919ecf6 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -370,9 +370,12 @@ def _resolve_name_collision(mod: GraphModule, gm: GraphModule) -> None: ): continue elif ( - torch.equal(gm_target, model_target) + gm_target.device == model_target.device and gm_target.dtype == model_target.dtype + and torch.equal(gm_target, model_target) ): + # If tensors with same name from gm and model are indeed the same, we don't need to rename + # Check device first, to avoid torch.equal(wrapper_CUDA__equal) raise when different device continue prefix = (