問題
如果今天,我需要支援以下兩種操作:
- 插入一個數字
- 查詢比一個數字大的最小元素是什麼
一個簡單的想法是,就用一個陣列通通存起來,然後查詢時一個一個找。這樣對於第一種插入操作是 $O(1)$ ,第二種是 $O(N)$。
另一個是,插入時使用 insertion sort 方法,讓整個陣列維持排序好的狀態。搜尋時就能直接使用 upper_bound 函式。這樣對於第一種插入操作是 $O(N)$ ,第二種是 $O(logN)$。
在插入與查詢次數差不多的情況下,這兩種方法都無法快速的解決問題。
二元搜尋樹
雖然現在大家可能還沒有樹的概念,還是請大家看看這個教學。
二元搜尋樹
二元搜尋樹保證,一個節點左邊一定全部都比它小,右邊一定全部比它大。
可以發現,以上兩種操作都可以透過二元搜尋樹來完成。而複雜度則是以樹的深度(從一開始走到沒路可走時的最長距離)來決定。
在一般的二元搜尋樹中,深度並沒有保證(想像插入的數字一直遞增),最高與元素數量 $N$ 相等。
但,若現在有一種二元搜尋樹,可以透過旋轉、換根等操作,保證在任何時刻,深度都是 $O(logN)$,這個題目就可以很好的解決了。
而 set 就可以做到這樣的事情!set 的內部,正是平衡二元搜尋樹之一的紅黑樹。
不同於之前的 queue 與 stack,這超難實作的,所以就直接進用法。
set 語法
set,意為集合,集合內不會存在兩個相同元素。
宣告一個 set 名為 s。
插入元素,若 set 內部已存在該元素,則無任何動作。
移除元素,若 set 內部無該元素,則無任何動作。也可以移除一個 iterator 所指向的元素,程式會自動判斷傳入型態。
1 2
| s.erase(5); s.erase(s.begin())
|
回傳該元素的 iterator,若 set 內部無該元素,則回傳 end()。
依照順序輸出內部所有元素,複雜度 $O(N)$。
1
| for(int i:s) cout << i << '\n';
|
問一個元素在不在 set 裡。可透過 find 的 return 值,或使用 s.count。
1 2
| if(s.find(10) != s.end()) cout << "In!\n"; if(s.count(10)) cout << "In!\n";
|
得到 set 的第一項,可使用 *s.begin(),最後一項使用 *s.rbegin()。
1 2
| cout << *s.begin() << '\n'; cout << *s.rbegin() << '\n';
|
那,可以輸出 set 中第 K 小的數字嗎?答案是不行的,set 沒有辦法做到這點。因此,set 中是沒有 s[k] 這樣的語法的。
但是,set 的 iterator,可以支援 ++, – 的運算,而且是 $O(1)$ 的。所以,若能以一個變數將 begin() 存下來,之後將其 ++,便可以得到第 2 大的數字。
要怎麼儲存呢?用變數。那型態是什麼?是:
1
| set<int>::iterator iter = s.begin();
|
這名稱實在太長,因此我們使用 C++11 後出現的新功能:自動判斷型態!
1 2
| auto iter = s.begin(); cout << *(++iter) << '\n';
|
因此,遍歷 set 也可以寫成:
1 2 3
| for(auto iter = s.begin(); iter != s.end(); iter++){ cout << *iter << '\n'; }
|
此外,若要得到一個 iterator 的前/後一項,可以使用 prev/next,這樣不會改到本身,常見於得到 set 中最大的數。
1
| cout << *prev(s.end()) << '\n';
|
回來看看
原來的問題,要如何用 set 去做呢?
操作一很簡單,就是 insert 而已。
操作二呢?恩…好像還是沒辦法,如果 set 中存在要被查詢的那個數,可以寫:
1 2
| auto iter = s.find(x); cout << *(++iter) << '\n';
|
但若不存在,find 函式回傳 end,就沒辦法了。
若是有一個類似 upper_bound 的函式就好了…
對,沒錯,set 的成員函式有 upper_bound()!
1
| cout << *s.upper_bound(5) << '\n';
|
順帶一提,upper_bound(s.begin(), s.end(), 5)這種寫法也能輸出正確答案,但是複雜度是糟糕的 $O(size)$。請忘記這種用法,記住 s.upper_bound() 就好!
連續刪除 a 到 b 區間內元素
1 2 3
| for(auto iter = s.find(5); iter != s.end() && *iter < 10; iter++){ s.erase(iter); }
|
這是對的嗎?看起來好像是,但其實不是,而且會導致 RE。為什麼呢?
當 iter 被 erase 後,它就不存在了,因此將其++後會發生什麼事?沒有人知道。
怎麼辦呢?
s.erase() 這個函式,其實不是 void,它會回傳被刪除的元素的下一個 iterator。
因此可以利用這個性質:
1
| for(auto iter = s.find(5); iter != s.end() && *iter < 10; iter = s.erase(iter));
|
我需要有多個相同元素的 set
set 在 insert 時,遇到相同元素會忽略,常常會因為這樣產生許多 bug。比如,插入 5 個 1 後刪除 1 個 1,在 set 中就不存在 1 了,但你真正想做的事,可能是跟剩下的 4 個 4 有關。
那要怎麼讓 set 中有多個元素呢?
第一個辦法是,多紀錄一個變數 cnt,代表每個元素的個數。只在 cnt 從 0 變為 1 時將其加入 set,從 1 變為 0 時將其移除。
但是,其實,STL 中有個東西叫 multiset,可以同時存在多個相同元素!
語法大致與 set 相同,但有一些東西要注意:
- 可用 .count() 函式回傳 multiset 中元素數量。
- s.erase(10) 代表的意思是:刪除 multiset 中所有的 10,若只刪除一個要用 s.erase(s.find(10))
我不需要排序,我只需要知道哪些東西在裡面
STL 中,有個東西叫 unordered_set,底層以 hash 實作。插入與刪除的複雜度皆為常數稍大的 $O(1)$,一般來說會比 set 的 $O(logN)$ 快。
unordered_set 失去 lower_bound 與 upper_bound 功能,遍歷時也不會照大小輸出,其餘功能與 set 相同。
自訂比較函式
這裡的比較函式跟 sort 的不一樣,是一個 struct,裡面有一個 operator(),所以要這樣寫:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
| #include<bits/stdc++.h> using namespace std;
int sum_of_digits(int x){ int ans = 0; while(x){ ans += x%10; x/=10; } return ans; }
struct cmp{ bool operator()(int a, int b){ return sum_of_digits(a) < sum_of_digits(b); } };
set<int, cmp> s;
int main(){ ios::sync_with_stdio(0), cin.tie(0); s.insert(56); s.insert(11); s.insert(7); for(auto i:s) cout << i << '\n'; }
|
預設的比較函式是 less<>,有另一個定義好的比較函式是 greater<>,所以如果要從大排到小,可以這樣寫。
1
| set<int, greater<int>> s;
|
練習
TIOJ 1911
85 分解:
ZJ b938
★★★ TOJ 275
★★★★ TIOJ 1161
★★★★★ TOJ 512
★★★★★★★★ TIOJ 1941
AC Code
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
| #include<bits/stdc++.h> using namespace std; #define ll long long multiset<ll> m; int n;
int main() { ios::sync_with_stdio(0), cin.tie(0); while(cin >> n){ if(n == 0) break; if(n > 0) m.insert(n); else{ if(m.empty()) continue; if(n == -1){ cout << *m.begin() << ' '; m.erase(m.begin()); } else{ auto iter = m.end(); iter--; cout << *iter << ' '; m.erase(iter); } } } cout << '\n'; }
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
| #include<bits/stdc++.h> using namespace std; int n, m; set<int> s; bool dead[1000005];
int main() { ios::sync_with_stdio(0), cin.tie(0); cin >> n >> m; for(int i=1;i<=n;i++) s.insert(i); for(int i=1,x;i<=m;i++){ cin >> x; if(dead[x]) cout << "0u0 ...... ?\n"; else{ auto iter = s.upper_bound(x); if(iter == s.end()) cout << "0u0 ...... ?\n"; else{ cout << *iter << '\n'; dead[*iter] = 1; s.erase(iter); } } } }
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39
|
#include<bits/stdc++.h> using namespace std; #define ll long long ll n; multiset<ll> small, big;
void pop_big(){ ll x = *big.begin(); small.insert(x); big.erase(big.begin()); }
void pop_small(){ ll x = *small.rbegin(); big.insert(x); small.erase(prev(small.end())); }
int main() { ios::sync_with_stdio(0), cin.tie(0); cout << fixed << setprecision(6); cin >> n; for(int i=1,x;i<=n;i++){ cin >> x;
if(i == 1 || x < *small.rbegin()) small.insert(x); else big.insert(x);
while(big.size() > i/2) pop_big(); while(big.size() < i/2) pop_small();
double ans; if(i&1) ans = *small.rbegin(); else ans = (*small.rbegin() + *big.begin()) * 1.0 / 2; cout << ans << '\n'; } }
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
| #include<bits/stdc++.h> using namespace std; #define pii pair<int,int>
multiset<int,greater<int>> s; int n, k; pii arr[1000005];
void solve(){ s.clear(); cin >> n >> k; for(int i=1;i<=n;i++) cin >> arr[i].first >> arr[i].second; sort(arr+1, arr+n+1); int ans = 1e9; for(int i=1;i<=n;i++){ s.insert(arr[i].second); if(s.size() > k) s.erase(s.begin()); if(s.size() == k) ans = min(ans, arr[i].first + *s.begin()); } cout << ans << '\n'; }
int main(){ int t; cin >> t; while(t--) solve(); }
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41
|
#include<bits/stdc++.h> using namespace std;
set<int> s; int n, q; int arr[1000005], pos[1000005];
void change(int x, int y){ s.erase(x); s.erase(y); swap(arr[x], arr[y]); pos[arr[x]] = x; pos[arr[y]] = y; if(arr[x] != x) s.insert(x); if(arr[y] != y) s.insert(y); }
int recover(int x, int y){ int ans = 0; while(1){ auto iter = s.lower_bound(x); if(iter == s.end() || *iter > y) break; ans++; change(pos[*iter], *iter); } return ans; }
int main(){ ios::sync_with_stdio(0), cin.tie(0); cin >> n >> q; for(int i=1;i<=n;i++) pos[i] = arr[i] = i; for(int i=1,ope,a,b;i<=q;i++){ cin >> ope >> a >> b; if(ope == 1) change(a, b); else cout << recover(a, b) << '\n'; } }
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
|
#include<bits/stdc++.h> using namespace std; multiset<int> s; int n; int main(){ ios::sync_with_stdio(0), cin.tie(0); cin >> n; for(int i=1,l,r;i<=n;i++){ cin >> l >> r; s.insert(l); auto iter = s.upper_bound(r); if(iter != s.end()) s.erase(iter); } cout << s.size() << '\n'; }
|