[torchgen] Refactor torchgen.utils.FileManager to accept pathlib.Path (#150726)

This PR allows `FileManager` to accept `pathlib.Path` as arguments while keeping the original `str` path support.

This allows us to simplify the code such as:

1. `os.path.join(..., ...)` with `Path.__floordiv__(..., ...)`.

95a5958db4/torchgen/utils.py (L155)

95a5958db4/torchgen/utils.py (L176)

2. `os.path.basename(...)` with `Path(...).name`.
 95a5958db4/torchgen/utils.py (L161)

3. Manual file extension split with `Path(...).with_stem(new_stem)`

95a5958db4/torchgen/utils.py (L241-L256)

------

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150726
Approved by: https://github.com/aorenste
This commit is contained in:
Xuehai Pan 2025-05-15 01:40:19 +08:00 committed by PyTorch MergeBot
parent 881a598a1e
commit 014726d9d3
12 changed files with 110 additions and 74 deletions

View File

@ -555,8 +555,7 @@ def gen_autograd_functions_lib(
fname, fname,
lambda: { lambda: {
"generated_comment": "@" "generated_comment": "@"
+ f"generated from {fm.template_dir_for_comments()}/" + f"generated from {fm.template_dir_for_comments()}/{fname}",
+ fname,
"autograd_function_declarations": declarations, "autograd_function_declarations": declarations,
"autograd_function_definitions": definitions, "autograd_function_definitions": definitions,
}, },

View File

@ -331,8 +331,7 @@ def gen_view_funcs(
fname, fname,
lambda: { lambda: {
"generated_comment": "@" "generated_comment": "@"
+ f"generated from {fm.template_dir_for_comments()}/" + f"generated from {fm.template_dir_for_comments()}/{fname}",
+ fname,
"view_func_declarations": declarations, "view_func_declarations": declarations,
"view_func_definitions": definitions, "view_func_definitions": definitions,
"ops_headers": ops_headers, "ops_headers": ops_headers,

View File

@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from typing_extensions import assert_never
from torchgen import local from torchgen import local
from torchgen.api.types import ( from torchgen.api.types import (
@ -48,7 +49,6 @@ from torchgen.model import (
TensorOptionsArguments, TensorOptionsArguments,
Type, Type,
) )
from torchgen.utils import assert_never
if TYPE_CHECKING: if TYPE_CHECKING:

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import itertools import itertools
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from typing_extensions import assert_never
from torchgen.api import cpp from torchgen.api import cpp
from torchgen.api.types import ArgName, Binding, CType, NamedCType from torchgen.api.types import ArgName, Binding, CType, NamedCType
@ -13,7 +14,7 @@ from torchgen.model import (
TensorOptionsArguments, TensorOptionsArguments,
Type, Type,
) )
from torchgen.utils import assert_never, concatMap from torchgen.utils import concatMap
if TYPE_CHECKING: if TYPE_CHECKING:

View File

@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from typing_extensions import assert_never
from torchgen import local from torchgen import local
from torchgen.api import cpp from torchgen.api import cpp
@ -29,7 +30,6 @@ from torchgen.model import (
TensorOptionsArguments, TensorOptionsArguments,
Type, Type,
) )
from torchgen.utils import assert_never
if TYPE_CHECKING: if TYPE_CHECKING:

View File

@ -1,5 +1,7 @@
from __future__ import annotations from __future__ import annotations
from typing_extensions import assert_never
from torchgen.api import cpp from torchgen.api import cpp
from torchgen.api.types import ( from torchgen.api.types import (
ArgName, ArgName,
@ -30,7 +32,6 @@ from torchgen.model import (
TensorOptionsArguments, TensorOptionsArguments,
Type, Type,
) )
from torchgen.utils import assert_never
# This file describes the translation of JIT schema to the structured functions API. # This file describes the translation of JIT schema to the structured functions API.

View File

@ -4,6 +4,7 @@ import itertools
import textwrap import textwrap
from dataclasses import dataclass from dataclasses import dataclass
from typing import Literal, TYPE_CHECKING from typing import Literal, TYPE_CHECKING
from typing_extensions import assert_never
import torchgen.api.cpp as cpp import torchgen.api.cpp as cpp
import torchgen.api.meta as meta import torchgen.api.meta as meta
@ -36,7 +37,7 @@ from torchgen.model import (
SchemaKind, SchemaKind,
TensorOptionsArguments, TensorOptionsArguments,
) )
from torchgen.utils import assert_never, mapMaybe, Target from torchgen.utils import mapMaybe, Target
if TYPE_CHECKING: if TYPE_CHECKING:

View File

@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from typing_extensions import assert_never
from torchgen import local from torchgen import local
from torchgen.api.types import ( from torchgen.api.types import (
@ -37,7 +38,6 @@ from torchgen.model import (
TensorOptionsArguments, TensorOptionsArguments,
Type, Type,
) )
from torchgen.utils import assert_never
if TYPE_CHECKING: if TYPE_CHECKING:

View File

@ -7,6 +7,7 @@ import itertools
from collections import defaultdict, namedtuple from collections import defaultdict, namedtuple
from dataclasses import dataclass from dataclasses import dataclass
from enum import IntEnum from enum import IntEnum
from typing_extensions import assert_never
from torchgen.model import ( from torchgen.model import (
BackendIndex, BackendIndex,
@ -16,7 +17,6 @@ from torchgen.model import (
NativeFunctionsGroup, NativeFunctionsGroup,
OperatorName, OperatorName,
) )
from torchgen.utils import assert_never
KERNEL_KEY_VERSION = 1 KERNEL_KEY_VERSION = 1

View File

@ -9,6 +9,7 @@ from collections import defaultdict, namedtuple, OrderedDict
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Literal, TYPE_CHECKING, TypeVar from typing import Any, Callable, Literal, TYPE_CHECKING, TypeVar
from typing_extensions import assert_never
import yaml import yaml
@ -84,7 +85,6 @@ from torchgen.native_function_generation import (
) )
from torchgen.selective_build.selector import SelectiveBuilder from torchgen.selective_build.selector import SelectiveBuilder
from torchgen.utils import ( from torchgen.utils import (
assert_never,
concatMap, concatMap,
context, context,
FileManager, FileManager,

View File

@ -6,8 +6,9 @@ import re
from dataclasses import dataclass from dataclasses import dataclass
from enum import auto, Enum from enum import auto, Enum
from typing import Callable, Optional, TYPE_CHECKING from typing import Callable, Optional, TYPE_CHECKING
from typing_extensions import assert_never
from torchgen.utils import assert_never, NamespaceHelper, OrderedSet from torchgen.utils import NamespaceHelper, OrderedSet
if TYPE_CHECKING: if TYPE_CHECKING:

View File

@ -11,7 +11,7 @@ from dataclasses import fields, is_dataclass
from enum import auto, Enum from enum import auto, Enum
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Generic, Literal, NoReturn, TYPE_CHECKING, TypeVar from typing import Any, Callable, Generic, Literal, NoReturn, TYPE_CHECKING, TypeVar
from typing_extensions import Self from typing_extensions import assert_never, deprecated, Self
from torchgen.code_template import CodeTemplate from torchgen.code_template import CodeTemplate
@ -21,7 +21,8 @@ if TYPE_CHECKING:
from collections.abc import Iterable, Iterator, Sequence from collections.abc import Iterable, Iterator, Sequence
REPO_ROOT = Path(__file__).absolute().parent.parent TORCHGEN_ROOT = Path(__file__).absolute().parent
REPO_ROOT = TORCHGEN_ROOT.parent
# Many of these functions share logic for defining both the definition # Many of these functions share logic for defining both the definition
@ -96,11 +97,13 @@ def context(msg_fn: Callable[[], str]) -> Iterator[None]:
raise raise
# A little trick from https://github.com/python/mypy/issues/6366 if TYPE_CHECKING:
# for getting mypy to do exhaustiveness checking # A little trick from https://github.com/python/mypy/issues/6366
# TODO: put this somewhere else, maybe # for getting mypy to do exhaustiveness checking
def assert_never(x: NoReturn) -> NoReturn: # TODO: put this somewhere else, maybe
raise AssertionError(f"Unhandled type: {type(x).__name__}") @deprecated("Use typing_extensions.assert_never instead")
def assert_never(x: NoReturn) -> NoReturn: # type: ignore[misc] # noqa: F811
raise AssertionError(f"Unhandled type: {type(x).__name__}")
@functools.cache @functools.cache
@ -118,39 +121,47 @@ def string_stable_hash(s: str) -> int:
# of what files have been written (so you can write out a list of output # of what files have been written (so you can write out a list of output
# files) # files)
class FileManager: class FileManager:
install_dir: str def __init__(
template_dir: str self,
dry_run: bool install_dir: str | Path,
filenames: set[str] template_dir: str | Path,
dry_run: bool,
def __init__(self, install_dir: str, template_dir: str, dry_run: bool) -> None: ) -> None:
self.install_dir = install_dir self.install_dir = Path(install_dir)
self.template_dir = template_dir self.template_dir = Path(template_dir)
self.filenames = set() self.files: set[Path] = set()
self.dry_run = dry_run self.dry_run = dry_run
def _write_if_changed(self, filename: str, contents: str) -> None: @property
old_contents: str | None def filenames(self) -> frozenset[str]:
return frozenset({file.as_posix() for file in self.files})
def _write_if_changed(self, filename: str | Path, contents: str) -> None:
file = Path(filename)
old_contents: str | None = None
try: try:
with open(filename) as f: old_contents = file.read_text(encoding="utf-8")
old_contents = f.read()
except OSError: except OSError:
old_contents = None pass
if contents != old_contents: if contents != old_contents:
# Create output directory if it doesn't exist # Create output directory if it doesn't exist
os.makedirs(os.path.dirname(filename), exist_ok=True) file.parent.mkdir(parents=True, exist_ok=True)
with open(filename, "w") as f: file.write_text(contents, encoding="utf-8")
f.write(contents)
# Read from template file and replace pattern with callable (type could be dict or str). # Read from template file and replace pattern with callable (type could be dict or str).
def substitute_with_template( def substitute_with_template(
self, template_fn: str, env_callable: Callable[[], str | dict[str, Any]] self,
template_fn: str | Path,
env_callable: Callable[[], str | dict[str, Any]],
) -> str: ) -> str:
template_path = os.path.join(self.template_dir, template_fn) assert not Path(template_fn).is_absolute(), (
f"template_fn must be relative: {template_fn}"
)
template_path = self.template_dir / template_fn
env = env_callable() env = env_callable()
if isinstance(env, dict): if isinstance(env, dict):
if "generated_comment" not in env: if "generated_comment" not in env:
generator_default = REPO_ROOT / "torchgen" / "gen.py" generator_default = TORCHGEN_ROOT / "gen.py"
try: try:
generator = Path( generator = Path(
sys.modules["__main__"].__file__ or generator_default sys.modules["__main__"].__file__ or generator_default
@ -170,38 +181,56 @@ class FileManager:
), ),
} }
template = _read_template(template_path) template = _read_template(template_path)
return template.substitute(env) substitute_out = template.substitute(env)
elif isinstance(env, str): # Ensure an extra blank line between the class/function definition
# and the docstring of the previous class/function definition.
# NB: It is generally not recommended to have docstrings in pyi stub
# files. But if there are any, we need to ensure that the file
# is properly formatted.
return re.sub(
r'''
(""")\n+ # match triple quotes
(
(\s*@.+\n)* # match decorators if any
\s*(class|def) # match class/function definition
)
''',
r"\g<1>\n\n\g<2>",
substitute_out,
flags=re.VERBOSE,
)
if isinstance(env, str):
return env return env
else: assert_never(env)
assert_never(env)
def write_with_template( def write_with_template(
self, self,
filename: str, filename: str | Path,
template_fn: str, template_fn: str | Path,
env_callable: Callable[[], str | dict[str, Any]], env_callable: Callable[[], str | dict[str, Any]],
) -> None: ) -> None:
filename = f"{self.install_dir}/{filename}" filename = Path(filename)
assert filename not in self.filenames, "duplicate file write {filename}" assert not filename.is_absolute(), f"filename must be relative: {filename}"
self.filenames.add(filename) file = self.install_dir / filename
assert file not in self.files, f"duplicate file write {file}"
self.files.add(file)
if not self.dry_run: if not self.dry_run:
substitute_out = self.substitute_with_template( substitute_out = self.substitute_with_template(
template_fn=template_fn, template_fn=template_fn,
env_callable=env_callable, env_callable=env_callable,
) )
self._write_if_changed(filename=filename, contents=substitute_out) self._write_if_changed(filename=file, contents=substitute_out)
def write( def write(
self, self,
filename: str, filename: str | Path,
env_callable: Callable[[], str | dict[str, Any]], env_callable: Callable[[], str | dict[str, Any]],
) -> None: ) -> None:
self.write_with_template(filename, filename, env_callable) self.write_with_template(filename, filename, env_callable)
def write_sharded( def write_sharded(
self, self,
filename: str, filename: str | Path,
items: Iterable[T], items: Iterable[T],
*, *,
key_fn: Callable[[T], str], key_fn: Callable[[T], str],
@ -223,8 +252,8 @@ class FileManager:
def write_sharded_with_template( def write_sharded_with_template(
self, self,
filename: str, filename: str | Path,
template_fn: str, template_fn: str | Path,
items: Iterable[T], items: Iterable[T],
*, *,
key_fn: Callable[[T], str], key_fn: Callable[[T], str],
@ -233,6 +262,8 @@ class FileManager:
base_env: dict[str, Any] | None = None, base_env: dict[str, Any] | None = None,
sharded_keys: set[str], sharded_keys: set[str],
) -> None: ) -> None:
file = Path(filename)
assert not file.is_absolute(), f"filename must be relative: {filename}"
everything: dict[str, Any] = {"shard_id": "Everything"} everything: dict[str, Any] = {"shard_id": "Everything"}
shards: list[dict[str, Any]] = [ shards: list[dict[str, Any]] = [
{"shard_id": f"_{i}"} for i in range(num_shards) {"shard_id": f"_{i}"} for i in range(num_shards)
@ -270,31 +301,27 @@ class FileManager:
merge_env(shards[sid], env) merge_env(shards[sid], env)
merge_env(everything, env) merge_env(everything, env)
dot_pos = filename.rfind(".")
if dot_pos == -1:
dot_pos = len(filename)
base_filename = filename[:dot_pos]
extension = filename[dot_pos:]
for shard in all_shards: for shard in all_shards:
shard_id = shard["shard_id"] shard_id = shard["shard_id"]
self.write_with_template( self.write_with_template(
f"{base_filename}{shard_id}{extension}", file.with_stem(f"{file.stem}{shard_id}"),
template_fn, template_fn,
lambda: shard, lambda: shard,
) )
# filenames is used to track compiled files, but FooEverything.cpp isn't meant to be compiled # filenames is used to track compiled files, but FooEverything.cpp isn't meant to be compiled
self.filenames.discard( self.files.discard(self.install_dir / file.with_stem(f"{file.stem}Everything"))
f"{self.install_dir}/{base_filename}Everything{extension}"
)
def write_outputs(self, variable_name: str, filename: str) -> None: def write_outputs(self, variable_name: str, filename: str | Path) -> None:
"""Write a file containing the list of all outputs which are """Write a file containing the list of all outputs which are generated by this script."""
generated by this script.""" content = "\n".join(
content = "set({}\n {})".format( (
variable_name, "set(",
"\n ".join('"' + name + '"' for name in sorted(self.filenames)), variable_name,
# Use POSIX paths to avoid invalid escape sequences on Windows
*(f' "{file.as_posix()}"' for file in sorted(self.files)),
")",
)
) )
self._write_if_changed(filename, content) self._write_if_changed(filename, content)
@ -309,12 +336,15 @@ class FileManager:
# Helper function to generate file manager # Helper function to generate file manager
def make_file_manager( def make_file_manager(
options: Namespace, install_dir: str | None = None options: Namespace,
install_dir: str | Path | None = None,
) -> FileManager: ) -> FileManager:
template_dir = os.path.join(options.source_path, "templates") template_dir = os.path.join(options.source_path, "templates")
install_dir = install_dir if install_dir else options.install_dir install_dir = install_dir if install_dir else options.install_dir
return FileManager( return FileManager(
install_dir=install_dir, template_dir=template_dir, dry_run=options.dry_run install_dir=install_dir,
template_dir=template_dir,
dry_run=options.dry_run,
) )
@ -437,7 +467,10 @@ class NamespaceHelper:
""" """
def __init__( def __init__(
self, namespace_str: str, entity_name: str = "", max_level: int = 2 self,
namespace_str: str,
entity_name: str = "",
max_level: int = 2,
) -> None: ) -> None:
# cpp_namespace can be a colon joined string such as torch::lazy # cpp_namespace can be a colon joined string such as torch::lazy
cpp_namespaces = namespace_str.split("::") cpp_namespaces = namespace_str.split("::")
@ -454,7 +487,8 @@ class NamespaceHelper:
@staticmethod @staticmethod
def from_namespaced_entity( def from_namespaced_entity(
namespaced_entity: str, max_level: int = 2 namespaced_entity: str,
max_level: int = 2,
) -> NamespaceHelper: ) -> NamespaceHelper:
""" """
Generate helper from nested namespaces as long as class/function name. E.g.: "torch::lazy::add" Generate helper from nested namespaces as long as class/function name. E.g.: "torch::lazy::add"