mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add torchcheck for replication_pad3d_backward (#151986)
Fixes #142833
Add check on channel dimension, logic same to the CUDA implementation 78bbb468c6/aten/src/ATen/native/cuda/ReplicationPadding.cu (L347)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151986
Approved by: https://github.com/mikaylagawarecki
This commit is contained in:
parent
51cd6697cd
commit
ee4c5c7cd2
|
|
@ -229,17 +229,20 @@ void replication_pad3d_backward_out_cpu_template(
|
|||
int pbottom = paddingSize[3];
|
||||
int pfront = paddingSize[4];
|
||||
int pback = paddingSize[5];
|
||||
int dimc = 0;
|
||||
int dimw = 3;
|
||||
int dimh = 2;
|
||||
int dimd = 1;
|
||||
|
||||
if (input.dim() == 5) {
|
||||
dimc++;
|
||||
dimw++;
|
||||
dimh++;
|
||||
dimd++;
|
||||
}
|
||||
|
||||
/* sizes */
|
||||
int64_t ichannel = input.size(dimc);
|
||||
int64_t idepth = input.size(dimd);
|
||||
int64_t iheight = input.size(dimh);
|
||||
int64_t iwidth = input.size(dimw);
|
||||
|
|
@ -249,6 +252,9 @@ void replication_pad3d_backward_out_cpu_template(
|
|||
|
||||
at::native::padding::check_valid_input<3>(input, paddingSize);
|
||||
|
||||
TORCH_CHECK(ichannel == gradOutput.size(dimc),
|
||||
"gradOutput width unexpected. Expected: ", ichannel, ", Got: ",
|
||||
gradOutput.size(dimc));
|
||||
TORCH_CHECK(owidth == gradOutput.size(dimw),
|
||||
"gradOutput width unexpected. Expected: ", owidth, ", Got: ",
|
||||
gradOutput.size(dimw));
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user