mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
## Context > **Note:** `mark_traceable` got renamed to `nonstrict_trace` after > offline discussion. The reasons are (1) it aligns with `torch.export`'s > `nonstrict` notion, and (2) it's more definitive in behavior suggestion. 1. [Overall Design](https://docs.google.com/document/d/1O-dR2ZQaJQVt_v67AVcDCw2yJLtqgkZFwoXK0buEWRg/edit?tab=t.0) 2. [Dynamo graph representation with `torch._higher_order_ops.flat_apply`](https://docs.google.com/document/d/1YHl5nPTJvYeCPE5TO9uA18DPWNgUYGE4gCn6bFvXcBM/edit?tab=t.0#heading=h.xtw3hhbro4gn) ## Summary This patch adds a `torch._dynamo.nonstrict_trace` decorator, which currently is an enhanced version of `torch._dynamo.allow_in_graph` (see docstring for their differences). Specifically, this patch focuses on the UI and functionality prototyping/plumbing. The main enhancement is supporting more input types, and the implementation challenge lies in reconstructing the input objects from Dynamo `VariableTracker` (while accounting for buffered side-effects and guards). This patch takes a middle-ground (simple implementation with a bit of user labor), by 1. asking the user to provide pytree registration for non-proxy-able input types, 2. letting Dynamo trace through `pytree_flatten` (which accounts for buffered side-effects and guards automatically), 3. and passing in the TreeSpec as a graph attribute constant into `torch._higher_order_ops.flat_apply` (which unflattens the inputs and invokes the underlying function). ## Next Steps In subsequent patches, we will try to support the following: - annotating on class method - reads to global tensors - inputs that contains `pytree.register_constant`-ed instances. - function as input - more output types (e.g., any pytree-registered type) - `torch.nn.Module` as inputs Pull Request resolved: https://github.com/pytorch/pytorch/pull/146367 Approved by: https://github.com/zou3519 ghstack dependencies: #146714 |
||
|---|---|---|
| .. | ||
| _strobelight | ||
| _sympy | ||
| backcompat | ||
| benchmark | ||
| bottleneck | ||
| data | ||
| hipify | ||
| jit | ||
| model_dump | ||
| serialization | ||
| tensorboard | ||
| viz | ||
| __init__.py | ||
| _backport_slots.py | ||
| _config_module.py | ||
| _config_typing.pyi | ||
| _content_store.py | ||
| _contextlib.py | ||
| _cpp_embed_headers.py | ||
| _cpp_extension_versioner.py | ||
| _cxx_pytree.py | ||
| _device.py | ||
| _exposed_in.py | ||
| _filelock.py | ||
| _foreach_utils.py | ||
| _freeze.py | ||
| _functools.py | ||
| _get_clean_triton.py | ||
| _import_utils.py | ||
| _mode_utils.py | ||
| _ordered_set.py | ||
| _python_dispatch.py | ||
| _pytree.py | ||
| _stats.py | ||
| _thunk.py | ||
| _traceback.py | ||
| _triton.py | ||
| _typing_utils.py | ||
| _zip.py | ||
| backend_registration.py | ||
| bundled_inputs.py | ||
| checkpoint.py | ||
| collect_env.py | ||
| cpp_backtrace.py | ||
| cpp_extension.py | ||
| deterministic.py | ||
| dlpack.py | ||
| file_baton.py | ||
| flop_counter.py | ||
| hooks.py | ||
| mkldnn.py | ||
| mobile_optimizer.py | ||
| model_zoo.py | ||
| module_tracker.py | ||
| show_pickle.py | ||
| throughput_benchmark.py | ||
| weak.py | ||