diff --git a/core/src/main/scala/org/graphframes/GraphFrame.scala b/core/src/main/scala/org/graphframes/GraphFrame.scala index 6f3983bc1..69e2b92ff 100644 --- a/core/src/main/scala/org/graphframes/GraphFrame.scala +++ b/core/src/main/scala/org/graphframes/GraphFrame.scala @@ -656,6 +656,26 @@ class GraphFrame private ( */ def detectingCycles: DetectingCycles = new DetectingCycles(this) + /** + * Converts the directed graph into an undirected graph by ensuring that all directed edges are + * bidirectional. For every directed edge (src, dst), a corresponding edge (dst, src) is added. + * + * @return + * a new GraphFrame representing the undirected graph. + */ + def asUndirected(): GraphFrame = { + val newEdges = edges + .select(col(SRC), col(DST), nestAsCol(edges, ATTR)) + .union(edges + .select(col(DST).alias(SRC), col(SRC).alias(DST), nestAsCol(edges, ATTR))) + .select(SRC, DST, ATTR) + val newColumns = Seq(col(SRC), col(DST)) ++ edges.columns + .filter(c => (c != SRC) && (c != DST)) + .map(c => col(ATTR).getField(c).alias(c)) + .toSeq + GraphFrame(vertices, newEdges.select(newColumns: _*)) + } + // ========= Motif finding (private) ========= /** diff --git a/core/src/test/scala/org/graphframes/GraphFrameSuite.scala b/core/src/test/scala/org/graphframes/GraphFrameSuite.scala index 23905e5d6..ff541adce 100644 --- a/core/src/test/scala/org/graphframes/GraphFrameSuite.scala +++ b/core/src/test/scala/org/graphframes/GraphFrameSuite.scala @@ -413,4 +413,40 @@ class GraphFrameSuite extends SparkFunSuite with GraphFrameTestSparkContext { .toSeq assert(Seq(0, 0, 0, 0, 1, 0) == clusters) } + + test("convert directed graph to undirected") { + val v = spark.createDataFrame(Seq((1L, "a"), (2L, "b"), (3L, "c"))).toDF("id", "name") + val e = spark.createDataFrame(Seq((1L, 2L), (2L, 3L))).toDF("src", "dst") + val g = GraphFrame(v, e) + val undirected = g.asUndirected() + + // Check edge count doubled + assert(undirected.edges.count() === 2 * g.edges.count()) + + // Verify reverse edges exist + val edges = undirected.edges.sort("src", "dst").collect() + assert(edges.length === 4) + assert(edges(0).getLong(0) === 1L) + assert(edges(0).getLong(1) === 2L) + assert(edges(1).getLong(0) === 2L) + assert(edges(1).getLong(1) === 1L) + assert(edges(2).getLong(0) === 2L) + assert(edges(2).getLong(1) === 3L) + assert(edges(3).getLong(0) === 3L) + assert(edges(3).getLong(1) === 2L) + } + + test("convert directed graph with edge attributes to undirected") { + val v = spark.createDataFrame(Seq((1L, "a"), (2L, "b"))).toDF("id", "name") + val e = spark.createDataFrame(Seq((1L, 2L, "edge1"))).toDF("src", "dst", "attr") + val g = GraphFrame(v, e) + val undirected = g.asUndirected() + + val edges = undirected.edges.collect() + assert(edges.length === 2) + assert( + edges.exists(r => r.getLong(0) == 1L && r.getLong(1) == 2L && r.getString(2) == "edge1")) + assert( + edges.exists(r => r.getLong(0) == 2L && r.getLong(1) == 1L && r.getString(2) == "edge1")) + } }