mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[NJT] Fix inference mode for composite implicit ops without nested-specific kernel (#146633)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146633 Approved by: https://github.com/jbschlosser
This commit is contained in:
parent
dfe3b64282
commit
3cadce7af2
|
|
@ -7679,6 +7679,22 @@ torch.cuda.synchronize()
|
|||
for dynamic in [False, True, None]:
|
||||
self.assertFalse(_recompiles_for_inputs(f, (nt,), (nt2,), dynamic=dynamic))
|
||||
|
||||
def test_dropout_inference_mode(self, device):
|
||||
seq_len = 32
|
||||
embed_dim = 128
|
||||
|
||||
nt = torch.nested.nested_tensor(
|
||||
[
|
||||
torch.randn(11, seq_len, embed_dim, device=device),
|
||||
torch.randn(11, seq_len, embed_dim, device=device),
|
||||
],
|
||||
layout=torch.jagged,
|
||||
device=device,
|
||||
)
|
||||
|
||||
with torch.inference_mode():
|
||||
torch.nn.functional.dropout(nt, p=0.05)
|
||||
|
||||
@dtypes(torch.float32, torch.double, torch.half)
|
||||
def test_unbind_backward(self, device, dtype):
|
||||
nt = torch.nested.nested_tensor(
|
||||
|
|
|
|||
|
|
@ -325,10 +325,17 @@ class NestedTensor(torch.Tensor):
|
|||
|
||||
# Poor man's redispatch for composite ops. This becomes relevant under inference
|
||||
# mode, where disabling autograd key dispatch prevents decomposition.
|
||||
dk = torch._C.DispatchKey.CompositeImplicitAutogradNestedTensor
|
||||
if torch._C._dispatch_has_kernel_for_dispatch_key(func.name(), dk):
|
||||
with torch.overrides.enable_reentrant_dispatch():
|
||||
return func._op_dk(dk, *args, **kwargs)
|
||||
all_dks = (
|
||||
# We want to handle both the cases where NestedTensor overrides the
|
||||
# composite implicit autograd kernel, and the case where it doesn't.
|
||||
# Prioritize calling into NestedTensor's kernel if it exists.
|
||||
torch._C.DispatchKey.CompositeImplicitAutogradNestedTensor,
|
||||
torch._C.DispatchKey.CompositeImplicitAutograd,
|
||||
)
|
||||
for dk in all_dks:
|
||||
if torch._C._dispatch_has_kernel_for_dispatch_key(func.name(), dk):
|
||||
with torch.overrides.enable_reentrant_dispatch():
|
||||
return func._op_dk(dk, *args, **kwargs)
|
||||
|
||||
raise NotImplementedError(func)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user