diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 4a0ad6fafb9..dea04cd54b6 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -1670,7 +1670,7 @@ class aot_inductor: # If link_libtorch is False and cross_target_platform is windows, # a library needs to be provided to provide the shim implementations. - aoti_shim_library: Optional[str] = None + aoti_shim_library: Optional[str | list[str]] = None aoti_shim_library_path: Optional[str] = None diff --git a/torch/_inductor/cpp_builder.py b/torch/_inductor/cpp_builder.py index 948089f3cc5..90b06a1b161 100644 --- a/torch/_inductor/cpp_builder.py +++ b/torch/_inductor/cpp_builder.py @@ -1147,10 +1147,16 @@ def _get_torch_related_args( else: libraries_dirs = [] if config.aot_inductor.cross_target_platform == "windows": - assert config.aot_inductor.aoti_shim_library, ( + aoti_shim_library = config.aot_inductor.aoti_shim_library + + assert aoti_shim_library, ( "'config.aot_inductor.aoti_shim_library' must be set when 'cross_target_platform' is 'windows'." ) - libraries.append(config.aot_inductor.aoti_shim_library) + if isinstance(aoti_shim_library, str): + libraries.append(aoti_shim_library) + else: + assert isinstance(aoti_shim_library, list) + libraries.extend(aoti_shim_library) if config.aot_inductor.cross_target_platform == "windows": assert config.aot_inductor.aoti_shim_library_path, (