[dynamo] Expand _nonvar_fields names (#111749)

This should be a small compile time optimization, since we won't need to
walk these fields in apply().

Pull Request resolved: https://github.com/pytorch/pytorch/pull/111749
Approved by: https://github.com/yanboliang
This commit is contained in:
Jason Ansel 2023-10-21 16:50:41 -07:00 committed by PyTorch MergeBot
parent 2b2b6caf8f
commit c65c0682b1
2 changed files with 10 additions and 3 deletions

View File

@ -107,7 +107,14 @@ class VariableTracker(metaclass=HasPostInit):
"""
# fields to leave unmodified in apply()
_nonvar_fields = ["value"]
_nonvar_fields = {
"value",
"guards",
"source",
"mutable_local",
"recursively_contains",
"user_code_variable_name",
}
@staticmethod
def propagate(*vars: List[List["VariableTracker"]]):

View File

@ -74,7 +74,7 @@ def record_nn_module_stack(module_key: str, source, tx, mod: torch.nn.Module):
class NNModuleVariable(VariableTracker):
_nonvar_fields = ["module_type", "module_key"]
_nonvar_fields = {"module_type", "module_key", *VariableTracker._nonvar_fields}
def __init__(self, module_type: type, module_key: str, **kwargs):
super().__init__(**kwargs)
@ -639,7 +639,7 @@ class NNModuleVariable(VariableTracker):
class UnspecializedNNModuleVariable(UserDefinedObjectVariable):
_nonvar_fields = ["value_type"]
_nonvar_fields = {"value_type", *UserDefinedObjectVariable._nonvar_fields}
"""
The above class will specialize on the id() of a module and place