On the computational side, there have been confusions about how TPUs and GPUs relate to BERT. BERT was done with 4 TPU pods (256 TPU chips) in 4 days. Does this mean only Google can train a BERT model? Does this mean that GPUs are dead? There are two fundamental things to understand here: (1) A TPU is a matrix multiplication engine — it does matrix multiplication and matrix operations, but not much else. It is fast at computing matrix multiplication, but one has to understand that (2) the slowest thing in matrix multiplication is to get the elements from the main memory and load it into the processing unit. In other words, the most expensive part in matrix multiplication is memory loads. Note the computational load for BERT should be about 90% for matrix multiplication. From these facts, we can do a small technical analysis on this topic.
Bandwidth Model for TPUs and GPUs
Transformers for TPUs
A common operation in BERT is matrix multiplication A*B=C where A is 256×1024 and B is 1024×1024 in dimension. A TPU computes such a matrix multiplication by splitting the matrix into many smaller 128×128 matrix multiplications. This means we need to load 16 128×128 matrix tiles from matrix A — and due to the nature of matrix multiplication — we need to load 64 tiles from B for every tile in A. This is a total of 16*64=1024 128×128 loads. At 16-bit that is a total of 32 MB of data.
Now we make a simplification: We assume that there is no latency if we do two memory loads after each other, which is actually not too unreasonable since often you can hide memory access latency under thread parallelism. In simple words, this means: While we wait for one 128×128 matrix copy to complete, we already do the next one. In doing it this way, we only wait for the first memory copy and we do not wait for other copies. This is a core reason why GPUs are fast and why we use many threads in GPUs thus 0 latency for overlapping memory transfers is not too far off from the real world. Using this simplification, we can now plainly use the memory bandwidth to compute the time needed to load the memory for the matrix multiplication. If we look at the bandwidth of the TPU we find that we have 600 GB/s, so we need 5.2e-05 seconds to transfer the 32 MB of data.
Transformers on GPUs
For a GPU we have the same process, but we use smaller tiles with more processors. Similarly to the TPU, we use two loads in parallel to hide memory latency. For GPUs, however, we would have a tile size of 96×96 for 16-bit data. If we take a V100 Tesla GPU, then we can run 160 of these in parallel at full bandwidth with low memory latency. What this means compared to a TPU: Instead of 2 matrix units which can hold 128×128 matrices, the GPU has 160 units (80 SMs, 160 thread blocks, each thread block has two 96×96 matrices) which hold two 96×96 matrices. Again this ensures that we can hide the memory latency through parallelism.
If we repeat the calculation from the top we receive the following: For matrix A with 256×1024 we have 33 96×96 tiles; for B with 1024×1024 we have 121 96×96 tiles. In total, we need to do 33*121=3993 loads of size 96×96 for a total of 70 MB. A V100 runs at 900 GB/s and so the memory loads will take 7.6e-05 seconds. Thus our model predicts that a GPU is 32% slower than a TPU for this specific scenario. Note that matrix tiles stay the same for an RTX 2080 Ti GPU, but the memory bandwidth decreases to 616 GB/s. Which means an RTX 2080 Ti is 54% slower than a TPU.
Note that both TPU and GPUs with Tensor Cores compute the respective matrix multiplication tile in one cycle. Thus the computation is about equally fast — the difference is only in how the memory is loaded.
BERT Training Time Estimate for GPUs
Using this data, a GPU cluster of V100s/RTX 2080 Tis with good networking (Infiniband +56GBits/s) and good parallelization algorithms (for example using Microsoft’s CNTK) we can expect to train BERT on 64 GPUs (the equivalent to 4 TPU pods) in 5 1/3 days or 8 1/2 days. On an 8 GPU machine for V100/RTX 2080 Tis with any software and any parallelization algorithm (PyTorch, TensorFlow) one can expect to train BERT in 42 days or 68 days. For a standard 4 GPU desktop with RTX 2080 Ti (much cheaper than other options), one can expect to replicate BERT in 99 days.
Limitations of the Bandwidth Model
Note that all models are wrong, but some are useful. I would expect that this bandwidth model is in about 30% of the correct runtime values for TPU vs GPU.
The biggest limitation is that these calculations are for specific matrices sizes. Computational differences can be amplified for certain sizes. For example, if your batch-size is 128, there is a slight speedup for GPUs compared to TPUs. If you go below a batch size of 128 you can expect GPUs to be significantly faster; increasing the matrix B further makes TPUs better and better compared to GPUs. Decreasing the size of matrix B will make the performance of GPUs better. Note that the BERT paper optimized matrix A and B sizes for the TPU — one would not choose these dimensions if you train with a GPU. So this comparison might favor TPUs slightly.
Further direct limitations include fused operations. The TPU can calculate additional element-wise operations such as a non-linear activation function or a bias on the fly within a matrix multiplication. This means that the TPU does not need to load from slow global memory as often as a GPU. The GPU also supports these operations but NVIDIA has not implemented them and thus GPU users will not be able to benefit from this. Thus one can expect a slowdown of about 1.6% (loading and storing a 256×1024 matrix) for each element-wise operation for a GPU. For example, if you apply a non-linear function and a bias, then the TPU would be about 3.2% faster compared to GPUs in this scenario.
The Importance of 32-bit vs 16-bit vs 8-bit
If we repeat the same calculations from above for 32-bit values (64x64x tiles) we find that TPUs would be 5.3x faster. So the datatype size has a much larger effect than switching from TPU to GPU and vice versa.
TPUs do not support 8-bit training, but Turing GPUs do. So we can also have a look at how 8-bit matrix multiplication would impact performance. I published research on 8-bit models and it is not too difficult to train them with 8-bit alone. In fact, the literature on low-bit computing is quite rich. With 32-bit accumulation as supported by Turing GPUs 8-bit training should be even easier. If we can make 8-bit computing work for general models this would entail huge speedups for transformers. If we repeat the above calculations for 8-bit for GPUs (128×128 tile) we find that GPUs are 3.0x faster than TPUs. 8-bit computation on an affordable standard machine with 4 RTX 2080 Ti would take about 21 days. All of this makes 16-bit computational ability for a GPU and important criterion if you are looking for a GPU to work with transformers.
TPUs are about 32% to 54% faster for training BERT-like models. One can expect to replicate BERT on an 8 GPU machine within about 40 to 70 days. On a standard, affordable GPU machine with 4 GPUs one can expect to train BERT for about 99 days using 16-bit or about 21 days using 8-bit.