mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Fix incorrect stride handling in adaptive_avg_pool3d (#157326)
Fixes #157248 Pull Request resolved: https://github.com/pytorch/pytorch/pull/157326 Approved by: https://github.com/eqy ghstack dependencies: #157242
This commit is contained in:
parent
b5ce77c1f5
commit
b40981c630
|
|
@ -53,7 +53,7 @@ __global__ void adaptiveaveragepool(
|
||||||
const scalar_t *input, scalar_t *output,
|
const scalar_t *input, scalar_t *output,
|
||||||
int isizeT, int isizeH, int isizeW,
|
int isizeT, int isizeH, int isizeW,
|
||||||
int osizeT, int osizeH, int osizeW,
|
int osizeT, int osizeH, int osizeW,
|
||||||
int64_t istrideD,
|
int64_t sizeD, int64_t istrideB, int64_t istrideD,
|
||||||
int64_t istrideT, int64_t istrideH, int64_t istrideW,
|
int64_t istrideT, int64_t istrideH, int64_t istrideW,
|
||||||
int64_t offsetZ) {
|
int64_t offsetZ) {
|
||||||
// iterates on output pixels
|
// iterates on output pixels
|
||||||
|
|
@ -70,15 +70,17 @@ __global__ void adaptiveaveragepool(
|
||||||
// select output plane
|
// select output plane
|
||||||
int64_t o_plane = blockIdx.x + offsetZ;
|
int64_t o_plane = blockIdx.x + offsetZ;
|
||||||
ot = o_plane % osizeT; // output frame/time
|
ot = o_plane % osizeT; // output frame/time
|
||||||
int d = o_plane / osizeT; // slice/feature
|
int d = o_plane / osizeT; // flattened (batch, channel) index
|
||||||
|
|
||||||
|
// Decompose d into batch and channel indices
|
||||||
|
int batch_idx = d / sizeD;
|
||||||
|
int channel_idx = d % sizeD;
|
||||||
|
|
||||||
// input frame/time range is fixed.
|
// input frame/time range is fixed.
|
||||||
int istartT = start_index(ot, osizeT, isizeT);
|
int istartT = start_index(ot, osizeT, isizeT);
|
||||||
int iendT = end_index(ot, osizeT, isizeT);
|
int iendT = end_index(ot, osizeT, isizeT);
|
||||||
int kT = iendT - istartT;
|
int kT = iendT - istartT;
|
||||||
|
|
||||||
// input offset by slice/feature and earliest relevant frame/time
|
|
||||||
const scalar_t *input_dt = input + d*istrideD + istartT*istrideT;
|
|
||||||
// output offset by slice/feature and frame/time
|
// output offset by slice/feature and frame/time
|
||||||
scalar_t *output_dt = output + o_plane*osizeH*osizeW;
|
scalar_t *output_dt = output + o_plane*osizeH*osizeW;
|
||||||
|
|
||||||
|
|
@ -93,8 +95,6 @@ __global__ void adaptiveaveragepool(
|
||||||
int iendW = end_index(ow, osizeW, isizeW);
|
int iendW = end_index(ow, osizeW, isizeW);
|
||||||
int kW = iendW - istartW;
|
int kW = iendW - istartW;
|
||||||
|
|
||||||
// Compute the average pooling from corresponding input pixels
|
|
||||||
const scalar_t *ptr_input = input_dt + istartH*istrideH + istartW*istrideW;
|
|
||||||
scalar_t *ptr_output = output_dt + oh*osizeW + ow;
|
scalar_t *ptr_output = output_dt + oh*osizeW + ow;
|
||||||
accscalar_t sum = static_cast<accscalar_t>(0);
|
accscalar_t sum = static_cast<accscalar_t>(0);
|
||||||
|
|
||||||
|
|
@ -102,11 +102,13 @@ __global__ void adaptiveaveragepool(
|
||||||
for (it = 0; it < kT; ++it) {
|
for (it = 0; it < kT; ++it) {
|
||||||
for (ih = 0; ih < kH; ++ih) {
|
for (ih = 0; ih < kH; ++ih) {
|
||||||
for (iw = 0; iw < kW; ++iw) {
|
for (iw = 0; iw < kW; ++iw) {
|
||||||
scalar_t val = ptr_input[ih*istrideH + iw*istrideW];
|
int64_t input_offset = batch_idx * istrideB + channel_idx * istrideD +
|
||||||
|
(istartT + it) * istrideT +
|
||||||
|
(istartH + ih) * istrideH + (istartW + iw) * istrideW;
|
||||||
|
scalar_t val = input[input_offset];
|
||||||
sum += static_cast<accscalar_t>(val);
|
sum += static_cast<accscalar_t>(val);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ptr_input += istrideT; // next input frame
|
|
||||||
}
|
}
|
||||||
// Update output
|
// Update output
|
||||||
const accscalar_t divide_factor = static_cast<accscalar_t>(kT * kH * kW);
|
const accscalar_t divide_factor = static_cast<accscalar_t>(kT * kH * kW);
|
||||||
|
|
@ -121,7 +123,7 @@ void adaptiveaveragepool_loop(
|
||||||
int64_t totalZ,
|
int64_t totalZ,
|
||||||
int isizeT, int isizeH, int isizeW,
|
int isizeT, int isizeH, int isizeW,
|
||||||
int osizeT, int osizeH, int osizeW,
|
int osizeT, int osizeH, int osizeW,
|
||||||
int64_t istrideD, int64_t istrideT, int64_t istrideH, int64_t istrideW) {
|
int64_t sizeD, int64_t istrideB, int64_t istrideD, int64_t istrideT, int64_t istrideH, int64_t istrideW) {
|
||||||
int64_t offsetZ = 0;
|
int64_t offsetZ = 0;
|
||||||
dim3 threads(32, 8);
|
dim3 threads(32, 8);
|
||||||
// each H*W plane is processed by blocksH thread blocks
|
// each H*W plane is processed by blocksH thread blocks
|
||||||
|
|
@ -133,7 +135,7 @@ void adaptiveaveragepool_loop(
|
||||||
input_data, output_data,
|
input_data, output_data,
|
||||||
isizeT, isizeH, isizeW,
|
isizeT, isizeH, isizeW,
|
||||||
osizeT, osizeH, osizeW,
|
osizeT, osizeH, osizeW,
|
||||||
istrideD,
|
sizeD, istrideB, istrideD,
|
||||||
istrideT, istrideH, istrideW,
|
istrideT, istrideH, istrideW,
|
||||||
offsetZ);
|
offsetZ);
|
||||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||||
|
|
@ -364,7 +366,7 @@ void adaptive_avg_pool3d_out_cuda_template(
|
||||||
int64_t osizeW = output_size[2];
|
int64_t osizeW = output_size[2];
|
||||||
|
|
||||||
int64_t sizeD, isizeT, isizeH, isizeW;
|
int64_t sizeD, isizeT, isizeH, isizeW;
|
||||||
int64_t istrideD, istrideT, istrideH, istrideW;
|
int64_t istrideB, istrideD, istrideT, istrideH, istrideW;
|
||||||
int64_t totalZ;
|
int64_t totalZ;
|
||||||
|
|
||||||
const Tensor& input = input_.ndimension() == 4 ? input_ : input_.contiguous();
|
const Tensor& input = input_.ndimension() == 4 ? input_ : input_.contiguous();
|
||||||
|
|
@ -375,6 +377,7 @@ void adaptive_avg_pool3d_out_cuda_template(
|
||||||
isizeH = input.size(2);
|
isizeH = input.size(2);
|
||||||
isizeW = input.size(3);
|
isizeW = input.size(3);
|
||||||
|
|
||||||
|
istrideB = 0;
|
||||||
istrideD = input.stride(0);
|
istrideD = input.stride(0);
|
||||||
istrideT = input.stride(1);
|
istrideT = input.stride(1);
|
||||||
istrideH = input.stride(2);
|
istrideH = input.stride(2);
|
||||||
|
|
@ -390,6 +393,7 @@ void adaptive_avg_pool3d_out_cuda_template(
|
||||||
isizeH = input.size(3);
|
isizeH = input.size(3);
|
||||||
isizeW = input.size(4);
|
isizeW = input.size(4);
|
||||||
|
|
||||||
|
istrideB = input.stride(0);
|
||||||
istrideD = input.stride(1);
|
istrideD = input.stride(1);
|
||||||
istrideT = input.stride(2);
|
istrideT = input.stride(2);
|
||||||
istrideH = input.stride(3);
|
istrideH = input.stride(3);
|
||||||
|
|
@ -415,7 +419,7 @@ void adaptive_avg_pool3d_out_cuda_template(
|
||||||
totalZ,
|
totalZ,
|
||||||
isizeT, isizeH, isizeW,
|
isizeT, isizeH, isizeW,
|
||||||
osizeT, osizeH, osizeW,
|
osizeT, osizeH, osizeW,
|
||||||
istrideD, istrideT, istrideH, istrideW);
|
sizeD, istrideB, istrideD, istrideT, istrideH, istrideW);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2165,6 +2165,52 @@ def triton_poi_fused_add_reflection_pad2d_0(in_ptr0, in_ptr1, out_ptr0, xnumel,
|
||||||
|
|
||||||
self.assertEqual(default_output, max_autotune_output)
|
self.assertEqual(default_output, max_autotune_output)
|
||||||
|
|
||||||
|
def test_adaptive_avg_pool3d_issue_157248(self):
|
||||||
|
"""Test for GitHub issue #157248: Conv2d-unsqueeze-AdaptiveAvgPool3d produces incorrect results"""
|
||||||
|
|
||||||
|
class Model(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = torch.nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
|
||||||
|
self.adaptive_pool = torch.nn.AdaptiveAvgPool3d((4, 4, 4))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv(x)
|
||||||
|
# This specific unsqueeze position was problematic due to zero strides
|
||||||
|
x = x.unsqueeze(1)
|
||||||
|
x = self.adaptive_pool(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
model = Model().cuda()
|
||||||
|
model.eval()
|
||||||
|
test_cases = [
|
||||||
|
(1, 3, 8, 8),
|
||||||
|
(2, 3, 16, 16),
|
||||||
|
(1, 3, 32, 32),
|
||||||
|
(1, 3, 15, 15),
|
||||||
|
(2, 3, 13, 13),
|
||||||
|
]
|
||||||
|
|
||||||
|
for batch, channels, h, w in test_cases:
|
||||||
|
with self.subTest(input_shape=(batch, channels, h, w)):
|
||||||
|
input_tensor = torch.randn(batch, channels, h, w, device="cuda")
|
||||||
|
|
||||||
|
# Test eager mode
|
||||||
|
with torch.no_grad():
|
||||||
|
eager_output = model(input_tensor)
|
||||||
|
|
||||||
|
# Test compiled mode with inductor
|
||||||
|
compiled_model = torch.compile(model, backend="inductor")
|
||||||
|
with torch.no_grad():
|
||||||
|
compiled_output = compiled_model(input_tensor)
|
||||||
|
|
||||||
|
# They should be identical (or very close)
|
||||||
|
self.assertTrue(
|
||||||
|
torch.allclose(eager_output, compiled_output, rtol=1e-5, atol=1e-5),
|
||||||
|
f"Results differ for input shape {(batch, channels, h, w)}. "
|
||||||
|
f"Max diff: {torch.max(torch.abs(eager_output - compiled_output)):.6f}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from torch._inductor.test_case import run_tests
|
from torch._inductor.test_case import run_tests
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user