mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
1. Extend fake_tensor_unsupported to support dynamic shapes mode. 2. Use fake_tensor_unsupported in dynamo ipex backend. Pull Request resolved: https://github.com/pytorch/pytorch/pull/94169 Approved by: https://github.com/jgong5, https://github.com/jansel
122 lines
3.6 KiB
Python
122 lines
3.6 KiB
Python
import functools
|
|
import logging
|
|
|
|
import torch
|
|
from torch._dynamo import eval_frame
|
|
from torch._dynamo.utils import counters
|
|
from torch._functorch.aot_autograd import aot_module_simplified
|
|
from torch._subclasses import FakeTensor
|
|
from torch.utils._python_dispatch import _disable_current_modes
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
def aot_autograd(**kwargs):
|
|
def compiler_fn(gm: torch.fx.GraphModule, example_inputs):
|
|
import functorch.compile
|
|
|
|
# Hack to get around circular import problems with aot_eager_decomp_partition
|
|
if callable(kwargs.get("decompositions")):
|
|
kwargs["decompositions"] = kwargs["decompositions"]()
|
|
|
|
# TODO: stop monkeypatching here (without even cleaning up, UGH!)
|
|
functorch.compile.config.use_functionalize = True
|
|
functorch.compile.config.use_fake_tensor = True
|
|
|
|
counters["aot_autograd"]["total"] += 1
|
|
use_fallback = False
|
|
|
|
if use_fallback:
|
|
log.debug("Unable to use AOT Autograd because graph has mutation")
|
|
counters["aot_autograd"]["not_ok"] += 1
|
|
return gm
|
|
|
|
# OK attempt to compile
|
|
|
|
def _wrapped_bw_compiler(*args, **kwargs):
|
|
# stop TorchDynamo from trying to compile our generated backwards pass
|
|
return eval_frame.disable(eval_frame.disable(bw_compiler)(*args, **kwargs))
|
|
|
|
bw_compiler = kwargs.get("bw_compiler") or kwargs["fw_compiler"]
|
|
kwargs["bw_compiler"] = _wrapped_bw_compiler
|
|
|
|
from torch._inductor.debug import enable_aot_logging
|
|
|
|
try:
|
|
# NB: NOT cloned!
|
|
with enable_aot_logging():
|
|
cg = aot_module_simplified(gm, example_inputs, **kwargs)
|
|
counters["aot_autograd"]["ok"] += 1
|
|
return eval_frame.disable(cg)
|
|
except Exception:
|
|
counters["aot_autograd"]["not_ok"] += 1
|
|
raise
|
|
|
|
return compiler_fn
|
|
|
|
|
|
def mem_efficient_fusion_kwargs(use_decomps):
|
|
from functorch.compile import (
|
|
default_decompositions,
|
|
min_cut_rematerialization_partition,
|
|
ts_compile,
|
|
)
|
|
|
|
kwargs = {
|
|
# these are taken from memory_efficient_fusion()
|
|
"fw_compiler": ts_compile,
|
|
"bw_compiler": ts_compile,
|
|
"partition_fn": min_cut_rematerialization_partition,
|
|
}
|
|
|
|
if use_decomps:
|
|
kwargs["decompositions"] = default_decompositions
|
|
|
|
return kwargs
|
|
|
|
|
|
def fake_tensor_unsupported(fn):
|
|
"""
|
|
Decorator for backends that need real inputs. We swap out fake
|
|
tensors for zero tensors.
|
|
"""
|
|
|
|
def defake(x):
|
|
if not isinstance(x, FakeTensor):
|
|
return x
|
|
if x._has_symbolic_sizes_strides:
|
|
size = [s.node.shape_env.size_hint(s.node.expr) for s in x.size()]
|
|
stride = [s.node.shape_env.size_hint(s.node.expr) for s in x.stride()]
|
|
else:
|
|
size = x.size()
|
|
stride = x.stride()
|
|
y = torch.empty_strided(
|
|
size,
|
|
stride,
|
|
dtype=x.dtype,
|
|
device=x.device,
|
|
requires_grad=x.requires_grad,
|
|
)
|
|
y.zero_()
|
|
return y
|
|
|
|
@functools.wraps(fn)
|
|
def wrapper(model, inputs, **kwargs):
|
|
with _disable_current_modes():
|
|
inputs = list(map(defake, inputs))
|
|
return fn(model, inputs, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
|
|
def device_from_inputs(example_inputs) -> torch.device:
|
|
for x in example_inputs:
|
|
if hasattr(x, "device"):
|
|
return x.device
|
|
|
|
|
|
def dtype_from_inputs(example_inputs) -> torch.dtype:
|
|
for x in example_inputs:
|
|
if hasattr(x, "dtype"):
|
|
return x.dtype
|