[Inductor] Inplacing with Donated Buffer (#140113)

Currently, inductor does not inplace update a buffer if it is an input buffer. Because we don't know if an input will be used by other functions.

Donated buffer provides additional information that an input buffer will not be used by other functions. So we can inplace update donated buffer when possible.

[Dashboard](https://hud.pytorch.org/benchmark/torchbench/inductor_dynamic?dashboard=torchinductor&startTime=Mon,%2011%20Nov%202024%2018:14:36%20GMT&stopTime=Mon,%2018%20Nov%202024%2018:14:36%20GMT&granularity=hour&mode=training&dtype=amp&deviceName=cuda%20(a100)&lBranch=bf/donated-buffer-inplace&lCommit=5df0769c00e6f9000caeb10fd5cbf0b165f69c2a&rBranch=main&rCommit=2b39a8db7741b816b03677a9c6fec1af05640dee)

![image](https://github.com/user-attachments/assets/f19d961f-7973-418e-9de8-5c2a97950478)
![image](https://github.com/user-attachments/assets/df3bd6a9-58b8-4e8a-8397-9e3b1de9adfe)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140113
Approved by: https://github.com/eellison
This commit is contained in:
Boyuan Feng 2024-11-26 17:19:50 +00:00 committed by PyTorch MergeBot
parent 3ef031909f
commit eecc8e362c
7 changed files with 180 additions and 17 deletions

View File

@ -5198,6 +5198,31 @@ class CommonTemplate:
if self.device != "cpu": if self.device != "cpu":
assertGeneratedKernelCountEqual(self, 1) assertGeneratedKernelCountEqual(self, 1)
def test_matmul_layer_norm(self):
batch_size = 32
seq_length = 50
hidden_size = 256
inp = torch.randn(
batch_size,
seq_length,
hidden_size,
requires_grad=True,
device=self.device,
)
weight = torch.randn(
hidden_size, hidden_size, requires_grad=True, device=self.device
)
layer_norm = torch.nn.LayerNorm(hidden_size, device=self.device)
def foo(inp, weight):
matmul_output = inp @ weight
final_output = layer_norm(matmul_output)
return final_output
self.common(foo, (inp, weight), check_lowp=False)
def test_transpose_add(self): def test_transpose_add(self):
def fn(a, b): def fn(a, b):
return a.t() + b return a.t() + b
@ -12855,6 +12880,43 @@ if HAS_GPU and not TEST_WITH_ASAN:
self.assertTrue(len(re.findall(r"in_out_ptr\d+", code)) > 0) self.assertTrue(len(re.findall(r"in_out_ptr\d+", code)) > 0)
self.assertEqual(fn_opt(*inps), fn(*inps)) self.assertEqual(fn_opt(*inps), fn(*inps))
def test_donated_buffer_inplace(self):
batch_size = 32
seq_length = 50
hidden_size = 256
inp = torch.randn(
batch_size,
seq_length,
hidden_size,
requires_grad=True,
device=self.device,
)
weight = torch.randn(
hidden_size, hidden_size, requires_grad=True, device=self.device
)
layer_norm = torch.nn.LayerNorm(hidden_size, device=self.device)
def fn(inp, weight):
matmul_output = inp @ weight
final_output = layer_norm(matmul_output)
return final_output
fn_opt = torch.compile(fn)
def wrapper(inp, weight):
return fn_opt(inp, weight).sum().backward()
_, code = run_and_get_code(wrapper, inp, weight)
if config.cpp_wrapper:
# when using cpp_wrapper, backward triton code is in code[2]
self.assertTrue("in_out_ptr" in code[2])
else:
# when not using cpp_wrapper, backward triton code is in code[1]
self.assertTrue("in_out_ptr" in code[1])
class RNNTest(TestCase): class RNNTest(TestCase):
device_type = GPU_TYPE device_type = GPU_TYPE

View File

@ -2120,7 +2120,11 @@ class PythonWrapperCodegen(CodeGen):
def codegen_allocation(self, buffer: ir.Buffer): def codegen_allocation(self, buffer: ir.Buffer):
name = buffer.get_name() name = buffer.get_name()
if name in V.graph.removed_buffers or name in self.allocated: if (
name in V.graph.removed_buffers
or name in self.allocated
or isinstance(buffer, ir.DonatedBuffer)
):
return return
self.allocated.add(name) self.allocated.add(name)
if isinstance( if isinstance(
@ -2174,7 +2178,12 @@ class PythonWrapperCodegen(CodeGen):
name = input_buffer.get_name() name = input_buffer.get_name()
return not ( return not (
name in V.graph.removed_buffers name in V.graph.removed_buffers
or name in V.graph.graph_inputs or (
name in V.graph.graph_inputs
and not isinstance(
V.graph.graph_inputs_original[name], ir.DonatedBuffer
)
)
or name in V.graph.constants or name in V.graph.constants
or name in V.graph.torchbind_constants or name in V.graph.torchbind_constants
or name in V.graph.never_reuse_buffers or name in V.graph.never_reuse_buffers

View File

@ -832,6 +832,20 @@ class CUDAGraphNode:
if isinstance(t, torch.Tensor) and self._is_cuda_graph_recorded_tensor(t) if isinstance(t, torch.Tensor) and self._is_cuda_graph_recorded_tensor(t)
] ]
# (depth, offset) of live tensors which are alias of previous graph outputs
self.live_cudagraph_managed_path_refs: InputList[Optional[PathOutputIndex]] = [
(
self._is_alias_of_live_recorded_tensor(t)
if isinstance(t, torch.Tensor)
else None
)
for t in inputs
]
# when replay, preserve the liveness of an input if it AliasesPriorGraphOutput
# and also aliases an output of the current CUDAGraphNode
self.preserved_aliased_inputs: InputList[bool] = [False] * len(inputs)
self.static_input_idxs: List[int] = list( self.static_input_idxs: List[int] = list(
set(wrapped_function.static_input_idxs) | set(self.cudagraph_managed_idxs) set(wrapped_function.static_input_idxs) | set(self.cudagraph_managed_idxs)
) )
@ -1038,11 +1052,11 @@ class CUDAGraphNode:
self.check_static_inputs_are_stable(new_inputs) self.check_static_inputs_are_stable(new_inputs)
self._copy_inputs_and_remove_from_src(self.reconstructed_inputs, new_inputs) self._copy_inputs_and_remove_from_src(self.reconstructed_inputs, new_inputs)
new_inputs.clear()
self.run_graph() self.run_graph()
outputs = self.reconstruct_outputs() outputs = self.reconstruct_outputs()
new_inputs.clear()
if config.triton.fast_path_cudagraph_asserts: if config.triton.fast_path_cudagraph_asserts:
self.debug_check_invariants_after_invocation() self.debug_check_invariants_after_invocation()
@ -1261,6 +1275,12 @@ class CUDAGraphNode:
path_ref = self._is_alias_of_live_recorded_tensor(o) path_ref = self._is_alias_of_live_recorded_tensor(o)
if path_ref is not None: if path_ref is not None:
self._mark_prior_graph_output_as_aliased(path_ref) self._mark_prior_graph_output_as_aliased(path_ref)
for idx, inp_path_ref in enumerate(
self.live_cudagraph_managed_path_refs
):
if path_ref == inp_path_ref:
self.preserved_aliased_inputs[idx] = True
self.output_storage_alias.append(AliasesPriorGraphOutput(path_ref)) self.output_storage_alias.append(AliasesPriorGraphOutput(path_ref))
continue continue
@ -1667,7 +1687,8 @@ class CUDAGraphNode:
# this invocation. it is too late to check after we've replayed the graph, # this invocation. it is too late to check after we've replayed the graph,
# because we would have already written over their memory. # because we would have already written over their memory.
for idx in self.cudagraph_managed_idxs: for idx in self.cudagraph_managed_idxs:
inputs[idx] = None # type: ignore[call-overload] if not self.preserved_aliased_inputs[idx]:
inputs[idx] = None # type: ignore[call-overload]
torch._check( torch._check(
self._check_liveness( self._check_liveness(

View File

@ -74,6 +74,7 @@ from .exc import (
) )
from .ir import ( from .ir import (
Constant, Constant,
DonatedBuffer,
FixedLayout, FixedLayout,
get_device_type, get_device_type,
InputBuffer, InputBuffer,
@ -103,6 +104,7 @@ from .utils import (
convert_shape_to_inductor, convert_shape_to_inductor,
gather_origins, gather_origins,
get_cloned_parameter_buffer_name, get_cloned_parameter_buffer_name,
get_donated_idxs,
get_sympy_Expr_dtype, get_sympy_Expr_dtype,
is_same_tensor, is_same_tensor,
maybe_get_suppress_shape_guards_ctx, maybe_get_suppress_shape_guards_ctx,
@ -486,6 +488,11 @@ class GraphLowering(torch.fx.Interpreter):
# state used by for Kernel.workspace # state used by for Kernel.workspace
self.workspace_id = itertools.count() self.workspace_id = itertools.count()
# track the current placeholder index that we are processing
self.placeholder_idx = -1
self.bw_donated_idxs = get_donated_idxs()
def has_feature( def has_feature(
self, self,
device: Union[torch._inductor.ir.IRNode, device, None], device: Union[torch._inductor.ir.IRNode, device, None],
@ -963,6 +970,7 @@ class GraphLowering(torch.fx.Interpreter):
def placeholder( def placeholder(
self, target: str, args: Tuple[object], kwargs: Dict[str, object] # type: ignore[override] self, target: str, args: Tuple[object], kwargs: Dict[str, object] # type: ignore[override]
) -> Union[Expr, TensorBox, None]: ) -> Union[Expr, TensorBox, None]:
self.placeholder_idx += 1
example = super().placeholder(target, args, kwargs) # type: ignore[arg-type] example = super().placeholder(target, args, kwargs) # type: ignore[arg-type]
target = self.qualify_name(target) target = self.qualify_name(target)
if isinstance(example, SymTypes): if isinstance(example, SymTypes):
@ -993,13 +1001,27 @@ class GraphLowering(torch.fx.Interpreter):
sizes, strides = self.static_sizes_strides(example) sizes, strides = self.static_sizes_strides(example)
else: else:
sizes, strides = self.symbolic_sizes_strides(example) # type: ignore[assignment] sizes, strides = self.symbolic_sizes_strides(example) # type: ignore[assignment]
# TODO(jansel): handle input aliasing
tensor = TensorBox.create( if (
InputBuffer( self.is_backward
name=target, and self.bw_donated_idxs
layout=FixedLayout(example.device, example.dtype, sizes, strides), and self.placeholder_idx in self.bw_donated_idxs
):
tensor = TensorBox.create(
DonatedBuffer(
name=target,
layout=FixedLayout(example.device, example.dtype, sizes, strides),
)
) )
) else:
# TODO(jansel): handle input aliasing
tensor = TensorBox.create(
InputBuffer(
name=target,
layout=FixedLayout(example.device, example.dtype, sizes, strides),
)
)
self.graph_inputs[target] = tensor self.graph_inputs[target] = tensor
self.graph_input_names.append(target) self.graph_input_names.append(target)
self.graph_inputs_original[target] = tensor.data.data self.graph_inputs_original[target] = tensor.data.data

View File

@ -3832,6 +3832,16 @@ class InputBuffer(Buffer):
return 1 return 1
class DonatedBuffer(InputBuffer):
"""
Represents a donated buffer which is a saved tensor that is not alias to any
fwd inputs, fwd user outputs, and bwd outputs. We generally cannot inplace
reuse the input tensor memory during backward since it might be used in another
function. However, donated buffer can be inplace reused during backward
to save memory.
"""
class ConstantBuffer(InputBuffer): class ConstantBuffer(InputBuffer):
override_device: Optional[torch.device] = None override_device: Optional[torch.device] = None

View File

@ -125,10 +125,16 @@ class SchedulerBuffer:
hasattr(V.kernel, "args") hasattr(V.kernel, "args")
and self.get_name() in V.kernel.inplace_update_buffers and self.get_name() in V.kernel.inplace_update_buffers
): ):
input_buffer: Union[ir.DonatedBuffer, ir.Buffer]
input_buffer_name = V.kernel.inplace_update_buffers[self.get_name()]
if input_buffer_name in self.scheduler.name_to_donated_buffer:
input_buffer = self.scheduler.name_to_donated_buffer[
input_buffer_name
].node
else:
input_buffer = self.scheduler.name_to_buf[input_buffer_name].node
V.graph.wrapper_code.codegen_inplace_reuse( V.graph.wrapper_code.codegen_inplace_reuse(
self.scheduler.name_to_buf[ input_buffer,
V.kernel.inplace_update_buffers[self.get_name()]
].node,
self.node, self.node,
) )
else: else:
@ -163,6 +169,11 @@ class SchedulerBuffer:
return self.node.get_mutation_names() return self.node.get_mutation_names()
@dataclasses.dataclass
class SchedulerDonatedBuffer(SchedulerBuffer):
defining_op: Optional[BaseSchedulerNode] = None # type: ignore[assignment]
class BaseSchedulerNode: class BaseSchedulerNode:
group: Tuple[torch.device, Tuple[Tuple[sympy.Expr, ...], ...]] group: Tuple[torch.device, Tuple[Tuple[sympy.Expr, ...], ...]]
read_writes: dependencies.ReadWrites read_writes: dependencies.ReadWrites
@ -442,9 +453,12 @@ class BaseSchedulerNode:
continue continue
for read in self.read_writes.reads: for read in self.read_writes.reads:
input_buf: Optional[SchedulerBuffer] = self.scheduler.name_to_buf.get( input_buf: Optional[Union[SchedulerBuffer, SchedulerDonatedBuffer]]
read.name if read.name in self.scheduler.name_to_donated_buffer:
) input_buf = self.scheduler.name_to_donated_buffer[read.name]
else:
input_buf = self.scheduler.name_to_buf.get(read.name)
if ( if (
input_buf input_buf
and V.graph.wrapper_code.can_reuse(input_buf, self) and V.graph.wrapper_code.can_reuse(input_buf, self)
@ -470,7 +484,8 @@ class BaseSchedulerNode:
), ),
) )
and not ( and not (
isinstance( input_buf.defining_op
and isinstance(
input_buf.defining_op.node, input_buf.defining_op.node,
(ir.FallbackKernel, ir.MultiOutput), (ir.FallbackKernel, ir.MultiOutput),
) )
@ -1801,6 +1816,9 @@ class Scheduler:
for node in self.nodes: for node in self.nodes:
node.prune_deps() node.prune_deps()
self.name_to_donated_buffer: Dict[
str, SchedulerDonatedBuffer
] = self.get_donated_buffers()
self.name_to_node: Dict[str, BaseSchedulerNode] = { self.name_to_node: Dict[str, BaseSchedulerNode] = {
n.get_name(): n for n in self.nodes n.get_name(): n for n in self.nodes
} }
@ -1884,6 +1902,17 @@ class Scheduler:
} }
) )
def get_donated_buffers(self) -> Dict[str, SchedulerDonatedBuffer]:
name_to_donated_buf = {}
for name in V.graph.graph_inputs_original:
if isinstance(V.graph.graph_inputs_original[name], ir.DonatedBuffer):
name_to_donated_buf[name] = SchedulerDonatedBuffer(
self,
V.graph.graph_inputs_original[name],
defining_op=None,
)
return name_to_donated_buf
@property @property
def current_device(self) -> Optional[torch.device]: def current_device(self) -> Optional[torch.device]:
return V.graph.current_device return V.graph.current_device
@ -2160,6 +2189,9 @@ class Scheduler:
for buf in node.get_outputs(): for buf in node.get_outputs():
buf.set_users(name_to_users[buf.get_name()].items) buf.set_users(name_to_users[buf.get_name()].items)
for name in self.name_to_donated_buffer:
self.name_to_donated_buffer[name].set_users(name_to_users[name].items)
def dead_node_elimination(self) -> None: def dead_node_elimination(self) -> None:
""" """
Remove any nodes without users Remove any nodes without users

View File

@ -2200,3 +2200,10 @@ def ir_dataclass(cls=None, /, *, frozen: bool = True):
if cls is None: if cls is None:
return wrap return wrap
return wrap(cls) return wrap(cls)
def get_donated_idxs() -> Optional[List[int]]:
tracing_context = torch._guards.TracingContext.try_get()
if tracing_context is not None and tracing_context.fw_metadata:
return tracing_context.fw_metadata.bw_donated_idxs
return None