Add support for Inductor + symbolic shapes + training (#93059)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/93059
Approved by: https://github.com/ezyang
This commit is contained in:
Edward Z. Yang 2023-02-28 14:13:33 -05:00 committed by PyTorch MergeBot
parent 70029214f3
commit 20dfce591c
2 changed files with 18 additions and 7 deletions

View File

@ -314,7 +314,7 @@ def dump_compiler_graph_state(gm, args, compiler_name):
def save_graph_repro(fd, gm, args, compiler_name):
sync_line = ""
for arg in args:
if arg.is_cuda:
if isinstance(arg, torch.Tensor) and arg.is_cuda:
sync_line = "torch.cuda.synchronize() # Ensures that segfaults are surfaced"
break

View File

@ -2315,12 +2315,20 @@ def aot_dispatch_autograd(flat_fn, flat_args: List[Any], aot_config: AOTConfig):
def call_compiled_backward():
if CompiledFunction.compiled_bw is None:
# TODO - pass in fake tensors ?
context = disable_autocast_manager if disable_amp else nullcontext
with context(), track_graph_compiling(aot_config, "backward"):
CompiledFunction.compiled_bw = aot_config.bw_compiler(
bw_module, all_args
if config.use_dynamic_shapes:
all_args_list = list(all_args)
CompiledFunction.compiled_bw = create_aot_dispatcher_function(
bw_module, all_args_list, AOTConfig(
aot_config.bw_compiler, None, None,
aot_config.decompositions, 0, aot_config.aot_id, aot_config.keep_inference_input_mutations
)
)
else:
context = disable_autocast_manager if disable_amp else nullcontext
with context(), track_graph_compiling(aot_config, "backward"):
CompiledFunction.compiled_bw = aot_config.bw_compiler(
bw_module, all_args
)
ctx.maybe_clear_saved_tensors()
out = call_func_with_args(
@ -2463,8 +2471,11 @@ def create_aot_dispatcher_function(
def process_inputs(flat_args):
if config.use_fake_tensor or isinstance(fake_mode, FakeTensorMode):
def convert(idx, x):
if shape_env is not None:
from torch._dynamo.source import ConstantSource
if isinstance(x, int):
return shape_env.create_symintnode(shape_env.create_symbol(x, ConstantSource(f"sym_{idx}")), hint=x)
if not isinstance(x, torch.Tensor):
return x
if isinstance(x, FakeTensor):