Refactor CppSignatureGroup to collect signatures as list. (#83667)

This makes it easier to add more signatures to the signature group,
as relevant logic which needs to run for each signature no longer
needs to be adjusted.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83667
Approved by: https://github.com/larryliu0820, https://github.com/bdhirsh
This commit is contained in:
Edward Z. Yang 2022-08-18 19:22:37 -07:00 committed by PyTorch MergeBot
parent 03e322c8d6
commit 0ec7fc13d6
2 changed files with 27 additions and 53 deletions

View File

@ -1,6 +1,6 @@
from dataclasses import dataclass
from enum import Enum
from typing import Dict, List, Optional, Sequence, Set, TypeVar, Union
from typing import Dict, Iterator, List, Optional, Sequence, Set, TypeVar, Union
from torchgen.model import (
Argument,
@ -504,29 +504,30 @@ class CppSignatureGroup:
else:
return self.signature
def signatures(self) -> Iterator[CppSignature]:
yield self.signature
if self.faithful_signature:
yield self.faithful_signature
@staticmethod
def from_native_function(
f: NativeFunction, *, method: bool, fallback_binding: bool = False
) -> "CppSignatureGroup":
func = f.func
faithful_signature: Optional[CppSignature]
if func.arguments.tensor_options is not None or len(func.arguments.out) > 0:
faithful_signature = CppSignature(
def make_sig(*, faithful: bool) -> CppSignature:
return CppSignature(
func=func,
faithful=True,
faithful=faithful,
method=method,
fallback_binding=fallback_binding,
cpp_no_default_args=f.cpp_no_default_args,
)
else:
faithful_signature = None
signature = CppSignature(
func=func,
faithful=False,
method=method,
fallback_binding=fallback_binding,
cpp_no_default_args=f.cpp_no_default_args,
)
faithful_signature: Optional[CppSignature] = None
if func.arguments.tensor_options is not None or len(func.arguments.out) > 0:
faithful_signature = make_sig(faithful=True)
signature = make_sig(faithful=False)
return CppSignatureGroup(
func=func,
signature=signature,

View File

@ -611,29 +611,19 @@ class ComputeFunction:
f, method=False, fallback_binding=f.manual_cpp_binding
)
def generate_defn(faithful: bool) -> str:
if faithful:
sig = sig_group.faithful_signature
assert sig is not None
else:
sig = sig_group.signature
result = ""
for sig in sig_group.signatures():
# See Note [The ATen Operators API]
target_sig = DispatcherSignature.from_schema(f.func)
exprs = translate(sig.arguments(), target_sig.arguments())
exprs_str = ", ".join([e.expr for e in exprs])
return f"""
result += f"""
// aten::{f.func}
inline {sig.decl()} {{
return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str});
}}
"""
result = generate_defn(False)
if sig_group.faithful_signature is not None:
result += generate_defn(True)
return result
@ -657,36 +647,28 @@ class ComputeTensorMethod:
)
if self.target is Target.DECLARATION:
result = f"{sig_group.signature.decl()} const;\n"
if sig_group.faithful_signature is not None:
result += f"{sig_group.faithful_signature.decl()} const;\n"
result = ""
for sig in sig_group.signatures():
result += f"{sig.decl()} const;\n"
return result
if self.target is not Target.DEFINITION:
assert_never(self.target)
def generate_defn(faithful: bool) -> str:
if faithful:
sig = sig_group.faithful_signature
assert sig is not None
else:
sig = sig_group.signature
result = ""
for sig in sig_group.signatures():
target_sig = DispatcherSignature.from_schema(f.func)
exprs = translate(sig.arguments(), target_sig.arguments(), method=True)
exprs_str = ", ".join([e.expr for e in exprs])
return f"""
result += f"""
// aten::{f.func}
inline {sig.defn(prefix="Tensor::")} const {{
return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str});
}}
"""
result = generate_defn(faithful=False)
if sig_group.faithful_signature is not None:
result += generate_defn(faithful=True)
return result
@ -703,28 +685,19 @@ class ComputeRedispatchFunction:
f, method=False, fallback_binding=f.manual_cpp_binding
)
def generate_defn(faithful: bool) -> str:
if faithful:
sig = sig_group.faithful_signature
assert sig is not None
else:
sig = sig_group.signature
result = ""
for sig in sig_group.signatures():
target_sig = DispatcherSignature.from_schema(f.func)
exprs = translate(sig.arguments(), target_sig.arguments())
exprs_str = ", ".join(["dispatchKeySet"] + [a.expr for a in exprs])
return f"""
result += f"""
// aten::{f.func}
inline {sig.decl(is_redispatching_fn=True)} {{
return at::_ops::{f.func.name.unambiguous_name()}::redispatch({exprs_str});
}}
"""
result = generate_defn(False)
if sig_group.faithful_signature is not None:
result += generate_defn(True)
return result