pytorch/torch/_export/serde/upgrade.py
Tugsbayasgalan Manlaibaatar a7b75f586a [RELAND] Disallow skipping dynamo (#110222)
Previous discussion: https://github.com/pytorch/pytorch/pull/109476

In this PR, I made following additions to the original PR:
1) Unlifted graph module now runs the runtime assertions in its' forward call.
2) When we retrace, we make sure we run the assertions to make sure user is tracing the module with correct inputs with respect to the assumptions we made during first tracing. The way I do is that I create new graph module type with modified call method. And the runtime assertions happen under torchdynamo.disable so that it is just run in eager directly. The reason is we don't this to be traced part of the graph.
3) Both ep.module and capture_pre_autograd now returns _UnliftedGraphModule.

Differential Revision: [D51078056](https://our.internmc.facebook.com/intern/diff/D51078056)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110222
Approved by: https://github.com/zhxchen17
2023-11-14 16:02:01 +00:00

202 lines
9.7 KiB
Python

import logging
from collections import defaultdict
from typing import Tuple, Dict, Optional, List
import torch
from torch._export import export
from torch._export.pass_base import _ExportPassBase
from torch._export.pass_infra.node_metadata import NodeMetadata
from torch._export.pass_infra.proxy_value import ProxyValue
from torch._subclasses import FakeTensor
from torch.fx.node import Target, Argument
from torch.library import Library
from torch.utils._pytree import tree_unflatten
import torch._export.exported_program as ep
import re
lib = Library("aten", "FRAGMENT")
impl_lib = Library("aten", "IMPL")
log = logging.getLogger(__name__)
def get_target_version(versioned_upgrader_name: str) -> int:
"""div_Scalar_0_3 is the name of the upgrader, meaning it applies to div.Scalar of version 0 to 3 and is
upgrading to version 4."""
if not re.match("^.*_[0-9]+_[0-9]+$", versioned_upgrader_name):
raise RuntimeError(f"Upgrader name {versioned_upgrader_name} is invalid")
return int(versioned_upgrader_name.split('_')[-1]) + 1
def get_upgraders() -> Dict[str, Tuple[str, str]]:
"""Getting upgraders entry map and operator version map and merge them into one dict."""
upgraders = torch._C._get_upgraders_entry_map()
op_version_map = torch._C._get_operator_version_map()
output: Dict[str, Tuple[str, str]] = defaultdict(tuple) # type: ignore[arg-type]
for opname, entry_list in op_version_map.items():
if not entry_list:
raise RuntimeError(f"Op version map has an empty entry for opname {opname}")
entry = entry_list[0]
old_schema = entry.old_schema
upgrader_name = entry.upgrader_name
upgrader_str = upgraders.get(upgrader_name, None)
if not upgrader_str:
raise RuntimeError(f"Can't find upgrader for op {opname} and upgrader name {upgrader_name}")
output[upgrader_name] = (old_schema, upgrader_str)
return output
class GraphModuleOpUpgrader:
"""This upgrader is able to upgrade the old version of ops in a given GraphModule, if all upgraders are available.
To use it, retrieve upgraders from somewhere (TorchScript API or new API) and pass it into this upgrader. In
__init__() it does the following:
1. parse the upgrader list and reorder for upgrading purpose.
2. register old versions of operators as custom ops.
3. prepare upgrader passes.
In `upgrade()` API run these upgrader passes.
An example of op_upgraders input:
{
"aten::div__Scalar_0_3": ( # versioned op name
"div._Scalar(self: Tensor, other: Scalar)", # old schema
'''
def div__Scalar_0_3(self: torch.Tensor, other) -> torch.Tensor: # upgrader in literal string
if (self.is_floating_point() or isinstance(other, float)):
return self.true_divide_(other)
return self.divide_(other, rounding_mode='trunc')
''',
),
},
Note that we require the upgrader function to be runnable in Python (which is a stricter requirement than the
original TorchScript upgrader).
"""
class UpgraderPass(_ExportPassBase):
def __init__(self, old_target: Target, new_target: Target):
super().__init__()
self.old_target = old_target
self.new_target = new_target
def call_operator(
self,
op,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
meta: NodeMetadata,
) -> ProxyValue:
if op == self.old_target:
return super().call_operator(self.new_target, args, kwargs, meta)
return super().call_operator(op, args, kwargs, meta)
def __init__(
self,
compiler_opset_version: Optional[Dict[str, int]] = None,
model_opset_version: Optional[Dict[str, int]] = None,
op_upgraders: Optional[Dict[str, Tuple[str, str]]] = None,
):
self.op_upgraders: Dict[str, Tuple[str, str]] = get_upgraders() if not op_upgraders else op_upgraders
self.compiler_opset_version = compiler_opset_version if compiler_opset_version else {}
self.model_opset_version = model_opset_version if model_opset_version else {}
self.upgrader_passes: List[GraphModuleOpUpgrader.UpgraderPass] = GraphModuleOpUpgrader._populate_passes(
self._parse_upgraders(self.op_upgraders))
def _parse_upgraders(self, op_upgraders: Optional[Dict[str, Tuple[str, str]]] = None) -> List[Tuple[str, str]]:
"""Reorder op_upgraders by version number, return an ordered list of tuples, containing old op schema as well
as the upgrader function string literal."""
# TODO(larryliu0820): Add support for custom ops
op_namespace = "aten"
if not op_upgraders or op_namespace not in self.model_opset_version or op_namespace not in self.compiler_opset_version:
return []
model_ver = self.model_opset_version[op_namespace]
curr_ver = self.compiler_opset_version[op_namespace]
# key is the target version. div__Scalar_0_3 should have a key of 4.
versioned_upgraders: Dict[int, Tuple[str, str]] = {get_target_version(name): v for name, v in
op_upgraders.items()}
target_upgraders: List[Tuple[str, str]] = []
# we need all upgraders from model_ver + 1 to curr_ver, inclusively
for ver in range(model_ver + 1, curr_ver + 1):
if ver in versioned_upgraders:
target_upgraders.append(versioned_upgraders[ver])
else:
# we may be able to get away with missing upgraders, if that operator is missing from given graph
# module.
log.warning("Missing an upgrader to upgrade to version {ver}.", extra={"ver": ver})
return target_upgraders
@staticmethod
def _populate_passes(upgraders: List[Tuple[str, str]]) -> List[UpgraderPass]:
"""Given a list of upgraders, loop through it from lower version to higher version and create passes for all
upgraders. se torch.Library API to register old ops. Op name will be
<name>_<valid_from_ver>_<valid_till_ver>. Register upgraders as CompositeImplicitAutograd kernels. For example:
lib = Library("aten", "FRAGMENT")
lib.define(old_schema)
impl_lib = Library("aten", "IMPL")
impl_lib.impl("div__Scalar_0_3", div__Scalar_0_3, "CompositeImplicitAutograd")
@:var upgraders: a list of tuples. The first element of the tuple is the old schema and the second is the
upgrader function literal text.
@:return upgrader passes, order matters
"""
upgrader_passes = []
def register_old_op(name: str, schema: str, impl_str: str):
"""Registers an old version operator using impl_name as old op name."""
lib.define(schema)
try:
exec(impl_str)
except Exception as e:
raise RuntimeError(f"Invalid upgrader string: {impl_str}") from e
impl_lib.impl(name, locals()[name], "CompositeImplicitAutograd")
for (schema, upgrader_str) in upgraders:
upgrader_name = upgrader_str.split('(')[0].split(' ')[-1]
op_name = schema.split('(')[0].split("::")[-1]
schema = schema.replace(op_name, upgrader_name)
try:
register_old_op(name=upgrader_name, schema=schema, impl_str=upgrader_str)
except RuntimeError as e:
if "with the same name and overload name multiple times" in str(e):
print(f"Registering {upgrader_name} multiple times")
else:
raise RuntimeError from e
old_op_target = getattr(torch.ops.aten, upgrader_name).default
# for example, the operator instance of "aten::div" is torch.op.aten.div.default. We need to append the
# "default" at the end.
op_name, overload_name = (op_name, "default") if "." not in op_name else tuple(op_name.split(".")[:2])
new_op_target = getattr(getattr(torch.ops.aten, op_name), overload_name)
# Note that the graph will have op names in the graph, but actually they are of old versions.
upgrader_passes.append(
GraphModuleOpUpgrader.UpgraderPass(old_target=new_op_target, new_target=old_op_target))
return upgrader_passes
def upgrade(self, exported_program: ep.ExportedProgram) -> ep.ExportedProgram:
"""Run each upgrader pass and then retrace to decompose it. Each upgrader pass replaces the old version of
operators with a custom operator. The custom operator contains a CompositeImplicitAutograd kernel (the
upgrading function itself). After retrace, this custom operator will be decomposed into the ops used in the
upgrader. After all passes are applied, the exported program will be upgraded to the target version."""
if not self.upgrader_passes:
return exported_program
args = [n.meta.get("val", None) for n in exported_program.graph.nodes if n.op == "placeholder"]
args_real_tensors = [torch.ones(tuple(arg.size()), dtype=arg.dtype) if isinstance(arg, FakeTensor) else arg for
arg in args]
assert exported_program.call_spec.in_spec is not None
args, kwargs = tree_unflatten(args_real_tensors, exported_program.call_spec.in_spec)
assert kwargs == {}
for _pass in self.upgrader_passes:
upgraded_program = exported_program._transform(_pass)
# NB: we have to retrace the graph_module instead of ep because of some failure.
exported_program = export(upgraded_program.module(), args, kwargs)
return exported_program