mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fixes #112595 - `torch/autograd/profiler.py` </br> **Before: 37** ``` torch/autograd/profiler.py:1 at module level: D100: Missing docstring in public module torch/autograd/profiler.py:91 in public class `profile`: D205: 1 blank line required between summary line and description (found 0) torch/autograd/profiler.py:175 in public method `__init__`: D107: Missing docstring in __init__ torch/autograd/profiler.py:261 in public method `config`: D102: Missing docstring in public method torch/autograd/profiler.py:272 in public method `__enter__`: D105: Missing docstring in magic method torch/autograd/profiler.py:290 in public method `__exit__`: D105: Missing docstring in magic method torch/autograd/profiler.py:308 in public method `__repr__`: D105: Missing docstring in magic method torch/autograd/profiler.py:313 in public method `__str__`: D105: Missing docstring in magic method torch/autograd/profiler.py:322 in public method `table`: D102: Missing docstring in public method torch/autograd/profiler.py:346 in public method `export_chrome_trace`: D102: Missing docstring in public method torch/autograd/profiler.py:355 in public method `export_stacks`: D102: Missing docstring in public method torch/autograd/profiler.py:361 in public method `key_averages`: D102: Missing docstring in public method torch/autograd/profiler.py:368 in public method `total_average`: D102: Missing docstring in public method torch/autograd/profiler.py:377 in public method `self_cpu_time_total`: D205: 1 blank line required between summary line and description (found 0) torch/autograd/profiler.py:377 in public method `self_cpu_time_total`: D400: First line should end with a period (not 'f') torch/autograd/profiler.py:555 in public class `record_function`: D205: 1 blank line required between summary line and description (found 0) torch/autograd/profiler.py:555 in public class `record_function`: D400: First line should end with a period (not 'f') torch/autograd/profiler.py:591 in public method `__init__`: D107: Missing docstring in __init__ torch/autograd/profiler.py:602 in public method `__enter__`: D105: Missing docstring in magic method torch/autograd/profiler.py:608 in public method `__exit__`: D105: Missing docstring in magic method torch/autograd/profiler.py:625 in private method `_call_end_callbacks_on_future`: D205: 1 blank line required between summary line and description (found 0) torch/autograd/profiler.py:625 in private method `_call_end_callbacks_on_future`: D400: First line should end with a period (not 'c') torch/autograd/profiler.py:707 in public method `__init__`: D107: Missing docstring in __init__ torch/autograd/profiler.py:712 in public method `__enter__`: D105: Missing docstring in magic method torch/autograd/profiler.py:733 in public method `__exit__`: D105: Missing docstring in magic method torch/autograd/profiler.py:826 in public method `__init__`: D107: Missing docstring in __init__ torch/autograd/profiler.py:831 in public method `__enter__`: D105: Missing docstring in magic method torch/autograd/profiler.py:853 in public method `__exit__`: D105: Missing docstring in magic method torch/autograd/profiler.py:863 in public function `load_nvprof`: D401: First line should be in imperative mood (perhaps 'Open', not 'Opens') torch/autograd/profiler.py:874 in public method `__init__`: D107: Missing docstring in __init__ torch/autograd/profiler.py:877 in public method `see`: D102: Missing docstring in public method torch/autograd/profiler.py:883 in public function `parse_nvprof_trace`: D103: Missing docstring in public function torch/autograd/profiler.py:951 in public class `KinetoStepTracker`: D205: 1 blank line required between summary line and description (found 0) torch/autograd/profiler.py:991 in public method `init_step_count`: D102: Missing docstring in public method torch/autograd/profiler.py:995 in public method `erase_step_count`: D102: Missing docstring in public method torch/autograd/profiler.py:1000 in public method `increment_step`: D205: 1 blank line required between summary line and description (found 0) torch/autograd/profiler.py:1023 in public method `current_step`: D102: Missing docstring in public method 37 ``` **After: 27** ``` torch/autograd/profiler.py:1 at module level: D100: Missing docstring in public module torch/autograd/profiler.py:176 in public method `__init__`: D107: Missing docstring in __init__ torch/autograd/profiler.py:262 in public method `config`: D102: Missing docstring in public method torch/autograd/profiler.py:273 in public method `__enter__`: D105: Missing docstring in magic method torch/autograd/profiler.py:291 in public method `__exit__`: D105: Missing docstring in magic method torch/autograd/profiler.py:309 in public method `__repr__`: D105: Missing docstring in magic method torch/autograd/profiler.py:314 in public method `__str__`: D105: Missing docstring in magic method torch/autograd/profiler.py:323 in public method `table`: D102: Missing docstring in public method torch/autograd/profiler.py:347 in public method `export_chrome_trace`: D102: Missing docstring in public method torch/autograd/profiler.py:356 in public method `export_stacks`: D102: Missing docstring in public method torch/autograd/profiler.py:362 in public method `key_averages`: D102: Missing docstring in public method torch/autograd/profiler.py:369 in public method `total_average`: D102: Missing docstring in public method torch/autograd/profiler.py:593 in public method `__init__`: D107: Missing docstring in __init__ torch/autograd/profiler.py:604 in public method `__enter__`: D105: Missing docstring in magic method torch/autograd/profiler.py:610 in public method `__exit__`: D105: Missing docstring in magic method torch/autograd/profiler.py:708 in public method `__init__`: D107: Missing docstring in __init__ torch/autograd/profiler.py:713 in public method `__enter__`: D105: Missing docstring in magic method torch/autograd/profiler.py:734 in public method `__exit__`: D105: Missing docstring in magic method torch/autograd/profiler.py:827 in public method `__init__`: D107: Missing docstring in __init__ torch/autograd/profiler.py:832 in public method `__enter__`: D105: Missing docstring in magic method torch/autograd/profiler.py:854 in public method `__exit__`: D105: Missing docstring in magic method torch/autograd/profiler.py:875 in public method `__init__`: D107: Missing docstring in __init__ torch/autograd/profiler.py:878 in public method `see`: D102: Missing docstring in public method torch/autograd/profiler.py:884 in public function `parse_nvprof_trace`: D103: Missing docstring in public function torch/autograd/profiler.py:993 in public method `init_step_count`: D102: Missing docstring in public method torch/autograd/profiler.py:997 in public method `erase_step_count`: D102: Missing docstring in public method torch/autograd/profiler.py:1025 in public method `current_step`: D102: Missing docstring in public method 27 ``` - `torch/autograd/graph.py` </br> **Before: 22** ``` torch/autograd/graph.py:1 at module level: D100: Missing docstring in public module torch/autograd/graph.py:24 in public class `Node`: D101: Missing docstring in public class torch/autograd/graph.py:27 in public method `name`: D401: First line should be in imperative mood (perhaps 'Return', not 'Returns') torch/autograd/graph.py:42 in public method `next_functions`: D102: Missing docstring in public method torch/autograd/graph.py:47 in public method `metadata`: D401: First line should be in imperative mood (perhaps 'Return', not 'Returns') torch/autograd/graph.py:56 in public method `register_hook`: D401: First line should be in imperative mood (perhaps 'Register', not 'Registers') torch/autograd/graph.py:94 in public method `register_prehook`: D401: First line should be in imperative mood (perhaps 'Register', not 'Registers') torch/autograd/graph.py:129 in public method `__subclasshook__`: D105: Missing docstring in magic method torch/autograd/graph.py:147 in public function `get_gradient_edge`: D205: 1 blank line required between summary line and description (found 0) torch/autograd/graph.py:147 in public function `get_gradient_edge`: D400: First line should end with a period (not 'f') torch/autograd/graph.py:147 in public function `get_gradient_edge`: D401: First line should be in imperative mood; try rephrasing (found 'This') torch/autograd/graph.py:166 in public function `increment_version`: D205: 1 blank line required between summary line and description (found 0) torch/autograd/graph.py:166 in public function `increment_version`: D400: First line should end with a period (not 'd') torch/autograd/graph.py:166 in public function `increment_version`: D401: First line should be in imperative mood; try rephrasing (found 'This') torch/autograd/graph.py:243 in public method `__init__`: D107: Missing docstring in __init__ torch/autograd/graph.py:251 in public method `__enter__`: D105: Missing docstring in magic method torch/autograd/graph.py:256 in public method `__exit__`: D105: Missing docstring in magic method torch/autograd/graph.py:261 in public class `save_on_cpu`: D205: 1 blank line required between summary line and description (found 0) torch/autograd/graph.py:261 in public class `save_on_cpu`: D400: First line should end with a period (not 'e') torch/autograd/graph.py:303 in public method `__init__`: D107: Missing docstring in __init__ torch/autograd/graph.py:365 in public function `register_multi_grad_hook`: D401: First line should be in imperative mood (perhaps 'Register', not 'Registers') torch/autograd/graph.py:588 in public function `allow_mutation_on_saved_tensors`: D400: First line should end with a period (not 'd') 22 ``` **After: 8** ``` torch/autograd/graph.py:1 at module level: D100: Missing docstring in public module torch/autograd/graph.py:24 in public class `Node`: D101: Missing docstring in public class torch/autograd/graph.py:42 in public method `next_functions`: D102: Missing docstring in public method torch/autograd/graph.py:129 in public method `__subclasshook__`: D105: Missing docstring in magic method torch/autograd/graph.py:244 in public method `__init__`: D107: Missing docstring in __init__ torch/autograd/graph.py:252 in public method `__enter__`: D105: Missing docstring in magic method torch/autograd/graph.py:257 in public method `__exit__`: D105: Missing docstring in magic method torch/autograd/graph.py:303 in public method `__init__`: D107: Missing docstring in __init__ 8 ``` - `torch/multiprocessing/pool.py` </br> **Before: 6** ``` torch/multiprocessing/pool.py:1 at module level: D100: Missing docstring in public module torch/multiprocessing/pool.py:7 in public function `clean_worker`: D103: Missing docstring in public function torch/multiprocessing/pool.py:18 in public class `Pool`: D205: 1 blank line required between summary line and description (found 0) torch/multiprocessing/pool.py:18 in public class `Pool`: D209: Multi-line docstring closing quotes should be on a separate line torch/multiprocessing/pool.py:29 in private method `_repopulate_pool`: D205: 1 blank line required between summary line and description (found 0) torch/multiprocessing/pool.py:29 in private method `_repopulate_pool`: D400: First line should end with a period (not ',') 6 ``` **After: 2** ``` torch/multiprocessing/pool.py:1 at module level: D100: Missing docstring in public module torch/multiprocessing/pool.py:7 in public function `clean_worker`: D103: Missing docstring in public function 2 ``` - `torch/multiprocessing/queue.py` </br> **Before: 11** ``` torch/multiprocessing/queue.py:1 at module level: D100: Missing docstring in public module torch/multiprocessing/queue.py:8 in public class `ConnectionWrapper`: D205: 1 blank line required between summary line and description (found 0) torch/multiprocessing/queue.py:8 in public class `ConnectionWrapper`: D209: Multi-line docstring closing quotes should be on a separate line torch/multiprocessing/queue.py:8 in public class `ConnectionWrapper`: D400: First line should end with a period (not 'o') torch/multiprocessing/queue.py:11 in public method `__init__`: D107: Missing docstring in __init__ torch/multiprocessing/queue.py:14 in public method `send`: D102: Missing docstring in public method torch/multiprocessing/queue.py:19 in public method `recv`: D102: Missing docstring in public method torch/multiprocessing/queue.py:23 in public method `__getattr__`: D105: Missing docstring in magic method torch/multiprocessing/queue.py:29 in public class `Queue`: D101: Missing docstring in public class torch/multiprocessing/queue.py:30 in public method `__init__`: D107: Missing docstring in __init__ torch/multiprocessing/queue.py:38 in public class `SimpleQueue`: D101: Missing docstring in public class 11 ``` **After: 8** ``` torch/multiprocessing/queue.py:1 at module level: D100: Missing docstring in public module torch/multiprocessing/queue.py:10 in public method `__init__`: D107: Missing docstring in __init__ torch/multiprocessing/queue.py:13 in public method `send`: D102: Missing docstring in public method torch/multiprocessing/queue.py:18 in public method `recv`: D102: Missing docstring in public method torch/multiprocessing/queue.py:22 in public method `__getattr__`: D105: Missing docstring in magic method torch/multiprocessing/queue.py:28 in public class `Queue`: D101: Missing docstring in public class torch/multiprocessing/queue.py:29 in public method `__init__`: D107: Missing docstring in __init__ torch/multiprocessing/queue.py:37 in public class `SimpleQueue`: D101: Missing docstring in public class 8 ``` - `torch/multiprocessing/reductions.py` </br> **Before: 31** ``` torch/multiprocessing/reductions.py:1 at module level: D100: Missing docstring in public module torch/multiprocessing/reductions.py:24 in public class `StorageWeakRef`: D209: Multi-line docstring closing quotes should be on a separate line torch/multiprocessing/reductions.py:31 in public method `__init__`: D107: Missing docstring in __init__ torch/multiprocessing/reductions.py:38 in public method `from_weakref`: D102: Missing docstring in public method torch/multiprocessing/reductions.py:44 in public method `expired`: D102: Missing docstring in public method torch/multiprocessing/reductions.py:47 in public method `__del__`: D105: Missing docstring in magic method torch/multiprocessing/reductions.py:50 in public method `__hash__`: D105: Missing docstring in magic method torch/multiprocessing/reductions.py:53 in public method `__eq__`: D105: Missing docstring in magic method torch/multiprocessing/reductions.py:60 in public class `SharedCache`: D400: First line should end with a period (not 'f') torch/multiprocessing/reductions.py:62 in public method `__init__`: D107: Missing docstring in __init__ torch/multiprocessing/reductions.py:75 in public method `get`: D102: Missing docstring in public method torch/multiprocessing/reductions.py:79 in public method `__setitem__`: D105: Missing docstring in magic method torch/multiprocessing/reductions.py:85 in public method `free_dead_references`: D102: Missing docstring in public method torch/multiprocessing/reductions.py:99 in public function `rebuild_event`: D103: Missing docstring in public function torch/multiprocessing/reductions.py:103 in public function `reduce_event`: D103: Missing docstring in public function torch/multiprocessing/reductions.py:108 in public function `rebuild_tensor`: D103: Missing docstring in public function torch/multiprocessing/reductions.py:121 in public function `rebuild_cuda_tensor`: D103: Missing docstring in public function torch/multiprocessing/reductions.py:189 in public function `reduce_tensor`: D103: Missing docstring in public function torch/multiprocessing/reductions.py:347 in public function `rebuild_nested_tensor`: D103: Missing docstring in public function torch/multiprocessing/reductions.py:364 in public function `reduce_nested_tensor`: D103: Missing docstring in public function torch/multiprocessing/reductions.py:389 in public function `fd_id`: D103: Missing docstring in public function torch/multiprocessing/reductions.py:397 in public function `storage_from_cache`: D103: Missing docstring in public function torch/multiprocessing/reductions.py:404 in public function `rebuild_storage_fd`: D103: Missing docstring in public function torch/multiprocessing/reductions.py:417 in public function `rebuild_storage_filename`: D103: Missing docstring in public function torch/multiprocessing/reductions.py:437 in public function `rebuild_storage_empty`: D103: Missing docstring in public function torch/multiprocessing/reductions.py:441 in public function `rebuild_typed_storage`: D103: Missing docstring in public function torch/multiprocessing/reductions.py:446 in public function `reduce_typed_storage`: D103: Missing docstring in public function torch/multiprocessing/reductions.py:450 in public function `rebuild_typed_storage_child`: D103: Missing docstring in public function torch/multiprocessing/reductions.py:455 in public function `reduce_typed_storage_child`: D103: Missing docstring in public function torch/multiprocessing/reductions.py:459 in public function `reduce_storage`: D103: Missing docstring in public function torch/multiprocessing/reductions.py:488 in public function `init_reductions`: D103: Missing docstring in public function 31 ``` **After: 29** ``` torch/multiprocessing/reductions.py:1 at module level: D100: Missing docstring in public module torch/multiprocessing/reductions.py:32 in public method `__init__`: D107: Missing docstring in __init__ torch/multiprocessing/reductions.py:39 in public method `from_weakref`: D102: Missing docstring in public method torch/multiprocessing/reductions.py:45 in public method `expired`: D102: Missing docstring in public method torch/multiprocessing/reductions.py:48 in public method `__del__`: D105: Missing docstring in magic method torch/multiprocessing/reductions.py:51 in public method `__hash__`: D105: Missing docstring in magic method torch/multiprocessing/reductions.py:54 in public method `__eq__`: D105: Missing docstring in magic method torch/multiprocessing/reductions.py:63 in public method `__init__`: D107: Missing docstring in __init__ torch/multiprocessing/reductions.py:76 in public method `get`: D102: Missing docstring in public method torch/multiprocessing/reductions.py:80 in public method `__setitem__`: D105: Missing docstring in magic method torch/multiprocessing/reductions.py:86 in public method `free_dead_references`: D102: Missing docstring in public method torch/multiprocessing/reductions.py:100 in public function `rebuild_event`: D103: Missing docstring in public function torch/multiprocessing/reductions.py:104 in public function `reduce_event`: D103: Missing docstring in public function torch/multiprocessing/reductions.py:109 in public function `rebuild_tensor`: D103: Missing docstring in public function torch/multiprocessing/reductions.py:122 in public function `rebuild_cuda_tensor`: D103: Missing docstring in public function torch/multiprocessing/reductions.py:190 in public function `reduce_tensor`: D103: Missing docstring in public function torch/multiprocessing/reductions.py:348 in public function `rebuild_nested_tensor`: D103: Missing docstring in public function torch/multiprocessing/reductions.py:365 in public function `reduce_nested_tensor`: D103: Missing docstring in public function torch/multiprocessing/reductions.py:390 in public function `fd_id`: D103: Missing docstring in public function torch/multiprocessing/reductions.py:398 in public function `storage_from_cache`: D103: Missing docstring in public function torch/multiprocessing/reductions.py:405 in public function `rebuild_storage_fd`: D103: Missing docstring in public function torch/multiprocessing/reductions.py:418 in public function `rebuild_storage_filename`: D103: Missing docstring in public function torch/multiprocessing/reductions.py:438 in public function `rebuild_storage_empty`: D103: Missing docstring in public function torch/multiprocessing/reductions.py:442 in public function `rebuild_typed_storage`: D103: Missing docstring in public function torch/multiprocessing/reductions.py:447 in public function `reduce_typed_storage`: D103: Missing docstring in public function torch/multiprocessing/reductions.py:451 in public function `rebuild_typed_storage_child`: D103: Missing docstring in public function torch/multiprocessing/reductions.py:456 in public function `reduce_typed_storage_child`: D103: Missing docstring in public function torch/multiprocessing/reductions.py:460 in public function `reduce_storage`: D103: Missing docstring in public function torch/multiprocessing/reductions.py:489 in public function `init_reductions`: D103: Missing docstring in public function 29 ``` - `torch/multiprocessing/spawn.py` </br> **Before: 19** ``` torch/multiprocessing/spawn.py:1 at module level: D100: Missing docstring in public module torch/multiprocessing/spawn.py:11 in public class `ProcessException`: D101: Missing docstring in public class torch/multiprocessing/spawn.py:14 in public method `__init__`: D107: Missing docstring in __init__ torch/multiprocessing/spawn.py:20 in public method `__reduce__`: D105: Missing docstring in magic method torch/multiprocessing/spawn.py:25 in public class `ProcessRaisedException`: D205: 1 blank line required between summary line and description (found 0) torch/multiprocessing/spawn.py:25 in public class `ProcessRaisedException`: D400: First line should end with a period (not 'n') torch/multiprocessing/spawn.py:30 in public method `__init__`: D107: Missing docstring in __init__ torch/multiprocessing/spawn.py:40 in public class `ProcessExitedException`: D205: 1 blank line required between summary line and description (found 0) torch/multiprocessing/spawn.py:40 in public class `ProcessExitedException`: D400: First line should end with a period (not 'l') torch/multiprocessing/spawn.py:47 in public method `__init__`: D107: Missing docstring in __init__ torch/multiprocessing/spawn.py:59 in public method `__reduce__`: D105: Missing docstring in magic method torch/multiprocessing/spawn.py:85 in public class `ProcessContext`: D101: Missing docstring in public class torch/multiprocessing/spawn.py:86 in public method `__init__`: D107: Missing docstring in __init__ torch/multiprocessing/spawn.py:93 in public method `pids`: D102: Missing docstring in public method torch/multiprocessing/spawn.py:97 in public method `join`: D205: 1 blank line required between summary line and description (found 0) torch/multiprocessing/spawn.py:97 in public method `join`: D401: First line should be in imperative mood (perhaps 'Try', not 'Tries') torch/multiprocessing/spawn.py:166 in public class `SpawnContext`: D101: Missing docstring in public class torch/multiprocessing/spawn.py:167 in public method `__init__`: D107: Missing docstring in __init__ torch/multiprocessing/spawn.py:180 in public function `start_processes`: D103: Missing docstring in public function 19 ``` **After: 13** ``` torch/multiprocessing/spawn.py:1 at module level: D100: Missing docstring in public module torch/multiprocessing/spawn.py:11 in public class `ProcessException`: D101: Missing docstring in public class torch/multiprocessing/spawn.py:14 in public method `__init__`: D107: Missing docstring in __init__ torch/multiprocessing/spawn.py:20 in public method `__reduce__`: D105: Missing docstring in magic method torch/multiprocessing/spawn.py:27 in public method `__init__`: D107: Missing docstring in __init__ torch/multiprocessing/spawn.py:41 in public method `__init__`: D107: Missing docstring in __init__ torch/multiprocessing/spawn.py:53 in public method `__reduce__`: D105: Missing docstring in magic method torch/multiprocessing/spawn.py:79 in public class `ProcessContext`: D101: Missing docstring in public class torch/multiprocessing/spawn.py:80 in public method `__init__`: D107: Missing docstring in __init__ torch/multiprocessing/spawn.py:87 in public method `pids`: D102: Missing docstring in public method torch/multiprocessing/spawn.py:161 in public class `SpawnContext`: D101: Missing docstring in public class torch/multiprocessing/spawn.py:162 in public method `__init__`: D107: Missing docstring in __init__ torch/multiprocessing/spawn.py:175 in public function `start_processes`: D103: Missing docstring in public function 13 ``` - `torch/multiprocessing/__init__.py` </br> **Before: 0** ``` torch/multiprocessing/__init__.py:1 at module level: D205: 1 blank line required between summary line and description (found 0) torch/multiprocessing/__init__.py:1 at module level: D400: First line should end with a period (not '`') torch/multiprocessing/__init__.py:57 in public function `set_sharing_strategy`: D401: First line should be in imperative mood (perhaps 'Set', not 'Sets') torch/multiprocessing/__init__.py:69 in public function `get_sharing_strategy`: D401: First line should be in imperative mood (perhaps 'Return', not 'Returns') torch/multiprocessing/__init__.py:74 in public function `get_all_sharing_strategies`: D401: First line should be in imperative mood (perhaps 'Return', not 'Returns') 5 ``` **After: 0** - `torch/nn/__init__.py` </br> **Before: 3** ``` torch/nn/__init__.py:1 at module level: D104: Missing docstring in public package torch/nn/__init__.py:14 in public function `factory_kwargs`: D205: 1 blank line required between summary line and description (found 0) torch/nn/__init__.py:14 in public function `factory_kwargs`: D400: First line should end with a period (not 'd') 3 ``` **After: 1** ``` torch/nn/__init__.py:1 at module level: D104: Missing docstring in public package 1 ``` - `torch/nn/cpp.py` </br> **Before: 16** ``` torch/nn/cpp.py:7 in public class `OrderedDictWrapper`: D205: 1 blank line required between summary line and description (found 0) torch/nn/cpp.py:7 in public class `OrderedDictWrapper`: D400: First line should end with a period (not 'e') torch/nn/cpp.py:16 in public method `__init__`: D107: Missing docstring in __init__ torch/nn/cpp.py:21 in public method `cpp_dict`: D102: Missing docstring in public method torch/nn/cpp.py:27 in public method `items`: D102: Missing docstring in public method torch/nn/cpp.py:30 in public method `keys`: D102: Missing docstring in public method torch/nn/cpp.py:33 in public method `values`: D102: Missing docstring in public method torch/nn/cpp.py:36 in public method `__iter__`: D105: Missing docstring in magic method torch/nn/cpp.py:39 in public method `__len__`: D105: Missing docstring in magic method torch/nn/cpp.py:42 in public method `__contains__`: D105: Missing docstring in magic method torch/nn/cpp.py:45 in public method `__getitem__`: D105: Missing docstring in magic method torch/nn/cpp.py:50 in public class `ModuleWrapper`: D205: 1 blank line required between summary line and description (found 0) torch/nn/cpp.py:50 in public class `ModuleWrapper`: D400: First line should end with a period (not 'd') torch/nn/cpp.py:55 in public method `__init__`: D107: Missing docstring in __init__ torch/nn/cpp.py:83 in public method `training`: D102: Missing docstring in public method torch/nn/cpp.py:90 in public method `__repr__`: D105: Missing docstring in magic method 16 ``` **After: 12** ``` torch/nn/cpp.py:16 in public method `__init__`: D107: Missing docstring in __init__ torch/nn/cpp.py:21 in public method `cpp_dict`: D102: Missing docstring in public method torch/nn/cpp.py:27 in public method `items`: D102: Missing docstring in public method torch/nn/cpp.py:30 in public method `keys`: D102: Missing docstring in public method torch/nn/cpp.py:33 in public method `values`: D102: Missing docstring in public method torch/nn/cpp.py:36 in public method `__iter__`: D105: Missing docstring in magic method torch/nn/cpp.py:39 in public method `__len__`: D105: Missing docstring in magic method torch/nn/cpp.py:42 in public method `__contains__`: D105: Missing docstring in magic method torch/nn/cpp.py:45 in public method `__getitem__`: D105: Missing docstring in magic method torch/nn/cpp.py:52 in public method `__init__`: D107: Missing docstring in __init__ torch/nn/cpp.py:80 in public method `training`: D102: Missing docstring in public method torch/nn/cpp.py:87 in public method `__repr__`: D105: Missing docstring in magic method 12 ``` - `torch/nn/grad.py` </br> **Before: 10** ``` torch/nn/grad.py:1 at module level: D400: First line should end with a period (not 'e') torch/nn/grad.py:8 in public function `conv1d_input`: D205: 1 blank line required between summary line and description (found 0) torch/nn/grad.py:8 in public function `conv1d_input`: D401: First line should be in imperative mood (perhaps 'Compute', not 'Computes') torch/nn/grad.py:40 in public function `conv1d_weight`: D401: First line should be in imperative mood (perhaps 'Compute', not 'Computes') torch/nn/grad.py:71 in public function `conv2d_input`: D205: 1 blank line required between summary line and description (found 0) torch/nn/grad.py:71 in public function `conv2d_input`: D401: First line should be in imperative mood (perhaps 'Compute', not 'Computes') torch/nn/grad.py:103 in public function `conv2d_weight`: D401: First line should be in imperative mood (perhaps 'Compute', not 'Computes') torch/nn/grad.py:134 in public function `conv3d_input`: D205: 1 blank line required between summary line and description (found 0) torch/nn/grad.py:134 in public function `conv3d_input`: D401: First line should be in imperative mood (perhaps 'Compute', not 'Computes') torch/nn/grad.py:166 in public function `conv3d_weight`: D401: First line should be in imperative mood (perhaps 'Compute', not 'Computes') 10 ``` **After: 0** - `torch/nn/parameter.py` </br> **Before: 17** ``` torch/nn/parameter.py:1 at module level: D100: Missing docstring in public module torch/nn/parameter.py:14 in public class `Parameter`: D204: 1 blank line required after class docstring (found 0) torch/nn/parameter.py:33 in public method `__new__`: D102: Missing docstring in public method torch/nn/parameter.py:54 in public method `__deepcopy__`: D105: Missing docstring in magic method torch/nn/parameter.py:62 in public method `__repr__`: D105: Missing docstring in magic method torch/nn/parameter.py:65 in public method `__reduce_ex__`: D105: Missing docstring in magic method torch/nn/parameter.py:84 in public class `UninitializedTensorMixin`: D101: Missing docstring in public class torch/nn/parameter.py:105 in public method `materialize`: D205: 1 blank line required between summary line and description (found 0) torch/nn/parameter.py:125 in public method `shape`: D102: Missing docstring in public method torch/nn/parameter.py:132 in public method `share_memory_`: D102: Missing docstring in public method torch/nn/parameter.py:138 in public method `__repr__`: D105: Missing docstring in magic method torch/nn/parameter.py:141 in public method `__reduce_ex__`: D105: Missing docstring in magic method torch/nn/parameter.py:149 in public method `__torch_function__`: D105: Missing docstring in magic method torch/nn/parameter.py:164 in public function `is_lazy`: D103: Missing docstring in public function torch/nn/parameter.py:186 in public method `__new__`: D102: Missing docstring in public method torch/nn/parameter.py:191 in public method `__deepcopy__`: D105: Missing docstring in magic method torch/nn/parameter.py:217 in public method `__new__`: D102: Missing docstring in public method 17 ``` **After: 15** ``` torch/nn/parameter.py:1 at module level: D100: Missing docstring in public module torch/nn/parameter.py:34 in public method `__new__`: D102: Missing docstring in public method torch/nn/parameter.py:55 in public method `__deepcopy__`: D105: Missing docstring in magic method torch/nn/parameter.py:63 in public method `__repr__`: D105: Missing docstring in magic method torch/nn/parameter.py:66 in public method `__reduce_ex__`: D105: Missing docstring in magic method torch/nn/parameter.py:85 in public class `UninitializedTensorMixin`: D101: Missing docstring in public class torch/nn/parameter.py:127 in public method `shape`: D102: Missing docstring in public method torch/nn/parameter.py:134 in public method `share_memory_`: D102: Missing docstring in public method torch/nn/parameter.py:140 in public method `__repr__`: D105: Missing docstring in magic method torch/nn/parameter.py:143 in public method `__reduce_ex__`: D105: Missing docstring in magic method torch/nn/parameter.py:151 in public method `__torch_function__`: D105: Missing docstring in magic method torch/nn/parameter.py:166 in public function `is_lazy`: D103: Missing docstring in public function torch/nn/parameter.py:188 in public method `__new__`: D102: Missing docstring in public method torch/nn/parameter.py:193 in public method `__deepcopy__`: D105: Missing docstring in magic method torch/nn/parameter.py:219 in public method `__new__`: D102: Missing docstring in public method 15 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/113052 Approved by: https://github.com/mikaylagawarecki, https://github.com/soulitzer
632 lines
23 KiB
Python
632 lines
23 KiB
Python
import abc
|
|
import contextlib
|
|
import weakref
|
|
from collections import defaultdict, namedtuple
|
|
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple
|
|
|
|
import torch
|
|
from torch.utils._python_dispatch import TorchDispatchMode
|
|
from torch.utils.hooks import RemovableHandle
|
|
|
|
__all__ = [
|
|
"saved_tensors_hooks",
|
|
"save_on_cpu",
|
|
"disable_saved_tensors_hooks",
|
|
"register_multi_grad_hook",
|
|
"allow_mutation_on_saved_tensors",
|
|
"Node",
|
|
"GradientEdge",
|
|
"get_gradient_edge",
|
|
"increment_version",
|
|
]
|
|
|
|
|
|
class Node(abc.ABC):
|
|
@abc.abstractmethod
|
|
def name(self) -> str:
|
|
r"""Return the name.
|
|
|
|
Example::
|
|
|
|
>>> import torch
|
|
>>> a = torch.tensor([0., 0., 0.], requires_grad=True)
|
|
>>> b = a.clone()
|
|
>>> assert isinstance(b.grad_fn, torch.autograd.graph.Node)
|
|
>>> print(b.grad_fn.name())
|
|
CloneBackward0
|
|
"""
|
|
...
|
|
|
|
@property
|
|
@abc.abstractmethod
|
|
def next_functions(self) -> Tuple[Tuple[Optional["Node"], int], ...]:
|
|
...
|
|
|
|
@abc.abstractmethod
|
|
def metadata(self) -> dict:
|
|
r"""Return the metadata."""
|
|
...
|
|
|
|
@abc.abstractmethod
|
|
def _register_hook_dict(self, tensor: torch.Tensor) -> None:
|
|
...
|
|
|
|
@abc.abstractmethod
|
|
def register_hook(self, fn: Callable[..., Any]) -> RemovableHandle:
|
|
r"""Register a backward hook.
|
|
|
|
The hook will be called every time a gradient with respect to the
|
|
Node is computed. The hook should have the following signature::
|
|
|
|
hook(grad_inputs: Tuple[Tensor], grad_outputs: Tuple[Tensor]) -> Tuple[Tensor] or None
|
|
|
|
|
|
The hook should not modify its argument, but it can optionally return
|
|
a new gradient which will be used in place of :attr:`grad_inputs`.
|
|
|
|
This function returns a handle with a method ``handle.remove()``
|
|
that removes the hook from the module.
|
|
|
|
.. note::
|
|
See :ref:`backward-hooks-execution` for more information on how when this hook
|
|
is executed, and how its execution is ordered relative to other hooks.
|
|
|
|
Example::
|
|
|
|
>>> import torch
|
|
>>> a = torch.tensor([0., 0., 0.], requires_grad=True)
|
|
>>> b = a.clone()
|
|
>>> assert isinstance(b.grad_fn, torch.autograd.graph.Node)
|
|
>>> handle = b.grad_fn.register_hook(lambda gI, gO: (gO[0] * 2,))
|
|
>>> b.sum().backward(retain_graph=True)
|
|
>>> print(a.grad)
|
|
tensor([2., 2., 2.])
|
|
>>> handle.remove() # Removes the hook
|
|
>>> a.grad = None
|
|
>>> b.sum().backward(retain_graph=True)
|
|
>>> print(a.grad)
|
|
tensor([1., 1., 1.])
|
|
"""
|
|
...
|
|
|
|
@abc.abstractmethod
|
|
def register_prehook(self, fn: Callable[..., Any]) -> RemovableHandle:
|
|
r"""Register a backward pre-hook.
|
|
|
|
The hook will be called every time a gradient with respect to the
|
|
Node is computed. The hook should have the following signature::
|
|
|
|
hook(grad_outputs: Tuple[Tensor]) -> Tuple[Tensor] or None
|
|
|
|
The hook should not modify its argument, but it can optionally return
|
|
a new gradient which will be used in place of :attr:`grad_outputs`.
|
|
|
|
This function returns a handle with a method ``handle.remove()``
|
|
that removes the hook from the module.
|
|
|
|
.. note::
|
|
See :ref:`backward-hooks-execution` for more information on how when this hook
|
|
is executed, and how its execution is ordered relative to other hooks.
|
|
|
|
Example::
|
|
|
|
>>> a = torch.tensor([0., 0., 0.], requires_grad=True)
|
|
>>> b = a.clone()
|
|
>>> assert isinstance(b.grad_fn, torch.autograd.graph.Node)
|
|
>>> handle = b.grad_fn.register_prehook(lambda gI: (gI[0] * 2,))
|
|
>>> b.sum().backward(retain_graph=True)
|
|
>>> print(a.grad)
|
|
tensor([2., 2., 2.])
|
|
>>> handle.remove()
|
|
>>> a.grad = None
|
|
>>> b.sum().backward(retain_graph=True)
|
|
>>> print(a.grad)
|
|
tensor([1., 1., 1.])
|
|
"""
|
|
...
|
|
|
|
@classmethod
|
|
def __subclasshook__(cls, C):
|
|
if cls is Node:
|
|
if (
|
|
C is not None and C is getattr(torch._C._functions, C.__name__, None)
|
|
) or issubclass(C, torch.autograd.function.BackwardCFunction):
|
|
return True
|
|
return NotImplemented
|
|
|
|
|
|
GradientEdge = namedtuple("GradientEdge", ("node output_nr"))
|
|
GradientEdge.__doc__ = """\
|
|
Object representing a given gradient edge within the autograd graph.
|
|
To get the gradient edge where a given Tensor gradient will be computed,
|
|
you can do ``edge = autograd.graph.get_gradient_edge(tensor)``.
|
|
"""
|
|
|
|
|
|
def get_gradient_edge(tensor):
|
|
"""Get the gradient edge for computing the gradient of the given Tensor.
|
|
|
|
In particular, it is equivalent to call
|
|
``g = autograd.grad(loss, input)`` and ``g = autograd.grad(loss, get_gradient_edge(input))``.
|
|
"""
|
|
if not tensor.requires_grad:
|
|
raise RuntimeError(
|
|
"It is not possible to get the gradient edge for a Tensor that does not require gradients"
|
|
)
|
|
grad_fn = tensor.grad_fn
|
|
if grad_fn is None:
|
|
# Do an op to force AccumulateGrad lazy creation and get it
|
|
grad_fn = tensor.view_as(tensor).grad_fn.next_functions[0][0]
|
|
|
|
# Note that output_nr default to 0 which is the right value
|
|
# for the AccumulateGrad node.
|
|
return GradientEdge(grad_fn, tensor.output_nr)
|
|
|
|
|
|
def increment_version(tensor):
|
|
"""Update autograd metadata tracking whether the given Tensor was modified in place.
|
|
|
|
This is to enable more accurate error checking within the autograd engine.
|
|
It is already done automatically by PyTorch functions and within custom Function
|
|
when mark_dirty() is called appropriately so you only need to call this explicitly
|
|
if you are doing inplace operation on the Tensor data in a way that Pytorch doesn't
|
|
know about. For example a custom kernel that reads the Tensor data_ptr and modifies
|
|
the memory inplace based on this pointer.
|
|
|
|
Note that incrementing the version counter multiple times for a single inplace operation
|
|
is not problematic.
|
|
"""
|
|
torch._C._increment_version(tensor)
|
|
|
|
|
|
class saved_tensors_hooks:
|
|
"""Context-manager that sets a pair of pack / unpack hooks for saved tensors.
|
|
|
|
Use this context-manager to define how intermediary results of an operation
|
|
should be packed before saving, and unpacked on retrieval.
|
|
|
|
In that context, the ``pack_hook`` function will be called everytime an
|
|
operation saves a tensor for backward (this includes intermediary results
|
|
saved using
|
|
:func:`~torch.autograd.function._ContextMethodMixin.save_for_backward` but
|
|
also those recorded by a PyTorch-defined operation). The output of
|
|
``pack_hook`` is then stored in the computation graph instead of the
|
|
original tensor.
|
|
|
|
The ``unpack_hook`` is called when the saved tensor needs to be accessed,
|
|
namely when executing :func:`torch.Tensor.backward()` or
|
|
:func:`torch.autograd.grad()`. It takes as argument the *packed* object
|
|
returned by ``pack_hook`` and should return a tensor which has the same
|
|
content as the original tensor (passed as input to the corresponding
|
|
``pack_hook``).
|
|
|
|
The hooks should have the following signatures:
|
|
|
|
pack_hook(tensor: Tensor) -> Any
|
|
|
|
unpack_hook(Any) -> Tensor
|
|
|
|
where the return value of ``pack_hook`` is a valid input to ``unpack_hook``.
|
|
|
|
In general, you want ``unpack_hook(pack_hook(t))`` to be equal to ``t`` in terms
|
|
of value, size, dtype and device.
|
|
|
|
Example::
|
|
|
|
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
|
|
>>> def pack_hook(x):
|
|
... print("Packing", x)
|
|
... return x
|
|
>>>
|
|
>>> def unpack_hook(x):
|
|
... print("Unpacking", x)
|
|
... return x
|
|
>>>
|
|
>>> a = torch.ones(5, requires_grad=True)
|
|
>>> b = torch.ones(5, requires_grad=True) * 2
|
|
>>> with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
|
|
... y = a * b
|
|
Packing tensor([1., 1., 1., 1., 1.], requires_grad=True)
|
|
Packing tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>)
|
|
>>> y.sum().backward()
|
|
Unpacking tensor([1., 1., 1., 1., 1.], requires_grad=True)
|
|
Unpacking tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>)
|
|
|
|
.. warning ::
|
|
Performing an inplace operation on the input to either hooks may lead
|
|
to undefined behavior.
|
|
|
|
.. warning ::
|
|
Only one pair of hooks is allowed at a time. When recursively nesting this
|
|
context-manager, only the inner-most pair of hooks will be applied.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
pack_hook: Callable[[torch.Tensor], Any],
|
|
unpack_hook: Callable[[Any], torch.Tensor],
|
|
):
|
|
self.pack_hook = pack_hook
|
|
self.unpack_hook = unpack_hook
|
|
|
|
def __enter__(self):
|
|
torch._C._autograd._push_saved_tensors_default_hooks(
|
|
self.pack_hook, self.unpack_hook
|
|
)
|
|
|
|
def __exit__(self, *args: object):
|
|
torch._C._autograd._pop_saved_tensors_default_hooks()
|
|
|
|
|
|
class save_on_cpu(saved_tensors_hooks):
|
|
"""Context manager under which tensors saved by the forward pass will be stored on cpu, then retrieved for backward.
|
|
|
|
When performing operations within this context manager, intermediary
|
|
results saved in the graph during the forward pass will be moved to CPU,
|
|
then copied back to the original device when needed for the backward pass.
|
|
If the graph was already on CPU, no tensor copy is performed.
|
|
|
|
Use this context-manager to trade compute for GPU memory usage (e.g.
|
|
when your model doesn't fit in GPU memory during training).
|
|
|
|
Args:
|
|
pin_memory (bool): If ``True`` tensors will be saved to CPU pinned memory
|
|
during packing and copied to GPU asynchronously during unpacking.
|
|
Defaults to ``False``.
|
|
Also see :ref:`cuda-memory-pinning`.
|
|
|
|
|
|
Example::
|
|
|
|
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
|
|
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
|
|
>>> a = torch.randn(5, requires_grad=True, device="cuda")
|
|
>>> b = torch.randn(5, requires_grad=True, device="cuda")
|
|
>>> c = torch.randn(5, requires_grad=True, device="cuda")
|
|
>>>
|
|
>>> def f(a, b, c):
|
|
... prod_1 = a * b # a and b are saved on GPU
|
|
... with torch.autograd.graph.save_on_cpu():
|
|
... prod_2 = prod_1 * c # prod_1 and c are saved on CPU
|
|
... y = prod_2 * a # prod_2 and a are saved on GPU
|
|
... return y
|
|
>>>
|
|
>>> y = f(a, b, c)
|
|
>>> del a, b, c # for illustration only
|
|
>>> # the content of a, b, and prod_2 are still alive on GPU
|
|
>>> # the content of prod_1 and c only live on CPU
|
|
>>> y.sum().backward() # all CPU tensors are moved back to GPU, for backward
|
|
>>> # all intermediary tensors are released (deleted) after the call to backward
|
|
|
|
"""
|
|
|
|
def __init__(self, pin_memory=False, device_type="cuda"):
|
|
device_module = getattr(torch, device_type, torch.cuda)
|
|
|
|
def pack_to_cpu(tensor):
|
|
if not pin_memory:
|
|
return (tensor.device, tensor.cpu())
|
|
packed = torch.empty(
|
|
tensor.size(),
|
|
dtype=tensor.dtype,
|
|
layout=tensor.layout,
|
|
pin_memory=(device_module.is_available() and not tensor.is_sparse),
|
|
)
|
|
packed.copy_(tensor)
|
|
return (tensor.device, packed)
|
|
|
|
def unpack_from_cpu(packed):
|
|
device, tensor = packed
|
|
return tensor.to(device, non_blocking=pin_memory)
|
|
|
|
super().__init__(pack_to_cpu, unpack_from_cpu)
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def disable_saved_tensors_hooks(error_message):
|
|
"""Context-manager that disables the saved tensors default hooks feature.
|
|
|
|
Useful for if you are creating a feature that does not work with saved
|
|
tensors default hooks.
|
|
|
|
Args:
|
|
error_message (str): When saved tensors default hooks are used when they
|
|
have been are disabled, a RuntimeError with this
|
|
error message gets raised.
|
|
|
|
Example::
|
|
|
|
>>> # xdoctest: +SKIP(failing)
|
|
>>> message = "saved tensors default hooks are disabled"
|
|
>>> with torch.autograd.graph.disable_saved_tensors_hooks(message):
|
|
... # Raises RuntimeError: saved tensors default hooks are disabled
|
|
... with torch.autograd.graph.save_on_cpu():
|
|
... pass
|
|
|
|
"""
|
|
try:
|
|
maybe_prev_message = (
|
|
torch._C._autograd._saved_tensors_hooks_get_disabled_error_message()
|
|
)
|
|
torch._C._autograd._saved_tensors_hooks_disable(error_message)
|
|
yield
|
|
finally:
|
|
# See NOTE: [disabled_error_message invariant]
|
|
if maybe_prev_message is None:
|
|
torch._C._autograd._saved_tensors_hooks_enable()
|
|
else:
|
|
torch._C._autograd._saved_tensors_hooks_disable(maybe_prev_message)
|
|
|
|
|
|
def register_multi_grad_hook(
|
|
tensors: Sequence[torch.Tensor],
|
|
fn: Callable[[Sequence[Optional[torch.Tensor]]], None],
|
|
):
|
|
r"""Register a multi-grad backward hook.
|
|
|
|
The hook will be called after gradients with respect to every tensor in
|
|
:attr:`tensors` have been computed. If a tensor is in :attr:`tensors` but
|
|
is not part of the graph, or if a tensor is not needed to compute the gradients
|
|
for any ``inputs`` specified for the current ``.backward()`` or ``.grad()`` call,
|
|
this tensor will be ignored and the hook will not wait for its gradient to be
|
|
computed.
|
|
|
|
After every non-ignored tensor's gradient has been computed, :attr:`fn` will be
|
|
called with those gradients. ``None`` will be passed for tensors that did not
|
|
have their gradients computed.
|
|
|
|
The hook should not modify its arguments.
|
|
|
|
This function returns a handle with a method ``handle.remove()`` that removes the hook.
|
|
|
|
.. note::
|
|
See :ref:`backward-hooks-execution` for more information on how when this hook
|
|
is executed, and how its execution is ordered relative to other hooks.
|
|
|
|
Example::
|
|
|
|
>>> import torch
|
|
>>>
|
|
>>> a = torch.rand(2, 3, requires_grad=True)
|
|
>>> b = torch.rand(2, 3, requires_grad=True)
|
|
>>> c = a * b
|
|
>>> d = a * b
|
|
>>>
|
|
>>> def fn(grads):
|
|
... print([g is not None for g in grads])
|
|
...
|
|
>>> torch.autograd.graph.register_multi_grad_hook((a, b, c, d), fn)
|
|
>>>
|
|
>>> c.sum().backward(retain_graph=True)
|
|
[True, True, True, False]
|
|
>>> c.sum().backward(inputs=(a,), retain_graph=True)
|
|
[True, False, True, False]
|
|
>>>
|
|
"""
|
|
count: Dict[int, int] = dict()
|
|
nb_calls = None
|
|
buffer: Dict[int, List[Optional[torch.Tensor]]] = dict()
|
|
|
|
def get_grad_fn(t):
|
|
# or grad accumulator
|
|
if t.requires_grad and t.grad_fn is None:
|
|
return t.expand_as(t).grad_fn.next_functions[0][0]
|
|
else:
|
|
return t.grad_fn
|
|
|
|
grad_fns = list(map(get_grad_fn, tensors))
|
|
len_tensors = len(tensors)
|
|
|
|
def get_inner_hook(idx):
|
|
def inner_hook(grad: torch.Tensor):
|
|
nonlocal count, nb_calls, buffer
|
|
id = torch._C._current_graph_task_id()
|
|
assert id != -1, "expected this hook to be called inside a backward call"
|
|
count[id] = count.get(id, 0)
|
|
buffer[id] = buffer.get(id, [None] * len_tensors)
|
|
|
|
if count[id] == 0:
|
|
# On the first call, compute the actual nb_calls and buffer
|
|
nb_calls = sum(torch._C._will_engine_execute_node(g) for g in grad_fns) # type: ignore[attr-defined]
|
|
|
|
buffer[id][idx] = grad
|
|
count[id] += 1
|
|
|
|
if count[id] == nb_calls:
|
|
fn(buffer[id])
|
|
del count[id]
|
|
del buffer[id]
|
|
|
|
return inner_hook
|
|
|
|
class Handle(RemovableHandle):
|
|
handles: Tuple[RemovableHandle, ...]
|
|
|
|
def __init__(self, handles: Tuple[RemovableHandle, ...]):
|
|
self.handles = handles
|
|
|
|
def remove(self):
|
|
for handle in self.handles:
|
|
handle.remove()
|
|
|
|
def __getstate__(self):
|
|
return self.handles
|
|
|
|
def __setstate__(self, state):
|
|
self.handles = state
|
|
|
|
handles: List[RemovableHandle] = []
|
|
for i, t in enumerate(tensors):
|
|
handles.append(t.register_hook(get_inner_hook(i)))
|
|
|
|
return Handle(tuple(handles))
|
|
|
|
|
|
# NOTE [Allow mutation on tensors saved for backward]
|
|
#
|
|
# 1. Tensor gets saved for backward
|
|
# - remember the python object id and the version of the tensor
|
|
# - remember aliasing information (data_ptr of base + version)
|
|
# - save the original so we control its lifetime
|
|
# 2. Any time a tensor gets in-placed
|
|
# - for each tensor aliased to it:
|
|
# - check using its object id and version to see if it has been saved
|
|
# - if it has been saved, clone it
|
|
# - delete the reference to the original
|
|
# 3. during backward
|
|
# - if the clone exists, the tensor must've been modified in-place
|
|
_allow_mutation_on_saved_tensors_enabled = False
|
|
|
|
|
|
def _get_tid(t) -> Tuple[int, int, int]:
|
|
return (id(t), t.data_ptr(), t._version)
|
|
|
|
|
|
def _get_sid(t) -> Tuple[int, int]:
|
|
return (t.data_ptr(), t._version)
|
|
|
|
|
|
class _Handle:
|
|
pass
|
|
|
|
|
|
class _swap_with_cloned(saved_tensors_hooks):
|
|
def __init__(self, ctx):
|
|
def pack_hook(t):
|
|
tid = _get_tid(t)
|
|
sid = _get_sid(t)
|
|
# Tensors saved for backward have an entry in _tid_to_weakhandle
|
|
handle: Optional[_Handle] = None
|
|
|
|
# Save aliasing information
|
|
ctx.sid_to_tid[sid].add(tid)
|
|
|
|
# NB: The same tensor (of the same version) can be saved multiple times
|
|
if tid not in ctx.tid_to_weakhandle:
|
|
handle = _Handle()
|
|
ctx.tid_to_weakhandle[tid] = handle
|
|
ctx.original[handle] = t
|
|
else:
|
|
# Store an additional strong reference to the handle
|
|
handle = ctx.tid_to_weakhandle[tid]
|
|
return handle
|
|
|
|
def unpack_hook(tup):
|
|
handle = tup
|
|
error_msg = (
|
|
"Trying to backward outside of the 'allow_mutation_on_saved_tensors' context"
|
|
"in which the graph was originally recorded."
|
|
)
|
|
assert _allow_mutation_on_saved_tensors_enabled, error_msg
|
|
if handle in ctx.cloned:
|
|
res = ctx.cloned[handle]
|
|
else:
|
|
assert handle in ctx.original, error_msg
|
|
res = ctx.original[handle]
|
|
return res
|
|
|
|
super().__init__(pack_hook, unpack_hook)
|
|
|
|
|
|
class _CloneArgBeforeMutateMode(TorchDispatchMode):
|
|
def __init__(self, ctx):
|
|
self.ctx = ctx
|
|
|
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
|
kwargs = kwargs or {}
|
|
|
|
for idx, arg in enumerate(func._schema.arguments):
|
|
if arg.alias_info is not None and arg.alias_info.is_write:
|
|
t = kwargs["out"] if arg.is_out else args[idx]
|
|
tid = _get_tid(t)
|
|
sid = _get_sid(t)
|
|
ctx = self.ctx
|
|
if sid in ctx.sid_to_tid:
|
|
for tid in ctx.sid_to_tid[sid]:
|
|
if tid not in ctx.tid_to_weakhandle:
|
|
# We know that if tid is in sid_to_tid, then it must also be in
|
|
# tid_to_weakhandle. However, it is possible for the tensor to be
|
|
# saved at one point, but cleared by backward before it is modified
|
|
# in-place. Consider the following example:
|
|
#
|
|
# >>> a = torch.randn(2, 3, requires_grad=True).clone()
|
|
# >>> out = (a**2).sum()
|
|
# >>> out.backward()
|
|
# >>> a.sin_()
|
|
continue
|
|
handle = ctx.tid_to_weakhandle[tid]
|
|
if handle in ctx.cloned:
|
|
# The same exact tensor has been cloned already
|
|
continue
|
|
ctx.cloned[handle] = ctx.original[handle].clone()
|
|
del ctx.original[handle]
|
|
|
|
rs = func(*args, **kwargs)
|
|
return rs
|
|
|
|
|
|
class _AllowMutationOnSavedContext:
|
|
def __init__(self):
|
|
self.cloned: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
|
|
self.original: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
|
|
self.tid_to_weakhandle: weakref.WeakValueDictionary = (
|
|
weakref.WeakValueDictionary()
|
|
)
|
|
self.sid_to_tid: Dict[Tuple[int, int], Set[Tuple[int, int, int]]] = defaultdict(
|
|
set
|
|
)
|
|
|
|
def clear(self):
|
|
self.cloned.clear()
|
|
self.original.clear()
|
|
self.tid_to_weakhandle.clear()
|
|
self.sid_to_tid.clear()
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def allow_mutation_on_saved_tensors():
|
|
"""Context manager under which mutating tensors saved for backward is allowed.
|
|
|
|
Under this context manager, tensors saved for backward are cloned on mutation,
|
|
so the original version can still be used during backward. Normally, mutating a tensor
|
|
saved for backward will result in an error raised when it's used during backward.
|
|
|
|
To ensure the correct behavior, both the forward and backward should be run under
|
|
the same context manager.
|
|
|
|
returns:
|
|
An _AllowMutationOnSavedContext object storing the state managed by this
|
|
context manager. This object can be useful for debugging purposes. The state
|
|
managed by the context manager is automatically cleared upon exiting.
|
|
|
|
Example::
|
|
|
|
>>> import torch
|
|
>>> with torch.autograd.graph.allow_mutation_on_saved_tensors():
|
|
... # forward
|
|
... a = torch.ones(2, 3, requires_grad=True)
|
|
... b = a.clone()
|
|
... out = (b**2).sum()
|
|
... b.sin_()
|
|
... # backward
|
|
... out.sum().backward()
|
|
...
|
|
tensor([[0.8415, 0.8415, 0.8415],
|
|
[0.8415, 0.8415, 0.8415]], grad_fn=<SinBackward0>)
|
|
"""
|
|
global _allow_mutation_on_saved_tensors_enabled
|
|
|
|
ctx = _AllowMutationOnSavedContext()
|
|
|
|
with _swap_with_cloned(ctx), _CloneArgBeforeMutateMode(ctx):
|
|
try:
|
|
if _allow_mutation_on_saved_tensors_enabled:
|
|
raise RuntimeError(
|
|
"allow_mutation_on_saved_tensors contexts cannot be nested"
|
|
)
|
|
_allow_mutation_on_saved_tensors_enabled = True
|
|
yield ctx
|
|
finally:
|
|
ctx.clear()
|
|
_allow_mutation_on_saved_tensors_enabled = False
|