mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[precompile] Serialization for GlobalStateGuard (#150636)
Summary: To preserve global state guards we need to make the C++ type serialzable. Using json because it's easier to do and we don't have a lot of data in global state. Test Plan: test_dynamo -k test_global_state_guard_serialization Differential Revision: D72410611 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150636 Approved by: https://github.com/williamwen42
This commit is contained in:
parent
b6929aef08
commit
24aadb40fb
|
|
@ -477,7 +477,7 @@ class C10_API TypeMeta final {
|
|||
/**
|
||||
* convert TypeMeta handles to ScalarType enum values
|
||||
*/
|
||||
inline ScalarType toScalarType() {
|
||||
inline ScalarType toScalarType() const {
|
||||
if (C10_LIKELY(isScalarType())) {
|
||||
return static_cast<ScalarType>(index_);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ import functools
|
|||
import gc
|
||||
import importlib
|
||||
import itertools
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import operator
|
||||
|
|
@ -3185,6 +3186,58 @@ utils_device.CURRENT_DEVICE == None""".split(
|
|||
self.assertEqual(cnts.frame_count, 1)
|
||||
self.assertEqual(cnts.op_count, 4)
|
||||
|
||||
def test_global_state_guard_serialization(self):
|
||||
GlobalStateGuard = torch._C._dynamo.guards.GlobalStateGuard
|
||||
guards = GlobalStateGuard()
|
||||
serialized_guards = guards.dump()
|
||||
json_guards = json.loads(serialized_guards)
|
||||
|
||||
samples = []
|
||||
# Test on non autocast state and autocast cache states.
|
||||
self.assertIn("autocast_state", json_guards)
|
||||
for key, value in json_guards.items():
|
||||
if type(value) == int:
|
||||
variant = value + 1
|
||||
elif type(value) == bool:
|
||||
variant = not value
|
||||
elif isinstance(value, dict) and key == "autocast_state":
|
||||
variant = value.copy()
|
||||
variant["cached_enabled"] = not variant["cached_enabled"]
|
||||
continue
|
||||
else:
|
||||
self.fail(f"Unknown global state type {key}: {value}")
|
||||
new_dict = json_guards.copy()
|
||||
new_dict[key] = variant
|
||||
samples.append(new_dict)
|
||||
|
||||
for sample in samples:
|
||||
guards.load(json.dumps(sample))
|
||||
self.assertFalse(guards.check())
|
||||
|
||||
guards.load(json.dumps(json_guards))
|
||||
self.assertTrue(guards.check())
|
||||
|
||||
# Test on autocast states.
|
||||
def _test_autocast(dtype):
|
||||
with torch.autocast("cpu", dtype):
|
||||
guards = GlobalStateGuard()
|
||||
serialized_guards = guards.dump()
|
||||
json_guards = json.loads(serialized_guards)
|
||||
|
||||
for i, enabled in enumerate(json_guards["autocast_state"]["enabled"]):
|
||||
if enabled:
|
||||
self.assertEqual(
|
||||
type(json_guards["autocast_state"]["dtype"][i]), int
|
||||
)
|
||||
json_guards["autocast_state"]["dtype"][i] += 1
|
||||
guards.load(json.dumps(json_guards))
|
||||
self.assertFalse(guards.check())
|
||||
|
||||
_test_autocast(torch.float16)
|
||||
_test_autocast(torch.float32)
|
||||
_test_autocast(torch.float64)
|
||||
_test_autocast(torch.bfloat16)
|
||||
|
||||
def test_type_copy(self):
|
||||
def fn(seq):
|
||||
a, b = seq
|
||||
|
|
|
|||
|
|
@ -20,6 +20,8 @@
|
|||
|
||||
#include <torch/csrc/dynamo/debug_macros.h>
|
||||
|
||||
#include <nlohmann/json.hpp>
|
||||
|
||||
#ifdef USE_CUDA
|
||||
#include <ATen/cuda/EmptyTensor.h>
|
||||
#endif
|
||||
|
|
@ -552,6 +554,20 @@ struct AutocastState {
|
|||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
friend void to_json(T& json_j, const AutocastState& json_t) {
|
||||
json_j["enabled"] = json_t.enabled;
|
||||
json_j["dtype"] = json_t.dtype;
|
||||
json_j["cached_enabled"] = json_t.cache_enabled;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
friend void from_json(const T& json_j, AutocastState& json_t) {
|
||||
json_t.enabled = json_j.at("enabled");
|
||||
json_t.dtype = json_j.at("dtype");
|
||||
json_t.cache_enabled = json_j.at("cached_enabled");
|
||||
}
|
||||
};
|
||||
|
||||
// TODO (janimesh) - Remove the PyObject_HEAD part when C++ guard manager is
|
||||
|
|
@ -623,6 +639,40 @@ struct GlobalStateGuard {
|
|||
return os.str();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
friend void to_json(T& json_j, const GlobalStateGuard& json_t) {
|
||||
json_j["grad_mode"] = json_t._grad_mode;
|
||||
json_j["autocast_state"] = json_t._autocast_state;
|
||||
json_j["torch_function"] = json_t._torch_function;
|
||||
json_j["torch_function_all_disabled"] = json_t._torch_function_all_disabled;
|
||||
json_j["deterministic_algorithms"] = json_t._deterministic_algorithms;
|
||||
json_j["deterministic_algorithms_warn_only"] =
|
||||
json_t._deterministic_algorithms_warn_only;
|
||||
json_j["allow_tf32"] = json_t._allow_tf32;
|
||||
json_j["allow_fp16_reduce"] = json_t._allow_fp16_reduce;
|
||||
json_j["allow_bf16_reduce"] = json_t._allow_bf16_reduce;
|
||||
json_j["num_threads"] = json_t._num_threads;
|
||||
json_j["default_dtype"] = json_t._default_dtype.toScalarType();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
friend void from_json(const T& json_j, GlobalStateGuard& json_t) {
|
||||
json_t._grad_mode = json_j.at("grad_mode");
|
||||
json_t._autocast_state = json_j.at("autocast_state");
|
||||
json_t._torch_function = json_j.at("torch_function");
|
||||
json_t._torch_function_all_disabled =
|
||||
json_j.at("torch_function_all_disabled");
|
||||
json_t._deterministic_algorithms = json_j.at("deterministic_algorithms");
|
||||
json_t._deterministic_algorithms_warn_only =
|
||||
json_j.at("deterministic_algorithms_warn_only");
|
||||
json_t._allow_tf32 = json_j.at("allow_tf32");
|
||||
json_t._allow_fp16_reduce = json_j.at("allow_fp16_reduce");
|
||||
json_t._allow_bf16_reduce = json_j.at("allow_bf16_reduce");
|
||||
json_t._num_threads = json_j.at("num_threads");
|
||||
json_t._default_dtype =
|
||||
caffe2::TypeMeta::fromScalarType(json_j.at("default_dtype"));
|
||||
}
|
||||
|
||||
bool _grad_mode;
|
||||
AutocastState _autocast_state;
|
||||
bool _torch_function;
|
||||
|
|
@ -663,6 +713,25 @@ PyObject* GlobalStateGuard_reason(
|
|||
return PyUnicode_FromString(self->reason().c_str());
|
||||
}
|
||||
|
||||
PyObject* GlobalStateGuard_dump(
|
||||
GlobalStateGuard* self,
|
||||
PyObject* args,
|
||||
PyObject* kwargs) {
|
||||
return PyUnicode_FromString(nlohmann::json(*self).dump().c_str());
|
||||
}
|
||||
|
||||
PyObject* GlobalStateGuard_load(
|
||||
GlobalStateGuard* self,
|
||||
PyObject* args,
|
||||
PyObject* kwargs) {
|
||||
char* json;
|
||||
if (!PyArg_ParseTuple(args, "s", &json)) {
|
||||
throw std::runtime_error("Cannot parse as json string.");
|
||||
}
|
||||
nlohmann::json::parse(json).get_to(*self);
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(*array*)
|
||||
static PyMethodDef GlobalStateGuard_methods[] = {
|
||||
{"check",
|
||||
|
|
@ -673,6 +742,14 @@ static PyMethodDef GlobalStateGuard_methods[] = {
|
|||
(PyCFunction)(void*)GlobalStateGuard_reason,
|
||||
METH_NOARGS,
|
||||
"Return string reason for guard check failing"},
|
||||
{"dump",
|
||||
(PyCFunction)(void*)GlobalStateGuard_dump,
|
||||
METH_NOARGS,
|
||||
"Return serialized json format"},
|
||||
{"load",
|
||||
(PyCFunction)(void*)GlobalStateGuard_load,
|
||||
METH_VARARGS,
|
||||
"Parse serialized json format"},
|
||||
{nullptr}};
|
||||
static PyTypeObject GlobalStateGuardType = { PyVarObject_HEAD_INIT(nullptr, 0)
|
||||
};
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user