/** * Return an RDD containing all pairs of elements with matching keys in `this` and `other`. Each * pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and * (k, v2) is in `other`. Performs a hash join across the cluster. */ def join[W](other: RDD[(K, W)]): RDD[(K, (V, W))] = self.withScope { join(other, defaultPartitioner(self, other)) } /** * Return an RDD containing all pairs of elements with matching keys in `this` and `other`. Each * pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and * (k, v2) is in `other`. Uses the given Partitioner to partition the output RDD. */ def join[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (V, W))] = self.withScope { this.cogroup(other, partitioner).flatMapValues( pair => for (v <- pair._1.iterator; w <- pair._2.iterator) yield (v, w) ) } /** * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the * list of values for that key in `this` as well as `other`. */ def cogroup[W](other: RDD[(K, W)], partitioner: Partitioner) : RDD[(K, (Iterable[V], Iterable[W]))] = self.withScope { if (partitioner.isInstanceOf[HashPartitioner] && keyClass.isArray) { throw new SparkException("HashPartitioner cannot partition array keys.") } val cg = new CoGroupedRDD[K](Seq(self, other), partitioner) cg.mapValues { case Array(vs, w1s) => (vs.asInstanceOf[Iterable[V]], w1s.asInstanceOf[Iterable[W]]) } }
class CoGroupedRDD[K: ClassTag]( @transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: Partitioner) extends RDD[(K, Array[Iterable[_]])](rdds.head.context, Nil) { override def getDependencies: Seq[Dependency[_]] = { { rdd: RDD[_] => if (rdd.partitioner == Some(part)) { logDebug("Adding one-to-one dependency with " + rdd) new OneToOneDependency(rdd) } else { logDebug("Adding shuffle dependency with " + rdd) new ShuffleDependency[K, Any, CoGroupCombiner]( rdd.asInstanceOf[RDD[_ <: Product2[K, _]]], part, serializer) } } } override def compute(s: Partition, context: TaskContext): Iterator[(K, Array[Iterable[_]])] = { val split = s.asInstanceOf[CoGroupPartition] val numRdds = dependencies.length // A list of (rdd iterator, dependency number) pairs val rddIterators = new ArrayBuffer[(Iterator[Product2[K, Any]], Int)] for ((dep, depNum) <- dependencies.zipWithIndex) dep match { case oneToOneDependency: OneToOneDependency[Product2[K, Any]] @unchecked => val dependencyPartition = split.narrowDeps(depNum).get.split // Read them from the parent val it = oneToOneDependency.rdd.iterator(dependencyPartition, context) rddIterators += ((it, depNum)) case shuffleDependency: ShuffleDependency[_, _, _] => // Read map outputs of shuffle val it = SparkEnv.get.shuffleManager .getReader(shuffleDependency.shuffleHandle, split.index, split.index + 1, context) .read() rddIterators += ((it, depNum)) } val map = createExternalMap(numRdds) for ((it, depNum) <- rddIterators) { map.insertAll( => (pair._1, new CoGroupValue(pair._2, depNum)))) } context.taskMetrics().incMemoryBytesSpilled(map.memoryBytesSpilled) context.taskMetrics().incDiskBytesSpilled(map.diskBytesSpilled) context.taskMetrics().incPeakExecutionMemory(map.peakMemoryUsedBytes) new InterruptibleIterator(context, map.iterator.asInstanceOf[Iterator[(K, Array[Iterable[_]])]]) } private def createExternalMap(numRdds: Int) : ExternalAppendOnlyMap[K, CoGroupValue, CoGroupCombiner] = { val createCombiner: (CoGroupValue => CoGroupCombiner) = value => { val newCombiner = Array.fill(numRdds)(new CoGroup) newCombiner(value._2) += value._1 newCombiner } val mergeValue: (CoGroupCombiner, CoGroupValue) => CoGroupCombiner = (combiner, value) => { combiner(value._2) += value._1 combiner } val mergeCombiners: (CoGroupCombiner, CoGroupCombiner) => CoGroupCombiner = (combiner1, combiner2) => { var depNum = 0 while (depNum < numRdds) { combiner1(depNum) ++= combiner2(depNum) depNum += 1 } combiner1 } new ExternalAppendOnlyMap[K, CoGroupValue, CoGroupCombiner]( createCombiner, mergeValue, mergeCombiners) }
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { // --- BroadcastHashJoin -------------------------------------------------------------------- case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) if canBuildRight(joinType) && canBroadcast(right) => Seq(joins.BroadcastHashJoinExec( leftKeys, rightKeys, joinType, BuildRight, condition, planLater(left), planLater(right))) case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) if canBuildLeft(joinType) && canBroadcast(left) => Seq(joins.BroadcastHashJoinExec( leftKeys, rightKeys, joinType, BuildLeft, condition, planLater(left), planLater(right))) // --- ShuffledHashJoin --------------------------------------------------------------------- case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) if !conf.preferSortMergeJoin && canBuildRight(joinType) && canBuildLocalHashMap(right) && muchSmaller(right, left) || !RowOrdering.isOrderable(leftKeys) => Seq(joins.ShuffledHashJoinExec( leftKeys, rightKeys, joinType, BuildRight, condition, planLater(left), planLater(right))) case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) if !conf.preferSortMergeJoin && canBuildLeft(joinType) && canBuildLocalHashMap(left) && muchSmaller(left, right) || !RowOrdering.isOrderable(leftKeys) => Seq(joins.ShuffledHashJoinExec( leftKeys, rightKeys, joinType, BuildLeft, condition, planLater(left), planLater(right))) // --- SortMergeJoin ------------------------------------------------------------ case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) if RowOrdering.isOrderable(leftKeys) => joins.SortMergeJoinExec( leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil ...
val PREFER_SORTMERGEJOIN = SQLConfigBuilder("spark.sql.join.preferSortMergeJoin") .internal() .doc("When true, prefer sort merge join over shuffle hash join.") .booleanConf .createWithDefault(true)
protected def join( streamedIter: Iterator[InternalRow], hashed: HashedRelation, numOutputRows: SQLMetric): Iterator[InternalRow] = { val joinedIter = joinType match { case _: InnerLike => innerJoin(streamedIter, hashed) case LeftOuter | RightOuter => outerJoin(streamedIter, hashed) case LeftSemi => semiJoin(streamedIter, hashed) case LeftAnti => antiJoin(streamedIter, hashed) case j: ExistenceJoin => existenceJoin(streamedIter, hashed) case x => throw new IllegalArgumentException( s"BroadcastHashJoin should not take $x as the JoinType") } val resultProj = createResultProjection { r => numOutputRows += 1 resultProj(r) } } private def innerJoin( streamIter: Iterator[InternalRow], hashedRelation: HashedRelation): Iterator[InternalRow] = { val joinRow = new JoinedRow val joinKeys = streamSideKeyGenerator() streamIter.flatMap { srow => joinRow.withLeft(srow) val matches = hashedRelation.get(joinKeys(srow)) if (matches != null) { } else { Seq.empty } } }
protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") val broadcastRelation = buildPlan.executeBroadcast[HashedRelation]() streamedPlan.execute().mapPartitions { streamedIter => val hashed = broadcastRelation.value.asReadOnlyCopy() TaskContext.get().taskMetrics().incPeakExecutionMemory(hashed.estimatedSize) join(streamedIter, hashed, numOutputRows) } }
protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, buildIter) => val hashed = buildHashedRelation(buildIter) join(streamIter, hashed, numOutputRows) } }
protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => val boundCondition: (InternalRow) => Boolean = { { cond => newPredicate(cond, left.output ++ right.output).eval _ }.getOrElse { (r: InternalRow) => true } } // An ordering that can be used to compare keys from both sides. val keyOrdering = newNaturalAscendingOrdering( val resultProj: InternalRow => InternalRow = UnsafeProjection.create(output, output) joinType match { case _: InnerLike => new RowIterator { private[this] var currentLeftRow: InternalRow = _ private[this] var currentRightMatches: ArrayBuffer[InternalRow] = _ private[this] var currentMatchIdx: Int = -1 private[this] val smjScanner = new SortMergeJoinScanner( createLeftKeyGenerator(), createRightKeyGenerator(), keyOrdering, RowIterator.fromScala(leftIter), RowIterator.fromScala(rightIter) ) private[this] val joinRow = new JoinedRow if (smjScanner.findNextInnerJoinRows()) { currentRightMatches = smjScanner.getBufferedMatches currentLeftRow = smjScanner.getStreamedRow currentMatchIdx = 0 } override def advanceNext(): Boolean = { while (currentMatchIdx >= 0) { if (currentMatchIdx == currentRightMatches.length) { if (smjScanner.findNextInnerJoinRows()) { currentRightMatches = smjScanner.getBufferedMatches currentLeftRow = smjScanner.getStreamedRow currentMatchIdx = 0 } else { currentRightMatches = null currentLeftRow = null currentMatchIdx = -1 return false } } joinRow(currentLeftRow, currentRightMatches(currentMatchIdx)) currentMatchIdx += 1 if (boundCondition(joinRow)) { numOutputRows += 1 return true } } false } override def getRow: InternalRow = resultProj(joinRow) }.toScala ...
final def findNextInnerJoinRows(): Boolean = { while (advancedStreamed() && streamedRowKey.anyNull) { // Advance the streamed side of the join until we find the next row whose join key contains // no nulls or we hit the end of the streamed iterator. } if (streamedRow == null) { // We have consumed the entire streamed iterator, so there can be no more matches. matchJoinKey = null bufferedMatches.clear() false } else if (matchJoinKey != null &&, matchJoinKey) == 0) { // The new streamed row has the same join key as the previous row, so return the same matches. true } else if (bufferedRow == null) { // The streamed row‘s join key does not match the current batch of buffered rows and there are // no more rows to read from the buffered iterator, so there can be no more matches. matchJoinKey = null bufferedMatches.clear() false } else { // Advance both the streamed and buffered iterators to find the next pair of matching rows. var comp =, bufferedRowKey) do { if (streamedRowKey.anyNull) { advancedStreamed() } else { assert(!bufferedRowKey.anyNull) comp =, bufferedRowKey) if (comp > 0) advancedBufferedToRowWithNullFreeJoinKey() else if (comp < 0) advancedStreamed() } } while (streamedRow != null && bufferedRow != null && comp != 0) if (streamedRow == null || bufferedRow == null) { // We have either hit the end of one of the iterators, so there can be no more matches. matchJoinKey = null bufferedMatches.clear() false } else { // The streamed row‘s join key matches the current buffered row‘s join, so walk through the // buffered iterator to buffer the rest of the matching rows. assert(comp == 0) bufferMatchingRows() true } } } /** * Advance the streamed iterator and compute the new row‘s join key. * @return true if the streamed iterator returned a row and false otherwise. */ private def advancedStreamed(): Boolean = { if (streamedIter.advanceNext()) { streamedRow = streamedIter.getRow streamedRowKey = streamedKeyGenerator(streamedRow) true } else { streamedRow = null streamedRowKey = null false } } /** * Advance the buffered iterator until we find a row with join key that does not contain nulls. * @return true if the buffered iterator returned a row and false otherwise. */ private def advancedBufferedToRowWithNullFreeJoinKey(): Boolean = { var foundRow: Boolean = false while (!foundRow && bufferedIter.advanceNext()) { bufferedRow = bufferedIter.getRow bufferedRowKey = bufferedKeyGenerator(bufferedRow) foundRow = !bufferedRowKey.anyNull } if (!foundRow) { bufferedRow = null bufferedRowKey = null false } else { true } } /** * Called when the streamed and buffered join keys match in order to buffer the matching rows. */ private def bufferMatchingRows(): Unit = { assert(streamedRowKey != null) assert(!streamedRowKey.anyNull) assert(bufferedRowKey != null) assert(!bufferedRowKey.anyNull) assert(, bufferedRowKey) == 0) // This join key may have been produced by a mutable projection, so we need to make a copy: matchJoinKey = streamedRowKey.copy() bufferedMatches.clear() do { bufferedMatches += bufferedRow.copy() // need to copy mutable rows before buffering them advancedBufferedToRowWithNullFreeJoinKey() } while (bufferedRow != null &&, bufferedRowKey) == 0) }
abstract class RowIterator { /** * Advance this iterator by a single row. Returns `false` if this iterator has no more rows * and `true` otherwise. If this returns `true`, then the new row can be retrieved by calling * [[getRow]]. */ def advanceNext(): Boolean /** * Retrieve the row from this iterator. This method is idempotent. It is illegal to call this * method after [[advanceNext()]] has returned `false`. */ def getRow: InternalRow /** * Convert this RowIterator into a [[scala.collection.Iterator]]. */ def toScala: Iterator[InternalRow] = new RowIteratorToScala(this) }