mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[precompile] Preserve default arguments for dynamo capture (#166654)
Summary: Handle the case where there's default arguments on function signature. Test Plan: pytest test/export/test_experimental.py -k test_dynamo_graph_capture_default_args Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/166654 Approved by: https://github.com/tugsbayasgalan
This commit is contained in:
parent
8d599045cf
commit
83cc38d9c1
|
|
@ -582,6 +582,23 @@ from user code:
|
|||
actual = compiled_fn(fn, *inputs)
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_aot_compile_with_default_args(self):
|
||||
def fn(x, y=1):
|
||||
return x + x
|
||||
|
||||
compiled_fn = torch.compile(fn, fullgraph=True).aot_compile(
|
||||
((torch.randn(3, 4),), {})
|
||||
)
|
||||
inputs = (torch.randn(3, 4),)
|
||||
expected = fn(*inputs)
|
||||
actual = compiled_fn(*inputs)
|
||||
self.assertEqual(expected, actual)
|
||||
compiled_fn.save_compiled_function(self.path())
|
||||
with open(self.path(), "rb") as f:
|
||||
compiled_fn = torch.compiler.load_compiled_function(f)
|
||||
actual = compiled_fn(*inputs)
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
|
|
|||
|
|
@ -585,6 +585,16 @@ def forward(self, args_0):
|
|||
test_inputs = input_fn()
|
||||
self.assertEqual(gm(*test_inputs), model(*test_inputs))
|
||||
|
||||
def test_dynamo_graph_capture_default_args(self):
|
||||
class Module(torch.nn.Module):
|
||||
def forward(self, x, y=1):
|
||||
return x + y
|
||||
|
||||
m = Module()
|
||||
ep = dynamo_graph_capture_for_export(m)(torch.randn(2, 3))
|
||||
test_inputs = (torch.randn(2, 3),)
|
||||
self.assertEqual(ep(*test_inputs), m(*test_inputs))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -50,6 +50,7 @@ class CompileArtifacts:
|
|||
compiled_fn: SerializableCallable
|
||||
original_code: types.CodeType
|
||||
closure: Optional[tuple[Any, ...]]
|
||||
argdefs: Optional[tuple[Any, ...]]
|
||||
source_info: "SourceInfo"
|
||||
device_type: str
|
||||
system_info: SystemInfo = dataclasses.field(default_factory=SystemInfo.current)
|
||||
|
|
@ -111,7 +112,10 @@ class AOTCompiledFunction:
|
|||
}
|
||||
# pyrefly: ignore [read-only]
|
||||
self.fn = types.FunctionType(
|
||||
self._artifacts.bytecode, f_globals, closure=self._artifacts.closure
|
||||
self._artifacts.bytecode,
|
||||
f_globals,
|
||||
closure=self._artifacts.closure,
|
||||
argdefs=self._artifacts.argdefs,
|
||||
)
|
||||
|
||||
if self._artifacts.guard_manager is None:
|
||||
|
|
@ -266,6 +270,7 @@ def aot_compile_fullgraph(
|
|||
compiled_fn=compiled_fn,
|
||||
original_code=fn.__code__,
|
||||
closure=fn.__closure__,
|
||||
argdefs=fn.__defaults__,
|
||||
source_info=source_info,
|
||||
device_type=device_type,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -882,7 +882,9 @@ class DynamoOutput:
|
|||
strict_error=strict_error,
|
||||
)
|
||||
|
||||
def graph_capture_output(self) -> GraphCaptureOutput:
|
||||
def graph_capture_output(
|
||||
self, argdefs: Optional[tuple[Any, ...]] = None
|
||||
) -> GraphCaptureOutput:
|
||||
output_graph = self.tracer_output.output_graph
|
||||
assert output_graph is not None
|
||||
return GraphCaptureOutput(
|
||||
|
|
@ -897,6 +899,7 @@ class DynamoOutput:
|
|||
output_graph.traced_code,
|
||||
self.bytecode,
|
||||
self.tracer_output.closure,
|
||||
argdefs,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -929,6 +932,7 @@ class GraphCaptureOutput:
|
|||
traced_code: list[CodeType]
|
||||
bytecode: CodeType
|
||||
closure: Optional[tuple[Any, ...]]
|
||||
argdefs: Optional[tuple[Any, ...]]
|
||||
|
||||
def build_guards(
|
||||
self,
|
||||
|
|
@ -984,6 +988,7 @@ class CaptureOutput:
|
|||
self.graph_capture_output.bytecode,
|
||||
f_globals,
|
||||
closure=self.graph_capture_output.closure,
|
||||
argdefs=self.graph_capture_output.argdefs,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -1044,6 +1049,7 @@ def _get_frame(
|
|||
f_locals,
|
||||
builtins.__dict__,
|
||||
closure=fn.__closure__ or (), # type: ignore[arg-type]
|
||||
argdefs=fn.__defaults__,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -1093,6 +1099,7 @@ class FrameInfo:
|
|||
locals: dict[str, object]
|
||||
builtins: dict[str, object]
|
||||
closure: tuple[CellType]
|
||||
argdefs: Optional[tuple[Any, ...]]
|
||||
|
||||
|
||||
def _fullgraph_capture_frame(
|
||||
|
|
@ -1146,7 +1153,7 @@ def _fullgraph_capture_frame(
|
|||
raise e.with_traceback(None) from e.__cause__ # User compiler error
|
||||
|
||||
return CaptureOutput(
|
||||
dynamo_output.graph_capture_output(),
|
||||
dynamo_output.graph_capture_output(frame.argdefs),
|
||||
backend_input,
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user