From b51d8037858f11e46371f62720c4fbd79864a65a Mon Sep 17 00:00:00 2001 From: "Yu, Guangye" Date: Fri, 13 Jun 2025 14:15:23 +0000 Subject: [PATCH] [Dynamo] Add XPU API to trace_rules (#155788) # Motivation - Add binding API and non-bindling API to trace rules for XPU; - Add some XPU API to the const fold function for Dynamo capture. Pull Request resolved: https://github.com/pytorch/pytorch/pull/155788 Approved by: https://github.com/jansel, https://github.com/EikanWang ghstack dependencies: #155787 --- torch/_dynamo/trace_rules.py | 55 ++++++++++++++++++++++++++++++++ torch/_dynamo/variables/torch.py | 4 +++ 2 files changed, 59 insertions(+) diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 99bdea6b5e1..4b4d927496a 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -350,6 +350,8 @@ manual_torch_name_rule_map: dict[str, Any] = { "torch.sparse_csr_tensor": SkipFunctionVariable, "torch.sparse_compressed_tensor": SkipFunctionVariable, "torch._C._autograd._unsafe_set_version_counter": TorchInGraphFunctionVariable, + "torch.xpu.get_rng_state": SkipFunctionVariable, + "torch.xpu.set_rng_state": SkipFunctionVariable, # avoid skipping user defined modules in distributed unit tests "torch/testing/_internal/common_fsdp.py#forward": UserFunctionVariable, f"torch/testing/_internal/common_fsdp.py#{TORCH_DYNAMO_RESUME_IN_PREFIX}": UserFunctionVariable, @@ -1343,6 +1345,21 @@ torch_c_binding_in_graph_functions = dict.fromkeys( "torch._C._warn", "torch._C._will_engine_execute_node", "torch._C._wrap_tensor_impl", + "torch._C._xpu_emptyCache", + "torch._C._xpu_getArchFlags", + "torch._C._xpu_getCurrentStream", + "torch._C._xpu_getCurrentRawStream", + "torch._C._xpu_getDeviceCount", + "torch._C._xpu_getDevice", + "torch._C._xpu_getMemoryInfo", + "torch._C._xpu_getStreamFromExternal", + "torch._C._xpu_isInBadFork", + "torch._C._xpu_init", + "torch._C._xpu_memoryStats", + "torch._C._xpu_resetAccumulatedMemoryStats", + "torch._C._xpu_resetPeakMemoryStats", + "torch._C._xpu_setStream", + "torch._C._xpu_synchronize", "torch._C.fork", "torch._C.get_autocast_cpu_dtype", "torch._C.get_autocast_dtype", @@ -2265,6 +2282,7 @@ torch_c_binding_in_graph_functions = dict.fromkeys( "torch.slice_inverse", "torch._assert_scalar", "torch._functional_assert_scalar", + "torch.xpu._get_device_properties", ], TorchInGraphFunctionVariable, ) @@ -2880,6 +2898,43 @@ torch_non_c_binding_in_graph_functions = dict.fromkeys( "torch.tensordot", "torch.unique_consecutive", "torch.use_deterministic_algorithms", + "torch.xpu._get_device", + "torch.xpu._get_generator", + "torch.xpu._get_rng_state_offset", + "torch.xpu._is_compiled", + "torch.xpu._lazy_call", + "torch.xpu._lazy_init", + "torch.xpu._set_rng_state_offset", + "torch.xpu._set_stream_by_id", + "torch.xpu._utils._get_device_index", + "torch.xpu.current_device", + "torch.xpu.current_stream", + "torch.xpu.device_count", + "torch.xpu.get_arch_list", + "torch.xpu.get_device_capability", + "torch.xpu.get_device_name", + "torch.xpu.get_device_properties", + "torch.xpu.get_gencode_flags", + "torch.xpu.get_stream_from_external", + "torch.xpu.init", + "torch.xpu.is_available", + "torch.xpu.is_bf16_supported", + "torch.xpu.is_initialized", + "torch.xpu.memory.empty_cache", + "torch.xpu.memory.max_memory_allocated", + "torch.xpu.memory.max_memory_reserved", + "torch.xpu.memory.mem_get_info", + "torch.xpu.memory.memory_allocated", + "torch.xpu.memory.memory_reserved", + "torch.xpu.memory.memory_stats_as_nested_dict", + "torch.xpu.memory.memory_stats", + "torch.xpu.memory.reset_accumulated_memory_stats", + "torch.xpu.memory.reset_peak_memory_stats", + "torch.xpu.random.initial_seed", + "torch.xpu.random.seed_all", + "torch.xpu.random.seed", + "torch.xpu.set_stream", + "torch.xpu.synchronize", ], TorchInGraphFunctionVariable, ) diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index 7ac0c224978..b47e76e5933 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -134,6 +134,8 @@ REWRITE_OPS_TO_TENSOR_SIZE_METHOD = dict.fromkeys( constant_fold_functions_need_guards = [ torch.cuda.current_device, torch.cuda.is_initialized, + torch.xpu.current_device, + torch.xpu.is_initialized, ] constant_fold_functions = [ @@ -156,6 +158,8 @@ constant_fold_functions = [ torch.promote_types, torch._C._get_privateuse1_backend_name, torch.autograd._is_checkpoint_valid, + torch.xpu.get_device_properties, + torch.xpu.is_available, ] + constant_fold_functions_need_guards if torch.distributed.is_available(): constant_fold_functions.extend(