mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Back out "[AOTI] Always use oss schema for ExternKernelNodes serialization" (#151026)
Summary: Revert for FC breaking Test Plan: CI Differential Revision: D72802075 Pull Request resolved: https://github.com/pytorch/pytorch/pull/151026 Approved by: https://github.com/hl475
This commit is contained in:
parent
f304483e95
commit
dbcd0b571d
|
|
@ -58,6 +58,15 @@ class AOTIRunnerUtil:
|
|||
restore_fqn=False,
|
||||
)
|
||||
|
||||
if IS_FBCODE:
|
||||
from deeplearning.aot_inductor.extern_node_thrift_serializer import (
|
||||
thrift_serializer,
|
||||
)
|
||||
|
||||
if options is None:
|
||||
options = {}
|
||||
options["extern_node_serializer"] = thrift_serializer
|
||||
|
||||
with torch.no_grad():
|
||||
so_path = torch._inductor.aot_compile(gm, example_inputs, options=options) # type: ignore[arg-type]
|
||||
|
||||
|
|
|
|||
|
|
@ -275,8 +275,7 @@ class TestTorchbind(TestCase):
|
|||
"is_hop_single_tensor_return": None,
|
||||
},
|
||||
},
|
||||
],
|
||||
"protocol": "json",
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
|
|
|
|||
14
torch/_export/serde/aoti_schema.py
Normal file
14
torch/_export/serde/aoti_schema.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
from dataclasses import dataclass
|
||||
|
||||
from torch._export.serde.schema import Node
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExternKernelNode:
|
||||
name: str
|
||||
node: Node
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExternKernelNodes:
|
||||
nodes: list[ExternKernelNode]
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
// @generated by update_schema.py
|
||||
// checksum<<3a8a6be8158821263b71ad9018c921664cd32c2f9b4deeac119e2292d186a02b>>
|
||||
// checksum<<f36968728ea96d9629b7c5269f5303e5cf23fba341d0221cb364aaf571b94dd6>>
|
||||
|
||||
namespace py3 torch._export
|
||||
namespace cpp2 torch._export.schema
|
||||
|
|
@ -358,5 +358,4 @@ struct ExternKernelNode {
|
|||
|
||||
struct ExternKernelNodes {
|
||||
10: list<ExternKernelNode> nodes;
|
||||
20: optional string protocol;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from typing import Annotated, Optional
|
|||
from torch._export.serde.union import _Union
|
||||
|
||||
# NOTE: Please update this value if any modifications are made to the schema
|
||||
SCHEMA_VERSION = (8, 8)
|
||||
SCHEMA_VERSION = (8, 7)
|
||||
TREESPEC_VERSION = 1
|
||||
|
||||
|
||||
|
|
@ -484,4 +484,3 @@ class ExternKernelNode:
|
|||
@dataclass
|
||||
class ExternKernelNodes:
|
||||
nodes: Annotated[list[ExternKernelNode], 10]
|
||||
protocol: Annotated[Optional[str], 20] = None
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
# @generated by update_schema.py
|
||||
# checksum<<9ce65dfb56cd253e43e4f529501c8158869aaf36048f8849fde36713c2039a57>>
|
||||
# checksum<<31c433c768b3f1bb61a5e8f4ceffc40c857bd80cf4fa0fc33fd03fa5ebb6c4d8>>
|
||||
AOTInductorModelPickleData:
|
||||
kind: struct
|
||||
fields:
|
||||
|
|
@ -141,9 +141,6 @@ ExternKernelNodes:
|
|||
fields:
|
||||
nodes:
|
||||
type: List[ExternKernelNode]
|
||||
protocol:
|
||||
type: Optional[str]
|
||||
default: None
|
||||
GradientToParameterSpec:
|
||||
kind: struct
|
||||
fields:
|
||||
|
|
@ -533,5 +530,5 @@ UserOutputSpec:
|
|||
type: Argument
|
||||
SCHEMA_VERSION:
|
||||
- 8
|
||||
- 8
|
||||
- 7
|
||||
TREESPEC_VERSION: 1
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import json
|
||||
|
||||
from torch._export.serde.schema import ExternKernelNode, ExternKernelNodes, Node
|
||||
from torch._export.serde.aoti_schema import ExternKernelNode, ExternKernelNodes, Node
|
||||
from torch._export.serde.serialize import _dataclass_to_dict, EnumEncoder
|
||||
from torch._inductor.ir import ExternKernelNode as inductor_ExternKernelNode
|
||||
|
||||
|
|
@ -19,7 +19,6 @@ def extern_node_json_serializer(
|
|||
extern_kernel_nodes: list[inductor_ExternKernelNode],
|
||||
) -> str:
|
||||
serialized_nodes = ExternKernelNodes(
|
||||
nodes=[serialize_extern_kernel_node(node) for node in extern_kernel_nodes],
|
||||
protocol="json",
|
||||
nodes=[serialize_extern_kernel_node(node) for node in extern_kernel_nodes]
|
||||
)
|
||||
return json.dumps(_dataclass_to_dict(serialized_nodes), cls=EnumEncoder, indent=2)
|
||||
return json.dumps(_dataclass_to_dict(serialized_nodes), cls=EnumEncoder)
|
||||
|
|
|
|||
|
|
@ -366,7 +366,9 @@ class GraphLowering(torch.fx.Interpreter):
|
|||
from torch._inductor.extern_node_serializer import extern_node_json_serializer
|
||||
|
||||
self.extern_node_serializer: Callable[[list[ir.ExternKernelNode]], Any] = (
|
||||
extern_node_json_serializer
|
||||
extern_node_serializer
|
||||
if config.is_fbcode() and extern_node_serializer
|
||||
else extern_node_json_serializer
|
||||
)
|
||||
|
||||
self.current_node: torch.fx.Node = None # type: ignore[assignment]
|
||||
|
|
|
|||
|
|
@ -112,7 +112,7 @@ consider rebuild your model with the latest AOTInductor.");
|
|||
|
||||
if (file_exists(json_filename)) {
|
||||
proxy_executor_ = std::make_unique<torch::aot_inductor::OSSProxyExecutor>(
|
||||
json_filename, device_str);
|
||||
json_filename, device_str == "cpu");
|
||||
proxy_executor_handle_ =
|
||||
reinterpret_cast<AOTIProxyExecutorHandle>(proxy_executor_.get());
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -18,19 +18,6 @@ bool has_key(
|
|||
return map.find(key) != map.end();
|
||||
}
|
||||
|
||||
c10::Device normalize_device(const c10::Device& device) {
|
||||
// cpu device doesn't have an index
|
||||
// cuda device must have an index
|
||||
if (device.is_cpu()) {
|
||||
return c10::Device(c10::DeviceType::CPU);
|
||||
} else if (device.is_cuda()) {
|
||||
return c10::Device(
|
||||
c10::DeviceType::CUDA, device.has_index() ? device.index() : 0);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported device type", device);
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef _WIN32
|
||||
const std::string k_separator = "\\";
|
||||
#else
|
||||
|
|
@ -224,11 +211,12 @@ void OSSProxyExecutor::prefill_stack_with_static_arguments(
|
|||
serialized_arg_val["index"].is_number()) {
|
||||
auto index = serialized_arg_val["index"].get<int>();
|
||||
device_string += ":" + std::to_string(index);
|
||||
device_->set_index(static_cast<int8_t>(index));
|
||||
}
|
||||
|
||||
c10::Device device(device_string);
|
||||
|
||||
if (device != *device_) {
|
||||
if (device.type() != device_->type()) {
|
||||
VLOG(1) << "ProxyExecutor is using " << *device_ << " for "
|
||||
<< op_kernel->target_ << " argument #" << index
|
||||
<< ", which is different from the one serialized in thrift: "
|
||||
|
|
@ -591,12 +579,15 @@ std::unique_ptr<OSSCallTorchBindKernel> OSSProxyExecutor::
|
|||
|
||||
OSSProxyExecutor::OSSProxyExecutor(
|
||||
const std::string& json_path,
|
||||
const std::string& device_str,
|
||||
bool is_cpu,
|
||||
std::optional<std::unordered_map<std::string, c10::IValue>> custom_objs) {
|
||||
// CUDA device must have an index as a kernel may require
|
||||
// an explicit device index. e.g., merge_pooled_embeddings
|
||||
c10::Device normalized_device = normalize_device(c10::Device(device_str));
|
||||
device_ = std::make_unique<c10::Device>(normalized_device);
|
||||
if (is_cpu) {
|
||||
device_ = std::make_unique<c10::Device>(c10::DeviceType::CPU);
|
||||
} else {
|
||||
int device_idx = -1;
|
||||
device_ = std::make_unique<c10::Device>(c10::DeviceType::CUDA, device_idx);
|
||||
}
|
||||
|
||||
// If custom_objs is provided, use it instead of loading from
|
||||
// custom_objs_config.json If custom_objs is not provided, try to load from
|
||||
// custom_objs_config.json
|
||||
|
|
@ -626,7 +617,7 @@ OSSProxyExecutor::OSSProxyExecutor(
|
|||
for (auto& [customObjName, file_name] : custom_objs_json.items()) {
|
||||
std::string customObjPath =
|
||||
folder_path + k_separator + file_name.get<std::string>();
|
||||
LOG(INFO) << "Loading custom object to OSSProxyExecutor from: "
|
||||
LOG(INFO) << "Loading custom object to FbProxyExecutor from: "
|
||||
<< customObjPath;
|
||||
|
||||
std::ifstream custom_obj_file(customObjPath, std::ios::binary);
|
||||
|
|
|
|||
|
|
@ -12,11 +12,26 @@
|
|||
|
||||
namespace torch::aot_inductor {
|
||||
|
||||
enum class DynamicArgType : int {
|
||||
TensorType = 0,
|
||||
ListTensorType = 1,
|
||||
ListOptionalTensorType = 2,
|
||||
IntType = 3,
|
||||
ListIntType = 4,
|
||||
NoneType = 5,
|
||||
};
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, DynamicArgType arg_type) {
|
||||
os << static_cast<int>(arg_type);
|
||||
return os;
|
||||
}
|
||||
|
||||
inline bool isTensorType(DynamicArgType arg_type) {
|
||||
return arg_type == DynamicArgType::TensorType ||
|
||||
arg_type == DynamicArgType::ListTensorType ||
|
||||
arg_type == DynamicArgType::ListOptionalTensorType;
|
||||
}
|
||||
|
||||
struct OSSDynamicArg {
|
||||
OSSDynamicArg(
|
||||
int arg_index,
|
||||
|
|
@ -103,7 +118,7 @@ class OSSProxyExecutor : public ProxyExecutor {
|
|||
public:
|
||||
explicit OSSProxyExecutor(
|
||||
const std::string& json_path,
|
||||
const std::string& device_str,
|
||||
bool is_cpu,
|
||||
std::optional<std::unordered_map<std::string, c10::IValue>> custom_objs =
|
||||
std::nullopt);
|
||||
|
||||
|
|
|
|||
|
|
@ -6,21 +6,6 @@
|
|||
|
||||
namespace torch::aot_inductor {
|
||||
|
||||
enum DynamicArgType : int {
|
||||
TensorType = 0,
|
||||
ListTensorType = 1,
|
||||
ListOptionalTensorType = 2,
|
||||
IntType = 3,
|
||||
ListIntType = 4,
|
||||
NoneType = 5,
|
||||
};
|
||||
|
||||
inline bool isTensorType(DynamicArgType arg_type) {
|
||||
return arg_type == DynamicArgType::TensorType ||
|
||||
arg_type == DynamicArgType::ListTensorType ||
|
||||
arg_type == DynamicArgType::ListOptionalTensorType;
|
||||
}
|
||||
|
||||
class ProxyExecutor {
|
||||
public:
|
||||
ProxyExecutor() = default;
|
||||
|
|
|
|||
20
torch/csrc/utils/generated_serialization_types.h
generated
20
torch/csrc/utils/generated_serialization_types.h
generated
|
|
@ -1,5 +1,5 @@
|
|||
// @generated by update_schema.py
|
||||
// checksum<<9ce65dfb56cd253e43e4f529501c8158869aaf36048f8849fde36713c2039a57>>
|
||||
// checksum<<31c433c768b3f1bb61a5e8f4ceffc40c857bd80cf4fa0fc33fd03fa5ebb6c4d8>>
|
||||
// clang-format off
|
||||
|
||||
#pragma once
|
||||
|
|
@ -54,9 +54,9 @@ class ForwardRef {
|
|||
|
||||
public:
|
||||
ForwardRef(): ptr_(std::make_unique<T>()) {}
|
||||
ForwardRef(ForwardRef<T>&&);
|
||||
ForwardRef(ForwardRef<T>&&) = default;
|
||||
ForwardRef(const ForwardRef<T>& other): ptr_(std::make_unique<T>(*other.ptr_)) {}
|
||||
ForwardRef<T>& operator=(ForwardRef<T>&&);
|
||||
ForwardRef<T>& operator=(ForwardRef<T>&&) = default;
|
||||
ForwardRef<T>& operator=(const ForwardRef<T>& other) {
|
||||
ptr_ = std::make_unique<T>(*other.ptr_);
|
||||
return *this;
|
||||
|
|
@ -3216,7 +3216,6 @@ class ExternKernelNode {
|
|||
class ExternKernelNodes {
|
||||
private:
|
||||
std::vector<ExternKernelNode> nodes;
|
||||
std::optional<std::string> protocol = std::nullopt;
|
||||
|
||||
public:
|
||||
|
||||
|
|
@ -3228,14 +3227,6 @@ class ExternKernelNodes {
|
|||
nodes = std::move(def);
|
||||
}
|
||||
|
||||
const std::optional<std::string>& get_protocol() const {
|
||||
return protocol;
|
||||
}
|
||||
|
||||
void set_protocol(std::optional<std::string> def) {
|
||||
protocol = std::move(def);
|
||||
}
|
||||
|
||||
friend void to_json(nlohmann::json& nlohmann_json_j, const ExternKernelNodes& nlohmann_json_t);
|
||||
friend void from_json(const nlohmann::json& nlohmann_json_j, ExternKernelNodes& nlohmann_json_t);
|
||||
};
|
||||
|
|
@ -3324,13 +3315,11 @@ inline void from_json(const nlohmann::json& nlohmann_json_j, ExternKernelNode& n
|
|||
|
||||
inline void to_json(nlohmann::json& nlohmann_json_j, const ExternKernelNodes& nlohmann_json_t) {
|
||||
nlohmann_json_j["nodes"] = nlohmann_json_t.nodes;
|
||||
nlohmann_json_j["protocol"] = nlohmann_json_t.protocol;
|
||||
}
|
||||
|
||||
inline void from_json(const nlohmann::json& nlohmann_json_j, ExternKernelNodes& nlohmann_json_t) {
|
||||
ExternKernelNodes nlohmann_json_default_obj;
|
||||
nlohmann_json_t.nodes = nlohmann_json_j.value("nodes", nlohmann_json_default_obj.nodes);
|
||||
nlohmann_json_t.protocol = nlohmann_json_j.value("protocol", nlohmann_json_default_obj.protocol);
|
||||
}
|
||||
|
||||
inline void to_json(nlohmann::json& nlohmann_json_j, const GradientToParameterSpec& nlohmann_json_t) {
|
||||
|
|
@ -3699,9 +3688,6 @@ inline void from_json(const nlohmann::json& nlohmann_json_j, UserOutputSpec& nlo
|
|||
nlohmann_json_t.arg = nlohmann_json_j.value("arg", nlohmann_json_default_obj.arg);
|
||||
}
|
||||
|
||||
|
||||
template <typename T> ForwardRef<T>::ForwardRef(ForwardRef<T>&&) = default;
|
||||
template <typename T> ForwardRef<T>& ForwardRef<T>::operator=(ForwardRef<T>&&) = default;
|
||||
} // namespace _export
|
||||
} // namespace torch
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user