From 09eefec627d7b54a3c4906dd9f16f8927e8fdbba Mon Sep 17 00:00:00 2001 From: Richard Barnes Date: Thu, 7 Jan 2021 15:36:48 -0800 Subject: [PATCH] Clean up some type annotations in android (#49944) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49944 Upgrades type annotations from Python2 to Python3 Test Plan: Sandcastle tests Reviewed By: xush6528 Differential Revision: D25717539 fbshipit-source-id: c621e2712e87eaed08cda48eb0fb224f6b0570c9 --- .../generate_test_torchscripts.py | 60 +++++++------------ android/pytorch_android/test_asset.jit | 60 +++++++------------ 2 files changed, 40 insertions(+), 80 deletions(-) diff --git a/android/pytorch_android/generate_test_torchscripts.py b/android/pytorch_android/generate_test_torchscripts.py index 6384d588e9a..8b41fefc246 100644 --- a/android/pytorch_android/generate_test_torchscripts.py +++ b/android/pytorch_android/generate_test_torchscripts.py @@ -20,92 +20,77 @@ class Test(torch.jit.ScriptModule): return None @torch.jit.script_method - def eqBool(self, input): - # type: (bool) -> bool + def eqBool(self, input: bool) -> bool: return input @torch.jit.script_method - def eqInt(self, input): - # type: (int) -> int + def eqInt(self, input: int) -> int: return input @torch.jit.script_method - def eqFloat(self, input): - # type: (float) -> float + def eqFloat(self, input: float) -> float: return input @torch.jit.script_method - def eqStr(self, input): - # type: (str) -> str + def eqStr(self, input: str) -> str: return input @torch.jit.script_method - def eqTensor(self, input): - # type: (Tensor) -> Tensor + def eqTensor(self, input: Tensor) -> Tensor: return input @torch.jit.script_method - def eqDictStrKeyIntValue(self, input): - # type: (Dict[str, int]) -> Dict[str, int] + def eqDictStrKeyIntValue(self, input: Dict[str, int]) -> Dict[str, int]: return input @torch.jit.script_method - def eqDictIntKeyIntValue(self, input): - # type: (Dict[int, int]) -> Dict[int, int] + def eqDictIntKeyIntValue(self, input: Dict[int, int]) -> Dict[int, int]: return input @torch.jit.script_method - def eqDictFloatKeyIntValue(self, input): - # type: (Dict[float, int]) -> Dict[float, int] + def eqDictFloatKeyIntValue(self, input: Dict[float, int]) -> Dict[float, int]: return input @torch.jit.script_method - def listIntSumReturnTuple(self, input): - # type: (List[int]) -> Tuple[List[int], int] + def listIntSumReturnTuple(self, input: List[int]) -> Tuple[List[int], int]: sum = 0 for x in input: sum += x return (input, sum) @torch.jit.script_method - def listBoolConjunction(self, input): - # type: (List[bool]) -> bool + def listBoolConjunction(self, input: List[bool]) -> bool: res = True for x in input: res = res and x return res @torch.jit.script_method - def listBoolDisjunction(self, input): - # type: (List[bool]) -> bool + def listBoolDisjunction(self, input: List[bool]) -> bool: res = False for x in input: res = res or x return res @torch.jit.script_method - def tupleIntSumReturnTuple(self, input): - # type: (Tuple[int, int, int]) -> Tuple[Tuple[int, int, int], int] + def tupleIntSumReturnTuple(self, input: Tuple[int, int, int]) -> Tuple[Tuple[int, int, int], int]: sum = 0 for x in input: sum += x return (input, sum) @torch.jit.script_method - def optionalIntIsNone(self, input): - # type: (Optional[int]) -> bool + def optionalIntIsNone(self, input: Optional[int]) -> bool: return input is None @torch.jit.script_method - def intEq0None(self, input): - # type: (int) -> Optional[int] + def intEq0None(self, input: int) -> Optional[int]: if input == 0: return None return input @torch.jit.script_method - def str3Concat(self, input): - # type: (str) -> str + def str3Concat(self, input: str) -> str: return input + input + input @torch.jit.script_method @@ -113,8 +98,7 @@ class Test(torch.jit.ScriptModule): return torch.tensor([int(input.item())])[0] @torch.jit.script_method - def testAliasWithOffset(self): - # type: () -> List[Tensor] + def testAliasWithOffset(self) -> List[Tensor]: x = torch.tensor([100, 200]) a = [x[0], x[1]] return a @@ -128,8 +112,7 @@ class Test(torch.jit.ScriptModule): return x @torch.jit.script_method - def conv2d(self, x, w, toChannelsLast): - # type: (Tensor, Tensor, bool) -> Tensor + def conv2d(self, x: Tensor, w: Tensor, toChannelsLast: bool) -> Tensor: r = torch.nn.functional.conv2d(x, w) if (toChannelsLast): r = r.contiguous(memory_format=torch.channels_last) @@ -138,18 +121,15 @@ class Test(torch.jit.ScriptModule): return r @torch.jit.script_method - def contiguous(self, x): - # type: (Tensor) -> Tensor + def contiguous(self, x: Tensor) -> Tensor: return x.contiguous() @torch.jit.script_method - def contiguousChannelsLast(self, x): - # type: (Tensor) -> Tensor + def contiguousChannelsLast(self, x: Tensor) -> Tensor: return x.contiguous(memory_format=torch.channels_last) @torch.jit.script_method - def contiguousChannelsLast3d(self, x): - # type: (Tensor) -> Tensor + def contiguousChannelsLast3d(self, x: Tensor) -> Tensor: return x.contiguous(memory_format=torch.channels_last_3d) scriptAndSave(Test(), "test.pt") diff --git a/android/pytorch_android/test_asset.jit b/android/pytorch_android/test_asset.jit index 49a41eff36a..3bd9037da4e 100644 --- a/android/pytorch_android/test_asset.jit +++ b/android/pytorch_android/test_asset.jit @@ -1,85 +1,69 @@ def forward(self, input): return None -def eqBool(self, input): - # type: (bool) -> bool +def eqBool(self, input: bool) -> bool: return input -def eqInt(self, input): - # type: (int) -> int +def eqInt(self, input: int) -> int: return input -def eqFloat(self, input): - # type: (float) -> float +def eqFloat(self, input: float) -> float: return input -def eqStr(self, input): - # type: (str) -> str +def eqStr(self, input: str) -> str: return input -def eqTensor(self, input): - # type: (Tensor) -> Tensor +def eqTensor(self, input: Tensor) -> Tensor: return input -def eqDictStrKeyIntValue(self, input): - # type: (Dict[str, int]) -> Dict[str, int] +def eqDictStrKeyIntValue(self, input: Dict[str, int]) -> Dict[str, int]: return input -def eqDictIntKeyIntValue(self, input): - # type: (Dict[int, int]) -> Dict[int, int] +def eqDictIntKeyIntValue(self, input: Dict[int, int]) -> Dict[int, int]: return input -def eqDictFloatKeyIntValue(self, input): - # type: (Dict[float, int]) -> Dict[float, int] +def eqDictFloatKeyIntValue(self, input: Dict[float, int]) -> Dict[float, int]: return input -def listIntSumReturnTuple(self, input): - # type: (List[int]) -> Tuple[List[int], int] +def listIntSumReturnTuple(self, input: List[int]) -> Tuple[List[int], int]: sum = 0 for x in input: sum += x return (input, sum) -def listBoolConjunction(self, input): - # type: (List[bool]) -> bool +def listBoolConjunction(self, input: List[bool]) -> bool: res = True for x in input: res = res and x return res -def listBoolDisjunction(self, input): - # type: (List[bool]) -> bool +def listBoolDisjunction(self, input: List[bool]) -> bool: res = False for x in input: res = res or x return res -def tupleIntSumReturnTuple(self, input): - # type: (Tuple[int, int, int]) -> Tuple[Tuple[int, int, int], int] +def tupleIntSumReturnTuple(self, input: Tuple[int, int, int]) -> Tuple[Tuple[int, int, int], int]: sum = 0 for x in input: sum += x return (input, sum) -def optionalIntIsNone(self, input): - # type: (Optional[int]) -> bool +def optionalIntIsNone(self, input: Optional[int]) -> bool: return input is None -def intEq0None(self, input): - # type: (int) -> Optional[int] +def intEq0None(self, input: int) -> Optional[int]: if input == 0: return None return input -def str3Concat(self, input): - # type: (str) -> str +def str3Concat(self, input: str) -> str: return input + input + input def newEmptyShapeWithItem(self, input): return torch.tensor([int(input.item())])[0] -def testAliasWithOffset(self): - # type: () -> List[Tensor] +def testAliasWithOffset(self) -> List[Tensor]: x = torch.tensor([100, 200]) a = [x[0], x[1]] return a @@ -91,8 +75,7 @@ def testNonContiguous(self): assert x[1] == 300 return x -def conv2d(self, x, w, toChannelsLast): - # type: (Tensor, Tensor, bool) -> Tensor +def conv2d(self, x: Tensor, w: Tensor, toChannelsLast: bool) -> Tensor: r = torch.conv2d(x, w) if (toChannelsLast): # memory_format=torch.channels_last @@ -101,16 +84,13 @@ def conv2d(self, x, w, toChannelsLast): r = r.contiguous() return r -def contiguous(self, x): - # type: (Tensor) -> Tensor +def contiguous(self, x: Tensor) -> Tensor: return x.contiguous() -def contiguousChannelsLast(self, x): - # type: (Tensor) -> Tensor +def contiguousChannelsLast(self, x: Tensor) -> Tensor: # memory_format=torch.channels_last return x.contiguous(memory_format=2) -def contiguousChannelsLast3d(self, x): - # type: (Tensor) -> Tensor +def contiguousChannelsLast3d(self, x: Tensor) -> Tensor: # memory_format=torch.channels_last_3d return x.contiguous(memory_format=3)