mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Refactor CppSignatureGroup to collect signatures as list. (#83667)
This makes it easier to add more signatures to the signature group, as relevant logic which needs to run for each signature no longer needs to be adjusted. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/83667 Approved by: https://github.com/larryliu0820, https://github.com/bdhirsh
This commit is contained in:
parent
03e322c8d6
commit
0ec7fc13d6
|
|
@ -1,6 +1,6 @@
|
|||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional, Sequence, Set, TypeVar, Union
|
||||
from typing import Dict, Iterator, List, Optional, Sequence, Set, TypeVar, Union
|
||||
|
||||
from torchgen.model import (
|
||||
Argument,
|
||||
|
|
@ -504,29 +504,30 @@ class CppSignatureGroup:
|
|||
else:
|
||||
return self.signature
|
||||
|
||||
def signatures(self) -> Iterator[CppSignature]:
|
||||
yield self.signature
|
||||
if self.faithful_signature:
|
||||
yield self.faithful_signature
|
||||
|
||||
@staticmethod
|
||||
def from_native_function(
|
||||
f: NativeFunction, *, method: bool, fallback_binding: bool = False
|
||||
) -> "CppSignatureGroup":
|
||||
func = f.func
|
||||
faithful_signature: Optional[CppSignature]
|
||||
if func.arguments.tensor_options is not None or len(func.arguments.out) > 0:
|
||||
faithful_signature = CppSignature(
|
||||
|
||||
def make_sig(*, faithful: bool) -> CppSignature:
|
||||
return CppSignature(
|
||||
func=func,
|
||||
faithful=True,
|
||||
faithful=faithful,
|
||||
method=method,
|
||||
fallback_binding=fallback_binding,
|
||||
cpp_no_default_args=f.cpp_no_default_args,
|
||||
)
|
||||
else:
|
||||
faithful_signature = None
|
||||
signature = CppSignature(
|
||||
func=func,
|
||||
faithful=False,
|
||||
method=method,
|
||||
fallback_binding=fallback_binding,
|
||||
cpp_no_default_args=f.cpp_no_default_args,
|
||||
)
|
||||
|
||||
faithful_signature: Optional[CppSignature] = None
|
||||
if func.arguments.tensor_options is not None or len(func.arguments.out) > 0:
|
||||
faithful_signature = make_sig(faithful=True)
|
||||
signature = make_sig(faithful=False)
|
||||
return CppSignatureGroup(
|
||||
func=func,
|
||||
signature=signature,
|
||||
|
|
|
|||
|
|
@ -611,29 +611,19 @@ class ComputeFunction:
|
|||
f, method=False, fallback_binding=f.manual_cpp_binding
|
||||
)
|
||||
|
||||
def generate_defn(faithful: bool) -> str:
|
||||
if faithful:
|
||||
sig = sig_group.faithful_signature
|
||||
assert sig is not None
|
||||
else:
|
||||
sig = sig_group.signature
|
||||
|
||||
result = ""
|
||||
for sig in sig_group.signatures():
|
||||
# See Note [The ATen Operators API]
|
||||
target_sig = DispatcherSignature.from_schema(f.func)
|
||||
exprs = translate(sig.arguments(), target_sig.arguments())
|
||||
exprs_str = ", ".join([e.expr for e in exprs])
|
||||
|
||||
return f"""
|
||||
result += f"""
|
||||
// aten::{f.func}
|
||||
inline {sig.decl()} {{
|
||||
return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str});
|
||||
}}
|
||||
"""
|
||||
|
||||
result = generate_defn(False)
|
||||
if sig_group.faithful_signature is not None:
|
||||
result += generate_defn(True)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
|
|
@ -657,36 +647,28 @@ class ComputeTensorMethod:
|
|||
)
|
||||
|
||||
if self.target is Target.DECLARATION:
|
||||
result = f"{sig_group.signature.decl()} const;\n"
|
||||
if sig_group.faithful_signature is not None:
|
||||
result += f"{sig_group.faithful_signature.decl()} const;\n"
|
||||
result = ""
|
||||
for sig in sig_group.signatures():
|
||||
result += f"{sig.decl()} const;\n"
|
||||
return result
|
||||
|
||||
if self.target is not Target.DEFINITION:
|
||||
assert_never(self.target)
|
||||
|
||||
def generate_defn(faithful: bool) -> str:
|
||||
if faithful:
|
||||
sig = sig_group.faithful_signature
|
||||
assert sig is not None
|
||||
else:
|
||||
sig = sig_group.signature
|
||||
result = ""
|
||||
|
||||
for sig in sig_group.signatures():
|
||||
target_sig = DispatcherSignature.from_schema(f.func)
|
||||
exprs = translate(sig.arguments(), target_sig.arguments(), method=True)
|
||||
exprs_str = ", ".join([e.expr for e in exprs])
|
||||
|
||||
return f"""
|
||||
result += f"""
|
||||
// aten::{f.func}
|
||||
inline {sig.defn(prefix="Tensor::")} const {{
|
||||
return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str});
|
||||
}}
|
||||
"""
|
||||
|
||||
result = generate_defn(faithful=False)
|
||||
if sig_group.faithful_signature is not None:
|
||||
result += generate_defn(faithful=True)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
|
|
@ -703,28 +685,19 @@ class ComputeRedispatchFunction:
|
|||
f, method=False, fallback_binding=f.manual_cpp_binding
|
||||
)
|
||||
|
||||
def generate_defn(faithful: bool) -> str:
|
||||
if faithful:
|
||||
sig = sig_group.faithful_signature
|
||||
assert sig is not None
|
||||
else:
|
||||
sig = sig_group.signature
|
||||
|
||||
result = ""
|
||||
for sig in sig_group.signatures():
|
||||
target_sig = DispatcherSignature.from_schema(f.func)
|
||||
exprs = translate(sig.arguments(), target_sig.arguments())
|
||||
exprs_str = ", ".join(["dispatchKeySet"] + [a.expr for a in exprs])
|
||||
|
||||
return f"""
|
||||
result += f"""
|
||||
// aten::{f.func}
|
||||
inline {sig.decl(is_redispatching_fn=True)} {{
|
||||
return at::_ops::{f.func.name.unambiguous_name()}::redispatch({exprs_str});
|
||||
}}
|
||||
"""
|
||||
|
||||
result = generate_defn(False)
|
||||
if sig_group.faithful_signature is not None:
|
||||
result += generate_defn(True)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user