根據貪心,不難想到每次會把最長隊伍末尾的那輛車移動到最短隊伍的末尾。但由于 k k k 的存在,會導致一些冗余移動的存在。設需要挪動 C C C 輛車,則怒氣值可以表示為 f ( C ) + k C f(C) + kC f(C)+kC,其中 f ( C ) f(C) f(C) 是排隊所產生的怒氣值, k C kC kC 為變道產生的額外怒氣值。仔細分析以后,可以發現這是一個凸函數,因此考慮三分答案。
一開始想要三分需要挪車的最短長度 y y y,但是不能忽略 k k k 的影響,有些隊伍的長度雖然 > y > y >y,但挪動不移動會更優。于是三分挪動車輛的數量才是最優的。
具體來說,可以枚舉哪些隊伍的車輛會減少/增加。若現在考慮會減少的隊伍的車輛,給 a i a_i ai? 排序后,設當前最長隊伍的車輛數為 x x x,次長的為 y y y ( x ≠ y x \neq y x=y),然后長度為 x , y x,y x,y 的隊伍的數量分別為 f x , f y f_x,f_y fx?,fy?。若共需要移動 C C C 輛車,則有兩種情況:
-
C ≥ ( x ? y ) × f x C \ge (x - y) \times f_x C≥(x?y)×fx?,也就是說長度為 x x x 的車可以直接變為 y y y, C ← C ? ( x ? y ) × f x ; f y ← f x + f y ; f x ← 0 C \leftarrow C - (x - y) \times f_x;\ f_y \leftarrow f_x + f_y;\ f_x \leftarrow 0 C←C?(x?y)×fx?;?fy?←fx?+fy?;?fx?←0。
-
C < ( x ? y ) × f x C < (x - y) \times f_x C<(x?y)×fx?,此時會產生新的隊伍長度,也就是 C ← 0 ; f x ? ? C f x ? ? 1 ← f x ? ? C f x ? ? 1 + C m o d f x ; ← f x ? ? C f x ? + ( f x ? C m o d f x ) C \leftarrow 0;\ f_{x - \lfloor\frac{C}{f_x}\rfloor - 1} \leftarrow f_{x - \lfloor\frac{C}{f_x}\rfloor - 1} + C \bmod f_x;\ \leftarrow f_{x - \lfloor\frac{C}{f_x}\rfloor} + (f_x - C \bmod f_x) C←0;?fx??fx?C???1?←fx??fx?C???1?+Cmodfx?;?←fx??fx?C???+(fx??Cmodfx?)。
可以發現最后隊伍長度的種類數不會超過 n + 2 n + 2 n+2,因此這是 O ( n ) O(n) O(n) 的。考慮增加的隊伍的車輛同理,用 STL 來寫會簡單一點。但是由于多了一支 log ? \log log,實測會超時:
ll tot = sum * k,res = sum,number = sum;
set <int> s;map <int,int> bg,sm;
s.insert (-1e9);
for (int i = 1;i <= n;++i) s.insert (a[i]),++bg[a[i]];
while (sum)
{int x = *(--s.end ()),num = bg[x];s.erase (x);int y = *(--s.end ());if (sum >= 1ll * (x - y) * num){sum -= 1ll * (x - y) * num;bg[y] += num;bg[x] = 0;}else {bg[x] = 0;int tmp = sum % num;if (tmp) bg[x - sum / num - 1] += tmp;bg[x - sum / num] += num - tmp;sum = 0;}
}
s.clear ();
for (auto [x,num] : bg)if (num) s.insert (x),sm[x] = num;
s.insert (1e9);
while (res)
{int x = *s.begin (),num = sm[x];s.erase (x);int y = *s.begin ();if (res >= 1ll * (y - x) * num){res -= 1ll * (y - x) * num;sm[y] += num;sm[x] = 0;}else{sm[x] = 0;int tmp = res % num;if (tmp) sm[x + res / num + 1] += tmp;sm[x + res / num] += num - tmp;res = 0;}
}
for (auto [x,num] : sm) tot += 1ll * x * (x + 1) / 2 * num;
return tot;
};
再次思考可以發現 STL 的 log ? \log log 完全是多余的,可以通過數組來替代,但需要小心清空與去重的問題。最后的 AC 代碼如下,時間復雜度 O ( n log ? n ) O(n \log n) O(nlogn):
#include <bits/stdc++.h>
#define init(x) memset (x,0,sizeof (x))
#define ll long long
#define ull unsigned long long
#define INF 2e18
#define pii pair <int,int>
using namespace std;
const int MAX = 2e5 + 5;
const int MOD = 1e9 + 7;
inline int read ();
int a[MAX],b[MAX];
vector <int> bg (1000001,0),sm (1000001,0);
void solve ()
{int n = read (),k = read ();ll ans = INF;for (int i = 1;i <= n;++i) a[i] = read ();sort (a + 1,a + 1 + n);auto check = [&] (ll sum) -> ll{ll tot = sum * k,res = sum;int cnt = 0;vector <int> p;for (int i = 1;i <= n;++i) p.push_back (a[i]);for (int i = 1;i <= n;++i) {if (!bg[a[i]]) b[++cnt] = a[i];++bg[a[i]];}b[0] = -1e9;while (sum > 0){int x = b[cnt--],num = bg[x];int y = b[cnt];if (sum >= 1ll * (x - y) * num){sum -= 1ll * (x - y) * num;bg[y] += num;bg[x] = 0;}else {bg[x] = 0;int tmp = sum % num;bg[x - sum / num] += num - tmp,p.push_back (x - sum / num);if (tmp) bg[x - sum / num - 1] += tmp,p.push_back (x - sum / num - 1);sum = 0;}}cnt = 0;for (auto v : p)if (bg[v]) b[++cnt] = v,sm[v] = bg[v],bg[v] = 0;p.clear ();for (int i = 1;i <= cnt;++i) p.push_back (b[i]);b[++cnt] = 1e9;cnt = 1;while (res > 0){int x = b[cnt++],num = sm[x];int y = b[cnt];if (res >= 1ll * (y - x) * num){res -= 1ll * (y - x) * num;sm[y] += num;sm[x] = 0;}else{sm[x] = 0;int tmp = res % num;if (tmp) sm[x + res / num + 1] += tmp,p.push_back (x + res / num + 1);sm[x + res / num] += num - tmp,p.push_back (x + res / num);res = 0;}}for (auto v : p) tot += 1ll * v * (v + 1) / 2 * sm[v],sm[v] = 0;return tot;};ll l = 0,r = accumulate (a + 1,a + n + 1,0ll);while (l < r){ll midl = l + (r - l) / 3,midr = r - (r - l) / 3;ll v1 = check (midl),v2 = check (midr);ans = min (ans,min (v1,v2));if (v1 <= v2) r = midr - 1;else l = midl + 1;}printf ("%lld\n",ans);
}
int main ()
{int t = read ();while (t--) solve ();return 0;
}
inline int read ()
{int s = 0;int f = 1;char ch = getchar ();while ((ch < '0' || ch > '9') && ch != EOF){if (ch == '-') f = -1;ch = getchar ();}while (ch >= '0' && ch <= '9'){s = s * 10 + ch - '0';ch = getchar ();}return s * f;
}