mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Part 1 of plan in https://docs.google.com/document/d/1MaX51H5aEQE5XnOlnZIpf9oCYwzGrTWkgBACxNzsmWE/edit?usp=sharing - Upgrade `aoti_torch_call_dispatcher` to v2 with an `extension_build_version` - Allow registration of StableIValue stack --> IValue stack adapters for schema changes #### Note: This PR does not include a linter that tells the user to add the upgrader if the schema changes, which is an important piece that will be added in a separate PR Pull Request resolved: https://github.com/pytorch/pytorch/pull/163683 Approved by: https://github.com/janeyx99 ghstack dependencies: #164356, #166373
47 lines
1.5 KiB
C
47 lines
1.5 KiB
C
#ifndef STABLE_TORCH_SHIM
|
|
#define STABLE_TORCH_SHIM
|
|
|
|
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
|
|
|
#include <torch/csrc/stable/version.h>
|
|
|
|
// This header defines stable C API extensions for backward/forward
|
|
// compatibility when calling ATen operations through the dispatcher.
|
|
//
|
|
// This is separate from the main AOTI shim to provide versioning capabilities
|
|
// for schema changes in native ATen functions.
|
|
|
|
#ifdef __cplusplus
|
|
extern "C" {
|
|
#endif
|
|
|
|
#if TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0
|
|
using StableIValue = uint64_t;
|
|
|
|
// Has the same semantic as aoti_torch_call_dispatcher, but takes an
|
|
// additional argument for the extension build version. This is
|
|
// needed for backward compatibility when calling native functions via
|
|
// the dispatcher. The caller should pass in the libtorch version the
|
|
// extension is building with (NOT target version).
|
|
AOTI_TORCH_EXPORT AOTITorchError torch_call_dispatcher(
|
|
const char* opName,
|
|
const char* overloadName,
|
|
StableIValue* stack,
|
|
uint64_t extension_build_version);
|
|
|
|
// Version-aware variant of aoti_torch_library_impl that takes an
|
|
// extension_build_version parameter for backward compatibility
|
|
AOTI_TORCH_EXPORT AOTITorchError torch_library_impl(
|
|
TorchLibraryHandle self,
|
|
const char* name,
|
|
void (*fn)(StableIValue*, uint64_t, uint64_t),
|
|
uint64_t extension_build_version);
|
|
|
|
#endif // TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0
|
|
|
|
#ifdef __cplusplus
|
|
} // extern "C"
|
|
#endif
|
|
|
|
#endif // STABLE_TORCH_SHIM
|