Merge pull request #99046 from stevemcgregory:fix/gpu-shared-mem-alignment

PiperOrigin-RevId: 803065740
This commit is contained in:
TensorFlower Gardener 2025-09-04 10:10:55 -07:00
commit 604991290b
2 changed files with 8 additions and 2 deletions

View File

@ -70,7 +70,10 @@ __global__ void concat_variable_kernel(
IntType num_inputs = input_ptr_data.size;
// verbose declaration needed due to template
GPU_DYNAMIC_SHARED_MEM_DECL(sizeof(T), unsigned char, smem);
constexpr size_t kAlignTI =
(alignof(T) > alignof(IntType)) ? alignof(T) : alignof(IntType);
constexpr size_t kAlign = (kAlignTI < 16) ? 16 : kAlignTI;
GPU_DYNAMIC_SHARED_MEM_DECL(kAlign, unsigned char, smem);
IntType* smem_col_scan = reinterpret_cast<IntType*>(smem);
if (useSmem) {

View File

@ -120,7 +120,10 @@ __global__ void split_v_kernel(const T* __restrict__ input_ptr,
int num_outputs = output_ptr_data.size;
// verbose declaration needed due to template
GPU_DYNAMIC_SHARED_MEM_DECL(sizeof(T), unsigned char, smem);
constexpr size_t kAlignTI =
(alignof(T) > alignof(IntType)) ? alignof(T) : alignof(IntType);
constexpr size_t kAlign = (kAlignTI < 16) ? 16 : kAlignTI;
GPU_DYNAMIC_SHARED_MEM_DECL(kAlign, unsigned char, smem);
IntType* smem_col_scan = reinterpret_cast<IntType*>(smem);
if (useSmem) {