diff --git a/aten/src/ATen/native/mps/kernels/EmbeddingBag.h b/aten/src/ATen/native/mps/kernels/EmbeddingBag.h index 47bec81bc11..6b2e702d377 100644 --- a/aten/src/ATen/native/mps/kernels/EmbeddingBag.h +++ b/aten/src/ATen/native/mps/kernels/EmbeddingBag.h @@ -14,7 +14,7 @@ struct EmbeddingBagParams { ::c10::metal::array output_strides; ::c10::metal::array max_indices_strides; - idx_type_t per_sample_weights_strides; + idx_type_t per_sample_weights_stride; idx_type_t num_indices; idx_type_t num_bags; diff --git a/aten/src/ATen/native/mps/kernels/EmbeddingBag.metal b/aten/src/ATen/native/mps/kernels/EmbeddingBag.metal index 861a093d41a..28f6aa09897 100644 --- a/aten/src/ATen/native/mps/kernels/EmbeddingBag.metal +++ b/aten/src/ATen/native/mps/kernels/EmbeddingBag.metal @@ -23,54 +23,72 @@ struct ReductionOpInit { template struct ReductionOp { inline opmath_t operator()( - T weight_val, + opmath_t weight_val, opmath_t out_val, - uint32_t per_sample_weights_index, - constant T* per_sample_weights, - uint32_t per_sample_weights_strides); -}; - -template -struct ReductionOp { - inline opmath_t operator()( - T weight_val, - opmath_t out_val, - uint32_t per_sample_weights_index, - constant T* per_sample_weights, - uint32_t per_sample_weights_strides) { - if (per_sample_weights_strides) { - T per_sample_weight = per_sample_weights - [per_sample_weights_strides * per_sample_weights_index]; - return static_cast>(per_sample_weight) * - static_cast>(weight_val) + - out_val; - } else { - return static_cast>(weight_val) + out_val; - } - } -}; - -template -struct ReductionOp { - inline opmath_t operator()( - T weight_val, - opmath_t out_val, - uint32_t, - constant T*, - uint32_t) { - return static_cast>(weight_val) + out_val; + bool is_first) { + return weight_val + out_val; } }; template struct ReductionOp { inline opmath_t operator()( - T weight_val, + opmath_t weight_val, opmath_t out_val, - uint32_t, - constant T*, - uint32_t) { - return max(static_cast>(weight_val), out_val); + bool is_first) { + return (is_first || weight_val > out_val) ? weight_val : out_val; + } +}; + +template +struct MaybeApplyPerSampleWeight { + inline opmath_t operator()( + opmath_t weight_val, + uint32_t per_sample_weights_index, + constant T* per_sample_weights, + uint32_t per_sample_weights_stride) { + return weight_val; + } +}; + +template +struct MaybeApplyPerSampleWeight { + inline opmath_t operator()( + opmath_t weight_val, + uint32_t per_sample_weights_index, + constant T* per_sample_weights, + uint32_t per_sample_weights_stride) { + if (per_sample_weights_stride) { + T per_sample_weight = per_sample_weights + [per_sample_weights_stride * per_sample_weights_index]; + return static_cast>(per_sample_weight) * weight_val; + } else { + return weight_val; + } + } +}; + +template +struct MaybeCalcMaxIndex { + inline void operator()( + opmath_t weight_val, + opmath_t out_val, + bool is_first, + thread I& max_idx, + I weight_idx, + bool pad) {} +}; + +template +struct MaybeCalcMaxIndex { + inline void operator()( + opmath_t weight_val, + opmath_t out_val, + bool is_first, + thread I& max_idx, + I weight_idx, + bool pad) { + max_idx = !pad && (is_first || weight_val > out_val) ? weight_idx : max_idx; } }; @@ -96,6 +114,30 @@ struct ReductionOpFinal { } }; +template +struct MaybeWriteMaxIndex { + inline void operator()( + device I*, + const constant ::c10::metal::array&, + uint32_t, + uint32_t, + I) {} +}; + +template +struct MaybeWriteMaxIndex { + inline void operator()( + device I* max_indices, + const constant ::c10::metal::array& max_indices_strides, + uint32_t bag_idx, + uint32_t feature_idx, + I max_idx) { + max_indices + [bag_idx * max_indices_strides[0] + + feature_idx * max_indices_strides[1]] = max_idx; + } +}; + template void embedding_bag_impl( constant T* weight, @@ -112,7 +154,7 @@ void embedding_bag_impl( auto num_bags = params.num_bags; auto feature_size = params.feature_size; auto padding_idx = params.padding_idx; - auto per_sample_weights_strides = params.per_sample_weights_strides; + auto per_sample_weights_stride = params.per_sample_weights_stride; constant auto& output_strides = params.output_strides; constant auto& weight_strides = params.weight_strides; constant auto& max_indices_strides = params.max_indices_strides; @@ -120,8 +162,6 @@ void embedding_bag_impl( auto bag_idx = tid / feature_size; auto feature_idx = tid % feature_size; - output += bag_idx * output_strides[0] + feature_idx * output_strides[1]; - uint32_t offsets_end = min(bag_idx + 1, num_bags - 1); bool is_last_bag = bag_idx + 1 == num_bags; uint32_t indices_start = static_cast(offsets[bag_idx]); @@ -131,28 +171,37 @@ void embedding_bag_impl( auto out_val = ReductionOpInit()(); uint32_t bag_size_ = 0; + I max_idx = 0; for (uint32_t indices_idx = indices_start; indices_idx < indices_end; indices_idx++) { I weight_idx = indices[indices_idx]; bool pad = (weight_idx == padding_idx); - T weight_val = weight - [static_cast(weight_idx) * weight_strides[0] + - feature_idx * weight_strides[1]]; + auto weight_val = static_cast>( + weight + [static_cast(weight_idx) * weight_strides[0] + + feature_idx * weight_strides[1]]); + weight_val = MaybeApplyPerSampleWeight()( + weight_val, indices_idx, per_sample_weights, per_sample_weights_stride); + + auto new_out_val = ReductionOp()(weight_val, out_val, bag_size_ == 0); + + MaybeCalcMaxIndex()( + weight_val, out_val, bag_size_ == 0, max_idx, weight_idx, pad); + + out_val = pad ? out_val : new_out_val; + offset2bag[indices_idx] = bag_idx; bag_size_ += static_cast(!pad); - - auto tmp_val = ReductionOp()( - weight_val, - out_val, - indices_idx, - per_sample_weights, - per_sample_weights_strides); - - out_val = pad ? out_val : tmp_val; } - *output = ReductionOpFinal()(out_val, bag_size_); + output[bag_idx * output_strides[0] + feature_idx * output_strides[1]] = + ReductionOpFinal()(out_val, bag_size_); + + bag_size[bag_idx] = bag_size_; + + MaybeWriteMaxIndex()( + max_indices, max_indices_strides, bag_idx, feature_idx, max_idx); } #define DISPATCH_IMPL(MODE) \ diff --git a/aten/src/ATen/native/mps/operations/EmbeddingBag.mm b/aten/src/ATen/native/mps/operations/EmbeddingBag.mm index d593fe2190d..b936500886e 100644 --- a/aten/src/ATen/native/mps/operations/EmbeddingBag.mm +++ b/aten/src/ATen/native/mps/operations/EmbeddingBag.mm @@ -66,11 +66,12 @@ static std::tuple _embedding_bag_mps_impl( int64_t num_indices = indices.size(0); int64_t num_bags = offsets.size(0); if (include_last_offset) { + TORCH_CHECK(num_bags >= 1, "include_last_offset: number of offsets should be at least 1"); num_bags -= 1; } int64_t feature_size = weight.size(1); - auto bag_size = at::empty(offsets.sizes(), indices.options()); + auto bag_size = at::empty({num_bags}, indices.options()); auto offset2bag = at::empty({indices.size(0)}, indices.options()); auto output = at::empty({num_bags, feature_size}, weight.options()); @@ -94,7 +95,7 @@ static std::tuple _embedding_bag_mps_impl( } bool use_per_sample_weights = per_sample_weights_opt.has_value() && per_sample_weights_opt->defined(); - params.per_sample_weights_strides = use_per_sample_weights ? per_sample_weights_opt->stride(0) : 0; + params.per_sample_weights_stride = use_per_sample_weights ? per_sample_weights_opt->stride(0) : 0; params.num_indices = num_indices; params.num_bags = num_bags; diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index 6cc31b4ab23..bfee21b54c7 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -7274,8 +7274,6 @@ GPU_TEST_FAILURES = { } MPS_TEST_FAILURES = { - # aten::_embedding_bag backward is not currently implemented for the MPS device. - "test_embedding_bag": fail_mps(), # aten::_scaled_dot_product_efficient_attention is not currently implemented for the MPS device. "test_scaled_dot_product_efficient_attention": fail_mps(), # aten::_int_mm is not implemented for MPS backend diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 6c89918bfce..2b3ff993cd4 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -5675,7 +5675,6 @@ class CommonTemplate: (torch.randn([2, 4, 4, 8]),), ) - @xfail_if_mps_unimplemented def test_embedding_bag(self): def fn(w, i, o): return aten._embedding_bag(w, i, o, False, 0, False, None) diff --git a/test/test_mps.py b/test/test_mps.py index 03cc4fe8b21..1a8f7af83e3 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -6940,6 +6940,70 @@ class TestMPS(TestCaseMPS): with self.assertRaisesRegex(RuntimeError, "Index to scalar can have only 1 value"): helper(22, 0, []) + # TODO: This test can be removed once the backward pass of embedding_bag is + # implemented and tested + @parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) + @parametrize("idx_dtype", [torch.long, torch.int]) + @parametrize("padding_idx", [-1, 1]) + @parametrize("include_last_offset", [True, False]) + @parametrize("mode", ['sum', 'mean', 'max']) + def test__embedding_bag(self, dtype, idx_dtype, padding_idx, include_last_offset, mode): + import time + torch.manual_seed(time.time() * 1000) + mode_num = {'sum': 0, 'mean': 1, 'max': 2}[mode] + num_words = 10 + feature_size = 7 + num_indices = 40 + num_bags = 5 + + weight_cpu = torch.randn(num_words, feature_size, dtype=dtype) + + # Test nan value behavior. + # Set second element of each word to nan. + weight_cpu[:, 1] = float('nan') + # Set third element of a randomized half of the words to nan. + weight_cpu[torch.randperm(num_words)[:num_words // 2], 2] = float('nan') + # Set fourth element of one randomized word to nan. + weight_cpu[torch.randint(0, num_words, ()), 3] = float('nan') + + input_cpu = torch.randint(0, num_words, (num_indices,), dtype=idx_dtype) + offsets_cpu = torch.tensor( + [0] + (torch.randperm(num_indices - 1)[:num_bags - 1].sort()[0] + 1).tolist(), + dtype=idx_dtype) + + if include_last_offset: + offsets_cpu[-1] = input_cpu.numel() + + per_sample_weights_cpu = torch.randn(num_indices, dtype=dtype) if mode == 'sum' else None + + r_cpu, offset2bag_cpu, bag_size_cpu, max_indices_cpu = torch._embedding_bag( + weight_cpu, + input_cpu, + offsets_cpu, + per_sample_weights=per_sample_weights_cpu, + mode=mode_num, + padding_idx=padding_idx, + include_last_offset=include_last_offset, + ) + r_mps, offset2bag_mps, bag_size_mps, max_indices_mps = torch._embedding_bag( + weight_cpu.to('mps'), + input_cpu.to('mps'), + offsets_cpu.to('mps'), + per_sample_weights=per_sample_weights_cpu.to('mps') if per_sample_weights_cpu is not None else None, + mode=mode_num, + padding_idx=padding_idx, + include_last_offset=include_last_offset, + ) + + self.assertEqual(r_cpu, r_mps) + + if mode != 'sum': + self.assertEqual(offset2bag_cpu, offset2bag_mps) + self.assertEqual(bag_size_cpu, bag_size_mps) + + if mode == 'max': + self.assertEqual(max_indices_cpu, max_indices_mps) + def test_embedding_dense_backward(self): def helper(n, d, m, idx): embeddingMPS = nn.Embedding(n, d, max_norm=True, device='mps')