BOJ 16639 - 괄호 추가하기 3
문제
- 아래 조건들을 만족하는 수식이 주어진다.
- 길이가 1이상 19 이하
- 수식의 수들은 모두 한 자릿수(0-9)
- 연산자는 +, -, *
- 연산자 우선순위 존재
- 괄호 없음
- 아래 조건을 만족하며 수식에 괄호를 추가할 수 있을 때 결과의 최댓값을 출력해야 한다.
- 괄호는 여러 개의 연산자 포함 가능
- 중복된 괄호 가능
- 입력은 올바른 수식만 주어진다.
- 정답은 \(-2^{31}\) 초과 \(2^{31}\) 미만이다.
풀이
- 이전에 생각했던 대로 DFS를 사용하여 풀었다.
- 각 연산자를 노드로 생각하고, 수식을 인자로 받는 함수로 구현해서 모든 경우를 탐색했다.
- 실행 시간은 112ms이다.
DFS 코드
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
#include <iostream>
#include <stack>
#include <string>
#include <vector>
#define v vector
using namespace std;
using vi = vector<int>;
//DFS
int pow(int x, int p) {
int r = 1;
while (p) {
if (p & 0x01) r *= x;
x *= x;
p >>= 1;
}
return r;
}
bool isdigit(char c) {
if (c >= '0' && c <= '9')
return true;
else
return false;
}
int solve(vector<long long> ts) {
int res = -1 * pow(2, 31), tr;
if (ts.size() == 1) return ts[0];
for (int i = 1; i < ts.size(); i += 2) {
int n1 = ts[i - 1], n2 = ts[i + 1];
vector<long long> tts(0);
tts.insert(tts.end(), ts.begin(), ts.begin() + i - 1);
if (ts[i] == -1) {
tts.push_back(n1 + n2);
} else if (ts[i] == 0)
tts.push_back(n1 - n2);
else
tts.push_back(n1 * n2);
tts.insert(tts.end(), ts.begin() + i + 2, ts.end());
tr = solve(tts);
if (res < tr) res = tr;
}
return res;
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0), cout.tie(0);
int n;
vector<long long> ts;
string s;
cin >> n;
cin >> s;
for (int i = 0; i < s.length(); i++) {
if (isdigit(s[i])) {
ts.push_back(s[i] - '0');
} else if (s[i] == '+')
ts.push_back(-1);
else if (s[i] == '-')
ts.push_back(0);
else
ts.push_back(1);
}
cout << solve(ts);
return 0;
}
실행 시간이 비교적 느려 다른 사람들의 코드를 확인해보니, dp로도 풀 수 있는 것을 알았고, dp로도 풀어보았다.
- 점화식은 아래와 같다.
dmax
,dmin
은 수식의 숫자만을 인덱스로 가진다.dmax(i)(j)
는 인덱스가[i, j]
인 숫자들을 활용하여 계산했을 때 가능한 최댓값이다.op(i)(j)
는i
,j
번째 숫자 사이의 연산자이다.
DP 코드
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
#include <algorithm>
#include <iostream>
#include <string>
#include <vector>
#define v vector
using namespace std;
using vi = vector<int>;
// dp
long long calc(long long n1, long long n2, char op) {
if (op == '+')
return n1 + n2;
else if (op == '-')
return n1 - n2;
else
return n1 * n2;
}
int main() {
int n, nn;
string str;
cin >> n;
cin >> str;
nn = n / 2 + 1;
v<v<long long>> dmax(nn, v<long long>(nn, -3486784410)), dmin(nn, v<long long>(nn, 3486784410));
for (int i = 0; i < nn; i++) {
dmax[i][i] = str[i * 2] - '0';
dmin[i][i] = str[i * 2] - '0';
}
for (int d = 1; d < nn; d++) {
for (int s = 0; s < nn - d; s++) {
for (int pv = 0; pv < d; pv++) {
char op = str[(s + pv) * 2 + 1];
v<long long> tv;
tv.push_back(calc(dmax[s][s + pv], dmax[s + pv + 1][s + d], op));
tv.push_back(calc(dmax[s][s + pv], dmin[s + pv + 1][s + d], op));
tv.push_back(calc(dmin[s][s + pv], dmax[s + pv + 1][s + d], op));
tv.push_back(calc(dmin[s][s + pv], dmin[s + pv + 1][s + d], op));
sort(tv.begin(), tv.end());
if (dmax[s][s + d] < tv[3]) dmax[s][s + d] = tv[3];
if (dmin[s][s + d] > tv[0]) dmin[s][s + d] = tv[0];
}
}
}
cout << dmax[0][nn - 1];
return 0;
}
풀고나서
- DFS로 얼른 풀겠다는 생각이 강해 dp를 고려해보지 못한게 아쉬웠다.
Leave a comment