mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Make torch_geometric models compatible with export (#123403)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123403 Approved by: https://github.com/angelayi
This commit is contained in:
parent
18c9d46068
commit
2ffab6e663
|
|
@ -1135,12 +1135,12 @@ class AOTInductorModelCache:
|
|||
example_outputs = copy.deepcopy(model)(*example_args, **example_kwargs)
|
||||
_register_dataclass_output_as_pytree(example_outputs)
|
||||
|
||||
# TODO(angelayi): change this to predispatch
|
||||
gm = torch.export._trace._export_to_torch_ir(
|
||||
gm = torch.export._trace._export(
|
||||
model,
|
||||
example_args,
|
||||
example_kwargs,
|
||||
)
|
||||
pre_dispatch=True,
|
||||
).module()
|
||||
with torch.no_grad():
|
||||
so_path = torch._inductor.aot_compile(
|
||||
gm, example_args, example_kwargs
|
||||
|
|
|
|||
|
|
@ -25,6 +25,20 @@ from torch._dynamo.utils import clone_inputs
|
|||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
|
||||
def _reassign_parameters(model):
|
||||
# torch_geometric models register parameter as tensors due to
|
||||
# https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/nn/dense/linear.py#L158-L168
|
||||
# Since it is unusual thing to do, we just reassign them to parameters
|
||||
def state_dict_hook(module, destination, prefix, local_metadata):
|
||||
for name, param in module.named_parameters():
|
||||
if isinstance(destination[name], torch.Tensor) and not isinstance(
|
||||
destination[name], torch.nn.Parameter
|
||||
):
|
||||
destination[name] = torch.nn.Parameter(destination[name])
|
||||
|
||||
model._register_state_dict_hook(state_dict_hook)
|
||||
|
||||
|
||||
def setup_torchbench_cwd():
|
||||
original_dir = abspath(os.getcwd())
|
||||
|
||||
|
|
@ -265,6 +279,9 @@ class TorchBenchmarkRunner(BenchmarkRunner):
|
|||
extra_args=extra_args,
|
||||
)
|
||||
model, example_inputs = benchmark.get_module()
|
||||
if model_name in ["basic_gnn_edgecnn", "basic_gnn_gcn", "basic_gnn_sage"]:
|
||||
_reassign_parameters(model)
|
||||
|
||||
# Models that must be in train mode while training
|
||||
if is_training and (
|
||||
not use_eval_mode or model_name in self._config["only_training"]
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user