mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Merge pull request #99046 from stevemcgregory:fix/gpu-shared-mem-alignment
PiperOrigin-RevId: 803065740
This commit is contained in:
commit
604991290b
|
|
@ -70,7 +70,10 @@ __global__ void concat_variable_kernel(
|
|||
IntType num_inputs = input_ptr_data.size;
|
||||
|
||||
// verbose declaration needed due to template
|
||||
GPU_DYNAMIC_SHARED_MEM_DECL(sizeof(T), unsigned char, smem);
|
||||
constexpr size_t kAlignTI =
|
||||
(alignof(T) > alignof(IntType)) ? alignof(T) : alignof(IntType);
|
||||
constexpr size_t kAlign = (kAlignTI < 16) ? 16 : kAlignTI;
|
||||
GPU_DYNAMIC_SHARED_MEM_DECL(kAlign, unsigned char, smem);
|
||||
IntType* smem_col_scan = reinterpret_cast<IntType*>(smem);
|
||||
|
||||
if (useSmem) {
|
||||
|
|
|
|||
|
|
@ -120,7 +120,10 @@ __global__ void split_v_kernel(const T* __restrict__ input_ptr,
|
|||
int num_outputs = output_ptr_data.size;
|
||||
|
||||
// verbose declaration needed due to template
|
||||
GPU_DYNAMIC_SHARED_MEM_DECL(sizeof(T), unsigned char, smem);
|
||||
constexpr size_t kAlignTI =
|
||||
(alignof(T) > alignof(IntType)) ? alignof(T) : alignof(IntType);
|
||||
constexpr size_t kAlign = (kAlignTI < 16) ? 16 : kAlignTI;
|
||||
GPU_DYNAMIC_SHARED_MEM_DECL(kAlign, unsigned char, smem);
|
||||
IntType* smem_col_scan = reinterpret_cast<IntType*>(smem);
|
||||
|
||||
if (useSmem) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user