mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Talked to @zou3519 and @ezyang on what the right UX is: tentatively, adding a new dynamo backend is cheap and simple, so it seems worth doing. And longer term, we agreed (?) that it's worth seeing if we can get custom ops sanity asserts to run more automatically, instead of needing a separate backend. Side comment: that actually seems tough: the mode detects secret mutations by cloning every input to every op, running the op, and checking that the data matches between the real input and the cloned input. So I doubt we'll be able to make that behavior always-on? It would need some config at least. Pull Request resolved: https://github.com/pytorch/pytorch/pull/99744 Approved by: https://github.com/albanD, https://github.com/ezyang, https://github.com/zou3519
71 lines
2.2 KiB
Python
71 lines
2.2 KiB
Python
import functools
|
|
from importlib import import_module
|
|
|
|
from functorch.compile import min_cut_rematerialization_partition, nop
|
|
|
|
import torch
|
|
from torch._functorch.compilers import ts_compile
|
|
from .common import aot_autograd
|
|
from .registry import register_debug_backend as register_backend
|
|
|
|
"""
|
|
This file contains TorchDynamo backends intended for debugging uses.
|
|
"""
|
|
|
|
|
|
@register_backend
|
|
def eager(gm, fake_tensor_inputs):
|
|
return gm
|
|
|
|
|
|
@register_backend
|
|
def eager_debug(gm, fake_tensor_inputs):
|
|
from torch._subclasses.schema_check_mode import SchemaCheckMode
|
|
|
|
# We could add more debugging bits here.
|
|
# Right now, this backend can be used to check for and error on
|
|
# custom dispatcher ops that have incorrect schemas.
|
|
def inner(*args):
|
|
with SchemaCheckMode():
|
|
return torch.fx.Interpreter(gm).run(*args)
|
|
|
|
return inner
|
|
|
|
|
|
@register_backend(name="ts")
|
|
def torchscript(gm, fake_tensor_inputs):
|
|
return torch.jit.script(gm)
|
|
|
|
|
|
# Useful for debugging purpose
|
|
# aot_eager uses AOT Autograd backend with nop compiler. It is helpful in debugging.
|
|
aot_eager = aot_autograd(fw_compiler=nop)
|
|
register_backend(name="aot_eager", compiler_fn=aot_eager)
|
|
|
|
|
|
# Uses TorchInductor AOT Autograd decomps and partitioner to isolate aot vs
|
|
# inductor problems.
|
|
# aot_eager_decomp_partition just replaces the inductor compiler with nop to help
|
|
# isolate inductor vs aot_eager errors
|
|
aot_eager_decomp_partition = aot_autograd(
|
|
# these are taken from memory_efficient_fusion()
|
|
fw_compiler=nop,
|
|
bw_compiler=nop,
|
|
# NB: lambda here is to delay import of inductor
|
|
decompositions=lambda: import_module(
|
|
"torch._inductor.compile_fx"
|
|
).select_decomp_table(),
|
|
partition_fn=functools.partial(
|
|
min_cut_rematerialization_partition, compiler="inductor"
|
|
),
|
|
)
|
|
register_backend(
|
|
name="aot_eager_decomp_partition", compiler_fn=aot_eager_decomp_partition
|
|
)
|
|
|
|
# AOT Autograd with torchscript backend. Default partitioner.
|
|
# aot_ts uses torchscript backend. We can use this with both nnc and nvfuser
|
|
# by using the relevant fuser with torch.jit.fuser(...)
|
|
aot_ts = aot_autograd(fw_compiler=ts_compile)
|
|
register_backend(name="aot_ts", compiler_fn=aot_ts)
|