pytorch/torch/utils/data/datapipes/gen_pyi.py
Aaron Gokaslan b7b2178204 [BE]: Remove useless lambdas (#113602)
Applies PLW0108 which removes useless lambda calls in Python, the rule is in preview so it is not ready to be enabled by default just yet. These are the autofixes from the rule.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113602
Approved by: https://github.com/albanD
2023-11-14 20:06:48 +00:00

247 lines
10 KiB
Python

import os
import pathlib
from collections import defaultdict
from typing import Any, Dict, List, Set, Tuple, Union
def materialize_lines(lines: List[str], indentation: int) -> str:
output = ""
new_line_with_indent = "\n" + " " * indentation
for i, line in enumerate(lines):
if i != 0:
output += new_line_with_indent
output += line.replace('\n', new_line_with_indent)
return output
def gen_from_template(dir: str, template_name: str, output_name: str, replacements: List[Tuple[str, Any, int]]):
template_path = os.path.join(dir, template_name)
output_path = os.path.join(dir, output_name)
with open(template_path) as f:
content = f.read()
for placeholder, lines, indentation in replacements:
with open(output_path, "w") as f:
content = content.replace(placeholder, materialize_lines(lines, indentation))
f.write(content)
def find_file_paths(dir_paths: List[str], files_to_exclude: Set[str]) -> Set[str]:
"""
When given a path to a directory, returns the paths to the relevant files within it.
This function does NOT recursive traverse to subdirectories.
"""
paths: Set[str] = set()
for dir_path in dir_paths:
all_files = os.listdir(dir_path)
python_files = {fname for fname in all_files if ".py" == fname[-3:]}
filter_files = {fname for fname in python_files if fname not in files_to_exclude}
paths.update({os.path.join(dir_path, fname) for fname in filter_files})
return paths
def extract_method_name(line: str) -> str:
"""Extract method name from decorator in the form of "@functional_datapipe({method_name})"."""
if "(\"" in line:
start_token, end_token = "(\"", "\")"
elif "(\'" in line:
start_token, end_token = "(\'", "\')"
else:
raise RuntimeError(f"Unable to find appropriate method name within line:\n{line}")
start, end = line.find(start_token) + len(start_token), line.find(end_token)
return line[start:end]
def extract_class_name(line: str) -> str:
"""Extract class name from class definition in the form of "class {CLASS_NAME}({Type}):"."""
start_token = "class "
end_token = "("
start, end = line.find(start_token) + len(start_token), line.find(end_token)
return line[start:end]
def parse_datapipe_file(file_path: str) -> Tuple[Dict[str, str], Dict[str, str], Set[str], Dict[str, List[str]]]:
"""Given a path to file, parses the file and returns a dictionary of method names to function signatures."""
method_to_signature, method_to_class_name, special_output_type = {}, {}, set()
doc_string_dict = defaultdict(list)
with open(file_path) as f:
open_paren_count = 0
method_name, class_name, signature = "", "", ""
skip = False
for line in f.readlines():
if line.count("\"\"\"") % 2 == 1:
skip = not skip
if skip or "\"\"\"" in line: # Saving docstrings
doc_string_dict[method_name].append(line)
continue
if "@functional_datapipe" in line:
method_name = extract_method_name(line)
doc_string_dict[method_name] = []
continue
if method_name and "class " in line:
class_name = extract_class_name(line)
continue
if method_name and ("def __init__(" in line or "def __new__(" in line):
if "def __new__(" in line:
special_output_type.add(method_name)
open_paren_count += 1
start = line.find("(") + len("(")
line = line[start:]
if open_paren_count > 0:
open_paren_count += line.count('(')
open_paren_count -= line.count(')')
if open_paren_count == 0:
end = line.rfind(')')
signature += line[:end]
method_to_signature[method_name] = process_signature(signature)
method_to_class_name[method_name] = class_name
method_name, class_name, signature = "", "", ""
elif open_paren_count < 0:
raise RuntimeError("open parenthesis count < 0. This shouldn't be possible.")
else:
signature += line.strip('\n').strip(' ')
return method_to_signature, method_to_class_name, special_output_type, doc_string_dict
def parse_datapipe_files(file_paths: Set[str]) -> Tuple[Dict[str, str], Dict[str, str], Set[str], Dict[str, List[str]]]:
methods_and_signatures, methods_and_class_names, methods_with_special_output_types = {}, {}, set()
methods_and_doc_strings = {}
for path in file_paths:
(
method_to_signature,
method_to_class_name,
methods_needing_special_output_types,
doc_string_dict,
) = parse_datapipe_file(path)
methods_and_signatures.update(method_to_signature)
methods_and_class_names.update(method_to_class_name)
methods_with_special_output_types.update(methods_needing_special_output_types)
methods_and_doc_strings.update(doc_string_dict)
return methods_and_signatures, methods_and_class_names, methods_with_special_output_types, methods_and_doc_strings
def split_outside_bracket(line: str, delimiter: str = ",") -> List[str]:
"""Given a line of text, split it on comma unless the comma is within a bracket '[]'."""
bracket_count = 0
curr_token = ""
res = []
for char in line:
if char == "[":
bracket_count += 1
elif char == "]":
bracket_count -= 1
elif char == delimiter and bracket_count == 0:
res.append(curr_token)
curr_token = ""
continue
curr_token += char
res.append(curr_token)
return res
def process_signature(line: str) -> str:
"""
Clean up a given raw function signature.
This includes removing the self-referential datapipe argument, default
arguments of input functions, newlines, and spaces.
"""
tokens: List[str] = split_outside_bracket(line)
for i, token in enumerate(tokens):
tokens[i] = token.strip(' ')
if token == "cls":
tokens[i] = "self"
elif i > 0 and ("self" == tokens[i - 1]) and (tokens[i][0] != "*"):
# Remove the datapipe after 'self' or 'cls' unless it has '*'
tokens[i] = ""
elif "Callable =" in token: # Remove default argument if it is a function
head, default_arg = token.rsplit("=", 2)
tokens[i] = head.strip(' ') + "= ..."
tokens = [t for t in tokens if t != ""]
line = ', '.join(tokens)
return line
def get_method_definitions(file_path: Union[str, List[str]],
files_to_exclude: Set[str],
deprecated_files: Set[str],
default_output_type: str,
method_to_special_output_type: Dict[str, str],
root: str = "") -> List[str]:
"""
#.pyi generation for functional DataPipes Process.
# 1. Find files that we want to process (exclude the ones who don't)
# 2. Parse method name and signature
# 3. Remove first argument after self (unless it is "*datapipes"), default args, and spaces
"""
if root == "":
root = str(pathlib.Path(__file__).parent.resolve())
file_path = [file_path] if isinstance(file_path, str) else file_path
file_path = [os.path.join(root, path) for path in file_path]
file_paths = find_file_paths(file_path,
files_to_exclude=files_to_exclude.union(deprecated_files))
methods_and_signatures, methods_and_class_names, methods_w_special_output_types, methods_and_doc_strings = \
parse_datapipe_files(file_paths)
for fn_name in method_to_special_output_type:
if fn_name not in methods_w_special_output_types:
methods_w_special_output_types.add(fn_name)
method_definitions = []
for method_name, arguments in methods_and_signatures.items():
class_name = methods_and_class_names[method_name]
if method_name in methods_w_special_output_types:
output_type = method_to_special_output_type[method_name]
else:
output_type = default_output_type
doc_string = "".join(methods_and_doc_strings[method_name])
if doc_string == "":
doc_string = " ...\n"
method_definitions.append(f"# Functional form of '{class_name}'\n"
f"def {method_name}({arguments}) -> {output_type}:\n"
f"{doc_string}")
method_definitions.sort(key=lambda s: s.split('\n')[1]) # sorting based on method_name
return method_definitions
# Defined outside of main() so they can be imported by TorchData
iterDP_file_path: str = "iter"
iterDP_files_to_exclude: Set[str] = {"__init__.py", "utils.py"}
iterDP_deprecated_files: Set[str] = set()
iterDP_method_to_special_output_type: Dict[str, str] = {"demux": "List[IterDataPipe]", "fork": "List[IterDataPipe]"}
mapDP_file_path: str = "map"
mapDP_files_to_exclude: Set[str] = {"__init__.py", "utils.py"}
mapDP_deprecated_files: Set[str] = set()
mapDP_method_to_special_output_type: Dict[str, str] = {"shuffle": "IterDataPipe"}
def main() -> None:
"""
# Inject file into template datapipe.pyi.in.
TODO: The current implementation of this script only generates interfaces for built-in methods. To generate
interface for user-defined DataPipes, consider changing `IterDataPipe.register_datapipe_as_function`.
"""
iter_method_definitions = get_method_definitions(iterDP_file_path, iterDP_files_to_exclude, iterDP_deprecated_files,
"IterDataPipe", iterDP_method_to_special_output_type)
map_method_definitions = get_method_definitions(mapDP_file_path, mapDP_files_to_exclude, mapDP_deprecated_files,
"MapDataPipe", mapDP_method_to_special_output_type)
path = pathlib.Path(__file__).parent.resolve()
replacements = [('${IterDataPipeMethods}', iter_method_definitions, 4),
('${MapDataPipeMethods}', map_method_definitions, 4)]
gen_from_template(dir=str(path),
template_name="datapipe.pyi.in",
output_name="datapipe.pyi",
replacements=replacements)
if __name__ == '__main__':
main()