題目
521. 運輸計劃
算法標簽: 樹上倍增, l c a lca lca, 前綴和, 樹上差分, 二分
思路
注意到答案是具有二分性質的, 對于某個時間 m i d mid mid假設是最優答案, 小于該時間是不可以的, 但是大于該時間是可行的, 因此可以二分答案
這樣就將問題轉化為, 對于給定的時間 m i d mid mid, 將樹中的一條邊權變為 0 0 0, 所有的運輸路線耗時是否 ≤ m i d \le mid ≤mid
可以將所有運輸的路線分為兩類, 一種是運輸時間 ≤ m i d \le mid ≤mid的, 這種路線不要需要刪除邊
但是還有一種路線是 > m i d > mid >mid, 對于這些路線需要找個這些路線的公共邊, 將這個公共邊的權值變為 0 0 0, 但是直接枚舉所有的邊和路線會超時, 因此需要進行優化
可以在所有路線上的邊 + 1 + 1 +1, 最終結果就是公共邊被加了 t t t次, t t t是大于 m i d mid mid的路線的數量, 這樣就找到了這個邊, 利用樹上差分, 實現對每個邊 + 1 +1 +1的操作
代碼
#include <iostream>
#include <algorithm>
#include <cstring>
#include <vector>using namespace std;const int N = 300010, M = N << 1, K = 19;int n, m;
int head[N], ed[M], ne[M], w[M], idx;
int fa[N][K], depth[N], d[N];
struct Path {int u, v, p, d;
} path[N];
int s[N];void add(int u, int v, int val) {ed[idx] = v, ne[idx] = head[u], w[idx] = val, head[u] = idx++;
}void dfs(int u, int pre, int dep) {depth[u] = dep;for (int i = head[u]; ~i; i = ne[i]) {int v = ed[i];if (v == pre) continue;fa[v][0] = u;for (int k = 1; k < K; ++k) fa[v][k] = fa[fa[v][k - 1]][k - 1];d[v] = d[u] + w[i];dfs(v, u, dep + 1);}
}int lca(int u, int v) {if (depth[u] < depth[v]) swap(u, v);for (int k = K - 1; k >= 0; --k) {if (depth[fa[u][k]] >= depth[v]) {u = fa[u][k];}}if (u == v) return v;for (int k = K - 1; k >= 0; --k) {if (fa[u][k] != fa[v][k]) {u = fa[u][k];v = fa[v][k];}}return fa[u][0];
}void dfs_sum(int u, int pre) {for (int i = head[u]; ~i; i = ne[i]) {int v = ed[i];if (v == pre) continue;dfs_sum(v, u);s[u] += s[v];}
}bool check(int mid) {memset(s, 0, sizeof s);int c = 0, max_d = 0;for (int i = 0; i < m; ++i) {auto [u, v, p, val] = path[i];if (val > mid) {c++;max_d = max(max_d, val);s[u]++;s[v]++;s[p] -= 2;}}if (c == 0) return true;dfs_sum(1, -1);for (int u = 2; u <= n; ++u) {if (s[u] == c && max_d - (d[u] - d[fa[u][0]]) <= mid) {return true;}}return false;
}int main() {ios::sync_with_stdio(false);cin.tie(0), cout.tie(0);memset(head, -1, sizeof head);cin >> n >> m;for (int i = 0; i < n - 1; ++i) {int u, v, w;cin >> u >> v >> w;add(u, v, w), add(v, u, w);}dfs(1, -1, 1);for (int i = 0; i < m; ++i) {int u, v;cin >> u >> v;int p = lca(u, v);int dis = d[u] + d[v] - 2 * d[p];path[i] = {u, v, p, dis};}int l = 0, r = 3e8;while (l < r) {int mid = l + r >> 1;if (check(mid)) r = mid;else l = mid + 1;}cout << l << "\n";return 0;
}
* v e c t o r vector vector存鄰接表會超時
#include <iostream>
#include <algorithm>
#include <cstring>
#include <vector>using namespace std;typedef pair<int, int> PII;
const int N = 300010, M = N << 1, K = 19;int n, m;
vector<PII> head[N];
int fa[N][K], depth[N], d[N];
struct Path {int u, v, p, d;
};
vector<Path> path;
int s[M];void init() {path.resize(m + 1);
}void add(int u, int v, int w) {head[u].push_back({v, w});
}void dfs(int u, int pre, int dep) {depth[u] = dep;for (auto [v, w] : head[u]) {if (v == pre) continue;fa[v][0] = u;for (int k = 1; k < K; ++k) fa[v][k] = fa[fa[v][k - 1]][k - 1];d[v] = d[u] + w;dfs(v, u, dep + 1);}
}int lca(int u, int v) {if (depth[u] < depth[v]) swap(u, v);for (int k = K - 1; k >= 0; --k) {if (depth[fa[u][k]] >= depth[v]) {u = fa[u][k];}}if (u == v) return u;for (int k = K - 1; k >= 0; --k) {if (fa[u][k] != fa[v][k]) {u = fa[u][k];v = fa[v][k];}}return fa[u][0];
}void dfs_sum(int u, int fa) {for (auto [v, w] : head[u]) {if (v == fa) continue;dfs_sum(v, u);s[u] += s[v];}
}bool check(int mid) {memset(s, 0, sizeof s);int cnt = 0, max_d = 0;for (auto [u, v, p, dis] : path) {if (dis > mid) {cnt++;s[u]++;s[v]++;s[p] -= 2;max_d = max(max_d, dis);}}if (cnt == 0) return true;dfs_sum(1, -1);for (int u = 2; u <= n; ++u) {if (s[u] == cnt && max_d - (d[u] - d[fa[u][0]]) <= mid) return true;}return false;
}int main() {ios::sync_with_stdio(false);cin.tie(0), cout.tie(0);cin >> n >> m;init();for (int i = 0; i < n - 1; ++i) {int u, v, w;cin >> u >> v >> w;add(u, v, w), add(v, u, w);}dfs(1, -1, 1);for (int i = 0; i < m; ++i) {int u, v;cin >> u >> v;int p = lca(u, v);path[i] = {u, v, p, d[u] + d[v] - 2 * d[p]};}int l = 0, r = 3e8;while (l < r) {int mid = l + r >> 1;if (check(mid)) r = mid;else l = mid + 1;}cout << l << "\n";return 0;
}