mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Generating reference outputs somtimes fails because of type mismatches in the graph, an issue which was noticed previously for `prims.convert_element_type` and fixed in #92036 but the same issue happens with other functions such as tensor constructors. This expands the fix from #92036 to all dtype keyword arguments. Pull Request resolved: https://github.com/pytorch/pytorch/pull/110232 Approved by: https://github.com/ezyang
57 lines
2.2 KiB
Python
57 lines
2.2 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
|
|
import torch
|
|
from functorch import make_fx
|
|
from torch._dynamo import debug_utils
|
|
from torch._dynamo.test_case import TestCase
|
|
|
|
|
|
class TestDebugUtils(TestCase):
|
|
def test_cast_model_to_fp64_dtype_args(self):
|
|
# Test that dtype arguments are converted to fp64
|
|
|
|
def fn(x):
|
|
return (
|
|
torch.ops.prims.convert_element_type(x, torch.float16),
|
|
x.to(torch.float16),
|
|
torch.full(x.shape, 2, dtype=torch.float32, device=x.device),
|
|
x.new_empty(x.shape),
|
|
)
|
|
|
|
x = torch.randn(32, device="cpu")
|
|
decomps = torch._decomp.core_aten_decompositions()
|
|
fx = make_fx(fn, decomposition_table=decomps)(x)
|
|
|
|
self.assertExpectedInline(
|
|
fx.code.lstrip(),
|
|
"""\
|
|
def forward(self, x_1):
|
|
convert_element_type = torch.ops.prims.convert_element_type.default(x_1, torch.float16)
|
|
_to_copy = torch.ops.aten._to_copy.default(x_1, dtype = torch.float16); x_1 = None
|
|
full = torch.ops.aten.full.default([32], 2, dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
|
|
empty = torch.ops.aten.empty.memory_format([32], dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
|
|
return (convert_element_type, _to_copy, full, empty)
|
|
""", # NOQA: B950
|
|
)
|
|
|
|
fp64_model, fp64_examples = debug_utils.cast_to_fp64(fx, (x,))
|
|
self.assertEqual(fp64_examples, (x.to(torch.float64),))
|
|
|
|
self.assertExpectedInline(
|
|
fx.code.lstrip(),
|
|
"""\
|
|
def forward(self, x_1):
|
|
convert_element_type = torch.ops.prims.convert_element_type.default(x_1, torch.float64)
|
|
_to_copy = torch.ops.aten._to_copy.default(x_1, dtype = torch.float64); x_1 = None
|
|
full = torch.ops.aten.full.default([32], 2, dtype = torch.float64, device = device(type='cpu'), pin_memory = False)
|
|
empty = torch.ops.aten.empty.memory_format([32], dtype = torch.float64, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
|
|
return (convert_element_type, _to_copy, full, empty)
|
|
""", # NOQA: B950
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
run_tests()
|