mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
This reverts commit aa8ea1d787.
Reverted https://github.com/pytorch/pytorch/pull/107246 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/107246#issuecomment-1693838522))
17 lines
397 B
Python
17 lines
397 B
Python
# Owner(s): ["module: nvfuser"]
|
|
|
|
import torch
|
|
from torch.testing._internal.common_utils import set_default_dtype
|
|
|
|
try:
|
|
from _nvfuser.test_torchscript import * # noqa: F403,F401
|
|
except ImportError:
|
|
def run_tests():
|
|
return
|
|
pass
|
|
|
|
if __name__ == '__main__':
|
|
# TODO: Update nvfuser to work with float default dtype
|
|
with set_default_dtype(torch.double):
|
|
run_tests()
|