• 길이 N의 수열이 주어지면 버블소트 하는 과정에서 swap이 몇 번 일어나는지 구하는 문제이다.

문제 풀이

Segment tree를 이용한 풀이

  • 수열을 인덱스와 함께 pair로 저장한 후 값을 기준으로 오름차순 정렬한다. 첫 번째 원소부터 순서대로 segment tree에 update한다. 이 때 현재 원소의 index부터 n-1사이에 원소가 존재한다면 그것은 현재 원소보다 값은 작으면서 인덱스는 뒤에 있다는 뜻이다. 즉 swap이 일어나기 때문에 update 전에 [현재 원소의 인덱스, n-1]의 구간합을 구한다.

  • stl의 sort()는 stable한 소트가 아니기 때문에 같은 값을 가진 원소들의 index가 섞인다. 따라서 stable_sort()를 이용해야 한다.

Merge sort를 이용한 풀이

  • merge sort의 partition을 합치는 과정(내 코드에서는 merge()함수이다.)에서 swap이 일어난다. 두 구간 A, B를 P로 합치는 과정에서 A에 원소가 남아있는데 B의 원소가 더 작아 먼저 P로 들어가면 A의 남아있는 원소들과 swap이 되는 것과 동일하다.(이렇게 swap이 한번에 하나씩 일어나지 않고 한 번에 여러 원소들을 건너뛰기 때문에 merge sort가 더 빠르다) 따라서 이런 경우에 cnt를 증가시켜주면 된다.

  • 아래 주석이 없는 코드는 merge sort, 주석이 있는 코드는 segment tree를 이용한 풀이이다. merge sort 코드에서도 A의 원소가 이동할 때, B의 원소가 이동할 때 cnt갱신(주석)을 각각 구현해봤다. B를 이용한 갱신은 A에 남아있는 원소 개수만큼을 더해주면 끝난다. 하지만 A를 이용한 갱신은 사용한 B의 개수만큼을 더해주고, 한 구간이 모두 끝난 후에 후처리 과정이 하나 더 있다. A가 남으면 모든 B들이 남은 A의 개수만큼 swap을 했다는 뜻이기 때문에 (남은 A 개수)*(B의 총 개수)를 더해줘야 한다.

소스코드
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
#include<bits/stdc++.h>
using namespace std;
typedef vector<int> vi;
typedef long long int lld;

lld cnt;

int mid(int s, int e){return (s+e)>>1;}
void mergesort(vi& p, int s, int e);
void merge(vi& p, int s, int mid, int e);

int main()
{
	ios::sync_with_stdio(false);
	cin.tie(0);
	int n;
	cin>>n;
	vi p(n);
	for(auto& i:p) cin>>i;
	mergesort(p, 0, n-1);
	cout<<cnt;
}

void mergesort(vi& p, int s, int e){
	if(s==e) return;
	mergesort(p, s, mid(s, e));
	mergesort(p, mid(s, e)+1, e);
	merge(p, s, mid(s, e), e);
}

void merge(vi& p, int s, int mid, int e){
	int aidx=s, bidx=mid+1, idx=0;
	vi t(e-s+1);
	while(aidx!=mid+1 && bidx!=e+1){
		if(p[aidx]<=p[bidx]){
			t[idx++]=p[aidx++];
			cnt+=bidx-mid-1;
		}
		else{
			t[idx++]=p[bidx++];
			//cnt+=mid-aidx+1;
		}
	}
	if(aidx!=mid+1){
		cnt+=(lld)(e-mid)*(mid-aidx+1);
		for(;aidx!=mid+1;) t[idx++]=p[aidx++];
	}
	else for(;bidx!=e+1;) t[idx++]=p[bidx++];
	for(int i=0;i<idx;i++) p[i+s]=t[i];
}

// #include<bits/stdc++.h>
// using namespace std;
// typedef vector<int> vi;
// typedef long long int lld;
// typedef pair<int, int> pii;

// bool compare(pii p1, pii p2){return p1.first<p2.first;}
// int mid(int s, int e){return (s+e)>>1;}
// void update(vi& st, int tidx, int s, int e, int nidx, int diff);
// int query(vi& st, int tidx, int s, int e, int fl, int fr);

// int main()
// {
// 	ios::sync_with_stdio(false);
// 	cin.tie(0);
// 	int n, sz=1;
// 	cin>>n;
// 	for(;sz<n;sz*=2);
// 	vi st(2*sz, 0);
// 	vector<pii> pairs;
// 	for(int i=0;i<n;i++){
// 		int t;
// 		cin>>t;
// 		pairs.push_back({t, i});
// 	}
// 	stable_sort(pairs.begin(), pairs.end(), compare);
// 	lld cnt=0;
// 	for(int i=0;i<n;i++){
// 		cnt+=query(st, 1, 0, n-1, pairs[i].second, n-1);
// 		update(st, 1, 0, n-1, pairs[i].second, 1);
// 	}
// 	cout<<cnt;
// }

// void update(vi& st, int tidx, int s, int e, int nidx, int diff){
// 	if(s>nidx || e<nidx) return;
// 	st[tidx]+=diff;
// 	if(s!=e){
// 		update(st, tidx*2, s, mid(s, e), nidx, diff);
// 		update(st, tidx*2+1, mid(s, e)+1, e, nidx, diff);
// 	}
// }

// int query(vi& st, int tidx, int s, int e, int fl, int fr){
// 	if(s>fr || e<fl) return 0;
// 	if(s>=fl && e<=fr) return st[tidx];
// 	return query(st, tidx*2, s, mid(s, e), fl, fr)+query(st, tidx*2+1, mid(s, e)+1, e, fl, fr);
// }

풀고나서

  • merge sort를 이용한 풀이에서 A를 이용한 갱신에서 시간을 낭비했다. 후처리 과정에서 (남은 A)_(모든 B)를 cnt에 더해야 하는데 이 항을 캐스팅하지 않아서 계속 WA가 나왔다. 심지어 (lld)((~)_(~))로 하면 계산된 결과가 캐스팅되기 때문에 오버플로우 값 그대로 나온다는 것을 간과했다.

Leave a comment