mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add some tests for 3d channels last (#118283)
Part of a multi-PR work to fix #59168. Pull Request resolved: https://github.com/pytorch/pytorch/pull/118283 Approved by: https://github.com/albanD
This commit is contained in:
parent
bacbad5bc9
commit
aef820926c
|
|
@ -125,6 +125,15 @@ class Test(torch.jit.ScriptModule):
|
||||||
r = r.contiguous()
|
r = r.contiguous()
|
||||||
return r
|
return r
|
||||||
|
|
||||||
|
@torch.jit.script_method
|
||||||
|
def conv3d(self, x: Tensor, w: Tensor, toChannelsLast: bool) -> Tensor:
|
||||||
|
r = torch.nn.functional.conv3d(x, w)
|
||||||
|
if toChannelsLast:
|
||||||
|
r = r.contiguous(memory_format=torch.channels_last_3d)
|
||||||
|
else:
|
||||||
|
r = r.contiguous()
|
||||||
|
return r
|
||||||
|
|
||||||
@torch.jit.script_method
|
@torch.jit.script_method
|
||||||
def contiguous(self, x: Tensor) -> Tensor:
|
def contiguous(self, x: Tensor) -> Tensor:
|
||||||
return x.contiguous()
|
return x.contiguous()
|
||||||
|
|
|
||||||
Binary file not shown.
|
|
@ -348,15 +348,32 @@ public abstract class PytorchTestBase {
|
||||||
@Test
|
@Test
|
||||||
public void testChannelsLastConv2d() throws IOException {
|
public void testChannelsLastConv2d() throws IOException {
|
||||||
long[] inputShape = new long[] {1, 3, 2, 2};
|
long[] inputShape = new long[] {1, 3, 2, 2};
|
||||||
long[] dataNCHW = new long[] {1, 2, 3, 4, 11, 12, 13, 14, 101, 102, 103, 104};
|
long[] dataNCHW = new long[] {
|
||||||
Tensor inputNCHW = Tensor.fromBlob(dataNCHW, inputShape, MemoryFormat.CONTIGUOUS);
|
111, 112,
|
||||||
long[] dataNHWC = new long[] {1, 11, 101, 2, 12, 102, 3, 13, 103, 4, 14, 104};
|
121, 122,
|
||||||
Tensor inputNHWC = Tensor.fromBlob(dataNHWC, inputShape, MemoryFormat.CHANNELS_LAST);
|
|
||||||
|
|
||||||
|
211, 212,
|
||||||
|
221, 222,
|
||||||
|
|
||||||
|
311, 312,
|
||||||
|
321, 322};
|
||||||
|
Tensor inputNCHW = Tensor.fromBlob(dataNCHW, inputShape, MemoryFormat.CONTIGUOUS);
|
||||||
|
long[] dataNHWC = new long[] {
|
||||||
|
111, 211, 311, 112, 212, 312,
|
||||||
|
|
||||||
|
121, 221, 321, 122, 222, 322};
|
||||||
|
Tensor inputNHWC = Tensor.fromBlob(dataNHWC, inputShape, MemoryFormat.CHANNELS_LAST);
|
||||||
long[] weightShape = new long[] {3, 3, 1, 1};
|
long[] weightShape = new long[] {3, 3, 1, 1};
|
||||||
long[] dataWeightOIHW = new long[] {2, 0, 0, 0, 1, 0, 0, 0, -1};
|
long[] dataWeightOIHW = new long[] {
|
||||||
|
2, 0, 0,
|
||||||
|
0, 1, 0,
|
||||||
|
0, 0, -1};
|
||||||
Tensor wNCHW = Tensor.fromBlob(dataWeightOIHW, weightShape, MemoryFormat.CONTIGUOUS);
|
Tensor wNCHW = Tensor.fromBlob(dataWeightOIHW, weightShape, MemoryFormat.CONTIGUOUS);
|
||||||
long[] dataWeightOHWI = new long[] {2, 0, 0, 0, 1, 0, 0, 0, -1};
|
long[] dataWeightOHWI = new long[] {
|
||||||
|
2, 0, 0,
|
||||||
|
0, 1, 0,
|
||||||
|
0, 0, -1};
|
||||||
|
|
||||||
Tensor wNHWC = Tensor.fromBlob(dataWeightOHWI, weightShape, MemoryFormat.CHANNELS_LAST);
|
Tensor wNHWC = Tensor.fromBlob(dataWeightOHWI, weightShape, MemoryFormat.CHANNELS_LAST);
|
||||||
|
|
||||||
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||||
|
|
@ -367,7 +384,15 @@ public abstract class PytorchTestBase {
|
||||||
outputNCHW,
|
outputNCHW,
|
||||||
MemoryFormat.CONTIGUOUS,
|
MemoryFormat.CONTIGUOUS,
|
||||||
new long[] {1, 3, 2, 2},
|
new long[] {1, 3, 2, 2},
|
||||||
new long[] {2, 4, 6, 8, 11, 12, 13, 14, -101, -102, -103, -104});
|
new long[] {
|
||||||
|
2*111, 2*112,
|
||||||
|
2*121, 2*122,
|
||||||
|
|
||||||
|
211, 212,
|
||||||
|
221, 222,
|
||||||
|
|
||||||
|
-311, -312,
|
||||||
|
-321, -322});
|
||||||
|
|
||||||
final IValue outputNHWC =
|
final IValue outputNHWC =
|
||||||
module.runMethod("conv2d", IValue.from(inputNHWC), IValue.from(wNHWC), IValue.from(true));
|
module.runMethod("conv2d", IValue.from(inputNHWC), IValue.from(wNHWC), IValue.from(true));
|
||||||
|
|
@ -375,7 +400,89 @@ public abstract class PytorchTestBase {
|
||||||
outputNHWC,
|
outputNHWC,
|
||||||
MemoryFormat.CHANNELS_LAST,
|
MemoryFormat.CHANNELS_LAST,
|
||||||
new long[] {1, 3, 2, 2},
|
new long[] {1, 3, 2, 2},
|
||||||
new long[] {2, 11, -101, 4, 12, -102, 6, 13, -103, 8, 14, -104});
|
new long[] {
|
||||||
|
2*111, 211, -311, 2*112, 212, -312,
|
||||||
|
2*121, 221, -321, 2*122, 222, -322});
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testChannelsLastConv3d() throws IOException {
|
||||||
|
long[] inputShape = new long[] {1, 3, 2, 2, 2};
|
||||||
|
long[] dataNCDHW = new long[] {
|
||||||
|
1111, 1112,
|
||||||
|
1121, 1122,
|
||||||
|
1211, 1212,
|
||||||
|
1221, 1222,
|
||||||
|
|
||||||
|
2111, 2112,
|
||||||
|
2121, 2122,
|
||||||
|
2211, 2212,
|
||||||
|
2221, 2222,
|
||||||
|
|
||||||
|
3111, 3112,
|
||||||
|
3121, 3122,
|
||||||
|
3211, 3212,
|
||||||
|
3221, 3222};
|
||||||
|
Tensor inputNCDHW = Tensor.fromBlob(dataNCDHW, inputShape, MemoryFormat.CONTIGUOUS);
|
||||||
|
long[] dataNDHWC = new long[] {
|
||||||
|
1111, 2111, 3111,
|
||||||
|
1112, 2112, 3112,
|
||||||
|
|
||||||
|
1121, 2121, 3121,
|
||||||
|
1122, 2122, 3122,
|
||||||
|
|
||||||
|
1211, 2211, 3211,
|
||||||
|
1212, 2212, 3212,
|
||||||
|
|
||||||
|
1221, 2221, 3221,
|
||||||
|
1222, 2222, 3222};
|
||||||
|
|
||||||
|
Tensor inputNDHWC = Tensor.fromBlob(dataNDHWC, inputShape, MemoryFormat.CHANNELS_LAST_3D);
|
||||||
|
|
||||||
|
long[] weightShape = new long[] {3, 3, 1, 1, 1};
|
||||||
|
long[] dataWeightOIDHW = new long[] {
|
||||||
|
2, 0, 0,
|
||||||
|
0, 1, 0,
|
||||||
|
0, 0, -1,
|
||||||
|
};
|
||||||
|
Tensor wNCDHW = Tensor.fromBlob(dataWeightOIDHW, weightShape, MemoryFormat.CONTIGUOUS);
|
||||||
|
long[] dataWeightODHWI = new long[] {
|
||||||
|
2, 0, 0,
|
||||||
|
0, 1, 0,
|
||||||
|
0, 0, -1,
|
||||||
|
};
|
||||||
|
Tensor wNDHWC = Tensor.fromBlob(dataWeightODHWI, weightShape, MemoryFormat.CHANNELS_LAST_3D);
|
||||||
|
|
||||||
|
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
|
||||||
|
|
||||||
|
final IValue outputNCDHW =
|
||||||
|
module.runMethod("conv3d", IValue.from(inputNCDHW), IValue.from(wNCDHW), IValue.from(false));
|
||||||
|
assertIValueTensor(
|
||||||
|
outputNCDHW,
|
||||||
|
MemoryFormat.CONTIGUOUS,
|
||||||
|
new long[] {1, 3, 2, 2, 2},
|
||||||
|
new long[] {
|
||||||
|
2*1111, 2*1112, 2*1121, 2*1122,
|
||||||
|
2*1211, 2*1212, 2*1221, 2*1222,
|
||||||
|
|
||||||
|
2111, 2112, 2121, 2122,
|
||||||
|
2211, 2212, 2221, 2222,
|
||||||
|
|
||||||
|
-3111, -3112, -3121, -3122,
|
||||||
|
-3211, -3212, -3221, -3222});
|
||||||
|
|
||||||
|
final IValue outputNDHWC =
|
||||||
|
module.runMethod("conv3d", IValue.from(inputNDHWC), IValue.from(wNDHWC), IValue.from(true));
|
||||||
|
assertIValueTensor(
|
||||||
|
outputNDHWC,
|
||||||
|
MemoryFormat.CHANNELS_LAST_3D,
|
||||||
|
new long[] {1, 3, 2, 2, 2},
|
||||||
|
new long[] {
|
||||||
|
2*1111, 2111, -3111, 2*1112, 2112, -3112,
|
||||||
|
2*1121, 2121, -3121, 2*1122, 2122, -3122,
|
||||||
|
|
||||||
|
2*1211, 2211, -3211, 2*1212, 2212, -3212,
|
||||||
|
2*1221, 2221, -3221, 2*1222, 2222, -3222});
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
|
|
||||||
|
|
@ -84,6 +84,15 @@ def conv2d(self, x: Tensor, w: Tensor, toChannelsLast: bool) -> Tensor:
|
||||||
r = r.contiguous()
|
r = r.contiguous()
|
||||||
return r
|
return r
|
||||||
|
|
||||||
|
def conv3d(self, x: Tensor, w: Tensor, toChannelsLast: bool) -> Tensor:
|
||||||
|
r = torch.conv3d(x, w)
|
||||||
|
if (toChannelsLast):
|
||||||
|
# memory_format=torch.channels_last_3d
|
||||||
|
r = r.contiguous(memory_format=2)
|
||||||
|
else:
|
||||||
|
r = r.contiguous()
|
||||||
|
return r
|
||||||
|
|
||||||
def contiguous(self, x: Tensor) -> Tensor:
|
def contiguous(self, x: Tensor) -> Tensor:
|
||||||
return x.contiguous()
|
return x.contiguous()
|
||||||
|
|
||||||
|
|
|
||||||
Binary file not shown.
|
|
@ -112,6 +112,15 @@ class AndroidAPIModule(torch.jit.ScriptModule):
|
||||||
r = r.contiguous()
|
r = r.contiguous()
|
||||||
return r
|
return r
|
||||||
|
|
||||||
|
@torch.jit.script_method
|
||||||
|
def conv3d(self, x: Tensor, w: Tensor, toChannelsLast: bool) -> Tensor:
|
||||||
|
r = torch.nn.functional.conv3d(x, w)
|
||||||
|
if toChannelsLast:
|
||||||
|
r = r.contiguous(memory_format=torch.channels_last_3d)
|
||||||
|
else:
|
||||||
|
r = r.contiguous()
|
||||||
|
return r
|
||||||
|
|
||||||
@torch.jit.script_method
|
@torch.jit.script_method
|
||||||
def contiguous(self, x: Tensor) -> Tensor:
|
def contiguous(self, x: Tensor) -> Tensor:
|
||||||
return x.contiguous()
|
return x.contiguous()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user