mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fixes #102375 Sequence_nr increments in the forward pass and decrements in the backward pass. Backward ops with the same sequence_nr as a forward op represent the backward implementation for the op. The long term goal is to make this information available to the profiler so users can observe which ops are fused by the inductor openai triton kernels. Added a test for this feature **test/dynamo/test_aot_autograd.py::AotAutogradFallbackTests::test_aot_sequence_nr**. The test case uses **aot_export_module()** to create a joint fwd/bwd fx graph. Then it walks all the nodes in fx graph using fx_graph.graph.nodes. The seq_nr of each node is recorded in node.meta. During the fwd pass the seq_nr increments and it decrements during the bwd pass. This allows the user to map forward ops to their corresponding bwd ops which is useful for performance analysis. Expected output from the test case. SeqNr|OrigAten|SrcFn 0|aten.convolution.default|l__self___conv1 0|aten.add.Tensor|l__self___bn1 1|aten._native_batch_norm_legit_functional.default|l__self___bn1 2|aten.relu.default|l__self___relu1 3|aten.add.Tensor|add 4|aten.view.default|flatten 5|aten.t.default|l__self___fc1 6|aten.unsqueeze.default|l__self___fc1 7|aten.mm.default|l__self___fc1 8|aten.squeeze.dim|l__self___fc1 9|aten.add.Tensor|l__self___fc1 10|aten.sub.Tensor|l__self___loss_fn 11|aten.abs.default|l__self___loss_fn 12|aten.mean.default|l__self___loss_fn 12|aten.ones_like.default| 12|aten.expand.default| 12|aten.div.Scalar| 11|aten.sgn.default| 11|aten.mul.Tensor| 8|aten.unsqueeze.default| 7|aten.t.default| 7|aten.mm.default| 7|aten.t.default| 7|aten.t.default| 7|aten.mm.default| 6|aten.squeeze.dim| 5|aten.t.default| 4|aten.view.default| 2|aten.threshold_backward.default| 1|aten.native_batch_norm_backward.default| 0|aten.convolution_backward.default| 0|aten.add.Tensor| Pull Request resolved: https://github.com/pytorch/pytorch/pull/103129 Approved by: https://github.com/soulitzer
103 lines
4.1 KiB
C++
103 lines
4.1 KiB
C++
#pragma once
|
|
|
|
#include <torch/csrc/python_headers.h>
|
|
#include <memory>
|
|
#include <typeinfo>
|
|
|
|
#include <torch/csrc/Exceptions.h>
|
|
#include <torch/csrc/autograd/function.h>
|
|
#include <torch/csrc/utils/object_ptr.h>
|
|
|
|
namespace torch {
|
|
namespace autograd {
|
|
|
|
struct THPCppFunction {
|
|
PyObject_HEAD std::shared_ptr<Node> cdata;
|
|
};
|
|
|
|
template <typename Ctor>
|
|
PyObject* CppFunction_pynew(
|
|
PyTypeObject* type,
|
|
PyObject* args,
|
|
PyObject* kwds) {
|
|
THPObjectPtr obj(type->tp_alloc(type, 0));
|
|
if (!obj)
|
|
return nullptr;
|
|
THPCppFunction* f = (THPCppFunction*)obj.get();
|
|
HANDLE_TH_ERRORS
|
|
new (&f->cdata) std::shared_ptr<Node>(Ctor()(args));
|
|
END_HANDLE_TH_ERRORS
|
|
if (!f->cdata) {
|
|
return nullptr;
|
|
}
|
|
return obj.release();
|
|
}
|
|
|
|
#define THP_FUNCTION_DEFAULT_METHODS \
|
|
{(char*)"_register_hook_dict", \
|
|
THPCppFunction_register_hook_dict, \
|
|
METH_O, \
|
|
nullptr}, \
|
|
{(char*)"register_hook", THPCppFunction_register_hook, METH_O, nullptr}, \
|
|
{(char*)"register_prehook", \
|
|
THPCppFunction_register_prehook, \
|
|
METH_O, \
|
|
nullptr}, \
|
|
{(char*)"name", THPCppFunction_name, METH_NOARGS, nullptr}, { \
|
|
(char*)"_sequence_nr", THPCppFunction_sequence_nr, METH_NOARGS, nullptr \
|
|
}
|
|
|
|
#define THP_FUNCTION_DEFAULT_PROPERTIES \
|
|
{(char*)"next_functions", \
|
|
THPCppFunction_next_functions, \
|
|
nullptr, \
|
|
nullptr, \
|
|
nullptr}, \
|
|
{(char*)"requires_grad", \
|
|
THPCppFunction_requires_grad, \
|
|
nullptr, \
|
|
nullptr, \
|
|
nullptr}, \
|
|
{ \
|
|
(char*)"metadata", THPCppFunction_metadata, nullptr, nullptr, nullptr \
|
|
}
|
|
|
|
PyObject* THPCppFunction_next_functions(PyObject* self, void* _unused);
|
|
PyObject* THPCppFunction_metadata(PyObject* self, void* _unused);
|
|
PyObject* THPCppFunction_requires_grad(PyObject* self, void* _unused);
|
|
PyObject* THPCppFunction_register_hook_dict(PyObject* self, PyObject* _var);
|
|
PyObject* THPCppFunction_register_hook(PyObject* self, PyObject* hook);
|
|
PyObject* THPCppFunction_register_prehook(PyObject* self, PyObject* hook);
|
|
|
|
PyObject* THPCppFunction_name(PyObject* self, PyObject* noargs);
|
|
PyObject* THPCppFunction_sequence_nr(PyObject* self, PyObject* noargs);
|
|
|
|
PyTypeObject* _initFunctionPyTypeObject(
|
|
PyTypeObject& type,
|
|
const char* name,
|
|
PyGetSetDef* function_properties,
|
|
PyMethodDef* function_methods);
|
|
|
|
PyObject* registerFunctionHook(Node& fn, PyObject* hook);
|
|
|
|
PyObject* registerFunctionPreHook(Node& fn, PyObject* hook);
|
|
|
|
template <typename Ctor>
|
|
PyTypeObject* createForwardFunctionPyTypeObject(
|
|
PyTypeObject& type,
|
|
const char* name,
|
|
PyGetSetDef* function_properties = nullptr,
|
|
PyMethodDef* function_methods = nullptr) {
|
|
type.tp_new = &CppFunction_pynew<Ctor>;
|
|
return _initFunctionPyTypeObject(
|
|
type, name, function_properties, function_methods);
|
|
}
|
|
|
|
void registerCppFunction(const std::type_info& type, PyTypeObject* pytype);
|
|
PyObject* functionToPyObject(const std::shared_ptr<Node>& cdata);
|
|
|
|
bool THPCppFunction_Check(PyObject* obj);
|
|
|
|
} // namespace autograd
|
|
} // namespace torch
|