pytorch/torch/csrc/dynamo/guards.h
zhxchen17 a34c28e0d2 [dynamo] Add guard serialization for tensor matches. (#151318)
This is a proof-of-concept of how we could serialize a guard and deserialize it back from the bytes.

The main behavioral change introduced in this diff is on CheckFunctionManager:

```
check_fn_manager = CheckFunctionManager(code, output_graph, guards_serialization_mode="save")

guards_state: bytes = check_fn_manager.guards_state
```

Once `guards_serialization_mode` is set to `save`, CheckFunctionManager will return an addtional `bytes` object called `guards_state` which should contain all the information needed for deserializing guards later.

When we load back guards state, we will set `guards_serialization_mode` is set to `load`:

```
output_graph_state = pickle.loads(guards_state)
check_fn_manager = CheckFunctionManager(code, output_graph_state, guards_serialization_mode="load")
```

# TENSOR_MATCH

Since we have many types of guards to support, we will break the work into small diffs instead of a single diff to support every guards.

We kick off the work from TENSOR_MATCH from this diff.

# Testing

For each type of guard we will test it like the following:
1. Use guard_filter_fn to select 1 type of guard each time.
2. Call InstructionTranslator directly on an example function to get OutputGraph and CheckFunctionManager (reference guard manager)
3. Serialize->deserialize the output graph state and re-build the guards with a new CheckFunctionManager (loaded guard manager)
4. Throw a set of example inputs to both reference and loaded guard manager to see if their behavior match.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151318
Approved by: https://github.com/jansel, https://github.com/anijain2305
2025-04-25 14:16:23 +00:00

92 lines
2.9 KiB
C++

#pragma once
#include <c10/core/GradMode.h>
#include <torch/csrc/dynamo/framelocals_mapping.h>
#include <torch/csrc/python_headers.h>
#include <torch/csrc/utils/pybind.h>
namespace torch::dynamo {
PyObject* torch_c_dynamo_guards_init();
// interfaces for extra_state and eval_frame.c because RootGuardManager class is
// not visible there.
void* convert_to_root_guard_manager(py::object root);
bool run_root_guard_manager(void* root, FrameLocalsMapping* f_locals);
struct LocalState {
// TLS state that changes operators
c10::impl::LocalDispatchKeySet dispatch_modifier;
c10::DispatchKeySet override_dispatch_key_set;
bool grad_mode_enabled;
at::DispatchKeySet apply(at::DispatchKeySet ks) const {
if (override_dispatch_key_set.empty()) {
return (ks | dispatch_modifier.included_) - dispatch_modifier.excluded_;
} else {
return override_dispatch_key_set;
}
}
LocalState()
: dispatch_modifier(c10::impl::tls_local_dispatch_key_set()),
override_dispatch_key_set(c10::BackendComponent::InvalidBit),
grad_mode_enabled(at::GradMode::is_enabled()) {}
void overrideDispatchKeySet(c10::DispatchKeySet ks) {
override_dispatch_key_set = ks;
}
};
class TensorCheck {
public:
TensorCheck(
const LocalState& state,
PyTypeObject* pt,
const at::Tensor& v,
c10::DispatchKeySet dispatch_key_set,
std::vector<std::optional<c10::SymInt>> dynamic_dims_sizes,
std::vector<std::optional<c10::SymInt>> dynamic_dims_strides);
TensorCheck(
const LocalState& state,
PyTypeObject* pt,
c10::DispatchKeySet dispatch_key_set,
at::ScalarType dtype,
at::DeviceIndex device_index,
bool requires_grad,
std::vector<std::optional<c10::SymInt>> dynamic_dims_sizes,
std::vector<std::optional<c10::SymInt>> dynamic_dims_strides);
bool check(const LocalState& state, const at::Tensor& v);
bool check(
const LocalState& state,
const c10::DispatchKeySet& dispatch_key_set,
const at::ScalarType& dtype,
const c10::Device& device,
const c10::SymIntArrayRef& dynamic_dims_sizes,
const c10::SymIntArrayRef& dynamic_dims_strides,
const bool& requires_grad);
std::string check_verbose(
const LocalState& state,
const at::Tensor& v,
const std::string& tensor_name);
PyTypeObject* pytype;
private:
uint64_t dispatch_key_; // DispatchKeySet includes device/layout
at::ScalarType dtype_;
// Note(voz): While dispatch_key_ is sufficiently representative of a device
// In that keys are more granular AND device specific - they do not
// necessarily capture device indices correctly.
at::DeviceIndex device_index_;
bool requires_grad_;
// NB: These are unset if dynamic shapes is enabled.
std::vector<std::optional<c10::SymInt>> sizes_;
std::vector<std::optional<c10::SymInt>> strides_;
// Not strictly required for dense tensors, but nested tensors need it.
int64_t dim_;
};
} // namespace torch::dynamo