pytorch/test/dynamo/test_debug_utils.py
Peter Bell 758735b739 [dynamo] Convert dtype arguments as well as inputs in cast_to_fp64 (#110232)
Generating reference outputs somtimes fails because of type mismatches in the graph,
an issue which was noticed previously for `prims.convert_element_type` and fixed in #92036
but the same issue happens with other functions such as tensor constructors.

This expands the fix from #92036 to all dtype keyword arguments.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110232
Approved by: https://github.com/ezyang
2023-09-29 12:42:14 +00:00

57 lines
2.2 KiB
Python

# Owner(s): ["module: dynamo"]
import torch
from functorch import make_fx
from torch._dynamo import debug_utils
from torch._dynamo.test_case import TestCase
class TestDebugUtils(TestCase):
def test_cast_model_to_fp64_dtype_args(self):
# Test that dtype arguments are converted to fp64
def fn(x):
return (
torch.ops.prims.convert_element_type(x, torch.float16),
x.to(torch.float16),
torch.full(x.shape, 2, dtype=torch.float32, device=x.device),
x.new_empty(x.shape),
)
x = torch.randn(32, device="cpu")
decomps = torch._decomp.core_aten_decompositions()
fx = make_fx(fn, decomposition_table=decomps)(x)
self.assertExpectedInline(
fx.code.lstrip(),
"""\
def forward(self, x_1):
convert_element_type = torch.ops.prims.convert_element_type.default(x_1, torch.float16)
_to_copy = torch.ops.aten._to_copy.default(x_1, dtype = torch.float16); x_1 = None
full = torch.ops.aten.full.default([32], 2, dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
empty = torch.ops.aten.empty.memory_format([32], dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
return (convert_element_type, _to_copy, full, empty)
""", # NOQA: B950
)
fp64_model, fp64_examples = debug_utils.cast_to_fp64(fx, (x,))
self.assertEqual(fp64_examples, (x.to(torch.float64),))
self.assertExpectedInline(
fx.code.lstrip(),
"""\
def forward(self, x_1):
convert_element_type = torch.ops.prims.convert_element_type.default(x_1, torch.float64)
_to_copy = torch.ops.aten._to_copy.default(x_1, dtype = torch.float64); x_1 = None
full = torch.ops.aten.full.default([32], 2, dtype = torch.float64, device = device(type='cpu'), pin_memory = False)
empty = torch.ops.aten.empty.memory_format([32], dtype = torch.float64, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
return (convert_element_type, _to_copy, full, empty)
""", # NOQA: B950
)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()