mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Previous discussion: https://github.com/pytorch/pytorch/pull/109476 In this PR, I made following additions to the original PR: 1) Unlifted graph module now runs the runtime assertions in its' forward call. 2) When we retrace, we make sure we run the assertions to make sure user is tracing the module with correct inputs with respect to the assumptions we made during first tracing. The way I do is that I create new graph module type with modified call method. And the runtime assertions happen under torchdynamo.disable so that it is just run in eager directly. The reason is we don't this to be traced part of the graph. 3) Both ep.module and capture_pre_autograd now returns _UnliftedGraphModule. Differential Revision: [D51078056](https://our.internmc.facebook.com/intern/diff/D51078056) Pull Request resolved: https://github.com/pytorch/pytorch/pull/110222 Approved by: https://github.com/zhxchen17
202 lines
9.7 KiB
Python
202 lines
9.7 KiB
Python
import logging
|
|
from collections import defaultdict
|
|
from typing import Tuple, Dict, Optional, List
|
|
|
|
import torch
|
|
from torch._export import export
|
|
from torch._export.pass_base import _ExportPassBase
|
|
from torch._export.pass_infra.node_metadata import NodeMetadata
|
|
from torch._export.pass_infra.proxy_value import ProxyValue
|
|
from torch._subclasses import FakeTensor
|
|
from torch.fx.node import Target, Argument
|
|
from torch.library import Library
|
|
from torch.utils._pytree import tree_unflatten
|
|
import torch._export.exported_program as ep
|
|
import re
|
|
|
|
lib = Library("aten", "FRAGMENT")
|
|
impl_lib = Library("aten", "IMPL")
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
def get_target_version(versioned_upgrader_name: str) -> int:
|
|
"""div_Scalar_0_3 is the name of the upgrader, meaning it applies to div.Scalar of version 0 to 3 and is
|
|
upgrading to version 4."""
|
|
if not re.match("^.*_[0-9]+_[0-9]+$", versioned_upgrader_name):
|
|
raise RuntimeError(f"Upgrader name {versioned_upgrader_name} is invalid")
|
|
|
|
return int(versioned_upgrader_name.split('_')[-1]) + 1
|
|
|
|
|
|
def get_upgraders() -> Dict[str, Tuple[str, str]]:
|
|
"""Getting upgraders entry map and operator version map and merge them into one dict."""
|
|
upgraders = torch._C._get_upgraders_entry_map()
|
|
op_version_map = torch._C._get_operator_version_map()
|
|
output: Dict[str, Tuple[str, str]] = defaultdict(tuple) # type: ignore[arg-type]
|
|
for opname, entry_list in op_version_map.items():
|
|
if not entry_list:
|
|
raise RuntimeError(f"Op version map has an empty entry for opname {opname}")
|
|
entry = entry_list[0]
|
|
old_schema = entry.old_schema
|
|
upgrader_name = entry.upgrader_name
|
|
upgrader_str = upgraders.get(upgrader_name, None)
|
|
if not upgrader_str:
|
|
raise RuntimeError(f"Can't find upgrader for op {opname} and upgrader name {upgrader_name}")
|
|
output[upgrader_name] = (old_schema, upgrader_str)
|
|
return output
|
|
|
|
|
|
class GraphModuleOpUpgrader:
|
|
"""This upgrader is able to upgrade the old version of ops in a given GraphModule, if all upgraders are available.
|
|
To use it, retrieve upgraders from somewhere (TorchScript API or new API) and pass it into this upgrader. In
|
|
__init__() it does the following:
|
|
1. parse the upgrader list and reorder for upgrading purpose.
|
|
2. register old versions of operators as custom ops.
|
|
3. prepare upgrader passes.
|
|
|
|
In `upgrade()` API run these upgrader passes.
|
|
|
|
An example of op_upgraders input:
|
|
{
|
|
"aten::div__Scalar_0_3": ( # versioned op name
|
|
"div._Scalar(self: Tensor, other: Scalar)", # old schema
|
|
'''
|
|
def div__Scalar_0_3(self: torch.Tensor, other) -> torch.Tensor: # upgrader in literal string
|
|
if (self.is_floating_point() or isinstance(other, float)):
|
|
return self.true_divide_(other)
|
|
return self.divide_(other, rounding_mode='trunc')
|
|
''',
|
|
),
|
|
},
|
|
|
|
Note that we require the upgrader function to be runnable in Python (which is a stricter requirement than the
|
|
original TorchScript upgrader).
|
|
"""
|
|
|
|
class UpgraderPass(_ExportPassBase):
|
|
def __init__(self, old_target: Target, new_target: Target):
|
|
super().__init__()
|
|
self.old_target = old_target
|
|
self.new_target = new_target
|
|
|
|
def call_operator(
|
|
self,
|
|
op,
|
|
args: Tuple[Argument, ...],
|
|
kwargs: Dict[str, Argument],
|
|
meta: NodeMetadata,
|
|
) -> ProxyValue:
|
|
if op == self.old_target:
|
|
return super().call_operator(self.new_target, args, kwargs, meta)
|
|
return super().call_operator(op, args, kwargs, meta)
|
|
|
|
def __init__(
|
|
self,
|
|
compiler_opset_version: Optional[Dict[str, int]] = None,
|
|
model_opset_version: Optional[Dict[str, int]] = None,
|
|
op_upgraders: Optional[Dict[str, Tuple[str, str]]] = None,
|
|
):
|
|
self.op_upgraders: Dict[str, Tuple[str, str]] = get_upgraders() if not op_upgraders else op_upgraders
|
|
self.compiler_opset_version = compiler_opset_version if compiler_opset_version else {}
|
|
self.model_opset_version = model_opset_version if model_opset_version else {}
|
|
self.upgrader_passes: List[GraphModuleOpUpgrader.UpgraderPass] = GraphModuleOpUpgrader._populate_passes(
|
|
self._parse_upgraders(self.op_upgraders))
|
|
|
|
def _parse_upgraders(self, op_upgraders: Optional[Dict[str, Tuple[str, str]]] = None) -> List[Tuple[str, str]]:
|
|
"""Reorder op_upgraders by version number, return an ordered list of tuples, containing old op schema as well
|
|
as the upgrader function string literal."""
|
|
# TODO(larryliu0820): Add support for custom ops
|
|
op_namespace = "aten"
|
|
if not op_upgraders or op_namespace not in self.model_opset_version or op_namespace not in self.compiler_opset_version:
|
|
return []
|
|
model_ver = self.model_opset_version[op_namespace]
|
|
curr_ver = self.compiler_opset_version[op_namespace]
|
|
|
|
# key is the target version. div__Scalar_0_3 should have a key of 4.
|
|
versioned_upgraders: Dict[int, Tuple[str, str]] = {get_target_version(name): v for name, v in
|
|
op_upgraders.items()}
|
|
target_upgraders: List[Tuple[str, str]] = []
|
|
# we need all upgraders from model_ver + 1 to curr_ver, inclusively
|
|
for ver in range(model_ver + 1, curr_ver + 1):
|
|
if ver in versioned_upgraders:
|
|
target_upgraders.append(versioned_upgraders[ver])
|
|
else:
|
|
# we may be able to get away with missing upgraders, if that operator is missing from given graph
|
|
# module.
|
|
log.warning("Missing an upgrader to upgrade to version {ver}.", extra={"ver": ver})
|
|
|
|
return target_upgraders
|
|
|
|
@staticmethod
|
|
def _populate_passes(upgraders: List[Tuple[str, str]]) -> List[UpgraderPass]:
|
|
"""Given a list of upgraders, loop through it from lower version to higher version and create passes for all
|
|
upgraders. se torch.Library API to register old ops. Op name will be
|
|
<name>_<valid_from_ver>_<valid_till_ver>. Register upgraders as CompositeImplicitAutograd kernels. For example:
|
|
|
|
lib = Library("aten", "FRAGMENT")
|
|
lib.define(old_schema)
|
|
|
|
impl_lib = Library("aten", "IMPL")
|
|
impl_lib.impl("div__Scalar_0_3", div__Scalar_0_3, "CompositeImplicitAutograd")
|
|
|
|
@:var upgraders: a list of tuples. The first element of the tuple is the old schema and the second is the
|
|
upgrader function literal text.
|
|
@:return upgrader passes, order matters
|
|
"""
|
|
|
|
upgrader_passes = []
|
|
|
|
def register_old_op(name: str, schema: str, impl_str: str):
|
|
"""Registers an old version operator using impl_name as old op name."""
|
|
lib.define(schema)
|
|
try:
|
|
exec(impl_str)
|
|
except Exception as e:
|
|
raise RuntimeError(f"Invalid upgrader string: {impl_str}") from e
|
|
impl_lib.impl(name, locals()[name], "CompositeImplicitAutograd")
|
|
|
|
for (schema, upgrader_str) in upgraders:
|
|
upgrader_name = upgrader_str.split('(')[0].split(' ')[-1]
|
|
op_name = schema.split('(')[0].split("::")[-1]
|
|
schema = schema.replace(op_name, upgrader_name)
|
|
try:
|
|
register_old_op(name=upgrader_name, schema=schema, impl_str=upgrader_str)
|
|
except RuntimeError as e:
|
|
if "with the same name and overload name multiple times" in str(e):
|
|
print(f"Registering {upgrader_name} multiple times")
|
|
else:
|
|
raise RuntimeError from e
|
|
old_op_target = getattr(torch.ops.aten, upgrader_name).default
|
|
# for example, the operator instance of "aten::div" is torch.op.aten.div.default. We need to append the
|
|
# "default" at the end.
|
|
op_name, overload_name = (op_name, "default") if "." not in op_name else tuple(op_name.split(".")[:2])
|
|
new_op_target = getattr(getattr(torch.ops.aten, op_name), overload_name)
|
|
# Note that the graph will have op names in the graph, but actually they are of old versions.
|
|
upgrader_passes.append(
|
|
GraphModuleOpUpgrader.UpgraderPass(old_target=new_op_target, new_target=old_op_target))
|
|
|
|
return upgrader_passes
|
|
|
|
def upgrade(self, exported_program: ep.ExportedProgram) -> ep.ExportedProgram:
|
|
"""Run each upgrader pass and then retrace to decompose it. Each upgrader pass replaces the old version of
|
|
operators with a custom operator. The custom operator contains a CompositeImplicitAutograd kernel (the
|
|
upgrading function itself). After retrace, this custom operator will be decomposed into the ops used in the
|
|
upgrader. After all passes are applied, the exported program will be upgraded to the target version."""
|
|
if not self.upgrader_passes:
|
|
return exported_program
|
|
|
|
args = [n.meta.get("val", None) for n in exported_program.graph.nodes if n.op == "placeholder"]
|
|
args_real_tensors = [torch.ones(tuple(arg.size()), dtype=arg.dtype) if isinstance(arg, FakeTensor) else arg for
|
|
arg in args]
|
|
assert exported_program.call_spec.in_spec is not None
|
|
args, kwargs = tree_unflatten(args_real_tensors, exported_program.call_spec.in_spec)
|
|
assert kwargs == {}
|
|
|
|
for _pass in self.upgrader_passes:
|
|
upgraded_program = exported_program._transform(_pass)
|
|
# NB: we have to retrace the graph_module instead of ep because of some failure.
|
|
exported_program = export(upgraded_program.module(), args, kwargs)
|
|
|
|
return exported_program
|