GridSearchCV with Apache Spark

This article continues where I left with Classification for machine learning

Apache Spark

Apache Spark is a very popular framework in big data processing. The main reason for this: it’s fast! It can be used to parallelize your task on a cluster so it will be completed earlier than if you would execute it serially.

It also can be used on a single computer, which has the advantage that you can use all in the cores in your computer.  The first solution I have written for the classification was using the sklearn package of python. Sklearn also provides functionality to do multicore processing on a single machine via joblib, but since my client wanted to use explicitly Spark I have used that.

I have been looking into how to migrate the sklearn code to Spark ML I found out there are some initiatives already to run a sklearn solution on Spark. Because the most expensive part of the code is to find the hyperparameters with GridSearchCV, it’s important to parallelize this functionality. Databricks, the company behind the founder of Spark, has developed an integration package for sklearn on Spark. Unfortunately, it didn’t work with my code. It was caused by the fact that I used a custom cross validator, StratifiedShuffleSplit, and I need this in order to keep balanced sample classes. I only had to make a slight modification to the code and published this on my github.

The python script can be submitted to Spark with the spark-submit command, since Spark 2.0 the pyspark command is not supported anymore to execute scripts. Spark-submit takes the python script as argument as well as some optional arguments. In the example I submit it to my local computer and specify it should use 8 cores.

spark-submit --master local[8] build_model_spark.py

Before my modifications, it took my laptop about 14 minutes to build the model on the whole dataset. With Spark this was reduced to less than 4 minutes, which is a pretty good improvement! My client was happy with the result and gave me a good review, so hope this results in more ML projects!

perf_plotupwork