mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Add proper shape checking to torch.cat (#4087)
* Fix catArray in THTensor Asserts that the inputs have the same size except in the cat dimension or are empty (or a mix of both). * Fix catArray for THCTensor * Document torch.cat shape checks * Fix types
This commit is contained in:
parent
2c71b679d2
commit
9394e65b44
|
|
@ -2809,116 +2809,114 @@ void THTensor_(cat)(THTensor *r_, THTensor *ta, THTensor *tb, int dimension)
|
|||
THTensor_(catArray)(r_, inputs, 2, dimension);
|
||||
}
|
||||
|
||||
void THTensor_(check_shape_except_dim)(THTensor *first, THTensor *second, int dimension);
|
||||
inline void THTensor_(check_shape_except_dim)(THTensor *first, THTensor *second, int dimension)
|
||||
{
|
||||
int first_dims = first->nDimension;
|
||||
int second_dims = second->nDimension;
|
||||
THArgCheck(first_dims == second_dims, 0,
|
||||
"Tensors must have same number of dimensions: got %d and %d",
|
||||
first_dims, second_dims);
|
||||
for (int dim = 0; dim < first_dims; dim++) {
|
||||
if (dim == dimension) {
|
||||
continue;
|
||||
}
|
||||
int64_t first_dim_size = first->size[dim];
|
||||
int64_t second_dim_size = second->size[dim];
|
||||
THArgCheck(first_dim_size == second_dim_size, 0,
|
||||
"Sizes of tensors must match except in dimension %d. Got %lld and %lld in dimension %d",
|
||||
dimension, (long long)first_dim_size, (long long)second_dim_size, dim);
|
||||
}
|
||||
}
|
||||
|
||||
void THTensor_(catArray)(THTensor *result, THTensor **inputs, int numInputs, int dimension)
|
||||
{
|
||||
THLongStorage *size;
|
||||
int i, j;
|
||||
int64_t offset;
|
||||
int maxDim = dimension + 1;
|
||||
// Find a non-empty tensor to record nDims
|
||||
int allEmpty = 1;
|
||||
int allContiguous = 1;
|
||||
|
||||
// cat_dimension is the actual dimension we cat along
|
||||
int cat_dimension = dimension;
|
||||
|
||||
for (i = 0; i < numInputs; i++)
|
||||
{
|
||||
maxDim = THMax(maxDim, inputs[i]->nDimension);
|
||||
int nDims = 0;
|
||||
THTensor *notEmptyTensor;
|
||||
for (int i = 0; i < numInputs; i++) {
|
||||
int input_dims = inputs[i]->nDimension;
|
||||
if (input_dims == 0) {
|
||||
continue;
|
||||
}
|
||||
// We've found a non-empty tensor
|
||||
allEmpty = 0;
|
||||
notEmptyTensor = inputs[i];
|
||||
nDims = input_dims;
|
||||
break;
|
||||
}
|
||||
if (allEmpty) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Compute cat_dimension based on the non-empty tensor
|
||||
THArgCheck(dimension >= -1 && dimension < nDims, 4, "invalid dimension %d", dimension);
|
||||
// When the user input dimension is -1 (i.e. -2 in C)
|
||||
// Then we pick the maximum last dimension across all tensors.
|
||||
if ( dimension + TH_INDEX_BASE == -1 )
|
||||
{
|
||||
cat_dimension = maxDim?(maxDim-1):0;
|
||||
// Then we pick the last dimension across non-empty tensors.
|
||||
int cat_dimension = dimension;
|
||||
if (dimension + TH_INDEX_BASE == -1) {
|
||||
cat_dimension = nDims ? nDims - 1 : 0;
|
||||
}
|
||||
|
||||
THArgCheck(numInputs > 0, 3, "invalid number of inputs %d", numInputs);
|
||||
THArgCheck(cat_dimension >= 0, 4, "invalid dimension %d", dimension + TH_INDEX_BASE);
|
||||
|
||||
size = THLongStorage_newWithSize(maxDim);
|
||||
|
||||
for(i = 0; i < maxDim; i++)
|
||||
{
|
||||
// dimSize is either the size of the dim if it exists, either 1 if #dim > 0, otherwise 0
|
||||
int64_t dimSize = i < inputs[0]->nDimension ? inputs[0]->size[i] : THMin(inputs[0]->nDimension, 1);
|
||||
if (i == cat_dimension)
|
||||
{
|
||||
for (j = 1; j < numInputs; j++)
|
||||
{
|
||||
// accumulate the size over the dimension we want to cat on.
|
||||
// Empty tensors are allowed
|
||||
dimSize += i < inputs[j]->nDimension ? inputs[j]->size[i] : THMin(inputs[j]->nDimension, 1);
|
||||
}
|
||||
// Compute size of the result in the cat dimension
|
||||
int64_t cat_dim_size = 0;
|
||||
for (int i = 0; i < numInputs; i++) {
|
||||
THTensor *tensor = inputs[i];
|
||||
if (tensor->nDimension == 0) {
|
||||
continue;
|
||||
}
|
||||
else
|
||||
{
|
||||
for (j = 1; j < numInputs; j++)
|
||||
{
|
||||
int64_t sz = (i < inputs[j]->nDimension ? inputs[j]->size[i] : THMin(inputs[j]->nDimension, 1));
|
||||
// If it's a dimension we're not catting on
|
||||
// Then fail if sizes are different AND > 0
|
||||
if (dimSize != sz && dimSize && sz)
|
||||
{
|
||||
THLongStorage_free(size);
|
||||
THError("inconsistent tensor sizes");
|
||||
}
|
||||
else if(!dimSize)
|
||||
{
|
||||
dimSize = sz;
|
||||
}
|
||||
}
|
||||
}
|
||||
allEmpty = allEmpty && !dimSize;
|
||||
size->data[i] = dimSize;
|
||||
THTensor_(check_shape_except_dim)(notEmptyTensor, tensor, cat_dimension);
|
||||
cat_dim_size += tensor->size[cat_dimension];
|
||||
}
|
||||
|
||||
// Initiate catting and resizing
|
||||
// If at least one of the input is not empty
|
||||
if (!allEmpty)
|
||||
{
|
||||
THTensor_(resize)(result, size, NULL);
|
||||
// Compute the size of the result
|
||||
THLongStorage *size = THLongStorage_newWithSize(nDims);
|
||||
for (int dim = 0; dim < nDims; dim++) {
|
||||
int64_t result_dim_size = notEmptyTensor->size[dim];
|
||||
if (dim == cat_dimension) {
|
||||
result_dim_size = cat_dim_size;
|
||||
}
|
||||
size->data[dim] = result_dim_size;
|
||||
}
|
||||
THTensor_(resize)(result, size, NULL);
|
||||
|
||||
// Check contiguity of all inputs and result
|
||||
for (i = 0; i < numInputs; i++) {
|
||||
if(inputs[i]->nDimension) {
|
||||
allContiguous = allContiguous && THTensor_(isContiguous)(inputs[i]);
|
||||
// Check contiguity of all inputs and result
|
||||
int allContiguous = 1;
|
||||
for (int i = 0; i < numInputs; i++) {
|
||||
if(inputs[i]->nDimension) {
|
||||
allContiguous = allContiguous && THTensor_(isContiguous)(inputs[i]);
|
||||
}
|
||||
}
|
||||
allContiguous = allContiguous && THTensor_(isContiguous)(result);
|
||||
|
||||
// First path is for contiguous inputs along dim 0
|
||||
// Second path for non-contiguous
|
||||
int64_t offset;
|
||||
if (cat_dimension == 0 && allContiguous) {
|
||||
real* result_data = result->storage->data + result->storageOffset;
|
||||
offset = 0;
|
||||
for (int j = 0; j < numInputs; j++) {
|
||||
if (inputs[j]->nDimension) {
|
||||
THTensor* input0 = inputs[j];
|
||||
real* input0_data = input0->storage->data + input0->storageOffset;
|
||||
int64_t input0_size = THTensor_(nElement)(input0);
|
||||
memcpy(result_data + offset, input0_data, input0_size*sizeof(real));
|
||||
offset += input0_size;
|
||||
}
|
||||
}
|
||||
allContiguous = allContiguous && THTensor_(isContiguous)(result);
|
||||
|
||||
// First path is for contiguous inputs along dim 1
|
||||
// Second path for non-contiguous
|
||||
if (cat_dimension == 0 && allContiguous)
|
||||
{
|
||||
real* result_data = result->storage->data + result->storageOffset;
|
||||
offset = 0;
|
||||
for (j = 0; j < numInputs; j++)
|
||||
{
|
||||
if (inputs[j]->nDimension)
|
||||
{
|
||||
THTensor* input0 = inputs[j];
|
||||
real* input0_data = input0->storage->data + input0->storageOffset;
|
||||
int64_t input0_size = THTensor_(nElement)(input0);
|
||||
memcpy(result_data + offset, input0_data, input0_size*sizeof(real));
|
||||
offset += input0_size;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
offset = 0;
|
||||
for (j = 0; j < numInputs; j++)
|
||||
{
|
||||
if (inputs[j]->nDimension)
|
||||
{
|
||||
int64_t dimSize = cat_dimension < inputs[j]->nDimension ? inputs[j]->size[cat_dimension] : 1;
|
||||
THTensor *nt = THTensor_(newWithTensor)(result);
|
||||
THTensor_(narrow)(nt, NULL, cat_dimension, offset, dimSize);
|
||||
THTensor_(copy)(nt, inputs[j]);
|
||||
THTensor_(free)(nt);
|
||||
offset += dimSize;
|
||||
}
|
||||
} else {
|
||||
offset = 0;
|
||||
for (int j = 0; j < numInputs; j++) {
|
||||
if (inputs[j]->nDimension) {
|
||||
int64_t dimSize = cat_dimension < inputs[j]->nDimension ? inputs[j]->size[cat_dimension] : 1;
|
||||
THTensor *nt = THTensor_(newWithTensor)(result);
|
||||
THTensor_(narrow)(nt, NULL, cat_dimension, offset, dimSize);
|
||||
THTensor_(copy)(nt, inputs[j]);
|
||||
THTensor_(free)(nt);
|
||||
offset += dimSize;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -90,6 +90,28 @@ void THCTensor_(cat)(THCState *state, THCTensor *result,
|
|||
THCTensor_(catArray)(state, result, inputs, 2, dimension);
|
||||
}
|
||||
|
||||
void THCTensor_(check_shape_except_dim)(THCState *state,
|
||||
THCTensor *first, THCTensor *second, int dimension);
|
||||
inline void THCTensor_(check_shape_except_dim)(THCState *state,
|
||||
THCTensor *first, THCTensor *second, int dimension)
|
||||
{
|
||||
int first_dims = THCTensor_(nDimension)(state, first);
|
||||
int second_dims = THCTensor_(nDimension)(state, second);
|
||||
THArgCheck(first_dims == second_dims, 0,
|
||||
"Tensors must have same number of dimensions: got %d and %d",
|
||||
first_dims, second_dims);
|
||||
for (int dim = 0; dim < first_dims; dim++) {
|
||||
if (dim == dimension) {
|
||||
continue;
|
||||
}
|
||||
int64_t first_dim_size = THCTensor_(size)(state, first, dim);
|
||||
int64_t second_dim_size = THCTensor_(size)(state, second, dim);
|
||||
THArgCheck(first_dim_size == second_dim_size, 0,
|
||||
"Sizes of tensors must match except in dimension %d. Got %lld and %lld in dimension %d",
|
||||
dimension, (long long)first_dim_size, (long long)second_dim_size, dim);
|
||||
}
|
||||
}
|
||||
|
||||
void THCTensor_(catArray)(THCState *state, THCTensor *result,
|
||||
THCTensor **inputs, int numInputs, int dimension)
|
||||
{
|
||||
|
|
@ -97,11 +119,12 @@ void THCTensor_(catArray)(THCState *state, THCTensor *result,
|
|||
int i, j, cohortMax;
|
||||
int64_t offset;
|
||||
bool hasEmptyInput = false;
|
||||
THCTensor *notEmptyTensor = NULL;
|
||||
|
||||
// Even in the case where dimension is negative (i.e. when we want
|
||||
// to cat along the last dimension), this logic still works, as the
|
||||
// loop below will overwrite the value
|
||||
int maxDim = dimension + 1;
|
||||
int nDims = dimension + 1;
|
||||
|
||||
// cat_dimension is the actual dimension we cat along
|
||||
int cat_dimension = dimension;
|
||||
|
|
@ -110,61 +133,44 @@ void THCTensor_(catArray)(THCState *state, THCTensor *result,
|
|||
{
|
||||
int inputDim = THCTensor_(nDimension)(state, inputs[i]);
|
||||
hasEmptyInput |= !inputDim;
|
||||
maxDim = THMax(maxDim, inputDim);
|
||||
if (inputDim > 0) {
|
||||
nDims = inputDim;
|
||||
notEmptyTensor = inputs[i];
|
||||
}
|
||||
}
|
||||
|
||||
// In the event that the user specified -1 as the concat dimension, then
|
||||
// we want to pick the maxDim as dimension to cat along (and thus maxDim - 1 as the
|
||||
// value due to 0-based indexing). If the maxDim is // 0 (i.e. we are catting all
|
||||
// we want to pick the nDims as dimension to cat along (and thus nDims - 1 as the
|
||||
// value due to 0-based indexing). If the nDims is // 0 (i.e. we are catting all
|
||||
// empty tensors), then we set cat_dimension to be 0
|
||||
if (dimension + TH_INDEX_BASE == -1) {
|
||||
cat_dimension = maxDim ? (maxDim - 1) : 0;
|
||||
cat_dimension = nDims ? (nDims - 1) : 0;
|
||||
}
|
||||
|
||||
THArgCheck(numInputs > 0, 3, "invalid number of inputs %d", numInputs);
|
||||
THArgCheck(cat_dimension >= 0, 4, "invalid dimension %d", dimension + TH_INDEX_BASE);
|
||||
|
||||
size = THLongStorage_newWithSize(maxDim);
|
||||
for(i = 0; i < maxDim; i++)
|
||||
{
|
||||
// dimSize is either the size of the dim if it exists, either 1 if #dim > 0, otherwise 0
|
||||
int64_t dimSize = i < THCTensor_(nDimension)(state, inputs[0])
|
||||
? THCTensor_(size)(state, inputs[0], i)
|
||||
: THMin(THCTensor_(nDimension)(state, inputs[0]), 1);
|
||||
if (i == cat_dimension)
|
||||
{
|
||||
for (j = 1; j < numInputs; j++)
|
||||
{
|
||||
// accumulate the size over the dimension we want to cat on.
|
||||
// Empty tensors are allowed
|
||||
dimSize += i < THCTensor_(nDimension)(state, inputs[j])
|
||||
? THCTensor_(size)(state, inputs[j], i)
|
||||
: THMin(THCTensor_(nDimension)(state, inputs[j]), 1);
|
||||
}
|
||||
|
||||
size = THLongStorage_newWithSize(nDims);
|
||||
|
||||
// Compute size of the result in the cat dimension
|
||||
int64_t cat_dim_size = 0;
|
||||
for (int i = 0; i < numInputs; i++) {
|
||||
THCTensor *tensor = inputs[i];
|
||||
if (THCTensor_(nDimension)(state, tensor) == 0) {
|
||||
continue;
|
||||
}
|
||||
else
|
||||
{
|
||||
for (j = 1; j < numInputs; j++)
|
||||
{
|
||||
int64_t sz = i < THCTensor_(nDimension)(state, inputs[j])
|
||||
? THCTensor_(size)(state, inputs[j], i)
|
||||
: THMin(THCTensor_(nDimension)(state, inputs[j]), 1);
|
||||
|
||||
// If it's a dimension we're not catting on
|
||||
// Then fail if sizes are different AND > 0
|
||||
if (dimSize != sz && dimSize && sz) {
|
||||
THLongStorage_free(size);
|
||||
THError("inconsistent tensor sizes");
|
||||
}
|
||||
else if(!dimSize)
|
||||
{
|
||||
dimSize = sz;
|
||||
}
|
||||
}
|
||||
}
|
||||
size->data[i] = dimSize;
|
||||
THCTensor_(check_shape_except_dim)(state, notEmptyTensor, tensor, cat_dimension);
|
||||
cat_dim_size += THCTensor_(size)(state, tensor, cat_dimension);
|
||||
}
|
||||
|
||||
// Compute the size of the result
|
||||
for (int dim = 0; dim < nDims; dim++) {
|
||||
int64_t result_dim_size = THCTensor_(size)(state, notEmptyTensor, dim);
|
||||
if (dim == cat_dimension) {
|
||||
result_dim_size = cat_dim_size;
|
||||
}
|
||||
size->data[dim] = result_dim_size;
|
||||
}
|
||||
THCTensor_(resize)(state, result, size, NULL);
|
||||
THLongStorage_free(size);
|
||||
|
||||
|
|
@ -198,7 +204,7 @@ void THCTensor_(catArray)(THCState *state, THCTensor *result,
|
|||
OutputTensorSizeStride<unsigned int, CAT_ARRAY_MAX_INPUT_DIMS> param;
|
||||
|
||||
// Next, let's initialize the size, stride arrays for the output Tensor.
|
||||
for (i = 0; i < maxDim; ++i) {
|
||||
for (i = 0; i < nDims; ++i) {
|
||||
param.outputSize[i] = THCTensor_(size)(state, result, i);
|
||||
param.outputStride[i] = THCTensor_(stride)(state, result, i);
|
||||
}
|
||||
|
|
@ -250,7 +256,7 @@ void THCTensor_(catArray)(THCState *state, THCTensor *result,
|
|||
getCatGrid(state, j, catGrid);
|
||||
|
||||
|
||||
switch (maxDim) {
|
||||
switch (nDims) {
|
||||
case 1:
|
||||
HANDLE_CASE(1);
|
||||
break;
|
||||
|
|
|
|||
|
|
@ -734,6 +734,38 @@ class TestCuda(TestCase):
|
|||
z = torch.cat([x, y], 0)
|
||||
self.assertEqual(z.get_device(), x.get_device())
|
||||
|
||||
def test_cat(self):
|
||||
SIZE = 10
|
||||
for dim in range(-3, 3):
|
||||
pos_dim = dim if dim >= 0 else 3 + dim
|
||||
x = torch.rand(13, SIZE, SIZE).transpose(0, pos_dim).cuda()
|
||||
y = torch.rand(17, SIZE, SIZE).transpose(0, pos_dim).cuda()
|
||||
z = torch.rand(19, SIZE, SIZE).transpose(0, pos_dim).cuda()
|
||||
|
||||
res1 = torch.cat((x, y, z), dim)
|
||||
self.assertEqual(res1.narrow(pos_dim, 0, 13), x, 0)
|
||||
self.assertEqual(res1.narrow(pos_dim, 13, 17), y, 0)
|
||||
self.assertEqual(res1.narrow(pos_dim, 30, 19), z, 0)
|
||||
|
||||
x = torch.randn(20, SIZE, SIZE).cuda()
|
||||
self.assertEqual(torch.cat(torch.split(x, 7)), x)
|
||||
self.assertEqual(torch.cat(torch.chunk(x, 7)), x)
|
||||
|
||||
y = torch.randn(1, SIZE, SIZE).cuda()
|
||||
z = torch.cat([x, y])
|
||||
self.assertEqual(z.size(), (21, SIZE, SIZE))
|
||||
|
||||
def test_cat_bad_input_sizes(self):
|
||||
x = torch.randn(2, 1).cuda()
|
||||
y = torch.randn(2, 1, 1).cuda()
|
||||
z = torch.randn(2, 1, 1).cuda()
|
||||
self.assertRaises(RuntimeError, lambda: torch.cat([x, y, z]))
|
||||
|
||||
x = torch.randn(2, 1, 2).cuda()
|
||||
y = torch.randn(2, 1, 1).cuda()
|
||||
z = torch.randn(2, 2, 1).cuda()
|
||||
self.assertRaises(RuntimeError, lambda: torch.cat([x, y, z], dim=1))
|
||||
|
||||
def test_serialization(self):
|
||||
x = torch.randn(4, 4).cuda()
|
||||
with tempfile.NamedTemporaryFile() as f:
|
||||
|
|
|
|||
|
|
@ -1917,6 +1917,17 @@ class TestTorch(TestCase):
|
|||
|
||||
self.assertRaises(RuntimeError, lambda: torch.cat([]))
|
||||
|
||||
def test_cat_bad_input_sizes(self):
|
||||
x = torch.randn(2, 1)
|
||||
y = torch.randn(2, 1, 1)
|
||||
z = torch.randn(2, 1, 1)
|
||||
self.assertRaises(RuntimeError, lambda: torch.cat([x, y, z]))
|
||||
|
||||
x = torch.randn(2, 1, 2)
|
||||
y = torch.randn(2, 1, 1)
|
||||
z = torch.randn(2, 2, 1)
|
||||
self.assertRaises(RuntimeError, lambda: torch.cat([x, y, z], dim=1))
|
||||
|
||||
def test_stack(self):
|
||||
x = torch.rand(2, 3, 4)
|
||||
y = torch.rand(2, 3, 4)
|
||||
|
|
|
|||
|
|
@ -594,6 +594,8 @@ add_docstr(torch._C.cat,
|
|||
cat(seq, dim=0, out=None) -> Tensor
|
||||
|
||||
Concatenates the given sequence of :attr:`seq` tensors in the given dimension.
|
||||
All tensors must either have the same shape (except in the cat dimension) or be
|
||||
empty.
|
||||
|
||||
:func:`torch.cat` can be seen as an inverse operation for :func:`torch.split`
|
||||
and :func:`torch.chunk`
|
||||
|
|
@ -601,7 +603,9 @@ and :func:`torch.chunk`
|
|||
:func:`cat` can be best understood via examples.
|
||||
|
||||
Args:
|
||||
seq (sequence of tensors): any python sequence of tensors of the same type
|
||||
seq (sequence of tensors): any python sequence of tensors of the same type.
|
||||
Non-empty tensors provided must have the same shape, except in the
|
||||
cat dimension.
|
||||
dim (int, optional): the dimension over which the tensors are concatenated
|
||||
out (Tensor, optional): the output tensor
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user