mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[aotd] Guess tangents stride as output strides (#144579)
AOTDispatch doing AOT backward graph preparation does not know real tangents that user will specify when runs backward. AOTD guesses the tangents. Before - we guessed that memory format of tangents will be as memory format of corresponding outputs. And if specified tangents at runtime are not the same memory format as we guessed during compilation, AOTD does coercion (copy) to guessed memory_format But as Horace found, there are popular use cases, where the outputs of compiled region will be in specific memory_format. E.g. in 4D tensor transposing dims 1 and 2. https://github.com/karpathy/nanoGPT/blob/master/model.py#L57 This PR changes the logic, that AOTD expects the same "strideness" of tangents as outputs. As a result it will avoid coercion for the case of transposed dims. Limitations: We keep guessing memory_format for: 1/ Dynamic shapes (needs more changes) 2/ Tensor subclasses (needs more changes) Other changes: test_torchinductor was always creating contiguous tangents via `torch.randn()`, changing them to be `torch.randn_like()` to compare computation with the same strideness. (E.g. for cuda float16 strideness affects numerics for fft ops). Pull Request resolved: https://github.com/pytorch/pytorch/pull/144579 Approved by: https://github.com/bdhirsh
This commit is contained in:
parent
9b1127437e
commit
2c4bc65366
|
|
@ -131,23 +131,23 @@ class _multiply_invoke(torch.nn.Module):
|
||||||
actual,
|
actual,
|
||||||
"""\
|
"""\
|
||||||
class GraphModule(torch.nn.Module):
|
class GraphModule(torch.nn.Module):
|
||||||
def forward(self, L_inputs_ : list, L_sizes_0_: "Sym(s0)"):
|
def forward(self, L_inputs_ : list, L_sizes_0_: "Sym(2)"):
|
||||||
l_inputs_ = L_inputs_
|
l_inputs_ = L_inputs_
|
||||||
l_sizes_0_ = L_sizes_0_
|
l_sizes_0_ = L_sizes_0_
|
||||||
|
|
||||||
getitem: "f32[s0]" = l_inputs_[0]; l_inputs_ = None
|
getitem: "f32[2]" = l_inputs_[0]; l_inputs_ = None
|
||||||
|
|
||||||
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [l_sizes_0_], False)]); getitem = l_sizes_0_ = None
|
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [l_sizes_0_], False)]); getitem = l_sizes_0_ = None
|
||||||
getitem_9: "f32[s0]" = validate_outputs[0]; validate_outputs = None
|
getitem_9: "f32[2]" = validate_outputs[0]; validate_outputs = None
|
||||||
|
|
||||||
call_aot_bwd_prologue = torch__dynamo_compiled_autograd_call_aot_bwd_prologue((), [], getitem_9); getitem_9 = None
|
call_aot_bwd_prologue = torch__dynamo_compiled_autograd_call_aot_bwd_prologue((), [], getitem_9); getitem_9 = None
|
||||||
aot1_tangents_1: "f32[s0]" = call_aot_bwd_prologue[0]; call_aot_bwd_prologue = None
|
aot1_tangents_1: "f32[2]" = call_aot_bwd_prologue[0]; call_aot_bwd_prologue = None
|
||||||
|
|
||||||
new_grad: "f32[s0]" = torch.clone(aot1_tangents_1)
|
new_grad: "f32[2]" = torch.clone(aot1_tangents_1)
|
||||||
|
|
||||||
result: "f32[s0]" = aot1_tangents_1 * aot1_tangents_1; aot1_tangents_1 = None
|
result: "f32[2]" = aot1_tangents_1 * aot1_tangents_1; aot1_tangents_1 = None
|
||||||
|
|
||||||
new_grad_1: "f32[s0]" = torch.clone(result); result = None
|
new_grad_1: "f32[2]" = torch.clone(result); result = None
|
||||||
return (new_grad, new_grad_1)
|
return (new_grad, new_grad_1)
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
@ -156,23 +156,23 @@ class GraphModule(torch.nn.Module):
|
||||||
actual,
|
actual,
|
||||||
"""\
|
"""\
|
||||||
class GraphModule(torch.nn.Module):
|
class GraphModule(torch.nn.Module):
|
||||||
def forward(self, L_inputs_ : list, L_sizes_0_: "Sym(s0)"):
|
def forward(self, L_inputs_ : list, L_sizes_0_: "Sym(2)"):
|
||||||
l_inputs_ = L_inputs_
|
l_inputs_ = L_inputs_
|
||||||
l_sizes_0_ = L_sizes_0_
|
l_sizes_0_ = L_sizes_0_
|
||||||
|
|
||||||
getitem: "f32[s0]" = l_inputs_[0]; l_inputs_ = None
|
getitem: "f32[2]" = l_inputs_[0]; l_inputs_ = None
|
||||||
|
|
||||||
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [l_sizes_0_], False)]); getitem = l_sizes_0_ = None
|
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [l_sizes_0_], False)]); getitem = l_sizes_0_ = None
|
||||||
getitem_9: "f32[s0]" = validate_outputs[0]; validate_outputs = None
|
getitem_9: "f32[2]" = validate_outputs[0]; validate_outputs = None
|
||||||
|
|
||||||
call_aot_bwd_prologue = torch__dynamo_compiled_autograd_call_aot_bwd_prologue((), [], getitem_9); getitem_9 = None
|
call_aot_bwd_prologue = torch__dynamo_compiled_autograd_call_aot_bwd_prologue((), [], getitem_9); getitem_9 = None
|
||||||
aot3_tangents_1: "f32[s0]" = call_aot_bwd_prologue[0]; call_aot_bwd_prologue = None
|
aot3_tangents_1: "f32[2]" = call_aot_bwd_prologue[0]; call_aot_bwd_prologue = None
|
||||||
|
|
||||||
new_grad: "f32[s0]" = torch.clone(aot3_tangents_1)
|
new_grad: "f32[2]" = torch.clone(aot3_tangents_1)
|
||||||
|
|
||||||
result: "f32[s0]" = aot3_tangents_1 * aot3_tangents_1; aot3_tangents_1 = None
|
result: "f32[2]" = aot3_tangents_1 * aot3_tangents_1; aot3_tangents_1 = None
|
||||||
|
|
||||||
new_grad_1: "f32[s0]" = torch.clone(result); result = None
|
new_grad_1: "f32[2]" = torch.clone(result); result = None
|
||||||
return (new_grad, new_grad_1)
|
return (new_grad, new_grad_1)
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
@ -233,26 +233,26 @@ class GraphModule(torch.nn.Module):
|
||||||
actual,
|
actual,
|
||||||
"""\
|
"""\
|
||||||
class GraphModule(torch.nn.Module):
|
class GraphModule(torch.nn.Module):
|
||||||
def forward(self, L_inputs_ : list, L_sizes_0_: "Sym(s0)", L_hooks_1_keywords_fn_keywords_obj_counter: "Sym(s7)"):
|
def forward(self, L_inputs_ : list, L_sizes_0_: "Sym(2)", L_hooks_1_keywords_fn_keywords_obj_counter: "Sym(s7)"):
|
||||||
l_inputs_ = L_inputs_
|
l_inputs_ = L_inputs_
|
||||||
l_sizes_0_ = L_sizes_0_
|
l_sizes_0_ = L_sizes_0_
|
||||||
l_hooks_1_keywords_fn_keywords_obj_counter = L_hooks_1_keywords_fn_keywords_obj_counter
|
l_hooks_1_keywords_fn_keywords_obj_counter = L_hooks_1_keywords_fn_keywords_obj_counter
|
||||||
|
|
||||||
getitem: "f32[s0]" = l_inputs_[0]; l_inputs_ = None
|
getitem: "f32[2]" = l_inputs_[0]; l_inputs_ = None
|
||||||
|
|
||||||
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [l_sizes_0_], False)]); getitem = l_sizes_0_ = None
|
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [l_sizes_0_], False)]); getitem = l_sizes_0_ = None
|
||||||
getitem_9: "f32[s0]" = validate_outputs[0]; validate_outputs = None
|
getitem_9: "f32[2]" = validate_outputs[0]; validate_outputs = None
|
||||||
|
|
||||||
call_aot_bwd_prologue = torch__dynamo_compiled_autograd_call_aot_bwd_prologue((), [], getitem_9); getitem_9 = None
|
call_aot_bwd_prologue = torch__dynamo_compiled_autograd_call_aot_bwd_prologue((), [], getitem_9); getitem_9 = None
|
||||||
aot0_tangents_1: "f32[s0]" = call_aot_bwd_prologue[0]; call_aot_bwd_prologue = None
|
aot0_tangents_1: "f32[2]" = call_aot_bwd_prologue[0]; call_aot_bwd_prologue = None
|
||||||
|
|
||||||
new_grad: "f32[s0]" = torch.clone(aot0_tangents_1)
|
new_grad: "f32[2]" = torch.clone(aot0_tangents_1)
|
||||||
|
|
||||||
add: "Sym(s7 + 1)" = l_hooks_1_keywords_fn_keywords_obj_counter + 1; l_hooks_1_keywords_fn_keywords_obj_counter = None
|
add: "Sym(s7 + 1)" = l_hooks_1_keywords_fn_keywords_obj_counter + 1; l_hooks_1_keywords_fn_keywords_obj_counter = None
|
||||||
|
|
||||||
result: "f32[s0]" = aot0_tangents_1 * aot0_tangents_1; aot0_tangents_1 = None
|
result: "f32[2]" = aot0_tangents_1 * aot0_tangents_1; aot0_tangents_1 = None
|
||||||
|
|
||||||
new_grad_1: "f32[s0]" = torch.clone(result); result = None
|
new_grad_1: "f32[2]" = torch.clone(result); result = None
|
||||||
return (new_grad, new_grad_1, add)
|
return (new_grad, new_grad_1, add)
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -52,6 +52,7 @@ from torch._inductor.output_code import MockFXGraphCacheOutput
|
||||||
from torch._subclasses.fake_tensor import DynamicOutputShapeException, FakeTensorMode
|
from torch._subclasses.fake_tensor import DynamicOutputShapeException, FakeTensorMode
|
||||||
from torch.fx.experimental.proxy_tensor import is_sym_node
|
from torch.fx.experimental.proxy_tensor import is_sym_node
|
||||||
from torch.fx.experimental.symbolic_shapes import GuardOnDataDependentSymNode, ShapeEnv
|
from torch.fx.experimental.symbolic_shapes import GuardOnDataDependentSymNode, ShapeEnv
|
||||||
|
from torch.nn.attention.flex_attention import flex_attention
|
||||||
from torch.nn.utils.rnn import PackedSequence
|
from torch.nn.utils.rnn import PackedSequence
|
||||||
from torch.testing._internal.common_device_type import (
|
from torch.testing._internal.common_device_type import (
|
||||||
instantiate_device_type_tests,
|
instantiate_device_type_tests,
|
||||||
|
|
@ -5802,45 +5803,69 @@ metadata incorrectly.
|
||||||
class GradsNoForceContiguousContextManager(ContextDecorator):
|
class GradsNoForceContiguousContextManager(ContextDecorator):
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
# flake8: noqa: TOR901
|
# flake8: noqa: TOR901
|
||||||
self.lib = torch.library.Library("_mylib", "FRAGMENT")
|
self.lib = torch.library.Library("_test_aotdispatch_lib", "FRAGMENT")
|
||||||
self.d = {
|
self.d = {
|
||||||
torch.channels_last: 0,
|
torch.channels_last: 0,
|
||||||
torch.contiguous_format: 0,
|
torch.contiguous_format: 0,
|
||||||
}
|
}
|
||||||
|
self.tangent_strides = []
|
||||||
|
|
||||||
self.lib.define("foo(Tensor x) -> Tensor")
|
self.lib.define("log_tangents_memory_format(Tensor x) -> Tensor")
|
||||||
self.lib.define("foo2(Tensor x) -> Tensor")
|
self.lib.define("log_tangents_memory_format_log(Tensor x) -> Tensor")
|
||||||
|
|
||||||
def foo_impl(a):
|
def log_tangents_memory_format_impl(a):
|
||||||
return a.clone()
|
return a.clone()
|
||||||
|
|
||||||
def foo_meta(a):
|
def log_tangents_memory_format_meta(a):
|
||||||
return a.clone()
|
return a.clone()
|
||||||
|
|
||||||
def foo2_impl(x):
|
def log_tangents_memory_format_log_impl(x):
|
||||||
self.d[torch._prims_common.suggest_memory_format(x)] += 1
|
self.d[torch._prims_common.suggest_memory_format(x)] += 1
|
||||||
|
self.tangent_strides.append(x.stride())
|
||||||
return x.clone()
|
return x.clone()
|
||||||
|
|
||||||
def foo2_meta(a):
|
def log_tangents_memory_format_log_meta(a):
|
||||||
return a.clone()
|
return a.clone()
|
||||||
|
|
||||||
for backend in ["CPU", "CUDA"]:
|
for backend in ["CPU", "CUDA"]:
|
||||||
self.lib.impl("foo", foo_impl, backend)
|
self.lib.impl(
|
||||||
self.lib.impl("foo2", foo2_impl, backend)
|
"log_tangents_memory_format", log_tangents_memory_format_impl, backend
|
||||||
|
)
|
||||||
|
self.lib.impl(
|
||||||
|
"log_tangents_memory_format_log",
|
||||||
|
log_tangents_memory_format_log_impl,
|
||||||
|
backend,
|
||||||
|
)
|
||||||
|
|
||||||
self.lib.impl("foo", foo_meta, "Meta")
|
self.lib.impl(
|
||||||
self.lib.impl("foo2", foo2_meta, "Meta")
|
"log_tangents_memory_format", log_tangents_memory_format_meta, "Meta"
|
||||||
|
)
|
||||||
|
self.lib.impl(
|
||||||
|
"log_tangents_memory_format_log",
|
||||||
|
log_tangents_memory_format_log_meta,
|
||||||
|
"Meta",
|
||||||
|
)
|
||||||
|
|
||||||
def foo_bwd(ctx, grad):
|
def log_tangents_memory_format_bwd(ctx, grad):
|
||||||
torch.ops._mylib.foo2(grad)
|
torch.ops._test_aotdispatch_lib.log_tangents_memory_format_log(grad)
|
||||||
return grad.clone()
|
return grad.clone()
|
||||||
|
|
||||||
torch.library.register_autograd("_mylib::foo", foo_bwd, lib=self.lib)
|
torch.library.register_autograd(
|
||||||
|
"_test_aotdispatch_lib::log_tangents_memory_format",
|
||||||
|
log_tangents_memory_format_bwd,
|
||||||
|
lib=self.lib,
|
||||||
|
)
|
||||||
|
|
||||||
from torch._higher_order_ops.effects import _EffectType, _register_effectful_op
|
from torch._higher_order_ops.effects import _EffectType, _register_effectful_op
|
||||||
|
|
||||||
_register_effectful_op(torch.ops._mylib.foo.default, _EffectType.ORDERED)
|
_register_effectful_op(
|
||||||
_register_effectful_op(torch.ops._mylib.foo2.default, _EffectType.ORDERED)
|
torch.ops._test_aotdispatch_lib.log_tangents_memory_format.default,
|
||||||
|
_EffectType.ORDERED,
|
||||||
|
)
|
||||||
|
_register_effectful_op(
|
||||||
|
torch.ops._test_aotdispatch_lib.log_tangents_memory_format_log.default,
|
||||||
|
_EffectType.ORDERED,
|
||||||
|
)
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
@ -6097,7 +6122,7 @@ class TestAOTModuleSimplified(AOTTestCase):
|
||||||
z = y + 3
|
z = y + 3
|
||||||
y.mul_(2)
|
y.mul_(2)
|
||||||
r = self.conv(x)
|
r = self.conv(x)
|
||||||
r = torch.ops._mylib.foo(r)
|
r = torch.ops._test_aotdispatch_lib.log_tangents_memory_format(r)
|
||||||
return (
|
return (
|
||||||
r,
|
r,
|
||||||
r.transpose(0, 1),
|
r.transpose(0, 1),
|
||||||
|
|
@ -6143,7 +6168,7 @@ class TestAOTModuleSimplified(AOTTestCase):
|
||||||
|
|
||||||
def forward(self, x, y):
|
def forward(self, x, y):
|
||||||
r = self.conv(x)
|
r = self.conv(x)
|
||||||
r = torch.ops._mylib.foo(r)
|
r = torch.ops._test_aotdispatch_lib.log_tangents_memory_format(r)
|
||||||
return r, y + 1
|
return r, y + 1
|
||||||
|
|
||||||
m = M()
|
m = M()
|
||||||
|
|
@ -6186,7 +6211,7 @@ class TestAOTModuleSimplified(AOTTestCase):
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
r = self.conv(x)
|
r = self.conv(x)
|
||||||
r = torch.ops._mylib.foo(r)
|
r = torch.ops._test_aotdispatch_lib.log_tangents_memory_format(r)
|
||||||
return r
|
return r
|
||||||
|
|
||||||
m = M()
|
m = M()
|
||||||
|
|
@ -6466,6 +6491,116 @@ metadata incorrectly.
|
||||||
_test_fn(fn_mutation)
|
_test_fn(fn_mutation)
|
||||||
_test_fn(fn_inplace, check_backward=False)
|
_test_fn(fn_inplace, check_backward=False)
|
||||||
|
|
||||||
|
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable")
|
||||||
|
@parametrize("dynamic_shapes", [True, False])
|
||||||
|
@parametrize("test_subclasses", [True, False])
|
||||||
|
@parametrize("device", ["cuda", "cpu"])
|
||||||
|
def test_noncontig_nonmemformat_tangents(
|
||||||
|
self, dynamic_shapes, test_subclasses, device
|
||||||
|
):
|
||||||
|
B = 2
|
||||||
|
T = 4
|
||||||
|
E = 6
|
||||||
|
|
||||||
|
def fn(x):
|
||||||
|
x = x + 1
|
||||||
|
return x.transpose(1, 2)
|
||||||
|
|
||||||
|
def _inp_dense():
|
||||||
|
t = torch.randn(B, T, E, device=device, requires_grad=True)
|
||||||
|
if dynamic_shapes:
|
||||||
|
for i in range(t.ndim):
|
||||||
|
torch._dynamo.mark_dynamic(t, i)
|
||||||
|
return t
|
||||||
|
|
||||||
|
def _inp_sc():
|
||||||
|
return TwoTensor(_inp_dense(), _inp_dense())
|
||||||
|
|
||||||
|
_inp = _inp_dense if not test_subclasses else _inp_sc
|
||||||
|
|
||||||
|
comp_fn = torch.compile(fn, backend="aot_eager", fullgraph=True)
|
||||||
|
|
||||||
|
def _tg3(y):
|
||||||
|
t = torch.randn(
|
||||||
|
2 * y.shape, dtype=y.dtype, layout=y.layout, device=y.device
|
||||||
|
)
|
||||||
|
return t.as_strided(y.shape, tuple(s * 2 for s in y.stride()))
|
||||||
|
|
||||||
|
TEST_CASES = [
|
||||||
|
(_inp, lambda y: torch.ones(y.shape, dtype=y.dtype, device=y.device)),
|
||||||
|
# Memory overlap, dense tangent
|
||||||
|
(
|
||||||
|
_inp,
|
||||||
|
lambda y: torch.tensor([1], dtype=y.dtype, device=y.device).as_strided(
|
||||||
|
y.shape, (0,) * y.ndim
|
||||||
|
),
|
||||||
|
),
|
||||||
|
# No memory overlap, not-dense tangent
|
||||||
|
(_inp, _tg3),
|
||||||
|
]
|
||||||
|
|
||||||
|
for i, (inp_fn, tg_fn) in enumerate(TEST_CASES):
|
||||||
|
ref_x = inp_fn()
|
||||||
|
x = ref_x.detach().clone().requires_grad_()
|
||||||
|
|
||||||
|
ref_y = fn(ref_x)
|
||||||
|
|
||||||
|
y = comp_fn(x)
|
||||||
|
self.assertEqual(ref_y, y)
|
||||||
|
|
||||||
|
ref_tg = (
|
||||||
|
tg_fn(ref_y)
|
||||||
|
if not test_subclasses
|
||||||
|
else TwoTensor(tg_fn(ref_y), tg_fn(ref_y))
|
||||||
|
)
|
||||||
|
tg = ref_tg.clone()
|
||||||
|
|
||||||
|
ref_y.backward(ref_tg)
|
||||||
|
y.backward(tg)
|
||||||
|
|
||||||
|
self.assertEqual(ref_x.grad, x.grad)
|
||||||
|
|
||||||
|
def test_flex_attn_noncontiguous_tangents(self):
|
||||||
|
with GradsNoForceContiguousContextManager() as ctx:
|
||||||
|
E = 16 # embedding dim
|
||||||
|
H = 4 # number of heads
|
||||||
|
|
||||||
|
@torch.compile(backend="aot_eager", fullgraph=True)
|
||||||
|
def attn_fn(q, k, v):
|
||||||
|
y = flex_attention(query=q, key=k, value=v)
|
||||||
|
y = torch.ops._test_aotdispatch_lib.log_tangents_memory_format(y)
|
||||||
|
return y
|
||||||
|
|
||||||
|
class M(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.c_attn = torch.nn.Linear(E, 3 * E)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
B, T, E = x.size()
|
||||||
|
q, k, v = self.c_attn(x).split(E, dim=2)
|
||||||
|
k = k.view(B, T, H, E // H).transpose(1, 2) # (B, nh, T, hs)
|
||||||
|
q = q.view(B, T, H, E // H).transpose(1, 2) # (B, nh, T, hs)
|
||||||
|
v = v.view(B, T, H, E // H).transpose(1, 2) # (B, nh, T, hs)
|
||||||
|
|
||||||
|
y = attn_fn(q, k, v)
|
||||||
|
|
||||||
|
return y.transpose(1, 2).contiguous().view(B, T, E)
|
||||||
|
|
||||||
|
m = M()
|
||||||
|
B = 1
|
||||||
|
T = 8
|
||||||
|
|
||||||
|
def _inp():
|
||||||
|
return torch.randn(B, T, E, requires_grad=True)
|
||||||
|
|
||||||
|
x = _inp()
|
||||||
|
y = m(x)
|
||||||
|
y.backward(torch.ones_like(y).contiguous())
|
||||||
|
|
||||||
|
self.assertEqual(1, len(ctx.tangent_strides))
|
||||||
|
self.assertEqual((128, 4, 16, 1), ctx.tangent_strides[0])
|
||||||
|
|
||||||
|
|
||||||
# entries in here don't work and need to be fixed.
|
# entries in here don't work and need to be fixed.
|
||||||
# Each one of these is a bug (or needs to be investigated)
|
# Each one of these is a bug (or needs to be investigated)
|
||||||
|
|
@ -6745,6 +6880,7 @@ class TestEagerFusionModuleInfo(AOTTestCase):
|
||||||
|
|
||||||
|
|
||||||
instantiate_parametrized_tests(TestAOTAutograd)
|
instantiate_parametrized_tests(TestAOTAutograd)
|
||||||
|
instantiate_parametrized_tests(TestAOTModuleSimplified)
|
||||||
only_for = "cpu"
|
only_for = "cpu"
|
||||||
instantiate_device_type_tests(
|
instantiate_device_type_tests(
|
||||||
TestPythonKey,
|
TestPythonKey,
|
||||||
|
|
|
||||||
|
|
@ -551,7 +551,7 @@ def check_model(
|
||||||
|
|
||||||
# generate random unit norm gradients
|
# generate random unit norm gradients
|
||||||
grads = [
|
grads = [
|
||||||
torch.rand(r.shape, device=r.device, dtype=r.dtype)
|
torch.randn_like(r)
|
||||||
for r in correct_flat
|
for r in correct_flat
|
||||||
if isinstance(r, torch.Tensor) and r.requires_grad
|
if isinstance(r, torch.Tensor) and r.requires_grad
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -466,6 +466,33 @@ inductor_override_kwargs["cuda"] = {
|
||||||
("index_reduce.amax", f32): {"check_gradient": False},
|
("index_reduce.amax", f32): {"check_gradient": False},
|
||||||
("index_reduce.amax", f16): {"check_gradient": False},
|
("index_reduce.amax", f16): {"check_gradient": False},
|
||||||
("tanh", f16): {"atol": 1e-4, "rtol": 1e-2},
|
("tanh", f16): {"atol": 1e-4, "rtol": 1e-2},
|
||||||
|
("_unsafe_masked_index", f16): {
|
||||||
|
"reference_in_float": True,
|
||||||
|
"atol": 3e-4,
|
||||||
|
"rtol": 2e-3,
|
||||||
|
},
|
||||||
|
("nn.functional.interpolate.linear", f16): {"reference_in_float": True},
|
||||||
|
("nn.functional.prelu", f16): {
|
||||||
|
"reference_in_float": True,
|
||||||
|
"atol": 1e-3,
|
||||||
|
"rtol": 4e-3,
|
||||||
|
},
|
||||||
|
("addmm", f16): {"reference_in_float": True},
|
||||||
|
("logaddexp", f16): {"reference_in_float": True},
|
||||||
|
("std_mean", f16): {"reference_in_float": True},
|
||||||
|
("hypot", f16): {"reference_in_float": True, "atol": 3e-4, "rtol": 2e-3},
|
||||||
|
("cummin", f16): {"reference_in_float": True, "atol": 5e-5, "rtol": 2e-3},
|
||||||
|
("unfold_copy", f16): {"reference_in_float": True, "atol": 2e-5, "rtol": 1e-2},
|
||||||
|
("nn.functional.upsample_bilinear", f16): {
|
||||||
|
"reference_in_float": True,
|
||||||
|
"atol": 1e-4,
|
||||||
|
"rtol": 2e-3,
|
||||||
|
},
|
||||||
|
("nn.functional.embedding_bag", f16): {
|
||||||
|
"reference_in_float": True,
|
||||||
|
"atol": 1e-4,
|
||||||
|
"rtol": 1e-2,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
inductor_override_kwargs["xpu"] = {
|
inductor_override_kwargs["xpu"] = {
|
||||||
|
|
|
||||||
|
|
@ -41,6 +41,7 @@ from .functional_utils import (
|
||||||
from .schemas import (
|
from .schemas import (
|
||||||
FunctionalTensorMetadataEq,
|
FunctionalTensorMetadataEq,
|
||||||
InputAliasInfo,
|
InputAliasInfo,
|
||||||
|
MemoryFormatMeta,
|
||||||
MutationType,
|
MutationType,
|
||||||
OutputAliasInfo,
|
OutputAliasInfo,
|
||||||
OutputType,
|
OutputType,
|
||||||
|
|
@ -73,14 +74,14 @@ def coerce_tangent_and_suggest_memory_format(x: Tensor):
|
||||||
|
|
||||||
out = x.detach()
|
out = x.detach()
|
||||||
|
|
||||||
suggest_memory_format = torch._prims_common.suggest_memory_format
|
|
||||||
is_subclass = is_traceable_wrapper_subclass(out)
|
is_subclass = is_traceable_wrapper_subclass(out)
|
||||||
|
|
||||||
memory_format = suggest_memory_format(out)
|
memory_format = MemoryFormatMeta.from_tensor(out)
|
||||||
|
|
||||||
was = out
|
if memory_format.memory_format is not None:
|
||||||
out = out.contiguous(memory_format=memory_format)
|
was = out
|
||||||
updated = out is not was
|
out = out.contiguous(memory_format=memory_format.memory_format)
|
||||||
|
updated = was is not out
|
||||||
|
|
||||||
# For subclass we keep memory format of outer strides at the beggining of the list
|
# For subclass we keep memory format of outer strides at the beggining of the list
|
||||||
out_memory_format = [memory_format] if is_subclass else memory_format
|
out_memory_format = [memory_format] if is_subclass else memory_format
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,7 @@ from .schemas import (
|
||||||
BackwardSignature,
|
BackwardSignature,
|
||||||
GraphSignature,
|
GraphSignature,
|
||||||
InputAliasInfo,
|
InputAliasInfo,
|
||||||
|
MemoryFormatMeta,
|
||||||
OutputAliasInfo,
|
OutputAliasInfo,
|
||||||
OutputType,
|
OutputType,
|
||||||
ViewAndMutationMeta,
|
ViewAndMutationMeta,
|
||||||
|
|
@ -61,7 +62,9 @@ def remove_dupe_metadata(
|
||||||
|
|
||||||
assert m.subclass_tangent_meta is not None
|
assert m.subclass_tangent_meta is not None
|
||||||
subclass_tangent_meta = [
|
subclass_tangent_meta = [
|
||||||
PlainTensorMeta(0, memory_format=torch.contiguous_format)
|
PlainTensorMeta(
|
||||||
|
0, memory_format=MemoryFormatMeta(memory_format=torch.contiguous_format)
|
||||||
|
)
|
||||||
] * len(filtered_inp_traced_tangents) + m.subclass_tangent_meta[num_data_mutations:]
|
] * len(filtered_inp_traced_tangents) + m.subclass_tangent_meta[num_data_mutations:]
|
||||||
|
|
||||||
return ViewAndMutationMeta(
|
return ViewAndMutationMeta(
|
||||||
|
|
|
||||||
|
|
@ -45,6 +45,7 @@ from .logging_utils import describe_input, format_guard_bug_msg, track_graph_com
|
||||||
from .schemas import (
|
from .schemas import (
|
||||||
AOTConfig,
|
AOTConfig,
|
||||||
InputAliasInfo,
|
InputAliasInfo,
|
||||||
|
MemoryFormatMeta,
|
||||||
MutationType,
|
MutationType,
|
||||||
OutputType,
|
OutputType,
|
||||||
PlainTensorMeta,
|
PlainTensorMeta,
|
||||||
|
|
@ -1752,6 +1753,39 @@ def _backward_epilogue_functional(
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def coerce_to_expected_memory_format(x: torch.Tensor, memory_format: MemoryFormatMeta):
|
||||||
|
if memory_format.memory_format is not None:
|
||||||
|
# Coerce to torch.memory_format
|
||||||
|
if not x.is_contiguous(memory_format=memory_format.memory_format):
|
||||||
|
x = x.contiguous(memory_format=memory_format.memory_format)
|
||||||
|
return x
|
||||||
|
|
||||||
|
expected_size = memory_format.size
|
||||||
|
assert expected_size is not None
|
||||||
|
expected_stride = memory_format.stride
|
||||||
|
assert expected_stride is not None
|
||||||
|
# Expected size and stride are static ints
|
||||||
|
# ok to use == to compare runtime tensor strides and shapes
|
||||||
|
|
||||||
|
if x.shape == expected_size and x.stride() == expected_stride:
|
||||||
|
# Runtime tangent size and stride are the same as expected, no need to coerce
|
||||||
|
return x
|
||||||
|
|
||||||
|
# Empty_strided creates a raw Tensor.
|
||||||
|
# We are guranteed that only raw Tensors has expected size and stride.
|
||||||
|
# Subclasses have only expected memory_format.
|
||||||
|
restrided = torch.empty_strided(
|
||||||
|
size=expected_size,
|
||||||
|
stride=expected_stride,
|
||||||
|
dtype=x.dtype,
|
||||||
|
device=x.device,
|
||||||
|
layout=x.layout,
|
||||||
|
requires_grad=x.requires_grad,
|
||||||
|
)
|
||||||
|
restrided.copy_(x)
|
||||||
|
return restrided
|
||||||
|
|
||||||
|
|
||||||
# This is wrapped in a class just for namespacing purposes
|
# This is wrapped in a class just for namespacing purposes
|
||||||
# No need to make it into an actual CompilerWrapper because it doesn't fit the abstract as cleanly
|
# No need to make it into an actual CompilerWrapper because it doesn't fit the abstract as cleanly
|
||||||
class AOTDispatchAutograd:
|
class AOTDispatchAutograd:
|
||||||
|
|
@ -1761,8 +1795,8 @@ class AOTDispatchAutograd:
|
||||||
return x, [x]
|
return x, [x]
|
||||||
|
|
||||||
if isinstance(x, FakeTensor):
|
if isinstance(x, FakeTensor):
|
||||||
if not x.is_contiguous(memory_format=meta.memory_format):
|
assert meta.memory_format
|
||||||
x = x.contiguous(memory_format=meta.memory_format)
|
x = coerce_to_expected_memory_format(x, meta.memory_format)
|
||||||
return x, [x]
|
return x, [x]
|
||||||
|
|
||||||
expected_type: Optional[type] = torch.Tensor
|
expected_type: Optional[type] = torch.Tensor
|
||||||
|
|
@ -1820,8 +1854,8 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
|
||||||
)
|
)
|
||||||
|
|
||||||
# Coerce to expected memory format
|
# Coerce to expected memory format
|
||||||
if not x.is_contiguous(memory_format=meta.memory_format):
|
assert meta.memory_format
|
||||||
x = x.contiguous(memory_format=meta.memory_format)
|
x = coerce_to_expected_memory_format(x, meta.memory_format)
|
||||||
|
|
||||||
if not is_traceable_wrapper_subclass(x):
|
if not is_traceable_wrapper_subclass(x):
|
||||||
return x, [x]
|
return x, [x]
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,8 @@ input/output types, metadata, config, function signatures etc.
|
||||||
import collections
|
import collections
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import functools
|
import functools
|
||||||
from collections.abc import Iterable
|
import itertools
|
||||||
|
from collections.abc import Iterable, Sequence
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Callable, NewType, Optional, Union
|
from typing import Any, Callable, NewType, Optional, Union
|
||||||
|
|
@ -155,10 +156,47 @@ class InputAliasInfo:
|
||||||
return MutationType.MUTATED_OUT_GRAPH
|
return MutationType.MUTATED_OUT_GRAPH
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MemoryFormatMeta:
|
||||||
|
# For static shapes we assume tangents have the same strideness as outputs
|
||||||
|
size: Optional[Sequence[int]] = None
|
||||||
|
stride: Optional[Sequence[int]] = None
|
||||||
|
|
||||||
|
# For dynamic shapes we assume the same memory format: contiguous, channels_last etc.
|
||||||
|
memory_format: Optional[torch.memory_format] = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_tensor(t: torch.Tensor) -> Optional["MemoryFormatMeta"]:
|
||||||
|
# We only memorize expected memory format for
|
||||||
|
# 1. Traceable wrapper subclasses
|
||||||
|
# We can not create restrided subclass tensor, as torch.empty_strided works only with dense tensors.
|
||||||
|
# 2. Dynamic shape tensors
|
||||||
|
# Support for symbolic shapes is not implemented yet.
|
||||||
|
use_memory_format: bool = is_traceable_wrapper_subclass(t)
|
||||||
|
if not use_memory_format:
|
||||||
|
is_static_shape = True
|
||||||
|
for s in itertools.chain(t.shape, t.stride()):
|
||||||
|
if not isinstance(s, int):
|
||||||
|
is_static_shape = False
|
||||||
|
break
|
||||||
|
|
||||||
|
use_memory_format = not is_static_shape
|
||||||
|
|
||||||
|
if use_memory_format:
|
||||||
|
return MemoryFormatMeta(
|
||||||
|
memory_format=torch._prims_common.suggest_memory_format(t),
|
||||||
|
)
|
||||||
|
|
||||||
|
return MemoryFormatMeta(
|
||||||
|
size=t.size(),
|
||||||
|
stride=t.stride(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PlainTensorMeta:
|
class PlainTensorMeta:
|
||||||
unwrapped_idx: int
|
unwrapped_idx: int
|
||||||
memory_format: Optional[torch.memory_format] = None
|
memory_format: Optional[MemoryFormatMeta] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -204,7 +242,7 @@ class SubclassCreationMeta:
|
||||||
|
|
||||||
# Used at runtime to determine the subclass type, so we don't need to save the original subclass
|
# Used at runtime to determine the subclass type, so we don't need to save the original subclass
|
||||||
original_subclass_type: Optional[type] = None
|
original_subclass_type: Optional[type] = None
|
||||||
memory_format: Optional[torch.memory_format] = None
|
memory_format: Optional[MemoryFormatMeta] = None
|
||||||
|
|
||||||
def compute_outer_size_and_stride(
|
def compute_outer_size_and_stride(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -46,16 +46,16 @@ def requires_subclass_dispatch(args, fw_metadata: ViewAndMutationMeta) -> bool:
|
||||||
return any_subclass_args or any_subclass_outputs
|
return any_subclass_args or any_subclass_outputs
|
||||||
|
|
||||||
|
|
||||||
suggest_memory_format = torch._prims_common.suggest_memory_format
|
from .schemas import MemoryFormatMeta
|
||||||
|
|
||||||
|
|
||||||
def maybe_suggest_memory_format(
|
def maybe_suggest_memory_format(
|
||||||
t, with_memory_format: bool
|
t, with_memory_format: bool
|
||||||
) -> Optional[torch.memory_format]:
|
) -> Optional[MemoryFormatMeta]:
|
||||||
if not with_memory_format:
|
if not with_memory_format:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return suggest_memory_format(t)
|
return MemoryFormatMeta.from_tensor(t)
|
||||||
|
|
||||||
|
|
||||||
def get_subclass_typing_container(
|
def get_subclass_typing_container(
|
||||||
|
|
|
||||||
|
|
@ -200,10 +200,18 @@ class NestedTensor(torch.Tensor):
|
||||||
def _max_seqlen_tensor(self) -> Optional[torch.Tensor]:
|
def _max_seqlen_tensor(self) -> Optional[torch.Tensor]:
|
||||||
return self._metadata_cache.get("max_seqlen", None)
|
return self._metadata_cache.get("max_seqlen", None)
|
||||||
|
|
||||||
|
@_max_seqlen_tensor.setter
|
||||||
|
def _max_seqlen_tensor(self, val: Optional[torch.Tensor]) -> None:
|
||||||
|
self._metadata_cache["max_seqlen"] = val
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _min_seqlen_tensor(self) -> Optional[torch.Tensor]:
|
def _min_seqlen_tensor(self) -> Optional[torch.Tensor]:
|
||||||
return self._metadata_cache.get("min_seqlen", None)
|
return self._metadata_cache.get("min_seqlen", None)
|
||||||
|
|
||||||
|
@_min_seqlen_tensor.setter
|
||||||
|
def _min_seqlen_tensor(self, val: Optional[torch.Tensor]) -> None:
|
||||||
|
self._metadata_cache["min_seqlen"] = val
|
||||||
|
|
||||||
# These are old private @property accessors that are kept around for internal BC
|
# These are old private @property accessors that are kept around for internal BC
|
||||||
# reasons. TODO: Remove these!
|
# reasons. TODO: Remove these!
|
||||||
@property
|
@property
|
||||||
|
|
|
||||||
|
|
@ -12696,10 +12696,26 @@ op_db: list[OpInfo] = [
|
||||||
check_batched_gradgrad=True,
|
check_batched_gradgrad=True,
|
||||||
sample_inputs_func=sample_inputs_linalg_cholesky_inverse,
|
sample_inputs_func=sample_inputs_linalg_cholesky_inverse,
|
||||||
gradcheck_wrapper=gradcheck_wrapper_triangular_input_real_positive_diagonal,
|
gradcheck_wrapper=gradcheck_wrapper_triangular_input_real_positive_diagonal,
|
||||||
decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack],
|
decorators=[
|
||||||
|
skipCUDAIfNoMagma,
|
||||||
|
skipCPUIfNoLapack,
|
||||||
|
DecorateInfo(
|
||||||
|
toleranceOverride({
|
||||||
|
torch.float32: tol(atol=5e-03, rtol=1e-04)
|
||||||
|
}),
|
||||||
|
'TestCommon', device_type='cpu',
|
||||||
|
),
|
||||||
|
DecorateInfo(
|
||||||
|
toleranceOverride({
|
||||||
|
torch.float32: tol(atol=5e-03, rtol=1e-04)
|
||||||
|
}),
|
||||||
|
'TestEagerFusionOpInfo', device_type='cpu',
|
||||||
|
),
|
||||||
|
],
|
||||||
skips=(
|
skips=(
|
||||||
# Strides are not the same! Original strides were ((4, 2, 1),) and strides are now ((4, 1, 2),)
|
# Strides are not the same! Original strides were ((4, 2, 1),) and strides are now ((4, 1, 2),)
|
||||||
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),)),
|
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),),
|
||||||
|
),
|
||||||
OpInfo('cholesky_solve',
|
OpInfo('cholesky_solve',
|
||||||
op=torch.cholesky_solve,
|
op=torch.cholesky_solve,
|
||||||
dtypes=floating_and_complex_types(),
|
dtypes=floating_and_complex_types(),
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user