mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo] Guard serialization for FUNCTORCH_STACK_MATCH (#152616)
Make Functorch interpreters serializable most of the time, so that we can save the guards on functorch states. ## Test Cases: 0. torch.compile() without functorch layers present. Guard should fail with any layer being pushed. 1. torch.compile() nested in vmap. 2. torch.compile() nested in grad. 3. torch.compile() nested in jvp + vmap 4. torch.compile() nested functionalize 5. torch.compile() nested in vmap + grad Differential Revision: [D74008787](https://our.internmc.facebook.com/intern/diff/D74008787/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/152616 Approved by: https://github.com/zou3519 ghstack dependencies: #152615
This commit is contained in:
parent
1d1cbcd8a3
commit
ffd58293f7
|
|
@ -376,6 +376,7 @@ cc_library(
|
||||||
":torch_headers",
|
":torch_headers",
|
||||||
"@fbgemm",
|
"@fbgemm",
|
||||||
"@ideep",
|
"@ideep",
|
||||||
|
"@nlohmann",
|
||||||
],
|
],
|
||||||
alwayslink = True,
|
alwayslink = True,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,8 @@
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <variant>
|
#include <variant>
|
||||||
|
|
||||||
|
#include <nlohmann/json.hpp>
|
||||||
|
|
||||||
namespace at::functorch {
|
namespace at::functorch {
|
||||||
|
|
||||||
// NOTE: [functorch interpreter stack]
|
// NOTE: [functorch interpreter stack]
|
||||||
|
|
@ -91,24 +93,95 @@ std::ostream& operator<<(std::ostream& os, const TransformType& t);
|
||||||
struct VmapInterpreterMeta {
|
struct VmapInterpreterMeta {
|
||||||
explicit VmapInterpreterMeta(c10::SymInt batchSize, RandomnessType randomness) :
|
explicit VmapInterpreterMeta(c10::SymInt batchSize, RandomnessType randomness) :
|
||||||
batchSize_(std::move(batchSize)), randomness_(randomness) {}
|
batchSize_(std::move(batchSize)), randomness_(randomness) {}
|
||||||
|
|
||||||
c10::SymInt batchSize_;
|
c10::SymInt batchSize_;
|
||||||
RandomnessType randomness_;
|
RandomnessType randomness_;
|
||||||
|
|
||||||
|
VmapInterpreterMeta() = default;
|
||||||
|
VmapInterpreterMeta(const VmapInterpreterMeta&) = default;
|
||||||
|
VmapInterpreterMeta(VmapInterpreterMeta&&) = default;
|
||||||
|
VmapInterpreterMeta& operator=(const VmapInterpreterMeta&) = default;
|
||||||
|
VmapInterpreterMeta& operator=(VmapInterpreterMeta&&) = default;
|
||||||
|
~VmapInterpreterMeta() = default;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
friend void to_json(T& json_j, const VmapInterpreterMeta& json_t) {
|
||||||
|
if (json_t.batchSize_.is_heap_allocated()) {
|
||||||
|
throw std::runtime_error("Serialization for heap-allocated SymInt is not implemented yet");
|
||||||
|
}
|
||||||
|
json_j["batchSize"] = json_t.batchSize_.as_int_unchecked();
|
||||||
|
json_j["randomness"] = static_cast<int64_t>(json_t.randomness_);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
friend void from_json(const T& json_j, VmapInterpreterMeta& json_t) {
|
||||||
|
json_t.batchSize_ = c10::SymInt(SymInt::Unchecked::UNCHECKED, json_j["batchSize"]);
|
||||||
|
json_t.randomness_ = static_cast<RandomnessType>(json_j["randomness"]);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct GradInterpreterMeta {
|
struct GradInterpreterMeta {
|
||||||
explicit GradInterpreterMeta(bool prevGradMode): prevGradMode_(prevGradMode) {}
|
explicit GradInterpreterMeta(bool prevGradMode): prevGradMode_(prevGradMode) {}
|
||||||
|
GradInterpreterMeta() = default;
|
||||||
|
GradInterpreterMeta(const GradInterpreterMeta&) = default;
|
||||||
|
GradInterpreterMeta(GradInterpreterMeta&&) = default;
|
||||||
|
GradInterpreterMeta& operator=(const GradInterpreterMeta&) = default;
|
||||||
|
GradInterpreterMeta& operator=(GradInterpreterMeta&&) = default;
|
||||||
|
~GradInterpreterMeta() = default;
|
||||||
|
|
||||||
bool prevGradMode_;
|
bool prevGradMode_;
|
||||||
|
template <typename T>
|
||||||
|
friend void to_json(T& json_j, const GradInterpreterMeta& json_t) {
|
||||||
|
json_j["prevGradMode"] = json_t.prevGradMode_;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
friend void from_json(const T& json_j, GradInterpreterMeta& json_t) {
|
||||||
|
json_t.prevGradMode_ = json_j["prevGradMode"];
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct JvpInterpreterMeta {
|
struct JvpInterpreterMeta {
|
||||||
explicit JvpInterpreterMeta(bool prevFwdGradMode) : prevFwdGradMode_(prevFwdGradMode) {}
|
explicit JvpInterpreterMeta(bool prevFwdGradMode) : prevFwdGradMode_(prevFwdGradMode) {}
|
||||||
|
JvpInterpreterMeta() = default;
|
||||||
|
JvpInterpreterMeta(const JvpInterpreterMeta&) = default;
|
||||||
|
JvpInterpreterMeta(JvpInterpreterMeta&&) = default;
|
||||||
|
JvpInterpreterMeta& operator=(const JvpInterpreterMeta&) = default;
|
||||||
|
JvpInterpreterMeta& operator=(JvpInterpreterMeta&&) = default;
|
||||||
|
~JvpInterpreterMeta() = default;
|
||||||
|
|
||||||
bool prevFwdGradMode_;
|
bool prevFwdGradMode_;
|
||||||
|
template <typename T>
|
||||||
|
friend void to_json(T& json_j, const JvpInterpreterMeta& json_t) {
|
||||||
|
json_j["prevFwdGradMode"] = json_t.prevFwdGradMode_;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
friend void from_json(const T& json_j, JvpInterpreterMeta& json_t) {
|
||||||
|
json_t.prevFwdGradMode_ = json_j["prevFwdGradMode"];
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct FunctionalizeInterpreterMeta {
|
struct FunctionalizeInterpreterMeta {
|
||||||
explicit FunctionalizeInterpreterMeta(bool functionalizeAddBackViews) :
|
explicit FunctionalizeInterpreterMeta(bool functionalizeAddBackViews) :
|
||||||
functionalizeAddBackViews_(functionalizeAddBackViews) {}
|
functionalizeAddBackViews_(functionalizeAddBackViews) {}
|
||||||
|
FunctionalizeInterpreterMeta() = default;
|
||||||
|
FunctionalizeInterpreterMeta(const FunctionalizeInterpreterMeta&) = default;
|
||||||
|
FunctionalizeInterpreterMeta(FunctionalizeInterpreterMeta&&) = default;
|
||||||
|
FunctionalizeInterpreterMeta& operator=(const FunctionalizeInterpreterMeta&) = default;
|
||||||
|
FunctionalizeInterpreterMeta& operator=(FunctionalizeInterpreterMeta&&) = default;
|
||||||
|
~FunctionalizeInterpreterMeta() = default;
|
||||||
|
|
||||||
bool functionalizeAddBackViews_;
|
bool functionalizeAddBackViews_;
|
||||||
|
template <typename T>
|
||||||
|
friend void to_json(T& json_j, const FunctionalizeInterpreterMeta& json_t) {
|
||||||
|
json_j["functionalizeAddBackViews"] = json_t.functionalizeAddBackViews_;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
friend void from_json(const T& json_j, FunctionalizeInterpreterMeta& json_t) {
|
||||||
|
json_t.functionalizeAddBackViews_ = json_j["functionalizeAddBackViews"];
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
typedef std::variant<
|
typedef std::variant<
|
||||||
|
|
@ -172,6 +245,75 @@ struct Interpreter {
|
||||||
// Please don't use this
|
// Please don't use this
|
||||||
explicit Interpreter() = default;
|
explicit Interpreter() = default;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
friend void to_json(T& json_j, const Interpreter& json_t) {
|
||||||
|
json_j["type"] = static_cast<int64_t>(json_t.type_);
|
||||||
|
json_j["level"] = json_t.level_;
|
||||||
|
if (json_t.savedLocalDispatchKeySet_) {
|
||||||
|
json_j["savedLocalDispatchKeySet"] = {
|
||||||
|
{"included", json_t.savedLocalDispatchKeySet_->included_.raw_repr()},
|
||||||
|
{"excluded", json_t.savedLocalDispatchKeySet_->excluded_.raw_repr()}
|
||||||
|
};
|
||||||
|
} else {
|
||||||
|
json_j["savedLocalDispatchKeySet"] = nlohmann::json();
|
||||||
|
}
|
||||||
|
json_j["is_alive"] = *json_t.is_alive_;
|
||||||
|
std::visit([&](auto&& arg) {
|
||||||
|
using V = std::decay_t<decltype(arg)>;
|
||||||
|
if constexpr (std::is_same_v<V, int64_t>) {
|
||||||
|
json_j["meta"] = {{"Torch", arg}};
|
||||||
|
} else if constexpr (std::is_same_v<V, GradInterpreterMeta>) {
|
||||||
|
json_j["meta"] = {{"Grad", arg}};
|
||||||
|
} else if constexpr (std::is_same_v<V, JvpInterpreterMeta>) {
|
||||||
|
json_j["meta"] = {{"Jvp", arg}};
|
||||||
|
} else if constexpr (std::is_same_v<V, VmapInterpreterMeta>) {
|
||||||
|
json_j["meta"] = {{"Vmap", arg}};
|
||||||
|
} else if constexpr (std::is_same_v<V, FunctionalizeInterpreterMeta>) {
|
||||||
|
json_j["meta"] = {{"Functionalize", arg}};
|
||||||
|
} else {
|
||||||
|
static_assert(false && sizeof(V), "unknown variant case");
|
||||||
|
}
|
||||||
|
}, json_t.meta_);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
friend void from_json(const T& json_j, Interpreter& json_t) {
|
||||||
|
json_t.type_ = static_cast<TransformType>(json_j["type"]);
|
||||||
|
json_t.level_ = json_j["level"];
|
||||||
|
auto savedLocalDispatchKeySet = json_j["savedLocalDispatchKeySet"];
|
||||||
|
if (savedLocalDispatchKeySet.is_null()) {
|
||||||
|
json_t.savedLocalDispatchKeySet_ = std::nullopt;
|
||||||
|
} else {
|
||||||
|
c10::impl::PODLocalDispatchKeySet pod;
|
||||||
|
pod.set_included(DispatchKeySet::from_raw_repr(savedLocalDispatchKeySet["included"].template get<uint64_t>()));
|
||||||
|
pod.set_excluded(DispatchKeySet::from_raw_repr(savedLocalDispatchKeySet["excluded"].template get<uint64_t>()));
|
||||||
|
json_t.savedLocalDispatchKeySet_ = c10::impl::LocalDispatchKeySet(pod);
|
||||||
|
}
|
||||||
|
json_t.is_alive_ = std::make_shared<bool>(json_j["is_alive"]);
|
||||||
|
auto meta = json_j["meta"];
|
||||||
|
if (meta.contains("Torch")) {
|
||||||
|
json_t.meta_.emplace<int64_t>(meta["Torch"].template get<int64_t>());
|
||||||
|
} else if (meta.contains("Grad")) {
|
||||||
|
json_t.meta_.emplace<GradInterpreterMeta>(meta["Grad"].template get<GradInterpreterMeta>());
|
||||||
|
} else if (meta.contains("Jvp")) {
|
||||||
|
json_t.meta_.emplace<JvpInterpreterMeta>(meta["Jvp"].template get<JvpInterpreterMeta>());
|
||||||
|
} else if (meta.contains("Vmap")) {
|
||||||
|
json_t.meta_.emplace<VmapInterpreterMeta>(meta["Vmap"].template get<VmapInterpreterMeta>());
|
||||||
|
} else if (meta.contains("Functionalize")) {
|
||||||
|
json_t.meta_.emplace<FunctionalizeInterpreterMeta>(meta["Functionalize"].template get<FunctionalizeInterpreterMeta>());
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error("unknown interpreter metadata type");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string serialize() const {
|
||||||
|
return nlohmann::json(*this).dump();
|
||||||
|
}
|
||||||
|
|
||||||
|
static Interpreter deserialize(const std::string& serialized) {
|
||||||
|
return nlohmann::json::parse(serialized).get<Interpreter>();
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
explicit Interpreter(TransformType type, int64_t level, InterpreterMeta meta):
|
explicit Interpreter(TransformType type, int64_t level, InterpreterMeta meta):
|
||||||
type_(type), level_(level), is_alive_(std::make_shared<bool>(false)), meta_(std::move(meta)) {}
|
type_(type), level_(level), is_alive_(std::make_shared<bool>(false)), meta_(std::move(meta)) {}
|
||||||
|
|
|
||||||
|
|
@ -94,7 +94,7 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
|
||||||
fn.__closure__ or (),
|
fn.__closure__ or (),
|
||||||
[], # TODO tf_mode_stack,
|
[], # TODO tf_mode_stack,
|
||||||
code_options,
|
code_options,
|
||||||
lambda gm, *args, **kwargs: gm.forward,
|
torch._dynamo.lookup_backend("eager"),
|
||||||
one_graph=False,
|
one_graph=False,
|
||||||
export=False,
|
export=False,
|
||||||
export_constraints=None,
|
export_constraints=None,
|
||||||
|
|
@ -326,6 +326,126 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
|
||||||
with torch.autograd.forward_ad.dual_level():
|
with torch.autograd.forward_ad.dual_level():
|
||||||
self._test_check_fn(ref, loaded, {"x": x}, False)
|
self._test_check_fn(ref, loaded, {"x": x}, False)
|
||||||
|
|
||||||
|
def test_functorch_stack_match(self):
|
||||||
|
# Test when functorch stack is empty.
|
||||||
|
def fn(x):
|
||||||
|
return torch.func.jvp(torch.sin, (x,), (x,))
|
||||||
|
|
||||||
|
x = torch.randn(3, 4)
|
||||||
|
ref, loaded = self._test_serialization("FUNCTORCH_STACK_MATCH", fn, x)
|
||||||
|
|
||||||
|
self._test_check_fn(ref, loaded, {"x": x}, True)
|
||||||
|
with torch._functorch.vmap.vmap_increment_nesting(2, "error"):
|
||||||
|
self._test_check_fn(ref, loaded, {"x": x}, False)
|
||||||
|
|
||||||
|
def fn(x):
|
||||||
|
def g(x):
|
||||||
|
return torch.vmap(torch.func.grad(torch.sin))(x)
|
||||||
|
|
||||||
|
return torch.vmap(g)(x)
|
||||||
|
|
||||||
|
x = torch.randn(4, 5)
|
||||||
|
ref, loaded = self._test_serialization("FUNCTORCH_STACK_MATCH", fn, x)
|
||||||
|
self._test_check_fn(ref, loaded, {"x": x}, True)
|
||||||
|
with torch._functorch.eager_transforms.grad_increment_nesting():
|
||||||
|
self._test_check_fn(ref, loaded, {"x": x}, False)
|
||||||
|
|
||||||
|
# Test when there are more than 0 functorch layers.
|
||||||
|
# Simulate the case where torch.compile is nested inside eager transforms.
|
||||||
|
|
||||||
|
# Case 1: vmap
|
||||||
|
def fn(x):
|
||||||
|
return x.sum()
|
||||||
|
|
||||||
|
ref = loaded = None
|
||||||
|
|
||||||
|
def run(x):
|
||||||
|
nonlocal ref, loaded
|
||||||
|
# Turn off automatic dynamic shape to so that functionalization
|
||||||
|
# doesn't produce extra SymInt to serialize.
|
||||||
|
with torch._dynamo.config.patch(automatic_dynamic_shapes=False):
|
||||||
|
ref, loaded = self._test_serialization("FUNCTORCH_STACK_MATCH", fn, x)
|
||||||
|
return fn(x)
|
||||||
|
|
||||||
|
torch.vmap(run)(x)
|
||||||
|
|
||||||
|
self._test_check_fn(ref, loaded, {"x": x}, False)
|
||||||
|
with torch._functorch.vmap.vmap_increment_nesting(1, "error"):
|
||||||
|
self._test_check_fn(ref, loaded, {"x": x}, True)
|
||||||
|
with torch._functorch.vmap.vmap_increment_nesting(1, "error"):
|
||||||
|
self._test_check_fn(ref, loaded, {"x": x}, False)
|
||||||
|
|
||||||
|
with torch._functorch.eager_transforms.grad_increment_nesting():
|
||||||
|
self._test_check_fn(ref, loaded, {"x": x}, False)
|
||||||
|
|
||||||
|
# Case 2: grad
|
||||||
|
x = torch.randn(3, 2)
|
||||||
|
ref = loaded = None
|
||||||
|
torch.func.grad(run)(x)
|
||||||
|
self._test_check_fn(ref, loaded, {"x": x}, False)
|
||||||
|
with torch._functorch.eager_transforms.grad_increment_nesting():
|
||||||
|
self._test_check_fn(ref, loaded, {"x": x}, True)
|
||||||
|
with torch._functorch.eager_transforms.grad_increment_nesting():
|
||||||
|
self._test_check_fn(ref, loaded, {"x": x}, False)
|
||||||
|
|
||||||
|
with torch._functorch.vmap.vmap_increment_nesting(1, "error"):
|
||||||
|
self._test_check_fn(ref, loaded, {"x": x}, False)
|
||||||
|
|
||||||
|
# Case 3: jvp + vmap
|
||||||
|
x = torch.randn(3, 4)
|
||||||
|
ref = loaded = None
|
||||||
|
|
||||||
|
def fn(x):
|
||||||
|
return torch.func.jvp(torch.sin, (x,), (x,))
|
||||||
|
|
||||||
|
torch.func.jvp(torch.vmap(run), (x,), (x,))
|
||||||
|
self._test_check_fn(ref, loaded, {"x": x}, False)
|
||||||
|
|
||||||
|
with torch._functorch.eager_transforms.jvp_increment_nesting():
|
||||||
|
with torch._functorch.vmap.vmap_increment_nesting(1, "error"):
|
||||||
|
self._test_check_fn(ref, loaded, {"x": x}, True)
|
||||||
|
|
||||||
|
with torch._functorch.vmap.vmap_increment_nesting(1, "error"):
|
||||||
|
with torch._functorch.eager_transforms.jvp_increment_nesting():
|
||||||
|
self._test_check_fn(ref, loaded, {"x": x}, False)
|
||||||
|
|
||||||
|
# Case 4: functionalize
|
||||||
|
x = torch.randn(3, 2)
|
||||||
|
ref = loaded = None
|
||||||
|
torch.func.functionalize(run)(x)
|
||||||
|
self._test_check_fn(ref, loaded, {"x": x}, False)
|
||||||
|
|
||||||
|
torch._C._functorch._func_increment_nesting(True)
|
||||||
|
try:
|
||||||
|
self._test_check_fn(ref, loaded, {"x": x}, True)
|
||||||
|
finally:
|
||||||
|
torch._C._functorch._func_decrement_nesting()
|
||||||
|
|
||||||
|
with torch._functorch.eager_transforms.jvp_increment_nesting():
|
||||||
|
self._test_check_fn(ref, loaded, {"x": x}, False)
|
||||||
|
|
||||||
|
# Case 5: vmap + grad
|
||||||
|
def fn(x):
|
||||||
|
return x.sum()
|
||||||
|
|
||||||
|
x = torch.randn(3, 2)
|
||||||
|
ref = loaded = None
|
||||||
|
torch.vmap(torch.func.grad(run))(x)
|
||||||
|
self._test_check_fn(ref, loaded, {"x": x}, False)
|
||||||
|
with torch._functorch.vmap.vmap_increment_nesting(1, "error"):
|
||||||
|
with torch._functorch.eager_transforms.grad_increment_nesting():
|
||||||
|
self._test_check_fn(ref, loaded, {"x": x}, True)
|
||||||
|
|
||||||
|
with torch._functorch.eager_transforms.grad_increment_nesting():
|
||||||
|
with torch._functorch.vmap.vmap_increment_nesting(1, "error"):
|
||||||
|
self._test_check_fn(ref, loaded, {"x": x}, False)
|
||||||
|
|
||||||
|
with torch._functorch.vmap.vmap_increment_nesting(1, "error"):
|
||||||
|
self._test_check_fn(ref, loaded, {"x": x}, False)
|
||||||
|
|
||||||
|
with torch._functorch.eager_transforms.grad_increment_nesting():
|
||||||
|
self._test_check_fn(ref, loaded, {"x": x}, False)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from torch._dynamo.test_case import run_tests
|
from torch._dynamo.test_case import run_tests
|
||||||
|
|
|
||||||
|
|
@ -49,6 +49,9 @@ class RandomnessType(Enum):
|
||||||
class CInterpreter:
|
class CInterpreter:
|
||||||
def key(self) -> TransformType: ...
|
def key(self) -> TransformType: ...
|
||||||
def level(self) -> int: ...
|
def level(self) -> int: ...
|
||||||
|
def serialize(self) -> bytes: ...
|
||||||
|
@staticmethod
|
||||||
|
def deserialize(bytes) -> CInterpreter: ...
|
||||||
|
|
||||||
class CGradInterpreterPtr:
|
class CGradInterpreterPtr:
|
||||||
def __init__(self, interpreter: CInterpreter) -> None: ...
|
def __init__(self, interpreter: CInterpreter) -> None: ...
|
||||||
|
|
|
||||||
|
|
@ -475,7 +475,11 @@ def get_verbose_code_part(code_part: str, guard: Guard) -> str:
|
||||||
extra = f" # {format_frame(fs, line=True)}"
|
extra = f" # {format_frame(fs, line=True)}"
|
||||||
break
|
break
|
||||||
elif guard.stack:
|
elif guard.stack:
|
||||||
extra = f" # {format_frame(guard.stack.summary()[-1])}"
|
summary = guard.stack.summary()
|
||||||
|
if len(summary) > 0:
|
||||||
|
extra = f" # {format_frame(summary[-1])}"
|
||||||
|
else:
|
||||||
|
extra = " # <unknown>"
|
||||||
return f"{code_part:<60}{extra}"
|
return f"{code_part:<60}{extra}"
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1591,7 +1595,7 @@ class GuardBuilder(GuardBuilderBase):
|
||||||
def FUNCTORCH_STACK_MATCH(self, guard: Guard):
|
def FUNCTORCH_STACK_MATCH(self, guard: Guard):
|
||||||
# Invalidate functorch code if current level is different than
|
# Invalidate functorch code if current level is different than
|
||||||
# the one when FX graph was generated
|
# the one when FX graph was generated
|
||||||
cis = torch._functorch.pyfunctorch.retrieve_all_functorch_interpreters()
|
cis = self.check_fn_manager.output_graph.functorch_layers
|
||||||
states = [ci.get_state() for ci in cis]
|
states = [ci.get_state() for ci in cis]
|
||||||
code = [f"torch._functorch.pyfunctorch.compare_functorch_state({states})"]
|
code = [f"torch._functorch.pyfunctorch.compare_functorch_state({states})"]
|
||||||
self._set_guard_export_info(guard, code)
|
self._set_guard_export_info(guard, code)
|
||||||
|
|
@ -2522,7 +2526,13 @@ class GuardsStatePickler(pickle.Pickler):
|
||||||
def _unpickle_dispatch_key_set(cls, raw_repr: int):
|
def _unpickle_dispatch_key_set(cls, raw_repr: int):
|
||||||
return torch._C.DispatchKeySet.from_raw_repr(raw_repr)
|
return torch._C.DispatchKeySet.from_raw_repr(raw_repr)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _unpickle_functorch_interpreter(cls, json: bytes):
|
||||||
|
return torch._C._functorch.CInterpreter.deserialize(json)
|
||||||
|
|
||||||
def reducer_override(self, obj):
|
def reducer_override(self, obj):
|
||||||
|
import sympy
|
||||||
|
|
||||||
if isinstance(obj, torch.Tensor) and obj.device.type != "meta":
|
if isinstance(obj, torch.Tensor) and obj.device.type != "meta":
|
||||||
return type(self)._unpickle_tensor, (
|
return type(self)._unpickle_tensor, (
|
||||||
torch.empty_like(obj, device="meta"),
|
torch.empty_like(obj, device="meta"),
|
||||||
|
|
@ -2543,6 +2553,20 @@ class GuardsStatePickler(pickle.Pickler):
|
||||||
elif isinstance(obj, torch._C.DispatchKeySet):
|
elif isinstance(obj, torch._C.DispatchKeySet):
|
||||||
return type(self)._unpickle_dispatch_key_set, (obj.raw_repr(),)
|
return type(self)._unpickle_dispatch_key_set, (obj.raw_repr(),)
|
||||||
|
|
||||||
|
elif isinstance(obj, torch._C._functorch.CInterpreter):
|
||||||
|
return type(self)._unpickle_functorch_interpreter, (obj.serialize(),)
|
||||||
|
|
||||||
|
elif (
|
||||||
|
inspect.isclass(obj)
|
||||||
|
and issubclass(obj, sympy.Function)
|
||||||
|
and hasattr(obj, "_torch_handler_name")
|
||||||
|
):
|
||||||
|
assert hasattr(obj, "_torch_unpickler")
|
||||||
|
return obj._torch_unpickler, (obj._torch_handler_name,)
|
||||||
|
|
||||||
|
elif isinstance(obj, torch.SymInt):
|
||||||
|
raise RuntimeError(f"Cannot serialize SymInt {obj} (node: {obj.node})")
|
||||||
|
|
||||||
if type(obj).__qualname__ != type(obj).__name__:
|
if type(obj).__qualname__ != type(obj).__name__:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Type {type(obj)} for object {obj} cannot be saved "
|
f"Type {type(obj)} for object {obj} cannot be saved "
|
||||||
|
|
|
||||||
|
|
@ -301,6 +301,7 @@ class OutputGraphGuardsState:
|
||||||
# Map from graph input's `Source` to sizes / strides metadata
|
# Map from graph input's `Source` to sizes / strides metadata
|
||||||
input_source_to_sizes_strides: dict[Source, dict[str, Any]]
|
input_source_to_sizes_strides: dict[Source, dict[str, Any]]
|
||||||
dual_level: int
|
dual_level: int
|
||||||
|
functorch_layers: list[torch._functorch.pyfunctorch.FuncTorchInterpreter]
|
||||||
|
|
||||||
export: bool = False
|
export: bool = False
|
||||||
export_constraints: bool = False
|
export_constraints: bool = False
|
||||||
|
|
@ -354,6 +355,7 @@ class OutputGraph(OutputGraphGuardsState):
|
||||||
guard_on_key_order=set(),
|
guard_on_key_order=set(),
|
||||||
input_source_to_sizes_strides={},
|
input_source_to_sizes_strides={},
|
||||||
dual_level=torch.autograd.forward_ad._current_level,
|
dual_level=torch.autograd.forward_ad._current_level,
|
||||||
|
functorch_layers=torch._functorch.pyfunctorch.retrieve_all_functorch_interpreters(),
|
||||||
)
|
)
|
||||||
self.tracers = [SubgraphTracer(self, is_export=export)]
|
self.tracers = [SubgraphTracer(self, is_export=export)]
|
||||||
# Map from graph input's `Source` to its `VariableTracker` to
|
# Map from graph input's `Source` to its `VariableTracker` to
|
||||||
|
|
@ -590,6 +592,7 @@ class OutputGraph(OutputGraphGuardsState):
|
||||||
guard_on_key_order=self.guard_on_key_order,
|
guard_on_key_order=self.guard_on_key_order,
|
||||||
input_source_to_sizes_strides=self.input_source_to_sizes_strides,
|
input_source_to_sizes_strides=self.input_source_to_sizes_strides,
|
||||||
dual_level=self.dual_level,
|
dual_level=self.dual_level,
|
||||||
|
functorch_layers=self.functorch_layers,
|
||||||
export=self.export,
|
export=self.export,
|
||||||
export_constraints=self.export_constraints,
|
export_constraints=self.export_constraints,
|
||||||
_guards=self.guards,
|
_guards=self.guards,
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
# mypy: allow-untyped-defs
|
# mypy: allow-untyped-defs
|
||||||
import contextlib
|
import contextlib
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from functools import cached_property
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
@ -79,6 +80,11 @@ class FuncTorchInterpreter(ABC):
|
||||||
def check_state(self, state):
|
def check_state(self, state):
|
||||||
return state == self.get_state()
|
return state == self.get_state()
|
||||||
|
|
||||||
|
def __getstate__(self):
|
||||||
|
state = self.__dict__.copy()
|
||||||
|
state.pop("_cptr", None)
|
||||||
|
return state
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def temporarily_pop_interpreter_stack():
|
def temporarily_pop_interpreter_stack():
|
||||||
|
|
@ -123,7 +129,10 @@ class VmapInterpreter(FuncTorchInterpreter):
|
||||||
# cdata is a generic CInterpreter. We wrap it in a CVmapInterpreterPtr
|
# cdata is a generic CInterpreter. We wrap it in a CVmapInterpreterPtr
|
||||||
# so that we can access methods specific to the vmap interpreter
|
# so that we can access methods specific to the vmap interpreter
|
||||||
self._cdata = cdata
|
self._cdata = cdata
|
||||||
self._cptr = CVmapInterpreterPtr(cdata)
|
|
||||||
|
@cached_property
|
||||||
|
def _cptr(self):
|
||||||
|
return CVmapInterpreterPtr(self._cdata)
|
||||||
|
|
||||||
def process(self, op, args, kwargs):
|
def process(self, op, args, kwargs):
|
||||||
kernel = op.functorch_table[TransformType.Vmap]
|
kernel = op.functorch_table[TransformType.Vmap]
|
||||||
|
|
@ -159,7 +168,10 @@ class GradInterpreter(FuncTorchInterpreter):
|
||||||
assert cdata.key() == TransformType.Grad
|
assert cdata.key() == TransformType.Grad
|
||||||
# See NOTE: [Interpreter cdata vs cptr]
|
# See NOTE: [Interpreter cdata vs cptr]
|
||||||
self._cdata = cdata
|
self._cdata = cdata
|
||||||
self._cptr = CGradInterpreterPtr(cdata)
|
|
||||||
|
@cached_property
|
||||||
|
def _cptr(self):
|
||||||
|
return CGradInterpreterPtr(self._cdata)
|
||||||
|
|
||||||
def lift(self, args, kwargs):
|
def lift(self, args, kwargs):
|
||||||
args, kwargs = pytree.tree_map_only(
|
args, kwargs = pytree.tree_map_only(
|
||||||
|
|
@ -193,7 +205,10 @@ class JvpInterpreter(FuncTorchInterpreter):
|
||||||
assert cdata.key() == TransformType.Jvp
|
assert cdata.key() == TransformType.Jvp
|
||||||
# See NOTE: [Interpreter cdata vs cptr]
|
# See NOTE: [Interpreter cdata vs cptr]
|
||||||
self._cdata = cdata
|
self._cdata = cdata
|
||||||
self._cptr = CJvpInterpreterPtr(cdata)
|
|
||||||
|
@cached_property
|
||||||
|
def _cptr(self):
|
||||||
|
return CJvpInterpreterPtr(self._cdata)
|
||||||
|
|
||||||
def lift(self, args, kwargs):
|
def lift(self, args, kwargs):
|
||||||
args, kwargs = pytree.tree_map_only(
|
args, kwargs = pytree.tree_map_only(
|
||||||
|
|
@ -226,7 +241,10 @@ class FunctionalizeInterpreter(FuncTorchInterpreter):
|
||||||
def __init__(self, cdata: CInterpreter):
|
def __init__(self, cdata: CInterpreter):
|
||||||
assert cdata.key() == TransformType.Functionalize
|
assert cdata.key() == TransformType.Functionalize
|
||||||
self._cdata = cdata
|
self._cdata = cdata
|
||||||
self._cptr = CFunctionalizeInterpreterPtr(cdata)
|
|
||||||
|
@cached_property
|
||||||
|
def _cptr(self):
|
||||||
|
return CFunctionalizeInterpreterPtr(self._cdata)
|
||||||
|
|
||||||
def process(self, op, args, kwargs):
|
def process(self, op, args, kwargs):
|
||||||
kernel = op.functorch_table[TransformType.Functionalize]
|
kernel = op.functorch_table[TransformType.Functionalize]
|
||||||
|
|
|
||||||
|
|
@ -575,7 +575,9 @@ void initFuncTorchBindings(PyObject* module) {
|
||||||
.value("Different", RandomnessType::Different);
|
.value("Different", RandomnessType::Different);
|
||||||
py::class_<Interpreter>(m, "CInterpreter")
|
py::class_<Interpreter>(m, "CInterpreter")
|
||||||
.def("key", &Interpreter::key)
|
.def("key", &Interpreter::key)
|
||||||
.def("level", &Interpreter::level);
|
.def("level", &Interpreter::level)
|
||||||
|
.def("serialize", &Interpreter::serialize)
|
||||||
|
.def_static("deserialize", &Interpreter::deserialize);
|
||||||
py::class_<GradInterpreterPtr>(m, "CGradInterpreterPtr")
|
py::class_<GradInterpreterPtr>(m, "CGradInterpreterPtr")
|
||||||
.def(py::init<const Interpreter*>())
|
.def(py::init<const Interpreter*>())
|
||||||
.def("key", &GradInterpreterPtr::key)
|
.def("key", &GradInterpreterPtr::key)
|
||||||
|
|
|
||||||
|
|
@ -1318,6 +1318,7 @@ def make_opaque_unary_fn(name):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_torch_handler_name = name
|
_torch_handler_name = name
|
||||||
|
_torch_unpickler = make_opaque_unary_fn
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def eval(cls, a):
|
def eval(cls, a):
|
||||||
|
|
@ -1378,6 +1379,9 @@ def make_opaque_bitwise_fn(name, real_op_name):
|
||||||
class BitwiseFn(sympy.Function):
|
class BitwiseFn(sympy.Function):
|
||||||
_torch_handler_name = name
|
_torch_handler_name = name
|
||||||
precedence: int = prec
|
precedence: int = prec
|
||||||
|
_torch_unpickler = functools.partial(
|
||||||
|
make_opaque_bitwise_fn, real_op_name=real_op_name
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def eval(cls, a, b):
|
def eval(cls, a, b):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user