mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Make PT2 compile backprop through custom op without autograd key a hard error (#166367)
Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/166367 Approved by: https://github.com/bdhirsh
This commit is contained in:
parent
8f40a0c634
commit
4acc66f119
|
|
@ -414,6 +414,7 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
|
||||||
|
|
||||||
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
|
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
|
||||||
model = Model().to(self.device)
|
model = Model().to(self.device)
|
||||||
|
model.emb.weight.requires_grad = False
|
||||||
model_compiled = torch.compile(model)
|
model_compiled = torch.compile(model)
|
||||||
inp = torch.tensor([[2, 1, 3, 0]], dtype=torch.long, device=self.device)
|
inp = torch.tensor([[2, 1, 3, 0]], dtype=torch.long, device=self.device)
|
||||||
out = model_compiled(inp, self.world_size, **self.get_world_trs())
|
out = model_compiled(inp, self.world_size, **self.get_world_trs())
|
||||||
|
|
@ -1340,13 +1341,11 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
||||||
assert counter.op_count == 3 # It generates 2 getattr to unpack the array
|
assert counter.op_count == 3 # It generates 2 getattr to unpack the array
|
||||||
assert same(out, correct)
|
assert same(out, correct)
|
||||||
|
|
||||||
|
# This doesn't work in all cases, and now we properly loudly error.
|
||||||
|
# See: https://github.com/pytorch/pytorch/issues/151240
|
||||||
|
# When differentiable funcols are implemented can revert.
|
||||||
|
@unittest.expectedFailure
|
||||||
def test_backwards(self):
|
def test_backwards(self):
|
||||||
"""
|
|
||||||
It's probably not that common to need backwards support for collectives.
|
|
||||||
|
|
||||||
However, I wanted to at least see if it was possible to support it as a design goal.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def func(inp):
|
def func(inp):
|
||||||
ar = _functional_collectives.all_reduce(inp, "sum", "0")
|
ar = _functional_collectives.all_reduce(inp, "sum", "0")
|
||||||
return ar
|
return ar
|
||||||
|
|
|
||||||
|
|
@ -9757,6 +9757,17 @@ def ___make_guard_fn():
|
||||||
def foo_impl(x, y):
|
def foo_impl(x, y):
|
||||||
return torch.cat([x, y])
|
return torch.cat([x, y])
|
||||||
|
|
||||||
|
def setup_context(ctx, inputs, output):
|
||||||
|
(x, _) = inputs
|
||||||
|
ctx.xs = x.shape[0]
|
||||||
|
|
||||||
|
def foo_backward(ctx, grad):
|
||||||
|
return grad[: ctx.xs], grad[ctx.xs :]
|
||||||
|
|
||||||
|
torch.library.register_autograd(
|
||||||
|
"mylib::foo", foo_backward, setup_context=setup_context
|
||||||
|
)
|
||||||
|
|
||||||
@torch.compile(backend="aot_eager", fullgraph=True)
|
@torch.compile(backend="aot_eager", fullgraph=True)
|
||||||
def f(x, i):
|
def f(x, i):
|
||||||
i0, i1 = i.tolist()
|
i0, i1 = i.tolist()
|
||||||
|
|
|
||||||
|
|
@ -1254,6 +1254,8 @@ def forward(self, x_1: "f32[2][1]cpu"):
|
||||||
torch._dynamo.reset()
|
torch._dynamo.reset()
|
||||||
|
|
||||||
mod = SimpleModule().cuda()
|
mod = SimpleModule().cuda()
|
||||||
|
for p in mod.parameters():
|
||||||
|
p.requires_grad = False
|
||||||
compiled = torch.compile(mod, backend="inductor")
|
compiled = torch.compile(mod, backend="inductor")
|
||||||
compiled(torch.randn(4, 4, device="cuda"))
|
compiled(torch.randn(4, 4, device="cuda"))
|
||||||
|
|
||||||
|
|
@ -1321,6 +1323,8 @@ def forward(self, x_1: "f32[2][1]cpu"):
|
||||||
torch._dynamo.reset()
|
torch._dynamo.reset()
|
||||||
|
|
||||||
mod = MixedModule().cuda()
|
mod = MixedModule().cuda()
|
||||||
|
for p in mod.parameters():
|
||||||
|
p.requires_grad = False
|
||||||
compiled = torch.compile(mod, backend="inductor")
|
compiled = torch.compile(mod, backend="inductor")
|
||||||
compiled(torch.randn(4, 4, device="cuda"))
|
compiled(torch.randn(4, 4, device="cuda"))
|
||||||
|
|
||||||
|
|
@ -1375,6 +1379,8 @@ def forward(self, x_1: "f32[2][1]cpu"):
|
||||||
with self._setup_runtime_estimates_capture() as payload_buffer:
|
with self._setup_runtime_estimates_capture() as payload_buffer:
|
||||||
torch._dynamo.reset()
|
torch._dynamo.reset()
|
||||||
mod = Mixed().cuda()
|
mod = Mixed().cuda()
|
||||||
|
for p in mod.parameters():
|
||||||
|
p.requires_grad = False
|
||||||
compiled = torch.compile(mod, backend="inductor")
|
compiled = torch.compile(mod, backend="inductor")
|
||||||
compiled(torch.randn(4, 4, device="cuda"))
|
compiled(torch.randn(4, 4, device="cuda"))
|
||||||
payload = payload_buffer.getvalue().strip()
|
payload = payload_buffer.getvalue().strip()
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ import warnings
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from torch._library.autograd import autograd_fallback_mode
|
||||||
from torch.library import _scoped_library
|
from torch.library import _scoped_library
|
||||||
from torch.testing._internal.common_utils import (
|
from torch.testing._internal.common_utils import (
|
||||||
instantiate_parametrized_tests,
|
instantiate_parametrized_tests,
|
||||||
|
|
@ -15,16 +16,6 @@ from torch.testing._internal.common_utils import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
|
||||||
def autograd_fallback_mode(mode):
|
|
||||||
prev = torch._C._get_autograd_fallback_mode()
|
|
||||||
try:
|
|
||||||
torch._C._set_autograd_fallback_mode(mode)
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
torch._C._set_autograd_fallback_mode(prev)
|
|
||||||
|
|
||||||
|
|
||||||
class TestAutogradFallback(TestCase):
|
class TestAutogradFallback(TestCase):
|
||||||
test_ns = "_test_autograd_fallback"
|
test_ns = "_test_autograd_fallback"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,7 @@ from torch._dynamo.utils import (
|
||||||
from torch._guards import detect_fake_mode
|
from torch._guards import detect_fake_mode
|
||||||
from torch._inductor.cudagraph_utils import BoxedDeviceIndex
|
from torch._inductor.cudagraph_utils import BoxedDeviceIndex
|
||||||
from torch._inductor.utils import BoxedBool
|
from torch._inductor.utils import BoxedBool
|
||||||
|
from torch._library.autograd import autograd_fallback_mode
|
||||||
from torch._subclasses import FakeTensor, FakeTensorMode
|
from torch._subclasses import FakeTensor, FakeTensorMode
|
||||||
from torch.export._tree_utils import reorder_kwargs
|
from torch.export._tree_utils import reorder_kwargs
|
||||||
from torch.fx.experimental.proxy_tensor import make_fx
|
from torch.fx.experimental.proxy_tensor import make_fx
|
||||||
|
|
@ -528,6 +529,9 @@ def create_aot_state(
|
||||||
stack.enter_context(
|
stack.enter_context(
|
||||||
torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing()
|
torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing()
|
||||||
)
|
)
|
||||||
|
# Make it an error to backprop through PT2 compliant ops that silently
|
||||||
|
# detach autograd
|
||||||
|
stack.enter_context(autograd_fallback_mode("error"))
|
||||||
|
|
||||||
from torch._library.fake_class_registry import FakeScriptObject, maybe_to_fake_obj
|
from torch._library.fake_class_registry import FakeScriptObject, maybe_to_fake_obj
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
# mypy: allow-untyped-defs
|
# mypy: allow-untyped-defs
|
||||||
|
import contextlib
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
@ -235,6 +236,16 @@ def not_list_of_optional_tensor(tree):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def autograd_fallback_mode(mode):
|
||||||
|
prev = _C._get_autograd_fallback_mode()
|
||||||
|
try:
|
||||||
|
_C._set_autograd_fallback_mode(mode)
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
_C._set_autograd_fallback_mode(prev)
|
||||||
|
|
||||||
|
|
||||||
flatten = _pytree.tree_flatten
|
flatten = _pytree.tree_flatten
|
||||||
unflatten = _pytree.tree_unflatten
|
unflatten = _pytree.tree_unflatten
|
||||||
spec_t = _pytree.TreeSpec
|
spec_t = _pytree.TreeSpec
|
||||||
|
|
|
||||||
|
|
@ -50,7 +50,6 @@ AutogradFallbackMode kAutogradFallbackMode = AutogradFallbackMode::Warn;
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void setAutogradFallbackMode(AutogradFallbackMode mode) {
|
void setAutogradFallbackMode(AutogradFallbackMode mode) {
|
||||||
TORCH_CHECK(mode != AutogradFallbackMode::Error, "NYI: mode='error'");
|
|
||||||
kAutogradFallbackMode = mode;
|
kAutogradFallbackMode = mode;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -58,41 +57,61 @@ AutogradFallbackMode getAutogradFallbackMode() {
|
||||||
return kAutogradFallbackMode;
|
return kAutogradFallbackMode;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void warnAutogradNotImplemented(const std::string& op_name) {
|
static void reportAutogradNotImplemented(
|
||||||
TORCH_WARN(
|
const std::string& op_name,
|
||||||
op_name,
|
bool is_warn) {
|
||||||
": an autograd kernel was not registered to the Autograd key(s) ",
|
if (is_warn) {
|
||||||
"but we are trying to backprop through it. This may lead to silently incorrect behavior. ",
|
TORCH_WARN(
|
||||||
"This behavior is deprecated and will be removed in a future version of PyTorch. ",
|
op_name,
|
||||||
"If your operator is differentiable, please ensure you have registered an "
|
": an autograd kernel was not registered to the Autograd key(s) ",
|
||||||
"autograd kernel to the correct Autograd key (e.g. DispatchKey::Autograd, "
|
"but we are trying to backprop through it. This may lead to silently incorrect behavior. ",
|
||||||
"DispatchKey::CompositeImplicitAutograd). If your operator is not "
|
"This behavior is deprecated and will be removed in a future version of PyTorch. ",
|
||||||
"differentiable, or to squash this warning and use the previous behavior, "
|
"If your operator is differentiable, please ensure you have registered an "
|
||||||
"please register torch::CppFunction::makeFallthrough() to DispatchKey::Autograd.");
|
"autograd kernel to the correct Autograd key (e.g. DispatchKey::Autograd, "
|
||||||
|
"DispatchKey::CompositeImplicitAutograd). If your operator is not "
|
||||||
|
"differentiable, or to squash this warning and use the previous behavior, "
|
||||||
|
"please register torch::CppFunction::makeFallthrough() to DispatchKey::Autograd.");
|
||||||
|
} else {
|
||||||
|
TORCH_CHECK(
|
||||||
|
0,
|
||||||
|
op_name,
|
||||||
|
": an autograd kernel was not registered to the Autograd key(s) ",
|
||||||
|
"but we are trying to backprop through it. This can lead to silently incorrect behavior. ",
|
||||||
|
"If your operator is differentiable, please ensure you have registered an "
|
||||||
|
"autograd kernel to the correct Autograd key (e.g. DispatchKey::Autograd, "
|
||||||
|
"). If your operator is not "
|
||||||
|
"differentiable and ensure NO gradients flow through this operator, "
|
||||||
|
"please register torch::CppFunction::makeFallthrough() to DispatchKey::Autograd.")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct WarnNotImplemented : public Node {
|
struct NotImplementedBackward : public Node {
|
||||||
WarnNotImplemented(
|
NotImplementedBackward(
|
||||||
std::string op_name,
|
std::string op_name,
|
||||||
size_t num_outputs,
|
size_t num_outputs,
|
||||||
|
bool is_warn,
|
||||||
edge_list&& next_edges)
|
edge_list&& next_edges)
|
||||||
: Node(std::move(next_edges)),
|
: Node(std::move(next_edges)),
|
||||||
op_name(std::move(op_name)),
|
op_name(std::move(op_name)),
|
||||||
num_outputs(num_outputs) {}
|
num_outputs(num_outputs),
|
||||||
|
is_warn(is_warn) {}
|
||||||
|
|
||||||
WarnNotImplemented(std::string op_name, size_t num_outputs)
|
NotImplementedBackward(std::string op_name, size_t num_outputs, bool is_warn)
|
||||||
: op_name(std::move(op_name)), num_outputs(num_outputs) {}
|
: op_name(std::move(op_name)),
|
||||||
|
num_outputs(num_outputs),
|
||||||
|
is_warn(is_warn) {}
|
||||||
|
|
||||||
variable_list apply(variable_list&& inputs) override;
|
variable_list apply(variable_list&& inputs) override;
|
||||||
|
|
||||||
std::string op_name;
|
std::string op_name;
|
||||||
size_t num_outputs;
|
size_t num_outputs;
|
||||||
|
bool is_warn;
|
||||||
};
|
};
|
||||||
|
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved)
|
// NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved)
|
||||||
auto WarnNotImplemented::apply(variable_list&& inputs) -> variable_list {
|
auto NotImplementedBackward::apply(variable_list&& inputs) -> variable_list {
|
||||||
auto inputsLocal = std::move(inputs);
|
auto inputsLocal = std::move(inputs);
|
||||||
warnAutogradNotImplemented(op_name);
|
reportAutogradNotImplemented(op_name, is_warn);
|
||||||
std::vector<at::Tensor> output(num_outputs);
|
std::vector<at::Tensor> output(num_outputs);
|
||||||
return output;
|
return output;
|
||||||
}
|
}
|
||||||
|
|
@ -111,8 +130,6 @@ static void basicAutogradNotImplementedFallbackImpl(
|
||||||
op.redispatchBoxed(dispatch_keys & c10::after_autograd_keyset, stack);
|
op.redispatchBoxed(dispatch_keys & c10::after_autograd_keyset, stack);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
TORCH_INTERNAL_ASSERT(
|
|
||||||
getAutogradFallbackMode() == AutogradFallbackMode::Warn);
|
|
||||||
|
|
||||||
bool any_input_requires_grad = false;
|
bool any_input_requires_grad = false;
|
||||||
_foreach_tensor(
|
_foreach_tensor(
|
||||||
|
|
@ -128,7 +145,9 @@ static void basicAutogradNotImplementedFallbackImpl(
|
||||||
// by putting it after the requires_grad checks.
|
// by putting it after the requires_grad checks.
|
||||||
any_input_requires_grad = any_input_requires_grad && GradMode::is_enabled();
|
any_input_requires_grad = any_input_requires_grad && GradMode::is_enabled();
|
||||||
|
|
||||||
std::shared_ptr<WarnNotImplemented> grad_fn;
|
bool is_warn = getAutogradFallbackMode() == AutogradFallbackMode::Warn;
|
||||||
|
|
||||||
|
std::shared_ptr<NotImplementedBackward> grad_fn;
|
||||||
if (any_input_requires_grad) {
|
if (any_input_requires_grad) {
|
||||||
// NB: It is standard to collect edges from all tensors
|
// NB: It is standard to collect edges from all tensors
|
||||||
// (see generated/VariableTypeEverything.cpp for examples)
|
// (see generated/VariableTypeEverything.cpp for examples)
|
||||||
|
|
@ -140,8 +159,9 @@ static void basicAutogradNotImplementedFallbackImpl(
|
||||||
stack,
|
stack,
|
||||||
stack_start,
|
stack_start,
|
||||||
num_arguments);
|
num_arguments);
|
||||||
grad_fn = std::shared_ptr<WarnNotImplemented>(
|
grad_fn = std::shared_ptr<NotImplementedBackward>(
|
||||||
new WarnNotImplemented(op_name, all_tensors_on_stack.size()),
|
new NotImplementedBackward(
|
||||||
|
op_name, all_tensors_on_stack.size(), is_warn),
|
||||||
deleteNode);
|
deleteNode);
|
||||||
grad_fn->set_next_edges(collect_next_edges(all_tensors_on_stack));
|
grad_fn->set_next_edges(collect_next_edges(all_tensors_on_stack));
|
||||||
}
|
}
|
||||||
|
|
@ -177,8 +197,8 @@ static void basicAutogradNotImplementedFallbackImpl(
|
||||||
// >>> y = op(k)
|
// >>> y = op(k)
|
||||||
// >>> torch.autograd.grad(z.sum(), w)
|
// >>> torch.autograd.grad(z.sum(), w)
|
||||||
if (t.requires_grad()) {
|
if (t.requires_grad()) {
|
||||||
t.register_hook([op_name](const at::Tensor& grad) {
|
t.register_hook([op_name, is_warn](const at::Tensor& grad) {
|
||||||
warnAutogradNotImplemented(op_name);
|
reportAutogradNotImplemented(op_name, is_warn);
|
||||||
});
|
});
|
||||||
// If history is rebased, then we will attempt to warn
|
// If history is rebased, then we will attempt to warn
|
||||||
// on the view's base. This will catch most cases (because
|
// on the view's base. This will catch most cases (because
|
||||||
|
|
@ -188,18 +208,19 @@ static void basicAutogradNotImplementedFallbackImpl(
|
||||||
const auto& base = t._base();
|
const auto& base = t._base();
|
||||||
if (base.requires_grad()) {
|
if (base.requires_grad()) {
|
||||||
// Can only register_hook on tensors that require grad.
|
// Can only register_hook on tensors that require grad.
|
||||||
base.register_hook([op_name](const at::TensorBase& grad) {
|
base.register_hook(
|
||||||
warnAutogradNotImplemented(op_name);
|
[op_name, is_warn](const at::TensorBase& grad) {
|
||||||
});
|
reportAutogradNotImplemented(op_name, is_warn);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the post-autograd implementation returns any Tensors that
|
// If the post-autograd implementation returns any Tensors that
|
||||||
// don't require grad, then we install the WarnNotImplemented grad_fn.
|
// don't require grad, then we install the NotImplementedBackward
|
||||||
// This grad_fn warns in backward and returns undefined tensor
|
// grad_fn. This grad_fn warns in backward and returns undefined
|
||||||
// gradients.
|
// tensor gradients.
|
||||||
//
|
//
|
||||||
// NOTE [autograd fallback and in-place operations]
|
// NOTE [autograd fallback and in-place operations]
|
||||||
// If the schema says the output is mutable, and the output
|
// If the schema says the output is mutable, and the output
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user