http://codeforces.com/contest/1042/problem/E
题意:
一个n*m的矩阵,每个位置有一个元素,给定一个起点,每次随机往一个小于这个点位置走,走过去的值为欧几里得距离的平方,求期望的值。
分析:
逆推期望。
将所有值取出,按元素大小排序,然后最小的就是0,往所有大于它的转移即可,复杂度n^2,见下方考试代码。
考虑每个点,从所有小于它的元素转移。排序后,维护前缀和,可以做到O(1)转移。
$f[i] = \sum\limits_{j=1,val[j]<val[i]}f[j] + (x_j - x_i)^2 + (y_j - y_i) ^ 2$
$f[i] =\sum\limits_{j=1,val[j]<val[i]} f[j] + x_j^2 - 2x_jx_i + x_i^2 + y_j^2 - 2y_jy_i + y_i ^ 2$
代码:
1 #include<cstdio> 2 #include<algorithm> 3 #include<cstring> 4 #include<cmath> 5 #include<iostream> 6 #include<cctype> 7 #include<set> 8 #include<vector> 9 #include<queue> 10 #include<map> 11 #define fi(s) freopen(s,"r",stdin); 12 #define fo(s) freopen(s,"w",stdout); 13 using namespace std; 14 typedef long long LL; 15 16 inline int read() { 17 int x=0,f=1;char ch=getchar();for(;!isdigit(ch);ch=getchar())if(ch==‘-‘)f=-1; 18 for(;isdigit(ch);ch=getchar())x=x*10+ch-‘0‘;return x*f; 19 } 20 21 const LL mod = 998244353; 22 const int N = 1000010; 23 24 struct Node { 25 int x, y, val; 26 bool zh; 27 bool operator < (const Node &A) const { 28 return val < A.val; 29 } 30 }A[N]; 31 LL f[N], cnt[N], sumx[N], sumy[N], sumx2[N], sumy2[N]; 32 33 LL ksm(LL a,LL b) { 34 LL ans = 1; 35 while (b) { 36 if (b & 1) ans = 1ll * ans * a % mod; 37 a = 1ll * a * a % mod; 38 b >>= 1; 39 } 40 return ans; 41 } 42 43 inline void add(LL &x,LL y) { (x += y) >= mod ? (x -= mod) : x; } 44 inline void sub(LL &x,LL y) { (x -= y) < 0 ? (x += mod) : x; } 45 46 void solve2(int n) { 47 48 A[0].val = -1; 49 for (int i=1; i<=n; ++i) { 50 if (A[i].val == A[i - 1].val) cnt[i] = cnt[i - 1]; 51 else cnt[i] = i - 1; 52 sumx[i] = (sumx[i - 1] + A[i].x) % mod; 53 sumy[i] = (sumy[i - 1] + A[i].y) % mod; 54 sumx2[i] = (sumx2[i - 1] + 1ll * A[i].x * A[i].x % mod) % mod; 55 sumy2[i] = (sumy2[i - 1] + 1ll * A[i].y * A[i].y % mod) % mod; 56 } 57 58 LL sum = 0, tmp = 0; 59 for (int i=1; i<=n; ++i) { 60 LL x2 = sumx2[cnt[i]]; 61 LL y2 = sumy2[cnt[i]]; 62 LL z1 = 1ll * sumx[cnt[i]] * 2 % mod * A[i].x % mod; 63 LL z2 = 1ll * sumy[cnt[i]] * 2 % mod * A[i].y % mod; 64 LL h1 = 1ll * cnt[i] * A[i].x % mod * A[i].x % mod; 65 LL h2 = 1ll * cnt[i] * A[i].y % mod * A[i].y % mod; 66 67 add(f[i], x2); add(f[i], y2); 68 sub(f[i], z1); sub(f[i], z2); 69 add(f[i], h1); add(f[i], h2); 70 add(f[i], sum); 71 72 f[i] = 1ll * f[i] * ksm(cnt[i], mod - 2) % mod; 73 if (A[i].zh) { 74 cout << f[i]; return ; 75 } 76 add(tmp, f[i]); // 只有小于的时候才转移!!! 77 if (A[i].val < A[i + 1].val) add(sum, tmp), tmp = 0; 78 } 79 80 } 81 82 int main() { 83 int n = read(), m = read(), tot = 0; 84 for (int i=1; i<=n; ++i) 85 for (int j=1; j<=m; ++j) 86 A[++tot].x = i, A[tot].y = j, A[tot].val = read(), A[tot].zh = false; 87 88 int x = read(), y = read(), z = (x - 1) * m + y; 89 A[z].zh = true; 90 91 sort(A + 1, A + tot + 1); 92 93 solve2(tot); 94 return 0; 95 }
比赛时代码,记录调试历程。
1 #include<cstdio> 2 #include<algorithm> 3 #include<cstring> 4 #include<cmath> 5 #include<iostream> 6 #include<cctype> 7 #include<set> 8 #include<vector> 9 #include<queue> 10 #include<map> 11 #define fi(s) freopen(s,"r",stdin); 12 #define fo(s) freopen(s,"w",stdout); 13 using namespace std; 14 typedef long long LL; 15 16 inline int read() { 17 int x=0,f=1;char ch=getchar();for(;!isdigit(ch);ch=getchar())if(ch==‘-‘)f=-1; 18 for(;isdigit(ch);ch=getchar())x=x*10+ch-‘0‘;return x*f; 19 } 20 21 const LL mod = 998244353; 22 const int N = 1000010; 23 24 struct Node { 25 int x, y, val; 26 bool zh; 27 bool operator < (const Node &A) const { 28 return val < A.val; 29 } 30 }A[N]; 31 32 LL inv[N], deg[N], f[N]; 33 //double dp[N]; 34 35 LL ksm(LL a,LL b) { 36 LL ans = 1; 37 while (b) { 38 if (b & 1) ans = 1ll * ans * a % mod; 39 a = 1ll * a * a % mod; 40 b >>= 1; 41 } 42 return ans; 43 } 44 45 LL Calc(int i,int j) { 46 return ((A[i].x - A[j].x) * (A[i].x - A[j].x) % mod + (A[i].y - A[j].y) * (A[i].y - A[j].y) % mod) % mod; 47 } 48 49 void solve1(int n) { 50 for (int i=1; i<=n; ++i) { 51 // cout << A[i].val << ": "; 52 if (deg[i]) { 53 // dp[i] = dp[i] / (double)(deg[i]); 54 f[i] = 1ll * ksm(deg[i], mod - 2) * f[i] % mod; 55 } 56 if (A[i].zh) { 57 cout << f[i]; return ; 58 } 59 for (int j=i+1; j<=n; ++j) 60 if (A[j].val > A[i].val) { 61 // dp[j] = dp[j] + dp[i] + Calc(i, j); 62 // cout << A[j].val << " " << Calc(i, j) <<"--"; 63 f[j] = (f[j] + f[i] + Calc(i, j)) % mod; 64 deg[j] ++; 65 } 66 // puts(""); 67 } 68 } 69 70 int cnt[N], sumx[N], sumy[N], sumx2[N], sumy2[N]; 71 72 inline void add(LL &x,LL y) { (x += y) >= mod ? (x -= mod) : x; } 73 inline void sub(LL &x,LL y) { (x -= y) < 0 ? (x += mod) : x; } 74 75 void solve2(int n) { 76 77 A[0].val = -1; 78 for (int i=1; i<=n; ++i) { 79 if (A[i].val == A[i - 1].val) cnt[i] = cnt[i - 1]; 80 else cnt[i] = i - 1; 81 sumx[i] = (sumx[i - 1] + A[i].x) % mod; 82 sumy[i] = (sumy[i - 1] + A[i].y) % mod; 83 sumx2[i] = (sumx2[i - 1] + 1ll * A[i].x * A[i].x % mod) % mod; 84 sumy2[i] = (sumy2[i - 1] + 1ll * A[i].y * A[i].y % mod) % mod; 85 } 86 87 LL sum = 0, tmp = 0; 88 for (int i=1; i<=n; ++i) { 89 LL x2 = sumx2[cnt[i]]; 90 LL y2 = sumy2[cnt[i]]; 91 LL z1 = 1ll * sumx[cnt[i]] * 2 % mod * A[i].x % mod; 92 LL z2 = 1ll * sumy[cnt[i]] * 2 % mod * A[i].y % mod; 93 LL h1 = 1ll * cnt[i] * A[i].x % mod * A[i].x % mod; 94 LL h2 = 1ll * cnt[i] * A[i].y % mod * A[i].y % mod; 95 96 add(f[i], x2); add(f[i], y2); 97 sub(f[i], z1); sub(f[i], z2); 98 add(f[i], h1); add(f[i], h2); 99 add(f[i], sum); 100 101 f[i] = 1ll * f[i] * ksm(cnt[i], mod - 2) % mod; 102 if (A[i].zh) { 103 cout << f[i]; return ; 104 } 105 add(tmp, f[i]); // 只有小于的时候才转移!!! 106 if (A[i].val < A[i + 1].val) add(sum, tmp), tmp = 0; 107 } 108 109 } 110 111 int main() { 112 int n = read(), m = read(), tot = 0; 113 for (int i=1; i<=n; ++i) 114 for (int j=1; j<=m; ++j) 115 A[++tot].x = i, A[tot].y = j, A[tot].val = read(), A[tot].zh = false; 116 117 int x = read(), y = read(), z = (x - 1) * m + y; 118 A[z].zh = true; 119 120 sort(A + 1, A + tot + 1); 121 122 123 // if (tot <= 1000) { 124 // solve1(tot) ;return 0; 125 // } 126 solve2(tot); 127 return 0; 128 }
CF 1042 E. Vasya and Magic Matrix
原文:https://www.cnblogs.com/mjtcn/p/9664893.html