Back out "[AOTI] Always use oss schema for ExternKernelNodes serialization" (#151026)

Summary: Revert for FC breaking

Test Plan: CI

Differential Revision: D72802075

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151026
Approved by: https://github.com/hl475
This commit is contained in:
Yiming Zhou 2025-04-10 22:36:31 +00:00 committed by PyTorch MergeBot
parent f304483e95
commit dbcd0b571d
13 changed files with 65 additions and 70 deletions

View File

@ -58,6 +58,15 @@ class AOTIRunnerUtil:
restore_fqn=False,
)
if IS_FBCODE:
from deeplearning.aot_inductor.extern_node_thrift_serializer import (
thrift_serializer,
)
if options is None:
options = {}
options["extern_node_serializer"] = thrift_serializer
with torch.no_grad():
so_path = torch._inductor.aot_compile(gm, example_inputs, options=options) # type: ignore[arg-type]

View File

@ -275,8 +275,7 @@ class TestTorchbind(TestCase):
"is_hop_single_tensor_return": None,
},
},
],
"protocol": "json",
]
},
)

View File

@ -0,0 +1,14 @@
from dataclasses import dataclass
from torch._export.serde.schema import Node
@dataclass
class ExternKernelNode:
name: str
node: Node
@dataclass
class ExternKernelNodes:
nodes: list[ExternKernelNode]

View File

@ -1,5 +1,5 @@
// @generated by update_schema.py
// checksum<<3a8a6be8158821263b71ad9018c921664cd32c2f9b4deeac119e2292d186a02b>>
// checksum<<f36968728ea96d9629b7c5269f5303e5cf23fba341d0221cb364aaf571b94dd6>>
namespace py3 torch._export
namespace cpp2 torch._export.schema
@ -358,5 +358,4 @@ struct ExternKernelNode {
struct ExternKernelNodes {
10: list<ExternKernelNode> nodes;
20: optional string protocol;
}

View File

@ -8,7 +8,7 @@ from typing import Annotated, Optional
from torch._export.serde.union import _Union
# NOTE: Please update this value if any modifications are made to the schema
SCHEMA_VERSION = (8, 8)
SCHEMA_VERSION = (8, 7)
TREESPEC_VERSION = 1
@ -484,4 +484,3 @@ class ExternKernelNode:
@dataclass
class ExternKernelNodes:
nodes: Annotated[list[ExternKernelNode], 10]
protocol: Annotated[Optional[str], 20] = None

View File

@ -1,5 +1,5 @@
# @generated by update_schema.py
# checksum<<9ce65dfb56cd253e43e4f529501c8158869aaf36048f8849fde36713c2039a57>>
# checksum<<31c433c768b3f1bb61a5e8f4ceffc40c857bd80cf4fa0fc33fd03fa5ebb6c4d8>>
AOTInductorModelPickleData:
kind: struct
fields:
@ -141,9 +141,6 @@ ExternKernelNodes:
fields:
nodes:
type: List[ExternKernelNode]
protocol:
type: Optional[str]
default: None
GradientToParameterSpec:
kind: struct
fields:
@ -533,5 +530,5 @@ UserOutputSpec:
type: Argument
SCHEMA_VERSION:
- 8
- 8
- 7
TREESPEC_VERSION: 1

View File

@ -1,6 +1,6 @@
import json
from torch._export.serde.schema import ExternKernelNode, ExternKernelNodes, Node
from torch._export.serde.aoti_schema import ExternKernelNode, ExternKernelNodes, Node
from torch._export.serde.serialize import _dataclass_to_dict, EnumEncoder
from torch._inductor.ir import ExternKernelNode as inductor_ExternKernelNode
@ -19,7 +19,6 @@ def extern_node_json_serializer(
extern_kernel_nodes: list[inductor_ExternKernelNode],
) -> str:
serialized_nodes = ExternKernelNodes(
nodes=[serialize_extern_kernel_node(node) for node in extern_kernel_nodes],
protocol="json",
nodes=[serialize_extern_kernel_node(node) for node in extern_kernel_nodes]
)
return json.dumps(_dataclass_to_dict(serialized_nodes), cls=EnumEncoder, indent=2)
return json.dumps(_dataclass_to_dict(serialized_nodes), cls=EnumEncoder)

View File

@ -366,7 +366,9 @@ class GraphLowering(torch.fx.Interpreter):
from torch._inductor.extern_node_serializer import extern_node_json_serializer
self.extern_node_serializer: Callable[[list[ir.ExternKernelNode]], Any] = (
extern_node_json_serializer
extern_node_serializer
if config.is_fbcode() and extern_node_serializer
else extern_node_json_serializer
)
self.current_node: torch.fx.Node = None # type: ignore[assignment]

View File

@ -112,7 +112,7 @@ consider rebuild your model with the latest AOTInductor.");
if (file_exists(json_filename)) {
proxy_executor_ = std::make_unique<torch::aot_inductor::OSSProxyExecutor>(
json_filename, device_str);
json_filename, device_str == "cpu");
proxy_executor_handle_ =
reinterpret_cast<AOTIProxyExecutorHandle>(proxy_executor_.get());
} else {

View File

@ -18,19 +18,6 @@ bool has_key(
return map.find(key) != map.end();
}
c10::Device normalize_device(const c10::Device& device) {
// cpu device doesn't have an index
// cuda device must have an index
if (device.is_cpu()) {
return c10::Device(c10::DeviceType::CPU);
} else if (device.is_cuda()) {
return c10::Device(
c10::DeviceType::CUDA, device.has_index() ? device.index() : 0);
} else {
TORCH_CHECK(false, "Unsupported device type", device);
}
}
#ifdef _WIN32
const std::string k_separator = "\\";
#else
@ -224,11 +211,12 @@ void OSSProxyExecutor::prefill_stack_with_static_arguments(
serialized_arg_val["index"].is_number()) {
auto index = serialized_arg_val["index"].get<int>();
device_string += ":" + std::to_string(index);
device_->set_index(static_cast<int8_t>(index));
}
c10::Device device(device_string);
if (device != *device_) {
if (device.type() != device_->type()) {
VLOG(1) << "ProxyExecutor is using " << *device_ << " for "
<< op_kernel->target_ << " argument #" << index
<< ", which is different from the one serialized in thrift: "
@ -591,12 +579,15 @@ std::unique_ptr<OSSCallTorchBindKernel> OSSProxyExecutor::
OSSProxyExecutor::OSSProxyExecutor(
const std::string& json_path,
const std::string& device_str,
bool is_cpu,
std::optional<std::unordered_map<std::string, c10::IValue>> custom_objs) {
// CUDA device must have an index as a kernel may require
// an explicit device index. e.g., merge_pooled_embeddings
c10::Device normalized_device = normalize_device(c10::Device(device_str));
device_ = std::make_unique<c10::Device>(normalized_device);
if (is_cpu) {
device_ = std::make_unique<c10::Device>(c10::DeviceType::CPU);
} else {
int device_idx = -1;
device_ = std::make_unique<c10::Device>(c10::DeviceType::CUDA, device_idx);
}
// If custom_objs is provided, use it instead of loading from
// custom_objs_config.json If custom_objs is not provided, try to load from
// custom_objs_config.json
@ -626,7 +617,7 @@ OSSProxyExecutor::OSSProxyExecutor(
for (auto& [customObjName, file_name] : custom_objs_json.items()) {
std::string customObjPath =
folder_path + k_separator + file_name.get<std::string>();
LOG(INFO) << "Loading custom object to OSSProxyExecutor from: "
LOG(INFO) << "Loading custom object to FbProxyExecutor from: "
<< customObjPath;
std::ifstream custom_obj_file(customObjPath, std::ios::binary);

View File

@ -12,11 +12,26 @@
namespace torch::aot_inductor {
enum class DynamicArgType : int {
TensorType = 0,
ListTensorType = 1,
ListOptionalTensorType = 2,
IntType = 3,
ListIntType = 4,
NoneType = 5,
};
inline std::ostream& operator<<(std::ostream& os, DynamicArgType arg_type) {
os << static_cast<int>(arg_type);
return os;
}
inline bool isTensorType(DynamicArgType arg_type) {
return arg_type == DynamicArgType::TensorType ||
arg_type == DynamicArgType::ListTensorType ||
arg_type == DynamicArgType::ListOptionalTensorType;
}
struct OSSDynamicArg {
OSSDynamicArg(
int arg_index,
@ -103,7 +118,7 @@ class OSSProxyExecutor : public ProxyExecutor {
public:
explicit OSSProxyExecutor(
const std::string& json_path,
const std::string& device_str,
bool is_cpu,
std::optional<std::unordered_map<std::string, c10::IValue>> custom_objs =
std::nullopt);

View File

@ -6,21 +6,6 @@
namespace torch::aot_inductor {
enum DynamicArgType : int {
TensorType = 0,
ListTensorType = 1,
ListOptionalTensorType = 2,
IntType = 3,
ListIntType = 4,
NoneType = 5,
};
inline bool isTensorType(DynamicArgType arg_type) {
return arg_type == DynamicArgType::TensorType ||
arg_type == DynamicArgType::ListTensorType ||
arg_type == DynamicArgType::ListOptionalTensorType;
}
class ProxyExecutor {
public:
ProxyExecutor() = default;

View File

@ -1,5 +1,5 @@
// @generated by update_schema.py
// checksum<<9ce65dfb56cd253e43e4f529501c8158869aaf36048f8849fde36713c2039a57>>
// checksum<<31c433c768b3f1bb61a5e8f4ceffc40c857bd80cf4fa0fc33fd03fa5ebb6c4d8>>
// clang-format off
#pragma once
@ -54,9 +54,9 @@ class ForwardRef {
public:
ForwardRef(): ptr_(std::make_unique<T>()) {}
ForwardRef(ForwardRef<T>&&);
ForwardRef(ForwardRef<T>&&) = default;
ForwardRef(const ForwardRef<T>& other): ptr_(std::make_unique<T>(*other.ptr_)) {}
ForwardRef<T>& operator=(ForwardRef<T>&&);
ForwardRef<T>& operator=(ForwardRef<T>&&) = default;
ForwardRef<T>& operator=(const ForwardRef<T>& other) {
ptr_ = std::make_unique<T>(*other.ptr_);
return *this;
@ -3216,7 +3216,6 @@ class ExternKernelNode {
class ExternKernelNodes {
private:
std::vector<ExternKernelNode> nodes;
std::optional<std::string> protocol = std::nullopt;
public:
@ -3228,14 +3227,6 @@ class ExternKernelNodes {
nodes = std::move(def);
}
const std::optional<std::string>& get_protocol() const {
return protocol;
}
void set_protocol(std::optional<std::string> def) {
protocol = std::move(def);
}
friend void to_json(nlohmann::json& nlohmann_json_j, const ExternKernelNodes& nlohmann_json_t);
friend void from_json(const nlohmann::json& nlohmann_json_j, ExternKernelNodes& nlohmann_json_t);
};
@ -3324,13 +3315,11 @@ inline void from_json(const nlohmann::json& nlohmann_json_j, ExternKernelNode& n
inline void to_json(nlohmann::json& nlohmann_json_j, const ExternKernelNodes& nlohmann_json_t) {
nlohmann_json_j["nodes"] = nlohmann_json_t.nodes;
nlohmann_json_j["protocol"] = nlohmann_json_t.protocol;
}
inline void from_json(const nlohmann::json& nlohmann_json_j, ExternKernelNodes& nlohmann_json_t) {
ExternKernelNodes nlohmann_json_default_obj;
nlohmann_json_t.nodes = nlohmann_json_j.value("nodes", nlohmann_json_default_obj.nodes);
nlohmann_json_t.protocol = nlohmann_json_j.value("protocol", nlohmann_json_default_obj.protocol);
}
inline void to_json(nlohmann::json& nlohmann_json_j, const GradientToParameterSpec& nlohmann_json_t) {
@ -3699,9 +3688,6 @@ inline void from_json(const nlohmann::json& nlohmann_json_j, UserOutputSpec& nlo
nlohmann_json_t.arg = nlohmann_json_j.value("arg", nlohmann_json_default_obj.arg);
}
template <typename T> ForwardRef<T>::ForwardRef(ForwardRef<T>&&) = default;
template <typename T> ForwardRef<T>& ForwardRef<T>::operator=(ForwardRef<T>&&) = default;
} // namespace _export
} // namespace torch