mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
This is my suggestion for resolving #152087 This PR extends the constructor of `AOTIModelPackageLoader` with an (optional) device index. The device type is still determined by `metadata_["AOTI_DEVICE_KEY"]`, but the `device_index` argument can be used to move an AOTI model package to different devices like `cuda:0`, `cuda:1`, ... in a convenient way. AFAIK, this is not possible so far using `AOTIModelPackageLoader` alone. The default case (no device index specified) with `metadata_["AOTI_DEVICE_KEY"] == "cuda"` would lead to the current behavior, i.e., the model is loaded to device `cuda`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/152093 Approved by: https://github.com/desertfire |
||
|---|---|---|
| .. | ||
| aoti_custom_class.cpp | ||
| aoti_custom_class.h | ||
| CMakeLists.txt | ||
| compile_model.py | ||
| generate_lowered_cpu.py | ||
| standalone_compile.sh | ||
| standalone_test.cpp | ||
| test.cpp | ||
| test.py | ||