mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[torchgen] Refactor torchgen.utils.FileManager to accept pathlib.Path (#150726)
This PR allows `FileManager` to accept `pathlib.Path` as arguments while keeping the original `str` path support. This allows us to simplify the code such as: 1. `os.path.join(..., ...)` with `Path.__floordiv__(..., ...)`.95a5958db4/torchgen/utils.py (L155)95a5958db4/torchgen/utils.py (L176)2. `os.path.basename(...)` with `Path(...).name`.95a5958db4/torchgen/utils.py (L161)3. Manual file extension split with `Path(...).with_stem(new_stem)`95a5958db4/torchgen/utils.py (L241-L256)------ Pull Request resolved: https://github.com/pytorch/pytorch/pull/150726 Approved by: https://github.com/aorenste
This commit is contained in:
parent
881a598a1e
commit
014726d9d3
|
|
@ -555,8 +555,7 @@ def gen_autograd_functions_lib(
|
||||||
fname,
|
fname,
|
||||||
lambda: {
|
lambda: {
|
||||||
"generated_comment": "@"
|
"generated_comment": "@"
|
||||||
+ f"generated from {fm.template_dir_for_comments()}/"
|
+ f"generated from {fm.template_dir_for_comments()}/{fname}",
|
||||||
+ fname,
|
|
||||||
"autograd_function_declarations": declarations,
|
"autograd_function_declarations": declarations,
|
||||||
"autograd_function_definitions": definitions,
|
"autograd_function_definitions": definitions,
|
||||||
},
|
},
|
||||||
|
|
|
||||||
|
|
@ -331,8 +331,7 @@ def gen_view_funcs(
|
||||||
fname,
|
fname,
|
||||||
lambda: {
|
lambda: {
|
||||||
"generated_comment": "@"
|
"generated_comment": "@"
|
||||||
+ f"generated from {fm.template_dir_for_comments()}/"
|
+ f"generated from {fm.template_dir_for_comments()}/{fname}",
|
||||||
+ fname,
|
|
||||||
"view_func_declarations": declarations,
|
"view_func_declarations": declarations,
|
||||||
"view_func_definitions": definitions,
|
"view_func_definitions": definitions,
|
||||||
"ops_headers": ops_headers,
|
"ops_headers": ops_headers,
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
from typing_extensions import assert_never
|
||||||
|
|
||||||
from torchgen import local
|
from torchgen import local
|
||||||
from torchgen.api.types import (
|
from torchgen.api.types import (
|
||||||
|
|
@ -48,7 +49,6 @@ from torchgen.model import (
|
||||||
TensorOptionsArguments,
|
TensorOptionsArguments,
|
||||||
Type,
|
Type,
|
||||||
)
|
)
|
||||||
from torchgen.utils import assert_never
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@ from __future__ import annotations
|
||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
from typing_extensions import assert_never
|
||||||
|
|
||||||
from torchgen.api import cpp
|
from torchgen.api import cpp
|
||||||
from torchgen.api.types import ArgName, Binding, CType, NamedCType
|
from torchgen.api.types import ArgName, Binding, CType, NamedCType
|
||||||
|
|
@ -13,7 +14,7 @@ from torchgen.model import (
|
||||||
TensorOptionsArguments,
|
TensorOptionsArguments,
|
||||||
Type,
|
Type,
|
||||||
)
|
)
|
||||||
from torchgen.utils import assert_never, concatMap
|
from torchgen.utils import concatMap
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
from typing_extensions import assert_never
|
||||||
|
|
||||||
from torchgen import local
|
from torchgen import local
|
||||||
from torchgen.api import cpp
|
from torchgen.api import cpp
|
||||||
|
|
@ -29,7 +30,6 @@ from torchgen.model import (
|
||||||
TensorOptionsArguments,
|
TensorOptionsArguments,
|
||||||
Type,
|
Type,
|
||||||
)
|
)
|
||||||
from torchgen.utils import assert_never
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,7 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing_extensions import assert_never
|
||||||
|
|
||||||
from torchgen.api import cpp
|
from torchgen.api import cpp
|
||||||
from torchgen.api.types import (
|
from torchgen.api.types import (
|
||||||
ArgName,
|
ArgName,
|
||||||
|
|
@ -30,7 +32,6 @@ from torchgen.model import (
|
||||||
TensorOptionsArguments,
|
TensorOptionsArguments,
|
||||||
Type,
|
Type,
|
||||||
)
|
)
|
||||||
from torchgen.utils import assert_never
|
|
||||||
|
|
||||||
|
|
||||||
# This file describes the translation of JIT schema to the structured functions API.
|
# This file describes the translation of JIT schema to the structured functions API.
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ import itertools
|
||||||
import textwrap
|
import textwrap
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Literal, TYPE_CHECKING
|
from typing import Literal, TYPE_CHECKING
|
||||||
|
from typing_extensions import assert_never
|
||||||
|
|
||||||
import torchgen.api.cpp as cpp
|
import torchgen.api.cpp as cpp
|
||||||
import torchgen.api.meta as meta
|
import torchgen.api.meta as meta
|
||||||
|
|
@ -36,7 +37,7 @@ from torchgen.model import (
|
||||||
SchemaKind,
|
SchemaKind,
|
||||||
TensorOptionsArguments,
|
TensorOptionsArguments,
|
||||||
)
|
)
|
||||||
from torchgen.utils import assert_never, mapMaybe, Target
|
from torchgen.utils import mapMaybe, Target
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
from typing_extensions import assert_never
|
||||||
|
|
||||||
from torchgen import local
|
from torchgen import local
|
||||||
from torchgen.api.types import (
|
from torchgen.api.types import (
|
||||||
|
|
@ -37,7 +38,6 @@ from torchgen.model import (
|
||||||
TensorOptionsArguments,
|
TensorOptionsArguments,
|
||||||
Type,
|
Type,
|
||||||
)
|
)
|
||||||
from torchgen.utils import assert_never
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ import itertools
|
||||||
from collections import defaultdict, namedtuple
|
from collections import defaultdict, namedtuple
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
|
from typing_extensions import assert_never
|
||||||
|
|
||||||
from torchgen.model import (
|
from torchgen.model import (
|
||||||
BackendIndex,
|
BackendIndex,
|
||||||
|
|
@ -16,7 +17,6 @@ from torchgen.model import (
|
||||||
NativeFunctionsGroup,
|
NativeFunctionsGroup,
|
||||||
OperatorName,
|
OperatorName,
|
||||||
)
|
)
|
||||||
from torchgen.utils import assert_never
|
|
||||||
|
|
||||||
|
|
||||||
KERNEL_KEY_VERSION = 1
|
KERNEL_KEY_VERSION = 1
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ from collections import defaultdict, namedtuple, OrderedDict
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Literal, TYPE_CHECKING, TypeVar
|
from typing import Any, Callable, Literal, TYPE_CHECKING, TypeVar
|
||||||
|
from typing_extensions import assert_never
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
|
@ -84,7 +85,6 @@ from torchgen.native_function_generation import (
|
||||||
)
|
)
|
||||||
from torchgen.selective_build.selector import SelectiveBuilder
|
from torchgen.selective_build.selector import SelectiveBuilder
|
||||||
from torchgen.utils import (
|
from torchgen.utils import (
|
||||||
assert_never,
|
|
||||||
concatMap,
|
concatMap,
|
||||||
context,
|
context,
|
||||||
FileManager,
|
FileManager,
|
||||||
|
|
|
||||||
|
|
@ -6,8 +6,9 @@ import re
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import auto, Enum
|
from enum import auto, Enum
|
||||||
from typing import Callable, Optional, TYPE_CHECKING
|
from typing import Callable, Optional, TYPE_CHECKING
|
||||||
|
from typing_extensions import assert_never
|
||||||
|
|
||||||
from torchgen.utils import assert_never, NamespaceHelper, OrderedSet
|
from torchgen.utils import NamespaceHelper, OrderedSet
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ from dataclasses import fields, is_dataclass
|
||||||
from enum import auto, Enum
|
from enum import auto, Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Generic, Literal, NoReturn, TYPE_CHECKING, TypeVar
|
from typing import Any, Callable, Generic, Literal, NoReturn, TYPE_CHECKING, TypeVar
|
||||||
from typing_extensions import Self
|
from typing_extensions import assert_never, deprecated, Self
|
||||||
|
|
||||||
from torchgen.code_template import CodeTemplate
|
from torchgen.code_template import CodeTemplate
|
||||||
|
|
||||||
|
|
@ -21,7 +21,8 @@ if TYPE_CHECKING:
|
||||||
from collections.abc import Iterable, Iterator, Sequence
|
from collections.abc import Iterable, Iterator, Sequence
|
||||||
|
|
||||||
|
|
||||||
REPO_ROOT = Path(__file__).absolute().parent.parent
|
TORCHGEN_ROOT = Path(__file__).absolute().parent
|
||||||
|
REPO_ROOT = TORCHGEN_ROOT.parent
|
||||||
|
|
||||||
|
|
||||||
# Many of these functions share logic for defining both the definition
|
# Many of these functions share logic for defining both the definition
|
||||||
|
|
@ -96,11 +97,13 @@ def context(msg_fn: Callable[[], str]) -> Iterator[None]:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
# A little trick from https://github.com/python/mypy/issues/6366
|
if TYPE_CHECKING:
|
||||||
# for getting mypy to do exhaustiveness checking
|
# A little trick from https://github.com/python/mypy/issues/6366
|
||||||
# TODO: put this somewhere else, maybe
|
# for getting mypy to do exhaustiveness checking
|
||||||
def assert_never(x: NoReturn) -> NoReturn:
|
# TODO: put this somewhere else, maybe
|
||||||
raise AssertionError(f"Unhandled type: {type(x).__name__}")
|
@deprecated("Use typing_extensions.assert_never instead")
|
||||||
|
def assert_never(x: NoReturn) -> NoReturn: # type: ignore[misc] # noqa: F811
|
||||||
|
raise AssertionError(f"Unhandled type: {type(x).__name__}")
|
||||||
|
|
||||||
|
|
||||||
@functools.cache
|
@functools.cache
|
||||||
|
|
@ -118,39 +121,47 @@ def string_stable_hash(s: str) -> int:
|
||||||
# of what files have been written (so you can write out a list of output
|
# of what files have been written (so you can write out a list of output
|
||||||
# files)
|
# files)
|
||||||
class FileManager:
|
class FileManager:
|
||||||
install_dir: str
|
def __init__(
|
||||||
template_dir: str
|
self,
|
||||||
dry_run: bool
|
install_dir: str | Path,
|
||||||
filenames: set[str]
|
template_dir: str | Path,
|
||||||
|
dry_run: bool,
|
||||||
def __init__(self, install_dir: str, template_dir: str, dry_run: bool) -> None:
|
) -> None:
|
||||||
self.install_dir = install_dir
|
self.install_dir = Path(install_dir)
|
||||||
self.template_dir = template_dir
|
self.template_dir = Path(template_dir)
|
||||||
self.filenames = set()
|
self.files: set[Path] = set()
|
||||||
self.dry_run = dry_run
|
self.dry_run = dry_run
|
||||||
|
|
||||||
def _write_if_changed(self, filename: str, contents: str) -> None:
|
@property
|
||||||
old_contents: str | None
|
def filenames(self) -> frozenset[str]:
|
||||||
|
return frozenset({file.as_posix() for file in self.files})
|
||||||
|
|
||||||
|
def _write_if_changed(self, filename: str | Path, contents: str) -> None:
|
||||||
|
file = Path(filename)
|
||||||
|
old_contents: str | None = None
|
||||||
try:
|
try:
|
||||||
with open(filename) as f:
|
old_contents = file.read_text(encoding="utf-8")
|
||||||
old_contents = f.read()
|
|
||||||
except OSError:
|
except OSError:
|
||||||
old_contents = None
|
pass
|
||||||
if contents != old_contents:
|
if contents != old_contents:
|
||||||
# Create output directory if it doesn't exist
|
# Create output directory if it doesn't exist
|
||||||
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
file.parent.mkdir(parents=True, exist_ok=True)
|
||||||
with open(filename, "w") as f:
|
file.write_text(contents, encoding="utf-8")
|
||||||
f.write(contents)
|
|
||||||
|
|
||||||
# Read from template file and replace pattern with callable (type could be dict or str).
|
# Read from template file and replace pattern with callable (type could be dict or str).
|
||||||
def substitute_with_template(
|
def substitute_with_template(
|
||||||
self, template_fn: str, env_callable: Callable[[], str | dict[str, Any]]
|
self,
|
||||||
|
template_fn: str | Path,
|
||||||
|
env_callable: Callable[[], str | dict[str, Any]],
|
||||||
) -> str:
|
) -> str:
|
||||||
template_path = os.path.join(self.template_dir, template_fn)
|
assert not Path(template_fn).is_absolute(), (
|
||||||
|
f"template_fn must be relative: {template_fn}"
|
||||||
|
)
|
||||||
|
template_path = self.template_dir / template_fn
|
||||||
env = env_callable()
|
env = env_callable()
|
||||||
if isinstance(env, dict):
|
if isinstance(env, dict):
|
||||||
if "generated_comment" not in env:
|
if "generated_comment" not in env:
|
||||||
generator_default = REPO_ROOT / "torchgen" / "gen.py"
|
generator_default = TORCHGEN_ROOT / "gen.py"
|
||||||
try:
|
try:
|
||||||
generator = Path(
|
generator = Path(
|
||||||
sys.modules["__main__"].__file__ or generator_default
|
sys.modules["__main__"].__file__ or generator_default
|
||||||
|
|
@ -170,38 +181,56 @@ class FileManager:
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
template = _read_template(template_path)
|
template = _read_template(template_path)
|
||||||
return template.substitute(env)
|
substitute_out = template.substitute(env)
|
||||||
elif isinstance(env, str):
|
# Ensure an extra blank line between the class/function definition
|
||||||
|
# and the docstring of the previous class/function definition.
|
||||||
|
# NB: It is generally not recommended to have docstrings in pyi stub
|
||||||
|
# files. But if there are any, we need to ensure that the file
|
||||||
|
# is properly formatted.
|
||||||
|
return re.sub(
|
||||||
|
r'''
|
||||||
|
(""")\n+ # match triple quotes
|
||||||
|
(
|
||||||
|
(\s*@.+\n)* # match decorators if any
|
||||||
|
\s*(class|def) # match class/function definition
|
||||||
|
)
|
||||||
|
''',
|
||||||
|
r"\g<1>\n\n\g<2>",
|
||||||
|
substitute_out,
|
||||||
|
flags=re.VERBOSE,
|
||||||
|
)
|
||||||
|
if isinstance(env, str):
|
||||||
return env
|
return env
|
||||||
else:
|
assert_never(env)
|
||||||
assert_never(env)
|
|
||||||
|
|
||||||
def write_with_template(
|
def write_with_template(
|
||||||
self,
|
self,
|
||||||
filename: str,
|
filename: str | Path,
|
||||||
template_fn: str,
|
template_fn: str | Path,
|
||||||
env_callable: Callable[[], str | dict[str, Any]],
|
env_callable: Callable[[], str | dict[str, Any]],
|
||||||
) -> None:
|
) -> None:
|
||||||
filename = f"{self.install_dir}/{filename}"
|
filename = Path(filename)
|
||||||
assert filename not in self.filenames, "duplicate file write {filename}"
|
assert not filename.is_absolute(), f"filename must be relative: {filename}"
|
||||||
self.filenames.add(filename)
|
file = self.install_dir / filename
|
||||||
|
assert file not in self.files, f"duplicate file write {file}"
|
||||||
|
self.files.add(file)
|
||||||
if not self.dry_run:
|
if not self.dry_run:
|
||||||
substitute_out = self.substitute_with_template(
|
substitute_out = self.substitute_with_template(
|
||||||
template_fn=template_fn,
|
template_fn=template_fn,
|
||||||
env_callable=env_callable,
|
env_callable=env_callable,
|
||||||
)
|
)
|
||||||
self._write_if_changed(filename=filename, contents=substitute_out)
|
self._write_if_changed(filename=file, contents=substitute_out)
|
||||||
|
|
||||||
def write(
|
def write(
|
||||||
self,
|
self,
|
||||||
filename: str,
|
filename: str | Path,
|
||||||
env_callable: Callable[[], str | dict[str, Any]],
|
env_callable: Callable[[], str | dict[str, Any]],
|
||||||
) -> None:
|
) -> None:
|
||||||
self.write_with_template(filename, filename, env_callable)
|
self.write_with_template(filename, filename, env_callable)
|
||||||
|
|
||||||
def write_sharded(
|
def write_sharded(
|
||||||
self,
|
self,
|
||||||
filename: str,
|
filename: str | Path,
|
||||||
items: Iterable[T],
|
items: Iterable[T],
|
||||||
*,
|
*,
|
||||||
key_fn: Callable[[T], str],
|
key_fn: Callable[[T], str],
|
||||||
|
|
@ -223,8 +252,8 @@ class FileManager:
|
||||||
|
|
||||||
def write_sharded_with_template(
|
def write_sharded_with_template(
|
||||||
self,
|
self,
|
||||||
filename: str,
|
filename: str | Path,
|
||||||
template_fn: str,
|
template_fn: str | Path,
|
||||||
items: Iterable[T],
|
items: Iterable[T],
|
||||||
*,
|
*,
|
||||||
key_fn: Callable[[T], str],
|
key_fn: Callable[[T], str],
|
||||||
|
|
@ -233,6 +262,8 @@ class FileManager:
|
||||||
base_env: dict[str, Any] | None = None,
|
base_env: dict[str, Any] | None = None,
|
||||||
sharded_keys: set[str],
|
sharded_keys: set[str],
|
||||||
) -> None:
|
) -> None:
|
||||||
|
file = Path(filename)
|
||||||
|
assert not file.is_absolute(), f"filename must be relative: {filename}"
|
||||||
everything: dict[str, Any] = {"shard_id": "Everything"}
|
everything: dict[str, Any] = {"shard_id": "Everything"}
|
||||||
shards: list[dict[str, Any]] = [
|
shards: list[dict[str, Any]] = [
|
||||||
{"shard_id": f"_{i}"} for i in range(num_shards)
|
{"shard_id": f"_{i}"} for i in range(num_shards)
|
||||||
|
|
@ -270,31 +301,27 @@ class FileManager:
|
||||||
merge_env(shards[sid], env)
|
merge_env(shards[sid], env)
|
||||||
merge_env(everything, env)
|
merge_env(everything, env)
|
||||||
|
|
||||||
dot_pos = filename.rfind(".")
|
|
||||||
if dot_pos == -1:
|
|
||||||
dot_pos = len(filename)
|
|
||||||
base_filename = filename[:dot_pos]
|
|
||||||
extension = filename[dot_pos:]
|
|
||||||
|
|
||||||
for shard in all_shards:
|
for shard in all_shards:
|
||||||
shard_id = shard["shard_id"]
|
shard_id = shard["shard_id"]
|
||||||
self.write_with_template(
|
self.write_with_template(
|
||||||
f"{base_filename}{shard_id}{extension}",
|
file.with_stem(f"{file.stem}{shard_id}"),
|
||||||
template_fn,
|
template_fn,
|
||||||
lambda: shard,
|
lambda: shard,
|
||||||
)
|
)
|
||||||
|
|
||||||
# filenames is used to track compiled files, but FooEverything.cpp isn't meant to be compiled
|
# filenames is used to track compiled files, but FooEverything.cpp isn't meant to be compiled
|
||||||
self.filenames.discard(
|
self.files.discard(self.install_dir / file.with_stem(f"{file.stem}Everything"))
|
||||||
f"{self.install_dir}/{base_filename}Everything{extension}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def write_outputs(self, variable_name: str, filename: str) -> None:
|
def write_outputs(self, variable_name: str, filename: str | Path) -> None:
|
||||||
"""Write a file containing the list of all outputs which are
|
"""Write a file containing the list of all outputs which are generated by this script."""
|
||||||
generated by this script."""
|
content = "\n".join(
|
||||||
content = "set({}\n {})".format(
|
(
|
||||||
variable_name,
|
"set(",
|
||||||
"\n ".join('"' + name + '"' for name in sorted(self.filenames)),
|
variable_name,
|
||||||
|
# Use POSIX paths to avoid invalid escape sequences on Windows
|
||||||
|
*(f' "{file.as_posix()}"' for file in sorted(self.files)),
|
||||||
|
")",
|
||||||
|
)
|
||||||
)
|
)
|
||||||
self._write_if_changed(filename, content)
|
self._write_if_changed(filename, content)
|
||||||
|
|
||||||
|
|
@ -309,12 +336,15 @@ class FileManager:
|
||||||
|
|
||||||
# Helper function to generate file manager
|
# Helper function to generate file manager
|
||||||
def make_file_manager(
|
def make_file_manager(
|
||||||
options: Namespace, install_dir: str | None = None
|
options: Namespace,
|
||||||
|
install_dir: str | Path | None = None,
|
||||||
) -> FileManager:
|
) -> FileManager:
|
||||||
template_dir = os.path.join(options.source_path, "templates")
|
template_dir = os.path.join(options.source_path, "templates")
|
||||||
install_dir = install_dir if install_dir else options.install_dir
|
install_dir = install_dir if install_dir else options.install_dir
|
||||||
return FileManager(
|
return FileManager(
|
||||||
install_dir=install_dir, template_dir=template_dir, dry_run=options.dry_run
|
install_dir=install_dir,
|
||||||
|
template_dir=template_dir,
|
||||||
|
dry_run=options.dry_run,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -437,7 +467,10 @@ class NamespaceHelper:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, namespace_str: str, entity_name: str = "", max_level: int = 2
|
self,
|
||||||
|
namespace_str: str,
|
||||||
|
entity_name: str = "",
|
||||||
|
max_level: int = 2,
|
||||||
) -> None:
|
) -> None:
|
||||||
# cpp_namespace can be a colon joined string such as torch::lazy
|
# cpp_namespace can be a colon joined string such as torch::lazy
|
||||||
cpp_namespaces = namespace_str.split("::")
|
cpp_namespaces = namespace_str.split("::")
|
||||||
|
|
@ -454,7 +487,8 @@ class NamespaceHelper:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_namespaced_entity(
|
def from_namespaced_entity(
|
||||||
namespaced_entity: str, max_level: int = 2
|
namespaced_entity: str,
|
||||||
|
max_level: int = 2,
|
||||||
) -> NamespaceHelper:
|
) -> NamespaceHelper:
|
||||||
"""
|
"""
|
||||||
Generate helper from nested namespaces as long as class/function name. E.g.: "torch::lazy::add"
|
Generate helper from nested namespaces as long as class/function name. E.g.: "torch::lazy::add"
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user