본문 바로 가기

PS

2022 KOI 고등부 2차 대회 후기

 2022 한국 정보 올림피아드 2차 대회가 2022년 7월 16일에 온라인으로 진행되었습니다. 모두 수고하셨습니다.

 

모두 334점을 받았습니다.

타임라인

00:00:00 - 00:14:01

 1번 문제 "트리와 쿼리"를 읽고, 서로소 집합을 이용한 \(O(N+Q+\alpha(N)\sum K)\) 풀이를 작성해 해결했습니다.

 

#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;

int uf[250006], sz[250006];
long long res;
vector<int> v;

int _find(int x) {
    if (uf[x] == -1) return x;
    return uf[x] = _find(uf[x]);
}

long long _merge(int x, int y) {
    v.push_back(x);
    v.push_back(y);
    x = _find(x), y = _find(y);
    long long ret = 0;
    if (x != y) {
        ret = 1LL * sz[x] * sz[y];
        sz[y] += sz[x], uf[x] = y;
    }
    return ret;
}

int n, q, pr[250006];
vector<int> adj[250006];
bool selected[250006];

void dfs(int x, int prev = -1) {
    pr[x] = prev;
    for (auto &i: adj[x]) if (i != prev) dfs(i, x);
}

void initialize() {
    res = 0;
    for (auto &i: v) uf[i] = -1, sz[i] = 1;
    v.clear();
}

int main() {
    scanf("%d", &n);
    for (int i = 0; i < n; i++) uf[i] = -1, sz[i] = 1;
    for (int i = 0; i < n - 1; i++) {
        int x, y;
        scanf("%d%d", &x, &y);
        x--, y--;
        adj[x].push_back(y);
        adj[y].push_back(x);
    }
    dfs(0);
    scanf("%d", &q);
    while (q--) {
        initialize();
        vector<int> lt;
        int k;
        scanf("%d", &k);
        lt.resize(k);
        for (auto &i: lt) scanf("%d", &i);
        for (auto &i: lt) selected[--i] = true;
        for (auto &i: lt) if (i) if (selected[pr[i]]) res += _merge(i, pr[i]);
        for (auto &i: lt) selected[i] = false;
        printf("%lld\n", res);
    }
}

 

00:14:01 ~ 00:27:42

 2번 문제 "식사 계획 세우기"를 읽고, 정해를 한번에 짤 수 있다는 확신이 없어 \(O(N^2)\) 풀이를 작성해 먼저 49점을 얻었습니다.

 

#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;

int n, a[300006];
bool selected[300006];
vector<int> v[300006], res;

int main() {
    scanf("%d", &n);
    for (int i = 0; i < n; i++) {
        scanf("%d", a + i);
        v[--a[i]].push_back(i);
    }
    for (int i = 0; i < n; i++) reverse(v[i].begin(), v[i].end());
    int pv = -1;
    for (int i = 0; i < n; i++) {
        int total = n - i;
        int mx = 0, s = -1;
        for (int i = 0; i < n; i++) {
            if ((int)v[i].size() > mx) mx = (int)v[i].size(), s = i;
        }
        if (mx >= total / 2 + 1) {
            if (pv == s) return puts("-1"), 0;
            pv = s;
            res.push_back(v[s].back());
            v[s].pop_back();
            continue;
        }
        mx = (int)1e9, s = -1;
        for (int i = 0; i < n; i++) if (!v[i].empty()) {
            if (v[i].back() < mx && pv != i) mx = v[i].back(), s = i;
        }
        if (s == -1) return puts("-1"), 0;
        res.push_back(v[s].back());
        pv = s;
        v[s].pop_back();
    }
    for (auto &i: res) printf("%d ", i + 1);
}

 

00:27:42 ~ 00:35:56

 \(O(N\log N)\) 풀이를 작성해 2번 문제를 해결했습니다.

 

#include <set>
#include <cstdio>
#include <vector>
#include <utility>
#include <algorithm>
using namespace std;

int n, a[300006];
bool selected[300006];
vector<int> v[300006], res;
set<pair<int, int>> s1, s2;

int main() {
    scanf("%d", &n);
    for (int i = 0; i < n; i++) {
        scanf("%d", a + i);
        v[--a[i]].push_back(i);
    }
    for (int i = 0; i < n; i++) {
        reverse(v[i].begin(), v[i].end());
        s1.insert({ (int)v[i].size(), i });
        if (!v[i].empty()) s2.insert({ v[i].back(), i });
    }
    int pv = -1;
    for (int i = 0; i < n; i++) {
        int total = n - i;
        int mx = 0, s = -1;
        auto it = s1.rbegin();
        mx = it->first, s = it->second;
        if (mx >= total / 2 + 1) {
            if (pv == s) return puts("-1"), 0;
            pv = s;
            s1.erase({ (int)v[s].size(), s });
            s2.erase({ v[s].back(), s });
            res.push_back(v[s].back());
            v[s].pop_back();
            s1.insert({ (int)v[s].size(), s });
            if (!v[s].empty()) s2.insert({ v[s].back(), s });
            continue;
        }
        if (pv != -1 && !v[pv].empty()) s2.erase({ v[pv].back(), pv });
        if (!s2.empty()) s = s2.begin()->second;
        else s = -1;
        if (pv != -1 && !v[pv].empty()) s2.insert({ v[pv].back(), pv });
        if (s == -1) return puts("-1"), 0;
        pv = s;
        s1.erase({ (int)v[s].size(), s });
        s2.erase({ v[s].back(), s });
        res.push_back(v[s].back());
        v[s].pop_back();
        s1.insert({ (int)v[s].size(), s });
        if (!v[s].empty()) s2.insert({ v[s].back(), s });
    }
    for (auto &i: res) printf("%d ", i + 1);
}

 

00:35:56 ~ 00:54:57

 3번 문제 "레벨 업"을 읽고, \(O(MN\log N+MK)\) 풀이를 작성해 4점을 얻었습니다.

 

#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;

int n, m, k, l[100006];

int main() {
    scanf("%d", &n);
    for (int i = 0; i < n; i++) scanf("%d", l + i);
    scanf("%d%d", &m, &k);
    for (int i = 0; i < m; i++) {
        sort(l, l + n);
        for (int i = 0; i < k; i++) l[i]++;
    }
    sort(l, l + n);
    for (int i = 0; i < n; i++) printf("%d ", l[i]);
}

 

00:54:57 ~ 03:03:55

 \(O(N\log N+K\log N\log X)\) 풀이를 작성해 3번 문제 "레벨 업"을 해결했습니다. 틀린 이유를 알지 못한 채로 오랫동안 WA를 받았는데, 그 까닭은 최대 레벨이 2×10⁹까지인 것을 10⁹까지로 처리했기 때문이었습니다.

 

#pragma GCC optimize("O3")
#pragma GCC optimize("Ofast")
#pragma GCC optimize("unroll-loops")
#pragma GCC target("avx,avx2,fma")

#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;

#define int long long

int tree1[100006];

void update1(int x, int p) {
    while (x < 100003) {
        tree1[x] += p;
        x += x & -x;
    }
}

int query1(int x) {
    int res = 0;
    while (x > 0) {
        res += tree1[x];
        x -= x & -x;
    }
    return res;
}

int query1(int l, int r) {
    return query1(r + 1) - query1(l);
}

long long tree2[100006];

void update2(int x, long long p) {
    while (x < 100003) {
        tree2[x] += p;
        x += x & -x;
    }
}

long long query2(int x) {
    long long res = 0;
    while (x > 0) {
        res += tree2[x];
        x -= x & -x;
    }
    return res;
}

long long query2(int l, int r) {
    return query2(r + 1) - query2(l);
}

int tree3[262144];
int lazy[262144];

inline void propagate(int i, int b, int e) {
    if (lazy[i] == -1) return;
    tree3[i] = lazy[i];
    if (b != e) lazy[i * 2 + 1] = lazy[i * 2 + 2] = lazy[i];
    lazy[i] = -1;
}

int query3(int i, int b, int e, int x) {
    propagate(i, b, e);
    if (x < b || e < x) return 0;
    if (b == e) return tree3[i];
    int m = (b + e) / 2;
    return query3(i * 2 + 1, b, m, x) + query3(i * 2 + 2, m + 1, e, x);
}

int query3c(int i, int b, int e, int x) {
    propagate(i, b, e);
    int m = (b + e) / 2;
    if (b == e) return b;
    else propagate(i * 2 + 1, b, m);
    if (tree3[i * 2 + 1] < x) return query3c(i * 2 + 2, m + 1, e, x);
    return query3c(i * 2 + 1, b, m, x);
}

int query3u(int i, int b, int e, int x) {
    propagate(i, b, e);
    int m = (b + e) / 2;
    if (b == e) return b;
    else propagate(i * 2 + 1, b, m);
    if (tree3[i * 2 + 1] <= x) return query3u(i * 2 + 2, m + 1, e, x);
    return query3u(i * 2 + 1, b, m, x);
}

int update3(int i, int b, int e, int l, int r, int x) {
    propagate(i, b, e);
    if (r < b || e < l) return tree3[i];
    if (l <= b && e <= r) {
        lazy[i] = max(lazy[i], x);
        propagate(i, b, e);
        return tree3[i];
    }
    int m = (b + e) / 2;
    return tree3[i] = max(update3(i * 2 + 1, b, m, l, r, x), update3(i * 2 + 2, m + 1, e, l, r, x));
}

int n, m, k, l[100006];
int pl[100006], cl[100006], to[100006];

int getc(int x) {
    return query3c(0, 0, 131071, x);
}

int getu(int x) {
    return query3u(0, 0, 131071, x);
}

signed main() {
    scanf("%lld", &n);
    for (int i = 0; i < n; i++) scanf("%lld", l + i);
    scanf("%lld%lld", &m, &k);
    sort(l, l + n);
    for (int i = 0; i < n; i++) tree3[131071 + i] = l[i];
    for (int i = n; i < 131071; i++) tree3[131071 + i] = (int)2e9 + 17;
    for (int i = 131070; i >= 0; i--) tree3[i] = max(tree3[i * 2 + 1], tree3[i * 2 + 2]);
    for (int i = 0; i < 262144; i++) lazy[i] = -1;
    for (int i = 0; i < k; i++) update1(getc(l[i]) + 1, 1);
    for (int i = 0; i < n; i++) update2(getc(l[i]) + 1, l[i]);
    for (int i = n - 1; i >= k; i--) to[i] = l[i], pl[i] = 1, cl[i] = 0;
    for (int i = k - 1; i >= 0; i--) {
        int L = l[i], R = (int)2e9 + 7;
        int getcli = getc(l[i]);
        while (L < R) {
            int M = (L + R) / 2;
            int getum = getu(M);
            int N = getum - getcli;
            long long sum = 1LL * N * M - query2(getcli, getum - 1);
            long long curr = 1LL * query1(getcli, getum - 1) * m - sum;
            if (curr >= N) L = M + 1;
            else R = M;
        }
        to[i] = L;
        int getum = getu(to[i]);
        int N = getum - getcli;
        long long sum = 1LL * N * to[i] - query2(getcli, getum - 1);
        long long curr = 1LL * query1(getcli, getum - 1) * m - sum;
        pl[i] = N, cl[i] = curr;
        if (N <= curr) return -1;
        int x = getc(L);
        if (i <= x - 1) update3(0, 0, 131071, i, min(x, n) - 1, to[i]);
    }
    for (int i = 0; i < n; i += pl[i]) {
        if (!pl[i] || pl[i] > n || cl[i] > n) return 0;
        for (int j = 0; j < pl[i] - cl[i]; j++) printf("%lld ", to[i]);
        for (int j = 0; j < cl[i]; j++) printf("%lld ", to[i] + 1);
    }
}

 

03:03:55 ~ 03:27:24

 4번 문제 "외곽 순환 도로"를 읽었습니다. 배점을 보니 334점까지는 도달하기 쉽고 그 초과는 어려울 것 같아서 부분 점수를 받기로 했습니다. 부분 문제 3을 해결해 5점을 얻었습니다.

 

#include <cstdio>
#include <vector>
#include <utility>
#include <algorithm>
using namespace std;

int q, n, p[100006], c[100006];
int sp[20][100006], depth[100006];
long long ds[20][100006];
vector<pair<int, int>> adj[100006];

int dfs(int x) {
    int res = 0;
    if (adj[x].empty()) res++;
    for (auto &i: adj[x]) {
        depth[i.first] = depth[x] + 1;
        res += dfs(i.first);
    }
    return res;
}

long long dist(int x, int y) {
    long long res = 0;
    if (depth[x] < depth[y]) swap(x, y);
    int diff = depth[x] - depth[y];
    for (int t = 19; t >= 0; t--) if (diff >= 1 << t) {
        diff -= 1 << t;
        res += ds[t][x];
        x = sp[t][x];
    }
    for (int t = 19; t >= 0; t--) {
        if (sp[t][x] == sp[t][y] || sp[t][x] == -1 || sp[t][y] == -1) continue;
        res += ds[t][x] + ds[t][y];
        x = sp[t][x], y = sp[t][y];
    }
    if (x != y) res += ds[0][x] + ds[0][y];
    return res;
}

int main() {
    scanf("%d", &n);
    for (int i = 1; i < n; i++) scanf("%d%d", p + i, c + i);
    for (int i = 0; i < n; i++) p[i]--;
    for (int i = 0; i < n; i++) {
        sp[0][i] = p[i];
        ds[0][i] = c[i];
    }
    for (int i = 1; i < n; i++) adj[p[i]].push_back({ i, c[i] });
    int T = dfs(0);
    ds[0][0] = (long long)4e18;
    for (int t = 1; t < 20; t++) for (int i = 0; i < n; i++) {
        if (sp[t - 1][i] == -1) {
            sp[t][i] = -1;
            ds[t][i] = (long long)4e18;
        } else {
            sp[t][i] = sp[t - 1][sp[t - 1][i]];
            ds[t][i] = ds[t - 1][i] + ds[t - 1][sp[t - 1][i]];
        }
    }
    for (int i = 0; i < T; i++) scanf("%*d");
    scanf("%d", &q);
    while (q--) {
        int x, y;
        scanf("%d%d", &x, &y);
        x--, y--;
        printf("%lld\n", dist(x, y));
    }
}

 

03:27:24 ~ 03:43:48

 4번 문제의 부분 문제 1을 해결해 6점을 추가로 얻었습니다.

 

#include <queue>
#include <cstdio>
#include <vector>
#include <utility>
#include <algorithm>
#include <functional>
using namespace std;

int Q, n, p[100006];
long long c[100006], w[100006], dist[100006];
int depth[100006];
vector<pair<int, long long>> adj[100006];
vector<int> leaves;

void dfs(int x, int prev = -1) {
    if ((int)adj[x].size() == 1) leaves.push_back(x);
    for (auto &i: adj[x]) if (i.first != prev) {
        depth[i.first] = depth[x] + 1;
        dfs(i.first, x);
    }
}

int main() {
    scanf("%d", &n);
    for (int i = 1; i < n; i++) scanf("%d%lld", p + i, c + i);
    for (int i = 0; i < n; i++) p[i]--;
    for (int i = 1; i < n; i++) {
        adj[p[i]].push_back({ i, c[i] });
        adj[i].push_back({ p[i], c[i] });
    }
    dfs(0);
    for (int i = 0; i < (int)leaves.size(); i++) {
        scanf("%lld", w + i);
        adj[leaves[i]].push_back({ leaves[(i + 1) % (int)leaves.size()], w[i] });
        adj[leaves[(i + 1) % (int)leaves.size()]].push_back({ leaves[i], w[i] });
    }
    priority_queue<pair<long long, int>, vector<pair<long long, int>>, greater<pair<long long, int>>> q;
    q.push({ 0, 0 });
    for (int i = 1; i < n; i++) dist[i] = (long long)4e18;
    while (!q.empty()) {
        pair<long long, int> t = q.top(); q.pop();
        if (dist[t.second] < t.first) continue;
        for (auto &i: adj[t.second]) if (dist[i.first] > t.first + i.second) {
            dist[i.first] = t.first + i.second;
            q.push({ dist[i.first], i.first });
        }
    }
    scanf("%d", &Q);
    while (Q--) {
        int x, y;
        scanf("%d%d", &x, &y);
        x--, y--;
        printf("%lld\n", dist[y]);
    }
}

 

03:43:48 ~ 03:49:42

 4번 문제의 부분 문제 2를 해결해 8점을 추가로 얻었습니다.

 

#include <queue>
#include <cstdio>
#include <vector>
#include <utility>
#include <algorithm>
#include <functional>
using namespace std;

int Q, n, p[100006];
long long c[100006], w[100006], dist[100006], a[100006];
int depth[100006];
vector<pair<int, long long>> adj[100006];
vector<int> leaves;

void dfs(int x, int prev = -1) {
    if ((int)adj[x].size() == 1) leaves.push_back(x);
    for (auto &i: adj[x]) if (i.first != prev) {
        depth[i.first] = depth[x] + 1;
        dfs(i.first, x);
    }
}

int main() {
    scanf("%d", &n);
    for (int i = 1; i < n; i++) scanf("%d%lld", p + i, c + i);
    for (int i = 0; i < n; i++) p[i]--;
    for (int i = 1; i < n; i++) {
        adj[p[i]].push_back({ i, c[i] });
        adj[i].push_back({ p[i], c[i] });
    }
    dfs(0);
    long long sum = 0;
    for (int i = 0; i < (int)leaves.size(); i++) {
        scanf("%lld", w + i);
        sum += w[i];
        adj[leaves[i]].push_back({ leaves[(i + 1) % (int)leaves.size()], w[i] });
        adj[leaves[(i + 1) % (int)leaves.size()]].push_back({ leaves[i], w[i] });
    }
    priority_queue<pair<long long, int>, vector<pair<long long, int>>, greater<pair<long long, int>>> q;
    q.push({ 0, 0 });
    for (int i = 1; i < n; i++) dist[i] = (long long)4e18;
    while (!q.empty()) {
        pair<long long, int> t = q.top(); q.pop();
        if (dist[t.second] < t.first) continue;
        for (auto &i: adj[t.second]) if (dist[i.first] > t.first + i.second) {
            dist[i.first] = t.first + i.second;
            q.push({ dist[i.first], i.first });
        }
    }
    for (int i = 0; i < (int)leaves.size(); i++) {
        if (i) a[leaves[i]] = a[leaves[i - 1]];
        if (i) a[leaves[i]] += w[i - 1];
    }
    scanf("%d", &Q);
    while (Q--) {
        int x, y;
        scanf("%d%d", &x, &y);
        x--, y--;
        if (x > y) swap(x, y);
        if (!x) {
            printf("%lld\n", dist[y]);
            continue;
        }
        long long D = a[y] - a[x];
        printf("%lld\n", min({ dist[x] + dist[y], D, sum - D }));
    }
}

 

03:49:42 ~ 03:59:05

 4번 문제의 부분 문제 4를 해결해 15점을 추가로 얻었습니다.

 

#include <queue>
#include <cstdio>
#include <vector>
#include <utility>
#include <algorithm>
#include <functional>
using namespace std;

int Q, n, p[100006], sp[20][100006];
long long c[100006], w[100006], dist[100006], ds[20][100006];
int depth[100006];
vector<pair<int, long long>> adj[100006];
vector<int> leaves;

void dfs(int x, int prev = -1) {
    if ((int)adj[x].size() == 1) leaves.push_back(x);
    for (auto &i: adj[x]) if (i.first != prev) {
        depth[i.first] = depth[x] + 1;
        dfs(i.first, x);
    }
}

long long tree_dist(int x, int y) {
    long long res = 0;
    if (depth[x] < depth[y]) swap(x, y);
    int diff = depth[x] - depth[y];
    for (int t = 19; t >= 0; t--) if (diff >= 1 << t) {
        diff -= 1 << t;
        res += ds[t][x];
        x = sp[t][x];
    }
    for (int t = 19; t >= 0; t--) {
        if (sp[t][x] == sp[t][y] || sp[t][x] == -1 || sp[t][y] == -1) continue;
        res += ds[t][x] + ds[t][y];
        x = sp[t][x], y = sp[t][y];
    }
    if (x != y) res += ds[0][x] + ds[0][y];
    return res;
}

int main() {
    scanf("%d", &n);
    for (int i = 1; i < n; i++) scanf("%d%lld", p + i, c + i);
    for (int i = 0; i < n; i++) p[i]--;
    for (int i = 1; i < n; i++) {
        adj[p[i]].push_back({ i, c[i] });
        adj[i].push_back({ p[i], c[i] });
    }
    for (int i = 0; i < n; i++) {
        sp[0][i] = p[i];
        ds[0][i] = c[i];
    }
    dfs(0);
    ds[0][0] = (long long)4e18;
    for (int t = 1; t < 20; t++) for (int i = 0; i < n; i++) {
        if (sp[t - 1][i] == -1) {
            sp[t][i] = -1;
            ds[t][i] = (long long)4e18;
        } else {
            sp[t][i] = sp[t - 1][sp[t - 1][i]];
            ds[t][i] = ds[t - 1][i] + ds[t - 1][sp[t - 1][i]];
        }
    }
    for (int i = 0; i < (int)leaves.size(); i++) {
        scanf("%lld", w + i);
        adj[leaves[i]].push_back({ leaves[(i + 1) % (int)leaves.size()], w[i] });
        adj[leaves[(i + 1) % (int)leaves.size()]].push_back({ leaves[i], w[i] });
    }
    priority_queue<pair<long long, int>, vector<pair<long long, int>>, greater<pair<long long, int>>> q;
    for (int i = 0; i < n; i++) dist[i] = (long long)4e18;
    for (auto &i: leaves) q.push({ 0, i }), dist[i] = 0;
    while (!q.empty()) {
        pair<long long, int> t = q.top(); q.pop();
        if (dist[t.second] < t.first) continue;
        for (auto &i: adj[t.second]) if (dist[i.first] > t.first + i.second) {
            dist[i.first] = t.first + i.second;
            q.push({ dist[i.first], i.first });
        }
    }
    scanf("%d", &Q);
    while (Q--) {
        int x, y;
        scanf("%d%d", &x, &y);
        x--, y--;
        if (x > y) swap(x, y);
        printf("%lld\n", min({ dist[x] + dist[y], tree_dist(x, y) }));
    }
}

 

03:59:05 ~ 04:30:00

 4번 문제를 생각하다가 대회가 끝났습니다. 4번 문제는 센트로이드를 이용한 문제였다고 합니다.

 

 

소감

 받아야 할 점수를 모두 받아서, 제 실력보다 잘 보았다고 생각합니다. 특히 3번 문제를 해결했을 때에는 쾌재를 불렀는데, 지금까지 치른 KOI에서 3번 이상의 문제를 해결했던 적이 없었기 때문입니다. 또한 4번 문제의 배점을 본 뒤 334점 초과의 점수를 받기가 굉장히 어려울 것으로 생각했습니다. 제가 파악하고 있는 바에 따르면 334점이 매우 많은 반면 저는 334점을 늦게 받은 편이며 금상 수상 인원이 적어서, 은상 수상이 확실시됩니다. 문제들은 예년보다는 쉬웠다고 생각하며, 결과 역시 만족합니다.

 

1. 트리와 쿼리

서브태스크 5

 트리 그래프에서 \(Q\) 개의 쿼리 각각에 대해 주어지는 크기가 \(K\)인 집합 \(S\) 안의 정점만을 이용해 오갈 수 있는 정점의 쌍의 개수를 구해야 합니다. 그리고 생각해 보면, 정점 \(A\)와 정점 \(B\) 사이를 오갈 수 있고 정점 \(A\)와 정점 \(C\) 사이를 오갈 수 있으면 정점 \(A\)와 정점 \(B\) 사이를 오갈 수 있습니다. 따라서 서로 오갈 수 있는 정점을 하나의 묶음으로 묶으면 \(S\)의 모든 정점은 몇 개의 묶음으로 나누어지고, 각각의 묶음에 속하는 정점의 개수가 \(c_i\)라고 할 때(즉, \(\sum c_i=K\)) 한 묶음에서 서로 오갈 수 있는 쌍의 개수는 \(\frac{1}{2}c_i(c_i-1)\)입니다. 따라서 \(\frac{1}{2}\sum c_i(c_i-1)\)의 값을 빠르게 구할 수 있는 방법을 찾으면 됩니다.

 정점이 몇 개의 묶음으로 나누어지므로 서로소 집합을 생각할 수 있습니다. 또한, 서로소 집합을 관리하며 각 집합의 원소 개수와 모든 집합에서 \(\frac{1}{2}c_i(c_i-1)\)의 합을 관리합시다. 집합을 몇 번 합쳐 서로소 집합의 최종 상태가 원하는 상태가 됐을 때 답을 구하는 방법을 생각해 봅시다. 원소의 개수가 각각 \(a\), \(b\)인 서로 다른 두 집합을 합친다고 생각하면, 기존에 두 집합에서의 관리하는 값은 \(\frac{1}{2}a(a-1)+\frac{1}{2}b(b-1)=\frac{1}{2}(a^2+b^2-a-b)\)이고, 합친 후의 집합에서의 관리하는 값은 \(\frac{1}{2}(a+b)(a+b-1)=\frac{1}{2}(a^2+b^2-a-b)+ab\)입니다. 따라서 두 집합을 합칠 때마다 관리하는 값에 두 집합의 원소의 개수의 곱을 더합니다. 이렇게 구하는 값을 빠르게 구할 수 있습니다.

 이때 집합을 합치는 연산의 횟수가 \(O(K)\) 회 정도로 적어야 합니다. 이는 트리 그래프의 성질 덕분에 가능한데, 트리 그래프의 루트를 하나 잡았을 때 \(S\)의 원소를 서로 연결하는 간선의 집합은 \(S\)의 각 원소 \(v\)와 \(v\)의 부모를 연결하는 간선의 집합의 부분 집합입니다. 따라서 \(O(K)\) 개의 \(S\)의 각 원소와 부모를 연결하는 간선들이 \(S\)의 원소를 서로 연결하는 간선인지를 보고, 그러한 간선만 선택해 서로소 집합에서 두 집합을 합치면 됩니다.

 각 쿼리 이후에 서로소 집합을 초기화해야 하는데, 단순히 이번 쿼리에서의 연산에 영향을 받은 메모리만 기록해 두었다가 초기화하면 됩니다. 메모리를 기록하는 데 걸린 시간 이상의 시간이 걸리지 않습니다. 시간 복잡도는 \(O(N+Q+\alpha(N)\sum K)\)입니다.

 트리에서의 동적 계획법을 이용한 풀이도 있다고 합니다.

 

#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;

int uf[250006], sz[250006];
long long res;
vector<int> v;

int _find(int x) {
    if (uf[x] == -1) return x;
    return uf[x] = _find(uf[x]);
}

long long _merge(int x, int y) {
    v.push_back(x);
    v.push_back(y);
    x = _find(x), y = _find(y);
    long long ret = 0;
    if (x != y) {
        ret = 1LL * sz[x] * sz[y];
        sz[y] += sz[x], uf[x] = y;
    }
    return ret;
}

int n, q, pr[250006];
vector<int> adj[250006];
bool selected[250006];

void dfs(int x, int prev = -1) {
    pr[x] = prev;
    for (auto &i: adj[x]) if (i != prev) dfs(i, x);
}

void initialize() {
    res = 0;
    for (auto &i: v) uf[i] = -1, sz[i] = 1;
    v.clear();
}

int main() {
    scanf("%d", &n);
    for (int i = 0; i < n; i++) uf[i] = -1, sz[i] = 1;
    for (int i = 0; i < n - 1; i++) {
        int x, y;
        scanf("%d%d", &x, &y);
        x--, y--;
        adj[x].push_back(y);
        adj[y].push_back(x);
    }
    dfs(0);
    scanf("%d", &q);
    while (q--) {
        initialize();
        vector<int> lt;
        int k;
        scanf("%d", &k);
        lt.resize(k);
        for (auto &i: lt) scanf("%d", &i);
        for (auto &i: lt) selected[--i] = true;
        for (auto &i: lt) if (i) if (selected[pr[i]]) res += _merge(i, pr[i]);
        for (auto &i: lt) selected[i] = false;
        printf("%lld\n", res);
    }
}

 

2. 식사 계획 세우기

서브태스크 3

 사전순으로 가장 빠른 올바른 식사 계획을 앞에서부터 세우는 전략을 생각합니다. 다음 전략을 이용해 선택한 음식을 올바른 식사 계획의 뒤에 추가하는 것을 반복하면 답을 얻을 수 있습니다.

  • 가장 많은 음식의 개수가 \(\frac{N+1}{2}\) 초과라면 불가능하므로 -1을 출력하고 종료합니다. 현재 시점에 가장 많은 음식을 가져가더라도 그 음식이 여전히 가장 많습니다.
  • 가장 많은 음식의 개수가 \(\frac{N+1}{2}\)라면 반드시 이 음식을 가져가야 합니다. 올바른 식사 계획이 사전순으로 가장 앞서야 하므로 이 음식 중 가장 앞에 있는 음식을 선택합니다.
  • 가장 많은 음식의 개수가 \(\frac{N+1}{2}\) 미만이라면 이전에 가져가지 않은 음식 종류 중 가장 앞에 있는 음식을 선택합니다.

 \(N\) 개의 음식을 하나씩 선택할 때 각각 \(O(N)\) 시간이 소요되므로 모두 \(O(N^2)\) 시간 안에 문제를 해결할 수 있습니다.

 

#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;

int n, a[300006];
bool selected[300006];
vector<int> v[300006], res;

int main() {
    scanf("%d", &n);
    for (int i = 0; i < n; i++) {
        scanf("%d", a + i);
        v[--a[i]].push_back(i);
    }
    for (int i = 0; i < n; i++) reverse(v[i].begin(), v[i].end());
    int pv = -1;
    for (int i = 0; i < n; i++) {
        int total = n - i;
        int mx = 0, s = -1;
        for (int i = 0; i < n; i++) {
            if ((int)v[i].size() > mx) mx = (int)v[i].size(), s = i;
        }
        if (mx >= total / 2 + 1) {
            if (pv == s) return puts("-1"), 0;
            pv = s;
            res.push_back(v[s].back());
            v[s].pop_back();
            continue;
        }
        mx = (int)1e9, s = -1;
        for (int i = 0; i < n; i++) if (!v[i].empty()) {
            if (v[i].back() < mx && pv != i) mx = v[i].back(), s = i;
        }
        if (s == -1) return puts("-1"), 0;
        res.push_back(v[s].back());
        pv = s;
        v[s].pop_back();
    }
    for (auto &i: res) printf("%d ", i + 1);
}

 

서브태스크 4

 서브태스크 3의 풀이에서 \(O(\log N)\) 시간 안에 다음 음식을 선택할 수 있습니다.

  • 가장 많은 음식과 그 개수를 \(O(\log N)\) 시간 안에 구할 수 있습니다. 순서쌍 (음식의 개수, 음식의 종류)를 std::set으로 관리합니다.
  • 각 음식 중 가장 앞에 있는 음식을 \(O(1)\) 시간 안에 구할 수 있습니다. 전처리를 통해 각 음식 종류별로 큐를 관리하면 음식의 종류를 알 때 가장 앞에 있는 음식을 상수 시간 안에 알 수 있습니다.
  • 가장 앞에 있는 음식과 그 개수를 \(O(\log N)\) 시간 안에 구할 수 있습니다. 순서쌍 (가장 앞에 있는 음식의 위치, 음식의 종류)를 std::set으로 관리합니다.

 이때 이전에 선택한 음식을 다시 선택할 수 없는데, std::set의 가장 위에 있는 음식을 얻기 전 이전에 선택한 음식을 std::set에서 삭제하고, 얻은 뒤에 삽입하면 역시 \(O(\log N)\) 시간 안에 처리할 수 있습니다. 또한 음식을 선택하면 std::set에서 바뀌는 원소는 최대 하나이므로 \(O(\log N)\) 시간 안에 처리할 수 있습니다.  \(N\) 개의 음식을 하나씩 선택할 때 각각 \(O(\log N)\) 시간이 소요되므로 모두 \(O(N\log N)\) 시간 안에 문제를 해결할 수 있습니다.

 

#include <set>
#include <cstdio>
#include <vector>
#include <utility>
#include <algorithm>
using namespace std;

int n, a[300006];
bool selected[300006];
vector<int> v[300006], res;
set<pair<int, int>> s1, s2;

int main() {
    scanf("%d", &n);
    for (int i = 0; i < n; i++) {
        scanf("%d", a + i);
        v[--a[i]].push_back(i);
    }
    for (int i = 0; i < n; i++) {
        reverse(v[i].begin(), v[i].end());
        s1.insert({ (int)v[i].size(), i });
        if (!v[i].empty()) s2.insert({ v[i].back(), i });
    }
    int pv = -1;
    for (int i = 0; i < n; i++) {
        int total = n - i;
        int mx = 0, s = -1;
        auto it = s1.rbegin();
        mx = it->first, s = it->second;
        if (mx >= total / 2 + 1) {
            if (pv == s) return puts("-1"), 0;
            pv = s;
            s1.erase({ (int)v[s].size(), s });
            s2.erase({ v[s].back(), s });
            res.push_back(v[s].back());
            v[s].pop_back();
            s1.insert({ (int)v[s].size(), s });
            if (!v[s].empty()) s2.insert({ v[s].back(), s });
            continue;
        }
        if (pv != -1 && !v[pv].empty()) s2.erase({ v[pv].back(), pv });
        if (!s2.empty()) s = s2.begin()->second;
        else s = -1;
        if (pv != -1 && !v[pv].empty()) s2.insert({ v[pv].back(), pv });
        if (s == -1) return puts("-1"), 0;
        pv = s;
        s1.erase({ (int)v[s].size(), s });
        s2.erase({ v[s].back(), s });
        res.push_back(v[s].back());
        v[s].pop_back();
        s1.insert({ (int)v[s].size(), s });
        if (!v[s].empty()) s2.insert({ v[s].back(), s });
    }
    for (auto &i: res) printf("%d ", i + 1);
}

 

3. 레벨 업

서브태스크 1

문제에 쓰인 내용을 그대로 구현하여 \(O(MN\log N+MK)\) 풀이를 얻습니다.

 

#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;

int n, m, k, l[100006];

int main() {
    scanf("%d", &n);
    for (int i = 0; i < n; i++) scanf("%d", l + i);
    scanf("%d%d", &m, &k);
    for (int i = 0; i < m; i++) {
        sort(l, l + n);
        for (int i = 0; i < k; i++) l[i]++;
    }
    sort(l, l + n);
    for (int i = 0; i < n; i++) printf("%d ", l[i]);
}

 

서브태스크 4

 \(K=1\)이라면 레벨이 가장 낮은 캐릭터의 레벨을 하나 높이는 것을 \(M\) 회 반복합니다. 이를 최초에 레벨이 가장 낮은 캐릭터의 레벨을 하나 높이는 것을 \(M\) 회 반복하고, 이 캐릭터와 레벨이 같은 캐릭터가 있다면 그 캐릭터의 레벨을 대신 높이는 것으로 생각한다면 결국 최초에 레벨이 가장 낮은 캐릭터가 레벨을 '높여 주게' 되는 캐릭터의 묶음이 생기고, 이 묶음의 캐릭터들은 최종 레벨이 최대 1 차이가 날 것입니다. 즉, 최초에 레벨이 가장 낮은 캐릭터의 최종 레벨이 \(l\)이라면 묶음 안의 캐릭터들은 최종 레벨이 \(l\) 또는 \(l+1\)입니다. 더 나아가, 캐릭터들의 높아지는 레벨의 합은 정확히 \(M\)입니다. 이때 \(l\)의 값과 묶음을 \(O(\log^2N)\) 시간 안에 구할 수 있습니다.

 

 배열 \(L\)이 정렬되어 있다고 가정하고, \(l\)의 값이 \(m\)과 같다고 가정합시다. 또한 \(k=1\)입니다. 이제 최초 레벨이 \(\left[L_k, m\right]\)에 속하는 캐릭터들은 하나의 묶음이고, 이들의 최종 레벨은 \(m\) 또는 \(m+1\)입니다. 정확히 \(p\) 개의 캐릭터들이 이 묶음에 속한다고 하고, \(n\) 개의 캐릭터들의 최종 레벨이 \(m+1\)이라고 하면 \(p-n\) 개의 캐릭터들의 최종 레벨은 \(m\)입니다. 여기에서 캐릭터들의 높아지는 레벨의 합이 정확히 \(M\)이므로 \(M=pm+n-\sum_{i=k}^pL_i)\)이어야 합니다. 곧, \(n\)의 값은 \(M+\sum_{i=k}^pL_i-pm\)입니다.

 이제 \(m\)에 대한 이분 탐색을 적용할 수 있습니다. \(n\geq p\)라면 반드시 \(l>m\)입니다. 반면, \(n<p\)라면 \(l\leq m\)입니다. \(l\)의 값과 묶음을 구했습니다. 이때 \(p\)의 값과 \(\sum_{i=k}^pL_i\)의 값을 세그먼트 트리를 통해 \(O(\log N)\) 시간에 구할 수 있으므로 전이분 탐색의 시간 복잡도는 \(O(\log N\log X)\)입니다. \(p\)의 값은 각 \(L_i\)의 위치에 1의 값을 가지는 합 세그먼트 트리를 만들어 구할 수 있으며, \(\sum_{i=k}^pL_i\)의 값은 각 \(L_i\)의 위치에 \(L_i\)의 값을 가지는 합 세그먼트 트리를 만들어 구할 수 있습니다. 이때 \(L_i\)의 값이 매우 크므로 좌표 압축을 사용해야 합니다.

 

 \(K=1\)이라는 조건이 없을 때에도 비슷하게 해결할 수 있습니다. 최초에 레벨이 가장 낮은 \(K\) 개의 캐릭터를 선택합니다. 선택된 캐릭터의 레벨을 하나씩 높이는 것을 \(M\) 회 반복하고, 각 캐릭터와 레벨이 같은 캐릭터가 있다면 그 캐릭터의 레벨을 대신 높이는 것으로 생각합니다. 이때에는 '묶음'이 여러 개 생기며, 묶음에 속하는 캐릭터 중 선택된 캐릭터가 \(r\) 개라면 캐릭터들의 높아지는 레벨의 합은 정확히 \(rM\) 입니다. 이제 \(K=1\)일 때의 방법을 모든 선택된 캐릭터에 대해 레벨이 높은 캐릭터부터 적용합니다.

  • 선택된 캐릭터 중 레벨이 가장 높은 캐릭터에 대해서는 위의 방법을 그대로 적용합니다.
  • 선택된 캐릭터 중 레벨이 두 번째로 높은 캐릭터에 대해서는 위의 방법을 적용하되, 묶음에 이미 처리한 캐릭터가 있는 경우에는 두 묶음이 합쳐집니다.
  • 선택된 캐릭터 중 레벨이 \(K-k+1\) 번째로 높은 캐릭터에 대해서는 위의 방법을 적용하되, 묶음에 이미 처리한 캐릭터가 있는 경우에는 두 묶음이 합쳐집니다.

 이때 위의 방법을 적용하며 \(x+1\) 번째 캐릭터의 레벨이 \(k\)가 되는 것으로 판명이 났을 때, \(x\) 번째 캐릭터의 묶음 범위가 \(\left[L_x,m\right]\)이고 \(m<k\)라면 묶음에 \(x+1\) 번째 캐릭터가 포함되지 않음에 주의해야 합니다. 이는 \(x+1\) 번째 캐릭터의 묶음을 처리한 뒤에 좌표 압축 배열상에서 묶음에 속하는 캐릭터들의 좌표를 \(k\)로 갱신함으로써 해결할 수 있습니다. 이를 위해서는 또 하나의 느리게 갱신하는 세그먼트 트리가 필요하며, 이 세그먼트 트리상에서 \(O(\log N)\) 시간 안에 동작하는 lower_bound와 upper_bound를 구현해야 합니다.

 

 위 방법의 시간 복잡도는 \(O(N\log N+K\log N\log X)\)입니다. 레벨의 값이 2×10⁹까지 커질 수 있음에 주의해야 합니다. 저는 합 세그먼트 트리를 펜윅 트리로 전환했습니다.

 

#pragma GCC optimize("O3")
#pragma GCC optimize("Ofast")
#pragma GCC optimize("unroll-loops")
#pragma GCC target("avx,avx2,fma")

#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;

#define int long long

int tree1[100006];

void update1(int x, int p) {
    while (x < 100003) {
        tree1[x] += p;
        x += x & -x;
    }
}

int query1(int x) {
    int res = 0;
    while (x > 0) {
        res += tree1[x];
        x -= x & -x;
    }
    return res;
}

int query1(int l, int r) {
    return query1(r + 1) - query1(l);
}

long long tree2[100006];

void update2(int x, long long p) {
    while (x < 100003) {
        tree2[x] += p;
        x += x & -x;
    }
}

long long query2(int x) {
    long long res = 0;
    while (x > 0) {
        res += tree2[x];
        x -= x & -x;
    }
    return res;
}

long long query2(int l, int r) {
    return query2(r + 1) - query2(l);
}

int tree3[262144];
int lazy[262144];

inline void propagate(int i, int b, int e) {
    if (lazy[i] == -1) return;
    tree3[i] = lazy[i];
    if (b != e) lazy[i * 2 + 1] = lazy[i * 2 + 2] = lazy[i];
    lazy[i] = -1;
}

int query3(int i, int b, int e, int x) {
    propagate(i, b, e);
    if (x < b || e < x) return 0;
    if (b == e) return tree3[i];
    int m = (b + e) / 2;
    return query3(i * 2 + 1, b, m, x) + query3(i * 2 + 2, m + 1, e, x);
}

int query3c(int i, int b, int e, int x) {
    propagate(i, b, e);
    int m = (b + e) / 2;
    if (b == e) return b;
    else propagate(i * 2 + 1, b, m);
    if (tree3[i * 2 + 1] < x) return query3c(i * 2 + 2, m + 1, e, x);
    return query3c(i * 2 + 1, b, m, x);
}

int query3u(int i, int b, int e, int x) {
    propagate(i, b, e);
    int m = (b + e) / 2;
    if (b == e) return b;
    else propagate(i * 2 + 1, b, m);
    if (tree3[i * 2 + 1] <= x) return query3u(i * 2 + 2, m + 1, e, x);
    return query3u(i * 2 + 1, b, m, x);
}

int update3(int i, int b, int e, int l, int r, int x) {
    propagate(i, b, e);
    if (r < b || e < l) return tree3[i];
    if (l <= b && e <= r) {
        lazy[i] = max(lazy[i], x);
        propagate(i, b, e);
        return tree3[i];
    }
    int m = (b + e) / 2;
    return tree3[i] = max(update3(i * 2 + 1, b, m, l, r, x), update3(i * 2 + 2, m + 1, e, l, r, x));
}

int n, m, k, l[100006];
int pl[100006], cl[100006], to[100006];

int getc(int x) {
    return query3c(0, 0, 131071, x);
}

int getu(int x) {
    return query3u(0, 0, 131071, x);
}

signed main() {
    scanf("%lld", &n);
    for (int i = 0; i < n; i++) scanf("%lld", l + i);
    scanf("%lld%lld", &m, &k);
    sort(l, l + n);
    for (int i = 0; i < n; i++) tree3[131071 + i] = l[i];
    for (int i = n; i < 131071; i++) tree3[131071 + i] = (int)2e9 + 17;
    for (int i = 131070; i >= 0; i--) tree3[i] = max(tree3[i * 2 + 1], tree3[i * 2 + 2]);
    for (int i = 0; i < 262144; i++) lazy[i] = -1;
    for (int i = 0; i < k; i++) update1(getc(l[i]) + 1, 1);
    for (int i = 0; i < n; i++) update2(getc(l[i]) + 1, l[i]);
    for (int i = n - 1; i >= k; i--) to[i] = l[i], pl[i] = 1, cl[i] = 0;
    for (int i = k - 1; i >= 0; i--) {
        int L = l[i], R = (int)2e9 + 7;
        int getcli = getc(l[i]);
        while (L < R) {
            int M = (L + R) / 2;
            int getum = getu(M);
            int N = getum - getcli;
            long long sum = 1LL * N * M - query2(getcli, getum - 1);
            long long curr = 1LL * query1(getcli, getum - 1) * m - sum;
            if (curr >= N) L = M + 1;
            else R = M;
        }
        to[i] = L;
        int getum = getu(to[i]);
        int N = getum - getcli;
        long long sum = 1LL * N * to[i] - query2(getcli, getum - 1);
        long long curr = 1LL * query1(getcli, getum - 1) * m - sum;
        pl[i] = N, cl[i] = curr;
        if (N <= curr) return -1;
        int x = getc(L);
        if (i <= x - 1) update3(0, 0, 131071, i, min(x, n) - 1, to[i]);
    }
    for (int i = 0; i < n; i += pl[i]) {
        if (!pl[i] || pl[i] > n || cl[i] > n) return 0;
        for (int j = 0; j < pl[i] - cl[i]; j++) printf("%lld ", to[i]);
        for (int j = 0; j < cl[i]; j++) printf("%lld ", to[i] + 1);
    }
}

 

4. 외곽 순환 도로

서브태스크 3

 \(w_i\)의 값이 매우 크므로 항상 외곽 순환 도로를 이용하는 것이 이용하지 않는 것보다 손해입니다. 따라서 트리상의 간선만을 이용해야 합니다. 트리상에서 두 정점 사이의 경로는 유일하므로, 유일한 경로의 길이를 구하면 됩니다. 이는 스파스 테이블을 이용해 해결할 수 있습니다. 시간 복잡도는 \(O(N\log N)\)입니다.

 

#include <cstdio>
#include <vector>
#include <utility>
#include <algorithm>
using namespace std;

int q, n, p[100006], c[100006];
int sp[20][100006], depth[100006];
long long ds[20][100006];
vector<pair<int, int>> adj[100006];

int dfs(int x) {
    int res = 0;
    if (adj[x].empty()) res++;
    for (auto &i: adj[x]) {
        depth[i.first] = depth[x] + 1;
        res += dfs(i.first);
    }
    return res;
}

long long dist(int x, int y) {
    long long res = 0;
    if (depth[x] < depth[y]) swap(x, y);
    int diff = depth[x] - depth[y];
    for (int t = 19; t >= 0; t--) if (diff >= 1 << t) {
        diff -= 1 << t;
        res += ds[t][x];
        x = sp[t][x];
    }
    for (int t = 19; t >= 0; t--) {
        if (sp[t][x] == sp[t][y] || sp[t][x] == -1 || sp[t][y] == -1) continue;
        res += ds[t][x] + ds[t][y];
        x = sp[t][x], y = sp[t][y];
    }
    if (x != y) res += ds[0][x] + ds[0][y];
    return res;
}

int main() {
    scanf("%d", &n);
    for (int i = 1; i < n; i++) scanf("%d%d", p + i, c + i);
    for (int i = 0; i < n; i++) p[i]--;
    for (int i = 0; i < n; i++) {
        sp[0][i] = p[i];
        ds[0][i] = c[i];
    }
    for (int i = 1; i < n; i++) adj[p[i]].push_back({ i, c[i] });
    int T = dfs(0);
    ds[0][0] = (long long)4e18;
    for (int t = 1; t < 20; t++) for (int i = 0; i < n; i++) {
        if (sp[t - 1][i] == -1) {
            sp[t][i] = -1;
            ds[t][i] = (long long)4e18;
        } else {
            sp[t][i] = sp[t - 1][sp[t - 1][i]];
            ds[t][i] = ds[t - 1][i] + ds[t - 1][sp[t - 1][i]];
        }
    }
    for (int i = 0; i < T; i++) scanf("%*d");
    scanf("%d", &q);
    while (q--) {
        int x, y;
        scanf("%d%d", &x, &y);
        x--, y--;
        printf("%lld\n", dist(x, y));
    }
}

 

서브태스크 1

 항상 \(u=1\)이므로 다익스트라 알고리즘을 이용하여 정점 1로부터 모든 정점까지의 거리를 전처리합니다. 시간 복잡도는 \(O(N\log N)\)입니다.

 

#include <queue>
#include <cstdio>
#include <vector>
#include <utility>
#include <algorithm>
#include <functional>
using namespace std;

int Q, n, p[100006];
long long c[100006], w[100006], dist[100006];
int depth[100006];
vector<pair<int, long long>> adj[100006];
vector<int> leaves;

void dfs(int x, int prev = -1) {
    if ((int)adj[x].size() == 1) leaves.push_back(x);
    for (auto &i: adj[x]) if (i.first != prev) {
        depth[i.first] = depth[x] + 1;
        dfs(i.first, x);
    }
}

int main() {
    scanf("%d", &n);
    for (int i = 1; i < n; i++) scanf("%d%lld", p + i, c + i);
    for (int i = 0; i < n; i++) p[i]--;
    for (int i = 1; i < n; i++) {
        adj[p[i]].push_back({ i, c[i] });
        adj[i].push_back({ p[i], c[i] });
    }
    dfs(0);
    for (int i = 0; i < (int)leaves.size(); i++) {
        scanf("%lld", w + i);
        adj[leaves[i]].push_back({ leaves[(i + 1) % (int)leaves.size()], w[i] });
        adj[leaves[(i + 1) % (int)leaves.size()]].push_back({ leaves[i], w[i] });
    }
    priority_queue<pair<long long, int>, vector<pair<long long, int>>, greater<pair<long long, int>>> q;
    q.push({ 0, 0 });
    for (int i = 1; i < n; i++) dist[i] = (long long)4e18;
    while (!q.empty()) {
        pair<long long, int> t = q.top(); q.pop();
        if (dist[t.second] < t.first) continue;
        for (auto &i: adj[t.second]) if (dist[i.first] > t.first + i.second) {
            dist[i.first] = t.first + i.second;
            q.push({ dist[i.first], i.first });
        }
    }
    scanf("%d", &Q);
    while (Q--) {
        int x, y;
        scanf("%d%d", &x, &y);
        x--, y--;
        printf("%lld\n", dist[y]);
    }
}

 

서브태스크 2

 \(u=1\)인 경우는 서브태스크 1에서 구한 값을 그대로 이용합니다. \(u\neq1\), \(v\neq1\)인 경우만을 살펴봅시다. 이때 \(u\)와 \(v\) 사이의 경로는 정점 1을 지나는 경로와 1번 정점을 지나지 않는 경로가 있습니다. 정점 1을 지나는 경로의 최소 길이는 정점 1과 정점 \(u\) 사이의 거리와, 정점 1과 정점 \(v\) 사이의 거리의 합입니다. 정점 1을 지나지 않는 경로는 반드시 외곽 순환 도로만을 지납니다. 외곽 순환 도로의 가중치의 누적 합을 이용해 경로의 최소 길이를 구할 수 있습니다. 특정 지점으로부터 정점 \(u\)까지의 가중치의 누적 합이 \(U\), 정점 \(v\)까지의 가중치의 누적 합이 \(V\)일 때 경로의 최소 길이는 \(\min(\text{abs}(U-V),\sum w_i-\text{abs}(U-V))\)입니다. 두 경우에서 최소의 길이를 선택하면 됩니다. 시간 복잡도는 \(O(N\log N)\)입니다.

 

#include <queue>
#include <cstdio>
#include <vector>
#include <utility>
#include <algorithm>
#include <functional>
using namespace std;

int Q, n, p[100006];
long long c[100006], w[100006], dist[100006], a[100006];
int depth[100006];
vector<pair<int, long long>> adj[100006];
vector<int> leaves;

void dfs(int x, int prev = -1) {
    if ((int)adj[x].size() == 1) leaves.push_back(x);
    for (auto &i: adj[x]) if (i.first != prev) {
        depth[i.first] = depth[x] + 1;
        dfs(i.first, x);
    }
}

int main() {
    scanf("%d", &n);
    for (int i = 1; i < n; i++) scanf("%d%lld", p + i, c + i);
    for (int i = 0; i < n; i++) p[i]--;
    for (int i = 1; i < n; i++) {
        adj[p[i]].push_back({ i, c[i] });
        adj[i].push_back({ p[i], c[i] });
    }
    dfs(0);
    long long sum = 0;
    for (int i = 0; i < (int)leaves.size(); i++) {
        scanf("%lld", w + i);
        sum += w[i];
        adj[leaves[i]].push_back({ leaves[(i + 1) % (int)leaves.size()], w[i] });
        adj[leaves[(i + 1) % (int)leaves.size()]].push_back({ leaves[i], w[i] });
    }
    priority_queue<pair<long long, int>, vector<pair<long long, int>>, greater<pair<long long, int>>> q;
    q.push({ 0, 0 });
    for (int i = 1; i < n; i++) dist[i] = (long long)4e18;
    while (!q.empty()) {
        pair<long long, int> t = q.top(); q.pop();
        if (dist[t.second] < t.first) continue;
        for (auto &i: adj[t.second]) if (dist[i.first] > t.first + i.second) {
            dist[i.first] = t.first + i.second;
            q.push({ dist[i.first], i.first });
        }
    }
    for (int i = 0; i < (int)leaves.size(); i++) {
        if (i) a[leaves[i]] = a[leaves[i - 1]];
        if (i) a[leaves[i]] += w[i - 1];
    }
    scanf("%d", &Q);
    while (Q--) {
        int x, y;
        scanf("%d%d", &x, &y);
        x--, y--;
        if (x > y) swap(x, y);
        if (!x) {
            printf("%lld\n", dist[y]);
            continue;
        }
        long long D = a[y] - a[x];
        printf("%lld\n", min({ dist[x] + dist[y], D, sum - D }));
    }
}

 

서브태스크 4

 \(w_i=0\)이므로 외곽 순환 도로의 가중치가 없습니다. 따라서 외곽 순환 도로를 하나의 정점으로 가정하고 다익스트라 알고리즘을 이용해 각 정점에서 외곽 순환 도로까지의 거리를 구할 수 있습니다. \(u\)와 \(v\)를 연결하는 경로는 외곽 순환 도로를 지나거나 지나지 않습니다. 외곽 순환 도로를 지나는 경로의 최단 길이는 외곽 순환 도로로부터 \(u\)까지의 거리와 \(v\)까지의 거리의 합입니다. 또한 외곽 순환 도로를 지나지 않는 경로는 트리상의 간선만을 이용하므로 서브태스크 3과 같습니다. 시간 복잡도는 \(O(N\log N)\)입니다.

 

#include <queue>
#include <cstdio>
#include <vector>
#include <utility>
#include <algorithm>
#include <functional>
using namespace std;

int Q, n, p[100006], sp[20][100006];
long long c[100006], w[100006], dist[100006], ds[20][100006];
int depth[100006];
vector<pair<int, long long>> adj[100006];
vector<int> leaves;

void dfs(int x, int prev = -1) {
    if ((int)adj[x].size() == 1) leaves.push_back(x);
    for (auto &i: adj[x]) if (i.first != prev) {
        depth[i.first] = depth[x] + 1;
        dfs(i.first, x);
    }
}

long long tree_dist(int x, int y) {
    long long res = 0;
    if (depth[x] < depth[y]) swap(x, y);
    int diff = depth[x] - depth[y];
    for (int t = 19; t >= 0; t--) if (diff >= 1 << t) {
        diff -= 1 << t;
        res += ds[t][x];
        x = sp[t][x];
    }
    for (int t = 19; t >= 0; t--) {
        if (sp[t][x] == sp[t][y] || sp[t][x] == -1 || sp[t][y] == -1) continue;
        res += ds[t][x] + ds[t][y];
        x = sp[t][x], y = sp[t][y];
    }
    if (x != y) res += ds[0][x] + ds[0][y];
    return res;
}

int main() {
    scanf("%d", &n);
    for (int i = 1; i < n; i++) scanf("%d%lld", p + i, c + i);
    for (int i = 0; i < n; i++) p[i]--;
    for (int i = 1; i < n; i++) {
        adj[p[i]].push_back({ i, c[i] });
        adj[i].push_back({ p[i], c[i] });
    }
    for (int i = 0; i < n; i++) {
        sp[0][i] = p[i];
        ds[0][i] = c[i];
    }
    dfs(0);
    ds[0][0] = (long long)4e18;
    for (int t = 1; t < 20; t++) for (int i = 0; i < n; i++) {
        if (sp[t - 1][i] == -1) {
            sp[t][i] = -1;
            ds[t][i] = (long long)4e18;
        } else {
            sp[t][i] = sp[t - 1][sp[t - 1][i]];
            ds[t][i] = ds[t - 1][i] + ds[t - 1][sp[t - 1][i]];
        }
    }
    for (int i = 0; i < (int)leaves.size(); i++) {
        scanf("%lld", w + i);
        adj[leaves[i]].push_back({ leaves[(i + 1) % (int)leaves.size()], w[i] });
        adj[leaves[(i + 1) % (int)leaves.size()]].push_back({ leaves[i], w[i] });
    }
    priority_queue<pair<long long, int>, vector<pair<long long, int>>, greater<pair<long long, int>>> q;
    for (int i = 0; i < n; i++) dist[i] = (long long)4e18;
    for (auto &i: leaves) q.push({ 0, i }), dist[i] = 0;
    while (!q.empty()) {
        pair<long long, int> t = q.top(); q.pop();
        if (dist[t.second] < t.first) continue;
        for (auto &i: adj[t.second]) if (dist[i.first] > t.first + i.second) {
            dist[i.first] = t.first + i.second;
            q.push({ dist[i.first], i.first });
        }
    }
    scanf("%d", &Q);
    while (Q--) {
        int x, y;
        scanf("%d%d", &x, &y);
        x--, y--;
        if (x > y) swap(x, y);
        printf("%lld\n", min({ dist[x] + dist[y], tree_dist(x, y) }));
    }
}

 

서브태스크 6

 트리를 센트로이드 분할하며 각 센트로이드에 대해 다익스트라 알고리즘을 이용한다고 합니다. 자세한 풀이가 완성되면 갱신하겠습니다.