比賽鏈接
C. 變化的數組(Easy Version)
題目大意
一個長度為 n n n 的非負數組 a a a,要求執行 k k k 次操作,每次操作如下:
- 有 1 2 \frac{1}{2} 21? 的概率令 a i ← a i + ( a i ? m ) + x , ? i ∈ [ 1 , n ] a_i \leftarrow a_i + (a_i \otimes m) + x, \ \forall i \in [1, n] ai?←ai?+(ai??m)+x,??i∈[1,n];
- 另 1 2 \frac{1}{2} 21? 的概率保持 ? a i \forall a_i ?ai? 不變。
求 ∑ i = 1 n a i \sum\limits_{i = 1}^{n}a_i i=1∑n?ai? 的期望,答案對 998244353 998244353 998244353 取模。
其中 ? \otimes ? 表示按位與,例如 ( 10 ) 2 ? ( 11 ) 2 = ( 10 ) 2 , ( 01 ) 2 ? ( 10 ) 2 = 0 (10)_2 \otimes (11)_2 = (10)_2, \ (01)_2 \otimes (10)_2 = 0 (10)2??(11)2?=(10)2?,?(01)2??(10)2?=0。
數據范圍
- 1 ≤ n ≤ 1 0 6 , 1 \leq n \leq 10^6, 1≤n≤106,
- 1 ≤ m , k ≤ 5 ? 1 0 3 , 1 \leq m, k \leq 5 \cdot 10^3, 1≤m,k≤5?103,
- 0 ≤ x ≤ 1 0 5 , 0 \leq x \leq 10^5, 0≤x≤105,
- 0 ≤ a i ≤ 1 0 9 . 0 \leq a_i \leq 10^9. 0≤ai?≤109.
Solution
我們觀察 a i a_i ai? 的增量 ( a i ? m ) + x (a_i \otimes m) + x (ai??m)+x,發現除了給定的 x x x, ( a i ? m ) (a_i \otimes m) (ai??m) 只與后 ? log ? 2 m ? \lfloor \log_2 m \rfloor ?log2?m? 位有關,于是記 M = 2 ? log ? 2 m ? + 1 , M = 2^{\lfloor \log_2 m \rfloor + 1}, M=2?log2?m?+1,
這樣一來我們只需要知道 a i ? ( M ? 1 ) a_i \otimes (M - 1) ai??(M?1) 就能知道 a i a_i ai? 的增量。
于是我們把每個 a i a_i ai? 劃分為兩部分,分別是 ? a i M ? \lfloor \frac{a_i}{M} \rfloor ?Mai??? 和 a i ? ( M ? 1 ) a_i \otimes (M - 1) ai??(M?1),我們稱其為高位和低位。
接下來我們就分別求高位和低位的期望 h i s \rm{his} his 和 l o s \rm{los} los,最終的答案就是 h i s × M + l o s \rm{his \times M + los} his×M+los。
對于低位來說,我們可以構造一個轉換表 s u f \rm{suf} suf,其中 s u f [ v ] = ( v + ( v ? m ) + x ) ? ( M ? 1 ) , v ∈ [ 0 , M ) . \rm{suf[v]} = (v + (v \otimes m) + x) \otimes (M - 1), \ v \in [0, M). suf[v]=(v+(v?m)+x)?(M?1),?v∈[0,M).
這樣就可以求出低位和的期望 l o s \rm{los} los。
- 假設有 j j j 次操作讓 a i a_i ai? 發生改變,現在已經求出 j ? 1 j - 1 j?1 次改變時每個低位值的個數,記為 c n t j ? 1 [ v ] cnt_{j - 1}[v] cntj?1?[v],其中 v ∈ [ 0 , M ) v \in [0, M) v∈[0,M),那么只要對 ? v ∈ [ 0 , M ) \forall v \in [0, M) ?v∈[0,M) 做一次 s u f \rm{suf} suf 變換就可以得到新的 v ′ v' v′ 以及 c n t j [ v ′ ] cnt_j[v'] cntj?[v′] 了,具體來說就是 c n t j [ s u f [ v ] ] : = c n t j [ s u f [ v ] ] + c n t j ? 1 [ v ] . \rm{cnt_j[suf[v]]} := cnt_j[suf[v]] + cnt_{j - 1}[v]. cntj?[suf[v]]:=cntj?[suf[v]]+cntj?1?[v]. 再對每個 c n t j [ v ′ ] × v ′ cnt_j[v'] \times v' cntj?[v′]×v′ 乘上 j j j 次改變的概率 ( 1 2 ) k ( k j ) , \left(\frac{1}{2}\right)^k{k \choose j}, (21?)k(jk?), 最后求和就是期望。
- 初始值 c n t 0 [ v ] cnt_0[v] cnt0?[v] 只要對 ? a i \forall a_i ?ai? 記錄 a i ? ( M ? 1 ) a_i \otimes (M - 1) ai??(M?1) 的數量即可。
高位就稍微復雜一些。
我們模仿低位,構造一個高位映射表 p r e \rm{pre} pre,其中 p r e [ v ] = ? ( v + ( v ? m ) + x ) M ? , v ∈ [ 0 , M ) . \rm{pre[v]} = \lfloor \frac{(v + (v \otimes m) + x)}{M} \rfloor, \ v \in [0, M). pre[v]=?M(v+(v?m)+x)??,?v∈[0,M).
對于高位來說,我們不用期望的原始公式 E ( X ) = ∑ i = 1 N p i X i , E(X) = \sum\limits_{i = 1}^{N}p_i X_i, E(X)=i=1∑N?pi?Xi?, 而是選擇一個基準 B B B,對其變形得到 E ( X ) = ∑ i = 1 N p i ( X i ? B + B ) = B + ∑ i = 1 N p i ( X i ? B ) . E(X) = \sum\limits_{i = 1}^{N}p_i (X_i - B + B) = B + \sum\limits_{i = 1}^{N}p_i (X_i - B). E(X)=i=1∑N?pi?(Xi??B+B)=B+i=1∑N?pi?(Xi??B). 其中 ( X i ? B ) (X_i - B) (Xi??B) 是每個隨機變量取值 X i X_i Xi? 相對于 B B B 的增量。
在高位上我們選擇的基準 B = ∑ i = 1 n ? ( a i + ( a i ? m ) + x ) M ? . B = \sum\limits_{i = 1}^{n}\lfloor \frac{(a_i + (a_i \otimes m) + x)}{M} \rfloor. B=i=1∑n??M(ai?+(ai??m)+x)??.
接下來就是算高位 增量 的期望了。
我們先給出求和式。
∑ j = 0 k ( 1 2 ) k ( k j ) ∑ i = 0 j ? 1 ∑ v = 0 M ? 1 c n t i [ v ] × p r e [ v ] , \sum\limits_{j = 0}^{k}\left( \frac{1}{2} \right)^k {k \choose j} \sum\limits_{i = 0}^{j - 1}\sum\limits_{v = 0}^{M - 1}cnt_i[v] \times pre[v], j=0∑k?(21?)k(jk?)i=0∑j?1?v=0∑M?1?cnti?[v]×pre[v],
上式中 ( 1 2 ) k ( k j ) \left( \frac{1}{2} \right)^k {k \choose j} (21?)k(jk?) 表示對數組 a a a 做了 j j j 次改變的概率,后面的兩重循環是求從開始到改變 j j j 次的增量和。
對于 j j j,之所以我們要遍歷 i ∈ [ 0 , j ) i \in [0, j) i∈[0,j),是因為 c n t i [ v ] cnt_i[v] cnti?[v] 是不斷變化的;而低位不需要這樣遍歷則是因為它不用求增量,可以直接獲得值。
但是這個三重循環的復雜度我們無法接受,所以考慮交換求和次序。
∑ j = 0 k ( 1 2 ) k ( k j ) ∑ i = 0 j ? 1 ∑ v = 0 M ? 1 c n t i [ v ] × p r e [ v ] = ∑ i = 0 k ∑ j = i + 1 k ∑ v = 0 M ? 1 ( 1 2 ) k ( k j ) × c n t i [ v ] × p r e [ v ] = ∑ i = 0 k ∑ v = 0 M ? 1 c n t i [ v ] × p r e [ v ] ∑ j = i + 1 k ( 1 2 ) k ( k j ) = ∑ i = 0 k ∑ v = 0 M ? 1 c n t i [ v ] × p r e [ v ] × s [ j ] . \begin{align*} &\sum\limits_{j = 0}^{k}\left( \frac{1}{2} \right)^k {k \choose j} \sum\limits_{i = 0}^{j - 1}\sum\limits_{v = 0}^{M - 1}cnt_i[v] \times pre[v] \\ &= \sum\limits_{i = 0}^{k}\sum\limits_{j = i + 1}^{k}\sum\limits_{v = 0}^{M - 1}\left( \frac{1}{2} \right)^k {k \choose j} \times cnt_i[v] \times pre[v] \\ &= \sum\limits_{i = 0}^{k}\sum\limits_{v = 0}^{M - 1}cnt_i[v] \times pre[v] \sum\limits_{j = i + 1}^{k}\left( \frac{1}{2} \right)^k {k \choose j} \\ &= \sum\limits_{i = 0}^{k}\sum\limits_{v = 0}^{M - 1}cnt_i[v] \times pre[v] \times s[j]. \end{align*} ?j=0∑k?(21?)k(jk?)i=0∑j?1?v=0∑M?1?cnti?[v]×pre[v]=i=0∑k?j=i+1∑k?v=0∑M?1?(21?)k(jk?)×cnti?[v]×pre[v]=i=0∑k?v=0∑M?1?cnti?[v]×pre[v]j=i+1∑k?(21?)k(jk?)=i=0∑k?v=0∑M?1?cnti?[v]×pre[v]×s[j].?
其中 s [ j ] = ∑ j = i + 1 k ( 1 2 ) k ( k j ) . s[j] = \sum\limits_{j = i + 1}^{k}\left( \frac{1}{2} \right)^k {k \choose j}. s[j]=j=i+1∑k?(21?)k(jk?).
這樣就把復雜度降到 O ( M k ) O(Mk) O(Mk) 了。
時間復雜度 O ( m k ) O(mk) O(mk)
- 雖然說 M = 2 ? log ? 2 m ? + 1 M = 2^{\lfloor \log_2 m\rfloor} + 1 M=2?log2?m?+1,但是量級上最多是 2 m 2m 2m, 2 2 2 可以看作常數。
C++ Code
#include <bits/stdc++.h>using u32 = unsigned;
using i64 = long long;
using u64 = unsigned long long;template<class T>
constexpr T power(T a, i64 b) {T res = 1;for (; b; b /= 2, a *= a) {if (b % 2) {res *= a;}}return res;
}
template<int P>
struct MInt {int x;constexpr MInt() : x{} {}constexpr MInt(i64 x) : x{norm(x % getMod())} {}static int Mod;constexpr static int getMod() {if (P > 0) {return P;} else {return Mod;}}constexpr static void setMod(int Mod_) {Mod = Mod_;}constexpr int norm(int x) const {if (x < 0) {x += getMod();}if (x >= getMod()) {x -= getMod();}return x;}constexpr int val() const {return x;}explicit constexpr operator int() const {return x;}constexpr MInt operator-() const {MInt res;res.x = norm(getMod() - x);return res;}constexpr MInt inv() const {assert(x != 0);return power(*this, getMod() - 2);}constexpr MInt &operator*=(MInt rhs) & {x = 1LL * x * rhs.x % getMod();return *this;}constexpr MInt &operator+=(MInt rhs) & {x = norm(x + rhs.x);return *this;}constexpr MInt &operator-=(MInt rhs) & {x = norm(x - rhs.x);return *this;}constexpr MInt &operator/=(MInt rhs) & {return *this *= rhs.inv();}friend constexpr MInt operator*(MInt lhs, MInt rhs) {MInt res = lhs;res *= rhs;return res;}friend constexpr MInt operator+(MInt lhs, MInt rhs) {MInt res = lhs;res += rhs;return res;}friend constexpr MInt operator-(MInt lhs, MInt rhs) {MInt res = lhs;res -= rhs;return res;}friend constexpr MInt operator/(MInt lhs, MInt rhs) {MInt res = lhs;res /= rhs;return res;}friend constexpr std::istream &operator>>(std::istream &is, MInt &a) {i64 v;is >> v;a = MInt(v);return is;}friend constexpr std::ostream &operator<<(std::ostream &os, const MInt &a) {return os << a.val();}friend constexpr bool operator==(MInt lhs, MInt rhs) {return lhs.val() == rhs.val();}friend constexpr bool operator!=(MInt lhs, MInt rhs) {return lhs.val() != rhs.val();}
};template<>
int MInt<0>::Mod = 998244353;template<int V, int P>
constexpr MInt<P> CInv = MInt<P>(V).inv();constexpr int P = 998244353;
using Z = MInt<P>;struct Comb {int n;std::vector<Z> _fac;std::vector<Z> _invfac;std::vector<Z> _inv;Comb() : n{0}, _fac{1}, _invfac{1}, _inv{0} {}Comb(int n) : Comb() {init(n);}void init(int m) {m = std::min(m, Z::getMod() - 1);if (m <= n) return;_fac.resize(m + 1);_invfac.resize(m + 1);_inv.resize(m + 1);for (int i = n + 1; i <= m; i += 1) {_fac[i] = _fac[i - 1] * i;}_invfac[m] = _fac[m].inv();for (int i = m; i > n; i -= 1) {_invfac[i - 1] = _invfac[i] * i;_inv[i] = _invfac[i] * _fac[i - 1];}n = m;}Z fac(int m) {if (m > n) init(2 * m);return _fac[m];}Z invfac(int m) {if (m > n) init(2 * m);return _invfac[m];}Z inv(int m) {if (m > n) init(2 * m);return _inv[m];}Z binom(int n, int m) {if (n < m || m < 0) {return 0;}return fac(n) * invfac(m) * invfac(n - m);}Z Lucas(i64 n, i64 m, int p) {if (n < p and m < p) {return binom(n, m);}return Lucas(n / p, m / p, p) * binom(n % p, m % p);}Z Lucas(i64 n, i64 m) {if (n < Z::getMod() and m < Z::getMod()) {return binom(n, m);}return Lucas(n / Z::getMod(), m / Z::getMod()) * binom(n % Z::getMod(), m % Z::getMod()); }Z perm(int n, int m) {if (n < m or m < 0) {return 0;}return fac(n) * invfac(n - m);}
} comb;template<class T>
std::istream &operator>>(std::istream &is, std::vector<T> &v) {for (auto &x: v) {is >> x;}return is;
}int main() {std::ios::sync_with_stdio(false);std::cin.tie(nullptr);int n, x, m, k;std::cin >> n >> x >> m >> k;std::vector<int> a(n);std::cin >> a;int lm = std::__lg(m) + 1;int M = 1 << lm;std::vector<int> pre(M);std::vector<int> suf(M);for (int i = 0; i < M; i++) {int v = i + (i & m) + x;pre[i] = v >> lm;suf[i] = v & (M - 1);}Z hi0 = 0;std::vector<int> cnt(M);for (int ai: a) {cnt[ai & (M - 1)]++;hi0 += ai >> lm;}std::vector<Z> binom(k + 1);for (int i = 0; i <= k; i++) {binom[i] = comb.binom(k, i) / power(Z(2), k);}std::vector<Z> s(k + 2);for (int i = k; i >= 0; i--) {s[i] = s[i + 1] + binom[i];}Z los = 0;Z his = hi0;for (int i = 0; i <= k; i++) {Z lo = 0;Z hi = 0;for (int j = 0; j < M; j++) {lo += Z(cnt[j]) * j;hi += Z(cnt[j]) * pre[j];}los += binom[i] * lo;his += s[i + 1] * hi;std::vector<int> ncnt(M);for (int j = 0; j < M; j++) {ncnt[suf[j]] += cnt[j];}cnt = std::move(ncnt);}Z ans = his * M + los;std::cout << ans << "\n";return 0;
}