#pragma once #include #include #include #include #include #include #include #include #include #include #include #include // see [Note: Compiled Autograd] namespace torch::dynamo::autograd { using namespace torch::autograd; // This is a layer of indirection for calling methods on the Python // AutogradCompilerInstance (referred to as the "py_compiler") from // libtorch_cpu (where Python is not available). // A PyCompilerInterfaceImpl in libtorch_python subclasses it and // overrides the methods to do the actual calls back to Python. struct TORCH_API PyCompilerInterface { PyCompilerInterface() = default; PyCompilerInterface(const PyCompilerInterface&) = delete; PyCompilerInterface& operator=(const PyCompilerInterface&) = delete; PyCompilerInterface(PyCompilerInterface&&) = delete; PyCompilerInterface& operator=(PyCompilerInterface&&) = delete; virtual ~PyCompilerInterface() = default; // Invokes py_compiler.bind_function virtual std::string bind_function( PyObject* py_compiler, const std::string& fn_name, // NOLINTNEXTLINE(performance-unnecessary-value-param) functional_apply_t fn, // NOLINTNEXTLINE(performance-unnecessary-value-param) std::vector packed_args_schema, bool is_custom_function = false, bool is_traceable = true) const { TORCH_INTERNAL_ASSERT(false, "Needs to be overridden"); } // Invokes py_compiler.method_name(fn_name, inputs, packed_args, // output_metadata) virtual variable_list call_function( PyObject* py_compiler, const char* method_name, const std::string& fn_name, const variable_list& inputs, const ivalue_list& packed_args, const c10::IValue& output_metadata) const { TORCH_INTERNAL_ASSERT(false, "Needs to be overridden"); } virtual variable_list call_copy_slices_prologue( PyObject* py_compiler, const variable_list& inputs, const at::TensorGeometry& base, const at::TensorGeometry& view) const { TORCH_INTERNAL_ASSERT(false, "Needs to be overridden"); } virtual variable_list call_copy_slices_epilogue( PyObject* py_compiler, const std::vector& needs_input_grad, const at::Tensor& result, const variable_list& res, const at::Tensor& grad_slice) const { TORCH_INTERNAL_ASSERT(false, "Needs to be overridden"); } virtual at::Tensor call_unpack( PyObject* py_compiler, std::optional hook_id, size_t hook_input_id) const { TORCH_INTERNAL_ASSERT(false, "Needs to be overridden"); } virtual void call_accumulate_grad( PyObject* py_compiler, const at::Tensor& variable, const at::Tensor& grad, bool has_post_hooks) const { TORCH_INTERNAL_ASSERT(false, "Needs to be overridden"); } }; TORCH_API const std::unique_ptr& getPyCompilerInterface(); struct TORCH_API PyCompilerGuard { explicit PyCompilerGuard(std::unique_ptr&& impl); PyCompilerGuard(const PyCompilerGuard&) = delete; PyCompilerGuard& operator=(const PyCompilerGuard&) = delete; PyCompilerGuard(PyCompilerGuard&&) = delete; PyCompilerGuard& operator=(PyCompilerGuard&&) = delete; ~PyCompilerGuard(); }; // including torch/csrc/autograd/engine.h breaks BC by somehow introducing // symbol resolution issues. Instead requiring downstream users to include // engine.h to access collect_input_metadata, we provide it here (with a // different name to avoid ambiguous symbols...) TORCH_API std::vector> get_input_metadata( const edge_list& edges); struct SizeInput { // Note: int value is still needed when dynamic to pass as an arg enum DynType : uint8_t { STATIC = 0, DYNAMIC = 1 }; SizeInput(DynType dt, int64_t v) : dyn_type(dt), value(v) {} DynType dyn_type; int64_t value; }; struct CacheKeyBuffer { CacheKeyBuffer(const uint8_t* key, uint16_t len) : data(new uint8_t[len]) { std::memcpy(data.get(), key, len); } const uint8_t* get() const { return data.get(); } private: // NOLINTNEXTLINE(*c-array*) std::unique_ptr data; }; struct CacheKey { // Key to find the next node in the shadow graph. We use C++ RTTI for the // type of the node (ntype), then a key generated with a visitor pattern. CacheKey(const std::type_index& ntype, const uint8_t* key, uint16_t len) : node_type(ntype), key_size(len), key(key) {} bool operator<(const CacheKey& other) const { if (node_type != other.node_type) { return node_type < other.node_type; } if (key_size != other.key_size) { return key_size < other.key_size; } return std::memcmp(key, other.key, key_size) < 0; } bool operator==(const CacheKey& other) const { return node_type == other.node_type && key_size == other.key_size && std::memcmp(key, other.key, key_size) == 0; } size_t hash() const { // don't bother hashing the key data, common case 1 cache entry per node return std::hash()(node_type) ^ key_size; } std::type_index node_type; uint16_t key_size; const uint8_t* key; }; struct NodeCall { NodeCall(uint32_t id_, std::shared_ptr node_) : id(id_), node(std::move(node_)) {} void mark_output(int input_nr, int output_idx) { graph_output.emplace_back(input_nr, output_idx); } uint32_t id; std::shared_ptr node; std::vector> tensor_pre_hooks; std::vector> cpp_tensor_pre_hooks; std::vector pre_hooks; std::vector post_hooks; std::vector post_acc_grad_hooks; std::vector> graph_output; bool needed = true; }; struct NodeCalls : public std::unordered_map { NodeCall& lookup(const std::shared_ptr& function) { auto it = find(function.get()); if (it == end()) { it = emplace(function.get(), NodeCall(_next_id++, function)).first; nodes.emplace_back(function.get()); } return it->second; } const NodeCall& lookup(uint32_t id) const { TORCH_INTERNAL_ASSERT(id < nodes.size()); auto it = find(nodes[id]); TORCH_INTERNAL_ASSERT(it != end()); return it->second; } void clear() { _next_id = 0; std::unordered_map::clear(); nodes.clear(); } private: uint32_t _next_id = 0; std::vector nodes; }; struct TensorArg { // Represents a de-duplicated tensor that will be passed into the graph TensorArg(uint32_t i = 0) : id(i) {} uint32_t index() const { TORCH_INTERNAL_ASSERT(defined()); return id - 1; } bool defined() const { return id != 0; } uint32_t id; at::Tensor proxy_tensor; }; struct TensorArgs { // Manages a collection of TensorArgs and mappings from Tensors/SavedVariables // to them. This also allows us to unpack SavedVariable exactly once and // store the unpacked Tensor. TensorArgs(const std::optional& active_node_call_idx) : active_node_call_idx(active_node_call_idx) {} TensorArg& lookup(const at::Tensor& tensor, bool create = false) { if (!tensor.defined()) { return _undefined; } auto impl = tensor.unsafeGetTensorImpl(); auto it = _args.find(impl); if (it == _args.end()) { TORCH_INTERNAL_ASSERT(create && inputs.size() == _next_id - 1); it = _args.emplace(impl, TensorArg(_next_id++)).first; inputs.emplace_back(tensor); if (active_node_call_idx.has_value()) { input_origins.emplace_back(active_node_call_idx.value()); } } return it->second; } TensorArg& lookup(const SavedVariable& sv) { if (auto it = _saved_variables.find(&sv); it != _saved_variables.end()) { // unpacked before graph return *it->second; } // unpacked in graph auto it2 = _saved_variables_proxies.find(&sv); TORCH_INTERNAL_ASSERT(it2 != _saved_variables_proxies.end()); return *it2->second; } TensorArg& add(const at::Tensor& tensor) { return lookup(tensor, true); } TensorArg& add(const SavedVariable& sv, const std::shared_ptr& node) { // no unpack hooks in this codepath at::Tensor tensor = sv.unpack(node); TensorArg& arg = add(tensor); _saved_variables.emplace(&sv, &arg); return arg; } // the concrete tensors that will get passed into the graph as inputs std::vector inputs; // NodeCall id of each input, only when verbose logging is enabled std::vector input_origins; private: // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) const std::optional& active_node_call_idx; std::unordered_map _args; // Every TensorArg from this is actually owned by _args (or _undefined) and // that's why we have an un-owned pointer here. std::unordered_map _saved_variables; std::unordered_map _saved_variables_proxies; TensorArg _undefined; uint32_t _next_id = 1; // id=0 used by _undefined }; struct LiftedIValueArg { LiftedIValueArg() = delete; LiftedIValueArg(const at::IValue* ptr) : actual_ptr(ptr), proxy(at::IValue::uninitialized()) {} const at::IValue* actual_ptr; // lifetime handled by autograd node at::IValue proxy; }; struct LiftedIValueArgs { LiftedIValueArgs(const std::optional& active_node_call_idx) : active_node_call_idx(active_node_call_idx) {} at::IValue& next_proxy(const at::IValue* actual_ptr) { TORCH_INTERNAL_ASSERT(next < args.size()); auto& iv_arg = args.at(next++); TORCH_INTERNAL_ASSERT(iv_arg.actual_ptr == actual_ptr); return iv_arg.proxy; } void add(const at::IValue* iv) { args.emplace_back(iv); if (active_node_call_idx.has_value()) { args_origins.emplace_back(active_node_call_idx.value()); } } std::vector args; size_t next = 0; // NodeCall id of each arg, only when verbose logging is enabled std::vector args_origins; private: // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) const std::optional& active_node_call_idx; }; struct AutogradCompilerCall { AutogradCompilerCall(SizeInput::DynType default_dyn_type) : active_node_call_idx(std::nullopt), tensor_args(active_node_call_idx), lifted_ivalue_args(active_node_call_idx), default_dyn_type(default_dyn_type) {} void add_size_input(const c10::SymInt& s) { all_size_inputs.emplace_back( default_dyn_type, s.guard_int(__FILE__, __LINE__)); if (active_node_call_idx.has_value()) { size_input_origins.emplace_back(active_node_call_idx.value()); } } size_t emplace_hook(c10::SafePyObject&& fn) { hooks.emplace_back(std::move(fn)); return hooks.size() - 1; } size_t emplace_cpp_tensor_pre_hook( std::function&& fn) { cpp_tensor_pre_hooks.emplace_back(std::move(fn)); return cpp_tensor_pre_hooks.size() - 1; } size_t emplace_packed_input(c10::SafePyObject&& input) { packed_inputs.emplace_back(std::move(input)); return packed_inputs.size() - 1; } void set_active_node_call_idx(size_t node_call_idx) { active_node_call_idx = node_call_idx; } std::optional active_node_call_idx; TensorArgs tensor_args; std::vector all_size_inputs; LiftedIValueArgs lifted_ivalue_args; std::vector dyn_size_inputs; std::vector hooks; std::vector> cpp_tensor_pre_hooks; std::vector packed_inputs; NodeCalls node_calls; SizeInput::DynType default_dyn_type; // NodeCall id of each size, only when verbose logging is enabled std::vector size_input_origins; std::unordered_map> sv_to_hooks; // pynode -> backward and backward state idx std::unordered_map>> pynode_objs; }; class CompiledNodeArgs { // CompiledNodeArgs builds a representation of the constant values found // across all the nodes in the compiled graph, via 'collect' overloads. The // collected constants are specialized on by concatenation into a cache key. // Tensor, symint arguments (which are lifted to become graph inputs rather // than specialized on) are forwarded to the compiler and not included in the // key. public: void collect(const TensorArg& t) { collect_size(t.id); if (t.defined()) { const at::Tensor& tensor = _compiler.tensor_args.inputs[t.index()]; // including these in the cache key means dynamo-level tensor guards can // be skipped collect(tensor.device()); collect(tensor.dtype()); collect(tensor.requires_grad()); } } void collect(const at::Tensor& t) { collect(_compiler.tensor_args.add(t)); } void collect(const SavedVariable& sv, bool is_output) { if (auto hook_data = sv.retrieve_unpack_hook_data(); hook_data.has_value()) { // hooks, unpack in graph auto& [hook, packed_input] = hook_data.value(); size_t hook_id = _compiler.emplace_hook(std::move(hook)); // rely on dynamo to dedup packed tensors from unpacked tensors size_t input_id = _compiler.emplace_packed_input(std::move(packed_input)); _compiler.sv_to_hooks.emplace(&sv, std::make_pair(hook_id, input_id)); } else { // no hooks, unpack now collect( _compiler.tensor_args.add(sv, is_output ? _node_call.node : nullptr)); } } void collect(const c10::SymInt& t) { _compiler.add_size_input(t); } void collect(const std::vector& t, bool is_output) { collect_size(t.size()); for (const SavedVariable& i : t) { collect(i, is_output); } } template void collect(const std::vector& t) { collect_size(t.size()); for (const T& i : t) { collect(i); } } void collect(const c10::ArrayRef& t, bool is_output) { collect_size(t.size()); for (const SavedVariable& i : t) { collect(i, is_output); } } template void collect(const c10::ArrayRef& t) { collect_size(t.size()); for (const T& i : t) { collect(i); } } template void collect(const c10::OptionalArray& t) { collect(t.list); } template void collect(const std::optional& t) { if (cond(t.has_value())) { // NOLINTNEXTLINE(bugprone-unchecked-optional-access) collect(*t); } } template void collect(const std::pair& t) { collect(t.first); collect(t.second); } template void collect(const ska::flat_hash_map& m) { collect_size(m.size()); std::vector keys; keys.reserve(m.size()); std::transform( m.begin(), m.end(), std::back_inserter(keys), [](const auto& entry) { return entry.first; }); std::sort(keys.begin(), keys.end()); for (const auto& k : keys) { collect(k); collect(m.at(k)); } } void collect(const at::IValue& iv, bool nested = false) { // used by AutogradContext::saved_data from CppNode if (iv.isList()) { c10::List list = iv.toList(); collect_size(list.size()); for (auto&& value : list) { collect(value, true); } } else if (iv.isGenericDict()) { c10::Dict ordered_dict = iv.toGenericDict(); collect_size(ordered_dict.size()); // NOLINTNEXTLINE(modernize-loop-convert) for (auto it = ordered_dict.begin(); it != ordered_dict.end(); it++) { collect(it->key()); collect(it->value(), true); } } else if (iv.isTensor()) { collect(iv.toTensor()); } else if ( !nested && (iv.isInt() || iv.isSymInt() || iv.isDouble() || iv.isSymFloat())) { // can't lift ivalues nested in collections _compiler.lifted_ivalue_args.add(&iv); } else { try { collect(static_cast(at::IValue::hash(iv))); } catch (const std::runtime_error& e) { std::string msg = "Compiled autograd can not trace unhashable IValues, error: " + std::string(e.what()); TORCH_CHECK_NOT_IMPLEMENTED(false, msg); } } } void collect(const c10::Scalar& t) { auto type = t.type(); specialize_on_bytes(type); if (type == c10::ScalarType::Double) { collect(t.toDouble()); } else if (type == c10::ScalarType::Long) { collect(t.toLong()); } else if (type == c10::ScalarType::Bool) { collect(t.toBool()); } else if (type == c10::ScalarType::ComplexDouble) { auto c = t.toComplexDouble(); collect(c.real()); collect(c.imag()); } else { TORCH_INTERNAL_ASSERT(false); } } void collect(const c10::TensorOptions& t) { collect(t.device()); collect(t.dtype()); collect(t.layout()); collect(t.requires_grad()); collect(t.pinned_memory()); collect(t.memory_format_opt()); } void collect(const at::TensorGeometry& t) { collect(t.sym_sizes()); collect(t.sym_strides()); collect(t.sym_storage_offset()); } void collect(const torch::autograd::TypeAndSize& t) { collect(t.sym_sizes); collect(t.options); } void collect(const c10::Device& t) { collect(t.type()); collect(t.index()); } void collect(const std::string& t) { collect_size(t.size()); for (char c : t) { collect(c); } } void collect(const caffe2::TypeMeta& t) { specialize_on_bytes(t.id()); } void collect(const std::shared_ptr& t) { // Note: this is only capturing the ID of the node not everything // contained inside it. This is used for tracking connections between // nodes and the actual details of the node itself must be handled by // a separate call to `node->compiled_args()`. if (cond((bool)t)) { collect(_compiler.node_calls.lookup(t)); } } void collect(const NodeCall& t) { collect_size(t.id); collect(t.graph_output); collect_hooks_from(t.node.get()); } void collect(const Edge& t) { if (cond(t.is_valid())) { collect_size(_compiler.node_calls.lookup(t.function).id); collect_size(t.input_nr); collect(t.function->input_metadata(t.input_nr)); // for validate_outputs } } void collect(const InputMetadata& t) { TORCH_CHECK_NOT_IMPLEMENTED( !t.is_nested_tensor(), "NestedTensor support not implemented. "); collect(t.options()); collect(t.is_tensor_subclass()); collect(t.shape_as_dim_vector()); } void collect(const VariableInfo& t) { collect(t.layout); collect(t.device); collect(t.scalar_type); collect(t.size); collect(t.requires_grad); collect(t.is_empty); } bool cond(bool cond) { collect(cond); return cond; } #define COLLECT_AS_BYTES(T) \ void collect(T t) { \ specialize_on_bytes(t); \ } COLLECT_AS_BYTES(c10::ScalarType) COLLECT_AS_BYTES(c10::DeviceType) COLLECT_AS_BYTES(c10::Layout) COLLECT_AS_BYTES(c10::MemoryFormat) COLLECT_AS_BYTES(int8_t) COLLECT_AS_BYTES(int16_t) COLLECT_AS_BYTES(int32_t) COLLECT_AS_BYTES(int64_t) COLLECT_AS_BYTES(uint8_t) COLLECT_AS_BYTES(uint16_t) COLLECT_AS_BYTES(uint32_t) COLLECT_AS_BYTES(uint64_t) COLLECT_AS_BYTES(bool) COLLECT_AS_BYTES(float) COLLECT_AS_BYTES(double) #undef COLLECT_AS_BYTES void collect_hooks_from(Node* fn) { for (auto& i : fn->tensor_pre_hooks()) { i->compiled_args(*this); } for (auto& [_, i] : fn->retains_grad_hooks()) { i->compiled_args(*this); } for (auto& i : fn->pre_hooks()) { i->compiled_args(*this); } for (auto& i : fn->post_hooks()) { i->compiled_args(*this); } collect_size(_node_call.tensor_pre_hooks.size()); collect_size(_node_call.pre_hooks.size()); collect_size(_node_call.post_hooks.size()); for (const auto& h : _node_call.tensor_pre_hooks) { collect_size(static_cast(h.second)); } } CacheKey key() const { Node* node = _node_call.node.get(); return CacheKey( typeid(*node), _specialization_key, _specialization_key_size); } void collect_pynode_objs( const Node* pynode, c10::SafePyObject&& bwd, std::optional&& bwd_state) { size_t bwd_idx = _compiler.emplace_hook(std::move(bwd)); std::optional bwd_state_idx; if (auto state = std::move(bwd_state); state.has_value()) { bwd_state_idx = _compiler.emplace_hook(std::move(state.value())); } _compiler.pynode_objs.emplace( pynode, std::make_pair(bwd_idx, bwd_state_idx)); } void add_tensor_pre_hook(c10::SafePyObject&& obj, int index) { auto fn_id = _compiler.emplace_hook(std::move(obj)); collect_size(fn_id); _node_call.tensor_pre_hooks.emplace_back(fn_id, index); } void add_cpp_single_tensor_pre_hook( const std::function& hook, size_t idx) { auto wrapper = [hook](const at::TensorBase& grad) { // handle when hook returns nothing auto out = hook(grad); if (!out.defined()) { return grad; } return out; }; auto hook_id = _compiler.emplace_cpp_tensor_pre_hook(std::move(wrapper)); collect_size(hook_id); _node_call.cpp_tensor_pre_hooks.emplace_back(hook_id, idx); } void add_pre_hook(c10::SafePyObject&& obj) { auto fn_id = _compiler.emplace_hook(std::move(obj)); collect_size(fn_id); _node_call.pre_hooks.emplace_back(fn_id); } void add_post_hook(c10::SafePyObject&& obj) { auto fn_id = _compiler.emplace_hook(std::move(obj)); collect_size(fn_id); _node_call.post_hooks.emplace_back(fn_id); } void add_post_acc_grad_hook(c10::SafePyObject&& obj) { auto fn_id = _compiler.emplace_hook(std::move(obj)); collect_size(fn_id); _node_call.post_acc_grad_hooks.emplace_back(fn_id); } // Need to template the size_t to silence internal 32-bit build errors due to // a mix of -Werror, -Wtautological-type-limit-compare and // -Wunknown-pragmas template std::enable_if_t, void> collect_size(T s) { // we expect sizes to be small, so try to cram them into a single byte constexpr uint8_t encode_as_u64 = std::numeric_limits::max(); constexpr uint8_t encode_as_u32 = encode_as_u64 - 1; constexpr uint8_t encode_as_u16 = encode_as_u64 - 2; if (C10_UNLIKELY(s >= encode_as_u16)) { // first write a byte indicating the path we followed, then the data if (s <= std::numeric_limits::max()) { // 3 bytes specialize_on_bytes(encode_as_u16); specialize_on_bytes(static_cast(s)); } else if (s <= std::numeric_limits::max()) { // 5 bytes specialize_on_bytes(encode_as_u32); specialize_on_bytes(static_cast(s)); } else { // 9 bytes specialize_on_bytes(encode_as_u64); specialize_on_bytes(s); } } else { // happy case, 1 byte specialize_on_bytes(static_cast(s)); } } SizeInput::DynType set_default_dyn_type(SizeInput::DynType default_dyn_type) { return std::exchange(_compiler.default_dyn_type, default_dyn_type); } CompiledNodeArgs(AutogradCompilerCall& compiler, NodeCall& node_call) : _compiler(compiler), _node_call(node_call), _specialization_key( // NOLINTNEXTLINE(cppcoreguidelines-no-malloc) (uint8_t*)std::malloc(_specialization_key_storage)) {} CompiledNodeArgs(const CompiledNodeArgs&) = delete; CompiledNodeArgs(CompiledNodeArgs&&) = delete; CompiledNodeArgs& operator=(const CompiledNodeArgs&) = delete; CompiledNodeArgs& operator=(CompiledNodeArgs&&) = delete; ~CompiledNodeArgs() { // NOLINTNEXTLINE(cppcoreguidelines-no-malloc) std::free(_specialization_key); } private: template void specialize_on_bytes(const T& t) { while (C10_UNLIKELY( _specialization_key_size + sizeof(T) > _specialization_key_storage)) { _specialization_key_storage *= 2; // NOLINTNEXTLINE(cppcoreguidelines-no-malloc) _specialization_key = (uint8_t*)std::realloc( _specialization_key, _specialization_key_storage); } std::memcpy(_specialization_key + _specialization_key_size, &t, sizeof(T)); _specialization_key_size += sizeof(T); } AutogradCompilerCall& _compiler; NodeCall& _node_call; size_t _specialization_key_size{0}; size_t _specialization_key_storage{1024}; uint8_t* _specialization_key; }; struct TraceState { TraceState(std::vector>&& ss, size_t num_outputs) : sym_sizes(std::move(ss)), outputs(num_outputs) {} void debug_asserts() { TORCH_INTERNAL_ASSERT(sym_sizes_index == sym_sizes.size()); } std::optional next_sym_size() { TORCH_INTERNAL_ASSERT(sym_sizes_index < sym_sizes.size()); return sym_sizes[sym_sizes_index++]; } size_t sym_sizes_index{0}; std::vector> sym_sizes; variable_list outputs; }; class SwapSavedVariables { // SwapSavedVariables is used during the tracing/compilation phase after a // cache-miss. It swaps any 'lifted' inputs (tensors, symints) to proxy nodes, // allows tracing to happen, then swaps them back afterwards. public: std::pair> retrieve_pynode_objs( Node* pynode) const { auto it = compiler.pynode_objs.find(pynode); TORCH_INTERNAL_ASSERT(it != compiler.pynode_objs.end()); return it->second; } void before(at::Tensor& t) { TensorArg& arg = compiler.tensor_args.lookup(t); stashed_tensors.save(&t, std::move(t)); if (arg.defined()) { TORCH_INTERNAL_ASSERT(arg.proxy_tensor.defined()); t = arg.proxy_tensor; } } void after(at::Tensor& t) { stashed_tensors.restore(&t); } void before(SavedVariable& t) { if (auto it = compiler.sv_to_hooks.find(&t); it != compiler.sv_to_hooks.end()) { const auto& pyinterface = torch::dynamo::autograd::getPyCompilerInterface(); auto proxy_tensor = pyinterface->call_unpack( get_py_compiler(), it->second.first, it->second.second); stashed_variables.save(&t, std::move(t)); bool prior = at::SavedTensorDefaultHooks::set_tracing(true); t = SavedVariable(proxy_tensor, false); at::SavedTensorDefaultHooks::set_tracing(prior); } else { // no hooks, was already unpacked TensorArg& arg = compiler.tensor_args.lookup(t); stashed_variables.save(&t, std::move(t)); if (arg.defined()) { bool prior = at::SavedTensorDefaultHooks::set_tracing(true); TORCH_INTERNAL_ASSERT(arg.proxy_tensor.defined()); t = SavedVariable(arg.proxy_tensor, false); at::SavedTensorDefaultHooks::set_tracing(prior); } } } void after(SavedVariable& t) { stashed_variables.restore(&t); } void before(c10::SymInt& t) { stashed_symints.save(&t, c10::SymInt(t)); auto opt_value = state.next_sym_size(); if (opt_value.has_value()) { t = *opt_value; // dynamic shape } } void after(c10::SymInt& t) { stashed_symints.restore(&t); } void before(at::IValue& iv) { if (iv.isTensor()) { before(iv.toTensor()); } else { stashed_ivalues.save(&iv, at::IValue(iv)); if (iv.isInt() || iv.isSymInt() || iv.isDouble() || iv.isSymFloat()) { iv = compiler.lifted_ivalue_args.next_proxy(&iv); } } } void after(at::IValue& t) { if (t.isTensor()) { after(t.toTensor()); } else { stashed_ivalues.restore(&t); } } void before(Edge& t) { if (t.is_valid()) { // need for symints used by validate_outputs before(t.function->mutable_input_metadata(t.input_nr)); } } void after(Edge& t) { if (t.is_valid()) { after(t.function->mutable_input_metadata(t.input_nr)); } } void before(InputMetadata& t) { before(t.mutable_shape_as_dim_vector()); } void after(InputMetadata& t) { after(t.mutable_shape_as_dim_vector()); } void before(at::TensorGeometry& t) { before(t.mutable_sizes()); before(t.mutable_strides()); before(t.mutable_storage_offset()); t.recompute(); } void after(at::TensorGeometry& t) { after(t.mutable_sizes()); after(t.mutable_strides()); after(t.mutable_storage_offset()); t.recompute(); } void before(torch::autograd::TypeAndSize& t) { before(t.sym_sizes); before(t.options); } void after(torch::autograd::TypeAndSize& t) { after(t.sym_sizes); after(t.options); } void before(VariableInfo& t) { before(t.size); } void after(VariableInfo& t) { after(t.size); } template void before(std::vector& t) { for (T& i : t) { before(i); } } template void after(std::vector& t) { for (T& i : t) { after(i); } } template void before(c10::SmallVector& t) { for (T& i : t) { before(i); } } template void after(c10::SmallVector& t) { for (T& i : t) { after(i); } } template void before(c10::OptionalArray& t) { before(t.list); } template void after(c10::OptionalArray& t) { after(t.list); } template void before(std::optional& t) { if (t.has_value()) { before(*t); } } template void after(std::optional& t) { if (t.has_value()) { after(*t); } } template void before(ska::flat_hash_map& m) { std::vector keys; keys.reserve(m.size()); std::transform( m.begin(), m.end(), std::back_inserter(keys), [](const auto& entry) { return entry.first; }); std::sort(keys.begin(), keys.end()); for (auto& k : keys) { before(m.at(k)); } } template void after(ska::flat_hash_map& m) { for (auto& [_, v] : m) { after(v); } } #define NO_OP_VISIT(T) \ void before(const T&) {} \ void after(const T&) {} NO_OP_VISIT(caffe2::TypeMeta) NO_OP_VISIT(c10::Device) NO_OP_VISIT(c10::DeviceType) NO_OP_VISIT(c10::Layout) NO_OP_VISIT(c10::MemoryFormat) NO_OP_VISIT(c10::ScalarType) NO_OP_VISIT(c10::Scalar) NO_OP_VISIT(c10::TensorOptions) NO_OP_VISIT(std::string) NO_OP_VISIT(int64_t) NO_OP_VISIT(bool) NO_OP_VISIT(double) #undef NO_OP_VISIT SwapSavedVariables( AutogradCompilerCall& c, TraceState& s, PyObject* p, const NodeCall& n) : compiler(c), state(s), py_compiler(p), curr_node_call(n) {} PyObject* get_py_compiler() const { return py_compiler; } const NodeCall& get_curr_node_call() { return curr_node_call; } void debug_asserts() { stashed_variables.debug_assert(); stashed_tensors.debug_assert(); stashed_symints.debug_assert(); } private: template struct Stashed { Stashed(T&& v) : prior_value(std::move(v)) {} T prior_value; // Note: we need count here to support duplicate calls to before() // which happen when we have multiple autograd::Edge objects pointing // to the same autograd::Node int count = 1; }; template struct StashedVars : public std::unordered_map> { void save(const T* key, T&& value) { auto [it, inserted] = this->try_emplace(key, std::move(value)); if (!inserted) { // keep the value from the prior save() it->second.count++; } } void restore(T* var) { auto it = this->find(var); TORCH_INTERNAL_ASSERT(it != this->end(), "missing before())"); if (--it->second.count == 0) { // restore the value on the last restore() *var = std::move(it->second.prior_value); this->erase(it); } } void debug_assert() { TORCH_INTERNAL_ASSERT(this->empty(), "missing call to after()"); } }; // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) AutogradCompilerCall& compiler; // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) TraceState& state; // This is a borrowed reference, we do not increment ownership, or lower it, // it's lifecycle is entirely longer than this objects. PyObject* py_compiler; // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) const NodeCall& curr_node_call; // These mappings are used to save the prior values when we overwrite things // in before(). In after(), we use these to cleanup after ourselves. StashedVars stashed_variables; StashedVars stashed_tensors; StashedVars stashed_symints; StashedVars stashed_ivalues; }; // NOTE: [Compiled Autograd and backward functions] // Built-in autograd nodes have functional apply variants // (e.g. MulBackward0_apply_functional). Compiled Autograd's initial graph // capture wants to take a variant of this function and proxy it into the graph. // Every autograd node defines an apply_with_saved function, that when invoked, // proxies a call to a function into the Compiled Autograd graph. // // Some requirements that we have are: // - The proxy'ed function must have inputs that are FX-graphable types. // - Windows has a DLL symbol limit of 65536. // - Node::apply_with_saved is in libtorch_cpu which does not have direct access // to Python // // There were multiple ways to skin the cat, but what we end up doing is: // - for e.g. MulBackward0_apply_functional, we create a new C++ function // MulBackward0_apply_functional_ivalue that accepts vector. // - We define how to pack and unpack arbitrary C++ types into IValues. // - apply_with_saved passes MulBackward0_apply_functional_ivalue and // the IValue arguments to Python via an indirection. // In Python, these get proxy'ed into a graph. // Helper struct for packing/unpacking an arbitrary C++ type into a single // IValue. There are various full and partial specializations for IValuePacker // to handle packing specific types (like TensorOptions) into an IValue. template struct IValuePacker { // Defines how to pack T into an IValue. static at::IValue pack(const T& t) { return t; } // Defines how to unpack an IValue into T. static T unpack(const at::IValue& t) { return t.to(); } // Returns the TypePtr for the IValue (this is like the "type" of the IValue). // We use this when passing the packed IValue from Python to C++. // In Python, the IValue is just a PyObject* with the native type. // For example, it may be a Python int, a Python List[int], etc. // When passing this PyObject* into C++, we need to know how to parse it // into a C++ type that then gets put into an IValue. // That's what the TypePtr is for: it contains the information to do the // parsing. See torch::jit::toIValue for more information. static at::TypePtr packed_type() { // On windows CPU is support compiled autograd. #if defined(_WIN32) && (defined(USE_CUDA) || defined(USE_ROCM)) // NB: the if-constexpr usage triggers compilation errors on Windows // with certain compiler settings // (see https://github.com/pytorch/pytorch/pull/144707 for examples). // It's not clear what the problem is, so we're going to ignore it for now. TORCH_CHECK_NOT_IMPLEMENTED( false, "torch.compile not supported on Windows"); #else if constexpr (::std::is_same_v) { return at::TensorType::get(); } else if constexpr (::std::is_same_v) { return at::IntType::get(); } else if constexpr (::std::is_same_v) { return at::SymIntType::get(); } else if constexpr (::std::is_same_v) { return at::BoolType::get(); } else if constexpr (::std::is_same_v) { return at::FloatType::get(); } else if constexpr (::std::is_same_v) { return at::SymFloatType::get(); } else if constexpr (::std::is_same_v) { return at::SymBoolType::get(); } else if constexpr (::std::is_same_v) { return at::LayoutType::get(); } else if constexpr (::std::is_same_v) { return at::StringType::get(); } else if constexpr (::std::is_same_v) { return at::DeviceObjType::get(); } else if constexpr (::std::is_same_v) { return at::NumberType::get(); } else if constexpr (::std::is_same_v) { return at::MemoryFormatType::get(); } else if constexpr (::std::is_same_v) { return at::ScalarTypeType::get(); } else { // If you got here, you have probably added a member of a new type // to a built-in C++ autograd node. // Unfortunately, we don't know how to handle this type yet. // To get this new type to work with Compiled Autograd, please // either change it to be an IValue-constructible type, or // define how to pack and unpack an object of this time into an IValue // by creating a specialization of IValuePacker for this type. // See NOTE: [Compiled Autograd and backward functions] for context. TORCH_CHECK_NOT_IMPLEMENTED( false, "IValuePacker not implemented for type"); return at::NoneType::get(); } #endif } }; template <> struct IValuePacker { static at::IValue pack(const size_t& t) { // We generally use size_t as the size of a list of Tensors or number of // dimensions. The number of dimensions generally do not exceed 64 // (TensorIterator has that limitation), and lists of Tensors generally do // not exceed the int64_t max (you'd probably run out of RAM or run into // significant Tensor overhead). If you run into this limitation the fix is // to figure out how to pack size_t into int64_t. Note that size_t has some // weird behavior on Mac OS. uint64_t maximum_value = std::numeric_limits::max(); TORCH_INTERNAL_ASSERT( static_cast(t) <= maximum_value, "size_t too large to pack into IValue"); return static_cast(t); // pack as int64_t } static size_t unpack(const at::IValue& t) { return static_cast(t.toInt()); } static at::TypePtr packed_type() { return IValuePacker::packed_type(); } }; template <> struct IValuePacker> { static at::IValue pack(const std::vector& t) { return t; } static std::vector unpack(const at::IValue& t) { // We need this because there's no t.to>() override? return t.toSymIntVector(); } static at::TypePtr packed_type() { return at::ListType::create(at::SymIntType::get()); } }; template <> struct IValuePacker { static at::IValue pack(const VariableInfo& t) { auto tuple = std::make_tuple( t.layout, t.device, t.scalar_type, t.size, t.requires_grad, t.is_empty); return tuple; } static VariableInfo unpack(const at::IValue& t) { auto tuple = t.toTuple(); const auto& tuple_elements = tuple->elements(); const auto elements = tuple_elements.asArrayRef(); TORCH_INTERNAL_ASSERT(elements.size() == 6); VariableInfo v; v.layout = elements[0].toLayout(); v.device = elements[1].toDevice(); v.scalar_type = elements[2].toScalarType(); v.size = elements[3].toSymIntVector(); v.requires_grad = elements[4].toBool(); v.is_empty = elements[5].toBool(); return v; } static at::TypePtr packed_type() { return at::TupleType::create({ at::LayoutType::get(), at::DeviceObjType::get(), at::ScalarTypeType::get(), at::ListType::create(at::SymIntType::get()), at::BoolType::get(), at::BoolType::get(), }); } }; template <> struct IValuePacker { static at::IValue pack(const caffe2::TypeMeta& t) { return at::typeMetaToScalarType(t); // pack as at::ScalarType } static caffe2::TypeMeta unpack(const at::IValue& t) { return caffe2::TypeMeta::fromScalarType(t.to()); } static at::TypePtr packed_type() { return IValuePacker::packed_type(); } }; inline std::optional optTypeMetaToScalarType( const std::optional& t) { if (t.has_value()) { return at::typeMetaToScalarType(t.value()); } else { return std::nullopt; } } using packed_tensoroptions_t = std::tuple< std::optional, std::optional, std::optional, std::optional, std::optional, std::optional>; inline packed_tensoroptions_t pack_TensorOptions(const at::TensorOptions& t) { auto tuple = std::make_tuple( t.requires_grad_opt(), t.memory_format_opt(), t.device_opt(), optTypeMetaToScalarType(t.dtype_opt()), t.layout_opt(), t.pinned_memory_opt()); return tuple; } inline at::TensorOptions unpack_TensorOptions( const packed_tensoroptions_t& tuple) { at::TensorOptions result; auto maybe_requires_grad = std::get<0>(tuple); if (maybe_requires_grad.has_value()) { result = result.requires_grad(maybe_requires_grad); } auto maybe_memory_format = std::get<1>(tuple); if (maybe_memory_format.has_value()) { result = result.memory_format(maybe_memory_format); } auto maybe_device = std::get<2>(tuple); if (maybe_device.has_value()) { result = result.device(maybe_device.value()); } auto maybe_dtype = std::get<3>(tuple); if (maybe_dtype.has_value()) { result = result.dtype(caffe2::TypeMeta::fromScalarType(maybe_dtype.value())); } auto maybe_layout = std::get<4>(tuple); if (maybe_layout.has_value()) { result = result.layout(maybe_layout); } auto maybe_pinned_memory = std::get<5>(tuple); if (maybe_pinned_memory.has_value()) { result = result.pinned_memory(maybe_pinned_memory); } return result; } template <> struct IValuePacker { static at::IValue pack(const at::TensorOptions& t) { return pack_TensorOptions(t); } static at::TensorOptions unpack(const at::IValue& t) { auto tuple = t.to(); return unpack_TensorOptions(tuple); } static at::TypePtr packed_type() { return at::TupleType::create( {at::OptionalType::create(at::BoolType::get()), at::OptionalType::create(at::MemoryFormatType::get()), at::OptionalType::create(at::DeviceObjType::get()), at::OptionalType::create(at::ScalarTypeType::get()), at::OptionalType::create(at::LayoutType::get()), at::OptionalType::create(at::BoolType::get())}); } }; template <> struct IValuePacker { static at::IValue pack(const TypeAndSize& t) { auto tuple = std::make_tuple(t.sym_sizes, pack_TensorOptions(t.options)); return tuple; } static TypeAndSize unpack(const at::IValue& t) { auto tuple = t.to, packed_tensoroptions_t>>(); TypeAndSize result; result.sym_sizes = std::get<0>(tuple); result.options = unpack_TensorOptions(std::get<1>(tuple)); return result; } static at::TypePtr packed_type() { return at::TupleType::create( {IValuePacker>::packed_type(), IValuePacker::packed_type()}); } }; template struct IValuePacker> { static at::IValue pack(const std::optional& t) { if (t.has_value()) { return IValuePacker::pack(t.value()); } else { return std::nullopt; } } static std::optional unpack(const at::IValue& t) { if (t.isNone()) { return std::nullopt; } else { return IValuePacker::unpack(t); } } static at::TypePtr packed_type() { return at::OptionalType::create(IValuePacker::packed_type()); } }; template struct IValuePacker> { static at::IValue pack(const std::vector& t) { if constexpr (::std::is_constructible_v) { return t; } if (t.empty()) { auto lst = c10::impl::GenericList(at::AnyType::get()); return lst; } auto type_ptr = IValuePacker::pack(t[0]).type(); auto lst = c10::impl::GenericList(type_ptr); for (const auto& elt : t) { lst.emplace_back(IValuePacker::pack(elt)); } return lst; } static std::vector unpack(const at::IValue& t) { if constexpr (::std::is_constructible_v) { return t.to<::std::vector>(); } std::vector result; auto lst = t.toList(); for (size_t i = 0; i < lst.size(); ++i) { const at::IValue& elt = lst.get(i); result.emplace_back(IValuePacker::unpack(elt)); } return result; } static at::TypePtr packed_type() { return at::ListType::create(IValuePacker::packed_type()); } }; template struct IValuePacker> { static at::IValue pack(const c10::List& t) { return IValuePacker>::pack(t.vec()); } static c10::List unpack(const at::IValue& t) { return c10::List(IValuePacker>::unpack(t)); } static at::TypePtr packed_type() { return IValuePacker>::packed_type(); } }; template struct IValuePacker> { static at::IValue pack(const std::array& t) { std::vector result(t.begin(), t.end()); return IValuePacker>::pack(result); } static std::array unpack(const at::IValue& t) { std::array result; auto packed = IValuePacker>::unpack(t); for (size_t i = 0; i < packed.size(); i++) { result[i] = packed[i]; } return result; } static at::TypePtr packed_type() { return IValuePacker>::packed_type(); } }; template <> struct IValuePacker { static at::IValue pack(const at::TensorGeometry& t) { auto tuple = std::make_tuple( t.sym_sizes().vec(), t.sym_strides().vec(), t.sym_storage_offset()); return tuple; } static at::TensorGeometry unpack(const at::IValue& t) { auto tuple = t.to, std::vector, at::SymInt>>(); return at::TensorGeometry( std::get<0>(tuple), std::get<1>(tuple), std::get<2>(tuple)); } static at::TypePtr packed_type() { return at::TupleType::create( {IValuePacker>::packed_type(), IValuePacker>::packed_type(), at::SymIntType::get()}); } }; template <> struct IValuePacker { static at::IValue pack(const InputMetadata& t) { TORCH_INTERNAL_ASSERT(!t.is_nested_tensor()); auto tuple = std::make_tuple( pack_TensorOptions(t.options()), t.shape_as_dim_vector().vec(), t.is_tensor_subclass(), t.grad_dtype()); return tuple; } static InputMetadata unpack(const at::IValue& t) { auto tuple = t.to, bool, std::optional>>(); return InputMetadata( unpack_TensorOptions(std::get<0>(tuple)), SymIntSmallVec(std::get<1>(tuple)), std::get<2>(tuple), false, std::get<3>(tuple)); } static at::TypePtr packed_type() { return at::TupleType::create( {IValuePacker::packed_type(), IValuePacker>::packed_type(), at::BoolType::get(), IValuePacker>::packed_type()}); } }; template struct IValuePacker> { static at::IValue pack(const at::OptionalArray& t) { return IValuePacker>>::pack(t.list); } static at::OptionalArray unpack(const at::IValue& t) { auto result = IValuePacker>>::unpack(t); if (result.has_value()) { return {result.value()}; } else { return {}; } } static at::TypePtr packed_type() { return IValuePacker>>::packed_type(); } }; // This is a helper struct for packing and unpacking multiple arguments into // an ivalue_list. It leverages IValuePacker. struct PackedArgs { PackedArgs() = default; explicit PackedArgs(std::vector stack_) : stack(std::move(stack_)) {} const std::vector& vec() const { return stack; } template void pack(const T& t) { stack.emplace_back(IValuePacker::pack(t)); } template T unpack() { return IValuePacker::unpack(std::move(stack[idx++])); } void pack_saved_data(const ska::flat_hash_map& dct) { std::vector keys; std::vector values; for (const auto& [key, value] : dct) { keys.emplace_back(key); values.emplace_back(value); } pack(keys); for (const auto& value : values) { pack(value); } } ska::flat_hash_map unpack_saved_data() { ska::flat_hash_map dct; auto keys = unpack>(); for (const auto& key : keys) { dct.insert({key, std::move(stack[idx++])}); } return dct; } private: std::vector stack; int64_t idx = 0; }; } // namespace torch::dynamo::autograd template <> struct std::hash { size_t operator()(const torch::dynamo::autograd::CacheKey& k) const { return k.hash(); } };