How does PageRank work in GraphX for Spark

This is a post summarizing my efforts in understanding how PageRank in GraphX works for Spark 1.4.0

The smallest code snippet for running PageRank is on GraphX documentation page

https://spark.apache.org/docs/latest/graphx-programming-guide.html, look for PageRank.

The code example works with a very small example in graphx/data/follower.txt, with only 6 data points. A much larger example is in

examples/src/main/scala/org/apache/spark/examples/graphx/LiveJournalPageRank.scala

We start from

spark/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala

The algorithm is clearly listed in the comments, but the real thing that is defined in this file are functional components, including (1) vertexProgram, (2) sendMessages (3) messgeCombiner, and all the components are passed into Pregel function.

<pre>def apply[VD: ClassTag, ED: ClassTag, A: ClassTag]
   (graph: Graph[VD, ED],
    initialMsg: A,
    maxIterations: Int = Int.MaxValue,
    activeDirection: EdgeDirection = EdgeDirection.Either)
   (vprog: (VertexId, VD, A) => VD,
    sendMsg: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)],
    mergeMsg: (A, A) => A)
  : Graph[VD, ED] =
{
  var g = graph.mapVertices((vid, vdata) => vprog(vid, vdata, initialMsg)).cache()
  // compute the messages
  var messages = g.mapReduceTriplets(sendMsg, mergeMsg)
  var activeMessages = messages.count()
  // Loop
  var prevG: Graph[VD, ED] = null
  var i = 0
  while (activeMessages > 0 && i < maxIterations) {
    // Receive the messages. Vertices that didn't get any messages do not appear in newVerts.
    val newVerts = g.vertices.innerJoin(messages)(vprog).cache()
    // Update the graph with the new vertices.
    prevG = g
    g = g.outerJoinVertices(newVerts) { (vid, old, newOpt) => newOpt.getOrElse(old) }
    g.cache()

    val oldMessages = messages
    // Send new messages. Vertices that didn't get any messages don't appear in newVerts, so don't
    // get to send messages. We must cache messages so it can be materialized on the next line,
    // allowing us to uncache the previous iteration.
    messages = g.mapReduceTriplets(sendMsg, mergeMsg, Some((newVerts, activeDirection))).cache()
    // The call to count() materializes `messages`, `newVerts`, and the vertices of `g`. This
    // hides oldMessages (depended on by newVerts), newVerts (depended on by messages), and the
    // vertices of prevG (depended on by newVerts, oldMessages, and the vertices of g).
    activeMessages = messages.count()

    logInfo("Pregel finished iteration " + i)

    // Unpersist the RDDs hidden by newly-materialized RDDs
    oldMessages.unpersist(blocking=false)
    newVerts.unpersist(blocking=false)
    prevG.unpersistVertices(blocking=false)
    prevG.edges.unpersist(blocking=false)
    // count the iteration
    i += 1
  }

  g
} // end of apply</pre>

The key part of the code is mapReduceTriplets, but this function is deprecated,

<pre>var messages = g.mapReduceTriplets(sendMsg, mergeMsg)

It is implemented in the following wayu

</pre>
<pre>override def mapReduceTriplets[A: ClassTag](
    mapFunc: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)],
    reduceFunc: (A, A) => A,
    activeSetOpt: Option[(VertexRDD[_], EdgeDirection)]): VertexRDD[A] = {

  def sendMsg(ctx: EdgeContext[VD, ED, A]) {
    mapFunc(ctx.toEdgeTriplet).foreach { kv =>
      val id = kv._1
      val msg = kv._2
      if (id == ctx.srcId) {
        ctx.sendToSrc(msg)
      } else {
        assert(id == ctx.dstId)
        ctx.sendToDst(msg)
      }
    }
  }

  val mapUsesSrcAttr = accessesVertexAttr(mapFunc, "srcAttr")
  val mapUsesDstAttr = accessesVertexAttr(mapFunc, "dstAttr")
  val tripletFields = new TripletFields(mapUsesSrcAttr, mapUsesDstAttr, true)

  aggregateMessagesWithActiveSet(sendMsg, reduceFunc, tripletFields, activeSetOpt)
}

And it all comes down to aggregateMessagesWithActiveSet

</pre>
<pre>override def aggregateMessagesWithActiveSet[A: ClassTag](
    sendMsg: EdgeContext[VD, ED, A] => Unit,
    mergeMsg: (A, A) => A,
    tripletFields: TripletFields,
    activeSetOpt: Option[(VertexRDD[_], EdgeDirection)]): VertexRDD[A] = {

  vertices.cache()
  // For each vertex, replicate its attribute only to partitions where it is
  // in the relevant position in an edge.
  replicatedVertexView.upgrade(vertices, tripletFields.useSrc, tripletFields.useDst)
  val view = activeSetOpt match {
    case Some((activeSet, _)) =>
      replicatedVertexView.withActiveSet(activeSet)
    case None =>
      replicatedVertexView
  }
  val activeDirectionOpt = activeSetOpt.map(_._2)

  // Map and combine.
  val preAgg = view.edges.partitionsRDD.mapPartitions(_.flatMap {
    case (pid, edgePartition) =>
      // Choose scan method
      val activeFraction = edgePartition.numActives.getOrElse(0) / edgePartition.indexSize.toFloat
      activeDirectionOpt match {
        case Some(EdgeDirection.Both) =>
          if (activeFraction < 0.8) {
            edgePartition.aggregateMessagesIndexScan(sendMsg, mergeMsg, tripletFields,
              EdgeActiveness.Both)
          } else {
            edgePartition.aggregateMessagesEdgeScan(sendMsg, mergeMsg, tripletFields,
              EdgeActiveness.Both)
          }
        case Some(EdgeDirection.Either) =>
          // TODO: Because we only have a clustered index on the source vertex ID, we can't filter
          // the index here. Instead we have to scan all edges and then do the filter.
          edgePartition.aggregateMessagesEdgeScan(sendMsg, mergeMsg, tripletFields,
            EdgeActiveness.Either)
        case Some(EdgeDirection.Out) =>
          if (activeFraction < 0.8) {
            edgePartition.aggregateMessagesIndexScan(sendMsg, mergeMsg, tripletFields,
              EdgeActiveness.SrcOnly)
          } else {
            edgePartition.aggregateMessagesEdgeScan(sendMsg, mergeMsg, tripletFields,
              EdgeActiveness.SrcOnly)
          }
        case Some(EdgeDirection.In) =>
          edgePartition.aggregateMessagesEdgeScan(sendMsg, mergeMsg, tripletFields,
            EdgeActiveness.DstOnly)
        case _ => // None
          edgePartition.aggregateMessagesEdgeScan(sendMsg, mergeMsg, tripletFields,
            EdgeActiveness.Neither)
      }
  }).setName("GraphImpl.aggregateMessages - preAgg")

  // do the final reduction reusing the index map
  vertices.aggregateUsingIndex(preAgg, mergeMsg)
}

The above code is in graphx/impl/GraphImpl.scala. According to the benchmark on LiveJournalPageRank, this function takes a lot of time, more than half of the overall execution time.

SpecificallymapPartitions at GraphImp.scal:235

</pre>
<pre>val preAgg = view.edges.partitionsRDD.mapPartitions(_.flatMap {...</pre>
<pre>

Is taking a major amount of time. This part seems to be the pre map aggregation happening on the map task side. However, it does seem that this still generate shuffled message. This function should be the performance bottleneck of the entire application.

Other than this application,
in graphx/impl/VertexRDDImpl.scala ,
line 90,

</pre>
<pre>/** The number of vertices in the RDD. */
override def count(): Long = {
  partitionsRDD.map(_.size).reduce(_ + _)
}</pre>
<pre>

Can take some time in later stages too. Even though it is not clear to me why that is the case.

A good way to summarize the call sequence is the following stack trace all the way to pre aggregation

org.apache.spark.rdd.RDD.mapPartitions(RDD.scala:663)
org.apache.spark.graphx.impl.GraphImpl.aggregateMessagesWithActiveSet(GraphImpl.scala:235)
org.apache.spark.graphx.impl.GraphImpl.mapReduceTriplets(GraphImpl.scala:213)

Another stack trace on the call to Reduce
org.apache.spark.rdd.RDD.reduce(RDD.scala:928)
org.apache.spark.graphx.impl.VertexRDDImpl.count(VertexRDDImpl.scala:90)
org.apache.spark.graphx.Pregel$.apply(Pregel.scala:145)

Next, I want to find a data set size that is reasonable to finish on a single node and continue test the performance of the application.

Advertisements
This entry was posted in Uncategorized. Bookmark the permalink.

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Google+ photo

You are commenting using your Google+ account. Log Out / Change )

Connecting to %s