mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[MPS] Add backward pass for embedding_bag (#163931)
Fixes #162270 Pull Request resolved: https://github.com/pytorch/pytorch/pull/163931 Approved by: https://github.com/malfet
This commit is contained in:
parent
86474ce996
commit
ef50c6e3e3
|
|
@ -14,6 +14,7 @@ struct EmbeddingBagParams {
|
|||
::c10::metal::array<idx_type_t, 2> output_strides;
|
||||
::c10::metal::array<idx_type_t, 2> max_indices_strides;
|
||||
|
||||
bool use_per_sample_weights;
|
||||
idx_type_t per_sample_weights_stride;
|
||||
|
||||
idx_type_t num_indices;
|
||||
|
|
@ -23,3 +24,24 @@ struct EmbeddingBagParams {
|
|||
EmbeddingBagMode mode;
|
||||
int64_t padding_idx;
|
||||
};
|
||||
|
||||
template <typename idx_type_t = uint32_t>
|
||||
struct EmbeddingBagBackwardParams {
|
||||
::c10::metal::array<idx_type_t, 2> weight_grad_strides;
|
||||
::c10::metal::array<idx_type_t, 2> output_grad_strides;
|
||||
::c10::metal::array<idx_type_t, 2> max_indices_strides;
|
||||
bool use_per_sample_weights;
|
||||
idx_type_t per_sample_weights_stride;
|
||||
idx_type_t feature_size;
|
||||
EmbeddingBagMode mode;
|
||||
int64_t padding_idx;
|
||||
};
|
||||
|
||||
template <typename idx_type_t = uint32_t>
|
||||
struct EmbeddingBagPerSampleWeightsBackwardParams {
|
||||
::c10::metal::array<idx_type_t, 2> output_grad_strides;
|
||||
::c10::metal::array<idx_type_t, 2> weight_strides;
|
||||
idx_type_t per_sample_weights_grad_stride;
|
||||
idx_type_t feature_size;
|
||||
int64_t padding_idx;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
#include <ATen/native/mps/kernels/EmbeddingBag.h>
|
||||
#include <c10/metal/atomic.h>
|
||||
#include <c10/metal/utils.h>
|
||||
#include <metal_array>
|
||||
#include <metal_stdlib>
|
||||
|
|
@ -44,6 +45,7 @@ template <EmbeddingBagMode M, typename T>
|
|||
struct MaybeApplyPerSampleWeight {
|
||||
inline opmath_t<T> operator()(
|
||||
opmath_t<T> weight_val,
|
||||
bool /*use_per_sample_weights*/,
|
||||
uint32_t /*per_sample_weights_index*/,
|
||||
constant T* /*per_sample_weights*/,
|
||||
uint32_t /*per_sample_weights_stride*/) {
|
||||
|
|
@ -55,10 +57,11 @@ template <typename T>
|
|||
struct MaybeApplyPerSampleWeight<EmbeddingBagMode::SUM, T> {
|
||||
inline opmath_t<T> operator()(
|
||||
opmath_t<T> weight_val,
|
||||
bool use_per_sample_weights,
|
||||
uint32_t per_sample_weights_index,
|
||||
constant T* per_sample_weights,
|
||||
uint32_t per_sample_weights_stride) {
|
||||
if (per_sample_weights_stride) {
|
||||
if (use_per_sample_weights) {
|
||||
T per_sample_weight = per_sample_weights
|
||||
[per_sample_weights_stride * per_sample_weights_index];
|
||||
return static_cast<opmath_t<T>>(per_sample_weight) * weight_val;
|
||||
|
|
@ -154,6 +157,7 @@ void embedding_bag_impl(
|
|||
auto num_bags = params.num_bags;
|
||||
auto feature_size = params.feature_size;
|
||||
auto padding_idx = params.padding_idx;
|
||||
auto use_per_sample_weights = params.use_per_sample_weights;
|
||||
auto per_sample_weights_stride = params.per_sample_weights_stride;
|
||||
constant auto& output_strides = params.output_strides;
|
||||
constant auto& weight_strides = params.weight_strides;
|
||||
|
|
@ -183,7 +187,11 @@ void embedding_bag_impl(
|
|||
feature_idx * weight_strides[1]]);
|
||||
|
||||
weight_val = MaybeApplyPerSampleWeight<M, T>()(
|
||||
weight_val, indices_idx, per_sample_weights, per_sample_weights_stride);
|
||||
weight_val,
|
||||
use_per_sample_weights,
|
||||
indices_idx,
|
||||
per_sample_weights,
|
||||
per_sample_weights_stride);
|
||||
|
||||
auto new_out_val = ReductionOp<M, T>()(weight_val, out_val, bag_size_ == 0);
|
||||
|
||||
|
|
@ -239,19 +247,208 @@ kernel void embedding_bag(
|
|||
}
|
||||
}
|
||||
|
||||
#define REGISTER_EMBEDDING_BAG_OP(T, I) \
|
||||
template [[host_name("embedding_bag_" #T "_" #I)]] \
|
||||
kernel void embedding_bag<T, I>( \
|
||||
constant T * weight [[buffer(0)]], \
|
||||
constant I * indices [[buffer(1)]], \
|
||||
constant I * offsets [[buffer(2)]], \
|
||||
constant T * per_sample_weights [[buffer(3)]], \
|
||||
device T * output [[buffer(4)]], \
|
||||
device I * offset2bag [[buffer(5)]], \
|
||||
device I * bag_size [[buffer(6)]], \
|
||||
device I * max_indices [[buffer(7)]], \
|
||||
constant EmbeddingBagParams<uint32_t> & params [[buffer(8)]], \
|
||||
uint tid [[thread_position_in_grid]]);
|
||||
template <EmbeddingBagMode M, typename T>
|
||||
struct MaybeDivBagSize {
|
||||
inline opmath_t<T> operator()(opmath_t<T> val, opmath_t<T> bag_size) {
|
||||
return val;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct MaybeDivBagSize<EmbeddingBagMode::MEAN, T> {
|
||||
inline opmath_t<T> operator()(opmath_t<T> val, opmath_t<T> bag_size) {
|
||||
return val / bag_size;
|
||||
}
|
||||
};
|
||||
|
||||
template <EmbeddingBagMode M, typename T, typename I>
|
||||
void embedding_bag_backward_sum_mean_impl(
|
||||
constant T* output_grad,
|
||||
constant I* indices,
|
||||
constant I* offset2bag,
|
||||
constant I* bag_size,
|
||||
constant T* per_sample_weights,
|
||||
device AtomicType_t<T>* weight_grad,
|
||||
constant EmbeddingBagBackwardParams<uint32_t>& params,
|
||||
uint tid) {
|
||||
auto feature_size = params.feature_size;
|
||||
auto indices_idx = tid / feature_size;
|
||||
auto bag_idx = static_cast<uint32_t>(offset2bag[indices_idx]);
|
||||
auto bag_size_val = bag_size[bag_idx];
|
||||
auto weight_idx = indices[indices_idx];
|
||||
auto padding_idx = params.padding_idx;
|
||||
|
||||
if (bag_size_val && weight_idx != padding_idx) {
|
||||
auto feature_idx = tid % feature_size;
|
||||
constant auto& weight_grad_strides = params.weight_grad_strides;
|
||||
constant auto& output_grad_strides = params.output_grad_strides;
|
||||
auto use_per_sample_weights = params.use_per_sample_weights;
|
||||
auto per_sample_weights_stride = params.per_sample_weights_stride;
|
||||
|
||||
auto output_grad_val =
|
||||
static_cast<opmath_t<T>>(output_grad
|
||||
[bag_idx * output_grad_strides[0] +
|
||||
feature_idx * output_grad_strides[1]]);
|
||||
|
||||
opmath_t<T> weight_grad_val = MaybeDivBagSize<M, T>()(
|
||||
MaybeApplyPerSampleWeight<M, T>()(
|
||||
output_grad_val,
|
||||
use_per_sample_weights,
|
||||
indices_idx,
|
||||
per_sample_weights,
|
||||
per_sample_weights_stride),
|
||||
static_cast<opmath_t<T>>(bag_size_val));
|
||||
|
||||
AtomicType<T>::atomic_add(
|
||||
weight_grad,
|
||||
static_cast<int32_t>(weight_idx) * weight_grad_strides[0] +
|
||||
feature_idx * weight_grad_strides[1],
|
||||
static_cast<T>(weight_grad_val));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename I>
|
||||
void embedding_bag_backward_max_impl(
|
||||
constant T* output_grad,
|
||||
constant I* bag_size,
|
||||
constant I* max_indices,
|
||||
device AtomicType_t<T>* weight_grad,
|
||||
constant EmbeddingBagBackwardParams<uint32_t>& params,
|
||||
uint tid) {
|
||||
auto feature_size = params.feature_size;
|
||||
auto bag_idx = tid / feature_size;
|
||||
auto bag_size_val = bag_size[bag_idx];
|
||||
|
||||
if (bag_size_val) {
|
||||
auto feature_idx = tid % feature_size;
|
||||
constant auto& weight_grad_strides = params.weight_grad_strides;
|
||||
constant auto& output_grad_strides = params.output_grad_strides;
|
||||
constant auto& max_indices_strides = params.max_indices_strides;
|
||||
|
||||
auto output_grad_val = output_grad
|
||||
[bag_idx * output_grad_strides[0] +
|
||||
feature_idx * output_grad_strides[1]];
|
||||
auto max_index =
|
||||
static_cast<uint32_t>(max_indices
|
||||
[bag_idx * max_indices_strides[0] +
|
||||
feature_idx * max_indices_strides[1]]);
|
||||
|
||||
AtomicType<T>::atomic_add(
|
||||
weight_grad,
|
||||
max_index * weight_grad_strides[0] +
|
||||
feature_idx * weight_grad_strides[1],
|
||||
output_grad_val);
|
||||
}
|
||||
}
|
||||
|
||||
#define DISPATCH_BACKWARD_SUM_MEAN_IMPL(MODE) \
|
||||
return embedding_bag_backward_sum_mean_impl<MODE>( \
|
||||
output_grad, \
|
||||
indices, \
|
||||
offset2bag, \
|
||||
bag_size, \
|
||||
per_sample_weights, \
|
||||
weight_grad, \
|
||||
params, \
|
||||
tid)
|
||||
|
||||
template <typename T, typename I>
|
||||
kernel void embedding_bag_backward(
|
||||
constant T* output_grad [[buffer(0)]],
|
||||
constant I* indices [[buffer(1)]],
|
||||
constant I* offset2bag [[buffer(2)]],
|
||||
constant I* bag_size [[buffer(3)]],
|
||||
constant I* max_indices [[buffer(4)]],
|
||||
constant T* per_sample_weights [[buffer(5)]],
|
||||
device AtomicType_t<T>* weight_grad [[buffer(6)]],
|
||||
constant EmbeddingBagBackwardParams<uint32_t>& params [[buffer(7)]],
|
||||
uint tid [[thread_position_in_grid]]) {
|
||||
switch (params.mode) {
|
||||
case EmbeddingBagMode::SUM:
|
||||
DISPATCH_BACKWARD_SUM_MEAN_IMPL(EmbeddingBagMode::SUM);
|
||||
case EmbeddingBagMode::MEAN:
|
||||
DISPATCH_BACKWARD_SUM_MEAN_IMPL(EmbeddingBagMode::MEAN);
|
||||
case EmbeddingBagMode::MAX:
|
||||
return embedding_bag_backward_max_impl(
|
||||
output_grad, bag_size, max_indices, weight_grad, params, tid);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename I>
|
||||
kernel void embedding_bag_per_sample_weights_backward(
|
||||
constant T* output_grad [[buffer(0)]],
|
||||
constant T* weight [[buffer(1)]],
|
||||
constant I* indices [[buffer(2)]],
|
||||
constant I* offset2bag [[buffer(3)]],
|
||||
device AtomicType_t<T>* per_sample_weights_grad [[buffer(4)]],
|
||||
constant EmbeddingBagPerSampleWeightsBackwardParams<uint32_t>& params
|
||||
[[buffer(5)]],
|
||||
uint tid [[thread_position_in_grid]]) {
|
||||
auto feature_size = params.feature_size;
|
||||
auto padding_idx = params.padding_idx;
|
||||
auto indices_idx = tid / feature_size;
|
||||
auto weight_idx = indices[indices_idx];
|
||||
|
||||
if (weight_idx != padding_idx) {
|
||||
auto feature_idx = tid % feature_size;
|
||||
auto bag_idx = static_cast<uint32_t>(offset2bag[indices_idx]);
|
||||
constant auto& output_grad_strides = params.output_grad_strides;
|
||||
constant auto& weight_strides = params.weight_strides;
|
||||
auto per_sample_weights_grad_stride = params.per_sample_weights_grad_stride;
|
||||
|
||||
auto weight_val = weight
|
||||
[static_cast<uint32_t>(weight_idx) * weight_strides[0] +
|
||||
feature_idx * weight_strides[1]];
|
||||
auto output_grad_val = output_grad
|
||||
[bag_idx * output_grad_strides[0] +
|
||||
feature_idx * output_grad_strides[1]];
|
||||
auto per_sample_weights_grad_val = static_cast<opmath_t<T>>(weight_val) *
|
||||
static_cast<opmath_t<T>>(output_grad_val);
|
||||
|
||||
AtomicType<T>::atomic_add(
|
||||
per_sample_weights_grad,
|
||||
indices_idx * per_sample_weights_grad_stride,
|
||||
static_cast<T>(per_sample_weights_grad_val));
|
||||
}
|
||||
}
|
||||
|
||||
#define REGISTER_EMBEDDING_BAG_OP(T, I) \
|
||||
template [[host_name("embedding_bag_" #T "_" #I)]] \
|
||||
kernel void embedding_bag<T, I>( \
|
||||
constant T * weight [[buffer(0)]], \
|
||||
constant I * indices [[buffer(1)]], \
|
||||
constant I * offsets [[buffer(2)]], \
|
||||
constant T * per_sample_weights [[buffer(3)]], \
|
||||
device T * output [[buffer(4)]], \
|
||||
device I * offset2bag [[buffer(5)]], \
|
||||
device I * bag_size [[buffer(6)]], \
|
||||
device I * max_indices [[buffer(7)]], \
|
||||
constant EmbeddingBagParams<uint32_t> & params [[buffer(8)]], \
|
||||
uint tid [[thread_position_in_grid]]); \
|
||||
\
|
||||
template [[host_name("embedding_bag_backward_" #T "_" #I)]] \
|
||||
kernel void embedding_bag_backward<T, I>( \
|
||||
constant T * output_grad [[buffer(0)]], \
|
||||
constant I * indices [[buffer(1)]], \
|
||||
constant I * offset2bag [[buffer(2)]], \
|
||||
constant I * bag_size [[buffer(3)]], \
|
||||
constant I * max_indices [[buffer(4)]], \
|
||||
constant T * per_sample_weights [[buffer(5)]], \
|
||||
device AtomicType_t<T> * weight_grad [[buffer(6)]], \
|
||||
constant EmbeddingBagBackwardParams<uint32_t> & params [[buffer(7)]], \
|
||||
uint tid [[thread_position_in_grid]]); \
|
||||
\
|
||||
template \
|
||||
[[host_name("embedding_bag_per_sample_weights_backward_" #T "_" #I)]] \
|
||||
kernel void embedding_bag_per_sample_weights_backward<T, I>( \
|
||||
constant T * output_grad [[buffer(0)]], \
|
||||
constant T * weight [[buffer(1)]], \
|
||||
constant I * indices [[buffer(2)]], \
|
||||
constant I * offset2bag [[buffer(3)]], \
|
||||
device AtomicType_t<T> * per_sample_weights_grad [[buffer(4)]], \
|
||||
constant EmbeddingBagPerSampleWeightsBackwardParams<uint32_t> & \
|
||||
params [[buffer(5)]], \
|
||||
uint tid [[thread_position_in_grid]]);
|
||||
|
||||
REGISTER_EMBEDDING_BAG_OP(float, int);
|
||||
REGISTER_EMBEDDING_BAG_OP(float, long);
|
||||
|
|
|
|||
|
|
@ -13,8 +13,10 @@
|
|||
#include <ATen/Functions.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#else
|
||||
#include <ATen/ops/_embedding_bag_dense_backward_native.h>
|
||||
#include <ATen/ops/_embedding_bag_forward_only_native.h>
|
||||
#include <ATen/ops/_embedding_bag_native.h>
|
||||
#include <ATen/ops/_embedding_bag_per_sample_weights_backward_native.h>
|
||||
#include <ATen/ops/empty.h>
|
||||
#endif
|
||||
|
||||
|
|
@ -95,6 +97,7 @@ static std::tuple<Tensor, Tensor, Tensor, Tensor> _embedding_bag_mps_impl(
|
|||
}
|
||||
|
||||
bool use_per_sample_weights = per_sample_weights_opt.has_value() && per_sample_weights_opt->defined();
|
||||
params.use_per_sample_weights = use_per_sample_weights;
|
||||
params.per_sample_weights_stride = use_per_sample_weights ? per_sample_weights_opt->stride(0) : 0;
|
||||
|
||||
params.num_indices = num_indices;
|
||||
|
|
@ -177,4 +180,117 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> _embedding_bag_forward_only_mps(
|
|||
padding_idx);
|
||||
}
|
||||
|
||||
Tensor _embedding_bag_dense_backward_mps(const Tensor& output_grad,
|
||||
const Tensor& indices,
|
||||
const Tensor& offset2bag,
|
||||
const Tensor& bag_size,
|
||||
const Tensor& max_indices,
|
||||
int64_t num_weights,
|
||||
bool scale_grad_by_freq,
|
||||
int64_t mode,
|
||||
const std::optional<Tensor>& per_sample_weights_opt,
|
||||
int64_t padding_idx) {
|
||||
// indices and offset2bag are assumed having correct dtypes and
|
||||
// contiguous here due to the checks in _embedding_bag_backward in
|
||||
// EmbeddingBag.cpp.
|
||||
// Also see NOTE [ embedding_bag Native Functions ] in native_functions.yaml
|
||||
// for more details.
|
||||
|
||||
int64_t feature_size = output_grad.size(1);
|
||||
auto weight_grad = at::zeros({num_weights, feature_size}, output_grad.options());
|
||||
EmbeddingBagBackwardParams<uint32_t> params;
|
||||
|
||||
for (const auto dim : c10::irange(2)) {
|
||||
params.output_grad_strides[dim] = output_grad.stride(dim);
|
||||
params.weight_grad_strides[dim] = weight_grad.stride(dim);
|
||||
|
||||
if (mode == EmbeddingBagMode::MAX) {
|
||||
params.max_indices_strides[dim] = safe_downcast<uint32_t, int64_t>(max_indices.stride(dim));
|
||||
}
|
||||
}
|
||||
|
||||
bool use_per_sample_weights = per_sample_weights_opt.has_value() && per_sample_weights_opt->defined();
|
||||
params.use_per_sample_weights = use_per_sample_weights;
|
||||
params.per_sample_weights_stride = use_per_sample_weights ? per_sample_weights_opt->stride(0) : 0;
|
||||
params.feature_size = output_grad.size(1);
|
||||
params.mode = static_cast<EmbeddingBagMode>(mode);
|
||||
params.padding_idx = padding_idx;
|
||||
|
||||
auto num_indices = offset2bag.numel();
|
||||
auto num_threads = (params.mode == EmbeddingBagMode::MAX) ? output_grad.numel() : num_indices * params.feature_size;
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
mps::dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||
@autoreleasepool {
|
||||
id<MTLComputeCommandEncoder> computeEncoder = stream->commandEncoder();
|
||||
auto pipeline_state = lib.getPipelineStateForFunc(fmt::format("embedding_bag_backward_{}_{}",
|
||||
mps::scalarToMetalTypeString(output_grad),
|
||||
mps::scalarToMetalTypeString(indices)));
|
||||
|
||||
getMPSProfiler().beginProfileKernel(
|
||||
pipeline_state, "embedding_bag", {output_grad, indices, offset2bag, bag_size});
|
||||
[computeEncoder setComputePipelineState:pipeline_state];
|
||||
mps::mtl_setArgs(computeEncoder,
|
||||
output_grad,
|
||||
indices,
|
||||
offset2bag,
|
||||
bag_size,
|
||||
max_indices,
|
||||
use_per_sample_weights ? per_sample_weights_opt : std::nullopt,
|
||||
weight_grad,
|
||||
params);
|
||||
|
||||
mps::mtl_dispatch1DJob(computeEncoder, pipeline_state, num_threads);
|
||||
getMPSProfiler().endProfileKernel(pipeline_state);
|
||||
}
|
||||
});
|
||||
|
||||
return std::move(weight_grad);
|
||||
}
|
||||
|
||||
Tensor _embedding_bag_per_sample_weights_backward_mps(const Tensor& output_grad,
|
||||
const Tensor& weight,
|
||||
const Tensor& indices,
|
||||
const Tensor& offsets,
|
||||
const Tensor& offset2bag,
|
||||
int64_t mode,
|
||||
int64_t padding_idx) {
|
||||
TORCH_INTERNAL_ASSERT(static_cast<EmbeddingBagMode>(mode) == EmbeddingBagMode::SUM);
|
||||
int64_t num_indices = indices.size(0);
|
||||
int64_t feature_size = output_grad.size(1);
|
||||
auto per_sample_weights_grad = at::zeros({num_indices}, output_grad.options());
|
||||
EmbeddingBagPerSampleWeightsBackwardParams params;
|
||||
|
||||
for (const auto dim : c10::irange(2)) {
|
||||
params.output_grad_strides[dim] = output_grad.stride(dim);
|
||||
params.weight_strides[dim] = weight.stride(dim);
|
||||
}
|
||||
|
||||
params.per_sample_weights_grad_stride = per_sample_weights_grad.stride(0);
|
||||
params.feature_size = feature_size;
|
||||
params.padding_idx = padding_idx;
|
||||
|
||||
auto num_threads = num_indices * feature_size;
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
mps::dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||
@autoreleasepool {
|
||||
id<MTLComputeCommandEncoder> computeEncoder = stream->commandEncoder();
|
||||
auto pipeline_state = lib.getPipelineStateForFunc(fmt::format("embedding_bag_per_sample_weights_backward_{}_{}",
|
||||
mps::scalarToMetalTypeString(output_grad),
|
||||
mps::scalarToMetalTypeString(indices)));
|
||||
|
||||
getMPSProfiler().beginProfileKernel(
|
||||
pipeline_state, "embedding_bag_per_sample_weights_backward", {output_grad, weight, indices, offset2bag});
|
||||
[computeEncoder setComputePipelineState:pipeline_state];
|
||||
mps::mtl_setArgs(computeEncoder, output_grad, weight, indices, offset2bag, per_sample_weights_grad, params);
|
||||
|
||||
mps::mtl_dispatch1DJob(computeEncoder, pipeline_state, num_threads);
|
||||
getMPSProfiler().endProfileKernel(pipeline_state);
|
||||
}
|
||||
});
|
||||
|
||||
return std::move(per_sample_weights_grad);
|
||||
}
|
||||
|
||||
} // namespace at::native
|
||||
|
|
|
|||
|
|
@ -2379,7 +2379,7 @@
|
|||
|
||||
- func: _embedding_bag_backward(Tensor grad, Tensor indices, Tensor offsets, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, SymInt num_weights, bool scale_grad_by_freq, int mode, bool sparse, Tensor? per_sample_weights, int padding_idx=-1) -> Tensor
|
||||
dispatch:
|
||||
CPU, CUDA: _embedding_bag_backward_symint
|
||||
CPU, CUDA, MPS: _embedding_bag_backward_symint
|
||||
|
||||
- func: _embedding_bag_sparse_backward(Tensor grad, Tensor indices, Tensor offsets, Tensor offset2bag, Tensor bag_size, SymInt num_weights, bool scale_grad_by_freq, int mode, Tensor? per_sample_weights, int padding_idx=-1) -> Tensor
|
||||
dispatch:
|
||||
|
|
@ -2389,12 +2389,14 @@
|
|||
dispatch:
|
||||
CPU: _embedding_bag_dense_backward_cpu
|
||||
CUDA: _embedding_bag_dense_backward_cuda
|
||||
MPS: _embedding_bag_dense_backward_mps
|
||||
autogen: _embedding_bag_dense_backward.out
|
||||
|
||||
- func: _embedding_bag_per_sample_weights_backward(Tensor grad, Tensor weight, Tensor indices, Tensor offsets, Tensor offset2bag, int mode, int padding_idx=-1) -> Tensor
|
||||
dispatch:
|
||||
CPU: _embedding_bag_per_sample_weights_backward_cpu
|
||||
CUDA: _embedding_bag_per_sample_weights_backward_cuda
|
||||
MPS: _embedding_bag_per_sample_weights_backward_mps
|
||||
autogen: _embedding_bag_per_sample_weights_backward.out
|
||||
|
||||
- func: empty.names(int[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
|
||||
|
|
|
|||
|
|
@ -6940,70 +6940,6 @@ class TestMPS(TestCaseMPS):
|
|||
with self.assertRaisesRegex(RuntimeError, "Index to scalar can have only 1 value"):
|
||||
helper(22, 0, [])
|
||||
|
||||
# TODO: This test can be removed once the backward pass of embedding_bag is
|
||||
# implemented and tested
|
||||
@parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
|
||||
@parametrize("idx_dtype", [torch.long, torch.int])
|
||||
@parametrize("padding_idx", [-1, 1])
|
||||
@parametrize("include_last_offset", [True, False])
|
||||
@parametrize("mode", ['sum', 'mean', 'max'])
|
||||
def test__embedding_bag(self, dtype, idx_dtype, padding_idx, include_last_offset, mode):
|
||||
import time
|
||||
torch.manual_seed(time.time() * 1000)
|
||||
mode_num = {'sum': 0, 'mean': 1, 'max': 2}[mode]
|
||||
num_words = 10
|
||||
feature_size = 7
|
||||
num_indices = 40
|
||||
num_bags = 5
|
||||
|
||||
weight_cpu = torch.randn(num_words, feature_size, dtype=dtype)
|
||||
|
||||
# Test nan value behavior.
|
||||
# Set second element of each word to nan.
|
||||
weight_cpu[:, 1] = float('nan')
|
||||
# Set third element of a randomized half of the words to nan.
|
||||
weight_cpu[torch.randperm(num_words)[:num_words // 2], 2] = float('nan')
|
||||
# Set fourth element of one randomized word to nan.
|
||||
weight_cpu[torch.randint(0, num_words, ()), 3] = float('nan')
|
||||
|
||||
input_cpu = torch.randint(0, num_words, (num_indices,), dtype=idx_dtype)
|
||||
offsets_cpu = torch.tensor(
|
||||
[0] + (torch.randperm(num_indices - 1)[:num_bags - 1].sort()[0] + 1).tolist(),
|
||||
dtype=idx_dtype)
|
||||
|
||||
if include_last_offset:
|
||||
offsets_cpu[-1] = input_cpu.numel()
|
||||
|
||||
per_sample_weights_cpu = torch.randn(num_indices, dtype=dtype) if mode == 'sum' else None
|
||||
|
||||
r_cpu, offset2bag_cpu, bag_size_cpu, max_indices_cpu = torch._embedding_bag(
|
||||
weight_cpu,
|
||||
input_cpu,
|
||||
offsets_cpu,
|
||||
per_sample_weights=per_sample_weights_cpu,
|
||||
mode=mode_num,
|
||||
padding_idx=padding_idx,
|
||||
include_last_offset=include_last_offset,
|
||||
)
|
||||
r_mps, offset2bag_mps, bag_size_mps, max_indices_mps = torch._embedding_bag(
|
||||
weight_cpu.to('mps'),
|
||||
input_cpu.to('mps'),
|
||||
offsets_cpu.to('mps'),
|
||||
per_sample_weights=per_sample_weights_cpu.to('mps') if per_sample_weights_cpu is not None else None,
|
||||
mode=mode_num,
|
||||
padding_idx=padding_idx,
|
||||
include_last_offset=include_last_offset,
|
||||
)
|
||||
|
||||
self.assertEqual(r_cpu, r_mps)
|
||||
|
||||
if mode != 'sum':
|
||||
self.assertEqual(offset2bag_cpu, offset2bag_mps)
|
||||
self.assertEqual(bag_size_cpu, bag_size_mps)
|
||||
|
||||
if mode == 'max':
|
||||
self.assertEqual(max_indices_cpu, max_indices_mps)
|
||||
|
||||
def test_embedding_dense_backward(self):
|
||||
def helper(n, d, m, idx):
|
||||
embeddingMPS = nn.Embedding(n, d, max_norm=True, device='mps')
|
||||
|
|
@ -12530,6 +12466,8 @@ class TestConsistency(TestCaseMPS):
|
|||
# several output grad elements of similar magnitudes get summed
|
||||
# together, introducing significant error for float16.
|
||||
atol, rtol = 5e-3, 5e-3
|
||||
if op.name == "nn.functional.embedding_bag" and dtype == torch.float16:
|
||||
atol, rtol = 5e-3, 5e-3
|
||||
self.assertEqual(cpu_grad_inputs, mps_grad_inputs, atol=atol, rtol=rtol)
|
||||
|
||||
# The CPU impl of grid_sampler_3d gives a large amount of error for half
|
||||
|
|
|
|||
|
|
@ -16,7 +16,9 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__adaptive_avg_pool2d_backward(At
|
|||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__cdist_forward(AtenTensorHandle x1, AtenTensorHandle x2, double p, int64_t* compute_mode, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__efficientzerotensor(const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__embedding_bag(AtenTensorHandle weight, AtenTensorHandle indices, AtenTensorHandle offsets, int32_t scale_grad_by_freq, int64_t mode, int32_t sparse, AtenTensorHandle* per_sample_weights, int32_t include_last_offset, int64_t padding_idx, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__embedding_bag_dense_backward(AtenTensorHandle grad, AtenTensorHandle indices, AtenTensorHandle offset2bag, AtenTensorHandle bag_size, AtenTensorHandle maximum_indices, int64_t num_weights, int32_t scale_grad_by_freq, int64_t mode, AtenTensorHandle* per_sample_weights, int64_t padding_idx, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__embedding_bag_forward_only(AtenTensorHandle weight, AtenTensorHandle indices, AtenTensorHandle offsets, int32_t scale_grad_by_freq, int64_t mode, int32_t sparse, AtenTensorHandle* per_sample_weights, int32_t include_last_offset, int64_t padding_idx, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__embedding_bag_per_sample_weights_backward(AtenTensorHandle grad, AtenTensorHandle weight, AtenTensorHandle indices, AtenTensorHandle offsets, AtenTensorHandle offset2bag, int64_t mode, int64_t padding_idx, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fft_c2c(AtenTensorHandle self, const int64_t* dim, int64_t dim_len_, int64_t normalization, int32_t forward, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fft_r2c(AtenTensorHandle self, const int64_t* dim, int64_t dim_len_, int64_t normalization, int32_t onesided, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fused_moving_avg_obs_fq_helper_functional(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4, AtenTensorHandle* ret5);
|
||||
|
|
|
|||
|
|
@ -737,7 +737,6 @@ if torch.backends.mps.is_available():
|
|||
"equal": [torch.float16, torch.float32],
|
||||
# 'float' object is not iterable
|
||||
"item": [torch.float16, torch.float32],
|
||||
"nn.functional.embedding_bag": None,
|
||||
# "smooth_l1_backward_cpu_out" not implemented for 'Half'
|
||||
"nn.functional.smooth_l1_loss": [torch.float16],
|
||||
# cpu error: grad requires non-empty inputs
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user