Add proper shape checking to torch.cat (#4087)

* Fix catArray in THTensor

Asserts that the inputs have the same size except in the
cat dimension or are empty (or a mix of both).

* Fix catArray for THCTensor

* Document torch.cat shape checks

* Fix types
This commit is contained in:
Richard Zou 2017-12-18 02:05:58 -05:00 committed by Soumith Chintala
parent 2c71b679d2
commit 9394e65b44
5 changed files with 192 additions and 141 deletions

View File

@ -2809,116 +2809,114 @@ void THTensor_(cat)(THTensor *r_, THTensor *ta, THTensor *tb, int dimension)
THTensor_(catArray)(r_, inputs, 2, dimension);
}
void THTensor_(check_shape_except_dim)(THTensor *first, THTensor *second, int dimension);
inline void THTensor_(check_shape_except_dim)(THTensor *first, THTensor *second, int dimension)
{
int first_dims = first->nDimension;
int second_dims = second->nDimension;
THArgCheck(first_dims == second_dims, 0,
"Tensors must have same number of dimensions: got %d and %d",
first_dims, second_dims);
for (int dim = 0; dim < first_dims; dim++) {
if (dim == dimension) {
continue;
}
int64_t first_dim_size = first->size[dim];
int64_t second_dim_size = second->size[dim];
THArgCheck(first_dim_size == second_dim_size, 0,
"Sizes of tensors must match except in dimension %d. Got %lld and %lld in dimension %d",
dimension, (long long)first_dim_size, (long long)second_dim_size, dim);
}
}
void THTensor_(catArray)(THTensor *result, THTensor **inputs, int numInputs, int dimension)
{
THLongStorage *size;
int i, j;
int64_t offset;
int maxDim = dimension + 1;
// Find a non-empty tensor to record nDims
int allEmpty = 1;
int allContiguous = 1;
// cat_dimension is the actual dimension we cat along
int cat_dimension = dimension;
for (i = 0; i < numInputs; i++)
{
maxDim = THMax(maxDim, inputs[i]->nDimension);
int nDims = 0;
THTensor *notEmptyTensor;
for (int i = 0; i < numInputs; i++) {
int input_dims = inputs[i]->nDimension;
if (input_dims == 0) {
continue;
}
// We've found a non-empty tensor
allEmpty = 0;
notEmptyTensor = inputs[i];
nDims = input_dims;
break;
}
if (allEmpty) {
return;
}
// Compute cat_dimension based on the non-empty tensor
THArgCheck(dimension >= -1 && dimension < nDims, 4, "invalid dimension %d", dimension);
// When the user input dimension is -1 (i.e. -2 in C)
// Then we pick the maximum last dimension across all tensors.
if ( dimension + TH_INDEX_BASE == -1 )
{
cat_dimension = maxDim?(maxDim-1):0;
// Then we pick the last dimension across non-empty tensors.
int cat_dimension = dimension;
if (dimension + TH_INDEX_BASE == -1) {
cat_dimension = nDims ? nDims - 1 : 0;
}
THArgCheck(numInputs > 0, 3, "invalid number of inputs %d", numInputs);
THArgCheck(cat_dimension >= 0, 4, "invalid dimension %d", dimension + TH_INDEX_BASE);
size = THLongStorage_newWithSize(maxDim);
for(i = 0; i < maxDim; i++)
{
// dimSize is either the size of the dim if it exists, either 1 if #dim > 0, otherwise 0
int64_t dimSize = i < inputs[0]->nDimension ? inputs[0]->size[i] : THMin(inputs[0]->nDimension, 1);
if (i == cat_dimension)
{
for (j = 1; j < numInputs; j++)
{
// accumulate the size over the dimension we want to cat on.
// Empty tensors are allowed
dimSize += i < inputs[j]->nDimension ? inputs[j]->size[i] : THMin(inputs[j]->nDimension, 1);
}
// Compute size of the result in the cat dimension
int64_t cat_dim_size = 0;
for (int i = 0; i < numInputs; i++) {
THTensor *tensor = inputs[i];
if (tensor->nDimension == 0) {
continue;
}
else
{
for (j = 1; j < numInputs; j++)
{
int64_t sz = (i < inputs[j]->nDimension ? inputs[j]->size[i] : THMin(inputs[j]->nDimension, 1));
// If it's a dimension we're not catting on
// Then fail if sizes are different AND > 0
if (dimSize != sz && dimSize && sz)
{
THLongStorage_free(size);
THError("inconsistent tensor sizes");
}
else if(!dimSize)
{
dimSize = sz;
}
}
}
allEmpty = allEmpty && !dimSize;
size->data[i] = dimSize;
THTensor_(check_shape_except_dim)(notEmptyTensor, tensor, cat_dimension);
cat_dim_size += tensor->size[cat_dimension];
}
// Initiate catting and resizing
// If at least one of the input is not empty
if (!allEmpty)
{
THTensor_(resize)(result, size, NULL);
// Compute the size of the result
THLongStorage *size = THLongStorage_newWithSize(nDims);
for (int dim = 0; dim < nDims; dim++) {
int64_t result_dim_size = notEmptyTensor->size[dim];
if (dim == cat_dimension) {
result_dim_size = cat_dim_size;
}
size->data[dim] = result_dim_size;
}
THTensor_(resize)(result, size, NULL);
// Check contiguity of all inputs and result
for (i = 0; i < numInputs; i++) {
if(inputs[i]->nDimension) {
allContiguous = allContiguous && THTensor_(isContiguous)(inputs[i]);
// Check contiguity of all inputs and result
int allContiguous = 1;
for (int i = 0; i < numInputs; i++) {
if(inputs[i]->nDimension) {
allContiguous = allContiguous && THTensor_(isContiguous)(inputs[i]);
}
}
allContiguous = allContiguous && THTensor_(isContiguous)(result);
// First path is for contiguous inputs along dim 0
// Second path for non-contiguous
int64_t offset;
if (cat_dimension == 0 && allContiguous) {
real* result_data = result->storage->data + result->storageOffset;
offset = 0;
for (int j = 0; j < numInputs; j++) {
if (inputs[j]->nDimension) {
THTensor* input0 = inputs[j];
real* input0_data = input0->storage->data + input0->storageOffset;
int64_t input0_size = THTensor_(nElement)(input0);
memcpy(result_data + offset, input0_data, input0_size*sizeof(real));
offset += input0_size;
}
}
allContiguous = allContiguous && THTensor_(isContiguous)(result);
// First path is for contiguous inputs along dim 1
// Second path for non-contiguous
if (cat_dimension == 0 && allContiguous)
{
real* result_data = result->storage->data + result->storageOffset;
offset = 0;
for (j = 0; j < numInputs; j++)
{
if (inputs[j]->nDimension)
{
THTensor* input0 = inputs[j];
real* input0_data = input0->storage->data + input0->storageOffset;
int64_t input0_size = THTensor_(nElement)(input0);
memcpy(result_data + offset, input0_data, input0_size*sizeof(real));
offset += input0_size;
}
}
}
else
{
offset = 0;
for (j = 0; j < numInputs; j++)
{
if (inputs[j]->nDimension)
{
int64_t dimSize = cat_dimension < inputs[j]->nDimension ? inputs[j]->size[cat_dimension] : 1;
THTensor *nt = THTensor_(newWithTensor)(result);
THTensor_(narrow)(nt, NULL, cat_dimension, offset, dimSize);
THTensor_(copy)(nt, inputs[j]);
THTensor_(free)(nt);
offset += dimSize;
}
} else {
offset = 0;
for (int j = 0; j < numInputs; j++) {
if (inputs[j]->nDimension) {
int64_t dimSize = cat_dimension < inputs[j]->nDimension ? inputs[j]->size[cat_dimension] : 1;
THTensor *nt = THTensor_(newWithTensor)(result);
THTensor_(narrow)(nt, NULL, cat_dimension, offset, dimSize);
THTensor_(copy)(nt, inputs[j]);
THTensor_(free)(nt);
offset += dimSize;
}
}
}

View File

@ -90,6 +90,28 @@ void THCTensor_(cat)(THCState *state, THCTensor *result,
THCTensor_(catArray)(state, result, inputs, 2, dimension);
}
void THCTensor_(check_shape_except_dim)(THCState *state,
THCTensor *first, THCTensor *second, int dimension);
inline void THCTensor_(check_shape_except_dim)(THCState *state,
THCTensor *first, THCTensor *second, int dimension)
{
int first_dims = THCTensor_(nDimension)(state, first);
int second_dims = THCTensor_(nDimension)(state, second);
THArgCheck(first_dims == second_dims, 0,
"Tensors must have same number of dimensions: got %d and %d",
first_dims, second_dims);
for (int dim = 0; dim < first_dims; dim++) {
if (dim == dimension) {
continue;
}
int64_t first_dim_size = THCTensor_(size)(state, first, dim);
int64_t second_dim_size = THCTensor_(size)(state, second, dim);
THArgCheck(first_dim_size == second_dim_size, 0,
"Sizes of tensors must match except in dimension %d. Got %lld and %lld in dimension %d",
dimension, (long long)first_dim_size, (long long)second_dim_size, dim);
}
}
void THCTensor_(catArray)(THCState *state, THCTensor *result,
THCTensor **inputs, int numInputs, int dimension)
{
@ -97,11 +119,12 @@ void THCTensor_(catArray)(THCState *state, THCTensor *result,
int i, j, cohortMax;
int64_t offset;
bool hasEmptyInput = false;
THCTensor *notEmptyTensor = NULL;
// Even in the case where dimension is negative (i.e. when we want
// to cat along the last dimension), this logic still works, as the
// loop below will overwrite the value
int maxDim = dimension + 1;
int nDims = dimension + 1;
// cat_dimension is the actual dimension we cat along
int cat_dimension = dimension;
@ -110,61 +133,44 @@ void THCTensor_(catArray)(THCState *state, THCTensor *result,
{
int inputDim = THCTensor_(nDimension)(state, inputs[i]);
hasEmptyInput |= !inputDim;
maxDim = THMax(maxDim, inputDim);
if (inputDim > 0) {
nDims = inputDim;
notEmptyTensor = inputs[i];
}
}
// In the event that the user specified -1 as the concat dimension, then
// we want to pick the maxDim as dimension to cat along (and thus maxDim - 1 as the
// value due to 0-based indexing). If the maxDim is // 0 (i.e. we are catting all
// we want to pick the nDims as dimension to cat along (and thus nDims - 1 as the
// value due to 0-based indexing). If the nDims is // 0 (i.e. we are catting all
// empty tensors), then we set cat_dimension to be 0
if (dimension + TH_INDEX_BASE == -1) {
cat_dimension = maxDim ? (maxDim - 1) : 0;
cat_dimension = nDims ? (nDims - 1) : 0;
}
THArgCheck(numInputs > 0, 3, "invalid number of inputs %d", numInputs);
THArgCheck(cat_dimension >= 0, 4, "invalid dimension %d", dimension + TH_INDEX_BASE);
size = THLongStorage_newWithSize(maxDim);
for(i = 0; i < maxDim; i++)
{
// dimSize is either the size of the dim if it exists, either 1 if #dim > 0, otherwise 0
int64_t dimSize = i < THCTensor_(nDimension)(state, inputs[0])
? THCTensor_(size)(state, inputs[0], i)
: THMin(THCTensor_(nDimension)(state, inputs[0]), 1);
if (i == cat_dimension)
{
for (j = 1; j < numInputs; j++)
{
// accumulate the size over the dimension we want to cat on.
// Empty tensors are allowed
dimSize += i < THCTensor_(nDimension)(state, inputs[j])
? THCTensor_(size)(state, inputs[j], i)
: THMin(THCTensor_(nDimension)(state, inputs[j]), 1);
}
size = THLongStorage_newWithSize(nDims);
// Compute size of the result in the cat dimension
int64_t cat_dim_size = 0;
for (int i = 0; i < numInputs; i++) {
THCTensor *tensor = inputs[i];
if (THCTensor_(nDimension)(state, tensor) == 0) {
continue;
}
else
{
for (j = 1; j < numInputs; j++)
{
int64_t sz = i < THCTensor_(nDimension)(state, inputs[j])
? THCTensor_(size)(state, inputs[j], i)
: THMin(THCTensor_(nDimension)(state, inputs[j]), 1);
// If it's a dimension we're not catting on
// Then fail if sizes are different AND > 0
if (dimSize != sz && dimSize && sz) {
THLongStorage_free(size);
THError("inconsistent tensor sizes");
}
else if(!dimSize)
{
dimSize = sz;
}
}
}
size->data[i] = dimSize;
THCTensor_(check_shape_except_dim)(state, notEmptyTensor, tensor, cat_dimension);
cat_dim_size += THCTensor_(size)(state, tensor, cat_dimension);
}
// Compute the size of the result
for (int dim = 0; dim < nDims; dim++) {
int64_t result_dim_size = THCTensor_(size)(state, notEmptyTensor, dim);
if (dim == cat_dimension) {
result_dim_size = cat_dim_size;
}
size->data[dim] = result_dim_size;
}
THCTensor_(resize)(state, result, size, NULL);
THLongStorage_free(size);
@ -198,7 +204,7 @@ void THCTensor_(catArray)(THCState *state, THCTensor *result,
OutputTensorSizeStride<unsigned int, CAT_ARRAY_MAX_INPUT_DIMS> param;
// Next, let's initialize the size, stride arrays for the output Tensor.
for (i = 0; i < maxDim; ++i) {
for (i = 0; i < nDims; ++i) {
param.outputSize[i] = THCTensor_(size)(state, result, i);
param.outputStride[i] = THCTensor_(stride)(state, result, i);
}
@ -250,7 +256,7 @@ void THCTensor_(catArray)(THCState *state, THCTensor *result,
getCatGrid(state, j, catGrid);
switch (maxDim) {
switch (nDims) {
case 1:
HANDLE_CASE(1);
break;

View File

@ -734,6 +734,38 @@ class TestCuda(TestCase):
z = torch.cat([x, y], 0)
self.assertEqual(z.get_device(), x.get_device())
def test_cat(self):
SIZE = 10
for dim in range(-3, 3):
pos_dim = dim if dim >= 0 else 3 + dim
x = torch.rand(13, SIZE, SIZE).transpose(0, pos_dim).cuda()
y = torch.rand(17, SIZE, SIZE).transpose(0, pos_dim).cuda()
z = torch.rand(19, SIZE, SIZE).transpose(0, pos_dim).cuda()
res1 = torch.cat((x, y, z), dim)
self.assertEqual(res1.narrow(pos_dim, 0, 13), x, 0)
self.assertEqual(res1.narrow(pos_dim, 13, 17), y, 0)
self.assertEqual(res1.narrow(pos_dim, 30, 19), z, 0)
x = torch.randn(20, SIZE, SIZE).cuda()
self.assertEqual(torch.cat(torch.split(x, 7)), x)
self.assertEqual(torch.cat(torch.chunk(x, 7)), x)
y = torch.randn(1, SIZE, SIZE).cuda()
z = torch.cat([x, y])
self.assertEqual(z.size(), (21, SIZE, SIZE))
def test_cat_bad_input_sizes(self):
x = torch.randn(2, 1).cuda()
y = torch.randn(2, 1, 1).cuda()
z = torch.randn(2, 1, 1).cuda()
self.assertRaises(RuntimeError, lambda: torch.cat([x, y, z]))
x = torch.randn(2, 1, 2).cuda()
y = torch.randn(2, 1, 1).cuda()
z = torch.randn(2, 2, 1).cuda()
self.assertRaises(RuntimeError, lambda: torch.cat([x, y, z], dim=1))
def test_serialization(self):
x = torch.randn(4, 4).cuda()
with tempfile.NamedTemporaryFile() as f:

View File

@ -1917,6 +1917,17 @@ class TestTorch(TestCase):
self.assertRaises(RuntimeError, lambda: torch.cat([]))
def test_cat_bad_input_sizes(self):
x = torch.randn(2, 1)
y = torch.randn(2, 1, 1)
z = torch.randn(2, 1, 1)
self.assertRaises(RuntimeError, lambda: torch.cat([x, y, z]))
x = torch.randn(2, 1, 2)
y = torch.randn(2, 1, 1)
z = torch.randn(2, 2, 1)
self.assertRaises(RuntimeError, lambda: torch.cat([x, y, z], dim=1))
def test_stack(self):
x = torch.rand(2, 3, 4)
y = torch.rand(2, 3, 4)

View File

@ -594,6 +594,8 @@ add_docstr(torch._C.cat,
cat(seq, dim=0, out=None) -> Tensor
Concatenates the given sequence of :attr:`seq` tensors in the given dimension.
All tensors must either have the same shape (except in the cat dimension) or be
empty.
:func:`torch.cat` can be seen as an inverse operation for :func:`torch.split`
and :func:`torch.chunk`
@ -601,7 +603,9 @@ and :func:`torch.chunk`
:func:`cat` can be best understood via examples.
Args:
seq (sequence of tensors): any python sequence of tensors of the same type
seq (sequence of tensors): any python sequence of tensors of the same type.
Non-empty tensors provided must have the same shape, except in the
cat dimension.
dim (int, optional): the dimension over which the tensors are concatenated
out (Tensor, optional): the output tensor