这道题 Robert_JYH 说是一道简单题,在这里向他致以最崇高的膜拜。
要做出来这道题,我们首先得明白亦或的本质,即“相同与不同”。相同就是0,不同就是1。
现在来考虑这样一个性质:设 \(\{a_n\}\) 是一个从小到大排好序的序列。那么 \(\min\limits_{i,j}\{a_i\oplus a_j\}=\min\limits_{1\leq i<n}\{a_i\oplus a_{i+1}\}\),即序列中两个元素亦或起来的最小值一定产生于相邻的两个元素之间。
要说明这个性质成立,我们只需说明 \(\forall 0\leq x<y<z\),一定有 \(x\oplus z\geq x\oplus y\) 且 \(x\oplus z \geq y\oplus z\)。定义 \(p(x,y)\) 为从高位向低位看,\(x\) 和 \(y\) 第一个不同的位。我们考虑这样一个过程,将 \(x\) 一点一点的向上增加,那么一定是 \(x\) 的低位先变动、\(x\) 的高位后变动。在这个过程中,\(x\) 先达到 \(y\),再达到 \(z\)。所以一定有 \(p(x,z)\geq p(x,y)\)。当然,这只是一个很不严谨的、非常感性的理解。具体证明看下面:
最后,设 DP 状态为 \(f[i]\) 表示强制以 \(i\) 结尾的合法子序列的数目。则有 DP 转移方程
时间复杂度:\(O(n^2)\)。
如何优化呢?看到亦或,当然要想到trie树。每次只需在trie树找出与 \(x\) 亦或起来大于等于 \(x\) 的所有数对应的DP值之和即可。
时间复杂度:\(O(n\log {maxa})\)。
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
#define FILEIN(s) freopen(s".in", "r", stdin);
#define FILEOUT(s) freopen(s".out", "w", stdout)
#define mem(s, v) memset(s, v, sizeof s)
inline long long read(void) {
long long x = 0, f = 1; char ch = getchar();
while (ch < ‘0‘ || ch > ‘9‘) { if (ch == ‘-‘) f = -1; ch = getchar(); }
while (ch >= ‘0‘ && ch <= ‘9‘) { x = x * 10 + ch - ‘0‘; ch = getchar(); }
return f * x;
}
const int maxn = 3e5 + 5, mod = 998244353;
int n, sz = 1;
long long X, a[maxn], f[maxn];
int trans[maxn * 80][2];
long long sum[maxn * 80];
inline int call(long long x, int i) { return (x >> i) & 1; }
inline void insert(long long x, long long val) {
int p = 1;
for (int i = 59; i >= 0; -- i) {
int now = call(x, i);
if (!trans[p][now]) trans[p][now] = ++ sz;
p = trans[p][now];
(sum[p] += val) %= mod;
}
}
inline long long query(long long x) {
int p = 1;
long long res = 0;
for (int i = 59; i >= 0; -- i) {
int now1 = call(x, i), now2 = call(X, i);
if (now2 == 0) {
if (trans[p][now1 ^ 1]) res += sum[trans[p][now1 ^ 1]];
}
if (!trans[p][now1 ^ now2]) trans[p][now1 ^ now2] = ++ sz;
p = trans[p][now1 ^ now2];
}
(res += sum[p]) %= mod;
return res;
}
int main() {
FILEIN("xor"); FILEOUT("xor");
n = read(); X = read();
for (int i = 1; i <= n; ++ i)
a[i] = read();
sort(a + 1, a + n + 1);
f[1] = 1;
insert(a[1], f[1]);
for (int i = 2; i <= n; ++ i) {
int tmp = query(a[i]);
f[i] = 1 + tmp;
insert(a[i], f[i]);
}
long long res = 0;
for (int i = 1; i <= n; ++ i) (res += f[i]) %= mod;
printf("%lld\n", res);
return 0;
}
原文:https://www.cnblogs.com/little-aztl/p/14829294.html