mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary:
I noticed after creating https://github.com/pytorch/pytorch/issues/71553 that the test ownership lint was not working properly.
This fixes my egregious mistake and fixes the broken lints.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71554
Reviewed By: malfet
Differential Revision: D33690732
Pulled By: janeyx99
fbshipit-source-id: ba4dfbcd98038e4afd63e326832ae40935d2501e
(cherry picked from commit 1bbc3d343a)
65 lines
2.4 KiB
Python
65 lines
2.4 KiB
Python
# Owner(s): ["oncall: mobile"]
|
|
|
|
import torch
|
|
import torch.utils.bundled_inputs
|
|
import io
|
|
|
|
from torch.jit.mobile import _load_for_lite_interpreter
|
|
from torch.testing._internal.common_utils import TestCase, run_tests
|
|
from pathlib import Path
|
|
from itertools import product
|
|
|
|
pytorch_test_dir = Path(__file__).resolve().parents[1]
|
|
|
|
class TestLiteScriptModule(TestCase):
|
|
|
|
def _save_load_mobile_module(self, script_module: torch.jit.ScriptModule):
|
|
buffer = io.BytesIO(script_module._save_to_buffer_for_lite_interpreter(_save_mobile_debug_info=True))
|
|
buffer.seek(0)
|
|
mobile_module = _load_for_lite_interpreter(buffer)
|
|
return mobile_module
|
|
|
|
def _try_fn(self, fn, *args, **kwargs):
|
|
try:
|
|
return fn(*args, **kwargs)
|
|
except Exception as e:
|
|
return e
|
|
|
|
def test_versioned_div_tensor(self):
|
|
|
|
def div_tensor_0_3(self, other):
|
|
if self.is_floating_point() or other.is_floating_point():
|
|
return self.true_divide(other)
|
|
return self.divide(other, rounding_mode='trunc')
|
|
|
|
model_path = pytorch_test_dir / "cpp" / "jit" / "upgrader_models" / "test_versioned_div_tensor_v2.ptl"
|
|
mobile_module_v2 = _load_for_lite_interpreter(str(model_path))
|
|
jit_module_v2 = torch.jit.load(str(model_path))
|
|
current_mobile_module = self._save_load_mobile_module(jit_module_v2)
|
|
vals = (2., 3., 2, 3)
|
|
for val_a, val_b in product(vals, vals):
|
|
a = torch.tensor((val_a,))
|
|
b = torch.tensor((val_b,))
|
|
|
|
def _helper(m, fn):
|
|
m_results = self._try_fn(m, a, b)
|
|
fn_result = self._try_fn(fn, a, b)
|
|
|
|
if isinstance(m_results, Exception):
|
|
self.assertTrue(isinstance(fn_result, Exception))
|
|
else:
|
|
for result in m_results:
|
|
print("result: ", result)
|
|
print("fn_result: ", fn_result)
|
|
print(result == fn_result)
|
|
self.assertTrue(result.eq(fn_result))
|
|
# self.assertEqual(result, fn_result)
|
|
|
|
# old operator should produce the same result as applying upgrader of torch.div op
|
|
# _helper(mobile_module_v2, div_tensor_0_3)
|
|
# latest operator should produce the same result as applying torch.div op
|
|
# _helper(current_mobile_module, torch.div)
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|