mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Action following https://github.com/pytorch/pytorch/issues/66232 Pull Request resolved: https://github.com/pytorch/pytorch/pull/66808 Reviewed By: mrshenli Differential Revision: D31761414 Pulled By: janeyx99 fbshipit-source-id: baf8c49ff9c4bcda7b0ea0f6aafd26380586e72d
21 lines
711 B
Python
21 lines
711 B
Python
# Owner(s): ["oncall: jit"]
|
|
|
|
import torch
|
|
from torch.testing._internal.jit_utils import JitTestCase
|
|
|
|
if __name__ == '__main__':
|
|
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
|
"\tpython test/test_jit.py TESTNAME\n\n"
|
|
"instead.")
|
|
|
|
class TestPythonIr(JitTestCase):
|
|
def test_param_strides(self):
|
|
def trace_me(arg):
|
|
return arg
|
|
t = torch.zeros(1, 3, 16, 16)
|
|
traced = torch.jit.trace(trace_me, t)
|
|
value = list(traced.graph.param_node().outputs())[0]
|
|
real_strides = list(t.stride())
|
|
type_strides = value.type().strides()
|
|
self.assertEqual(real_strides, type_strides)
|