mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Fix incorrect stride handling in adaptive_avg_pool3d (#157326)
Fixes #157248 Pull Request resolved: https://github.com/pytorch/pytorch/pull/157326 Approved by: https://github.com/eqy ghstack dependencies: #157242
This commit is contained in:
parent
b5ce77c1f5
commit
b40981c630
|
|
@ -53,7 +53,7 @@ __global__ void adaptiveaveragepool(
|
|||
const scalar_t *input, scalar_t *output,
|
||||
int isizeT, int isizeH, int isizeW,
|
||||
int osizeT, int osizeH, int osizeW,
|
||||
int64_t istrideD,
|
||||
int64_t sizeD, int64_t istrideB, int64_t istrideD,
|
||||
int64_t istrideT, int64_t istrideH, int64_t istrideW,
|
||||
int64_t offsetZ) {
|
||||
// iterates on output pixels
|
||||
|
|
@ -70,15 +70,17 @@ __global__ void adaptiveaveragepool(
|
|||
// select output plane
|
||||
int64_t o_plane = blockIdx.x + offsetZ;
|
||||
ot = o_plane % osizeT; // output frame/time
|
||||
int d = o_plane / osizeT; // slice/feature
|
||||
int d = o_plane / osizeT; // flattened (batch, channel) index
|
||||
|
||||
// Decompose d into batch and channel indices
|
||||
int batch_idx = d / sizeD;
|
||||
int channel_idx = d % sizeD;
|
||||
|
||||
// input frame/time range is fixed.
|
||||
int istartT = start_index(ot, osizeT, isizeT);
|
||||
int iendT = end_index(ot, osizeT, isizeT);
|
||||
int kT = iendT - istartT;
|
||||
|
||||
// input offset by slice/feature and earliest relevant frame/time
|
||||
const scalar_t *input_dt = input + d*istrideD + istartT*istrideT;
|
||||
// output offset by slice/feature and frame/time
|
||||
scalar_t *output_dt = output + o_plane*osizeH*osizeW;
|
||||
|
||||
|
|
@ -93,8 +95,6 @@ __global__ void adaptiveaveragepool(
|
|||
int iendW = end_index(ow, osizeW, isizeW);
|
||||
int kW = iendW - istartW;
|
||||
|
||||
// Compute the average pooling from corresponding input pixels
|
||||
const scalar_t *ptr_input = input_dt + istartH*istrideH + istartW*istrideW;
|
||||
scalar_t *ptr_output = output_dt + oh*osizeW + ow;
|
||||
accscalar_t sum = static_cast<accscalar_t>(0);
|
||||
|
||||
|
|
@ -102,11 +102,13 @@ __global__ void adaptiveaveragepool(
|
|||
for (it = 0; it < kT; ++it) {
|
||||
for (ih = 0; ih < kH; ++ih) {
|
||||
for (iw = 0; iw < kW; ++iw) {
|
||||
scalar_t val = ptr_input[ih*istrideH + iw*istrideW];
|
||||
int64_t input_offset = batch_idx * istrideB + channel_idx * istrideD +
|
||||
(istartT + it) * istrideT +
|
||||
(istartH + ih) * istrideH + (istartW + iw) * istrideW;
|
||||
scalar_t val = input[input_offset];
|
||||
sum += static_cast<accscalar_t>(val);
|
||||
}
|
||||
}
|
||||
ptr_input += istrideT; // next input frame
|
||||
}
|
||||
// Update output
|
||||
const accscalar_t divide_factor = static_cast<accscalar_t>(kT * kH * kW);
|
||||
|
|
@ -121,7 +123,7 @@ void adaptiveaveragepool_loop(
|
|||
int64_t totalZ,
|
||||
int isizeT, int isizeH, int isizeW,
|
||||
int osizeT, int osizeH, int osizeW,
|
||||
int64_t istrideD, int64_t istrideT, int64_t istrideH, int64_t istrideW) {
|
||||
int64_t sizeD, int64_t istrideB, int64_t istrideD, int64_t istrideT, int64_t istrideH, int64_t istrideW) {
|
||||
int64_t offsetZ = 0;
|
||||
dim3 threads(32, 8);
|
||||
// each H*W plane is processed by blocksH thread blocks
|
||||
|
|
@ -133,7 +135,7 @@ void adaptiveaveragepool_loop(
|
|||
input_data, output_data,
|
||||
isizeT, isizeH, isizeW,
|
||||
osizeT, osizeH, osizeW,
|
||||
istrideD,
|
||||
sizeD, istrideB, istrideD,
|
||||
istrideT, istrideH, istrideW,
|
||||
offsetZ);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
|
|
@ -364,7 +366,7 @@ void adaptive_avg_pool3d_out_cuda_template(
|
|||
int64_t osizeW = output_size[2];
|
||||
|
||||
int64_t sizeD, isizeT, isizeH, isizeW;
|
||||
int64_t istrideD, istrideT, istrideH, istrideW;
|
||||
int64_t istrideB, istrideD, istrideT, istrideH, istrideW;
|
||||
int64_t totalZ;
|
||||
|
||||
const Tensor& input = input_.ndimension() == 4 ? input_ : input_.contiguous();
|
||||
|
|
@ -375,6 +377,7 @@ void adaptive_avg_pool3d_out_cuda_template(
|
|||
isizeH = input.size(2);
|
||||
isizeW = input.size(3);
|
||||
|
||||
istrideB = 0;
|
||||
istrideD = input.stride(0);
|
||||
istrideT = input.stride(1);
|
||||
istrideH = input.stride(2);
|
||||
|
|
@ -390,6 +393,7 @@ void adaptive_avg_pool3d_out_cuda_template(
|
|||
isizeH = input.size(3);
|
||||
isizeW = input.size(4);
|
||||
|
||||
istrideB = input.stride(0);
|
||||
istrideD = input.stride(1);
|
||||
istrideT = input.stride(2);
|
||||
istrideH = input.stride(3);
|
||||
|
|
@ -415,7 +419,7 @@ void adaptive_avg_pool3d_out_cuda_template(
|
|||
totalZ,
|
||||
isizeT, isizeH, isizeW,
|
||||
osizeT, osizeH, osizeW,
|
||||
istrideD, istrideT, istrideH, istrideW);
|
||||
sizeD, istrideB, istrideD, istrideT, istrideH, istrideW);
|
||||
});
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -2165,6 +2165,52 @@ def triton_poi_fused_add_reflection_pad2d_0(in_ptr0, in_ptr1, out_ptr0, xnumel,
|
|||
|
||||
self.assertEqual(default_output, max_autotune_output)
|
||||
|
||||
def test_adaptive_avg_pool3d_issue_157248(self):
|
||||
"""Test for GitHub issue #157248: Conv2d-unsqueeze-AdaptiveAvgPool3d produces incorrect results"""
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
|
||||
self.adaptive_pool = torch.nn.AdaptiveAvgPool3d((4, 4, 4))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
# This specific unsqueeze position was problematic due to zero strides
|
||||
x = x.unsqueeze(1)
|
||||
x = self.adaptive_pool(x)
|
||||
return x
|
||||
|
||||
model = Model().cuda()
|
||||
model.eval()
|
||||
test_cases = [
|
||||
(1, 3, 8, 8),
|
||||
(2, 3, 16, 16),
|
||||
(1, 3, 32, 32),
|
||||
(1, 3, 15, 15),
|
||||
(2, 3, 13, 13),
|
||||
]
|
||||
|
||||
for batch, channels, h, w in test_cases:
|
||||
with self.subTest(input_shape=(batch, channels, h, w)):
|
||||
input_tensor = torch.randn(batch, channels, h, w, device="cuda")
|
||||
|
||||
# Test eager mode
|
||||
with torch.no_grad():
|
||||
eager_output = model(input_tensor)
|
||||
|
||||
# Test compiled mode with inductor
|
||||
compiled_model = torch.compile(model, backend="inductor")
|
||||
with torch.no_grad():
|
||||
compiled_output = compiled_model(input_tensor)
|
||||
|
||||
# They should be identical (or very close)
|
||||
self.assertTrue(
|
||||
torch.allclose(eager_output, compiled_output, rtol=1e-5, atol=1e-5),
|
||||
f"Results differ for input shape {(batch, channels, h, w)}. "
|
||||
f"Max diff: {torch.max(torch.abs(eager_output - compiled_output)):.6f}",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._inductor.test_case import run_tests
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user