[precompile] Serialization for GlobalStateGuard (#150636)

Summary: To preserve global state guards we need to make the C++ type serialzable. Using json because it's easier to do and we don't have a lot of data in global state.

Test Plan: test_dynamo -k test_global_state_guard_serialization

Differential Revision: D72410611

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150636
Approved by: https://github.com/williamwen42
This commit is contained in:
Zhengxu Chen 2025-04-07 03:10:03 +00:00 committed by PyTorch MergeBot
parent b6929aef08
commit 24aadb40fb
3 changed files with 131 additions and 1 deletions

View File

@ -477,7 +477,7 @@ class C10_API TypeMeta final {
/** /**
* convert TypeMeta handles to ScalarType enum values * convert TypeMeta handles to ScalarType enum values
*/ */
inline ScalarType toScalarType() { inline ScalarType toScalarType() const {
if (C10_LIKELY(isScalarType())) { if (C10_LIKELY(isScalarType())) {
return static_cast<ScalarType>(index_); return static_cast<ScalarType>(index_);
} }

View File

@ -11,6 +11,7 @@ import functools
import gc import gc
import importlib import importlib
import itertools import itertools
import json
import logging import logging
import math import math
import operator import operator
@ -3185,6 +3186,58 @@ utils_device.CURRENT_DEVICE == None""".split(
self.assertEqual(cnts.frame_count, 1) self.assertEqual(cnts.frame_count, 1)
self.assertEqual(cnts.op_count, 4) self.assertEqual(cnts.op_count, 4)
def test_global_state_guard_serialization(self):
GlobalStateGuard = torch._C._dynamo.guards.GlobalStateGuard
guards = GlobalStateGuard()
serialized_guards = guards.dump()
json_guards = json.loads(serialized_guards)
samples = []
# Test on non autocast state and autocast cache states.
self.assertIn("autocast_state", json_guards)
for key, value in json_guards.items():
if type(value) == int:
variant = value + 1
elif type(value) == bool:
variant = not value
elif isinstance(value, dict) and key == "autocast_state":
variant = value.copy()
variant["cached_enabled"] = not variant["cached_enabled"]
continue
else:
self.fail(f"Unknown global state type {key}: {value}")
new_dict = json_guards.copy()
new_dict[key] = variant
samples.append(new_dict)
for sample in samples:
guards.load(json.dumps(sample))
self.assertFalse(guards.check())
guards.load(json.dumps(json_guards))
self.assertTrue(guards.check())
# Test on autocast states.
def _test_autocast(dtype):
with torch.autocast("cpu", dtype):
guards = GlobalStateGuard()
serialized_guards = guards.dump()
json_guards = json.loads(serialized_guards)
for i, enabled in enumerate(json_guards["autocast_state"]["enabled"]):
if enabled:
self.assertEqual(
type(json_guards["autocast_state"]["dtype"][i]), int
)
json_guards["autocast_state"]["dtype"][i] += 1
guards.load(json.dumps(json_guards))
self.assertFalse(guards.check())
_test_autocast(torch.float16)
_test_autocast(torch.float32)
_test_autocast(torch.float64)
_test_autocast(torch.bfloat16)
def test_type_copy(self): def test_type_copy(self):
def fn(seq): def fn(seq):
a, b = seq a, b = seq

View File

@ -20,6 +20,8 @@
#include <torch/csrc/dynamo/debug_macros.h> #include <torch/csrc/dynamo/debug_macros.h>
#include <nlohmann/json.hpp>
#ifdef USE_CUDA #ifdef USE_CUDA
#include <ATen/cuda/EmptyTensor.h> #include <ATen/cuda/EmptyTensor.h>
#endif #endif
@ -552,6 +554,20 @@ struct AutocastState {
} }
return true; return true;
} }
template <typename T>
friend void to_json(T& json_j, const AutocastState& json_t) {
json_j["enabled"] = json_t.enabled;
json_j["dtype"] = json_t.dtype;
json_j["cached_enabled"] = json_t.cache_enabled;
}
template <typename T>
friend void from_json(const T& json_j, AutocastState& json_t) {
json_t.enabled = json_j.at("enabled");
json_t.dtype = json_j.at("dtype");
json_t.cache_enabled = json_j.at("cached_enabled");
}
}; };
// TODO (janimesh) - Remove the PyObject_HEAD part when C++ guard manager is // TODO (janimesh) - Remove the PyObject_HEAD part when C++ guard manager is
@ -623,6 +639,40 @@ struct GlobalStateGuard {
return os.str(); return os.str();
} }
template <typename T>
friend void to_json(T& json_j, const GlobalStateGuard& json_t) {
json_j["grad_mode"] = json_t._grad_mode;
json_j["autocast_state"] = json_t._autocast_state;
json_j["torch_function"] = json_t._torch_function;
json_j["torch_function_all_disabled"] = json_t._torch_function_all_disabled;
json_j["deterministic_algorithms"] = json_t._deterministic_algorithms;
json_j["deterministic_algorithms_warn_only"] =
json_t._deterministic_algorithms_warn_only;
json_j["allow_tf32"] = json_t._allow_tf32;
json_j["allow_fp16_reduce"] = json_t._allow_fp16_reduce;
json_j["allow_bf16_reduce"] = json_t._allow_bf16_reduce;
json_j["num_threads"] = json_t._num_threads;
json_j["default_dtype"] = json_t._default_dtype.toScalarType();
}
template <typename T>
friend void from_json(const T& json_j, GlobalStateGuard& json_t) {
json_t._grad_mode = json_j.at("grad_mode");
json_t._autocast_state = json_j.at("autocast_state");
json_t._torch_function = json_j.at("torch_function");
json_t._torch_function_all_disabled =
json_j.at("torch_function_all_disabled");
json_t._deterministic_algorithms = json_j.at("deterministic_algorithms");
json_t._deterministic_algorithms_warn_only =
json_j.at("deterministic_algorithms_warn_only");
json_t._allow_tf32 = json_j.at("allow_tf32");
json_t._allow_fp16_reduce = json_j.at("allow_fp16_reduce");
json_t._allow_bf16_reduce = json_j.at("allow_bf16_reduce");
json_t._num_threads = json_j.at("num_threads");
json_t._default_dtype =
caffe2::TypeMeta::fromScalarType(json_j.at("default_dtype"));
}
bool _grad_mode; bool _grad_mode;
AutocastState _autocast_state; AutocastState _autocast_state;
bool _torch_function; bool _torch_function;
@ -663,6 +713,25 @@ PyObject* GlobalStateGuard_reason(
return PyUnicode_FromString(self->reason().c_str()); return PyUnicode_FromString(self->reason().c_str());
} }
PyObject* GlobalStateGuard_dump(
GlobalStateGuard* self,
PyObject* args,
PyObject* kwargs) {
return PyUnicode_FromString(nlohmann::json(*self).dump().c_str());
}
PyObject* GlobalStateGuard_load(
GlobalStateGuard* self,
PyObject* args,
PyObject* kwargs) {
char* json;
if (!PyArg_ParseTuple(args, "s", &json)) {
throw std::runtime_error("Cannot parse as json string.");
}
nlohmann::json::parse(json).get_to(*self);
Py_RETURN_NONE;
}
// NOLINTNEXTLINE(*array*) // NOLINTNEXTLINE(*array*)
static PyMethodDef GlobalStateGuard_methods[] = { static PyMethodDef GlobalStateGuard_methods[] = {
{"check", {"check",
@ -673,6 +742,14 @@ static PyMethodDef GlobalStateGuard_methods[] = {
(PyCFunction)(void*)GlobalStateGuard_reason, (PyCFunction)(void*)GlobalStateGuard_reason,
METH_NOARGS, METH_NOARGS,
"Return string reason for guard check failing"}, "Return string reason for guard check failing"},
{"dump",
(PyCFunction)(void*)GlobalStateGuard_dump,
METH_NOARGS,
"Return serialized json format"},
{"load",
(PyCFunction)(void*)GlobalStateGuard_load,
METH_VARARGS,
"Parse serialized json format"},
{nullptr}}; {nullptr}};
static PyTypeObject GlobalStateGuardType = { PyVarObject_HEAD_INIT(nullptr, 0) static PyTypeObject GlobalStateGuardType = { PyVarObject_HEAD_INIT(nullptr, 0)
}; };