mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Signed-off-by: Edward Z. Yang <ezyangfb.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/76089 Approved by: https://github.com/albanD
230 lines
8.1 KiB
Python
230 lines
8.1 KiB
Python
# this code should be common among cwrap and ATen preprocessing
|
|
# for now, I have put it in one place but right now is copied out of cwrap
|
|
|
|
import copy
|
|
from typing import Any, Dict, Iterable, List, Union
|
|
|
|
Arg = Dict[str, Any]
|
|
|
|
|
|
def parse_arguments(args: List[Union[str, Arg]]) -> List[Arg]:
|
|
new_args = []
|
|
for arg in args:
|
|
# Simple arg declaration of form "<type> <name>"
|
|
if isinstance(arg, str):
|
|
t, _, name = arg.partition(" ")
|
|
new_args.append({"type": t, "name": name})
|
|
elif isinstance(arg, dict):
|
|
if "arg" in arg:
|
|
arg["type"], _, arg["name"] = arg["arg"].partition(" ")
|
|
del arg["arg"]
|
|
new_args.append(arg)
|
|
else:
|
|
raise AssertionError()
|
|
return new_args
|
|
|
|
|
|
Declaration = Dict[str, Any]
|
|
|
|
|
|
def set_declaration_defaults(declaration: Declaration) -> None:
|
|
if "schema_string" not in declaration:
|
|
# This happens for legacy TH bindings like
|
|
# _thnn_conv_depthwise2d_backward
|
|
declaration["schema_string"] = ""
|
|
declaration.setdefault("arguments", [])
|
|
declaration.setdefault("return", "void")
|
|
if "cname" not in declaration:
|
|
declaration["cname"] = declaration["name"]
|
|
if "backends" not in declaration:
|
|
declaration["backends"] = ["CPU", "CUDA"]
|
|
assert "api_name" not in declaration
|
|
declaration["api_name"] = declaration["name"]
|
|
# NB: keep this in sync with gen_autograd.py
|
|
if declaration.get("overload_name"):
|
|
declaration["type_wrapper_name"] = "{}_{}".format(
|
|
declaration["name"], declaration["overload_name"]
|
|
)
|
|
else:
|
|
declaration["type_wrapper_name"] = declaration["name"]
|
|
# TODO: Uggggh, parsing the schema string here, really???
|
|
declaration["operator_name_with_overload"] = declaration["schema_string"].split(
|
|
"("
|
|
)[0]
|
|
if declaration["schema_string"]:
|
|
declaration["unqual_schema_string"] = declaration["schema_string"].split("::")[
|
|
1
|
|
]
|
|
declaration["unqual_operator_name_with_overload"] = declaration[
|
|
"operator_name_with_overload"
|
|
].split("::")[1]
|
|
else:
|
|
declaration["unqual_schema_string"] = ""
|
|
declaration["unqual_operator_name_with_overload"] = ""
|
|
# Simulate multiple dispatch, even if it's not necessary
|
|
if "options" not in declaration:
|
|
declaration["options"] = [
|
|
{
|
|
"arguments": copy.deepcopy(declaration["arguments"]),
|
|
"schema_order_arguments": copy.deepcopy(
|
|
declaration["schema_order_arguments"]
|
|
),
|
|
}
|
|
]
|
|
del declaration["arguments"]
|
|
del declaration["schema_order_arguments"]
|
|
# Parse arguments (some of them can be strings)
|
|
for option in declaration["options"]:
|
|
option["arguments"] = parse_arguments(option["arguments"])
|
|
option["schema_order_arguments"] = parse_arguments(
|
|
option["schema_order_arguments"]
|
|
)
|
|
# Propagate defaults from declaration to options
|
|
for option in declaration["options"]:
|
|
for k, v in declaration.items():
|
|
# TODO(zach): why does cwrap not propagate 'name'? I need it
|
|
# propagaged for ATen
|
|
if k != "options":
|
|
option.setdefault(k, v)
|
|
|
|
|
|
# TODO(zach): added option to remove keyword handling for C++ which cannot
|
|
# support it.
|
|
|
|
Option = Dict[str, Any]
|
|
|
|
|
|
def filter_unique_options(
|
|
options: Iterable[Option],
|
|
allow_kwarg: bool,
|
|
type_to_signature: Dict[str, str],
|
|
remove_self: bool,
|
|
) -> List[Option]:
|
|
def exclude_arg(arg: Arg) -> bool:
|
|
return arg["type"] == "CONSTANT" # type: ignore[no-any-return]
|
|
|
|
def exclude_arg_with_self_check(arg: Arg) -> bool:
|
|
return exclude_arg(arg) or (remove_self and arg["name"] == "self")
|
|
|
|
def signature(option: Option, num_kwarg_only: int) -> str:
|
|
if num_kwarg_only == 0:
|
|
kwarg_only_count = None
|
|
else:
|
|
kwarg_only_count = -num_kwarg_only
|
|
arg_signature = "#".join(
|
|
type_to_signature.get(arg["type"], arg["type"])
|
|
for arg in option["arguments"][:kwarg_only_count]
|
|
if not exclude_arg_with_self_check(arg)
|
|
)
|
|
if kwarg_only_count is None:
|
|
return arg_signature
|
|
kwarg_only_signature = "#".join(
|
|
arg["name"] + "#" + arg["type"]
|
|
for arg in option["arguments"][kwarg_only_count:]
|
|
if not exclude_arg(arg)
|
|
)
|
|
return arg_signature + "#-#" + kwarg_only_signature
|
|
|
|
seen_signatures = set()
|
|
unique = []
|
|
for option in options:
|
|
# if only check num_kwarg_only == 0 if allow_kwarg == False
|
|
limit = len(option["arguments"]) if allow_kwarg else 0
|
|
for num_kwarg_only in range(0, limit + 1):
|
|
sig = signature(option, num_kwarg_only)
|
|
if sig not in seen_signatures:
|
|
if num_kwarg_only > 0:
|
|
for arg in option["arguments"][-num_kwarg_only:]:
|
|
arg["kwarg_only"] = True
|
|
unique.append(option)
|
|
seen_signatures.add(sig)
|
|
break
|
|
return unique
|
|
|
|
|
|
def sort_by_number_of_args(declaration: Declaration, reverse: bool = True) -> None:
|
|
def num_args(option: Option) -> int:
|
|
return len(option["arguments"])
|
|
|
|
declaration["options"].sort(key=num_args, reverse=reverse)
|
|
|
|
|
|
class Function(object):
|
|
def __init__(self, name: str) -> None:
|
|
self.name = name
|
|
self.arguments: List["Argument"] = []
|
|
|
|
def add_argument(self, arg: "Argument") -> None:
|
|
assert isinstance(arg, Argument)
|
|
self.arguments.append(arg)
|
|
|
|
def __repr__(self) -> str:
|
|
return self.name + "(" + ", ".join(a.__repr__() for a in self.arguments) + ")"
|
|
|
|
|
|
class Argument(object):
|
|
def __init__(self, _type: str, name: str, is_optional: bool):
|
|
self.type = _type
|
|
self.name = name
|
|
self.is_optional = is_optional
|
|
|
|
def __repr__(self) -> str:
|
|
return self.type + " " + self.name
|
|
|
|
|
|
def parse_header(path: str) -> List[Function]:
|
|
with open(path, "r") as f:
|
|
lines: Iterable[Any] = f.read().split("\n")
|
|
|
|
# Remove empty lines and prebackend directives
|
|
lines = filter(lambda l: l and not l.startswith("#"), lines)
|
|
# Remove line comments
|
|
lines = (l.partition("//") for l in lines)
|
|
# Select line and comment part
|
|
lines = ((l[0].strip(), l[2].strip()) for l in lines)
|
|
# Remove trailing special signs
|
|
lines = ((l[0].rstrip(");").rstrip(","), l[1]) for l in lines)
|
|
# Split arguments
|
|
lines = ((l[0].split(","), l[1]) for l in lines)
|
|
# Flatten lines
|
|
new_lines = []
|
|
for l, c in lines:
|
|
for split in l:
|
|
new_lines.append((split, c))
|
|
lines = new_lines
|
|
del new_lines
|
|
# Remove unnecessary whitespace
|
|
lines = ((l[0].strip(), l[1]) for l in lines)
|
|
# Remove empty lines
|
|
lines = filter(lambda l: l[0], lines)
|
|
generic_functions = []
|
|
for l, c in lines:
|
|
if l.startswith("TH_API void THNN_"):
|
|
fn_name = l[len("TH_API void THNN_") :]
|
|
if fn_name[0] == "(" and fn_name[-2] == ")":
|
|
fn_name = fn_name[1:-2]
|
|
else:
|
|
fn_name = fn_name[:-1]
|
|
generic_functions.append(Function(fn_name))
|
|
elif l.startswith("TORCH_CUDA_CPP_API void THNN_"):
|
|
fn_name = l[len("TORCH_CUDA_CPP_API void THNN_") :]
|
|
if fn_name[0] == "(" and fn_name[-2] == ")":
|
|
fn_name = fn_name[1:-2]
|
|
else:
|
|
fn_name = fn_name[:-1]
|
|
generic_functions.append(Function(fn_name))
|
|
elif l.startswith("TORCH_CUDA_CU_API void THNN_"):
|
|
fn_name = l[len("TORCH_CUDA_CU_API void THNN_") :]
|
|
if fn_name[0] == "(" and fn_name[-2] == ")":
|
|
fn_name = fn_name[1:-2]
|
|
else:
|
|
fn_name = fn_name[:-1]
|
|
generic_functions.append(Function(fn_name))
|
|
elif l:
|
|
t, name = l.split()
|
|
if "*" in name:
|
|
t = t + "*"
|
|
name = name[1:]
|
|
generic_functions[-1].add_argument(Argument(t, name, "[OPTIONAL]" in c))
|
|
return generic_functions
|