mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[nativert] Move graph_passes to nativert (#155411)
Summary: Move graph_passes to nativert Test Plan: CI Rollback Plan: Differential Revision: D76205048 Pull Request resolved: https://github.com/pytorch/pytorch/pull/155411 Approved by: https://github.com/zhxchen17
This commit is contained in:
parent
338a8c7853
commit
9462106b7e
|
|
@ -592,6 +592,7 @@ libtorch_core_jit_sources = sorted(jit_sources_full)
|
|||
|
||||
libtorch_nativert_sources = [
|
||||
"torch/nativert/graph/Graph.cpp",
|
||||
"torch/nativert/graph/GraphPasses.cpp",
|
||||
"torch/nativert/graph/GraphSignature.cpp",
|
||||
"torch/nativert/graph/Serialization.cpp",
|
||||
"torch/nativert/graph/TensorMeta.cpp",
|
||||
|
|
|
|||
180
torch/nativert/graph/GraphPasses.cpp
Normal file
180
torch/nativert/graph/GraphPasses.cpp
Normal file
|
|
@ -0,0 +1,180 @@
|
|||
#include <torch/nativert/graph/GraphPasses.h>
|
||||
|
||||
#include <unordered_set>
|
||||
|
||||
#include <fmt/format.h>
|
||||
|
||||
#include <ATen/core/dispatch/Dispatcher.h>
|
||||
#include <ATen/core/function_schema.h>
|
||||
|
||||
#include <c10/util/StringUtil.h>
|
||||
|
||||
namespace torch::nativert {
|
||||
namespace {
|
||||
bool isScalar(const Constant& c) {
|
||||
return std::holds_alternative<int64_t>(c) ||
|
||||
std::holds_alternative<double>(c);
|
||||
}
|
||||
|
||||
bool isScalar(const Value& v) {
|
||||
return v.type() == Type::Kind::SymInt || v.type() == Type::Kind::SymFloat;
|
||||
}
|
||||
|
||||
bool schemaTypeMatch(const c10::FunctionSchema& schema, const Node& node) {
|
||||
std::unordered_set<std::string> inputNames;
|
||||
for (const auto& input : node.inputs()) {
|
||||
// The number of arguments is always O(10), so we can just do a linear scan.
|
||||
for (const auto& schemaArg : schema.arguments()) {
|
||||
if (schemaArg.name() == input.name) {
|
||||
if (schemaArg.type() == c10::TensorType::get() && input.value &&
|
||||
isScalar(*input.value)) {
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
inputNames.insert(input.name);
|
||||
}
|
||||
for (const auto& constant : node.attributes()) {
|
||||
for (const auto& schemaArg : schema.arguments()) {
|
||||
if (schemaArg.name() == constant.name) {
|
||||
if (schemaArg.type() == c10::TensorType::get() &&
|
||||
isScalar(constant.value)) {
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
inputNames.insert(constant.name);
|
||||
}
|
||||
|
||||
// Make sure we have all the required arguments.
|
||||
for (const auto& schemaArg : schema.arguments()) {
|
||||
if (!schemaArg.default_value()) {
|
||||
if (inputNames.find(schemaArg.name()) == inputNames.end()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// PT2 intentionally broadcast things like aten.sub.Scalar
|
||||
// to aten.sub.Tensor. https://github.com/pytorch/pytorch/issues/90923.
|
||||
std::string selectScalarOverloadName(const Node& node) {
|
||||
// Copied from torch/csrc/utils/python_arg_parser.cpp
|
||||
// torch::should_allow_numbers_as_tensors() to workaround
|
||||
// some linking issues.
|
||||
static std::unordered_set<std::string> allowed = {
|
||||
"add",
|
||||
"add_",
|
||||
"add_out",
|
||||
"div",
|
||||
"div_",
|
||||
"div_out",
|
||||
"divide",
|
||||
"divide_",
|
||||
"divide_out", // alias of div
|
||||
"mul",
|
||||
"mul_",
|
||||
"mul_out",
|
||||
"multiply",
|
||||
"multiply_",
|
||||
"multiply_out", // alias of mul
|
||||
"sub",
|
||||
"sub_",
|
||||
"sub_out",
|
||||
"subtract",
|
||||
"subtract_",
|
||||
"subtract_out", // alias of sub
|
||||
"true_divide",
|
||||
"true_divide_",
|
||||
"true_divide_out",
|
||||
"to",
|
||||
"_to_copy",
|
||||
"copy_",
|
||||
"copy",
|
||||
"floor_divide",
|
||||
"floor_divide_",
|
||||
"floor_divide_out",
|
||||
"_conj"};
|
||||
std::vector<std::string_view> atoms = c10::split(node.target(), '.');
|
||||
TORCH_CHECK_GE(atoms.size(), 3);
|
||||
|
||||
std::string ns = std::string{atoms[atoms.size() - 3]};
|
||||
std::string opName = std::string{atoms[atoms.size() - 2]};
|
||||
std::string overloadName = std::string{atoms[atoms.size() - 1]};
|
||||
if (overloadName != "Tensor" && overloadName != "Tensor_Tensor" &&
|
||||
overloadName != "Tensor_mode") {
|
||||
return overloadName;
|
||||
}
|
||||
if (allowed.find(std::string{opName}) == allowed.end()) {
|
||||
return overloadName;
|
||||
}
|
||||
auto op = c10::Dispatcher::singleton().findSchemaOrThrow(
|
||||
fmt::format("{}::{}", ns, opName.c_str()).c_str(), overloadName.c_str());
|
||||
if (schemaTypeMatch(op.schema(), node)) {
|
||||
return overloadName;
|
||||
}
|
||||
for (const auto& variant :
|
||||
{"Scalar_mode", "Scalar", "Scalar_Tensor", "Tensor_Scalar"}) {
|
||||
if (auto schema = c10::Dispatcher::singleton().findSchema(
|
||||
{fmt::format("{}::{}", ns, opName.c_str()).c_str(), variant})) {
|
||||
if (schemaTypeMatch(schema->schema(), node)) {
|
||||
return variant;
|
||||
}
|
||||
}
|
||||
}
|
||||
return overloadName;
|
||||
}
|
||||
|
||||
void selectScalarOverload(Graph* graph) {
|
||||
for (auto& node : graph->nodes()) {
|
||||
for (auto& attr : node.attributes()) {
|
||||
if (std::holds_alternative<std::unique_ptr<Graph>>(attr.value)) {
|
||||
selectScalarOverload(
|
||||
std::get<std::unique_ptr<Graph>>(attr.value).get());
|
||||
}
|
||||
}
|
||||
|
||||
auto target = node.target();
|
||||
std::vector<std::string_view> atoms = c10::split(target, '.');
|
||||
|
||||
size_t numAtoms = atoms.size();
|
||||
if (numAtoms != 5) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const std::string_view ns = atoms[numAtoms - 3];
|
||||
const std::string_view opName = atoms[numAtoms - 2];
|
||||
if (atoms[0] != "torch" || atoms[1] != "ops" || ns != "aten") {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto overloadName = selectScalarOverloadName(node);
|
||||
if (overloadName != atoms[numAtoms - 1]) {
|
||||
node.setTarget(
|
||||
fmt::format("torch.ops.{}.{}.{}", ns, opName, overloadName));
|
||||
} else if (ns == "aten" && opName == "sub" && overloadName == "Tensor") {
|
||||
// Special case for aten.sub.Tensor.
|
||||
if (auto i = node.tryGetInput("self")) {
|
||||
if (isScalar(*i->value)) {
|
||||
node.updateInputName("self", "other");
|
||||
node.updateInputName("other", "self");
|
||||
node.setTarget("torch.ops.aten.rsub.Scalar");
|
||||
}
|
||||
}
|
||||
if (auto a = node.tryGetAttribute("self")) {
|
||||
if (isScalar(a->value)) {
|
||||
node.updateAttributeName("self", "other");
|
||||
node.updateInputName("other", "self");
|
||||
node.setTarget("torch.ops.aten.rsub.Scalar");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace torch::nativert
|
||||
11
torch/nativert/graph/GraphPasses.h
Normal file
11
torch/nativert/graph/GraphPasses.h
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
#pragma once
|
||||
|
||||
#include <torch/nativert/graph/Graph.h>
|
||||
|
||||
namespace torch::nativert {
|
||||
|
||||
void selectScalarOverload(torch::nativert::Graph* graph);
|
||||
|
||||
std::string selectScalarOverloadName(const torch::nativert::Node& node);
|
||||
|
||||
} // namespace torch::nativert
|
||||
Loading…
Reference in New Issue
Block a user