mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Move static from_ivalue/to_ivalue to new shim_common.cpp (#166373)
Move `from_ivalue` and `to_ivalue` and their dependents `StableIValueBoxedKernel`, `aoti_torch_library_impl` `aoti_torch_call_dispatcher` into new (non-aoti shim_common.cpp) This is in prep for the above PRs where I add v2s (`torch_call_dispatcher` and `torch_library_impl`) that are versioning aware Pull Request resolved: https://github.com/pytorch/pytorch/pull/166373 Approved by: https://github.com/janeyx99 ghstack dependencies: #164356
This commit is contained in:
parent
fefb546b91
commit
c0bbda37e8
|
|
@ -482,6 +482,7 @@ inductor_core_resources = [
|
|||
"torch/csrc/inductor/aoti_torch/oss_proxy_executor.cpp",
|
||||
"torch/csrc/inductor/inductor_ops.cpp",
|
||||
"torch/csrc/jit/serialization/pickle.cpp",
|
||||
"torch/csrc/shim_common.cpp",
|
||||
]
|
||||
|
||||
libtorch_core_sources = sorted(
|
||||
|
|
|
|||
|
|
@ -1406,169 +1406,6 @@ AOTITorchError aoti_torch_zero_(AtenTensorHandle tensor) {
|
|||
});
|
||||
}
|
||||
|
||||
static StableIValue from_ivalue(
|
||||
const c10::TypePtr& type,
|
||||
const c10::IValue& ivalue) {
|
||||
switch (type->kind()) {
|
||||
case c10::TypeKind::TensorType: {
|
||||
AtenTensorHandle ath = torch::aot_inductor::new_tensor_handle(
|
||||
std::move(const_cast<at::Tensor&>(ivalue.toTensor())));
|
||||
return torch::stable::detail::from(ath);
|
||||
}
|
||||
case c10::TypeKind::IntType: {
|
||||
return torch::stable::detail::from(ivalue.toInt());
|
||||
}
|
||||
case c10::TypeKind::FloatType: {
|
||||
return torch::stable::detail::from(ivalue.toDouble());
|
||||
}
|
||||
case c10::TypeKind::BoolType: {
|
||||
return torch::stable::detail::from(ivalue.toBool());
|
||||
}
|
||||
case c10::TypeKind::ScalarTypeType: {
|
||||
return torch::stable::detail::from(ivalue.toScalarType());
|
||||
}
|
||||
case c10::TypeKind::DeviceObjType: {
|
||||
return torch::stable::detail::from(ivalue.toDevice());
|
||||
}
|
||||
case c10::TypeKind::LayoutType: {
|
||||
return torch::stable::detail::from(ivalue.toLayout());
|
||||
}
|
||||
case c10::TypeKind::MemoryFormatType: {
|
||||
return torch::stable::detail::from(ivalue.toMemoryFormat());
|
||||
}
|
||||
case c10::TypeKind::OptionalType: {
|
||||
auto inner_type = type->castRaw<at::OptionalType>()->getElementType();
|
||||
|
||||
// ideally, if we had the C++ type corresponding to inner_type, which we
|
||||
// will denote as inner_type::t (does not actually exist), we would be
|
||||
// able to follow the patterned semantic of every other case here in one
|
||||
// line:
|
||||
//
|
||||
// return
|
||||
// torch::stable::detail::from<std::optional<inner_type::t>>(ivalue.toInnerTypeT()));
|
||||
//
|
||||
// BUT we do NOT have that type inner_type::t readily available, so we
|
||||
// will manually unwrap and recursively call. This implementation MUST
|
||||
// be kept in sync with torch::stable::detail::from<std::optional<T>>
|
||||
// function in torch/csrc/stable/stableivalue_conversions.h
|
||||
if (ivalue.isNone()) {
|
||||
return torch::stable::detail::from(std::nullopt);
|
||||
}
|
||||
StableIValue* sivp = new StableIValue(from_ivalue(inner_type, ivalue));
|
||||
return torch::stable::detail::from(sivp);
|
||||
}
|
||||
default: {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Not yet supported conversion from IValue to StableIValue for schema type: ",
|
||||
type->str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static c10::IValue to_ivalue(
|
||||
const c10::TypePtr& type,
|
||||
const StableIValue stable_ivalue) {
|
||||
switch (type->kind()) {
|
||||
case c10::TypeKind::TensorType: {
|
||||
auto ret_raiiath = torch::aot_inductor::RAIIAtenTensorHandle(
|
||||
torch::stable::detail::to<AtenTensorHandle>(stable_ivalue));
|
||||
return (c10::IValue(*torch::aot_inductor::tensor_handle_to_tensor_pointer(
|
||||
ret_raiiath.get())));
|
||||
}
|
||||
case c10::TypeKind::IntType: {
|
||||
return c10::IValue(torch::stable::detail::to<int64_t>(stable_ivalue));
|
||||
}
|
||||
case c10::TypeKind::FloatType: {
|
||||
return c10::IValue(torch::stable::detail::to<double>(stable_ivalue));
|
||||
}
|
||||
case c10::TypeKind::BoolType: {
|
||||
return c10::IValue(torch::stable::detail::to<bool>(stable_ivalue));
|
||||
}
|
||||
case c10::TypeKind::ScalarTypeType: {
|
||||
return c10::IValue(
|
||||
torch::stable::detail::to<c10::ScalarType>(stable_ivalue));
|
||||
}
|
||||
case c10::TypeKind::DeviceObjType: {
|
||||
return c10::IValue(torch::stable::detail::to<c10::Device>(stable_ivalue));
|
||||
}
|
||||
case c10::TypeKind::LayoutType: {
|
||||
return c10::IValue(torch::stable::detail::to<c10::Layout>(stable_ivalue));
|
||||
}
|
||||
case c10::TypeKind::MemoryFormatType: {
|
||||
return c10::IValue(
|
||||
torch::stable::detail::to<c10::MemoryFormat>(stable_ivalue));
|
||||
}
|
||||
case c10::TypeKind::OptionalType: {
|
||||
auto inner_type = type->castRaw<at::OptionalType>()->getElementType();
|
||||
|
||||
// ideally, if we had the C++ type corresponding to inner_type, which we
|
||||
// will denote as inner_type::t (does not actually exist), we would be
|
||||
// able to follow the patterned semantic of every other case here in one
|
||||
// line:
|
||||
//
|
||||
// return
|
||||
// c10::IValue(torch::stable::detail::to<std::optional<inner_type::t>>(stable_ivalue));
|
||||
//
|
||||
// BUT we do NOT have that type inner_type::t readily available, so we
|
||||
// will manually unwrap and recursively call. This implementation MUST
|
||||
// be kept in sync with the torch::stable::detail::to<T> function in
|
||||
// torch/csrc/stable/stableivalue_conversions.h
|
||||
if (stable_ivalue == torch::stable::detail::from(std::nullopt)) {
|
||||
return c10::IValue();
|
||||
}
|
||||
auto sivp = torch::stable::detail::to<StableIValue*>(stable_ivalue);
|
||||
auto ival = to_ivalue(inner_type, *sivp);
|
||||
delete sivp;
|
||||
return ival;
|
||||
}
|
||||
default: {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Not yet supported conversion from StableIValue to IValue for schema type: ",
|
||||
type->str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class StableIValueBoxedKernel : public c10::OperatorKernel {
|
||||
public:
|
||||
StableIValueBoxedKernel(void (*fn)(StableIValue*, uint64_t, uint64_t))
|
||||
: fn_(fn) {}
|
||||
|
||||
void operator()(
|
||||
const c10::OperatorHandle& op,
|
||||
c10::DispatchKeySet keyset,
|
||||
torch::jit::Stack* stack) {
|
||||
const auto& schema = op.schema();
|
||||
const auto num_returns = schema.returns().size();
|
||||
const auto num_arguments = schema.arguments().size();
|
||||
|
||||
auto ministack =
|
||||
std::make_unique<StableIValue[]>(std::max(num_arguments, num_returns));
|
||||
|
||||
for (const auto idx : c10::irange(num_arguments)) {
|
||||
const auto ministack_idx = num_arguments - idx - 1;
|
||||
const c10::TypePtr& arg_type = schema.arguments()[ministack_idx].type();
|
||||
ministack[ministack_idx] = from_ivalue(arg_type, torch::jit::pop(stack));
|
||||
}
|
||||
|
||||
// boxed function is going to take a stack of StableIValues, cast them to
|
||||
// our schema values, and run the function and modify the StableIValue stack
|
||||
fn_(ministack.get(), num_arguments, num_returns);
|
||||
|
||||
// read the output from the end of the stack and wrap that back into
|
||||
// IValue from StableIValue
|
||||
for (size_t idx = 0; idx < num_returns; idx++) {
|
||||
const c10::TypePtr& ret_type = schema.returns()[idx].type();
|
||||
torch::jit::push(stack, to_ivalue(ret_type, ministack[idx]));
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
void (*fn_)(StableIValue*, uint64_t, uint64_t);
|
||||
};
|
||||
|
||||
AOTITorchError aoti_torch_library_init_impl(
|
||||
const char* ns,
|
||||
const char* k,
|
||||
|
|
@ -1618,18 +1455,6 @@ AOTITorchError aoti_torch_library_init_fragment(
|
|||
});
|
||||
}
|
||||
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_library_impl(
|
||||
TorchLibraryHandle self,
|
||||
const char* name,
|
||||
void (*fn)(StableIValue*, uint64_t, uint64_t)) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||
reinterpret_cast<torch::Library*>(self)->impl(
|
||||
name,
|
||||
torch::CppFunction::makeFromBoxedFunctor(
|
||||
std::make_unique<StableIValueBoxedKernel>(fn)));
|
||||
});
|
||||
}
|
||||
|
||||
AOTI_TORCH_EXPORT AOTITorchError
|
||||
aoti_torch_library_def(TorchLibraryHandle self, const char* name) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE(
|
||||
|
|
@ -1642,40 +1467,6 @@ aoti_torch_delete_library_object(TorchLibraryHandle tlh) {
|
|||
{ delete reinterpret_cast<torch::Library*>(tlh); });
|
||||
}
|
||||
|
||||
AOTITorchError aoti_torch_call_dispatcher(
|
||||
const char* opName,
|
||||
const char* overloadName,
|
||||
StableIValue* stack) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||
const auto op =
|
||||
c10::Dispatcher::singleton().findSchemaOrThrow(opName, overloadName);
|
||||
const auto& schema = op.schema();
|
||||
const auto num_returns = schema.returns().size();
|
||||
const auto num_arguments = schema.arguments().size();
|
||||
|
||||
torch::jit::Stack ivalue_stack;
|
||||
// we will only need max(num_args, num_returns)
|
||||
ivalue_stack.reserve(std::max(num_arguments, num_returns));
|
||||
|
||||
// convert StableIValue stack to c10::IValue stack
|
||||
for (const auto idx : c10::irange(num_arguments)) {
|
||||
auto stable_ivalue = stack[idx];
|
||||
auto arg_type = schema.arguments()[idx].type();
|
||||
torch::jit::push(ivalue_stack, to_ivalue(arg_type, stable_ivalue));
|
||||
}
|
||||
|
||||
op.callBoxed(ivalue_stack);
|
||||
|
||||
// there should then be num_returns IValues on the stack, which
|
||||
// we will convert to StableIValue and repopulate user input stack
|
||||
for (const auto idx : c10::irange(num_returns)) {
|
||||
const auto stack_idx = num_returns - idx - 1;
|
||||
const c10::TypePtr& ret_type = schema.returns()[idx].type();
|
||||
stack[stack_idx] = from_ivalue(ret_type, torch::jit::pop(ivalue_stack));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
AOTITorchError aoti_torch_create_device_guard(
|
||||
int32_t device_index,
|
||||
DeviceGuardHandle* ret_guard // returns new reference
|
||||
|
|
|
|||
218
torch/csrc/shim_common.cpp
Normal file
218
torch/csrc/shim_common.cpp
Normal file
|
|
@ -0,0 +1,218 @@
|
|||
#include <c10/core/DispatchKey.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <torch/csrc/inductor/aoti_runtime/utils.h>
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||
#include <torch/csrc/inductor/aoti_torch/tensor_converter.h>
|
||||
#include <torch/csrc/inductor/aoti_torch/utils.h>
|
||||
#include <torch/csrc/jit/serialization/pickle.h>
|
||||
#include <torch/csrc/stable/library.h>
|
||||
#include <torch/library.h>
|
||||
|
||||
static StableIValue from_ivalue(
|
||||
const c10::TypePtr& type,
|
||||
const c10::IValue& ivalue) {
|
||||
switch (type->kind()) {
|
||||
case c10::TypeKind::TensorType: {
|
||||
AtenTensorHandle ath = torch::aot_inductor::new_tensor_handle(
|
||||
std::move(const_cast<at::Tensor&>(ivalue.toTensor())));
|
||||
return torch::stable::detail::from(ath);
|
||||
}
|
||||
case c10::TypeKind::IntType: {
|
||||
return torch::stable::detail::from(ivalue.toInt());
|
||||
}
|
||||
case c10::TypeKind::FloatType: {
|
||||
return torch::stable::detail::from(ivalue.toDouble());
|
||||
}
|
||||
case c10::TypeKind::BoolType: {
|
||||
return torch::stable::detail::from(ivalue.toBool());
|
||||
}
|
||||
case c10::TypeKind::ScalarTypeType: {
|
||||
return torch::stable::detail::from(ivalue.toScalarType());
|
||||
}
|
||||
case c10::TypeKind::DeviceObjType: {
|
||||
return torch::stable::detail::from(ivalue.toDevice());
|
||||
}
|
||||
case c10::TypeKind::LayoutType: {
|
||||
return torch::stable::detail::from(ivalue.toLayout());
|
||||
}
|
||||
case c10::TypeKind::MemoryFormatType: {
|
||||
return torch::stable::detail::from(ivalue.toMemoryFormat());
|
||||
}
|
||||
case c10::TypeKind::OptionalType: {
|
||||
auto inner_type = type->castRaw<at::OptionalType>()->getElementType();
|
||||
|
||||
// ideally, if we had the C++ type corresponding to inner_type, which we
|
||||
// will denote as inner_type::t (does not actually exist), we would be
|
||||
// able to follow the patterned semantic of every other case here in one
|
||||
// line:
|
||||
//
|
||||
// return
|
||||
// torch::stable::detail::from<std::optional<inner_type::t>>(ivalue.toInnerTypeT()));
|
||||
//
|
||||
// BUT we do NOT have that type inner_type::t readily available, so we
|
||||
// will manually unwrap and recursively call. This implementation MUST
|
||||
// be kept in sync with torch::stable::detail::from<std::optional<T>>
|
||||
// function in torch/csrc/stable/stableivalue_conversions.h
|
||||
if (ivalue.isNone()) {
|
||||
return torch::stable::detail::from(std::nullopt);
|
||||
}
|
||||
StableIValue* sivp = new StableIValue(from_ivalue(inner_type, ivalue));
|
||||
return torch::stable::detail::from(sivp);
|
||||
}
|
||||
default: {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Not yet supported conversion from IValue to StableIValue for schema type: ",
|
||||
type->str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static c10::IValue to_ivalue(
|
||||
const c10::TypePtr& type,
|
||||
const StableIValue stable_ivalue) {
|
||||
switch (type->kind()) {
|
||||
case c10::TypeKind::TensorType: {
|
||||
auto ret_raiiath = torch::aot_inductor::RAIIAtenTensorHandle(
|
||||
torch::stable::detail::to<AtenTensorHandle>(stable_ivalue));
|
||||
return (c10::IValue(*torch::aot_inductor::tensor_handle_to_tensor_pointer(
|
||||
ret_raiiath.get())));
|
||||
}
|
||||
case c10::TypeKind::IntType: {
|
||||
return c10::IValue(torch::stable::detail::to<int64_t>(stable_ivalue));
|
||||
}
|
||||
case c10::TypeKind::FloatType: {
|
||||
return c10::IValue(torch::stable::detail::to<double>(stable_ivalue));
|
||||
}
|
||||
case c10::TypeKind::BoolType: {
|
||||
return c10::IValue(torch::stable::detail::to<bool>(stable_ivalue));
|
||||
}
|
||||
case c10::TypeKind::ScalarTypeType: {
|
||||
return c10::IValue(
|
||||
torch::stable::detail::to<c10::ScalarType>(stable_ivalue));
|
||||
}
|
||||
case c10::TypeKind::DeviceObjType: {
|
||||
return c10::IValue(torch::stable::detail::to<c10::Device>(stable_ivalue));
|
||||
}
|
||||
case c10::TypeKind::LayoutType: {
|
||||
return c10::IValue(torch::stable::detail::to<c10::Layout>(stable_ivalue));
|
||||
}
|
||||
case c10::TypeKind::MemoryFormatType: {
|
||||
return c10::IValue(
|
||||
torch::stable::detail::to<c10::MemoryFormat>(stable_ivalue));
|
||||
}
|
||||
case c10::TypeKind::OptionalType: {
|
||||
auto inner_type = type->castRaw<at::OptionalType>()->getElementType();
|
||||
|
||||
// ideally, if we had the C++ type corresponding to inner_type, which we
|
||||
// will denote as inner_type::t (does not actually exist), we would be
|
||||
// able to follow the patterned semantic of every other case here in one
|
||||
// line:
|
||||
//
|
||||
// return
|
||||
// c10::IValue(torch::stable::detail::to<std::optional<inner_type::t>>(stable_ivalue));
|
||||
//
|
||||
// BUT we do NOT have that type inner_type::t readily available, so we
|
||||
// will manually unwrap and recursively call. This implementation MUST
|
||||
// be kept in sync with the torch::stable::detail::to<T> function in
|
||||
// torch/csrc/stable/stableivalue_conversions.h
|
||||
if (stable_ivalue == torch::stable::detail::from(std::nullopt)) {
|
||||
return c10::IValue();
|
||||
}
|
||||
auto sivp = torch::stable::detail::to<StableIValue*>(stable_ivalue);
|
||||
auto ival = to_ivalue(inner_type, *sivp);
|
||||
delete sivp;
|
||||
return ival;
|
||||
}
|
||||
default: {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Not yet supported conversion from StableIValue to IValue for schema type: ",
|
||||
type->str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class StableIValueBoxedKernel : public c10::OperatorKernel {
|
||||
public:
|
||||
StableIValueBoxedKernel(void (*fn)(StableIValue*, uint64_t, uint64_t))
|
||||
: fn_(fn) {}
|
||||
|
||||
void operator()(
|
||||
const c10::OperatorHandle& op,
|
||||
c10::DispatchKeySet keyset,
|
||||
torch::jit::Stack* stack) {
|
||||
const auto& schema = op.schema();
|
||||
const auto num_returns = schema.returns().size();
|
||||
const auto num_arguments = schema.arguments().size();
|
||||
|
||||
auto ministack =
|
||||
std::make_unique<StableIValue[]>(std::max(num_arguments, num_returns));
|
||||
|
||||
for (const auto idx : c10::irange(num_arguments)) {
|
||||
const auto ministack_idx = num_arguments - idx - 1;
|
||||
const c10::TypePtr& arg_type = schema.arguments()[ministack_idx].type();
|
||||
ministack[ministack_idx] = from_ivalue(arg_type, torch::jit::pop(stack));
|
||||
}
|
||||
|
||||
// boxed function is going to take a stack of StableIValues, cast them to
|
||||
// our schema values, and run the function and modify the StableIValue stack
|
||||
fn_(ministack.get(), num_arguments, num_returns);
|
||||
|
||||
// read the output from the end of the stack and wrap that back into
|
||||
// IValue from StableIValue
|
||||
for (size_t idx = 0; idx < num_returns; idx++) {
|
||||
const c10::TypePtr& ret_type = schema.returns()[idx].type();
|
||||
torch::jit::push(stack, to_ivalue(ret_type, ministack[idx]));
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
void (*fn_)(StableIValue*, uint64_t, uint64_t);
|
||||
};
|
||||
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_library_impl(
|
||||
TorchLibraryHandle self,
|
||||
const char* name,
|
||||
void (*fn)(StableIValue*, uint64_t, uint64_t)) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||
reinterpret_cast<torch::Library*>(self)->impl(
|
||||
name,
|
||||
torch::CppFunction::makeFromBoxedFunctor(
|
||||
std::make_unique<StableIValueBoxedKernel>(fn)));
|
||||
});
|
||||
}
|
||||
|
||||
AOTITorchError aoti_torch_call_dispatcher(
|
||||
const char* opName,
|
||||
const char* overloadName,
|
||||
StableIValue* stack) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||
const auto op =
|
||||
c10::Dispatcher::singleton().findSchemaOrThrow(opName, overloadName);
|
||||
const auto& schema = op.schema();
|
||||
const auto num_returns = schema.returns().size();
|
||||
const auto num_arguments = schema.arguments().size();
|
||||
|
||||
torch::jit::Stack ivalue_stack;
|
||||
// we will only need max(num_args, num_returns)
|
||||
ivalue_stack.reserve(std::max(num_arguments, num_returns));
|
||||
|
||||
// convert StableIValue stack to c10::IValue stack
|
||||
for (const auto idx : c10::irange(num_arguments)) {
|
||||
auto stable_ivalue = stack[idx];
|
||||
auto arg_type = schema.arguments()[idx].type();
|
||||
torch::jit::push(ivalue_stack, to_ivalue(arg_type, stable_ivalue));
|
||||
}
|
||||
|
||||
op.callBoxed(ivalue_stack);
|
||||
|
||||
// there should then be num_returns IValues on the stack, which
|
||||
// we will convert to StableIValue and repopulate user input stack
|
||||
for (const auto idx : c10::irange(num_returns)) {
|
||||
const auto stack_idx = num_returns - idx - 1;
|
||||
const c10::TypePtr& ret_type = schema.returns()[idx].type();
|
||||
stack[stack_idx] = from_ivalue(ret_type, torch::jit::pop(ivalue_stack));
|
||||
}
|
||||
});
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user