mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72602
This change adds a script to auto-generate out variant dispatchers from `caffe2/aten/src/ATen/native/native_functions.yaml` for Static Runtime.
Test Plan: The following diff, D33373928, adds out variant dispatchers generated by this diff.
Reviewed By: mikeiovine
Differential Revision: D33373919
fbshipit-source-id: 1f6d0766c561c405a2d61d0171064f39e3d15a79
(cherry picked from commit b3a1923331)
311 lines
12 KiB
Python
311 lines
12 KiB
Python
import tools.codegen.api.cpp as cpp
|
|
from tools.codegen.context import native_function_manager
|
|
from tools.codegen.model import (Argument, BaseTy, FunctionSchema, OptionalType,
|
|
SelfArgument,
|
|
BaseType, NativeFunctionsGroup, TensorOptionsArguments, Type)
|
|
from tools.codegen.static_runtime import config
|
|
|
|
import math
|
|
from typing import List, Optional, Sequence, Tuple, Union
|
|
|
|
|
|
def has_alias(arguments: Sequence[Union[Argument, SelfArgument, TensorOptionsArguments]]) -> bool:
|
|
for arg in arguments:
|
|
annotation = getattr(arg, "annotation", None)
|
|
if not annotation:
|
|
continue
|
|
alias_set = getattr(annotation, "alias_set", ())
|
|
if alias_set:
|
|
return True
|
|
return False
|
|
|
|
def is_supported(g: NativeFunctionsGroup) -> bool:
|
|
if not g.structured:
|
|
return False
|
|
if config.is_hand_written(g):
|
|
return False
|
|
if has_alias(g.out.func.arguments.non_out):
|
|
# This op may create an alias of inputs.
|
|
return False
|
|
if len(g.out.func.arguments.out) > 1:
|
|
# More than 1 output values.
|
|
return False
|
|
if "at::Tensor &" != cpp.returns_type(g.out.func.returns).cpp_type():
|
|
# Returns a non-Tensor value.
|
|
return False
|
|
for arg in g.out.func.schema_order_arguments():
|
|
maybe_method = ivalue_type_conversion_method(arg.type)
|
|
if not maybe_method:
|
|
# Type converting is unsupported yet.
|
|
return False
|
|
return True
|
|
|
|
def ivalue_type_conversion_method(arg_type: Union[BaseType, OptionalType, Type]) -> Optional[Tuple[bool, str]]:
|
|
"""
|
|
Return the method call expression of `c10::ivalue' to convert its contained value to
|
|
the expected value of `arg_type` type. For example, for `arg_type` == BaseTy.Tensor,
|
|
this function returns ".toTensor()", so that it can be appended to the ivalue's
|
|
variable name to get the value of the expected type.
|
|
"""
|
|
type_conversion_methods = {
|
|
BaseTy.Tensor: ((True, "toTensor()"), (False, "toOptional<at::Tensor>()")),
|
|
BaseTy.int: ((False, "toInt()"), (False, "toOptional<int64_t>()")),
|
|
BaseTy.bool: ((False, "toBool()"), (False, "toOptional<bool>()")),
|
|
BaseTy.Scalar: ((False, "toScalar()"), (False, "toOptional<at::Scalar>()")),
|
|
BaseTy.ScalarType: ((False, "toScalarType()"), (False, "toOptional<at::ScalarType>()")),
|
|
BaseTy.str: ((False, "toStringView()"), (False, "toOptional<c10::string_view>()"))}
|
|
|
|
base_ty_object = None
|
|
if isinstance(arg_type, BaseType):
|
|
base_ty_object = arg_type.name
|
|
elif isinstance(arg_type, OptionalType):
|
|
assert isinstance(arg_type.elem, BaseType)
|
|
base_ty_object = arg_type.elem.name
|
|
else:
|
|
return None
|
|
|
|
if base_ty_object not in type_conversion_methods:
|
|
return None
|
|
methods = type_conversion_methods[base_ty_object]
|
|
if isinstance(arg_type, BaseType):
|
|
return methods[0]
|
|
return methods[1]
|
|
|
|
should_use_int_tensor_ops_ = frozenset(("bitwise_not", "bitwise_and", "bitwise_or", "bitwise_xor", "gcd",
|
|
"lcm", "scatter", "gather", "_convert_indices_from_coo_to_csr",
|
|
"_convert_indices_from_csr_to_coo"))
|
|
|
|
def should_use_int_tensor(op_name: str) -> bool:
|
|
return op_name in should_use_int_tensor_ops_
|
|
|
|
test_tensor_dim_ops_1_ = frozenset(("addmv", "index_add", "_convert_indices_from_coo_to_csr",
|
|
"_convert_indices_from_csr_to_coo", "nll_loss_backward"))
|
|
test_tensor_dim_ops_2_ = frozenset(("addmm", "mm"))
|
|
|
|
def test_tensor_dim(op_name: str) -> int:
|
|
if op_name in test_tensor_dim_ops_1_:
|
|
return 1
|
|
if op_name in test_tensor_dim_ops_2_:
|
|
return 2
|
|
return 3
|
|
|
|
def test_value_expression(arg_type: Union[BaseType, OptionalType, Type], index: int, op_name: str) -> str:
|
|
num_tensors = 16 if index == 0 else 64
|
|
num_dim = test_tensor_dim(op_name)
|
|
size_per_dim = math.ceil(num_tensors / float(num_dim))
|
|
size_per_dim += size_per_dim % 2
|
|
tensor_size_ex = "{%s}" % (",".join([f"{size_per_dim}"] * num_dim))
|
|
if should_use_int_tensor(op_name):
|
|
tensor_expression = f"at::randint(1, 100, {tensor_size_ex}, at::kInt)"
|
|
else:
|
|
tensor_expression = f"at::rand({tensor_size_ex})"
|
|
|
|
value_expressions = {
|
|
BaseTy.Tensor: tensor_expression,
|
|
BaseTy.int: "1",
|
|
BaseTy.bool: "false",
|
|
BaseTy.Scalar: "2",
|
|
BaseTy.ScalarType: "at::ScalarType::Float",
|
|
BaseTy.str: "\"floor\""}
|
|
|
|
base_ty_object = None
|
|
if isinstance(arg_type, BaseType):
|
|
base_ty_object = arg_type.name
|
|
else:
|
|
assert isinstance(arg_type, OptionalType) and isinstance(arg_type.elem, BaseType)
|
|
base_ty_object = arg_type.elem.name
|
|
assert base_ty_object in value_expressions, "not expected type"
|
|
value_expression = value_expressions[base_ty_object]
|
|
return value_expression
|
|
|
|
def generate_test_value_definitions(g: NativeFunctionsGroup, index: int) -> str:
|
|
schema = g.functional.func
|
|
assert not schema.is_out_fn()
|
|
schema_name = schema.name.name.base
|
|
arg_map = {}
|
|
for arg in schema.schema_order_arguments():
|
|
test_value_exp = test_value_expression(arg.type, index, schema_name)
|
|
arg_map[arg.name] = test_value_exp
|
|
config.override_test_values(arg_map, schema_name, index)
|
|
arg_populations = []
|
|
for arg_name, arg_value in arg_map.items():
|
|
arg_populations.append(f'auto {arg_name}{index} = {arg_value}')
|
|
return ";\n ".join(arg_populations) + ";"
|
|
|
|
def generate_test_value_names(g: NativeFunctionsGroup, index: int) -> str:
|
|
schema = g.functional.func
|
|
assert not schema.is_out_fn()
|
|
return ",".join(f"{arg.name}{index}" for arg in schema.schema_order_arguments())
|
|
|
|
generate_test_ir_arguments_base_ty_to_type_str_ = {
|
|
BaseTy.Tensor: 'Tensor', BaseTy.int: 'int', BaseTy.float: 'float',
|
|
BaseTy.str: 'str', BaseTy.Scalar: 'int', BaseTy.ScalarType: 'int',
|
|
BaseTy.bool: 'bool'}
|
|
|
|
def generate_test_ir_arguments(g: NativeFunctionsGroup) -> List[Tuple[str, Optional[str]]]:
|
|
def ir_argument(arg: Argument) -> Tuple[str, Optional[str]]:
|
|
t = arg.type
|
|
add_optional = False
|
|
if isinstance(t, OptionalType):
|
|
t = t.elem
|
|
add_optional = True
|
|
assert isinstance(t, BaseType)
|
|
type_str = None
|
|
if t.name in generate_test_ir_arguments_base_ty_to_type_str_:
|
|
type_str = generate_test_ir_arguments_base_ty_to_type_str_[t.name]
|
|
if type_str and add_optional:
|
|
type_str = f'{type_str}?'
|
|
return ("%" + arg.name, type_str)
|
|
|
|
schema = g.functional.func
|
|
assert not schema.is_out_fn()
|
|
return [ir_argument(arg) for arg in schema.schema_order_arguments()]
|
|
|
|
def generate_arg_extraction(g: NativeFunctionsGroup) -> str:
|
|
schema = g.functional.func
|
|
assert not schema.is_out_fn()
|
|
arg_populations = []
|
|
for i, arg in enumerate(schema.schema_order_arguments()):
|
|
maybe_method = ivalue_type_conversion_method(arg.type)
|
|
assert maybe_method
|
|
is_reference, type_conversion_method = maybe_method
|
|
reference = "&" if is_reference else ""
|
|
arg_populations.append(f'const auto{reference} {arg.name} = p_node->Input({i}).{type_conversion_method}')
|
|
return ";\n ".join(arg_populations) + ";"
|
|
|
|
def generate_non_out_variant_call(g: NativeFunctionsGroup) -> str:
|
|
schema = g.functional.func
|
|
assert not schema.is_out_fn()
|
|
arg_names = (arg.name for arg in schema.schema_order_arguments())
|
|
return f'at::cpu::{cpp.name(schema)}({",".join(arg_names)})'
|
|
|
|
def generate_out_variant_call(g: NativeFunctionsGroup) -> str:
|
|
schema = g.out.func
|
|
assert schema.is_out_fn()
|
|
arg_names = [out_arg.name for out_arg in schema.arguments.out]
|
|
for arg in schema.arguments.non_out:
|
|
if isinstance(arg, SelfArgument):
|
|
arg_names.append(arg.argument.name)
|
|
else:
|
|
assert isinstance(arg, Argument)
|
|
arg_names.append(arg.name)
|
|
cpp_func_name = cpp.name(schema)
|
|
cpp_arg_names = ",".join(arg_names)
|
|
return f'at::cpu::{cpp_func_name}({cpp_arg_names})'
|
|
|
|
|
|
def should_check_resize(schema: FunctionSchema) -> bool:
|
|
schema_str = str(schema)
|
|
type_variant_op_name = schema_str[: schema_str.find("(")]
|
|
return type_variant_op_name not in ("isin.Scalar_Tensor", "index_add")
|
|
|
|
|
|
def op_name_from_group(g: NativeFunctionsGroup) -> str:
|
|
return g.functional.func.name.name.base
|
|
|
|
|
|
class GenOutVariantDispatcher:
|
|
def __call__(self, groups: Sequence[NativeFunctionsGroup]) -> str:
|
|
if not groups:
|
|
return ""
|
|
generated_type_variants = []
|
|
for g in groups:
|
|
with native_function_manager(g):
|
|
assert is_supported(g)
|
|
assert isinstance(g, NativeFunctionsGroup)
|
|
generated_type_variant = self.gen_structured(g)
|
|
generated_type_variants.append(generated_type_variant)
|
|
op_name = op_name_from_group(groups[0])
|
|
body = "\n".join(generated_type_variants)
|
|
generated = f"""
|
|
REGISTER_OPERATOR_FUNCTOR(
|
|
aten::{op_name},
|
|
aten_{op_name},
|
|
[](Node* n) -> SROperator {{
|
|
{body}
|
|
LogAndDumpSchema(n);
|
|
return nullptr;
|
|
}});
|
|
"""
|
|
return generated
|
|
|
|
def gen_structured(self, g: NativeFunctionsGroup) -> str:
|
|
functional = g.functional
|
|
schema = str(functional.func)
|
|
op_name = op_name_from_group(g)
|
|
populated_argument = generate_arg_extraction(g)
|
|
functional_variant_call = generate_non_out_variant_call(g)
|
|
assert len(g.out.func.arguments.out) == 1
|
|
out_variable_name = str(g.out.func.arguments.out[0].name)
|
|
out_variant_call = generate_out_variant_call(g)
|
|
generated = f"""
|
|
if (n->matches(torch::schema("aten::{schema}"))) {{
|
|
return [](ProcessedNode* p_node) {{
|
|
{populated_argument}
|
|
if (p_node->Output(0).isNone()) {{
|
|
p_node->Output(0) = {functional_variant_call};
|
|
return;
|
|
}}
|
|
auto& {out_variable_name} = p_node->Output(0).toTensor();
|
|
fastResizeToZero({out_variable_name});
|
|
{out_variant_call};
|
|
}};
|
|
}}"""
|
|
return generated
|
|
|
|
|
|
class GenOutVariantDispatcherTestCase:
|
|
def __call__(self, groups: Sequence[NativeFunctionsGroup]) -> str:
|
|
if not groups:
|
|
return ""
|
|
generated_type_variants = []
|
|
for g in groups:
|
|
with native_function_manager(g):
|
|
assert is_supported(g)
|
|
assert isinstance(g, NativeFunctionsGroup)
|
|
generated_type_variant = self.gen_structured_test_case(g)
|
|
generated_type_variants.append(generated_type_variant)
|
|
return "\n".join(generated_type_variants)
|
|
|
|
def gen_structured_test_case(self, g: NativeFunctionsGroup) -> str:
|
|
functional = g.functional
|
|
schema = str(functional.func)
|
|
assert schema.find("(") > 0
|
|
type_variant_op_name = schema[: schema.find("(")].replace(".", "_")
|
|
op_name = op_name_from_group(g)
|
|
assert type_variant_op_name.startswith(op_name)
|
|
|
|
arg_types = generate_test_ir_arguments(g)
|
|
arg_declarations = ", ".join((arg_name if arg_type is None
|
|
else f"{arg_name}: {arg_type}"
|
|
for arg_name, arg_type in arg_types))
|
|
arg_names = ", ".join((arg_name for arg_name, _ in arg_types))
|
|
assert (len(functional.func.returns) == 1 and isinstance(functional.func.returns[0].type, BaseType) and
|
|
functional.func.returns[0].type.name is BaseTy.Tensor)
|
|
test_value_definitions = generate_test_value_definitions(g, 0)
|
|
test_value_names = generate_test_value_names(g, 0)
|
|
test_value_definitions2 = generate_test_value_definitions(g, 1)
|
|
test_value_names2 = generate_test_value_names(g, 1)
|
|
check_resize = "true" if should_check_resize(functional.func) else "false"
|
|
generated = f"""
|
|
TEST(StaticRuntime, autogen_{type_variant_op_name}) {{
|
|
const std::string script = R"IR(
|
|
graph({arg_declarations}):
|
|
%bias: None = prim::Constant()
|
|
%ret = aten::{op_name}({arg_names})
|
|
%cloned = aten::clone(%ret, %bias)
|
|
return (%cloned)
|
|
)IR";
|
|
|
|
{test_value_definitions}
|
|
std::vector<IValue> args{{{test_value_names}}};
|
|
testStaticRuntime(script, args, {{}}, /*use_allclose=*/false, /*use_equalnan=*/false, /*check_resize=*/{check_resize});
|
|
|
|
{test_value_definitions2}
|
|
std::vector<IValue> args2{{{test_value_names2}}};
|
|
testStaticRuntime(script, args, args2, /*use_allclose=*/false, /*use_equalnan=*/false, /*check_resize=*/{check_resize});
|
|
|
|
}}
|
|
"""
|
|
return generated
|