diff --git a/aten/src/ATen/record_function.cpp b/aten/src/ATen/record_function.cpp index 6cb65d35af7..e17e32621a3 100644 --- a/aten/src/ATen/record_function.cpp +++ b/aten/src/ATen/record_function.cpp @@ -517,6 +517,7 @@ void RecordFunction::before(const char* name, int64_t sequence_nr) { if (!isActive()) { return; } + state_->op_input_size = state_->inputs_.size(); state_->name_ = StringView(name); state_->sequence_nr_ = sequence_nr; state_->thread_id_ = currentThreadId(); @@ -529,6 +530,7 @@ void RecordFunction::before(std::string name, int64_t sequence_nr) { if (!isActive()) { return; } + state_->op_input_size = state_->inputs_.size(); state_->name_ = StringView(std::move(name)); state_->sequence_nr_ = sequence_nr; state_->thread_id_ = currentThreadId(); diff --git a/test/test_profiler.py b/test/test_profiler.py index 8b9428ec41f..c58884ebe62 100644 --- a/test/test_profiler.py +++ b/test/test_profiler.py @@ -13,6 +13,7 @@ from torch.testing._internal.common_cuda import TEST_MULTIGPU from torch.testing._internal.common_utils import ( TestCase, run_tests, TEST_WITH_ASAN, TEST_WITH_ROCM, IS_WINDOWS, TemporaryFileName, TemporaryDirectoryName) +from torch.autograd import (_record_function_with_args_enter, _record_function_with_args_exit) from torch.autograd.profiler import profile as _profile from torch.profiler import ( kineto_available, profile, record_function, supported_activities, @@ -57,6 +58,29 @@ class TestProfilerCUDA(TestCase): self.assertTrue(not (is_increasing and max_diff > 100 * 1024), msg='memory usage is increasing, {}'.format(str(last_rss))) +class TestRecordFunction(TestCase): + def _record_function_with_param(self): + u = torch.randn(3, 4, 5, requires_grad=True) + with _profile(with_stack=True, use_kineto=kineto_available(), record_shapes=True) as prof: + with record_function("## TEST 1 ##", "1, 2, 3"): + rf_handle = _record_function_with_args_enter("## TEST 2 ##", 1, False, 2.5, [u, u], "hello", u) + _record_function_with_args_exit(rf_handle) + return prof + + def test_record_function(self): + prof_result = self._record_function_with_param() + found_test_1 = False + found_test_2 = False + for e in prof_result.function_events: + if "## TEST 1 ##" == e.name: + found_test_1 = True + self.assertTrue(e.input_shapes == [[]]) + elif "## TEST 2 ##" == e.name: + found_test_2 = True + self.assertTrue(e.input_shapes == [[], [], [], [], [], [3, 4, 5]]) + self.assertTrue(found_test_1) + self.assertTrue(found_test_2) + class TestProfiler(TestCase): def test_source(self): """Checks that source code attribution works for eager, TS and autograd mode diff --git a/torch/_C/_autograd.pyi b/torch/_C/_autograd.pyi index 7ffb618e3f0..e2a6039b281 100644 --- a/torch/_C/_autograd.pyi +++ b/torch/_C/_autograd.pyi @@ -1,6 +1,8 @@ -from typing import List, Set, Callable +from typing import List, Set, Callable, Any from enum import Enum +import torch + # Defined in tools/autograd/init.cpp class ProfilerState(Enum): @@ -86,6 +88,8 @@ def _disable_profiler() -> _ProfilerResult: ... def _profiler_enabled() -> bool: ... def _add_metadata_json(key: str, value: str) -> None: ... def kineto_available() -> bool: ... +def _record_function_with_args_enter(name: str, args: List[Any]) -> torch.Tensor: ... +def _record_function_with_args_exit(handle: torch.Tensor) -> None: ... def _supported_activities() -> Set[ProfilerActivity]: ... def _enable_record_function(enable: bool) -> None: ... def _set_empty_test_observer(is_global: bool, sampling_prob: float) -> None: ... diff --git a/torch/autograd/__init__.py b/torch/autograd/__init__.py index 513333e5cd0..a1305da8169 100644 --- a/torch/autograd/__init__.py +++ b/torch/autograd/__init__.py @@ -305,6 +305,7 @@ if not torch._C._autograd_init(): from torch._C._autograd import (DeviceType, ProfilerActivity, ProfilerState, ProfilerConfig, ProfilerEvent, _enable_profiler_legacy, _disable_profiler_legacy, _profiler_enabled, _enable_record_function, _set_empty_test_observer, kineto_available, + _record_function_with_args_enter, _record_function_with_args_exit, _supported_activities, _add_metadata_json, SavedTensor, _register_saved_tensors_default_hooks, _reset_saved_tensors_default_hooks) diff --git a/torch/autograd/profiler.py b/torch/autograd/profiler.py index c121b113635..91c8d40c0cd 100644 --- a/torch/autograd/profiler.py +++ b/torch/autograd/profiler.py @@ -423,8 +423,9 @@ class record_function(ContextDecorator): CUDA time total: 0.000us """ - def __init__(self, name: str): + def __init__(self, name: str, args: Optional[str] = None): self.name: str = name + self.args: Optional[str] = args # Whether or not we should run record function's end callbacks when exiting. self.run_callbacks_on_exit: bool = True # Stores underlying RecordFunction as a tensor. TODO: move to custom @@ -432,7 +433,7 @@ class record_function(ContextDecorator): self.handle: torch.Tensor = torch.zeros(1) def __enter__(self): - self.handle = torch.ops.profiler._record_function_enter(self.name) + self.handle = torch.ops.profiler._record_function_enter(self.name, self.args) return self def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any): diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp index 9d8550f50ec..396709af89c 100644 --- a/torch/csrc/autograd/init.cpp +++ b/torch/csrc/autograd/init.cpp @@ -6,7 +6,10 @@ #include #include #include +#include #include +#include +#include #include #include #include @@ -250,6 +253,35 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject *unused) { #endif }); + // NOTICE: These record functions are not torch operators and may not show up + // in TorchScript tracing, FX transforms, or operator serialization. For these + // use cases, please use `torch.profiler.record_function`. + // Creates a new profiling scope using RecordFunction and invokes its starting + // callbacks. + m.def("_record_function_with_args_enter", [](const std::string& name, py::args args) { + auto rec = std::make_unique(at::RecordScope::USER_SCOPE); + if (rec->isActive()) { + if (rec->needsInputs()) { + auto iv_inputs = std::vector(); + for (const auto& arg : args) { + iv_inputs.push_back(torch::jit::toTypeInferredIValue(arg)); + } + rec->before(name, iv_inputs); + } else { + rec->before(name); + } + } + return at::cpp_custom_type_hack::create(std::move(rec), at::TensorOptions()); + }); + + // Ends the profiling scope created with record_function_with_param_enter. + m.def("_record_function_with_args_exit", [](const at::Tensor& handle) { + // We don't actually need to do anything with handle just need to persist the + // lifetime until now. + auto& rec = at::cpp_custom_type_hack::cast(handle); + rec.end(); + }); + m.def("_supported_activities", []() { std::set activities {ActivityType::CPU}; #if defined(USE_KINETO) && !defined(LIBKINETO_NOCUPTI) diff --git a/torch/csrc/autograd/record_function_ops.cpp b/torch/csrc/autograd/record_function_ops.cpp index 9650c354c58..2cf427e04f6 100644 --- a/torch/csrc/autograd/record_function_ops.cpp +++ b/torch/csrc/autograd/record_function_ops.cpp @@ -16,9 +16,17 @@ namespace profiler { // Creates a new profiling scope using RecordFunction and invokes its starting // callbacks. -at::Tensor record_function_enter(const std::string& name) { +at::Tensor record_function_enter( + const std::string& name, + const c10::optional& args) { auto rec = std::make_unique(at::RecordScope::USER_SCOPE); - rec->before(name); + if (rec->isActive()) { + if (rec->needsInputs() && args.has_value()) { + rec->before(name, std::vector{c10::IValue{args.value()}}); + } else { + rec->before(name); + } + } return at::cpp_custom_type_hack::create(std::move(rec), at::TensorOptions()); } @@ -67,7 +75,7 @@ c10::intrusive_ptr _call_end_callbacks_on_fut( // Internal only, do not use directly, use Python's record_function() TORCH_LIBRARY_FRAGMENT(profiler, m) { - m.def("_record_function_enter", &record_function_enter); + m.def("_record_function_enter(str name, str? args=None) -> Tensor", &record_function_enter); m.def("_record_function_exit", &record_function_exit); } diff --git a/torch/csrc/autograd/record_function_ops.h b/torch/csrc/autograd/record_function_ops.h index bc9f2c975fc..9042537aeab 100644 --- a/torch/csrc/autograd/record_function_ops.h +++ b/torch/csrc/autograd/record_function_ops.h @@ -1,12 +1,13 @@ #pragma once #include +#include namespace torch { namespace autograd { namespace profiler { // Creates a new profiling scope using RecordFunction and invokes its starting // callbacks. -TORCH_API at::Tensor record_function_enter(const std::string& name); +TORCH_API at::Tensor record_function_enter(const std::string& name, const c10::optional& args = c10::nullopt); // Schedules RecordFunction's end callbacks to be run on completion of a future. TORCH_API c10::intrusive_ptr _call_end_callbacks_on_fut(