题目描述
牛牛准备送给牛妹一条项链,但是项链太长了,于是他准备将项链拆掉。
牛妹有一些不喜欢的数,所以拆分出来的长度中一定要有至少一个不是牛妹不喜欢的(不然就送不出去了)。
注意,由于项链上的每一颗珠子都不同,所以1 5 1和1 1 5是两种不同的拆分方案,不拆也是一种方案。
输入描述:
第一行:两个数N和P。
第二行:一个数K,表示牛妹有K个不喜欢的数。
第三行:K个牛妹不喜欢的数T。
输出描述:
一个整数表示拆分方案对P取模后的值,不保证P是质数。
示例1
输入
复制
5 10007
2
2 3
输出
复制
14
备注:
对于10%的数据:N <= 10 ;
对于30%的数据:N <= 1e3;
对于100%的数据:N <= 1e18,P <= 1e9 ,K , T <= 100 。
思路:
首先对于长度n的项链,拆开的方案数为 2 n − 1 2^{n-1} 2n−1,因为相当于有 n − 1 n-1 n−1个隔板,每个隔板可以存在也可以不存在。
正难则反,合法方案数就是总方案数减去完全由牛妹不喜欢数构成的方案数。
定义 d p [ i ] dp[i] dp[i]为所有不喜欢数构成总和为 i i i的方案数,则转移方程为 d p [ i ] = ∑ d p [ i − a [ j ] ] dp[i]=∑dp[i-a[j]] dp[i]=∑dp[i−a[j]]。复杂度O(NK)
注意到 n n n的范围1e18,直接转移肯定不行。
因为 a [ j ] a[j] a[j]的范围很小,为100,可以想到用矩阵快速幂优化DP转移过程。
两个矩阵,一个是用 d p dp dp数组构成的矩阵,一个是用来辅助转移的矩阵。
则状态矩阵 a n s ans ans为
f [ i ] f [ i − 1 ] f [ i − 2 ] . . . f [ 0 ] (1) \begin{matrix} f[i] \\ f[i-1] \\ f[i-2] \\ ... \\ f[0] \end{matrix} \tag{1} f[i]f[i−1]f[i−2]...f[0](1)
转移矩阵 m a t r i x matrix matrix为
x y z 1 0 0 0 1 0 . . . 0 . . . 1 (1) \begin{matrix} x & y & z \\ 1 & 0 & 0 \\ 0 & 1 & 0\\ ... \\ 0 & ... & 1 \end{matrix} \tag{1} x10...0y01...z001(1)
其中第一行为 x , y , z . . . , 这 些 数 为 0 或 者 1 x,y,z...,这些数为0或者1 x,y,z...,这些数为0或者1,取决与不喜欢数构成的数组,如果存在 a [ j ] a[j] a[j],则 m a t r i x [ 0 ] [ a [ j ] − 1 ] = 1 matrix[0][a[j]-1]=1 matrix[0][a[j]−1]=1。
对于状态矩阵,与转移矩阵乘一次,结果矩阵为
f [ i + 1 ] f [ i ] f [ i − 1 ] . . . f [ 1 ] (1) \begin{matrix} f[i+1] \\ f[i] \\ f[i-1] \\ ... \\ f[1] \end{matrix} \tag{1} f[i+1]f[i]f[i−1]...f[1](1)
也就是第一行为新的状态,其他行直接继承之前的状态。
所以最终结果矩阵就是 m a t r i x n ∗ a n s matrix^{n}*ans matrixn∗ans,
利用快速幂,就可以加速这个转移过程了。
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <vector>
#include <map>
#include <set>
#include <cmath>using namespace std;typedef long long ll;
const int maxn = 105;
int a[maxn],up;
ll n,mod,k;struct Matrix {ll A[105][105];Matrix() {memset(A,0,sizeof(A));}Matrix operator * (const Matrix &B) const {Matrix C;for(int i = 0;i < up;i++) {for(int j = 0;j < up;j++) {for(int k = 0;k < up;k++) {C.A[i][j] = (C.A[i][j] + B.A[i][k] * A[k][j] % mod) % mod;}}}return C;}
}ans,matrix;ll qpow(ll x,ll n) {ll res = 1;while(n) {if(n & 1) res = res * x % mod;x = x * x % mod;n >>= 1;}return res;
}int main() {scanf("%lld%lld%lld",&n,&mod,&k);for(int i = 1;i <= k;i++) {scanf("%d",&a[i]);matrix.A[0][a[i] - 1] = 1;up = max(up,a[i]);}for(int i = 0;i < up - 1;i++) {matrix.A[i + 1][i] = 1;}ll num = n;ans.A[0][0] = 1;while(num) {if(num & 1) ans = ans * matrix;matrix = matrix * matrix;num >>= 1;}printf("%lld\n",((qpow(2,n - 1) - ans.A[0][0] + mod) % mod) % mod);return 0;
}