定义 mex(i, j) 为序列中第 i 项到第 j 项中没有出现的最小自然数。给定序列,求 Σ1≤i,j≤n,i≤j mex(i, j)。
首先我们可以 O(n) 预处理出 mex(1, 1 ~ n),因为显然的是mex是递增的。然后我们考虑怎么从 mex(i, i ~ n) 推出 mex(i + 1, i + 1 ~ n),我们删掉 a[i] 这个数后,哪些区间的mex会改变呢?其实就是到下一个a[i]出现前mex大于a[i]的区间,因为这段区间没有了a[i]这个数,而他们原本的mex却大于a[i],所以可以变小。所以要区间查询、修改、求和,用线段树就可以了。
#include <cstdio> #include <cstring> #include <algorithm> #include <map> #define i64 long long using namespace std; const int N = 2e5 + 10; int n, a[N], mex[N], nxt[N]; i64 ans; map<int, int> mp; struct node { int s, mx, tag; } tr[N * 8]; void init() { int now = 0; for (int i = 1; i <= n; i ++) { mp[a[i]] = 1; while (mp.count(now)) now ++; mex[i] = now; } mp.clear(); for (int i = n; i; i --) { if (mp.count(a[i])) nxt[i] = mp[a[i]]; else nxt[i] = n + 1; mp[a[i]] = i; } } void build(int o, int l, int r) { if (l == r) { tr[o].s = tr[o].mx = mex[l]; tr[o].tag = -1; return; } tr[o].tag = -1; int m = l + r >> 1; build(o << 1, l, m); build(o << 1 | 1, m + 1, r); tr[o].s = tr[o << 1].s + tr[o << 1 | 1].s; tr[o].mx = max(tr[o << 1].mx, tr[o << 1 | 1].mx); } void pushdown(int o, int l, int r) { if (tr[o].tag == -1) return; tr[o << 1].tag = tr[o << 1 | 1].tag = tr[o].tag; tr[o].s = tr[o].tag * (r - l + 1); tr[o].mx = tr[o].tag; tr[o].tag = -1; } int find(int o, int l, int r, int v) { if (l == r) return l; int m = l + r >> 1; pushdown(o << 1, l, m); pushdown(o << 1 | 1, m + 1, r); if (tr[o << 1].mx > v) return find(o << 1, l, m, v); else return find(o << 1 | 1, m + 1, r, v); } void updata(int o, int l, int r) { int m = l + r >> 1, x, y; if (tr[o << 1].tag != -1) x = tr[o << 1].tag; else x = tr[o << 1].mx; if (tr[o << 1 | 1].tag != -1) y = tr[o << 1 | 1].tag; else y = tr[o << 1 | 1].mx; tr[o].mx = max(x, y); if (tr[o << 1].tag != -1) x = tr[o << 1].tag * (m - l + 1); else x = tr[o << 1].s; if (tr[o << 1 | 1].tag != -1) y = tr[o << 1 | 1].tag * (r - m); else y = tr[o << 1 | 1].s; tr[o].s = x + y; } void modify(int o, int l, int r, int x, int y, int v) { if (x <= l && r <= y) { tr[o].tag = v; return; } pushdown(o, l, r); int m = l + r >> 1; if (x <= m) modify(o << 1, l, m, x, y, v); if (y > m) modify(o << 1 | 1, m + 1, r, x, y, v); updata(o, l, r); } int query(int o, int l, int r, int x, int y) { pushdown(o, l, r); if (x <= l && r <= y) return tr[o].s; int m = l + r >> 1, t = 0; if (x <= m) t = query(o << 1, l, m, x, y); if (y > m) t += query(o << 1 | 1, m + 1, r, x, y); return t; } void work() { ans += (i64)query(1, 1, n, 1, n - 1); for (int i = 1; i < n - 1; i ++) { pushdown(1, 1, n); int k = find(1, 1, n, a[i]); if (k < nxt[i]) modify(1, 1, n, k, nxt[i] - 1, a[i]); ans += (i64)query(1, 1, n, i + 1, n - 1); } printf("%lld", ans); } int main() { scanf("%d", &n); for (int i = 1; i <= n; i ++) scanf("%d", &a[i]); init(); mex[++ n] = N; build(1, 1, n); work(); return 0; }
原文:http://www.cnblogs.com/awner/p/5778023.html