diff --git a/build_variables.bzl b/build_variables.bzl index 658aac7613a..288d0dd8546 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -592,6 +592,7 @@ libtorch_core_jit_sources = sorted(jit_sources_full) libtorch_nativert_sources = [ "torch/nativert/graph/Graph.cpp", + "torch/nativert/graph/GraphPasses.cpp", "torch/nativert/graph/GraphSignature.cpp", "torch/nativert/graph/Serialization.cpp", "torch/nativert/graph/TensorMeta.cpp", diff --git a/torch/nativert/graph/GraphPasses.cpp b/torch/nativert/graph/GraphPasses.cpp new file mode 100644 index 00000000000..327f32185e9 --- /dev/null +++ b/torch/nativert/graph/GraphPasses.cpp @@ -0,0 +1,180 @@ +#include + +#include + +#include + +#include +#include + +#include + +namespace torch::nativert { +namespace { +bool isScalar(const Constant& c) { + return std::holds_alternative(c) || + std::holds_alternative(c); +} + +bool isScalar(const Value& v) { + return v.type() == Type::Kind::SymInt || v.type() == Type::Kind::SymFloat; +} + +bool schemaTypeMatch(const c10::FunctionSchema& schema, const Node& node) { + std::unordered_set inputNames; + for (const auto& input : node.inputs()) { + // The number of arguments is always O(10), so we can just do a linear scan. + for (const auto& schemaArg : schema.arguments()) { + if (schemaArg.name() == input.name) { + if (schemaArg.type() == c10::TensorType::get() && input.value && + isScalar(*input.value)) { + return false; + } + break; + } + } + inputNames.insert(input.name); + } + for (const auto& constant : node.attributes()) { + for (const auto& schemaArg : schema.arguments()) { + if (schemaArg.name() == constant.name) { + if (schemaArg.type() == c10::TensorType::get() && + isScalar(constant.value)) { + return false; + } + break; + } + } + inputNames.insert(constant.name); + } + + // Make sure we have all the required arguments. + for (const auto& schemaArg : schema.arguments()) { + if (!schemaArg.default_value()) { + if (inputNames.find(schemaArg.name()) == inputNames.end()) { + return false; + } + } + } + return true; +} + +} // namespace + +// PT2 intentionally broadcast things like aten.sub.Scalar +// to aten.sub.Tensor. https://github.com/pytorch/pytorch/issues/90923. +std::string selectScalarOverloadName(const Node& node) { + // Copied from torch/csrc/utils/python_arg_parser.cpp + // torch::should_allow_numbers_as_tensors() to workaround + // some linking issues. + static std::unordered_set allowed = { + "add", + "add_", + "add_out", + "div", + "div_", + "div_out", + "divide", + "divide_", + "divide_out", // alias of div + "mul", + "mul_", + "mul_out", + "multiply", + "multiply_", + "multiply_out", // alias of mul + "sub", + "sub_", + "sub_out", + "subtract", + "subtract_", + "subtract_out", // alias of sub + "true_divide", + "true_divide_", + "true_divide_out", + "to", + "_to_copy", + "copy_", + "copy", + "floor_divide", + "floor_divide_", + "floor_divide_out", + "_conj"}; + std::vector atoms = c10::split(node.target(), '.'); + TORCH_CHECK_GE(atoms.size(), 3); + + std::string ns = std::string{atoms[atoms.size() - 3]}; + std::string opName = std::string{atoms[atoms.size() - 2]}; + std::string overloadName = std::string{atoms[atoms.size() - 1]}; + if (overloadName != "Tensor" && overloadName != "Tensor_Tensor" && + overloadName != "Tensor_mode") { + return overloadName; + } + if (allowed.find(std::string{opName}) == allowed.end()) { + return overloadName; + } + auto op = c10::Dispatcher::singleton().findSchemaOrThrow( + fmt::format("{}::{}", ns, opName.c_str()).c_str(), overloadName.c_str()); + if (schemaTypeMatch(op.schema(), node)) { + return overloadName; + } + for (const auto& variant : + {"Scalar_mode", "Scalar", "Scalar_Tensor", "Tensor_Scalar"}) { + if (auto schema = c10::Dispatcher::singleton().findSchema( + {fmt::format("{}::{}", ns, opName.c_str()).c_str(), variant})) { + if (schemaTypeMatch(schema->schema(), node)) { + return variant; + } + } + } + return overloadName; +} + +void selectScalarOverload(Graph* graph) { + for (auto& node : graph->nodes()) { + for (auto& attr : node.attributes()) { + if (std::holds_alternative>(attr.value)) { + selectScalarOverload( + std::get>(attr.value).get()); + } + } + + auto target = node.target(); + std::vector atoms = c10::split(target, '.'); + + size_t numAtoms = atoms.size(); + if (numAtoms != 5) { + continue; + } + + const std::string_view ns = atoms[numAtoms - 3]; + const std::string_view opName = atoms[numAtoms - 2]; + if (atoms[0] != "torch" || atoms[1] != "ops" || ns != "aten") { + continue; + } + + auto overloadName = selectScalarOverloadName(node); + if (overloadName != atoms[numAtoms - 1]) { + node.setTarget( + fmt::format("torch.ops.{}.{}.{}", ns, opName, overloadName)); + } else if (ns == "aten" && opName == "sub" && overloadName == "Tensor") { + // Special case for aten.sub.Tensor. + if (auto i = node.tryGetInput("self")) { + if (isScalar(*i->value)) { + node.updateInputName("self", "other"); + node.updateInputName("other", "self"); + node.setTarget("torch.ops.aten.rsub.Scalar"); + } + } + if (auto a = node.tryGetAttribute("self")) { + if (isScalar(a->value)) { + node.updateAttributeName("self", "other"); + node.updateInputName("other", "self"); + node.setTarget("torch.ops.aten.rsub.Scalar"); + } + } + } + } +} + +} // namespace torch::nativert diff --git a/torch/nativert/graph/GraphPasses.h b/torch/nativert/graph/GraphPasses.h new file mode 100644 index 00000000000..bbda9792a2a --- /dev/null +++ b/torch/nativert/graph/GraphPasses.h @@ -0,0 +1,11 @@ +#pragma once + +#include + +namespace torch::nativert { + +void selectScalarOverload(torch::nativert::Graph* graph); + +std::string selectScalarOverloadName(const torch::nativert::Node& node); + +} // namespace torch::nativert