본문 바로 가기

PS

2021 KOI 고등부 2차 대회 후기

2021 한국 정보 올림피아드 2차 대회가 2021년 7월 25일에 온라인으로 진행되었습니다. 모두 수고하셨습니다.

 

총점 278점을 받았습니다.

 들어가기 전에 하소연을 좀 하겠습니다. 사실 제가 잘못한 것이라 억울할 것도 없지만, 3번 문제의 4번 서브태스크를 왜인지 모르겠지만 \(|A|=|B|\)로 봤고, 그래서 긁지 않았습니다. 긁었다면 총점 295점인데, 이것으로 상의 색깔이 바뀐다면 많이 아쉽겠습니다.

타임라인

00:00:00 - 00:40:04

 1번 문제 "헬기 착륙장"을 읽고 해결했습니다. 1번으로 예상했던 난이도보다 어려워서 당황했고, 시간이 좀 오래 걸렸습니다.

 

#include <cstdio>
#include <algorithm>
using namespace std;

const int MOD = 1e9 + 7;

int t, a, b, dp[100006][506], sum[100006][506];

int f(int x, int y) {
	if (x < 0) return 0;
	return dp[x][y];
}

int main() {
	dp[0][0] = dp[1][1] = 1;
	for (int i = 0; i < 506; i++) sum[0][i] = 1;
	for (int i = 1; i < 506; i++) sum[1][i] = 1;
	for (int i = 2; i < 100006; i++) for (int j = 1; j < 506; j++) {
		if (i >= j) dp[i][j] = sum[i - j][j - 1];
		sum[i][j] = (dp[i][j] + sum[i][j - 1]) % MOD;
	}
	for (int i = 0; i < 100006; i++) for (int j = 0; j < 506; j++) {
		if (i) dp[i][j] = (dp[i][j] + dp[i - 1][j]) % MOD;
		if (j) dp[i][j] = (dp[i][j] + dp[i][j - 1]) % MOD;
		if (i && j) dp[i][j] = (dp[i][j] - dp[i - 1][j - 1] + MOD) % MOD;
	}
	for (scanf("%d", &t); t--; ) {
		scanf("%d%d", &a, &b);
		int res = 0;
		for (int i = 1; i < 500; i++) {
			if (a >= i) {
				int x = a - i;
				int _min = max(0, i * (i - 1) / 2 - x);
				int _max = min(b, i * (i - 1) / 2);
				if (_min <= _max) {
					res = (res + f(_max, i - 1)) % MOD;
					res = (res - f(_min - 1, i - 1) + MOD) % MOD;
				}
			}
			if (b >= i) {
				int x = b - i;
				int _min = max(0, i * (i - 1) / 2 - x);
				int _max = min(a, i * (i - 1) / 2);
				if (_min <= _max) {
					res = (res + f(_max, i - 1)) % MOD;
					res = (res - f(_min - 1, i - 1) + MOD) % MOD;
				}
			}
		}
		printf("%d\n", res);
	}
}

 

00:40:04 - 01:03:24

 2번 문제 "그래프 균형 맞추기"를 읽고 해결했습니다. 문제가 BOI 2020 Graph와 판박이라서 쉽게 해결할 수 있었습니다.

 

#include <cstdio>
#include <vector>
#include <utility>
#include <algorithm>
using namespace std;

int n, m, res[100006];
vector<pair<int, int>> adj[100006];
vector<long long> v;
pair<int, long long> a[100006];
bool visited[100006];

void dfs(int x) {
    visited[x] = true;
    for (auto &i: adj[x]) if (!visited[i.first]) {
        a[i.first] = { -a[x].first, i.second - a[x].second };
        dfs(i.first);
    }
}

int main() {
    scanf("%d%d", &n, &m);
    for (int i = 0; i < m; i++) {
        int u, v, w;
        scanf("%d%d%d", &u, &v, &w);
        u--, v--;
        adj[u].push_back({ v, w });
        adj[v].push_back({ u, w });
    }
    a[0] = { 1, 0 };
    dfs(0);
    for (int i = 0; i < n; i++) for (auto &j: adj[i]) if (i < j.first) {
        if (a[i].first != a[j.first].first) {
            if (a[i].second + a[j.first].second == j.second) continue;
            puts("No");
            return 0;
        } else v.push_back(2 * (j.second - a[j.first].second - a[i].second) / (a[i].first + a[j.first].first));
    }
    sort(v.begin(), v.end());
    v.erase(unique(v.begin(), v.end()), v.end());
    if ((int)v.size() > 1) {
        puts("No");
        return 0;
    }
    if ((int)v.size() == 1) {
        if (v[0] % 2) puts("No");
        else {
            puts("Yes");
            long long x = v[0] / 2;
            for (int i = 0; i < n; i++) printf("%lld ", a[i].first * x + a[i].second);
        }
        return 0;
    }
    for (int i = 0; i < n; i++) v.push_back(-a[i].first * a[i].second);
    sort(v.begin(), v.end());
    long long x = v[(int)v.size() / 2];
    puts("Yes");
    for (int i = 0; i < n; i++) printf("%lld ", a[i].first * x + a[i].second);
}

 

01:03:24 - 01:28:55

 3번 문제와 4번 문제를 읽고 풀이를 생각했습니다. 뒤이어 4번 문제의 9점 서브태스크를 스위핑 + DP로 긁었습니다.

 

#include <tuple>
#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;

int n, m;
long long dp[100006];
vector<int> adj[100006];
vector<tuple<int, int, int>> v;

int main() {
    scanf("%d%d", &n, &m);
    for (int i = 0; i < n - 1; i++) {
        int u, v;
        scanf("%d%d", &u, &v);
        u--, v--;
        adj[u].push_back(v);
        adj[v].push_back(u);
    }
    for (int i = 0; i < m; i++) {
        int c, d, g;
        scanf("%d%d%d", &c, &d, &g);
        c--;
        v.push_back({ max(c - d, 0), min(c + d, n - 1), g });
    }
    sort(v.begin(), v.end());
    long long curr = 0;
    int k = 0;
    for (int i = 0; i < n; i++) {
        while (k < m && get<0>(v[k]) <= i) {
            dp[get<1>(v[k])] = max(dp[get<1>(v[k])], curr + get<2>(v[k]));
            k++;
        }
        curr = max(curr, dp[i]);
    }
    printf("%lld", curr);
}

 

01:28:55 - 01:39:45

 3번 문제에서 해싱을 이용한 \(O(N^2\log N+M^2\log M)\) 풀이를 작성했고, 41점을 긁었습니다. 답안을 제출한 뒤 3번 서브태스크가 애매해서 시간 초과를 받을 거라고 생각하고 LCS 6과 비슷하게 23점을 긁을 궁리를 하고 있었는데, 41점이 긁혀서 꽤 놀랐습니다.

 

#include <cstdio>
#include <vector>
#include <utility>
#include <cstring>
#include <algorithm>
using namespace std;

const int MOD = 1e9 + 7;

int t, n, m;
char a[500006], b[500006];
vector<pair<int, long long>> x, y;

int main() {
    for (scanf("%d", &t); t--; ) {
        scanf("%s%s", a, b); x.clear(); y.clear();
        n = strlen(a), m = strlen(b);
        for (int i = 0; i < n; i++) {
            int r = 0;
            long long curr = 0;
            for (int j = i; j < n; j++) {
                r += a[j] == '(' ? 1 : -1;
                curr = 2 * curr % MOD;
                if (a[j] == '(') curr = (curr + 1) % MOD;
                if (r < 0) break;
                if (r == 0) x.push_back({ j - i + 1, curr });
            }
        }
        for (int i = 0; i < m; i++) {
            int r = 0;
            long long curr = 0;
            for (int j = i; j < m; j++) {
                r += b[j] == '(' ? 1 : -1;
                curr = 2 * curr % MOD;
                if (b[j] == '(') curr = (curr + 1) % MOD;
                if (r < 0) break;
                if (r == 0) y.push_back({ j - i + 1, curr });
            }
        }
        sort(x.begin(), x.end(), [](const pair<int, long long> &a, const pair<int, long long> &b) {
            if (a.first != b.first) return a > b;
            return a.second < b.second;
        });
        sort(y.begin(), y.end(), [](const pair<int, long long> &a, const pair<int, long long> &b) {
            if (a.first != b.first) return a > b;
            return a.second < b.second;
        });
        int k = 0, res = 0;
        for (int i = 0; i < (int)x.size(); i++) {
            while (k < (int)y.size() && (y[k].first > x[i].first || y[k].first == x[i].first && y[k].second < x[i].second)) k++;
            while (k < (int)y.size() && y[k] == x[i]) res = max(res, x[i].first), k++;
        }
        printf("%d\n", res);
    }
}

 

01:39:45 - 02:56:38

 HLD를 이용한 4번 문제의 풀이를 찾았다고 생각하고, 신나서 100점 풀이를 짰지만 반례가 \(N\) 개 존재한다는 사실을 알게 되었고, 접었습니다.

 

#include <tuple>
#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;

int n, m, sz[100006], depth[100006], par[100006], in[100006], rev[100006], t;
long long dp[100006], res;
vector<int> adj[100006], childs[100006];
vector<tuple<int, int, long long>> tag[100006];

int dfs_sz(int x, int prev = -1) {
    sz[x] = 1; par[x] = prev;
    for (auto &i: adj[x]) if (i != prev) {
        childs[x].push_back(i);
        depth[i] = depth[x] + 1;
        sz[x] += dfs_sz(i, x);
        if (sz[childs[x][0]] < sz[i]) swap(childs[x][0], sz[i]);
    }
    return sz[x];
}

void dfs_hld(int x) {
    in[x] = t;
    rev[t] = x;
    t++;
    for (auto &i: childs[x]) dfs_hld(i);
}

vector<int> lt;
vector<tuple<int, int, long long>> v;

void dfs_bond(int x, int y) {
    dp[in[y] + (depth[x] - depth[y])] += dp[in[x]];
    if (!childs[x].empty()) dfs_bond(childs[x][0], y);
}

void dfs_res(int x, int y) {
    sort(v.begin(), v.end());
    reverse(v.begin(), v.end());
    long long curr = 0;
    int k = 0;
    for (int i = y; true; i = par[i]) {
        while (k < (int)v.size() && get<0>(v[k]) >= i) {
            if (in[get<1>(v[k])] < in[x]) {
                if (!x) dp[0] = max(dp[0], curr + get<2>(v[k]));
                else {
                    int inc = (get<0>(v[k]) + get<1>(v[k])) / 2;
                    int c = rev[inc];
                    int d = get<0>(v[k]) - c;
                    int diff = depth[c] - depth[par[x]];
                    int nd = d - diff;
                    if (in[x] + 1 <= in[y]) get<2>(v[k]) += dp[in[x] + 1];
                    tag[par[x]].push_back({ in[par[x]] + nd, in[par[x]] - nd, get<2>(v[k]) });
                }
            } else dp[in[get<1>(v[k])]] = max(dp[in[get<1>(v[k])]], curr + get<2>(v[k]));
            k++;
        }
        curr = max(curr, dp[in[i]]);
        if (i == x) break;
    }
    res = max(res, curr);
}

void dfs_calc(int x, int y, int z) {
    lt.push_back(x);
    for (auto &i: tag[x]) v.push_back(i);
    for (int i = 0; i < (int)childs[x].size(); i++) dfs_bond(childs[x][i], x);
    if (x == y) dfs_res(y, z);
    else dfs_calc(par[x], y, z);
}

void dfs(int x, int y) {
    if (childs[x].empty()) {
        v.clear();
        lt.clear();
        dfs_calc(x, y, x);
        return;
    }
    for (int i = 1; i < (int)childs[x].size(); i++) dfs(childs[x][i], childs[x][i]);
    dfs(childs[x][0], y);
}

int main() {
    scanf("%d%d", &n, &m);
    for (int i = 0; i < n - 1; i++) {
        int u, v;
        scanf("%d%d", &u, &v);
        u--, v--;
        adj[u].push_back(v);
        adj[v].push_back(u);
    }
    dfs_sz(0);
    dfs_hld(0);
    for (int i = 0; i < m; i++) {
        int c, d, g;
        scanf("%d%d%d", &c, &d, &g);
        c--;
        tag[c].push_back({ in[c] + d, in[c] - d, g });
    }
    dfs(0, 0);
    printf("%lld", res);
}

 

02:56:38 - 03:37:21

 4번 문제의 \(O(N(N+M)\log N)\) 풀이를 찾아 28점을 추가로 얻었습니다. 

 

#include <tuple>
#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;

int n, m, par[20][100006], depth[100006];
long long dp[100006], mx[100006];
vector<int> adj[100006];
vector<tuple<int, int, int>> tag[100006];

void dfs(int x, int prev = -1) {
    par[0][x] = prev;
    for (auto &i: adj[x]) if (i != prev) {
        depth[i] = depth[x] + 1;
        dfs(i, x);
    }
}

int up(int x, int y) {
    for (int t = 19; t >= 0; t--) if (y >= 1 << t && x != -1) {
        y -= 1 << t;
        x = par[t][x];
    }
    if (x == -1) x = 0;
    return x;
}

int dist(int x, int y) {
    if (depth[x] > depth[y]) swap(x, y);
    int diff = depth[y] - depth[x], l = diff;
    for (int t = 19; t >= 0; t--) if (diff >= 1 << t) {
        diff -= 1 << t;
        y = par[t][y];
    }
    if (x == y) return l;
    for (int t = 19; t >= 0; t--) if (par[t][x] != -1 && par[t][y] != -1 && par[t][x] != par[t][y]) {
        x = par[t][x];
        y = par[t][y];
        l += 1 << t + 1;
    }
    return l + 2;
}

vector<int> v;

void dfs_t(int x, int prev = -1) {
    v.push_back(x);
    for (auto &i: adj[x]) if (i != prev) dfs_t(i, x);
}

long long dfs_res(int x, int prev = -1) {
    for (auto &i: adj[x]) if (i != prev) dp[x] += dfs_res(i, x);
    v.clear();
    dfs_t(x, prev);
    for (auto &i: tag[x]) {
        long long curr = get<2>(i);
        for (auto &j: v) if (dist(get<0>(i), j) == get<1>(i) + 1) curr += mx[j];
        dp[x] = max(dp[x], curr);
    }
    return mx[x] = max(mx[x], dp[x]);
}

int main() {
    scanf("%d%d", &n, &m);
    for (int i = 0; i < n - 1; i++) {
        int u, v;
        scanf("%d%d", &u, &v);
        u--, v--;
        adj[u].push_back(v);
        adj[v].push_back(u);
    }
    dfs(0);
    for (int t = 1; t < 20; t++) for (int i = 0; i < n; i++) {
        if (par[t - 1][i] == -1) par[t][i] = -1;
        else par[t][i] = par[t - 1][par[t - 1][i]];
    }
    for (int i = 0; i < m; i++) {
        int c, d, g;
        scanf("%d%d%d", &c, &d, &g);
        c--;
        tag[up(c, d)].push_back({ c, d, g });
    }
    printf("%lld", dfs_res(0));
}

 

03:37:21 - 04:30:00

 4번 문제의 6번 서브태스크를 긁다가 대회가 끝났습니다. 나중에 들은 바에 따르면 4번 문제는 방역과 비슷했다고 합니다.

 

#include <tuple>
#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;

int n, m, t, in[100006], depth[100006];
long long dp[100006], mx[100006], res, tot[100006];
vector<int> adj[100006];
vector<tuple<int, int, long long>> tag[100006];

void dfs(int x, int prev = -1) {
    in[x] = t++;
    for (auto &i: adj[x]) if (i != prev) dfs(i, x);
}

void dfs_sz(int x, int prev = -1) {
    for (auto &i: adj[x]) if (i != prev) {
        depth[in[i]] = depth[in[x]] + 1;
        dfs_sz(i, x);
    }
}

int tail(int x, int prev = -1) {
    if ((int)adj[x].size() == 1) return x;
    return tail(adj[x][0] + adj[x][1] - prev, x);
}

void solve(int x, int y) {
    long long curr = 0;
    vector<tuple<int, int, int>> v;
    for (int i = y; i >= x; i--) for (auto &j: tag[i]) v.push_back(j);
    sort(v.begin(), v.end(), [](const tuple<int, int, int> &a, const tuple<int, int, int> &b) {
        if (get<0>(a) + get<1>(a) != get<0>(b) + get<1>(b)) return get<0>(a) + get<1>(a) > get<0>(b) + get<1>(b);
        return get<0>(a) - get<1>(a) > get<0>(b) - get<1>(b);
    });
    int k = 0;
    for (int i = y; i >= x; i--) {
        while (k < (int)v.size() && i <= get<0>(v[k]) + get<1>(v[k])) {
            if (get<0>(v[k]) - get<1>(v[k]) >= x) dp[get<0>(v[k]) - get<1>(v[k])] = max(dp[get<0>(v[k]) - get<1>(v[k])], get<2>(v[k]) + curr);
            k++;
        }
        curr = max(curr, dp[i]);
        mx[i] = curr;
    }
    k = 0, curr = 0;
    for (int i = y; i >= x; i--) {
        while (k < (int)v.size() && i <= get<0>(v[k]) + get<1>(v[k])) {
            if (get<0>(v[k]) - get<1>(v[k]) < x) tag[0].push_back({ 0, get<1>(v[k]) - depth[get<0>(v[k])], get<2>(v[k]) + curr - mx[x + get<1>(v[k]) - depth[get<0>(v[k])]]});
            k++;
        }
        curr = max(curr, dp[i]);
        tot[i - x + 1] += mx[i - x + 1];
    }
    res = max(res, curr);
}

int main() {
    scanf("%d%d", &n, &m);
    for (int i = 0; i < n - 1; i++) {
        int u, v;
        scanf("%d%d", &u, &v);
        u--, v--;
        adj[u].push_back(v);
        adj[v].push_back(u);
    }
    int x = 0;
    for (int i = 0; i < n; i++) if ((int)adj[i].size() > 2) x = i;
    dfs(x);
    dfs_sz(x);
    for (int i = 0; i < m; i++) {
        int c, d, g;
        scanf("%d%d%d", &c, &d, &g);
        c--;
        tag[in[c]].push_back({ in[c], d, g });
    }
    for (auto &i: adj[x]) solve(in[i], in[tail(i, x)]);
    for (auto &i: tag[0]) {
        res = max(res, get<2>(i) + tot[get<1>(i) + 1]);
        if (get<1>(i) < 0) return 1;
    }
    printf("%lld", max(res, tot[1]));
}

 

소감

 아무래도 아쉬움이 큰 대회였습니다. 대회 전에는 300점을 넘길 생각을 하고 있었는데, 성적은 생각만큼 나오지 않아서 많이 아쉬웠습니다. 가장 아쉬운 점은 앞에서 말했듯 17점을 긁지 못한 점이었습니다. 다만 운이 좋으면 은상을 받을 수 있을 것 같습니다. 옆 블로그에 누구누구는 PS를 시작한 지 한 해 만에 무서운 기록들을 세우던데, 저는 PS를 한 기간에 비해 원하는 만큼 성과가 나오지 않아서 슬프네요. 그래도 이번 대회를 더 빡세게 공부하는 계기로 삼아, 다음번에는 금상을 받을 수 있게 해야겠습니다. 이번 결과는 당분간 잊고, 돈 주는 NYPC 준비에 전념해야겠네요.

 

결과

은상을 받았습니다.

 총점 278점을 얻으며 은상을 받았습니다. 금상 커트라인은 300점으로 밝혀지며 억울함은 풀었네요. 그 외에도 대상 커트라인은 325점, 은상 커트라인은 246점이었다고 합니다. 또, 동상 커트라인은 134점, 장려상 커트라인은 41점이었다고 하네요.

 

헬기 착륙장

문제

서브태스크 1

 \(a\leq6\), \(b\leq6\)이므로 \(a+b\leq12\)입니다. 따라서 헬기 착륙장의 반지름 \(k\leq4\)입니다. 헬기 착륙장의 반지름이 작으므로 가능한 모든 경우인 \(\sum_{i=1}^k2^i\) 가지를 모두 보면 됩니다. 문제를 \(O(Tk^22^k)\) 시간 안에 해결할 수 있습니다.

#include <cstdio>

int a, b, cnt;

int main() {
    scanf("%*d%d%d", &a, &b);
    for (int c = 1; c <= 5; c++) {
        for (int i = 0; i < 1 << c; i++) {
            int x = 0, y = 0;
            for (int j = 0; j < c; j++) {
                if (i & 1 << j) x += j + 1;
                else y += j + 1;
            }
            if (x <= a && y <= b) cnt++;
        }
    }
    printf("%d", cnt);
}

 

서브태스크 2

\(a\leq100\), \(b\leq100\)이므로 \(O(T\mathrm{max}(a,b)^3)\) 풀이로 충분합니다. \(f_{i,\,j,\,k}\)를 반지름이 \(i\)이고 빨강 페인트를 \(a\) 통, 파랑 페인트를 \(b\) 통 사용하는 헬기 착륙장의 개수로 정의하면 다음 점화식이 세워집니다.

$$f_{i,\,j,\,k}=\begin{cases}0&(\text{$j<0$ or $k<0$})\\f_{i-1,\,j-i,\,k}+f_{i-1,\,j,\,k-i}&(\text{otherwise})\end{cases}\,(i\geq1)$$

그리고 편의상

$$f_{0,\,i,\,j}=\begin{cases}1&(\text{$j=0$ and $k=0$})\\0&(\text{otherwise})\end{cases}$$

문제의 답은 \(\sum_{i>0,\,j\leq a,\,k\leq b} f_{i,\,j,\,k}\)입니다. 따라서 문제를 \(O(Tab(a+b))\) 시간 안에 해결할 수 있습니다.

#include <cstdio>
#include <algorithm>
using namespace std;

const int MOD = 1e9 + 7;

int a, b, dp[206][106][106], res;

int main() {
    dp[0][0][0] = 1;
    scanf("%*d%d%d", &a, &b);
    for (int i = 1; i <= a + b; i++) for (int x = 0; x <= a; x++) for (int y = 0; y <= b; y++) {
        if (x >= i) dp[i][x][y] = (dp[i][x][y] + dp[i - 1][x - i][y]) % MOD;
        if (y >= i) dp[i][x][y] = (dp[i][x][y] + dp[i - 1][x][y - i]) % MOD;
        res = (res + dp[i][x][y]) % MOD;
    }
    printf("%d", res);
}

 

서브태스크 3

 사실, 서브태스크 2에서 문제의 답은 \(\sum_{0<i\leq\lfloor2(a+b)\rfloor,\,j\leq a,\,k\leq b} f_{i,\,j,\,k}\)입니다. 헬기 착륙장의 최대 반지름 \(k\)에 대하여 \(a+b\leq\frac{1}{2}k(k+1)\)이므로 \(k=\lfloor\sqrt{2(a+b)}\rfloor\)이기 때문입니다. 따라서, 문제를 \(O(Tab\sqrt{a+b})\) 시간 안에 해결할 수 있습니다.

 

#include <cstdio>
#include <algorithm>
using namespace std;

const int MOD = 1e9 + 7;

int a, b, dp[64][1006][1006], res;

int main() {
    dp[0][0][0] = 1;
    scanf("%*d%d%d", &a, &b);
    for (int i = 1; i < 64; i++) for (int x = 0; x <= a; x++) for (int y = 0; y <= b; y++) {
        if (x >= i) dp[i][x][y] = (dp[i][x][y] + dp[i - 1][x - i][y]) % MOD;
        if (y >= i) dp[i][x][y] = (dp[i][x][y] + dp[i - 1][x][y - i]) % MOD;
        res = (res + dp[i][x][y]) % MOD;
    }
    printf("%d", res);
}

 

서브태스크 4

 한 헬기 착륙장의 반지름 \(k\)가 정해지고 사용한 빨강 페인트 통의 수 \(A\)가 정해지면, 사용한 파랑 페인트 통의 수 \(b\)가 정해집니다. 따라서 \(f_{i,\,j}\)를 반지름 \(k=i\), 사용한 빨강 페인트 통의 수 \(A=j\)인 헬기 착륙장의 수로 두면 다음 점화식이 세워집니다.

$$f_{i,\,j}=\begin{cases}0&(\text{$j<0$ or $j>\frac{1}{2}i(i+1)$})\\f_{i-1,j}+f_{i-1,j-i}&(\text{otherwise})\end{cases}\,(i\geq1)$$

그리고 편의상

$$f_{0,\,j}=\begin{cases}1&(\text{$j=0$})\\0&(\text{otherwise}))\end{cases}$$

서브태스크 3과 마찬가지로, 문제의 답은 \(\sum_{0<i\leq\lfloor\sqrt{2(a+b)}\rfloor,\,\frac{1}{2}i(i+1)-b\leq j\leq a}f_{i,\,j}\)입니다. 따라서 문제를 \(O(Ta\sqrt{2(a+b)}\) 시간 안에 해결할 수 있습니다.

 

#include <cstdio>
#include <algorithm>
using namespace std;

const int MOD = 1e9 + 7;

int a, b, dp[156][10006], res;

int main() {
    dp[0][0] = 1;
    scanf("%*d%d%d", &a, &b);
    for (int i = 1; i < 156; i++) for (int j = 0; j <= a; j++) {
        dp[i][j] = dp[i - 1][j];
        if (i <= j) dp[i][j] = (dp[i][j] + dp[i - 1][j - i]) % MOD;
        if (j >= i * (i + 1) / 2 - b) res = (res + dp[i][j]) % MOD;
    }
    printf("%d", res);
}

 

서브태스크 5

 \(f_{i,\,j}\)를 반지름 \(j\)인 빨간 동심원이 최대의 동심원이고 사용한 빨강 페인트 통의 수가 \(i\)인 헬기 착륙장의 수라고 하면 다음 점화식이 세워집니다.

$$f_{i,\,j}=\begin{cases}0&(i<j)\\\sum_{1\\leq k\leq j-1}f_{i-j,\,k}&(i\geq j)\end{cases}\,(i\geq1)$$

그리고 편의상

$$f_{0,\,i}=\begin{cases}1&(i=0)\\0&(i\neq0)\end{cases}$$

 또한, \(f_{i,\,j}\)는 대칭적으로 반지름 \(j\)인 파란 동심원이 최대의 동심원이고 사용한 파랑 페인트 통의 수가 \(i\)인 헬기 착륙장의 수와 서로 같습니다. 이제 \(O(\mathrm{max}(a,b)\sqrt{2(a+b)})\) 시간 안에 동적 계획표 표를 모두 구성할 수 있습니다.

 이제 각 테스트 케이스에 대해 가장 큰 동심원의 반지름 \(k\)가 될 수 있는 범위 \(k\in[1,\lfloor\sqrt{2(a+b)}\rfloor]\)에 대해 모두 생각합니다. 또, 각 테스트 케이스에서 가장 큰 동심원이 빨간 경우와 파란 경우를 각각 따져 줍니다. 일반성을 잃지 않고 가장 큰 동심원이 빨갛다고 합시다. 이제 반지름이 \(k\)인 헬기 착륙장에서 사용한 파랑 페인트 통의 수 \(b\)는 \([\mathrm{max}(0,\frac{1}{2}k(k-1)), \mathrm{min}(b,\frac{1}{2}k(k-1))]\) 범위 안에 있습니다. 따라서 각 테스트 케이스에서 \(\sum_{\mathrm{max}(0,\frac{1}{2}k(k-1))\leq i\leq\mathrm{min}(b,\frac{1}{2}k(k-1))}f_{i,\,k-1}\)의 값을 빠르게 구할 수 있다면 전체 문제를 해결할 수 있습니다. 이는 \(f\)의 누적 합 배열 \(g\)에 대하여 다음과 같은 점화식을 이용하여 알아낼 수 있습니다.

$$g_{i,\,j}=\begin{cases}0&(\text{$i<0$ or $j<0$})\\g_{i-1,\,j}+g_{i,\,j-1}-g_{i-1,\,j-1}&(\text{otherwise})\end{cases}$$

이제 각 테스트 케이스에서 \(g_{\mathrm{max}(0,\frac{1}{2}k(k-1))}-g_{\mathrm{min}(b,\frac{1}{2}k(k-1))}\)의 값을 \(O(1)\) 시간 안에 구할 수 있으므로, 전체 문제를 \(O((\mathrm{max}(a,b)+T)\sqrt{2(a+b)})\) 시간 안에 해결할 수 있습니다.

 

#include <cstdio>
#include <algorithm>
using namespace std;

const int MOD = 1e9 + 7;

int t, a, b, dp[100006][506], sum[100006][506];

int f(int x, int y) {
    if (x < 0) return 0;
    return dp[x][y];
}

int main() {
    dp[0][0] = dp[1][1] = 1;
    for (int i = 0; i < 506; i++) sum[0][i] = 1;
    for (int i = 1; i < 506; i++) sum[1][i] = 1;
    for (int i = 2; i < 100006; i++) for (int j = 1; j < 506; j++) {
        if (i >= j) dp[i][j] = sum[i - j][j - 1];
        sum[i][j] = (dp[i][j] + sum[i][j - 1]) % MOD;
    }
    for (int i = 0; i < 100006; i++) for (int j = 0; j < 506; j++) {
        if (i) dp[i][j] = (dp[i][j] + dp[i - 1][j]) % MOD;
        if (j) dp[i][j] = (dp[i][j] + dp[i][j - 1]) % MOD;
        if (i && j) dp[i][j] = (dp[i][j] - dp[i - 1][j - 1] + MOD) % MOD;
    }
    for (scanf("%d", &t); t--; ) {
        scanf("%d%d", &a, &b);
        int res = 0;
        for (int i = 1; i < 500; i++) {
            if (a >= i) {
                int x = a - i;
                int _min = max(0, i * (i - 1) / 2 - x);
                int _max = min(b, i * (i - 1) / 2);
                if (_min <= _max) {
                    res = (res + f(_max, i - 1)) % MOD;
                    res = (res - f(_min - 1, i - 1) + MOD) % MOD;
                }
            }
            if (b >= i) {
                int x = b - i;
                int _min = max(0, i * (i - 1) / 2 - x);
                int _max = min(a, i * (i - 1) / 2);
                if (_min <= _max) {
                    res = (res + f(_max, i - 1)) % MOD;
                    res = (res - f(_min - 1, i - 1) + MOD) % MOD;
                }
            }
        }
        printf("%d\n", res);
    }
}

 

그래프 균형 맞추기

문제

서브태스크 1

 정점 2와 정점 3 사이를 가중치가 \(c_1\)인 간선이, 정점 3과 정점 1 사이를 가중치가 \(c_2\)인 간선이, 정점 1과 정점 2 사이를 가중치가 \(c_3\)인 간선이 연결한다고 합시다. 이때 정점 1의 가중치 \(x\), 정점 2의 가중치 \(y\), 정점 3의 가중치 \(z\)에 대하여 다음이 성립합니다.

$$x+y=c_3,\;y+z=c_1,\;z+x=c_2$$

따라서,

$$x=\frac{c_2+c_3-c_1}{2}\\\\,\;y=\frac{c_3+c_1-c_2}{2},\;z=\frac{c_1+c_2-c_3}{2}$$

와 같이 세 가중치 \(x\), \(y\), \(z\)가 유일하게 정해집니다. \(c_1+c_2+c_3\)이 홀수라면 그래프의 균형이 맞도록 정수 가중치를 부여하는 방법은 없습니다.

 

#include <cstdio>

int n = 3, m = 3;
int a[3];

int main() {
    scanf("%*d%*d");
    for (int i = 0; i < m; i++) {
        int u, v, w;
        scanf("%d%d%d", &u, &v, &w);
        u--, v--;
        a[3 - u - v] = w;
    }
    if ((a[0] + a[1] + a[2]) % 2) return puts("No"), 0;
    puts("Yes");
    printf("%d %d %d", (a[1] + a[2] - a[0]) / 2, (a[2] + a[0] - a[1]) / 2, (a[0] + a[1] - a[2]) / 2);
}

 

서브태스크 2

 정점들의 가중치가 유일하게 정해지지 않습니다. 간선들의 조건을 만족하는 정점들의 가중치를 정했다고 합시다. 정점들의 가중치는 \(x_1,\,x_2,\,\cdots,\,x_n\)입니다. 여기서 \(\sum|x_i|\)의 값을 최소로 해야 합니다. 이때 \(x_i=0\)인 \(i\)가 존재해도 최적해가 됩니다. 그런데 어떤 정점의 가중치가 정해지면 나머지 정점의 가중치가 모두 정해집니다. 따라서, 모든 \(N\) 개의 정점에 대해 \(x_i=0\)으로 두고, 나머지 정점들의 가중치를 계산합니다. \(\sum|x_i|\)의 최솟값이 답이 됩니다. 문제를 \(O(N^2)\) 시간에 해결할 수 있습니다.

 

#include <cstdio>
#include <utility>
#include <algorithm>
using namespace std;

int n, a[1006];
long long w[1006];
pair<long long, int> res = { (long long)9e18, -1 };

int main() {
    scanf("%d%*d", &n);
    for (int i = 0; i < n - 1; i++) scanf("%*d%*d%d", a + i);
    for (int i = 0; i < n; i++) {
        w[i] = 0;
        for (int j = i - 1; j >= 0; j--) w[j] = a[j] - w[j + 1];
        for (int j = i + 1; j < n; j++) w[j] = a[j - 1] - w[j - 1];
        long long curr = 0;
        for (int j = 0; j < n; j++) curr += w[j] > 0 ? w[j] : -w[j];
        res = min(res, pair{ curr, i });
    }
    puts("Yes");
    w[res.second] = 0;
    for (int i = res.second - 1; i >= 0; i--) w[i] = a[i] - w[i + 1];
    for (int i = res.second + 1; i < n; i++) w[i] = a[i - 1] - w[i - 1];
    for (int i = 0; i < n; i++) printf("%lld ", w[i]);
}

 

서브태스크 3

 \(i\)번 간선의 가중치를 \(a_i\)로 둡시다. \(1\)번 정점의 가중치를 \(x\)라고 두면, \(2\)번 정점의 가중치 \(a_1-x\), \(3\)번 정점의 가중치 \(a_2-a_1+x\), …, \(n\)번 정점의 가중치 \(a_n-a_{n-1}+\cdots+(-1)^nx\)가 모두 정해집니다. 따라서 최소로 하는 값은 \(f(x)=|x|+|x-a_1|+|x-a_1+a_2|+\cdots+|x-a_1+a_2-\cdots+(-1)^na_n|\)입니다. 간단히, \(c_1\leq c_2\leq\cdots\leq c_n\)에 대해 \(f(x)=|x-c_1|+|x-c_2|+\cdots+|x-c_n|\)가 되게 하는 \(c_i\)들을 \(O(N\log N)\) 시간에 찾을 수 있습니다. 이때 \(x\)의 범위에 따른 \(f(x)\)는 기울기가 음수에서 양수로 점차 증가하는 꺾인 직선 모양의 개형을 보입니다. \(x<c_{\lfloor\frac{n+1}{2}\rfloor}\)에서 \(f(x)\)의 기울기는 음수, \(x>c_{\lceil\frac{n+1}{2}\rceil}\)에서 \(f(x)\)의 기울기는 양수이고 \(c_{\lfloor\frac{n+1}{2}\rfloor}\leq x\leq c_{\lceil\frac{n+1}{2}\rceil}\)에서 \(f(x)\)의 기울기는 \(0\)이므로, \(f(x)\)는 \(x=c_{\lfloor\frac{n+1}{2}\rfloor}\)에서 최소입니다. 따라서 모든 정점의 가중치를 \(O(N)\) 시간 안에 구할 수 있고, 문제를 \(O(N\log N)\) 시간 안에 해결할 수 있습니다.

 

#include <cstdio>
#include <algorithm>
using namespace std;

int n, a[100006];
long long w[100006], c[100006];

int main() {
    scanf("%d%*d", &n);
    for (int i = 0; i < n - 1; i++) scanf("%*d%*d%d", a + i);
    for (int i = 1; i < n; i++) {
        w[i] = w[i - 1];
        if (i % 2) w[i] += a[i - 1];
        else w[i] -= a[i - 1];
    }
    for (int i = 0; i < n; i++) c[i] = w[i];
    sort(c, c + n);
    puts("Yes");
    for (int i = 0; i < n; i++) {
        if (i % 2) printf("%lld ", w[i] - c[n / 2]);
        else printf("%lld ", c[n / 2] - w[i]);
    }
}

 

서브태스크 4

 서브태스크 3과 같은 관찰이 성립합니다. \(1\)번 정점의 가중치를 \(x-a_1\,(a_1=0)\)라고 두면, 트리를 탐색하며 \(i(>1)\)번 정점의 가중치 \(x-a_i\) 또는 \(-x+a_i\)를 \(O(N)\) 시간 안에 구할 수 있고, 마찬가지로 \(f(x)=|x-c_1|+|x-c_2|+\cdots+|x-c_n|\)가 \(x=c_{\lfloor\frac{n+1}{2}\rfloor}\)에서 최소이므로 모든 정점의 가중치를 \(O(N)\) 시간 안에 구할 수 있습니다. 문제를 \(O(N\log N)\) 시간 안에 해결할 수 있습니다.

 

#include <cstdio>
#include <vector>
#include <utility>
#include <algorithm>
using namespace std;

int n;
vector<pair<int, int>> adj[100006];
vector<long long> v;
pair<int, long long> a[100006];

void dfs(int x, int prev = -1) {
    for (auto &i: adj[x]) if (i.first != prev) {
        a[i.first] = { -a[x].first, i.second - a[x].second };
        dfs(i.first, x);
    }
}

int main() {
    scanf("%d%*d", &n);
    for (int i = 0; i < n - 1; i++) {
        int u, v, w;
        scanf("%d%d%d", &u, &v, &w);
        u--, v--;
        adj[u].push_back({ v, w });
        adj[v].push_back({ u, w });
    }
    a[0] = { 1, 0 };
    dfs(0);
    for (int i = 0; i < n; i++) v.push_back(-a[i].first * a[i].second);
    sort(v.begin(), v.end());
    long long x = v[(int)v.size() / 2];
    puts("Yes");
    for (int i = 0; i < n; i++) printf("%lld ", a[i].first * x + a[i].second);
}

 

서브태스크 5

 그래프는 정확히 하나의 큰 사이클을 이룹니다. \(N\)이 홀수라면 정점들의 가중치는 유일하게 결정되며, \(N\)이 짝수라면 정점들의 가중치가 유일하게 결정되지 않습니다. 사이클 위의 정점들을 순서대로 \(1,\;2,\;\cdots,\;N\)번 정점이라고 합시다.

 · \(N\)이 홀수라면 \(1\)번 정점의 가중치를 \(x-c_1\,(c_1=0)\)이라고 할 때 \(k\)번 정점의 가중치는 \(x-c_k\)입니다. 따라서 \(1\)번 정점과 \(k\)번 정점을 연결하는 간선의 가중치 \(p\)에 대하여 \(x=\frac{c_1+c_k+p}{2}\)입니다. 이제 모든 정점의 가중치를 유일하게 결정할 수 있고, \(x\)가 정수가 아니라면 그래프의 균형이 맞도록 정점에 정수 가중치를 부여하는 방법은 없습니다.

 · \(N\)이 짝수라면 \(1\)번 정점의 가중치를 \(x-c_1\,(c_1=0)\)이라고 할 때 \(k\)번 정점의 가중치는 \(-x+c_k\)입니다. 따라서 \(1\)번 정점과 \(k\)번 정점을 연결하는 간선의 가중치 \(p\)에 대하여 \(c_k-c_1\neq p\)라면 그래프의 균형이 맞도록 정점에 정수 가중치를 부여하는 방법이 없고, 아니라면 서브태스크 4와 같이 문제를 해결할 수 있습니다.

 따라서 문제를 \(O(N\log N)\) 시간 안에 해결할 수 있습니다.

 

#include <cstdio>
#include <vector>
#include <utility>
#include <algorithm>
using namespace std;

int n;
vector<pair<int, int>> adj[100006];
vector<long long> v;
pair<int, long long> a[100006];
pair<int, int> y;
bool visited[100006];

void dfs(int x, int prev = -1) {
    visited[x] = true;
    for (auto &i: adj[x]) if (i.first != prev) {
        if (!i.first) y = { x, i.second };
        if (!visited[i.first]) {
            a[i.first] = { -a[x].first, i.second - a[x].second };
            dfs(i.first, x);
        }
    }
}

int main() {
    scanf("%d%*d", &n);
    for (int i = 0; i < n; i++) {
        int u, v, w;
        scanf("%d%d%d", &u, &v, &w);
        u--, v--;
        adj[u].push_back({ v, w });
        adj[v].push_back({ u, w });
    }
    a[0] = { 1, 0 };
    dfs(0);
    if (n % 2) {
        if ((y.second - a[y.first].second) % 2) return puts("No"), 0;
        puts("Yes");
        long long x = (y.second - a[y.first].second) / 2;
        for (int i = 0; i < n; i++) printf("%lld ", a[i].first * x + a[i].second);
        return 0;
    }
    if (y.second != a[y.first].second) return puts("No"), 0;
    for (int i = 0; i < n; i++) v.push_back(-a[i].first * a[i].second);
    sort(v.begin(), v.end());
    long long x = v[(int)v.size() / 2];
    puts("Yes");
    for (int i = 0; i < n; i++) printf("%lld ", a[i].first * x + a[i].second);
}

 

서브태스크 6

 \(1\)번 정점의 가중치를 \(x-c_1\,(c_1=0)\)이라고 두면 DFS 트리를 탐색하며 나머지 \(N-1\) 개 정점의 가중치를 \(x-c_i\) 또는 \(-x+c_i\)로 유일하게 결정할 수 있습니다. 이후 그래프상에서 \(u\)번 정점과 \(v\)번 정점을 연결하는 가중치 \(p\)의 간선에서 각각 다음을 수행합니다.

 · \(u\)번 정점의 가중치가 \(x-c_u\)이고 \(v\)번 정점의 가중치가 \(x-c_v\)라면 \(x=\frac{c_u+c_v+p}{2}\)입니다.

 · \(u\)번 정점의 가중치가 \(x-c_u\)이고 \(v\)번 정점의 가중치가 \(-x+c_v\)라면 \(c_v-c_u=p\)이어야 합니다.

 · \(u\)번 정점의 가중치가 \(-x+c_u\)이고 \(v\)번 정점의 가중치가 \(x-c_v\)라면 \(c_u-c_v=p\)이어야 합니다.

 · \(u\)번 정점의 가중치가 \(-x+c_u\)이고 \(v\)번 정점의 가중치가 \(-x+c_v\)라면 \(x=\frac{c_u+c_v-p}{2}\)입니다.

 이때 \(x\)의 값에 대해 모순이 발생하거나 조건이 성립하지 않을 때, 그리고 \(x\)의 값이 정수가 아닐 때 그래프의 균형이 맞도록 정점에 정수 가중치를 부여하는 방법이 없습니다. 위 검사를 모두 통과했다고 합시다. \(x\)의 값이 정해졌다면 나머지 정점의 가중치를 유일하게 결정할 수 있습니다. 아니라면, 서브태스크 2와 같은 관찰이 성립합니다. 각 \(i\)마다 \(x_i=0\)으로 두고 나머지 정점의 가중치를 유일하게 결정하면 문제를 \(O(N^2+M)\) 시간 안에 해결할 수 있습니다. 

 

#include <cmath>
#include <cstdio>
#include <vector>
#include <utility>
#include <algorithm>
using namespace std;

int n, m, res[100006];
vector<pair<int, int>> adj[100006];
vector<long long> v;
pair<int, long long> a[100006];
bool visited[100006];

void dfs(int x) {
    visited[x] = true;
    for (auto &i: adj[x]) if (!visited[i.first]) {
        a[i.first] = { -a[x].first, i.second - a[x].second };
        dfs(i.first);
    }
}

int main() {
    scanf("%d%d", &n, &m);
    for (int i = 0; i < m; i++) {
        int u, v, w;
        scanf("%d%d%d", &u, &v, &w);
        u--, v--;
        adj[u].push_back({ v, w });
        adj[v].push_back({ u, w });
    }
    a[0] = { 1, 0 };
    dfs(0);
    for (int i = 0; i < n; i++) for (auto &j: adj[i]) if (i < j.first) {
        if (a[i].first != a[j.first].first) {
            if (a[i].second + a[j.first].second == j.second) continue;
            puts("No");
            return 0;
        } else v.push_back(2 * (j.second - a[j.first].second - a[i].second) / (a[i].first + a[j.first].first));
    }
    sort(v.begin(), v.end());
    v.erase(unique(v.begin(), v.end()), v.end());
    if ((int)v.size() > 1) {
        puts("No");
        return 0;
    }
    if ((int)v.size() == 1) {
        if (v[0] % 2) puts("No");
        else {
            puts("Yes");
            long long x = v[0] / 2;
            for (int i = 0; i < n; i++) printf("%lld ", a[i].first * x + a[i].second);
        }
        return 0;
    }
    pair<long long, int> res = { (long long)9e18, -1 };
    for (int i = 0; i < n; i++) {
        long long curr = 0, x = -a[i].first * a[i].second;
        for (int j = 0; j < n; j++) curr += llabs(a[j].first * x + a[j].second);
        res = min(res, { curr, i });
    }
    puts("Yes");
    for (int i = 0; i < n; i++) printf("%lld ", -a[i].first * a[res.second].first * a[res.second].second + a[i].second);
}

 

서브태스크 7

 서브태스크 6과 같이 시행하되, 최솟값을 서브태스크 3과 같이 찾아 문제를 \(O(N\log N+M)\) 시간 안에 해결할 수 있습니다.

 

#include <cstdio>
#include <vector>
#include <utility>
#include <algorithm>
using namespace std;

int n, m, res[100006];
vector<pair<int, int>> adj[100006];
vector<long long> v;
pair<int, long long> a[100006];
bool visited[100006];

void dfs(int x) {
    visited[x] = true;
    for (auto &i: adj[x]) if (!visited[i.first]) {
        a[i.first] = { -a[x].first, i.second - a[x].second };
        dfs(i.first);
    }
}

int main() {
    scanf("%d%d", &n, &m);
    for (int i = 0; i < m; i++) {
        int u, v, w;
        scanf("%d%d%d", &u, &v, &w);
        u--, v--;
        adj[u].push_back({ v, w });
        adj[v].push_back({ u, w });
    }
    a[0] = { 1, 0 };
    dfs(0);
    for (int i = 0; i < n; i++) for (auto &j: adj[i]) if (i < j.first) {
        if (a[i].first != a[j.first].first) {
            if (a[i].second + a[j.first].second == j.second) continue;
            puts("No");
            return 0;
        } else v.push_back(2 * (j.second - a[j.first].second - a[i].second) / (a[i].first + a[j.first].first));
    }
    sort(v.begin(), v.end());
    v.erase(unique(v.begin(), v.end()), v.end());
    if ((int)v.size() > 1) {
        puts("No");
        return 0;
    }
    if ((int)v.size() == 1) {
        if (v[0] % 2) puts("No");
        else {
            puts("Yes");
            long long x = v[0] / 2;
            for (int i = 0; i < n; i++) printf("%lld ", a[i].first * x + a[i].second);
        }
        return 0;
    }
    for (int i = 0; i < n; i++) v.push_back(-a[i].first * a[i].second);
    sort(v.begin(), v.end());
    long long x = v[(int)v.size() / 2];
    puts("Yes");
    for (int i = 0; i < n; i++) printf("%lld ", a[i].first * x + a[i].second);
}

 

가장 긴 공통 괄호 문자열

문제

서브태스크 1

 문자열 \(A\)에서 모든 부분 문자열 \(O(N^2)\) 개, 그리고 문자열 \(B\)에서 모든 부분 문자열 \(O(M^2)\) 개에 대해 각각이 서로 같은 올바른 괄호열인지 확인합니다. 이때 서로 같은 괄호열의 길이가 서로 같음을 이용하여 문제를 \(O(N^2M\mathrm{min}(N,M))\) 시간 안에 해결할 수 있습니다.

 

#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;

int t, n, m;
char a[106], b[106];

int main() {
    for (scanf("%d", &t); t--; ) {
        int res = 0;
        scanf("%s%s", a, b);
        n = strlen(a), m = strlen(b);
        for (int i = 0; i < n; i++) for (int j = i + 1; j < n; j += 2) for (int x = 0; x < m; x++) {
            int y = j - i + x;
            if (y >= m) continue;
            int r = 0;
            for (int k = i; k <= j; k++) {
                r += a[k] == '(' ? 1 : -1;
                if (r < 0) goto next;
            }
            if (r) goto next;
            r = 0;
            for (int k = x; k <= y; k++) {
                r += b[k] == '(' ? 1 : -1;
                if (r < 0) goto next;
            }
            if (r) goto next;
            for (int k = i; k <= j; k++) if (a[k] != b[k - i + x]) goto next;
            res = max(res, j - i + 1);
            next:;
        }
        printf("%d\n", res);
    }
}

 

서브태스크 2

 서브태스크 1에서 문자열 \(A\)의 올바른 괄호열의 집합 \(a\), 문자열 \(B\)의 올바른 괄호열의 집합 \(b\)에 대해 \(a\)와 \(b\)를 각각 \(O(N^2\log N+M^2\log M)\) 시간 안에 정렬한 뒤 선형 시간에 비교합니다. 문제를 \(O(N^3+M^3)\) 시간 안에 해결할 수 있습니다. 상수가 작아서 통과합니다.

 

#include <cstdio>
#include <vector>
#include <string>
#include <cstring>
#include <algorithm>
using namespace std;

int t, n, m;
char a[1006], b[1006];

int main() {
    for (scanf("%d", &t); t--; ) {
        vector<string> A, B;
        int res = 0;
        scanf("%s%s", a, b);
        n = strlen(a), m = strlen(b);
        for (int i = 0; i < n; i++) for (int j = i + 1; j < n; j += 2) {
            string s;
            int r = 0;
            for (int k = i; k <= j; k++) {
                r += a[k] == '(' ? 1 : -1;
                s += a[k];
                if (r < 0) goto next;
            }
            if (r) goto next;
            A.push_back(s);
            next:;
        }
        for (int i = 0; i < m; i++) for (int j = i + 1; j < m; j += 2) {
            string s;
            int r = 0;
            for (int k = i; k <= j; k++) {
                r += b[k] == '(' ? 1 : -1;
                s += b[k];
                if (r < 0) goto next2;
            }
            if (r) goto next2;
            B.push_back(s);
            next2:;
        }
        sort(A.begin(), A.end());
        sort(B.begin(), B.end());
        int k = 0;
        for (int i = 0; i < (int)A.size(); i++) {
            while (k < (int)B.size() && B[k] < A[i]) k++;
            if (k < (int)B.size() && B[k] == A[i]) res = max(res, (int)A[i].size());
        }
        printf("%d\n", res);
    }
}

 

서브태스크 3

 문자열 \(A\)에서 모든 올바른 괄호열을 하나의 정수로 해싱합니다. 그런 정수들의 집합을 \(a\)라고 하면 \(a\)의 크기는 \(O(N^2)\)입니다. 또, 문자열 \(B\)에서 모든 올바른 괄호열을 하나의 정수로 해싱합니다. 그런 정수들의 집합을 \(b\)라고 하면 \(b\)의 크기는 \(O(M^2)\)입니다. \(a\)와 \(b\)를 각각 \(O(N^2\log N+M^2\log M)\) 시간 안에 정렬한 뒤 선형 시간에 비교하여 서로 같은 괄호열들의 목록을 얻을 수 있습니다. 이 목록에서 가장 긴 괄호열이 문제의 답이 됩니다. 문제를 \(O(N^2\log N+M^2\log M)\) 시간 안에 해결할 수 있습니다. 상수가 많이 작아서 통과합니다.

 

#include <cstdio>
#include <vector>
#include <utility>
#include <cstring>
#include <algorithm>
using namespace std;

const int MOD = 1e9 + 7;

int t, n, m;
char a[500006], b[500006];
vector<pair<int, long long>> x, y;

int main() {
    for (scanf("%d", &t); t--; ) {
        scanf("%s%s", a, b); x.clear(); y.clear();
        n = strlen(a), m = strlen(b);
        for (int i = 0; i < n; i++) {
            int r = 0;
            long long curr = 0;
            for (int j = i; j < n; j++) {
                r += a[j] == '(' ? 1 : -1;
                curr = 2 * curr % MOD;
                if (a[j] == '(') curr = (curr + 1) % MOD;
                if (r < 0) break;
                if (r == 0) x.push_back({ j - i + 1, curr });
            }
        }
        for (int i = 0; i < m; i++) {
            int r = 0;
            long long curr = 0;
            for (int j = i; j < m; j++) {
                r += b[j] == '(' ? 1 : -1;
                curr = 2 * curr % MOD;
                if (b[j] == '(') curr = (curr + 1) % MOD;
                if (r < 0) break;
                if (r == 0) y.push_back({ j - i + 1, curr });
            }
        }
        sort(x.begin(), x.end(), [](const pair<int, long long> &a, const pair<int, long long> &b) {
            if (a.first != b.first) return a > b;
            return a.second < b.second;
        });
        sort(y.begin(), y.end(), [](const pair<int, long long> &a, const pair<int, long long> &b) {
            if (a.first != b.first) return a > b;
            return a.second < b.second;
        });
        int k = 0, res = 0;
        for (int i = 0; i < (int)x.size(); i++) {
            while (k < (int)y.size() && (y[k].first > x[i].first || y[k].first == x[i].first && y[k].second < x[i].second)) k++;
            while (k < (int)y.size() && y[k] == x[i]) res = max(res, x[i].first), k++;
        }
        printf("%d\n", res);
    }
}

 

서브태스크 4

 \(i\)째 문자의 이전에 등장하는 ('('의 개수) - (')'의 개수)의 값을 \(r_i\)로 정의합시다. 이제 각 \(i\)에 대하여, \(r_i=r_k\;(k>i)\)인 최소의 \(k\)를 \(b_i\)라 합시다. 단, 그러한 \(k\)가 존재하지 않거나 \(i\)째 문자부터 \(k-1\)째 문자가 올바른 괄호열을 이루지 않는다면 \(b_i=-1\)입니다. 이는 세그먼트 트리를 이용해 판별할 수 있습니다. \(b_i=k\)인 \(i\)가 존재하지 않는 \(k\)들을 관리하며 문자열 \(A\)의 올바른 괄호열의 최대 길이를 얻을 수 있습니다. 문제를 \(O(N\log N)\) 시간 안에 해결할 수 있습니다.

 

#include <cstdio>
#include <vector>
#include <cstring>
#include <algorithm>
using namespace std;

int tree[1048576];

int update(int i, int b, int e, int p, int v) {
    if (p < b || e < p) return tree[i];
    if (b == e) return tree[i] = v;
    int m = (b + e) / 2;
    return tree[i] = min(update(i * 2 + 1, b, m, p, v), update(i * 2 + 2, m + 1, e, p, v));
}

int query(int i, int b, int e, int l, int r) {
    if (r < l || r < b || e < l) return 1e9;
    if (l <= b && e <= r) return tree[i];
    int m = (b + e) / 2;
    return min(query(i * 2 + 1, b, m, l, r), query(i * 2 + 2, m + 1, e, l, r));
}

int t, n, r[500006];
char a[500006];
int ina[500006], _fa[500006], *fa;
vector<int> in_a;

int main() {
    for (scanf("%d", &t); t--; ) {
        scanf("%s%*s", a);
        n = strlen(a);
        for (int i = 0; i < n; i++) {
            r[i] = (i ? r[i - 1] : 0) + (a[i] == '(' ? 1 : -1);
            update(0, 0, n - 1, i, r[i]);
        }
        int mr = *min_element(r, r + n);
        in_a.clear();
        fa = _fa - min(mr, 0);
        for (int i = 0; i <= n; i++) _fa[i] = 1e9;
        for (int i = n; i >= 0; i--) {
            int x = i ? r[i - 1] : 0;
            ina[i] = fa[x];
            fa[x] = i;
        }
        for (int i = 0; i <= n; i++) if (_fa[i] < (int)1e9) in_a.push_back(_fa[i]);
        for (int i = 0; i <= n; i++) if (ina[i] < (int)1e9 && query(0, 0, n - 1, i, ina[i] - 1) < (i ? r[i - 1] : 0)) {
            in_a.push_back(ina[i]);
            ina[i] = 1e9;
        }
        int res = 0;
        for (auto &i: in_a) {
            int x = i;
            while (ina[x] < (int)1e9) x = ina[x];
            res = max(res, x - i);
        }
        printf("%d\n", res);
    }
}

 

서브태스크 5

 문제의 답 \(k\)에 대해 이분 탐색합니다. 문제는 "길이 \(k\) 이상의 공통 괄호 문자열이 존재하는가?"의 결정 문제로 환원되었습니다. 문자열 \(A\)의 각 문자에 대해 "각 문자에서 시작하는 길이 \(k\) 이상의 문자열 중 길이가 최소인 올바른 괄호열"들을 모두 해싱해 집합 \(a\)를 얻고, 문자열 \(B\)의 각 문자에 대해 같은 방법으로 집합 \(b\)를 얻습니다. 이는 각 문자의 이전에 등장하는 ('('의 개수) - (')'의 개수)가 같은 문자들끼리 묶어 준 뒤 투 포인터를 이용해 구하는 괄호열의 후보를 얻고, 그 후보가 유효한지 세그먼트 트리 등을 이용해 확인하는 방법으로 \(O(N\log N+M\log M)\) 시간 안에 구성할 수 있습니다. 집합 \(a\)의 크기는 \(O(N)\)이며, 집합 \(b\)의 크기는 \(O(M)\)입니다. 서브태스크 3과 같은 방법으로 문제의 답을 \(O((N\log N+M\log M)\log\mathrm{min}(N,M))\) 시간 안에 얻을 수 있습니다.

 

#include <cstdio>
#include <vector>
#include <cstring>
#include <algorithm>
using namespace std;

const int MOD = 1e9 + 7;

int tree1[1048576], tree2[1048576];

int update(int *tree, int i, int b, int e, int p, int v) {
    if (p < b || e < p) return tree[i];
    if (b == e) return tree[i] = v;
    int m = (b + e) / 2;
    return tree[i] = min(update(tree, i * 2 + 1, b, m, p, v), update(tree, i * 2 + 2, m + 1, e, p, v));
}

int query(int *tree, int i, int b, int e, int l, int r) {
    if (r < l || r < b || e < l) return 1e9;
    if (l <= b && e <= r) return tree[i];
    int m = (b + e) / 2;
    return min(query(tree, i * 2 + 1, b, m, l, r), query(tree, i * 2 + 2, m + 1, e, l, r));
}

int t, n, m, r[500006], s[500006], two[500006], rev[500006];
char a[500006], b[500006];
int hasha[500006], hashb[500006];
int ina[500006], inb[500006], _fa[500006], *fa, _fb[500006], *fb;
vector<int> A, B, in_a, in_b;

int get_hash(int *hs, int n, int x, int y) {
    int ret = (hs[y] - (x ? hs[x - 1] : 0) + MOD) % MOD;
    return 1LL * ret * rev[x] % MOD;
}

int main() {
    two[0] = rev[0] = 1;
    for (int i = 1; i < 500006; i++) {
        rev[i] = rev[i - 1] * 500000004LL % MOD;
        two[i] = two[i - 1] * 2 % MOD;
    }
    for (scanf("%d", &t); t--; ) {
        scanf("%s%s", a, b);
        n = strlen(a), m = strlen(b);
        for (int i = 0; i < n; i++) {
            r[i] = (i ? r[i - 1] : 0) + (a[i] == '(' ? 1 : -1);
            update(tree1, 0, 0, n - 1, i, r[i]);
            hasha[i] = ((i ? hasha[i - 1] : 0) + two[i] * (a[i] == '(' ? 1 : 0)) % MOD;
        }
        for (int i = 0; i < m; i++) {
            s[i] = (i ? s[i - 1] : 0) + (b[i] == '(' ? 1 : -1);
            update(tree2, 0, 0, m - 1, i, s[i]);
            hashb[i] = ((i ? hashb[i - 1] : 0) + two[i] * (b[i] == '(' ? 1 : 0)) % MOD;
        }
        int mr = *min_element(r, r + n), ms = *min_element(s, s + m);
        in_a.clear(), in_b.clear();
        fa = _fa - min(mr, 0);
        for (int i = 0; i <= n; i++) _fa[i] = 1e9;
        for (int i = n; i >= 0; i--) {
            int x = i ? r[i - 1] : 0;
            ina[i] = fa[x];
            fa[x] = i;
        }
        for (int i = 0; i <= n; i++) if (_fa[i] < (int)1e9) in_a.push_back(_fa[i]);
        for (int i = 0; i <= n; i++) if (ina[i] < (int)1e9 && query(tree1, 0, 0, n - 1, i, ina[i] - 1) < (i ? r[i - 1] : 0)) {
            in_a.push_back(ina[i]);
            ina[i] = 1e9;
        }
        fb = _fb - min(ms, 0);
        for (int i = 0; i <= m; i++) _fb[i] = 1e9;
        for (int i = m; i >= 0; i--) {
            int x = i ? s[i - 1] : 0;
            inb[i] = fb[x];
            fb[x] = i;
        }
        for (int i = 0; i <= m; i++) if (_fb[i] < (int)1e9) in_b.push_back(_fb[i]);
        for (int i = 0; i <= m; i++) if (inb[i] < (int)1e9 && query(tree2, 0, 0, m - 1, i, inb[i] - 1) < (i ? s[i - 1] : 0)) {
            in_b.push_back(inb[i]);
            inb[i] = 1e9;
        }
        int L = 1, R = min(n, m);
        while (L < R) {
            int M = (L + R + 1) / 2;
            A.clear(), B.clear();
            for (auto &i: in_a) {
                int curr = i, x = curr;
                while (x < (int)1e9) {
                    while (x < (int)1e9 && x - curr < M) x = ina[x];
                    if (x < (int)1e9) {
                        A.push_back(get_hash(hasha, n, curr, x - 1));
                        curr = ina[curr];
                    }
                }
            }
            for (auto &i: in_b) {
                int curr = i, x = curr;
                while (x < (int)1e9) {
                    while (x < (int)1e9 && x - curr < M) x = inb[x];
                    if (x < (int)1e9) {
                        B.push_back(get_hash(hashb, m, curr, x - 1));
                        curr = inb[curr];
                    }
                }
            }
            sort(A.begin(), A.end());
            sort(B.begin(), B.end());
            int k = 0;
            bool none = true;
            for (int i = 0; i < (int)A.size(); i++) {
                while (k < (int)B.size() && B[k] < A[i]) k++;
                if (k < (int)B.size() && A[i] == B[k]) {
                    none = false;
                    break;
                }
            }
            if (none) R = M - 1;
            else L = M;
        }
        printf("%d\n", L > 1 ? L : 0);
    }
}

 

 

맛집 추천

문제