Skip to content

Add support to partitionBy and Graphx PartitionStrategy #221

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 87 additions & 2 deletions src/main/scala/org/graphframes/GraphFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@ import java.util.Random

import scala.reflect.runtime.universe.TypeTag

import org.apache.spark.graphx.{Edge, Graph}
import org.apache.spark.Partitioner
import org.apache.spark.graphx.{Edge, Graph, PartitionID, PartitionStrategy}
import org.apache.spark.sql._
import org.apache.spark.sql.functions.{array, broadcast, col, count, explode, struct, udf, monotonically_increasing_id}
import org.apache.spark.sql.functions.{array, broadcast, col, count, explode, struct, udf, monotonically_increasing_id, lit}
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel

Expand Down Expand Up @@ -265,6 +266,90 @@ class GraphFrame private(
edges.select(explode(array(SRC, DST)).as(ID)).groupBy(ID).agg(count("*").cast("int").as("degree"))
}

// ========================= Partition By ====================================
private val PARTITION_ID: String = "partition_id"

/**
* A [[org.apache.spark.Partitioner]] that use the key of PairRDD as partition
* id number.
*/
private class ExactAsKeyPartitioner(partitions: Int) extends Partitioner {
require(partitions >= 0, s"Number of partitions ($partitions) cannot be negative.")

def numPartitions: Int = partitions

def getPartition(key: Any): Int = {
val partitionIdAsKey = key.asInstanceOf[Int]
return partitionIdAsKey
}
}

/**
* Repartitions the edges in the graph according to partitionStrategy.
*
* @param partitionStrategy the partitioning strategy to use when partitioning the edges in the graph.
* @param numPartitions the number of edge partitions in the new graph.
*
* @return A new `GraphFrame` constructed from new edges and existing vertices
*/
def partitionBy(partitionStrategy: PartitionStrategy, numPartitions: Int): GraphFrame = {
val getPartitionIdUdf = udf[PartitionID, Long, Long, Int] {
(src, dst, numParts) => partitionStrategy.getPartition(src, dst, numParts)
}

// Remove 'src' and 'dst' and get original 'attr' cols
val (unnestedAttrCols, _) = edgeColumnMap.filter { p =>
val key = p._1
key != SRC && key != DST
}
.toSeq.sortBy(_._2)
.unzip

// Construct the flatten columns of edges + new partition id col
val edgesWithPartitionIdColumns = new scala.collection.mutable.ListBuffer[Column]()
.+=(col(SRC))
.+=(col(DST))
.++=(unnestedAttrCols.map(c => col(ATTR + "." + c)))
.+=(col(PARTITION_ID))
.toSeq

val edgesWithPartitionId = indexedEdges
.withColumn(
PARTITION_ID,
getPartitionIdUdf(col(LONG_SRC), col(LONG_DST), lit(numPartitions)))
.drop(LONG_SRC, LONG_DST)
.select(edgesWithPartitionIdColumns:_*)

// Use low level rdd partitioner to manipulate the data splitting
val partitioned = edgesWithPartitionId.rdd
.map(r => (r.getAs[Int](PARTITION_ID), r))
.partitionBy(new ExactAsKeyPartitioner(numPartitions))
.values

val partitionIdStructField: StructField = StructField(
PARTITION_ID, IntegerType, false)
val intermediateSchema = edges.schema.add(partitionIdStructField)

// Construct new edges from our partitioned & intermediate schema
val newEdges = edges.sqlContext
.createDataFrame(partitioned, intermediateSchema)
.drop(PARTITION_ID)

new GraphFrame(vertices, newEdges)
}


/**
* Repartitions the edges in the graph according to partitionStrategy.
*
* @param partitionStrategy the partitioning strategy to use when partitioning the edges in the graph.
*
* @return A new `GraphFrame` constructed from new edges and existing vertices
*/
def partitionBy(partitionStrategy: PartitionStrategy): GraphFrame = {
partitionBy(partitionStrategy, edges.rdd.getNumPartitions)
}

// ============================ Motif finding ========================================

/**
Expand Down
40 changes: 40 additions & 0 deletions src/test/scala/org/graphframes/GraphFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -291,4 +291,44 @@ class GraphFrameSuite extends SparkFunSuite with GraphFrameTestSparkContext {

GraphFrame.setBroadcastThreshold(defaultThreshold)
}

test("partitionBy strategy") {
import org.apache.spark.graphx.PartitionStrategy._

val sqlContext = this.sqlContext
import sqlContext.implicits._

def mkGraph(edges: List[(Long, Long)]): GraphFrame = {
GraphFrame.fromEdges(sc.parallelize(edges, 2)
.toDF("src", "dst")
.withColumn("rel", lit(0)))
}

def nonemptyParts(graph: GraphFrame): DataFrame = {
val partitionSizeDf = graph.edges.mapPartitions { iter =>
Iterator(iter.size)
}.toDF("size")
partitionSizeDf.where(col("size") > 0)
}

val identicalEdges = List((0L, 1L), (0L, 1L))
val canonicalEdges = List((0L, 1L), (1L, 0L))
val sameSrcEdges = List((0L, 1L), (0L, 2L))
val sameSrcTwoPartitionsEdges = List((0L, 1L), (0L, 2L), (1L, 1L))

// partitionBy(RandomVertexCut) puts identical edges in the same partition
assert(nonemptyParts(mkGraph(identicalEdges).partitionBy(RandomVertexCut)).count === 1)

// partitionBy(EdgePartition1D) puts same-source edges in the same partition
assert(nonemptyParts(mkGraph(sameSrcEdges).partitionBy(EdgePartition1D)).count === 1)
assert(nonemptyParts(mkGraph(sameSrcTwoPartitionsEdges).partitionBy(EdgePartition1D)).count > 1)

// partitionBy(CanonicalRandomVertexCut) puts edges that are identical modulo direction into
// the same partition
assert(
nonemptyParts(mkGraph(canonicalEdges).partitionBy(CanonicalRandomVertexCut)).count === 1)

// partitionBy(EdgePartition2D) puts identical edges in the same partition
assert(nonemptyParts(mkGraph(identicalEdges).partitionBy(EdgePartition2D)).count === 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible to add a test where count != 1? or count > 1?

}
}