[ca] suggest to disable compiled autograd for trace-time NotImplementedErrors (#156509)

Example:

```python
  File "/home/xmfan/core/a/pytorch/torch/autograd/graph.py", line 829, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
NotImplementedError: TorchDispatchMode not yet implemented for compiled autograd.
  You can disable compiled autograd for this operation by:
  1.  Relocating the unsupported autograd call outside the compiled region.
  2.  Wrapping the unsupported autograd call within a scope that disables compiled autograd.
  3.  Configuring the specific compilation unit to disable compiled autograd.
  4.  Globally disabling compiled autograd at the application's initialization.
```

No duplicate error messages for python side trace-time errors
```python
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xmfan/core/a/pytorch/torch/_dynamo/compiled_autograd.py", line 344, in begin_capture
    raise NotImplementedError(
NotImplementedError: Found tensor of type <class 'torch.nn.utils._expanded_weights.expanded_weights_impl.ExpandedWeight'>, which is not supported by FakeTensorMode. You can turn off compiled autograd by either:
1. Moving the unsupported autograd call outside of the torch.compile'd region.
2. Wrapping the unsupported autograd call in the torch._dynamo.compiled_autograd._disable() context manager.
3. Setting torch._dynamo.config.compiled_autograd=False for the torch.compile call containing the unsupported autograd call.
4. Setting torch._dynamo.config.compiled_autograd=False at the start of the program.
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156509
Approved by: https://github.com/jansel
ghstack dependencies: #156374
This commit is contained in:
Simon Fan 2025-06-21 11:25:11 -07:00 committed by PyTorch MergeBot
parent f1968a5e76
commit 5f2f343e1e
6 changed files with 69 additions and 39 deletions

View File

@ -69,11 +69,11 @@ if TYPE_CHECKING:
from torch.fx.proxy import Proxy
TURN_OFF_MSG = """Turn off compiled autograd by either:
1. Moving the unsupported autograd calls outside of the torch.compile'd region.
2. Wrapping the unsupported in the torch._dynamo.compiled_autograd._disable() context manager.
3. Setting torch._dynamo.config.compiled_autograd to False for the torch.compile call containing the unsupported autograd call.
4. Setting torch._dynamo.config.compiled_autograd to False at the start of the program."""
TURN_OFF_MSG = """You can turn off compiled autograd by either:
1. Moving the unsupported autograd call outside of the torch.compile'd region.
2. Wrapping the unsupported autograd call in the torch._dynamo.compiled_autograd._disable() context manager.
3. Setting torch._dynamo.config.compiled_autograd=False for the torch.compile call containing the unsupported autograd call.
4. Setting torch._dynamo.config.compiled_autograd=False at the start of the program."""
compiled_autograd_log = getArtifactLogger(__name__, "compiled_autograd")
verbose_log = getArtifactLogger(__name__, "compiled_autograd_verbose")
@ -164,7 +164,7 @@ class NaNChecker:
if grad is not None:
assert not torch.isnan(grad).any(), (
f"Compiled autograd running under anomaly mode with inputs[{idx}] already "
"having NaN gradient. This is not supported."
"having NaN gradient. This is not supported. {TURN_OFF_MSG}"
)
self.params_to_check[f"inputs[{idx}]"] = inputs[idx]

View File

@ -1345,7 +1345,7 @@ auto Engine::execute(
}
if (compiled_autograd != nullptr) {
TORCH_CHECK(
TORCH_CHECK_NOT_IMPLEMENTED(
num_threads_in_compiled_autograd.load() == 0,
"Re-entrant into Compiled Autograd from a parent Compiled Autograd call is not yet supported. Consider disabling Compiled Autograd on the re-entrant call.");
// Allows us to assert no other threads are in backwards

View File

@ -594,8 +594,8 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
// Implementations in subclasses should call args.collect() with all node
// attrs. These functions are only called durring backward.
virtual void compiled_args(CompiledNodeArgs& args) const {
throw std::runtime_error(
std::string("compiled_args not implemented: ") + name());
TORCH_CHECK_NOT_IMPLEMENTED(
false, std::string("compiled_args not implemented: ") + name());
}
// Used by compiled autograd to call apply() with different saved tensors
@ -604,8 +604,8 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
virtual variable_list apply_with_saved(
const variable_list& inputs,
SwapSavedVariables& saved) {
throw std::runtime_error(
std::string("apply_with_saved not implemented: ") + name());
TORCH_CHECK_NOT_IMPLEMENTED(
false, std::string("apply_with_saved not implemented: ") + name());
}
// If this node is the AOTBackward node produced by torch.compile.

View File

@ -24,9 +24,10 @@ struct TORCH_API FunctionPreHook {
// only implemented for python hooks, registers hook with compiled autograd
virtual void compiled_args(
torch::dynamo::autograd::CompiledNodeArgs& args) const {
throw std::runtime_error(
TORCH_CHECK_NOT_IMPLEMENTED(
false,
std::string("compiled_args nyi, see [Note: Compiled Autograd] ") +
typeid(*this).name());
typeid(*this).name());
}
};
@ -38,9 +39,10 @@ struct TORCH_API FunctionPostHook {
// only implemented for python hooks, registers hook with compiled autograd
virtual void compiled_args(
torch::dynamo::autograd::CompiledNodeArgs& args) const {
throw std::runtime_error(
TORCH_CHECK_NOT_IMPLEMENTED(
false,
std::string("compiled_args nyi, see [Note: Compiled Autograd] ") +
typeid(*this).name());
typeid(*this).name());
}
};
@ -51,17 +53,19 @@ struct TORCH_API PostAccumulateGradHook {
// autograd
virtual void compiled_args(
torch::dynamo::autograd::CompiledNodeArgs& args) const {
throw std::runtime_error(
std::string("not yet implemented for compiled autograd: ") +
typeid(*this).name());
TORCH_CHECK_NOT_IMPLEMENTED(
false,
std::string("compiled_args nyi, see [Note: Compiled Autograd] ") +
typeid(*this).name());
}
virtual void apply_with_saved(
Variable&,
torch::dynamo::autograd::SwapSavedVariables&) {
throw std::runtime_error(
std::string("not yet implemented for compiled autograd: ") +
typeid(*this).name());
TORCH_CHECK_NOT_IMPLEMENTED(
false,
std::string("compiled_args nyi, see [Note: Compiled Autograd] ") +
typeid(*this).name());
}
};

View File

@ -572,7 +572,8 @@ class CompiledNodeArgs {
}
}
void collect(const InputMetadata& t) {
TORCH_CHECK(!t.is_nested_tensor(), "NestedTensor not implemented");
TORCH_CHECK_NOT_IMPLEMENTED(
!t.is_nested_tensor(), "NestedTensor support not implemented. ");
collect(t.options());
collect(t.is_tensor_subclass());
collect(t.shape_as_dim_vector());
@ -1110,7 +1111,8 @@ struct IValuePacker {
// with certain compiler settings
// (see https://github.com/pytorch/pytorch/pull/144707 for examples).
// It's not clear what the problem is, so we're going to ignore it for now.
TORCH_INTERNAL_ASSERT(false, "torch.compile not supported on Windows");
TORCH_CHECK_NOT_IMPLEMENTED(
false, "torch.compile not supported on Windows");
#else
if constexpr (::std::is_same_v<T, at::Tensor>) {
return at::TensorType::get();
@ -1147,7 +1149,8 @@ struct IValuePacker {
// define how to pack and unpack an object of this time into an IValue
// by creating a specialization of IValuePacker for this type.
// See NOTE: [Compiled Autograd and backward functions] for context.
TORCH_INTERNAL_ASSERT(false, "IValuePacker not implemented for type");
TORCH_CHECK_NOT_IMPLEMENTED(
false, "IValuePacker not implemented for type");
return at::NoneType::get();
}
#endif

View File

@ -10,6 +10,7 @@
#include <iostream>
#include <sstream>
#include <string>
#include <string_view>
#include <vector>
/*
@ -56,6 +57,19 @@ namespace {
PyObject* the_autograd_compiler = nullptr;
int default_dyn_type_int = 0;
PyObject* python_verbose_logger = nullptr;
constexpr std::string_view _TURN_OFF_COMPILED_AUTOGRAD_MSG = R"(
You can disable compiled autograd for this operation by:
1. Relocating the unsupported autograd call outside the compiled region.
2. Wrapping the unsupported autograd call within a scope that disables compiled autograd.
3. Configuring the specific compilation unit to disable compiled autograd.
4. Globally disabling compiled autograd at the application's initialization.
)";
std::string TURN_OFF_COMPILED_AUTOGRAD_MSG() {
return std::string(_TURN_OFF_COMPILED_AUTOGRAD_MSG);
}
} // namespace
// see https://github.com/pytorch/pytorch/pull/34845
@ -1172,9 +1186,10 @@ struct LockGuardWithErrorLogs {
// performance reasons, but it shouldn't happen here since we:
// 1. disable multithreaded autograd
// 2. plenty of latency between backward calls
TORCH_INTERNAL_ASSERT(
TORCH_CHECK_NOT_IMPLEMENTED(
mtx_.try_lock(),
"Trying to run compiled autograd within another compiled autograd call (e.g. reentrant checkpointing), this is not supported yet.");
"Trying to run compiled autograd within another compiled autograd call, this is not supported yet. " +
TURN_OFF_COMPILED_AUTOGRAD_MSG());
}
~LockGuardWithErrorLogs() {
@ -1190,9 +1205,10 @@ static variable_list compiled_autograd(
const GraphTask& graph_task,
bool accumulate_grad,
const edge_list& output_edges) {
TORCH_CHECK(
TORCH_CHECK_NOT_IMPLEMENTED(
c10::impl::TorchDispatchModeTLS::stack_len() == 0,
"TorchDispatchMode not yet implemented for compiled autograd")
"TorchDispatchMode not yet implemented for compiled autograd. " +
TURN_OFF_COMPILED_AUTOGRAD_MSG());
static std::mutex mtx;
LockGuardWithErrorLogs lock_guard(mtx);
pybind11::gil_scoped_acquire gil;
@ -1204,17 +1220,24 @@ static variable_list compiled_autograd(
THPObjectPtr ivalue_args;
THPObjectPtr hooks;
THPObjectPtr packed_inputs;
CacheNode* cache = _compiled_autograd_impl(
graph_root,
graph_task,
accumulate_grad,
output_edges,
&inputs,
&sizes,
&ivalue_args,
&hooks,
&packed_inputs,
active_rstate);
CacheNode* cache = nullptr;
try {
cache = _compiled_autograd_impl(
graph_root,
graph_task,
accumulate_grad,
output_edges,
&inputs,
&sizes,
&ivalue_args,
&hooks,
&packed_inputs,
active_rstate);
} catch (const c10::NotImplementedError& e) {
TORCH_CHECK_NOT_IMPLEMENTED(
false, std::string(e.what()) + " " + TURN_OFF_COMPILED_AUTOGRAD_MSG());
}
TORCH_INTERNAL_ASSERT(cache != nullptr);
THPObjectPtr pyresult(check(PyObject_CallFunctionObjArgs(
cache->runtime_wrapper.get(),