有點難😅
考慮加入每一列,發現我們只關心當前還未確定的行的數目
有點難算😅
設 d p i , j dp_{i,j} dpi,j?表示有 i i i列,其中 j j j行未確定的方案數。欽定每一列至少有一個黑色格子。
d p i , j = j ( j + 1 ) 2 d p i ? 1 , j + ∑ k ≥ 1 ∑ k ≤ l ≤ j ( j ? l + 1 ) ( l k ) d p i ? 1 , j ? k dp_{i,j}=\frac{j(j+1)}{2}dp_{i-1,j}+\sum_{k\ge 1}\sum_{k\le l\le j}(j-l+1)\binom{l}{k}dp_{i-1,j-k} dpi,j?=2j(j+1)?dpi?1,j?+∑k≥1?∑k≤l≤j?(j?l+1)(kl?)dpi?1,j?k?
暴力 D P DP DP的復雜度為 O ( N 3 M ) O(N^3M) O(N3M),考慮優化
發現可以看成從 j + 2 j+2 j+2個數中選 k + 2 k+2 k+2個數的方案數,上面的式子其實是在枚舉倒數第二個被選中的數的位置。
d p i , j = j ( j + 1 ) 2 d p i ? 1 , j + ∑ k < j ( j + 2 k ) d p i ? 1 , k dp_{i,j}=\frac{j(j+1)}{2}dp_{i-1,j}+\sum_{k<j}\binom{j+2}{k}dp_{i-1,k} dpi,j?=2j(j+1)?dpi?1,j?+∑k<j?(kj+2?)dpi?1,k?
這樣優化到了 O ( N 2 M ) O(N^2M) O(N2M)
將組合數拆成階乘的形式,可以用多項式優化。
復雜度 O ( N M log ? N ) O(NM\log N) O(NMlogN)。
#include<bits/stdc++.h>
#define fi first
#define se second
#define ll long long
#define pb push_back
#define db double
#define inf 0x3f3f3f3f
using namespace std;
const int mod=998244353;
const int N=8005;
const int M=205;
int n,m;
ll dp[N],res;
ll fac[N],inv[N];
ll fpow(ll x,ll y=mod-2){ll z(1);for(;y;y>>=1){if(y&1)z=z*x%mod;x=x*x%mod;}return z;
}
void init(int n){fac[0]=1;for(int i=1;i<=n;i++)fac[i]=fac[i-1]*i%mod;inv[n]=fpow(fac[n]);for(int i=n;i>=1;i--)inv[i-1]=inv[i]*i%mod;
}
ll binom(int x,int y){if(x<0||y<0||x<y)return 0;return fac[x]*inv[y]%mod*inv[x-y]%mod;
}
void add(ll &x,ll y){x=(x+y)%mod;
}
int len;
ll invlen;
ll omega[N<<2][2];
void ntt(vector<ll>&a,int len,int f=0){int k=0;while((1<<k)<len)k++;for(int i=0;i<len;i++){int t=0;for(int j=0;j<k;j++){if(i>>j&1)t+=(1<<k-j-1);}if(i<t)swap(a[i],a[t]);}for(int l=2;l<=len;l<<=1){int k=l/2;ll x=omega[l][f];for(int i=0;i!=len;i+=l){ll y=1;for(int j=0;j<k;j++){ll tmp=a[i+j+k]*y%mod;a[i+j+k]=(a[i+j]-tmp)%mod;a[i+j]=(a[i+j]+tmp)%mod;y=y*x%mod;}}}if(f)for(int i=0;i<len;i++)a[i]=a[i]*invlen%mod;
}
struct poly{vector<ll>a;friend poly operator *(poly a,poly b){ntt(a.a,len),ntt(b.a,len);for(int i=0;i<len;i++)a.a[i]=a.a[i]*b.a[i]%mod;ntt(a.a,len,1);return a;}
}f,g;
signed main(){ios::sync_with_stdio(false);cin.tie(0),cout.tie(0);cin>>n>>m,init(max(n,m)+2);dp[0]=1;len=1;while(len<=2*(n+2))len<<=1;invlen=fpow(len);omega[len][0]=fpow(3,(mod-1)/len);omega[len][1]=fpow(3,mod-1-(mod-1)/len);for(int i=len/2;i;i>>=1){omega[i][0]=omega[i<<1][0]*omega[i<<1][0]%mod;omega[i][1]=omega[i<<1][1]*omega[i<<1][1]%mod;}g.a.resize(len);for(int i=3;i<=n+2;i++)g.a[i]=inv[i];add(res,1);for(int i=1;i<=m;i++){f.a.clear(),f.a.resize(len);for(int j=0;j<=n;j++)f.a[j]=dp[j]*inv[j]%mod;f=f*g;for(int j=0;j<=n;j++){dp[j]=(j*(j+1)/2*dp[j]%mod+f.a[j+2]*fac[j+2])%mod;}for(int j=0;j<=n;j++)add(res,dp[j]*binom(n,j)%mod*binom(m,i));}cout<<(res+mod)%mod;
}