問題 如果今天,我需要支援以下兩種操作:
插入一個數字
查詢比一個數字大的最小元素是什麼
一個簡單的想法是,就用一個陣列通通存起來,然後查詢時一個一個找。這樣對於第一種插入操作是 $O(1)$ ,第二種是 $O(N)$。
另一個是,插入時使用 insertion sort 方法,讓整個陣列維持排序好的狀態。搜尋時就能直接使用 upper_bound 函式。這樣對於第一種插入操作是 $O(N)$ ,第二種是 $O(logN)$。
在插入與查詢次數差不多的情況下,這兩種方法都無法快速的解決問題。
二元搜尋樹 雖然現在大家可能還沒有樹的概念,還是請大家看看這個教學。
二元搜尋樹
二元搜尋樹保證,一個節點左邊一定全部都比它小,右邊一定全部比它大。
可以發現,以上兩種操作都可以透過二元搜尋樹來完成。而複雜度則是以樹的深度(從一開始走到沒路可走時的最長距離)來決定。
在一般的二元搜尋樹中,深度並沒有保證(想像插入的數字一直遞增),最高與元素數量 $N$ 相等。
但,若現在有一種二元搜尋樹,可以透過旋轉、換根等操作,保證在任何時刻,深度都是 $O(logN)$,這個題目就可以很好的解決了。
而 set 就可以做到這樣的事情!set 的內部,正是平衡二元搜尋樹 之一的紅黑樹 。
不同於之前的 queue 與 stack,這超難實作的,所以就直接進用法。
set 語法 set,意為集合,集合內不會存在兩個相同元素。
宣告一個 set 名為 s。
插入元素,若 set 內部已存在該元素,則無任何動作。
移除元素,若 set 內部無該元素,則無任何動作。也可以移除一個 iterator 所指向的元素,程式會自動判斷傳入型態。
1 2 s.erase(5 ); s.erase(s.begin())
回傳該元素的 iterator ,若 set 內部無該元素,則回傳 end()。
依照順序輸出內部所有元素,複雜度 $O(N)$。
1 for (int i:s) cout << i << '\n' ;
問一個元素在不在 set 裡。可透過 find 的 return 值,或使用 s.count。
1 2 if (s.find(10 ) != s.end()) cout << "In!\n" ;if (s.count(10 )) cout << "In!\n" ;
得到 set 的第一項,可使用 *s.begin(),最後一項使用 *s.rbegin()。
1 2 cout << *s.begin() << '\n' ;cout << *s.rbegin() << '\n' ;
那,可以輸出 set 中第 K 小的數字嗎?答案是不行的,set 沒有辦法做到這點。因此,set 中是沒有 s[k] 這樣的語法的。
但是,set 的 iterator,可以支援 ++, – 的運算,而且是 $O(1)$ 的。所以,若能以一個變數將 begin() 存下來,之後將其 ++,便可以得到第 2 大的數字。
要怎麼儲存呢?用變數。那型態是什麼?是:
1 set <int >::iterator iter = s.begin();
這名稱實在太長,因此我們使用 C++11 後出現的新功能:自動判斷型態!
1 2 auto iter = s.begin();cout << *(++iter) << '\n' ;
因此,遍歷 set 也可以寫成:
1 2 3 for (auto iter = s.begin(); iter != s.end(); iter++){ cout << *iter << '\n' ; }
此外,若要得到一個 iterator 的前/後一項,可以使用 prev/next,這樣不會改到本身,常見於得到 set 中最大的數。
1 cout << *prev(s.end()) << '\n' ;
回來看看 原來的問題,要如何用 set 去做呢?
操作一很簡單,就是 insert 而已。
操作二呢?恩…好像還是沒辦法,如果 set 中存在要被查詢的那個數,可以寫:
1 2 auto iter = s.find(x);cout << *(++iter) << '\n' ;
但若不存在,find 函式回傳 end,就沒辦法了。
若是有一個類似 upper_bound 的函式就好了…
對,沒錯,set 的成員函式有 upper_bound()!
1 cout << *s.upper_bound(5 ) << '\n' ;
順帶一提,upper_bound(s.begin(), s.end(), 5)這種寫法也能輸出正確答案,但是複雜度是糟糕的 $O(size)$。請忘記這種用法,記住 s.upper_bound() 就好!
連續刪除 a 到 b 區間內元素 1 2 3 for (auto iter = s.find(5 ); iter != s.end() && *iter < 10 ; iter++){ s.erase(iter); }
這是對的嗎?看起來好像是,但其實不是,而且會導致 RE。為什麼呢?
當 iter 被 erase 後,它就不存在了,因此將其++後會發生什麼事?沒有人知道。
怎麼辦呢? s.erase() 這個函式,其實不是 void,它會回傳被刪除的元素的下一個 iterator。
因此可以利用這個性質:
1 for (auto iter = s.find(5 ); iter != s.end() && *iter < 10 ; iter = s.erase(iter));
我需要有多個相同元素的 set set 在 insert 時,遇到相同元素會忽略,常常會因為這樣產生許多 bug。比如,插入 5 個 1 後刪除 1 個 1,在 set 中就不存在 1 了,但你真正想做的事,可能是跟剩下的 4 個 4 有關。
那要怎麼讓 set 中有多個元素呢?
第一個辦法是,多紀錄一個變數 cnt,代表每個元素的個數。只在 cnt 從 0 變為 1 時將其加入 set,從 1 變為 0 時將其移除。
但是,其實,STL 中有個東西叫 multiset,可以同時存在多個相同元素!
語法大致與 set 相同,但有一些東西要注意:
可用 .count() 函式回傳 multiset 中元素數量。
s.erase(10) 代表的意思是:刪除 multiset 中所有 的 10,若只刪除一個要用 s.erase(s.find(10))
我不需要排序,我只需要知道哪些東西在裡面 STL 中,有個東西叫 unordered_set,底層以 hash 實作。插入與刪除的複雜度皆為常數稍大的 $O(1)$,一般來說會比 set 的 $O(logN)$ 快。
unordered_set 失去 lower_bound 與 upper_bound 功能,遍歷時也不會照大小輸出,其餘功能與 set 相同。
自訂比較函式 這裡的比較函式跟 sort 的不一樣,是一個 struct,裡面有一個 operator(),所以要這樣寫:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 #include <bits/stdc++.h> using namespace std ;int sum_of_digits (int x) { int ans = 0 ; while (x){ ans += x%10 ; x/=10 ; } return ans; } struct cmp { bool operator () (int a, int b) { return sum_of_digits(a) < sum_of_digits(b); } }; set <int , cmp> s;int main () { ios::sync_with_stdio(0 ), cin .tie(0 ); s.insert(56 ); s.insert(11 ); s.insert(7 ); for (auto i:s) cout << i << '\n' ; }
預設的比較函式是 less<>,有另一個定義好的比較函式是 greater<>,所以如果要從大排到小,可以這樣寫。
1 set <int , greater<int >> s;
練習 TIOJ 1911
85 分解:ZJ b938
★★★ TOJ 275
★★★★ TIOJ 1161
★★★★★ TOJ 512
★★★★★★★★ TIOJ 1941
AC Code 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 #include <bits/stdc++.h> using namespace std ;#define ll long long multiset <ll> m;int n;int main () { ios::sync_with_stdio(0 ), cin .tie(0 ); while (cin >> n){ if (n == 0 ) break ; if (n > 0 ) m.insert(n); else { if (m.empty()) continue ; if (n == -1 ){ cout << *m.begin() << ' ' ; m.erase(m.begin()); } else { auto iter = m.end(); iter--; cout << *iter << ' ' ; m.erase(iter); } } } cout << '\n' ; }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 #include <bits/stdc++.h> using namespace std ;int n, m;set <int > s;bool dead[1000005 ];int main () { ios::sync_with_stdio(0 ), cin .tie(0 ); cin >> n >> m; for (int i=1 ;i<=n;i++) s.insert(i); for (int i=1 ,x;i<=m;i++){ cin >> x; if (dead[x]) cout << "0u0 ...... ?\n" ; else { auto iter = s.upper_bound(x); if (iter == s.end()) cout << "0u0 ...... ?\n" ; else { cout << *iter << '\n' ; dead[*iter] = 1 ; s.erase(iter); } } } }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 #include <bits/stdc++.h> using namespace std ;#define ll long long ll n; multiset <ll> small, big;void pop_big () { ll x = *big.begin(); small.insert(x); big.erase(big.begin()); } void pop_small () { ll x = *small.rbegin(); big.insert(x); small.erase(prev(small.end())); } int main () { ios::sync_with_stdio(0 ), cin .tie(0 ); cout << fixed << setprecision(6 ); cin >> n; for (int i=1 ,x;i<=n;i++){ cin >> x; if (i == 1 || x < *small.rbegin()) small.insert(x); else big.insert(x); while (big.size() > i/2 ) pop_big(); while (big.size() < i/2 ) pop_small(); double ans; if (i&1 ) ans = *small.rbegin(); else ans = (*small.rbegin() + *big.begin()) * 1.0 / 2 ; cout << ans << '\n' ; } }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 #include <bits/stdc++.h> using namespace std ;#define pii pair<int,int> multiset <int ,greater<int >> s;int n, k;pii arr[1000005 ]; void solve () { s.clear(); cin >> n >> k; for (int i=1 ;i<=n;i++) cin >> arr[i].first >> arr[i].second; sort(arr+1 , arr+n+1 ); int ans = 1e9 ; for (int i=1 ;i<=n;i++){ s.insert(arr[i].second); if (s.size() > k) s.erase(s.begin()); if (s.size() == k) ans = min(ans, arr[i].first + *s.begin()); } cout << ans << '\n' ; } int main () { int t; cin >> t; while (t--) solve(); }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 #include <bits/stdc++.h> using namespace std ;set <int > s;int n, q;int arr[1000005 ], pos[1000005 ];void change (int x, int y) { s.erase(x); s.erase(y); swap(arr[x], arr[y]); pos[arr[x]] = x; pos[arr[y]] = y; if (arr[x] != x) s.insert(x); if (arr[y] != y) s.insert(y); } int recover (int x, int y) { int ans = 0 ; while (1 ){ auto iter = s.lower_bound(x); if (iter == s.end() || *iter > y) break ; ans++; change(pos[*iter], *iter); } return ans; } int main () { ios::sync_with_stdio(0 ), cin .tie(0 ); cin >> n >> q; for (int i=1 ;i<=n;i++) pos[i] = arr[i] = i; for (int i=1 ,ope,a,b;i<=q;i++){ cin >> ope >> a >> b; if (ope == 1 ) change(a, b); else cout << recover(a, b) << '\n' ; } }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 #include <bits/stdc++.h> using namespace std ;multiset <int > s;int n;int main () { ios::sync_with_stdio(0 ), cin .tie(0 ); cin >> n; for (int i=1 ,l,r;i<=n;i++){ cin >> l >> r; s.insert(l); auto iter = s.upper_bound(r); if (iter != s.end()) s.erase(iter); } cout << s.size() << '\n' ; }