首页 > 其他 > 详细

SparkGraphx计算指定节点的N度关系节点

时间:2017-01-20 19:39:44      阅读:1208      评论:0      收藏:0      [点我收藏+]

直接上代码:

/**
  * Created by yepei.ye on 2017/1/19.
  * Description:用于在图中为指定的节点计算这些节点的N度关系节点,输出这些节点与源节点的路径长度和节点id
  */
object GraphNdegUtil {
  /**
    * 计算节点的N度关系
    *
    * @param edges
    * @param choosedVertex
    * @param degree
    * @tparam ED
    * @return
    */
  def aggNdegreedVertices[ED: ClassTag](edges: RDD[(VertexId, VertexId)], choosedVertex: RDD[VertexId], degree: Int): VertexRDD[Seq[(Int, ArrayBuffer[VertexId])]] = aggNdegreedVertices(Graph.fromEdgeTuples(edges, 0), choosedVertex, degree)

  /**
    * 计算指定节点的小于等于N度的所有对应度数下的各个关系节点
    * 计算步骤:
    * 1.在图中给被选节点找Tag,标记要为这些节点找N度关系节点
    * 2.为图中每个节点定义初始的距离向量:(vertexId,distance)
    * 3.迭代发消息,对某一端包含Tag的边发双向消息(如果dstType不在choosedType集合中则不发),消息内容是:消息接收者不包含的距离向量
    * 4.merge消息:合并关系节点,新的tag取并集
    * 5.更新消息接收节点的路径向量集合:把接收到的距离向量的值分别加1
    * 6.迭代结束,去掉距离向量值为0的,剩下的按距离值汇总即得到结果
    *
    * @param graph         注意:图参数的顶点和边属性应该尽量简单,除了filter函数会使用到顶点的属性参数外,整个计算过程中只使用了顶点的id和边的srcId及dstId
    * @param choosedVertex ,要计算的节点的Id
    * @param filter        ,只取符合过滤条件的关系节点
    * @param degree        ,指定最多取几度关系节点,在实际迭代计算中,迭代的次数要比degree参数大1才能保证找的N度关系没有漏的
    * @return
    */
  def aggNdegreedVertices[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED],
                                                      choosedVertex: RDD[VertexId],
                                                      degree: Int,
                                                      filter: VD => Boolean = (_: VD) => true): VertexRDD[Seq[(Int, ArrayBuffer[VertexId])]] = {
    if (degree < 1) {
      throw new InvalidParameterException("度参数错误:" + degree)
    }
    val tag = true
    //要找的标记,仅一个boolean值
    val tupleRdd = choosedVertex.map(e => (e, tag)).persist(StorageLevel.MEMORY_AND_DISK_SER)
    //打标记并定义初始距离向量
    val tagedGraph: Graph[DegVertex[VD], ED] = graph.outerJoinVertices(tupleRdd)((id, attr, tg) => DegVertex(attr, tg.getOrElse(false), mutable.Map(id -> 0)))
    val result = tagedGraph.pregel[(Boolean, mutable.Map[VertexId, Int])]((false, mutable.Map()), degree + 1, EdgeDirection.Either)((_, attr, msg) => attr.addMsg(msg), sendMsg(_, filter), mergeMsg)
    val maped = result.vertices.join(tupleRdd).map(e => (e._1, sortResult(e._2._1))).persist(StorageLevel.MEMORY_AND_DISK_SER)
    tupleRdd.unpersist()
    tagedGraph.unpersist()
    VertexRDD(maped)
  }

  private def sortResult[VD: ClassTag](degs: DegVertex[VD]): Seq[(Int, ArrayBuffer[VertexId])] = degs.degVertices.toSeq.map(e => (e._2, ArrayBuffer(e._1))).filter(e => e._1 > 0).reduceByKey(_ ++ _).toSeq.sortBy(e => e._1)

  case class DegVertex[VD: ClassTag](attr: VD, var taged: Boolean, degVertices: mutable.Map[VertexId, Int]) {
    def addMsg(msg: (Boolean, mutable.Map[VertexId, Int])): DegVertex[VD] = {
      try {
        val newMsg = msg._2.map(e => (e._1, e._2 + 1))
        val newMap: mutable.Map[VertexId, Int] = mutable.Map() //注意,这里一定要新起一个map,不要直接在this.degVertices上操作,否则二轮迭代时,值传不过去
        for (k <- this.degVertices.keySet) {
          val v = this.degVertices.get(k)
          if (v.isDefined) {
            newMap.put(k, v.get)
          }
        }
        for (k <- newMsg.keySet) {
          val v = newMsg.get(k)
          if (!newMap.contains(k) && v.isDefined) {
            newMap.put(k, v.get)
          }
        }
        val newTag = if (!this.taged) msg._1 else this.taged
        DegVertex(attr, newTag, newMap)
      } catch {
        case e: Exception =>
          println("=============error found:" + msg)
          e.printStackTrace()
          this
      }
    }
  }

  private def sendMsg[VD: ClassTag, ED: ClassTag](e: EdgeTriplet[DegVertex[VD], ED], filter: VD => Boolean): Iterator[(VertexId, (Boolean, mutable.Map[VertexId, Int]))] = {
    var r: ArrayBuffer[(VertexId, (Boolean, mutable.Map[VertexId, Int]))] = ArrayBuffer[(VertexId, (Boolean, mutable.Map[VertexId, Int]))]()
    try {
      if ((e.srcAttr.taged || e.dstAttr.taged) && !isAttrSame(e.srcAttr, e.dstAttr)) {
        //1.发给src
        if (filter(e.dstAttr.attr)) {
          val ele: (VertexId, (Boolean, mutable.Map[VertexId, Int])) = (e.srcId, (e.dstAttr.taged, e.dstAttr.degVertices))
          r += ele
        } //2.发给dst
        if (filter(e.srcAttr.attr)) {
          val ele: (VertexId, (Boolean, mutable.Map[VertexId, Int])) = (e.dstId, (e.srcAttr.taged, e.srcAttr.degVertices))
          r += ele
        }
      }
    } catch {
      case ex: Exception =>
        println("==========error found:" + e)
        ex.printStackTrace()
    }
    r.iterator
  }

  private def mergeMsg(a: (Boolean, mutable.Map[VertexId, Int]), b: (Boolean, mutable.Map[VertexId, Int])): (Boolean, mutable.Map[VertexId, Int]) = {
    try {
      for (k <- b._2.keySet) {
        val v = b._2.get(k)
        if (!a._2.contains(k) && v.isDefined) {
          a._2.put(k, v.get)
        }
      }
      (b._1 || a._1, a._2)
    } catch {
      case e: Exception =>
        println("==========error found:" + a + "," + b)
        e.printStackTrace()
        (false, mutable.Map())
    }
  }

  private def isAttrSame[VD: ClassTag](a: DegVertex[VD], b: DegVertex[VD]): Boolean = a.taged == b.taged && hasSameKey(a.degVertices, b.degVertices)

  private def hasSameKey(a: mutable.Map[VertexId, Int], b: mutable.Map[VertexId, Int]): Boolean = {
    if (a.size != b.size) return false
    for (k <- a.keySet) {
      if (!b.contains(k)) return false
    }
    true
  }
}

其中sortResult方法里对Traversable[(K,V)]类型的集合使用了reduceByKey方法,这个方法是自行封装的,使用时需要导入,代码如下:

/**
  * Created by yepei.ye on 2016/12/21.
  * Description:
  */
object CollectionUtil {
  /**
    * 对具有Traversable[(K, V)]类型的集合添加reduceByKey相关方法
    *
    * @param collection
    * @param kt
    * @param vt
    * @tparam K
    * @tparam V
    */
  implicit class CollectionHelper[K, V](collection: Traversable[(K, V)])(implicit kt: ClassTag[K], vt: ClassTag[V]) {
    def reduceByKey(f: (V, V) => V): Traversable[(K, V)] = collection.groupBy(_._1).map { case (_: K, values: Traversable[(K, V)]) => values.reduce((a, b) => (a._1, f(a._2, b._2))) }

    /**
      * reduceByKey的同时,返回被reduce掉的元素的集合
      *
      * @param f
      * @return
      */
    def reduceByKeyWithReduced(f: (V, V) => V)(implicit kt: ClassTag[K], vt: ClassTag[V]): (Traversable[(K, V)], Traversable[(K, V)]) = {
      val reduced: ArrayBuffer[(K, V)] = ArrayBuffer()
      val newSeq = collection.groupBy(_._1).map {
        case (_: K, values: Traversable[(K, V)]) => values.reduce((a, b) => {
          val newValue: V = f(a._2, b._2)
          val reducedValue: V = if (newValue == a._2) b._2 else a._2
          val reducedPair: (K, V) = (a._1, reducedValue)
          reduced += reducedPair
          (a._1, newValue)
        })
      }
      (newSeq, reduced.toTraversable)
    }
  }
}

SparkGraphx计算指定节点的N度关系节点

原文:http://www.cnblogs.com/yepei/p/6323545.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!