mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
For training graphs (when inputs require grad), previously, we would speculate the forward and backward graph to determine if there are any graph breaks, side effect and etc but would not actually use these speculated graphs. We would just insert a call function node on the graph and later rely on autograd's tracing. This approach does not work for more generalized graphs like graphs that include user defined triton kernels because autograd is not able to do the higher order function conversation. This PR speculates the forward and backward functions and emits them in a HOF that later gets used via templating mechanism. While working on this PR, I have exposed some bugs in the current tracing due to trampoline functions losing the source information resulting in incorrect graphs being produced. I have fixed these source information bugs and killed the trampolines. Pull Request resolved: https://github.com/pytorch/pytorch/pull/116358 Approved by: https://github.com/jansel
918 lines
28 KiB
Python
918 lines
28 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
|
|
import copy
|
|
import math
|
|
|
|
import torch
|
|
|
|
import torch._dynamo.test_case
|
|
import torch._dynamo.testing
|
|
import torch._dynamo.utils
|
|
from torch.testing._internal.common_utils import skipIfRocm
|
|
from torch.testing._internal.triton_utils import HAS_CUDA, requires_cuda
|
|
|
|
if HAS_CUDA:
|
|
import triton
|
|
from torch.testing._internal.triton_utils import add_kernel
|
|
|
|
|
|
class CustomFunc1(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, foo):
|
|
return foo + foo
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
return grad_output
|
|
|
|
|
|
class CustomFunc3(torch.autograd.Function):
|
|
# Test there is graph break in forward function
|
|
@staticmethod
|
|
def forward(ctx, foo):
|
|
result = foo + foo
|
|
torch._dynamo.graph_break()
|
|
result = result + foo
|
|
ctx.save_for_backward(result)
|
|
return result
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
(result,) = ctx.saved_tensors
|
|
return grad_output * math.sqrt(result.numel())
|
|
|
|
|
|
class Module1(torch.nn.Module):
|
|
def forward(self, foo):
|
|
return CustomFunc1().apply(foo)
|
|
|
|
|
|
class Module2(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.fn = CustomFunc1.apply
|
|
|
|
def forward(self, foo):
|
|
return self.fn(foo)
|
|
|
|
|
|
class Module3(torch.nn.Module):
|
|
def forward(self, foo):
|
|
return CustomFunc1().apply(foo)
|
|
|
|
|
|
class Module4(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.fn = CustomFunc1.apply
|
|
|
|
def forward(self, foo):
|
|
return self.fn(foo)
|
|
|
|
|
|
class Module5(torch.nn.Module):
|
|
def forward(self, foo):
|
|
return CustomFunc3().apply(foo)
|
|
|
|
|
|
class Module6(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.fn = CustomFunc3.apply
|
|
|
|
def forward(self, foo):
|
|
return self.fn(foo)
|
|
|
|
|
|
class LinearFunction(torch.autograd.Function):
|
|
# Note that forward, setup_context, and backward are @staticmethods
|
|
@staticmethod
|
|
def forward(input, weight, bias):
|
|
output = input.mm(weight.t())
|
|
if bias is not None:
|
|
output += bias.unsqueeze(0).expand_as(output)
|
|
return output
|
|
|
|
@staticmethod
|
|
# inputs is a Tuple of all of the inputs passed to forward.
|
|
# output is the output of the forward().
|
|
def setup_context(ctx, inputs, output):
|
|
input, weight, bias = inputs
|
|
ctx.save_for_backward(input, weight, bias)
|
|
|
|
# This function has only a single output, so it gets only one gradient
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
input, weight, bias = ctx.saved_tensors
|
|
grad_input = grad_weight = grad_bias = None
|
|
if ctx.needs_input_grad[0]:
|
|
grad_input = grad_output.mm(weight)
|
|
if ctx.needs_input_grad[1]:
|
|
grad_weight = grad_output.t().mm(input)
|
|
if bias is not None and ctx.needs_input_grad[2]:
|
|
grad_bias = grad_output.sum(0)
|
|
|
|
return grad_input, grad_weight, grad_bias
|
|
|
|
|
|
class ModuleLinear(torch.nn.Module):
|
|
def forward(self, input, weight, bias=None):
|
|
return LinearFunction.apply(input, weight, bias)
|
|
|
|
|
|
class MaterializingGradFunction(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, x):
|
|
ctx.set_materialize_grads(False)
|
|
return x.clone(), x.clone()
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_out1, grad_out2):
|
|
return grad_out1, grad_out2
|
|
|
|
|
|
class MaterializingGradModule(torch.nn.Module):
|
|
def forward(self, x):
|
|
return MaterializingGradFunction.apply(x)
|
|
|
|
|
|
class CustomFuncBwdPrintGraphBreak(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, foo):
|
|
return torch.add(foo, foo)
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
print("graph break!")
|
|
return grad_output
|
|
|
|
|
|
class CustomFuncBwdPrintModule(torch.nn.Module):
|
|
def forward(self, x):
|
|
return CustomFuncBwdPrintGraphBreak.apply(x)
|
|
|
|
|
|
class CustomFuncStrideBwd(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, foo):
|
|
return torch.add(foo, foo)
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
return grad_output.stride()
|
|
|
|
|
|
class CustomFuncStrideModule(torch.nn.Module):
|
|
def forward(self, x):
|
|
return CustomFuncStrideBwd.apply(x)
|
|
|
|
|
|
class CustomFuncSaveForBwd(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, foo):
|
|
result = foo + foo
|
|
result = result + foo
|
|
ctx.save_for_backward(result)
|
|
return result
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
(result,) = ctx.saved_tensors
|
|
return grad_output * math.sqrt(result.numel())
|
|
|
|
|
|
class SaveForBwdModule(torch.nn.Module):
|
|
def forward(self, foo):
|
|
return CustomFuncSaveForBwd().apply(foo)
|
|
|
|
|
|
class ContextSaveAndMark(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, x):
|
|
with torch.no_grad():
|
|
ctx.save_for_backward(x)
|
|
ctx.mark_non_differentiable(x)
|
|
return x
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
return grad_output
|
|
|
|
|
|
class ContextMarkAndSave(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, x):
|
|
with torch.no_grad():
|
|
ctx.mark_non_differentiable(x)
|
|
ctx.save_for_backward(x)
|
|
return x
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
return grad_output
|
|
|
|
|
|
class ModuleWithGradFunc(torch.nn.Module):
|
|
def __init__(self, func):
|
|
super().__init__()
|
|
self.f = func.apply
|
|
|
|
def forward(self, x):
|
|
return self.f(x)
|
|
|
|
|
|
class AutogradFunctionTests(torch._dynamo.test_case.TestCase):
|
|
# Sound behaviors, tested for working capture
|
|
def test_autograd_function_equivalence(self):
|
|
for grad in [True, False]:
|
|
for i in range(1, 5):
|
|
torch._dynamo.reset()
|
|
model = globals()[f"Module{i}"]()
|
|
opt_model = torch._dynamo.optimize("eager")(model)
|
|
self.assertTrue(
|
|
torch.allclose(
|
|
opt_model(torch.ones(2, 3, requires_grad=grad)),
|
|
torch.tensor([2.0], requires_grad=grad),
|
|
)
|
|
)
|
|
|
|
def test_autograd_function_has_graph_break(self):
|
|
for grad in [True, False]:
|
|
x = torch.randn(10, requires_grad=grad)
|
|
for model in [Module5(), Module6()]:
|
|
torch._dynamo.reset()
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_model = torch._dynamo.optimize(cnts)(model)
|
|
for _ in range(3):
|
|
ref = model(x)
|
|
res = opt_model(x)
|
|
self.assertTrue(torch.allclose(ref, res))
|
|
self.assertEqual(cnts.frame_count, 2)
|
|
|
|
def test_linear_setup_context(self):
|
|
model = ModuleLinear()
|
|
opt_model = torch._dynamo.optimize("eager")(model)
|
|
input = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
|
|
weight = torch.randn(3, 2, dtype=torch.double, requires_grad=True)
|
|
optim_result = opt_model(input, weight)
|
|
eager_result = model(input, weight)
|
|
self.assertEqual(optim_result, eager_result)
|
|
|
|
def test_materialize_grad(self):
|
|
model = MaterializingGradModule()
|
|
opt_model = torch._dynamo.optimize("eager")(model)
|
|
x = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
|
|
optim_result = opt_model(x)
|
|
eager_result = model(x)
|
|
self.assertEqual(optim_result, eager_result)
|
|
|
|
def test_print_in_bwd(self):
|
|
model = CustomFuncBwdPrintModule()
|
|
opt_model = torch._dynamo.optimize("eager", nopython=True)(model)
|
|
x = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.Unsupported, ".*BuiltinVariable\\(print\\).*"
|
|
):
|
|
opt_model(x)
|
|
|
|
def test_stride_in_bwd(self):
|
|
model = CustomFuncStrideModule()
|
|
opt_model = torch._dynamo.optimize("eager", nopython=True)(model)
|
|
x = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.Unsupported,
|
|
".*HigherOrderOperator body's output must consist of tensors only",
|
|
):
|
|
opt_model(x)
|
|
|
|
def test_enum_arg(self):
|
|
from enum import Enum
|
|
|
|
class SomeEnum(Enum):
|
|
A = 0
|
|
B = 1
|
|
|
|
class Foo(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, x, e):
|
|
if e is SomeEnum.A:
|
|
return x.sin()
|
|
else:
|
|
return x.cos()
|
|
|
|
@staticmethod
|
|
def backward(ctx, g):
|
|
return g
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def f(x, enum):
|
|
output = Foo.apply(
|
|
x,
|
|
enum,
|
|
)
|
|
return output
|
|
|
|
x = torch.tensor([[1.0, 2, 3], [4, 5, 6]], requires_grad=True)
|
|
y = f(x, SomeEnum.A)
|
|
self.assertEqual(y, x.sin())
|
|
|
|
def test_save_for_bwd(self):
|
|
model = SaveForBwdModule()
|
|
opt_model = torch._dynamo.optimize("eager", nopython=True)(model)
|
|
x = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
|
|
opt_model(x)
|
|
|
|
def test_allow_in_graph(self):
|
|
torch._dynamo.utils.counters.clear()
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
|
|
@torch._dynamo.allow_in_graph
|
|
class AllowInGraphFunc(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, x):
|
|
torch._dynamo.graph_break()
|
|
ctx.x0 = x.size(0)
|
|
return x * 2
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_out):
|
|
return grad_out * ctx.x0
|
|
|
|
@torch.compile(backend=cnt, fullgraph=True)
|
|
def fn(x):
|
|
return AllowInGraphFunc.apply(x)
|
|
|
|
x = torch.rand(2, 3, requires_grad=True)
|
|
result = fn(x)
|
|
|
|
self.assertEqual(result, AllowInGraphFunc.apply(x))
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
def test_once_differentiable(self):
|
|
from torch.autograd.function import once_differentiable
|
|
|
|
torch._dynamo.utils.counters.clear()
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
|
|
class ScaleGradient(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, x):
|
|
return x
|
|
|
|
@staticmethod
|
|
@once_differentiable
|
|
def backward(ctx, grad):
|
|
return grad * 0.5
|
|
|
|
@torch.compile(backend=cnt, fullgraph=True)
|
|
def fn(x):
|
|
return ScaleGradient.apply(x)
|
|
|
|
x = torch.randn(3, requires_grad=True)
|
|
result = fn(x)
|
|
|
|
self.assertEqual(result, ScaleGradient.apply(x))
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
def test_classmethod(self):
|
|
class Shake(torch.autograd.Function):
|
|
@classmethod
|
|
def forward(cls, ctx, foo):
|
|
return foo + foo
|
|
|
|
@classmethod
|
|
def backward(cls, ctx, grad_output):
|
|
return grad_output
|
|
|
|
def f(x):
|
|
return Shake.apply(x)
|
|
|
|
x = torch.randn(4, 4, 4, 4, requires_grad=True)
|
|
opt_m = torch.compile(backend="eager")(f)
|
|
opt_m(x)
|
|
|
|
def test_function_context_save_and_mark(self):
|
|
mod = ModuleWithGradFunc(ContextSaveAndMark)
|
|
args, kwargs = ([torch.rand([1])], {})
|
|
before = mod(*args, **kwargs)
|
|
|
|
torch._dynamo.reset()
|
|
compiled_model = torch._dynamo.optimize("eager")(mod)
|
|
after = compiled_model(*args, **kwargs)
|
|
self.assertEqual(before, after)
|
|
|
|
def test_function_context_mark_and_save(self):
|
|
mod = ModuleWithGradFunc(ContextMarkAndSave)
|
|
args, kwargs = ([torch.rand([1])], {})
|
|
before = mod(*args, **kwargs)
|
|
|
|
torch._dynamo.reset()
|
|
compiled_model = torch._dynamo.optimize("eager")(mod)
|
|
after = compiled_model(*args, **kwargs)
|
|
self.assertEqual(before, after)
|
|
|
|
def test_multi_output(self):
|
|
torch._dynamo.utils.counters.clear()
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
|
|
class Foo(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, x):
|
|
return x.clone(), x.clone()
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad1, grad2):
|
|
return grad1 + grad2
|
|
|
|
@torch.compile(backend=cnt, fullgraph=True)
|
|
def f(x):
|
|
return Foo.apply(x)
|
|
|
|
x = torch.randn(3, requires_grad=True)
|
|
result = f(x)
|
|
|
|
self.assertEqual(result, Foo.apply(x))
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
def test_amp_custom_fwd_bwd(self):
|
|
torch._dynamo.utils.counters.clear()
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
|
|
class MyMM(torch.autograd.Function):
|
|
@staticmethod
|
|
@torch.cuda.amp.custom_fwd
|
|
def forward(ctx, a, b):
|
|
ctx.save_for_backward(a, b)
|
|
return a.mm(b)
|
|
|
|
@staticmethod
|
|
@torch.cuda.amp.custom_bwd
|
|
def backward(ctx, grad):
|
|
a, b = ctx.saved_tensors
|
|
return grad.mm(b.t()), a.t().mm(grad)
|
|
|
|
@torch.compile(backend=cnt, fullgraph=True)
|
|
def fn(a, b):
|
|
return MyMM.apply(a, b)
|
|
|
|
a = torch.randn([64, 64], dtype=torch.float32, requires_grad=True)
|
|
grad = a.clone()
|
|
res = fn(a, a)
|
|
res.backward(grad)
|
|
|
|
self.assertEqual(res, MyMM.apply(a, a))
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
def test_graph_break_if_lifted_free_variable(self):
|
|
torch._dynamo.utils.counters.clear()
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
delta = torch.randn(3)
|
|
|
|
class Foo(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, x):
|
|
return x.clone(), (x + delta).clone()
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad1, grad2):
|
|
return grad1 + grad2
|
|
|
|
@torch.compile(backend=cnt)
|
|
def f(x):
|
|
return Foo.apply(x)
|
|
|
|
x = torch.randn(3, requires_grad=True)
|
|
result = f(x)
|
|
|
|
self.assertEqual(result, Foo.apply(x))
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertEqual(
|
|
list(torch._dynamo.utils.counters["graph_break"].values()), [1]
|
|
)
|
|
|
|
def test_function_with_bound_free_variable(self):
|
|
class LowerBound(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, inputs, bound):
|
|
ctx.save_for_backward(inputs, inputs.new_ones(1) * bound)
|
|
return inputs.clamp(min=bound)
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
inputs, bound = ctx.saved_tensors
|
|
return (inputs >= bound) * grad_output, None
|
|
|
|
class MyMod(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.gamma = torch.nn.Parameter(torch.rand([4, 128, 32, 32]))
|
|
|
|
def forward(self, x):
|
|
gamma = LowerBound.apply(self.gamma, 1)
|
|
return x + gamma
|
|
|
|
mod = MyMod()
|
|
args, kwargs = ([torch.rand([4, 128, 32, 32])], {})
|
|
before = mod(*args, **kwargs)
|
|
|
|
compiled_model = torch._dynamo.optimize("eager")(mod)
|
|
after = compiled_model(*args, **kwargs)
|
|
self.assertEqual(before, after)
|
|
|
|
# I pulled all of these test cases from test_autograd.py
|
|
# In the future, we should make the Dynamo test suite actually
|
|
# run on test_autograd.py (it's disabled right now) and delete these.
|
|
def test_smoke_from_test_autograd(self):
|
|
class Func(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, x):
|
|
out0 = x.clone()
|
|
out1 = x.clone()
|
|
ctx.mark_non_differentiable(out1)
|
|
ctx._materialize_non_diff_grads = False
|
|
return out0, out1
|
|
|
|
@staticmethod
|
|
def backward(ctx, g0, g1):
|
|
assert g1 is None
|
|
return g0
|
|
|
|
def mult1(x):
|
|
return x.prod(dim=-1).prod(dim=-1)
|
|
|
|
class Mult(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, x):
|
|
y = mult1(x)
|
|
ctx.save_for_backward(x, y)
|
|
return y
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
x, y = ctx.saved_tensors
|
|
return (grad_output * y)[:, None, None] / x
|
|
|
|
mult2 = Mult.apply
|
|
|
|
class Double(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, x):
|
|
y = x**2
|
|
ctx.save_for_backward(x, y)
|
|
return y
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
x, _ = ctx.saved_tensors
|
|
return grad_output * 2 * x
|
|
|
|
# this is equivalent, but uses the output of .forward() in .backward()
|
|
class Double2(Double):
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
x, y = ctx.saved_tensors
|
|
return grad_output * 2 * y / x
|
|
|
|
double = Double.apply
|
|
double2 = Double2.apply
|
|
|
|
class Identity(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, a, b):
|
|
return a, a + b
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_a, grad_b):
|
|
return grad_a + grad_b, grad_b
|
|
|
|
class MyFunc2(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, inp):
|
|
return inp.clone()
|
|
|
|
@staticmethod
|
|
def backward(ctx, gO):
|
|
return torch.tensor(float("nan")).expand(10, 10)
|
|
|
|
def run_fn(a):
|
|
out = MyFunc2.apply(a)
|
|
return out.sum()
|
|
|
|
class MyFn(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, inp):
|
|
return inp.view_as(inp)
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad):
|
|
return grad
|
|
|
|
class MyAdder(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, a, b):
|
|
a.add_(b)
|
|
ctx.mark_dirty(a)
|
|
return a
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad):
|
|
return grad, grad
|
|
|
|
class InplaceMul(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, x):
|
|
result = x.mul_(2)
|
|
ctx.mark_dirty(result)
|
|
return result
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
pass
|
|
|
|
@staticmethod
|
|
def jvp(ctx, x_t):
|
|
if jvp_err:
|
|
return x_t
|
|
else:
|
|
return x_t.mul_(2)
|
|
|
|
class MyFn2(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, x, y):
|
|
return x + y, x
|
|
|
|
@staticmethod
|
|
def vjp(ctx, gO1, gO2):
|
|
return gO1 + gO2, gO1
|
|
|
|
@staticmethod
|
|
def jvp(ctx, x_t, y_t):
|
|
return x_t + y_t, fn(x_t)
|
|
|
|
class MyFn3(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, inp, inplace):
|
|
view = inp.clone()[:3]
|
|
if inplace:
|
|
view += 2
|
|
return view
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad):
|
|
return grad, None
|
|
|
|
def test():
|
|
a = torch.tensor(1.0, requires_grad=True)
|
|
out = Func.apply(a)[0]
|
|
out.backward()
|
|
|
|
x = torch.ones(2, 4, 4).requires_grad_()
|
|
mult2(x)
|
|
|
|
x = torch.tensor(2).double().requires_grad_()
|
|
double(x)
|
|
double2(x)
|
|
|
|
x = torch.randn(5, 5, requires_grad=True)
|
|
y = torch.randn(5, 5, requires_grad=True)
|
|
q, p = Identity.apply(x, y)
|
|
|
|
a = torch.rand(1, 2)
|
|
b = torch.rand(1, requires_grad=True)
|
|
view_a = MyFn.apply(a)
|
|
|
|
a = torch.ones(2, requires_grad=True)
|
|
b = torch.ones(2, requires_grad=True)
|
|
c = MyAdder.apply(a.clone(), b)
|
|
c.sum().backward()
|
|
|
|
z = torch.tensor(1.0, requires_grad=True)
|
|
x = z.clone()
|
|
y = InplaceMul.apply(x)
|
|
|
|
a = torch.tensor(1.0, dtype=torch.double, requires_grad=True)
|
|
b = torch.tensor(1.0, dtype=torch.double, requires_grad=True)
|
|
c = torch.tensor(1.0, dtype=torch.double)
|
|
d = torch.tensor(1.0, dtype=torch.double)
|
|
MyFn2.apply(a, b)
|
|
MyFn2.apply(c, d)
|
|
|
|
base = torch.rand(10, requires_grad=True)
|
|
foo = MyFn3.apply(base, False)
|
|
|
|
test()
|
|
opt_test = torch._dynamo.optimize("eager")(test)
|
|
opt_test()
|
|
|
|
def test_tensor_subclass_intermediary_input(self):
|
|
class FooTensor(torch.Tensor):
|
|
@staticmethod
|
|
def __new__(cls, data, config, scale):
|
|
self = torch.Tensor._make_wrapper_subclass(
|
|
cls,
|
|
config[0],
|
|
strides=config[1],
|
|
storage_offset=config[2],
|
|
dtype=config[3],
|
|
layout=config[4],
|
|
requires_grad=config[5],
|
|
device=data.device,
|
|
)
|
|
self._data = data
|
|
self._config = config
|
|
self._scale = scale
|
|
return self
|
|
|
|
def __repr__(self):
|
|
return "FooTensor"
|
|
|
|
def __tensor_flatten__(self):
|
|
return ("_data",), (
|
|
self._config,
|
|
self._scale,
|
|
)
|
|
|
|
@staticmethod
|
|
def __tensor_unflatten__(tensors, metadatas, outer_size, outer_stride):
|
|
return FooTensor(tensors["_data"], metadatas[0], metadatas[1])
|
|
|
|
@classmethod
|
|
def __torch_dispatch__(cls, func, types, args, kwargs=None):
|
|
# handling clone and view is so dynamo fakefication passes, it's not
|
|
# intended to be handling user code
|
|
if func == torch.ops.aten.clone.default:
|
|
return FooTensor(
|
|
args[0]._data.clone(), args[0]._config, args[0]._scale
|
|
)
|
|
elif func == torch.ops.aten.view.default:
|
|
new_data = args[0]._data.view(*args[1:])
|
|
return FooTensor(new_data, args[0]._config, args[0]._scale)
|
|
|
|
raise NotImplementedError()
|
|
|
|
__torch_function__ = torch._C._disabled_torch_function_impl
|
|
|
|
class foo_autograd_fn(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, x):
|
|
# access some data from `x`, where `x` is a tensor subclass
|
|
x2 = x._data + 1.0
|
|
# create and return a tensor subclass from within a torch.autograd.Function
|
|
x3 = FooTensor(x2, x._config, x._scale)
|
|
return x3._data
|
|
|
|
@staticmethod
|
|
def backward(ctx, g):
|
|
return g
|
|
|
|
x_ref = torch.randn(4, 4).requires_grad_(True)
|
|
x = copy.deepcopy(x_ref)
|
|
scale = torch.tensor(1.0)
|
|
# Weird that this is needed, but not having this breaks a lot of things
|
|
torch._dynamo.allow_in_graph(FooTensor)
|
|
|
|
def foo(x, scale):
|
|
config = (
|
|
x.size(),
|
|
x.stride(),
|
|
x.storage_offset(),
|
|
x.dtype,
|
|
x.layout,
|
|
x.requires_grad,
|
|
)
|
|
x = FooTensor(x, config, scale)
|
|
x = foo_autograd_fn.apply(x)
|
|
return x
|
|
|
|
y_ref = foo(x_ref, scale)
|
|
y_ref.sum().backward()
|
|
|
|
foo_opt = torch.compile(foo, backend="eager")
|
|
y = foo_opt(x, scale)
|
|
y.sum().backward()
|
|
|
|
self.assertEqual(y, y_ref)
|
|
self.assertEqual(x.grad, x_ref.grad)
|
|
|
|
def test_smuggle_symint_issue_111031(self):
|
|
from torch.autograd import Function
|
|
|
|
class Foo(Function):
|
|
@staticmethod
|
|
def forward(ctx, x):
|
|
ctx.x0 = x.size(0)
|
|
return x * 2
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_out):
|
|
return grad_out * ctx.x0
|
|
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
|
|
@torch.compile(backend=cnts, fullgraph=True, dynamic=True)
|
|
def foo(x):
|
|
return Foo.apply(x)
|
|
|
|
foo(torch.randn(2, requires_grad=True))
|
|
self.assertEqual(cnts.frame_count, 1)
|
|
|
|
def test_smuggle_tensor_and_complex_structures(self):
|
|
from torch.autograd import Function
|
|
|
|
class Foo(Function):
|
|
@staticmethod
|
|
def forward(ctx, x):
|
|
ctx.x0 = x
|
|
ctx.x1 = [1, 2, 3]
|
|
return x * 2
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_out):
|
|
x0mul = grad_out * ctx.x0
|
|
for i in ctx.x1:
|
|
x0mul = (x0mul * i) + x0mul
|
|
return x0mul
|
|
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
|
|
@torch.compile(backend=cnts, fullgraph=True, dynamic=True)
|
|
def foo(x):
|
|
return Foo.apply(x)
|
|
|
|
foo(torch.randn(2, requires_grad=True))
|
|
self.assertEqual(cnts.frame_count, 1)
|
|
|
|
@requires_cuda()
|
|
@skipIfRocm
|
|
def test_triton_kernel_basic(self):
|
|
class Add(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, x, y):
|
|
ctx.save_for_backward(x, y)
|
|
output = torch.zeros_like(x)
|
|
n_elements = output.numel()
|
|
grid = lambda meta: ( # noqa: E731
|
|
triton.cdiv(n_elements, meta["BLOCK_SIZE"]),
|
|
)
|
|
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16)
|
|
return output
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
x, y = ctx.saved_tensors
|
|
return x * grad_output, y * grad_output
|
|
|
|
@torch.compile(fullgraph=True, backend="inductor")
|
|
def f(x, y):
|
|
z = Add.apply(x, y)
|
|
return z
|
|
|
|
x = torch.randn(10, device="cuda", requires_grad=True)
|
|
y = torch.randn(10, device="cuda", requires_grad=True)
|
|
z = f(x, y)
|
|
loss = z.sum()
|
|
loss.backward()
|
|
self.assertEqual(x + y, z)
|
|
|
|
@requires_cuda()
|
|
@skipIfRocm
|
|
def test_triton_kernel_multiple_out(self):
|
|
class Add(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, x, y):
|
|
ctx.save_for_backward(x, y)
|
|
ctx.t1 = x
|
|
ctx.t2 = y
|
|
output = torch.zeros_like(x)
|
|
n_elements = output.numel()
|
|
grid = lambda meta: ( # noqa: E731
|
|
triton.cdiv(n_elements, meta["BLOCK_SIZE"]),
|
|
)
|
|
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16)
|
|
return output, x
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output, old_x):
|
|
x, y = ctx.saved_tensors
|
|
x1 = ctx.t1
|
|
y1 = ctx.t2
|
|
return old_x * x * x1 * grad_output, y * y1 * grad_output
|
|
|
|
@torch.compile(fullgraph=True, backend="inductor")
|
|
def f(x, y):
|
|
z = Add.apply(x, y)
|
|
return z
|
|
|
|
x = torch.randn(10, device="cuda", requires_grad=True)
|
|
y = torch.randn(10, device="cuda", requires_grad=True)
|
|
z, _ = f(x, y)
|
|
loss = z.sum()
|
|
loss.backward()
|
|
self.assertEqual(x + y, z)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
run_tests()
|