Wavelet Tree
Key Points
- •A wavelet tree is a recursive data structure built over a sequence’s alphabet that answers rank, select, and quantile (k-th smallest) queries in O(log time, where σ is the number of distinct values.
- •It partitions the array level-by-level using value ranges (or bits), storing a prefix-count bitvector per level to route any subarray to the correct child.
- •Coordinate compression reduces the alphabet to 0.. so height becomes O(log instead of O(log U) where U is the raw integer universe.
- •Classic queries supported include: count of in [l,r], count of value x in [l,r], k-th smallest in [l,r], predecessor/successor in [l,r], and sometimes select of the k-th occurrence.
- •Space usage is O(n log bits (or O(n log integers if implemented plainly), and build time is O(n log
- •Wavelet trees transform 1D range statistics into 2D dominance queries on points (i, A[i]), enabling rectangle counting and quantiles.
- •They are static by default; dynamic updates are nontrivial and require advanced bitvectors or different structures (e.g., Fenwick/segment trees or dynamic succinct dictionaries).
- •Augmentations (like prefix sums per node) let you compute sums of the k smallest elements or sum of in a range, still in O(log
Prerequisites
- →Prefix sums and rank/select basics — Wavelet trees rely on counting elements in prefixes and mapping between levels using prefix counts (rank).
- →Coordinate compression — Reduces the alphabet to 0..σ-1 so height becomes O(log σ) and comparisons use ranks.
- →Binary search and recursion — Queries descend logarithmically by making binary decisions on value ranges with proper index mapping.
- →Asymptotic analysis (Big-O) — To reason about O(log σ) query time and O(n log σ) build/space costs.
- →Order statistics — Understanding k-th smallest and quantiles clarifies what the wavelet tree returns.
- →2D range counting perspective — Interpreting (i, A[i]) as points makes many queries intuitively clear and suggests extensions.
Detailed Explanation
Tap terms for definitions01Overview
A wavelet tree is a compact, recursive data structure for sequences that supports fast order-statistics and frequency queries. Think of an array A of length n whose values come from a set (alphabet) of size σ. The wavelet tree hierarchically partitions A by value ranges: the root covers the whole value range, the left child covers the lower half, and the right child the upper half, and so on. At each node, we keep a prefix-count array (a bitvector with ranks) that records, for every prefix of the subsequence at that node, how many elements go to the left child. Using this information, any subarray [l,r] can be translated to the corresponding subarray in either child in O(1). Repeating this for O(log σ) levels answers queries like “how many numbers ≤ x in [l,r]?” or “what is the k-th smallest number in [l,r]?”
The power of wavelet trees is that they combine the speed of binary search on the value domain with the stability of maintaining original order information via the per-level bitvectors. After coordinate compression (mapping values to ranks 0..σ-1), the tree height is O(log σ), ensuring operations are fast even when raw values are large or negative. While the classic implementation is static (no updates), the structure is extremely useful in competitive programming and information retrieval, and can be further augmented (e.g., with prefix sums) to support sum queries of the smallest k elements in a range.
02Intuition & Analogies
Imagine sorting socks by color using a sequence of yes/no questions. First, you ask, “Is the color index ≤ mid?” All socks answering “yes” go left, others go right. You record, for each sock in the original lineup, whether it went left. Now, if a friend asks, “Among socks 10 to 30, how many are blue or earlier colors?” you can simulate the same sequence of questions but only for that subline. At each step, the recorded left-counts let you jump to the appropriate subline on the next level without actually moving socks again.
Another analogy: Think of a decision tree that splits movie ratings into low vs. high at each level. You also keep a running counter along the original timeline of ratings saying “how many went left so far?” If someone asks, “Among days 100..200, what’s the 10th lowest rating?” you follow the tree from the root down: at each split you check how many of those days went left; if at least 10, you go left; otherwise you subtract and go right. In just a handful of steps (one per level), you find the exact value.
Finally, picture the array as 2D points (i, A[i]). Counting how many numbers in a range [l,r] are ≤ x becomes counting points in the rectangle x ∈ (-∞, x], i ∈ [l, r]. The wavelet tree is essentially a compact index that lets you navigate these rectangles quickly by narrowing the value range at each level, while preserving the original indices via the bitvector prefixes.
03Formal Definition
04When to Use
Use a wavelet tree when you need fast order-statistic and frequency queries on static arrays:
- k-th smallest in a subarray [l, r] (quantiles) for median, percentiles, and selection.
- Count of values ≤ x or in a value range [x1, x2] within [l, r] (range counting and dominance queries).
- Count of occurrences of a value x in [l, r] (rank on a value).
- Predecessor/successor in [l, r] (largest ≤ x, smallest ≥ x) via counts and quantiles.
- Sum of the k smallest (or sum of values ≤ x) in [l, r] when augmented with per-node prefix sums.
They are particularly strong for competitive programming tasks that require many queries with no updates, or offline updates that can be transformed. If your data is dynamic with frequent point updates, consider alternatives like a Fenwick/segment tree of ordered containers, a persistent segment tree, or a dynamic wavelet tree with sophisticated bitvectors (advanced). When σ is small relative to n, wavelet trees are especially memory- and time-efficient; for very large universes but few distinct values, coordinate compression keeps height small.
⚠️Common Mistakes
- Forgetting coordinate compression: building by integer mid on a huge universe U yields O(log U) height even if only σ values appear. Always compress to ranks 0..σ-1 to achieve O(log σ).
- Off-by-one errors in mapping ranges across levels: remember map to children uses prefix counts B[l-1] and B[r], and node sequences are 1-indexed in many references. Keep indices consistent throughout.
- Misinterpreting rank and select: rank_x(i) counts occurrences of x in A[1..i]; select_x(k) returns the position of the k-th x. Implementations that only store rank usually need extra logic or binary search to simulate select.
- Assuming dynamic updates are easy: the classic wavelet tree is static. Supporting updates or range swaps requires advanced dynamic bitvectors or different data structures.
- Memory blow-up: storing full arrays per level with 32-bit ints can be heavy. Prefer bitvectors or 32-bit prefix counts and compress values; avoid storing redundant per-level sequences after building.
- Confusing ≤ vs. <: when implementing count ≤ x or < x, be explicit about whether you include equals. Ensure the pivot comparison matches your split (e.g., goLeft if rank ≤ mid).
Key Formulas
Tree Height
Explanation: The wavelet tree has height equal to the number of times you can split the alphabet in half until reaching singletons. Fewer distinct values mean a shorter tree.
Build Complexity
Explanation: At each of O(log levels we partition the current sequence once, resulting in linear work per level.
Space Complexity
Explanation: Each level stores a bitvector (or prefix-count array) for the n elements; across O(log levels this yields O(n log bits (plus overhead).
Query Time
Explanation: Each query descends one level of the tree at a time, performing O(1) work per level to remap [l,r].
Range Count by Two LTEs
Explanation: Counting values in a closed interval reduces to two counts of values not exceeding a bound.
Dominance Interpretation
Explanation: Counting in index range [l,r] equals counting points in a 2D rectangle, which the wavelet tree supports efficiently.
Build Recurrence
Explanation: Each node partitions its sequence into left and right subsequences and recurses; the sum across a level is linear, over O(log levels.
Sum-k Smallest Recurrence
Explanation: At each node, if k exceeds the number going left (), take all left contributions () and continue with the right child for the remaining k.
Complexity Analysis
Code Examples
1 #include <bits/stdc++.h> 2 using namespace std; 3 4 // Wavelet Tree over compressed ranks [lo, hi] with per-node prefix counts (map-left) 5 // Supports: 6 // - kth(l,r,k): k-th smallest in A[l..r] 7 // - lte(l,r,x): count of values <= x in A[l..r] 8 // - countEqual(l,r,x): count of value x in A[l..r] 9 // - selectKthOccurrence(x,k): position of k-th occurrence of x in the whole array (O(log n * log sigma)) 10 // Indices l,r are 1-based inside the structure. 11 12 struct WTNode { 13 int lo, hi; // value ranks covered by this node (inclusive) 14 vector<int> prefLeft; // prefLeft[i] = #elements routed to left among first i elements (i from 0..m) 15 WTNode *L = nullptr, *R = nullptr; 16 }; 17 18 struct WaveletTree { 19 WTNode* root = nullptr; 20 vector<int> values; // original values (for mapping back) 21 vector<int> uniq; // sorted unique values for rank<->value mapping 22 23 // Build from 1-based array of values (will be compressed internally) 24 WaveletTree(const vector<int>& a) { buildFromValues(a); } 25 WaveletTree() {} 26 27 // Map value -> rank in [0..sigma-1] 28 int rankOf(int x) const { 29 return int(lower_bound(uniq.begin(), uniq.end(), x) - uniq.begin()); 30 } 31 32 // Map rank -> original value 33 int valueOf(int r) const { return uniq[r]; } 34 35 // kth smallest in A[l..r] (1-based), returns original value 36 int kth(int l, int r, int k) const { 37 if (!root || l > r || k < 1 || k > (r - l + 1)) return INT_MIN; // invalid 38 int rk = kthRank(root, l, r, k); 39 return valueOf(rk); 40 } 41 42 // Count of values <= x in A[l..r] 43 int lte(int l, int r, int x) const { 44 if (!root || l > r) return 0; 45 int rx = int(upper_bound(uniq.begin(), uniq.end(), x) - uniq.begin()) - 1; // largest rank with value <= x 46 if (rx < 0) return 0; // all greater 47 if (rx >= (int)uniq.size()-1) return (r - l + 1); // all <= x 48 return lteRank(root, l, r, rx); 49 } 50 51 // Count of occurrences of value x in A[l..r] 52 int countEqual(int l, int r, int x) const { 53 if (!root || l > r) return 0; 54 int rx = rankOf(x); 55 if (rx >= (int)uniq.size() || uniq[rx] != x) return 0; 56 return countRank(root, l, r, rx); 57 } 58 59 // Position of k-th occurrence of value x in the WHOLE array (1-based). Returns -1 if nonexistent. 60 // Implemented using binary search over prefLeft at each level: O(log n * log sigma). 61 int selectKthOccurrence(int x, int k) const { 62 if (!root) return -1; 63 int rx = rankOf(x); 64 if (rx >= (int)uniq.size() || uniq[rx] != x) return -1; 65 // Descend to leaf to ensure x exists k times, while mapping position within subsequences 66 int pos = selectInNode(root, rx, k); 67 return pos; // position in root sequence equals index in original array 68 } 69 70 private: 71 void buildFromValues(const vector<int>& a) { 72 int n = (int)a.size() - 1; // expect 1-based input 73 values = a; 74 uniq = vector<int>(a.begin()+1, a.end()); 75 sort(uniq.begin(), uniq.end()); 76 uniq.erase(unique(uniq.begin(), uniq.end()), uniq.end()); 77 // compress to ranks 78 vector<int> ranks(n+1); 79 for (int i = 1; i <= n; ++i) ranks[i] = rankOf(values[i]); 80 root = buildNode(ranks, 1, n, 0, (int)uniq.size()-1); 81 } 82 83 WTNode* buildNode(const vector<int>& arr, int Lidx, int Ridx, int lo, int hi) { 84 WTNode* node = new WTNode(); 85 node->lo = lo; node->hi = hi; 86 int m = Ridx - Lidx + 1; 87 node->prefLeft.assign(m + 1, 0); 88 if (lo == hi || m <= 0) return node; 89 int mid = (lo + hi) >> 1; 90 // Stable partition into left (<= mid) and right (> mid) 91 vector<int> leftPart; leftPart.reserve(m); 92 vector<int> rightPart; rightPart.reserve(m); 93 for (int i = 0; i < m; ++i) { 94 int isLeft = (arr[Lidx + i] <= mid); 95 node->prefLeft[i+1] = node->prefLeft[i] + isLeft; 96 } 97 for (int i = 0; i < m; ++i) { 98 int v = arr[Lidx + i]; 99 if (v <= mid) leftPart.push_back(v); 100 else rightPart.push_back(v); 101 } 102 // Prepare arrays for children 103 if (!leftPart.empty()) node->L = buildNodeWithArray(leftPart, lo, mid); 104 else node->L = buildEmpty(lo, mid); 105 if (!rightPart.empty()) node->R = buildNodeWithArray(rightPart, mid+1, hi); 106 else node->R = buildEmpty(mid+1, hi); 107 return node; 108 } 109 110 // Helper: build node directly from an array segment [0..m-1] with known [lo,hi] 111 WTNode* buildNodeWithArray(const vector<int>& vec, int lo, int hi) { 112 WTNode* node = new WTNode(); 113 node->lo = lo; node->hi = hi; 114 int m = (int)vec.size(); 115 node->prefLeft.assign(m + 1, 0); 116 if (lo == hi || m <= 0) return node; 117 int mid = (lo + hi) >> 1; 118 vector<int> leftPart; leftPart.reserve(m); 119 vector<int> rightPart; rightPart.reserve(m); 120 for (int i = 0; i < m; ++i) { 121 int isLeft = (vec[i] <= mid); 122 node->prefLeft[i+1] = node->prefLeft[i] + isLeft; 123 } 124 for (int i = 0; i < m; ++i) { 125 if (vec[i] <= mid) leftPart.push_back(vec[i]); 126 else rightPart.push_back(vec[i]); 127 } 128 if (!leftPart.empty()) node->L = buildNodeWithArray(leftPart, lo, mid); 129 else node->L = buildEmpty(lo, mid); 130 if (!rightPart.empty()) node->R = buildNodeWithArray(rightPart, mid+1, hi); 131 else node->R = buildEmpty(mid+1, hi); 132 return node; 133 } 134 135 WTNode* buildEmpty(int lo, int hi) { 136 WTNode* node = new WTNode(); 137 node->lo = lo; node->hi = hi; 138 node->prefLeft.assign(1, 0); 139 return node; 140 } 141 142 // kth on ranks in [l,r] within node; returns rank 143 int kthRank(WTNode* node, int l, int r, int k) const { 144 if (node->lo == node->hi) return node->lo; 145 int leftInRange = node->prefLeft[r] - node->prefLeft[l-1]; 146 if (k <= leftInRange) { 147 int nl = node->prefLeft[l-1] + 1; 148 int nr = node->prefLeft[r]; 149 return kthRank(node->L, nl, nr, k); 150 } else { 151 int nl = (l - 1) - node->prefLeft[l-1] + 1; 152 int nr = r - node->prefLeft[r]; 153 return kthRank(node->R, nl, nr, k - leftInRange); 154 } 155 } 156 157 int lteRank(WTNode* node, int l, int r, int rx) const { 158 if (l > r || rx < node->lo) return 0; 159 if (node->hi <= rx) return r - l + 1; 160 int mid = (node->lo + node->hi) >> 1; 161 int nl = node->prefLeft[l-1] + 1; 162 int nr = node->prefLeft[r]; 163 int leftCnt = lteRank(node->L, nl, nr, rx); 164 int rl = (l - 1) - node->prefLeft[l-1] + 1; 165 int rr = r - node->prefLeft[r]; 166 int rightCnt = 0; 167 if (rx > mid) rightCnt = lteRank(node->R, rl, rr, rx); 168 return leftCnt + rightCnt; 169 } 170 171 int countRank(WTNode* node, int l, int r, int rx) const { 172 if (l > r || rx < node->lo || rx > node->hi) return 0; 173 if (node->lo == node->hi) return r - l + 1; 174 int mid = (node->lo + node->hi) >> 1; 175 if (rx <= mid) { 176 int nl = node->prefLeft[l-1] + 1; 177 int nr = node->prefLeft[r]; 178 return countRank(node->L, nl, nr, rx); 179 } else { 180 int nl = (l - 1) - node->prefLeft[l-1] + 1; 181 int nr = r - node->prefLeft[r]; 182 return countRank(node->R, nl, nr, rx); 183 } 184 } 185 186 // Select: position in node's sequence of k-th occurrence of value with rank rx 187 int selectInNode(WTNode* node, int rx, int k) const { 188 if (k <= 0) return -1; 189 if (node->lo == node->hi) { 190 // In this node, any position is valid; return k if within bounds 191 int m = (int)node->prefLeft.size() - 1; // sequence length 192 return (k <= m) ? k : -1; 193 } 194 int mid = (node->lo + node->hi) >> 1; 195 if (rx <= mid) { 196 int posChild = selectInNode(node->L, rx, k); 197 if (posChild == -1) return -1; 198 // Map posChild (in left child) to position in this node: first i s.t. prefLeft[i] >= posChild 199 int m = (int)node->prefLeft.size() - 1; 200 int lo = 1, hi = m, ans = -1; 201 while (lo <= hi) { 202 int md = (lo + hi) >> 1; 203 if (node->prefLeft[md] >= posChild) { ans = md; hi = md - 1; } 204 else lo = md + 1; 205 } 206 return ans; 207 } else { 208 int posChild = selectInNode(node->R, rx, k); 209 if (posChild == -1) return -1; 210 // Map posChild in right child to position in this node using g(i) = i - prefLeft[i] 211 int m = (int)node->prefLeft.size() - 1; 212 int lo = 1, hi = m, ans = -1; 213 auto g = [&](int i){ return i - node->prefLeft[i]; }; 214 while (lo <= hi) { 215 int md = (lo + hi) >> 1; 216 if (g(md) >= posChild) { ans = md; hi = md - 1; } 217 else lo = md + 1; 218 } 219 return ans; 220 } 221 } 222 }; 223 224 int main() { 225 ios::sync_with_stdio(false); 226 cin.tie(nullptr); 227 228 // Example usage 229 // Input array (1-based): negatives allowed 230 vector<int> A = {0, 5, -2, 7, 5, 9, -2, 10}; // A[1..7] 231 232 WaveletTree wt(A); 233 234 // Queries 235 cout << "kth(2,7,3) = " << wt.kth(2,7,3) << "\n"; // 3rd smallest in A[2..7] 236 cout << "lte(1,7,5) = " << wt.lte(1,7,5) << "\n"; // count of <=5 in A[1..7] 237 cout << "countEqual(1,7,5) = " << wt.countEqual(1,7,5) << "\n"; // # of 5s 238 239 // Select: find position of 2nd occurrence of 5 in the whole array 240 cout << "select 5, k=2 -> position = " << wt.selectKthOccurrence(5,2) << "\n"; 241 242 return 0; 243 } 244
This program builds a compressed wavelet tree from a 1-based array and supports k-th smallest, count of values ≤ x, count of a specific x, and select (k-th occurrence) over the whole array. Coordinate compression keeps the height small. The mapping of ranges across levels uses the prefix counts. Select is implemented via binary search on prefix counts per level, giving O(log n · log σ) time.
1 #include <bits/stdc++.h> 2 using namespace std; 3 4 struct WTNode { 5 int lo, hi; 6 vector<int> prefLeft; 7 WTNode *L = nullptr, *R = nullptr; 8 }; 9 10 struct WaveletTree { 11 WTNode* root = nullptr; 12 vector<int> uniq; 13 14 WaveletTree() {} 15 WaveletTree(const vector<int>& a) { build(a); } 16 17 void build(const vector<int>& a) { 18 int n = (int)a.size() - 1; // 1-based 19 vector<int> uniqv(a.begin()+1, a.end()); 20 sort(uniqv.begin(), uniqv.end()); uniqv.erase(unique(uniqv.begin(), uniqv.end()), uniqv.end()); 21 uniq.swap(uniqv); 22 vector<int> ranks(n+1); 23 for (int i = 1; i <= n; ++i) ranks[i] = rankOf(a[i]); 24 root = buildNode(ranks, 1, n, 0, (int)uniq.size()-1); 25 } 26 27 int kth(int l, int r, int k) const { return uniq[kthRank(root,l,r,k)]; } 28 int lte(int l, int r, int x) const { 29 int rx = int(upper_bound(uniq.begin(), uniq.end(), x) - uniq.begin()) - 1; 30 if (rx < 0) return 0; if (rx >= (int)uniq.size()-1) return (r-l+1); 31 return lteRank(root,l,r,rx); 32 } 33 int countEqual(int l, int r, int x) const { 34 int rx = int(lower_bound(uniq.begin(), uniq.end(), x) - uniq.begin()); 35 if (rx >= (int)uniq.size() || uniq[rx] != x) return 0; 36 return countRank(root,l,r,rx); 37 } 38 int countInRange(int l, int r, int x1, int x2) const { 39 if (x1 > x2) return 0; 40 return lte(l,r,x2) - lte(l,r,x1-1); 41 } 42 43 private: 44 int rankOf(int x) const { return int(lower_bound(uniq.begin(), uniq.end(), x) - uniq.begin()); } 45 46 WTNode* buildNode(const vector<int>& arr, int Lidx, int Ridx, int lo, int hi) const { 47 WTNode* node = new WTNode(); 48 node->lo = lo; node->hi = hi; 49 int m = Ridx - Lidx + 1; 50 node->prefLeft.assign(m+1, 0); 51 if (lo == hi || m <= 0) return node; 52 int mid = (lo + hi) >> 1; 53 vector<int> leftPart; leftPart.reserve(m); 54 vector<int> rightPart; rightPart.reserve(m); 55 for (int i = 0; i < m; ++i) node->prefLeft[i+1] = node->prefLeft[i] + (arr[Lidx+i] <= mid); 56 for (int i = 0; i < m; ++i) ((arr[Lidx+i] <= mid) ? leftPart : rightPart).push_back(arr[Lidx+i]); 57 node->L = buildNode(leftPart, 0, (int)leftPart.size()-1, lo, mid); 58 node->R = buildNode(rightPart, 0, (int)rightPart.size()-1, mid+1, hi); 59 return node; 60 } 61 62 // Overload to build from a 0-based vector segment [L..R] 63 WTNode* buildNode(const vector<int>& arr, int L, int R, int lo, int hi) const { 64 WTNode* node = new WTNode(); 65 node->lo = lo; node->hi = hi; 66 int m = (R>=L) ? (R-L+1) : 0; 67 node->prefLeft.assign(m+1, 0); 68 if (lo == hi || m <= 0) return node; 69 int mid = (lo + hi) >> 1; 70 vector<int> leftPart; leftPart.reserve(m); 71 vector<int> rightPart; rightPart.reserve(m); 72 for (int i = 0; i < m; ++i) node->prefLeft[i+1] = node->prefLeft[i] + (arr[L+i] <= mid); 73 for (int i = 0; i < m; ++i) ((arr[L+i] <= mid) ? leftPart : rightPart).push_back(arr[L+i]); 74 node->L = buildNode(leftPart, 0, (int)leftPart.size()-1, lo, mid); 75 node->R = buildNode(rightPart, 0, (int)rightPart.size()-1, mid+1, hi); 76 return node; 77 } 78 79 int kthRank(WTNode* node, int l, int r, int k) const { 80 if (node->lo == node->hi) return node->lo; 81 int leftIn = node->prefLeft[r] - node->prefLeft[l-1]; 82 if (k <= leftIn) { 83 int nl = node->prefLeft[l-1] + 1; 84 int nr = node->prefLeft[r]; 85 return kthRank(node->L, nl, nr, k); 86 } else { 87 int nl = (l-1) - node->prefLeft[l-1] + 1; 88 int nr = r - node->prefLeft[r]; 89 return kthRank(node->R, nl, nr, k - leftIn); 90 } 91 } 92 93 int lteRank(WTNode* node, int l, int r, int rx) const { 94 if (l > r || rx < node->lo) return 0; 95 if (node->hi <= rx) return r - l + 1; 96 int nl = node->prefLeft[l-1] + 1; 97 int nr = node->prefLeft[r]; 98 int leftCnt = lteRank(node->L, nl, nr, rx); 99 int rl = (l-1) - node->prefLeft[l-1] + 1; 100 int rr = r - node->prefLeft[r]; 101 int rightCnt = lteRank(node->R, rl, rr, rx); 102 return leftCnt + rightCnt; 103 } 104 105 int countRank(WTNode* node, int l, int r, int rx) const { 106 if (l > r || rx < node->lo || rx > node->hi) return 0; 107 if (node->lo == node->hi) return r - l + 1; 108 int mid = (node->lo + node->hi) >> 1; 109 if (rx <= mid) { 110 int nl = node->prefLeft[l-1] + 1; 111 int nr = node->prefLeft[r]; 112 return countRank(node->L, nl, nr, rx); 113 } else { 114 int nl = (l-1) - node->prefLeft[l-1] + 1; 115 int nr = r - node->prefLeft[r]; 116 return countRank(node->R, nl, nr, rx); 117 } 118 } 119 }; 120 121 int main(){ 122 ios::sync_with_stdio(false); 123 cin.tie(nullptr); 124 125 int n, q; cin >> n >> q; 126 vector<int> A(n+1); 127 for (int i = 1; i <= n; ++i) cin >> A[i]; 128 WaveletTree wt(A); 129 130 // Supported queries: 131 // K l r k -> k-th smallest in [l,r] 132 // C l r x -> count of x in [l,r] 133 // L l r x -> count of values <= x in [l,r] 134 // R l r x1 x2 -> count of values in [x1,x2] in [l,r] 135 while (q--) { 136 char type; cin >> type; 137 if (type == 'K') { 138 int l, r, k; cin >> l >> r >> k; 139 cout << wt.kth(l,r,k) << "\n"; 140 } else if (type == 'C') { 141 int l, r, x; cin >> l >> r >> x; 142 cout << wt.countEqual(l,r,x) << "\n"; 143 } else if (type == 'L') { 144 int l, r, x; cin >> l >> r >> x; 145 cout << wt.lte(l,r,x) << "\n"; 146 } else if (type == 'R') { 147 int l, r, x1, x2; cin >> l >> r >> x1 >> x2; 148 cout << wt.countInRange(l,r,x1,x2) << "\n"; 149 } 150 } 151 return 0; 152 } 153
This program reads an array and answers multiple query types using a compressed wavelet tree. It demonstrates how to expose a small, practical API: k-th smallest, count of a single value, count of ≤ x, and count in [x1, x2] by subtracting two LTE queries. This is a common competitive programming pattern.
1 #include <bits/stdc++.h> 2 using namespace std; 3 4 // Augmented Wavelet Tree: in addition to prefLeft (counts), store prefLeftSum (sum of values routed left) 5 // This allows computing sum of k smallest in [l,r] in O(log sigma), and sum of values <= x in [l,r]. 6 7 struct Node { 8 int lo, hi; 9 vector<int> prefLeft; // counts 10 vector<long long> prefLeftSum; // sums of ORIGINAL values that go left 11 Node *L = nullptr, *R = nullptr; 12 }; 13 14 struct WaveletTreeSum { 15 Node* root = nullptr; 16 vector<int> uniq; 17 18 WaveletTreeSum() {} 19 WaveletTreeSum(const vector<int>& a) { build(a); } 20 21 void build(const vector<int>& a) { 22 int n = (int)a.size() - 1; // 1-based 23 vector<int> uniqv(a.begin()+1, a.end()); 24 sort(uniqv.begin(), uniqv.end()); uniqv.erase(unique(uniqv.begin(), uniqv.end()), uniqv.end()); 25 uniq.swap(uniqv); 26 vector<int> ranks(n+1); 27 for (int i = 1; i <= n; ++i) ranks[i] = int(lower_bound(uniq.begin(), uniq.end(), a[i]) - uniq.begin()); 28 root = buildNode(vector<int>(ranks.begin()+1, ranks.end()), vector<int>(a.begin()+1, a.end()), 0, (int)uniq.size()-1); 29 } 30 31 // Sum of k smallest values in A[l..r] 32 long long sumKSmallest(int l, int r, int k) const { 33 k = min(k, r - l + 1); 34 if (k <= 0) return 0; 35 return sumKSmallest(root, l, r, k); 36 } 37 38 // Sum of values <= x in A[l..r] 39 long long sumLTE(int l, int r, int x) const { 40 int rx = int(upper_bound(uniq.begin(), uniq.end(), x) - uniq.begin()) - 1; 41 if (rx < 0) return 0; 42 return sumLTE(root, l, r, rx); 43 } 44 45 private: 46 Node* buildNode(const vector<int>& ranks, const vector<int>& orig, int lo, int hi) { 47 Node* node = new Node(); 48 node->lo = lo; node->hi = hi; 49 int m = (int)ranks.size(); 50 node->prefLeft.assign(m+1, 0); 51 node->prefLeftSum.assign(m+1, 0); 52 if (lo == hi || m == 0) return node; 53 int mid = (lo + hi) >> 1; 54 vector<int> lR; lR.reserve(m); vector<int> rR; rR.reserve(m); 55 vector<int> lO; lO.reserve(m); vector<int> rO; rO.reserve(m); 56 for (int i = 0; i < m; ++i) { 57 bool goLeft = (ranks[i] <= mid); 58 node->prefLeft[i+1] = node->prefLeft[i] + (goLeft ? 1 : 0); 59 node->prefLeftSum[i+1] = node->prefLeftSum[i] + (goLeft ? (long long)orig[i] : 0LL); 60 if (goLeft) { lR.push_back(ranks[i]); lO.push_back(orig[i]); } 61 else { rR.push_back(ranks[i]); rO.push_back(orig[i]); } 62 } 63 node->L = buildNode(lR, lO, lo, mid); 64 node->R = buildNode(rR, rO, mid+1, hi); 65 return node; 66 } 67 68 long long sumKSmallest(Node* node, int l, int r, int k) const { 69 if (k == 0 || l > r) return 0; 70 if (node->lo == node->hi) { 71 long long val = (long long)uniq[node->lo]; 72 return val * k; 73 } 74 int leftCnt = node->prefLeft[r] - node->prefLeft[l-1]; 75 long long leftSum = node->prefLeftSum[r] - node->prefLeftSum[l-1]; 76 if (k <= leftCnt) { 77 int nl = node->prefLeft[l-1] + 1; 78 int nr = node->prefLeft[r]; 79 return sumKSmallest(node->L, nl, nr, k); 80 } else { 81 int rl = (l-1) - node->prefLeft[l-1] + 1; 82 int rr = r - node->prefLeft[r]; 83 return leftSum + sumKSmallest(node->R, rl, rr, k - leftCnt); 84 } 85 } 86 87 long long sumLTE(Node* node, int l, int r, int rx) const { 88 if (l > r || rx < node->lo) return 0; 89 if (node->hi <= rx) { 90 // All values in [l,r] are <= rx, but we need their sum. We don't store full prefix sums here. 91 // Strategy: recursively sum both sides but we can shortcut left child using prefLeftSum. 92 if (node->lo == node->hi) { 93 long long val = (long long)uniq[node->lo]; 94 return val * (r - l + 1); 95 } 96 } 97 if (node->lo == node->hi) { 98 long long val = (long long)uniq[node->lo]; 99 return val * (min(r, r) - l + 1); 100 } 101 int mid = (node->lo + node->hi) >> 1; 102 int nl = node->prefLeft[l-1] + 1; 103 int nr = node->prefLeft[r]; 104 long long leftPart = sumLTE(node->L, nl, nr, rx); 105 int rl = (l-1) - node->prefLeft[l-1] + 1; 106 int rr = r - node->prefLeft[r]; 107 long long rightPart = 0; 108 if (rx > mid) rightPart = sumLTE(node->R, rl, rr, rx); 109 else if (rx >= node->lo && rx <= mid) { 110 // fully within left side; nothing from right 111 } 112 return leftPart + rightPart; 113 } 114 }; 115 116 int main(){ 117 ios::sync_with_stdio(false); 118 cin.tie(nullptr); 119 120 int n; cin >> n; 121 vector<int> A(n+1); 122 for (int i = 1; i <= n; ++i) cin >> A[i]; 123 WaveletTreeSum wts(A); 124 125 int q; cin >> q; 126 // Queries: 127 // S l r k : sum of k smallest values in A[l..r] 128 // T l r x : sum of values <= x in A[l..r] 129 while (q--) { 130 char type; cin >> type; 131 if (type == 'S') { 132 int l, r, k; cin >> l >> r >> k; 133 cout << wts.sumKSmallest(l,r,k) << "\n"; 134 } else if (type == 'T') { 135 int l, r, x; cin >> l >> r >> x; 136 cout << wts.sumLTE(l,r,x) << "\n"; 137 } 138 } 139 return 0; 140 } 141
This augmented variant stores, at each node, prefix sums of the original values that go left. With this, the sum of the k smallest in [l,r] is computed by taking entire left contributions when possible and continuing on the right for the remainder. The sum of values ≤ x can be answered via a similar descent that accumulates contributions fully within the allowed value range.