pytorch/torch/csrc/stable/library.h
Jane Xu 971606befa Add a stable TORCH_LIBRARY to C shim (#148124)
This PR adds two main parts:
- shim.h stable C APIs into torch::Library APIs
- a higher level API in torch/csrc/stable/library.h that calls into this shim.h + otherwise is self contained

Goal: custom kernel writers should be able to call the apis in the directories above in order to register their library in a way that allows their custom extension to run with a different libtorch version than it was built with.

Subplots resolved:

- Do we want a whole separate StableLibrary or do we want to freeze torch::Library and add `m.stable_impl(cstring, void (*fn)(void **, int64_t, int64_t)` into it
    - Yes, we want a separate StableLibrary. We cannot freeze Library and it is NOT header only.
- Should I use unint64_t as the common denominator instead of void* to support 32bit architectures better?
    -  Yes, and done
- Should I add a stable `def` and `fragment` when those can be done in python?
    - I think we do want these --- and now they're done
- Where should library_stable_impl.cpp live? -- no longer relevant
- I need some solid test cases to make sure everything's going ok. I've intentionally thrown in a bunch of random dtypes into the signature, but I still haven't tested returning multiple things, returning nothing, complex dtypes, etc.
    - Have since tested all the torch library endpoints. the others can be tested in a followup to separate components that need to be in shim.h vs can be added later

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148124
Approved by: https://github.com/albanD, https://github.com/zou3519, https://github.com/atalman
2025-03-11 19:12:46 +00:00

180 lines
7.0 KiB
C++

// this file can only have stable stuff! Akin to shim.h
// but unlike shim.h, this file can contain header-only C++
// code for better UX.
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
// use anonymous namespace to avoid collisions between differing
// versions of this file that may be included by different sources
namespace {
// helpers for converting between StableIValue and actual IValues
template <typename T>
StableIValue from(T val) {
static_assert(
sizeof(T) <= sizeof(StableIValue),
"StableLibrary stack does not support parameter types larger than 64 bits.");
return *reinterpret_cast<StableIValue*>(&val);
}
template <typename T>
T to(StableIValue val) {
return *reinterpret_cast<T*>(&val);
}
// end to helpers for converting between StableIValue and actual IValues
class StableLibrary final {
private:
TorchLibraryHandle lib_;
public:
enum class Kind {
DEF,
IMPL,
FRAGMENT,
};
// constructor
/// \private
///
/// Use STABLE_TORCH_LIBRARY or STABLE_TORCH_LIBRARY_IMPL() instead of using
/// these constructors directly
StableLibrary(
Kind kind,
const char* ns,
const char* k,
const char* file,
uint32_t line) {
if (kind == Kind::IMPL) {
aoti_torch_library_init_impl(ns, k, file, line, &lib_);
} else if (kind == Kind::DEF) {
aoti_torch_library_init_def(ns, file, line, &lib_);
} else { // kind == FRAGMENT
aoti_torch_library_init_fragment(ns, file, line, &lib_);
}
}
// do not permit copy
StableLibrary(const StableLibrary&) = delete;
StableLibrary& operator=(const StableLibrary&) = delete;
// do not permit move
StableLibrary(StableLibrary&& other) = delete;
StableLibrary& operator=(StableLibrary&& other) = delete;
~StableLibrary() {
aoti_torch_delete_library_object(lib_);
}
// corresponds to a limited, stable version of torch::library::impl()
// Inputs:
// name: the name of the function to implement
// fn: a boxed function with schema
// (StableIValue* stack, uint64_t num_inputs, uint64_t num_outputs) ->
// void
// fn should follow the calling convention of our boxed kernels that convert
// to IValues. fn will be called with a StableIValue* array of length
// max(num_inputs, num_outputs), where the first num_inputs entries are
// populated with inputs. fn is responsible for stealing the memory of the
// inputs, in effect "popping" them off the stack, and then populating the
// stack with StableIValue outputs. Concretely, fn should:
// 1. read StableIValue inputs from the given stack
// 2. convert the inputs to the proper types
// 3. call the function corresponding to name with the inputs
// 4. convert the outputs to StableIValues
// 5. populate the now empty stack with StableIValue outputs
// If the operation corresponding to name takes in 4 inputs and returns 2
// outputs, fn should expect stack to contain 4 StableIValues:
// [stable_arg1, stable_arg2, stable_arg3, stable_arg4]
// to end, fn should fill the stack with 2 StableIValues representing outputs:
// [stable_ret1, stable_ret2, -, -]
StableLibrary& impl(
const char* name,
void (*fn)(StableIValue*, uint64_t, uint64_t)) {
aoti_torch_library_impl(lib_, name, fn);
return *this;
}
// corresponds to a limited, stable version of torch::library::def()
StableLibrary& def(const char* schema) {
aoti_torch_library_def(lib_, schema);
return *this;
}
};
class StableTorchLibraryInit final {
private:
using InitFn = void(StableLibrary&);
StableLibrary lib_;
public:
StableTorchLibraryInit(
StableLibrary::Kind kind,
InitFn* fn,
const char* ns,
const char* k,
const char* file,
uint32_t line)
: lib_(kind, ns, k, file, line) {
fn(lib_);
}
};
} // namespace
// macros copied from c10/macros/Macros.h
#ifdef __COUNTER__
#define STABLE_UID __COUNTER__
#else
#define STABLE_UID __LINE__
#endif
#define STABLE_CONCATENATE_IMPL(s1, s2) s1##s2
#define STABLE_CONCATENATE(s1, s2) STABLE_CONCATENATE_IMPL(s1, s2)
// end of macros copied from c10/macros/Macros.h
#define STABLE_TORCH_LIBRARY_IMPL(ns, k, m) \
_STABLE_TORCH_LIBRARY_IMPL(ns, k, m, STABLE_UID)
#define _STABLE_TORCH_LIBRARY_IMPL(ns, k, m, uid) \
static void STABLE_CONCATENATE( \
STABLE_TORCH_LIBRARY_IMPL_init_##ns##_##k##_, uid)(StableLibrary&); \
static const StableTorchLibraryInit STABLE_CONCATENATE( \
STABLE_TORCH_LIBRARY_IMPL_static_init_##ns##_##k##_, uid)( \
StableLibrary::Kind::IMPL, \
&STABLE_CONCATENATE(STABLE_TORCH_LIBRARY_IMPL_init_##ns##_##k##_, uid), \
#ns, \
#k, \
__FILE__, \
__LINE__); \
void STABLE_CONCATENATE( \
STABLE_TORCH_LIBRARY_IMPL_init_##ns##_##k##_, uid)(StableLibrary & m)
#define STABLE_TORCH_LIBRARY(ns, m) \
static void STABLE_TORCH_LIBRARY_init_##ns(StableLibrary&); \
static const StableTorchLibraryInit STABLE_TORCH_LIBRARY_static_init_##ns( \
StableLibrary::Kind::DEF, \
&STABLE_TORCH_LIBRARY_init_##ns, \
#ns, \
nullptr, \
__FILE__, \
__LINE__); \
void STABLE_TORCH_LIBRARY_init_##ns(StableLibrary& m)
#define STABLE_TORCH_LIBRARY_FRAGMENT(ns, m) \
_STABLE_TORCH_LIBRARY_FRAGMENT(ns, m, STABLE_UID)
#define _STABLE_TORCH_LIBRARY_FRAGMENT(ns, m, uid) \
static void STABLE_CONCATENATE( \
STABLE_TORCH_LIBRARY_FRAGMENT_init_##ns##_, uid)(StableLibrary&); \
static const StableTorchLibraryInit STABLE_CONCATENATE( \
STABLE_TORCH_LIBRARY_FRAGMENT_static_init_##ns##_, uid)( \
StableLibrary::Kind::FRAGMENT, \
&STABLE_CONCATENATE(STABLE_TORCH_LIBRARY_FRAGMENT_init_##ns##_, uid), \
#ns, \
nullptr, \
__FILE__, \
__LINE__); \
void STABLE_CONCATENATE( \
STABLE_TORCH_LIBRARY_FRAGMENT_init_##ns##_, uid)(StableLibrary & m)