mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
This is a proof-of-concept of how we could serialize a guard and deserialize it back from the bytes. The main behavioral change introduced in this diff is on CheckFunctionManager: ``` check_fn_manager = CheckFunctionManager(code, output_graph, guards_serialization_mode="save") guards_state: bytes = check_fn_manager.guards_state ``` Once `guards_serialization_mode` is set to `save`, CheckFunctionManager will return an addtional `bytes` object called `guards_state` which should contain all the information needed for deserializing guards later. When we load back guards state, we will set `guards_serialization_mode` is set to `load`: ``` output_graph_state = pickle.loads(guards_state) check_fn_manager = CheckFunctionManager(code, output_graph_state, guards_serialization_mode="load") ``` # TENSOR_MATCH Since we have many types of guards to support, we will break the work into small diffs instead of a single diff to support every guards. We kick off the work from TENSOR_MATCH from this diff. # Testing For each type of guard we will test it like the following: 1. Use guard_filter_fn to select 1 type of guard each time. 2. Call InstructionTranslator directly on an example function to get OutputGraph and CheckFunctionManager (reference guard manager) 3. Serialize->deserialize the output graph state and re-build the guards with a new CheckFunctionManager (loaded guard manager) 4. Throw a set of example inputs to both reference and loaded guard manager to see if their behavior match. Pull Request resolved: https://github.com/pytorch/pytorch/pull/151318 Approved by: https://github.com/jansel, https://github.com/anijain2305
92 lines
2.9 KiB
C++
92 lines
2.9 KiB
C++
#pragma once
|
|
#include <c10/core/GradMode.h>
|
|
#include <torch/csrc/dynamo/framelocals_mapping.h>
|
|
#include <torch/csrc/python_headers.h>
|
|
#include <torch/csrc/utils/pybind.h>
|
|
|
|
namespace torch::dynamo {
|
|
|
|
PyObject* torch_c_dynamo_guards_init();
|
|
|
|
// interfaces for extra_state and eval_frame.c because RootGuardManager class is
|
|
// not visible there.
|
|
void* convert_to_root_guard_manager(py::object root);
|
|
bool run_root_guard_manager(void* root, FrameLocalsMapping* f_locals);
|
|
|
|
struct LocalState {
|
|
// TLS state that changes operators
|
|
c10::impl::LocalDispatchKeySet dispatch_modifier;
|
|
c10::DispatchKeySet override_dispatch_key_set;
|
|
bool grad_mode_enabled;
|
|
|
|
at::DispatchKeySet apply(at::DispatchKeySet ks) const {
|
|
if (override_dispatch_key_set.empty()) {
|
|
return (ks | dispatch_modifier.included_) - dispatch_modifier.excluded_;
|
|
} else {
|
|
return override_dispatch_key_set;
|
|
}
|
|
}
|
|
|
|
LocalState()
|
|
: dispatch_modifier(c10::impl::tls_local_dispatch_key_set()),
|
|
override_dispatch_key_set(c10::BackendComponent::InvalidBit),
|
|
grad_mode_enabled(at::GradMode::is_enabled()) {}
|
|
|
|
void overrideDispatchKeySet(c10::DispatchKeySet ks) {
|
|
override_dispatch_key_set = ks;
|
|
}
|
|
};
|
|
|
|
class TensorCheck {
|
|
public:
|
|
TensorCheck(
|
|
const LocalState& state,
|
|
PyTypeObject* pt,
|
|
const at::Tensor& v,
|
|
c10::DispatchKeySet dispatch_key_set,
|
|
std::vector<std::optional<c10::SymInt>> dynamic_dims_sizes,
|
|
std::vector<std::optional<c10::SymInt>> dynamic_dims_strides);
|
|
|
|
TensorCheck(
|
|
const LocalState& state,
|
|
PyTypeObject* pt,
|
|
c10::DispatchKeySet dispatch_key_set,
|
|
at::ScalarType dtype,
|
|
at::DeviceIndex device_index,
|
|
bool requires_grad,
|
|
std::vector<std::optional<c10::SymInt>> dynamic_dims_sizes,
|
|
std::vector<std::optional<c10::SymInt>> dynamic_dims_strides);
|
|
|
|
bool check(const LocalState& state, const at::Tensor& v);
|
|
bool check(
|
|
const LocalState& state,
|
|
const c10::DispatchKeySet& dispatch_key_set,
|
|
const at::ScalarType& dtype,
|
|
const c10::Device& device,
|
|
const c10::SymIntArrayRef& dynamic_dims_sizes,
|
|
const c10::SymIntArrayRef& dynamic_dims_strides,
|
|
const bool& requires_grad);
|
|
std::string check_verbose(
|
|
const LocalState& state,
|
|
const at::Tensor& v,
|
|
const std::string& tensor_name);
|
|
|
|
PyTypeObject* pytype;
|
|
|
|
private:
|
|
uint64_t dispatch_key_; // DispatchKeySet includes device/layout
|
|
at::ScalarType dtype_;
|
|
// Note(voz): While dispatch_key_ is sufficiently representative of a device
|
|
// In that keys are more granular AND device specific - they do not
|
|
// necessarily capture device indices correctly.
|
|
at::DeviceIndex device_index_;
|
|
bool requires_grad_;
|
|
// NB: These are unset if dynamic shapes is enabled.
|
|
std::vector<std::optional<c10::SymInt>> sizes_;
|
|
std::vector<std::optional<c10::SymInt>> strides_;
|
|
// Not strictly required for dense tensors, but nested tensors need it.
|
|
int64_t dim_;
|
|
};
|
|
|
|
} // namespace torch::dynamo
|