Fix incorrect stride handling in adaptive_avg_pool3d (#157326)

Fixes #157248

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157326
Approved by: https://github.com/eqy
ghstack dependencies: #157242
This commit is contained in:
Jason Ansel 2025-06-30 16:41:39 -07:00 committed by PyTorch MergeBot
parent b5ce77c1f5
commit b40981c630
2 changed files with 62 additions and 12 deletions

View File

@ -53,7 +53,7 @@ __global__ void adaptiveaveragepool(
const scalar_t *input, scalar_t *output,
int isizeT, int isizeH, int isizeW,
int osizeT, int osizeH, int osizeW,
int64_t istrideD,
int64_t sizeD, int64_t istrideB, int64_t istrideD,
int64_t istrideT, int64_t istrideH, int64_t istrideW,
int64_t offsetZ) {
// iterates on output pixels
@ -70,15 +70,17 @@ __global__ void adaptiveaveragepool(
// select output plane
int64_t o_plane = blockIdx.x + offsetZ;
ot = o_plane % osizeT; // output frame/time
int d = o_plane / osizeT; // slice/feature
int d = o_plane / osizeT; // flattened (batch, channel) index
// Decompose d into batch and channel indices
int batch_idx = d / sizeD;
int channel_idx = d % sizeD;
// input frame/time range is fixed.
int istartT = start_index(ot, osizeT, isizeT);
int iendT = end_index(ot, osizeT, isizeT);
int kT = iendT - istartT;
// input offset by slice/feature and earliest relevant frame/time
const scalar_t *input_dt = input + d*istrideD + istartT*istrideT;
// output offset by slice/feature and frame/time
scalar_t *output_dt = output + o_plane*osizeH*osizeW;
@ -93,8 +95,6 @@ __global__ void adaptiveaveragepool(
int iendW = end_index(ow, osizeW, isizeW);
int kW = iendW - istartW;
// Compute the average pooling from corresponding input pixels
const scalar_t *ptr_input = input_dt + istartH*istrideH + istartW*istrideW;
scalar_t *ptr_output = output_dt + oh*osizeW + ow;
accscalar_t sum = static_cast<accscalar_t>(0);
@ -102,11 +102,13 @@ __global__ void adaptiveaveragepool(
for (it = 0; it < kT; ++it) {
for (ih = 0; ih < kH; ++ih) {
for (iw = 0; iw < kW; ++iw) {
scalar_t val = ptr_input[ih*istrideH + iw*istrideW];
int64_t input_offset = batch_idx * istrideB + channel_idx * istrideD +
(istartT + it) * istrideT +
(istartH + ih) * istrideH + (istartW + iw) * istrideW;
scalar_t val = input[input_offset];
sum += static_cast<accscalar_t>(val);
}
}
ptr_input += istrideT; // next input frame
}
// Update output
const accscalar_t divide_factor = static_cast<accscalar_t>(kT * kH * kW);
@ -121,7 +123,7 @@ void adaptiveaveragepool_loop(
int64_t totalZ,
int isizeT, int isizeH, int isizeW,
int osizeT, int osizeH, int osizeW,
int64_t istrideD, int64_t istrideT, int64_t istrideH, int64_t istrideW) {
int64_t sizeD, int64_t istrideB, int64_t istrideD, int64_t istrideT, int64_t istrideH, int64_t istrideW) {
int64_t offsetZ = 0;
dim3 threads(32, 8);
// each H*W plane is processed by blocksH thread blocks
@ -133,7 +135,7 @@ void adaptiveaveragepool_loop(
input_data, output_data,
isizeT, isizeH, isizeW,
osizeT, osizeH, osizeW,
istrideD,
sizeD, istrideB, istrideD,
istrideT, istrideH, istrideW,
offsetZ);
C10_CUDA_KERNEL_LAUNCH_CHECK();
@ -364,7 +366,7 @@ void adaptive_avg_pool3d_out_cuda_template(
int64_t osizeW = output_size[2];
int64_t sizeD, isizeT, isizeH, isizeW;
int64_t istrideD, istrideT, istrideH, istrideW;
int64_t istrideB, istrideD, istrideT, istrideH, istrideW;
int64_t totalZ;
const Tensor& input = input_.ndimension() == 4 ? input_ : input_.contiguous();
@ -375,6 +377,7 @@ void adaptive_avg_pool3d_out_cuda_template(
isizeH = input.size(2);
isizeW = input.size(3);
istrideB = 0;
istrideD = input.stride(0);
istrideT = input.stride(1);
istrideH = input.stride(2);
@ -390,6 +393,7 @@ void adaptive_avg_pool3d_out_cuda_template(
isizeH = input.size(3);
isizeW = input.size(4);
istrideB = input.stride(0);
istrideD = input.stride(1);
istrideT = input.stride(2);
istrideH = input.stride(3);
@ -415,7 +419,7 @@ void adaptive_avg_pool3d_out_cuda_template(
totalZ,
isizeT, isizeH, isizeW,
osizeT, osizeH, osizeW,
istrideD, istrideT, istrideH, istrideW);
sizeD, istrideB, istrideD, istrideT, istrideH, istrideW);
});
}

View File

@ -2165,6 +2165,52 @@ def triton_poi_fused_add_reflection_pad2d_0(in_ptr0, in_ptr1, out_ptr0, xnumel,
self.assertEqual(default_output, max_autotune_output)
def test_adaptive_avg_pool3d_issue_157248(self):
"""Test for GitHub issue #157248: Conv2d-unsqueeze-AdaptiveAvgPool3d produces incorrect results"""
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
self.adaptive_pool = torch.nn.AdaptiveAvgPool3d((4, 4, 4))
def forward(self, x):
x = self.conv(x)
# This specific unsqueeze position was problematic due to zero strides
x = x.unsqueeze(1)
x = self.adaptive_pool(x)
return x
model = Model().cuda()
model.eval()
test_cases = [
(1, 3, 8, 8),
(2, 3, 16, 16),
(1, 3, 32, 32),
(1, 3, 15, 15),
(2, 3, 13, 13),
]
for batch, channels, h, w in test_cases:
with self.subTest(input_shape=(batch, channels, h, w)):
input_tensor = torch.randn(batch, channels, h, w, device="cuda")
# Test eager mode
with torch.no_grad():
eager_output = model(input_tensor)
# Test compiled mode with inductor
compiled_model = torch.compile(model, backend="inductor")
with torch.no_grad():
compiled_output = compiled_model(input_tensor)
# They should be identical (or very close)
self.assertTrue(
torch.allclose(eager_output, compiled_output, rtol=1e-5, atol=1e-5),
f"Results differ for input shape {(batch, channels, h, w)}. "
f"Max diff: {torch.max(torch.abs(eager_output - compiled_output)):.6f}",
)
if __name__ == "__main__":
from torch._inductor.test_case import run_tests