Refactor CppSignatureGroup to collect signatures as list. (#83667)

This makes it easier to add more signatures to the signature group,
as relevant logic which needs to run for each signature no longer
needs to be adjusted.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83667
Approved by: https://github.com/larryliu0820, https://github.com/bdhirsh
This commit is contained in:
Edward Z. Yang 2022-08-18 19:22:37 -07:00 committed by PyTorch MergeBot
parent 03e322c8d6
commit 0ec7fc13d6
2 changed files with 27 additions and 53 deletions

View File

@ -1,6 +1,6 @@
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import Dict, List, Optional, Sequence, Set, TypeVar, Union from typing import Dict, Iterator, List, Optional, Sequence, Set, TypeVar, Union
from torchgen.model import ( from torchgen.model import (
Argument, Argument,
@ -504,29 +504,30 @@ class CppSignatureGroup:
else: else:
return self.signature return self.signature
def signatures(self) -> Iterator[CppSignature]:
yield self.signature
if self.faithful_signature:
yield self.faithful_signature
@staticmethod @staticmethod
def from_native_function( def from_native_function(
f: NativeFunction, *, method: bool, fallback_binding: bool = False f: NativeFunction, *, method: bool, fallback_binding: bool = False
) -> "CppSignatureGroup": ) -> "CppSignatureGroup":
func = f.func func = f.func
faithful_signature: Optional[CppSignature]
if func.arguments.tensor_options is not None or len(func.arguments.out) > 0: def make_sig(*, faithful: bool) -> CppSignature:
faithful_signature = CppSignature( return CppSignature(
func=func, func=func,
faithful=True, faithful=faithful,
method=method, method=method,
fallback_binding=fallback_binding, fallback_binding=fallback_binding,
cpp_no_default_args=f.cpp_no_default_args, cpp_no_default_args=f.cpp_no_default_args,
) )
else:
faithful_signature = None faithful_signature: Optional[CppSignature] = None
signature = CppSignature( if func.arguments.tensor_options is not None or len(func.arguments.out) > 0:
func=func, faithful_signature = make_sig(faithful=True)
faithful=False, signature = make_sig(faithful=False)
method=method,
fallback_binding=fallback_binding,
cpp_no_default_args=f.cpp_no_default_args,
)
return CppSignatureGroup( return CppSignatureGroup(
func=func, func=func,
signature=signature, signature=signature,

View File

@ -611,29 +611,19 @@ class ComputeFunction:
f, method=False, fallback_binding=f.manual_cpp_binding f, method=False, fallback_binding=f.manual_cpp_binding
) )
def generate_defn(faithful: bool) -> str: result = ""
if faithful: for sig in sig_group.signatures():
sig = sig_group.faithful_signature
assert sig is not None
else:
sig = sig_group.signature
# See Note [The ATen Operators API] # See Note [The ATen Operators API]
target_sig = DispatcherSignature.from_schema(f.func) target_sig = DispatcherSignature.from_schema(f.func)
exprs = translate(sig.arguments(), target_sig.arguments()) exprs = translate(sig.arguments(), target_sig.arguments())
exprs_str = ", ".join([e.expr for e in exprs]) exprs_str = ", ".join([e.expr for e in exprs])
return f""" result += f"""
// aten::{f.func} // aten::{f.func}
inline {sig.decl()} {{ inline {sig.decl()} {{
return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str}); return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str});
}} }}
""" """
result = generate_defn(False)
if sig_group.faithful_signature is not None:
result += generate_defn(True)
return result return result
@ -657,36 +647,28 @@ class ComputeTensorMethod:
) )
if self.target is Target.DECLARATION: if self.target is Target.DECLARATION:
result = f"{sig_group.signature.decl()} const;\n" result = ""
if sig_group.faithful_signature is not None: for sig in sig_group.signatures():
result += f"{sig_group.faithful_signature.decl()} const;\n" result += f"{sig.decl()} const;\n"
return result return result
if self.target is not Target.DEFINITION: if self.target is not Target.DEFINITION:
assert_never(self.target) assert_never(self.target)
def generate_defn(faithful: bool) -> str: result = ""
if faithful:
sig = sig_group.faithful_signature
assert sig is not None
else:
sig = sig_group.signature
for sig in sig_group.signatures():
target_sig = DispatcherSignature.from_schema(f.func) target_sig = DispatcherSignature.from_schema(f.func)
exprs = translate(sig.arguments(), target_sig.arguments(), method=True) exprs = translate(sig.arguments(), target_sig.arguments(), method=True)
exprs_str = ", ".join([e.expr for e in exprs]) exprs_str = ", ".join([e.expr for e in exprs])
return f""" result += f"""
// aten::{f.func} // aten::{f.func}
inline {sig.defn(prefix="Tensor::")} const {{ inline {sig.defn(prefix="Tensor::")} const {{
return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str}); return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str});
}} }}
""" """
result = generate_defn(faithful=False)
if sig_group.faithful_signature is not None:
result += generate_defn(faithful=True)
return result return result
@ -703,28 +685,19 @@ class ComputeRedispatchFunction:
f, method=False, fallback_binding=f.manual_cpp_binding f, method=False, fallback_binding=f.manual_cpp_binding
) )
def generate_defn(faithful: bool) -> str: result = ""
if faithful: for sig in sig_group.signatures():
sig = sig_group.faithful_signature
assert sig is not None
else:
sig = sig_group.signature
target_sig = DispatcherSignature.from_schema(f.func) target_sig = DispatcherSignature.from_schema(f.func)
exprs = translate(sig.arguments(), target_sig.arguments()) exprs = translate(sig.arguments(), target_sig.arguments())
exprs_str = ", ".join(["dispatchKeySet"] + [a.expr for a in exprs]) exprs_str = ", ".join(["dispatchKeySet"] + [a.expr for a in exprs])
return f""" result += f"""
// aten::{f.func} // aten::{f.func}
inline {sig.decl(is_redispatching_fn=True)} {{ inline {sig.decl(is_redispatching_fn=True)} {{
return at::_ops::{f.func.name.unambiguous_name()}::redispatch({exprs_str}); return at::_ops::{f.func.name.unambiguous_name()}::redispatch({exprs_str});
}} }}
""" """
result = generate_defn(False)
if sig_group.faithful_signature is not None:
result += generate_defn(True)
return result return result