Sign up
Log in
Sign up
Log in
Home
Blog

Compiling classical ML for performance gains (up to 30x) and hardware portability

Blog Author - Jason Knight

Feb 25, 2021

7 minutes

Share

Authors: Masahiro Masuda, OctoML; Jason Knight, OctoML; Matteo Interlandi, Microsoft; Karla Saur, Microsoft

Classical machine learning software today

Today, machine learning engineers and data scientists use popular frameworks such as Scikit-learn, XGBoost, and LightGBM to train and deploy classical ML models such as linear and logistic regression, decision trees and gradient boosting. But what if one wants more performance? Not only on CPUs which are most widely used today, but also leveraging GPUs and ML accelerators of the future? And how about integrating trained models into a larger application, particularly if the application is written in a language other than Python?

Apache TVM is a machine learning compiler stack that compiles models from popular frameworks such as PyTorch and Tensorflow into optimized machine code for wide varieties of platforms. Here at OctoML, we’ve shown that TVM excels at accelerating deep learning tasks on a variety of platforms. But given that classical ML is still the most commonly used set of algorithms in practice, as shown in the Kaggle survey conducted last year, is there any way to apply TVM to classical ML workloads?

This question was answered with a resounding YES, when last year a team of researchers and engineers from Microsoft demonstrated that TVM can be used to accelerate classical ML workloads through a project “Hummingbird”.

Introducing Hummingbird and its TVM backend

In a nutshell, Hummingbird takes a trained classical machine learning model, such as a decision tree trained in Scikit-learn, and compiles it into “tensor operations”, that are supported and able to be accelerated by deep learning frameworks and compilers such as TVM.

To understand how this works in more detail, consider a decision tree. A prediction in decision trees is an instance of a tree traversal: for each tree node, we look up an input feature value and a corresponding threshold value from the model, using these two values, we decide whether to proceed down the left or right child to continue down the tree. To turn this into tensor operation, we can collect several of these per-node operations and encode them using data parallel tensor operations such as element-wise arithmetics, gather, and where (conditional selection).

Compared to the typical “scalar” traversal used by standard decision tree algorithms, we actually end up doing more redundant work, but in a massively parallel manner. And since tensor operations are great fit for GPU execution, and Hummingbird backends such as TVM have excellent support for them, GPU acceleration for classical machine learning algorithms comes for free. Even on CPU we can use multithreading and vector instructions to better exploit data parallelism.

For more information on Hummingbird in general, please refer to their Github and previous blog posts.

Originally, Hummingbird leveraged PyTorch as its Tensor execution backend, but in joint work we are pleased to announce that Hummingbird now supports TVM as a first class backend, bringing an end to end tensor compilation stack to the project. Below, we give examples of its use, and some benchmark data to whet your appetite.

How do I use it?

Using Hummingbird with the TVM backend is simple. The input can either be a model trained in Scikit-learn directly, or an XGBoost or a LightGBM model trained with the Scikit-learn API.

Hummingbird then offers a hummingbird.ml.convert function, that takes our model and the name of the backend, and returns a compiled model that has the same prediction API as Scikit-learn. For the TVM backend, we additionally require a "test input" to be passed in, whose number of rows must be the same as the input you would be passing to the predict(...) method of the Hummingbird-compiled model. For now this restriction is required since TVM code generation today still relies on static input shapes and dynamic shape compilation is a problem that the TVM community is still actively working on.

Let’s look at that in code. Here is an example of how you would train a logistic regression model in Scikit-learn, convert and compile the model to TVM using Hummingbird, and do the prediction, making sure that two outputs are identical.

model = LogisticRegression(max_iter=1000)
model.fit(X, y)

tvm_model = hummingbird.ml.convert(model, "tvm", X)

np.testing.assert_equal(model.predict(X), tvm_model.predict(X))

Random forest can also be compiled to TVM:

model = RandomForestClassifier(max_depth=8)
model.fit(X, y)

tvm_model = hummingbird.ml.convert(model, "tvm", X)

np.testing.assert_equal(model.predict(X), tvm_model.predict(X))

We also support regression models, using the same API.

Let’s compare the performance of Scikit-learn RandomForestClassifier and the same model compiled to TVM:

X, y = fetch_california_housing(return_X_y=True) # input shape: (20640, 8)
X = X.astype(np.float32)  # make sure to use fp32 input

model = RandomForestRegressor(max_depth=8, n_estimators=250)
model.fit(X, y)

tvm_model = hummingbird.ml.convert(model, "tvm", X)

loop = 20
res_sk = timeit.timeit('model.predict(X)', number=loop)
res_tvm = timeit.timeit('tvm_model.predict(X)', number=loop)

In [2]: res_sk
Out[2]: 3.173023913999998

In [3]: res_tvm
Out[3]: 0.7454483920000143

As you can see, the TVM compiled model runs more than 4x faster.

We can also run compiled models on GPU, for much better performance. We need to pass device="cuda" to target NVIDIA GPUs.

tvm_model = hummingbird.ml.convert(model, "tvm", X, device="cuda") 
tvm_model.predict(X) # warmup, this is important
res_tvm_gpu = timeit.timeit('tvm_model.predict(X)', number=loop)

In [5]: res_tvm_gpu                            
Out[5]: 0.0787845610000204

We got further 10x speed up by simply changing one line! Which leads to more than a 30x performance improvement against using Scikit-learn on CPU alone.

Not only Scikit-learn, we also support gradient boosting models from XGBoost and LightGBM. The usage is identical with Scikit-learn, but you have to train your model using the respective Scikit-learn API.

model = xgb.XGBClassifier(max_depth=8)
model.fit(X, y)

tvm_model = hummingbird.ml.convert(model, "tvm", X)
np.testing.assert_equal(model.predict(X), tvm_model.predict(X))

A runnable script that contains examples above is available here. Also check out our notebook for more usage demonstration .

Benchmarks

The Hummingbird repository has a comprehensive benchmark script to compare the performance of various backends supported by Hummingbird such as PyTorch, ONNXRuntime, and TVM, against popular frameworks such as Scikit-learn, XGBoost, and LightGBM. The runtime is measured on real world datasets. Here, we show some of the results. We highly encourage you to try it for yourself, by following the instructions here.

We trained Scikit-learn RandomForestClassifier and XGBoost XGBClassifier on a batch X of size of 1000 to 50000 samples from each dataset, and measure the runtime of model.predict(X). The results are averaged over 100 iterations, and plotted with the TVM results normalized to 1. We can change the number of trees and maximum depth, but here we only show the results on 500 trees and maximum depth 8. The CPU used is Core-i7 8700K with 6 physical cores, and GPU is GTX 1070 TI.

Scikit-learn on CPU

These are CPU runtime comparison against Scikit-learn RandomForestClassifier,

using the batch size of 10000 and 50000. The result on the left is obtained with this command:

hummingbird/benchmarks/trees$ python run.py -operator rf -backend hb-tvm -niters 100 -batch_benchmark -batch_size 10000 -max_depth 8 -ntrees 500 -dataset fraud,epsilon,year,covtype,higgs

Accelerating Scikit-learn models on CPU up to 16x

As you can see (smaller numbers are better and everything is normalized to Hummingbird + TVM runtime), Hummingbird + TVM enable 4–16x speedups for smaller batch sizes in the first plot, and 1.5–3x speedups against Scikit-learn for larger batch sizes.

XGBoost on GPU

These are GPU runtime comparison against XGB XGBClassifier, using the batch size of 1000 and 10000. The result on the left is obtained with this command:

hummingbird/benchmarks/trees$ python run.py -operator xgb -gpu -backend hb-tvm -niters 100 -batch_benchmark -batch_size 1000 -max_depth 8 -ntrees 500 -dataset fraud,epsilon,year,covtype,higgs

Accelerating XGBoost models up to 8x on GPU

Again, the plots are normalized runtimes (lower is better) and everything is normalized to Hummingbird + TVM performance (so that everything fits nicely on one plot). Here we see that Hummingbird + TVM offer 2–8x performance improvements on GPUs for batch size 1000, and 0.9–4x on batch size 10000.

We also wanted to highlight one case where Hummingbird + TVM performs poorly. Since Hummingbird encodes tree traversals level-by-level, the performance scales linearly with each level. For trees that have bigger depth, we observed that Hummingbird performed considerably worse than XGBoost.

For models with deeper trees (here max depth=12) there are still improvements needing to be made

For example, here is the result on the trees with max depth 12 as compared to 8 in all the above figures. Compared to the results above, TVM results are either only slightly better, or sometimes significantly worse. In particular, on covtype dataset, which is the only multiclass dataset we used here, TVM is more than 4x slower than XGBoost. We also observed that the performance of XGBoost on GPU did not change much as we increased the tree depth. Closing this gap is an important ongoing challenge for Hummingbird's approach.

If you are interested in seeing further benchmarks, make sure and check out Microsoft’s OSDI paper: A Tensor Compiler for Unified Machine Learning Prediction Serving. Supun Nakandala, Karla Saur, Gyeong-In Yu, Konstantinos Karanasos, Carlo Curino, Markus Weimer, Matteo Interlandi. Especially Table 7 which shows Hummingbird+TVM with wider model coverage (and often performance) than Nvidia FIL (RAPIDS).

Further benefits of TVM — Portability with minimal dependencies

Model compilation in TVM has other benefits as well. Beyond improved performance, it also enables removing the dependency on Python. This enables one to compile your classical ML workload to a binary that then can be linked to an application written in C, C++, Rust, Java etc without worrying about the Python dependency of (eg) Scikit-learn.

Additionally, you can also deploy to constrained platforms such as microcontrollers (ARM/RISC-V) through µTVM, or WASM or WebGPU in the browser.

Conclusion

We have shown that Hummingbird, together with the new TVM backend, can bring great performance improvement for classical machine learning models. The biggest limitation, in terms of usability, is the requirement of static input size. Since TVM needs input shapes fixed ahead of time to generate optimized code, users have to specify the desired number of a batch size (the number of rows for tabular data) to be used for prediction. The compiled model would not run on test data having different number of rows. If your use case fits within this constraint, however, Hummingbird and its TVM backend can be a great tool to accelerate your models.

If you use Scikit-learn, XGBoost or LightGBM for deployment, we highly recommend taking a look at Hummingbird project. If you encounter any problem, please open an issue on Github. And TVM is also open source and you can find more info on the TVM homepage, or join us in the discuss forums.

And if you’d like to see more of these kinds of posts, please follow us here on Twitter, or feel free to reach out directly at info@octoml.ai.

OUR BLOG

Related Posts

All Posts
AI circles
2020-08-11
Build ML models once, run anywhere.

Apache TVM democratizes efficient machine learning with a unified software foundation. OctoML is building an MLops automation platform on top of it.

Blog Author - Luis Ceze
Luis Ceze
2020-12-04
Amplify ML Hardware Design Productivity with TVM-driven Hardware Simulation

Machine learning (ML) has spurred tremendous innovation in CPU, GPUs and NPUs.

Blog Author - Thierry Moreau
Thierry Moreau