[NJT] Fix inference mode for composite implicit ops without nested-specific kernel (#146633)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146633
Approved by: https://github.com/jbschlosser
This commit is contained in:
soulitzer 2025-02-06 16:02:31 -05:00 committed by PyTorch MergeBot
parent dfe3b64282
commit 3cadce7af2
2 changed files with 27 additions and 4 deletions

View File

@ -7679,6 +7679,22 @@ torch.cuda.synchronize()
for dynamic in [False, True, None]:
self.assertFalse(_recompiles_for_inputs(f, (nt,), (nt2,), dynamic=dynamic))
def test_dropout_inference_mode(self, device):
seq_len = 32
embed_dim = 128
nt = torch.nested.nested_tensor(
[
torch.randn(11, seq_len, embed_dim, device=device),
torch.randn(11, seq_len, embed_dim, device=device),
],
layout=torch.jagged,
device=device,
)
with torch.inference_mode():
torch.nn.functional.dropout(nt, p=0.05)
@dtypes(torch.float32, torch.double, torch.half)
def test_unbind_backward(self, device, dtype):
nt = torch.nested.nested_tensor(

View File

@ -325,10 +325,17 @@ class NestedTensor(torch.Tensor):
# Poor man's redispatch for composite ops. This becomes relevant under inference
# mode, where disabling autograd key dispatch prevents decomposition.
dk = torch._C.DispatchKey.CompositeImplicitAutogradNestedTensor
if torch._C._dispatch_has_kernel_for_dispatch_key(func.name(), dk):
with torch.overrides.enable_reentrant_dispatch():
return func._op_dk(dk, *args, **kwargs)
all_dks = (
# We want to handle both the cases where NestedTensor overrides the
# composite implicit autograd kernel, and the case where it doesn't.
# Prioritize calling into NestedTensor's kernel if it exists.
torch._C.DispatchKey.CompositeImplicitAutogradNestedTensor,
torch._C.DispatchKey.CompositeImplicitAutograd,
)
for dk in all_dks:
if torch._C._dispatch_has_kernel_for_dispatch_key(func.name(), dk):
with torch.overrides.enable_reentrant_dispatch():
return func._op_dk(dk, *args, **kwargs)
raise NotImplementedError(func)