[MPS] Add backward pass for embedding_bag (#163931)

Fixes #162270
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163931
Approved by: https://github.com/malfet
This commit is contained in:
Kurt Mohler 2025-10-02 13:03:47 -05:00 committed by PyTorch MergeBot
parent 86474ce996
commit ef50c6e3e3
7 changed files with 357 additions and 81 deletions

View File

@ -14,6 +14,7 @@ struct EmbeddingBagParams {
::c10::metal::array<idx_type_t, 2> output_strides;
::c10::metal::array<idx_type_t, 2> max_indices_strides;
bool use_per_sample_weights;
idx_type_t per_sample_weights_stride;
idx_type_t num_indices;
@ -23,3 +24,24 @@ struct EmbeddingBagParams {
EmbeddingBagMode mode;
int64_t padding_idx;
};
template <typename idx_type_t = uint32_t>
struct EmbeddingBagBackwardParams {
::c10::metal::array<idx_type_t, 2> weight_grad_strides;
::c10::metal::array<idx_type_t, 2> output_grad_strides;
::c10::metal::array<idx_type_t, 2> max_indices_strides;
bool use_per_sample_weights;
idx_type_t per_sample_weights_stride;
idx_type_t feature_size;
EmbeddingBagMode mode;
int64_t padding_idx;
};
template <typename idx_type_t = uint32_t>
struct EmbeddingBagPerSampleWeightsBackwardParams {
::c10::metal::array<idx_type_t, 2> output_grad_strides;
::c10::metal::array<idx_type_t, 2> weight_strides;
idx_type_t per_sample_weights_grad_stride;
idx_type_t feature_size;
int64_t padding_idx;
};

View File

@ -1,4 +1,5 @@
#include <ATen/native/mps/kernels/EmbeddingBag.h>
#include <c10/metal/atomic.h>
#include <c10/metal/utils.h>
#include <metal_array>
#include <metal_stdlib>
@ -44,6 +45,7 @@ template <EmbeddingBagMode M, typename T>
struct MaybeApplyPerSampleWeight {
inline opmath_t<T> operator()(
opmath_t<T> weight_val,
bool /*use_per_sample_weights*/,
uint32_t /*per_sample_weights_index*/,
constant T* /*per_sample_weights*/,
uint32_t /*per_sample_weights_stride*/) {
@ -55,10 +57,11 @@ template <typename T>
struct MaybeApplyPerSampleWeight<EmbeddingBagMode::SUM, T> {
inline opmath_t<T> operator()(
opmath_t<T> weight_val,
bool use_per_sample_weights,
uint32_t per_sample_weights_index,
constant T* per_sample_weights,
uint32_t per_sample_weights_stride) {
if (per_sample_weights_stride) {
if (use_per_sample_weights) {
T per_sample_weight = per_sample_weights
[per_sample_weights_stride * per_sample_weights_index];
return static_cast<opmath_t<T>>(per_sample_weight) * weight_val;
@ -154,6 +157,7 @@ void embedding_bag_impl(
auto num_bags = params.num_bags;
auto feature_size = params.feature_size;
auto padding_idx = params.padding_idx;
auto use_per_sample_weights = params.use_per_sample_weights;
auto per_sample_weights_stride = params.per_sample_weights_stride;
constant auto& output_strides = params.output_strides;
constant auto& weight_strides = params.weight_strides;
@ -183,7 +187,11 @@ void embedding_bag_impl(
feature_idx * weight_strides[1]]);
weight_val = MaybeApplyPerSampleWeight<M, T>()(
weight_val, indices_idx, per_sample_weights, per_sample_weights_stride);
weight_val,
use_per_sample_weights,
indices_idx,
per_sample_weights,
per_sample_weights_stride);
auto new_out_val = ReductionOp<M, T>()(weight_val, out_val, bag_size_ == 0);
@ -239,19 +247,208 @@ kernel void embedding_bag(
}
}
#define REGISTER_EMBEDDING_BAG_OP(T, I) \
template [[host_name("embedding_bag_" #T "_" #I)]] \
kernel void embedding_bag<T, I>( \
constant T * weight [[buffer(0)]], \
constant I * indices [[buffer(1)]], \
constant I * offsets [[buffer(2)]], \
constant T * per_sample_weights [[buffer(3)]], \
device T * output [[buffer(4)]], \
device I * offset2bag [[buffer(5)]], \
device I * bag_size [[buffer(6)]], \
device I * max_indices [[buffer(7)]], \
constant EmbeddingBagParams<uint32_t> & params [[buffer(8)]], \
uint tid [[thread_position_in_grid]]);
template <EmbeddingBagMode M, typename T>
struct MaybeDivBagSize {
inline opmath_t<T> operator()(opmath_t<T> val, opmath_t<T> bag_size) {
return val;
}
};
template <typename T>
struct MaybeDivBagSize<EmbeddingBagMode::MEAN, T> {
inline opmath_t<T> operator()(opmath_t<T> val, opmath_t<T> bag_size) {
return val / bag_size;
}
};
template <EmbeddingBagMode M, typename T, typename I>
void embedding_bag_backward_sum_mean_impl(
constant T* output_grad,
constant I* indices,
constant I* offset2bag,
constant I* bag_size,
constant T* per_sample_weights,
device AtomicType_t<T>* weight_grad,
constant EmbeddingBagBackwardParams<uint32_t>& params,
uint tid) {
auto feature_size = params.feature_size;
auto indices_idx = tid / feature_size;
auto bag_idx = static_cast<uint32_t>(offset2bag[indices_idx]);
auto bag_size_val = bag_size[bag_idx];
auto weight_idx = indices[indices_idx];
auto padding_idx = params.padding_idx;
if (bag_size_val && weight_idx != padding_idx) {
auto feature_idx = tid % feature_size;
constant auto& weight_grad_strides = params.weight_grad_strides;
constant auto& output_grad_strides = params.output_grad_strides;
auto use_per_sample_weights = params.use_per_sample_weights;
auto per_sample_weights_stride = params.per_sample_weights_stride;
auto output_grad_val =
static_cast<opmath_t<T>>(output_grad
[bag_idx * output_grad_strides[0] +
feature_idx * output_grad_strides[1]]);
opmath_t<T> weight_grad_val = MaybeDivBagSize<M, T>()(
MaybeApplyPerSampleWeight<M, T>()(
output_grad_val,
use_per_sample_weights,
indices_idx,
per_sample_weights,
per_sample_weights_stride),
static_cast<opmath_t<T>>(bag_size_val));
AtomicType<T>::atomic_add(
weight_grad,
static_cast<int32_t>(weight_idx) * weight_grad_strides[0] +
feature_idx * weight_grad_strides[1],
static_cast<T>(weight_grad_val));
}
}
template <typename T, typename I>
void embedding_bag_backward_max_impl(
constant T* output_grad,
constant I* bag_size,
constant I* max_indices,
device AtomicType_t<T>* weight_grad,
constant EmbeddingBagBackwardParams<uint32_t>& params,
uint tid) {
auto feature_size = params.feature_size;
auto bag_idx = tid / feature_size;
auto bag_size_val = bag_size[bag_idx];
if (bag_size_val) {
auto feature_idx = tid % feature_size;
constant auto& weight_grad_strides = params.weight_grad_strides;
constant auto& output_grad_strides = params.output_grad_strides;
constant auto& max_indices_strides = params.max_indices_strides;
auto output_grad_val = output_grad
[bag_idx * output_grad_strides[0] +
feature_idx * output_grad_strides[1]];
auto max_index =
static_cast<uint32_t>(max_indices
[bag_idx * max_indices_strides[0] +
feature_idx * max_indices_strides[1]]);
AtomicType<T>::atomic_add(
weight_grad,
max_index * weight_grad_strides[0] +
feature_idx * weight_grad_strides[1],
output_grad_val);
}
}
#define DISPATCH_BACKWARD_SUM_MEAN_IMPL(MODE) \
return embedding_bag_backward_sum_mean_impl<MODE>( \
output_grad, \
indices, \
offset2bag, \
bag_size, \
per_sample_weights, \
weight_grad, \
params, \
tid)
template <typename T, typename I>
kernel void embedding_bag_backward(
constant T* output_grad [[buffer(0)]],
constant I* indices [[buffer(1)]],
constant I* offset2bag [[buffer(2)]],
constant I* bag_size [[buffer(3)]],
constant I* max_indices [[buffer(4)]],
constant T* per_sample_weights [[buffer(5)]],
device AtomicType_t<T>* weight_grad [[buffer(6)]],
constant EmbeddingBagBackwardParams<uint32_t>& params [[buffer(7)]],
uint tid [[thread_position_in_grid]]) {
switch (params.mode) {
case EmbeddingBagMode::SUM:
DISPATCH_BACKWARD_SUM_MEAN_IMPL(EmbeddingBagMode::SUM);
case EmbeddingBagMode::MEAN:
DISPATCH_BACKWARD_SUM_MEAN_IMPL(EmbeddingBagMode::MEAN);
case EmbeddingBagMode::MAX:
return embedding_bag_backward_max_impl(
output_grad, bag_size, max_indices, weight_grad, params, tid);
}
}
template <typename T, typename I>
kernel void embedding_bag_per_sample_weights_backward(
constant T* output_grad [[buffer(0)]],
constant T* weight [[buffer(1)]],
constant I* indices [[buffer(2)]],
constant I* offset2bag [[buffer(3)]],
device AtomicType_t<T>* per_sample_weights_grad [[buffer(4)]],
constant EmbeddingBagPerSampleWeightsBackwardParams<uint32_t>& params
[[buffer(5)]],
uint tid [[thread_position_in_grid]]) {
auto feature_size = params.feature_size;
auto padding_idx = params.padding_idx;
auto indices_idx = tid / feature_size;
auto weight_idx = indices[indices_idx];
if (weight_idx != padding_idx) {
auto feature_idx = tid % feature_size;
auto bag_idx = static_cast<uint32_t>(offset2bag[indices_idx]);
constant auto& output_grad_strides = params.output_grad_strides;
constant auto& weight_strides = params.weight_strides;
auto per_sample_weights_grad_stride = params.per_sample_weights_grad_stride;
auto weight_val = weight
[static_cast<uint32_t>(weight_idx) * weight_strides[0] +
feature_idx * weight_strides[1]];
auto output_grad_val = output_grad
[bag_idx * output_grad_strides[0] +
feature_idx * output_grad_strides[1]];
auto per_sample_weights_grad_val = static_cast<opmath_t<T>>(weight_val) *
static_cast<opmath_t<T>>(output_grad_val);
AtomicType<T>::atomic_add(
per_sample_weights_grad,
indices_idx * per_sample_weights_grad_stride,
static_cast<T>(per_sample_weights_grad_val));
}
}
#define REGISTER_EMBEDDING_BAG_OP(T, I) \
template [[host_name("embedding_bag_" #T "_" #I)]] \
kernel void embedding_bag<T, I>( \
constant T * weight [[buffer(0)]], \
constant I * indices [[buffer(1)]], \
constant I * offsets [[buffer(2)]], \
constant T * per_sample_weights [[buffer(3)]], \
device T * output [[buffer(4)]], \
device I * offset2bag [[buffer(5)]], \
device I * bag_size [[buffer(6)]], \
device I * max_indices [[buffer(7)]], \
constant EmbeddingBagParams<uint32_t> & params [[buffer(8)]], \
uint tid [[thread_position_in_grid]]); \
\
template [[host_name("embedding_bag_backward_" #T "_" #I)]] \
kernel void embedding_bag_backward<T, I>( \
constant T * output_grad [[buffer(0)]], \
constant I * indices [[buffer(1)]], \
constant I * offset2bag [[buffer(2)]], \
constant I * bag_size [[buffer(3)]], \
constant I * max_indices [[buffer(4)]], \
constant T * per_sample_weights [[buffer(5)]], \
device AtomicType_t<T> * weight_grad [[buffer(6)]], \
constant EmbeddingBagBackwardParams<uint32_t> & params [[buffer(7)]], \
uint tid [[thread_position_in_grid]]); \
\
template \
[[host_name("embedding_bag_per_sample_weights_backward_" #T "_" #I)]] \
kernel void embedding_bag_per_sample_weights_backward<T, I>( \
constant T * output_grad [[buffer(0)]], \
constant T * weight [[buffer(1)]], \
constant I * indices [[buffer(2)]], \
constant I * offset2bag [[buffer(3)]], \
device AtomicType_t<T> * per_sample_weights_grad [[buffer(4)]], \
constant EmbeddingBagPerSampleWeightsBackwardParams<uint32_t> & \
params [[buffer(5)]], \
uint tid [[thread_position_in_grid]]);
REGISTER_EMBEDDING_BAG_OP(float, int);
REGISTER_EMBEDDING_BAG_OP(float, long);

View File

@ -13,8 +13,10 @@
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_embedding_bag_dense_backward_native.h>
#include <ATen/ops/_embedding_bag_forward_only_native.h>
#include <ATen/ops/_embedding_bag_native.h>
#include <ATen/ops/_embedding_bag_per_sample_weights_backward_native.h>
#include <ATen/ops/empty.h>
#endif
@ -95,6 +97,7 @@ static std::tuple<Tensor, Tensor, Tensor, Tensor> _embedding_bag_mps_impl(
}
bool use_per_sample_weights = per_sample_weights_opt.has_value() && per_sample_weights_opt->defined();
params.use_per_sample_weights = use_per_sample_weights;
params.per_sample_weights_stride = use_per_sample_weights ? per_sample_weights_opt->stride(0) : 0;
params.num_indices = num_indices;
@ -177,4 +180,117 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> _embedding_bag_forward_only_mps(
padding_idx);
}
Tensor _embedding_bag_dense_backward_mps(const Tensor& output_grad,
const Tensor& indices,
const Tensor& offset2bag,
const Tensor& bag_size,
const Tensor& max_indices,
int64_t num_weights,
bool scale_grad_by_freq,
int64_t mode,
const std::optional<Tensor>& per_sample_weights_opt,
int64_t padding_idx) {
// indices and offset2bag are assumed having correct dtypes and
// contiguous here due to the checks in _embedding_bag_backward in
// EmbeddingBag.cpp.
// Also see NOTE [ embedding_bag Native Functions ] in native_functions.yaml
// for more details.
int64_t feature_size = output_grad.size(1);
auto weight_grad = at::zeros({num_weights, feature_size}, output_grad.options());
EmbeddingBagBackwardParams<uint32_t> params;
for (const auto dim : c10::irange(2)) {
params.output_grad_strides[dim] = output_grad.stride(dim);
params.weight_grad_strides[dim] = weight_grad.stride(dim);
if (mode == EmbeddingBagMode::MAX) {
params.max_indices_strides[dim] = safe_downcast<uint32_t, int64_t>(max_indices.stride(dim));
}
}
bool use_per_sample_weights = per_sample_weights_opt.has_value() && per_sample_weights_opt->defined();
params.use_per_sample_weights = use_per_sample_weights;
params.per_sample_weights_stride = use_per_sample_weights ? per_sample_weights_opt->stride(0) : 0;
params.feature_size = output_grad.size(1);
params.mode = static_cast<EmbeddingBagMode>(mode);
params.padding_idx = padding_idx;
auto num_indices = offset2bag.numel();
auto num_threads = (params.mode == EmbeddingBagMode::MAX) ? output_grad.numel() : num_indices * params.feature_size;
MPSStream* stream = getCurrentMPSStream();
mps::dispatch_sync_with_rethrow(stream->queue(), ^() {
@autoreleasepool {
id<MTLComputeCommandEncoder> computeEncoder = stream->commandEncoder();
auto pipeline_state = lib.getPipelineStateForFunc(fmt::format("embedding_bag_backward_{}_{}",
mps::scalarToMetalTypeString(output_grad),
mps::scalarToMetalTypeString(indices)));
getMPSProfiler().beginProfileKernel(
pipeline_state, "embedding_bag", {output_grad, indices, offset2bag, bag_size});
[computeEncoder setComputePipelineState:pipeline_state];
mps::mtl_setArgs(computeEncoder,
output_grad,
indices,
offset2bag,
bag_size,
max_indices,
use_per_sample_weights ? per_sample_weights_opt : std::nullopt,
weight_grad,
params);
mps::mtl_dispatch1DJob(computeEncoder, pipeline_state, num_threads);
getMPSProfiler().endProfileKernel(pipeline_state);
}
});
return std::move(weight_grad);
}
Tensor _embedding_bag_per_sample_weights_backward_mps(const Tensor& output_grad,
const Tensor& weight,
const Tensor& indices,
const Tensor& offsets,
const Tensor& offset2bag,
int64_t mode,
int64_t padding_idx) {
TORCH_INTERNAL_ASSERT(static_cast<EmbeddingBagMode>(mode) == EmbeddingBagMode::SUM);
int64_t num_indices = indices.size(0);
int64_t feature_size = output_grad.size(1);
auto per_sample_weights_grad = at::zeros({num_indices}, output_grad.options());
EmbeddingBagPerSampleWeightsBackwardParams params;
for (const auto dim : c10::irange(2)) {
params.output_grad_strides[dim] = output_grad.stride(dim);
params.weight_strides[dim] = weight.stride(dim);
}
params.per_sample_weights_grad_stride = per_sample_weights_grad.stride(0);
params.feature_size = feature_size;
params.padding_idx = padding_idx;
auto num_threads = num_indices * feature_size;
MPSStream* stream = getCurrentMPSStream();
mps::dispatch_sync_with_rethrow(stream->queue(), ^() {
@autoreleasepool {
id<MTLComputeCommandEncoder> computeEncoder = stream->commandEncoder();
auto pipeline_state = lib.getPipelineStateForFunc(fmt::format("embedding_bag_per_sample_weights_backward_{}_{}",
mps::scalarToMetalTypeString(output_grad),
mps::scalarToMetalTypeString(indices)));
getMPSProfiler().beginProfileKernel(
pipeline_state, "embedding_bag_per_sample_weights_backward", {output_grad, weight, indices, offset2bag});
[computeEncoder setComputePipelineState:pipeline_state];
mps::mtl_setArgs(computeEncoder, output_grad, weight, indices, offset2bag, per_sample_weights_grad, params);
mps::mtl_dispatch1DJob(computeEncoder, pipeline_state, num_threads);
getMPSProfiler().endProfileKernel(pipeline_state);
}
});
return std::move(per_sample_weights_grad);
}
} // namespace at::native

View File

@ -2379,7 +2379,7 @@
- func: _embedding_bag_backward(Tensor grad, Tensor indices, Tensor offsets, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, SymInt num_weights, bool scale_grad_by_freq, int mode, bool sparse, Tensor? per_sample_weights, int padding_idx=-1) -> Tensor
dispatch:
CPU, CUDA: _embedding_bag_backward_symint
CPU, CUDA, MPS: _embedding_bag_backward_symint
- func: _embedding_bag_sparse_backward(Tensor grad, Tensor indices, Tensor offsets, Tensor offset2bag, Tensor bag_size, SymInt num_weights, bool scale_grad_by_freq, int mode, Tensor? per_sample_weights, int padding_idx=-1) -> Tensor
dispatch:
@ -2389,12 +2389,14 @@
dispatch:
CPU: _embedding_bag_dense_backward_cpu
CUDA: _embedding_bag_dense_backward_cuda
MPS: _embedding_bag_dense_backward_mps
autogen: _embedding_bag_dense_backward.out
- func: _embedding_bag_per_sample_weights_backward(Tensor grad, Tensor weight, Tensor indices, Tensor offsets, Tensor offset2bag, int mode, int padding_idx=-1) -> Tensor
dispatch:
CPU: _embedding_bag_per_sample_weights_backward_cpu
CUDA: _embedding_bag_per_sample_weights_backward_cuda
MPS: _embedding_bag_per_sample_weights_backward_mps
autogen: _embedding_bag_per_sample_weights_backward.out
- func: empty.names(int[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor

View File

@ -6940,70 +6940,6 @@ class TestMPS(TestCaseMPS):
with self.assertRaisesRegex(RuntimeError, "Index to scalar can have only 1 value"):
helper(22, 0, [])
# TODO: This test can be removed once the backward pass of embedding_bag is
# implemented and tested
@parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
@parametrize("idx_dtype", [torch.long, torch.int])
@parametrize("padding_idx", [-1, 1])
@parametrize("include_last_offset", [True, False])
@parametrize("mode", ['sum', 'mean', 'max'])
def test__embedding_bag(self, dtype, idx_dtype, padding_idx, include_last_offset, mode):
import time
torch.manual_seed(time.time() * 1000)
mode_num = {'sum': 0, 'mean': 1, 'max': 2}[mode]
num_words = 10
feature_size = 7
num_indices = 40
num_bags = 5
weight_cpu = torch.randn(num_words, feature_size, dtype=dtype)
# Test nan value behavior.
# Set second element of each word to nan.
weight_cpu[:, 1] = float('nan')
# Set third element of a randomized half of the words to nan.
weight_cpu[torch.randperm(num_words)[:num_words // 2], 2] = float('nan')
# Set fourth element of one randomized word to nan.
weight_cpu[torch.randint(0, num_words, ()), 3] = float('nan')
input_cpu = torch.randint(0, num_words, (num_indices,), dtype=idx_dtype)
offsets_cpu = torch.tensor(
[0] + (torch.randperm(num_indices - 1)[:num_bags - 1].sort()[0] + 1).tolist(),
dtype=idx_dtype)
if include_last_offset:
offsets_cpu[-1] = input_cpu.numel()
per_sample_weights_cpu = torch.randn(num_indices, dtype=dtype) if mode == 'sum' else None
r_cpu, offset2bag_cpu, bag_size_cpu, max_indices_cpu = torch._embedding_bag(
weight_cpu,
input_cpu,
offsets_cpu,
per_sample_weights=per_sample_weights_cpu,
mode=mode_num,
padding_idx=padding_idx,
include_last_offset=include_last_offset,
)
r_mps, offset2bag_mps, bag_size_mps, max_indices_mps = torch._embedding_bag(
weight_cpu.to('mps'),
input_cpu.to('mps'),
offsets_cpu.to('mps'),
per_sample_weights=per_sample_weights_cpu.to('mps') if per_sample_weights_cpu is not None else None,
mode=mode_num,
padding_idx=padding_idx,
include_last_offset=include_last_offset,
)
self.assertEqual(r_cpu, r_mps)
if mode != 'sum':
self.assertEqual(offset2bag_cpu, offset2bag_mps)
self.assertEqual(bag_size_cpu, bag_size_mps)
if mode == 'max':
self.assertEqual(max_indices_cpu, max_indices_mps)
def test_embedding_dense_backward(self):
def helper(n, d, m, idx):
embeddingMPS = nn.Embedding(n, d, max_norm=True, device='mps')
@ -12530,6 +12466,8 @@ class TestConsistency(TestCaseMPS):
# several output grad elements of similar magnitudes get summed
# together, introducing significant error for float16.
atol, rtol = 5e-3, 5e-3
if op.name == "nn.functional.embedding_bag" and dtype == torch.float16:
atol, rtol = 5e-3, 5e-3
self.assertEqual(cpu_grad_inputs, mps_grad_inputs, atol=atol, rtol=rtol)
# The CPU impl of grid_sampler_3d gives a large amount of error for half

View File

@ -16,7 +16,9 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__adaptive_avg_pool2d_backward(At
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__cdist_forward(AtenTensorHandle x1, AtenTensorHandle x2, double p, int64_t* compute_mode, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__efficientzerotensor(const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__embedding_bag(AtenTensorHandle weight, AtenTensorHandle indices, AtenTensorHandle offsets, int32_t scale_grad_by_freq, int64_t mode, int32_t sparse, AtenTensorHandle* per_sample_weights, int32_t include_last_offset, int64_t padding_idx, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__embedding_bag_dense_backward(AtenTensorHandle grad, AtenTensorHandle indices, AtenTensorHandle offset2bag, AtenTensorHandle bag_size, AtenTensorHandle maximum_indices, int64_t num_weights, int32_t scale_grad_by_freq, int64_t mode, AtenTensorHandle* per_sample_weights, int64_t padding_idx, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__embedding_bag_forward_only(AtenTensorHandle weight, AtenTensorHandle indices, AtenTensorHandle offsets, int32_t scale_grad_by_freq, int64_t mode, int32_t sparse, AtenTensorHandle* per_sample_weights, int32_t include_last_offset, int64_t padding_idx, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__embedding_bag_per_sample_weights_backward(AtenTensorHandle grad, AtenTensorHandle weight, AtenTensorHandle indices, AtenTensorHandle offsets, AtenTensorHandle offset2bag, int64_t mode, int64_t padding_idx, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fft_c2c(AtenTensorHandle self, const int64_t* dim, int64_t dim_len_, int64_t normalization, int32_t forward, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fft_r2c(AtenTensorHandle self, const int64_t* dim, int64_t dim_len_, int64_t normalization, int32_t onesided, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fused_moving_avg_obs_fq_helper_functional(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4, AtenTensorHandle* ret5);

View File

@ -737,7 +737,6 @@ if torch.backends.mps.is_available():
"equal": [torch.float16, torch.float32],
# 'float' object is not iterable
"item": [torch.float16, torch.float32],
"nn.functional.embedding_bag": None,
# "smooth_l1_backward_cpu_out" not implemented for 'Half'
"nn.functional.smooth_l1_loss": [torch.float16],
# cpu error: grad requires non-empty inputs