pytorch/torch/_dynamo/variables
Joel Schlosser 07b618e2d4 Graph break cleanly in Dynamo for module parametrization (#121041)
Fixes #118795

This is a graph breaking partial fix for #120914. We still need -actual- module parametrization tracing support, but at least it doesn't blow up hard now.

**Background**: Module parametrization injects a property as the module parameter attribute that calls a `nn.Module` whose forward takes in a module parameter and returns a reparametrized module parameter.
Example:
```
class MyParametrization(nn.Module):
    def forward(X):
        # This reparametrization just negates the original parameter value
        return -X

m = nn.Linear(...)
p = MyParametrization()
register_parametrization(m, "weight", p)

# Accessing the "weight" attribute will invoke p's forward() on m's original weight and return the output as the new weight.
# m.weight here is now an injected property that does the above instead of an actual Parameter.
# This property is defined in torch/nn/utils/parametrize.py.
m.weight

# NB: Parametrization changes the module type (e.g. torch.nn.utils.parametrize.ParametrizedLinear)
print(type(m))
```

**Problem 1**: Dynamo has special tracing rules for things in `torch.nn`. Parametrizing a module changes the type of the module and the parametrized attribute, so now these rules wrongly affect tracing here. To fix this:
* For parametrized modules, call `convert_to_unspecialized()` to restart analysis where Dynamo starts inlining the module.

**Problem 2**: The issue seen in #118795 is that Dynamo will see a dynamically constructed tensor when `m.weight` is called and introduce that to its `tensor_weakref_to_sizes_strides` cache during fake-ification. This tensor is also made to be a graph input, since it's a module parameter. When guards are created for this module parameter input, the logic calls `m.weight` again and tries to look the result up in the cache, but this is a different tensor now, giving the `KeyError` symptom. To fix this:
* Replace Dynamo's `tensor_weakref_to_sizes_strides` cache with a `input_source_to_sizes_strides` cache.
    * This cache was originally introduced in #100128.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121041
Approved by: https://github.com/anijain2305
2024-03-26 23:44:51 +00:00
..
__init__.py Teach dynamo about torch.func.jvp (#119926) 2024-03-22 20:25:47 +00:00
base.py [dynamo] Remove uses of raise unimplemented (#122136) 2024-03-22 19:29:58 +00:00
builder.py Graph break cleanly in Dynamo for module parametrization (#121041) 2024-03-26 23:44:51 +00:00
builtin.py dynamo: handle DTensor.device_mesh.device_type (#118803) 2024-03-22 14:42:22 +00:00
constant.py [dynamo] Remove VariableTracker.parents_tracker (#122058) 2024-03-19 04:23:24 +00:00
ctx_manager.py Teach dynamo about torch.func.jvp (#119926) 2024-03-22 20:25:47 +00:00
dicts.py [dynamo] Add missing _nonvar_fields annotations (#122219) 2024-03-20 07:53:18 +00:00
distributed.py Don't create world pg variable out of thin air when rewriting c10d collectives (#122561) 2024-03-26 20:12:08 +00:00
functions.py Don't create world pg variable out of thin air when rewriting c10d collectives (#122561) 2024-03-26 20:12:08 +00:00
higher_order_ops.py [dynamo] Remove uses of raise unimplemented (#122136) 2024-03-22 19:29:58 +00:00
iter.py [dynamo] Remove uses of raise unimplemented (#122136) 2024-03-22 19:29:58 +00:00
lazy.py [dynamo] Replace VariableTracker.apply with visit/realize_all (#122218) 2024-03-20 07:53:18 +00:00
lists.py [dynamo] Fix list comparison ops (#122559) 2024-03-25 07:03:23 +00:00
misc.py [dynamo] Add missing _nonvar_fields annotations (#122219) 2024-03-20 07:53:18 +00:00
nn_module.py Graph break cleanly in Dynamo for module parametrization (#121041) 2024-03-26 23:44:51 +00:00
optimizer.py [dynamo] Add missing _nonvar_fields annotations (#122219) 2024-03-20 07:53:18 +00:00
sdpa.py [dynamo] Refactor reconstruct() not to return anything (#120150) 2024-02-17 17:13:41 +00:00
tensor.py dynamo: support placement kwargs for DTensor.to_local() (#119947) 2024-03-22 14:42:27 +00:00
torch_function.py [dynamo] Optimize SourcelessBuilder (#122063) 2024-03-19 04:23:30 +00:00
torch.py Teach dynamo about torch.func.jvp (#119926) 2024-03-22 20:25:47 +00:00
user_defined.py [dynamo] Add HASATTR guard for UserDefinedObject attrs (#122555) 2024-03-24 03:41:58 +00:00