🔬 Research Summary by Fraser Mince, an independent machine learning researcher and a Senior Software Engineer at Waymark.
[Original paper by Fraser Mince, Dzung Dinh, Jonas Kgomo, Neil Thompson, and Sara Hooker]
Overview: The portability of machine learning frameworks across hardware classes suffers from significant gaps. This paper measures portability between TPUs and GPUs for TensorFlow, PyTorch, and JAX. Significant disparities were found, with 22% of TensorFlow functions failing on GPUs and 44% of PyTorch functions failing on TPUs.
Introduction
Hardware plays a significant role in determining the trajectory of machine learning. Gaps in portability exacerbate inequality by restricting research to only those with large quantities of GPUs. Device lock-in can limit the range of ideas explored to only those compatible with particular hardware. But what is the current state of ML portability? For the most part, it has not been measured. This paper measures the difference in latency and error rates of a distribution of framework operations in PyTorch, TensorFlow, and JAX to create a big-picture view of the state of hardware portability in machine learning.
An evaluation framework is presented at the beginning of a time when hardware and software specialization is growing, and comparative evaluations are becoming more critical. The economics of chip specialization have dramatically changed over the last decade. With specialization comes radical changes in performance. These disparities will only increase, as will the importance of co-designing implementations to those chips. Thus, this type of quantitative portability analysis will only become more critical in the coming years to aid the design of efficient and portable tooling.
Key Insights
Measuring Framework Hardware Portability
Methodology
This paper seeks to measure portability between frameworks and hardware. Portability is defined as the “ease with which a machine learning workload (code, data, and models) can transfer between different hardware types.“ Failures are measured across three categories: “complete failure to run,” “partial failure to run,” and “intolerable latency.” Additionally, latency was measured for changes in performance between devices.
These portability measures are gathered over a large distribution of operations sampled from their frequency of use in the CodeParrot Clean dataset. Five functions were then sampled from each decile to create a representative distribution over frequency of use. The choice to focus on individual operations instead of workloads is intentional to avoid overfitting to common workloads and prevent sidelining more niche and diverse ideas that might be covered with a representative distribution.
Tests were then collected from the respective repos for each framework function. This allowed benchmarking to be as conservative as possible, given that the framework authors wrote the tests themselves. These tests were then adapted to be device-agnostic and run on various GPUs and TPUs. Specifically, the tests were run on T4 GPUs, A100 GPUs, TPU v2, and TPU v3.
Findings
These evaluations show that PyTorch and TensorFlow have large portability issues. On GPUs, 22% of the TensorFlow benchmark functions fail partially or completely. On TPUs, a remarkable 44% of PyTorch benchmark functions partially or completely fail. Even where functions are portable, significant performance gaps are found, with unexpected speedups and slowdowns in moving functions between the GPU and the TPU. 81.4% of functions in PyTorch exhibit more than a 10x slowdown when transferring functions from GPU to TPU.
The major exception to these portability gaps is JAX. Regarding latency, 91.8% of our function set runs faster on TPUs in JAX. Failure rates between GPUs and TPUs in JAX are very similar. With 98% succeeding on GPUs and 97% succeeding on TPUs. The paper hypothesizes this is due to a “first-class citizen effect,” where frameworks run better in environments they were built for. JAX, in particular, is built with XLA as a target in mind.
One question the paper explores is whether portability was impacted by the frequency of use of its functions. With the expectation that the more heavily used a function was, the more portable it would be given the incentive to support the top use cases. However, there is a fairly consistent failure and partial failure rate across deciles. This holds for all libraries, which suggests that frequency of use has not significantly affected the prioritization of support across hardware types.
The paper additionally examines the portability of the top 20 most frequent functions in each framework to explore if the frequency of use has affected portability. It is observed that some libraries like JAX have 0% failure rates in the top 20 and low overall failure rates across all functions. However, on TPUs, PyTorch surprisingly presents slightly higher failure rates in the top 20 functions than across all functions (46% vs 44%). TensorFlow also presents a considerably higher failure rate on GPUs in the top 20 functions (33% vs 22%). Across the board, the error rates between the deciles and the top 20 are quite similar, showing that even the most used functions do not benefit from significantly increased portability.
Additionally, the authors classify the errors into categories. The categories are: “type failure,” “not implemented,” “timeout,” “memory issue,” and “float precision error.” The error class with the greatest failures across the frameworks is “not implemented,” making it clear unimplemented operations play a key role in machine learning portability.
Between the lines
Machine learning portability is a critical gap in the field, and better tooling is needed. While tools like JAX have gone some way to bridge the gap, more effort is required. There are real consequences to device lock-in that are often underappreciated. Inequity and concentration of research in the hands of large labs are direct consequences of this lock-in. Innovative new ideas may never get off the ground because the researcher lacks the resources. Ideas that do not run well on popular hardware might never take off.
To fix this, we need better tooling. Tooling to allow kernels to be easier to write and understand by the average developer. Compilers to target intermediate representations and make efficient kernels approachable. Fortunately, better tooling is on the horizon. Languages like Triton and Mojo make ML kernels easier to write by targeting efficient intermediate representations. Organizations like tinycorp are working to get machine learning working on AMD chips. However, these projects are often quite early on in their development. As these projects, and projects like them, continue to mature, we will see an improvement in portability in the field of machine learning.