mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107940 Approved by: https://github.com/tugsbayasgalan
143 lines
5.6 KiB
Python
143 lines
5.6 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
import unittest
|
|
from unittest.mock import patch
|
|
|
|
import torch
|
|
import torch._dynamo as torchdynamo
|
|
from torch._export import export
|
|
from torch._export.serde.serialize import GraphModuleOpUpgrader
|
|
from torch._export.serde.upgrade import get_target_version, get_upgraders
|
|
from torch.testing._internal.common_utils import (
|
|
run_tests,
|
|
TestCase,
|
|
)
|
|
|
|
TEST_UPGRADERS = {
|
|
"aten::div__Scalar_mode_0_3": (
|
|
"div.Scalar_mode(Tensor self, Scalar other, *, str rounding_mode) -> Tensor",
|
|
"""
|
|
from typing import Any, Optional
|
|
def div__Scalar_mode_0_3(self: torch.Tensor, other: Any, *, rounding_mode: Optional[str]=None) -> torch.Tensor:
|
|
return self.divide(other, rounding_mode=rounding_mode)
|
|
""",
|
|
),
|
|
"aten::gelu_0_9": (
|
|
"gelu(Tensor self) -> Tensor",
|
|
"""
|
|
def gelu_0_9(self: Tensor) -> Tensor:
|
|
return torch.gelu(self, approximate='none')
|
|
""",
|
|
),
|
|
}
|
|
|
|
TEST_UPGRADERS_ENTRY_MAP = {
|
|
"div__Scalar_mode_0_3":
|
|
"""
|
|
from typing import Any, Optional
|
|
def div__Scalar_mode_0_3(self: torch.Tensor, other: Any, *, rounding_mode: Optional[str]=None) -> torch.Tensor:
|
|
return self.divide_(other, rounding_mode=rounding_mode)"""
|
|
}
|
|
|
|
TEST_OP_VERSION_MAP = {
|
|
"aten::div_.Scalar_mode": [
|
|
torch._C._UpgraderEntry(
|
|
4,
|
|
"div__Scalar_mode_0_3",
|
|
"aten::div_.Scalar_mode(Tensor(a!) self, Scalar other, *, str? rounding_mode) -> Tensor(a!)"
|
|
)
|
|
]
|
|
}
|
|
|
|
|
|
def count_op(graph, target_str):
|
|
return len(
|
|
[n for n in graph.nodes if isinstance(n.target, torch._ops.OpOverload) and n.target.name() == target_str])
|
|
|
|
|
|
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support")
|
|
class TestUpgrade(TestCase):
|
|
def test_get_upgraders(self):
|
|
with patch.object(torch._C, "_get_upgraders_entry_map", return_value=TEST_UPGRADERS_ENTRY_MAP), \
|
|
patch.object(torch._C, "_get_operator_version_map", return_value=TEST_OP_VERSION_MAP):
|
|
op_upgraders = get_upgraders()
|
|
self.assertEqual(op_upgraders, {
|
|
"div__Scalar_mode_0_3": (
|
|
"aten::div_.Scalar_mode(Tensor(a!) self, Scalar other, *, str? rounding_mode) -> Tensor(a!)",
|
|
"""
|
|
from typing import Any, Optional
|
|
def div__Scalar_mode_0_3(self: torch.Tensor, other: Any, *, rounding_mode: Optional[str]=None) -> torch.Tensor:
|
|
return self.divide_(other, rounding_mode=rounding_mode)""",
|
|
)})
|
|
|
|
def test_get_upgraders_missing_from_entry_map_raises(self):
|
|
with patch.object(torch._C, "_get_upgraders_entry_map", return_value={}), \
|
|
patch.object(torch._C, "_get_operator_version_map", return_value=TEST_OP_VERSION_MAP):
|
|
with self.assertRaises(RuntimeError):
|
|
get_upgraders()
|
|
|
|
def test_upgrader_with_invalid_format_throws_exception(self):
|
|
"""Invalid upgrader function string should throw exception"""
|
|
upgraders = [("div(Tensor a, Tensor b) -> Tensor", "TEST")]
|
|
with self.assertRaises(RuntimeError):
|
|
GraphModuleOpUpgrader._populate_passes(upgraders)
|
|
|
|
def test_get_target_version_invalid_format_throws_exception(self):
|
|
with self.assertRaises(RuntimeError):
|
|
get_target_version("div_0")
|
|
with self.assertRaises(RuntimeError):
|
|
get_target_version("div_0_")
|
|
with self.assertRaises(RuntimeError):
|
|
get_target_version("div")
|
|
|
|
def test_creates_upgrader_pass(self):
|
|
compiler_opset_version = {"aten": 4}
|
|
model_opset_version = {"aten": 3}
|
|
upgrader = GraphModuleOpUpgrader(compiler_opset_version, model_opset_version, TEST_UPGRADERS)
|
|
self.assertEqual(len(upgrader.upgrader_passes), 1)
|
|
|
|
def test_div_upgrader_replaces_op_with_old_version(self):
|
|
def fn(a: torch.Tensor, b):
|
|
return torch.ops.aten.div.Scalar_mode(a, b, rounding_mode='trunc')
|
|
|
|
inputs = (torch.ones([2, 3]) * 4, 2.)
|
|
ep = export(fn, inputs, [])
|
|
compiler_opset_version = {"aten": 4}
|
|
model_opset_version = {"aten": 3}
|
|
upgrader = GraphModuleOpUpgrader(compiler_opset_version, model_opset_version, TEST_UPGRADERS)
|
|
upgraded = ep._transform(*upgrader.upgrader_passes)
|
|
upgraded.graph_module.print_readable()
|
|
|
|
count = count_op(upgraded.graph, "aten::div.Scalar_mode")
|
|
self.assertEqual(count, 0)
|
|
custom_op_count = count_op(upgraded.graph, "aten::div__Scalar_mode_0_3")
|
|
self.assertEqual(custom_op_count, 1)
|
|
|
|
def test_div_upgrader_pass_return_new_op_after_retrace(self):
|
|
def fn(a: torch.Tensor, b):
|
|
return torch.ops.aten.div.Scalar_mode(a, b, rounding_mode='trunc')
|
|
|
|
inputs = (torch.ones([2, 3]) * 4, 2.)
|
|
ep = export(fn, inputs, {}, [])
|
|
compiler_opset_version = {"aten": 4}
|
|
model_opset_version = {"aten": 3}
|
|
upgrader = GraphModuleOpUpgrader(compiler_opset_version, model_opset_version, TEST_UPGRADERS)
|
|
|
|
count = count_op(ep.graph, "aten::div.Scalar_mode")
|
|
self.assertEqual(count, 1)
|
|
|
|
# upgrade: replace op (div.Scalar_mode -> div__Scalar_mode_0_3) then retrace
|
|
upgraded_ep = upgrader.upgrade(ep)
|
|
upgraded_ep.graph_module.print_readable()
|
|
|
|
# no old version of op (div__Scalar_mode_0_3) anymore.
|
|
custom_op_count = count_op(upgraded_ep.graph, "aten::div__Scalar_mode_0_3")
|
|
self.assertEqual(custom_op_count, 0)
|
|
|
|
# div__Scalar_mode_0_3 decomposes into div.Tensor.
|
|
decomposed_op_count = count_op(upgraded_ep.graph, "aten::div.Tensor_mode")
|
|
self.assertEqual(decomposed_op_count, 1)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|