mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
More aggressively market functorch.vmap when torch.vmap gets called (#67347)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/67347 This PR: - changes the warning when torch.vmap gets called to suggest using functorch.vmap - changes the warning when a batching rule isn't implemented to suggest using functorch.vmap Test Plan: - test/test_vmap.py Reviewed By: H-Huang Differential Revision: D31966603 Pulled By: zou3519 fbshipit-source-id: b01dc1c2e298ce899b4a3a5fb333222a8d5bfb56
This commit is contained in:
parent
da5ffe752a
commit
a8b93cb3ec
|
|
@ -66,13 +66,16 @@ static bool isInplaceOp(const c10::FunctionSchema& schema) {
|
||||||
return return_alias_info && return_alias_info->isWrite();
|
return return_alias_info && return_alias_info->isWrite();
|
||||||
}
|
}
|
||||||
|
|
||||||
static void warnFallback(const c10::FunctionSchema& schema, bool is_inplace) {
|
static void warnFallback(const c10::FunctionSchema& schema) {
|
||||||
if (!globalContext().areVmapFallbackWarningsEnabled()) {
|
if (!globalContext().areVmapFallbackWarningsEnabled()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
auto uses_stack = is_inplace ? "" : " and stack";
|
TORCH_WARN("There is a performance drop because we have not yet implemented ",
|
||||||
TORCH_WARN("Batching rule not implemented for ", schema.operator_name(), " falling back "
|
"the batching rule for ", schema.operator_name(), ". ",
|
||||||
"to slow (for loop", uses_stack, ") implementation");
|
"We've moved development of vmap to to functorch "
|
||||||
|
"(https://github.com/pytorch/functorch), please try functorch.vmap "
|
||||||
|
"instead and/or file ",
|
||||||
|
" an issue on GitHub so that we can prioritize its implementation.");
|
||||||
}
|
}
|
||||||
|
|
||||||
// The general flow of the algorithm is as follows.
|
// The general flow of the algorithm is as follows.
|
||||||
|
|
@ -88,7 +91,7 @@ static void warnFallback(const c10::FunctionSchema& schema, bool is_inplace) {
|
||||||
// the operator, and then pop the results off the stack.
|
// the operator, and then pop the results off the stack.
|
||||||
void batchedTensorInplaceForLoopFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
|
void batchedTensorInplaceForLoopFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
|
||||||
const auto& schema = op.schema();
|
const auto& schema = op.schema();
|
||||||
warnFallback(schema, /*in_place*/true);
|
warnFallback(schema);
|
||||||
|
|
||||||
const auto num_arguments = static_cast<int64_t>(schema.arguments().size());
|
const auto num_arguments = static_cast<int64_t>(schema.arguments().size());
|
||||||
const auto arguments = torch::jit::last(stack, num_arguments);
|
const auto arguments = torch::jit::last(stack, num_arguments);
|
||||||
|
|
@ -260,7 +263,7 @@ void batchedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Sta
|
||||||
TORCH_CHECK(num_returns >= 1,
|
TORCH_CHECK(num_returns >= 1,
|
||||||
"Batching rule not implemented for ", schema.operator_name(), ". ",
|
"Batching rule not implemented for ", schema.operator_name(), ". ",
|
||||||
"The fallback path does not support operations with no returns.");
|
"The fallback path does not support operations with no returns.");
|
||||||
warnFallback(schema, /*in_place*/false);
|
warnFallback(schema);
|
||||||
|
|
||||||
const auto num_arguments = static_cast<int64_t>(schema.arguments().size());
|
const auto num_arguments = static_cast<int64_t>(schema.arguments().size());
|
||||||
const auto arguments = torch::jit::last(stack, num_arguments);
|
const auto arguments = torch::jit::last(stack, num_arguments);
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ from torch.testing._internal.common_device_type import instantiate_device_type_t
|
||||||
import types
|
import types
|
||||||
|
|
||||||
|
|
||||||
FALLBACK_REGEX = r'falling back to slow \(for loop( and stack)?\) implementation'
|
FALLBACK_REGEX = r'There is a performance drop'
|
||||||
|
|
||||||
class EnableVmapFallbackWarnings:
|
class EnableVmapFallbackWarnings:
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
|
|
|
||||||
|
|
@ -158,9 +158,10 @@ def vmap(func: Callable, in_dims: in_dims_t = 0, out_dims: out_dims_t = 0) -> Ca
|
||||||
gradients when composed with autograd.
|
gradients when composed with autograd.
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
We are actively developing a different and improved vmap prototype
|
We have moved development of vmap to
|
||||||
`here. <https://github.com/zou3519/functorch>`_ The improved
|
`functorch. <https://github.com/pytorch/functorch>`_ functorch's
|
||||||
prototype is able to arbitrarily compose with gradient computation.
|
vmap is able to arbitrarily compose with gradient computation
|
||||||
|
and contains significant performance improvements.
|
||||||
Please give that a try if that is what you're looking for.
|
Please give that a try if that is what you're looking for.
|
||||||
|
|
||||||
Furthermore, if you're interested in using vmap for your use case,
|
Furthermore, if you're interested in using vmap for your use case,
|
||||||
|
|
@ -247,12 +248,11 @@ def vmap(func: Callable, in_dims: in_dims_t = 0, out_dims: out_dims_t = 0) -> Ca
|
||||||
sequences out of the box.
|
sequences out of the box.
|
||||||
"""
|
"""
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
'torch.vmap is an experimental prototype that is subject to '
|
'Please use functorch.vmap instead of torch.vmap '
|
||||||
'change and/or deletion. Please use at your own risk. There may be '
|
'(https://github.com/pytorch/functorch). '
|
||||||
'unexpected performance cliffs due to certain operators not being '
|
'We\'ve moved development on torch.vmap over to functorch; '
|
||||||
'implemented. To see detailed performance warnings please use '
|
'functorch\'s vmap has a multitude of significant performance and '
|
||||||
'`torch._C._debug_only_display_vmap_fallback_warnings(True) '
|
'functionality improvements.',
|
||||||
'before the call to `vmap`.',
|
|
||||||
stacklevel=2)
|
stacklevel=2)
|
||||||
return _vmap(func, in_dims, out_dims)
|
return _vmap(func, in_dims, out_dims)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user