mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Fix https://github.com/pytorch/pytorch/issues/99639 by handling the case in `InliningInstructionTranslator`'s `LOAD_CLOSURE` definition when the requested cell is not in `self.closure_cells`. My intuition is that the behavior of `LOAD_DEREF` and `STORE_DEREF` on a cell/freevar should not depend on whether or not we called `LOAD_CLOSURE` (that is, we shouldn't create a new cell var in `LOAD_CLOSURE` like in https://github.com/pytorch/pytorch/pull/101357). But we need a way to push cells created by the inlined function that were not present in the caller - `InlinedClosureVariable` is used to differentiate these cells from other cells. Adding this test causes an error though (EDIT: this test is not relevant to this PR and instead just reveals that `cond` with Python side effects is still broken): ```python def test_closure_out_of_scope_cell_with_cond(self): from functorch.experimental.control_flow import cond cell1 = torch.rand(3, 3) cell2 = torch.rand(3, 3) orig3 = torch.rand(3, 3) def test(x): cell3 = orig3.clone() def then(): nonlocal cell3 cell3 += cell1 return cell3 def els(): nonlocal cell3 cell3 += cell2 return cell3 return cond(x > 0, then, els, []) opt_fn = torch._dynamo.optimize("eager")(test) result1 = opt_fn(1) self.assertTrue(torch.allclose(result1, orig3 + cell1)) result2 = opt_fn(-1) self.assertTrue(torch.allclose(result1, orig3 + cell1 + cell2)) ``` ``` Traceback (most recent call last): File "/scratch/williamwen/work/pytorch2/test/dynamo/test_misc.py", line 1768, in test_closure_out_of_scope_cell_with_cond result1 = opt_fn(1) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/eval_frame.py", line 295, in _fn return fn(*args, **kwargs) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/eval_frame.py", line 448, in catch_errors return callback(frame, cache_size, hooks, frame_state) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/convert_frame.py", line 526, in _convert_frame result = inner_convert(frame, cache_size, hooks, frame_state) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/convert_frame.py", line 127, in _fn return fn(*args, **kwargs) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/convert_frame.py", line 360, in _convert_frame_assert return _compile( File "/scratch/williamwen/work/pytorch2/torch/_dynamo/utils.py", line 180, in time_wrapper r = func(*args, **kwargs) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/convert_frame.py", line 430, in _compile out_code = transform_code_object(code, transform) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/bytecode_transformation.py", line 1000, in transform_code_object transformations(instructions, code_options) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/convert_frame.py", line 415, in transform tracer.run() File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 2029, in run super().run() File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 708, in run and self.step() File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 668, in step getattr(self, inst.opname)(inst) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 391, in wrapper return inner_fn(self, inst) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 1100, in CALL_FUNCTION self.call_function(fn, args, {}) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 559, in call_function self.push(fn.call_function(self, args, kwargs)) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/variables/torch.py", line 1061, in call_function (false_r, false_graph, false_lifted_freevars) = speculate_branch(False) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/variables/torch.py", line 1044, in speculate_branch ret_val, ret_graph, ret_lifted_freevars = speculate_subgraph( File "/scratch/williamwen/work/pytorch2/torch/_dynamo/variables/torch.py", line 850, in speculate_subgraph output = f.call_function(tx, args, {}) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/variables/functions.py", line 121, in call_function return tx.inline_user_function_return( File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 595, in inline_user_function_return result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 2134, in inline_call return cls.inline_call_(parent, func, args, kwargs) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 2231, in inline_call_ tracer.run() File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 708, in run and self.step() File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 668, in step getattr(self, inst.opname)(inst) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 162, in impl self.push(fn_var.call_function(self, self.popn(nargs), {})) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/variables/builtin.py", line 497, in call_function proxy = tx.output.create_proxy( File "/scratch/williamwen/work/pytorch2/torch/_dynamo/output_graph.py", line 345, in create_proxy return self.current_tracer.create_proxy(*args, **kwargs) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/output_graph.py", line 1109, in create_proxy new_arg = self.lift_tracked_freevar_to_input(arg) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/output_graph.py", line 1226, in lift_tracked_freevar_to_input self.parent.lift_tracked_freevar_to_input(proxy) File "/scratch/williamwen/work/pytorch2/torch/_dynamo/output_graph.py", line 1219, in lift_tracked_freevar_to_input assert ( AssertionError: lift_tracked_freevar_to_input on root SubgraphTracer from user code: File "/scratch/williamwen/work/pytorch2/test/dynamo/test_misc.py", line 1766, in test return cond(x > 0, then, els, []) File "/scratch/williamwen/work/pytorch2/test/dynamo/test_misc.py", line 1764, in els cell3 += cell2 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/104222 Approved by: https://github.com/jansel |
||
|---|---|---|
| .. | ||
| __init__.py | ||
| base.py | ||
| builder.py | ||
| builtin.py | ||
| constant.py | ||
| ctx_manager.py | ||
| dicts.py | ||
| functions.py | ||
| lists.py | ||
| misc.py | ||
| nn_module.py | ||
| optimizer.py | ||
| tensor.py | ||
| torch.py | ||
| user_defined.py | ||