2021 한국 정보 올림피아드 2차 대회가 2021년 7월 25일에 온라인으로 진행되었습니다. 모두 수고하셨습니다.
들어가기 전에 하소연을 좀 하겠습니다. 사실 제가 잘못한 것이라 억울할 것도 없지만, 3번 문제의 4번 서브태스크를 왜인지 모르겠지만 \(|A|=|B|\)로 봤고, 그래서 긁지 않았습니다. 긁었다면 총점 295점인데, 이것으로 상의 색깔이 바뀐다면 많이 아쉽겠습니다.
타임라인
00:00:00 - 00:40:04
1번 문제 "헬기 착륙장"을 읽고 해결했습니다. 1번으로 예상했던 난이도보다 어려워서 당황했고, 시간이 좀 오래 걸렸습니다.
#include <cstdio>
#include <algorithm>
using namespace std;
const int MOD = 1e9 + 7;
int t, a, b, dp[100006][506], sum[100006][506];
int f(int x, int y) {
if (x < 0) return 0;
return dp[x][y];
}
int main() {
dp[0][0] = dp[1][1] = 1;
for (int i = 0; i < 506; i++) sum[0][i] = 1;
for (int i = 1; i < 506; i++) sum[1][i] = 1;
for (int i = 2; i < 100006; i++) for (int j = 1; j < 506; j++) {
if (i >= j) dp[i][j] = sum[i - j][j - 1];
sum[i][j] = (dp[i][j] + sum[i][j - 1]) % MOD;
}
for (int i = 0; i < 100006; i++) for (int j = 0; j < 506; j++) {
if (i) dp[i][j] = (dp[i][j] + dp[i - 1][j]) % MOD;
if (j) dp[i][j] = (dp[i][j] + dp[i][j - 1]) % MOD;
if (i && j) dp[i][j] = (dp[i][j] - dp[i - 1][j - 1] + MOD) % MOD;
}
for (scanf("%d", &t); t--; ) {
scanf("%d%d", &a, &b);
int res = 0;
for (int i = 1; i < 500; i++) {
if (a >= i) {
int x = a - i;
int _min = max(0, i * (i - 1) / 2 - x);
int _max = min(b, i * (i - 1) / 2);
if (_min <= _max) {
res = (res + f(_max, i - 1)) % MOD;
res = (res - f(_min - 1, i - 1) + MOD) % MOD;
}
}
if (b >= i) {
int x = b - i;
int _min = max(0, i * (i - 1) / 2 - x);
int _max = min(a, i * (i - 1) / 2);
if (_min <= _max) {
res = (res + f(_max, i - 1)) % MOD;
res = (res - f(_min - 1, i - 1) + MOD) % MOD;
}
}
}
printf("%d\n", res);
}
}
00:40:04 - 01:03:24
2번 문제 "그래프 균형 맞추기"를 읽고 해결했습니다. 문제가 BOI 2020 Graph와 판박이라서 쉽게 해결할 수 있었습니다.
#include <cstdio>
#include <vector>
#include <utility>
#include <algorithm>
using namespace std;
int n, m, res[100006];
vector<pair<int, int>> adj[100006];
vector<long long> v;
pair<int, long long> a[100006];
bool visited[100006];
void dfs(int x) {
visited[x] = true;
for (auto &i: adj[x]) if (!visited[i.first]) {
a[i.first] = { -a[x].first, i.second - a[x].second };
dfs(i.first);
}
}
int main() {
scanf("%d%d", &n, &m);
for (int i = 0; i < m; i++) {
int u, v, w;
scanf("%d%d%d", &u, &v, &w);
u--, v--;
adj[u].push_back({ v, w });
adj[v].push_back({ u, w });
}
a[0] = { 1, 0 };
dfs(0);
for (int i = 0; i < n; i++) for (auto &j: adj[i]) if (i < j.first) {
if (a[i].first != a[j.first].first) {
if (a[i].second + a[j.first].second == j.second) continue;
puts("No");
return 0;
} else v.push_back(2 * (j.second - a[j.first].second - a[i].second) / (a[i].first + a[j.first].first));
}
sort(v.begin(), v.end());
v.erase(unique(v.begin(), v.end()), v.end());
if ((int)v.size() > 1) {
puts("No");
return 0;
}
if ((int)v.size() == 1) {
if (v[0] % 2) puts("No");
else {
puts("Yes");
long long x = v[0] / 2;
for (int i = 0; i < n; i++) printf("%lld ", a[i].first * x + a[i].second);
}
return 0;
}
for (int i = 0; i < n; i++) v.push_back(-a[i].first * a[i].second);
sort(v.begin(), v.end());
long long x = v[(int)v.size() / 2];
puts("Yes");
for (int i = 0; i < n; i++) printf("%lld ", a[i].first * x + a[i].second);
}
01:03:24 - 01:28:55
3번 문제와 4번 문제를 읽고 풀이를 생각했습니다. 뒤이어 4번 문제의 9점 서브태스크를 스위핑 + DP로 긁었습니다.
#include <tuple>
#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;
int n, m;
long long dp[100006];
vector<int> adj[100006];
vector<tuple<int, int, int>> v;
int main() {
scanf("%d%d", &n, &m);
for (int i = 0; i < n - 1; i++) {
int u, v;
scanf("%d%d", &u, &v);
u--, v--;
adj[u].push_back(v);
adj[v].push_back(u);
}
for (int i = 0; i < m; i++) {
int c, d, g;
scanf("%d%d%d", &c, &d, &g);
c--;
v.push_back({ max(c - d, 0), min(c + d, n - 1), g });
}
sort(v.begin(), v.end());
long long curr = 0;
int k = 0;
for (int i = 0; i < n; i++) {
while (k < m && get<0>(v[k]) <= i) {
dp[get<1>(v[k])] = max(dp[get<1>(v[k])], curr + get<2>(v[k]));
k++;
}
curr = max(curr, dp[i]);
}
printf("%lld", curr);
}
01:28:55 - 01:39:45
3번 문제에서 해싱을 이용한 \(O(N^2\log N+M^2\log M)\) 풀이를 작성했고, 41점을 긁었습니다. 답안을 제출한 뒤 3번 서브태스크가 애매해서 시간 초과를 받을 거라고 생각하고 LCS 6과 비슷하게 23점을 긁을 궁리를 하고 있었는데, 41점이 긁혀서 꽤 놀랐습니다.
#include <cstdio>
#include <vector>
#include <utility>
#include <cstring>
#include <algorithm>
using namespace std;
const int MOD = 1e9 + 7;
int t, n, m;
char a[500006], b[500006];
vector<pair<int, long long>> x, y;
int main() {
for (scanf("%d", &t); t--; ) {
scanf("%s%s", a, b); x.clear(); y.clear();
n = strlen(a), m = strlen(b);
for (int i = 0; i < n; i++) {
int r = 0;
long long curr = 0;
for (int j = i; j < n; j++) {
r += a[j] == '(' ? 1 : -1;
curr = 2 * curr % MOD;
if (a[j] == '(') curr = (curr + 1) % MOD;
if (r < 0) break;
if (r == 0) x.push_back({ j - i + 1, curr });
}
}
for (int i = 0; i < m; i++) {
int r = 0;
long long curr = 0;
for (int j = i; j < m; j++) {
r += b[j] == '(' ? 1 : -1;
curr = 2 * curr % MOD;
if (b[j] == '(') curr = (curr + 1) % MOD;
if (r < 0) break;
if (r == 0) y.push_back({ j - i + 1, curr });
}
}
sort(x.begin(), x.end(), [](const pair<int, long long> &a, const pair<int, long long> &b) {
if (a.first != b.first) return a > b;
return a.second < b.second;
});
sort(y.begin(), y.end(), [](const pair<int, long long> &a, const pair<int, long long> &b) {
if (a.first != b.first) return a > b;
return a.second < b.second;
});
int k = 0, res = 0;
for (int i = 0; i < (int)x.size(); i++) {
while (k < (int)y.size() && (y[k].first > x[i].first || y[k].first == x[i].first && y[k].second < x[i].second)) k++;
while (k < (int)y.size() && y[k] == x[i]) res = max(res, x[i].first), k++;
}
printf("%d\n", res);
}
}
01:39:45 - 02:56:38
HLD를 이용한 4번 문제의 풀이를 찾았다고 생각하고, 신나서 100점 풀이를 짰지만 반례가 \(N\) 개 존재한다는 사실을 알게 되었고, 접었습니다.
#include <tuple>
#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;
int n, m, sz[100006], depth[100006], par[100006], in[100006], rev[100006], t;
long long dp[100006], res;
vector<int> adj[100006], childs[100006];
vector<tuple<int, int, long long>> tag[100006];
int dfs_sz(int x, int prev = -1) {
sz[x] = 1; par[x] = prev;
for (auto &i: adj[x]) if (i != prev) {
childs[x].push_back(i);
depth[i] = depth[x] + 1;
sz[x] += dfs_sz(i, x);
if (sz[childs[x][0]] < sz[i]) swap(childs[x][0], sz[i]);
}
return sz[x];
}
void dfs_hld(int x) {
in[x] = t;
rev[t] = x;
t++;
for (auto &i: childs[x]) dfs_hld(i);
}
vector<int> lt;
vector<tuple<int, int, long long>> v;
void dfs_bond(int x, int y) {
dp[in[y] + (depth[x] - depth[y])] += dp[in[x]];
if (!childs[x].empty()) dfs_bond(childs[x][0], y);
}
void dfs_res(int x, int y) {
sort(v.begin(), v.end());
reverse(v.begin(), v.end());
long long curr = 0;
int k = 0;
for (int i = y; true; i = par[i]) {
while (k < (int)v.size() && get<0>(v[k]) >= i) {
if (in[get<1>(v[k])] < in[x]) {
if (!x) dp[0] = max(dp[0], curr + get<2>(v[k]));
else {
int inc = (get<0>(v[k]) + get<1>(v[k])) / 2;
int c = rev[inc];
int d = get<0>(v[k]) - c;
int diff = depth[c] - depth[par[x]];
int nd = d - diff;
if (in[x] + 1 <= in[y]) get<2>(v[k]) += dp[in[x] + 1];
tag[par[x]].push_back({ in[par[x]] + nd, in[par[x]] - nd, get<2>(v[k]) });
}
} else dp[in[get<1>(v[k])]] = max(dp[in[get<1>(v[k])]], curr + get<2>(v[k]));
k++;
}
curr = max(curr, dp[in[i]]);
if (i == x) break;
}
res = max(res, curr);
}
void dfs_calc(int x, int y, int z) {
lt.push_back(x);
for (auto &i: tag[x]) v.push_back(i);
for (int i = 0; i < (int)childs[x].size(); i++) dfs_bond(childs[x][i], x);
if (x == y) dfs_res(y, z);
else dfs_calc(par[x], y, z);
}
void dfs(int x, int y) {
if (childs[x].empty()) {
v.clear();
lt.clear();
dfs_calc(x, y, x);
return;
}
for (int i = 1; i < (int)childs[x].size(); i++) dfs(childs[x][i], childs[x][i]);
dfs(childs[x][0], y);
}
int main() {
scanf("%d%d", &n, &m);
for (int i = 0; i < n - 1; i++) {
int u, v;
scanf("%d%d", &u, &v);
u--, v--;
adj[u].push_back(v);
adj[v].push_back(u);
}
dfs_sz(0);
dfs_hld(0);
for (int i = 0; i < m; i++) {
int c, d, g;
scanf("%d%d%d", &c, &d, &g);
c--;
tag[c].push_back({ in[c] + d, in[c] - d, g });
}
dfs(0, 0);
printf("%lld", res);
}
02:56:38 - 03:37:21
4번 문제의 \(O(N(N+M)\log N)\) 풀이를 찾아 28점을 추가로 얻었습니다.
#include <tuple>
#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;
int n, m, par[20][100006], depth[100006];
long long dp[100006], mx[100006];
vector<int> adj[100006];
vector<tuple<int, int, int>> tag[100006];
void dfs(int x, int prev = -1) {
par[0][x] = prev;
for (auto &i: adj[x]) if (i != prev) {
depth[i] = depth[x] + 1;
dfs(i, x);
}
}
int up(int x, int y) {
for (int t = 19; t >= 0; t--) if (y >= 1 << t && x != -1) {
y -= 1 << t;
x = par[t][x];
}
if (x == -1) x = 0;
return x;
}
int dist(int x, int y) {
if (depth[x] > depth[y]) swap(x, y);
int diff = depth[y] - depth[x], l = diff;
for (int t = 19; t >= 0; t--) if (diff >= 1 << t) {
diff -= 1 << t;
y = par[t][y];
}
if (x == y) return l;
for (int t = 19; t >= 0; t--) if (par[t][x] != -1 && par[t][y] != -1 && par[t][x] != par[t][y]) {
x = par[t][x];
y = par[t][y];
l += 1 << t + 1;
}
return l + 2;
}
vector<int> v;
void dfs_t(int x, int prev = -1) {
v.push_back(x);
for (auto &i: adj[x]) if (i != prev) dfs_t(i, x);
}
long long dfs_res(int x, int prev = -1) {
for (auto &i: adj[x]) if (i != prev) dp[x] += dfs_res(i, x);
v.clear();
dfs_t(x, prev);
for (auto &i: tag[x]) {
long long curr = get<2>(i);
for (auto &j: v) if (dist(get<0>(i), j) == get<1>(i) + 1) curr += mx[j];
dp[x] = max(dp[x], curr);
}
return mx[x] = max(mx[x], dp[x]);
}
int main() {
scanf("%d%d", &n, &m);
for (int i = 0; i < n - 1; i++) {
int u, v;
scanf("%d%d", &u, &v);
u--, v--;
adj[u].push_back(v);
adj[v].push_back(u);
}
dfs(0);
for (int t = 1; t < 20; t++) for (int i = 0; i < n; i++) {
if (par[t - 1][i] == -1) par[t][i] = -1;
else par[t][i] = par[t - 1][par[t - 1][i]];
}
for (int i = 0; i < m; i++) {
int c, d, g;
scanf("%d%d%d", &c, &d, &g);
c--;
tag[up(c, d)].push_back({ c, d, g });
}
printf("%lld", dfs_res(0));
}
03:37:21 - 04:30:00
4번 문제의 6번 서브태스크를 긁다가 대회가 끝났습니다. 나중에 들은 바에 따르면 4번 문제는 방역과 비슷했다고 합니다.
#include <tuple>
#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;
int n, m, t, in[100006], depth[100006];
long long dp[100006], mx[100006], res, tot[100006];
vector<int> adj[100006];
vector<tuple<int, int, long long>> tag[100006];
void dfs(int x, int prev = -1) {
in[x] = t++;
for (auto &i: adj[x]) if (i != prev) dfs(i, x);
}
void dfs_sz(int x, int prev = -1) {
for (auto &i: adj[x]) if (i != prev) {
depth[in[i]] = depth[in[x]] + 1;
dfs_sz(i, x);
}
}
int tail(int x, int prev = -1) {
if ((int)adj[x].size() == 1) return x;
return tail(adj[x][0] + adj[x][1] - prev, x);
}
void solve(int x, int y) {
long long curr = 0;
vector<tuple<int, int, int>> v;
for (int i = y; i >= x; i--) for (auto &j: tag[i]) v.push_back(j);
sort(v.begin(), v.end(), [](const tuple<int, int, int> &a, const tuple<int, int, int> &b) {
if (get<0>(a) + get<1>(a) != get<0>(b) + get<1>(b)) return get<0>(a) + get<1>(a) > get<0>(b) + get<1>(b);
return get<0>(a) - get<1>(a) > get<0>(b) - get<1>(b);
});
int k = 0;
for (int i = y; i >= x; i--) {
while (k < (int)v.size() && i <= get<0>(v[k]) + get<1>(v[k])) {
if (get<0>(v[k]) - get<1>(v[k]) >= x) dp[get<0>(v[k]) - get<1>(v[k])] = max(dp[get<0>(v[k]) - get<1>(v[k])], get<2>(v[k]) + curr);
k++;
}
curr = max(curr, dp[i]);
mx[i] = curr;
}
k = 0, curr = 0;
for (int i = y; i >= x; i--) {
while (k < (int)v.size() && i <= get<0>(v[k]) + get<1>(v[k])) {
if (get<0>(v[k]) - get<1>(v[k]) < x) tag[0].push_back({ 0, get<1>(v[k]) - depth[get<0>(v[k])], get<2>(v[k]) + curr - mx[x + get<1>(v[k]) - depth[get<0>(v[k])]]});
k++;
}
curr = max(curr, dp[i]);
tot[i - x + 1] += mx[i - x + 1];
}
res = max(res, curr);
}
int main() {
scanf("%d%d", &n, &m);
for (int i = 0; i < n - 1; i++) {
int u, v;
scanf("%d%d", &u, &v);
u--, v--;
adj[u].push_back(v);
adj[v].push_back(u);
}
int x = 0;
for (int i = 0; i < n; i++) if ((int)adj[i].size() > 2) x = i;
dfs(x);
dfs_sz(x);
for (int i = 0; i < m; i++) {
int c, d, g;
scanf("%d%d%d", &c, &d, &g);
c--;
tag[in[c]].push_back({ in[c], d, g });
}
for (auto &i: adj[x]) solve(in[i], in[tail(i, x)]);
for (auto &i: tag[0]) {
res = max(res, get<2>(i) + tot[get<1>(i) + 1]);
if (get<1>(i) < 0) return 1;
}
printf("%lld", max(res, tot[1]));
}
소감
아무래도 아쉬움이 큰 대회였습니다. 대회 전에는 300점을 넘길 생각을 하고 있었는데, 성적은 생각만큼 나오지 않아서 많이 아쉬웠습니다. 가장 아쉬운 점은 앞에서 말했듯 17점을 긁지 못한 점이었습니다. 다만 운이 좋으면 은상을 받을 수 있을 것 같습니다. 옆 블로그에 누구랑 누구는 PS를 시작한 지 한 해 만에 무서운 기록들을 세우던데, 저는 PS를 한 기간에 비해 원하는 만큼 성과가 나오지 않아서 슬프네요. 그래도 이번 대회를 더 빡세게 공부하는 계기로 삼아, 다음번에는 금상을 받을 수 있게 해야겠습니다. 이번 결과는 당분간 잊고, 돈 주는 NYPC 준비에 전념해야겠네요.
결과
총점 278점을 얻으며 은상을 받았습니다. 금상 커트라인은 300점으로 밝혀지며 억울함은 풀었네요. 그 외에도 대상 커트라인은 325점, 은상 커트라인은 246점이었다고 합니다. 또, 동상 커트라인은 134점, 장려상 커트라인은 41점이었다고 하네요.
헬기 착륙장
문제
서브태스크 1
\(a\leq6\), \(b\leq6\)이므로 \(a+b\leq12\)입니다. 따라서 헬기 착륙장의 반지름 \(k\leq4\)입니다. 헬기 착륙장의 반지름이 작으므로 가능한 모든 경우인 \(\sum_{i=1}^k2^i\) 가지를 모두 보면 됩니다. 문제를 \(O(Tk^22^k)\) 시간 안에 해결할 수 있습니다.
#include <cstdio>
int a, b, cnt;
int main() {
scanf("%*d%d%d", &a, &b);
for (int c = 1; c <= 5; c++) {
for (int i = 0; i < 1 << c; i++) {
int x = 0, y = 0;
for (int j = 0; j < c; j++) {
if (i & 1 << j) x += j + 1;
else y += j + 1;
}
if (x <= a && y <= b) cnt++;
}
}
printf("%d", cnt);
}
서브태스크 2
\(a\leq100\), \(b\leq100\)이므로 \(O(T\mathrm{max}(a,b)^3)\) 풀이로 충분합니다. \(f_{i,\,j,\,k}\)를 반지름이 \(i\)이고 빨강 페인트를 \(a\) 통, 파랑 페인트를 \(b\) 통 사용하는 헬기 착륙장의 개수로 정의하면 다음 점화식이 세워집니다.
$$f_{i,\,j,\,k}=\begin{cases}0&(\text{$j<0$ or $k<0$})\\f_{i-1,\,j-i,\,k}+f_{i-1,\,j,\,k-i}&(\text{otherwise})\end{cases}\,(i\geq1)$$
그리고 편의상
$$f_{0,\,i,\,j}=\begin{cases}1&(\text{$j=0$ and $k=0$})\\0&(\text{otherwise})\end{cases}$$
문제의 답은 \(\sum_{i>0,\,j\leq a,\,k\leq b} f_{i,\,j,\,k}\)입니다. 따라서 문제를 \(O(Tab(a+b))\) 시간 안에 해결할 수 있습니다.
#include <cstdio>
#include <algorithm>
using namespace std;
const int MOD = 1e9 + 7;
int a, b, dp[206][106][106], res;
int main() {
dp[0][0][0] = 1;
scanf("%*d%d%d", &a, &b);
for (int i = 1; i <= a + b; i++) for (int x = 0; x <= a; x++) for (int y = 0; y <= b; y++) {
if (x >= i) dp[i][x][y] = (dp[i][x][y] + dp[i - 1][x - i][y]) % MOD;
if (y >= i) dp[i][x][y] = (dp[i][x][y] + dp[i - 1][x][y - i]) % MOD;
res = (res + dp[i][x][y]) % MOD;
}
printf("%d", res);
}
서브태스크 3
사실, 서브태스크 2에서 문제의 답은 \(\sum_{0<i\leq\lfloor2(a+b)\rfloor,\,j\leq a,\,k\leq b} f_{i,\,j,\,k}\)입니다. 헬기 착륙장의 최대 반지름 \(k\)에 대하여 \(a+b\leq\frac{1}{2}k(k+1)\)이므로 \(k=\lfloor\sqrt{2(a+b)}\rfloor\)이기 때문입니다. 따라서, 문제를 \(O(Tab\sqrt{a+b})\) 시간 안에 해결할 수 있습니다.
#include <cstdio>
#include <algorithm>
using namespace std;
const int MOD = 1e9 + 7;
int a, b, dp[64][1006][1006], res;
int main() {
dp[0][0][0] = 1;
scanf("%*d%d%d", &a, &b);
for (int i = 1; i < 64; i++) for (int x = 0; x <= a; x++) for (int y = 0; y <= b; y++) {
if (x >= i) dp[i][x][y] = (dp[i][x][y] + dp[i - 1][x - i][y]) % MOD;
if (y >= i) dp[i][x][y] = (dp[i][x][y] + dp[i - 1][x][y - i]) % MOD;
res = (res + dp[i][x][y]) % MOD;
}
printf("%d", res);
}
서브태스크 4
한 헬기 착륙장의 반지름 \(k\)가 정해지고 사용한 빨강 페인트 통의 수 \(A\)가 정해지면, 사용한 파랑 페인트 통의 수 \(b\)가 정해집니다. 따라서 \(f_{i,\,j}\)를 반지름 \(k=i\), 사용한 빨강 페인트 통의 수 \(A=j\)인 헬기 착륙장의 수로 두면 다음 점화식이 세워집니다.
$$f_{i,\,j}=\begin{cases}0&(\text{$j<0$ or $j>\frac{1}{2}i(i+1)$})\\f_{i-1,j}+f_{i-1,j-i}&(\text{otherwise})\end{cases}\,(i\geq1)$$
그리고 편의상
$$f_{0,\,j}=\begin{cases}1&(\text{$j=0$})\\0&(\text{otherwise}))\end{cases}$$
서브태스크 3과 마찬가지로, 문제의 답은 \(\sum_{0<i\leq\lfloor\sqrt{2(a+b)}\rfloor,\,\frac{1}{2}i(i+1)-b\leq j\leq a}f_{i,\,j}\)입니다. 따라서 문제를 \(O(Ta\sqrt{2(a+b)}\) 시간 안에 해결할 수 있습니다.
#include <cstdio>
#include <algorithm>
using namespace std;
const int MOD = 1e9 + 7;
int a, b, dp[156][10006], res;
int main() {
dp[0][0] = 1;
scanf("%*d%d%d", &a, &b);
for (int i = 1; i < 156; i++) for (int j = 0; j <= a; j++) {
dp[i][j] = dp[i - 1][j];
if (i <= j) dp[i][j] = (dp[i][j] + dp[i - 1][j - i]) % MOD;
if (j >= i * (i + 1) / 2 - b) res = (res + dp[i][j]) % MOD;
}
printf("%d", res);
}
서브태스크 5
\(f_{i,\,j}\)를 반지름 \(j\)인 빨간 동심원이 최대의 동심원이고 사용한 빨강 페인트 통의 수가 \(i\)인 헬기 착륙장의 수라고 하면 다음 점화식이 세워집니다.
$$f_{i,\,j}=\begin{cases}0&(i<j)\\\sum_{1\\leq k\leq j-1}f_{i-j,\,k}&(i\geq j)\end{cases}\,(i\geq1)$$
그리고 편의상
$$f_{0,\,i}=\begin{cases}1&(i=0)\\0&(i\neq0)\end{cases}$$
또한, \(f_{i,\,j}\)는 대칭적으로 반지름 \(j\)인 파란 동심원이 최대의 동심원이고 사용한 파랑 페인트 통의 수가 \(i\)인 헬기 착륙장의 수와 서로 같습니다. 이제 \(O(\mathrm{max}(a,b)\sqrt{2(a+b)})\) 시간 안에 동적 계획표 표를 모두 구성할 수 있습니다.
이제 각 테스트 케이스에 대해 가장 큰 동심원의 반지름 \(k\)가 될 수 있는 범위 \(k\in[1,\lfloor\sqrt{2(a+b)}\rfloor]\)에 대해 모두 생각합니다. 또, 각 테스트 케이스에서 가장 큰 동심원이 빨간 경우와 파란 경우를 각각 따져 줍니다. 일반성을 잃지 않고 가장 큰 동심원이 빨갛다고 합시다. 이제 반지름이 \(k\)인 헬기 착륙장에서 사용한 파랑 페인트 통의 수 \(b\)는 \([\mathrm{max}(0,\frac{1}{2}k(k-1)), \mathrm{min}(b,\frac{1}{2}k(k-1))]\) 범위 안에 있습니다. 따라서 각 테스트 케이스에서 \(\sum_{\mathrm{max}(0,\frac{1}{2}k(k-1))\leq i\leq\mathrm{min}(b,\frac{1}{2}k(k-1))}f_{i,\,k-1}\)의 값을 빠르게 구할 수 있다면 전체 문제를 해결할 수 있습니다. 이는 \(f\)의 누적 합 배열 \(g\)에 대하여 다음과 같은 점화식을 이용하여 알아낼 수 있습니다.
$$g_{i,\,j}=\begin{cases}0&(\text{$i<0$ or $j<0$})\\g_{i-1,\,j}+g_{i,\,j-1}-g_{i-1,\,j-1}&(\text{otherwise})\end{cases}$$
이제 각 테스트 케이스에서 \(g_{\mathrm{max}(0,\frac{1}{2}k(k-1))}-g_{\mathrm{min}(b,\frac{1}{2}k(k-1))}\)의 값을 \(O(1)\) 시간 안에 구할 수 있으므로, 전체 문제를 \(O((\mathrm{max}(a,b)+T)\sqrt{2(a+b)})\) 시간 안에 해결할 수 있습니다.
#include <cstdio>
#include <algorithm>
using namespace std;
const int MOD = 1e9 + 7;
int t, a, b, dp[100006][506], sum[100006][506];
int f(int x, int y) {
if (x < 0) return 0;
return dp[x][y];
}
int main() {
dp[0][0] = dp[1][1] = 1;
for (int i = 0; i < 506; i++) sum[0][i] = 1;
for (int i = 1; i < 506; i++) sum[1][i] = 1;
for (int i = 2; i < 100006; i++) for (int j = 1; j < 506; j++) {
if (i >= j) dp[i][j] = sum[i - j][j - 1];
sum[i][j] = (dp[i][j] + sum[i][j - 1]) % MOD;
}
for (int i = 0; i < 100006; i++) for (int j = 0; j < 506; j++) {
if (i) dp[i][j] = (dp[i][j] + dp[i - 1][j]) % MOD;
if (j) dp[i][j] = (dp[i][j] + dp[i][j - 1]) % MOD;
if (i && j) dp[i][j] = (dp[i][j] - dp[i - 1][j - 1] + MOD) % MOD;
}
for (scanf("%d", &t); t--; ) {
scanf("%d%d", &a, &b);
int res = 0;
for (int i = 1; i < 500; i++) {
if (a >= i) {
int x = a - i;
int _min = max(0, i * (i - 1) / 2 - x);
int _max = min(b, i * (i - 1) / 2);
if (_min <= _max) {
res = (res + f(_max, i - 1)) % MOD;
res = (res - f(_min - 1, i - 1) + MOD) % MOD;
}
}
if (b >= i) {
int x = b - i;
int _min = max(0, i * (i - 1) / 2 - x);
int _max = min(a, i * (i - 1) / 2);
if (_min <= _max) {
res = (res + f(_max, i - 1)) % MOD;
res = (res - f(_min - 1, i - 1) + MOD) % MOD;
}
}
}
printf("%d\n", res);
}
}
그래프 균형 맞추기
문제
서브태스크 1
정점 2와 정점 3 사이를 가중치가 \(c_1\)인 간선이, 정점 3과 정점 1 사이를 가중치가 \(c_2\)인 간선이, 정점 1과 정점 2 사이를 가중치가 \(c_3\)인 간선이 연결한다고 합시다. 이때 정점 1의 가중치 \(x\), 정점 2의 가중치 \(y\), 정점 3의 가중치 \(z\)에 대하여 다음이 성립합니다.
$$x+y=c_3,\;y+z=c_1,\;z+x=c_2$$
따라서,
$$x=\frac{c_2+c_3-c_1}{2}\\\\,\;y=\frac{c_3+c_1-c_2}{2},\;z=\frac{c_1+c_2-c_3}{2}$$
와 같이 세 가중치 \(x\), \(y\), \(z\)가 유일하게 정해집니다. \(c_1+c_2+c_3\)이 홀수라면 그래프의 균형이 맞도록 정수 가중치를 부여하는 방법은 없습니다.
#include <cstdio>
int n = 3, m = 3;
int a[3];
int main() {
scanf("%*d%*d");
for (int i = 0; i < m; i++) {
int u, v, w;
scanf("%d%d%d", &u, &v, &w);
u--, v--;
a[3 - u - v] = w;
}
if ((a[0] + a[1] + a[2]) % 2) return puts("No"), 0;
puts("Yes");
printf("%d %d %d", (a[1] + a[2] - a[0]) / 2, (a[2] + a[0] - a[1]) / 2, (a[0] + a[1] - a[2]) / 2);
}
서브태스크 2
정점들의 가중치가 유일하게 정해지지 않습니다. 간선들의 조건을 만족하는 정점들의 가중치를 정했다고 합시다. 정점들의 가중치는 \(x_1,\,x_2,\,\cdots,\,x_n\)입니다. 여기서 \(\sum|x_i|\)의 값을 최소로 해야 합니다. 이때 \(x_i=0\)인 \(i\)가 존재해도 최적해가 됩니다. 그런데 어떤 정점의 가중치가 정해지면 나머지 정점의 가중치가 모두 정해집니다. 따라서, 모든 \(N\) 개의 정점에 대해 \(x_i=0\)으로 두고, 나머지 정점들의 가중치를 계산합니다. \(\sum|x_i|\)의 최솟값이 답이 됩니다. 문제를 \(O(N^2)\) 시간에 해결할 수 있습니다.
#include <cstdio>
#include <utility>
#include <algorithm>
using namespace std;
int n, a[1006];
long long w[1006];
pair<long long, int> res = { (long long)9e18, -1 };
int main() {
scanf("%d%*d", &n);
for (int i = 0; i < n - 1; i++) scanf("%*d%*d%d", a + i);
for (int i = 0; i < n; i++) {
w[i] = 0;
for (int j = i - 1; j >= 0; j--) w[j] = a[j] - w[j + 1];
for (int j = i + 1; j < n; j++) w[j] = a[j - 1] - w[j - 1];
long long curr = 0;
for (int j = 0; j < n; j++) curr += w[j] > 0 ? w[j] : -w[j];
res = min(res, pair{ curr, i });
}
puts("Yes");
w[res.second] = 0;
for (int i = res.second - 1; i >= 0; i--) w[i] = a[i] - w[i + 1];
for (int i = res.second + 1; i < n; i++) w[i] = a[i - 1] - w[i - 1];
for (int i = 0; i < n; i++) printf("%lld ", w[i]);
}
서브태스크 3
\(i\)번 간선의 가중치를 \(a_i\)로 둡시다. \(1\)번 정점의 가중치를 \(x\)라고 두면, \(2\)번 정점의 가중치 \(a_1-x\), \(3\)번 정점의 가중치 \(a_2-a_1+x\), …, \(n\)번 정점의 가중치 \(a_n-a_{n-1}+\cdots+(-1)^nx\)가 모두 정해집니다. 따라서 최소로 하는 값은 \(f(x)=|x|+|x-a_1|+|x-a_1+a_2|+\cdots+|x-a_1+a_2-\cdots+(-1)^na_n|\)입니다. 간단히, \(c_1\leq c_2\leq\cdots\leq c_n\)에 대해 \(f(x)=|x-c_1|+|x-c_2|+\cdots+|x-c_n|\)가 되게 하는 \(c_i\)들을 \(O(N\log N)\) 시간에 찾을 수 있습니다. 이때 \(x\)의 범위에 따른 \(f(x)\)는 기울기가 음수에서 양수로 점차 증가하는 꺾인 직선 모양의 개형을 보입니다. \(x<c_{\lfloor\frac{n+1}{2}\rfloor}\)에서 \(f(x)\)의 기울기는 음수, \(x>c_{\lceil\frac{n+1}{2}\rceil}\)에서 \(f(x)\)의 기울기는 양수이고 \(c_{\lfloor\frac{n+1}{2}\rfloor}\leq x\leq c_{\lceil\frac{n+1}{2}\rceil}\)에서 \(f(x)\)의 기울기는 \(0\)이므로, \(f(x)\)는 \(x=c_{\lfloor\frac{n+1}{2}\rfloor}\)에서 최소입니다. 따라서 모든 정점의 가중치를 \(O(N)\) 시간 안에 구할 수 있고, 문제를 \(O(N\log N)\) 시간 안에 해결할 수 있습니다.
#include <cstdio>
#include <algorithm>
using namespace std;
int n, a[100006];
long long w[100006], c[100006];
int main() {
scanf("%d%*d", &n);
for (int i = 0; i < n - 1; i++) scanf("%*d%*d%d", a + i);
for (int i = 1; i < n; i++) {
w[i] = w[i - 1];
if (i % 2) w[i] += a[i - 1];
else w[i] -= a[i - 1];
}
for (int i = 0; i < n; i++) c[i] = w[i];
sort(c, c + n);
puts("Yes");
for (int i = 0; i < n; i++) {
if (i % 2) printf("%lld ", w[i] - c[n / 2]);
else printf("%lld ", c[n / 2] - w[i]);
}
}
서브태스크 4
서브태스크 3과 같은 관찰이 성립합니다. \(1\)번 정점의 가중치를 \(x-a_1\,(a_1=0)\)라고 두면, 트리를 탐색하며 \(i(>1)\)번 정점의 가중치 \(x-a_i\) 또는 \(-x+a_i\)를 \(O(N)\) 시간 안에 구할 수 있고, 마찬가지로 \(f(x)=|x-c_1|+|x-c_2|+\cdots+|x-c_n|\)가 \(x=c_{\lfloor\frac{n+1}{2}\rfloor}\)에서 최소이므로 모든 정점의 가중치를 \(O(N)\) 시간 안에 구할 수 있습니다. 문제를 \(O(N\log N)\) 시간 안에 해결할 수 있습니다.
#include <cstdio>
#include <vector>
#include <utility>
#include <algorithm>
using namespace std;
int n;
vector<pair<int, int>> adj[100006];
vector<long long> v;
pair<int, long long> a[100006];
void dfs(int x, int prev = -1) {
for (auto &i: adj[x]) if (i.first != prev) {
a[i.first] = { -a[x].first, i.second - a[x].second };
dfs(i.first, x);
}
}
int main() {
scanf("%d%*d", &n);
for (int i = 0; i < n - 1; i++) {
int u, v, w;
scanf("%d%d%d", &u, &v, &w);
u--, v--;
adj[u].push_back({ v, w });
adj[v].push_back({ u, w });
}
a[0] = { 1, 0 };
dfs(0);
for (int i = 0; i < n; i++) v.push_back(-a[i].first * a[i].second);
sort(v.begin(), v.end());
long long x = v[(int)v.size() / 2];
puts("Yes");
for (int i = 0; i < n; i++) printf("%lld ", a[i].first * x + a[i].second);
}
서브태스크 5
그래프는 정확히 하나의 큰 사이클을 이룹니다. \(N\)이 홀수라면 정점들의 가중치는 유일하게 결정되며, \(N\)이 짝수라면 정점들의 가중치가 유일하게 결정되지 않습니다. 사이클 위의 정점들을 순서대로 \(1,\;2,\;\cdots,\;N\)번 정점이라고 합시다.
· \(N\)이 홀수라면 \(1\)번 정점의 가중치를 \(x-c_1\,(c_1=0)\)이라고 할 때 \(k\)번 정점의 가중치는 \(x-c_k\)입니다. 따라서 \(1\)번 정점과 \(k\)번 정점을 연결하는 간선의 가중치 \(p\)에 대하여 \(x=\frac{c_1+c_k+p}{2}\)입니다. 이제 모든 정점의 가중치를 유일하게 결정할 수 있고, \(x\)가 정수가 아니라면 그래프의 균형이 맞도록 정점에 정수 가중치를 부여하는 방법은 없습니다.
· \(N\)이 짝수라면 \(1\)번 정점의 가중치를 \(x-c_1\,(c_1=0)\)이라고 할 때 \(k\)번 정점의 가중치는 \(-x+c_k\)입니다. 따라서 \(1\)번 정점과 \(k\)번 정점을 연결하는 간선의 가중치 \(p\)에 대하여 \(c_k-c_1\neq p\)라면 그래프의 균형이 맞도록 정점에 정수 가중치를 부여하는 방법이 없고, 아니라면 서브태스크 4와 같이 문제를 해결할 수 있습니다.
따라서 문제를 \(O(N\log N)\) 시간 안에 해결할 수 있습니다.
#include <cstdio>
#include <vector>
#include <utility>
#include <algorithm>
using namespace std;
int n;
vector<pair<int, int>> adj[100006];
vector<long long> v;
pair<int, long long> a[100006];
pair<int, int> y;
bool visited[100006];
void dfs(int x, int prev = -1) {
visited[x] = true;
for (auto &i: adj[x]) if (i.first != prev) {
if (!i.first) y = { x, i.second };
if (!visited[i.first]) {
a[i.first] = { -a[x].first, i.second - a[x].second };
dfs(i.first, x);
}
}
}
int main() {
scanf("%d%*d", &n);
for (int i = 0; i < n; i++) {
int u, v, w;
scanf("%d%d%d", &u, &v, &w);
u--, v--;
adj[u].push_back({ v, w });
adj[v].push_back({ u, w });
}
a[0] = { 1, 0 };
dfs(0);
if (n % 2) {
if ((y.second - a[y.first].second) % 2) return puts("No"), 0;
puts("Yes");
long long x = (y.second - a[y.first].second) / 2;
for (int i = 0; i < n; i++) printf("%lld ", a[i].first * x + a[i].second);
return 0;
}
if (y.second != a[y.first].second) return puts("No"), 0;
for (int i = 0; i < n; i++) v.push_back(-a[i].first * a[i].second);
sort(v.begin(), v.end());
long long x = v[(int)v.size() / 2];
puts("Yes");
for (int i = 0; i < n; i++) printf("%lld ", a[i].first * x + a[i].second);
}
서브태스크 6
\(1\)번 정점의 가중치를 \(x-c_1\,(c_1=0)\)이라고 두면 DFS 트리를 탐색하며 나머지 \(N-1\) 개 정점의 가중치를 \(x-c_i\) 또는 \(-x+c_i\)로 유일하게 결정할 수 있습니다. 이후 그래프상에서 \(u\)번 정점과 \(v\)번 정점을 연결하는 가중치 \(p\)의 간선에서 각각 다음을 수행합니다.
· \(u\)번 정점의 가중치가 \(x-c_u\)이고 \(v\)번 정점의 가중치가 \(x-c_v\)라면 \(x=\frac{c_u+c_v+p}{2}\)입니다.
· \(u\)번 정점의 가중치가 \(x-c_u\)이고 \(v\)번 정점의 가중치가 \(-x+c_v\)라면 \(c_v-c_u=p\)이어야 합니다.
· \(u\)번 정점의 가중치가 \(-x+c_u\)이고 \(v\)번 정점의 가중치가 \(x-c_v\)라면 \(c_u-c_v=p\)이어야 합니다.
· \(u\)번 정점의 가중치가 \(-x+c_u\)이고 \(v\)번 정점의 가중치가 \(-x+c_v\)라면 \(x=\frac{c_u+c_v-p}{2}\)입니다.
이때 \(x\)의 값에 대해 모순이 발생하거나 조건이 성립하지 않을 때, 그리고 \(x\)의 값이 정수가 아닐 때 그래프의 균형이 맞도록 정점에 정수 가중치를 부여하는 방법이 없습니다. 위 검사를 모두 통과했다고 합시다. \(x\)의 값이 정해졌다면 나머지 정점의 가중치를 유일하게 결정할 수 있습니다. 아니라면, 서브태스크 2와 같은 관찰이 성립합니다. 각 \(i\)마다 \(x_i=0\)으로 두고 나머지 정점의 가중치를 유일하게 결정하면 문제를 \(O(N^2+M)\) 시간 안에 해결할 수 있습니다.
#include <cmath>
#include <cstdio>
#include <vector>
#include <utility>
#include <algorithm>
using namespace std;
int n, m, res[100006];
vector<pair<int, int>> adj[100006];
vector<long long> v;
pair<int, long long> a[100006];
bool visited[100006];
void dfs(int x) {
visited[x] = true;
for (auto &i: adj[x]) if (!visited[i.first]) {
a[i.first] = { -a[x].first, i.second - a[x].second };
dfs(i.first);
}
}
int main() {
scanf("%d%d", &n, &m);
for (int i = 0; i < m; i++) {
int u, v, w;
scanf("%d%d%d", &u, &v, &w);
u--, v--;
adj[u].push_back({ v, w });
adj[v].push_back({ u, w });
}
a[0] = { 1, 0 };
dfs(0);
for (int i = 0; i < n; i++) for (auto &j: adj[i]) if (i < j.first) {
if (a[i].first != a[j.first].first) {
if (a[i].second + a[j.first].second == j.second) continue;
puts("No");
return 0;
} else v.push_back(2 * (j.second - a[j.first].second - a[i].second) / (a[i].first + a[j.first].first));
}
sort(v.begin(), v.end());
v.erase(unique(v.begin(), v.end()), v.end());
if ((int)v.size() > 1) {
puts("No");
return 0;
}
if ((int)v.size() == 1) {
if (v[0] % 2) puts("No");
else {
puts("Yes");
long long x = v[0] / 2;
for (int i = 0; i < n; i++) printf("%lld ", a[i].first * x + a[i].second);
}
return 0;
}
pair<long long, int> res = { (long long)9e18, -1 };
for (int i = 0; i < n; i++) {
long long curr = 0, x = -a[i].first * a[i].second;
for (int j = 0; j < n; j++) curr += llabs(a[j].first * x + a[j].second);
res = min(res, { curr, i });
}
puts("Yes");
for (int i = 0; i < n; i++) printf("%lld ", -a[i].first * a[res.second].first * a[res.second].second + a[i].second);
}
서브태스크 7
서브태스크 6과 같이 시행하되, 최솟값을 서브태스크 3과 같이 찾아 문제를 \(O(N\log N+M)\) 시간 안에 해결할 수 있습니다.
#include <cstdio>
#include <vector>
#include <utility>
#include <algorithm>
using namespace std;
int n, m, res[100006];
vector<pair<int, int>> adj[100006];
vector<long long> v;
pair<int, long long> a[100006];
bool visited[100006];
void dfs(int x) {
visited[x] = true;
for (auto &i: adj[x]) if (!visited[i.first]) {
a[i.first] = { -a[x].first, i.second - a[x].second };
dfs(i.first);
}
}
int main() {
scanf("%d%d", &n, &m);
for (int i = 0; i < m; i++) {
int u, v, w;
scanf("%d%d%d", &u, &v, &w);
u--, v--;
adj[u].push_back({ v, w });
adj[v].push_back({ u, w });
}
a[0] = { 1, 0 };
dfs(0);
for (int i = 0; i < n; i++) for (auto &j: adj[i]) if (i < j.first) {
if (a[i].first != a[j.first].first) {
if (a[i].second + a[j.first].second == j.second) continue;
puts("No");
return 0;
} else v.push_back(2 * (j.second - a[j.first].second - a[i].second) / (a[i].first + a[j.first].first));
}
sort(v.begin(), v.end());
v.erase(unique(v.begin(), v.end()), v.end());
if ((int)v.size() > 1) {
puts("No");
return 0;
}
if ((int)v.size() == 1) {
if (v[0] % 2) puts("No");
else {
puts("Yes");
long long x = v[0] / 2;
for (int i = 0; i < n; i++) printf("%lld ", a[i].first * x + a[i].second);
}
return 0;
}
for (int i = 0; i < n; i++) v.push_back(-a[i].first * a[i].second);
sort(v.begin(), v.end());
long long x = v[(int)v.size() / 2];
puts("Yes");
for (int i = 0; i < n; i++) printf("%lld ", a[i].first * x + a[i].second);
}
가장 긴 공통 괄호 문자열
문제
서브태스크 1
문자열 \(A\)에서 모든 부분 문자열 \(O(N^2)\) 개, 그리고 문자열 \(B\)에서 모든 부분 문자열 \(O(M^2)\) 개에 대해 각각이 서로 같은 올바른 괄호열인지 확인합니다. 이때 서로 같은 괄호열의 길이가 서로 같음을 이용하여 문제를 \(O(N^2M\mathrm{min}(N,M))\) 시간 안에 해결할 수 있습니다.
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
int t, n, m;
char a[106], b[106];
int main() {
for (scanf("%d", &t); t--; ) {
int res = 0;
scanf("%s%s", a, b);
n = strlen(a), m = strlen(b);
for (int i = 0; i < n; i++) for (int j = i + 1; j < n; j += 2) for (int x = 0; x < m; x++) {
int y = j - i + x;
if (y >= m) continue;
int r = 0;
for (int k = i; k <= j; k++) {
r += a[k] == '(' ? 1 : -1;
if (r < 0) goto next;
}
if (r) goto next;
r = 0;
for (int k = x; k <= y; k++) {
r += b[k] == '(' ? 1 : -1;
if (r < 0) goto next;
}
if (r) goto next;
for (int k = i; k <= j; k++) if (a[k] != b[k - i + x]) goto next;
res = max(res, j - i + 1);
next:;
}
printf("%d\n", res);
}
}
서브태스크 2
서브태스크 1에서 문자열 \(A\)의 올바른 괄호열의 집합 \(a\), 문자열 \(B\)의 올바른 괄호열의 집합 \(b\)에 대해 \(a\)와 \(b\)를 각각 \(O(N^2\log N+M^2\log M)\) 시간 안에 정렬한 뒤 선형 시간에 비교합니다. 문제를 \(O(N^3+M^3)\) 시간 안에 해결할 수 있습니다. 상수가 작아서 통과합니다.
#include <cstdio>
#include <vector>
#include <string>
#include <cstring>
#include <algorithm>
using namespace std;
int t, n, m;
char a[1006], b[1006];
int main() {
for (scanf("%d", &t); t--; ) {
vector<string> A, B;
int res = 0;
scanf("%s%s", a, b);
n = strlen(a), m = strlen(b);
for (int i = 0; i < n; i++) for (int j = i + 1; j < n; j += 2) {
string s;
int r = 0;
for (int k = i; k <= j; k++) {
r += a[k] == '(' ? 1 : -1;
s += a[k];
if (r < 0) goto next;
}
if (r) goto next;
A.push_back(s);
next:;
}
for (int i = 0; i < m; i++) for (int j = i + 1; j < m; j += 2) {
string s;
int r = 0;
for (int k = i; k <= j; k++) {
r += b[k] == '(' ? 1 : -1;
s += b[k];
if (r < 0) goto next2;
}
if (r) goto next2;
B.push_back(s);
next2:;
}
sort(A.begin(), A.end());
sort(B.begin(), B.end());
int k = 0;
for (int i = 0; i < (int)A.size(); i++) {
while (k < (int)B.size() && B[k] < A[i]) k++;
if (k < (int)B.size() && B[k] == A[i]) res = max(res, (int)A[i].size());
}
printf("%d\n", res);
}
}
서브태스크 3
문자열 \(A\)에서 모든 올바른 괄호열을 하나의 정수로 해싱합니다. 그런 정수들의 집합을 \(a\)라고 하면 \(a\)의 크기는 \(O(N^2)\)입니다. 또, 문자열 \(B\)에서 모든 올바른 괄호열을 하나의 정수로 해싱합니다. 그런 정수들의 집합을 \(b\)라고 하면 \(b\)의 크기는 \(O(M^2)\)입니다. \(a\)와 \(b\)를 각각 \(O(N^2\log N+M^2\log M)\) 시간 안에 정렬한 뒤 선형 시간에 비교하여 서로 같은 괄호열들의 목록을 얻을 수 있습니다. 이 목록에서 가장 긴 괄호열이 문제의 답이 됩니다. 문제를 \(O(N^2\log N+M^2\log M)\) 시간 안에 해결할 수 있습니다. 상수가 많이 작아서 통과합니다.
#include <cstdio>
#include <vector>
#include <utility>
#include <cstring>
#include <algorithm>
using namespace std;
const int MOD = 1e9 + 7;
int t, n, m;
char a[500006], b[500006];
vector<pair<int, long long>> x, y;
int main() {
for (scanf("%d", &t); t--; ) {
scanf("%s%s", a, b); x.clear(); y.clear();
n = strlen(a), m = strlen(b);
for (int i = 0; i < n; i++) {
int r = 0;
long long curr = 0;
for (int j = i; j < n; j++) {
r += a[j] == '(' ? 1 : -1;
curr = 2 * curr % MOD;
if (a[j] == '(') curr = (curr + 1) % MOD;
if (r < 0) break;
if (r == 0) x.push_back({ j - i + 1, curr });
}
}
for (int i = 0; i < m; i++) {
int r = 0;
long long curr = 0;
for (int j = i; j < m; j++) {
r += b[j] == '(' ? 1 : -1;
curr = 2 * curr % MOD;
if (b[j] == '(') curr = (curr + 1) % MOD;
if (r < 0) break;
if (r == 0) y.push_back({ j - i + 1, curr });
}
}
sort(x.begin(), x.end(), [](const pair<int, long long> &a, const pair<int, long long> &b) {
if (a.first != b.first) return a > b;
return a.second < b.second;
});
sort(y.begin(), y.end(), [](const pair<int, long long> &a, const pair<int, long long> &b) {
if (a.first != b.first) return a > b;
return a.second < b.second;
});
int k = 0, res = 0;
for (int i = 0; i < (int)x.size(); i++) {
while (k < (int)y.size() && (y[k].first > x[i].first || y[k].first == x[i].first && y[k].second < x[i].second)) k++;
while (k < (int)y.size() && y[k] == x[i]) res = max(res, x[i].first), k++;
}
printf("%d\n", res);
}
}
서브태스크 4
\(i\)째 문자의 이전에 등장하는 ('('의 개수) - (')'의 개수)의 값을 \(r_i\)로 정의합시다. 이제 각 \(i\)에 대하여, \(r_i=r_k\;(k>i)\)인 최소의 \(k\)를 \(b_i\)라 합시다. 단, 그러한 \(k\)가 존재하지 않거나 \(i\)째 문자부터 \(k-1\)째 문자가 올바른 괄호열을 이루지 않는다면 \(b_i=-1\)입니다. 이는 세그먼트 트리를 이용해 판별할 수 있습니다. \(b_i=k\)인 \(i\)가 존재하지 않는 \(k\)들을 관리하며 문자열 \(A\)의 올바른 괄호열의 최대 길이를 얻을 수 있습니다. 문제를 \(O(N\log N)\) 시간 안에 해결할 수 있습니다.
#include <cstdio>
#include <vector>
#include <cstring>
#include <algorithm>
using namespace std;
int tree[1048576];
int update(int i, int b, int e, int p, int v) {
if (p < b || e < p) return tree[i];
if (b == e) return tree[i] = v;
int m = (b + e) / 2;
return tree[i] = min(update(i * 2 + 1, b, m, p, v), update(i * 2 + 2, m + 1, e, p, v));
}
int query(int i, int b, int e, int l, int r) {
if (r < l || r < b || e < l) return 1e9;
if (l <= b && e <= r) return tree[i];
int m = (b + e) / 2;
return min(query(i * 2 + 1, b, m, l, r), query(i * 2 + 2, m + 1, e, l, r));
}
int t, n, r[500006];
char a[500006];
int ina[500006], _fa[500006], *fa;
vector<int> in_a;
int main() {
for (scanf("%d", &t); t--; ) {
scanf("%s%*s", a);
n = strlen(a);
for (int i = 0; i < n; i++) {
r[i] = (i ? r[i - 1] : 0) + (a[i] == '(' ? 1 : -1);
update(0, 0, n - 1, i, r[i]);
}
int mr = *min_element(r, r + n);
in_a.clear();
fa = _fa - min(mr, 0);
for (int i = 0; i <= n; i++) _fa[i] = 1e9;
for (int i = n; i >= 0; i--) {
int x = i ? r[i - 1] : 0;
ina[i] = fa[x];
fa[x] = i;
}
for (int i = 0; i <= n; i++) if (_fa[i] < (int)1e9) in_a.push_back(_fa[i]);
for (int i = 0; i <= n; i++) if (ina[i] < (int)1e9 && query(0, 0, n - 1, i, ina[i] - 1) < (i ? r[i - 1] : 0)) {
in_a.push_back(ina[i]);
ina[i] = 1e9;
}
int res = 0;
for (auto &i: in_a) {
int x = i;
while (ina[x] < (int)1e9) x = ina[x];
res = max(res, x - i);
}
printf("%d\n", res);
}
}
서브태스크 5
문제의 답 \(k\)에 대해 이분 탐색합니다. 문제는 "길이 \(k\) 이상의 공통 괄호 문자열이 존재하는가?"의 결정 문제로 환원되었습니다. 문자열 \(A\)의 각 문자에 대해 "각 문자에서 시작하는 길이 \(k\) 이상의 문자열 중 길이가 최소인 올바른 괄호열"들을 모두 해싱해 집합 \(a\)를 얻고, 문자열 \(B\)의 각 문자에 대해 같은 방법으로 집합 \(b\)를 얻습니다. 이는 각 문자의 이전에 등장하는 ('('의 개수) - (')'의 개수)가 같은 문자들끼리 묶어 준 뒤 투 포인터를 이용해 구하는 괄호열의 후보를 얻고, 그 후보가 유효한지 세그먼트 트리 등을 이용해 확인하는 방법으로 \(O(N\log N+M\log M)\) 시간 안에 구성할 수 있습니다. 집합 \(a\)의 크기는 \(O(N)\)이며, 집합 \(b\)의 크기는 \(O(M)\)입니다. 서브태스크 3과 같은 방법으로 문제의 답을 \(O((N\log N+M\log M)\log\mathrm{min}(N,M))\) 시간 안에 얻을 수 있습니다.
#include <cstdio>
#include <vector>
#include <cstring>
#include <algorithm>
using namespace std;
const int MOD = 1e9 + 7;
int tree1[1048576], tree2[1048576];
int update(int *tree, int i, int b, int e, int p, int v) {
if (p < b || e < p) return tree[i];
if (b == e) return tree[i] = v;
int m = (b + e) / 2;
return tree[i] = min(update(tree, i * 2 + 1, b, m, p, v), update(tree, i * 2 + 2, m + 1, e, p, v));
}
int query(int *tree, int i, int b, int e, int l, int r) {
if (r < l || r < b || e < l) return 1e9;
if (l <= b && e <= r) return tree[i];
int m = (b + e) / 2;
return min(query(tree, i * 2 + 1, b, m, l, r), query(tree, i * 2 + 2, m + 1, e, l, r));
}
int t, n, m, r[500006], s[500006], two[500006], rev[500006];
char a[500006], b[500006];
int hasha[500006], hashb[500006];
int ina[500006], inb[500006], _fa[500006], *fa, _fb[500006], *fb;
vector<int> A, B, in_a, in_b;
int get_hash(int *hs, int n, int x, int y) {
int ret = (hs[y] - (x ? hs[x - 1] : 0) + MOD) % MOD;
return 1LL * ret * rev[x] % MOD;
}
int main() {
two[0] = rev[0] = 1;
for (int i = 1; i < 500006; i++) {
rev[i] = rev[i - 1] * 500000004LL % MOD;
two[i] = two[i - 1] * 2 % MOD;
}
for (scanf("%d", &t); t--; ) {
scanf("%s%s", a, b);
n = strlen(a), m = strlen(b);
for (int i = 0; i < n; i++) {
r[i] = (i ? r[i - 1] : 0) + (a[i] == '(' ? 1 : -1);
update(tree1, 0, 0, n - 1, i, r[i]);
hasha[i] = ((i ? hasha[i - 1] : 0) + two[i] * (a[i] == '(' ? 1 : 0)) % MOD;
}
for (int i = 0; i < m; i++) {
s[i] = (i ? s[i - 1] : 0) + (b[i] == '(' ? 1 : -1);
update(tree2, 0, 0, m - 1, i, s[i]);
hashb[i] = ((i ? hashb[i - 1] : 0) + two[i] * (b[i] == '(' ? 1 : 0)) % MOD;
}
int mr = *min_element(r, r + n), ms = *min_element(s, s + m);
in_a.clear(), in_b.clear();
fa = _fa - min(mr, 0);
for (int i = 0; i <= n; i++) _fa[i] = 1e9;
for (int i = n; i >= 0; i--) {
int x = i ? r[i - 1] : 0;
ina[i] = fa[x];
fa[x] = i;
}
for (int i = 0; i <= n; i++) if (_fa[i] < (int)1e9) in_a.push_back(_fa[i]);
for (int i = 0; i <= n; i++) if (ina[i] < (int)1e9 && query(tree1, 0, 0, n - 1, i, ina[i] - 1) < (i ? r[i - 1] : 0)) {
in_a.push_back(ina[i]);
ina[i] = 1e9;
}
fb = _fb - min(ms, 0);
for (int i = 0; i <= m; i++) _fb[i] = 1e9;
for (int i = m; i >= 0; i--) {
int x = i ? s[i - 1] : 0;
inb[i] = fb[x];
fb[x] = i;
}
for (int i = 0; i <= m; i++) if (_fb[i] < (int)1e9) in_b.push_back(_fb[i]);
for (int i = 0; i <= m; i++) if (inb[i] < (int)1e9 && query(tree2, 0, 0, m - 1, i, inb[i] - 1) < (i ? s[i - 1] : 0)) {
in_b.push_back(inb[i]);
inb[i] = 1e9;
}
int L = 1, R = min(n, m);
while (L < R) {
int M = (L + R + 1) / 2;
A.clear(), B.clear();
for (auto &i: in_a) {
int curr = i, x = curr;
while (x < (int)1e9) {
while (x < (int)1e9 && x - curr < M) x = ina[x];
if (x < (int)1e9) {
A.push_back(get_hash(hasha, n, curr, x - 1));
curr = ina[curr];
}
}
}
for (auto &i: in_b) {
int curr = i, x = curr;
while (x < (int)1e9) {
while (x < (int)1e9 && x - curr < M) x = inb[x];
if (x < (int)1e9) {
B.push_back(get_hash(hashb, m, curr, x - 1));
curr = inb[curr];
}
}
}
sort(A.begin(), A.end());
sort(B.begin(), B.end());
int k = 0;
bool none = true;
for (int i = 0; i < (int)A.size(); i++) {
while (k < (int)B.size() && B[k] < A[i]) k++;
if (k < (int)B.size() && A[i] == B[k]) {
none = false;
break;
}
}
if (none) R = M - 1;
else L = M;
}
printf("%d\n", L > 1 ? L : 0);
}
}
맛집 추천
문제