[precompile] Preserve default arguments for dynamo capture (#166654)

Summary:
Handle the case where there's default arguments on function signature.

Test Plan:
pytest test/export/test_experimental.py -k test_dynamo_graph_capture_default_args

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166654
Approved by: https://github.com/tugsbayasgalan
This commit is contained in:
Zhengxu Chen 2025-11-01 00:12:10 +00:00 committed by PyTorch MergeBot
parent 8d599045cf
commit 83cc38d9c1
4 changed files with 42 additions and 3 deletions

View File

@ -582,6 +582,23 @@ from user code:
actual = compiled_fn(fn, *inputs)
self.assertEqual(expected, actual)
def test_aot_compile_with_default_args(self):
def fn(x, y=1):
return x + x
compiled_fn = torch.compile(fn, fullgraph=True).aot_compile(
((torch.randn(3, 4),), {})
)
inputs = (torch.randn(3, 4),)
expected = fn(*inputs)
actual = compiled_fn(*inputs)
self.assertEqual(expected, actual)
compiled_fn.save_compiled_function(self.path())
with open(self.path(), "rb") as f:
compiled_fn = torch.compiler.load_compiled_function(f)
actual = compiled_fn(*inputs)
self.assertEqual(expected, actual)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -585,6 +585,16 @@ def forward(self, args_0):
test_inputs = input_fn()
self.assertEqual(gm(*test_inputs), model(*test_inputs))
def test_dynamo_graph_capture_default_args(self):
class Module(torch.nn.Module):
def forward(self, x, y=1):
return x + y
m = Module()
ep = dynamo_graph_capture_for_export(m)(torch.randn(2, 3))
test_inputs = (torch.randn(2, 3),)
self.assertEqual(ep(*test_inputs), m(*test_inputs))
if __name__ == "__main__":
run_tests()

View File

@ -50,6 +50,7 @@ class CompileArtifacts:
compiled_fn: SerializableCallable
original_code: types.CodeType
closure: Optional[tuple[Any, ...]]
argdefs: Optional[tuple[Any, ...]]
source_info: "SourceInfo"
device_type: str
system_info: SystemInfo = dataclasses.field(default_factory=SystemInfo.current)
@ -111,7 +112,10 @@ class AOTCompiledFunction:
}
# pyrefly: ignore [read-only]
self.fn = types.FunctionType(
self._artifacts.bytecode, f_globals, closure=self._artifacts.closure
self._artifacts.bytecode,
f_globals,
closure=self._artifacts.closure,
argdefs=self._artifacts.argdefs,
)
if self._artifacts.guard_manager is None:
@ -266,6 +270,7 @@ def aot_compile_fullgraph(
compiled_fn=compiled_fn,
original_code=fn.__code__,
closure=fn.__closure__,
argdefs=fn.__defaults__,
source_info=source_info,
device_type=device_type,
)

View File

@ -882,7 +882,9 @@ class DynamoOutput:
strict_error=strict_error,
)
def graph_capture_output(self) -> GraphCaptureOutput:
def graph_capture_output(
self, argdefs: Optional[tuple[Any, ...]] = None
) -> GraphCaptureOutput:
output_graph = self.tracer_output.output_graph
assert output_graph is not None
return GraphCaptureOutput(
@ -897,6 +899,7 @@ class DynamoOutput:
output_graph.traced_code,
self.bytecode,
self.tracer_output.closure,
argdefs,
)
@ -929,6 +932,7 @@ class GraphCaptureOutput:
traced_code: list[CodeType]
bytecode: CodeType
closure: Optional[tuple[Any, ...]]
argdefs: Optional[tuple[Any, ...]]
def build_guards(
self,
@ -984,6 +988,7 @@ class CaptureOutput:
self.graph_capture_output.bytecode,
f_globals,
closure=self.graph_capture_output.closure,
argdefs=self.graph_capture_output.argdefs,
)
@ -1044,6 +1049,7 @@ def _get_frame(
f_locals,
builtins.__dict__,
closure=fn.__closure__ or (), # type: ignore[arg-type]
argdefs=fn.__defaults__,
)
@ -1093,6 +1099,7 @@ class FrameInfo:
locals: dict[str, object]
builtins: dict[str, object]
closure: tuple[CellType]
argdefs: Optional[tuple[Any, ...]]
def _fullgraph_capture_frame(
@ -1146,7 +1153,7 @@ def _fullgraph_capture_frame(
raise e.with_traceback(None) from e.__cause__ # User compiler error
return CaptureOutput(
dynamo_output.graph_capture_output(),
dynamo_output.graph_capture_output(frame.argdefs),
backend_input,
)