关联规则求解算法Apriori的实现
code + 报告 见:https://github.com/JianmingS/Apriori
1 // by Shi Jianming 2 /* 3 数据挖掘:关联规则求解算法Apriori的实现 4 */ 5 6 #define _CRT_SECURE_NO_WARNINGS 7 #define HOME 8 9 #include <iostream> 10 #include <cstdio> 11 #include <vector> 12 #include <string> 13 #include <cmath> 14 #include <map> 15 #include <locale> 16 using namespace std; 17 const double eps = 1e-8; 18 const int MaxColNum = 100; 19 20 int rowNum, columnNum; // 行数,列数 21 double supportMin, confidenceMin; // 最小支持度, 最小置信度 22 int supporNum; // 最小支持数 23 int total; 24 int Case; 25 26 vector<vector<int> > dataBase; // 保存原始数据集 27 vector<string> columnName; // 保存每一列的栏目名 28 29 // 数据集 30 struct itemset 31 { 32 vector<int> item; // 事务(包含0个或多个项) 33 int cnt; // 事务出现次数 34 int id; // 事务唯一标识 35 itemset() 36 { 37 cnt = 0; 38 id = -1; 39 } 40 }; 41 42 vector<itemset> preL; // 频繁(k-1)-项集 43 vector<itemset> C; // 候选(k)-项集 44 vector<itemset> L; // 频繁(k)-项集 45 46 map<int, itemset> forL; // 为构造频繁(k)-项集 47 48 int C1[MaxColNum]; // 记录C1 49 50 51 /****************************************************/ 52 /* 53 Hash树: 54 Hash函数: h(p) = p mod k 55 时间复杂度:O(k) 56 */ 57 58 struct hashTrie 59 { 60 hashTrie *next[MaxColNum]; // Hash树后继节点 61 vector<itemset> C; // 候选(k)-项集 62 hashTrie() 63 { 64 fill(next, next + MaxColNum, nullptr); 65 } 66 }; 67 // 创建Hash树 68 void CrehashTrie(hashTrie *root, vector<int> branch) 69 { 70 hashTrie *p = root; 71 for (auto i = 0; i < branch.size(); ++i) 72 { 73 int pos = branch[i] % branch.size(); 74 if (nullptr == p->next[pos]) 75 { 76 p->next[pos] = new hashTrie; 77 } 78 p = p->next[pos]; 79 } 80 itemset itsetTmp; 81 itsetTmp.item = branch; 82 itsetTmp.id = (total++); 83 p->C.push_back(itsetTmp); 84 } 85 // 查找branch的值,判断是否可以在Hash树中匹配成功,并记录在Hash树中匹配成功的次数,保存频繁集 86 bool FindhashTrie(hashTrie *root, vector<int> branch) 87 { 88 hashTrie *p = root; 89 for (auto i = 0; i < branch.size(); ++i) 90 { 91 int pos = branch[i] % branch.size(); 92 if (nullptr == p->next[pos]) 93 { 94 return false; 95 } 96 p = p->next[pos]; 97 } 98 for (auto &tmp : p->C) 99 { 100 auto i = 0; 101 for (; i != tmp.item.size(); ++i) 102 { 103 if (tmp.item[i] != branch[i]) 104 { 105 break; 106 } 107 } 108 if (i == tmp.item.size()) 109 { 110 111 ++(tmp.cnt); 112 if (tmp.cnt >= (supporNum)) 113 { 114 if (forL.find(tmp.id) != forL.end()) 115 { 116 ++(forL[tmp.id].cnt); 117 }else 118 { 119 forL.insert({tmp.id, tmp}); 120 } 121 } 122 return true; 123 } 124 } 125 return false; 126 } 127 // 销毁Hash树 128 void DelhashTrie(hashTrie *T, int len) 129 { 130 for (int i = 0; i < len; ++i) 131 { 132 if (T->next[i] != nullptr) 133 { 134 DelhashTrie(T->next[i], len); 135 } 136 } 137 if (!T->C.empty()) 138 { 139 T->C.clear(); 140 } 141 delete[] T->next; 142 total = 0; 143 } 144 145 /****************************************************/ 146 147 148 149 /****************************************************/ 150 /* 151 从集合{0,1,2,3..,(N-1)} 中找出所有大小为k的子集, 并按照字典序排序 152 */ 153 vector<vector<int>> combine; 154 int arr[MaxColNum]; 155 int visit[MaxColNum]; 156 int combineN, combineK; 157 // 起始:dfs(0, 0) 158 void dfs(int d, int pos) 159 { 160 if (d == combineK) 161 { 162 vector<int> tmp; 163 for (int i = 0; i < combineK; ++i) 164 { 165 tmp.push_back(arr[i]); 166 } 167 combine.push_back(tmp); 168 return; 169 } 170 for (int i = pos; i < combineN; ++i) 171 { 172 if (!visit[i]) 173 { 174 visit[i] = true; 175 arr[d] = i; 176 dfs(d + 1, i + 1); 177 visit[i] = false; 178 } 179 } 180 } 181 /****************************************************/ 182 183 // 读取原始数据集 184 void Input() 185 { 186 cin >> rowNum >> columnNum; 187 supporNum = ceil(supportMin*(rowNum - 1)); 188 string rowFirst; 189 for (auto i = 0; i < rowNum; ++i) 190 { 191 cin >> rowFirst; 192 vector<int> dataRow; 193 int valueTmp; 194 // 去掉输入数据的第一列 195 for (auto j = 0; j < (columnNum - 1); ++j) 196 { 197 if (i != 0) 198 { 199 cin >> valueTmp; 200 if (valueTmp) { 201 ++C1[j]; 202 dataRow.push_back(j); 203 } 204 } 205 else 206 { 207 string colNameTmp; 208 cin >> colNameTmp; 209 columnName.push_back(colNameTmp); 210 } 211 } 212 if (i != 0) dataBase.push_back(dataRow); 213 } 214 } 215 216 // 获取频繁1-项集 217 void Ini() 218 { 219 for (auto i = 0; i < (columnNum - 1); ++i) 220 { 221 if (C1[i] >= supporNum) 222 { 223 itemset itemsetTmp; 224 itemsetTmp.item.push_back(i); 225 itemsetTmp.cnt = C1[i]; 226 preL.push_back(itemsetTmp); 227 } 228 } 229 } 230 231 232 // 枚举所有事务(即原始数据)包含的k-项集,计算支持度 233 void getItemsK(hashTrie *root, int k) 234 { 235 vector<int> branch; 236 // int bbb = 0; 237 for (auto tmp : dataBase) 238 { 239 // cout << bbb++ << " : " << endl; 240 if (tmp.size() < k) continue; 241 242 combineN = tmp.size(); 243 combineK = k; 244 dfs(0, 0); 245 246 for (int i = 0; i < combine.size(); ++i) 247 { 248 for (int j = 0; j < combine[i].size(); ++j) 249 { 250 branch.push_back(tmp[combine[i][j]]); 251 } 252 /***********************/ 253 /* 254 匹配候选k-项集,计算支持数 255 */ 256 FindhashTrie(root, branch); 257 // if (FindhashTrie(root, branch)) 258 // { 259 // for (auto aaa = 0; aaa < branch.size(); ++aaa) 260 // { 261 // cout << branch[aaa] << " "; 262 // } 263 // cout << endl; 264 // } 265 // /***********************/ 266 branch.clear(); 267 } 268 combine.clear(); 269 // cout << endl; 270 } 271 272 } 273 274 // 判断生成的候选(k)-项集的某个(k-1)-项子集是否为频繁项集 275 bool isInfrequentSubset(itemset c) 276 { 277 hashTrie *root = new hashTrie; 278 int k = c.item.size() - 1; 279 for (auto tmp : preL) 280 { 281 CrehashTrie(root, tmp.item); 282 } 283 vector<int> branch; 284 285 combineN = c.item.size(); 286 combineK = k; 287 dfs(0, 0); 288 289 for (int i = 0; i < combine.size(); ++i) 290 { 291 for (int j = 0; j < combine[i].size(); ++j) 292 { 293 branch.push_back(c.item[combine[i][j]]); 294 } 295 296 /***********************/ 297 /* 298 判断生成的((k-1)-项子集是否为频繁的。 299 */ 300 if (!FindhashTrie(root, branch)) 301 { 302 combine.clear(); 303 DelhashTrie(root, k); 304 return true; 305 } 306 /***********************/ 307 branch.clear(); 308 } 309 combine.clear(); 310 DelhashTrie(root, k); 311 return false; 312 } 313 314 // 产生候选(k)-项集 315 void apriori_gen(int k) 316 { 317 for (auto L1 = 0; L1 < preL.size(); ++L1) 318 { 319 for (auto L2 = L1 + 1; L2 < preL.size(); ++L2) 320 { 321 auto judge = true; 322 for (auto i = 0; i < (k - 1); ++i) 323 { 324 if (preL[L1].item[i] != preL[L2].item[i]) 325 { 326 judge = false; 327 } 328 } 329 if (!judge) continue; 330 itemset itemsetTmp; 331 for (auto i = 0; i < (k - 1); ++i) 332 { 333 itemsetTmp.item.push_back(preL[L1].item[i]); 334 } 335 itemsetTmp.item.push_back(preL[L1].item[k - 1]); 336 itemsetTmp.item.push_back(preL[L2].item[k - 1]); 337 if (isInfrequentSubset(itemsetTmp)) { 338 continue; 339 } 340 C.push_back(itemsetTmp); 341 } 342 } 343 } 344 // Apriori算法实现,并输出关联规则集 345 void Apriori() 346 { 347 for (auto k = 2; !preL.empty(); ++k) 348 { 349 hashTrie *root = new hashTrie; 350 apriori_gen(k - 1); // 求出候选(k)-项集; 351 for (auto i = 0; i < C.size(); ++i) 352 { 353 CrehashTrie(root, C[i].item); 354 } 355 C.clear(); 356 getItemsK(root, k); 357 DelhashTrie(root, k); 358 for (auto tmp : forL) 359 { 360 L.push_back(tmp.second); 361 } 362 forL.clear(); 363 if (L.empty()) 364 { 365 break; 366 } 367 for (auto fromTmp : L) 368 { 369 for (auto toTmp : preL) 370 { 371 auto i = 0; 372 for (; i < toTmp.item.size(); ++i) 373 { 374 if (toTmp.item[i] != fromTmp.item[i]) 375 { 376 break; 377 } 378 } 379 if (i == toTmp.item.size()) 380 { 381 // double aaa = (1.0*fromTmp.cnt) / (1.0*toTmp.cnt); 382 // double bbb = (1.0*fromTmp.cnt) / (1.0*toTmp.cnt) - confidenceMin; 383 if ((1.0*fromTmp.cnt)/(1.0*toTmp.cnt) - confidenceMin >= 0.0) 384 { 385 cout << "Case " << Case++ << " : " << endl; 386 for (auto j = 0; j < toTmp.item.size(); ++j) 387 { 388 cout << columnName[toTmp.item[j]]; 389 if (j != toTmp.item.size() - 1) 390 { 391 cout << ","; 392 } 393 } 394 cout << " => " << columnName[fromTmp.item[toTmp.item.size()]] << endl; 395 } 396 } 397 } 398 } 399 preL.clear(); 400 preL = L; 401 L.clear(); 402 } 403 } 404 405 int main() 406 { 407 #ifdef HOME 408 freopen("in", "r", stdin); 409 freopen("out", "w", stdout); 410 #endif 411 cin >> supportMin >> confidenceMin; 412 Case = 0; 413 total = 0; 414 Input(); 415 Ini(); 416 Apriori(); 417 418 #ifdef HOME 419 cerr << "Time elapsed: " << clock() / CLOCKS_PER_SEC << " ms" << endl; 420 #endif 421 return 0; 422 }
原文:http://www.cnblogs.com/shijianming/p/4992610.html