#include #include #include #include #include #include #include #include #include #include #include #include namespace torch::jit { namespace { std::atomic autocast_enabled = true; struct AutocastContext { bool gpu_enabled = false; bool cpu_enabled = false; c10::ScalarType gpu_scalar_type = c10::ScalarType::Undefined; c10::ScalarType cpu_scalar_type = c10::ScalarType::Undefined; operator bool() const { return gpu_enabled || cpu_enabled; } }; struct AutocastScope { Value* instance = nullptr; AutocastContext context; void stack(const AutocastContext& parent_context) {} }; bool isAutocastNode(Value* value) { const auto class_name = getModuleName(value); return class_name.has_value() && (*class_name == "__torch__.torch.cuda.amp.autocast_mode.autocast" || *class_name == "__torch__.torch.cpu.amp.autocast_mode.autocast" || *class_name == "__torch__.torch.amp.autocast_mode.autocast"); } // If we have an autocast instance, return it // // This is the pattern we're looking for (this is done after // autocast.__init__() has been inlined) // // %4 : bool = prim::Constant[value=1]() // %5 : __torch__.torch.cuda.amp.autocast_mode.autocast = prim::CreateObject() // = prim::SetAttr[name="_enabled"](%5, %4) // // Notes: // 1. There's no guarantee that the autocast instance is in the same block // as the prim::Enter() node // 2. `prim::SetAttr` must follow `prim::CreateObject()` in the same block, // but there might be other nodes in between // std::optional parseAutocast( Value* value, const AutocastContext& context) { if (!isAutocastNode(value)) { // Not an autocast... return std::nullopt; } if (value->node()->kind() == prim::CreateObject) { AutocastScope scope; scope.instance = value; scope.context = context; std::optional enabled; std::string device; c10::ScalarType dtype = c10::ScalarType::Undefined; for (Use use : value->uses()) { // TODO: support runtime flag if (use.user->kind() == prim::SetAttr && use.user->s(attr::name) == "_enabled") { // Search for `prim::SetAttr[name="_enabled"]` enabled = constant_as(use.user->input(1)); TORCH_CHECK( enabled.has_value(), "Autocast _enabled argument must be a constant"); } else if ( use.user->kind() == prim::SetAttr && use.user->s(attr::name) == "device") { // Search for `prim::SetAttr[name="device"]` auto ret = constant_as(use.user->input(1)); TORCH_CHECK( ret.has_value(), "Autocast device argument must be a constant"); device = ret.value(); } else if ( use.user->kind() == prim::SetAttr && use.user->s(attr::name) == "fast_dtype") { // Search for `prim::SetAttr[name="fast_dtype"]` auto ret = constant_as(use.user->input(1)); if (ret.has_value()) { dtype = ret.value(); } } } TORCH_CHECK(enabled.has_value(), "Autocast missing _enabled attribute"); TORCH_CHECK(!device.empty(), "Autocast missing device attribute"); if (dtype == c10::ScalarType::Undefined) { dtype = at::autocast::get_autocast_dtype(c10::Device(device).type()); } TORCH_CHECK( dtype != c10::ScalarType::Undefined, "Autocast has invalid fast_dtype attribute"); if (device == "cuda" || device == "mps") { scope.context.gpu_enabled = enabled.value(); scope.context.gpu_scalar_type = dtype; } else if (device == "cpu") { scope.context.cpu_enabled = enabled.value(); scope.context.cpu_scalar_type = dtype; } else { TORCH_INTERNAL_ASSERT( false, "unrecognized device for autocast pass: ", device); } return scope; } else { // We only support simple and static autocast expressions. For example, // the following should report an error (since the autocast would not // work as expected) // // autocast_on = autocast(enabled=True) // autocast_off = autocast(enabled=False) // with autocast_on if condition else autocast_off: // ... // // TODO: better error message // TORCH_CHECK(false, "Unsupported autocast syntax"); } return std::nullopt; } void castTensorInputs( Node* node, Symbol cast_op, const AutocastContext& context) { if (!context) { return; } const auto graph = node->owningGraph(); std::unordered_set casted_inputs; // need to also keep the inputs in order, otherwise tracing fails // sanity checks because casting ops are inserted in random order std::vector casted_inputs_ordered; for (auto input : node->inputs()) { // TODO: update cast_op signature to take dynamic context flags auto input_tensor_type = input->type()->cast(); if (input_tensor_type && input->node()->kind() != cast_op) { auto has_inserted = casted_inputs.insert(input); if (has_inserted.second) { casted_inputs_ordered.push_back(input); } } } WithInsertPoint insert_point(node); for (auto input : casted_inputs_ordered) { if (cast_op == aten::_autocast_to_full_precision) { const auto new_input = graph->insert( cast_op, {input, graph->insertConstant(IValue(context.gpu_enabled)), graph->insertConstant(IValue(context.cpu_enabled))}); node->replaceInputWith(input, new_input); } else if (cast_op == aten::_autocast_to_reduced_precision) { const auto new_input = graph->insert( cast_op, {input, graph->insertConstant(IValue(context.gpu_enabled)), graph->insertConstant(IValue(context.cpu_enabled)), graph->insertConstant(IValue(context.gpu_scalar_type)), graph->insertConstant(IValue(context.cpu_scalar_type))}); node->replaceInputWith(input, new_input); } else { TORCH_INTERNAL_ASSERT( false, "unrecognized cast_op symbol: ", cast_op.toQualString()); } } } bool hasExplicitDtypeArgument(Node* node) { if (node->hasNamedInput("dtype")) { Value* dtype_arg = node->namedInput("dtype"); return dtype_arg->type()->kind() != TypeKind::NoneType; } return false; } void castInputsToWidestType(Node* node, const AutocastContext& context) { if (!context) { return; } // Figure out the widest type // (really, just looking for any float32 inputs) // // TODO: revisit this (do we need to consider float64 types?) // for (auto input : node->inputs()) { if (auto tensor_type = input->type()->cast()) { const auto dtype = tensor_type->scalarType(); if (!dtype.has_value() || *dtype == at::ScalarType::Float) { castTensorInputs(node, aten::_autocast_to_full_precision, context); return; } } } } // Users can call torch.is_autocast_enabled() or is_autocast_cpu_enabled() to // determine whether autocasting is enabled. With JIT-scripted functions, we // actually need to return true if eager autocast OR jit autocast are enabled. // // In the case where JIT autocast is enabled, we replace // %x : bool = aten::is_autocast_enabled() // with a constant "True". // // More context on eager vs JIT autocasting: // // Autocasting actually has two settings: eager autocasting, and JIT // autocasting. Eager autocasting is the thread-local setting that turns on // the relevant bit in the dispatcher settings. JIT autocasting is the pass // implemented in this file, which makes changes to the graph to insert casting // ops in order to achieve the same behavior as eager autocasting. // // If eager autocasting is enabled at the time when a JIT-scripted function is // invoked, then autocasting will occur regardless of what the JIT-autocasting // settings are. void updateAutocastEnabledCheck(Node* node, bool is_jit_enabled) { if (!is_jit_enabled) { return; } auto graph = node->owningGraph(); WithInsertPoint insert_point(node); Value* true_constant = graph->insertConstant(IValue(true)); node->output()->replaceAllUsesWith(true_constant); node->destroy(); } // [Note: implicit type promotion in Autocast] // // Casting policy below mostly follows pytorch/aten/src/ATen/autocast.cpp, with // a few exceptions, e.g. `aten::add`, which is needed to be put to promotion // list for JIT autocast. // The reason is that in eager amp, some binary ops promote inputs implicitly // inside the operation, e.g. `aten::add` with fp16 & fp32 inputs would both be // casted to fp32. In backward, autograd would cast dgrad to match their // scalar_type in forward graph. So inputs with mismatched scalar_type would // get the different dgrad. // While in JIT, autodiff doesn't do this, so implicit cast is not visible to // autodiff and backward dgrad for mismatched inputs would ended up with dgrads // in the same scalar_type. This has caused downstream operations, which // expects dgrad to be the same scalar type to throw mismatch error. // // TODO: Use the list from AMP eager directly void handleBlock(Block* block, AutocastContext initial_state) { std::stack autocast_stack; std::optional incompatible_amp = std::nullopt; // The current autocast enabled/disabled state auto current_state = [&] { return autocast_stack.empty() ? initial_state : autocast_stack.top().context; }; for (Node* node : block->nodes()) { switch (node->kind()) { case prim::CallFunction: // TODO: limit it only to amp related node; if (current_state() == initial_state) { // if the current autocasting state is the same as the global state, // then autocasting will be done correctly on subsequent method and // function calls if (current_state()) { castTensorInputs( node, aten::_autocast_to_full_precision, current_state()); } break; } TORCH_INTERNAL_ASSERT( !incompatible_amp.has_value() || incompatible_amp.value(), "Calls are not expected with AMP & JIT"); incompatible_amp = true; break; case prim::CallMethod: // TODO: limit it only to amp related node; if (current_state() == initial_state) { // if the current autocasting state is the same as the global state, // then autocasting will be done correctly on subsequent method and // function calls if (current_state()) { castTensorInputs( node, aten::_autocast_to_full_precision, current_state()); } break; } if (auto class_type = node->input(0)->type()->cast()) { const auto& name = node->s(attr::name); const auto& function = class_type->getMethod(name); if (!function.isGraphFunction()) { TORCH_INTERNAL_ASSERT( !incompatible_amp.has_value() || incompatible_amp.value(), "Calls are not expected with AMP & JIT"); incompatible_amp = true; } } else { TORCH_INTERNAL_ASSERT( !incompatible_amp.has_value() || incompatible_amp.value(), "Unexpected prim::CallMethod form with AMP & JIT"); incompatible_amp = true; } break; case prim::Enter: if (auto autocast_scope = parseAutocast(node->input(), current_state())) { if (node->hasUses()) { // TODO: better error message TORCH_CHECK(false, "`with autocast() as ...` is not supported"); } TORCH_INTERNAL_ASSERT( !incompatible_amp.has_value() || !incompatible_amp.value(), "Unsupported case by AMP & JIT"); incompatible_amp = false; autocast_stack.push(*autocast_scope); } break; case prim::Exit: if (isAutocastNode(node->input(0))) { TORCH_INTERNAL_ASSERT(!autocast_stack.empty()); TORCH_INTERNAL_ASSERT(autocast_stack.top().instance == node->input()); TORCH_INTERNAL_ASSERT( !incompatible_amp.has_value() || !incompatible_amp.value(), "Unsupported case by AMP & JIT"); incompatible_amp = false; autocast_stack.pop(); } break; case aten::is_autocast_enabled: updateAutocastEnabledCheck(node, current_state().gpu_enabled); break; case aten::is_autocast_cpu_enabled: updateAutocastEnabledCheck(node, current_state().cpu_enabled); break; // CastPolicy::fp16 (cast all inputs to float16) case aten::_convolution: case aten::conv1d: case aten::conv2d: case aten::conv3d: case aten::conv_tbc: case aten::conv_transpose1d: case aten::convolution: case aten::cudnn_convolution: case aten::cudnn_convolution_transpose: case aten::prelu: case aten::addmm: case aten::addmv: case aten::addr: case aten::matmul: case aten::mm: case aten::mv: case aten::linear: case aten::addbmm: case aten::baddbmm: case aten::bmm: case aten::chain_matmul: case aten::_thnn_fused_lstm_cell: case aten::_thnn_fused_gru_cell: case aten::lstm_cell: case aten::gru_cell: case aten::rnn_tanh_cell: case aten::rnn_relu_cell: if (!node->schema().is_mutable()) { castTensorInputs( node, aten::_autocast_to_reduced_precision, current_state()); } break; // CastPolicy::fp32 (cast all inputs to float32) case aten::native_layer_norm: case aten::acos: case aten::asin: case aten::cosh: case aten::erfinv: case aten::exp: case aten::expm1: case aten::log: case aten::log10: case aten::log2: case aten::log1p: case aten::reciprocal: case aten::rsqrt: case aten::sinh: case aten::tan: case aten::pow: case aten::softplus: case aten::gelu: case aten::layer_norm: case aten::group_norm: case aten::frobenius_norm: case aten::nuclear_norm: case aten::cosine_similarity: case aten::cosine_embedding_loss: case aten::nll_loss: case aten::nll_loss2d: case aten::hinge_embedding_loss: case aten::kl_div: case aten::l1_loss: case aten::smooth_l1_loss: case aten::mse_loss: case aten::margin_ranking_loss: case aten::multilabel_margin_loss: case aten::soft_margin_loss: case aten::triplet_margin_loss: case aten::multi_margin_loss: case aten::binary_cross_entropy_with_logits: case aten::dist: case aten::pdist: case aten::cdist: case aten::renorm: case aten::logsumexp: if (!node->schema().is_mutable()) { castTensorInputs( node, aten::_autocast_to_full_precision, current_state()); } break; // CastPolicy::fp32_set_opt_dtype case aten::prod: case aten::log_softmax: case aten::cumprod: case aten::cumsum: case aten::sum: if (!node->schema().is_mutable() && !hasExplicitDtypeArgument(node)) { castTensorInputs( node, aten::_autocast_to_full_precision, current_state()); } break; // cast softmax to fp32 only on GPU case aten::softmax: if (!node->schema().is_mutable() && !hasExplicitDtypeArgument(node)) { auto context = current_state(); context.cpu_enabled = false; castTensorInputs(node, aten::_autocast_to_full_precision, context); } break; // CastPolicy::promote (promote inputs to the widest type) case aten::addcdiv: case aten::addcmul: case aten::atan2: case aten::bilinear: case aten::cat: case aten::cross: case aten::dot: case aten::equal: case aten::index_put: case aten::stack: case aten::tensordot: // add, sub, mul, div were added to autocast jit, because aten implicit // type promotion is not visible to JIT and could cause dtype mismatch on // backward // see [Note: implicit type promotion in Autocast] case aten::add: case aten::sub: case aten::mul: case aten::div: if (!node->schema().is_mutable()) { castInputsToWidestType(node, current_state()); } break; // Banned in autocast, see binary_cross_entropy_banned() case aten::binary_cross_entropy: if (current_state()) { TORCH_CHECK(false, "Unsafe to autocast"); } } // process sub-blocks, if any for (Block* sub_block : node->blocks()) { handleBlock(sub_block, current_state()); } } // Sanity check: make sure there's no unbalanced transition TORCH_INTERNAL_ASSERT(autocast_stack.empty()); } } // namespace bool setAutocastMode(bool value) { return autocast_enabled.exchange(value); } bool autocastEnabled() { return autocast_enabled; } void Autocast(const std::shared_ptr& graph) { GRAPH_DUMP("\nBefore Autocast: ", graph); if (autocastEnabled()) { AutocastContext init = { at::autocast::is_autocast_enabled(at::kCUDA), at::autocast::is_autocast_enabled(at::kCPU), at::autocast::get_autocast_dtype(at::kCUDA), at::autocast::get_autocast_dtype(at::kCPU)}; handleBlock(graph->block(), init); } GRAPH_DUMP("\nAfter Autocast: ", graph); } } // namespace torch::jit