#include #include #include #include #include namespace c10 { SymbolicShapeMeta::SymbolicShapeMeta(const SymbolicShapeMeta& other) // Non-mutables can be accessed outside the mutex : sizes_(other.sizes_), strides_(other.strides_), storage_offset_(other.storage_offset_), strides_valid_(other.strides_valid_) { std::scoped_lock lock(other.mutables_); // These must be copied under lock, so ignore clang-tidy here! // NOLINTBEGIN(cppcoreguidelines-prefer-member-initializer) numel_ = other.numel_; is_contiguous_ = other.is_contiguous_; is_channels_last_contiguous_ = other.is_channels_last_contiguous_; is_channels_last_3d_contiguous_ = other.is_channels_last_3d_contiguous_; is_channels_last_ = other.is_channels_last_; is_channels_last_3d_ = other.is_channels_last_3d_; is_non_overlapping_and_dense_ = other.is_non_overlapping_and_dense_; available_.store(other.available_.load()); // NOLINTEND(cppcoreguidelines-prefer-member-initializer) } // base, sizes, strides static std::optional< std::tuple, std::vector>> normalize_sym_sizes_strides(SymIntArrayRef sizes, SymIntArrayRef strides) { // Look for a SymNode to dispatch on SymNode base; bool all_hinted = true; // NB: sizes/strides guaranteed to be positive, so only need // is_heap_allocated for (const auto& s : sizes) { if (all_hinted && !s.has_hint()) { all_hinted = false; } if (!base && s.is_heap_allocated()) { base = s.toSymNode(); } } for (const auto& s : strides) { if (all_hinted && !s.has_hint()) { all_hinted = false; } if (!base && s.is_heap_allocated()) { base = s.toSymNode(); } } if (!base || all_hinted) { // Couldn't find. Tell the caller to do the normal computation // Alternately, if everything is hinted, we want the normal computation // too return std::nullopt; } // Populate the SymNode array std::vector size_nodes; std::vector stride_nodes; size_nodes.reserve(sizes.size()); stride_nodes.reserve(strides.size()); for (const auto& s : sizes) { size_nodes.emplace_back(s.wrap_node(base)); } for (const auto& s : strides) { stride_nodes.emplace_back(s.wrap_node(base)); } return std::tuple, std::vector>( std::move(base), std::move(size_nodes), std::move(stride_nodes)); } namespace { bool all_hinted( const c10::SymIntArrayRef& sizes, const c10::SymIntArrayRef& strides) { auto all_hinted = true; for (const auto& s : sizes) { if (!s.has_hint()) { return false; } } if (all_hinted) { for (const auto& s : strides) { if (!s.has_hint()) { return false; } } } return all_hinted; } } // namespace // Special treatment because of numel SymBool SymbolicShapeMeta::compute_contiguous() const { if (!strides_valid_) { return false; } c10::SymIntArrayRef sizes(sizes_); c10::SymIntArrayRef strides(strides_); auto result = _compute_contiguous_sym(sizes, strides, numel()); // If the result is already determined without guarding, just return it. auto maybe_as_bool = result.maybe_as_bool(); if (maybe_as_bool.has_value()) { return maybe_as_bool.value(); } if (all_hinted(sizes, strides)) { // We avoid going through the slow path if everything is hinted, // because evaluating a large SymPy expression can be expensive. // TODO exclude backed_size_oblivious from this path. return _compute_contiguous(sizes_, strides_, numel()); } return result; } SymBool SymbolicShapeMeta::compute_channels_last_contiguous_2d() const { if (!strides_valid_) { return false; } c10::SymIntArrayRef sizes(sizes_); c10::SymIntArrayRef strides(strides_); auto result = _compute_channels_last_contiguous_2d_sym(sizes, strides); // If the result is already determined without guarding, just return it. auto maybe_as_bool = result.maybe_as_bool(); if (maybe_as_bool.has_value()) { return maybe_as_bool.value(); } if (all_hinted(sizes, strides)) { // We avoid going through the slow path if everything is hinted, // because evaluating a large SymPy expression can be expensive. // TODO exclude backed_size_oblivious from this path. return _compute_channels_last_contiguous_2d(sizes_, strides_); } return result; } SymBool SymbolicShapeMeta::compute_channels_last_contiguous_3d() const { if (!strides_valid_) { return false; } c10::SymIntArrayRef sizes(sizes_); c10::SymIntArrayRef strides(strides_); auto result = _compute_channels_last_contiguous_3d_sym(sizes, strides); // If the result is already determined without guarding, just return it. auto maybe_as_bool = result.maybe_as_bool(); if (maybe_as_bool.has_value()) { return maybe_as_bool.value(); } if (all_hinted(sizes, strides)) { // We avoid going through the slow path if everything is hinted, // because evaluating a large SymPy expression can be expensive. // TODO exclude backed_size_oblivious from this path. return _compute_channels_last_contiguous_3d(sizes_, strides_); } return result; } // The rest of them #define DEFINE_EAGER_SYMBOOL_COMPUTE(name, fallback) \ SymBool SymbolicShapeMeta::name() const { \ if (!strides_valid_) { \ return false; \ } \ c10::SymIntArrayRef sizes(sizes_); \ c10::SymIntArrayRef strides(strides_); \ return fallback(sizes, strides); \ } #define DEFINE_SYMBOOL_COMPUTE(name, nodeimpl, fallback) \ SymBool SymbolicShapeMeta::name() const { \ if (!strides_valid_) { \ return false; \ } \ auto n = normalize_sym_sizes_strides(sizes_, strides_); \ if (n.has_value()) { \ auto [base, size_nodes, stride_nodes] = *n; \ return SymBool(base->nodeimpl(size_nodes, stride_nodes)); \ } else { \ c10::SymIntArrayRef sizes(sizes_); \ c10::SymIntArrayRef strides(strides_); \ return fallback(sizes, strides); \ } \ } // clang-format off DEFINE_EAGER_SYMBOOL_COMPUTE(compute_strides_like_channels_last_2d, is_channels_last_strides_2d) DEFINE_EAGER_SYMBOOL_COMPUTE(compute_strides_like_channels_last_3d, is_channels_last_strides_3d) DEFINE_SYMBOOL_COMPUTE(compute_non_overlapping_and_dense, is_non_overlapping_and_dense, _compute_non_overlapping_and_dense) // clang-format on #undef DEFINE_SYMBOOL_COMPUTE // Glue compute // NB: this logic very intentionally short circuits if possible. Without // short circuiting, it causes // python test/functorch/test_aotdispatch.py -k // test_aot_autograd_symbolic_exhaustive_nn_functional_unfold_cpu_float32 to run // very slowly. SymBool SymbolicShapeMeta::compute_is_non_overlapping_and_dense_dim4() const { init_is_contiguous(); if (guard_or_false(is_contiguous(), __FILE__, __LINE__)) { return true; } init_is_channels_last_contiguous(); if (guard_or_false(is_channels_last_contiguous(), __FILE__, __LINE__)) { return true; } return is_contiguous() | is_channels_last_contiguous() | compute_non_overlapping_and_dense(); } SymBool SymbolicShapeMeta::compute_channels_last_contiguous_3d_dim5() const { init_is_channels_last_contiguous(); if (guard_or_false(is_channels_last_contiguous(), __FILE__, __LINE__)) { return false; } return ~is_channels_last_contiguous() & compute_channels_last_contiguous_3d(); } SymBool SymbolicShapeMeta::compute_channels_last_2d_dim5() const { init_is_channels_last_3d_contiguous(); if (guard_or_false(is_channels_last_3d_contiguous(), __FILE__, __LINE__)) { return false; } return ~is_channels_last_3d_contiguous() & compute_strides_like_channels_last_2d(); } SymBool SymbolicShapeMeta::compute_channels_last_3d_dim5() const { if (guard_or_false(is_channels_last(), __FILE__, __LINE__)) { return false; } return ~is_channels_last() & compute_strides_like_channels_last_3d(); } SymBool SymbolicShapeMeta::compute_is_non_overlapping_and_dense_dim5() const { if (guard_or_false(is_contiguous(), __FILE__, __LINE__)) { return true; } if (guard_or_false(is_channels_last_contiguous(), __FILE__, __LINE__)) { return true; } if (guard_or_false(is_channels_last_3d_contiguous(), __FILE__, __LINE__)) { return true; } return is_contiguous() | is_channels_last_contiguous() | is_channels_last_3d_contiguous() | compute_non_overlapping_and_dense(); } SymBool SymbolicShapeMeta::compute_is_non_overlapping_and_dense_anydim() const { if (guard_or_false(is_contiguous(), __FILE__, __LINE__)) { return true; } return is_contiguous() | compute_non_overlapping_and_dense(); } void SymbolicShapeMeta::set_numel(SymInt val) const { std::scoped_lock lock(mutables_); if (has_numel()) { return; } numel_ = std::move(val); available_.fetch_or(numel_avail); } void SymbolicShapeMeta::set_is_contiguous(SymBool val) const { std::scoped_lock lock(mutables_); if (has_is_contiguous()) { return; } is_contiguous_ = std::move(val); available_.fetch_or(is_contiguous_avail); } void SymbolicShapeMeta::set_is_channels_last_contiguous(SymBool val) const { std::scoped_lock lock(mutables_); if (has_is_channels_last_contiguous()) { return; } is_channels_last_contiguous_ = std::move(val); available_.fetch_or(is_channels_last_contiguous_avail); } void SymbolicShapeMeta::set_is_channels_last_3d_contiguous(SymBool val) const { std::scoped_lock lock(mutables_); if (has_is_channels_last_3d_contiguous()) { return; } is_channels_last_3d_contiguous_ = std::move(val); available_.fetch_or(is_channels_last_3d_contiguous_avail); } void SymbolicShapeMeta::set_is_channels_last(SymBool val) const { std::scoped_lock lock(mutables_); if (has_is_channels_last()) { return; } is_channels_last_ = std::move(val); available_.fetch_or(is_channels_last_avail); } void SymbolicShapeMeta::set_is_channels_last_3d(SymBool val) const { std::scoped_lock lock(mutables_); if (has_is_channels_last_3d()) { return; } is_channels_last_3d_ = std::move(val); available_.fetch_or(is_channels_last_3d_avail); } void SymbolicShapeMeta::set_is_non_overlapping_and_dense(SymBool val) const { std::scoped_lock lock(mutables_); if (has_is_non_overlapping_and_dense()) { return; } is_non_overlapping_and_dense_ = std::move(val); available_.fetch_or(is_non_overlapping_and_dense_avail); } void SymbolicShapeMeta::init_numel() const { set_numel(multiply_integers(sizes_)); } void SymbolicShapeMeta::init_is_contiguous() const { set_is_contiguous(compute_contiguous()); } void SymbolicShapeMeta::init_is_channels_last_contiguous() const { set_is_channels_last_contiguous([&] { switch (dim()) { case 5: case 4: { return compute_channels_last_contiguous_2d(); } default: return SymBool{false}; } }()); } void SymbolicShapeMeta::init_is_channels_last_3d_contiguous() const { set_is_channels_last_3d_contiguous([&] { switch (dim()) { case 5: return compute_channels_last_contiguous_3d_dim5(); default: return SymBool{false}; } }()); } void SymbolicShapeMeta::init_is_channels_last() const { set_is_channels_last([&] { switch (dim()) { case 5: return compute_channels_last_2d_dim5(); case 4: return compute_strides_like_channels_last_2d(); default: return SymBool{false}; } }()); } void SymbolicShapeMeta::init_is_channels_last_3d() const { set_is_channels_last_3d([&] { switch (dim()) { case 5: return compute_channels_last_3d_dim5(); default: return SymBool{false}; } }()); } void SymbolicShapeMeta::init_is_non_overlapping_and_dense() const { set_is_non_overlapping_and_dense([&] { switch (dim()) { case 5: return compute_is_non_overlapping_and_dense_dim5(); case 4: return compute_is_non_overlapping_and_dense_dim4(); default: return compute_is_non_overlapping_and_dense_anydim(); } }()); } } // namespace c10