关于REPTree我实在是没找到什么相关其算法的资料,或许是Weka自创的一个关于决策树的改进,也许是其它某种决策树方法的别名,根据类的注释:Fast decision tree learner. Builds a decision/regression tree using information gain/variance and prunes it using reduced-error pruning (with backfitting). Only sorts values for numeric attributes once. Missing values are dealt with by splitting the corresponding instances into pieces (i.e. as in C4.5).
public void buildClassifier(Instances data) throws Exception { // 首先例行公事看一下给定数据集是否能使用REPTree进行分类,REPTREE基本能支持所有类型 getCapabilities().testWithFail(data); // 把classIndex上没有数据的instance干掉,这些数据既不能用于训练也不能用于backfit data = new Instances(data); data.deleteWithMissingClass(); Random random = new Random(m_Seed); m_zeroR = null; if (data.numAttributes() == 1) { m_zeroR = new ZeroR();//如果只有一列的话,就是用m_ZerO作为分类器,很直观只有一列的话肯定就是结果列了,只有结果列无法训练分类器,只能使用最基本的米ZerO作为分类器,mZerO的分类方法再上篇日志有说到。 m_zeroR.buildClassifier(data); return; } // Randomize and stratify data.randomize(random);//进行随机排列 if (data.classAttribute().isNominal()) { data.stratify(m_NumFolds);//如果枚举型还要进行一下分层,目的是 } // 如果需要剪枝,则分为train集合和prune集合,否则只要train集合就行了 Instances train = null; Instances prune = null; if (!m_NoPruning) { train = data.trainCV(m_NumFolds, 0, random);//这里是用了多折交叉验证的方法取得train和test prune = data.testCV(m_NumFolds, 0); } else { train = data; } // 建立了两个数组,第一维数据无意义,只是把三维数组当二维数组用而已,第二维代表各属性,第三维代表排序的index(顺序统计量) int[][][] sortedIndices = new int[1][train.numAttributes()][0];//这个里面存放的是各instance的下标 double[][][] weights = new double[1][train.numAttributes()][0];//这个里面存放的是下标对应的instance的weight double[] vals = new double[train.numInstances()];//这个是临时数组,用于排序用的 for (int j = 0; j < train.numAttributes(); j++) { if (j != train.classIndex()) { weights[0][j] = new double[train.numInstances()]; if (train.attribute(j).isNominal()) { //如果是枚举类型,所做的排序工作就是简单的把Missing放到最后面 sortedIndices[0][j] = new int[train.numInstances()]; int count = 0; for (int i = 0; i < train.numInstances(); i++) { Instance inst = train.instance(i); if (!inst.isMissing(j)) { sortedIndices[0][j][count] = i; weights[0][j][count] = inst.weight(); count++; } } for (int i = 0; i < train.numInstances(); i++) { Instance inst = train.instance(i); if (inst.isMissing(j)) { sortedIndices[0][j][count] = i; weights[0][j][count] = inst.weight(); count++; } } } else { // 如果是数值类型,则进行排序 for (int i = 0; i < train.numInstances(); i++) { Instance inst = train.instance(i); vals[i] = inst.value(j); } sortedIndices[0][j] = Utils.sort(vals); for (int i = 0; i < train.numInstances(); i++) { weights[0][j][i] = train.instance(sortedIndices[0][j][i]).weight(); } } } } // 这里建立数组存放训练集中每个类的分布 double[] classProbs = new double[train.numClasses()]; double totalWeight = 0, totalSumSquared = 0; for (int i = 0; i < train.numInstances(); i++) { Instance inst = train.instance(i); if (data.classAttribute().isNominal()) {
classProbs[(int)inst.classValue()] += inst.weight();//如果是枚举类型,就进行简单的统计 totalWeight += inst.weight(); } else { classProbs[0] += inst.classValue() * inst.weight();//如果是数值型,就相加,到后面进行取平均的操作 totalSumSquared += inst.classValue() * inst.classValue() * inst.weight(); totalWeight += inst.weight(); } } m_Tree = new Tree();//建立决策树节点 double trainVariance = 0;//训练集的方差 if (data.classAttribute().isNumeric()) { trainVariance = m_Tree. singleVariance(classProbs[0], totalSumSquared, totalWeight) / totalWeight; classProbs[0] /= totalWeight;//这里取平均操作 } // Build tree m_Tree.buildTree(sortedIndices, weights, train, totalWeight, classProbs, new Instances(train, 0), m_MinNum, m_MinVarianceProp * trainVariance, 0, m_MaxDepth);//执行具体树上的构建操作,这参数还真多 // Insert pruning data and perform reduced error pruning if (!m_NoPruning) { m_Tree.insertHoldOutSet(prune);//传入剪枝数据 m_Tree.reducedErrorPrune();//进行剪枝 m_Tree.backfitHoldOutSet();//backfit } }
protected void buildTree(int[][][] sortedIndices, double[][][] weights, Instances data, double totalWeight, double[] classProbs, Instances header, double minNum, double minVariance, int depth, int maxDepth) throws Exception { //第一个参数是按属性排好序的下标,第二个是这些下标对应的weight,第三个是训练数据
<span style="white-space:pre"> </span>//第四个是总权重,第五个是各类的分布,第六个是表头,第七个是每个节点最小instance数量
<span style="white-space:pre"> </span>//第八个是最小的方差 ,第九个是当前深度(0 base),第十个是最大深度
m_Info = header;//首先存下表头 if (data.classAttribute().isNumeric()) { m_HoldOutDist = new double[2];//这个数组用于存放分布 } else { m_HoldOutDist = new double[data.numClasses()]; } // 看看是否有有效数据 int helpIndex = 0; if (data.classIndex() == 0) { helpIndex = 1;//传入的数据至少两列,因为一列的话上层就用m_zerO模型了,这个if是为了保证helpIndex对应的肯定是训练数据 } if (sortedIndices[0][helpIndex].length == 0) {//如果没数据,就直接反悔了 if (data.classAttribute().isNumeric()) { m_Distribution = new double[2];//为什么是二维的?第一维存放方差,第二维存放weight,基于约定的编程方式 } else { m_Distribution = new double[data.numClasses()]; } m_ClassProbs = null; sortedIndices[0] = null; weights[0] = null; return; } double priorVar = 0;//存放class的方差(其实是方差*num),只有class是数值才有意义,下面就是计算方差的过程。 if (data.classAttribute().isNumeric()) { // 每个sortedIndices[0][i]里面的都是一个Instances的index不同排列而已,使用helpIndex只是为了保证别对应到classIndex上 double totalSum = 0, totalSumSquared = 0, totalSumOfWeights = 0; for (int i = 0; i < sortedIndices[0][helpIndex].length; i++) { Instance inst = data.instance(sortedIndices[0][helpIndex][i]); totalSum += inst.classValue() * weights[0][helpIndex][i]; totalSumSquared += inst.classValue() * inst.classValue() * weights[0][helpIndex][i]; totalSumOfWeights += weights[0][helpIndex][i]; } priorVar = singleVariance(totalSum, totalSumSquared, totalSumOfWeights); } //把分布拷贝一下 m_ClassProbs = new double[classProbs.length]; System.arraycopy(classProbs, 0, m_ClassProbs, 0, classProbs.length); if ((//退出条件有4个
<span style="white-space:pre"> </span>//第一个是instances里面的totalweight总量(可以理解成里面的instance数量,因为weight默认都是1)小于两倍的minNum,minNum默认是2.
<span style="white-space:pre"> </span>totalWeight < (2 * minNum)) || // 如果是枚举类型,并且都在一类中 (data.classAttribute().isNominal() && Utils.eq(m_ClassProbs[Utils.maxIndex(m_ClassProbs)], Utils.sum(m_ClassProbs))) || // 数值型则比较方差是否小于minVariance,这个minVariance默认是原始方差的0.001,从上层代码可以得知 (data.classAttribute().isNumeric() && ((priorVar / totalWeight) < minVariance)) || // 达到最大深度 ((m_MaxDepth >= 0) && (depth >= maxDepth))) { // 设置成叶子 m_Attribute = -1; if (data.classAttribute().isNominal()) { // 设置枚举类型的分布 m_Distribution = new double[m_ClassProbs.length]; for (int i = 0; i < m_ClassProbs.length; i++) { m_Distribution[i] = m_ClassProbs[i]; } Utils.normalize(m_ClassProbs); } else { // 设置数值类型的“分布” m_Distribution = new double[2]; m_Distribution[0] = priorVar; m_Distribution[1] = totalWeight; } sortedIndices[0] = null; weights[0] = null; return; } // 下面是寻找分裂点的过程 double[] vals = new double[data.numAttributes()];//每个属性产生的信息增益 double[][][] dists = new double[data.numAttributes()][0][0];//每个属性下每个类的分布 double[][] props = new double[data.numAttributes()][0];//每个属性下class的概率,也就是根据上面这个数组的分布求概率 double[][] totalSubsetWeights = new double[data.numAttributes()][0];//每个属性下每个subset的数量 double[] splits = new double[data.numAttributes()];//每个属性的分裂点,如果是枚举型则为NaN if (data.classAttribute().isNominal()) { // 首先来看classAttribute是枚举类型的情况 for (int i = 0; i < data.numAttributes(); i++) { if (i != data.classIndex()) { splits[i] = distribution(props, dists, i, sortedIndices[0][i], weights[0][i], totalSubsetWeights, data);//得到分裂点、概率和分布 vals[i] = gain(dists[i], priorVal(dists[i]));//得到信息增益 } } } else { // 如果是数值类型则不算信息增益(为什么数值类型不算增益?只有因为枚举型才算的出信息熵)(吐个槽:话说这个if-else为啥不放在循环里面??) for (int i = 0; i < data.numAttributes(); i++) { if (i != data.classIndex()) { splits[i] = numericDistribution(props, dists, i, sortedIndices[0][i], weights[0][i], totalSubsetWeights, data, vals); } } } // 选出信息增益最大的作为分裂属性 m_Attribute = Utils.maxIndex(vals); int numAttVals = dists[m_Attribute].length; // 每个subset都要多于minNum,这样才算一个有效subset int count = 0; for (int i = 0; i < numAttVals; i++) { if (totalSubsetWeights[m_Attribute][i] >= minNum) { count++; } if (count > 1) { break; } } // 至少存在2个有效subset,才算是一个有效的split if (Utils.gr(vals[m_Attribute], 0) && (count > 1)) { // Set split point, proportions, and temp arrays m_SplitPoint = splits[m_Attribute]; m_Prop = props[m_Attribute]; double[][] attSubsetDists = dists[m_Attribute]; double[] attTotalSubsetWeights = totalSubsetWeights[m_Attribute]; // 释放内存 vals = null; dists = null; props = null; totalSubsetWeights = null; splits = null; // 得到subSet的有序index int[][][][] subsetIndices = new int[numAttVals][1][data.numAttributes()][0]; double[][][][] subsetWeights = new double[numAttVals][1][data.numAttributes()][0]; splitData(subsetIndices, subsetWeights, m_Attribute, m_SplitPoint, sortedIndices[0], weights[0], data); // 释放内存 sortedIndices[0] = null; weights[0] = null; //释放内存 m_Successors = new Tree[numAttVals]; for (int i = 0; i < numAttVals; i++) { m_Successors[i] = new Tree();//构建孩子节点 m_Successors[i]. buildTree(subsetIndices[i], subsetWeights[i], data, attTotalSubsetWeights[i], attSubsetDists[i], header, minNum, minVariance, depth + 1, maxDepth); // 还是释放内存 attSubsetDists[i] = null; } } else { // 如果不存在2个有效的subset,就直接当叶子节点了 m_Attribute = -1; sortedIndices[0] = null; weights[0] = null; } // 构建attribute用于之后的分类过程(当然这是在没有prune和backfit情况下用的) if (data.classAttribute().isNominal()) { m_Distribution = new double[m_ClassProbs.length]; for (int i = 0; i < m_ClassProbs.length; i++) { m_Distribution[i] = m_ClassProbs[i]; } Utils.normalize(m_ClassProbs); } else { m_Distribution = new double[2]; m_Distribution[0] = priorVar; m_Distribution[1] = totalWeight; } }