mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fix channels-last dimension mapping in CUDA parallel_cat (#165023)
Fixes #164849 `dimension` was updated in-place, so for more than one batch of channels-last tensors the concat `dimension` for the second kernel launch was wrong ## Testing - python -m compileall test/test_tensor_creation_ops.py ------ https://chatgpt.com/codex/tasks/task_e_68e708879b30832f89b10ae55faa68e8 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165023 Approved by: https://github.com/ezyang
This commit is contained in:
parent
ed2d514ad8
commit
228973df7f
|
|
@ -488,15 +488,16 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
|
|||
}
|
||||
}
|
||||
|
||||
int cat_dim = dimension;
|
||||
if (memory_format != c10::MemoryFormat::Contiguous) {
|
||||
switch (dimension) {
|
||||
switch (cat_dim) {
|
||||
case 0:
|
||||
break;
|
||||
case 1:
|
||||
dimension = nDims - dimension;
|
||||
cat_dim = nDims - cat_dim;
|
||||
break;
|
||||
default:
|
||||
dimension--;
|
||||
cat_dim--;
|
||||
}
|
||||
}
|
||||
// Template Declarations for dim = 1, 2, 3, 4
|
||||
|
|
@ -505,23 +506,23 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
|
|||
constexpr auto elems_per_vec = alignment / sizeof(scalar_t); \
|
||||
CatArrayBatchedCopy_vectorized<scalar_t, unsigned int, DIMS, batch_size, stride_size, alignment, elems_per_vec><<<\
|
||||
catGrid, applyBlock, 0, stream.stream()>>>(\
|
||||
(char*)data, catMetaData, kernelOutputParam, dimension, trailingSize);\
|
||||
(char*)data, catMetaData, kernelOutputParam, cat_dim, trailingSize);\
|
||||
} else if (isContig && isAligned && sizeof(scalar_t) > 2 && sizeof(scalar_t) <= 8) {\
|
||||
CatArrayBatchedCopy_alignedK_contig<scalar_t, unsigned int, DIMS, batch_size, stride_size, ALIGNED_VEC_LOAD_BYTES_16><<<\
|
||||
catGrid, applyBlock, 0, stream.stream()>>>(\
|
||||
data, catMetaData, outputParam, dimension, outputParam.tensorStride[dimension]);\
|
||||
data, catMetaData, outputParam, cat_dim, outputParam.tensorStride[cat_dim]);\
|
||||
} else if (isContig && isAligned && sizeof(scalar_t) == 2) { \
|
||||
CatArrayBatchedCopy_alignedK_contig<scalar_t, unsigned int, DIMS, batch_size, stride_size, ALIGNED_VEC_LOAD_BYTES_8><<<\
|
||||
catGrid, applyBlock, 0, stream.stream()>>>(\
|
||||
data, catMetaData, outputParam, dimension, outputParam.tensorStride[dimension]);\
|
||||
data, catMetaData, outputParam, cat_dim, outputParam.tensorStride[cat_dim]);\
|
||||
} else if (isContig) {\
|
||||
CatArrayBatchedCopy_contig<scalar_t, unsigned int, DIMS, batch_size, stride_size><<<\
|
||||
catGrid, applyBlock, 0, stream.stream()>>>(\
|
||||
data, catMetaData, outputParam, dimension, outputParam.tensorStride[dimension]);\
|
||||
data, catMetaData, outputParam, cat_dim, outputParam.tensorStride[cat_dim]);\
|
||||
} else {\
|
||||
CatArrayBatchedCopy<scalar_t, unsigned int, DIMS, batch_size, stride_size><<<\
|
||||
catGrid, applyBlock, 0, stream.stream()>>>(\
|
||||
data, catMetaData, outputParam, dimension, outputParam.tensorStride[dimension]);\
|
||||
data, catMetaData, outputParam, cat_dim, outputParam.tensorStride[cat_dim]);\
|
||||
}\
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
switch (nDims) {
|
||||
|
|
|
|||
|
|
@ -688,6 +688,21 @@ class TestTensorCreation(TestCase):
|
|||
self.assertEqual(res1, res2)
|
||||
self.assertTrue(res1.is_contiguous(memory_format=torch.channels_last))
|
||||
|
||||
@onlyCUDA
|
||||
def test_cat_channels_last_large_inputs(self, device):
|
||||
num_tensors = 130
|
||||
inputs_cuda = [
|
||||
torch.randn((2, 3, 4, 4), device=device).contiguous(memory_format=torch.channels_last)
|
||||
for _ in range(num_tensors)
|
||||
]
|
||||
inputs_cpu = [t.cpu() for t in inputs_cuda]
|
||||
|
||||
result = torch.cat(inputs_cuda, dim=1)
|
||||
expected = torch.cat(inputs_cpu, dim=1)
|
||||
|
||||
self.assertEqual(result.cpu(), expected)
|
||||
self.assertTrue(result.is_contiguous(memory_format=torch.channels_last))
|
||||
|
||||
@onlyCUDA
|
||||
def test_cat_out_memory_format(self, device):
|
||||
inp_size = (4, 4, 4, 4)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user