mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo][cpp-guards] Optimize tensor.grad accessor (#123226)
For LayoutLM model, reduces C++ guard overhead by 1.48x. These are the numbers  Pull Request resolved: https://github.com/pytorch/pytorch/pull/123226 Approved by: https://github.com/jansel
This commit is contained in:
parent
9288b27461
commit
d91db70295
|
|
@ -66,6 +66,7 @@ from .source import (
|
|||
GlobalSource,
|
||||
GlobalStateSource,
|
||||
GlobalWeakRefSource,
|
||||
GradSource,
|
||||
LocalSource,
|
||||
NNModuleSource,
|
||||
NotNNModuleSource,
|
||||
|
|
@ -372,16 +373,9 @@ def getitem_on_dict_manager(
|
|||
)
|
||||
|
||||
|
||||
def is_grad_source(source):
|
||||
if isinstance(source, AttrSource):
|
||||
return source.member == "grad"
|
||||
return False
|
||||
|
||||
|
||||
def match_on_id_for_tensor(guard):
|
||||
return guard.originating_source.is_dict_key() and not is_grad_source(
|
||||
guard.originating_source
|
||||
)
|
||||
source = guard.originating_source
|
||||
return source.is_dict_key() and not isinstance(source, GradSource)
|
||||
|
||||
|
||||
# The ready to eval generated code (possibly multiple parts) for a guard, plus
|
||||
|
|
@ -543,6 +537,11 @@ class GuardBuilder(GuardBuilderBase):
|
|||
):
|
||||
assert base_guard_manager # to make mypy happy
|
||||
return base_guard_manager
|
||||
elif istype(source, GradSource):
|
||||
assert base_guard_manager # to make mypy happy
|
||||
return base_guard_manager.grad_manager(
|
||||
source=source_name, example_value=example_value
|
||||
)
|
||||
elif istype(source, AttrSource):
|
||||
assert base_guard_manager # to make mypy happy
|
||||
return base_guard_manager.getattr_manager(
|
||||
|
|
|
|||
|
|
@ -170,6 +170,25 @@ class AttrSource(ChainedSource):
|
|||
return f"{self.base.name()}.{self.member}"
|
||||
|
||||
|
||||
# Represents tensor.grad source. It could be represented by AttrSource as well.
|
||||
# But, we could access grad field on tensor directly in C++ without going
|
||||
# through the Python bytecodes. Therefore, we use a separate source for grad
|
||||
# field.
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class GradSource(ChainedSource):
|
||||
member: str = "grad"
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
self.base.reconstruct(codegen)
|
||||
codegen.extend_output(codegen.create_load_attrs(self.member))
|
||||
|
||||
def guard_source(self):
|
||||
return self.base.guard_source()
|
||||
|
||||
def name(self):
|
||||
return f"{self.base.name()}.{self.member}"
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class ParamBufferSource(AttrSource):
|
||||
def guard_source(self):
|
||||
|
|
|
|||
|
|
@ -9,7 +9,13 @@ from torch.utils._pytree import tree_map_only
|
|||
from ..exc import unimplemented, Unsupported
|
||||
|
||||
from ..guards import GuardBuilder, install_guard
|
||||
from ..source import AttrSource, ConstDictKeySource, GetItemSource, GlobalWeakRefSource
|
||||
from ..source import (
|
||||
AttrSource,
|
||||
ConstDictKeySource,
|
||||
GetItemSource,
|
||||
GlobalWeakRefSource,
|
||||
GradSource,
|
||||
)
|
||||
from ..utils import GLOBAL_KEY_PREFIX
|
||||
|
||||
from .base import VariableTracker
|
||||
|
|
@ -202,7 +208,7 @@ class OptimizerVariable(UserDefinedObjectVariable):
|
|||
):
|
||||
param_source = p_vt.source
|
||||
self.tensor_to_source[p] = param_source
|
||||
grad_source = AttrSource(
|
||||
grad_source = GradSource(
|
||||
param_source,
|
||||
"grad",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -2300,6 +2300,47 @@ class DictGetItemGuardAccessor : public GuardAccessor {
|
|||
PyObject* _attr_name;
|
||||
};
|
||||
|
||||
/**
|
||||
* Represents tensor.grad acccessor.
|
||||
*/
|
||||
class GradGuardAccessor : public GuardAccessor {
|
||||
public:
|
||||
GradGuardAccessor(
|
||||
RootGuardManager* root,
|
||||
py::str name,
|
||||
std::string source,
|
||||
py::handle example_value)
|
||||
: GuardAccessor(root, std::move(name), std::move(source), example_value) {
|
||||
}
|
||||
|
||||
// NB: Intentional duplication between check_nopybind and
|
||||
// check_verbose_nopybind.
|
||||
bool check_nopybind(PyObject* obj) override { // borrowed ref
|
||||
// check that its a tensor
|
||||
if (!THPVariable_CheckExact(obj) && !THPVariable_Check(obj)) {
|
||||
return false;
|
||||
}
|
||||
PyObject* grad = THPVariable_Wrap(THPVariable_Unpack(obj).grad());
|
||||
return _guard_manager->check_nopybind(grad);
|
||||
}
|
||||
|
||||
GuardDebugInfo check_verbose_nopybind(
|
||||
PyObject* obj) override { // borrowed ref
|
||||
// check that its a tensor
|
||||
if (!THPVariable_CheckExact(obj) && !THPVariable_Check(obj)) {
|
||||
return GuardDebugInfo(
|
||||
false, "not a tensor - grad field is accessed " + get_source(), 0);
|
||||
}
|
||||
PyObject* grad = THPVariable_Wrap(THPVariable_Unpack(obj).grad());
|
||||
return _guard_manager->check_verbose_nopybind(grad);
|
||||
}
|
||||
|
||||
std::string repr() const override {
|
||||
// Helpful when priting GuardManager tree structure.
|
||||
return "GradGuardAccessor(grad)";
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Represents func.__defaults__ accessor.
|
||||
*/
|
||||
|
|
@ -3169,6 +3210,21 @@ PyObject* torch_c_dynamo_guards_init() {
|
|||
py::arg("source"),
|
||||
py::arg("example_value"),
|
||||
py::return_value_policy::reference)
|
||||
// return by reference because GuardManager has the ownership of accessors
|
||||
// and guard managers
|
||||
.def(
|
||||
"grad_manager",
|
||||
[](GuardManager& self,
|
||||
std::string source,
|
||||
py::handle example_value) -> GuardManager* {
|
||||
// A unique key is used to save as the accessor key.
|
||||
py::str unique_key("__grad_accessor__");
|
||||
return self.get_child_manager<GradGuardAccessor>(
|
||||
std::move(unique_key), std::move(source), example_value);
|
||||
},
|
||||
py::arg("source"),
|
||||
py::arg("example_value"),
|
||||
py::return_value_policy::reference)
|
||||
// return by reference because C++ GuardManager has the ownership of
|
||||
// accessors and guard managers
|
||||
.def(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user