spark 使用的時候,總有些需求比較另類吧,比如有球友問過這樣一個需求:

浪尖,我想要在driver端獲取executor執行task返回的結果,比如task是個規則引擎,
我想知道每條規則命中了幾條數據,請問這個怎麼做呢?

這個是不是很騷氣,也很常見,按理說你輸出之後,在mysql里跑條sql就行了,但是這個往往顯的比較麻煩。而且有時候,在 driver可能還要用到這些數據呢?具體該怎麼做呢?

大部分的想法估計是collect方法,那麼用collect如何實現呢?大家自己可以考慮一下,我只能告訴你不簡單,不如輸出到資料庫里,然後driver端寫sql分析一下。

還有一種考慮就是使用自定義累加器。這樣就可以在executor端將結果累加然後在driver端使用,不過具體實現也是很麻煩。大家也可以自己琢磨一下下~

那麼,浪尖就給大家介紹一個比較常用也比較騷的操作吧。

其實,這種操作我們最先想到的應該是count函數,因為他就是將task的返回值返回到driver端,然後進行聚合的。我們可以從idea count函數點擊進去,可以看到

def count(): Long = sc.runJob(this, Utils.getIteratorSize _).sum

也即是sparkcontext的runJob方法。

Utils.getIteratorSize _這個方法主要是計算每個iterator的元素個數,也即是每個分區的元素個數,返回值就是元素個數:

/**
* Counts the number of elements of an iterator using a while loop rather than calling
* [[scala.collection.Iterator#size]] because it uses a for loop, which is slightly slower
* in the current version of Scala.
*/
def getIteratorSize[T](iterator: Iterator[T]): Long = {
var count = 0L
while (iterator.hasNext) {
count += 1L
iterator.next()
}
count
}

然後就是runJob返回的是一個數組,每個數組的元素就是我們task執行函數的返回值,然後調用sum就得到我們的統計值了。

那麼我們完全可以藉助這個思路實現我們開頭的目標。浪尖在這裡直接上案例了:

import org.apache.spark.{SparkConf, SparkContext, TaskContext}
import org.elasticsearch.hadoop.cfg.ConfigurationOptions

object es2sparkRunJob {

def main(args: Array[String]): Unit = {
val conf = new SparkConf().setMaster("local[*]").setAppName(this.getClass.getCanonicalName)

conf.set(ConfigurationOptions.ES_NODES, "127.0.0.1")
conf.set(ConfigurationOptions.ES_PORT, "9200")
conf.set(ConfigurationOptions.ES_NODES_WAN_ONLY, "true")
conf.set(ConfigurationOptions.ES_INDEX_AUTO_CREATE, "true")
conf.set(ConfigurationOptions.ES_NODES_DISCOVERY, "false")
conf.set("es.write.rest.error.handlers", "ignoreConflict")
conf.set("es.write.rest.error.handler.ignoreConflict", "com.jointsky.bigdata.handler.IgnoreConflictsHandler")

val sc = new SparkContext(conf)
import org.elasticsearch.spark._

val rdd = sc.esJsonRDD("posts").repartition(10)

rdd.count()
val func = (itr : Iterator[(String,String)]) => {
var count = 0
itr.foreach(each=>{
count += 1
})
(TaskContext.getPartitionId(),count)
}

val res = sc.runJob(rdd,func)

res.foreach(println)

sc.stop()
}
}

例子中driver端獲取的就是每個task處理的數據量。

效率高,而且操作靈活高效~

是不是很騷氣~~

推薦閱讀:

相关文章