我的数据挖掘算法代码:https://github.com/linyiqun/DataMiningAlgorithm
Apriori算法是一个经典的数据挖掘算法,Apriori的单词的意思是"先验的",说明这个算法是具有先验性质的,就是说要通过上一次的结果推导出下一次的结果,这个如何体现将会在下面的分析中会慢慢的体现出来。Apriori算法的用处是挖掘频繁项集的,频繁项集粗俗的理解就是找出经常出现的组合,然后根据这些组合最终推出我们的关联规则。
Apriori算法是一种逐层搜索的迭代式算法,其中k项集用于挖掘(k+1)项集,这是依靠他的先验性质的:
频繁项集的所有非空子集一定是也是频繁的。
通过这个性质可以对候选集进行剪枝。用k项集如何生成(k+1)项集呢,这个是算法里面最难也是最核心的部分。
通过2个步骤
1、连接步,将频繁项自己与自己进行连接运算。
2、剪枝步,去除候选集项中的不符合要求的候选项,不符合要求指的是这个候选项的子集并非都是频繁项,要遵守上文提到的先验性质。
3、通过1,2步骤还不够,在后面还要根据支持度计数筛选掉不满足最小支持度数的候选集。
首先是测试数据:
交易ID |
商品ID列表 |
T100 |
I1,I2,I5 |
T200 |
I2,I4 |
T300 |
I2,I3 |
T400 |
I1,I2,I4 |
T500 |
I1,I3 |
T600 |
I2,I3 |
T700 |
I1,I3 |
T800 |
I1,I2,I3,I5 |
T900 |
I1,I2,I3 |
最后我们可以看到频繁3项集的结果为{1, 2, 3}和{1, 2, 5},然后我们去后者{1, 2, 5}作为频繁项集来生产他的关联规则,但是在这之前得先知道一些概念,怎么样才能够成为一条关联规则,关有频繁项集还是不够的。
按照字面上的意思就是限制置信度值的一个限制条件嘛,这个很好理解。
强规则就是指的是置信度满足最小置信度(就是>=最小置信度)的推断就是一个强规则,也就是文中所说的关联规则了。这个在下面的程序中会有所体现。
我自己写的算法实现可能会让你有点晦涩难懂,不过重在理解算法的整个思路即可,尤其是连接步和剪枝步是最难点所在,可能还存在bug。
输入数据:
T1 1 2 5 T2 2 4 T3 2 3 T4 1 2 4 T5 1 3 T6 2 3 T7 1 3 T8 1 2 3 5 T9 1 2 3频繁项类:
/** * 频繁项集 * * @author lyq * */ public class FrequentItem implements Comparable<FrequentItem>{ // 频繁项集的集合ID private String[] idArray; // 频繁项集的支持度计数 private int count; //频繁项集的长度,1项集或是2项集,亦或是3项集 private int length; public FrequentItem(String[] idArray, int count){ this.idArray = idArray; this.count = count; length = idArray.length; } public String[] getIdArray() { return idArray; } public void setIdArray(String[] idArray) { this.idArray = idArray; } public int getCount() { return count; } public void setCount(int count) { this.count = count; } public int getLength() { return length; } public void setLength(int length) { this.length = length; } @Override public int compareTo(FrequentItem o) { // TODO Auto-generated method stub return this.getIdArray()[0].compareTo(o.getIdArray()[0]); } }主程序类:
package DataMining_Apriori; import java.io.BufferedReader; import java.io.File; import java.io.FileReader; import java.io.IOException; import java.text.MessageFormat; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.Map; /** * apriori算法工具类 * * @author lyq * */ public class AprioriTool { // 最小支持度计数 private int minSupportCount; // 测试数据文件地址 private String filePath; // 每个事务中的商品ID private ArrayList<String[]> totalGoodsIDs; // 过程中计算出来的所有频繁项集列表 private ArrayList<FrequentItem> resultItem; // 过程中计算出来频繁项集的ID集合 private ArrayList<String[]> resultItemID; public AprioriTool(String filePath, int minSupportCount) { this.filePath = filePath; this.minSupportCount = minSupportCount; readDataFile(); } /** * 从文件中读取数据 */ private void readDataFile() { File file = new File(filePath); ArrayList<String[]> dataArray = new ArrayList<String[]>(); try { BufferedReader in = new BufferedReader(new FileReader(file)); String str; String[] tempArray; while ((str = in.readLine()) != null) { tempArray = str.split(" "); dataArray.add(tempArray); } in.close(); } catch (IOException e) { e.getStackTrace(); } String[] temp = null; totalGoodsIDs = new ArrayList<>(); for (String[] array : dataArray) { temp = new String[array.length - 1]; System.arraycopy(array, 1, temp, 0, array.length - 1); // 将事务ID加入列表吧中 totalGoodsIDs.add(temp); } } /** * 判读字符数组array2是否包含于数组array1中 * * @param array1 * @param array2 * @return */ public boolean iSStrContain(String[] array1, String[] array2) { if (array1 == null || array2 == null) { return false; } boolean iSContain = false; for (String s : array2) { // 新的字母比较时,重新初始化变量 iSContain = false; // 判读array2中每个字符,只要包括在array1中 ,就算包含 for (String s2 : array1) { if (s.equals(s2)) { iSContain = true; break; } } // 如果已经判断出不包含了,则直接中断循环 if (!iSContain) { break; } } return iSContain; } /** * 项集进行连接运算 */ private void computeLink() { // 连接计算的终止数,k项集必须算到k-1子项集为止 int endNum = 0; // 当前已经进行连接运算到几项集,开始时就是1项集 int currentNum = 1; // 商品,1频繁项集映射图 HashMap<String, FrequentItem> itemMap = new HashMap<>(); FrequentItem tempItem; // 初始列表 ArrayList<FrequentItem> list = new ArrayList<>(); // 经过连接运算后产生的结果项集 resultItem = new ArrayList<>(); resultItemID = new ArrayList<>(); // 商品ID的种类 ArrayList<String> idType = new ArrayList<>(); for (String[] a : totalGoodsIDs) { for (String s : a) { if (!idType.contains(s)) { tempItem = new FrequentItem(new String[] { s }, 1); idType.add(s); resultItemID.add(new String[] { s }); } else { // 支持度计数加1 tempItem = itemMap.get(s); tempItem.setCount(tempItem.getCount() + 1); } itemMap.put(s, tempItem); } } // 将初始频繁项集转入到列表中,以便继续做连接运算 for (Map.Entry entry : itemMap.entrySet()) { list.add((FrequentItem) entry.getValue()); } // 按照商品ID进行排序,否则连接计算结果将会不一致,将会减少 Collections.sort(list); resultItem.addAll(list); String[] array1; String[] array2; String[] resultArray; ArrayList<String> tempIds; ArrayList<String[]> resultContainer; // 总共要算到endNum项集 endNum = list.size() - 1; while (currentNum < endNum) { resultContainer = new ArrayList<>(); for (int i = 0; i < list.size() - 1; i++) { tempItem = list.get(i); array1 = tempItem.getIdArray(); for (int j = i + 1; j < list.size(); j++) { tempIds = new ArrayList<>(); array2 = list.get(j).getIdArray(); for (int k = 0; k < array1.length; k++) { // 如果对应位置上的值相等的时候,只取其中一个值,做了一个连接删除操作 if (array1[k].equals(array2[k])) { tempIds.add(array1[k]); } else { tempIds.add(array1[k]); tempIds.add(array2[k]); } } resultArray = new String[tempIds.size()]; tempIds.toArray(resultArray); boolean isContain = false; // 过滤不符合条件的的ID数组,包括重复的和长度不符合要求的 if (resultArray.length == (array1.length + 1)) { isContain = isIDArrayContains(resultContainer, resultArray); if (!isContain) { resultContainer.add(resultArray); } } } } // 做频繁项集的剪枝处理,必须保证新的频繁项集的子项集也必须是频繁项集 list = cutItem(resultContainer); currentNum++; } // 输出频繁项集 for (int k = 1; k <= currentNum; k++) { System.out.println("频繁" + k + "项集:"); for (FrequentItem i : resultItem) { if (i.getLength() == k) { System.out.print("{"); for (String t : i.getIdArray()) { System.out.print(t + ","); } System.out.print("},"); } } System.out.println(); } } /** * 判断列表结果中是否已经包含此数组 * * @param container * ID数组容器 * @param array * 待比较数组 * @return */ private boolean isIDArrayContains(ArrayList<String[]> container, String[] array) { boolean isContain = true; if (container.size() == 0) { isContain = false; return isContain; } for (String[] s : container) { // 比较的视乎必须保证长度一样 if (s.length != array.length) { continue; } isContain = true; for (int i = 0; i < s.length; i++) { // 只要有一个id不等,就算不相等 if (s[i] != array[i]) { isContain = false; break; } } // 如果已经判断是包含在容器中时,直接退出 if (isContain) { break; } } return isContain; } /** * 对频繁项集做剪枝步骤,必须保证新的频繁项集的子项集也必须是频繁项集 */ private ArrayList<FrequentItem> cutItem(ArrayList<String[]> resultIds) { String[] temp; // 忽略的索引位置,以此构建子集 int igNoreIndex = 0; FrequentItem tempItem; // 剪枝生成新的频繁项集 ArrayList<FrequentItem> newItem = new ArrayList<>(); // 不符合要求的id ArrayList<String[]> deleteIdArray = new ArrayList<>(); // 子项集是否也为频繁子项集 boolean isContain = true; for (String[] array : resultIds) { // 列举出其中的一个个的子项集,判断存在于频繁项集列表中 temp = new String[array.length - 1]; for (igNoreIndex = 0; igNoreIndex < array.length; igNoreIndex++) { isContain = true; for (int j = 0, k = 0; j < array.length; j++) { if (j != igNoreIndex) { temp[k] = array[j]; k++; } } if (!isIDArrayContains(resultItemID, temp)) { isContain = false; break; } } if (!isContain) { deleteIdArray.add(array); } } // 移除不符合条件的ID组合 resultIds.removeAll(deleteIdArray); // 移除支持度计数不够的id集合 int tempCount = 0; for (String[] array : resultIds) { tempCount = 0; for (String[] array2 : totalGoodsIDs) { if (isStrArrayContain(array2, array)) { tempCount++; } } // 如果支持度计数大于等于最小最小支持度计数则生成新的频繁项集,并加入结果集中 if (tempCount >= minSupportCount) { tempItem = new FrequentItem(array, tempCount); newItem.add(tempItem); resultItemID.add(array); resultItem.add(tempItem); } } return newItem; } /** * 数组array2是否包含于array1中,不需要完全一样 * * @param array1 * @param array2 * @return */ private boolean isStrArrayContain(String[] array1, String[] array2) { boolean isContain = true; for (String s2 : array2) { isContain = false; for (String s1 : array1) { // 只要s2字符存在于array1中,这个字符就算包含在array1中 if (s2.equals(s1)) { isContain = true; break; } } // 一旦发现不包含的字符,则array2数组不包含于array1中 if (!isContain) { break; } } return isContain; } /** * 根据产生的频繁项集输出关联规则 * * @param minConf * 最小置信度阈值 */ public void printAttachRule(double minConf) { // 进行连接和剪枝操作 computeLink(); int count1 = 0; int count2 = 0; ArrayList<String> childGroup1; ArrayList<String> childGroup2; String[] group1; String[] group2; // 以最后一个频繁项集做关联规则的输出 String[] array = resultItem.get(resultItem.size() - 1).getIdArray(); // 子集总数,计算的时候除去自身和空集 int totalNum = (int) Math.pow(2, array.length); String[] temp; // 二进制数组,用来代表各个子集 int[] binaryArray; // 除去头和尾部 for (int i = 1; i < totalNum - 1; i++) { binaryArray = new int[array.length]; numToBinaryArray(binaryArray, i); childGroup1 = new ArrayList<>(); childGroup2 = new ArrayList<>(); count1 = 0; count2 = 0; // 按照二进制位关系取出子集 for (int j = 0; j < binaryArray.length; j++) { if (binaryArray[j] == 1) { childGroup1.add(array[j]); } else { childGroup2.add(array[j]); } } group1 = new String[childGroup1.size()]; group2 = new String[childGroup2.size()]; childGroup1.toArray(group1); childGroup2.toArray(group2); for (String[] a : totalGoodsIDs) { if (isStrArrayContain(a, group1)) { count1++; // 在group1的条件下,统计group2的事件发生次数 if (isStrArrayContain(a, group2)) { count2++; } } } // {A}-->{B}的意思为在A的情况下发生B的概率 System.out.print("{"); for (String s : group1) { System.out.print(s + ", "); } System.out.print("}-->"); System.out.print("{"); for (String s : group2) { System.out.print(s + ", "); } System.out.print(MessageFormat.format( "},confidence(置信度):{0}/{1}={2}", count2, count1, count2 * 1.0 / count1)); if (count2 * 1.0 / count1 < minConf) { // 不符合要求,不是强规则 System.out.println("由于此规则置信度未达到最小置信度的要求,不是强规则"); } else { System.out.println("为强规则"); } } } /** * 数字转为二进制形式 * * @param binaryArray * 转化后的二进制数组形式 * @param num * 待转化数字 */ private void numToBinaryArray(int[] binaryArray, int num) { int index = 0; while (num != 0) { binaryArray[index] = num % 2; index++; num /= 2; } } }调用类:
/** * apriori关联规则挖掘算法调用类 * @author lyq * */ public class Client { public static void main(String[] args){ String filePath = "C:\\Users\\lyq\\Desktop\\icon\\testInput.txt"; AprioriTool tool = new AprioriTool(filePath, 2); tool.printAttachRule(0.7); } }输出的结果:
频繁1项集: {1,},{2,},{3,},{4,},{5,}, 频繁2项集: {1,2,},{1,3,},{1,5,},{2,3,},{2,4,},{2,5,}, 频繁3项集: {1,2,3,},{1,2,5,}, 频繁4项集: {1, }-->{2, 5, },confidence(置信度):2/6=0.333由于此规则置信度未达到最小置信度的要求,不是强规则 {2, }-->{1, 5, },confidence(置信度):2/7=0.286由于此规则置信度未达到最小置信度的要求,不是强规则 {1, 2, }-->{5, },confidence(置信度):2/4=0.5由于此规则置信度未达到最小置信度的要求,不是强规则 {5, }-->{1, 2, },confidence(置信度):2/2=1为强规则 {1, 5, }-->{2, },confidence(置信度):2/2=1为强规则 {2, 5, }-->{1, },confidence(置信度):2/2=1为强规则
在实现Apiori算法的时候,碰到的一些问题和待优化的点特别要提一下:
1、首先程序的运行效率不高,里面有大量的for嵌套循环叠加上循环,当然这有本身算法的原因(连接运算所致)还有我的各个的方法选择,很多一部分用来比较字符串数组。
2、这个是我觉得会是程序的一个漏洞,当生成的候选项集加入resultItemId时,会出现{1, 2, 3}和{3, 2, 1}会被当成不同的侯选集,未做顺序的判断。
3、程序的调试过程中由于未按照从小到大的排序,导致,生成的候选集与真实值不一致的情况,所以这里必须在频繁1项集的时候就应该是有序的。
4、在输出关联规则的时候,用到了数字转二进制数组的形式,输出他的各个非空子集,然后最出关联规则的判断。
此算法的的应用非常广泛,但是他在运算的过程中会产生大量的侯选集,而且在匹配的时候要进行整个数据库的扫描,因为要做支持度计数的统计操作,在小规模的数据上操作还不会有大问题,如果是大型的数据库上呢,他的效率还是有待提高的。
原文:http://blog.csdn.net/androidlushangderen/article/details/43059211