mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Modify pointwise cat heuristic to only apply when inputs are all pointwise and outputs are all pointwise (#114520)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114520 Approved by: https://github.com/eellison
This commit is contained in:
parent
a5a1f0a6b1
commit
3d47b92dfb
|
|
@ -216,6 +216,54 @@ class NumBytesMetricTests(TestCase):
|
|||
inp = [T(10, 10, 10), T(10, 10, 10)]
|
||||
self.assertExpectedInline(count_numel(f, *inp), """2600""")
|
||||
|
||||
def test_cat_pointwise(self):
|
||||
def f(a, b):
|
||||
return torch.cat([torch.softmax(a, dim=-1), torch.softmax(b, dim=-1)])
|
||||
|
||||
inp = (T(10, 10), T(10, 10))
|
||||
self.assertExpectedInline(count_numel(f, *inp), """400""")
|
||||
|
||||
def f(a, b):
|
||||
return torch.cat([torch.softmax(a, dim=-1), torch.softmax(b, dim=-1)]).cos()
|
||||
|
||||
inp = (T(10, 10), T(10, 10))
|
||||
self.assertExpectedInline(count_numel(f, *inp), """680""")
|
||||
|
||||
# This one is a little bit tricky since in theory, fusing the `cos()`
|
||||
# could result in saving a read.
|
||||
# But in this case, using masked pointwise codegen for concat forces
|
||||
# softmax to materialize extra values, so we don't want to.
|
||||
|
||||
def f(a, b):
|
||||
out = torch.cat([torch.softmax(a, dim=-1), torch.softmax(b, dim=-1)])
|
||||
return out, out.cos()
|
||||
|
||||
inp = (T(10, 10), T(10, 10))
|
||||
self.assertExpectedInline(count_numel(f, *inp), """800""")
|
||||
|
||||
def f(a, b):
|
||||
out = torch.cat([a, b])
|
||||
return out.cos()
|
||||
|
||||
inp = (T(10, 10), T(10, 10))
|
||||
self.assertExpectedInline(count_numel(f, *inp), """400""")
|
||||
|
||||
# Should turn into pointwise even if only some of inputs are pointwise.
|
||||
def f(a, b):
|
||||
out = torch.cat([a.cos(), torch.mm(b, b)])
|
||||
return out.cos()
|
||||
|
||||
inp = (T(10, 10), T(10, 10))
|
||||
self.assertExpectedInline(count_numel(f, *inp), """600""")
|
||||
|
||||
# Should not turn into pointwise if all inputs are not pointwise
|
||||
def f(a, b):
|
||||
out = torch.cat([torch.mm(a, a), torch.mm(b, b)])
|
||||
return out.cos()
|
||||
|
||||
inp = (T(10, 10), T(10, 10))
|
||||
self.assertExpectedInline(count_numel(f, *inp), """800""")
|
||||
|
||||
def test_index(self):
|
||||
def f(a, b):
|
||||
return a[b]
|
||||
|
|
@ -799,6 +847,7 @@ class InplacingTests(TestCase):
|
|||
|
||||
|
||||
# Test cases where we don't do the right thing yet.
|
||||
# NOTE: These tests do not get run (and that's intentional)!
|
||||
class WouldBeNiceIfItWorked:
|
||||
def test_horizontal(self):
|
||||
def f(a):
|
||||
|
|
|
|||
|
|
@ -686,6 +686,22 @@ class CommonTemplate:
|
|||
],
|
||||
)
|
||||
|
||||
def test_index_put_bf16(self):
|
||||
def fn(inp, src, index):
|
||||
inp2 = inp.clone()
|
||||
inp2[index] = src
|
||||
return inp2
|
||||
|
||||
for dtype in [torch.int64, torch.bool, torch.bfloat16]:
|
||||
self.common(
|
||||
fn,
|
||||
[
|
||||
torch.zeros(3, 5, dtype=dtype),
|
||||
torch.ones((2, 5), dtype=dtype),
|
||||
torch.tensor([0, 1]),
|
||||
],
|
||||
)
|
||||
|
||||
def test_randn_generator(self):
|
||||
def fn(a, generator):
|
||||
torch.randn([20, 20], generator=generator, device=a.device)
|
||||
|
|
|
|||
|
|
@ -1076,12 +1076,14 @@ def cat(inputs, dim=0):
|
|||
|
||||
return False
|
||||
|
||||
if len(inputs) <= config.max_pointwise_cat_inputs:
|
||||
if (
|
||||
len(inputs) <= config.max_pointwise_cat_inputs
|
||||
and inputs[0].get_device().type != "cpu"
|
||||
):
|
||||
pointwise_uses = all(is_pointwise_use(use) for use in V.current_node.users)
|
||||
all_pointwise_inputs = all(should_lower_cat_input(inp) for inp in inputs)
|
||||
any_pointwise_inputs = any(should_lower_cat_input(inp) for inp in inputs)
|
||||
all_pointwise_inputs = any(should_lower_cat_input(inp) for inp in inputs)
|
||||
|
||||
if all_pointwise_inputs or (any_pointwise_inputs and pointwise_uses):
|
||||
if all_pointwise_inputs and pointwise_uses:
|
||||
return pointwise_cat(inputs, dim)
|
||||
|
||||
return TensorBox(ir.ConcatKernel.create(inputs, dim))
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user