mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Refactor CppSignatureGroup to collect signatures as list. (#83667)
This makes it easier to add more signatures to the signature group, as relevant logic which needs to run for each signature no longer needs to be adjusted. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/83667 Approved by: https://github.com/larryliu0820, https://github.com/bdhirsh
This commit is contained in:
parent
03e322c8d6
commit
0ec7fc13d6
|
|
@ -1,6 +1,6 @@
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Dict, List, Optional, Sequence, Set, TypeVar, Union
|
from typing import Dict, Iterator, List, Optional, Sequence, Set, TypeVar, Union
|
||||||
|
|
||||||
from torchgen.model import (
|
from torchgen.model import (
|
||||||
Argument,
|
Argument,
|
||||||
|
|
@ -504,29 +504,30 @@ class CppSignatureGroup:
|
||||||
else:
|
else:
|
||||||
return self.signature
|
return self.signature
|
||||||
|
|
||||||
|
def signatures(self) -> Iterator[CppSignature]:
|
||||||
|
yield self.signature
|
||||||
|
if self.faithful_signature:
|
||||||
|
yield self.faithful_signature
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_native_function(
|
def from_native_function(
|
||||||
f: NativeFunction, *, method: bool, fallback_binding: bool = False
|
f: NativeFunction, *, method: bool, fallback_binding: bool = False
|
||||||
) -> "CppSignatureGroup":
|
) -> "CppSignatureGroup":
|
||||||
func = f.func
|
func = f.func
|
||||||
faithful_signature: Optional[CppSignature]
|
|
||||||
if func.arguments.tensor_options is not None or len(func.arguments.out) > 0:
|
def make_sig(*, faithful: bool) -> CppSignature:
|
||||||
faithful_signature = CppSignature(
|
return CppSignature(
|
||||||
func=func,
|
func=func,
|
||||||
faithful=True,
|
faithful=faithful,
|
||||||
method=method,
|
method=method,
|
||||||
fallback_binding=fallback_binding,
|
fallback_binding=fallback_binding,
|
||||||
cpp_no_default_args=f.cpp_no_default_args,
|
cpp_no_default_args=f.cpp_no_default_args,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
faithful_signature = None
|
faithful_signature: Optional[CppSignature] = None
|
||||||
signature = CppSignature(
|
if func.arguments.tensor_options is not None or len(func.arguments.out) > 0:
|
||||||
func=func,
|
faithful_signature = make_sig(faithful=True)
|
||||||
faithful=False,
|
signature = make_sig(faithful=False)
|
||||||
method=method,
|
|
||||||
fallback_binding=fallback_binding,
|
|
||||||
cpp_no_default_args=f.cpp_no_default_args,
|
|
||||||
)
|
|
||||||
return CppSignatureGroup(
|
return CppSignatureGroup(
|
||||||
func=func,
|
func=func,
|
||||||
signature=signature,
|
signature=signature,
|
||||||
|
|
|
||||||
|
|
@ -611,29 +611,19 @@ class ComputeFunction:
|
||||||
f, method=False, fallback_binding=f.manual_cpp_binding
|
f, method=False, fallback_binding=f.manual_cpp_binding
|
||||||
)
|
)
|
||||||
|
|
||||||
def generate_defn(faithful: bool) -> str:
|
result = ""
|
||||||
if faithful:
|
for sig in sig_group.signatures():
|
||||||
sig = sig_group.faithful_signature
|
|
||||||
assert sig is not None
|
|
||||||
else:
|
|
||||||
sig = sig_group.signature
|
|
||||||
|
|
||||||
# See Note [The ATen Operators API]
|
# See Note [The ATen Operators API]
|
||||||
target_sig = DispatcherSignature.from_schema(f.func)
|
target_sig = DispatcherSignature.from_schema(f.func)
|
||||||
exprs = translate(sig.arguments(), target_sig.arguments())
|
exprs = translate(sig.arguments(), target_sig.arguments())
|
||||||
exprs_str = ", ".join([e.expr for e in exprs])
|
exprs_str = ", ".join([e.expr for e in exprs])
|
||||||
|
|
||||||
return f"""
|
result += f"""
|
||||||
// aten::{f.func}
|
// aten::{f.func}
|
||||||
inline {sig.decl()} {{
|
inline {sig.decl()} {{
|
||||||
return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str});
|
return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str});
|
||||||
}}
|
}}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result = generate_defn(False)
|
|
||||||
if sig_group.faithful_signature is not None:
|
|
||||||
result += generate_defn(True)
|
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -657,36 +647,28 @@ class ComputeTensorMethod:
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.target is Target.DECLARATION:
|
if self.target is Target.DECLARATION:
|
||||||
result = f"{sig_group.signature.decl()} const;\n"
|
result = ""
|
||||||
if sig_group.faithful_signature is not None:
|
for sig in sig_group.signatures():
|
||||||
result += f"{sig_group.faithful_signature.decl()} const;\n"
|
result += f"{sig.decl()} const;\n"
|
||||||
return result
|
return result
|
||||||
|
|
||||||
if self.target is not Target.DEFINITION:
|
if self.target is not Target.DEFINITION:
|
||||||
assert_never(self.target)
|
assert_never(self.target)
|
||||||
|
|
||||||
def generate_defn(faithful: bool) -> str:
|
result = ""
|
||||||
if faithful:
|
|
||||||
sig = sig_group.faithful_signature
|
|
||||||
assert sig is not None
|
|
||||||
else:
|
|
||||||
sig = sig_group.signature
|
|
||||||
|
|
||||||
|
for sig in sig_group.signatures():
|
||||||
target_sig = DispatcherSignature.from_schema(f.func)
|
target_sig = DispatcherSignature.from_schema(f.func)
|
||||||
exprs = translate(sig.arguments(), target_sig.arguments(), method=True)
|
exprs = translate(sig.arguments(), target_sig.arguments(), method=True)
|
||||||
exprs_str = ", ".join([e.expr for e in exprs])
|
exprs_str = ", ".join([e.expr for e in exprs])
|
||||||
|
|
||||||
return f"""
|
result += f"""
|
||||||
// aten::{f.func}
|
// aten::{f.func}
|
||||||
inline {sig.defn(prefix="Tensor::")} const {{
|
inline {sig.defn(prefix="Tensor::")} const {{
|
||||||
return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str});
|
return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str});
|
||||||
}}
|
}}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result = generate_defn(faithful=False)
|
|
||||||
if sig_group.faithful_signature is not None:
|
|
||||||
result += generate_defn(faithful=True)
|
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -703,28 +685,19 @@ class ComputeRedispatchFunction:
|
||||||
f, method=False, fallback_binding=f.manual_cpp_binding
|
f, method=False, fallback_binding=f.manual_cpp_binding
|
||||||
)
|
)
|
||||||
|
|
||||||
def generate_defn(faithful: bool) -> str:
|
result = ""
|
||||||
if faithful:
|
for sig in sig_group.signatures():
|
||||||
sig = sig_group.faithful_signature
|
|
||||||
assert sig is not None
|
|
||||||
else:
|
|
||||||
sig = sig_group.signature
|
|
||||||
|
|
||||||
target_sig = DispatcherSignature.from_schema(f.func)
|
target_sig = DispatcherSignature.from_schema(f.func)
|
||||||
exprs = translate(sig.arguments(), target_sig.arguments())
|
exprs = translate(sig.arguments(), target_sig.arguments())
|
||||||
exprs_str = ", ".join(["dispatchKeySet"] + [a.expr for a in exprs])
|
exprs_str = ", ".join(["dispatchKeySet"] + [a.expr for a in exprs])
|
||||||
|
|
||||||
return f"""
|
result += f"""
|
||||||
// aten::{f.func}
|
// aten::{f.func}
|
||||||
inline {sig.decl(is_redispatching_fn=True)} {{
|
inline {sig.decl(is_redispatching_fn=True)} {{
|
||||||
return at::_ops::{f.func.name.unambiguous_name()}::redispatch({exprs_str});
|
return at::_ops::{f.func.name.unambiguous_name()}::redispatch({exprs_str});
|
||||||
}}
|
}}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result = generate_defn(False)
|
|
||||||
if sig_group.faithful_signature is not None:
|
|
||||||
result += generate_defn(True)
|
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user