mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Tensor construction codemod(ResizeLike) - 1/7 (#15073)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/15073 Codemod generated with clangr shard mode, 25 files per diff, motivation: https://github.com/pytorch/pytorch/pull/12407 Reviewed By: dzhulgakov Differential Revision: D13419563 fbshipit-source-id: 8c284405fa3a867303216df876ee6b20d8a46551
This commit is contained in:
parent
2db742fc95
commit
3fc889e976
|
|
@ -257,8 +257,8 @@ class MaxPoolGradientRTCOp final : public ConvPoolOpBase<CUDAContext> {
|
|||
auto& Y = Input(1);
|
||||
auto& dY = Input(2);
|
||||
CAFFE_ENFORCE_EQ(dY.dim(), 4);
|
||||
auto* dX = Output(0);
|
||||
dX->ResizeLike(X);
|
||||
|
||||
auto* dX = Output(0, X.sizes(), at::dtype<float>());
|
||||
ConvPoolOpBase<CUDAContext>::ComputePads({X.dim32(2), X.dim32(3)});
|
||||
if (input_dims_ != X.sizes()) {
|
||||
VLOG(1) << "MaxPoolGradient RTC recompiling";
|
||||
|
|
|
|||
|
|
@ -138,11 +138,9 @@ class FullyConnectedDecompGradientOp : public Operator<Context> {
|
|||
DCHECK_EQ(X.dim(), 1);
|
||||
DCHECK_EQ(N, dY.numel());
|
||||
}
|
||||
auto* dU = Output(0);
|
||||
auto* dV = Output(1);
|
||||
|
||||
dU->ResizeLike(U);
|
||||
dV->ResizeLike(V);
|
||||
auto* dU = Output(0, U.sizes(), at::dtype<T>());
|
||||
auto* dV = Output(1, V.sizes(), at::dtype<T>());
|
||||
auto* db = Output(2, {N}, at::dtype<T>());
|
||||
|
||||
// Compute dU
|
||||
|
|
@ -189,8 +187,7 @@ class FullyConnectedDecompGradientOp : public Operator<Context> {
|
|||
&context_);
|
||||
// Compute dX if necessary.
|
||||
if (OutputSize() == 4) {
|
||||
auto* dX = Output(3);
|
||||
dX->ResizeLike(X);
|
||||
auto* dX = Output(3, X.sizes(), at::dtype<T>());
|
||||
dx_buffer_.Resize(M, middle);
|
||||
T* dx_buffer_data = dx_buffer_.template mutable_data<T>();
|
||||
math::Gemm<T, Context, Engine>(
|
||||
|
|
|
|||
|
|
@ -220,7 +220,7 @@ namespace caffe2 {
|
|||
auto* Ag_dW_ptr = Output(4);
|
||||
auto& Ag_dW = *Ag_dW_ptr;
|
||||
// it is also the Input(5)
|
||||
auto* mask_seq_auto = Output(5);
|
||||
|
||||
// how about get threshold
|
||||
auto& thres = Input(6);
|
||||
//TODO(wyiming): check comp_lb is a float
|
||||
|
|
@ -251,9 +251,8 @@ namespace caffe2 {
|
|||
DCHECK_EQ(X.dim(), 1);
|
||||
DCHECK_EQ(N, dY.numel());
|
||||
}
|
||||
auto* dW = Output(0);
|
||||
|
||||
dW->ResizeLike(W);
|
||||
auto* dW = Output(0, W.sizes(), at::dtype<T>());
|
||||
auto* db = Output(1, {N}, at::dtype<T>());
|
||||
|
||||
// Compute dW
|
||||
|
|
@ -292,7 +291,7 @@ namespace caffe2 {
|
|||
Ag_dW.template mutable_data<T>(),
|
||||
sum_buffer_.template mutable_data<T>(),
|
||||
&context_);
|
||||
mask_seq_auto->ResizeLike(W);
|
||||
auto* mask_seq_auto = Output(5, W.sizes(), at::dtype<T>());
|
||||
T* mask_seq = mask_seq_auto->template mutable_data<T>();
|
||||
math::Set<T, Context>(N*K, static_cast<T>(0),
|
||||
mask_seq_auto->template mutable_data<T>(), &context_);
|
||||
|
|
@ -338,8 +337,7 @@ namespace caffe2 {
|
|||
&context_);
|
||||
// Compute dX if necessary.
|
||||
if (OutputSize() == 7) {
|
||||
auto* dX = Output(6);
|
||||
dX->ResizeLike(X);
|
||||
auto* dX = Output(6, X.sizes(), at::dtype<T>());
|
||||
math::Gemm<T, Context, Engine>(
|
||||
CblasNoTrans, CblasNoTrans, M, K, N, 1,
|
||||
dY.template data<T>(), W.template data<T>(),
|
||||
|
|
|
|||
|
|
@ -164,8 +164,8 @@ class FunHashGradientOp : public Operator<Context> {
|
|||
if (adaptive_) {
|
||||
const auto& alpha = Input(5);
|
||||
num_alpha = alpha.size(0);
|
||||
auto* grad_alpha = Output(1);
|
||||
grad_alpha->ResizeLike(alpha);
|
||||
|
||||
auto* grad_alpha = Output(1, alpha.sizes(), at::dtype<T>());
|
||||
grad_alpha_data = grad_alpha->template mutable_data<T>();
|
||||
memset(grad_alpha_data, 0, sizeof(T) * num_alpha);
|
||||
}
|
||||
|
|
@ -175,8 +175,7 @@ class FunHashGradientOp : public Operator<Context> {
|
|||
int64_t num_weight = weight.size(0);
|
||||
int64_t num_nz_ent = seg.size(0);
|
||||
|
||||
auto* grad_weight = Output(0);
|
||||
grad_weight->ResizeLike(weight);
|
||||
auto* grad_weight = Output(0, weight.sizes(), at::dtype<T>());
|
||||
T* grad_weight_data = grad_weight->template mutable_data<T>();
|
||||
|
||||
const auto* grad_out_data = grad_out.template data<T>();
|
||||
|
|
|
|||
|
|
@ -163,8 +163,8 @@ class SparseFunHashGradientOp : public Operator<Context> {
|
|||
if (adaptive_) {
|
||||
const auto& alpha = Input(5);
|
||||
num_alpha = alpha.size(0);
|
||||
auto* grad_alpha = Output(2);
|
||||
grad_alpha->ResizeLike(alpha);
|
||||
|
||||
auto* grad_alpha = Output(2, alpha.sizes(), at::dtype<T>());
|
||||
grad_alpha_data = grad_alpha->template mutable_data<T>();
|
||||
memset(grad_alpha_data, 0, sizeof(T) * num_alpha);
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user