mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[Inductor] fix device error for NopKernelSchedulerNode (#141372)
This PR adds device guard support for NopKernelSchedulerNode which may create a tensor. Prior to this PR, we do not codegen device guard for NopKernelSchedulerNode, leading to errors.
Prior to the PR:
```python
def call(args):
arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1 = args
args.clear()
assert_size_stride(arg0_1, (1, 1, 2048, 128), (262144, 262144, 128, 1))
assert_size_stride(arg1_1, (1, 1, 2048, 128), (262144, 262144, 128, 1))
assert_size_stride(arg2_1, (1, 1, 2048, 128), (262144, 262144, 128, 1))
assert_size_stride(arg3_1, (1, 1, 16), (16, 16, 1))
assert_size_stride(arg4_1, (1, 1, 16, 16), (256, 256, 16, 1))
assert_size_stride(arg5_1, (1, 1, 16), (16, 16, 1))
assert_size_stride(arg6_1, (1, 1, 16, 16), (256, 256, 16, 1))
assert_size_stride(arg7_1, (1, 1, 16), (16, 16, 1))
assert_size_stride(arg8_1, (1, 1, 16, 16), (256, 256, 16, 1))
assert_size_stride(arg9_1, (1, 1, 16), (16, 16, 1))
assert_size_stride(arg10_1, (1, 1, 16, 16), (256, 256, 16, 1))
buf0 = empty_strided_cuda((1, 1, 2048), (2048, 2048, 1), torch.float32) # TODO: ERROR here. Should be cuda:1
with torch.cuda._DeviceGuard(1):
torch.cuda.set_device(1)
buf1 = empty_strided_cuda((1, 1, 2048, 128), (262144, 262144, 128, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
stream1 = get_raw_stream(1)
breakpoint()
triton_tem_fused_0.run(arg0_1, arg1_1, arg2_1, buf0, arg3_1, arg4_1, arg5_1, arg6_1, buf1, grid=torch._inductor.kernel.flex_attention.flex_attention_grid(1, 1, 2048, 128, meta0), stream=stream1)
del arg0_1
del arg1_1
del arg2_1
del arg3_1
del arg4_1
del arg5_1
del arg6_1
del buf0
return (buf1, )
```
After the PR:
```python
def call(args):
arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1 = args
args.clear()
assert_size_stride(arg0_1, (1, 1, 2048, 128), (262144, 262144, 128, 1))
assert_size_stride(arg1_1, (1, 1, 2048, 128), (262144, 262144, 128, 1))
assert_size_stride(arg2_1, (1, 1, 2048, 128), (262144, 262144, 128, 1))
assert_size_stride(arg3_1, (1, 1, 16), (16, 16, 1))
assert_size_stride(arg4_1, (1, 1, 16, 16), (256, 256, 16, 1))
assert_size_stride(arg5_1, (1, 1, 16), (16, 16, 1))
assert_size_stride(arg6_1, (1, 1, 16, 16), (256, 256, 16, 1))
assert_size_stride(arg7_1, (1, 1, 16), (16, 16, 1))
assert_size_stride(arg8_1, (1, 1, 16, 16), (256, 256, 16, 1))
assert_size_stride(arg9_1, (1, 1, 16), (16, 16, 1))
assert_size_stride(arg10_1, (1, 1, 16, 16), (256, 256, 16, 1))
with torch.cuda._DeviceGuard(1):
torch.cuda.set_device(1)
buf0 = empty_strided_cuda((1, 1, 2048), (2048, 2048, 1), torch.float32) # New: move into device guard
buf1 = empty_strided_cuda((1, 1, 2048, 128), (262144, 262144, 128, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
stream1 = get_raw_stream(1)
triton_tem_fused_0.run(arg0_1, arg1_1, arg2_1, buf0, arg3_1, arg4_1, arg5_1, arg6_1, buf1, grid=torch._inductor.kernel.flex_attention.flex_attention_grid(1, 1, 2048, 128, meta0), stream=stream1)
del arg0_1
del arg1_1
del arg2_1
del arg3_1
del arg4_1
del arg5_1
del arg6_1
del buf0
return (buf1, )
```
Fixes #141010
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141372
Approved by: https://github.com/eellison
This commit is contained in:
parent
3fd51e079d
commit
61a7c83c64
|
|
@ -3303,6 +3303,27 @@ class GraphModule(torch.nn.Module):
|
|||
""", # noqa: B950
|
||||
)
|
||||
|
||||
@unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
|
||||
def test_device_cuda_1(self):
|
||||
class TestModule(torch.nn.Module):
|
||||
def forward(self, q, k, v, block_mask):
|
||||
return flex_attention(q, k, v, block_mask=block_mask)
|
||||
|
||||
q = torch.randn(1, 1, 256, 32, device="cuda:1", dtype=torch.bfloat16)
|
||||
k = torch.randn(1, 1, 256, 32, device="cuda:1", dtype=torch.bfloat16)
|
||||
v = torch.randn(1, 1, 256, 32, device="cuda:1", dtype=torch.bfloat16)
|
||||
mask = create_block_mask(
|
||||
lambda b, h, q_idx, kv_idx: q_idx >= kv_idx,
|
||||
B=None,
|
||||
H=None,
|
||||
Q_LEN=256,
|
||||
KV_LEN=256,
|
||||
device="cuda:1",
|
||||
)
|
||||
mod = torch.compile(TestModule())
|
||||
attn_output = mod(q, k, v, mask)
|
||||
self.assertEqual(attn_output.device, torch.device("cuda:1"))
|
||||
|
||||
|
||||
class TestBlockMask(InductorTestCase):
|
||||
@supported_platform
|
||||
|
|
|
|||
|
|
@ -2494,6 +2494,13 @@ class CommonTemplate:
|
|||
|
||||
self.common(fn, (torch.Tensor([]),))
|
||||
|
||||
@requires_multigpu()
|
||||
def test_linspace4(self):
|
||||
def fn(x):
|
||||
return torch.linspace(0, 2, 0, device=f"{GPU_TYPE}:1")
|
||||
|
||||
self.common(fn, (torch.Tensor([]),))
|
||||
|
||||
def test_tensor1(self):
|
||||
def fn(x):
|
||||
return torch.tensor([1], device=x.device) + x, torch.tensor(
|
||||
|
|
|
|||
|
|
@ -850,8 +850,12 @@ class GraphLowering(torch.fx.Interpreter):
|
|||
device = buffer.get_device()
|
||||
if (
|
||||
# Skip empty CPU tensor so that CUDA graphs can succeed, see https://github.com/pytorch/pytorch/pull/114144
|
||||
not (isinstance(buffer, ir.ComputedBuffer) and buffer.is_zero_elements())
|
||||
and device is not None
|
||||
device is not None
|
||||
and not (
|
||||
isinstance(buffer, ir.ComputedBuffer)
|
||||
and buffer.is_zero_elements()
|
||||
and device == torch.device("cpu")
|
||||
)
|
||||
):
|
||||
self.add_device_info(device)
|
||||
|
||||
|
|
|
|||
|
|
@ -3176,8 +3176,10 @@ class Layout(OutputSpec):
|
|||
offset = ""
|
||||
if self.offset != 0:
|
||||
offset = f", offset={self.offset}"
|
||||
|
||||
device_index_str = "" if self.device.index is None else f":{self.device.index}"
|
||||
return (
|
||||
f"{type(self).__name__}('{self.device.type}', {self.dtype}, "
|
||||
f"{type(self).__name__}('{self.device.type}{device_index_str}', {self.dtype}, "
|
||||
f"size={self.size}, stride={self.stride}{offset})"
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -3033,7 +3033,7 @@ def new_constant(fill_value):
|
|||
dtype = decode_dtype(dtype) or x.get_dtype()
|
||||
device = device or x.get_device()
|
||||
size = [sympy.Integer(s) for s in size]
|
||||
return _full(fill_value, device, dtype, size)
|
||||
return _full(fill_value, decode_device(device), dtype, size)
|
||||
|
||||
return _new_constant
|
||||
|
||||
|
|
@ -3045,7 +3045,12 @@ def new_empty(x, size, *, dtype=None, layout=None, device=None, pin_memory=None)
|
|||
if device is None:
|
||||
device = x.get_device()
|
||||
return empty_strided(
|
||||
size, None, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
|
||||
size,
|
||||
None,
|
||||
dtype=dtype,
|
||||
layout=layout,
|
||||
device=decode_device(device),
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -3059,6 +3064,7 @@ def empty_strided(
|
|||
assert_nyi(layout in (None, torch.strided), f"layout={layout}")
|
||||
dtype = decode_dtype(dtype) or torch.get_default_dtype()
|
||||
device = device or torch.tensor(0.0).device
|
||||
device = decode_device(device)
|
||||
pointwise = _full(fill_value=0, device=device, dtype=dtype, size=size)
|
||||
pointwise.realize()
|
||||
buffer = pointwise.data.data
|
||||
|
|
@ -3089,7 +3095,12 @@ def new_empty_strided(
|
|||
if device is None:
|
||||
device = x.get_device()
|
||||
return empty_strided(
|
||||
size, stride, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
|
||||
size,
|
||||
stride,
|
||||
dtype=dtype,
|
||||
layout=layout,
|
||||
device=decode_device(device),
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -3511,9 +3511,7 @@ class Scheduler:
|
|||
|
||||
self.enter_context(node)
|
||||
|
||||
if not isinstance(node, NopKernelSchedulerNode) and (
|
||||
device := node.get_device()
|
||||
):
|
||||
if device := node.get_device():
|
||||
if (
|
||||
device != self.current_device
|
||||
or node.is_extern()
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user