# Owner(s): ["oncall: jit"] from typing import Sequence import torch import functools from torch.testing._internal.common_utils import run_tests, TestCase from torch.testing._internal.jit_utils import JitTestCase from torch.testing._internal.common_methods_invocations import op_db from torch.testing._internal.common_device_type import ops, instantiate_device_type_tests import torch._lazy import torch._lazy.config import torch._lazy.metrics import torch._lazy.ir_cache import torch._lazy.ts_backend import itertools import yaml import os import pathlib torch._lazy.ts_backend.init() def get_test_device(): return 'cuda' if 'LTC_TS_CUDA' in os.environ else 'cpu' def remove_suffixes(l): return [x.split(".")[0] for x in l] def init_lists(): path_to_script = pathlib.Path(os.path.abspath(os.path.dirname(__file__))) TS_NATIVE_FUNCTIONS_PATH = path_to_script.parent.parent / "aten/src/ATen/native/ts_native_functions.yaml" with open(TS_NATIVE_FUNCTIONS_PATH) as f: yaml_ts = yaml.load(f, yaml.Loader) LAZY_OPS_LIST = set(remove_suffixes(itertools.chain(yaml_ts["full_codegen"], yaml_ts["supported"], yaml_ts["autograd"]))) FALLBACK_LIST = set(["clamp"]) SKIP_RUNTIME_ERROR_LIST = set([ 'index_select', # Empty output_sizes is not supported 'clone', # is clone decomposed? # General ASAN Failure due to related to generating bool values. # https://github.com/pytorch/pytorch/issues/74519 # https://github.com/pytorch/pytorch/issues/63034 'nonzero', # ASAN failure (paste: P501906539) 'all', # ASAN failure 'any', # ASAN failure 'logdet', # ASAN failure ]) SKIP_INCORRECT_RESULTS_LIST = set([ 'squeeze', # Value out of range 't', # Value out of range 'transpose', # Value out of range 'bernoulli', # incorrect results 'pow', # incorrect results 'addcdiv', # incorrect results (on CI not locally?) ]) return (LAZY_OPS_LIST, FALLBACK_LIST, SKIP_RUNTIME_ERROR_LIST, SKIP_INCORRECT_RESULTS_LIST) (LAZY_OPS_LIST, FALLBACK_LIST, SKIP_RUNTIME_ERROR_LIST, SKIP_INCORRECT_RESULTS_LIST) = init_lists() torch.manual_seed(42) def clone_move(t): dev = 'lazy' copy_t = t.detach().clone().requires_grad_(True).to(device=dev) return copy_t class TestLazyTensor(JitTestCase): def testConvolutionBackward(self): test_device = get_test_device() inp = torch.rand(1, 3, 128, 128, device=test_device, requires_grad=True) inp_copy = clone_move(inp) grad = torch.rand(1, 32, 121, 121, device=test_device) # no requires_grad grad_copy = clone_move(grad) weight = torch.rand(32, 3, 8, 8, device=test_device, requires_grad=True) weight_copy = clone_move(weight) bias = torch.rand(32, device=test_device, requires_grad=True) bias_copy = clone_move(bias) # run eager conv_out = torch.nn.functional.conv2d(inp, weight, bias) (inp_grad, weight_grad, bias_grad) = torch.autograd.grad([conv_out], [inp, weight, bias], [grad]) # run lazy conv_copy_out = torch.nn.functional.conv2d(inp_copy, weight_copy, bias_copy) (inp_copy_grad, weight_copy_grad, bias_copy_grad) = torch.autograd.grad( [conv_copy_out], [inp_copy, weight_copy, bias_copy], [grad_copy]) # check numerics torch.testing.assert_close(bias_copy_grad.cpu(), bias_grad.cpu()) torch.testing.assert_close(weight_copy_grad.cpu(), weight_grad.cpu()) torch.testing.assert_close(inp_copy_grad.cpu(), inp_grad.cpu()) class TestLazyOpInfo(TestCase): @ops([op for op in op_db if op.name in LAZY_OPS_LIST and op.name not in SKIP_RUNTIME_ERROR_LIST], allowed_dtypes=(torch.float,)) def test_dispatched_to_lazy(self, device, dtype, op): def get_name(op): l = [op.name] if op.variant_test_name != '': l.append(op.variant_test_name) return '.'.join(l) global FALLBACK_LIST samples = op.sample_inputs("lazy", dtype, requires_grad=False) sample = list(samples)[0] args = [sample.input] + list(sample.args) kwargs = sample.kwargs torch._lazy.mark_step() torch._lazy.wait_device_ops() torch._lazy.metrics.reset() r = op(*args, **kwargs) torch._lazy.mark_step() torch._lazy.wait_device_ops() prefix = "aten" if op.name in FALLBACK_LIST else "lazy" found = f"{prefix}::{op.name}" in remove_suffixes(torch._lazy.metrics.counter_names()) # check aliases if not found: for alias in op.aliases: alias_found = f"{prefix}::{alias.name}" in remove_suffixes(torch._lazy.metrics.counter_names()) found = found or alias_found if found: break self.assertTrue(found) @ops([op for op in op_db if op.name in LAZY_OPS_LIST and op.name not in SKIP_RUNTIME_ERROR_LIST | SKIP_INCORRECT_RESULTS_LIST], allowed_dtypes=(torch.float,)) # noqa: B950 def test_correctness(self, device, dtype, op): test_device = get_test_device() def clone_to_device(input, dev): if isinstance(input, torch.Tensor): return input.detach().clone().to(device=dev) if isinstance(input, Sequence) and not isinstance(input, str): return tuple(map(functools.partial(clone_to_device, dev=dev), input)) return input def assert_allclose_rec(t): a, b = t self.assertEqual(type(a), type(b)) if isinstance(a, torch.Tensor): self.assertTrue(torch.allclose(clone_to_device(a, test_device), b, atol=1e-4)) if isinstance(a, Sequence): map(assert_allclose_rec, zip(a, b)) samples = op.sample_inputs("lazy", dtype, requires_grad=False) for sample in samples: args = [sample.input] + list(sample.args) kwargs = sample.kwargs copy_args = clone_to_device(args, test_device) r_exp = op(*copy_args, **kwargs) r_actual = op(*args, **kwargs) assert_allclose_rec((r_actual, r_exp)) @ops([op for op in op_db if op.name in LAZY_OPS_LIST and op.name not in SKIP_RUNTIME_ERROR_LIST | SKIP_INCORRECT_RESULTS_LIST], allowed_dtypes=(torch.float,)) # noqa: B950 def test_correctness_with_reusing_ir(self, device, dtype, op): torch._lazy.config.set_reuse_ir(True) test_device = get_test_device() def clone_to_device(input, dev): if isinstance(input, torch.Tensor): return input.detach().clone().to(device=dev) if isinstance(input, Sequence) and not isinstance(input, str): return tuple(map(functools.partial(clone_to_device, dev=dev), input)) return input def assert_allclose_rec(t): a, b = t self.assertEqual(type(a), type(b)) if isinstance(a, torch.Tensor): self.assertTrue(torch.allclose(clone_to_device(a, test_device), b, atol=1e-4)) if isinstance(a, Sequence): map(assert_allclose_rec, zip(a, b)) samples = op.sample_inputs("lazy", dtype, requires_grad=False) for sample in samples: args = [sample.input] + list(sample.args) kwargs = sample.kwargs copy_args = clone_to_device(args, test_device) r_exp = op(*copy_args, **kwargs) r_actual = op(*args, **kwargs) torch._lazy.mark_step() assert_allclose_rec((r_actual, r_exp)) torch._lazy.ir_cache.reset() torch._lazy.config.set_reuse_ir(False) # TODO: after we move to master, add Lazy as a new Device here: # https://github.com/pytorch/pytorch/blob/master/torch/testing/_internal/common_device_type.py#L532 instantiate_device_type_tests(TestLazyOpInfo, globals(), only_for="cpu") class TestLazyDynamicOps(TestCase): @classmethod def setUpClass(cls) -> None: # Setup the dynamic shape mode cls.old_ssa_mode = torch._C._lazy._get_symbolic_shape_mode() torch._C._lazy._set_symbolic_shape_mode(True) return super().setUpClass() @classmethod def tearDownClass(cls) -> None: torch._C._lazy._set_symbolic_shape_mode(cls.old_ssa_mode) return super().tearDownClass() def test_nonzero_dynamic(self): # Test that nonzero gives upper bounds sizes when symbolic shape mode is enabled test_device = get_test_device() x1 = torch.tensor([[0, 1.0, 2.0], [3.0, 0, 0]], device=test_device, requires_grad=True) x1_lazy = clone_move(x1) x2_lazy = torch.nonzero(x1_lazy) print(x2_lazy.size()) self.assertEqual(tuple(x2_lazy.size()), (6, 2)) # We should still be able to instantiate it and get the actual result x2_eager = x2_lazy.cpu() self.assertEqual(tuple(x2_eager.size()), (3, 2)) if __name__ == '__main__': run_tests()