Commit Graph

66 Commits

Author SHA1 Message Date
Shivam Raikundalia
8486d3df69 [Profiler] Hide ProfilerStep Alignment behind Experimental Config (#137668)
Summary: Aligning ProfilerStep# annotation can be useful for visual purposes but it affects downstream tools like HTA to misreport how long each step took. For this reason, lets give users the option to turn on this alignment manually but also turn it off by default

Test Plan:
Alignment off:

https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/dynocli/devvm2185.cco0.facebook.com/rank-0.Oct_09_16_11_48.2543945.pt.trace.json.gz&bucket=gpu_traces

Alignment on:

https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/dynocli/devvm2185.cco0.facebook.com/rank-0.Oct_09_16_08_27.2518391.pt.trace.json.gz&bucket=gpu_traces

Differential Revision: D64146115

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137668
Approved by: https://github.com/aaronenyeshi
2024-10-11 22:57:05 +00:00
Xuehai Pan
8962610247 [BE][clang-format] make macro PyObject_HEAD_INIT(type) and PyVarObject_HEAD_INIT(type, size) have its own line (#136949)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136949
Approved by: https://github.com/albanD, https://github.com/eqy
ghstack dependencies: #136945
2024-10-02 18:39:22 +00:00
Xuehai Pan
89c37be6b7 [BE][clang-format] make macro PyObject_HEAD have its own line (#136945)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136945
Approved by: https://github.com/albanD
2024-10-02 18:39:21 +00:00
Shivam Raikundalia
9ffcca7060 [Profiler] Handle Tensor Sizes/Strides Parsing Error (#134862)
Summary:
Currently some jobs are encountering the following trace, P1539415198. This suggests that when we are parsing through tensors the path is prone to encountering an invalid address. This is is possibly occurring because for some reason the sizes() and strides() of a Tensor seem to not be of the same dimensions. We assume such when iterating through the shapes to get the Ivalue generator. When browsing some of the tensor implementations, I found that some of the size and stride paths are different which could be the cause of this issue. Regardless, the profiler should be flexible enough to handle such issues without bringing down the whole main thread.

If the crashes still persist, it will still give us a data point as to where they are occurring and we can rule out the strides/sizes as the culprit

Test Plan: This change doesn't break anything in the happy path, just makes sure the bad path is not exited abruptly. We should use this in order to debug what the events are having mismatching dimensions between sizes and strides.

Differential Revision: D62008788

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134862
Approved by: https://github.com/aaronenyeshi
2024-09-03 23:46:38 +00:00
cyy
f4dcf2ae93 [1/N] Change #include <c10/util/Optional.h> to #include <optional> (#128301)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128301
Approved by: https://github.com/ezyang, https://github.com/r-barnes
2024-07-08 07:03:53 +00:00
PyTorch MergeBot
846bb30e13 Revert "[1/N] Change #include <c10/util/Optional.h> to #include <optional> (#128301)"
This reverts commit bd72e28314.

Reverted https://github.com/pytorch/pytorch/pull/128301 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it fails XLA build bd72e28314. Please rebase your PR before relanding because I think the failure is hidden by an unrelated broken trunk XLA failure from your current base commit ([comment](https://github.com/pytorch/pytorch/pull/128301#issuecomment-2169035822))
2024-06-15 01:58:20 +00:00
cyy
bd72e28314 [1/N] Change #include <c10/util/Optional.h> to #include <optional> (#128301)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128301
Approved by: https://github.com/ezyang
2024-06-14 23:21:01 +00:00
cyy
e2a72313e8 Concat namespaces of torch/csrc/profiler code and other fixes (#128606)
Improve namespaces and modernize codebase of torch/csrc/profiler code.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128606
Approved by: https://github.com/Skylion007, https://github.com/aaronenyeshi
2024-06-13 16:46:34 +00:00
FEI
b950217f19 Support third-party devices emit a range for each autograd operator (#125822)
Fixes #125752

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125822
Approved by: https://github.com/aaronenyeshi
2024-05-15 05:06:24 +00:00
zdevito
352a893b0c Fast standalone symbolize for unwinding (#123966)
We've had issues using addr2line. On certain versions of
CentOS it is on a version that has a performance regression making it very slow,
and even normallly it is not that fast, taking several seconds even when parallelized
for a typical memory trace dump.

Folly Symbolize or LLVMSymbolize are fast but it requires PyTorch take a dependency on those libraries to do this, and given the number of environments we run stuff in, we end up hitting cases where we fallback to slow addr2line behavior.

This adds a standalone symbolizer to PyTorch similar to the unwinder which has
no external dependencies and is ~20x faster than addr2line for unwinding PyTorch frames.

I've tested this on some memory profiling runs using all combinations of {gcc, clang} x {dwarf4, dwarf5} and it seems to do a good job at getting line numbers and function names right. It is also careful to route all reads of library data through the `CheckedLexer` object, which ensure it is not reading out of bounds of the section. Errors are routed through UnwindError so that those exceptions get caught and we produce a ?? frame rather than crash. I also added a fuzz test which gives all our symbolizer options random addresses in the process to make sure they do not crash.

Differential Revision: [D56828968](https://our.internmc.facebook.com/intern/diff/D56828968)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123966
Approved by: https://github.com/ezyang, https://github.com/aaronenyeshi
2024-05-14 19:39:17 +00:00
albanD
b119e1bcc2 Fix refcount handling for dtype, layout and memory format (#125271)
Finish fixing https://github.com/pytorch/pytorch/issues/124868
re-use our wrap() utils as much as possible and NewRef in other places.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125271
Approved by: https://github.com/colesbury
2024-05-02 02:34:34 +00:00
PyTorch MergeBot
c0fd7894cc Revert "Fast standalone symbolize for unwinding (#123966)"
This reverts commit 772ae6da1e.

Reverted https://github.com/pytorch/pytorch/pull/123966 on behalf of https://github.com/jeanschmidt due to Breaking internal builds, check D56522678 ([comment](https://github.com/pytorch/pytorch/pull/123966#issuecomment-2076821043))
2024-04-25 10:04:48 +00:00
Florian
7ad6dc2cf3 [Profiler][PrivateUse1] Profiler support PrivateUse1 key (#124818)
Summary:
1.Package public headers of kineto if USE_KINETO so that they can be used by PrivateUse1 user.
2.Add PrivateUse1 key to ActivityType.
3. Support PrivateUse1 key in function deviceTypeFromActivity and _supported_activities.
4. Fix some bugs when processing profiler results.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124818
Approved by: https://github.com/aaronenyeshi
2024-04-24 18:52:08 +00:00
zdevito
772ae6da1e Fast standalone symbolize for unwinding (#123966)
We've had issues using addr2line. On certain versions of
CentOS it is on a version that has a performance regression making it very slow,
and even normallly it is not that fast, taking several seconds even when parallelized
for a typical memory trace dump.

Folly Symbolize or LLVMSymbolize are fast but it requires PyTorch take a dependency on those libraries to do this, and given the number of environments we run stuff in, we end up hitting cases where we fallback to slow addr2line behavior.

This adds a standalone symbolizer to PyTorch similar to the unwinder which has
no external dependencies and is ~20x faster than addr2line for unwinding PyTorch frames.

I've tested this on some memory profiling runs using all combinations of {gcc, clang} x {dwarf4, dwarf5} and it seems to do a good job at getting line numbers and function names right. It is also careful to route all reads of library data through the `CheckedLexer` object, which ensure it is not reading out of bounds of the section. Errors are routed through UnwindError so that those exceptions get caught and we produce a ?? frame rather than crash. I also added a fuzz test which gives all our symbolizer options random addresses in the process to make sure they do not crash.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123966
Approved by: https://github.com/ezyang
2024-04-23 15:27:18 +00:00
PyTorch MergeBot
36f6928a37 Revert "[Profiler][PrivateUse1] Profiler support PrivateUse1 key (#120556)"
This reverts commit 41613a0803.

Reverted https://github.com/pytorch/pytorch/pull/120556 on behalf of https://github.com/aaronenyeshi due to Breaks GPU Chrome trace UI ([comment](https://github.com/pytorch/pytorch/pull/120556#issuecomment-2061578951))
2024-04-17 15:38:14 +00:00
Florian
41613a0803 [Profiler][PrivateUse1] Profiler support PrivateUse1 key (#120556)
Summary:
1.Package public headers of kineto if USE_KINETO so that they can be used by PrivateUse1 user.
2.Add PrivateUse1 key to ActivityType.
3. Support PrivateUse1 key in function deviceTypeFromActivity and _supported_activities.
4. Fix some bugs when processing profiler results.
Co-authored-by: albanD <desmaison.alban@gmail.com>
Co-authored-by: Aaron Shi <enye.shi@gmail.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120556
Approved by: https://github.com/aaronenyeshi
2024-04-12 14:28:19 +00:00
Shivam Raikundalia
c9c099b271 Add kwargs to RecordFunctionFast (#123600)
Differential Revision: [D55897888](https://our.internmc.facebook.com/intern/diff/D55897888/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123600
Approved by: https://github.com/davidberard98
2024-04-10 18:17:50 +00:00
sraikund16
6fa72480d3 Enhance RecordFunctionFast input args and use input args in triton_heuristics.py (#123459)
Summary: Now that we can input shapes as input args for RecordFunctionFast, let's add that to the triton heuristics. Also, lets add the ability to pass in a tuple into the RecordFunctionFast constructor.

Test Plan:
Ran both the _inductor/test_profile.py and profiler/test_profiler.py unit tests. Also added tuple based unit test to profiler/test_profiler.py

Ran record_function_fast.py from the following branch
https://github.com/pytorch/pytorch/compare/sraikund/record_funct_test?expand=1

No shape or args: tests function fast with no args and profile without record_shapes
With shape tests: tests function fast with args and profile with record_shapes true
Args no shape: tests function fast with args inputted but record_shapes set to false
Args shape tuple: tests function fast with args inputted in form of tuple and record_shapes true

Stdout:

No shape or args:: 1.8491458892822266 us
With shape:: 2.211381196975708 us
Args no shape:: 1.9212646484375 us
With shape tuple:: 2.245788335800171 us

Differential Revision: D55809967

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123459
Approved by: https://github.com/davidberard98
2024-04-06 02:44:06 +00:00
Shivam Raikundalia
4732375042 make RecordFunctionFast take inputs (#123208)
Summary: RECORD_FUNCTION in C++ and torch.profiler.record_function already support recording inputs. Let's do the same for RecordFunctionFast.

Test Plan: Add tests in test_profiler.py that take args and also do not take args so we can support it being an optional parameter

Differential Revision: D55648870

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123208
Approved by: https://github.com/davidberard98
2024-04-03 21:58:09 +00:00
cyy
ff82dcd8fa [2/N] Enable clang-tidy checks in torch/csrc/profiler (#113439)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113439
Approved by: https://github.com/Skylion007
2023-11-14 00:39:54 +00:00
cyy
41e8632ca4 [1/N] Fix clang-tidy warnings in torch/csrc/profiler (#112360)
This PR fixes some clang-tidy warnings in torch/csrc/profiler

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112360
Approved by: https://github.com/ezyang
2023-11-10 07:37:23 +00:00
cyy
168f516fae [3/N] Move c10::variant to std::variant (#110141)
This PR moves more c10::variant calls to std::variant

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110141
Approved by: https://github.com/Skylion007
2023-09-28 18:43:55 +00:00
David Berard
614b865721 [profiler] _RecordFunctionFast - faster python bindings for record_function (#107195)
torch.profiler.record_function is relatively slow; for example, in some benchmarks I was running, x.view_as(x) was ~2us, and ~16-17us when wrapped in a record_function context. The reasons for this are: dispatcher overhead from going through an op (the main source of overhead), python binding / python conversion overhead, and some overhead from the context manager.

This new implementation is faster, but it won't work with torchscript. Based on the benchmarks I was running, it adds 0.5-0.7us overhead per call when the profiler is turned off. To use it, you can just:

```python
with torch._C._profiler_manual._RecordFunctionFast("title"):
    torch.add(x, y)
```

It implements a context manager in python which directly calls the record_function utilities, instead of calling through an op.
* The context manager is implemented directly in python because the overhead from calling a python function seems non-negligible
* All the record_function calls, python object conversions are guarded on checks for whether the profiler is enabled or not. It seems like this saves a few hundred nanoseconds.

For more details about the experiments I ran to choose this implementation, see [my record_functions experiments branch](https://github.com/pytorch/pytorch/compare/main...davidberard98:pytorch:record-function-fast-experiments?expand=1).

This also adds a `torch.autograd.profiler._is_profiler_enabled` global variable that can be used to check whether a profiler is currently enabled. It's useful for further reducing the overhead, like this:

```python
if torch.autograd.profiler._is_profiler_enabled:
    with torch._C._profiler_manual._RecordFunctionFast("title"):
        torch.add(x, y)
else:
    torch.add(x, y)
```

On BERT_pytorch (CPU-bound model), if we add a record_function inside CachedAutotuning.run:
* Naive torch.profiler.record_function() is a ~30% slowdown
* Always wrapping with RecordFunctionFast causes a regression of ~2-4%.
* Guarding with an if statement - any regression is within noise

**Selected benchmark results**: these come from a 2.20GHz machine, GPU build but only running CPU ops; running `x.view_as(x)`, with various record_functions applied (with profiling turned off). For more detailed results see "record_functions experiments branch" linked above (those results are on a different machine, but show the same patterns). Note that the results are somewhat noisy, assume 0.05-0.1us variations

```
Baseline:: 1.7825262546539307 us  # Just running x.view_as(x)
profiled_basic:: 13.600390434265137 us  # torch.profiler.record_function(x) + view_as
precompute_manual_cm_rf:: 2.317216396331787 us  # torch._C._profiler_manual._RecordFunctionFast(), if the context is pre-constructed + view_as
guard_manual_cm_rf:: 1.7994389533996582 us  # guard with _is_profiler_enabled + view_as
```

Differential Revision: [D48421198](https://our.internmc.facebook.com/intern/diff/D48421198)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107195
Approved by: https://github.com/albanD, https://github.com/aaronenyeshi
2023-08-22 18:48:30 +00:00
Edward Z. Yang
d5f7df3b8a Hand bind CapturedTraceback (#107438)
I do this instead of pybind11 because I need a custom tp_dealloc to promptly free PyFrames. I also add GC traverse/clear support. This is required to avoid leaking memory from co_extra on code objects in some obscure situations. This is indirectly tested by #107388

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107438
Approved by: https://github.com/albanD
2023-08-18 19:05:52 +00:00
Brian Coutinho
8d9c8897ed [profiler] add option for kineto synchronization events in the trace (#105187)
Summary:
## About Sync Events
For CUDA profiling mode, we can enable tracing CUDA synchronization events.
* This feature captures synchronization events in CUDA including 1) context/device sync, 2) stream sync, 3) CUDA event sync, 4) CUDA stream wait event (inter stream synchronization). Read more
* We add this flag using the profiler's experimental config option.
* This PR relies on 7b003638c6 change in pytorch/kineto

## Usage
Just set the `enable_cuda_sync_events` option in `_ExperimentalConfig`
```
from torch.autograd.profiler import profile, _ExperimentalConfig
with profile(use_kineto=True, use_cuda=True,
   experimental_config=_ExperimentalConfig(enable_cuda_sync_events=True),
) as prof:
   workload()
```

**Please wait for PyTorch github repo to point to 7b003638c6 or later commit in Kineto**

Test Plan:
## Unit Test
Added a unit test

  buck2 test mode/dev-nosan caffe2/test:profiler --local-only -- test_profiler_cuda_sync_events
  Tests finished: Pass 1. Fail 0. Fatal 0. Skip 0. Build failure 0
ttps://www.internalfb.com/intern/testinfra/testrun/281475298097379

Reviewed By: davidberard98

Differential Revision: D46244591

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105187
Approved by: https://github.com/aaronenyeshi
2023-07-26 03:45:04 +00:00
Louis Feng
5847cb55e4 [PyPer][ET] Refactor EG to ET (#99694)
Summary:
Change execution graph to execution trace.
See post: https://fb.workplace.com/groups/873291503156329/permalink/1529496217535851/

Test Plan: Run a job.

Reviewed By: chaekit

Differential Revision: D44121392

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99694
Approved by: https://github.com/chaekit
2023-06-22 19:41:54 +00:00
dujinhang
2e8ce910bb [Profiler][1/N] add profiler support for custom device. (#101554)
1. `torch.autograd.profiler` interface parameters changed. (use `self.use_device` instead of `self.use_cuda` facilitates access by other devices and integrate it in subsequent pr)
2. Modify `ProfilerEventStub`(aka `std::shared_ptr<CUevent_st>`) to `ProfilerVoidEventStub`(aka `std::shared_ptr<void>`) so that `ProfilerStubs` can be inherited by any `{device}Methods`.
In addition, `cuda_event_start_` is renamed to `device_event_start_` , cuda and other devices can use this event pointer if needed.
4. custom device support using legacy profiling(add `ProfilerState::KINETO_PRIVATEUSE1_FALLBACK` option)
5. add `privateuse1Stubs` register
(parse results and test cases are added in subsequent pr)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101554
Approved by: https://github.com/aaronenyeshi
2023-06-02 09:19:19 +00:00
David Berard
5324124eac [profiler] Reintroduce forward-backward links (#102424)
**TL;DR:** This re-introduces links between backward kernels and their corresponding forward kernels.

<img width="1020" alt="Screenshot 2023-05-26 at 7 25 22 PM" src="https://github.com/pytorch/pytorch/assets/5067123/02571b59-859c-4c9e-b3ef-121ef3159812">

In the example above, you can see there are two such flows - one for aten::add, and one for aten::binary_cross_entropy

### Details

Forward/backward links were added in https://github.com/pytorch/pytorch/pull/62553, but then disabled in https://github.com/pytorch/pytorch/pull/72904 due to segfaults (e.g. https://github.com/pytorch/pytorch/issues/69443).

Between now and when the fwd-bwd links were disabled, there's been a lot of refactoring; so this PR updates the implementation:
* Use a raw profiler::impl::Result instead of a KinetoEvent
* Move the implementation to collection.cpp, where the TraceWrapper is currently handled.
* Sort the events before processing, because they aren't always in chronological order
* There can now be more than one event in the backward pass that matches the sequenceNr-threadID pair. The implementation needed to be updated to avoid showing multiple endpoints for a given sequenceNr-threadID pair ([ptr to where the bwd sequenceNr-threadID pair is duplicated](6e3e3dd477/torch/csrc/profiler/collection.cpp (L398-L399))).

Next, we need to verify that https://github.com/pytorch/pytorch/issues/69443 is fixed. Running the repro no longer errors. Looking further into the details of the issue it seems like the handling of the [raw linkedActivity pointer (old code from 2021)](6089dcac48/libkineto/src/output_json.cpp (L283)) resulted in the segfault. Now, it doesn't look like the linked activity is used anywhere in output_json.cpp so the issue should be fixed.

### Testing

#### 1. unit test
`test_profiler_fwd_bwd_link` was un-skipped. It was modified to match the new implementation.

#### 2. https://github.com/pytorch/pytorch/issues/69443

I ran the repro in https://github.com/pytorch/pytorch/issues/69443 and verified there were no segfaults.

#### 3. Duplicate flow IDs

When forward-backward connections were first introduced, gpu-cpu async links had not been introduced. There's a possibility that gpu-cpu links and fwd-bwd links could interfere if their IDs overlap.

I manually tested this in chrome://tracing; I edited a file so that a gpu-cpu link had the same ID as one of the fwd-bwd connections. The chrome tracing UI continued showing both types of links.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102424
Approved by: https://github.com/aaronenyeshi
2023-05-31 02:50:38 +00:00
David Berard
935100cbde [profiler] When record_inputs=True, record scalar lists of length <= 30 (#100593)
Many ops take as inputs scalars or scalar lists which are important to understand the properties of the op. For example, convolution ops' behavior and output shapes often depend on padding and strides, which are provided as scalars of lists of scalars. This will record scalar lists when record_inputs=True.

Details:
During collection (and this was true before this PR as well), we serialize values and tensor metadata into an InputOutputEncoder. After collection occurs, we deserialize these values to attach the information to each of the events.

This PR does this:
- Adds support for serializing scalar lists during collection / serialization
- Adds an extra field called "Concrete Args"
- Splits up the deserialization process into two steps - one for generating "input shapes" and one for generating "concrete args". We split up input shapes and concrete args to avoid interrupting any previous workflows that relied on the specific data in the input shapes category; additionally, it's just a better description. Note that single scalars will remain in the "input shapes" category as they were already in that category in the past.

Differential Revision: [D45798431](https://our.internmc.facebook.com/intern/diff/D45798431)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100593
Approved by: https://github.com/aaronenyeshi
2023-05-16 07:58:46 +00:00
Richard Li
c523d7d899 Add a new hook (#99854)
Differential Revision: D45220984

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99854
Approved by: https://github.com/albanD
2023-04-26 23:00:38 +00:00
Aaron Enye Shi
237f917f5b [Profiler][Easy] Fix typo in Profiler report input shapes (#99430)
Summary:
There are two variables for profiler input shapes:
- In C++ interface: report_input_shapes
- In Python interface: record_shapes

Therefore record_input_shapes is a typo. We should also look to reducing redundant naming between the two.

Test Plan: CI

Pulled By: aaronenyeshi

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99430
Approved by: https://github.com/davidberard98
2023-04-19 21:50:52 +00:00
Zachary DeVito
1c83888be8 [memory profiling] show pre-existing memory in trace_plot (#97590)
Previously we only plotted memory if it was allocated or freed while
trace recording was active. This change also adds any pre-existing blocks
to the visualization. This helps because it is common to enable trace recording
later and then not realize that there is a lot of allocated memory in
the trace eventhough a lot was allocated beforehad.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97590
Approved by: https://github.com/eellison
2023-03-28 16:31:10 +00:00
Zachary DeVito
e74f70d212 Revert "Revert "[memory profiling] add a facility to gather combined C++/Python/TorchScript stack traces. (#95541)"" (#96878)
This reverts commit e1ea584b1c.
Adds __has_include check to fix fbcode build.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96878
Approved by: https://github.com/ezyang
2023-03-16 04:12:54 +00:00
PyTorch MergeBot
e1ea584b1c Revert "[memory profiling] add a facility to gather combined C++/Python/TorchScript stack traces. (#95541)"
This reverts commit 4e1060c609.

Reverted https://github.com/pytorch/pytorch/pull/95541 on behalf of https://github.com/DanilBaibak due to breaking internal builds
2023-03-15 13:28:41 +00:00
Zachary DeVito
4e1060c609 [memory profiling] add a facility to gather combined C++/Python/TorchScript stack traces. (#95541)
This refactors the stack trace facility specific to memory profiling
    in python+cuda to make a generic facility to generate combined stack
    traces.

    The generic facility (combined_traceback.h) does not require
    python to be around to work, but will return python stacks if it is
    present.

    This facility is then used to add support for stack trace gathering in memory profiling that
    happens directly from C++.

    It is also used to expose a python API for gathering and symbolizing
    combineds stacks.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95541
Approved by: https://github.com/ezyang
2023-03-14 18:26:05 +00:00
Xunsong, Huang
b053a0f2ba [XPU][Profiler] Add API support for XPU profiler to Kineto path (#94502)
This patch is aimed to add support to XPU profiler which will co-work with Kineto. After this PR, kineto will follow these API to fit itself. Also, the development of interface in python is near done.

Signed-off-by: Huang, Xunsong <xunsong.huang@intel.com>

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94502
Approved by: https://github.com/ezyang
2023-03-10 12:17:14 +00:00
Salil Desai
193068cbcf [Vulkan + Profiler] Enable Processing Vulkan Events in Profiler (#90852)
@bypass-github-export-checks

This diff enables passing processing events in the profiler. Passing the events from QueryPool, and making sure vulkan events align with parent CPU events correctly will be handled later in this diff stack.

This diff was made by forking Taylor's scaffolding diff, D39779878, with a few changes:
- Rebasing + resolving merge conflicts
- Fixing (i.e. removing) auto import of profiler/containers.h
- Changing the activity type to CPU_OP which makes the vulkan events appear on chrometrace
- Moving timestamp adjustment scaffolding to D39893109

Differential Revision: [D39834805](https://our.internmc.facebook.com/intern/diff/D39834805/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90852
Approved by: https://github.com/mcr229
2022-12-19 19:54:32 +00:00
Taylor Robie
8023c9dc64 [Profiler] Memory profiler part 3: Schema parsing and mutable arguments (#86854)
The appropriate annotation for a block of memory is a function of time: an input can be mutated in-place to become an activation, a clever kernel might steal the memory of a detached input (such as a mask) to use as output memory, etc.

We could pessimistically assume that all ops mutate all of their inputs, however inspection of schema allows us to significantly narrow that assumption with minimal effort. Checking schemas also allows us to distinguish between dispatcher ops (which have load bearing semantics) and user annotations with reasonably high precision.

Differential Revision: [D40220390](https://our.internmc.facebook.com/intern/diff/D40220390/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86854
Approved by: https://github.com/chaekit
2022-11-15 19:17:57 +00:00
Taylor Robie
cef13ebea0 [Profiler] Memory profiler part 1: Gradient identification (#86802)
There are multiple ways to indentify that a Tensor is a gradient. (A subset of which also give additional context.) So to start off I've made a utility to handle that determination.

Differential Revision: [D39920730](https://our.internmc.facebook.com/intern/diff/D39920730/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86802
Approved by: https://github.com/chaekit
2022-11-08 23:53:13 +00:00
Taylor Robie
6e6f929b2c [Profiler] Restructure inputs and capture TensorLists. (#87825)
This PR unifies and rationalizes some of the input representation in Result. The current approach of storing separate types in separate vectors is tedious for two types (Tensors and scalars), but would be even more annoying with the addition of TensorLists. A similar disconnection exists with sizes and strides which the user is also expected to zip with tensor_metadata.

I simplified things by moving inputs to a variant and moving sizes and strides into TensorMetadata. This also forced collection of sizes and strides in python tracer which helps to bring it in line with op profiling. Collection of TensorLists is fairly straightforward; `InputOutputEncoder` already has a spot for them (I actually collected them in the original TorchTidy prototype) so it was just a matter of plumbing things through.

Differential Revision: [D40734451](https://our.internmc.facebook.com/intern/diff/D40734451/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87825
Approved by: https://github.com/slgong-fb, https://github.com/chaekit
2022-11-08 21:48:43 +00:00
Taylor Robie
e132c45fd0 [Profiler] Handle ABA for TensorImpl* when assigning IDs (#87133)
Part of the current ID assingment algorithm groups any Storages which are associated with the same TensorImpl*. This isn't sound (which I knew but deferred until it actually became a problem) because pointers can be reused by different objects. (ABA problem)

ABA is easy to handle for Storage because we see allocations and frees, but ~TensorImpl is very hot and cannot tolerate profiling code without significant increases in overhead.

This PR narrows the conditions under which ID assignment will join on TensorImpl*. Two storages which are associated with the same TensorImpl* are grouped IFF they were live at the same time. (Note that this still allows storages with disjoint lifetimes to be joined transitively through a third storage which overlaps with both.)

The need for this PR arose in memory profiling. The Python argument parser creates short lived Tensors for (some) scalar arguments which triggers this issue. (Which is stochastic and platform dependent since optimizations like reusing recently freed allocations is implementation defined.) Spurious connections can lead to confusing and long range interactions when building up the memory profile, so it makes sense to harden ID assignment to avoid any issues.

Differential Revision: [D40445121](https://our.internmc.facebook.com/intern/diff/D40445121/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87133
Approved by: https://github.com/slgong-fb, https://github.com/chaekit
2022-11-08 21:48:43 +00:00
Digant Desai
dcbcf5b90e [profiler] Expose experimental performance events to python (#87905)
Reports total counts (includes time spent in all children), self counts can be calculated manully.

Differential Revision: [D40282770](https://our.internmc.facebook.com/intern/diff/D40282770/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87905
Approved by: https://github.com/SS-JIA
2022-11-02 14:54:15 +00:00
Taylor Robie
b16b5fb802 [Profiler] Hold weak reference to prevent TensorImpl address reuse during profiling. (#87244)
A recurring problem with assigning Tensor IDs is that we want to preserve identity when storage changes but we don't observe TensorImpl destruction so identity assignment is not robust to the ABA problem with respect to TensorImpl*. ~TensorImpl is far too hot to instrument; even adding a call to a no-op function in a different compilation unit increases overhead by tens of percent. (OSS builds do not have any sort of LTO.)

Fortunately there is a solution. A PyTorch Tensor is a `c10::intrusive_ptr<c10::TensorImpl>`, which in turn holds a storage. (Which is a `c10::intrusive_ptr<c10::StorageImpl>`) `c10::intrusive_ptr` has a `c10::weak_intrusive_ptr` class for taking non-owning references to the underlying object. The implementation involves both a strong refcount and weak refcount in `c10::intrusive_ptr`. If the strong refcount of an intrusive_ptr goes to zero and there are no weak references then everything is deleted. However if there is a weak reference then the intrusive_ptr calls `release_resources()` but not delete.

This has the effect of freeing the underlying resources (ensuring that program semantics are unchanged) but leaves behind an empty shell of an `intrusive_ptr` that the `weak_intrusive_ptr`s use to check status. And herein lies the solution: as long as we hold a weak reference to a TensorImpl we will block deletion and prevent the `TensorImpl*` from being reused.

This PR uses a `c10::weak_intrusive_ptr<c10::TensorImpl>` to store the address of profiled TensorImpls and then converts it to a raw pointer (or rather, a `TensorImplAddress`) during post processing when we no longer care about blocking address reuse.

Differential Revision: [D40492848](https://our.internmc.facebook.com/intern/diff/D40492848/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87244
Approved by: https://github.com/slgong-fb, https://github.com/albanD
2022-10-27 06:38:11 +00:00
Taylor Robie
5ec03fc17a [Profiler][Trivial] Add Module cls and self bindings and type_caster macro (#86755)
Just a bit of clean up. We will need `self` and `cls` for memory profiling, and the type_caster specializations were getting quite verbose.

Differential Revision: [D39920728](https://our.internmc.facebook.com/intern/diff/D39920728/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86755
Approved by: https://github.com/slgong-fb, https://github.com/aaronenyeshi
2022-10-23 19:23:44 +00:00
Taylor Robie
be2d647ea6 [Profiler] Use parameter as key for optimizer state recording. (#86753)
While optimizer can store state however it likes, in practice most optimizer state corresponds to a particular parameter. (This is the case for all `torch.optim` optimizers.) Thus, it turns out to be ergonomic to collect using that structure. Note that this doesn't lock us into anything; we can always collect state with non Tensor keys if the use case arises.

One simplification that arises is that Module and Optimizer collection has very similar structure. So similar, in fact, that it is possible to use a common template for config. I also found that a lot of the `check_and_store` logic could be simplified and inlined by this joining of collected optimizer state.

Differential Revision: [D40210703](https://our.internmc.facebook.com/intern/diff/D40210703/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86753
Approved by: https://github.com/slgong-fb, https://github.com/aaronenyeshi
2022-10-23 19:23:39 +00:00
Taylor Robie
c16b7b41f7 [Profiler][Trivial] Small style and safety fixes (#86752)
I noticed a couple abbreviations in the new optimizer capture code that are worth expanding. I also made the RawTensorMetadata a bit safer.

Differential Revision: [D40210702](https://our.internmc.facebook.com/intern/diff/D40210702/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86752
Approved by: https://github.com/slgong-fb, https://github.com/aaronenyeshi
2022-10-20 17:34:16 +00:00
Taylor Robie
35fb007749 [Profiler][Minor] Separate standalone profilers from the main PyTorch profiler. (#85511)
There are a number of instrumentation utils which have been added to the profiler toolkit. They are generally small and self contained, often wrapping vendor APIs. (NVTX, ITT)

They don't really interact with the much more expansive machinery of the PyTorch profiler beyond registration / unregistration, minor util sharing, and reusing the profiler base class. Just as in the case of stubs, it makes sense to group them in a dedicated subfolder.

Differential Revision: [D39108649](https://our.internmc.facebook.com/intern/diff/D39108649/)

**NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D39108649/)!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85511
Approved by: https://github.com/albanD
2022-10-14 05:38:48 +00:00
Seonglyong Gong
dbea07b6aa [Profiler] record gradient from nnModule (#86355)
Summary:
- catch .grad tensor info
- update data type and `check_and_store`, etc
- update unit test case

Test Plan: buck run mode/opt //caffe2/test:profiler

Reviewed By: chaekit

Differential Revision: D39711295

Pull Request resolved: https://github.com/pytorch/pytorch/pull/86355
Approved by: https://github.com/chaekit
2022-10-07 09:58:50 +00:00
Seonglyong Gong
27c3fb0386 [Profiler] trace verbose=false by default (#86263)
Summary:
- Added config option to remove 'Call stack' field from trace file (#84982)
- Change default value to `false`

Test Plan:
- `experimental_config=_ExperimentalConfig(verbose=true),` will add 'Call stack' field back in the trace file.
- CI tests

Differential Revision: D40092377

Pull Request resolved: https://github.com/pytorch/pytorch/pull/86263
Approved by: https://github.com/aaronenyeshi
2022-10-06 06:32:25 +00:00
Seonglyong Gong
a117fde86f [Profiler] Apply TensorMetadata for Optimizer and nnModule (#86047)
Summary: - Use `TensorMetadat` struct in saving tensor info from Optimizer and nnModule.

Test Plan: buck run mode/opt //caffe2/test:profiler

Reviewed By: chaekit

Differential Revision: D39682205

Pull Request resolved: https://github.com/pytorch/pytorch/pull/86047
Approved by: https://github.com/chaekit, https://github.com/robieta
2022-10-06 06:18:56 +00:00