Make torch_geometric models compatible with export (#123403)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123403
Approved by: https://github.com/angelayi
This commit is contained in:
Tugsbayasgalan Manlaibaatar 2024-04-04 17:08:26 -07:00 committed by PyTorch MergeBot
parent 18c9d46068
commit 2ffab6e663
2 changed files with 20 additions and 3 deletions

View File

@ -1135,12 +1135,12 @@ class AOTInductorModelCache:
example_outputs = copy.deepcopy(model)(*example_args, **example_kwargs)
_register_dataclass_output_as_pytree(example_outputs)
# TODO(angelayi): change this to predispatch
gm = torch.export._trace._export_to_torch_ir(
gm = torch.export._trace._export(
model,
example_args,
example_kwargs,
)
pre_dispatch=True,
).module()
with torch.no_grad():
so_path = torch._inductor.aot_compile(
gm, example_args, example_kwargs

View File

@ -25,6 +25,20 @@ from torch._dynamo.utils import clone_inputs
torch.backends.cuda.matmul.allow_tf32 = True
def _reassign_parameters(model):
# torch_geometric models register parameter as tensors due to
# https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/nn/dense/linear.py#L158-L168
# Since it is unusual thing to do, we just reassign them to parameters
def state_dict_hook(module, destination, prefix, local_metadata):
for name, param in module.named_parameters():
if isinstance(destination[name], torch.Tensor) and not isinstance(
destination[name], torch.nn.Parameter
):
destination[name] = torch.nn.Parameter(destination[name])
model._register_state_dict_hook(state_dict_hook)
def setup_torchbench_cwd():
original_dir = abspath(os.getcwd())
@ -265,6 +279,9 @@ class TorchBenchmarkRunner(BenchmarkRunner):
extra_args=extra_args,
)
model, example_inputs = benchmark.get_module()
if model_name in ["basic_gnn_edgecnn", "basic_gnn_gcn", "basic_gnn_sage"]:
_reassign_parameters(model)
# Models that must be in train mode while training
if is_training and (
not use_eval_mode or model_name in self._config["only_training"]