mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ca] suggest to disable compiled autograd for trace-time NotImplementedErrors (#156509)
Example:
```python
File "/home/xmfan/core/a/pytorch/torch/autograd/graph.py", line 829, in _engine_run_backward
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
NotImplementedError: TorchDispatchMode not yet implemented for compiled autograd.
You can disable compiled autograd for this operation by:
1. Relocating the unsupported autograd call outside the compiled region.
2. Wrapping the unsupported autograd call within a scope that disables compiled autograd.
3. Configuring the specific compilation unit to disable compiled autograd.
4. Globally disabling compiled autograd at the application's initialization.
```
No duplicate error messages for python side trace-time errors
```python
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xmfan/core/a/pytorch/torch/_dynamo/compiled_autograd.py", line 344, in begin_capture
raise NotImplementedError(
NotImplementedError: Found tensor of type <class 'torch.nn.utils._expanded_weights.expanded_weights_impl.ExpandedWeight'>, which is not supported by FakeTensorMode. You can turn off compiled autograd by either:
1. Moving the unsupported autograd call outside of the torch.compile'd region.
2. Wrapping the unsupported autograd call in the torch._dynamo.compiled_autograd._disable() context manager.
3. Setting torch._dynamo.config.compiled_autograd=False for the torch.compile call containing the unsupported autograd call.
4. Setting torch._dynamo.config.compiled_autograd=False at the start of the program.
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156509
Approved by: https://github.com/jansel
ghstack dependencies: #156374
This commit is contained in:
parent
f1968a5e76
commit
5f2f343e1e
|
|
@ -69,11 +69,11 @@ if TYPE_CHECKING:
|
|||
from torch.fx.proxy import Proxy
|
||||
|
||||
|
||||
TURN_OFF_MSG = """Turn off compiled autograd by either:
|
||||
1. Moving the unsupported autograd calls outside of the torch.compile'd region.
|
||||
2. Wrapping the unsupported in the torch._dynamo.compiled_autograd._disable() context manager.
|
||||
3. Setting torch._dynamo.config.compiled_autograd to False for the torch.compile call containing the unsupported autograd call.
|
||||
4. Setting torch._dynamo.config.compiled_autograd to False at the start of the program."""
|
||||
TURN_OFF_MSG = """You can turn off compiled autograd by either:
|
||||
1. Moving the unsupported autograd call outside of the torch.compile'd region.
|
||||
2. Wrapping the unsupported autograd call in the torch._dynamo.compiled_autograd._disable() context manager.
|
||||
3. Setting torch._dynamo.config.compiled_autograd=False for the torch.compile call containing the unsupported autograd call.
|
||||
4. Setting torch._dynamo.config.compiled_autograd=False at the start of the program."""
|
||||
|
||||
compiled_autograd_log = getArtifactLogger(__name__, "compiled_autograd")
|
||||
verbose_log = getArtifactLogger(__name__, "compiled_autograd_verbose")
|
||||
|
|
@ -164,7 +164,7 @@ class NaNChecker:
|
|||
if grad is not None:
|
||||
assert not torch.isnan(grad).any(), (
|
||||
f"Compiled autograd running under anomaly mode with inputs[{idx}] already "
|
||||
"having NaN gradient. This is not supported."
|
||||
"having NaN gradient. This is not supported. {TURN_OFF_MSG}"
|
||||
)
|
||||
|
||||
self.params_to_check[f"inputs[{idx}]"] = inputs[idx]
|
||||
|
|
|
|||
|
|
@ -1345,7 +1345,7 @@ auto Engine::execute(
|
|||
}
|
||||
|
||||
if (compiled_autograd != nullptr) {
|
||||
TORCH_CHECK(
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
num_threads_in_compiled_autograd.load() == 0,
|
||||
"Re-entrant into Compiled Autograd from a parent Compiled Autograd call is not yet supported. Consider disabling Compiled Autograd on the re-entrant call.");
|
||||
// Allows us to assert no other threads are in backwards
|
||||
|
|
|
|||
|
|
@ -594,8 +594,8 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
|
|||
// Implementations in subclasses should call args.collect() with all node
|
||||
// attrs. These functions are only called durring backward.
|
||||
virtual void compiled_args(CompiledNodeArgs& args) const {
|
||||
throw std::runtime_error(
|
||||
std::string("compiled_args not implemented: ") + name());
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false, std::string("compiled_args not implemented: ") + name());
|
||||
}
|
||||
|
||||
// Used by compiled autograd to call apply() with different saved tensors
|
||||
|
|
@ -604,8 +604,8 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
|
|||
virtual variable_list apply_with_saved(
|
||||
const variable_list& inputs,
|
||||
SwapSavedVariables& saved) {
|
||||
throw std::runtime_error(
|
||||
std::string("apply_with_saved not implemented: ") + name());
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false, std::string("apply_with_saved not implemented: ") + name());
|
||||
}
|
||||
|
||||
// If this node is the AOTBackward node produced by torch.compile.
|
||||
|
|
|
|||
|
|
@ -24,9 +24,10 @@ struct TORCH_API FunctionPreHook {
|
|||
// only implemented for python hooks, registers hook with compiled autograd
|
||||
virtual void compiled_args(
|
||||
torch::dynamo::autograd::CompiledNodeArgs& args) const {
|
||||
throw std::runtime_error(
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
std::string("compiled_args nyi, see [Note: Compiled Autograd] ") +
|
||||
typeid(*this).name());
|
||||
typeid(*this).name());
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -38,9 +39,10 @@ struct TORCH_API FunctionPostHook {
|
|||
// only implemented for python hooks, registers hook with compiled autograd
|
||||
virtual void compiled_args(
|
||||
torch::dynamo::autograd::CompiledNodeArgs& args) const {
|
||||
throw std::runtime_error(
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
std::string("compiled_args nyi, see [Note: Compiled Autograd] ") +
|
||||
typeid(*this).name());
|
||||
typeid(*this).name());
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -51,17 +53,19 @@ struct TORCH_API PostAccumulateGradHook {
|
|||
// autograd
|
||||
virtual void compiled_args(
|
||||
torch::dynamo::autograd::CompiledNodeArgs& args) const {
|
||||
throw std::runtime_error(
|
||||
std::string("not yet implemented for compiled autograd: ") +
|
||||
typeid(*this).name());
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
std::string("compiled_args nyi, see [Note: Compiled Autograd] ") +
|
||||
typeid(*this).name());
|
||||
}
|
||||
|
||||
virtual void apply_with_saved(
|
||||
Variable&,
|
||||
torch::dynamo::autograd::SwapSavedVariables&) {
|
||||
throw std::runtime_error(
|
||||
std::string("not yet implemented for compiled autograd: ") +
|
||||
typeid(*this).name());
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
std::string("compiled_args nyi, see [Note: Compiled Autograd] ") +
|
||||
typeid(*this).name());
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -572,7 +572,8 @@ class CompiledNodeArgs {
|
|||
}
|
||||
}
|
||||
void collect(const InputMetadata& t) {
|
||||
TORCH_CHECK(!t.is_nested_tensor(), "NestedTensor not implemented");
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
!t.is_nested_tensor(), "NestedTensor support not implemented. ");
|
||||
collect(t.options());
|
||||
collect(t.is_tensor_subclass());
|
||||
collect(t.shape_as_dim_vector());
|
||||
|
|
@ -1110,7 +1111,8 @@ struct IValuePacker {
|
|||
// with certain compiler settings
|
||||
// (see https://github.com/pytorch/pytorch/pull/144707 for examples).
|
||||
// It's not clear what the problem is, so we're going to ignore it for now.
|
||||
TORCH_INTERNAL_ASSERT(false, "torch.compile not supported on Windows");
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false, "torch.compile not supported on Windows");
|
||||
#else
|
||||
if constexpr (::std::is_same_v<T, at::Tensor>) {
|
||||
return at::TensorType::get();
|
||||
|
|
@ -1147,7 +1149,8 @@ struct IValuePacker {
|
|||
// define how to pack and unpack an object of this time into an IValue
|
||||
// by creating a specialization of IValuePacker for this type.
|
||||
// See NOTE: [Compiled Autograd and backward functions] for context.
|
||||
TORCH_INTERNAL_ASSERT(false, "IValuePacker not implemented for type");
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false, "IValuePacker not implemented for type");
|
||||
return at::NoneType::get();
|
||||
}
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@
|
|||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <vector>
|
||||
|
||||
/*
|
||||
|
|
@ -56,6 +57,19 @@ namespace {
|
|||
PyObject* the_autograd_compiler = nullptr;
|
||||
int default_dyn_type_int = 0;
|
||||
PyObject* python_verbose_logger = nullptr;
|
||||
|
||||
constexpr std::string_view _TURN_OFF_COMPILED_AUTOGRAD_MSG = R"(
|
||||
You can disable compiled autograd for this operation by:
|
||||
1. Relocating the unsupported autograd call outside the compiled region.
|
||||
2. Wrapping the unsupported autograd call within a scope that disables compiled autograd.
|
||||
3. Configuring the specific compilation unit to disable compiled autograd.
|
||||
4. Globally disabling compiled autograd at the application's initialization.
|
||||
)";
|
||||
|
||||
std::string TURN_OFF_COMPILED_AUTOGRAD_MSG() {
|
||||
return std::string(_TURN_OFF_COMPILED_AUTOGRAD_MSG);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// see https://github.com/pytorch/pytorch/pull/34845
|
||||
|
|
@ -1172,9 +1186,10 @@ struct LockGuardWithErrorLogs {
|
|||
// performance reasons, but it shouldn't happen here since we:
|
||||
// 1. disable multithreaded autograd
|
||||
// 2. plenty of latency between backward calls
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
mtx_.try_lock(),
|
||||
"Trying to run compiled autograd within another compiled autograd call (e.g. reentrant checkpointing), this is not supported yet.");
|
||||
"Trying to run compiled autograd within another compiled autograd call, this is not supported yet. " +
|
||||
TURN_OFF_COMPILED_AUTOGRAD_MSG());
|
||||
}
|
||||
|
||||
~LockGuardWithErrorLogs() {
|
||||
|
|
@ -1190,9 +1205,10 @@ static variable_list compiled_autograd(
|
|||
const GraphTask& graph_task,
|
||||
bool accumulate_grad,
|
||||
const edge_list& output_edges) {
|
||||
TORCH_CHECK(
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
c10::impl::TorchDispatchModeTLS::stack_len() == 0,
|
||||
"TorchDispatchMode not yet implemented for compiled autograd")
|
||||
"TorchDispatchMode not yet implemented for compiled autograd. " +
|
||||
TURN_OFF_COMPILED_AUTOGRAD_MSG());
|
||||
static std::mutex mtx;
|
||||
LockGuardWithErrorLogs lock_guard(mtx);
|
||||
pybind11::gil_scoped_acquire gil;
|
||||
|
|
@ -1204,17 +1220,24 @@ static variable_list compiled_autograd(
|
|||
THPObjectPtr ivalue_args;
|
||||
THPObjectPtr hooks;
|
||||
THPObjectPtr packed_inputs;
|
||||
CacheNode* cache = _compiled_autograd_impl(
|
||||
graph_root,
|
||||
graph_task,
|
||||
accumulate_grad,
|
||||
output_edges,
|
||||
&inputs,
|
||||
&sizes,
|
||||
&ivalue_args,
|
||||
&hooks,
|
||||
&packed_inputs,
|
||||
active_rstate);
|
||||
CacheNode* cache = nullptr;
|
||||
try {
|
||||
cache = _compiled_autograd_impl(
|
||||
graph_root,
|
||||
graph_task,
|
||||
accumulate_grad,
|
||||
output_edges,
|
||||
&inputs,
|
||||
&sizes,
|
||||
&ivalue_args,
|
||||
&hooks,
|
||||
&packed_inputs,
|
||||
active_rstate);
|
||||
} catch (const c10::NotImplementedError& e) {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false, std::string(e.what()) + " " + TURN_OFF_COMPILED_AUTOGRAD_MSG());
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT(cache != nullptr);
|
||||
|
||||
THPObjectPtr pyresult(check(PyObject_CallFunctionObjArgs(
|
||||
cache->runtime_wrapper.get(),
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user