mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Part of #68972 This only attempts to avoid setting the default dtype for `test_jit.py` and `test_ops.py`. There are other tests, like `test_nn.py`, which will be addressed in follow up PRs Pull Request resolved: https://github.com/pytorch/pytorch/pull/105072 Approved by: https://github.com/ezyang
17 lines
397 B
Python
17 lines
397 B
Python
# Owner(s): ["module: nvfuser"]
|
|
|
|
import torch
|
|
from torch.testing._internal.common_utils import set_default_dtype
|
|
|
|
try:
|
|
from _nvfuser.test_torchscript import * # noqa: F403,F401
|
|
except ImportError:
|
|
def run_tests():
|
|
return
|
|
pass
|
|
|
|
if __name__ == '__main__':
|
|
# TODO: Update nvfuser to work with float default dtype
|
|
with set_default_dtype(torch.double):
|
|
run_tests()
|