[dynamo][cpp-guards] Optimize tensor.grad accessor (#123226)

For LayoutLM model, reduces C++ guard overhead by 1.48x. These are the numbers

![image](https://github.com/pytorch/pytorch/assets/13822661/25cfc35b-b67d-4903-8403-71fa931dacdd)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123226
Approved by: https://github.com/jansel
This commit is contained in:
Animesh Jain 2024-04-02 16:59:04 -07:00 committed by PyTorch MergeBot
parent 9288b27461
commit d91db70295
4 changed files with 91 additions and 11 deletions

View File

@ -66,6 +66,7 @@ from .source import (
GlobalSource,
GlobalStateSource,
GlobalWeakRefSource,
GradSource,
LocalSource,
NNModuleSource,
NotNNModuleSource,
@ -372,16 +373,9 @@ def getitem_on_dict_manager(
)
def is_grad_source(source):
if isinstance(source, AttrSource):
return source.member == "grad"
return False
def match_on_id_for_tensor(guard):
return guard.originating_source.is_dict_key() and not is_grad_source(
guard.originating_source
)
source = guard.originating_source
return source.is_dict_key() and not isinstance(source, GradSource)
# The ready to eval generated code (possibly multiple parts) for a guard, plus
@ -543,6 +537,11 @@ class GuardBuilder(GuardBuilderBase):
):
assert base_guard_manager # to make mypy happy
return base_guard_manager
elif istype(source, GradSource):
assert base_guard_manager # to make mypy happy
return base_guard_manager.grad_manager(
source=source_name, example_value=example_value
)
elif istype(source, AttrSource):
assert base_guard_manager # to make mypy happy
return base_guard_manager.getattr_manager(

View File

@ -170,6 +170,25 @@ class AttrSource(ChainedSource):
return f"{self.base.name()}.{self.member}"
# Represents tensor.grad source. It could be represented by AttrSource as well.
# But, we could access grad field on tensor directly in C++ without going
# through the Python bytecodes. Therefore, we use a separate source for grad
# field.
@dataclasses.dataclass(frozen=True)
class GradSource(ChainedSource):
member: str = "grad"
def reconstruct(self, codegen):
self.base.reconstruct(codegen)
codegen.extend_output(codegen.create_load_attrs(self.member))
def guard_source(self):
return self.base.guard_source()
def name(self):
return f"{self.base.name()}.{self.member}"
@dataclasses.dataclass(frozen=True)
class ParamBufferSource(AttrSource):
def guard_source(self):

View File

@ -9,7 +9,13 @@ from torch.utils._pytree import tree_map_only
from ..exc import unimplemented, Unsupported
from ..guards import GuardBuilder, install_guard
from ..source import AttrSource, ConstDictKeySource, GetItemSource, GlobalWeakRefSource
from ..source import (
AttrSource,
ConstDictKeySource,
GetItemSource,
GlobalWeakRefSource,
GradSource,
)
from ..utils import GLOBAL_KEY_PREFIX
from .base import VariableTracker
@ -202,7 +208,7 @@ class OptimizerVariable(UserDefinedObjectVariable):
):
param_source = p_vt.source
self.tensor_to_source[p] = param_source
grad_source = AttrSource(
grad_source = GradSource(
param_source,
"grad",
)

View File

@ -2300,6 +2300,47 @@ class DictGetItemGuardAccessor : public GuardAccessor {
PyObject* _attr_name;
};
/**
* Represents tensor.grad acccessor.
*/
class GradGuardAccessor : public GuardAccessor {
public:
GradGuardAccessor(
RootGuardManager* root,
py::str name,
std::string source,
py::handle example_value)
: GuardAccessor(root, std::move(name), std::move(source), example_value) {
}
// NB: Intentional duplication between check_nopybind and
// check_verbose_nopybind.
bool check_nopybind(PyObject* obj) override { // borrowed ref
// check that its a tensor
if (!THPVariable_CheckExact(obj) && !THPVariable_Check(obj)) {
return false;
}
PyObject* grad = THPVariable_Wrap(THPVariable_Unpack(obj).grad());
return _guard_manager->check_nopybind(grad);
}
GuardDebugInfo check_verbose_nopybind(
PyObject* obj) override { // borrowed ref
// check that its a tensor
if (!THPVariable_CheckExact(obj) && !THPVariable_Check(obj)) {
return GuardDebugInfo(
false, "not a tensor - grad field is accessed " + get_source(), 0);
}
PyObject* grad = THPVariable_Wrap(THPVariable_Unpack(obj).grad());
return _guard_manager->check_verbose_nopybind(grad);
}
std::string repr() const override {
// Helpful when priting GuardManager tree structure.
return "GradGuardAccessor(grad)";
}
};
/**
* Represents func.__defaults__ accessor.
*/
@ -3169,6 +3210,21 @@ PyObject* torch_c_dynamo_guards_init() {
py::arg("source"),
py::arg("example_value"),
py::return_value_policy::reference)
// return by reference because GuardManager has the ownership of accessors
// and guard managers
.def(
"grad_manager",
[](GuardManager& self,
std::string source,
py::handle example_value) -> GuardManager* {
// A unique key is used to save as the accessor key.
py::str unique_key("__grad_accessor__");
return self.get_child_manager<GradGuardAccessor>(
std::move(unique_key), std::move(source), example_value);
},
py::arg("source"),
py::arg("example_value"),
py::return_value_policy::reference)
// return by reference because C++ GuardManager has the ownership of
// accessors and guard managers
.def(