堆栈是一种经典的后进先出的线性结构,相关的操作主要有“入栈”(在堆栈顶插入一个元素)和“出栈”(将栈顶元素返回并从堆栈中删除)。本题要求你实现另一个附加的操作:“取中值”——即返回所有堆栈中元素键值的中值。给定 N 个元素,如果 N 是偶数,则中值定义为第 N/2 小元;若是奇数,则为第 (N+1)/2 小元。

输入格式:

输入的第一行是正整数 N(≤105)。随后 N 行,每行给出一句指令,为以下 3 种之一:

1
2
3
Push key
Pop
PeekMedian

其中 key 是不超过 105 的正整数;Push 表示“入栈”;Pop 表示“出栈”;PeekMedian 表示“取中值”。

输出格式:

对每个 Push 操作,将 key 插入堆栈,无需输出;对每个 PopPeekMedian 操作,在一行中输出相应的返回值。若操作非法,则对应输出 Invalid

输入样例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
17
Pop
PeekMedian
Push 3
PeekMedian
Push 2
PeekMedian
Push 1
PeekMedian
Pop
Pop
Push 5
Push 4
PeekMedian
Pop
Pop
Pop
Pop

输出样例:

1
2
3
4
5
6
7
8
9
10
11
12
Invalid
Invalid
3
2
2
1
2
4
4
5
3
Invalid

思路

Push栈的入栈、Pop出栈操作,PeekMedian栈中元素的中值。如果排序查找做的话,会超时。可以省去排序这个步骤,直接在Push操作的时候进行插入排序。

定义一个栈s进行入栈出栈操作;定义一个vector数组存放栈中元素,用来查找中值。利用STL库中的lower_bound函数,查找大于插入值的第一个位置,然后将该元素直接插入。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
#include <iostream>
#include <stack>
#include <algorithm>
#include <vector>
using namespace std;
int main() {
int n;
cin >> n;
stack<int> s;
vector<int> ve;
for(int i = 0; i < n; i++) {
string str;
cin >> str;
if(str == "Pop") {
if(s.empty()) cout << "Invalid\n";
else {
auto it = lower_bound(ve.begin(), ve.end(), s.top());
ve.erase(it);
cout << s.top() << endl;
s.pop();

}
}
else if(str == "Push") {
int key;
cin >> key;
s.push(key);
ve.insert(lower_bound(ve.begin(), ve.end(), s.top()), key);
}
else if(str == "PeekMedian") {
if(s.empty()) cout << "Invalid\n";
else {
if(ve.size()%2 == 0) cout << ve[ve.size()/2-1] << endl;
else cout << ve[ve.size()/2] << endl;
}
}
}
}

更多做法(年前测试题解)

做法一:

题目总共有三种操作:

1.Push:在序列中插入一个数x

2.Pop: 删除序列的某个数

3.PeekMedian 查询第N/2大数

很明显,这三种操作均可以使用平衡树来完成。

平衡树并不好写,但是大多数时候可以用vector来代替普通平衡树(复杂度是O(N^2)的,但是常数小,很多时候都能跑得飞快,set是平衡树,但是不支持查询第k大值操作,不过pbds库中有支持查询第k大值的平衡树, pta也能使用pbds库,感兴趣可以自己学习)

这篇博客里有使用treap与vector两种方式来完成普通平衡树的代码https://www.cnblogs.com/HocRiser/p/8763251.html

vector(O(N^2))

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
#include<bits/stdc++.h>
using namespace std;
const int N = 1e5 + 5;
#define ins(x) (a.insert(upper_bound(a.begin(),a.end(),x),x)) // 插入x
#define que(x) (a[x - 1]) // 查询排名为x的数
#define del(x) (a.erase(lower_bound(a.begin(),a.end(),x))) // 删除值为x的数
int main() {
cin.tie(0), cout.tie(0);
int n, x;
cin >> n;
stack<int> stk;
vector<int> a;
for (int i = 1; i <= n; i++) {
string op;
cin >> op;
if (op == "Pop") {
if (!stk.size()) puts("Invalid");
else cout << stk.top() << '\n', del(stk.top()), stk.pop();
}
else if (op == "Push") {
cin >> x;
stk.push(x);
ins(x);
}
else {
if (!stk.size()) {puts("Invalid"); continue;}
cout << que((stk.size() + 1) / 2) << '\n';
}
}
return 0;
}

pbds平衡树(O(NlogN))

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
#include<bits/stdc++.h>
#include<ext/pb_ds/assoc_container.hpp>
#include<ext/pb_ds/tree_policy.hpp>
using namespace std;
__gnu_pbds::tree<array<int, 2>, __gnu_pbds::null_type, less<array<int, 2>> ,__gnu_pbds::rb_tree_tag,__gnu_pbds::tree_order_statistics_node_update> T;
int main() {
cin.tie(0), cout.tie(0);
int n, x, t = 0;
cin >> n;
stack<int> stk;
vector<int> a;
for (int i = 1; i <= n; i++) {
string op;
cin >> op;
if (op == "Pop") {
if (!stk.size()) puts("Invalid");
else cout << stk.top() << '\n', T.erase(T.find_by_order(T.order_of_key({stk.top(),0}))), stk.pop();
}
else if (op == "Push") {
cin >> x;
stk.push(x);
T.insert({x, ++t});
}
else {
if (!stk.size()) {puts("Invalid"); continue;}
cout << (*T.find_by_order((stk.size() + 1) / 2 - 1))[0] << '\n';
}
}
return 0;
}

做法二:

虽然set没法动态查询第K大值,但是K是固定的,所以我们可以利用对顶堆的思想,用两个set来维护中位数,其中s1存储从小到大排序前ceil{N/2}个数,s2存剩余的数,那么每次取中位数就是取s1中的最后一个元素,set插入、删除和查询均是O(logN), 所以时间复杂度为O(NlogN)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
#include<bits/stdc++.h>
using namespace std;
multiset<int> s1, s2;
void add(int x){
if(s1.size() - s2.size() == 1) {
if(s1.size() && x <= *prev(s1.end())) s2.insert(*prev(s1.end())), s1.erase(prev(s1.end())), s1.insert(x);
else s2.insert(x);
}
else {
if(s2.size() && x > *s2.begin()) s1.insert(*s2.begin()), s2.erase(s2.begin()), s2.insert(x);
else s1.insert(x);
}
}
void del(int x){
if(x <= *prev(s1.end())) s1.erase(s1.find(x));
else s2.erase(s2.find(x));
while((int)s1.size() - (int)s2.size() < 0) s1.insert(*s2.begin()) ,s2.erase(s2.begin());
while((int)s1.size() - (int)s2.size() > 1) s2.insert(*prev(s1.end())), s1.erase(prev(s1.end()));
}
int main(){
cin.tie(0), cout.tie(0);
int n;
cin >> n;
stack<int> stk;
for(int i = 1;i <= n; i++){
string op;
cin >> op;
if(op == "Pop") {
if(!stk.size()) puts("Invalid");
else cout << stk.top() << '\n', del(stk.top()), stk.pop();
}
else if(op == "Push") {
int x;
cin >> x;
stk.push(x);
add(x);
}
else {
if(!stk.size()) {puts("Invalid");continue;}
cout << *prev(s1.end()) << '\n';
}
}
return 0;
}

做法三:

发现查询、删除与插入这三种操作与序列里每个数的先后顺序无关,所以我们可以用一个数组cnt[N]来记录每个数的个数,cnt[i]表示1~i的数的个数,那么中位数就是第一个满足cnt[i] >= N/2的i,显然cnt数组单调不减,所以我们可以用树状数组维护cnt数组,然后二分查询,时间复杂度O(NlogN^2)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
#include<bits/stdc++.h>
using namespace std;
const int N = 1e5 + 5;
int cnt[N];
void add(int pos, int val){
for(;pos < N; pos += pos & -pos) cnt[pos]+=val;
}
int get(int pos){
int res = 0;
for(;pos;pos -= pos & -pos) res += cnt[pos];
return res;
}
int main(){
cin.tie(0), cout.tie(0);
int n;
cin >> n;
stack<int> stk;
for(int i = 1;i <= n; i++){
string op;
cin >> op;
if(op == "Pop") {
if(!stk.size()) puts("Invalid");
else cout << stk.top() << '\n', add(stk.top(), -1), stk.pop();
}
else if(op == "Push") {
int x;
cin >> x;
stk.push(x);
add(x, 1);
}
else {
if(!stk.size()) {puts("Invalid");continue;}
int l = 0, r = 1e5;
while(l < r){
int mid = l + r >> 1;
if(get(mid) >= (stk.size() + 1) / 2) r = mid;
else l = mid + 1;
}
cout << r << '\n';
}
}
return 0;
}

做法四:

注意到做法三实际上就是维护一个[0,100000]的值域,所以我们可以使用权值线段树来维护值域,时间复杂度为O(NlogN)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
#include<bits/stdc++.h>
using namespace std;
const int N = 1e5 + 5;
struct seg{
int l, r, sum;
}tr[N << 2];
void pushup(int u){
tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
}
void build(int u, int l, int r){
tr[u] = {l, r};
if(l == r) {
return;
}
int mid = l + r >> 1;
build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
pushup(u);
}
void add(int u, int pos, int val){
if(tr[u].l == tr[u].r) tr[u].sum+=val;
else{
int mid = tr[u].l + tr[u].r >> 1;
if(pos <= mid) add(u << 1, pos, val);
else add(u << 1 | 1, pos, val);
pushup(u);
}
}
int get(int u, int val){
if(tr[u].l == tr[u].r) return tr[u].r;
int mid = tr[u].l + tr[u].r >> 1;
if(tr[u << 1].sum >= val) return get(u << 1, val);
return get(u << 1 | 1, val - tr[u << 1].sum);
}
int main() {
cin.tie(0), cout.tie(0);
int n;
cin >> n;
stack<int> stk;
build(1, 0, 1e5);
for (int i = 1; i <= n; i++) {
string op;
cin >> op;
if (op == "Pop") {
if (!stk.size()) puts("Invalid");
else cout << stk.top() << '\n', add(1, stk.top(), -1), stk.pop();
}
else if (op == "Push") {
int x;
cin >> x;
stk.push(x);
add(1, x, 1);
}
else {
if (!stk.size()) {puts("Invalid"); continue;}
cout << get(1, (stk.size() + 1) / 2) << '\n';
}
}
return 0;
}