mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[dynamo] Expand _nonvar_fields names (#111749)
This should be a small compile time optimization, since we won't need to walk these fields in apply(). Pull Request resolved: https://github.com/pytorch/pytorch/pull/111749 Approved by: https://github.com/yanboliang
This commit is contained in:
parent
2b2b6caf8f
commit
c65c0682b1
|
|
@ -107,7 +107,14 @@ class VariableTracker(metaclass=HasPostInit):
|
|||
"""
|
||||
|
||||
# fields to leave unmodified in apply()
|
||||
_nonvar_fields = ["value"]
|
||||
_nonvar_fields = {
|
||||
"value",
|
||||
"guards",
|
||||
"source",
|
||||
"mutable_local",
|
||||
"recursively_contains",
|
||||
"user_code_variable_name",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def propagate(*vars: List[List["VariableTracker"]]):
|
||||
|
|
|
|||
|
|
@ -74,7 +74,7 @@ def record_nn_module_stack(module_key: str, source, tx, mod: torch.nn.Module):
|
|||
|
||||
|
||||
class NNModuleVariable(VariableTracker):
|
||||
_nonvar_fields = ["module_type", "module_key"]
|
||||
_nonvar_fields = {"module_type", "module_key", *VariableTracker._nonvar_fields}
|
||||
|
||||
def __init__(self, module_type: type, module_key: str, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
|
@ -639,7 +639,7 @@ class NNModuleVariable(VariableTracker):
|
|||
|
||||
|
||||
class UnspecializedNNModuleVariable(UserDefinedObjectVariable):
|
||||
_nonvar_fields = ["value_type"]
|
||||
_nonvar_fields = {"value_type", *UserDefinedObjectVariable._nonvar_fields}
|
||||
|
||||
"""
|
||||
The above class will specialize on the id() of a module and place
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user