pytorch/test/dynamo/test_autograd_function.py
Oguz Ulgen 97891b184c [Dynamo] Trace autograd.function in dynamo when inputs require grad (#116358)
For training graphs (when inputs require grad), previously, we would speculate the forward and backward graph to determine if there are any graph breaks, side effect and etc but would not actually use these speculated graphs. We would just insert a call function node on the graph and later rely on autograd's tracing.

This approach does not work for more generalized graphs like graphs that include user defined triton kernels because autograd is not able to do the higher order function conversation.

This PR speculates the forward and backward functions and emits them in a HOF that later gets used via templating mechanism.

While working on this PR, I have exposed some bugs in the current tracing due to trampoline functions losing the source information resulting in incorrect graphs being produced. I have fixed these source information bugs and killed the trampolines.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116358
Approved by: https://github.com/jansel
2023-12-30 01:51:30 +00:00

918 lines
28 KiB
Python

# Owner(s): ["module: dynamo"]
import copy
import math
import torch
import torch._dynamo.test_case
import torch._dynamo.testing
import torch._dynamo.utils
from torch.testing._internal.common_utils import skipIfRocm
from torch.testing._internal.triton_utils import HAS_CUDA, requires_cuda
if HAS_CUDA:
import triton
from torch.testing._internal.triton_utils import add_kernel
class CustomFunc1(torch.autograd.Function):
@staticmethod
def forward(ctx, foo):
return foo + foo
@staticmethod
def backward(ctx, grad_output):
return grad_output
class CustomFunc3(torch.autograd.Function):
# Test there is graph break in forward function
@staticmethod
def forward(ctx, foo):
result = foo + foo
torch._dynamo.graph_break()
result = result + foo
ctx.save_for_backward(result)
return result
@staticmethod
def backward(ctx, grad_output):
(result,) = ctx.saved_tensors
return grad_output * math.sqrt(result.numel())
class Module1(torch.nn.Module):
def forward(self, foo):
return CustomFunc1().apply(foo)
class Module2(torch.nn.Module):
def __init__(self):
super().__init__()
self.fn = CustomFunc1.apply
def forward(self, foo):
return self.fn(foo)
class Module3(torch.nn.Module):
def forward(self, foo):
return CustomFunc1().apply(foo)
class Module4(torch.nn.Module):
def __init__(self):
super().__init__()
self.fn = CustomFunc1.apply
def forward(self, foo):
return self.fn(foo)
class Module5(torch.nn.Module):
def forward(self, foo):
return CustomFunc3().apply(foo)
class Module6(torch.nn.Module):
def __init__(self):
super().__init__()
self.fn = CustomFunc3.apply
def forward(self, foo):
return self.fn(foo)
class LinearFunction(torch.autograd.Function):
# Note that forward, setup_context, and backward are @staticmethods
@staticmethod
def forward(input, weight, bias):
output = input.mm(weight.t())
if bias is not None:
output += bias.unsqueeze(0).expand_as(output)
return output
@staticmethod
# inputs is a Tuple of all of the inputs passed to forward.
# output is the output of the forward().
def setup_context(ctx, inputs, output):
input, weight, bias = inputs
ctx.save_for_backward(input, weight, bias)
# This function has only a single output, so it gets only one gradient
@staticmethod
def backward(ctx, grad_output):
input, weight, bias = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
if ctx.needs_input_grad[0]:
grad_input = grad_output.mm(weight)
if ctx.needs_input_grad[1]:
grad_weight = grad_output.t().mm(input)
if bias is not None and ctx.needs_input_grad[2]:
grad_bias = grad_output.sum(0)
return grad_input, grad_weight, grad_bias
class ModuleLinear(torch.nn.Module):
def forward(self, input, weight, bias=None):
return LinearFunction.apply(input, weight, bias)
class MaterializingGradFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.set_materialize_grads(False)
return x.clone(), x.clone()
@staticmethod
def backward(ctx, grad_out1, grad_out2):
return grad_out1, grad_out2
class MaterializingGradModule(torch.nn.Module):
def forward(self, x):
return MaterializingGradFunction.apply(x)
class CustomFuncBwdPrintGraphBreak(torch.autograd.Function):
@staticmethod
def forward(ctx, foo):
return torch.add(foo, foo)
@staticmethod
def backward(ctx, grad_output):
print("graph break!")
return grad_output
class CustomFuncBwdPrintModule(torch.nn.Module):
def forward(self, x):
return CustomFuncBwdPrintGraphBreak.apply(x)
class CustomFuncStrideBwd(torch.autograd.Function):
@staticmethod
def forward(ctx, foo):
return torch.add(foo, foo)
@staticmethod
def backward(ctx, grad_output):
return grad_output.stride()
class CustomFuncStrideModule(torch.nn.Module):
def forward(self, x):
return CustomFuncStrideBwd.apply(x)
class CustomFuncSaveForBwd(torch.autograd.Function):
@staticmethod
def forward(ctx, foo):
result = foo + foo
result = result + foo
ctx.save_for_backward(result)
return result
@staticmethod
def backward(ctx, grad_output):
(result,) = ctx.saved_tensors
return grad_output * math.sqrt(result.numel())
class SaveForBwdModule(torch.nn.Module):
def forward(self, foo):
return CustomFuncSaveForBwd().apply(foo)
class ContextSaveAndMark(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
with torch.no_grad():
ctx.save_for_backward(x)
ctx.mark_non_differentiable(x)
return x
@staticmethod
def backward(ctx, grad_output):
return grad_output
class ContextMarkAndSave(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
with torch.no_grad():
ctx.mark_non_differentiable(x)
ctx.save_for_backward(x)
return x
@staticmethod
def backward(ctx, grad_output):
return grad_output
class ModuleWithGradFunc(torch.nn.Module):
def __init__(self, func):
super().__init__()
self.f = func.apply
def forward(self, x):
return self.f(x)
class AutogradFunctionTests(torch._dynamo.test_case.TestCase):
# Sound behaviors, tested for working capture
def test_autograd_function_equivalence(self):
for grad in [True, False]:
for i in range(1, 5):
torch._dynamo.reset()
model = globals()[f"Module{i}"]()
opt_model = torch._dynamo.optimize("eager")(model)
self.assertTrue(
torch.allclose(
opt_model(torch.ones(2, 3, requires_grad=grad)),
torch.tensor([2.0], requires_grad=grad),
)
)
def test_autograd_function_has_graph_break(self):
for grad in [True, False]:
x = torch.randn(10, requires_grad=grad)
for model in [Module5(), Module6()]:
torch._dynamo.reset()
cnts = torch._dynamo.testing.CompileCounter()
opt_model = torch._dynamo.optimize(cnts)(model)
for _ in range(3):
ref = model(x)
res = opt_model(x)
self.assertTrue(torch.allclose(ref, res))
self.assertEqual(cnts.frame_count, 2)
def test_linear_setup_context(self):
model = ModuleLinear()
opt_model = torch._dynamo.optimize("eager")(model)
input = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
weight = torch.randn(3, 2, dtype=torch.double, requires_grad=True)
optim_result = opt_model(input, weight)
eager_result = model(input, weight)
self.assertEqual(optim_result, eager_result)
def test_materialize_grad(self):
model = MaterializingGradModule()
opt_model = torch._dynamo.optimize("eager")(model)
x = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
optim_result = opt_model(x)
eager_result = model(x)
self.assertEqual(optim_result, eager_result)
def test_print_in_bwd(self):
model = CustomFuncBwdPrintModule()
opt_model = torch._dynamo.optimize("eager", nopython=True)(model)
x = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
with self.assertRaisesRegex(
torch._dynamo.exc.Unsupported, ".*BuiltinVariable\\(print\\).*"
):
opt_model(x)
def test_stride_in_bwd(self):
model = CustomFuncStrideModule()
opt_model = torch._dynamo.optimize("eager", nopython=True)(model)
x = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
with self.assertRaisesRegex(
torch._dynamo.exc.Unsupported,
".*HigherOrderOperator body's output must consist of tensors only",
):
opt_model(x)
def test_enum_arg(self):
from enum import Enum
class SomeEnum(Enum):
A = 0
B = 1
class Foo(torch.autograd.Function):
@staticmethod
def forward(ctx, x, e):
if e is SomeEnum.A:
return x.sin()
else:
return x.cos()
@staticmethod
def backward(ctx, g):
return g
@torch.compile(backend="eager", fullgraph=True)
def f(x, enum):
output = Foo.apply(
x,
enum,
)
return output
x = torch.tensor([[1.0, 2, 3], [4, 5, 6]], requires_grad=True)
y = f(x, SomeEnum.A)
self.assertEqual(y, x.sin())
def test_save_for_bwd(self):
model = SaveForBwdModule()
opt_model = torch._dynamo.optimize("eager", nopython=True)(model)
x = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
opt_model(x)
def test_allow_in_graph(self):
torch._dynamo.utils.counters.clear()
cnt = torch._dynamo.testing.CompileCounter()
@torch._dynamo.allow_in_graph
class AllowInGraphFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
torch._dynamo.graph_break()
ctx.x0 = x.size(0)
return x * 2
@staticmethod
def backward(ctx, grad_out):
return grad_out * ctx.x0
@torch.compile(backend=cnt, fullgraph=True)
def fn(x):
return AllowInGraphFunc.apply(x)
x = torch.rand(2, 3, requires_grad=True)
result = fn(x)
self.assertEqual(result, AllowInGraphFunc.apply(x))
self.assertEqual(cnt.frame_count, 1)
def test_once_differentiable(self):
from torch.autograd.function import once_differentiable
torch._dynamo.utils.counters.clear()
cnt = torch._dynamo.testing.CompileCounter()
class ScaleGradient(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return x
@staticmethod
@once_differentiable
def backward(ctx, grad):
return grad * 0.5
@torch.compile(backend=cnt, fullgraph=True)
def fn(x):
return ScaleGradient.apply(x)
x = torch.randn(3, requires_grad=True)
result = fn(x)
self.assertEqual(result, ScaleGradient.apply(x))
self.assertEqual(cnt.frame_count, 1)
def test_classmethod(self):
class Shake(torch.autograd.Function):
@classmethod
def forward(cls, ctx, foo):
return foo + foo
@classmethod
def backward(cls, ctx, grad_output):
return grad_output
def f(x):
return Shake.apply(x)
x = torch.randn(4, 4, 4, 4, requires_grad=True)
opt_m = torch.compile(backend="eager")(f)
opt_m(x)
def test_function_context_save_and_mark(self):
mod = ModuleWithGradFunc(ContextSaveAndMark)
args, kwargs = ([torch.rand([1])], {})
before = mod(*args, **kwargs)
torch._dynamo.reset()
compiled_model = torch._dynamo.optimize("eager")(mod)
after = compiled_model(*args, **kwargs)
self.assertEqual(before, after)
def test_function_context_mark_and_save(self):
mod = ModuleWithGradFunc(ContextMarkAndSave)
args, kwargs = ([torch.rand([1])], {})
before = mod(*args, **kwargs)
torch._dynamo.reset()
compiled_model = torch._dynamo.optimize("eager")(mod)
after = compiled_model(*args, **kwargs)
self.assertEqual(before, after)
def test_multi_output(self):
torch._dynamo.utils.counters.clear()
cnt = torch._dynamo.testing.CompileCounter()
class Foo(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return x.clone(), x.clone()
@staticmethod
def backward(ctx, grad1, grad2):
return grad1 + grad2
@torch.compile(backend=cnt, fullgraph=True)
def f(x):
return Foo.apply(x)
x = torch.randn(3, requires_grad=True)
result = f(x)
self.assertEqual(result, Foo.apply(x))
self.assertEqual(cnt.frame_count, 1)
def test_amp_custom_fwd_bwd(self):
torch._dynamo.utils.counters.clear()
cnt = torch._dynamo.testing.CompileCounter()
class MyMM(torch.autograd.Function):
@staticmethod
@torch.cuda.amp.custom_fwd
def forward(ctx, a, b):
ctx.save_for_backward(a, b)
return a.mm(b)
@staticmethod
@torch.cuda.amp.custom_bwd
def backward(ctx, grad):
a, b = ctx.saved_tensors
return grad.mm(b.t()), a.t().mm(grad)
@torch.compile(backend=cnt, fullgraph=True)
def fn(a, b):
return MyMM.apply(a, b)
a = torch.randn([64, 64], dtype=torch.float32, requires_grad=True)
grad = a.clone()
res = fn(a, a)
res.backward(grad)
self.assertEqual(res, MyMM.apply(a, a))
self.assertEqual(cnt.frame_count, 1)
def test_graph_break_if_lifted_free_variable(self):
torch._dynamo.utils.counters.clear()
cnt = torch._dynamo.testing.CompileCounter()
delta = torch.randn(3)
class Foo(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return x.clone(), (x + delta).clone()
@staticmethod
def backward(ctx, grad1, grad2):
return grad1 + grad2
@torch.compile(backend=cnt)
def f(x):
return Foo.apply(x)
x = torch.randn(3, requires_grad=True)
result = f(x)
self.assertEqual(result, Foo.apply(x))
self.assertEqual(cnt.frame_count, 1)
self.assertEqual(
list(torch._dynamo.utils.counters["graph_break"].values()), [1]
)
def test_function_with_bound_free_variable(self):
class LowerBound(torch.autograd.Function):
@staticmethod
def forward(ctx, inputs, bound):
ctx.save_for_backward(inputs, inputs.new_ones(1) * bound)
return inputs.clamp(min=bound)
@staticmethod
def backward(ctx, grad_output):
inputs, bound = ctx.saved_tensors
return (inputs >= bound) * grad_output, None
class MyMod(torch.nn.Module):
def __init__(self):
super().__init__()
self.gamma = torch.nn.Parameter(torch.rand([4, 128, 32, 32]))
def forward(self, x):
gamma = LowerBound.apply(self.gamma, 1)
return x + gamma
mod = MyMod()
args, kwargs = ([torch.rand([4, 128, 32, 32])], {})
before = mod(*args, **kwargs)
compiled_model = torch._dynamo.optimize("eager")(mod)
after = compiled_model(*args, **kwargs)
self.assertEqual(before, after)
# I pulled all of these test cases from test_autograd.py
# In the future, we should make the Dynamo test suite actually
# run on test_autograd.py (it's disabled right now) and delete these.
def test_smoke_from_test_autograd(self):
class Func(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
out0 = x.clone()
out1 = x.clone()
ctx.mark_non_differentiable(out1)
ctx._materialize_non_diff_grads = False
return out0, out1
@staticmethod
def backward(ctx, g0, g1):
assert g1 is None
return g0
def mult1(x):
return x.prod(dim=-1).prod(dim=-1)
class Mult(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
y = mult1(x)
ctx.save_for_backward(x, y)
return y
@staticmethod
def backward(ctx, grad_output):
x, y = ctx.saved_tensors
return (grad_output * y)[:, None, None] / x
mult2 = Mult.apply
class Double(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
y = x**2
ctx.save_for_backward(x, y)
return y
@staticmethod
def backward(ctx, grad_output):
x, _ = ctx.saved_tensors
return grad_output * 2 * x
# this is equivalent, but uses the output of .forward() in .backward()
class Double2(Double):
@staticmethod
def backward(ctx, grad_output):
x, y = ctx.saved_tensors
return grad_output * 2 * y / x
double = Double.apply
double2 = Double2.apply
class Identity(torch.autograd.Function):
@staticmethod
def forward(ctx, a, b):
return a, a + b
@staticmethod
def backward(ctx, grad_a, grad_b):
return grad_a + grad_b, grad_b
class MyFunc2(torch.autograd.Function):
@staticmethod
def forward(ctx, inp):
return inp.clone()
@staticmethod
def backward(ctx, gO):
return torch.tensor(float("nan")).expand(10, 10)
def run_fn(a):
out = MyFunc2.apply(a)
return out.sum()
class MyFn(torch.autograd.Function):
@staticmethod
def forward(ctx, inp):
return inp.view_as(inp)
@staticmethod
def backward(ctx, grad):
return grad
class MyAdder(torch.autograd.Function):
@staticmethod
def forward(ctx, a, b):
a.add_(b)
ctx.mark_dirty(a)
return a
@staticmethod
def backward(ctx, grad):
return grad, grad
class InplaceMul(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
result = x.mul_(2)
ctx.mark_dirty(result)
return result
@staticmethod
def backward(ctx, grad_output):
pass
@staticmethod
def jvp(ctx, x_t):
if jvp_err:
return x_t
else:
return x_t.mul_(2)
class MyFn2(torch.autograd.Function):
@staticmethod
def forward(ctx, x, y):
return x + y, x
@staticmethod
def vjp(ctx, gO1, gO2):
return gO1 + gO2, gO1
@staticmethod
def jvp(ctx, x_t, y_t):
return x_t + y_t, fn(x_t)
class MyFn3(torch.autograd.Function):
@staticmethod
def forward(ctx, inp, inplace):
view = inp.clone()[:3]
if inplace:
view += 2
return view
@staticmethod
def backward(ctx, grad):
return grad, None
def test():
a = torch.tensor(1.0, requires_grad=True)
out = Func.apply(a)[0]
out.backward()
x = torch.ones(2, 4, 4).requires_grad_()
mult2(x)
x = torch.tensor(2).double().requires_grad_()
double(x)
double2(x)
x = torch.randn(5, 5, requires_grad=True)
y = torch.randn(5, 5, requires_grad=True)
q, p = Identity.apply(x, y)
a = torch.rand(1, 2)
b = torch.rand(1, requires_grad=True)
view_a = MyFn.apply(a)
a = torch.ones(2, requires_grad=True)
b = torch.ones(2, requires_grad=True)
c = MyAdder.apply(a.clone(), b)
c.sum().backward()
z = torch.tensor(1.0, requires_grad=True)
x = z.clone()
y = InplaceMul.apply(x)
a = torch.tensor(1.0, dtype=torch.double, requires_grad=True)
b = torch.tensor(1.0, dtype=torch.double, requires_grad=True)
c = torch.tensor(1.0, dtype=torch.double)
d = torch.tensor(1.0, dtype=torch.double)
MyFn2.apply(a, b)
MyFn2.apply(c, d)
base = torch.rand(10, requires_grad=True)
foo = MyFn3.apply(base, False)
test()
opt_test = torch._dynamo.optimize("eager")(test)
opt_test()
def test_tensor_subclass_intermediary_input(self):
class FooTensor(torch.Tensor):
@staticmethod
def __new__(cls, data, config, scale):
self = torch.Tensor._make_wrapper_subclass(
cls,
config[0],
strides=config[1],
storage_offset=config[2],
dtype=config[3],
layout=config[4],
requires_grad=config[5],
device=data.device,
)
self._data = data
self._config = config
self._scale = scale
return self
def __repr__(self):
return "FooTensor"
def __tensor_flatten__(self):
return ("_data",), (
self._config,
self._scale,
)
@staticmethod
def __tensor_unflatten__(tensors, metadatas, outer_size, outer_stride):
return FooTensor(tensors["_data"], metadatas[0], metadatas[1])
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):
# handling clone and view is so dynamo fakefication passes, it's not
# intended to be handling user code
if func == torch.ops.aten.clone.default:
return FooTensor(
args[0]._data.clone(), args[0]._config, args[0]._scale
)
elif func == torch.ops.aten.view.default:
new_data = args[0]._data.view(*args[1:])
return FooTensor(new_data, args[0]._config, args[0]._scale)
raise NotImplementedError()
__torch_function__ = torch._C._disabled_torch_function_impl
class foo_autograd_fn(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
# access some data from `x`, where `x` is a tensor subclass
x2 = x._data + 1.0
# create and return a tensor subclass from within a torch.autograd.Function
x3 = FooTensor(x2, x._config, x._scale)
return x3._data
@staticmethod
def backward(ctx, g):
return g
x_ref = torch.randn(4, 4).requires_grad_(True)
x = copy.deepcopy(x_ref)
scale = torch.tensor(1.0)
# Weird that this is needed, but not having this breaks a lot of things
torch._dynamo.allow_in_graph(FooTensor)
def foo(x, scale):
config = (
x.size(),
x.stride(),
x.storage_offset(),
x.dtype,
x.layout,
x.requires_grad,
)
x = FooTensor(x, config, scale)
x = foo_autograd_fn.apply(x)
return x
y_ref = foo(x_ref, scale)
y_ref.sum().backward()
foo_opt = torch.compile(foo, backend="eager")
y = foo_opt(x, scale)
y.sum().backward()
self.assertEqual(y, y_ref)
self.assertEqual(x.grad, x_ref.grad)
def test_smuggle_symint_issue_111031(self):
from torch.autograd import Function
class Foo(Function):
@staticmethod
def forward(ctx, x):
ctx.x0 = x.size(0)
return x * 2
@staticmethod
def backward(ctx, grad_out):
return grad_out * ctx.x0
cnts = torch._dynamo.testing.CompileCounter()
@torch.compile(backend=cnts, fullgraph=True, dynamic=True)
def foo(x):
return Foo.apply(x)
foo(torch.randn(2, requires_grad=True))
self.assertEqual(cnts.frame_count, 1)
def test_smuggle_tensor_and_complex_structures(self):
from torch.autograd import Function
class Foo(Function):
@staticmethod
def forward(ctx, x):
ctx.x0 = x
ctx.x1 = [1, 2, 3]
return x * 2
@staticmethod
def backward(ctx, grad_out):
x0mul = grad_out * ctx.x0
for i in ctx.x1:
x0mul = (x0mul * i) + x0mul
return x0mul
cnts = torch._dynamo.testing.CompileCounter()
@torch.compile(backend=cnts, fullgraph=True, dynamic=True)
def foo(x):
return Foo.apply(x)
foo(torch.randn(2, requires_grad=True))
self.assertEqual(cnts.frame_count, 1)
@requires_cuda()
@skipIfRocm
def test_triton_kernel_basic(self):
class Add(torch.autograd.Function):
@staticmethod
def forward(ctx, x, y):
ctx.save_for_backward(x, y)
output = torch.zeros_like(x)
n_elements = output.numel()
grid = lambda meta: ( # noqa: E731
triton.cdiv(n_elements, meta["BLOCK_SIZE"]),
)
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16)
return output
@staticmethod
def backward(ctx, grad_output):
x, y = ctx.saved_tensors
return x * grad_output, y * grad_output
@torch.compile(fullgraph=True, backend="inductor")
def f(x, y):
z = Add.apply(x, y)
return z
x = torch.randn(10, device="cuda", requires_grad=True)
y = torch.randn(10, device="cuda", requires_grad=True)
z = f(x, y)
loss = z.sum()
loss.backward()
self.assertEqual(x + y, z)
@requires_cuda()
@skipIfRocm
def test_triton_kernel_multiple_out(self):
class Add(torch.autograd.Function):
@staticmethod
def forward(ctx, x, y):
ctx.save_for_backward(x, y)
ctx.t1 = x
ctx.t2 = y
output = torch.zeros_like(x)
n_elements = output.numel()
grid = lambda meta: ( # noqa: E731
triton.cdiv(n_elements, meta["BLOCK_SIZE"]),
)
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16)
return output, x
@staticmethod
def backward(ctx, grad_output, old_x):
x, y = ctx.saved_tensors
x1 = ctx.t1
y1 = ctx.t2
return old_x * x * x1 * grad_output, y * y1 * grad_output
@torch.compile(fullgraph=True, backend="inductor")
def f(x, y):
z = Add.apply(x, y)
return z
x = torch.randn(10, device="cuda", requires_grad=True)
y = torch.randn(10, device="cuda", requires_grad=True)
z, _ = f(x, y)
loss = z.sum()
loss.backward()
self.assertEqual(x + y, z)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()