# basic-sgemm **Repository Path**: haukzero/basic-sgemm ## Basic Information - **Project Name**: basic-sgemm - **Description**: 针对CUDA上的单精度矩阵乘法做一些比较容易想到的优化 - **Primary Language**: Unknown - **License**: Not specified - **Default Branch**: master - **Homepage**: None - **GVP Project**: No ## Statistics - **Stars**: 1 - **Forks**: 0 - **Created**: 2025-02-25 - **Last Updated**: 2025-02-25 ## Categories & Tags **Categories**: Uncategorized **Tags**: cuda-programming, sgemm ## README # CUDA SGEMM 基础优化 针对单精度矩阵乘法, 做出一些比较容易想到的优化技巧 在 4060 上实验效果: ``` ===== M = 256, N = 256, K = 1024 ===== cublas | time: 0.038912 ms | max diff: N/A v1 | time: 1.213440 ms | max diff: 0.000046 v2 | time: 0.249856 ms | max diff: 0.000046 v3 | time: 0.196448 ms | max diff: 0.000046 v4 | time: 0.116736 ms | max diff: 0.000046 v5 | time: 0.132000 ms | max diff: 0.000046 v6 | time: 0.121856 ms | max diff: 0.000046 v7 | time: 0.139264 ms | max diff: 0.000046 v8 | time: 0.155648 ms | max diff: 0.000046 ===== M = 512, N = 512, K = 1024 ===== cublas | time: 0.114688 ms | max diff: N/A v1 | time: 0.908224 ms | max diff: 0.000061 v2 | time: 0.698368 ms | max diff: 0.000061 v3 | time: 0.621568 ms | max diff: 0.000061 v4 | time: 0.149472 ms | max diff: 0.000061 v5 | time: 0.158720 ms | max diff: 0.000061 v6 | time: 0.220160 ms | max diff: 0.000061 v7 | time: 0.186304 ms | max diff: 0.000061 v8 | time: 0.176128 ms | max diff: 0.000061 ===== M = 1024, N = 1024, K = 1024 ===== cublas | time: 0.401408 ms | max diff: N/A v1 | time: 3.567616 ms | max diff: 0.000072 v2 | time: 2.742272 ms | max diff: 0.000072 v3 | time: 2.458624 ms | max diff: 0.000072 v4 | time: 0.491520 ms | max diff: 0.000072 v5 | time: 0.535296 ms | max diff: 0.000072 v6 | time: 0.604160 ms | max diff: 0.000072 v7 | time: 0.578560 ms | max diff: 0.000072 v8 | time: 0.616448 ms | max diff: 0.000072 ===== M = 2048, N = 2048, K = 1024 ===== cublas | time: 0.966656 ms | max diff: N/A v1 | time: 8.936448 ms | max diff: 0.000000 v2 | time: 6.849536 ms | max diff: 0.000000 v3 | time: 6.164480 ms | max diff: 0.000000 v4 | time: 1.184768 ms | max diff: 0.000000 v5 | time: 1.287168 ms | max diff: 0.000000 v6 | time: 1.438720 ms | max diff: 0.000000 v7 | time: 1.427456 ms | max diff: 0.000000 v8 | time: 1.478656 ms | max diff: 0.000000 ===== M = 4096, N = 4096, K = 1024 ===== cublas | time: 3.634176 ms | max diff: N/A v1 | time: 36.820992 ms | max diff: 0.000000 v2 | time: 30.008320 ms | max diff: 0.000000 v3 | time: 27.866112 ms | max diff: 0.000000 v4 | time: 4.813824 ms | max diff: 0.000000 v5 | time: 5.111808 ms | max diff: 0.000000 v6 | time: 5.800960 ms | max diff: 0.000000 v7 | time: 5.719040 ms | max diff: 0.000000 v8 | time: 5.867520 ms | max diff: 0.000000 ===== M = 8192, N = 8192, K = 1024 ===== cublas | time: 14.385152 ms | max diff: N/A v1 | time: 179.564545 ms | max diff: 0.000000 v2 | time: 136.337402 ms | max diff: 0.000000 v3 | time: 112.858109 ms | max diff: 0.000000 v4 | time: 19.419136 ms | max diff: 0.000000 v5 | time: 20.897793 ms | max diff: 0.000000 v6 | time: 26.982401 ms | max diff: 0.000000 v7 | time: 25.618431 ms | max diff: 0.000000 v8 | time: 24.745983 ms | max diff: 0.000000 ``` ## 主要优化点 - 使用 shared memory 加快从 global memory 的数据读取 - 使用向量化 `float4` 加载 - 使用 register 加快从 shared memory 的数据读取 - 使用 double buffer 隐藏从 gmem 读取数据到 smem 的时间开销 - 使用 padding, swizzle, 更改数据加载方式三种方式来缓解 bank conflict