mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[JIT SSA] Allow updating shape functions without recompilation (#83629)
In order to avoid extra round trips, and avoid confusion in places such as this to manually pull in the latest copy of the shape_functions.py file This also fixes the cases where people pull in the wrong version of the file. This can happen in cases such as when developers run `python setup.py install` instead of `python setup.py develop` to generate their current copy of Pytorch. Pull Request resolved: https://github.com/pytorch/pytorch/pull/83629 Approved by: https://github.com/davidberard98
This commit is contained in:
parent
53cda905be
commit
eff28d61c9
|
|
@ -1,12 +1,34 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
import importlib.util
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from torch.jit._shape_functions import (
|
|
||||||
bounded_compute_graph_mapping,
|
# Manually importing the shape function module based on current directory
|
||||||
shape_compute_graph_mapping,
|
# instead of torch imports to avoid needing to recompile Pytorch before
|
||||||
)
|
# running the script
|
||||||
|
|
||||||
|
file_path = Path.cwd() / "torch" / "jit" / "_shape_functions.py"
|
||||||
|
module_name = "torch.jit._shape_functions"
|
||||||
|
|
||||||
|
err_msg = """Could not find shape functions file, please make sure
|
||||||
|
you are in the root directory of the Pytorch git repo"""
|
||||||
|
if not file_path.exists():
|
||||||
|
raise Exception(err_msg)
|
||||||
|
|
||||||
|
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
||||||
|
assert spec is not None
|
||||||
|
module = importlib.util.module_from_spec(spec)
|
||||||
|
sys.modules[module_name] = module
|
||||||
|
assert spec.loader is not None
|
||||||
|
assert module is not None
|
||||||
|
spec.loader.exec_module(module)
|
||||||
|
|
||||||
|
bounded_compute_graph_mapping = module.bounded_compute_graph_mapping
|
||||||
|
shape_compute_graph_mapping = module.shape_compute_graph_mapping
|
||||||
|
|
||||||
|
|
||||||
SHAPE_HEADER = r"""
|
SHAPE_HEADER = r"""
|
||||||
/**
|
/**
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user