mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add sequoia runner to mac-mps (#132190)
Adds MacOS 15 runners to GitHub actions for Mac-mps test suite Co-authored-by: Joona Havukainen <jhavukainen@apple.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/132190 Approved by: https://github.com/malfet
This commit is contained in:
parent
d72e863b3e
commit
f8b6e91840
1
.github/workflows/mac-mps.yml
vendored
1
.github/workflows/mac-mps.yml
vendored
|
|
@ -29,6 +29,7 @@ jobs:
|
||||||
{ include: [
|
{ include: [
|
||||||
{ config: "mps", shard: 1, num_shards: 1, runner: "macos-m1-13" },
|
{ config: "mps", shard: 1, num_shards: 1, runner: "macos-m1-13" },
|
||||||
{ config: "mps", shard: 1, num_shards: 1, runner: "macos-m1-14" },
|
{ config: "mps", shard: 1, num_shards: 1, runner: "macos-m1-14" },
|
||||||
|
{ config: "mps", shard: 1, num_shards: 1, runner: "macos-m2-15" },
|
||||||
]}
|
]}
|
||||||
|
|
||||||
macos-py3-arm64-mps-test:
|
macos-py3-arm64-mps-test:
|
||||||
|
|
|
||||||
|
|
@ -32,6 +32,7 @@ enum class MacOSVersion : uint32_t {
|
||||||
MACOS_VER_13_3_PLUS,
|
MACOS_VER_13_3_PLUS,
|
||||||
MACOS_VER_14_0_PLUS,
|
MACOS_VER_14_0_PLUS,
|
||||||
MACOS_VER_14_4_PLUS,
|
MACOS_VER_14_4_PLUS,
|
||||||
|
MACOS_VER_15_0_PLUS,
|
||||||
};
|
};
|
||||||
|
|
||||||
//-----------------------------------------------------------------
|
//-----------------------------------------------------------------
|
||||||
|
|
|
||||||
|
|
@ -118,6 +118,7 @@ bool MPSDevice::isMacOS13Plus(MacOSVersion version) const {
|
||||||
static bool _macos_13_3_plus = is_os_version_at_least(13, 3);
|
static bool _macos_13_3_plus = is_os_version_at_least(13, 3);
|
||||||
static bool _macos_14_0_plus = is_os_version_at_least(14, 0);
|
static bool _macos_14_0_plus = is_os_version_at_least(14, 0);
|
||||||
static bool _macos_14_4_plus = is_os_version_at_least(14, 0);
|
static bool _macos_14_4_plus = is_os_version_at_least(14, 0);
|
||||||
|
static bool _macos_15_0_plus = is_os_version_at_least(15, 0);
|
||||||
|
|
||||||
switch (version) {
|
switch (version) {
|
||||||
case MacOSVersion::MACOS_VER_13_0_PLUS:
|
case MacOSVersion::MACOS_VER_13_0_PLUS:
|
||||||
|
|
@ -132,6 +133,8 @@ bool MPSDevice::isMacOS13Plus(MacOSVersion version) const {
|
||||||
return _macos_14_0_plus;
|
return _macos_14_0_plus;
|
||||||
case MacOSVersion::MACOS_VER_14_4_PLUS:
|
case MacOSVersion::MACOS_VER_14_4_PLUS:
|
||||||
return _macos_14_4_plus;
|
return _macos_14_4_plus;
|
||||||
|
case MacOSVersion::MACOS_VER_15_0_PLUS:
|
||||||
|
return _macos_15_0_plus;
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -68,6 +68,14 @@ std::tuple<Tensor, Tensor> _scaled_dot_product_attention_math_mps(const Tensor&
|
||||||
|
|
||||||
auto maskedMM = [mpsGraph matrixMultiplicationWithPrimaryTensor:qTensor secondaryTensor:kT name:nil];
|
auto maskedMM = [mpsGraph matrixMultiplicationWithPrimaryTensor:qTensor secondaryTensor:kT name:nil];
|
||||||
|
|
||||||
|
bool macOS15_0_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS);
|
||||||
|
if (macOS15_0_plus && [maskedMM dataType] == MPSDataTypeFloat32) {
|
||||||
|
// TODO: In MacOS15 beta, there is a MPSGraph issue when the SDPA sequence gets remapped to use
|
||||||
|
// an improved kernel for the computation, causing NaNs in the result. This identity prevents the remapping.
|
||||||
|
// Limit the availability check once a fix lands.
|
||||||
|
maskedMM = [mpsGraph identityWithTensor:maskedMM name:nil];
|
||||||
|
}
|
||||||
|
|
||||||
// upcasting to float32 if needed to improve precision when multiplying by the scale factor
|
// upcasting to float32 if needed to improve precision when multiplying by the scale factor
|
||||||
if ([maskedMM dataType] != MPSDataTypeFloat32) {
|
if ([maskedMM dataType] != MPSDataTypeFloat32) {
|
||||||
maskedMM = [mpsGraph castTensor:maskedMM toType:MPSDataTypeFloat32 name:nil];
|
maskedMM = [mpsGraph castTensor:maskedMM toType:MPSDataTypeFloat32 name:nil];
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user