mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[2/N] Avoid copy in std::get (#141826)
Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/141826 Approved by: https://github.com/Skylion007, https://github.com/malfet Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
This commit is contained in:
parent
b2fe1b9409
commit
fc74ec4989
|
|
@ -205,18 +205,17 @@ convolution_backward_input_batch_rule(
|
|||
const auto result = at::convolution_backward_symint(
|
||||
grad_output, dummy_input, weight_, std::nullopt, stride, padding,
|
||||
dilation, transposed, output_padding, groups, mask);
|
||||
const auto grad_input = reshape_dim_outof(1, batch_size, std::get<0>(result));
|
||||
return std::make_tuple(grad_input, 1);
|
||||
auto grad_input = reshape_dim_outof(1, batch_size, std::get<0>(result));
|
||||
return std::make_tuple(std::move(grad_input), 1);
|
||||
}
|
||||
Tensor grad_input;
|
||||
if (!transposed) {
|
||||
// N(GO), B(GO)I -> N(GO), (GO)(BI) -> N(GBI)
|
||||
const auto weight_ = reshape_dim_into(*weight_bdim, 1, weight);
|
||||
auto dummy_input = make_dummy(input, input_bdim, 1, batch_size);
|
||||
const auto result = at::convolution_backward_symint(
|
||||
grad_input = std::get<0>(at::convolution_backward_symint(
|
||||
grad_output, dummy_input, weight_, std::nullopt, stride, padding,
|
||||
dilation, transposed, output_padding, groups, mask);
|
||||
grad_input = std::get<0>(result); // N(GBI)
|
||||
dilation, transposed, output_padding, groups, mask)); // N(GBI)
|
||||
} else {
|
||||
// N(GO), B(GI)O -> N(GO), (GBI)O -> N(GBI)
|
||||
auto weight_ = moveBatchDimToFront(weight, weight_bdim); // B(GI)O
|
||||
|
|
@ -224,24 +223,23 @@ convolution_backward_input_batch_rule(
|
|||
weight_ = weight_.transpose(0, 1); // GBIO
|
||||
weight_ = weight_.flatten(0, 2); // (GBI)O
|
||||
const auto dummy_input = make_dummy(input, input_bdim, 1, batch_size);
|
||||
const auto result = at::convolution_backward_symint(
|
||||
grad_input = std::get<0>(at::convolution_backward_symint(
|
||||
grad_output, dummy_input, weight_, std::nullopt, stride, padding,
|
||||
dilation, transposed, output_padding, groups, mask);
|
||||
grad_input = std::get<0>(result); // N(GBI)
|
||||
dilation, transposed, output_padding, groups, mask)); // N(GBI)
|
||||
}
|
||||
// N(GBI) -> NG(BI) -> NGBI -> NBGI -> NB(GI)
|
||||
grad_input = reshape_dim_outof_symint(1, groups, grad_input);
|
||||
grad_input = reshape_dim_outof_symint(2, batch_size, grad_input);
|
||||
grad_input = grad_input.transpose(1, 2);
|
||||
grad_input = reshape_dim_into(2, 2, grad_input);
|
||||
return std::make_tuple(grad_input, 1);
|
||||
return std::make_tuple(std::move(grad_input), 1);
|
||||
} else {
|
||||
TORCH_INTERNAL_ASSERT(input_bdim);
|
||||
const auto dummy_input = make_dummy(input, input_bdim, 0, 1);
|
||||
const auto result = at::convolution_backward_symint(
|
||||
auto result = at::convolution_backward_symint(
|
||||
grad_output, dummy_input, weight, std::nullopt, stride, padding,
|
||||
dilation, transposed, output_padding, groups, mask);
|
||||
return std::make_tuple(std::get<0>(result), std::nullopt);
|
||||
return std::make_tuple(std::move(std::get<0>(result)), std::nullopt);
|
||||
}
|
||||
}
|
||||
static std::tuple<Tensor, std::optional<int64_t>>
|
||||
|
|
@ -258,12 +256,12 @@ convolution_backward_weight_batch_rule(
|
|||
const auto grad_output_ = reshape_dim_into(*grad_output_bdim, 1, grad_output);
|
||||
const auto input_ = reshape_dim_into(*input_bdim, 1, input);
|
||||
const auto dummy_weight = make_dummy(weight, weight_bdim, 0, batch_size);
|
||||
const auto result = at::convolution_backward_symint(
|
||||
auto result = at::convolution_backward_symint(
|
||||
grad_output_, input_, dummy_weight, std::nullopt, stride, padding,
|
||||
dilation, transposed, output_padding, groups * batch_size, mask);
|
||||
auto grad_weight = std::get<1>(result);
|
||||
auto& grad_weight = std::get<1>(result);
|
||||
grad_weight = reshape_dim_outof_symint(0, batch_size, grad_weight);
|
||||
return std::make_tuple(grad_weight, 0);
|
||||
return std::make_tuple(std::move(grad_weight), 0);
|
||||
} else if (grad_output_bdim && !input_bdim) {
|
||||
const auto batch_size = grad_output.size(*grad_output_bdim);
|
||||
if (groups == 1) {
|
||||
|
|
@ -327,10 +325,10 @@ convolution_backward_weight_batch_rule(
|
|||
if (!transposed) {
|
||||
// regular: N(GO), BN(GI) -> N(GO), N(GBI) -> (GO)(BI)
|
||||
const auto dummy_weight = make_dummy(weight, weight_bdim, 1, batch_size);
|
||||
const auto result = at::convolution_backward_symint(
|
||||
auto result = at::convolution_backward_symint(
|
||||
grad_output, input_, dummy_weight, std::nullopt, stride, padding,
|
||||
dilation, transposed, output_padding, groups, mask);
|
||||
auto grad_weight = std::get<1>(result);
|
||||
auto& grad_weight = std::get<1>(result);
|
||||
grad_weight = reshape_dim_outof_symint(1, batch_size, grad_weight);
|
||||
return std::make_tuple(grad_weight, 1);
|
||||
} else {
|
||||
|
|
@ -423,23 +421,23 @@ static std::tuple<Tensor,Tensor,Tensor> convolution_backward_plumbing(
|
|||
Tensor grad_input;
|
||||
if (output_mask[0]) {
|
||||
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
|
||||
auto result = convolution_backward_input_batch_rule(
|
||||
auto [tensor, bdim] = convolution_backward_input_batch_rule(
|
||||
grad_output, grad_output_bdim,
|
||||
input, input_bdim,
|
||||
weight, weight_bdim,
|
||||
stride, padding, dilation, transposed, output_padding, groups);
|
||||
grad_input = makeBatched(std::get<0>(result), std::get<1>(result), cur_level);
|
||||
grad_input = makeBatched(tensor, bdim, cur_level);
|
||||
}
|
||||
|
||||
Tensor grad_weight;
|
||||
if (output_mask[1]) {
|
||||
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
|
||||
const auto result = convolution_backward_weight_batch_rule(
|
||||
auto [tensor, bdim] = convolution_backward_weight_batch_rule(
|
||||
grad_output, grad_output_bdim,
|
||||
input, input_bdim,
|
||||
weight, weight_bdim,
|
||||
stride, padding, dilation, transposed, output_padding, groups);
|
||||
grad_weight = makeBatched(std::get<0>(result), std::get<1>(result), cur_level);
|
||||
grad_weight = makeBatched(tensor, bdim, cur_level);
|
||||
}
|
||||
return std::make_tuple(grad_input, grad_weight, grad_bias);
|
||||
|
||||
|
|
|
|||
|
|
@ -161,16 +161,14 @@ grid_sample_backward_helper_in(
|
|||
|
||||
static std::tuple<Tensor, std::optional<int64_t>, Tensor, std::optional<int64_t>>
|
||||
grid_sample_backward_helper_out(
|
||||
const std::tuple<Tensor, Tensor> & bw_out,
|
||||
std::tuple<Tensor, Tensor> bw_out,
|
||||
std::optional<int64_t> grad_input_out_bdim,
|
||||
std::optional<int64_t> grad_grid_out_bdim,
|
||||
int64_t bdim_size) {
|
||||
auto grad_input = std::get<0>(bw_out);
|
||||
auto grad_grid = std::get<1>(bw_out);
|
||||
auto& [grad_input, grad_grid] = bw_out;
|
||||
grad_input = reshape_dim_outof(*grad_input_out_bdim, bdim_size, grad_input);
|
||||
grad_grid = reshape_dim_outof(*grad_grid_out_bdim, bdim_size, grad_grid);
|
||||
auto result = std::make_tuple(grad_input, grad_input_out_bdim, grad_grid, grad_grid_out_bdim);
|
||||
return result;
|
||||
return std::make_tuple(std::move(grad_input), grad_input_out_bdim, std::move(grad_grid), grad_grid_out_bdim);
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -185,34 +183,26 @@ grid_sample_backward_batch_rule(
|
|||
auto new_bw_input = grid_sample_backward_helper_in(
|
||||
grad_output, grad_output_bdim, input, input_bdim, grid, grid_bdim);
|
||||
|
||||
auto new_grad_output = std::get<0>(new_bw_input);
|
||||
auto new_input = std::get<1>(new_bw_input);
|
||||
auto new_grid = std::get<2>(new_bw_input);
|
||||
int64_t batch_size = std::get<3>(new_bw_input);
|
||||
auto [new_grad_output, new_input, new_grid, batch_size] = new_bw_input;
|
||||
|
||||
auto bw_out = Func(new_grad_output, new_input, new_grid, std::forward<ExtraArgs>(extra_args)...);
|
||||
auto bw_out = Func(std::move(new_grad_output), std::move(new_input), std::move(new_grid), std::forward<ExtraArgs>(extra_args)...);
|
||||
|
||||
return grid_sample_backward_helper_out(bw_out, 0, 0, batch_size);
|
||||
return grid_sample_backward_helper_out(std::move(bw_out), 0, 0, batch_size);
|
||||
}
|
||||
|
||||
template<typename F, F Func>
|
||||
std::tuple<Tensor, std::optional<int64_t>, Tensor, std::optional<int64_t>>
|
||||
cudnn_grid_sample_backward_batch_rule(
|
||||
static cudnn_grid_sample_backward_batch_rule(
|
||||
const Tensor& input, std::optional<int64_t> input_bdim,
|
||||
const Tensor& grid, std::optional<int64_t> grid_bdim,
|
||||
const Tensor& grad_output, std::optional<int64_t> grad_output_bdim) {
|
||||
|
||||
auto new_bw_input = grid_sample_backward_helper_in(
|
||||
auto [new_grad_output,new_input,new_grid,bdim_size]= grid_sample_backward_helper_in(
|
||||
grad_output, grad_output_bdim, input, input_bdim, grid, grid_bdim);
|
||||
|
||||
auto new_grad_output = std::get<0>(new_bw_input);
|
||||
auto new_input = std::get<1>(new_bw_input);
|
||||
auto new_grid = std::get<2>(new_bw_input);
|
||||
int64_t bdim_size = std::get<3>(new_bw_input);
|
||||
auto bw_out = Func(std::move(new_input), std::move(new_grid), std::move(new_grad_output));
|
||||
|
||||
auto bw_out = Func(new_input, new_grid, new_grad_output);
|
||||
|
||||
return grid_sample_backward_helper_out(bw_out, 0, 0, bdim_size);
|
||||
return grid_sample_backward_helper_out(std::move(bw_out), 0, 0, bdim_size);
|
||||
}
|
||||
|
||||
// TODO: replace with targetable functionalization
|
||||
|
|
|
|||
|
|
@ -88,7 +88,7 @@ static std::vector<std::optional<Tensor>> batchIndices(
|
|||
bool indices_batched = any_has_value(indices_bdims);
|
||||
|
||||
for (size_t i = 0; i < indices.size(); i++) {
|
||||
auto index = indices[i];
|
||||
auto const & index = indices[i];
|
||||
if (index.has_value() && index->sym_numel() != 0) {
|
||||
const auto idx_bdim = indices_bdims[i];
|
||||
indices_.emplace_back(maybePadToLogicalRank(moveBatchDimToFront(index.value(), idx_bdim), idx_bdim, maxLogicalRank));
|
||||
|
|
|
|||
|
|
@ -1476,15 +1476,11 @@ calc_i0(T _x) {
|
|||
T x = std::abs(_x);
|
||||
|
||||
if (x <= T{8.0}) {
|
||||
auto coeff_pair = chebyshev_coefficients_i0e_A<T>();
|
||||
auto A = std::get<0>(coeff_pair);
|
||||
auto len = std::get<1>(coeff_pair);
|
||||
auto [A, len] = chebyshev_coefficients_i0e_A<T>();
|
||||
T y = (x / T{2.0}) - T{2.0};
|
||||
return static_cast<T>(std::exp(x) * chbevl(y, A, len));
|
||||
}
|
||||
auto coeff_pair = chebyshev_coefficients_i0e_B<T>();
|
||||
auto B = std::get<0>(coeff_pair);
|
||||
auto len = std::get<1>(coeff_pair);
|
||||
auto [B, len] = chebyshev_coefficients_i0e_B<T>();
|
||||
return std::exp(x) * chbevl(T{32.0} / x - T{2.0}, B, len) / std::sqrt(x);
|
||||
}
|
||||
|
||||
|
|
@ -1507,16 +1503,12 @@ calc_i1(T _x) {
|
|||
T x = std::abs(_x);
|
||||
|
||||
if (x <= T{8.0}) {
|
||||
auto coeff_pair = chebyshev_coefficients_i1e_A<T>();
|
||||
auto A = std::get<0>(coeff_pair);
|
||||
auto len = std::get<1>(coeff_pair);
|
||||
auto [A, len] = chebyshev_coefficients_i1e_A<T>();
|
||||
T y = (x / T{2.0}) - T{2.0};
|
||||
const T out = std::exp(x) * x * chbevl(y, A, len);
|
||||
return (_x < T{0.0}) ? -out : out;
|
||||
}
|
||||
auto coeff_pair = chebyshev_coefficients_i1e_B<T>();
|
||||
auto B = std::get<0>(coeff_pair);
|
||||
auto len = std::get<1>(coeff_pair);
|
||||
auto [B, len] = chebyshev_coefficients_i1e_B<T>();
|
||||
const T out = (std::exp(x) * chbevl(T{32.0} / x - T{2.0}, B, len)) / std::sqrt(x);
|
||||
return (_x < T{0.0}) ? -out : out;
|
||||
}
|
||||
|
|
@ -1541,16 +1533,12 @@ calc_i1e(T _x) {
|
|||
T x = std::abs(_x);
|
||||
|
||||
if (x <= T{8.0}) {
|
||||
auto coeff_pair = chebyshev_coefficients_i1e_A<T>();
|
||||
auto A = std::get<0>(coeff_pair);
|
||||
auto len = std::get<1>(coeff_pair);
|
||||
auto [A, len] = chebyshev_coefficients_i1e_A<T>();
|
||||
T y = (x / T{2.0}) - T{2.0};
|
||||
const T out = chbevl(y, A, len) * x;
|
||||
return (_x < T{0.0}) ? -out : out;
|
||||
}
|
||||
auto coeff_pair = chebyshev_coefficients_i1e_B<T>();
|
||||
auto B = std::get<0>(coeff_pair);
|
||||
auto len = std::get<1>(coeff_pair);
|
||||
auto [B, len] = chebyshev_coefficients_i1e_B<T>();
|
||||
const auto out = chbevl(T{32.0} / x - T{2.0}, B, len) / std::sqrt(x);
|
||||
return (_x < T{0.0}) ? -out : out;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3210,16 +3210,12 @@ static inline C10_HOST_DEVICE scalar_t calc_i0(scalar_t _x) {
|
|||
scalar_t x = ::abs(_x);
|
||||
|
||||
if (x <= scalar_t{8.0}) {
|
||||
auto coeff_pair = chebyshev_coefficients_i0e_A<scalar_t>();
|
||||
auto A = std::get<0>(coeff_pair);
|
||||
auto len = std::get<1>(coeff_pair);
|
||||
auto [A, len] = chebyshev_coefficients_i0e_A<scalar_t>();
|
||||
scalar_t y = (x / scalar_t{2.0}) - scalar_t{2.0};
|
||||
return (::exp(x) * chbevl(y, A, len));
|
||||
}
|
||||
|
||||
auto coeff_pair = chebyshev_coefficients_i0e_B<scalar_t>();
|
||||
auto B = std::get<0>(coeff_pair);
|
||||
auto len = std::get<1>(coeff_pair);
|
||||
auto [B, len] = chebyshev_coefficients_i0e_B<scalar_t>();
|
||||
return (::exp(x) * chbevl(scalar_t{32.0} / x - scalar_t{2.0}, B, len) / ::sqrt(x));
|
||||
}
|
||||
|
||||
|
|
@ -3334,17 +3330,13 @@ template <typename scalar_t>
|
|||
static inline C10_HOST_DEVICE scalar_t calc_i1(scalar_t _x) {
|
||||
const auto x = ::abs(_x);
|
||||
if (x <= scalar_t{8.0}) {
|
||||
auto coeff_pair = chebyshev_coefficients_i1e_A<scalar_t>();
|
||||
auto A = std::get<0>(coeff_pair);
|
||||
auto len = std::get<1>(coeff_pair);
|
||||
auto [A, len] = chebyshev_coefficients_i1e_A<scalar_t>();
|
||||
scalar_t y = x / scalar_t{2.0} - scalar_t{2.0};
|
||||
const scalar_t out = ::exp(x) * x * chbevl(y, A, len);
|
||||
return (_x < scalar_t{0.0}) ? -out : out;
|
||||
}
|
||||
|
||||
auto coeff_pair = chebyshev_coefficients_i1e_B<scalar_t>();
|
||||
auto B = std::get<0>(coeff_pair);
|
||||
auto len = std::get<1>(coeff_pair);
|
||||
auto [B, len] = chebyshev_coefficients_i1e_B<scalar_t>();
|
||||
const scalar_t out = (::exp(x) * chbevl(scalar_t{32.0} / x - scalar_t{2.0}, B, len)) / ::sqrt(x);
|
||||
return (_x < scalar_t{0.0}) ? -out : out;
|
||||
}
|
||||
|
|
@ -3353,17 +3345,13 @@ template <typename scalar_t>
|
|||
static inline C10_HOST_DEVICE scalar_t calc_i1e(scalar_t _x) {
|
||||
const auto x = ::abs(_x);
|
||||
if (x <= scalar_t{8.0}) {
|
||||
auto coeff_pair = chebyshev_coefficients_i1e_A<scalar_t>();
|
||||
auto A = std::get<0>(coeff_pair);
|
||||
auto len = std::get<1>(coeff_pair);
|
||||
auto [A, len] = chebyshev_coefficients_i1e_A<scalar_t>();
|
||||
const scalar_t y = x / scalar_t{2.0} - scalar_t{2.0};
|
||||
const scalar_t out = chbevl(y, A, len) * x;
|
||||
return (_x < scalar_t{0.0}) ? -out : out;
|
||||
}
|
||||
|
||||
auto coeff_pair = chebyshev_coefficients_i1e_B<scalar_t>();
|
||||
auto B = std::get<0>(coeff_pair);
|
||||
auto len = std::get<1>(coeff_pair);
|
||||
auto [B, len] = chebyshev_coefficients_i1e_B<scalar_t>();
|
||||
const scalar_t out = chbevl(scalar_t{32.0} / x - scalar_t{2.0}, B, len) / ::sqrt(x);
|
||||
return (_x < scalar_t{0.0}) ? -out : out;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -378,22 +378,15 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
|
|||
if (max_elements_per_tensor == 0)
|
||||
continue;
|
||||
|
||||
dim3 applyBlock, catGrid;
|
||||
|
||||
#ifdef USE_ROCM
|
||||
// always base grid size on max_elements_per_tensor
|
||||
{
|
||||
std::tuple<dim3, dim3> launchParams = getCatGridRocm<scalar_t>(
|
||||
auto [catGrid, applyBlock] = getCatGridRocm<scalar_t>(
|
||||
max_elements_per_tensor, batchCounter);
|
||||
catGrid = std::get<0>(launchParams);
|
||||
applyBlock = std::get<1>(launchParams);
|
||||
}
|
||||
#else
|
||||
dim3 applyBlock, catGrid;
|
||||
if (isContig && sizeof(scalar_t) > 2) {
|
||||
std::tuple<dim3, dim3> launchParams = getCatGridContig<scalar_t>(
|
||||
std::tie(catGrid, applyBlock) = getCatGridContig<scalar_t>(
|
||||
max_elements_per_tensor, batchCounter);
|
||||
catGrid = std::get<0>(launchParams);
|
||||
applyBlock = std::get<1>(launchParams);
|
||||
} else {
|
||||
applyBlock = dim3(32 * 16);
|
||||
getCatGrid(batchCounter, catGrid);
|
||||
|
|
|
|||
|
|
@ -2709,8 +2709,7 @@ void lstm_cudnn(
|
|||
bidirectional,
|
||||
batch_first);
|
||||
output = result.first;
|
||||
hy = std::get<0>(result.second);
|
||||
cy = std::get<1>(result.second);
|
||||
std::tie(hy, cy) = result.second;
|
||||
}
|
||||
|
||||
void lstm_packed_cudnn(
|
||||
|
|
@ -2738,8 +2737,7 @@ void lstm_packed_cudnn(
|
|||
train,
|
||||
bidirectional);
|
||||
output = result.first;
|
||||
hy = std::get<0>(result.second);
|
||||
cy = std::get<1>(result.second);
|
||||
std::tie(hy, cy) = result.second;
|
||||
}
|
||||
|
||||
REGISTER_CUDA_DISPATCH(lstm_cudnn_stub, &lstm_cudnn)
|
||||
|
|
|
|||
|
|
@ -19,18 +19,17 @@
|
|||
#endif
|
||||
|
||||
#include <array>
|
||||
#include <functional>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
namespace at::native {
|
||||
|
||||
template <typename T>
|
||||
void check_group_norm_inputs(
|
||||
static void check_group_norm_inputs(
|
||||
const Tensor& input,
|
||||
const Tensor& weight,
|
||||
const Tensor& bias,
|
||||
T C,
|
||||
const T& C,
|
||||
int64_t num_groups) {
|
||||
TORCH_CHECK(
|
||||
num_groups > 0,
|
||||
|
|
@ -237,8 +236,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> math_group_norm(
|
|||
/*training=*/true,
|
||||
/*momentum=*/0,
|
||||
eps);
|
||||
at::Tensor out = std::get<0>(outputs);
|
||||
out = out.view(input_shape);
|
||||
auto out = std::get<0>(outputs).view(input_shape);
|
||||
std::vector<int64_t> affine_param_shape(input.dim(), 1);
|
||||
affine_param_shape[1] = C;
|
||||
if (weight.defined() && bias.defined()) {
|
||||
|
|
@ -253,6 +251,6 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> math_group_norm(
|
|||
// This follows the same behavior as the CPU and CUDA kernels.
|
||||
at::Tensor mean = std::get<1>(outputs).to(c10::TensorOptions().dtype(input.scalar_type())).view({N, group});
|
||||
at::Tensor rstd = std::get<2>(outputs).to(c10::TensorOptions().dtype(input.scalar_type())).view({N, group});
|
||||
return std::make_tuple(out, mean, rstd);
|
||||
return std::make_tuple(std::move(out), std::move(mean), std::move(rstd));
|
||||
}
|
||||
} // namespace at::native
|
||||
|
|
|
|||
|
|
@ -241,8 +241,7 @@ std::tuple<Tensor, Tensor, Tensor> math_native_layer_norm(
|
|||
auto outputs = at::native_batch_norm(
|
||||
input_reshaped, /*weight=*/{}, /*bias=*/{}, /*running_mean=*/{},
|
||||
/*running_var=*/{}, /*training=*/true, /*momentum=*/0, eps);
|
||||
at::Tensor out = std::get<0>(outputs);
|
||||
out = out.view(input_shape);
|
||||
auto out = std::get<0>(outputs).view(input_shape);
|
||||
if (weight.defined() && bias.defined()) {
|
||||
out = bias.addcmul(out, weight, 1);
|
||||
} else if (weight.defined()) {
|
||||
|
|
@ -297,7 +296,7 @@ Tensor rms_norm_symint(
|
|||
c10::ScalarType opmath_t = toOpMathType(input.scalar_type());
|
||||
Tensor upcasted_input = input.to(opmath_t);
|
||||
|
||||
Tensor rqrst_input = rsqrt(at::pow(upcasted_input, 2).mean(dims_to_reduce_ref, /*keep_dim=*/true).add_(eps_val));
|
||||
auto rqrst_input = rsqrt(at::pow(upcasted_input, 2).mean(dims_to_reduce_ref, /*keepdim=*/true).add_(eps_val));
|
||||
Tensor result = upcasted_input.mul(rqrst_input).type_as(input);
|
||||
|
||||
if (weight_opt.has_value()) {
|
||||
|
|
|
|||
|
|
@ -553,9 +553,8 @@ void lstm_mkldnn(Tensor& output, Tensor& hy, Tensor& cy,
|
|||
int64_t num_layers, double dropout_p, bool train, bool bidirectional, bool batch_first) {
|
||||
auto result = mkldnn_impl(input, std::make_tuple(hx[0], hx[1]), params, has_biases,
|
||||
ideep::rnn_kind::LSTM, num_layers, dropout_p, train, bidirectional, batch_first);
|
||||
output = result.first;
|
||||
hy = std::get<0>(result.second);
|
||||
cy = std::get<1>(result.second);
|
||||
output = std::move(result.first);
|
||||
std::tie(hy, cy) = std::move(result.second);
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
|
|
|
|||
|
|
@ -200,7 +200,7 @@ bool ParameterMetadata::equal_to(const c10::Scalar& scalar) const {
|
|||
return false;
|
||||
}
|
||||
|
||||
auto self_scalar = std::get<c10::Scalar>(value_);
|
||||
const auto& self_scalar = std::get<c10::Scalar>(value_);
|
||||
if (scalar.isFloatingPoint() && self_scalar.isFloatingPoint()) {
|
||||
return self_scalar.toDouble() == scalar.toDouble();
|
||||
} else if (scalar.isIntegral(true) && self_scalar.isIntegral(true)) {
|
||||
|
|
|
|||
|
|
@ -203,19 +203,15 @@ std::string compile_so(
|
|||
std::string compile_flags_path = filename + "_compile_flags.json";
|
||||
const nlohmann::json compile_flags = load_json_file(compile_flags_path);
|
||||
|
||||
auto compile_result =
|
||||
auto [compile_cmd, output_o] =
|
||||
get_cpp_compile_command(filename, {cpp_filename}, compile_flags);
|
||||
std::string compile_cmd = std::get<0>(compile_result);
|
||||
std::string output_o = std::get<1>(compile_result);
|
||||
|
||||
std::string linker_flags_path =
|
||||
cpp_filename.substr(0, lastindex) + "_linker_flags.json";
|
||||
const nlohmann::json linker_flags = load_json_file(linker_flags_path);
|
||||
|
||||
auto link_result = get_cpp_compile_command(
|
||||
auto [link_cmd, output_so] = get_cpp_compile_command(
|
||||
filename, {output_o, consts_filename}, linker_flags);
|
||||
std::string link_cmd = std::get<0>(link_result);
|
||||
std::string output_so = std::get<1>(link_result);
|
||||
|
||||
// Run the commands to generate a .so file
|
||||
int status = system(compile_cmd.c_str());
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user