pytorch/test/test_jit_cuda_fuser.py

17 lines
397 B
Python

# Owner(s): ["module: nvfuser"]
import torch
from torch.testing._internal.common_utils import set_default_dtype
try:
from _nvfuser.test_torchscript import * # noqa: F403,F401
except ImportError:
def run_tests():
return
pass
if __name__ == '__main__':
# TODO: Update nvfuser to work with float default dtype
with set_default_dtype(torch.double):
run_tests()