我们只需要考虑每一个数位上的情况即可
将每一个数位上的情况相乘即为答案
考虑一个区间并起来为1即这个区间固定了,即全为1,设其为情况1
一个区间如果并起来为0,即可以区间至少有一个0,设其为情况2
如果一个区间没有被任何束缚,设其为情况3
设\(dp_i\)表示最后一个0在第\(i\)的合法方案数量
如果一个位置上是情况1,那么\(dp_i=0\)
之后我们考虑一个区间,这个区间的\(r<i\),并且这个区间的\(l\)是所有\(r<i\)的区间中最大的一个
这里的区间指的是情况二
倒数第二个0在\(l\)到\(r-1\)中必然合法,即\(dp_i=\sum_{l}^{r-1}dp_l\)
情况三中的转移和情况二的转移相同,这里就不在阐释
最后的答案为\(dp_{n+1}\)
#pragma GCC optimize(2)
#include<iostream>
#include<algorithm>
#include<cstdio>
#include<vector>
#include<cstring>
using namespace std;
const int mod=998244353;
struct node
{
int l;
int r;
friend bool operator < (const node &a,const node &b)
{
if(a.r==b.r)
return a.l<b.l;
return a.r<b.r;
}
};
vector<node> g[40];//这几位全部都为1
vector<node> v[40];//这几位不能全部都为1
int n,k,m;
int pos,maxx;
long long dp[500005],sum[5000005],ret;
long long ans=1;
bool cmp(const node &a,const node &b)
{
if(a.l==b.l)
return a.r<b.r;
return a.l<b.l;
}
void divide(int l,int r,int val)
{
int _ind=1;
while(val)
{
if(val&1){
g[_ind].push_back((node){l,r});
//if(_ind==1)printf("pos %d, [%d, %d] all must be 1\n",_ind,l,r);
}
else{
v[_ind].push_back((node){l,r});
//printf("pos %d, [%d, %d] mustn‘t all be 1\n",_ind,l,r);
}
val>>=1;
_ind++;
}
for(int i=_ind;i<=k;i++){
v[i].push_back((node){l,r});
//printf("pos %d, [%d, %d] mustn‘t all be 1\n",_ind,l,r);
}
}
void init(int _ind)
{
for(int i=0;i<=n+1;i++)
{
dp[i]=-1;
sum[i]=0;
}
dp[0]=sum[0]=1;
maxx=0;
pos=0;
for(int i=0,sz=g[_ind].size();i<sz;i++)
{
if(pos<g[_ind][i].l)
pos=g[_ind][i].l;
while(pos<=g[_ind][i].r)
{
dp[pos]=0;
pos++;
}
}
pos=0;
}
long long solve(int _ind)
{
//printf("Now _ind == %d\n",_ind);
init(_ind);
dp[0]=1;
sum[0]=1;
int sz=v[_ind].size();
long long ret;
for(int i=1;i<=n+1;i++)
{
if(dp[i]==-1)
{
dp[i]=0;
while(pos<sz && v[_ind][pos].r<i)
{
maxx=max(v[_ind][pos].l,maxx);
pos++;
}
if(maxx==0)
dp[i]=sum[i-1];
else
dp[i]=(sum[i-1]-sum[maxx-1]+mod)%mod;
ret=dp[i];
}
sum[i]=(sum[i-1]+dp[i])%mod;
//printf("dp[%d] == %lld, sum[%d] == %lld\n",i,dp[i],i,sum[i]);
}
//puts("________________________________________________");
return ret;
}
signed main()
{
scanf("%d %d %d",&n,&k,&m);
for(int i=1;i<=m;i++)
{
int l,r,x;
scanf("%d %d %d",&l,&r,&x);
divide(l,r,x);
}
for(int i=1;i<=k;i++)
{
//v[i].push_back((node){0,0});
sort(v[i].begin(),v[i].end());
sort(g[i].begin(),g[i].end(),cmp);
}
for(int i=1;i<=k;i++)
ans=ans*solve(i)%mod;
cout<<ans;
return 0;
}
原文:https://www.cnblogs.com/loney-s/p/13341271.html