diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index c2a7731fee0..24dc8de068e 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -5198,6 +5198,31 @@ class CommonTemplate: if self.device != "cpu": assertGeneratedKernelCountEqual(self, 1) + def test_matmul_layer_norm(self): + batch_size = 32 + seq_length = 50 + hidden_size = 256 + + inp = torch.randn( + batch_size, + seq_length, + hidden_size, + requires_grad=True, + device=self.device, + ) + weight = torch.randn( + hidden_size, hidden_size, requires_grad=True, device=self.device + ) + + layer_norm = torch.nn.LayerNorm(hidden_size, device=self.device) + + def foo(inp, weight): + matmul_output = inp @ weight + final_output = layer_norm(matmul_output) + return final_output + + self.common(foo, (inp, weight), check_lowp=False) + def test_transpose_add(self): def fn(a, b): return a.t() + b @@ -12855,6 +12880,43 @@ if HAS_GPU and not TEST_WITH_ASAN: self.assertTrue(len(re.findall(r"in_out_ptr\d+", code)) > 0) self.assertEqual(fn_opt(*inps), fn(*inps)) + def test_donated_buffer_inplace(self): + batch_size = 32 + seq_length = 50 + hidden_size = 256 + + inp = torch.randn( + batch_size, + seq_length, + hidden_size, + requires_grad=True, + device=self.device, + ) + weight = torch.randn( + hidden_size, hidden_size, requires_grad=True, device=self.device + ) + + layer_norm = torch.nn.LayerNorm(hidden_size, device=self.device) + + def fn(inp, weight): + matmul_output = inp @ weight + final_output = layer_norm(matmul_output) + return final_output + + fn_opt = torch.compile(fn) + + def wrapper(inp, weight): + return fn_opt(inp, weight).sum().backward() + + _, code = run_and_get_code(wrapper, inp, weight) + + if config.cpp_wrapper: + # when using cpp_wrapper, backward triton code is in code[2] + self.assertTrue("in_out_ptr" in code[2]) + else: + # when not using cpp_wrapper, backward triton code is in code[1] + self.assertTrue("in_out_ptr" in code[1]) + class RNNTest(TestCase): device_type = GPU_TYPE diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 16d093852ad..15fe504446b 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -2120,7 +2120,11 @@ class PythonWrapperCodegen(CodeGen): def codegen_allocation(self, buffer: ir.Buffer): name = buffer.get_name() - if name in V.graph.removed_buffers or name in self.allocated: + if ( + name in V.graph.removed_buffers + or name in self.allocated + or isinstance(buffer, ir.DonatedBuffer) + ): return self.allocated.add(name) if isinstance( @@ -2174,7 +2178,12 @@ class PythonWrapperCodegen(CodeGen): name = input_buffer.get_name() return not ( name in V.graph.removed_buffers - or name in V.graph.graph_inputs + or ( + name in V.graph.graph_inputs + and not isinstance( + V.graph.graph_inputs_original[name], ir.DonatedBuffer + ) + ) or name in V.graph.constants or name in V.graph.torchbind_constants or name in V.graph.never_reuse_buffers diff --git a/torch/_inductor/cudagraph_trees.py b/torch/_inductor/cudagraph_trees.py index fed24b5d69e..26d11e767f9 100644 --- a/torch/_inductor/cudagraph_trees.py +++ b/torch/_inductor/cudagraph_trees.py @@ -832,6 +832,20 @@ class CUDAGraphNode: if isinstance(t, torch.Tensor) and self._is_cuda_graph_recorded_tensor(t) ] + # (depth, offset) of live tensors which are alias of previous graph outputs + self.live_cudagraph_managed_path_refs: InputList[Optional[PathOutputIndex]] = [ + ( + self._is_alias_of_live_recorded_tensor(t) + if isinstance(t, torch.Tensor) + else None + ) + for t in inputs + ] + + # when replay, preserve the liveness of an input if it AliasesPriorGraphOutput + # and also aliases an output of the current CUDAGraphNode + self.preserved_aliased_inputs: InputList[bool] = [False] * len(inputs) + self.static_input_idxs: List[int] = list( set(wrapped_function.static_input_idxs) | set(self.cudagraph_managed_idxs) ) @@ -1038,11 +1052,11 @@ class CUDAGraphNode: self.check_static_inputs_are_stable(new_inputs) self._copy_inputs_and_remove_from_src(self.reconstructed_inputs, new_inputs) - new_inputs.clear() self.run_graph() outputs = self.reconstruct_outputs() + new_inputs.clear() if config.triton.fast_path_cudagraph_asserts: self.debug_check_invariants_after_invocation() @@ -1261,6 +1275,12 @@ class CUDAGraphNode: path_ref = self._is_alias_of_live_recorded_tensor(o) if path_ref is not None: self._mark_prior_graph_output_as_aliased(path_ref) + + for idx, inp_path_ref in enumerate( + self.live_cudagraph_managed_path_refs + ): + if path_ref == inp_path_ref: + self.preserved_aliased_inputs[idx] = True self.output_storage_alias.append(AliasesPriorGraphOutput(path_ref)) continue @@ -1667,7 +1687,8 @@ class CUDAGraphNode: # this invocation. it is too late to check after we've replayed the graph, # because we would have already written over their memory. for idx in self.cudagraph_managed_idxs: - inputs[idx] = None # type: ignore[call-overload] + if not self.preserved_aliased_inputs[idx]: + inputs[idx] = None # type: ignore[call-overload] torch._check( self._check_liveness( diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index d1f3d34eda4..b3d4e2c0c33 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -74,6 +74,7 @@ from .exc import ( ) from .ir import ( Constant, + DonatedBuffer, FixedLayout, get_device_type, InputBuffer, @@ -103,6 +104,7 @@ from .utils import ( convert_shape_to_inductor, gather_origins, get_cloned_parameter_buffer_name, + get_donated_idxs, get_sympy_Expr_dtype, is_same_tensor, maybe_get_suppress_shape_guards_ctx, @@ -486,6 +488,11 @@ class GraphLowering(torch.fx.Interpreter): # state used by for Kernel.workspace self.workspace_id = itertools.count() + # track the current placeholder index that we are processing + self.placeholder_idx = -1 + + self.bw_donated_idxs = get_donated_idxs() + def has_feature( self, device: Union[torch._inductor.ir.IRNode, device, None], @@ -963,6 +970,7 @@ class GraphLowering(torch.fx.Interpreter): def placeholder( self, target: str, args: Tuple[object], kwargs: Dict[str, object] # type: ignore[override] ) -> Union[Expr, TensorBox, None]: + self.placeholder_idx += 1 example = super().placeholder(target, args, kwargs) # type: ignore[arg-type] target = self.qualify_name(target) if isinstance(example, SymTypes): @@ -993,13 +1001,27 @@ class GraphLowering(torch.fx.Interpreter): sizes, strides = self.static_sizes_strides(example) else: sizes, strides = self.symbolic_sizes_strides(example) # type: ignore[assignment] - # TODO(jansel): handle input aliasing - tensor = TensorBox.create( - InputBuffer( - name=target, - layout=FixedLayout(example.device, example.dtype, sizes, strides), + + if ( + self.is_backward + and self.bw_donated_idxs + and self.placeholder_idx in self.bw_donated_idxs + ): + tensor = TensorBox.create( + DonatedBuffer( + name=target, + layout=FixedLayout(example.device, example.dtype, sizes, strides), + ) ) - ) + else: + # TODO(jansel): handle input aliasing + tensor = TensorBox.create( + InputBuffer( + name=target, + layout=FixedLayout(example.device, example.dtype, sizes, strides), + ) + ) + self.graph_inputs[target] = tensor self.graph_input_names.append(target) self.graph_inputs_original[target] = tensor.data.data diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 6379d5c99d3..e279f6534bb 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -3832,6 +3832,16 @@ class InputBuffer(Buffer): return 1 +class DonatedBuffer(InputBuffer): + """ + Represents a donated buffer which is a saved tensor that is not alias to any + fwd inputs, fwd user outputs, and bwd outputs. We generally cannot inplace + reuse the input tensor memory during backward since it might be used in another + function. However, donated buffer can be inplace reused during backward + to save memory. + """ + + class ConstantBuffer(InputBuffer): override_device: Optional[torch.device] = None diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 90b561379a5..c69ed657455 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -125,10 +125,16 @@ class SchedulerBuffer: hasattr(V.kernel, "args") and self.get_name() in V.kernel.inplace_update_buffers ): + input_buffer: Union[ir.DonatedBuffer, ir.Buffer] + input_buffer_name = V.kernel.inplace_update_buffers[self.get_name()] + if input_buffer_name in self.scheduler.name_to_donated_buffer: + input_buffer = self.scheduler.name_to_donated_buffer[ + input_buffer_name + ].node + else: + input_buffer = self.scheduler.name_to_buf[input_buffer_name].node V.graph.wrapper_code.codegen_inplace_reuse( - self.scheduler.name_to_buf[ - V.kernel.inplace_update_buffers[self.get_name()] - ].node, + input_buffer, self.node, ) else: @@ -163,6 +169,11 @@ class SchedulerBuffer: return self.node.get_mutation_names() +@dataclasses.dataclass +class SchedulerDonatedBuffer(SchedulerBuffer): + defining_op: Optional[BaseSchedulerNode] = None # type: ignore[assignment] + + class BaseSchedulerNode: group: Tuple[torch.device, Tuple[Tuple[sympy.Expr, ...], ...]] read_writes: dependencies.ReadWrites @@ -442,9 +453,12 @@ class BaseSchedulerNode: continue for read in self.read_writes.reads: - input_buf: Optional[SchedulerBuffer] = self.scheduler.name_to_buf.get( - read.name - ) + input_buf: Optional[Union[SchedulerBuffer, SchedulerDonatedBuffer]] + if read.name in self.scheduler.name_to_donated_buffer: + input_buf = self.scheduler.name_to_donated_buffer[read.name] + else: + input_buf = self.scheduler.name_to_buf.get(read.name) + if ( input_buf and V.graph.wrapper_code.can_reuse(input_buf, self) @@ -470,7 +484,8 @@ class BaseSchedulerNode: ), ) and not ( - isinstance( + input_buf.defining_op + and isinstance( input_buf.defining_op.node, (ir.FallbackKernel, ir.MultiOutput), ) @@ -1801,6 +1816,9 @@ class Scheduler: for node in self.nodes: node.prune_deps() + self.name_to_donated_buffer: Dict[ + str, SchedulerDonatedBuffer + ] = self.get_donated_buffers() self.name_to_node: Dict[str, BaseSchedulerNode] = { n.get_name(): n for n in self.nodes } @@ -1884,6 +1902,17 @@ class Scheduler: } ) + def get_donated_buffers(self) -> Dict[str, SchedulerDonatedBuffer]: + name_to_donated_buf = {} + for name in V.graph.graph_inputs_original: + if isinstance(V.graph.graph_inputs_original[name], ir.DonatedBuffer): + name_to_donated_buf[name] = SchedulerDonatedBuffer( + self, + V.graph.graph_inputs_original[name], + defining_op=None, + ) + return name_to_donated_buf + @property def current_device(self) -> Optional[torch.device]: return V.graph.current_device @@ -2160,6 +2189,9 @@ class Scheduler: for buf in node.get_outputs(): buf.set_users(name_to_users[buf.get_name()].items) + for name in self.name_to_donated_buffer: + self.name_to_donated_buffer[name].set_users(name_to_users[name].items) + def dead_node_elimination(self) -> None: """ Remove any nodes without users diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 95ee3b74dfa..74ff87f4fa6 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -2200,3 +2200,10 @@ def ir_dataclass(cls=None, /, *, frozen: bool = True): if cls is None: return wrap return wrap(cls) + + +def get_donated_idxs() -> Optional[List[int]]: + tracing_context = torch._guards.TracingContext.try_get() + if tracing_context is not None and tracing_context.fw_metadata: + return tracing_context.fw_metadata.bw_donated_idxs + return None