diff --git a/test/dynamo/test_aot_compile.py b/test/dynamo/test_aot_compile.py index c822569f624..8f39435b922 100644 --- a/test/dynamo/test_aot_compile.py +++ b/test/dynamo/test_aot_compile.py @@ -582,6 +582,23 @@ from user code: actual = compiled_fn(fn, *inputs) self.assertEqual(expected, actual) + def test_aot_compile_with_default_args(self): + def fn(x, y=1): + return x + x + + compiled_fn = torch.compile(fn, fullgraph=True).aot_compile( + ((torch.randn(3, 4),), {}) + ) + inputs = (torch.randn(3, 4),) + expected = fn(*inputs) + actual = compiled_fn(*inputs) + self.assertEqual(expected, actual) + compiled_fn.save_compiled_function(self.path()) + with open(self.path(), "rb") as f: + compiled_fn = torch.compiler.load_compiled_function(f) + actual = compiled_fn(*inputs) + self.assertEqual(expected, actual) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/export/test_experimental.py b/test/export/test_experimental.py index a1cc8856810..f78e2893df7 100644 --- a/test/export/test_experimental.py +++ b/test/export/test_experimental.py @@ -585,6 +585,16 @@ def forward(self, args_0): test_inputs = input_fn() self.assertEqual(gm(*test_inputs), model(*test_inputs)) + def test_dynamo_graph_capture_default_args(self): + class Module(torch.nn.Module): + def forward(self, x, y=1): + return x + y + + m = Module() + ep = dynamo_graph_capture_for_export(m)(torch.randn(2, 3)) + test_inputs = (torch.randn(2, 3),) + self.assertEqual(ep(*test_inputs), m(*test_inputs)) + if __name__ == "__main__": run_tests() diff --git a/torch/_dynamo/aot_compile.py b/torch/_dynamo/aot_compile.py index 4b13d677f5a..000d977d29f 100644 --- a/torch/_dynamo/aot_compile.py +++ b/torch/_dynamo/aot_compile.py @@ -50,6 +50,7 @@ class CompileArtifacts: compiled_fn: SerializableCallable original_code: types.CodeType closure: Optional[tuple[Any, ...]] + argdefs: Optional[tuple[Any, ...]] source_info: "SourceInfo" device_type: str system_info: SystemInfo = dataclasses.field(default_factory=SystemInfo.current) @@ -111,7 +112,10 @@ class AOTCompiledFunction: } # pyrefly: ignore [read-only] self.fn = types.FunctionType( - self._artifacts.bytecode, f_globals, closure=self._artifacts.closure + self._artifacts.bytecode, + f_globals, + closure=self._artifacts.closure, + argdefs=self._artifacts.argdefs, ) if self._artifacts.guard_manager is None: @@ -266,6 +270,7 @@ def aot_compile_fullgraph( compiled_fn=compiled_fn, original_code=fn.__code__, closure=fn.__closure__, + argdefs=fn.__defaults__, source_info=source_info, device_type=device_type, ) diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index 74a53c6d9c4..2b81a70eab5 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -882,7 +882,9 @@ class DynamoOutput: strict_error=strict_error, ) - def graph_capture_output(self) -> GraphCaptureOutput: + def graph_capture_output( + self, argdefs: Optional[tuple[Any, ...]] = None + ) -> GraphCaptureOutput: output_graph = self.tracer_output.output_graph assert output_graph is not None return GraphCaptureOutput( @@ -897,6 +899,7 @@ class DynamoOutput: output_graph.traced_code, self.bytecode, self.tracer_output.closure, + argdefs, ) @@ -929,6 +932,7 @@ class GraphCaptureOutput: traced_code: list[CodeType] bytecode: CodeType closure: Optional[tuple[Any, ...]] + argdefs: Optional[tuple[Any, ...]] def build_guards( self, @@ -984,6 +988,7 @@ class CaptureOutput: self.graph_capture_output.bytecode, f_globals, closure=self.graph_capture_output.closure, + argdefs=self.graph_capture_output.argdefs, ) @@ -1044,6 +1049,7 @@ def _get_frame( f_locals, builtins.__dict__, closure=fn.__closure__ or (), # type: ignore[arg-type] + argdefs=fn.__defaults__, ) @@ -1093,6 +1099,7 @@ class FrameInfo: locals: dict[str, object] builtins: dict[str, object] closure: tuple[CellType] + argdefs: Optional[tuple[Any, ...]] def _fullgraph_capture_frame( @@ -1146,7 +1153,7 @@ def _fullgraph_capture_frame( raise e.with_traceback(None) from e.__cause__ # User compiler error return CaptureOutput( - dynamo_output.graph_capture_output(), + dynamo_output.graph_capture_output(frame.argdefs), backend_input, )