diff --git a/src/main/scala/org/graphframes/lib/ConnectedComponents.scala b/src/main/scala/org/graphframes/lib/ConnectedComponents.scala index 5ca41e0dc..b49c85eb1 100644 --- a/src/main/scala/org/graphframes/lib/ConnectedComponents.scala +++ b/src/main/scala/org/graphframes/lib/ConnectedComponents.scala @@ -327,22 +327,22 @@ object ConnectedComponents extends Logging { // checkpointing if (shouldCheckpoint && (iteration % checkpointInterval == 0)) { - // TODO: remove this after DataFrame.checkpoint is implemented - val out = s"${checkpointDir.get}/$iteration" - ee.write.parquet(out) - // may hit S3 eventually consistent issue - ee = spark.read.parquet(out) - - // remove previous checkpoint + // enable checkpointing if not yet done + if (spark.sparkContext.getCheckpointDir.isEmpty) { + spark.sparkContext.setCheckpointDir(checkpointDir.get) + } + ee = ee.checkpoint(eager = true) + // remove previous checkpoint manually if needed if (iteration > checkpointInterval) { - val path = new Path(s"${checkpointDir.get}/${iteration - checkpointInterval}") - path.getFileSystem(sc.hadoopConfiguration).delete(path, true) + val oldCheckpointPath = new Path( + s"${checkpointDir.get}/${iteration - checkpointInterval}") + oldCheckpointPath + .getFileSystem(sc.hadoopConfiguration) + .delete(oldCheckpointPath, true) } - System.gc() // hint Spark to clean shuffle directories } - ee.persist(intermediateStorageLevel) currRoundPersistedDFs = currRoundPersistedDFs :+ ee minNbrs1 = minNbrs(ee) // src >= min_nbr diff --git a/src/test/scala/org/graphframes/lib/ConnectedComponentsSuite.scala b/src/test/scala/org/graphframes/lib/ConnectedComponentsSuite.scala index 9614d1e20..c1b678733 100644 --- a/src/test/scala/org/graphframes/lib/ConnectedComponentsSuite.scala +++ b/src/test/scala/org/graphframes/lib/ConnectedComponentsSuite.scala @@ -17,17 +17,17 @@ package org.graphframes.lib -import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.Row -import org.apache.spark.sql.functions.col -import org.apache.spark.sql.functions.lit +import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.spark.sql.functions.{col, lit} import org.apache.spark.sql.types.DataTypes +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.storage.StorageLevel import org.graphframes.GraphFrame._ import org.graphframes._ import org.graphframes.examples.Graphs import java.io.IOException +import java.net.URI import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag @@ -173,6 +173,20 @@ class ConnectedComponentsSuite extends SparkFunSuite with GraphFrameTestSparkCon assert(components.groupBy("component").count().count() === 1L) } + def listCheckpointFiles(path: String): Set[String] = { + val fs = FileSystem.get(new URI(path), spark.sparkContext.hadoopConfiguration) + val checkpointPath = new Path(path) + if (!fs.exists(checkpointPath)) return Set.empty + fs.listStatus(checkpointPath) + .flatMap { + case status if status.isDirectory => + fs.listStatus(status.getPath).toSeq.map(_.getPath.toString) + case status => + Seq(status.getPath.toString) + } + .toSet + } + test("checkpoint interval") { val friends = Graphs.friends val expected = Set(Set("a", "b", "c", "d", "e", "f"), Set("g")) @@ -182,8 +196,12 @@ class ConnectedComponentsSuite extends SparkFunSuite with GraphFrameTestSparkCon cc.getCheckpointInterval === 2, s"Default checkpoint interval should be 2, but got ${cc.getCheckpointInterval}.") - val checkpointDir = sc.getCheckpointDir - assert(checkpointDir.nonEmpty) + val checkpointDirOpt = sc.getCheckpointDir + assert(checkpointDirOpt.nonEmpty) + val checkpointDir = checkpointDirOpt.get + // clean up the checkpoint dir first to avoid test conflicts + val checkpointPath = new Path(s"${checkpointDir}/") + checkpointPath.getFileSystem(sc.hadoopConfiguration).delete(checkpointPath, true) sc.setCheckpointDir(null) withClue( @@ -194,30 +212,31 @@ class ConnectedComponentsSuite extends SparkFunSuite with GraphFrameTestSparkCon } } - // Checks whether the input DataFrame is from some checkpoint data. - // TODO: The implemetnation is a little hacky. - def isFromCheckpoint(df: DataFrame): Boolean = { - df.queryExecution.logical.toString().toLowerCase.contains("parquet") - } - + // The `dataframe.rdd.isCheckpointed` behavior changed and can't be used to verify whether + // a dataframe is checkpointed, so we have check the files instead val components0 = cc.setCheckpointInterval(0).run() + val after0 = listCheckpointFiles(checkpointDir) assertComponents(components0, expected) assert( - !isFromCheckpoint(components0), + after0.isEmpty, "The result shouldn't depend on checkpoint data if checkpointing is disabled.") - sc.setCheckpointDir(checkpointDir.get) + sc.setCheckpointDir(checkpointDir) + val before1 = listCheckpointFiles(checkpointDir) val components1 = cc.setCheckpointInterval(1).run() + val after1 = listCheckpointFiles(checkpointDir) assertComponents(components1, expected) assert( - isFromCheckpoint(components1), + (after1 -- before1).nonEmpty, "The result should depend on checkpoint data if checkpoint interval is 1.") + val before10 = listCheckpointFiles(checkpointDir) val components10 = cc.setCheckpointInterval(10).run() + val after10 = listCheckpointFiles(checkpointDir) assertComponents(components10, expected) assert( - !isFromCheckpoint(components10), + (after10 -- before10).isEmpty, "The result shouldn't depend on checkpoint data if converged before first checkpoint.") }