// this file can only have stable stuff! Akin to shim.h // but unlike shim.h, this file can contain header-only C++ // code for better UX. #include // use anonymous namespace to avoid collisions between differing // versions of this file that may be included by different sources namespace { // helpers for converting between StableIValue and actual IValues template StableIValue from(T val) { static_assert( sizeof(T) <= sizeof(StableIValue), "StableLibrary stack does not support parameter types larger than 64 bits."); return *reinterpret_cast(&val); } template T to(StableIValue val) { return *reinterpret_cast(&val); } // end to helpers for converting between StableIValue and actual IValues class StableLibrary final { private: TorchLibraryHandle lib_; public: enum class Kind { DEF, IMPL, FRAGMENT, }; // constructor /// \private /// /// Use STABLE_TORCH_LIBRARY or STABLE_TORCH_LIBRARY_IMPL() instead of using /// these constructors directly StableLibrary( Kind kind, const char* ns, const char* k, const char* file, uint32_t line) { if (kind == Kind::IMPL) { aoti_torch_library_init_impl(ns, k, file, line, &lib_); } else if (kind == Kind::DEF) { aoti_torch_library_init_def(ns, file, line, &lib_); } else { // kind == FRAGMENT aoti_torch_library_init_fragment(ns, file, line, &lib_); } } // do not permit copy StableLibrary(const StableLibrary&) = delete; StableLibrary& operator=(const StableLibrary&) = delete; // do not permit move StableLibrary(StableLibrary&& other) = delete; StableLibrary& operator=(StableLibrary&& other) = delete; ~StableLibrary() { aoti_torch_delete_library_object(lib_); } // corresponds to a limited, stable version of torch::library::impl() // Inputs: // name: the name of the function to implement // fn: a boxed function with schema // (StableIValue* stack, uint64_t num_inputs, uint64_t num_outputs) -> // void // fn should follow the calling convention of our boxed kernels that convert // to IValues. fn will be called with a StableIValue* array of length // max(num_inputs, num_outputs), where the first num_inputs entries are // populated with inputs. fn is responsible for stealing the memory of the // inputs, in effect "popping" them off the stack, and then populating the // stack with StableIValue outputs. Concretely, fn should: // 1. read StableIValue inputs from the given stack // 2. convert the inputs to the proper types // 3. call the function corresponding to name with the inputs // 4. convert the outputs to StableIValues // 5. populate the now empty stack with StableIValue outputs // If the operation corresponding to name takes in 4 inputs and returns 2 // outputs, fn should expect stack to contain 4 StableIValues: // [stable_arg1, stable_arg2, stable_arg3, stable_arg4] // to end, fn should fill the stack with 2 StableIValues representing outputs: // [stable_ret1, stable_ret2, -, -] StableLibrary& impl( const char* name, void (*fn)(StableIValue*, uint64_t, uint64_t)) { aoti_torch_library_impl(lib_, name, fn); return *this; } // corresponds to a limited, stable version of torch::library::def() StableLibrary& def(const char* schema) { aoti_torch_library_def(lib_, schema); return *this; } }; class StableTorchLibraryInit final { private: using InitFn = void(StableLibrary&); StableLibrary lib_; public: StableTorchLibraryInit( StableLibrary::Kind kind, InitFn* fn, const char* ns, const char* k, const char* file, uint32_t line) : lib_(kind, ns, k, file, line) { fn(lib_); } }; } // namespace // macros copied from c10/macros/Macros.h #ifdef __COUNTER__ #define STABLE_UID __COUNTER__ #else #define STABLE_UID __LINE__ #endif #define STABLE_CONCATENATE_IMPL(s1, s2) s1##s2 #define STABLE_CONCATENATE(s1, s2) STABLE_CONCATENATE_IMPL(s1, s2) // end of macros copied from c10/macros/Macros.h #define STABLE_TORCH_LIBRARY_IMPL(ns, k, m) \ _STABLE_TORCH_LIBRARY_IMPL(ns, k, m, STABLE_UID) #define _STABLE_TORCH_LIBRARY_IMPL(ns, k, m, uid) \ static void STABLE_CONCATENATE( \ STABLE_TORCH_LIBRARY_IMPL_init_##ns##_##k##_, uid)(StableLibrary&); \ static const StableTorchLibraryInit STABLE_CONCATENATE( \ STABLE_TORCH_LIBRARY_IMPL_static_init_##ns##_##k##_, uid)( \ StableLibrary::Kind::IMPL, \ &STABLE_CONCATENATE(STABLE_TORCH_LIBRARY_IMPL_init_##ns##_##k##_, uid), \ #ns, \ #k, \ __FILE__, \ __LINE__); \ void STABLE_CONCATENATE( \ STABLE_TORCH_LIBRARY_IMPL_init_##ns##_##k##_, uid)(StableLibrary & m) #define STABLE_TORCH_LIBRARY(ns, m) \ static void STABLE_TORCH_LIBRARY_init_##ns(StableLibrary&); \ static const StableTorchLibraryInit STABLE_TORCH_LIBRARY_static_init_##ns( \ StableLibrary::Kind::DEF, \ &STABLE_TORCH_LIBRARY_init_##ns, \ #ns, \ nullptr, \ __FILE__, \ __LINE__); \ void STABLE_TORCH_LIBRARY_init_##ns(StableLibrary& m) #define STABLE_TORCH_LIBRARY_FRAGMENT(ns, m) \ _STABLE_TORCH_LIBRARY_FRAGMENT(ns, m, STABLE_UID) #define _STABLE_TORCH_LIBRARY_FRAGMENT(ns, m, uid) \ static void STABLE_CONCATENATE( \ STABLE_TORCH_LIBRARY_FRAGMENT_init_##ns##_, uid)(StableLibrary&); \ static const StableTorchLibraryInit STABLE_CONCATENATE( \ STABLE_TORCH_LIBRARY_FRAGMENT_static_init_##ns##_, uid)( \ StableLibrary::Kind::FRAGMENT, \ &STABLE_CONCATENATE(STABLE_TORCH_LIBRARY_FRAGMENT_init_##ns##_, uid), \ #ns, \ nullptr, \ __FILE__, \ __LINE__); \ void STABLE_CONCATENATE( \ STABLE_TORCH_LIBRARY_FRAGMENT_init_##ns##_, uid)(StableLibrary & m)