mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Document dynamo (#146736)
Many files in dynamo are currently lacking file/module-level documentation, which makes it hard to know what they do at a glance and without digging into the code. This fixes that. Note: documentation was AI-generated and could be incorrect, please review carefully. Pull Request resolved: https://github.com/pytorch/pytorch/pull/146736 Approved by: https://github.com/jansel, https://github.com/StrongerXi, https://github.com/anijain2305, https://github.com/zou3519
This commit is contained in:
parent
0344bf8a5a
commit
21c2565f35
|
|
@ -1,3 +1,13 @@
|
|||
"""
|
||||
TorchDynamo is a Python-level JIT compiler designed to make unmodified PyTorch programs faster.
|
||||
TorchDynamo hooks into the frame evaluation API in CPython (PEP 523) to dynamically modify Python
|
||||
bytecode right before it is executed. It rewrites Python bytecode in order to extract sequences of
|
||||
PyTorch operations into an FX Graph which is then just-in-time compiled with a customizable backend.
|
||||
It creates this FX Graph through bytecode analysis and is designed to mix Python execution with
|
||||
compiled backends to get the best of both worlds: usability and performance. This allows it to
|
||||
seamlessly optimize PyTorch programs, including those using modern Python features.
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
from . import convert_frame, eval_frame, resume_execution
|
||||
|
|
|
|||
|
|
@ -1,3 +1,33 @@
|
|||
"""trace_wrapped(*args, fn) is equivalent to fn(*args), but with a twist:
|
||||
if you make_fx trace through this call, we will not actually trace into fn; instead,
|
||||
we will directly insert it as a call_function to fn in the graph.
|
||||
(Unlike make_fx, Dynamo WILL inline into fn.)
|
||||
You can think of this as a one off allow_in_graph equivalent for proxy tensor tracing.
|
||||
|
||||
Because proxy tensor tracing does not actually run the function, there are
|
||||
requirements on the behavior of fn. We are still figuring it out, but here is the current state:
|
||||
|
||||
1) fn SHOULD only take a single argument, which must be a tensor
|
||||
2) fn MUST return a new tensor with the same metadata as the original tensor
|
||||
(e.g., zeros_like(input) is a permissible implementation of fn).
|
||||
This is verified via an extra assert that is inserted into the traced graph.
|
||||
3) fn MAY have side effects, but it MAY NOT perform metadata mutation on other tensors
|
||||
participating in proxy tensor tracing (it MAY mutate other tensors, it MAY mutate Python state)
|
||||
These requirements stem from the requirement that we need to continue performing proxy tensor tracing,
|
||||
which assumes accurate fake tensor metadata, without actually running fn.
|
||||
In the future, we may allow for a "meta" function associated with fn to allow for more interesting input-output patterns.
|
||||
|
||||
Note that tensors / Python state are allowed to be mutated.
|
||||
This is relaxed constraint is not always sound, but it is sound for backward tracing with fake
|
||||
tensors as it takes place in AOTAutograd, as the backward pass is guaranteed not to depend on concrete
|
||||
tensor values (via fake tensor) or Python state (because the autograd engine doesn't depend on Python).
|
||||
|
||||
The intended use case for this function is to allow AOTAutograd to defer complex
|
||||
backward hooks to compiled autograd. AOTAutograd performs a make_fx trace which preserves
|
||||
the function call as is in the graph, and only when we Dynamo through the backward graph in
|
||||
compiled autograd do we inline into the function.
|
||||
"""
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
|
@ -19,36 +49,6 @@ Tensor = torch.Tensor
|
|||
__all__ = ["trace_wrapped"]
|
||||
|
||||
|
||||
# trace_wrapped(*args, fn) is equivalent to fn(*args), but with a twist:
|
||||
# if you make_fx trace through this call, we will not actually trace into fn; instead,
|
||||
# we will directly insert it as a call_function to fn in the graph.
|
||||
# (Unlike make_fx, Dynamo WILL inline into fn.)
|
||||
# You can think of this as a one off allow_in_graph equivalent for proxy tensor tracing.
|
||||
#
|
||||
# Because proxy tensor tracing does not actually run the function, there are
|
||||
# requirements on the behavior of fn. We are still figuring it out, but here is the current state:
|
||||
#
|
||||
# 1) fn SHOULD only take a single argument, which must be a tensor
|
||||
# 2) fn MUST return a new tensor with the same metadata as the original tensor
|
||||
# (e.g., zeros_like(input) is a permissible implementation of fn).
|
||||
# This is verified via an extra assert that is inserted into the traced graph.
|
||||
# 3) fn MAY have side effects, but it MAY NOT perform metadata mutation on other tensors
|
||||
# participating in proxy tensor tracing (it MAY mutate other tensors, it MAY mutate Python state)
|
||||
# These requirements stem from the requirement that we need to continue performing proxy tensor tracing,
|
||||
# which assumes accurate fake tensor metadata, without actually running fn.
|
||||
# In the future, we may allow for a "meta" function associated with fn to allow for more interesting input-output patterns.
|
||||
#
|
||||
# Note that tensors / Python state are allowed to be mutated.
|
||||
# This is relaxed constraint is not always sound, but it is sound for backward tracing with fake
|
||||
# tensors as it takes place in AOTAutograd, as the backward pass is guaranteed not to depend on concrete
|
||||
# tensor values (via fake tensor) or Python state (because the autograd engine doesn't depend on Python).
|
||||
#
|
||||
# The intended use case for this function is to allow AOTAutograd to defer complex
|
||||
# backward hooks to compiled autograd. AOTAutograd performs a make_fx trace which preserves
|
||||
# the function call as is in the graph, and only when we Dynamo through the backward graph in
|
||||
# compiled autograd do we inline into the function.
|
||||
|
||||
|
||||
if not torch._running_with_deploy():
|
||||
# torch.library.custom_op does not work with torch.deploy/multipy
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,23 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
"""
|
||||
This module provides common utilities and base classes for TorchDynamo backends.
|
||||
|
||||
Key components:
|
||||
- AotAutograd: Base class for implementing AOT (Ahead-of-Time) autograd backends
|
||||
- Backend utilities for handling:
|
||||
- Fake tensor conversion
|
||||
- Device/dtype detection from inputs
|
||||
- Memory efficient fusion
|
||||
- Graph flattening
|
||||
- Common compiler configurations
|
||||
|
||||
The utilities here are used by various backend implementations to handle
|
||||
common operations and provide consistent behavior across different backends.
|
||||
AOT autograd functionality is particularly important as it enables ahead-of-time
|
||||
optimization of both forward and backward passes.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
import logging
|
||||
|
|
|
|||
|
|
@ -1,5 +1,28 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
"""
|
||||
This module implements CUDA graphs support for TorchDynamo backends.
|
||||
|
||||
CUDA graphs allow for capturing and replaying GPU operations, which can significantly
|
||||
reduce CPU overhead in GPU-accelerated PyTorch models. This module provides:
|
||||
|
||||
- CUDA graph creation and management for both forward and backward passes
|
||||
- Input mutation detection and handling
|
||||
- Device compatibility checking
|
||||
- Stack trace management for debugging
|
||||
- Integration with TorchInductor's cudagraph trees
|
||||
|
||||
The backend supports two main modes:
|
||||
1. cudagraphs: Full CUDA graph support with both forward and backward pass optimization
|
||||
2. cudagraphs_inner: Lower-level CUDA graph implementation used for benchmarking
|
||||
|
||||
Key components:
|
||||
- CudagraphsBackend: Main backend class for CUDA graph integration
|
||||
- Mutation detection utilities to ensure graph safety
|
||||
- Device mapping and compatibility checks
|
||||
- Stack trace collection for debugging
|
||||
"""
|
||||
|
||||
import functools
|
||||
from collections import defaultdict
|
||||
from typing import Optional
|
||||
|
|
|
|||
|
|
@ -1,5 +1,30 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
"""
|
||||
This module provides debugging backends for TorchDynamo to help diagnose and troubleshoot
|
||||
compilation and execution issues. It includes:
|
||||
|
||||
Key Debugging Backends:
|
||||
- eager: Simple pass-through backend that runs models in eager mode
|
||||
- eager_noexcept: Similar to eager but with additional exception handling
|
||||
- eager_debug: Adds schema validation checks for custom operators
|
||||
- aot_eager: Uses AOT Autograd with nop compiler for debugging
|
||||
- aot_eager_decomp_partition: Uses TorchInductor decompositions for debugging
|
||||
- torchscript: Compiles using TorchScript for debugging JIT-related issues
|
||||
|
||||
Testing and Development Tools:
|
||||
- Backends for inducing specific errors (compile/runtime/accuracy)
|
||||
- ExplainOutput class for detailed graph compilation analysis
|
||||
- Utilities for cross-referencing and mode management
|
||||
- Tools for graph detail inspection and break reason analysis
|
||||
|
||||
These backends are primarily used for:
|
||||
1. Debugging graph breaks and compilation failures
|
||||
2. Testing error handling and recovery mechanisms
|
||||
3. Analyzing performance bottlenecks
|
||||
4. Validating operator schemas and decompositions
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
import functools
|
||||
import logging
|
||||
|
|
@ -19,11 +44,6 @@ from .registry import register_debug_backend as register_backend
|
|||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
"""
|
||||
This file contains TorchDynamo backends intended for debugging uses.
|
||||
"""
|
||||
|
||||
|
||||
@register_backend
|
||||
def eager(gm, fake_tensor_inputs, **kwargs):
|
||||
if kwargs:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,23 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
"""
|
||||
This module implements distributed training optimizations for TorchDynamo backends.
|
||||
|
||||
It provides functionality to optimize models wrapped in DistributedDataParallel (DDP)
|
||||
by intelligently splitting compiled graphs to align with DDP's gradient synchronization
|
||||
boundaries. Key features include:
|
||||
|
||||
- Graph partitioning based on parameter bucket sizes
|
||||
- Optimization of allreduce operations for distributed training
|
||||
- Support for parameter ignoring and buffer handling
|
||||
- Submodule compilation and management
|
||||
- Debugging utilities for distributed training
|
||||
|
||||
The main component is the DDPOptimizer class, which handles graph splitting and
|
||||
recompilation to enable efficient distributed training while maintaining the benefits
|
||||
of compilation.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import traceback
|
||||
from dataclasses import dataclass, field
|
||||
|
|
|
|||
|
|
@ -1,5 +1,16 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
"""
|
||||
This module provides the TorchInductor backend integration for TorchDynamo.
|
||||
|
||||
TorchInductor is a compiler backend that generates optimized code for both CPU and GPU.
|
||||
This module lazily imports and registers the TorchInductor compiler to avoid loading it
|
||||
into memory when it is not being used. This helps reduce memory overhead when using
|
||||
other backends.
|
||||
|
||||
The inductor backend can be used with torch.compile():
|
||||
model = torch.compile(model, backend="inductor")
|
||||
"""
|
||||
|
||||
from torch._dynamo import register_backend
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,65 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
"""
|
||||
This module implements TorchDynamo's backend registry system for managing compiler backends.
|
||||
|
||||
The registry provides a centralized way to register, discover and manage different compiler
|
||||
backends that can be used with torch.compile(). It handles:
|
||||
|
||||
- Backend registration and discovery through decorators and entry points
|
||||
- Lazy loading of backend implementations
|
||||
- Lookup and validation of backend names
|
||||
- Categorization of backends using tags (debug, experimental, etc.)
|
||||
|
||||
Key components:
|
||||
- CompilerFn: Type for backend compiler functions that transform FX graphs
|
||||
- _BACKENDS: Registry mapping backend names to entry points
|
||||
- _COMPILER_FNS: Registry mapping backend names to loaded compiler functions
|
||||
|
||||
Example usage:
|
||||
@register_backend
|
||||
def my_compiler(fx_graph, example_inputs):
|
||||
# Transform FX graph into optimized implementation
|
||||
return compiled_fn
|
||||
|
||||
# Use registered backend
|
||||
torch.compile(model, backend="my_compiler")
|
||||
|
||||
The registry also supports discovering backends through setuptools entry points
|
||||
in the "torch_dynamo_backends" group. Example:
|
||||
```
|
||||
setup.py
|
||||
---
|
||||
from setuptools import setup
|
||||
|
||||
setup(
|
||||
name='my_torch_backend',
|
||||
version='0.1',
|
||||
packages=['my_torch_backend'],
|
||||
entry_points={
|
||||
'torch_dynamo_backends': [
|
||||
# name = path to entry point of backend implementation
|
||||
'my_compiler = my_torch_backend.compiler:my_compiler_function',
|
||||
],
|
||||
},
|
||||
)
|
||||
```
|
||||
```
|
||||
my_torch_backend/compiler.py
|
||||
---
|
||||
def my_compiler_function(fx_graph, example_inputs):
|
||||
# Transform FX graph into optimized implementation
|
||||
return compiled_fn
|
||||
```
|
||||
Using `my_compiler` backend:
|
||||
```
|
||||
import torch
|
||||
|
||||
model = ... # Your PyTorch model
|
||||
optimized_model = torch.compile(model, backend="my_compiler")
|
||||
```
|
||||
"""
|
||||
|
||||
import functools
|
||||
import logging
|
||||
import sys
|
||||
|
|
|
|||
|
|
@ -1,5 +1,27 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
"""
|
||||
This module provides TVM backend integration for TorchDynamo.
|
||||
|
||||
Apache TVM is a deep learning compiler framework that can optimize and execute
|
||||
models on various hardware backends. This module enables:
|
||||
|
||||
- Compilation of PyTorch models to TVM's computation graphs
|
||||
- Multiple scheduling options:
|
||||
- Default scheduler
|
||||
- Auto-scheduler for automatic optimization
|
||||
- Meta-schedule for evolutionary search-based tuning
|
||||
- Hardware-specific optimizations:
|
||||
- CUDA GPU support
|
||||
- CPU support with LLVM targeting and architecture-specific tuning
|
||||
- Automatic detection of CPU capabilities (AVX2, AVX512)
|
||||
- Tensor conversion utilities between PyTorch and TVM formats
|
||||
- Configurable optimization levels and tuning trials
|
||||
|
||||
The backend can be used with torch.compile():
|
||||
model = torch.compile(model, backend="tvm")
|
||||
"""
|
||||
|
||||
import functools
|
||||
import importlib
|
||||
import logging
|
||||
|
|
|
|||
|
|
@ -1,4 +1,19 @@
|
|||
# mypy: allow-untyped-defs
|
||||
|
||||
"""
|
||||
This module provides utilities for analyzing and optimizing Python bytecode.
|
||||
Key functionality includes:
|
||||
- Dead code elimination
|
||||
- Jump instruction optimization
|
||||
- Stack size analysis and verification
|
||||
- Live variable analysis
|
||||
- Line number propagation and cleanup
|
||||
- Exception table handling for Python 3.11+
|
||||
|
||||
The utilities in this module are used to analyze and transform bytecode
|
||||
for better performance while maintaining correct semantics.
|
||||
"""
|
||||
|
||||
import bisect
|
||||
import dataclasses
|
||||
import dis
|
||||
|
|
|
|||
|
|
@ -1,4 +1,20 @@
|
|||
# mypy: allow-untyped-defs
|
||||
|
||||
"""
|
||||
This module provides utilities for analyzing, transforming and manipulating Python bytecode.
|
||||
It includes functionality for:
|
||||
- Converting between different bytecode formats and versions
|
||||
- Virtualizing jumps and managing jump targets
|
||||
- Handling exception tables and their entries
|
||||
- Managing instruction offsets and extended arguments
|
||||
- Providing a clean API for bytecode modification and transformation
|
||||
- Supporting Python version-specific bytecode features
|
||||
- Generating bytecode from template functions
|
||||
|
||||
The module is designed to work across different Python versions (3.7+) and handles
|
||||
version-specific bytecode differences transparently.
|
||||
"""
|
||||
|
||||
import copy
|
||||
import dataclasses
|
||||
import dis
|
||||
|
|
|
|||
|
|
@ -1,3 +1,30 @@
|
|||
"""
|
||||
This module provides callback management functionality for TorchDynamo's compilation process.
|
||||
|
||||
It implements a thread-safe system for registering, managing and executing callbacks that run
|
||||
at the start and end of TorchDynamo compilations. Key features include:
|
||||
|
||||
- Registration and deregistration of compilation callbacks
|
||||
- Thread-safe callback handling with proper locking mechanisms
|
||||
- Prevention of duplicate callback execution when configured
|
||||
- Decorator utilities for easy callback registration
|
||||
- Context manager for controlled callback lifecycle
|
||||
|
||||
The module centers around the CompilationCallbackHandler class which maintains separate
|
||||
lists for start and end callbacks, manages their execution order, and ensures thread-safety.
|
||||
Utility decorators @on_compile_start and @on_compile_end provide a convenient way to
|
||||
register compilation hooks.
|
||||
|
||||
Example usage:
|
||||
@on_compile_start
|
||||
def my_start_callback():
|
||||
print("Starting compilation")
|
||||
|
||||
@on_compile_end
|
||||
def my_end_callback():
|
||||
print("Compilation complete")
|
||||
"""
|
||||
|
||||
import threading
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
|
|
|
|||
|
|
@ -1,3 +1,33 @@
|
|||
"""
|
||||
This module provides thread-safe code context management for TorchDynamo using weak references.
|
||||
|
||||
The CodeContextDict class maintains a mapping between Python code objects and their associated
|
||||
context data, using weak references to automatically clean up entries when code objects are
|
||||
garbage collected. This prevents memory leaks while allowing context data to be associated
|
||||
with code objects throughout their lifecycle.
|
||||
|
||||
Key features:
|
||||
- Thread-safe context storage and retrieval
|
||||
- Automatic cleanup using weak references
|
||||
- Safe context management for Python code objects
|
||||
- Memory-leak prevention
|
||||
|
||||
Example usage:
|
||||
code_obj = compile('x = 1', '<string>', 'exec')
|
||||
|
||||
# Store context
|
||||
context = code_context.get_context(code_obj)
|
||||
context['metadata'] = {'optimized': True}
|
||||
|
||||
# Retrieve context
|
||||
if code_context.has_context(code_obj):
|
||||
ctx = code_context.get_context(code_obj)
|
||||
# Use context data...
|
||||
|
||||
# Remove context
|
||||
ctx = code_context.pop_context(code_obj)
|
||||
"""
|
||||
|
||||
import types
|
||||
from typing import Any
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,17 @@
|
|||
# mypy: allow-untyped-defs
|
||||
|
||||
"""
|
||||
This module provides utilities for generating Python bytecode in PyTorch's Dynamo system.
|
||||
It includes functionality for:
|
||||
- Constructing bytecode sequences for Python operations
|
||||
- Managing stack operations and variable tracking
|
||||
- Handling graph outputs and their conversions
|
||||
- Supporting different Python versions (3.11+, 3.12+, 3.13+)
|
||||
- Converting high-level operations to low-level bytecode instructions
|
||||
- Managing constant loading and attribute access
|
||||
- Supporting function creation and closure handling
|
||||
"""
|
||||
|
||||
import collections
|
||||
import dataclasses
|
||||
import re
|
||||
|
|
|
|||
|
|
@ -1,4 +1,21 @@
|
|||
# mypy: allow-untyped-defs
|
||||
|
||||
"""
|
||||
Provides functionality for compiling PyTorch's autograd (automatic differentiation) system.
|
||||
|
||||
This module implements compiled autograd, which traces and optimizes backward pass
|
||||
computations at runtime. The key components are:
|
||||
|
||||
- AutogradCompilerInstance: Traces and compiles autograd graphs using FX
|
||||
- Context managers (_enable/_disable): Control when compiled autograd is active
|
||||
- Utility functions: Support graph manipulation, tensor operations, and hooks
|
||||
|
||||
Compiled autograd can significantly improve backward pass performance by removing
|
||||
Python overhead and enabling additional optimizations. It works by capturing
|
||||
backward computations into an FX graph that can be compiled and optimized,
|
||||
while maintaining the same semantics as eager mode autograd.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
import itertools
|
||||
|
|
|
|||
|
|
@ -1,10 +1,41 @@
|
|||
# mypy: allow-untyped-defs
|
||||
# This file establishes the public comptime interface to Dynamo.
|
||||
# This allows Dynamo users to execute arbitrary Python code while
|
||||
# Dynamo is symbolically evaluating their original programs.
|
||||
#
|
||||
# The goal of the public API is to give users rope, without actually
|
||||
# leaking private implementation details of Dynamo.
|
||||
|
||||
"""
|
||||
This module provides the public comptime interface to TorchDynamo, enabling users to execute
|
||||
arbitrary Python code during symbolic evaluation of their programs.
|
||||
|
||||
The comptime interface allows inspection and modification of TorchDynamo's compilation
|
||||
process while it is running. This can be useful for:
|
||||
|
||||
- Debugging compilation issues
|
||||
- Inspecting intermediate state
|
||||
- Adding custom guards or graph breaks
|
||||
- Analyzing symbolic shapes and values
|
||||
|
||||
Example usage:
|
||||
|
||||
import torch
|
||||
from torch._dynamo.comptime import comptime
|
||||
|
||||
def my_model(x):
|
||||
# Print the compile-time known information about x
|
||||
comptime.print(x)
|
||||
|
||||
# Print the current FX graph being constructed
|
||||
comptime.print_graph()
|
||||
|
||||
# Force a value to be treated as static
|
||||
if comptime(lambda ctx: ctx.get_local("x").is_dynamic()):
|
||||
comptime.force_static(x)
|
||||
|
||||
# Add a manual graph break
|
||||
comptime.graph_break()
|
||||
|
||||
Note: While this API provides significant flexibility, it intentionally avoids
|
||||
exposing internal implementation details of TorchDynamo to maintain compatibility
|
||||
across versions.
|
||||
"""
|
||||
|
||||
|
||||
import builtins
|
||||
import dis
|
||||
|
|
|
|||
|
|
@ -1,4 +1,16 @@
|
|||
# mypy: allow-untyped-defs
|
||||
|
||||
"""
|
||||
Configuration module for TorchDynamo compiler and optimization settings.
|
||||
|
||||
This module contains various configuration flags and settings that control TorchDynamo's
|
||||
behavior, including:
|
||||
- Runtime behavior flags (e.g., guard settings, specialization options)
|
||||
- Debugging and development options
|
||||
- Performance tuning parameters
|
||||
- Feature toggles for experimental features
|
||||
"""
|
||||
|
||||
import getpass
|
||||
import os
|
||||
import sys
|
||||
|
|
|
|||
|
|
@ -1,4 +1,24 @@
|
|||
# mypy: allow-untyped-decorators
|
||||
|
||||
"""
|
||||
This module implements TorchDynamo's core frame conversion functionality, transforming Python
|
||||
frames into FX graphs. It handles:
|
||||
|
||||
- Frame analysis and bytecode transformation
|
||||
- Guard creation and management for dynamic behaviors
|
||||
- Cache management for recompilation
|
||||
- Error handling and fallback mechanisms
|
||||
|
||||
Key classes:
|
||||
- ConvertFrame: Main entry point for frame conversion with error handling
|
||||
- ConvertFrameAssert: Implements core frame to graph conversion logic
|
||||
- Tracker: Tracks input/output code objects during conversion
|
||||
- CatchErrorsWrapper: Provides error handling and suppression logic
|
||||
|
||||
The conversion process preserves program semantics while enabling optimizations
|
||||
through torch.compile() and related systems.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
|
|
|
|||
|
|
@ -1,3 +1,20 @@
|
|||
"""
|
||||
Provides thread-local scope identification for SubgraphTracer instances.
|
||||
|
||||
This module implements a thread-safe mechanism for tracking nested tracing contexts,
|
||||
which is essential when multiple SubgraphTracer instances are active. The scope ID
|
||||
helps identify which tracer context is currently active when direct access to the
|
||||
InstructionTranslator is difficult.
|
||||
|
||||
Key components:
|
||||
- Thread-local scope ID storage (_current_scope_id)
|
||||
- Getter function (current_scope_id) to safely access the current scope
|
||||
- Context manager (enter_new_scope) for managing nested scope transitions
|
||||
|
||||
The scope ID increments when entering a new context and decrements when exiting,
|
||||
allowing proper tracking of nested tracing operations across different threads.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import threading
|
||||
from collections.abc import Generator
|
||||
|
|
|
|||
|
|
@ -1,5 +1,24 @@
|
|||
# mypy: allow-untyped-defs
|
||||
# mypy: disable-error-code="method-assign"
|
||||
|
||||
"""
|
||||
Debug utilities for TorchDynamo compilation and execution.
|
||||
|
||||
This module provides various debugging tools and utilities for TorchDynamo, including:
|
||||
|
||||
- Minification support for reducing test cases while preserving bugs
|
||||
- Input/output handling via InputReader and InputWriter for reproducible testing
|
||||
- Accuracy checking between original and compiled models
|
||||
- Neural network module string conversion via NNModuleToString
|
||||
- Profiling tools and system information collection
|
||||
- Buck build system integration for Meta-internal testing
|
||||
|
||||
Key classes:
|
||||
- InputReader/InputWriter: Handle serialization of model inputs/outputs
|
||||
- NNModuleToString: Converts nn.Modules to string representations
|
||||
- BuckTargetWriter: Manages Buck build system integration
|
||||
"""
|
||||
|
||||
import atexit
|
||||
import copy
|
||||
import cProfile
|
||||
|
|
|
|||
|
|
@ -1,5 +1,10 @@
|
|||
# mypy: allow-untyped-defs
|
||||
# ruff: noqa: TCH004
|
||||
|
||||
"""
|
||||
This module provides decorators and utilities for controlling TorchDynamo's behavior during compilation.
|
||||
"""
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
import sys
|
||||
|
|
|
|||
|
|
@ -1,4 +1,22 @@
|
|||
# mypy: allow-untyped-defs
|
||||
|
||||
"""
|
||||
Device abstraction layer for TorchDynamo and Inductor backends.
|
||||
|
||||
This module provides a unified interface for different hardware backends (CUDA, XPU,
|
||||
CPU, MPS) through a common device interface. Key components include:
|
||||
|
||||
- DeviceInterface: Base class defining the common API for all device types
|
||||
- Device-specific implementations: CudaInterface, XpuInterface, CpuInterface, MpsInterface
|
||||
- Device registration system for managing available backends
|
||||
- Worker APIs for multi-processing scenarios
|
||||
- Stream and event management across different devices
|
||||
- Device property caching for worker processes
|
||||
|
||||
The abstraction layer enables device-agnostic code in TorchDynamo while allowing
|
||||
specialized implementations for each hardware backend's unique features.
|
||||
"""
|
||||
|
||||
import time
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
|
|
|
|||
|
|
@ -1,3 +1,19 @@
|
|||
"""
|
||||
Manages process groups for distributed compilation in TorchDynamo.
|
||||
|
||||
This module handles the initialization and management of process groups used for
|
||||
distributed compilation. Key features:
|
||||
|
||||
- Lazy initialization of compilation process groups
|
||||
- Only creates groups when distributed mode is enabled and available
|
||||
- Integrates with compiler_collectives configuration setting
|
||||
- Provides a single global process group for compilation coordination
|
||||
|
||||
The process group is created only when needed and if the distributed environment
|
||||
is properly initialized, making it safe to import and use this module even in
|
||||
non-distributed scenarios.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch.distributed as dist
|
||||
|
|
|
|||
|
|
@ -2,10 +2,25 @@
|
|||
# mypy: disable-error-code="method-assign"
|
||||
|
||||
"""
|
||||
Functions in this file are responsible for modifying the eval frame
|
||||
handler at RUNTIME. Therefore, all functions in this file are hot.
|
||||
Functions that only execute at compile time should be placed
|
||||
in torch._dynamo.convert_frame.
|
||||
This module implements the core frame evaluation handler for TorchDynamo's compilation system.
|
||||
The eval frame handler intercepts Python bytecode execution at runtime to enable dynamic
|
||||
compilation and optimization of PyTorch code.
|
||||
|
||||
Key components defined here:
|
||||
- Frame evaluation handlers that intercept and analyze Python execution frames
|
||||
- Guards management for tracking dependencies and invalidating compiled code
|
||||
- Optimization contexts and decorators (optimize, run_once, disable, etc.)
|
||||
- Export functionality for saving optimized graphs
|
||||
- Backend compiler integrations and callback management
|
||||
|
||||
Functions in this file are responsible for modifying the eval frame handler at RUNTIME.
|
||||
Therefore, all functions in this file are hot and performance-critical. Functions that
|
||||
only execute at compile time should be placed in torch._dynamo.convert_frame.
|
||||
|
||||
The eval frame handler is the core mechanism that enables TorchDynamo to dynamically
|
||||
intercept, analyze and optimize PyTorch code during execution. It works by registering
|
||||
a custom frame evaluation function that gets called for every Python frame, allowing
|
||||
us to detect PyTorch operations and trigger compilation as needed.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
|
|
|||
|
|
@ -1,5 +1,31 @@
|
|||
from __future__ import annotations
|
||||
|
||||
|
||||
"""Exception handling and error reporting for TorchDynamo.
|
||||
|
||||
This module provides a comprehensive set of exception classes and utilities for error
|
||||
handling in TorchDynamo. It includes:
|
||||
|
||||
Base Exceptions:
|
||||
- TorchDynamoException: Base class for all TorchDynamo-specific exceptions
|
||||
- Various specialized subclasses for different error scenarios
|
||||
|
||||
User Error Handling:
|
||||
- UserError: Exceptions for user-facing errors in TorchDynamo usage
|
||||
- UserErrorType: Enumeration of different categories of user errors
|
||||
- Formatted error messages with debugging information
|
||||
|
||||
Observed Exceptions:
|
||||
- Classes for handling exceptions observed during tracing
|
||||
- Special handling for StopIteration, LookupError, etc.
|
||||
- Exception state management during compilation
|
||||
|
||||
Error Formatting:
|
||||
- Stack trace filtering and formatting
|
||||
- Error message augmentation
|
||||
- Debugging utilities for error reporting
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import textwrap
|
||||
|
|
|
|||
|
|
@ -1,5 +1,27 @@
|
|||
# This module contains functions that *will be allowed* by dynamo
|
||||
|
||||
"""
|
||||
This module contains utility functions that are explicitly allowed to be called during
|
||||
TorchDynamo compilation. These functions are carefully vetted to ensure they work
|
||||
correctly within the TorchDynamo tracing and compilation process.
|
||||
|
||||
Key functionality groups:
|
||||
|
||||
- Compilation State:
|
||||
Functions for checking compilation state (is_compiling)
|
||||
|
||||
- Function Wrapping:
|
||||
Utilities for wrapping functions (wrap_inline, wrap_numpy) to work with
|
||||
TorchDynamo compilation
|
||||
|
||||
- Autograd Hooks:
|
||||
Functions and classes for handling autograd hooks and backward passes
|
||||
(call_hook, FakeBackwardCFunction, etc.)
|
||||
|
||||
- Tensor Operations:
|
||||
Utility functions for tensor operations and transformations
|
||||
"""
|
||||
|
||||
import functools
|
||||
import warnings
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union
|
||||
|
|
|
|||
|
|
@ -1,3 +1,21 @@
|
|||
"""
|
||||
This module provides functionality for caching and looking up fully qualified function
|
||||
and class names from Python source files by line number.
|
||||
|
||||
It uses Python's tokenize module to parse source files and tracks function/class
|
||||
definitions along with their nesting to build fully qualified names (e.g. 'class.method'
|
||||
or 'module.function'). The results are cached in a two-level dictionary mapping:
|
||||
|
||||
filename -> (line_number -> fully_qualified_name)
|
||||
|
||||
Example usage:
|
||||
name = get_funcname("myfile.py", 42) # Returns name of function/class at line 42
|
||||
clearcache() # Clear the cache if file contents have changed
|
||||
|
||||
The parsing is done lazily when a file is first accessed. Invalid Python files or
|
||||
IO errors are handled gracefully by returning empty cache entries.
|
||||
"""
|
||||
|
||||
import tokenize
|
||||
from typing import Optional
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,12 @@
|
|||
"""
|
||||
This module implements graph deduplication functionality for TorchDynamo's optimization pipeline.
|
||||
Graph deduplication identifies identical subgraphs in the computational graph and merges them
|
||||
to reduce redundancy and improve performance. The process involves analyzing regions of the graph,
|
||||
identifying structurally equivalent regions, and replacing them with a single shared implementation.
|
||||
This optimization is particularly effective for models with repeated patterns or similar computational
|
||||
structures across different parts of the network.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import operator
|
||||
from collections.abc import Iterable
|
||||
|
|
|
|||
|
|
@ -1,3 +1,18 @@
|
|||
"""
|
||||
This module provides functionality for tracking and managing regions in computational graphs.
|
||||
It supports graph optimization by identifying and grouping similar regions based on their
|
||||
structure and behavior. The module implements algorithms for:
|
||||
|
||||
1. Tracking nodes and their relationships in the computational graph
|
||||
2. Identifying identical or similar regions across the graph
|
||||
3. Managing graph regions for optimization purposes
|
||||
4. Supporting deduplication and other graph transformation passes
|
||||
|
||||
The core functionality revolves around the GraphRegionTracker class which maintains
|
||||
mappings between nodes and their duplicates, enabling efficient graph analysis and
|
||||
optimization operations.
|
||||
"""
|
||||
|
||||
import copyreg
|
||||
import io
|
||||
import logging
|
||||
|
|
|
|||
|
|
@ -1,5 +1,22 @@
|
|||
# mypy: allow-untyped-defs
|
||||
|
||||
"""
|
||||
Core guard system for Dynamo that detects when compiled code needs to be recompiled due to
|
||||
changes in program state. Guards are conditions that must remain true for previously-compiled
|
||||
code to be valid for reuse.
|
||||
|
||||
This module provides the infrastructure for creating, managing and checking guards, including:
|
||||
- Guard creation and composition
|
||||
- Guard state management and invalidation
|
||||
- Guard checking and failure handling
|
||||
- Utilities for guard optimization and debugging
|
||||
- Integration with Dynamo's compilation caching
|
||||
|
||||
The guard system is critical for Dynamo's ability to efficiently reuse compiled code while
|
||||
maintaining correctness by detecting when recompilation is necessary due to changes in
|
||||
program state, tensor properties, or control flow.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
|
|
|
|||
|
|
@ -1,3 +1,15 @@
|
|||
"""Hook system for Dynamo's guard functionality.
|
||||
|
||||
This module provides a way to register callback functions that are triggered during
|
||||
guard-related operations.
|
||||
|
||||
The Hooks class manages two types of hook functions:
|
||||
- guard_export_fn: Called when guards need to be exported, taking a GuardsSet as input
|
||||
- guard_fail_fn: Called when a guard check fails, taking a GuardFail object as input
|
||||
|
||||
These hooks enable customization of guard export and failure handling behaviors.
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
from typing import Callable, Optional
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,14 @@
|
|||
"""Logging utilities for Dynamo and Inductor.
|
||||
|
||||
This module provides specialized logging functionality including:
|
||||
- Step-based logging that prepends step numbers to log messages
|
||||
- Progress bar management for compilation phases
|
||||
- Centralized logger management for Dynamo and Inductor components
|
||||
|
||||
The logging system helps track the progress of compilation phases and provides structured
|
||||
logging output for debugging and monitoring.
|
||||
"""
|
||||
|
||||
import itertools
|
||||
import logging
|
||||
from typing import Any, Callable
|
||||
|
|
|
|||
|
|
@ -1,3 +1,18 @@
|
|||
"""Metrics collection and management system for Dynamo.
|
||||
|
||||
This module provides context managers for gathering and reporting metrics during
|
||||
compilation and runtime.
|
||||
|
||||
It includes two main components:
|
||||
- MetricsContext: A context manager for collecting metrics during compilation, supporting
|
||||
nested contexts and various metric types (counters, sets, key-value pairs)
|
||||
- RuntimeMetricsContext: A specialized context for runtime metrics collection that doesn't
|
||||
require explicit context management
|
||||
|
||||
The metrics system enables comprehensive monitoring and analysis of both compilation and
|
||||
execution performance.
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Any, Callable, Optional
|
||||
from typing_extensions import TypeAlias
|
||||
|
|
|
|||
|
|
@ -1,3 +1,17 @@
|
|||
"""Mutation tracking and dynamic module detection system for Dynamo.
|
||||
|
||||
This module provides mechanisms to track and respond to mutations in PyTorch modules
|
||||
and detect dynamically created or modified modules.
|
||||
|
||||
Key components:
|
||||
- MutationTracker: Tracks mutations to objects and invalidates associated cached code
|
||||
- GenerationTracker: Tracks module creation timing to identify dynamic instances
|
||||
- Patching system for nn.Module to detect mutations and dynamic creation
|
||||
|
||||
The system ensures that Dynamo's optimizations remain valid by detecting and responding
|
||||
to runtime changes in module state and structure.
|
||||
"""
|
||||
|
||||
import functools
|
||||
import weakref
|
||||
from collections.abc import MutableMapping
|
||||
|
|
|
|||
|
|
@ -1,4 +1,26 @@
|
|||
# mypy: allow-untyped-defs
|
||||
|
||||
"""
|
||||
Core graph building functionality for PyTorch's Dynamo system. This module contains
|
||||
the essential components for constructing and managing FX graphs during compilation:
|
||||
|
||||
- OutputGraph: Manages the overall graph construction and compilation process. It owns
|
||||
a SubgraphTracer and handles graph compilation, execution, and state management.
|
||||
OutputGraph also manages features like graph deduplication, symbolic shape handling,
|
||||
and tracking of side effects.
|
||||
|
||||
- SubgraphTracer: Handles the actual FX graph construction by tracing Python code.
|
||||
It supports advanced features like higher-order operators through nested tracers,
|
||||
lifting of free variables, and handling of symbolic shapes.
|
||||
|
||||
The module supports key Dynamo features including:
|
||||
- Higher-order operators through nested SubgraphTracers
|
||||
- Graph deduplication for optimization
|
||||
- Symbolic shape handling and propagation
|
||||
- Side effect tracking and management
|
||||
- Guard insertion and management
|
||||
"""
|
||||
|
||||
import collections
|
||||
import contextlib
|
||||
import copy
|
||||
|
|
|
|||
|
|
@ -1,3 +1,14 @@
|
|||
"""
|
||||
Profile Guided Optimization (PGO) implementation for Dynamo.
|
||||
|
||||
This module provides functionality for caching and managing code state profiles
|
||||
that guide optimization decisions in Dynamo. It implements both local and remote
|
||||
caching mechanisms for storing profile information across runs, handles profile
|
||||
merging across distributed ranks, and manages the lifecycle of profile data
|
||||
during compilation. The profiles track dynamic vs static properties of tensors
|
||||
and help Dynamo make better specialization decisions.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
|
|
|
|||
|
|
@ -1,3 +1,17 @@
|
|||
"""
|
||||
Dynamo profiling implementation.
|
||||
|
||||
This module provides profiling functionality for Dynamo, including:
|
||||
- ProfileMetrics: Class for collecting and aggregating performance metrics like
|
||||
execution time, operator counts, and fusion statistics
|
||||
- ProfileResult: Class for analyzing and reporting profiling results
|
||||
- Utilities for tracking missed/uncaptured operations
|
||||
- Functions for instrumenting FX graphs with profiling capabilities
|
||||
|
||||
The profiler helps measure and optimize the performance of Dynamo-compiled code
|
||||
by tracking both captured and total operations, timing, and graph statistics.
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
import os
|
||||
from typing import Any
|
||||
|
|
|
|||
|
|
@ -1,3 +1,18 @@
|
|||
"""
|
||||
Python execution state recording and replay functionality.
|
||||
|
||||
This module provides mechanisms for capturing and replaying Python execution state:
|
||||
|
||||
- ModuleRecord: Tracks module access patterns and attribute usage
|
||||
- DummyModule: Lightweight module substitute for replay
|
||||
- ExecutionRecord: Manages execution context including globals, locals and builtins
|
||||
- ExecutionRecorder: Records variable states and module access during execution
|
||||
|
||||
The module enables serialization and reproduction of Python execution environments,
|
||||
particularly useful for debugging and testing frameworks that need to capture
|
||||
and recreate specific program states.
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
from dataclasses import field
|
||||
from types import CellType, CodeType, ModuleType
|
||||
|
|
|
|||
|
|
@ -1,5 +1,24 @@
|
|||
# mypy: allow-untyped-defs
|
||||
|
||||
"""
|
||||
Utilities for reproducing and debugging issues in PyTorch's Dynamo AOT compilation.
|
||||
|
||||
This module provides tools and infrastructure for:
|
||||
1. Generating minimal reproducible test cases ("repros") from failing compilations
|
||||
2. Analyzing accuracy issues between eager and compiled execution
|
||||
3. Minifying large models/inputs to isolate problematic patterns
|
||||
4. Debugging compiler errors and accuracy divergences
|
||||
|
||||
The main components include:
|
||||
- Repro generation: Creates standalone Python files that reproduce compiler issues
|
||||
- Minification: Reduces large graphs to minimal failing examples
|
||||
- Accuracy analysis: Compares compiled vs eager execution, with fp64 reference
|
||||
- Debug tools: Dumps graph state, tracks intermediates, analyzes divergences
|
||||
|
||||
This is primarily used by PyTorch developers and researchers to debug issues in
|
||||
the Dynamo AOT compilation pipeline, particularly for the Inductor backend.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import functools
|
||||
|
|
|
|||
|
|
@ -1,4 +1,23 @@
|
|||
# mypy: allow-untyped-defs
|
||||
|
||||
"""
|
||||
Utilities for reproducing and debugging issues in Dynamo after graph capture.
|
||||
|
||||
This file provides tools and infrastructure for debugging problems that occur
|
||||
after Dynamo has captured the graph but before/during backend compilation.
|
||||
Key components include:
|
||||
|
||||
- Minification tools to reduce large graphs to minimal failing examples
|
||||
- Accuracy testing to validate compiled graph outputs match eager mode
|
||||
- Repro generation to create standalone reproduction scripts
|
||||
- Debug backends for capturing and analyzing failures
|
||||
- Utilities for saving/loading graph states and inputs
|
||||
|
||||
The tools here focus specifically on the post-graph-capture stage, making them
|
||||
useful for debugging backend compilation issues, AOTAutograd problems, and
|
||||
accuracy discrepancies between compiled and eager execution.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import functools
|
||||
|
|
|
|||
|
|
@ -1,4 +1,22 @@
|
|||
# mypy: allow-untyped-defs
|
||||
|
||||
"""
|
||||
Utilities for debugging and reproducing issues in Ahead of Time with Inductor (AOTI) compilation.
|
||||
|
||||
This file provides tools and utilities for:
|
||||
- Generating minimal reproducible test cases (minification)
|
||||
- Handling exported programs and graph modules
|
||||
- Creating debug repros for AOTI compilation issues
|
||||
- Supporting both accuracy testing and error reproduction
|
||||
- Managing configuration and environment for repro cases
|
||||
|
||||
The main components include:
|
||||
- Minification tools to reduce test cases while preserving errors
|
||||
- Repro generation utilities for exported programs
|
||||
- Error handling specific to AOTI compilation
|
||||
- Command-line interface for running and managing repros
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import functools
|
||||
import io
|
||||
|
|
|
|||
|
|
@ -1,4 +1,20 @@
|
|||
# mypy: allow-untyped-defs
|
||||
|
||||
"""
|
||||
This module provides functionality for resuming Python execution at specific points in code,
|
||||
primarily used by PyTorch Dynamo for control flow handling and optimization. It implements
|
||||
bytecode transformation and execution state management to enable:
|
||||
|
||||
- Resuming execution at arbitrary points in Python bytecode
|
||||
- Managing context managers and their state across execution boundaries
|
||||
- Transforming and generating new code objects with preserved execution state
|
||||
- Supporting Python 3.11+ exception handling and block management
|
||||
- Restoring torch function mode stacks and other execution context
|
||||
|
||||
The module is critical for PyTorch Dynamo's ability to optimize code while preserving
|
||||
Python semantics and execution state.
|
||||
"""
|
||||
|
||||
import copy
|
||||
import dataclasses
|
||||
import sys
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
# mypy: allow-untyped-defs
|
||||
|
||||
import collections
|
||||
import contextlib
|
||||
import inspect
|
||||
|
|
@ -55,8 +56,22 @@ def _manual_list_update(list_from, list_to):
|
|||
|
||||
class SideEffects:
|
||||
"""
|
||||
Track side effects (list mutation, setattr, etc) that need to be
|
||||
Maintain records of mutations and provide methods to apply them during code generation.
|
||||
|
||||
Handles tracking and applying side effects during PyTorch Dynamo compilation,
|
||||
maintaining Python semantics by managing mutations, attribute modifications,
|
||||
and other side effects that occur during program execution.
|
||||
|
||||
Key responsibilities:
|
||||
- Tracks mutations to Python objects, lists, and dictionaries that need to be
|
||||
applied after an FX graph is run.
|
||||
- Manages attribute modifications and deletions
|
||||
- Handles tensor hooks and backward pass state
|
||||
- Tracks cell variable mutations and global variable changes
|
||||
- Ensures correct ordering and application of side effects after graph execution
|
||||
|
||||
This ensures that optimized code behaves identically to the original Python code with
|
||||
respect to object mutations and other side effects.
|
||||
"""
|
||||
|
||||
id_to_variable: dict[int, VariableTracker]
|
||||
|
|
|
|||
|
|
@ -1,4 +1,24 @@
|
|||
# mypy: allow-untyped-defs
|
||||
|
||||
"""
|
||||
This module provides Source classes that track the origins of values in PyTorch Dynamo.
|
||||
Sources represent where values come from (e.g. local variables, globals, attributes) and
|
||||
are used for guard generation and code reconstruction during compilation.
|
||||
|
||||
The module includes specialized sources for:
|
||||
- Local variables and synthetic locals
|
||||
- Global variables and constants
|
||||
- Object attributes and method calls
|
||||
- NN module specialization (specialized vs unspecialized)
|
||||
- Random values and tensor properties
|
||||
- Default argument handling
|
||||
- FSDP (Fully Sharded Data Parallel) modules
|
||||
|
||||
Sources play a key role in Dynamo's guard system by tracking value origins for
|
||||
guard generation, and in code reconstruction by providing methods to rebuild
|
||||
the code needed to recreate values.
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
import enum
|
||||
from typing import Any, Optional, Union
|
||||
|
|
|
|||
|
|
@ -1,4 +1,29 @@
|
|||
# mypy: allow-untyped-defs
|
||||
|
||||
"""
|
||||
Core module responsible for converting Python bytecode into TorchDynamo's symbolic execution format.
|
||||
|
||||
This module implements the bytecode-level tracing system that allows TorchDynamo to analyze
|
||||
and transform Python code. It converts Python bytecode instructions into a symbolic format
|
||||
that tracks the flow of tensors and other values through the program.
|
||||
|
||||
Key components:
|
||||
- InstructionTranslatorBase: Base class for converting bytecode to symbolic execution
|
||||
- InstructionTranslator: Main translator for function bytecode
|
||||
- InliningInstructionTranslator: Handles inlining of called functions
|
||||
- SpeculationLog: Manages state for speculative execution and rollback
|
||||
|
||||
The symbolic conversion process handles:
|
||||
- Control flow (loops, conditionals, etc.)
|
||||
- Function inlining and call stack management
|
||||
- Tracking of program values and side effects
|
||||
- Graph breaks and resumption points
|
||||
- Exception handling and stack frame management
|
||||
|
||||
This is a core part of TorchDynamo's tracing system that enables ahead-of-time
|
||||
optimization of PyTorch programs.
|
||||
"""
|
||||
|
||||
import collections
|
||||
import collections.abc
|
||||
import contextlib
|
||||
|
|
|
|||
|
|
@ -1,4 +1,23 @@
|
|||
# mypy: allow-untyped-defs
|
||||
|
||||
"""This module implements tensor version operations for Dynamo tracing.
|
||||
|
||||
It provides primitives for handling tensor versioning during tracing, particularly in the
|
||||
context of functionalization where version operations are handled eagerly on fake tensors.
|
||||
|
||||
When we functionalize _tensor_version + _unsafe_set_version_counter, the ops disappear from
|
||||
the traced graph. We run them eagerly on the fake tensors used for tracing, in order to get
|
||||
past asserts that would fail in autograd.
|
||||
|
||||
Why is this ok?
|
||||
1) Versions on functional tensors do not make any sense since you cannot mutate a functional
|
||||
tensor.
|
||||
2) The whole point of version munging is to trick autograd into doing what we want, and after
|
||||
AotAutograd there is no longer any need for these ops.
|
||||
|
||||
Note this is similar to how no_grad is handled.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch._prims import _make_prim, RETURN_TYPE
|
||||
from torch._subclasses import FakeTensorMode
|
||||
|
|
@ -34,21 +53,6 @@ _unsafe_set_version_counter = _make_prim(
|
|||
torch.fx.node.has_side_effect(_unsafe_set_version_counter)
|
||||
|
||||
|
||||
"""
|
||||
When we functionalize _tensor_version + _unsafe_set_version_counter,
|
||||
the ops disappear from the traced graph. We run them eagerly on the
|
||||
fake tensors used for tracing, in order to get past asserts that would
|
||||
fail in autograd.
|
||||
|
||||
Why is this ok?
|
||||
1) Versions on functional tensors don't make any sense since you can't mutate a functional tensor.
|
||||
2) The whole point of version munging is to trick autograd into doing what we want, and after
|
||||
AotAtuograd there is no longer any need for these ops.
|
||||
|
||||
Note this is similar to how no_grad is handled.
|
||||
"""
|
||||
|
||||
|
||||
@_tensor_version.py_impl(FunctionalTensorMode)
|
||||
def _tensor_version_functional(mode, self):
|
||||
return self._version
|
||||
|
|
|
|||
|
|
@ -1,3 +1,13 @@
|
|||
"""Testing utilities for Dynamo, providing a specialized TestCase class and test running functionality.
|
||||
|
||||
This module extends PyTorch's testing framework with Dynamo-specific testing capabilities.
|
||||
It includes:
|
||||
- A custom TestCase class that handles Dynamo-specific setup/teardown
|
||||
- Test running utilities with dependency checking
|
||||
- Automatic reset of Dynamo state between tests
|
||||
- Proper handling of gradient mode state
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import importlib
|
||||
import logging
|
||||
|
|
|
|||
|
|
@ -1,4 +1,20 @@
|
|||
# mypy: allow-untyped-defs
|
||||
|
||||
"""Common utilities for testing Dynamo's minifier functionality.
|
||||
|
||||
This module provides the base infrastructure for running minification tests in Dynamo.
|
||||
It includes:
|
||||
- MinifierTestResult: A dataclass for storing and processing minifier test results
|
||||
- MinifierTestBase: A base test class with utilities for:
|
||||
- Running tests in isolated environments
|
||||
- Managing temporary directories and configurations
|
||||
- Executing minifier launcher scripts
|
||||
- Running and validating reproduction scripts
|
||||
- Supporting both compile-time and runtime error testing
|
||||
|
||||
The minifier helps reduce failing Dynamo compilations to minimal reproductions.
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
import io
|
||||
import logging
|
||||
|
|
|
|||
|
|
@ -1,3 +1,18 @@
|
|||
"""Testing utilities and infrastructure for Dynamo.
|
||||
|
||||
This module provides a comprehensive set of testing utilities including:
|
||||
- Test result collection and validation
|
||||
- Graph manipulation and comparison tools
|
||||
- Test case management and execution helpers
|
||||
- Specialized test decorators for different Python versions and features
|
||||
- RNG state management
|
||||
- Compilation counting and monitoring
|
||||
- Debug utilities for bytecode transformation
|
||||
|
||||
The utilities in this module are used across Dynamo's test suite to ensure
|
||||
consistent testing patterns and proper test isolation.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import dis
|
||||
import functools
|
||||
|
|
|
|||
|
|
@ -1,3 +1,16 @@
|
|||
"""This module contains the core type definitions and protocols used throughout Dynamo.
|
||||
|
||||
The types defined here fall into several categories:
|
||||
- Guard related types (GuardFn, GuardFail, GuardedCode): Used for tracking and managing guards that protect compiled code
|
||||
- Frame and cache types (FrameState, CacheEntry): Used for managing interpreter frame state and caching
|
||||
- Callback protocols (DynamoCallbackFn): Define the interface for frame evaluation callbacks
|
||||
- Hook protocols (DynamoGuardHook, ProfilerStartHook, ProfilerEndHook, BytecodeHook): Define various hook points for
|
||||
instrumentation and customization
|
||||
|
||||
These types provide the foundational interfaces that enable Dynamo's dynamic compilation and optimization system,
|
||||
ensuring type safety and clear contracts between different components of the system.
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
import types
|
||||
from typing import Any, Callable, NamedTuple, Optional, Protocol, Union
|
||||
|
|
|
|||
|
|
@ -1,4 +1,19 @@
|
|||
# mypy: allow-untyped-defs
|
||||
|
||||
"""
|
||||
Utility functions and classes used throughout the TorchDynamo system.
|
||||
|
||||
This module contains a collection of helper utilities used by various parts of Dynamo for:
|
||||
- Performance metrics collection and reporting
|
||||
- Compilation timing and debugging
|
||||
- Graph manipulation and tensor operations
|
||||
- Runtime guards and checks
|
||||
- Common data structure operations
|
||||
- Testing and development tools
|
||||
|
||||
This is an internal module that provides shared functionality used across the Dynamo codebase.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import atexit
|
||||
|
|
|
|||
|
|
@ -1,3 +1,21 @@
|
|||
"""
|
||||
This package implements variable tracking and symbolic execution capabilities for Dynamo,
|
||||
which are essential for converting Python code into FX graphs. It provides a comprehensive
|
||||
set of variable types that handle different Python constructs during tracing.
|
||||
|
||||
Each variable type (like BuiltinVariable, TensorVariable, NNModuleVariable, etc.) is responsible
|
||||
for tracking and symbolically executing operations on specific Python objects. This enables
|
||||
Dynamo to:
|
||||
- Track the flow of values through Python code
|
||||
- Maintain correct semantics during graph conversion
|
||||
- Handle complex Python features like context managers, iterators, and custom objects
|
||||
- Support both eager and symbolic execution modes
|
||||
|
||||
The VariableTracker base class provides the foundation for all variable types, with each
|
||||
subclass implementing specific behavior for different Python constructs. This modular design
|
||||
allows Dynamo to accurately trace and optimize Python code while preserving its semantics.
|
||||
"""
|
||||
|
||||
from .base import VariableTracker
|
||||
from .builtin import BuiltinVariable
|
||||
from .constant import ConstantVariable, EnumVariable
|
||||
|
|
|
|||
|
|
@ -1,5 +1,20 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
"""
|
||||
Core variable tracking functionality for Dynamo. This module defines the fundamental
|
||||
classes and systems used to track and manage variables during Dynamo's operation.
|
||||
|
||||
The module provides:
|
||||
1. VariableTracker - The base class for tracking variables during compilation
|
||||
2. MutationType system - Classes for tracking and managing mutations to variables
|
||||
3. Source type management - Utilities for tracking variable origins and scope
|
||||
4. Variable state management - Tools for managing variable state and transformations
|
||||
|
||||
These components form the foundation of Dynamo's variable handling system,
|
||||
enabling accurate tracking and transformation of Python code into optimized
|
||||
computations.
|
||||
"""
|
||||
|
||||
import collections
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Optional, Sequence, TYPE_CHECKING
|
||||
|
|
|
|||
|
|
@ -1,5 +1,24 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
"""
|
||||
This module contains classes and utilities for building variable trackers in Dynamo.
|
||||
Variable trackers are used to convert Python values into symbolic representations
|
||||
that can be traced and transformed during graph capture.
|
||||
|
||||
The key classes are:
|
||||
|
||||
- VariableBuilder: Handles source-tracked objects that need guards and proper
|
||||
reconstruction in the output graph. Used for inputs, module attributes, etc.
|
||||
|
||||
- SourcelessBuilder: Handles ephemeral objects created during tracing that don't
|
||||
need source tracking or guards. Used for temporary lists, intermediate values, etc.
|
||||
|
||||
Variable trackers enable Dynamo to track the flow of values through the program,
|
||||
maintain guards for dynamic properties, and reconstruct values in the output graph.
|
||||
The builders in this module handle converting Python values into appropriate
|
||||
VariableTracker instances based on their type and usage context.
|
||||
"""
|
||||
|
||||
import abc
|
||||
import collections
|
||||
import contextlib
|
||||
|
|
|
|||
|
|
@ -121,8 +121,13 @@ polyfill_fn_mapping = {
|
|||
|
||||
class BuiltinVariable(VariableTracker):
|
||||
"""
|
||||
A VariableTracker that represents a built-in value. A lot of the code
|
||||
here assumes it will be a function object.
|
||||
A VariableTracker that represents a built-in value (functions and operators).
|
||||
A lot of the code here assumes it will be a function object.
|
||||
|
||||
The BuiltinVariable class wraps Python built-in functions (like len, isinstance, etc.)
|
||||
and operators (like +, -, *, etc.) to enable symbolic execution during tracing. This allows
|
||||
Dynamo to properly handle these operations when converting Python code to FX graphs while
|
||||
maintaining correct semantics and enabling optimizations.
|
||||
"""
|
||||
|
||||
_SENTINEL = object()
|
||||
|
|
|
|||
|
|
@ -1,5 +1,13 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
"""
|
||||
Constant and enum variable tracking in Dynamo.
|
||||
|
||||
This module is fundamental to Dynamo's ability to track and propagate constant
|
||||
values during compilation, ensuring proper handling of Python literals and
|
||||
maintaining type safety through the compilation process.
|
||||
"""
|
||||
|
||||
import operator
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
|
@ -17,6 +25,14 @@ if TYPE_CHECKING:
|
|||
|
||||
|
||||
class ConstantVariable(VariableTracker):
|
||||
"""
|
||||
Variable tracker for Python literals and basic immutable types, with automatic
|
||||
routing support for collection types (lists, tuples, sets, etc.).
|
||||
|
||||
The create() method intelligently constructs appropriate variable types for
|
||||
nested collections.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def create(value, **kwargs) -> VariableTracker:
|
||||
"""
|
||||
|
|
@ -202,6 +218,12 @@ its type to `common_constant_types`.
|
|||
|
||||
|
||||
class EnumVariable(VariableTracker):
|
||||
"""VariableTracker for enum.Enum and enum.IntEnum instances
|
||||
|
||||
Provides specialized handling for Python enum types, supporting
|
||||
both standard Enum and IntEnum with proper value tracking and comparison.
|
||||
"""
|
||||
|
||||
def __init__(self, value, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.value = value
|
||||
|
|
|
|||
|
|
@ -1,4 +1,25 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
"""
|
||||
This file contains a collection of context manager classes used by Dynamo for tracking
|
||||
and managing various PyTorch runtime states during graph compilation. These context
|
||||
managers handle different aspects of PyTorch's execution environment, including:
|
||||
|
||||
- Autograd states (grad mode, inference mode)
|
||||
- CUDA streams and events
|
||||
- Profiling contexts
|
||||
- Deterministic algorithms
|
||||
- Forward/backward AD modes
|
||||
- SDPA (Scaled Dot Product Attention) kernels
|
||||
- FSDP (Fully Sharded Data Parallel) states
|
||||
- AMP (Automatic Mixed Precision) autocast states
|
||||
|
||||
The context managers ensure proper state transitions during graph compilation by
|
||||
tracking enter/exit points and managing cleanup operations. They help maintain
|
||||
consistency between eager execution and compiled graph behavior by capturing and
|
||||
restoring state changes.
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
import inspect
|
||||
import sys
|
||||
|
|
@ -34,7 +55,7 @@ if TYPE_CHECKING:
|
|||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ContextMangerState:
|
||||
class ContextManagerState:
|
||||
"""
|
||||
Mutating `self` in VariableTracker is not allowed because we copy
|
||||
them. This is a mutable container pointed to by context managers
|
||||
|
|
@ -69,7 +90,7 @@ class ContextWrappingVariable(VariableTracker):
|
|||
super().__init__(**kwargs)
|
||||
self.target_values = target_values
|
||||
self.initial_values = initial_values
|
||||
self.state = ContextMangerState() if state is None else state
|
||||
self.state = ContextManagerState() if state is None else state
|
||||
|
||||
def enter(self, tx):
|
||||
self._call_func(tx, self.target_values)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,25 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
"""
|
||||
Dictionary-related variable tracking classes for PyTorch Dynamo.
|
||||
|
||||
This module implements variable tracking for different types of dictionary-like objects:
|
||||
- Regular Python dictionaries (dict)
|
||||
- Ordered dictionaries (collections.OrderedDict)
|
||||
- Default dictionaries (collections.defaultdict)
|
||||
- Dictionary views (keys and values)
|
||||
- Sets and frozensets (implemented internally using dictionaries)
|
||||
|
||||
These classes are responsible for tracking dictionary operations during graph compilation,
|
||||
maintaining proper guards for dictionary mutations and key existence checks. They handle
|
||||
dictionary creation, modification, key/value access, and view operations while ensuring
|
||||
correct behavior in the compiled code through appropriate guard installation.
|
||||
|
||||
The implementation uses a special _HashableTracker wrapper to handle dictionary keys
|
||||
while preserving proper aliasing semantics. Sets are implemented as dictionaries with
|
||||
None values for efficiency and code reuse.
|
||||
"""
|
||||
|
||||
import collections
|
||||
import functools
|
||||
import types
|
||||
|
|
|
|||
|
|
@ -1,4 +1,25 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
"""
|
||||
Distributed computing variable tracking classes for PyTorch Dynamo.
|
||||
|
||||
This module implements variable tracking for distributed computing components:
|
||||
- Process Groups (for collective communication)
|
||||
- Device Meshes (for distributed tensor sharding)
|
||||
- Placement Types (for specifying distribution strategies)
|
||||
- Distributed Tensors and their operations
|
||||
- Backward hooks for distributed module operations
|
||||
|
||||
These classes are responsible for tracking distributed operations during graph
|
||||
compilation while maintaining proper guards and handling distributed-specific
|
||||
behaviors. They ensure correct handling of distributed components like process
|
||||
groups, device meshes, and placement strategies while preserving proper semantics
|
||||
for distributed tensor operations in the compiled code.
|
||||
|
||||
The implementation provides special handling for distributed package availability
|
||||
checks and proper tracking of distributed state and operations across processes.
|
||||
"""
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
from typing import TYPE_CHECKING
|
||||
|
|
|
|||
|
|
@ -1,5 +1,28 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
"""
|
||||
Function-related variable tracking classes for Dynamo's symbolic execution.
|
||||
|
||||
This module contains classes that track different types of functions during graph
|
||||
compilation, including:
|
||||
- User-defined functions and methods
|
||||
- Built-in functions and methods
|
||||
- Wrapped functions (e.g. from decorators)
|
||||
- Special function types (e.g. functools.partial)
|
||||
- Triton kernels and related function types
|
||||
|
||||
These classes are responsible for:
|
||||
- Tracking function calls and their arguments
|
||||
- Managing function closures and cell variables
|
||||
- Handling function attributes and special methods
|
||||
- Maintaining guards for function identity and closure contents
|
||||
- Supporting function inlining and specialization
|
||||
- Enabling proper symbolic execution of different function types
|
||||
|
||||
The variable trackers here work together with the rest of Dynamo to enable
|
||||
accurate graph capture while handling Python's various function-related behaviors.
|
||||
"""
|
||||
|
||||
import builtins
|
||||
import functools
|
||||
import inspect
|
||||
|
|
|
|||
|
|
@ -1,5 +1,24 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
"""
|
||||
This module contains classes and utilities for handling higher-order operators in Dynamo.
|
||||
It provides functionality for tracing and transforming control flow constructs like
|
||||
conditions (torch.cond), loops (torch.while_loop), maps (torch.ops.higher_order.map),
|
||||
and other higher-order operations.
|
||||
|
||||
The module includes specialized VariableTracker classes for different types of
|
||||
higher-order operations, along with utilities for:
|
||||
- Speculating and capturing subgraphs
|
||||
- Managing control flow
|
||||
- Handling autograd function applications
|
||||
- Supporting function transformations
|
||||
- Processing activation checkpoints
|
||||
|
||||
These classes work together to enable Dynamo to correctly trace and compile code
|
||||
containing complex control flow patterns and higher-order functions while preserving
|
||||
their semantic behavior.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import copy
|
||||
import functools
|
||||
|
|
|
|||
|
|
@ -1,5 +1,20 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
"""
|
||||
This module provides iterator-related variable tracking functionality for Dynamo.
|
||||
It implements variable classes for handling Python iterators and itertools functions
|
||||
during symbolic execution and tracing.
|
||||
|
||||
The module includes:
|
||||
- Base iterator variable classes for tracking iterator state
|
||||
- Implementations of built-in iterators (zip, map, filter)
|
||||
- Support for itertools functions (product, accumulate, combinations, etc.)
|
||||
- Mutation tracking and reconstruction capabilities for iterator operations
|
||||
|
||||
These classes integrate with Dynamo's variable tracking system to enable proper
|
||||
handling of iterator operations during code transformation and optimization.
|
||||
"""
|
||||
|
||||
import itertools
|
||||
import operator
|
||||
import sys
|
||||
|
|
|
|||
|
|
@ -1,5 +1,21 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
"""
|
||||
Variable tracking implementations for list-like data structures in Dynamo.
|
||||
|
||||
This module provides specialized variable tracking for various collection types:
|
||||
- Lists and list subclasses (including torch.nn.ModuleList, ParameterList)
|
||||
- Tuples and named tuples
|
||||
- Ranges and slices
|
||||
- Collections.deque
|
||||
- torch.Size with special proxy handling
|
||||
|
||||
The implementations support both mutable and immutable collections, iteration,
|
||||
and common sequence operations. Each collection type has a dedicated Variable
|
||||
class that handles its unique behaviors while integrating with Dynamo's
|
||||
variable tracking system.
|
||||
"""
|
||||
|
||||
import collections
|
||||
import inspect
|
||||
import operator
|
||||
|
|
|
|||
|
|
@ -1,4 +1,22 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
"""
|
||||
This module contains miscellaneous variable tracker implementations for various Python types
|
||||
and features used in Dynamo's symbolic execution. These classes help track and propagate
|
||||
information about different kinds of variables during graph capture.
|
||||
|
||||
Key classes include:
|
||||
- SuperVariable: Handles super() calls and method resolution
|
||||
- ExceptionVariable: Tracks exception objects
|
||||
- RandomVariable: Manages random number generators
|
||||
- GetAttrVariable: Tracks attribute access
|
||||
- MethodWrapperVariable: Handles method wrappers
|
||||
- PythonModuleVariable: Tracks Python modules
|
||||
- NumpyVariable: Handles numpy functions and types
|
||||
- StringFormatVariable: Manages string formatting
|
||||
- DebuggingVariable: Handles print and logging
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
import functools
|
||||
import inspect
|
||||
|
|
|
|||
|
|
@ -1,5 +1,28 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
"""
|
||||
This module implements variable tracking for PyTorch nn.Module instances during Dynamo tracing.
|
||||
|
||||
It provides specialized handling for different types of nn.Module instances through several key classes:
|
||||
|
||||
- NNModuleVariable: Handles instance-specific module tracing, specializing on module id() and placing
|
||||
parameters directly on the torch.fx.GraphModule. This creates one graph per module instance.
|
||||
|
||||
- UnspecializedNNModuleVariable: Provides class-level module tracing, treating nn.Modules like other
|
||||
user-defined objects and passing parameters as inputs to the FX graph. This creates one graph per
|
||||
module class.
|
||||
|
||||
- UnspecializedBuiltinNNModuleVariable: Specifically handles built-in PyTorch modules (e.g. nn.Linear)
|
||||
with appropriate optimizations.
|
||||
|
||||
- FSDPManagedNNModuleVariable: Special handling for FSDP-wrapped modules with modified guarding behavior
|
||||
and parameter handling.
|
||||
|
||||
The module integrates with Dynamo's broader tracing functionality to handle module method calls,
|
||||
parameter access, hooks, and other nn.Module behaviors while maintaining proper scoping and guarding
|
||||
of module state.
|
||||
"""
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
import itertools
|
||||
|
|
|
|||
|
|
@ -1,5 +1,27 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
"""
|
||||
This module implements variable tracking for PyTorch optimizers during Dynamo tracing.
|
||||
|
||||
The OptimizerVariable class provides specialized handling for optimizer instances by:
|
||||
- Optimizing the tracing of expensive optimizer initialization
|
||||
- Managing optimizer state and parameter group tracking
|
||||
- Handling tensor sources and guards for optimizer state tensors
|
||||
- Supporting CUDA graph execution through static tensor address management
|
||||
- Providing special handling for parameter gradients and optimizer state tensors
|
||||
|
||||
Key features include:
|
||||
- Efficient initialization tracing via _init_group optimization
|
||||
- Automatic marking of optimizer state tensors as static for CUDA graphs
|
||||
- Proper source tracking for parameter groups, gradients, and state tensors
|
||||
- Guard installation for optimizer state structure
|
||||
- Support for both CPU and GPU tensor handling
|
||||
- Cleanup of static tensor references via finalizers
|
||||
|
||||
The module integrates with Dynamo's broader tracing system while providing
|
||||
optimizer-specific optimizations and safety guarantees.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import weakref
|
||||
from typing import TYPE_CHECKING
|
||||
|
|
|
|||
|
|
@ -1,5 +1,26 @@
|
|||
# mypy: allow-untyped-decorators
|
||||
# mypy: allow-untyped-defs
|
||||
|
||||
"""
|
||||
This module implements variable tracking for TorchScript objects during Dynamo tracing.
|
||||
|
||||
The TorchScriptObjectVariable class provides specialized handling for TorchScript
|
||||
objects with strong safety guarantees by:
|
||||
- Enforcing method-call-only access to prevent unsafe attribute manipulation
|
||||
- Converting graph breaks into hard errors via _raise_hard_error_if_graph_break
|
||||
- Proper proxy and source tracking for TorchScript method calls
|
||||
- Integration with higher-order operators for method call handling
|
||||
|
||||
Key safety features:
|
||||
- Strict validation that only method calls are allowed (no direct attribute access)
|
||||
- Immediate error reporting for potentially unsafe operations
|
||||
- Proper source tracking for debugging and guard installation
|
||||
- Safe handling of TorchScript object method calls through torchbind
|
||||
|
||||
The module ensures that TorchScript objects are handled safely during tracing
|
||||
by limiting operations to known-safe patterns and failing fast for unsafe usage.
|
||||
"""
|
||||
|
||||
import functools
|
||||
|
||||
import torch
|
||||
|
|
|
|||
|
|
@ -1,5 +1,22 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
"""
|
||||
This module contains variable tracker classes for handling tensors and tensor-related operations in Dynamo.
|
||||
|
||||
The main class is TensorVariable which represents torch.Tensor inputs and intermediate values in the FX graph.
|
||||
It handles tensor operations, method calls, and maintains metadata about tensor properties like dtype, device, etc.
|
||||
|
||||
Other key classes include:
|
||||
- SymNodeVariable: Represents symbolic scalars (int/float/bool) used for size computation and unspecialized values
|
||||
- NumpyNdarrayVariable: Handles numpy array interop through torch._numpy
|
||||
- UnspecializedPythonVariable: Represents unspecialized Python numeric values as 1-element tensors
|
||||
- TensorSubclassVariable: Handles tensor subclasses with __torch_function__ overrides
|
||||
- UntypedStorageVariable: Represents tensor storage objects
|
||||
- DataPtrVariable: Handles tensor data pointer operations
|
||||
|
||||
These classes work together to track tensor operations and properties during Dynamo's tracing process.
|
||||
"""
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
import logging
|
||||
|
|
|
|||
|
|
@ -1,5 +1,33 @@
|
|||
# mypy: allow-untyped-decorators
|
||||
# mypy: allow-untyped-defs
|
||||
|
||||
"""
|
||||
This module implements variable tracking for torch functions and operations during Dynamo tracing.
|
||||
|
||||
It provides classes to handle different types of torch operations:
|
||||
|
||||
TorchInGraphFunctionVariable: Handles torch.* functions that should be captured in the FX graph.
|
||||
Provides special handling for constant folding, tensor methods, and torch function overrides.
|
||||
Manages complex cases like out= variants and parameter construction.
|
||||
|
||||
TorchCtxManagerClassVariable: Handles torch context managers like torch.no_grad(), autocast, etc.
|
||||
Provides implementations for entering/exiting these contexts during tracing.
|
||||
|
||||
DispatchKeySetVariable: Represents torch.DispatchKeySet for managing dispatch keys and
|
||||
device-specific operations during tracing.
|
||||
|
||||
The module includes special handling for:
|
||||
- Constant folding of pure functions
|
||||
- Tensor method calls
|
||||
- torch.nn.Parameter construction
|
||||
- __torch_function__ overrides
|
||||
- Context manager state tracking
|
||||
- Device and dtype management
|
||||
|
||||
This is a core part of Dynamo's tracing system, translating torch operations into
|
||||
traceable graph nodes while preserving correct semantics and handling edge cases.
|
||||
"""
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
import logging
|
||||
|
|
|
|||
|
|
@ -1,5 +1,34 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
"""TorchDynamo support for __torch_function__ tensor subclasses.
|
||||
|
||||
This module implements support for tensor subclasses with __torch_function__ overrides.
|
||||
A tensor subclass instance is represented as a TensorWithTFOverrideVariable, which handles
|
||||
dispatching __torch_function__ on attribute accesses, method calls, and torch API calls.
|
||||
|
||||
Unsupported features:
|
||||
- Triggering __torch_function__ on tensor subclass non-tensor custom attributes
|
||||
- Graph breaking on mutating guardable tensor properties within a __torch_function__ context
|
||||
(can cause excessive recompiles in certain cases)
|
||||
- Matching exact eager behavior of ignoring __torch_function__ objects in non-tensor
|
||||
argument positions of Torch API calls
|
||||
|
||||
Supported features:
|
||||
- Static method implementations of __torch_function__ on custom objects (triggers on torch
|
||||
API calls with the object as any argument)
|
||||
- Triggering __torch_function__ on torch API calls with tensor subclass arguments
|
||||
- __torch_function__ calls on base tensor attribute access and method calls for tensor
|
||||
subclass instances
|
||||
- Matches dispatch ordering behavior of eager __torch_function__ with subclass/object
|
||||
arguments in any position
|
||||
|
||||
See https://docs.google.com/document/d/1WBxBSvW3NXhRp9ncmtokJloMLCtF4AYNhJaffvHe8Kw/edit#heading=h.vacn73lozd9w
|
||||
for more information on the design.
|
||||
|
||||
To enable subclass behavior, add your tensor subclass type to traceable_tensor_subclasses
|
||||
in torch/_dynamo/config.py
|
||||
"""
|
||||
|
||||
import collections
|
||||
import contextlib
|
||||
import functools
|
||||
|
|
@ -43,27 +72,6 @@ if TYPE_CHECKING:
|
|||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
|
||||
|
||||
# [Note: __torch_function__] This feature is a prototype and has some rough edges (contact mlazos with issues):
|
||||
# At a high level, a torch function tensor subclass is represented as a TensorWithTFOverrideVariable, which dispatches
|
||||
# __torch_function__ on attribute accesses, method calls, and torch API calls.
|
||||
# The following is not supported:
|
||||
# - triggering __torch_function__ on tensor subclass non-tensor custom attributes
|
||||
# - graph breaking on mutating guardable tensor properties within a __torch_function__ context, this can cause
|
||||
# excessive recompiles in certain degenerate cases
|
||||
# - Matching the exact eager behavior of *ignoring* __torch_function__ objects in non-tensor argument positions of Torch API calls
|
||||
|
||||
# The following is supported:
|
||||
# - static method impls of __torch_function__ on custom objects; this will trigger on torch API calls with the object as
|
||||
# any argument
|
||||
# - triggering __torch_function__ on torch API calls with tensor subclass arguments
|
||||
# - __torch_function__ calls on base tensor attribute access and method calls for tensor subclass instances
|
||||
# - matches the dispatch ordering behavior of eager __torch_function__ with subclass/object argumnents in any argument position
|
||||
|
||||
# See https://docs.google.com/document/d/1WBxBSvW3NXhRp9ncmtokJloMLCtF4AYNhJaffvHe8Kw/edit#heading=h.vacn73lozd9w
|
||||
# for more information on the design.
|
||||
|
||||
# To enable subclass behavior, add your tensor subclass type to traceable_tensor_subclasses in dynamo/config.py
|
||||
|
||||
bin_ops = [
|
||||
operator.pow,
|
||||
operator.mul,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,26 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
"""
|
||||
This module contains variable classes for handling user-defined objects in Dynamo's tracing system.
|
||||
|
||||
The key classes are:
|
||||
- UserDefinedVariable: Base class for representing custom Python objects
|
||||
- UserDefinedClassVariable: Handles Python class objects/types
|
||||
- UserDefinedObjectVariable: Fallback class for instance objects, with support for method calls,
|
||||
attribute access, and other Python object behaviors.
|
||||
- Specialized subclasses for common patterns:
|
||||
- UserDefinedDictVariable: For dict subclasses
|
||||
- UserDefinedTupleVariable: For tuple subclasses
|
||||
- FrozenDataClassVariable: Special handling of frozen dataclasses
|
||||
- MutableMappingVariable: For collections.abc.MutableMapping subclasses
|
||||
|
||||
Dynamo specializes to VariableTracker subclasses like FrozenDataClassVariable if available; if no
|
||||
subclass qualifies, it falls back to UserDefinedObjectVariable.
|
||||
|
||||
These classes help Dynamo track and handle arbitrary Python objects during tracing,
|
||||
maintaining proper semantics while enabling optimizations where possible.
|
||||
"""
|
||||
|
||||
import collections
|
||||
import contextlib
|
||||
import dataclasses
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user