堆栈是一种经典的后进先出的线性结构,相关的操作主要有“入栈”(在堆栈顶插入一个元素)和“出栈”(将栈顶元素返回并从堆栈中删除)。本题要求你实现另一个附加的操作:“取中值”——即返回所有堆栈中元素键值的中值。给定 N 个元素,如果 N 是偶数,则中值定义为第 N/2 小元;若是奇数,则为第 (N+1)/2 小元。
输入格式: 输入的第一行是正整数 N(≤105)。随后 N 行,每行给出一句指令,为以下 3 种之一:
其中 key
是不超过 105 的正整数;Push
表示“入栈”;Pop
表示“出栈”;PeekMedian
表示“取中值”。
输出格式: 对每个 Push
操作,将 key
插入堆栈,无需输出;对每个 Pop
或 PeekMedian
操作,在一行中输出相应的返回值。若操作非法,则对应输出 Invalid
。
输入样例: 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 17 Pop PeekMedian Push 3 PeekMedian Push 2 PeekMedian Push 1 PeekMedian Pop Pop Push 5 Push 4 PeekMedian Pop Pop Pop Pop
输出样例: 1 2 3 4 5 6 7 8 9 10 11 12 Invalid Invalid 3 2 2 1 2 4 4 5 3 Invalid
思路 Push
栈的入栈、Pop
出栈操作,PeekMedian
栈中元素的中值。如果排序查找做的话,会超时。可以省去排序这个步骤,直接在Push
操作的时候进行插入排序。
定义一个栈s
进行入栈出栈操作;定义一个vector数组存放栈中元素,用来查找中值。利用STL库中的lower_bound
函数,查找大于插入值的第一个位置,然后将该元素直接插入。
代码 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 #include <iostream> #include <stack> #include <algorithm> #include <vector> using namespace std ;int main () { int n; cin >> n; stack <int > s; vector <int > ve; for (int i = 0 ; i < n; i++) { string str; cin >> str; if (str == "Pop" ) { if (s.empty()) cout << "Invalid\n" ; else { auto it = lower_bound(ve.begin(), ve.end(), s.top()); ve.erase(it); cout << s.top() << endl ; s.pop(); } } else if (str == "Push" ) { int key; cin >> key; s.push(key); ve.insert(lower_bound(ve.begin(), ve.end(), s.top()), key); } else if (str == "PeekMedian" ) { if (s.empty()) cout << "Invalid\n" ; else { if (ve.size()%2 == 0 ) cout << ve[ve.size()/2 -1 ] << endl ; else cout << ve[ve.size()/2 ] << endl ; } } } }
更多做法(年前测试题解) 做法一: 题目总共有三种操作:
1.Push:在序列中插入一个数x
2.Pop: 删除序列的某个数
3.PeekMedian 查询第N/2大数
很明显,这三种操作均可以使用平衡树来完成。
平衡树并不好写,但是大多数时候可以用vector来代替普通平衡树(复杂度是O(N^2)的,但是常数小,很多时候都能跑得飞快,set是平衡树,但是不支持查询第k大值操作,不过pbds库中有支持查询第k大值的平衡树, pta也能使用pbds库,感兴趣可以自己学习)
这篇博客里有使用treap与vector两种方式来完成普通平衡树的代码https://www.cnblogs.com/HocRiser/p/8763251.html
vector(O(N^2))
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 #include <bits/stdc++.h> using namespace std ;const int N = 1e5 + 5 ;#define ins(x) (a.insert(upper_bound(a.begin(),a.end(),x),x)) #define que(x) (a[x - 1]) #define del(x) (a.erase(lower_bound(a.begin(),a.end(),x))) int main () { cin .tie(0 ), cout .tie(0 ); int n, x; cin >> n; stack <int > stk; vector <int > a; for (int i = 1 ; i <= n; i++) { string op; cin >> op; if (op == "Pop" ) { if (!stk.size()) puts ("Invalid" ); else cout << stk.top() << '\n' , del(stk.top()), stk.pop(); } else if (op == "Push" ) { cin >> x; stk.push(x); ins(x); } else { if (!stk.size()) {puts ("Invalid" ); continue ;} cout << que((stk.size() + 1 ) / 2 ) << '\n' ; } } return 0 ; }
pbds平衡树(O(NlogN))
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 #include <bits/stdc++.h> #include <ext/pb_ds/assoc_container.hpp> #include <ext/pb_ds/tree_policy.hpp> using namespace std ;__gnu_pbds::tree<array <int , 2>, __gnu_pbds::null_type, less<array <int , 2>> ,__gnu_pbds::rb_tree_tag,__gnu_pbds::tree_order_statistics_node_update> T; int main () { cin .tie(0 ), cout .tie(0 ); int n, x, t = 0 ; cin >> n; stack <int > stk; vector <int > a; for (int i = 1 ; i <= n; i++) { string op; cin >> op; if (op == "Pop" ) { if (!stk.size()) puts ("Invalid" ); else cout << stk.top() << '\n' , T.erase(T.find_by_order(T.order_of_key({stk.top(),0 }))), stk.pop(); } else if (op == "Push" ) { cin >> x; stk.push(x); T.insert({x, ++t}); } else { if (!stk.size()) {puts ("Invalid" ); continue ;} cout << (*T.find_by_order((stk.size() + 1 ) / 2 - 1 ))[0 ] << '\n' ; } } return 0 ; }
做法二: 虽然set没法动态查询第K大值,但是K是固定的,所以我们可以利用对顶堆的思想,用两个set来维护中位数,其中s1存储从小到大排序前ceil{N/2}个数,s2存剩余的数,那么每次取中位数就是取s1中的最后一个元素,set插入、删除和查询均是O(logN), 所以时间复杂度为O(NlogN)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 #include <bits/stdc++.h> using namespace std ;multiset <int > s1, s2;void add (int x) { if (s1.size() - s2.size() == 1 ) { if (s1.size() && x <= *prev(s1.end())) s2.insert(*prev(s1.end())), s1.erase(prev(s1.end())), s1.insert(x); else s2.insert(x); } else { if (s2.size() && x > *s2.begin()) s1.insert(*s2.begin()), s2.erase(s2.begin()), s2.insert(x); else s1.insert(x); } } void del (int x) { if (x <= *prev(s1.end())) s1.erase(s1.find(x)); else s2.erase(s2.find(x)); while ((int )s1.size() - (int )s2.size() < 0 ) s1.insert(*s2.begin()) ,s2.erase(s2.begin()); while ((int )s1.size() - (int )s2.size() > 1 ) s2.insert(*prev(s1.end())), s1.erase(prev(s1.end())); } int main () { cin .tie(0 ), cout .tie(0 ); int n; cin >> n; stack <int > stk; for (int i = 1 ;i <= n; i++){ string op; cin >> op; if (op == "Pop" ) { if (!stk.size()) puts ("Invalid" ); else cout << stk.top() << '\n' , del(stk.top()), stk.pop(); } else if (op == "Push" ) { int x; cin >> x; stk.push(x); add(x); } else { if (!stk.size()) {puts ("Invalid" );continue ;} cout << *prev(s1.end()) << '\n' ; } } return 0 ; }
做法三: 发现查询、删除与插入这三种操作与序列里每个数的先后顺序无关,所以我们可以用一个数组cnt[N]来记录每个数的个数,cnt[i]表示1~i的数的个数,那么中位数就是第一个满足cnt[i] >= N/2的i,显然cnt数组单调不减,所以我们可以用树状数组维护cnt数组,然后二分查询,时间复杂度O(NlogN^2)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 #include <bits/stdc++.h> using namespace std ;const int N = 1e5 + 5 ;int cnt[N];void add (int pos, int val) { for (;pos < N; pos += pos & -pos) cnt[pos]+=val; } int get (int pos) { int res = 0 ; for (;pos;pos -= pos & -pos) res += cnt[pos]; return res; } int main () { cin .tie(0 ), cout .tie(0 ); int n; cin >> n; stack <int > stk; for (int i = 1 ;i <= n; i++){ string op; cin >> op; if (op == "Pop" ) { if (!stk.size()) puts ("Invalid" ); else cout << stk.top() << '\n' , add(stk.top(), -1 ), stk.pop(); } else if (op == "Push" ) { int x; cin >> x; stk.push(x); add(x, 1 ); } else { if (!stk.size()) {puts ("Invalid" );continue ;} int l = 0 , r = 1e5 ; while (l < r){ int mid = l + r >> 1 ; if (get(mid) >= (stk.size() + 1 ) / 2 ) r = mid; else l = mid + 1 ; } cout << r << '\n' ; } } return 0 ; }
做法四: 注意到做法三实际上就是维护一个[0,100000]的值域,所以我们可以使用权值线段树来维护值域,时间复杂度为O(NlogN)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 #include <bits/stdc++.h> using namespace std ;const int N = 1e5 + 5 ;struct seg { int l, r, sum; }tr[N << 2 ]; void pushup (int u) { tr[u].sum = tr[u << 1 ].sum + tr[u << 1 | 1 ].sum; } void build (int u, int l, int r) { tr[u] = {l, r}; if (l == r) { return ; } int mid = l + r >> 1 ; build(u << 1 , l, mid), build(u << 1 | 1 , mid + 1 , r); pushup(u); } void add (int u, int pos, int val) { if (tr[u].l == tr[u].r) tr[u].sum+=val; else { int mid = tr[u].l + tr[u].r >> 1 ; if (pos <= mid) add(u << 1 , pos, val); else add(u << 1 | 1 , pos, val); pushup(u); } } int get (int u, int val) { if (tr[u].l == tr[u].r) return tr[u].r; int mid = tr[u].l + tr[u].r >> 1 ; if (tr[u << 1 ].sum >= val) return get(u << 1 , val); return get(u << 1 | 1 , val - tr[u << 1 ].sum); } int main () { cin .tie(0 ), cout .tie(0 ); int n; cin >> n; stack <int > stk; build(1 , 0 , 1e5 ); for (int i = 1 ; i <= n; i++) { string op; cin >> op; if (op == "Pop" ) { if (!stk.size()) puts ("Invalid" ); else cout << stk.top() << '\n' , add(1 , stk.top(), -1 ), stk.pop(); } else if (op == "Push" ) { int x; cin >> x; stk.push(x); add(1 , x, 1 ); } else { if (!stk.size()) {puts ("Invalid" ); continue ;} cout << get(1 , (stk.size() + 1 ) / 2 ) << '\n' ; } } return 0 ; }