Skip to content

Conversation

mnehete32
Copy link
Contributor

Follow up of PR: #15635

Convolution Performance Results (Old)

FP32 (float32) Performance

Input Shape Kernel Shape Runs Time/Run (µs) FLOPs/Run GFLOPS
[19,19,256,16] [4,4,256,4096] 3 368871.67 137.42 GFLOP 372.55
[19,19,8,16] [4,4,8,128] 2992 399.90 133.69 MFLOP 334.32
[19,19,8,16] [4,4,8,130] 2948 407.43 135.78 MFLOP 333.27
[19,19,4,16] [2,2,4,4] 131072 8.05 642.82 kFLOP 79.89
[224,224,3,1] [3,3,3,8] 14358 103.77 20.90 MFLOP 201.38
[224,224,1,1] [2,2,1,8] 24576 55.27 2.78 MFLOP 50.39
[224,224,1,8] [2,2,1,8] 4489 437.02 22.28 MFLOP 50.98
[58,58,32,1] [3,3,32,64] 3468 368.43 115.40 MFLOP 313.23
[58,58,32,8] [3,3,32,64] 436 2888.17 923.24 MFLOP 319.66
[16,16,128,8] [3,3,128,512] 220 5489.08 1.85 GFLOP 336.83

FP16 (float16) Performance

Input Shape Kernel Shape Runs Time/Run (µs) FLOPs/Run GFLOPS
[19,19,256,16] [4,4,256,4096] 3 403320.33 137.42 GFLOP 340.73
[19,19,8,16] [4,4,8,128] 2244 448.02 133.69 MFLOP 298.41
[19,19,8,16] [4,4,8,130] 2211 455.74 135.78 MFLOP 297.94
[19,19,4,16] [2,2,4,4] 122880 8.63 642.82 kFLOP 74.47
[224,224,3,1] [3,3,3,8] 9572 116.88 20.90 MFLOP 178.78
[224,224,1,1] [2,2,1,8] 24576 60.58 2.78 MFLOP 45.97
[224,224,1,8] [2,2,1,8] 4489 474.92 22.28 MFLOP 46.91
[58,58,32,1] [3,3,32,64] 2601 411.38 115.40 MFLOP 280.53
[58,58,32,8] [3,3,32,64] 327 3302.93 923.24 MFLOP 279.52
[16,16,128,8] [3,3,128,512] 165 6181.65 1.85 GFLOP 299.09

Convolution Performance Results (New)

FP32 (float32) Performance

Input Shape Kernel Shape Runs Time/Run (µs) FLOPs/Run TFLOPS
[19,19,256,16] [4,4,256,4096] 12 87193.67 137.42 GFLOP 1.58
[19,19,8,16] [4,4,8,128] 10472 97.68 133.69 MFLOP 1.37
[19,19,8,16] [4,4,8,130] 5896 180.83 135.78 MFLOP 0.75
[19,19,4,16] [2,2,4,4] 40960 24.69 642.82 kFLOP 0.026
[224,224,3,1] [3,3,3,8] 4786 254.30 20.90 MFLOP 0.082
[224,224,1,1] [2,2,1,8] 8192 130.07 2.78 MFLOP 0.021
[224,224,1,8] [2,2,1,8] 4489 1038.51 22.28 MFLOP 0.021
[58,58,32,1] [3,3,32,64] 5202 199.36 115.40 MFLOP 0.579
[58,58,32,8] [3,3,32,64] 872 1248.11 923.24 MFLOP 0.740
[16,16,128,8] [3,3,128,512] 660 1631.43 1.85 GFLOP 1.13

FP16 (float16) Performance

Input Shape Kernel Shape Runs Time/Run (µs) FLOPs/Run TFLOPS
[19,19,256,16] [4,4,256,4096] 26 38639.81 137.42 GFLOP 3.56
[19,19,8,16] [4,4,8,128] 19448 51.80 133.69 MFLOP 2.58
[19,19,8,16] [4,4,8,130] 11792 88.61 135.78 MFLOP 1.53
[19,19,4,16] [2,2,4,4] 81920 13.16 642.82 kFLOP 0.049
[224,224,3,1] [3,3,3,8] 9572 126.75 20.90 MFLOP 0.165
[224,224,1,1] [2,2,1,8] 16384 67.38 2.78 MFLOP 0.041
[224,224,1,8] [2,2,1,8] 4489 529.18 22.28 MFLOP 0.042
[58,58,32,1] [3,3,32,64] 11271 92.85 115.40 MFLOP 1.24
[58,58,32,8] [3,3,32,64] 1744 588.97 923.24 MFLOP 1.57
[16,16,128,8] [3,3,128,512] 1320 781.39 1.85 GFLOP 2.37

Convolution Performance Comparison (Old vs New)

FP32 (float32)

Input Shape Kernel Shape Old GFLOPS New GFLOPS Improvement
[19,19,256,16] [4,4,256,4096] 372.55 1580 +4.2×
[19,19,8,16] [4,4,8,128] 334.32 1370 +4.1×
[19,19,8,16] [4,4,8,130] 333.27 751 +2.3×
[19,19,4,16] [2,2,4,4] 79.89 26.04 -67%
[224,224,3,1] [3,3,3,8] 201.38 82.17 -59%
[224,224,1,1] [2,2,1,8] 50.39 21.41 -57%
[224,224,1,8] [2,2,1,8] 50.98 21.45 -58%
[58,58,32,1] [3,3,32,64] 313.23 578.89 +1.85×
[58,58,32,8] [3,3,32,64] 319.66 739.71 +2.3×
[16,16,128,8] [3,3,128,512] 336.83 1130 +3.36×

FP16 (float16)

Input Shape Kernel Shape Old GFLOPS New GFLOPS Improvement
[19,19,256,16] [4,4,256,4096] 340.73 3560 +10.5×
[19,19,8,16] [4,4,8,128] 298.41 2580 +8.6×
[19,19,8,16] [4,4,8,130] 297.94 1530 +5.1×
[19,19,4,16] [2,2,4,4] 74.47 48.84 -34%
[224,224,3,1] [3,3,3,8] 178.78 164.86 -8%
[224,224,1,1] [2,2,1,8] 45.97 41.33 -10%
[224,224,1,8] [2,2,1,8] 46.91 42.10 -10%
[58,58,32,1] [3,3,32,64] 280.53 1240 +4.4×
[58,58,32,8] [3,3,32,64] 279.52 1570 +5.6×
[16,16,128,8] [3,3,128,512] 299.09 2370 +7.9×

Summary:

  • Large convolutions now see 3–10× improvement in GFLOPS.
  • Small convolutions may see lower GFLOPS due to memory-bound performance (shared memory), not compute.
  • FP16 gains are more significant than FP32 on large kernels.
  • used ggml_cuda_cast<T> cast, to make sure it doesnt break build.

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Sep 5, 2025
* removed flash-attenion definition
Copy link
Collaborator

@JohannesGaessler JohannesGaessler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you're going to make your own primitives anyways, take a look at mma.cuh. The WMMA interface NVIDIA provides for the "high-level" CUDA code is quite frankly terrible, so I exposed the tensor core PTX instructions (assembly equivalent). The practical upside is that you get a defined memory layout (important for mul_mat_q and FlashAttention but I think not here) and that you can use smaller matrix tiles (minimum is 16x8). The downside is that Volta and AMD are still lacking an implementation.

@Green-Sky
Copy link
Collaborator

Green-Sky commented Sep 5, 2025

It is becoming increasingly hard to test these kind of changes with sd.cpp, @ggerganov please sync ggml, it has been 2.5 weeks of rapid convolution development. :)

@JohannesGaessler
Copy link
Collaborator

I only monitor the llama.cpp and ggml repositories but when it comes to convolution kernels such as this it would also be fine for me if you open PRs in sd.cpp and tag me.

@ggerganov
Copy link
Member

ggml repo is up-to-date now

@Green-Sky
Copy link
Collaborator

Green-Sky commented Sep 5, 2025

@mnehete32 please run tests, the output seems to be broken.

output_2

@mnehete32
Copy link
Contributor Author

@Green-Sky Checking

@mnehete32
Copy link
Contributor Author

I think the kernel was not able to launch complete threads, as launcher launches warps per each WMMA_M, WMMA_N, i will work with launch fewer threads per block, also with the mma.cuh , it looks like I don't need shared memory to store results, I haven't checked completely yet. Also looking into it.

@JohannesGaessler
Copy link
Collaborator

it looks like I don't need shared memory to store results

FYI, for tensor cores you theoretically don't need shared memory at all. Each thread in a warp holds fractions of the input and output tiles in its registers. You only need shared memory to organize the data in such a way that the global memory accesses are coalesced (see mmf.cu) or in the case of WMMA to work around the memory layout being undefined.

@mnehete32
Copy link
Contributor Author

I thought because, output to thread mapping is unknown, it changes based on architecture. I first need to load output in shared memory before storing.

@JohannesGaessler
Copy link
Collaborator

If you read the PTX documentation you'll find that all tensor core instructions have a well-defined memory layout. It's only when you try to cover all tensor core instructions with a simple interface that you run into problems. Volta has 8x8 tensor cores. Turing, Ampere, and Ada Lovelace have 16x8 tensor cores (used by mma.cuh). Hopper has some special asynchronous tensor cores, Blackwell has yet another instructions. I don't think either of the latter two would fit WMMA, but the 16x8 instructions are still available.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants