Problem
国王奇遇记加强版之再加强版
TimeLimit:15Sec
MemoryLimit:128MB
Description
共一行,包括两个正整数N和M。
Output
共一行,为所求表达式的值对109+7取模的值。
Sample Output
Hint
1≤N≤109,1≤m≤5×105
Source
Bywjy1998
标签:多项式插值
Solution
好题,看了Miskcoo′sSpace中的题解才懂。以下题解全部部分摘自特殊多项式在整点上的线性插值方法和BZOJ-3157. 国王奇遇记。
1. 多项式整点插值
观察二项式系数(mx)=m!x(x−1)(x−2)⋯(x−m+1),其为一个x的m次多项式。对于(0x),(1x),⋯,(mx),由于其次数互不相同,故其线性无关。可以发现这m+1个多项式是m次多项式线性空间Pm[x]的一组基。
于是对于∀F(x)∈Pm[x],其一定可以表示为这m+1个多项式的线性组合,即
F(x)=i=0∑m(ix)ai(1)
由于i>x时,(ix)=0,于是当x>m时,有
F(x)=i=0∑x(ix)ai(2)
根据二项式定理,可知∑i=0m(−1)i(im)=[m=0],应用其对(2)进行二项式反演,得
ai=j=0∑i(−1)i−j(ji)F(j)(3)
将(3)带入(1),得
F(x)=i=0∑m(ix)j=0∑i(−1)i−j(ji)F(j)=j=0∑mF(j)i=j∑m(ix)(−1)i−j(ji)=j=0∑mF(j)i=0∑m−j(−1)i(i+jx)(ji+j)
化简一下二项式系数
(i+jx)(ji+j)=(i+j)!(x−i−j)!x!×i!j!(i+j)!=i!j!(x−i−j)!x!=j!(x−j)!x!×i!(x−i−j)!(x−j)!=(jx)(ix−j)
于是
F(x)=j=0∑mF(j)i=0∑m−j(−1)i(i+jx)(ji+j)=j=0∑mF(j)i=0∑m−j(−1)i(jx)(ix−j)=j=0∑m(jx)F(j)i=0∑m−j(−1)i(ix−j)
即
F(x)=j=0∑m(jx)F(j)i=0∑m−j(−1)i(ix−j)(4)
将后面的部分进一步化简,即化简形如∑i=0n(−1)i(ix)的式子。
根据上指标反转(qp)=(−1)q(qq−p−1),可将其化简
i=0∑n(−1)i(ix)=i=0∑n(ii−x−1)=(0−x−1+0)+(1−x−1+1)+⋯+(n−x−1+n)=(0−x−1+1)+(1−x−1+1)+⋯+(n−x−1+n)=(1−x−1+2)+(2−x−1+2)+⋯+(n−x−1+n)=(n−x−1+n+1)=(nn−x)
带入(4)得
F(x)=j=0∑m(jx)F(j)i=0∑m−j(−1)i(ix−j)=j=0∑m(jx)(m−jm−x)F(j)
将后面的(m−xx−j−1)再次上指标反转得
F(x)=j=0∑m(−1)m−j(jx)(m−jx−j−1)F(j)
即
F(x)=i=0∑m(−1)m−i(ix)(m−ix−i−1)F(i)
故对于m次多项式F(x),当x>m时,如果能得到F(0)∼F(m)的值,可以求出F(x)的值。
具体地,两个二项式系数的乘积为
(ix)(m−ix−i−1)=i!(x−i)!x!×(m−i)!(x−m−1)!(x−i−1)!=i!(x−i)(m−i)!(x−m−1)!x!=x−ix(x−1)(x−2)⋯(x−m)×(m−i)!i!=x(x−1)⋯(x−i+1)×(x−i−1)(x−i−2)⋯(x−m)×(m−i)!i!=(m−i)!i!p=x−m∏x−i−1pq=x−i+1∏xq
对于分式部分,可以预处理阶乘及其逆元以O(1)计算,对于两个乘式,可以预处理(x−m)∼x的前缀积和后缀积以O(1)计算。总时间复杂度为O(m)。
2. 幂和
记Sm(n)=∑i=1nimmi,求出m较小时的通项,瞎猜发现当m>1时,Sm(n)一定有如下形式:
Sm(n)=mnFm(n)−Fm(0)
其中Fm(n)是一个m次多项式。根据前面多项式整点线性插值,只需求出Fm(0),Fm(1),⋯,Fm(m)即可O(m)求得Fm(n)。
注意到在Sm(i)−Sm(i−1)中Fm(0)被消去,于是可以得到Fm的递推式:
Sm(i)−Sm(i−1)=miFm(i)−mi−1Fm(i−1)=immiFm(i)=mFm(i−1)+im
于是可以将Fm(i)表示为kiFm(0)+bi的形式,然而还无法求出Fm(0)。
根据多项式插值时推导的公式,当x>m时,有
Fm(x)=i=0∑m(−1)m−i(ix)(m−ix−i−1)Fm(i)
将m+1作为x带入得
Fm(m+1)=i=0∑m(−1)m−i(im+1)(m−im−i)Fm(i)=i=0∑m(−1)m−i(im+1)Fm(i)
将F(m+1)移到右边得
i=0∑m(−1)m−i(im+1)Fm(i)=0
将∀i∈[1,n+1],Fm(i)=kiFm(0)+bi带入即可解出Fm(0),复杂度O(m)。然后通过前面O(m)插值可求出Fm(n),带入Sm(n)=mnFm(n)−Fm(0)即可求得Sm(n)。总时间复杂度O(m)。
注意由于m=1时不满足性质,需要单独计算;此外还需暴力计算n≤m的情况。
简单版之再简单版见BZOJ3157,简单版见BZOJ3516。
Code
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
| #include <bits/stdc++.h> #define MAX_M 500000 #define P 1000000007 using namespace std; typedef long long lnt; template <class T> inline void read(T &x) { x = 0; int c = getchar(), f = 1; for (; !isdigit(c); c = getchar()) if (c == 45) f = -1; for (; isdigit(c); c = getchar()) (x *= 10) += f*(c-'0'); } int n, m; lnt F[MAX_M+5]; lnt pw[MAX_M+5], fac[MAX_M+5], inv[MAX_M+5]; bool NotPri[MAX_M+5]; vector <int> pri; lnt Pow(lnt x, lnt k) { lnt ret = 1; for (; k; k >>= 1, x = x*x%P) if (k&1) ret = ret*x%P; return ret; } void init(int n) { fac[0] = inv[0] = inv[1] = pw[1] = 1; for (int i = 2; i <= n; i++) { if (!NotPri[i]) pri.push_back(i), pw[i] = Pow(i, m); for (int j = 0; j < (int)pri.size(); j++) { if (i*pri[j] > n) break; NotPri[i*pri[j]] = true; pw[i*pri[j]] = pw[i]*pw[pri[j]]%P; if (!(i%pri[j])) break; } } for (int i = 1; i <= n; i++) fac[i] = fac[i-1]*i%P; for (int i = 2; i <= n; i++) inv[i] = (P-P/i*inv[P%i]%P)%P; for (int i = 2; i <= n; i++) inv[i] = inv[i-1]*inv[i]%P; } lnt C(int n, int m) {return fac[n]*inv[m]%P*inv[n-m]%P;} void get_Pnt_Val() { lnt k[MAX_M+5], b[MAX_M+5], invm; k[0] = 1, b[0] = 0, invm = Pow(m, P-2); for (int i = 1; i <= m+1; i++) k[i] = k[i-1]*invm%P, b[i] = (b[i-1]*invm%P+pw[i])%P; int f = (m&1) ? -1 : 1; k[0] *= f, f = -f; for (int i = 1; i <= m+1; i++, f = -f) k[0] = (k[0]+f*C(m+1, i)*k[i]%P)%P, b[0] = (b[0]+f*C(m+1, i)*b[i]%P)%P; F[0] = -b[0]*Pow(k[0], P-2)%P; for (int i = 1; i <= m; i++) F[i] = (k[i]*F[0]%P+b[i])%P; } #define getL(i) (i < m ? pre[m-i-1] : 1) #define getR(i) (i > 0 ? suc[m-i+1] : 1) lnt Poly_Inter() { lnt ret = 0; lnt pre[MAX_M+5], suc[MAX_M+5]; pre[0] = n-m, suc[m] = n; for (int i = 1; i <= m; i++) pre[i] = pre[i-1]*(n-m+i)%P; for (int i = m-1; i >= 1; i--) suc[i] = suc[i+1]*(n-m+i)%P; for (int i = 0, f = (m&1) ? -1 : 1; i <= m; i++, f = -f) ret = (ret+f*getL(i)*getR(i)%P*inv[i]%P*inv[m-i]%P*F[i]%P)%P; return ((Pow(m, n)*ret%P-F[0])%P+P)%P; } int main() { read(n), read(m), init(m+1); if (m == 1) printf("%lld\n", 1LL*n*(n+1)/2%P); else if (n <= m) { lnt pwm = m, sum = 0; for (int i = 1; i <= n; i++, pwm = pwm*m%P) sum = (sum+pw[i]*pwm%P)%P; printf("%lld\n", sum); } else get_Pnt_Val(), printf("%lld\n", Poly_Inter()); return 0; }
|