pytorch/android/pytorch_android/test_asset.jit
David Reiss 1b4d3d5748 Properly return data from non-contiguous tensors in Java
Summary:
These were returning incorrect data before.  Now we make a contiguous copy
before converting to Java.  Exposing raw data to the user might be faster in
some cases, but it's not clear that it's worth the complexity and code size.

Test Plan: New unit test.

Reviewed By: IvanKobzarev

Differential Revision: D19221361

fbshipit-source-id: 22ecdad252c8fd968f833a2be5897c5ae483700c
2020-01-07 16:33:31 -08:00

93 lines
2.0 KiB
Plaintext

def forward(self, input):
return None
def eqBool(self, input):
# type: (bool) -> bool
return input
def eqInt(self, input):
# type: (int) -> int
return input
def eqFloat(self, input):
# type: (float) -> float
return input
def eqStr(self, input):
# type: (str) -> str
return input
def eqTensor(self, input):
# type: (Tensor) -> Tensor
return input
def eqDictStrKeyIntValue(self, input):
# type: (Dict[str, int]) -> Dict[str, int]
return input
def eqDictIntKeyIntValue(self, input):
# type: (Dict[int, int]) -> Dict[int, int]
return input
def eqDictFloatKeyIntValue(self, input):
# type: (Dict[float, int]) -> Dict[float, int]
return input
def listIntSumReturnTuple(self, input):
# type: (List[int]) -> Tuple[List[int], int]
sum = 0
for x in input:
sum += x
return (input, sum)
def listBoolConjunction(self, input):
# type: (List[bool]) -> bool
res = True
for x in input:
res = res and x
return res
def listBoolDisjunction(self, input):
# type: (List[bool]) -> bool
res = False
for x in input:
res = res or x
return res
def tupleIntSumReturnTuple(self, input):
# type: (Tuple[int, int, int]) -> Tuple[Tuple[int, int, int], int]
sum = 0
for x in input:
sum += x
return (input, sum)
def optionalIntIsNone(self, input):
# type: (Optional[int]) -> bool
return input is None
def intEq0None(self, input):
# type: (int) -> Optional[int]
if input == 0:
return None
return input
def str3Concat(self, input):
# type: (str) -> str
return input + input + input
def newEmptyShapeWithItem(self, input):
return torch.tensor([int(input.item())])[0]
def testAliasWithOffset(self):
# type: () -> List[Tensor]
x = torch.tensor([100, 200])
a = [x[0], x[1]]
return a
def testNonContiguous(self):
x = torch.tensor([100, 200, 300])[::2]
assert not x.is_contiguous()
assert x[0] == 100
assert x[1] == 300
return x