go-torch is now 115x faster; a short note on engineering optimization
On the way of developing go-torch
, a PyTorch-like library in Go, one of the most critical operations, matrix multiplication (MatMul), started as a significant bottleneck. But I was able to optimize some stuff, and make our MatMul operate 115 times faster for a 1024x1024 matrix multiplication, dropping from an initial ~15 seconds to ~130 milliseconds!
The Starting Point: Naive Matrix Multiplication
Every matrix multiplication journey begins with the standard definition. For two matrices $A$ (of size $M \times K$) and $B$ (of size $K \times N$), the resulting matrix $C$ (of size $M \times N$) is defined by:
$$C_{ij} = \sum_{k=1}^{K} A_{ik} \cdot B_{kj}$$
A direct implementation of this involves three nested loops:
for i := 0; i < M; i++ { // Loop over rows of C (and A) for j := 0; j < N; j++ { // Loop over columns of C (and B) sum := 0.0 for k := 0; k < K; k++ { // Loop for the dot product sum += A[i][k] * B[k][j] // Accessing B column-wise } C[i][j] = sum }}
While correct, this “naive” approach is notoriously slow for large matrices. Our initial go-torch
implementation, took around 15 seconds for a 1024x1024 matrix multiplication.
Optimization Step 1: Go-routines for Parallelism & Cache Friendliness
The first leap in performance came from addressing two key aspects: leveraging multiple CPU cores and improving how we access data in memory.
Why Parallelism?
we know that Modern CPUs have multiple cores, each capable of executing instructions independently. So, calculation of each row (or even blocks of rows) in the output matrix $C$ is independent of other rows. This makes matrix multiplication an “embarrassingly parallel” problem, ideal for dividing the work.
Go makes concurrency straightforward with goroutines, which are lightweight, concurrently executing functions. We can spawn multiple goroutines, each responsible for computing a subset of the rows of the output matrix.
Why Cache Friendliness?
CPUs don’t fetch data directly from RAM every time; that’s too slow. They use smaller, much faster memory caches (L1, L2, L3). If it’s not (a “cache miss”), the CPU stalls while fetching it from RAM.
In the naive MatMul $A_{ik} \cdot B_{kj}$, if matrices are stored row-by-row (row-major order, common in C-like languages and Go slices), accessing $B_{kj}$ means jumping across memory for each element in a column of B. This leads to poor cache utilization.
The Solution: Transposition + Goroutines
-
Transpose Matrix B: We first transpose matrix $B$ to get $B^T$. Now, accessing a column of $B$ is equivalent to accessing a row of $B^T$. The multiplication becomes: $$C_{ij} = \sum_{k=1}^{K} A_{ik} \cdot B^T_{jk}$$ Now, when computing $C_{ij}$, we iterate through the $i$-th row of $A$ and the $j$-th row of $B^T$. Both are sequential memory accesses, which is much better for CPU caches!
-
Parallelize Row Computation: We divide the $M$ rows of the output matrix $C$ among
runtime.NumCPU()
goroutines. Each goroutine calculates its assigned chunk of rows using the transposed $B$. Async.WaitGroup
ensures all goroutines complete before we proceed.
func MatMulParallel(t1, t2 *Tensor) (*Tensor, error) { // ... (shape checks) ... M, K1 := t1.shape[0], t1.shape[1] K2, N := t2.shape[0], t2.shape[1] // ... (K1 == K2 check) ... K := K1
outData := make([]float64, M*N) t1Data := t1.GetData()
t2Transposed, _ := Transpose(t2) // Step 1: Transpose t2TransposedData := t2Transposed.GetData()
numGoroutines := runtime.NumCPU() rowsPerGoroutine := (M + numGoroutines - 1) / numGoroutines
var wg sync.WaitGroup for i := 0; i < numGoroutines; i++ { // Step 2: Parallelize startRow := i * rowsPerGoroutine
wg.Add(1) go func(sR, eR int) { defer wg.Done() for rIdx := sR; rIdx < eR; rIdx++ { for cIdx := 0; cIdx < N; cIdx++ { sum := 0.0 t1RowOffset := rIdx * K t2TRowOffset := cIdx * K // Accessing row of B^T for kIdx := 0; kIdx < K; kIdx++ { sum += t1Data[t1RowOffset+kIdx] * t2TransposedData[t2TRowOffset+kIdx] } outData[rIdx*N+cIdx] = sum } } }(startRow, endRow) } wg.Wait()
return out, nil}
was it good?: This first round of optimization brought the 1024x1024 MatMul time down significantly, for example, from ~15 seconds to around 720 milliseconds (a ~20x improvement!). But it’s still slower, atleast not in the range of actual py-torch. So, i thought to do some work with the BLAS.
Optimization Step 2: Enter BLAS - doing actually cool stuff
While our Go-parallel version was much better, it’s still not good in comparison to py-torch. For a 1024x1024 matrix, py-torch did the MatMul in 0.8 ms.
The 720ms, while a good Go result, still left a lot of performance on the table. I found these following reas to be making a bottleneck in performance:
- SIMD (Single Instruction, Multiple Data) Instructions: Modern CPUs have special instructions (like AVX, SSE) that can perform the same operation on multiple data elements simultaneously.
- Advanced Blocking/Tiling: Sophisticated algorithms that break matrices into cache-sized blocks to minimize RAM access.
- Fine-tuned Assembly: Hand-optimized assembly code for critical loops.
Use BLAS (Basic Linear Algebra Subprograms)
BLAS libraries are collections of routines that provide standard building blocks for performing basic vector and matrix operations. They are typically written in Fortran/C/Assembly and are extremely optimized by CPU firms and the HPC people.
Go can interface with C libraries using CGo. The gonum
suite of scientific libraries for Go provides gonum/mat
, which uses an system installed blas library (i installed openblas with mysys2) to optimize stuff
Implementation with gonum/mat
:
- Installation: We installed OpenBLAS on the system and ensured the MinGW-w64 toolchain (for GCC on Windows) was set up.
- Code Change: Our
MatMulTensor
function was modified:- For sufficiently large matrices (determined by a
blasThreshold
), we convert ourgo-torch
tensors intogonum/mat.Dense
matrices. - We then call the
Mul
method on thesegonum
matrices (e.g.,gonumResult.Mul(gonumT1, gonumT2)
). ThisMul
call, under the hood, invokes the highly optimizedDGEMM
(Double-precision General Matrix Multiply) routine from the system’s OpenBLAS library via CGo. - The result from
gonum
is then copied back into ourgo-torch
tensor format. - For smaller matrices (below
blasThreshold
), we fall back to our Go-parallel version to avoid CGo call overhead and data conversion costs.
- For sufficiently large matrices (determined by a
import "gonum.org/v1/gonum/mat"
const blasThreshold = 64 // modify as needed
func MatMulWithBLAS(t1, t2 *Tensor) (*Tensor, error) { M, K, N := t1.shape[0], t1.shape[1], t2.shape[1] // Assuming K1==K2
if M > blasThreshold && K > blasThreshold && N > blasThreshold { // Heuristic // BLAS Path gonumT1 := mat.NewDense(M, K, t1.GetData()) gonumT2 := mat.NewDense(K, N, t2.GetData()) gonumOut := mat.NewDense(M, N, nil)
gonumOut.Mul(gonumT1, gonumT2) // This calls OpenBLAS DGEMM
outData := make([]float64, M*N) copy(outData, gonumOut.RawMatrix().Data) return NewTensor([]int{M, N}, outData) } else { // if blas not available, fall back to tensor paralleliztion (step 1) return MatMulParallel(t1, t2) }}
The 83x Speedup! The 1024x1024 MatMul time dropped from ~720ms (our Go-parallel best) to ~164 milliseconds!
- Initial (Naive): ~15,000 ms
- After Go Parallelism & Transpose: ~720 ms (~20x faster than naive)
- After BLAS Integration: ~130 ms (~4.4x faster than Go-parallel, and ~115x faster than the initial estimate!)
Fine-tuning BLAS: Threading
BLAS libraries like OpenBLAS have their own internal threading. I did some experimentation by setting the OPENBLAS_NUM_THREADS
environment variable:
OPENBLAS_NUM_THREADS=1
: 210 msOPENBLAS_NUM_THREADS=4
: 130.50 ms (Optimal for the test machine, likely matching physical cores)OPENBLAS_NUM_THREADS=8
: 172 ms (Slightly slower, possibly due to hyperthreading overhead)
Current Benchmarks
After these optimizations, go-torch
now scores:
Operation | Time (1024x1024) |
---|---|
Matrix Multiply (go-torch with BLAS) | 130.50 ms |
Matrix Multiply (Previous Go-Parallel) | ~720 ms |
Matrix Multiply (Initial Naive) | ~15,000 ms |
Benchmark Detail | 128x128 | 512x512 | 1024x1024 |
---|---|---|---|
Matrix Multiply | 510.33 µs | 13.54 ms | 130.50 ms |
Element-wise Add | 71.72 µs | 1.29 ms | 4.13 ms |
Element-wise Mul | 47.83 µs | 1.63 ms | 3.91 ms |
ReLU Activation | 121.18 µs | 1.75 ms | 6.45 ms |
Linear Layer Forward (B32,I128,O10) | 71.93 µs | ||
CrossEntropyLoss (B32,C10) | 11.16 µs | ||
Full Fwd-Bwd (Net:128-256-10, B32) | 4.02 ms |
Next Steps and The Quest for More Speed
Even though 130ms is great, still we have memory overhead while computing heavy matrices, especially because we copy the matrices and compute on their duplicates. I’m now checking out the sync.pool methods and how could i optimize data transfer and memory overhead using pool. Let’s see how far we could optimize.