mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159530 Approved by: https://github.com/eellison
This commit is contained in:
parent
47a1db823d
commit
3be70dc30e
261
test/inductor/test_segmented_tree.py
Normal file
261
test/inductor/test_segmented_tree.py
Normal file
|
|
@ -0,0 +1,261 @@
|
||||||
|
# Owner(s): ["module: inductor"]
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from hypothesis import given, strategies as st
|
||||||
|
|
||||||
|
from torch._inductor.codegen.segmented_tree import SegmentedTree
|
||||||
|
from torch.testing._internal.common_utils import run_tests
|
||||||
|
|
||||||
|
|
||||||
|
# Helper functions for operations
|
||||||
|
def max_op(a, b):
|
||||||
|
return max(a, b)
|
||||||
|
|
||||||
|
|
||||||
|
def add_op(a, b):
|
||||||
|
return a + b
|
||||||
|
|
||||||
|
|
||||||
|
# Naive implementations for reference
|
||||||
|
def naive_range_max(arr, start, end):
|
||||||
|
return max(arr[start : end + 1])
|
||||||
|
|
||||||
|
|
||||||
|
def naive_range_update(arr, start, end, value):
|
||||||
|
for i in range(start, end + 1):
|
||||||
|
arr[i] += value
|
||||||
|
|
||||||
|
|
||||||
|
# Strategies for hypothesis testing
|
||||||
|
positive_integers = st.lists(
|
||||||
|
st.integers(min_value=1, max_value=100), min_size=1, max_size=50
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def valid_range_indices(array_length):
|
||||||
|
return st.tuples(
|
||||||
|
st.integers(min_value=0, max_value=array_length - 1),
|
||||||
|
st.integers(min_value=0, max_value=array_length - 1),
|
||||||
|
).map(lambda x: (min(x), max(x)))
|
||||||
|
|
||||||
|
|
||||||
|
update_values = st.integers(min_value=1, max_value=50)
|
||||||
|
|
||||||
|
|
||||||
|
# Basic construction and initialization tests
|
||||||
|
def test_basic_construction():
|
||||||
|
values = [1, 3, 5, 7, 9]
|
||||||
|
tree = SegmentedTree(values, add_op, max_op, 0)
|
||||||
|
assert tree.summarize_range(0, 4) == 9
|
||||||
|
|
||||||
|
|
||||||
|
def test_empty_array():
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
SegmentedTree([], add_op, max_op, 0)
|
||||||
|
|
||||||
|
|
||||||
|
# Property-based tests
|
||||||
|
@given(values=positive_integers)
|
||||||
|
def test_max_query_matches_naive(values):
|
||||||
|
tree = SegmentedTree(values, add_op, max_op, 0)
|
||||||
|
|
||||||
|
for start in range(len(values)):
|
||||||
|
for end in range(start, len(values)):
|
||||||
|
expected = naive_range_max(values, start, end)
|
||||||
|
actual = tree.summarize_range(start, end)
|
||||||
|
assert actual == expected, (
|
||||||
|
f"Range [{start}:{end}] expected {expected}, got {actual}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@given(values=positive_integers, range_indices=st.data(), update_value=update_values)
|
||||||
|
def test_range_update(values, range_indices, update_value):
|
||||||
|
# Create a copy for naive implementation
|
||||||
|
naive_values = values.copy()
|
||||||
|
|
||||||
|
# Create segment tree
|
||||||
|
tree = SegmentedTree(values, add_op, max_op, 0)
|
||||||
|
|
||||||
|
# Get valid range indices
|
||||||
|
start, end = range_indices.draw(valid_range_indices(len(values)))
|
||||||
|
|
||||||
|
# Apply updates
|
||||||
|
tree.update_range(start, end, update_value)
|
||||||
|
naive_range_update(naive_values, start, end, update_value)
|
||||||
|
|
||||||
|
# Verify all possible ranges
|
||||||
|
for i in range(len(values)):
|
||||||
|
for j in range(i, len(values)):
|
||||||
|
expected = naive_range_max(naive_values, i, j)
|
||||||
|
actual = tree.summarize_range(i, j)
|
||||||
|
assert actual == expected, (
|
||||||
|
f"After update, range [{i}:{j}] expected {expected}, got {actual}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@given(values=positive_integers, range_data=st.data())
|
||||||
|
def test_multiple_operations(values, range_data):
|
||||||
|
# Create a copy for naive implementation
|
||||||
|
naive_values = values.copy()
|
||||||
|
tree = SegmentedTree(values, add_op, max_op, 0)
|
||||||
|
|
||||||
|
# Perform multiple operations
|
||||||
|
num_operations = 5
|
||||||
|
for _ in range(num_operations):
|
||||||
|
# Randomly choose between query and update
|
||||||
|
operation_type = range_data.draw(st.sampled_from(["query", "update"]))
|
||||||
|
start, end = range_data.draw(valid_range_indices(len(values)))
|
||||||
|
|
||||||
|
if operation_type == "query":
|
||||||
|
expected = naive_range_max(naive_values, start, end)
|
||||||
|
actual = tree.summarize_range(start, end)
|
||||||
|
assert actual == expected, (
|
||||||
|
f"Range query [{start}:{end}] expected {expected}, got {actual}"
|
||||||
|
)
|
||||||
|
else: # update
|
||||||
|
update_value = range_data.draw(update_values)
|
||||||
|
tree.update_range(start, end, update_value)
|
||||||
|
naive_range_update(naive_values, start, end, update_value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_single_element_ranges():
|
||||||
|
values = [1, 3, 5, 7, 9]
|
||||||
|
tree = SegmentedTree(values, add_op, max_op, 0)
|
||||||
|
|
||||||
|
for i in range(len(values)):
|
||||||
|
assert tree.summarize_range(i, i) == values[i], (
|
||||||
|
f"Single element range at index {i} failed"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_full_array_range():
|
||||||
|
values = [1, 3, 5, 7, 9]
|
||||||
|
tree = SegmentedTree(values, add_op, max_op, 0)
|
||||||
|
|
||||||
|
# Test querying the entire array
|
||||||
|
assert tree.summarize_range(0, len(values) - 1) == max(values)
|
||||||
|
|
||||||
|
# Update the entire array and test again
|
||||||
|
update_value = 10
|
||||||
|
tree.update_range(0, len(values) - 1, update_value)
|
||||||
|
expected = max([v + update_value for v in values])
|
||||||
|
assert tree.summarize_range(0, len(values) - 1) == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_boundary_conditions():
|
||||||
|
values = [1, 3, 5, 7, 9]
|
||||||
|
tree = SegmentedTree(values, add_op, max_op, 0)
|
||||||
|
|
||||||
|
# Test first element
|
||||||
|
assert tree.summarize_range(0, 0) == values[0]
|
||||||
|
|
||||||
|
# Test last element
|
||||||
|
assert tree.summarize_range(len(values) - 1, len(values) - 1) == values[-1]
|
||||||
|
|
||||||
|
# Test first two elements
|
||||||
|
assert tree.summarize_range(0, 1) == max(values[0:2])
|
||||||
|
|
||||||
|
# Test last two elements
|
||||||
|
assert tree.summarize_range(len(values) - 2, len(values) - 1) == max(values[-2:])
|
||||||
|
|
||||||
|
|
||||||
|
def test_invalid_ranges():
|
||||||
|
values = [1, 3, 5, 7, 9]
|
||||||
|
tree = SegmentedTree(values, add_op, max_op, 0)
|
||||||
|
|
||||||
|
# Test start > end
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
tree.summarize_range(3, 2)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
tree.update_range(4, 2, 10)
|
||||||
|
|
||||||
|
|
||||||
|
def test_out_of_bounds():
|
||||||
|
values = [1, 3, 5, 7, 9]
|
||||||
|
tree = SegmentedTree(values, add_op, max_op, 0)
|
||||||
|
|
||||||
|
# Test negative indices
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
tree.summarize_range(-1, 3)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
tree.summarize_range(0, -1)
|
||||||
|
|
||||||
|
# Test indices >= n
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
tree.summarize_range(0, len(values))
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
tree.summarize_range(len(values), len(values) + 1)
|
||||||
|
|
||||||
|
# Test update with out of bounds indices
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
tree.update_range(-1, 3, 10)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
tree.update_range(0, len(values), 10)
|
||||||
|
|
||||||
|
|
||||||
|
def test_overlapping_updates():
|
||||||
|
values = [1, 3, 5, 7, 9]
|
||||||
|
naive_values = values.copy()
|
||||||
|
tree = SegmentedTree(values, add_op, max_op, 0)
|
||||||
|
|
||||||
|
# Apply overlapping updates
|
||||||
|
tree.update_range(0, 2, 5) # Update [0, 1, 2]
|
||||||
|
naive_range_update(naive_values, 0, 2, 5)
|
||||||
|
|
||||||
|
tree.update_range(1, 3, 3) # Update [1, 2, 3]
|
||||||
|
naive_range_update(naive_values, 1, 3, 3)
|
||||||
|
|
||||||
|
# Verify all possible ranges
|
||||||
|
for i in range(len(values)):
|
||||||
|
for j in range(i, len(values)):
|
||||||
|
expected = naive_range_max(naive_values, i, j)
|
||||||
|
actual = tree.summarize_range(i, j)
|
||||||
|
assert actual == expected, (
|
||||||
|
f"After overlapping updates, range [{i}:{j}] expected {expected}, got {actual}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_sequential_updates_and_queries():
|
||||||
|
values = [2, 4, 6, 8, 10, 12, 14]
|
||||||
|
naive_values = values.copy()
|
||||||
|
tree = SegmentedTree(values, add_op, max_op, 0)
|
||||||
|
|
||||||
|
# Sequence of operations
|
||||||
|
operations = [
|
||||||
|
("update", 1, 3, 5), # Update range [1, 2, 3] with +5
|
||||||
|
("query", 0, 4), # Query range [0, 1, 2, 3, 4]
|
||||||
|
("update", 2, 5, 3), # Update range [2, 3, 4, 5] with +3
|
||||||
|
("query", 1, 3), # Query range [1, 2, 3]
|
||||||
|
("update", 0, 6, 2), # Update entire array with +2
|
||||||
|
("query", 0, 6), # Query entire array
|
||||||
|
("query", 3, 5), # Query range [3, 4, 5]
|
||||||
|
]
|
||||||
|
|
||||||
|
for op in operations:
|
||||||
|
if op[0] == "update":
|
||||||
|
_, start, end, value = op
|
||||||
|
tree.update_range(start, end, value)
|
||||||
|
naive_range_update(naive_values, start, end, value)
|
||||||
|
|
||||||
|
# Verify tree state after update
|
||||||
|
for i in range(len(values)):
|
||||||
|
for j in range(i, len(values)):
|
||||||
|
expected = naive_range_max(naive_values, i, j)
|
||||||
|
actual = tree.summarize_range(i, j)
|
||||||
|
assert actual == expected, (
|
||||||
|
f"After update ({start}, {end}, {value}), query [{i}:{j}] expected {expected}, got {actual}"
|
||||||
|
)
|
||||||
|
else: # query
|
||||||
|
_, start, end = op
|
||||||
|
expected = naive_range_max(naive_values, start, end)
|
||||||
|
assert tree.summarize_range(start, end) == expected, (
|
||||||
|
f"Query [{start}:{end}] expected {expected}, got {tree.summarize_range(start, end)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run_tests()
|
||||||
|
|
@ -13754,6 +13754,45 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
|
||||||
has_lowered = not re.search(r"repeat_interleave.Tensor", code)
|
has_lowered = not re.search(r"repeat_interleave.Tensor", code)
|
||||||
self.assertEqual(has_lowered, can_lower)
|
self.assertEqual(has_lowered, can_lower)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _is_triggering_buffer_reuse(fn, *inputs):
|
||||||
|
with config.patch(allow_buffer_reuse=True):
|
||||||
|
_, (code_allowed,) = run_and_get_code(fn, *inputs)
|
||||||
|
with config.patch(allow_buffer_reuse=False):
|
||||||
|
_, (code_disallowed,) = run_and_get_code(fn, *inputs)
|
||||||
|
code_allowed = re.sub(r"AOT ID: .*", "AOT ID: ['test']", code_allowed)
|
||||||
|
code_disallowed = re.sub(r"AOT ID: .*", "AOT ID: ['test']", code_disallowed)
|
||||||
|
return code_allowed != code_disallowed
|
||||||
|
|
||||||
|
def test_allow_reuse_disable_if_exceed_peak(self):
|
||||||
|
@torch.compile
|
||||||
|
def fn(inp): # 1*N^2
|
||||||
|
a = inp.mean(-1) # 1*N^2 + N
|
||||||
|
b = (inp - a) ** 2 # 2*N^2 + N
|
||||||
|
c = b @ b # 3*N^2 (!!) since this is the peak, can not reuse across
|
||||||
|
d = c.mean(-1) # 2*N^2 + N
|
||||||
|
return d # 1*N^2 + N
|
||||||
|
|
||||||
|
inp = torch.randn(100, 100, device=self.device)
|
||||||
|
self.assertFalse(CommonTemplate._is_triggering_buffer_reuse(fn, inp))
|
||||||
|
|
||||||
|
def test_allow_reuse_active_if_under_peak(self):
|
||||||
|
def g(inp):
|
||||||
|
return (inp - torch.logsumexp(inp, -1)) ** 2
|
||||||
|
|
||||||
|
@torch.compile
|
||||||
|
def fn(m, inp):
|
||||||
|
inp = m @ g(inp)
|
||||||
|
inp = m @ g(inp)
|
||||||
|
inp = m @ g(inp)
|
||||||
|
inp = m @ g(inp)
|
||||||
|
inp = m @ g(inp)
|
||||||
|
return inp
|
||||||
|
|
||||||
|
m = torch.randn(100, 100, device=self.device)
|
||||||
|
inp = torch.randn(100, 100, device=self.device)
|
||||||
|
self.assertTrue(CommonTemplate._is_triggering_buffer_reuse(fn, m, inp))
|
||||||
|
|
||||||
# end of class CommonTemplate - add new tests here
|
# end of class CommonTemplate - add new tests here
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
241
torch/_inductor/codegen/segmented_tree.py
Normal file
241
torch/_inductor/codegen/segmented_tree.py
Normal file
|
|
@ -0,0 +1,241 @@
|
||||||
|
from typing import Callable, Generic, Optional, TypeVar
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
def _value_or(opt: Optional[T], default: T) -> T:
|
||||||
|
return opt if opt is not None else default
|
||||||
|
|
||||||
|
|
||||||
|
class SegmentedTree(Generic[T]):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
values: list[T],
|
||||||
|
update_op: Callable[[T, T], T],
|
||||||
|
summary_op: Callable[[T, T], T],
|
||||||
|
identity_element: T,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize a segment tree with the given values and operations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
values: list of initial values
|
||||||
|
update_op: Function to apply when updating a value (e.g., addition)
|
||||||
|
summary_op: Function to summarize two values (e.g., min, max, sum)
|
||||||
|
identity_element: Identity element for the summary_op (e.g., 0 for sum, float('inf') for min)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the input values list is empty
|
||||||
|
"""
|
||||||
|
if not values:
|
||||||
|
raise ValueError("Cannot create a segment tree with empty values list")
|
||||||
|
|
||||||
|
self.n = len(values)
|
||||||
|
self.update_op = update_op
|
||||||
|
self.summary_op = summary_op
|
||||||
|
self.identity = identity_element
|
||||||
|
|
||||||
|
# Size of segment tree array (next power of 2 * 2)
|
||||||
|
# The tree follows a standard heap layout where
|
||||||
|
# node `n`'s children are at `2*n` and `2*n+1`.
|
||||||
|
# Index 0 is unused.
|
||||||
|
self.size = 1
|
||||||
|
while self.size < self.n:
|
||||||
|
self.size *= 2
|
||||||
|
self.size *= 2
|
||||||
|
|
||||||
|
# Initialize tree and lazy arrays
|
||||||
|
self.tree = [identity_element] * self.size
|
||||||
|
# The lazy array contains updates to the given node
|
||||||
|
# Upon update, we only push updates to the top-most
|
||||||
|
# nodes that fully receive the update. We then
|
||||||
|
# propagate the update down as required (i.e., when
|
||||||
|
# we receive an interval query that neither fully
|
||||||
|
# contains the node nor fully doesn't contain the
|
||||||
|
# node
|
||||||
|
self.lazy: list[Optional[T]] = [None] * self.size
|
||||||
|
|
||||||
|
# Build the tree
|
||||||
|
self._build(values, 1, 0, self.n - 1)
|
||||||
|
|
||||||
|
def _build(self, values: list[T], node: int, start: int, end: int) -> None:
|
||||||
|
"""
|
||||||
|
Build the segment tree recursively.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
values: Original array of values
|
||||||
|
node: Current node index in the segment tree
|
||||||
|
start: Start index of the segment
|
||||||
|
end: End index of the segment
|
||||||
|
"""
|
||||||
|
if start == end:
|
||||||
|
# Leaf node
|
||||||
|
if start < len(values):
|
||||||
|
self.tree[node] = values[start]
|
||||||
|
return
|
||||||
|
|
||||||
|
mid = (start + end) // 2
|
||||||
|
left_child = 2 * node
|
||||||
|
right_child = 2 * node + 1
|
||||||
|
|
||||||
|
# Recursively build left and right subtrees
|
||||||
|
self._build(values, left_child, start, mid)
|
||||||
|
self._build(values, right_child, mid + 1, end)
|
||||||
|
|
||||||
|
# Update current node with summary of children
|
||||||
|
self.tree[node] = self.summary_op(self.tree[left_child], self.tree[right_child])
|
||||||
|
|
||||||
|
def _children(self, node: int) -> list[int]:
|
||||||
|
return [2 * node, 2 * node + 1]
|
||||||
|
|
||||||
|
def _push_lazy(self, node: int, start: int, end: int) -> None:
|
||||||
|
"""
|
||||||
|
Push lazy updates down to children.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: Current node index
|
||||||
|
start: Start index of the segment
|
||||||
|
end: End index of the segment
|
||||||
|
"""
|
||||||
|
lazy_node = self.lazy[node]
|
||||||
|
if lazy_node is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Apply lazy update to current node
|
||||||
|
self.tree[node] = self.update_op(self.tree[node], lazy_node)
|
||||||
|
|
||||||
|
if start != end: # Not a leaf node
|
||||||
|
# Propagate to children
|
||||||
|
for child in self._children(node):
|
||||||
|
self.lazy[child] = self.update_op(
|
||||||
|
_value_or(self.lazy[child], self.identity), lazy_node
|
||||||
|
)
|
||||||
|
|
||||||
|
# Clear the lazy value
|
||||||
|
self.lazy[node] = None
|
||||||
|
|
||||||
|
def _update_range_helper(
|
||||||
|
self, node: int, start: int, end: int, left: int, right: int, value: T
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Helper method to update a range of values in the segment tree.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: Current node index
|
||||||
|
start: Start index of the current segment
|
||||||
|
end: End index of the current segment
|
||||||
|
left: Start index of the range to update
|
||||||
|
right: End index of the range to update
|
||||||
|
value: Value to apply to the range
|
||||||
|
"""
|
||||||
|
# Push lazy updates before processing this node
|
||||||
|
self._push_lazy(node, start, end)
|
||||||
|
|
||||||
|
# No overlap
|
||||||
|
if start > right or end < left:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Complete overlap
|
||||||
|
if start >= left and end <= right:
|
||||||
|
# Apply update to current node
|
||||||
|
self.lazy[node] = value
|
||||||
|
self._push_lazy(node, start, end)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Partial overlap, recurse to children
|
||||||
|
mid = (start + end) // 2
|
||||||
|
left_child = 2 * node
|
||||||
|
right_child = 2 * node + 1
|
||||||
|
|
||||||
|
self._update_range_helper(left_child, start, mid, left, right, value)
|
||||||
|
self._update_range_helper(right_child, mid + 1, end, left, right, value)
|
||||||
|
|
||||||
|
# Update current node based on children
|
||||||
|
self.tree[node] = self.summary_op(self.tree[left_child], self.tree[right_child])
|
||||||
|
|
||||||
|
def _query_range_helper(
|
||||||
|
self, node: int, start: int, end: int, left: int, right: int
|
||||||
|
) -> T:
|
||||||
|
"""
|
||||||
|
Helper method to query a range of values in the segment tree.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: Current node index
|
||||||
|
start: Start index of the current segment
|
||||||
|
end: End index of the current segment
|
||||||
|
left: Start index of the range to query
|
||||||
|
right: End index of the range to query
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Summary value for the range
|
||||||
|
"""
|
||||||
|
# No overlap
|
||||||
|
if start > right or end < left:
|
||||||
|
return self.identity
|
||||||
|
|
||||||
|
# Push lazy updates before processing this node
|
||||||
|
self._push_lazy(node, start, end)
|
||||||
|
|
||||||
|
# Complete overlap
|
||||||
|
if start >= left and end <= right:
|
||||||
|
return self.tree[node]
|
||||||
|
|
||||||
|
# Partial overlap, recurse to children
|
||||||
|
mid = (start + end) // 2
|
||||||
|
left_child = 2 * node
|
||||||
|
right_child = 2 * node + 1
|
||||||
|
|
||||||
|
left_result = self._query_range_helper(left_child, start, mid, left, right)
|
||||||
|
right_result = self._query_range_helper(right_child, mid + 1, end, left, right)
|
||||||
|
|
||||||
|
# Combine results from children
|
||||||
|
return self.summary_op(left_result, right_result)
|
||||||
|
|
||||||
|
def update_range(self, start: int, end: int, value: T) -> None:
|
||||||
|
"""
|
||||||
|
Update a range of values in the segment tree.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
start: Start index of the range to update (inclusive)
|
||||||
|
end: End index of the range to update (inclusive)
|
||||||
|
value: Value to apply to the range
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If start > end or indices are out of bounds
|
||||||
|
"""
|
||||||
|
if start > end:
|
||||||
|
raise ValueError("Start index must be less than or equal to end index")
|
||||||
|
|
||||||
|
if start < 0 or start >= self.n:
|
||||||
|
raise ValueError(f"Start index {start} out of bounds [0, {self.n - 1}]")
|
||||||
|
|
||||||
|
if end < 0 or end >= self.n:
|
||||||
|
raise ValueError(f"End index {end} out of bounds [0, {self.n - 1}]")
|
||||||
|
|
||||||
|
self._update_range_helper(1, 0, self.n - 1, start, end, value)
|
||||||
|
|
||||||
|
def summarize_range(self, start: int, end: int) -> T:
|
||||||
|
"""
|
||||||
|
Query a range of values in the segment tree.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
start: Start index of the range to query (inclusive)
|
||||||
|
end: End index of the range to query (inclusive)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Summary value for the range according to the summary operation
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If start > end or indices are out of bounds
|
||||||
|
"""
|
||||||
|
if start > end:
|
||||||
|
raise ValueError("Start index must be less than or equal to end index")
|
||||||
|
|
||||||
|
if start < 0 or start >= self.n:
|
||||||
|
raise ValueError(f"Start index {start} out of bounds [0, {self.n - 1}]")
|
||||||
|
|
||||||
|
if end < 0 or end >= self.n:
|
||||||
|
raise ValueError(f"End index {end} out of bounds [0, {self.n - 1}]")
|
||||||
|
|
||||||
|
return self._query_range_helper(1, 0, self.n - 1, start, end)
|
||||||
|
|
@ -48,6 +48,7 @@ from ..utils import (
|
||||||
cache_on_self,
|
cache_on_self,
|
||||||
DelayReplaceLine,
|
DelayReplaceLine,
|
||||||
get_benchmark_name,
|
get_benchmark_name,
|
||||||
|
get_dtype_size,
|
||||||
IndentedBuffer,
|
IndentedBuffer,
|
||||||
is_codegen_graph_partition_subgraph,
|
is_codegen_graph_partition_subgraph,
|
||||||
is_using_cudagraph_partition,
|
is_using_cudagraph_partition,
|
||||||
|
|
@ -587,10 +588,64 @@ class MemoryPlanningLine(WrapperLine):
|
||||||
return f"{type(self).__name__}({', '.join(args)})"
|
return f"{type(self).__name__}({', '.join(args)})"
|
||||||
|
|
||||||
|
|
||||||
|
class EfficientPeakEstimate:
|
||||||
|
def __init__(self):
|
||||||
|
from ..memory import estimate_peak_memory, get_freeable_input_buf
|
||||||
|
|
||||||
|
scheduler_nodes = V.graph.scheduler.nodes
|
||||||
|
graph_inputs = OrderedSet(V.graph.graph_inputs.keys())
|
||||||
|
graph_outputs = OrderedSet(V.graph.get_output_names())
|
||||||
|
names_to_freeable_bufs = get_freeable_input_buf(scheduler_nodes, graph_inputs)
|
||||||
|
self.overall_peak_memory, peak_by_scheduler_node = estimate_peak_memory(
|
||||||
|
scheduler_nodes,
|
||||||
|
names_to_freeable_bufs,
|
||||||
|
graph_outputs,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .segmented_tree import SegmentedTree
|
||||||
|
|
||||||
|
self.segmented_tree = SegmentedTree(
|
||||||
|
peak_by_scheduler_node, operator.add, max, 0
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_size(self, node: BufferLike) -> int:
|
||||||
|
return V.graph.sizevars.size_hint(
|
||||||
|
V.graph.get_allocation_storage_size(node), fallback=0
|
||||||
|
) * get_dtype_size(node.get_dtype())
|
||||||
|
|
||||||
|
def peak_between(self, line_a: FreeIfNotReusedLine, line_b: AllocateLine):
|
||||||
|
return self.segmented_tree.summarize_range(
|
||||||
|
line_a.scheduler_node_index + 1, line_b.scheduler_node_index - 1
|
||||||
|
)
|
||||||
|
|
||||||
|
def update_peak_between(self, line_a: FreeIfNotReusedLine, line_b: AllocateLine):
|
||||||
|
if line_a.scheduler_node_index + 1 == line_b.scheduler_node_index:
|
||||||
|
return
|
||||||
|
self.segmented_tree.update_range(
|
||||||
|
line_a.scheduler_node_index + 1,
|
||||||
|
line_b.scheduler_node_index - 1,
|
||||||
|
self._get_size(line_b.node),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class AllocateLine(MemoryPlanningLine):
|
class AllocateLine(MemoryPlanningLine):
|
||||||
node: BufferLike
|
node: BufferLike
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
assert V.graph.scheduler.current_node is not None
|
||||||
|
self.scheduler_node_index = V.graph.scheduler.nodes.index(
|
||||||
|
V.graph.scheduler.current_node
|
||||||
|
)
|
||||||
|
|
||||||
|
def should_reuse_buffer(self, free_line: FreeIfNotReusedLine, size: int) -> bool:
|
||||||
|
if free_line.scheduler_node_index + 1 == self.scheduler_node_index:
|
||||||
|
return True
|
||||||
|
overall_peak_memory = self.wrapper.estimate_peak.overall_peak_memory
|
||||||
|
peak_memory_in_range = self.wrapper.estimate_peak.peak_between(free_line, self)
|
||||||
|
new_peak_memory = size + peak_memory_in_range
|
||||||
|
return new_peak_memory <= overall_peak_memory
|
||||||
|
|
||||||
def plan(self, state: MemoryPlanningState) -> MemoryPlanningLine:
|
def plan(self, state: MemoryPlanningState) -> MemoryPlanningLine:
|
||||||
if self.node.get_name() in V.graph.removed_buffers:
|
if self.node.get_name() in V.graph.removed_buffers:
|
||||||
return NullLine(self.wrapper)
|
return NullLine(self.wrapper)
|
||||||
|
|
@ -599,8 +654,16 @@ class AllocateLine(MemoryPlanningLine):
|
||||||
key = buffer_reuse_key(self.node)
|
key = buffer_reuse_key(self.node)
|
||||||
if config.allow_buffer_reuse and key in state:
|
if config.allow_buffer_reuse and key in state:
|
||||||
free_line = state.pop(key)
|
free_line = state.pop(key)
|
||||||
free_line.is_reused = True
|
size = V.graph.sizevars.size_hint(
|
||||||
return ReuseLine(self.wrapper, free_line.node, self.node)
|
V.graph.get_allocation_storage_size(self.node), fallback=0
|
||||||
|
) * get_dtype_size(self.node.get_dtype())
|
||||||
|
if self.should_reuse_buffer(free_line, size):
|
||||||
|
free_line.is_reused = True
|
||||||
|
self.wrapper.estimate_peak.update_peak_between(free_line, self)
|
||||||
|
return ReuseLine(self.wrapper, free_line.node, self.node)
|
||||||
|
else:
|
||||||
|
state.push(key, free_line)
|
||||||
|
return self
|
||||||
|
|
||||||
if self.node.get_device_or_error().type == "cpu":
|
if self.node.get_device_or_error().type == "cpu":
|
||||||
static_shape = self.wrapper.static_shape_for_buffer_or_none(self.node)
|
static_shape = self.wrapper.static_shape_for_buffer_or_none(self.node)
|
||||||
|
|
@ -625,6 +688,12 @@ class FreeIfNotReusedLine(MemoryPlanningLine):
|
||||||
node: BufferLike
|
node: BufferLike
|
||||||
is_reused: bool = False
|
is_reused: bool = False
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
assert V.graph.scheduler.current_node is not None
|
||||||
|
self.scheduler_node_index = V.graph.scheduler.nodes.index(
|
||||||
|
V.graph.scheduler.current_node
|
||||||
|
)
|
||||||
|
|
||||||
def plan(self, state: MemoryPlanningState) -> MemoryPlanningLine:
|
def plan(self, state: MemoryPlanningState) -> MemoryPlanningLine:
|
||||||
if len(self.node.get_inputs_that_alias_output()) > 0:
|
if len(self.node.get_inputs_that_alias_output()) > 0:
|
||||||
return self
|
return self
|
||||||
|
|
@ -1645,6 +1714,7 @@ class PythonWrapperCodegen(CodeGen):
|
||||||
if is_inference and config.memory_planning:
|
if is_inference and config.memory_planning:
|
||||||
self.memory_plan()
|
self.memory_plan()
|
||||||
else:
|
else:
|
||||||
|
self.estimate_peak = EfficientPeakEstimate()
|
||||||
self.memory_plan_reuse()
|
self.memory_plan_reuse()
|
||||||
|
|
||||||
def codegen_input_symbol_assignment(
|
def codegen_input_symbol_assignment(
|
||||||
|
|
|
||||||
|
|
@ -2073,6 +2073,7 @@ class Scheduler:
|
||||||
)
|
)
|
||||||
|
|
||||||
self.nodes = [self.create_scheduler_node(n) for n in nodes]
|
self.nodes = [self.create_scheduler_node(n) for n in nodes]
|
||||||
|
self.current_node: Optional[BaseSchedulerNode] = None
|
||||||
self.update_zero_dim_cpu_tensor()
|
self.update_zero_dim_cpu_tensor()
|
||||||
# some new constants could have been created above
|
# some new constants could have been created above
|
||||||
self.available_buffer_names.update(V.graph.constants.keys())
|
self.available_buffer_names.update(V.graph.constants.keys())
|
||||||
|
|
@ -4989,6 +4990,7 @@ class Scheduler:
|
||||||
assert device.index is not None, "device should have an index"
|
assert device.index is not None, "device should have an index"
|
||||||
V.graph.wrapper_code.codegen_device_guard_enter(device.index)
|
V.graph.wrapper_code.codegen_device_guard_enter(device.index)
|
||||||
|
|
||||||
|
self.current_node = node
|
||||||
self.buffer_names_to_free.update(node.last_usage)
|
self.buffer_names_to_free.update(node.last_usage)
|
||||||
|
|
||||||
if node.is_template():
|
if node.is_template():
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user