Populate self.export in InstructionTranslatorBase (#88508)

Summary:

This is a followup to https://github.com/pytorch/pytorch/pull/88354/files#diff-622913fdb49db90d6f3a8ab225b4badb7996023e6498e9f7c6d03fe9f32d0986R836

Reference to self.export got added to InstructionTranslatorBase (i.e. STORE_ATTR) but self.export is populated only for InstructionTranslators.

Here's an example failure

```
   File "/scratch/williamwen/work/pytorch/torch/_dynamo/symbolic_convert.py", line 322, in step
    getattr(self, inst.opname)(inst)
  File "/scratch/williamwen/work/pytorch/torch/_dynamo/symbolic_convert.py", line 844, in STORE_ATTR
    not self.export
AttributeError: 'InliningInstructionTranslator' object has no attribute 'export'
```

Let's populate with the base class with export flag.

Test Plan:

python test/dynamo/test_export_mutations.py
python test/dynamo/test_export.py

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88508
Approved by: https://github.com/tugsbayasgalan
This commit is contained in:
Mergen Nachin 2022-11-04 13:03:00 -07:00 committed by PyTorch MergeBot
parent afdc2283ef
commit dde9affeaa
2 changed files with 20 additions and 0 deletions

View File

@ -54,6 +54,21 @@ class MutationExportTests(torch._dynamo.test_case.TestCase):
self.check_failure_on_export(Foo(), torch.Tensor(3, 2))
def test_module_attribute_mutation_violation_positive_4(self):
# Mutating attribute with an inline function
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
def add(self, a, b):
return a + b
def forward(self, x):
self.a = self.add(1, 2) * self.add(3, 4)
return x.sum() + self.a
self.check_failure_on_export(Foo(), torch.Tensor(3, 2))
def test_module_attribute_mutation_violation_negative_1(self):
# Mutating attribute with a Tensor type inside __init__ but
# not in forward()

View File

@ -1328,6 +1328,7 @@ class InstructionTranslatorBase(object):
symbolic_locals: Dict[str, VariableTracker],
symbolic_globals: Dict[str, VariableTracker],
f_code: types.CodeType,
export: bool,
):
super(InstructionTranslatorBase, self).__init__()
@ -1357,6 +1358,8 @@ class InstructionTranslatorBase(object):
self.exec_recorder = ExecutionRecorder(code=f_code, code_options=code_options)
# Stack of module being parsed, current nn.module is at the end of ordered dict
self.nn_module_stack: Dict[str, str] = {}
# Flag to indicate whether tracing is used for export.
self.export = export
if fake_tensors_available:
with torch._subclasses.FakeTensorMode(
@ -1407,6 +1410,7 @@ class InstructionTranslator(InstructionTranslatorBase):
# A global var is inserted only after a STORE_GLOBAL happens to it
symbolic_globals=collections.OrderedDict(),
f_code=f_code,
export=export,
)
self.one_graph: bool = one_graph
self.export = export
@ -1634,6 +1638,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
instructions=cleaned_instructions(code),
code_options={k: getattr(code, k) for k in dir(code)},
f_code=code,
export=parent.export,
)
self.parent = parent
self.symbolic_result = None