題目鏈接: 三元組最小距離
定義三元組 $(a, b, c)$($a,b,c$ 均為整數)的距離 $D=|a-b|+|b-c|+|c-a|$。
給定 $3$ 個非空整數集合 $S_1, S_2, S_3$,按升序分別存儲在 $3$ 個數組中。
請設計一個盡可能高效的算法,計算并輸出所有可能的三元組 $(a, b, c)$($a \in S_1,b \in S_2,c \in S_3$)中的最小距離。
例如 $S_1=\{-1, 0, 9\}, S_2=\{-25, -10, 10, 11\}, S_3=\{2, 9, 17, 30, 41\}$ 則最小距離為 $2$,相應的三元組為 $(9,10,9)$。
輸入格式
第一行包含三個整數 $l,m,n$,分別表示 $S_1,S_2,S_3$ 的長度。
第二行包含 l 個整數,表示 $S_1$ 中的所有元素。
第三行包含 $m$ 個整數,表示 $S_2$ 中的所有元素。
第四行包含 $n$ 個整數,表示 $S_3$ 中的所有元素。
以上三個數組中的元素都是按升序順序給出的。
輸出格式
輸出三元組的最小距離。
數據范圍
$1 \le l,m,n \le 10^5$,
所有數組元素的取值范圍 $[-10^9,10^9]$。
輸入樣例:
3 4 5
-1 0 9
-25 -10 10 11
2 9 17 30 41
輸出樣例:
2
-
暴力想法: 枚舉所有可能的答案排列 從小到大 [S1 S2 S3], [S3 S2 S1], [S2, S3, S1], [S1, S3, S2], [S2, S1, S3], [S3, S1, S2], 如果單純暴力復雜度就是O(6n^3) 鐵定過不了,這時候我們可以選擇枚舉中間值屬于哪個素組, 在每個枚舉中我們已經確定了中間的數,那么我們就可以根據這個中間值在另外兩個數組中二分找到符合排列的數,這里又有兩個前后問題,我們還是直接枚舉比較,比如確定中間的數來自于S2,那么答案排列可能是S1,S2,S3, 或者 S3, S2, S1這樣算下來總的時間復雜度O(3nlogn).
-
代碼
#include<bits/stdc++.h>
using namespace std;const int N = 1e5 + 10;
typedef long long LL;int main()
{int l, n, m; cin >> l >> n >> m;vector<int> a(l), b(n), c(m);for(int i = 0; i < l; i ++) cin >> a[i];for(int i = 0; i < n; i ++) cin >> b[i];for(int i = 0; i < m; i ++) cin >> c[i];// a中元素當中間值LL ans = 1e12;for(auto it : a){//b a cint dexb = upper_bound(b.begin(), b.end(), it) - b.begin();if(dexb != 0) dexb --;int dexc = lower_bound(c.begin(), c.end(), it) - c.begin();if(dexc == c.size()) dexc --;int xa = it, xb = b[dexb], xc = c[dexc];ans = min(ans, 1ll*abs(xa - xb) + abs(xa - xc) + abs(xb - xc));//c a bdexc = upper_bound(c.begin(), c.end(), it) - c.begin();if(dexc != 0) dexc --;dexb = lower_bound(b.begin(), b.end(), it) - b.begin();if(dexb == b.size()) dexb --;xa = it, xb = b[dexb], xc = c[dexc];ans = min(ans, 1ll*abs(xa - xb) + abs(xa - xc) + abs(xb - xc));}// b 當中間值for(auto it : b){//a b cint dexa = upper_bound(a.begin(), a.end(), it) - a.begin();if(dexa != 0) dexa --;int dexc = lower_bound(c.begin(), c.end(), it) - c.begin();if(dexc == c.size()) dexc --;int xb = it, xa = a[dexa], xc = c[dexc];ans = min(ans, 1ll*abs(xa - xb) + abs(xa - xc) + abs(xb - xc));//c b adexc = upper_bound(c.begin(), c.end(), it) - c.begin();if(dexc != 0) dexc --;dexa = lower_bound(a.begin(), a.end(), it) - a.begin();if(dexa == a.size()) dexa --;xb = it, xa = a[dexa], xc = c[dexc];ans = min(ans, 1ll*abs(xa - xb) + abs(xa - xc) + abs(xb - xc));}// c 當中間值for(auto it : c){//a c bint dexa = upper_bound(a.begin(), a.end(), it) - a.begin();if(dexa != 0) dexa --;int dexb = lower_bound(b.begin(), b.end(), it) - b.begin();if(dexb == b.size()) dexb --;int xc = it, xa = a[dexa], xb = b[dexb];ans = min(ans, 1ll*abs(xa - xb) + abs(xa - xc) + abs(xb - xc));//b c adexb = upper_bound(b.begin(), b.end(), it) - b.begin();if(dexb != 0) dexb --;dexa = lower_bound(a.begin(), a.end(), it) - a.begin();if(dexa == a.size()) dexa --;xc = it, xa = a[dexa], xb = b[dexb];ans = min(ans, 1ll*abs(xa - xb) + abs(xa - xc) + abs(xb - xc));}cout << ans << endl;return 0;
}
- 另一個思路:滑動窗口 O(3n*log(3n))實現
我們可以標記好每個數來自哪個數組然后統一排序,滑動窗口找相鄰三個不同歸屬的數進行答案比較,這個和上面的相比實現更簡單些! - 代碼
#include<bits/stdc++.h>
using namespace std;const int N = 1e5 + 10;
typedef long long LL;int main()
{int l, n, m; cin >> l >> n >> m;// 可以結構體數組,或者pair 存數值與所屬關系,用multimap純屬個人偷懶行為multimap<int, int> cnt;for(int i = 0; i < l; i ++) {int x; cin >> x;cnt.insert(pair<int, int>(x, 1));}for(int i = 0; i < n; i ++){int x; cin >> x;cnt.insert(pair<int, int>(x, 2));}for(int i = 0; i < m; i ++){int x; cin >> x;cnt.insert(pair<int, int>(x, 3));}LL ans = 1e18;vector<LL> st(4, 1e10);// 將st1,2,3賦值1e10表示空for(auto [a,b] : cnt){st[b] = a;if(st[1] != 1e10 && st[2] != 1e10 && st[3] != 1e10){ans = min(ans, 1ll*abs(st[1]-st[2]) + abs(st[1]-st[3]) + abs(st[2]-st[3]));}}cout << ans << endl;return 0;
}
- 進階思路:O(3n)實現 三路歸并
假設x < y < z 我們化簡 |x-y|+|y-z|+|z-x| 后發現 我們每一次算出的答案都是2*(max-min)
所以我們只需要每個三元組中的最大值與最小值即可,所以我們盡可能讓max與min逼近
這時候就可以三路歸并,每次只讓最小的去靠近最大的值,實現也很簡單。 - 代碼
#include<bits/stdc++.h>
using namespace std;const int N = 1e5 + 10;
typedef long long LL;int main()
{int l, n, m; cin >> l >> n >> m;vector<int> a(l), b(n), c(m);for(int i = 0; i < l; i ++) cin >> a[i];for(int i = 0; i < n; i ++) cin >> b[i];for(int i = 0; i < m; i ++) cin >> c[i];LL ans = 1e18;for(int i = 0, j = 0, k = 0; i < l && j < n && k < m;){int x = a[i], y = b[j], z = c[k];ans = min(ans, 2*(1ll*max(x, max(y, z)) - min(x, min(y, z))));if(x <= y && x <= z) i ++;else if(y <= x && y <= z) j ++;else k ++;}cout << ans << endl;return 0;
}