[MPS] Error out when BatchNorm is called for Complex (#166215)

Or BatchNorm or LayerNorm for Long types

Discovered while trying to enable `test_ops.py` for MPS
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166215
Approved by: https://github.com/dcci, https://github.com/kulinseth, https://github.com/Skylion007
ghstack dependencies: #166214, #166687
This commit is contained in:
Nikita Shulga 2025-10-31 14:07:24 -07:00 committed by PyTorch MergeBot
parent d80ae738c9
commit 9261a1fb12

View File

@ -84,6 +84,9 @@ std::tuple<Tensor&, Tensor&, Tensor&> batch_norm_mps_out(const Tensor& self,
Tensor& output,
Tensor& save_mean,
Tensor& save_var) {
TORCH_CHECK_NOT_IMPLEMENTED(self.scalar_type() != kLong, "Long batch norm is not supported with MPS");
TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(self.scalar_type()),
"Batch norm for complex is not supported for MPS");
using namespace at::native::mps;
struct CachedGraph : public MPSCachedGraph {
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
@ -918,6 +921,7 @@ std::tuple<Tensor, Tensor, Tensor> layer_norm_mps(const Tensor& input,
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
const int axis = input_ndim - normalized_ndim;
MPSStream* stream = getCurrentMPSStream();
TORCH_CHECK_NOT_IMPLEMENTED(input.scalar_type() != kLong, "Not implemented for long on MPS");
@autoreleasepool {
mps::dispatch_sync_with_rethrow(stream->queue(), ^() {
// which kernel variant to use based on the normalized axis N size