mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[MPS] Error out when BatchNorm is called for Complex (#166215)
Or BatchNorm or LayerNorm for Long types Discovered while trying to enable `test_ops.py` for MPS Pull Request resolved: https://github.com/pytorch/pytorch/pull/166215 Approved by: https://github.com/dcci, https://github.com/kulinseth, https://github.com/Skylion007 ghstack dependencies: #166214, #166687
This commit is contained in:
parent
d80ae738c9
commit
9261a1fb12
|
|
@ -84,6 +84,9 @@ std::tuple<Tensor&, Tensor&, Tensor&> batch_norm_mps_out(const Tensor& self,
|
|||
Tensor& output,
|
||||
Tensor& save_mean,
|
||||
Tensor& save_var) {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(self.scalar_type() != kLong, "Long batch norm is not supported with MPS");
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(self.scalar_type()),
|
||||
"Batch norm for complex is not supported for MPS");
|
||||
using namespace at::native::mps;
|
||||
struct CachedGraph : public MPSCachedGraph {
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
|
|
@ -918,6 +921,7 @@ std::tuple<Tensor, Tensor, Tensor> layer_norm_mps(const Tensor& input,
|
|||
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
|
||||
const int axis = input_ndim - normalized_ndim;
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(input.scalar_type() != kLong, "Not implemented for long on MPS");
|
||||
@autoreleasepool {
|
||||
mps::dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||
// which kernel variant to use based on the normalized axis N size
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user