mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
**Summary:** This commit adds the `prepare_qat_pt2e` API and the fusion logic for Conv + BN. We use the subgraph rewriter to match and replace the pattern with the existing logic in `nniqat.ConvBn2d`. Note this is not the end-to-end flow yet. In particular, the convert flow needs to swap the new subgraph with another one that merges the batchnorm stats back into conv. The Conv + BN fusion is implemented in the following steps: 1. Annotate all nodes in the pattern `[conv - bn - getitem]` 2. Match and replace this pattern with the fused QAT pattern (note that this is a larger subgraph than the original one) 3. Copy over metadata from the original nodes to the corresponding nodes in the new subgraph, to ensure the stack traces and dtype annotations are preserved 4. Prepare will insert fake quantizes in the right places based on the annotations **Test Plan:** python test/test_quantization.py TestQuantizePT2E.test_qat_conv_bn_fusion **Reviewers:** jerryzh168, kimishpatel, yanboliang Pull Request resolved: https://github.com/pytorch/pytorch/pull/98568 Approved by: https://github.com/kimishpatel
225 lines
5.1 KiB
Python
225 lines
5.1 KiB
Python
import _collections_abc
|
|
import _weakrefset
|
|
import abc
|
|
import collections
|
|
import contextlib
|
|
import copy
|
|
import copyreg
|
|
import dataclasses
|
|
import enum
|
|
import functools
|
|
import importlib
|
|
import inspect
|
|
import linecache
|
|
import logging
|
|
import multiprocessing
|
|
import operator
|
|
import os
|
|
import posixpath
|
|
import random
|
|
import re
|
|
import selectors
|
|
import signal
|
|
import tempfile
|
|
import threading
|
|
import tokenize
|
|
import traceback
|
|
import types
|
|
import typing
|
|
import unittest
|
|
import weakref
|
|
|
|
import torch
|
|
import torch._export.constraints as _export_constraints
|
|
import torch._inductor.test_operators
|
|
import torch.ao.quantization._pt2e.qat_utils
|
|
|
|
from . import comptime, config, external_utils
|
|
|
|
"""
|
|
A note on skipfiles:
|
|
|
|
Dynamo consults this file to determine whether code should be compiled or skipped.
|
|
|
|
A skip applies at the frame boundary, meaning dynamo either triggers a graph break
|
|
at the beginning of the frame or attempts to trace the whole frame. When skipping
|
|
a frame, recursively called frames are still traced by dynamo unless also skipped.
|
|
|
|
Skipfiles (skipped at the file level instead of function level) still apply on a
|
|
frame-by-frame boundary as dynamo traces, but apply to all functions in that file.
|
|
|
|
@skip is a helper decorator that can be applied to your function to cause it to be
|
|
included here.
|
|
"""
|
|
|
|
|
|
def _strip_init_py(s):
|
|
return re.sub(r"__init__.py$", "", s)
|
|
|
|
|
|
def _module_dir(m: types.ModuleType):
|
|
return _strip_init_py(m.__file__)
|
|
|
|
|
|
SKIP_DIRS = [
|
|
# torch.*
|
|
_module_dir(torch),
|
|
# torchdynamo.*
|
|
os.path.dirname(__file__) + "/",
|
|
"<frozen importlib",
|
|
"<__array_function__ internals>",
|
|
] + [
|
|
# skip some standard libs
|
|
_module_dir(m)
|
|
for m in (
|
|
abc,
|
|
collections,
|
|
contextlib,
|
|
copy,
|
|
copyreg,
|
|
dataclasses,
|
|
enum,
|
|
functools,
|
|
importlib,
|
|
inspect,
|
|
linecache,
|
|
logging,
|
|
multiprocessing,
|
|
operator,
|
|
os,
|
|
posixpath,
|
|
random,
|
|
re,
|
|
selectors,
|
|
signal,
|
|
tempfile,
|
|
threading,
|
|
tokenize,
|
|
traceback,
|
|
types,
|
|
typing,
|
|
unittest,
|
|
weakref,
|
|
_collections_abc,
|
|
_weakrefset,
|
|
)
|
|
]
|
|
|
|
FILENAME_ALLOWLIST = {
|
|
torch.nn.Sequential.__init__.__code__.co_filename,
|
|
torch.set_rng_state.__code__.co_filename,
|
|
torch._inductor.test_operators.__file__,
|
|
# These are dynamo files!
|
|
external_utils.__file__,
|
|
comptime.__file__, # Want to inline these helpers
|
|
}
|
|
|
|
# Include optimizer code for tracing
|
|
FILENAME_ALLOWLIST |= {
|
|
inspect.getfile(obj)
|
|
for obj in torch.optim.__dict__.values()
|
|
if inspect.isclass(obj)
|
|
}
|
|
FILENAME_ALLOWLIST |= {torch.optim._functional.__file__}
|
|
FILENAME_ALLOWLIST |= {_export_constraints.__file__}
|
|
|
|
# Do trace through match and replace patterns used in PT2E QAT
|
|
# Note: These patterns are comprised of torch ops and for internal use only.
|
|
# They are exported to aten graphs before being passed to the FX subgraph rewriter.
|
|
FILENAME_ALLOWLIST |= {torch.ao.quantization._pt2e.qat_utils.__file__}
|
|
|
|
|
|
SKIP_DIRS_RE = None
|
|
|
|
is_fbcode = importlib.import_module("torch._inductor.config").is_fbcode()
|
|
# Skip fbcode paths(including torch.package paths) containing
|
|
# one of the following strings.
|
|
FBCODE_SKIP_DIRS = {
|
|
"torchrec/distributed",
|
|
"torchrec/fb/distributed",
|
|
"caffe2/torch/fb/sparsenn/pooled_embeddings_modules.py",
|
|
}
|
|
FBCODE_SKIP_DIRS_RE = re.compile(f".*({'|'.join(map(re.escape, FBCODE_SKIP_DIRS))})")
|
|
|
|
|
|
def _recompile_re():
|
|
global SKIP_DIRS_RE
|
|
SKIP_DIRS_RE = re.compile(f"^({'|'.join(map(re.escape, SKIP_DIRS))})")
|
|
|
|
|
|
def add(import_name: str):
|
|
if isinstance(import_name, types.ModuleType):
|
|
return add(import_name.__name__)
|
|
assert isinstance(import_name, str)
|
|
module_spec = importlib.util.find_spec(import_name)
|
|
if not module_spec:
|
|
return
|
|
origin = module_spec.origin
|
|
if origin is None:
|
|
return
|
|
global SKIP_DIRS_RE
|
|
SKIP_DIRS.append(_strip_init_py(origin))
|
|
_recompile_re()
|
|
|
|
|
|
def check(filename, allow_torch=False):
|
|
"""Should skip this file?"""
|
|
if filename is None:
|
|
return True
|
|
if filename in FILENAME_ALLOWLIST:
|
|
return False
|
|
if allow_torch and is_torch(filename):
|
|
return False
|
|
if is_fbcode and bool(FBCODE_SKIP_DIRS_RE.match(filename)):
|
|
return True
|
|
return bool(SKIP_DIRS_RE.match(filename))
|
|
|
|
|
|
# skip common third party libs
|
|
for _name in (
|
|
"einops",
|
|
"einops_exts",
|
|
"functorch",
|
|
"fx2trt_oss",
|
|
"intel_extension_for_pytorch",
|
|
"networkx",
|
|
"numpy",
|
|
"omegaconf",
|
|
"onnx",
|
|
"onnxruntime",
|
|
"onnx_tf",
|
|
"pandas",
|
|
"sklearn",
|
|
"tabulate",
|
|
"tensorflow",
|
|
"tensorrt",
|
|
"torch2trt",
|
|
"tqdm",
|
|
"tree",
|
|
"tvm",
|
|
"xarray",
|
|
):
|
|
add(_name)
|
|
|
|
_recompile_re()
|
|
|
|
|
|
def is_torch_inline_allowed(filename):
|
|
return any(
|
|
filename.startswith(_module_dir(mod))
|
|
for mod in config.skipfiles_inline_module_allowlist
|
|
)
|
|
|
|
|
|
@functools.lru_cache(None)
|
|
def dynamo_dir():
|
|
import torch._dynamo
|
|
|
|
return _module_dir(torch._dynamo)
|
|
|
|
|
|
def is_torch(filename):
|
|
if filename.startswith(dynamo_dir()):
|
|
return False
|
|
return filename.startswith(_module_dir(torch))
|