给出 \(n\) 个物品,第 \(i\) 个物品体积为 \(a_i\) 。
对于每个体积 \(V\) ,求选出 \(3\) 个物品,体积之和为 \(V\) 的方案总数。
选择顺序不同算同一种方案。
\(n\) 保证不会读入到 \(TLE\) , \(a_i\le 4 \times 10^4\) 。
\(FFT\) ,生成函数。
设生成函数 \(A(x)\) 为只选择一个物品的生成函数。其中 \([x^m]A(x)\) 的系数代表了体积 \(m\) 有多少种选法。
同理设 \(B(x)\) 为选择两个相同物品的生成函数,设 \(C(x)\) 为选择三个相同物品的生成函数。
则对于最后的答案而言:
若选择的 \(3\) 个物品互不相同,则方案数为:
\[ \frac{A^3(x)-3B(x)A(x)+2C(x)}{6} \]
因为根据容斥,\(A^3(x)\) 等于所有选择三个物品的方案数,\(B(x)A(x)\) 则是所有形如 \((a, a, b)\) 的方案数,由于这种方案在 \(A^3(x)\) 会出现三次,所以要乘 \(3\) ,然后对于所有 \((a,a,a)\) ,也即生成函数 \(C(x)\) 在 \(B(x)A(x)\) 中出现了 \(3\) 次,但实际上在 \(A^3(x)\) 只会被计算一次,所以还要加回 \(2\) 个来。
若选择 \(2\) 个物品,那么方案为:
\[ \frac{A^2(x)-B(x)}{2} \]
这个很好理解。
选择一个物品的方案自然就是 \(A(x)\) 了。
\(FFT\) 即可。
#include <cmath>
#include <complex>
#include <cstdio>
#include <iostream>
using namespace std;
#define LL long long
#define cp complex<double>
#define inline __inline__ __attribute__((always_inline))
inline LL read() {
LL x = 0, w = 1;
char ch = getchar();
while (!isdigit(ch)) {
if (ch == '-') w = -1;
ch = getchar();
}
while (isdigit(ch)) {
x = (x << 3) + (x << 1) + ch - '0';
ch = getchar();
}
return x * w;
}
const int Max_n = 4e5 + 5, Ml = 1.2e5;
const double pi = acos(-1);
cp ans[Max_n], A[Max_n], B[Max_n], C[Max_n];
namespace Input {
void main() {
int n = read();
for (int i = 1, x; i <= n; i++)
x = read(), A[x] += 1, B[x * 2] += 1, C[x * 3] += 1;
}
} // namespace Input
namespace Solve {
int bit, len, rev[Max_n];
void init() {
int bit = log2(Ml + 1) + 1;
len = 1 << bit;
for (int i = 0; i < len; i++)
rev[i] = rev[i >> 1] >> 1 | ((i & 1) << (bit - 1));
}
void dft(cp *f, int t) {
for (int i = 0; i < len; i++)
if (i < rev[i]) swap(f[i], f[rev[i]]);
for (int l = 1; l < len; l <<= 1) {
cp Wn(cos(t * pi / (double)l), sin(t * pi / (double)l));
for (int i = 0; i < len; i += (l << 1)) {
cp Wnk(1, 0);
for (int k = i; k < i + l; k++, Wnk *= Wn) {
cp x = f[k], y = f[k + l] * Wnk;
f[k] = x + y, f[k + l] = x - y;
}
}
}
}
void main() {
init();
dft(A, 1), dft(B, 1), dft(C, 1);
for (int i = 0; i < len; i++) {
ans[i] = (A[i] * A[i] * A[i] - A[i] * B[i] * 3.0 + 2.0 * C[i]) / 6.0;
ans[i] += (A[i] * A[i] - B[i]) / 2.0 + A[i];
}
dft(ans, -1);
for (int i = 0; i <= Ml; i++) ans[i] /= (double)len;
for (int i = 0; i <= Ml; i++) {
LL Ans = (LL)(ans[i].real() + 0.5);
if (Ans) printf("%d %lld\n", i, Ans);
}
}
} // namespace Solve
int main() {
Input::main();
Solve::main();
}
原文:https://www.cnblogs.com/luoshuitianyi/p/12056962.html