import java.net.InetAddress import org.apache.spark.Logging /** * The goal of this class is to provide as easy method to pipe data through an external command. It is done by combining * a {@link PipedOutputStream} with a {@link PipedInputStream} to create a single pipe to feed data through. This is * done asynchronously so data can be read and written to at the same time. * Created by jmorra on 1/22/15. */ class PipeUtils(bufferSize: Int = 1 << 20) { import java.io._ import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.Future import scala.language.postfixOps import scala.sys.process._ /** * This implicit class will allow easy access to streaming through external processes. This * should work on a line by line basis just like Spark's pipe command. * http://stackoverflow.com/questions/28095469/stream-input-to-external-process-in-scala * @param s: The input stream */ implicit class IteratorStream(s: TraversableOnce[String]) { def pipe(cmd: String): Stream[String] = cmd #< iter2is(s) lines def pipe(cmd: Seq[String]): Stream[String] = cmd #< iter2is(s) lines def run(cmd: String): String = cmd #< iter2is(s) !! private[this] def iter2is[A](it: TraversableOnce[A]): InputStream = { // What is written to the output stream will appear in the input stream. val pos = new PipedOutputStream val pis = new PipedInputStream(pos, bufferSize) val w = new PrintWriter(new BufferedOutputStream(pos, bufferSize), false) // Scala 2.11 (scala 2.10, use 'future'). Executes asynchronously. // Fill the stream, then close. Future { try it.foreach(w.println) finally w.close } // Return possibly before pis is fully written to. pis } } } /** * A framework for running VW in a cluster environment using Apache Spark. This * is meant only as a framework and may require some modification to work under your specific case. * Created by jmorra on 8/19/15. */ case class VwSparkCluster( pipeUtils: PipeUtils = new PipeUtils, ipAddress: String = InetAddress.getLocalHost.getHostAddress, defaultParallelism: Int = 2) extends Logging { import java.io._ import org.apache.commons.io.IOUtils import org.apache.spark.rdd.RDD import org.apache.spark.SparkContext import scala.sys.process._ import pipeUtils._ /** * This will learn a VW model in cluster mode. If you notice that this command never starts and just stalls then the parallelism * is probably too high. Refer to this * for more information. * @param data an RDD of Strings that are in VW input format. * @param vwCmd the VW command to run. Note that this command must NOT contain --cache_file and -f. Those will automatically * be appended if necessary. * @param parallelism the amount of parallelism to use. This is calculated using a formula defined in getParallelism * if it is not supplied. It is recommended to only supply this if getParallelism is not working * in you case. * @return a byte array containing the final VW model. */ def train(data: RDD[String], vwCmd: String, parallelism: Option[Int] = None): Array[Byte] = { if (numberOfRunningProcesses("spanning_tree") != 1) { throw new IllegalStateException("spanning_tree is not running on the driver, cannot proceed. Please start spanning_tree and try again.") } val sc = data.context val conf = sc.getConf // By using the job id and the RDD id we should get a globally unique ID. val jobId = (conf.get("spark.app.id").replaceAll("[^\\d]", "") + data.id).toLong logInfo(s"VW cluster job ID: $jobId") val partitions = parallelism.getOrElse(getParallelism(sc).getOrElse(defaultParallelism)) logInfo(s"VW cluster parallelism: ${partitions}") val repartitionedData = if (data.partitions.size == partitions) data else data.repartition(partitions) val vwBaseCmd = s"$vwCmd --total $partitions --span_server $ipAddress --unique_id $jobId" logInfo(s"VW cluster baseCmd: $vwBaseCmd") val vwModels = repartitionedData.mapPartitionsWithIndex{case (partition, x) => Iterator(runVWOnPartition(vwBaseCmd, x, partition)) } vwModels.collect.flatten.flatten } def numberOfRunningProcesses(process: String): Int = "ps aux".#|(s"grep $process").!!.split("\n").size - 1 /** * Gets the executor storage status excluding the driver node. * @param sc the SparkContext * @return an Array of Strings that are the names of all the storage statuses. */ def executors(sc: SparkContext): Array[String] = { sc.getExecutorStorageStatus.collect{ case x if x.blockManagerId.executorId != "" => x.blockManagerId.executorId } } /** * Gets the parallelism of the cluster. This is very much so a work in progress that seems to work now. This took * a lot of experimentation on Spark 1.2.0 to get to work. I make no guarantees that it will work on other Spark versions * especially if dynamic * allocation is enabled. I also only tested this with a master of yarn-client and local so I'm not sure how * well it'll behave in other resource management environments (Spark Standalone, Mesos, etc.). * @param sc the SparkContext * @return if the parallelism can be found then the expected amount of parallelism. */ def getParallelism(sc: SparkContext): Option[Int] = { sc.master match { case x if (x.contains("yarn")) => sc.getConf.getOption("spark.executor.cores").map(x => x.toInt * executors(sc).size) case _ => Some(sc.defaultParallelism) } } /** * This will accept a base VW command, and append a cache file if necessary. It will also create a temp file * to store the VW model. It will then run VW on the supplied data. Finally it will return the bytes of the * model ONLY if the partition is 0. * * This function was tricky to write because the end result of each calculation is a file on the local disk. * According to John all the models should be in the same state after learning so we can choose to save * anyone we want, therefore, transferring the contents of each file to the driver would be wasteful. * In order to avoid this unnecessary transfer we're just going to get the first file. Now you might * ask yourself why not just call .first on the RDD. We cannot do that because in that case Spark would * only evaluate the first mapper and we need all of them to be evaluated, hence the need for .collect to * be called. Note that you may have to increase spark.driver.maxResultSize if the size of the VW model * is too large. * @param vwBaseCmd the base VW command without a cache file or an output specified. A cache file will automatically * be used if --passes is specified. * @param data a String a data in VW format to be passed to VW * @param partition the partition number of this chunk of data * @return an Array of the bytes of the VW model ONLY if this is the 0th partition, else None. */ def runVWOnPartition(vwBaseCmd: String, data: Iterator[String], partition: Int): Option[Array[Byte]] = { val cacheFile = if (vwBaseCmd.contains("--passes ")) { val c = File.createTempFile("vw-cache", ".cache") c.deleteOnExit Option(c) } else None val vwBaseCmdWithCache = cacheFile.map(x => s"$vwBaseCmd -k --cache_file ${x.getCanonicalPath}").getOrElse(vwBaseCmd) val output = File.createTempFile("vw-model", ".model") output.deleteOnExit val vwCmd = s"$vwBaseCmdWithCache --node $partition -f ${output.getCanonicalPath}" data.pipe(vwCmd) cacheFile.foreach(_.delete) val vwModel = if (partition == 0) { val inputStream = new BufferedInputStream(new FileInputStream(output)) val byteArray = IOUtils.toByteArray(inputStream) inputStream.close Option(byteArray) } else None output.delete() vwModel } }