문제

  • 아래 조건들을 만족하는 수식이 주어진다.
    • 길이가 1이상 19 이하
    • 수식의 수들은 모두 한 자릿수(0-9)
    • 연산자는 +, -, *
    • 연산자 우선순위 존재
    • 괄호 없음
  • 아래 조건을 만족하며 수식에 괄호를 추가할 수 있을 때 결과의 최댓값을 출력해야 한다.
    • 괄호는 여러 개의 연산자 포함 가능
    • 중복된 괄호 가능
  • 입력은 올바른 수식만 주어진다.
  • 정답은 \(-2^{31}\) 초과 \(2^{31}\) 미만이다.

풀이

  • 이전에 생각했던 대로 DFS를 사용하여 풀었다.
    • 각 연산자를 노드로 생각하고, 수식을 인자로 받는 함수로 구현해서 모든 경우를 탐색했다.
    • 실행 시간은 112ms이다.
DFS 코드
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
#include <iostream>
#include <stack>
#include <string>
#include <vector>
#define v vector

using namespace std;
using vi = vector<int>;

//DFS
int pow(int x, int p) {
    int r = 1;
    while (p) {
        if (p & 0x01) r *= x;
        x *= x;
        p >>= 1;
    }
    return r;
}

bool isdigit(char c) {
    if (c >= '0' && c <= '9')
        return true;
    else
        return false;
}

int solve(vector<long long> ts) {
    int res = -1 * pow(2, 31), tr;
    if (ts.size() == 1) return ts[0];
    for (int i = 1; i < ts.size(); i += 2) {
        int n1 = ts[i - 1], n2 = ts[i + 1];
        vector<long long> tts(0);
        tts.insert(tts.end(), ts.begin(), ts.begin() + i - 1);
        if (ts[i] == -1) {
            tts.push_back(n1 + n2);
        } else if (ts[i] == 0)
            tts.push_back(n1 - n2);
        else
            tts.push_back(n1 * n2);
        tts.insert(tts.end(), ts.begin() + i + 2, ts.end());
        tr = solve(tts);
        if (res < tr) res = tr;
    }
    return res;
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(0), cout.tie(0);
    int n;
    vector<long long> ts;
    string s;
    cin >> n;
    cin >> s;
    for (int i = 0; i < s.length(); i++) {
        if (isdigit(s[i])) {
            ts.push_back(s[i] - '0');
        } else if (s[i] == '+')
            ts.push_back(-1);
        else if (s[i] == '-')
            ts.push_back(0);
        else
            ts.push_back(1);
    }
    cout << solve(ts);
    return 0;
}

실행 시간이 비교적 느려 다른 사람들의 코드를 확인해보니, dp로도 풀 수 있는 것을 알았고, dp로도 풀어보았다.

  • 점화식은 아래와 같다.
\[\text{dmax(i)(j)}=\text{max} \begin{pmatrix} dmax(i)(i+k)\;op(i+k)(i+k+1)\;dmax(i+k+1)(j) \\ dmax(i)(i+k)\;op(i+k)(i+k+1)\;dmin(i+k+1)(j) \\ dmin(i)(i+k)\;op(i+k)(i+k+1)\;dmax(i+k+1)(j) \\ dmin(i)(i+k)\;op(i+k)(i+k+1)\;dmin(i+k+1)(j) \end{pmatrix}\] \[\text{dmin(i)(j)}=\text{min} \begin{pmatrix} dmax(i)(i+k)\;op(i+k)(i+k+1)\;dmax(i+k+1)(j) \\ dmax(i)(i+k)\;op(i+k)(i+k+1)\;dmin(i+k+1)(j) \\ dmin(i)(i+k)\;op(i+k)(i+k+1)\;dmax(i+k+1)(j) \\ dmin(i)(i+k)\;op(i+k)(i+k+1)\;dmin(i+k+1)(j) \end{pmatrix}\]
  • dmax, dmin은 수식의 숫자만을 인덱스로 가진다.
  • dmax(i)(j)는 인덱스가 [i, j]인 숫자들을 활용하여 계산했을 때 가능한 최댓값이다.
  • op(i)(j)i, j 번째 숫자 사이의 연산자이다.
DP 코드
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
#include <algorithm>
#include <iostream>
#include <string>
#include <vector>
#define v vector

using namespace std;
using vi = vector<int>;

// dp
long long calc(long long n1, long long n2, char op) {
    if (op == '+')
        return n1 + n2;
    else if (op == '-')
        return n1 - n2;
    else
        return n1 * n2;
}

int main() {
    int n, nn;
    string str;
    cin >> n;
    cin >> str;
    nn = n / 2 + 1;
    v<v<long long>> dmax(nn, v<long long>(nn, -3486784410)), dmin(nn, v<long long>(nn, 3486784410));
    for (int i = 0; i < nn; i++) {
        dmax[i][i] = str[i * 2] - '0';
        dmin[i][i] = str[i * 2] - '0';
    }
    for (int d = 1; d < nn; d++) {
        for (int s = 0; s < nn - d; s++) {
            for (int pv = 0; pv < d; pv++) {
                char op = str[(s + pv) * 2 + 1];
                v<long long> tv;
                tv.push_back(calc(dmax[s][s + pv], dmax[s + pv + 1][s + d], op));
                tv.push_back(calc(dmax[s][s + pv], dmin[s + pv + 1][s + d], op));
                tv.push_back(calc(dmin[s][s + pv], dmax[s + pv + 1][s + d], op));
                tv.push_back(calc(dmin[s][s + pv], dmin[s + pv + 1][s + d], op));
                sort(tv.begin(), tv.end());
                if (dmax[s][s + d] < tv[3]) dmax[s][s + d] = tv[3];
                if (dmin[s][s + d] > tv[0]) dmin[s][s + d] = tv[0];
            }
        }
    }
    cout << dmax[0][nn - 1];
    return 0;
}

풀고나서

  • DFS로 얼른 풀겠다는 생각이 강해 dp를 고려해보지 못한게 아쉬웠다.

Tags: , ,

Categories:

Updated:

Leave a comment