mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add support for Inductor + symbolic shapes + training (#93059)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/93059 Approved by: https://github.com/ezyang
This commit is contained in:
parent
70029214f3
commit
20dfce591c
|
|
@ -314,7 +314,7 @@ def dump_compiler_graph_state(gm, args, compiler_name):
|
|||
def save_graph_repro(fd, gm, args, compiler_name):
|
||||
sync_line = ""
|
||||
for arg in args:
|
||||
if arg.is_cuda:
|
||||
if isinstance(arg, torch.Tensor) and arg.is_cuda:
|
||||
sync_line = "torch.cuda.synchronize() # Ensures that segfaults are surfaced"
|
||||
break
|
||||
|
||||
|
|
|
|||
|
|
@ -2315,12 +2315,20 @@ def aot_dispatch_autograd(flat_fn, flat_args: List[Any], aot_config: AOTConfig):
|
|||
|
||||
def call_compiled_backward():
|
||||
if CompiledFunction.compiled_bw is None:
|
||||
# TODO - pass in fake tensors ?
|
||||
context = disable_autocast_manager if disable_amp else nullcontext
|
||||
with context(), track_graph_compiling(aot_config, "backward"):
|
||||
CompiledFunction.compiled_bw = aot_config.bw_compiler(
|
||||
bw_module, all_args
|
||||
if config.use_dynamic_shapes:
|
||||
all_args_list = list(all_args)
|
||||
CompiledFunction.compiled_bw = create_aot_dispatcher_function(
|
||||
bw_module, all_args_list, AOTConfig(
|
||||
aot_config.bw_compiler, None, None,
|
||||
aot_config.decompositions, 0, aot_config.aot_id, aot_config.keep_inference_input_mutations
|
||||
)
|
||||
)
|
||||
else:
|
||||
context = disable_autocast_manager if disable_amp else nullcontext
|
||||
with context(), track_graph_compiling(aot_config, "backward"):
|
||||
CompiledFunction.compiled_bw = aot_config.bw_compiler(
|
||||
bw_module, all_args
|
||||
)
|
||||
|
||||
ctx.maybe_clear_saved_tensors()
|
||||
out = call_func_with_args(
|
||||
|
|
@ -2463,8 +2471,11 @@ def create_aot_dispatcher_function(
|
|||
|
||||
def process_inputs(flat_args):
|
||||
if config.use_fake_tensor or isinstance(fake_mode, FakeTensorMode):
|
||||
|
||||
def convert(idx, x):
|
||||
if shape_env is not None:
|
||||
from torch._dynamo.source import ConstantSource
|
||||
if isinstance(x, int):
|
||||
return shape_env.create_symintnode(shape_env.create_symbol(x, ConstantSource(f"sym_{idx}")), hint=x)
|
||||
if not isinstance(x, torch.Tensor):
|
||||
return x
|
||||
if isinstance(x, FakeTensor):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user