#include #include #include #include #include #include #include #ifndef C10_MOBILE #include #include #endif namespace torch::jit { namespace { c10::FunctionSchema defaultSchemaFor(const GraphFunction& function) { std::vector args; std::vector returns; Graph& g = *function.graph(); size_t num_inputs = function.num_inputs(); for (const auto i : c10::irange(num_inputs)) { const Value* v = g.inputs().at(i); std::string name = v->hasDebugName() ? v->debugNameBase() : ("argument_" + c10::to_string(i)); args.emplace_back(std::move(name), unshapedType(g.inputs()[i]->type())); } for (const auto i : c10::irange(g.outputs().size())) { returns.emplace_back("", unshapedType(g.outputs()[i]->type())); } return {function.name(), "", std::move(args), std::move(returns)}; } template T* tryToGraphFunctionImpl(F& function) noexcept { if (!function.isGraphFunction()) { return nullptr; } return static_cast(&function); } template T& toGraphFunctionImpl(F& function) { if (auto* g = tryToGraphFunctionImpl(function)) { return *g; } TORCH_INTERNAL_ASSERT( false, "Failed to downcast a Function to a GraphFunction. " "This probably indicates that the JIT calling context needs a " "special case on tryToGraphFunction() instead."); } } // namespace void placeholderCreator(GraphFunction&) { throw RecursiveMethodCallError(); } void GraphFunction::run(Stack& stack) { get_executor().run(stack); } c10::intrusive_ptr GraphFunction::runAsync( Stack& stack, TaskLauncher taskLauncher) { return get_executor().runAsync(stack, std::move(taskLauncher)); } void GraphFunction::ensure_defined() { if (function_creator_) { auto creator = function_creator_; function_creator_ = placeholderCreator; creator(*this); function_creator_ = nullptr; } check_single_output(); } const c10::FunctionSchema& GraphFunction::getSchema() const { if (schema_ == nullptr) { schema_ = std::make_unique(defaultSchemaFor(*this)); } return *schema_; } GraphFunction::SpecializationKey GraphFunction::currentSpecialization() const { if (force_no_amp_) { return SpecializationKey::AutocastOff; } #ifdef C10_MOBILE // disabling autodiff pass for mobile build since autocast APIs don't exist return SpecializationKey::AutocastOff; #else bool cpu_enabled = at::autocast::is_cpu_enabled(); bool gpu_enabled = at::autocast::is_enabled(); if (cpu_enabled && gpu_enabled) { return SpecializationKey::CpuGpuAutocastOn; } else if (!cpu_enabled && !gpu_enabled) { return SpecializationKey::AutocastOff; } else { return gpu_enabled ? SpecializationKey::GpuAutocastOn : SpecializationKey::CpuAutocastOn; } #endif } void preoptimizeGraph(std::shared_ptr& graph, bool disable_autocast) { Inline(*graph); // Peephole Optimize cleans up many "is None" checks and creates constant prop // opportunities PeepholeOptimize(graph, true); // AliasDb construction can be slow, so run it just on immutable types // to clean up constant Ifs & other easy wins ConstantPropagationImmutableTypes(graph); #ifndef C10_MOBILE // Inject casts for automatic mixed precision // // TODO: Ideally, this pass could run earlier, before inlining // or any other optimizations. That setup is preferable because: // 1. The AMP pass would be self-contained and function independently // of the any optimizations // 2. AMP transformations would benefit from followup passes's cleanup // if (!disable_autocast) { Autocast(graph); } #endif ConstantPooling(graph); } GraphFunction* tryToGraphFunction(Function& function) noexcept { return tryToGraphFunctionImpl(function); } GraphFunction& toGraphFunction(Function& function) { return toGraphFunctionImpl(function); } const GraphFunction& toGraphFunction(const Function& function) { return toGraphFunctionImpl(function); } } // namespace torch::jit