mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 00:20:20 +01:00
Add additional concat test.
PiperOrigin-RevId: 157844113
This commit is contained in:
parent
f661128dbf
commit
d5421cf58e
|
|
@ -442,6 +442,39 @@ XLA_TEST_F(ConcatTest, ConcatSeveralR1S32s) {
|
||||||
ComputeAndCompareR1<int32>(&builder, expected, {});
|
ComputeAndCompareR1<int32>(&builder, expected, {});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
XLA_TEST_F(ConcatTest, ConcatR3WeirdDims) {
|
||||||
|
ComputationBuilder builder(client_, TestName());
|
||||||
|
|
||||||
|
Array3D<float> arr0(9, 17, 1);
|
||||||
|
arr0.Fill(1);
|
||||||
|
|
||||||
|
Array3D<float> arr1(9, 17, 256);
|
||||||
|
arr1.Fill(2);
|
||||||
|
|
||||||
|
Array3D<float> expected(9, 17, arr0.n3() + arr1.n3());
|
||||||
|
for (int64 i = 0; i < expected.n1(); ++i) {
|
||||||
|
for (int64 j = 0; j < expected.n2(); ++j) {
|
||||||
|
int64 kk = 0;
|
||||||
|
for (const Array3D<float>& arr : {arr0, arr1}) {
|
||||||
|
for (int64 k = 0; k < arr.n3(); ++k, ++kk) {
|
||||||
|
expected(i, j, kk) = arr(i, j, k);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ComputationDataHandle h0;
|
||||||
|
auto p0 = CreateR3Parameter<float>(arr0, /*parameter_number=*/0, "p0",
|
||||||
|
&builder, &h0);
|
||||||
|
ComputationDataHandle h1;
|
||||||
|
auto p1 = CreateR3Parameter<float>(arr1, /*parameter_number=*/1, "p1",
|
||||||
|
&builder, &h1);
|
||||||
|
|
||||||
|
auto concatenated = builder.ConcatInDim({h0, h1}, 2);
|
||||||
|
|
||||||
|
ComputeAndCompareR3<float>(&builder, expected, {p0.get(), p1.get()});
|
||||||
|
}
|
||||||
|
|
||||||
// Describes a binary rank-2 concatenation test.
|
// Describes a binary rank-2 concatenation test.
|
||||||
struct R2BinarySpec {
|
struct R2BinarySpec {
|
||||||
int64 lhs_dim0;
|
int64 lhs_dim0;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user