mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Fixes #118795 This is a graph breaking partial fix for #120914. We still need -actual- module parametrization tracing support, but at least it doesn't blow up hard now. **Background**: Module parametrization injects a property as the module parameter attribute that calls a `nn.Module` whose forward takes in a module parameter and returns a reparametrized module parameter. Example: ``` class MyParametrization(nn.Module): def forward(X): # This reparametrization just negates the original parameter value return -X m = nn.Linear(...) p = MyParametrization() register_parametrization(m, "weight", p) # Accessing the "weight" attribute will invoke p's forward() on m's original weight and return the output as the new weight. # m.weight here is now an injected property that does the above instead of an actual Parameter. # This property is defined in torch/nn/utils/parametrize.py. m.weight # NB: Parametrization changes the module type (e.g. torch.nn.utils.parametrize.ParametrizedLinear) print(type(m)) ``` **Problem 1**: Dynamo has special tracing rules for things in `torch.nn`. Parametrizing a module changes the type of the module and the parametrized attribute, so now these rules wrongly affect tracing here. To fix this: * For parametrized modules, call `convert_to_unspecialized()` to restart analysis where Dynamo starts inlining the module. **Problem 2**: The issue seen in #118795 is that Dynamo will see a dynamically constructed tensor when `m.weight` is called and introduce that to its `tensor_weakref_to_sizes_strides` cache during fake-ification. This tensor is also made to be a graph input, since it's a module parameter. When guards are created for this module parameter input, the logic calls `m.weight` again and tries to look the result up in the cache, but this is a different tensor now, giving the `KeyError` symptom. To fix this: * Replace Dynamo's `tensor_weakref_to_sizes_strides` cache with a `input_source_to_sizes_strides` cache. * This cache was originally introduced in #100128. Pull Request resolved: https://github.com/pytorch/pytorch/pull/121041 Approved by: https://github.com/anijain2305 |
||
|---|---|---|
| .. | ||
| __init__.py | ||
| base.py | ||
| builder.py | ||
| builtin.py | ||
| constant.py | ||
| ctx_manager.py | ||
| dicts.py | ||
| distributed.py | ||
| functions.py | ||
| higher_order_ops.py | ||
| iter.py | ||
| lazy.py | ||
| lists.py | ||
| misc.py | ||
| nn_module.py | ||
| optimizer.py | ||
| sdpa.py | ||
| tensor.py | ||
| torch_function.py | ||
| torch.py | ||
| user_defined.py | ||