Imagine you are an ML Scientist setting out to build a brand new model based on the latest and greatest ML research. You develop an innovative new layer which makes the model perform on its specific task more accurately, but it is 10x slower to train.
Do you abandon this model because the training process is prohibitively slow, or continue forward? Continuing forward requires framework modifications, or hand-optimized device specific code which is written in a programming language you may not know well. What if you could solve this problem without ever leaving Python, and integrate the solution into your existing workflow in just a few lines of code? You can do this by using Apache TVM — the open source ML stack for performance and portability — whose mission is to enable access to high performance machine learning anywhere for everyone.
By using TVMScript, TVM's embedded domain specific language (DSL), OctoML engineers are able to demonstrate a 20x speedup over a straightforward PyTorch implementation on CPU, and a 1.3x speedup over handwritten CUDA implementation on GPU for a real-world kernel.
Not only can TVMScript dramatically improve your performance, but it enables you to write straightforward for-loops instead of terse array code. It also provides seamless integration for both CPU and GPU in Python, and mitigates the need to learn a separate application flow.
TVMScript is a new way to use the full power of TVM directly in Python. TVMScript enables end-users to write low-level array programs in a natural loop style and accelerate them. One way to conceptualize this is as a Python-friendly way to write CUDA-like compute code directly. In the coming year, these TVMScript programs will benefit from TVM’s new auto-tuning framework’s ability to automatically optimize for a variety of hardware. For more details on this on-going work on auto-tuning you can check out the recent GTC talk from OctoML.
By moving even the lowest level of device programming APIs into Python it is possible for engineers to quickly inspect, iterate on, and debug code much like what is being done in JAX and PyTorch for higher level computation graphs.
All that is required to get started is a kernel in the TVMScript DSL, and a few lines of code to integrate it with your existing Python code. TVMScript is human readable and editable in Python—making debugging easier. For example, a user can dump intermediate state from compiler passes to understand performance or how code is transformed.
The below code is a simple “Hello World” program of TVMScript. It implements an element wise addition operation over two same sized vectors. TVMScript takes its arguments in an “in-out” style, where what would traditionally be the return value is passed in as an argument as well.