mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Adds mixed precision autocasting support between fp32/fp16 to torchscript/JIT. More in depth descriptoin can be found at [torch/csrc/jit/JIT-AUTOCAST.md](https://github.com/pytorch/pytorch/pull/63939/files#diff-1f1772aaa508841c5bb58b74ab98f49a1e577612cd9ea5c386c8714a75db830b) This PR implemented an autocast optimization pass that inserts casting ops per AMP rule (torch/csrc/jit/passes/autocast.cpp), that mimics the behavior of eager autocast. The pass also takes into consideration the context of `torch.cuda.amp.autocast` and only inserts casting ops within the enabled context manager, giving feature parity as with eager amp autocast. We currently provide JIT AMP autocast as a prototyping feature, so it is default off and could be turned on via `torch._C._jit_set_autocast_mode(True)` The JIT support for autocast is subject to different constraints compared to the eager mode implementation (mostly related to the fact that TorchScript is statically typed), restriction on the user facing python code is described in doc torch/csrc/jit/JIT-AUTOCAST.md This is a prototype, there are also implementation limitation that's necessary to keep this PR small and get something functioning quickly on upstream, so we can iterate on designs. Few limitation/challenge that is not properly resolved in this PR: 1. Autocast inserts cast operation, which would have impact on scalar type of output tensor feeding downstream operations. We are not currently propagating the updated scalar types, this would give issues/wrong results on operations in promotion rules. 2. Backward for autodiff in JIT misses the casting of dgrad to input scalar type, as what autograd does in eager. This forces us to explicitly mark the casting operation for certain operations (e.g. binary ops), otherwise, we might be feeding dgrad with mismatch scalar type to input. This could potentially break gradient function consuming dgrad. (e.g. gemm backwards, which assumes grad_output to be of same scalar type as input') 3. `torch.autocast` api has an optional argument `dtype` which is not currently supported in the JIT autocast and we require a static value. Credit goes mostly to: tlemo kevinstephano Pull Request resolved: https://github.com/pytorch/pytorch/pull/63939 Reviewed By: navahgar Differential Revision: D31093381 Pulled By: eellison fbshipit-source-id: da6e26c668c38b01e296f304507048d6c1794314
160 lines
4.5 KiB
C++
160 lines
4.5 KiB
C++
#include <c10/util/irange.h>
|
|
#include <torch/csrc/jit/api/function_impl.h>
|
|
#include <torch/csrc/jit/passes/inliner.h>
|
|
|
|
#include <torch/csrc/jit/frontend/error_report.h>
|
|
#include <torch/csrc/jit/passes/constant_pooling.h>
|
|
#include <torch/csrc/jit/passes/constant_propagation.h>
|
|
#include <torch/csrc/jit/passes/peephole.h>
|
|
|
|
#ifndef C10_MOBILE
|
|
#include <ATen/autocast_mode.h>
|
|
#include <torch/csrc/jit/passes/autocast.h>
|
|
#endif
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
namespace {
|
|
c10::FunctionSchema defaultSchemaFor(const GraphFunction& function) {
|
|
std::vector<c10::Argument> args;
|
|
std::vector<c10::Argument> returns;
|
|
Graph& g = *function.graph();
|
|
size_t num_inputs = function.num_inputs();
|
|
for (const auto i : c10::irange(num_inputs)) {
|
|
const Value* v = g.inputs().at(i);
|
|
std::string name = v->hasDebugName() ? v->debugNameBase()
|
|
: ("argument_" + c10::to_string(i));
|
|
args.emplace_back(std::move(name), unshapedType(g.inputs()[i]->type()));
|
|
}
|
|
for (const auto i : c10::irange(g.outputs().size())) {
|
|
returns.emplace_back("", unshapedType(g.outputs()[i]->type()));
|
|
}
|
|
return {function.name(), "", std::move(args), std::move(returns)};
|
|
}
|
|
|
|
template <typename T, typename F>
|
|
T* tryToGraphFunctionImpl(F& function) noexcept {
|
|
if (!function.isGraphFunction()) {
|
|
return nullptr;
|
|
}
|
|
|
|
return static_cast<T*>(&function);
|
|
}
|
|
|
|
template <typename T, typename F>
|
|
T& toGraphFunctionImpl(F& function) {
|
|
if (auto* g = tryToGraphFunctionImpl<T>(function)) {
|
|
return *g;
|
|
}
|
|
|
|
TORCH_INTERNAL_ASSERT(
|
|
false,
|
|
"Failed to downcast a Function to a GraphFunction. "
|
|
"This probably indicates that the JIT calling context needs a "
|
|
"special case on tryToGraphFunction() instead.");
|
|
}
|
|
|
|
} // namespace
|
|
|
|
void placeholderCreator(GraphFunction&) {
|
|
throw RecursiveMethodCallError();
|
|
}
|
|
|
|
void GraphFunction::run(Stack& stack) {
|
|
get_executor().run(stack);
|
|
}
|
|
|
|
void GraphFunction::run(Stack&& stack) {
|
|
run(stack);
|
|
}
|
|
|
|
c10::intrusive_ptr<c10::ivalue::Future> GraphFunction::runAsync(
|
|
Stack& stack,
|
|
TaskLauncher taskLauncher) {
|
|
return get_executor().runAsync(stack, std::move(taskLauncher));
|
|
}
|
|
|
|
IValue GraphFunction::operator()(
|
|
std::vector<IValue> stack,
|
|
const Kwargs& kwargs) {
|
|
getSchema().checkAndNormalizeInputs(stack, kwargs);
|
|
run(stack);
|
|
return stack.front();
|
|
}
|
|
|
|
void GraphFunction::ensure_defined() {
|
|
if (function_creator_) {
|
|
auto creator = function_creator_;
|
|
function_creator_ = placeholderCreator;
|
|
creator(*this);
|
|
function_creator_ = nullptr;
|
|
}
|
|
check_single_output();
|
|
}
|
|
|
|
const c10::FunctionSchema& GraphFunction::getSchema() const {
|
|
if (schema_ == nullptr) {
|
|
schema_ = std::make_unique<c10::FunctionSchema>(defaultSchemaFor(*this));
|
|
}
|
|
return *schema_;
|
|
}
|
|
|
|
GraphFunction::SpecializationKey GraphFunction::currentSpecialization() const {
|
|
#ifdef C10_MOBILE
|
|
// disabling autodiff pass for mobile build since autocast APIs don't exist
|
|
return SpecializationKey::AutocastOff;
|
|
#else
|
|
bool cpu_enabled = at::autocast::is_cpu_enabled();
|
|
bool gpu_enabled = at::autocast::is_enabled();
|
|
if (cpu_enabled && gpu_enabled) {
|
|
return SpecializationKey::CpuGpuAutocastOn;
|
|
} else if (!cpu_enabled && !gpu_enabled) {
|
|
return SpecializationKey::AutocastOff;
|
|
} else {
|
|
return gpu_enabled ? SpecializationKey::GpuAutocastOn
|
|
: SpecializationKey::CpuAutocastOn;
|
|
}
|
|
#endif
|
|
}
|
|
|
|
void preoptimizeGraph(std::shared_ptr<Graph>& graph) {
|
|
Inline(*graph);
|
|
|
|
// Peephole Optimize cleans up many "is None" checks and creates constant prop
|
|
// opportunities
|
|
PeepholeOptimize(graph, true);
|
|
|
|
// AliasDb construction can be slow, so run it just on immutable types
|
|
// to clean up constant Ifs & other easy wins
|
|
ConstantPropagationImmutableTypes(graph);
|
|
|
|
#ifndef C10_MOBILE
|
|
// Inject casts for automatic mixed precision
|
|
//
|
|
// TODO: Ideally, this pass could run earlier, before inlining
|
|
// or any other optimizations. That setup is preferable because:
|
|
// 1. The AMP pass would be self-contained and function independently
|
|
// of the any optimizations
|
|
// 2. AMP transformations would benefit from followup passes's cleanup
|
|
//
|
|
Autocast(graph);
|
|
#endif
|
|
|
|
ConstantPooling(graph);
|
|
}
|
|
|
|
GraphFunction* tryToGraphFunction(Function& function) noexcept {
|
|
return tryToGraphFunctionImpl<GraphFunction>(function);
|
|
}
|
|
|
|
GraphFunction& toGraphFunction(Function& function) {
|
|
return toGraphFunctionImpl<GraphFunction>(function);
|
|
}
|
|
|
|
const GraphFunction& toGraphFunction(const Function& function) {
|
|
return toGraphFunctionImpl<const GraphFunction>(function);
|
|
}
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|