mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[MPS] Add API to query GPU core count (#160414)
Using good old IOKit to get `gpu-core-count` property from device implementing `AGXAccelerator` service
Expose this one as `torch.backend.mps.get_core_count()` and make it accessible via `MpsInterface` to the inductor
Test Plan: Run `python3 -c "import torch;print(torch.backends.mps.get_name(), torch.backends.mps.get_core_count())"` and compare it to `system_profiler SPDisplaysDataType|head -n10`
```
% python3 -c "import torch;print(torch.backends.mps.get_name(), torch.backends.mps.get_core_count())"
Apple M1 Pro 16
% system_profiler SPDisplaysDataType|head -n10
Graphics/Displays:
Apple M1 Pro:
Chipset Model: Apple M1 Pro
Type: GPU
Bus: Built-In
Total Number of Cores: 16
Vendor: Apple (0x106b)
Metal Support: Metal 3
```
This would significantly improve occupancy for torch.compile generated kernels
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160414
Approved by: https://github.com/dcci
This commit is contained in:
parent
50a8c11875
commit
a06ec54d40
|
|
@ -1196,7 +1196,7 @@ if(APPLE)
|
||||||
string(
|
string(
|
||||||
APPEND
|
APPEND
|
||||||
CMAKE_SHARED_LINKER_FLAGS
|
CMAKE_SHARED_LINKER_FLAGS
|
||||||
" -weak_framework Foundation -weak_framework MetalPerformanceShaders -weak_framework MetalPerformanceShadersGraph -weak_framework Metal"
|
" -weak_framework Foundation -weak_framework MetalPerformanceShaders -weak_framework MetalPerformanceShadersGraph -weak_framework Metal -weak_framework IOKit"
|
||||||
)
|
)
|
||||||
# To suppress MPSGraph availability warnings
|
# To suppress MPSGraph availability warnings
|
||||||
append_cxx_flag_if_supported("-Wno-unguarded-availability-new"
|
append_cxx_flag_if_supported("-Wno-unguarded-availability-new"
|
||||||
|
|
|
||||||
|
|
@ -55,6 +55,17 @@ class TORCH_API MPSDevice {
|
||||||
*/
|
*/
|
||||||
bool isMacOS13Plus(MacOSVersion version) const;
|
bool isMacOS13Plus(MacOSVersion version) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns device name
|
||||||
|
*/
|
||||||
|
std::string getName() const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns number of GPU cores.
|
||||||
|
* 1 Core = 16 ExecutionUnit x 8 ALU x 24 threads
|
||||||
|
*/
|
||||||
|
unsigned getCoreCount() const;
|
||||||
|
|
||||||
~MPSDevice();
|
~MPSDevice();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
||||||
|
|
@ -85,10 +85,36 @@ bool MPSDevice::isMacOS13Plus(MacOSVersion version) const {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string MPSDevice::getName() const {
|
||||||
|
@autoreleasepool {
|
||||||
|
return [[_mtl_device name] UTF8String];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned MPSDevice::getCoreCount() const {
|
||||||
|
io_iterator_t iterator = 0;
|
||||||
|
io_registry_entry_t entry = 0;
|
||||||
|
int core_count = 0;
|
||||||
|
auto matchingDict = IOServiceMatching("AGXAccelerator");
|
||||||
|
TORCH_INTERNAL_ASSERT(matchingDict, "Failed to create matching dict");
|
||||||
|
const auto status = IOServiceGetMatchingServices(kIOMainPortDefault, matchingDict, &iterator);
|
||||||
|
TORCH_INTERNAL_ASSERT(status == KERN_SUCCESS);
|
||||||
|
while ((entry = IOIteratorNext(iterator)) != 0) {
|
||||||
|
auto property = IORegistryEntryCreateCFProperty(entry, CFSTR("gpu-core-count"), kCFAllocatorDefault, 0);
|
||||||
|
auto found = CFNumberGetValue(static_cast<CFNumberRef>(property), kCFNumberIntType, &core_count);
|
||||||
|
CFRelease(property);
|
||||||
|
IOObjectRelease(entry);
|
||||||
|
if (found) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
IOObjectRelease(iterator);
|
||||||
|
return core_count;
|
||||||
|
}
|
||||||
|
|
||||||
at::Allocator* GetMPSAllocator(bool useSharedAllocator) {
|
at::Allocator* GetMPSAllocator(bool useSharedAllocator) {
|
||||||
return getIMPSAllocator(useSharedAllocator);
|
return getIMPSAllocator(useSharedAllocator);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool is_available() {
|
bool is_available() {
|
||||||
return MPSDevice::getInstance()->device() != nil;
|
return MPSDevice::getInstance()->device() != nil;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1979,7 +1979,9 @@ def _mtia_resetPeakMemoryStats(device: _int) -> None: ...
|
||||||
|
|
||||||
# Defined in torch/csrc/mps/Module.cpp
|
# Defined in torch/csrc/mps/Module.cpp
|
||||||
def _mps_deviceSynchronize() -> None: ...
|
def _mps_deviceSynchronize() -> None: ...
|
||||||
|
def _mps_get_core_count() -> _int: ...
|
||||||
def _mps_get_default_generator() -> Generator: ...
|
def _mps_get_default_generator() -> Generator: ...
|
||||||
|
def _mps_get_name() -> _str: ...
|
||||||
def _mps_emptyCache() -> None: ...
|
def _mps_emptyCache() -> None: ...
|
||||||
def _mps_setMemoryFraction(fraction: _float) -> None: ...
|
def _mps_setMemoryFraction(fraction: _float) -> None: ...
|
||||||
def _mps_currentAllocatedMemory() -> _int: ...
|
def _mps_currentAllocatedMemory() -> _int: ...
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,7 @@ specialized implementations for each hardware backend's unique features.
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import time
|
import time
|
||||||
|
from collections import namedtuple
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Callable, Literal, Optional, Union
|
from typing import Any, Callable, Literal, Optional, Union
|
||||||
|
|
@ -544,8 +545,10 @@ class MpsInterface(DeviceInterface):
|
||||||
|
|
||||||
class Worker:
|
class Worker:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_device_properties(device: torch.types.Device = None) -> dict[str, Any]:
|
def get_device_properties(device: torch.types.Device = None) -> Any:
|
||||||
return {}
|
return namedtuple("MPSProperties", ["multi_processor_count"])(
|
||||||
|
torch.backends.mps.get_core_count() # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def current_device() -> int:
|
def current_device() -> int:
|
||||||
|
|
|
||||||
|
|
@ -153,9 +153,6 @@ class DeviceProperties(typing.NamedTuple):
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
if device_type == "xpu":
|
if device_type == "xpu":
|
||||||
multi_processor_count = props.gpu_subslice_count
|
multi_processor_count = props.gpu_subslice_count
|
||||||
elif device_type == "mps":
|
|
||||||
# TODO: Fetch the actual value from ioreg
|
|
||||||
multi_processor_count = 8
|
|
||||||
elif device_type == "mtia":
|
elif device_type == "mtia":
|
||||||
multi_processor_count = 64
|
multi_processor_count = 64
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,14 @@ import torch
|
||||||
from torch.library import Library as _Library
|
from torch.library import Library as _Library
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["is_built", "is_available", "is_macos13_or_newer", "is_macos_or_newer"]
|
__all__ = [
|
||||||
|
"get_core_count",
|
||||||
|
"get_name",
|
||||||
|
"is_built",
|
||||||
|
"is_available",
|
||||||
|
"is_macos13_or_newer",
|
||||||
|
"is_macos_or_newer",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def is_built() -> bool:
|
def is_built() -> bool:
|
||||||
|
|
@ -36,6 +43,23 @@ def is_macos13_or_newer(minor: int = 0) -> bool:
|
||||||
return torch._C._mps_is_on_macos_or_newer(13, minor)
|
return torch._C._mps_is_on_macos_or_newer(13, minor)
|
||||||
|
|
||||||
|
|
||||||
|
@_lru_cache
|
||||||
|
def get_name() -> str:
|
||||||
|
r"""Return Metal device name"""
|
||||||
|
return torch._C._mps_get_name()
|
||||||
|
|
||||||
|
|
||||||
|
@_lru_cache
|
||||||
|
def get_core_count() -> int:
|
||||||
|
r"""Return GPU core count.
|
||||||
|
|
||||||
|
According to the documentation, one core is comprised of 16 Execution Units.
|
||||||
|
One execution Unit has 8 ALUs.
|
||||||
|
And one ALU can run 24 threads, i.e. one core is capable of executing 3072 threads concurrently.
|
||||||
|
"""
|
||||||
|
return torch._C._mps_get_core_count()
|
||||||
|
|
||||||
|
|
||||||
_lib: Optional[_Library] = None
|
_lib: Optional[_Library] = None
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,6 @@
|
||||||
#include <ATen/Parallel.h>
|
#include <ATen/Parallel.h>
|
||||||
#include <ATen/Utils.h>
|
#include <ATen/Utils.h>
|
||||||
#include <ATen/core/Vitals.h>
|
#include <ATen/core/Vitals.h>
|
||||||
#include <ATen/detail/AcceleratorHooksInterface.h>
|
|
||||||
#include <ATen/dlpack.h>
|
#include <ATen/dlpack.h>
|
||||||
#include <ATen/native/ConvUtils.h>
|
#include <ATen/native/ConvUtils.h>
|
||||||
#include <ATen/native/ForeachUtils.h>
|
#include <ATen/native/ForeachUtils.h>
|
||||||
|
|
|
||||||
|
|
@ -501,6 +501,12 @@ void initModule(PyObject* module) {
|
||||||
at::mps::getMPSProfiler().startCapture(fileName);
|
at::mps::getMPSProfiler().startCapture(fileName);
|
||||||
});
|
});
|
||||||
m.def("_mps_stopCapture", []() { at::mps::getMPSProfiler().stopCapture(); });
|
m.def("_mps_stopCapture", []() { at::mps::getMPSProfiler().stopCapture(); });
|
||||||
|
m.def("_mps_get_name", []() {
|
||||||
|
return at::mps::MPSDevice::getInstance()->getName();
|
||||||
|
});
|
||||||
|
m.def("_mps_get_core_count", []() {
|
||||||
|
return at::mps::MPSDevice::getInstance()->getCoreCount();
|
||||||
|
});
|
||||||
}
|
}
|
||||||
#endif /* USE_MPS */
|
#endif /* USE_MPS */
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user