mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Revert "[inductor] add a threshold for membw saving during fusion (#136782)"
This reverts commit 6647320de2.
Reverted https://github.com/pytorch/pytorch/pull/136782 on behalf of https://github.com/huydhn due to Sorry for reverting your change but test_memory starts to fail after this lands in trunk ([comment](https://github.com/pytorch/pytorch/pull/136782#issuecomment-2423549196))
This commit is contained in:
parent
fecd370ea1
commit
ac7f52b301
|
|
@ -358,7 +358,6 @@ def make_test(
|
|||
device="cuda",
|
||||
**kwargs,
|
||||
):
|
||||
@config.patch("score_fusion_memory_threshold", 1)
|
||||
def test_fn(self):
|
||||
stack = ExitStack()
|
||||
try:
|
||||
|
|
@ -443,7 +442,6 @@ def make_test(
|
|||
|
||||
|
||||
def make_recompile_test(optim_cls, closure=None, kernel_count=2, **kwargs):
|
||||
@config.patch("score_fusion_memory_threshold", 1)
|
||||
@requires_gpu
|
||||
def test_fn(self):
|
||||
torch._dynamo.reset()
|
||||
|
|
|
|||
|
|
@ -412,46 +412,6 @@ class LoopOrderingTest(TestCase):
|
|||
self.do_acc_test(f, x, scale)
|
||||
self.assertEqual(1, metrics.generated_kernel_count)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, "FP8 requires H100+ and MI300+")
|
||||
def test_fp8_pattern_2(self):
|
||||
"""
|
||||
This test repros the fp8 fusion relation issue here:
|
||||
https://github.com/pytorch/pytorch/issues/133242
|
||||
"""
|
||||
ref_dtype = torch.bfloat16
|
||||
M, K = 4096, 4096
|
||||
|
||||
input_tensor = torch.randn(
|
||||
M, K, device="cuda", dtype=ref_dtype, requires_grad=False
|
||||
)
|
||||
scale = torch.Tensor([10.0]).to("cuda")
|
||||
|
||||
E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max
|
||||
E5M2_MAX_POS = torch.finfo(torch.float8_e5m2).max
|
||||
|
||||
def test_pattern2(tensor_x_inp, scale_x):
|
||||
tensor_x = tensor_x_inp * scale_x
|
||||
tensor_x = tensor_x.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS)
|
||||
tensor_fp8 = tensor_x.to(torch.float8_e4m3fn)
|
||||
|
||||
tensor_x_t = (tensor_x_inp * scale_x).t()
|
||||
tensor_x_t = tensor_x_t.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS)
|
||||
tensor_fp8_t = tensor_x_t.to(torch.float8_e4m3fn)
|
||||
|
||||
tensor_fp8_t = tensor_fp8_t.contiguous().t()
|
||||
|
||||
return (tensor_fp8, tensor_fp8_t)
|
||||
|
||||
test_pattern = torch.compile(test_pattern2)
|
||||
tensor_fp8, tensor_fp8_t = test_pattern(input_tensor, scale)
|
||||
|
||||
self.assertEqual(1, metrics.generated_kernel_count)
|
||||
|
||||
expected_numbytes = scale.nbytes # scalar
|
||||
expected_numbytes += input_tensor.nbytes # input
|
||||
expected_numbytes += tensor_fp8.nbytes + tensor_fp8_t.nbytes # output
|
||||
self.assertEqual(expected_numbytes, metrics.num_bytes_accessed)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if HAS_GPU:
|
||||
|
|
|
|||
|
|
@ -438,17 +438,6 @@ loop_ordering_after_fusion = (
|
|||
os.environ.get("TORCHINDUCTOR_LOOP_ORDERING_AFTER_FUSION", "0") == "1"
|
||||
)
|
||||
|
||||
# If fusing two nodes only save less then score_fusion_memory_threshold memory,
|
||||
# we should not bother fusing the nodes.
|
||||
#
|
||||
# This is especially helpful to resolve https://github.com/pytorch/pytorch/issues/133242
|
||||
# Previously we fuse two nodes because of common read of a scalar tensor.
|
||||
# If we skip it, the loop ordering after fusion mechanism kicks in and can
|
||||
# brings more savings.
|
||||
#
|
||||
# For the cases loop ordering after fusion does not help, we don't lose much.
|
||||
score_fusion_memory_threshold = 10
|
||||
|
||||
# For Triton Templates, select fastest of best template + epilogue vs best template + separate epilogue kernel
|
||||
benchmark_epilogue_fusion = (
|
||||
os.environ.get("TORCHINDUCTOR_BENCHMARK_EPILOGUE_FUSION", "1") == "1"
|
||||
|
|
|
|||
|
|
@ -2922,10 +2922,7 @@ class Scheduler:
|
|||
node2.get_name(),
|
||||
)
|
||||
|
||||
return (
|
||||
self.score_fusion_memory(node1, node2)
|
||||
>= config.score_fusion_memory_threshold
|
||||
)
|
||||
return self.score_fusion_memory(node1, node2) > 0
|
||||
|
||||
def unfusable_node(self, node: BaseSchedulerNode) -> bool:
|
||||
"""
|
||||
|
|
@ -2993,10 +2990,7 @@ class Scheduler:
|
|||
return False
|
||||
del device2
|
||||
|
||||
no_shared_data = (
|
||||
self.score_fusion_memory(node1, node2)
|
||||
< config.score_fusion_memory_threshold
|
||||
)
|
||||
no_shared_data = self.score_fusion_memory(node1, node2) == 0
|
||||
if no_shared_data:
|
||||
no_shared_data = not self.has_shared_data_after_reordering_loop(
|
||||
node1, node2
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user