pytorch/test/test_jit_cuda_fuser.py
Kurt Mohler ffce2492af Remove set_default_dtype calls from jit and ops tests (#105072)
Part of #68972

This only attempts to avoid setting the default dtype for `test_jit.py` and `test_ops.py`. There are other tests, like `test_nn.py`, which will be addressed in follow up PRs

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105072
Approved by: https://github.com/ezyang
2023-07-15 03:18:33 +00:00

17 lines
397 B
Python

# Owner(s): ["module: nvfuser"]
import torch
from torch.testing._internal.common_utils import set_default_dtype
try:
from _nvfuser.test_torchscript import * # noqa: F403,F401
except ImportError:
def run_tests():
return
pass
if __name__ == '__main__':
# TODO: Update nvfuser to work with float default dtype
with set_default_dtype(torch.double):
run_tests()