Modify pointwise cat heuristic to only apply when inputs are all pointwise and outputs are all pointwise (#114520)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114520
Approved by: https://github.com/eellison
This commit is contained in:
chilli 2023-12-01 00:42:43 -08:00 committed by PyTorch MergeBot
parent a5a1f0a6b1
commit 3d47b92dfb
3 changed files with 71 additions and 4 deletions

View File

@ -216,6 +216,54 @@ class NumBytesMetricTests(TestCase):
inp = [T(10, 10, 10), T(10, 10, 10)]
self.assertExpectedInline(count_numel(f, *inp), """2600""")
def test_cat_pointwise(self):
def f(a, b):
return torch.cat([torch.softmax(a, dim=-1), torch.softmax(b, dim=-1)])
inp = (T(10, 10), T(10, 10))
self.assertExpectedInline(count_numel(f, *inp), """400""")
def f(a, b):
return torch.cat([torch.softmax(a, dim=-1), torch.softmax(b, dim=-1)]).cos()
inp = (T(10, 10), T(10, 10))
self.assertExpectedInline(count_numel(f, *inp), """680""")
# This one is a little bit tricky since in theory, fusing the `cos()`
# could result in saving a read.
# But in this case, using masked pointwise codegen for concat forces
# softmax to materialize extra values, so we don't want to.
def f(a, b):
out = torch.cat([torch.softmax(a, dim=-1), torch.softmax(b, dim=-1)])
return out, out.cos()
inp = (T(10, 10), T(10, 10))
self.assertExpectedInline(count_numel(f, *inp), """800""")
def f(a, b):
out = torch.cat([a, b])
return out.cos()
inp = (T(10, 10), T(10, 10))
self.assertExpectedInline(count_numel(f, *inp), """400""")
# Should turn into pointwise even if only some of inputs are pointwise.
def f(a, b):
out = torch.cat([a.cos(), torch.mm(b, b)])
return out.cos()
inp = (T(10, 10), T(10, 10))
self.assertExpectedInline(count_numel(f, *inp), """600""")
# Should not turn into pointwise if all inputs are not pointwise
def f(a, b):
out = torch.cat([torch.mm(a, a), torch.mm(b, b)])
return out.cos()
inp = (T(10, 10), T(10, 10))
self.assertExpectedInline(count_numel(f, *inp), """800""")
def test_index(self):
def f(a, b):
return a[b]
@ -799,6 +847,7 @@ class InplacingTests(TestCase):
# Test cases where we don't do the right thing yet.
# NOTE: These tests do not get run (and that's intentional)!
class WouldBeNiceIfItWorked:
def test_horizontal(self):
def f(a):

View File

@ -686,6 +686,22 @@ class CommonTemplate:
],
)
def test_index_put_bf16(self):
def fn(inp, src, index):
inp2 = inp.clone()
inp2[index] = src
return inp2
for dtype in [torch.int64, torch.bool, torch.bfloat16]:
self.common(
fn,
[
torch.zeros(3, 5, dtype=dtype),
torch.ones((2, 5), dtype=dtype),
torch.tensor([0, 1]),
],
)
def test_randn_generator(self):
def fn(a, generator):
torch.randn([20, 20], generator=generator, device=a.device)

View File

@ -1076,12 +1076,14 @@ def cat(inputs, dim=0):
return False
if len(inputs) <= config.max_pointwise_cat_inputs:
if (
len(inputs) <= config.max_pointwise_cat_inputs
and inputs[0].get_device().type != "cpu"
):
pointwise_uses = all(is_pointwise_use(use) for use in V.current_node.users)
all_pointwise_inputs = all(should_lower_cat_input(inp) for inp in inputs)
any_pointwise_inputs = any(should_lower_cat_input(inp) for inp in inputs)
all_pointwise_inputs = any(should_lower_cat_input(inp) for inp in inputs)
if all_pointwise_inputs or (any_pointwise_inputs and pointwise_uses):
if all_pointwise_inputs and pointwise_uses:
return pointwise_cat(inputs, dim)
return TensorBox(ir.ConcatKernel.create(inputs, dim))