mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
* Deprecate ctx.saved_variables via python warning.
Advises replacing saved_variables with saved_tensors.
Also replaces all instances of ctx.saved_variables with ctx.saved_tensors in the
codebase.
Test by running:
```
import torch
from torch.autograd import Function
class MyFunction(Function):
@staticmethod
def forward(ctx, tensor1, tensor2):
ctx.save_for_backward(tensor1, tensor2)
return tensor1 + tensor2
@staticmethod
def backward(ctx, grad_output):
var1, var2 = ctx.saved_variables
return (grad_output, grad_output)
x = torch.randn((3, 3), requires_grad=True)
y = torch.randn((3, 3), requires_grad=True)
model = MyFunction()
model.apply(x, y).sum().backward()
```
and assert the warning shows up.
* Address comments
* Add deprecation test for saved_variables
|
||
|---|---|---|
| .. | ||
| autograd.rst | ||
| broadcasting.rst | ||
| cuda.rst | ||
| extending.rst | ||
| faq.rst | ||
| multiprocessing.rst | ||
| serialization.rst | ||