Skip to content

HIP: Enable Matrix cores for MMQ Kernels, Enable stream-K for CDNA 3 #14624

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 21 commits into
base: master
Choose a base branch
from

Conversation

deepsek
Copy link

@deepsek deepsek commented Jul 10, 2025

  • Added Matrix cores support (MFMA instructions) for MMQ kernels.

  • Enable stream-K for CDNA3 to work with MMQ kernels.

  • Removed usage of WARP_SIZE hardcoded constant in MMQ kernels.

  • NOTE: Thoughts on removing all uses of hardcoded const specific to only NVIDIA (like WARP_SIZE) in order to support other GPUs?

@JohannesGaessler @ggerganov
P.S. I am part of an AMD team actively working on enabling AMD feature set on llama.cpp. We would like to get on call to discuss some future PR plans for additional backends, flash attention changes, etc.

EDIT:
Update to add some performance charts for DeepSeekV3 model.

Upstream vs ROCm Fork Development
image

MI300X vs H100 Throughput Test
image

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Jul 10, 2025
@JohannesGaessler
Copy link
Collaborator

I would be happy to get on a call with you to discuss AMD hardware support, my email address can be found on my Github page.

@ggerganov
Copy link
Member

P.S. I am part of an AMD team actively working on enabling AMD feature set on llama.cpp. We would like to get on call to discuss some future PR plans for additional backends, flash attention changes, etc.

@deepsek Thanks for the contribution and for reaching out. On topics related to the CUDA backend, @JohannesGaessler is the best person to consult with. For additional backends, @slaren can provide guidelines and advice. I'll be happy to provide input on any matters as well.

I am also available for call - feel free to contact me.

@Dampfinchen
Copy link

Dampfinchen commented Jul 11, 2025

Very nice to see the initiative. I assume improvements made for CDNA will also swap into the consumer side next year when UDNA releases. So this is exciting news for the future of AMD products!

@IMbackK
Copy link
Collaborator

IMbackK commented Jul 12, 2025

This certainly is good news

@JohannesGaessler
Copy link
Collaborator

Sorry, I wanted to ask: @IMbackK since you've been working on AMD support, are you interested in joining the discussion?

@IMbackK
Copy link
Collaborator

IMbackK commented Jul 14, 2025

Sorry, I wanted to ask: @IMbackK since you've been working on AMD support, are you interested in joining the discussion?

Yes, certainly. It would help to avoid duplication of effort. i can be reached via email at uvos.xyz user carl

@deepsek deepsek requested a review from ngxson as a code owner July 15, 2025 16:53
@github-actions github-actions bot added the devops improvements to build systems and github actions label Jul 15, 2025
@deepsek
Copy link
Author

deepsek commented Jul 21, 2025

Hi @JohannesGaessler, is there any blocker for merging this PR to the main branch?

@IMbackK
Copy link
Collaborator

IMbackK commented Jul 21, 2025

@deepsek There are a few small things as discussed, better naming for this mfma path so that a rdna wmma solution can be added later without the nameing being strange is one thing, use of two V_MFMA_I32_16X16X16I8 instructions on gfx908 and gfx90a, even if this path is not chosen for those, to ease maintainability is another.

I would also like to try this myself on gfx94x somehow and i am not sure what the state is with regard to access to amds cloud for maintenance of a gfx94x specific code path, maybe @ggerganov can also comment on that. A problem here being that after cdna2/gfx90a/mi210 AMD has not made any further CDNA devices that are in a pcie addon board form factor, so out side of the acquisition of an entire mi300 oam machine no one can simply add a CDNA3/gfx94x/MI3xx compatible card to their system.

@IMbackK
Copy link
Collaborator

IMbackK commented Jul 23, 2025

im not sure its faster at any batch size than the combination of dp4a and bf16 rocblas. needs further examination, which i cant do due to the above denial.

Also yes also the q2_K issue could simply be resolved by adjusting ggml_cuda_should_use_mmq since it exists at large batch sizes

@deepsek
Copy link
Author

deepsek commented Jul 23, 2025

Further i think it needs to be examined if this pr is necessary at all given that it would seam that the current code path taken (dequant->rocblas) could perform better than this pr if run with export ROCBLAS_USE_HIPBLASLT=1 which causes rocblas to cross dispatch to hipblaslt or if bf16 is used as a dequant target.

I'm leaning to agree with @JohannesGaessler on this. While the prompt size 512+ for q2_k is lower right now, that should be fixed soon. But we can lean to dp4a, if I'm not able to resolve it soon.

Additionally, there is a mfma instr with 4x4x4_16B tiles, which could potentially give better perf than the hipblaslt idea. But I don't have the bandwidth to try that since we are up to par with the current 32x32x16 tile right now.

Also, regarding using hipblaslt right now. There is an issue with it right now, it will cause a core dump. We have internal tickets investigating. I also additionally need to explore code on test-backend-ops. test-backend-ops seems to jump from size 8 to 512 directly skipping the other sizes.

Anyways, here are some numbers for q2_K:

PR:
| model                          |       size |     params | backend    | ngl |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | --------------: | -------------------: |
| llama 8B Q2_K - Medium         |   2.95 GiB |     8.03 B | ROCm       |  99 |            pp16 |        665.97 ± 1.32 |
| llama 8B Q2_K - Medium         |   2.95 GiB |     8.03 B | ROCm       |  99 |            pp32 |       1290.22 ± 1.64 |
| llama 8B Q2_K - Medium         |   2.95 GiB |     8.03 B | ROCm       |  99 |            pp64 |       3041.69 ± 7.28 |
| llama 8B Q2_K - Medium         |   2.95 GiB |     8.03 B | ROCm       |  99 |            pp96 |       3231.11 ± 2.29 |
| llama 8B Q2_K - Medium         |   2.95 GiB |     8.03 B | ROCm       |  99 |           pp128 |       3886.78 ± 7.50 |
| llama 8B Q2_K - Medium         |   2.95 GiB |     8.03 B | ROCm       |  99 |           pp256 |       4737.21 ± 4.33 |
| llama 8B Q2_K - Medium         |   2.95 GiB |     8.03 B | ROCm       |  99 |           pp512 |       5095.03 ± 4.94 |
| llama 8B Q2_K - Medium         |   2.95 GiB |     8.03 B | ROCm       |  99 |          pp1024 |      4946.90 ± 62.20 |
| llama 8B Q2_K - Medium         |   2.95 GiB |     8.03 B | ROCm       |  99 |          pp2048 |      4714.18 ± 18.66 |
| llama 8B Q2_K - Medium         |   2.95 GiB |     8.03 B | ROCm       |  99 |          pp4096 |       4160.67 ± 3.34 |

Main branch (rocblas):
| model                          |       size |     params | backend    | ngl |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | --------------: | -------------------: |
| llama 8B Q2_K - Medium         |   2.95 GiB |     8.03 B | ROCm       |  99 |            pp16 |        402.83 ± 0.59 |
| llama 8B Q2_K - Medium         |   2.95 GiB |     8.03 B | ROCm       |  99 |            pp32 |        510.58 ± 0.13 |
| llama 8B Q2_K - Medium         |   2.95 GiB |     8.03 B | ROCm       |  99 |            pp64 |        981.36 ± 0.51 |
| llama 8B Q2_K - Medium         |   2.95 GiB |     8.03 B | ROCm       |  99 |            pp96 |       1451.36 ± 0.88 |
| llama 8B Q2_K - Medium         |   2.95 GiB |     8.03 B | ROCm       |  99 |           pp128 |       1911.94 ± 1.80 |
| llama 8B Q2_K - Medium         |   2.95 GiB |     8.03 B | ROCm       |  99 |           pp256 |       3513.48 ± 3.58 |
| llama 8B Q2_K - Medium         |   2.95 GiB |     8.03 B | ROCm       |  99 |           pp512 |       6198.16 ± 8.92 |
| llama 8B Q2_K - Medium         |   2.95 GiB |     8.03 B | ROCm       |  99 |          pp1024 |      6000.75 ± 83.56 |
| llama 8B Q2_K - Medium         |   2.95 GiB |     8.03 B | ROCm       |  99 |          pp2048 |      5622.92 ± 23.35 |
| llama 8B Q2_K - Medium         |   2.95 GiB |     8.03 B | ROCm       |  99 |          pp4096 |      4906.49 ± 10.95 |

Main branch(hipblaslt):
| model                          |       size |     params | backend    | ngl |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | --------------: | -------------------: |
hipBLASLt error: Heuristic Fetch Failed!
This message will be only be displayed once, unless the ROCBLAS_VERBOSE_HIPBLASLT_ERROR environment variable is set.

rocBLAS warning: hipBlasLT failed, falling back to tensile. 
This message will be only be displayed once, unless the ROCBLAS_VERBOSE_TENSILE_ERROR environment variable is set.
| llama 8B Q2_K - Medium         |   2.95 GiB |     8.03 B | ROCm       |  99 |            pp16 |        254.02 ± 0.70 |
Memory access fault by GPU node-9 (Agent handle: 0x617180ea0450) on address 0x7c4a13401000. Reason: Unknown.
GPU core dump created: gpucore.218659
Aborted (core dumped)

test-backend-ops (PR):
 MUL_MAT(type_a=f32,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):               848 runs -  1180.60 us/run -  60.13 GFLOP/run -  50.93 TFLOPS
  MUL_MAT(type_a=f16,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):              5412 runs -   184.81 us/run -  60.13 GFLOP/run - 325.36 TFLOPS
  MUL_MAT(type_a=bf16,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     6242 runs -   160.21 us/run -  60.13 GFLOP/run - 375.31 TFLOPS
  MUL_MAT(type_a=q4_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     3736 runs -   267.77 us/run -  60.13 GFLOP/run - 224.56 TFLOPS
  MUL_MAT(type_a=q4_1,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     3746 runs -   267.07 us/run -  60.13 GFLOP/run - 225.15 TFLOPS
  MUL_MAT(type_a=q5_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     3368 runs -   297.88 us/run -  60.13 GFLOP/run - 201.86 TFLOPS
  MUL_MAT(type_a=q5_1,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     3452 runs -   289.73 us/run -  60.13 GFLOP/run - 207.53 TFLOPS
  MUL_MAT(type_a=q8_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     3702 runs -   270.17 us/run -  60.13 GFLOP/run - 222.56 TFLOPS
  MUL_MAT(type_a=q2_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     1190 runs -   841.06 us/run -  60.13 GFLOP/run -  71.49 TFLOPS
  MUL_MAT(type_a=q3_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     2110 runs -   474.26 us/run -  60.13 GFLOP/run - 126.79 TFLOPS
  MUL_MAT(type_a=q4_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     3588 runs -   278.75 us/run -  60.13 GFLOP/run - 215.71 TFLOPS
  MUL_MAT(type_a=q5_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     3102 runs -   322.38 us/run -  60.13 GFLOP/run - 186.51 TFLOPS
  MUL_MAT(type_a=q6_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     2132 runs -   469.15 us/run -  60.13 GFLOP/run - 128.17 TFLOPS
  MUL_MAT(type_a=iq2_xxs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                  3274 runs -   305.44 us/run -  60.13 GFLOP/run - 196.86 TFLOPS
  MUL_MAT(type_a=iq2_xs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                   2022 runs -   494.72 us/run -  60.13 GFLOP/run - 121.54 TFLOPS
  MUL_MAT(type_a=iq2_s,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                    2314 runs -   432.47 us/run -  60.13 GFLOP/run - 139.04 TFLOPS
  MUL_MAT(type_a=iq3_xxs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                  3290 runs -   303.99 us/run -  60.13 GFLOP/run - 197.80 TFLOPS
  MUL_MAT(type_a=iq1_s,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                    3438 runs -   290.90 us/run -  60.13 GFLOP/run - 206.70 TFLOPS
  MUL_MAT(type_a=iq1_m,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                    3128 runs -   319.84 us/run -  60.13 GFLOP/run - 188.00 TFLOPS
  MUL_MAT(type_a=iq4_nl,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                   2848 runs -   351.19 us/run -  60.13 GFLOP/run - 171.22 TFLOPS
  MUL_MAT(type_a=iq3_s,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                    3022 runs -   331.09 us/run -  60.13 GFLOP/run - 181.61 TFLOPS
  MUL_MAT(type_a=iq4_xs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                   2828 runs -   353.66 us/run -  60.13 GFLOP/run - 170.02 TFLOPS

test-backend-ops (main rocblas):
  MUL_MAT(type_a=f32,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):               278 runs -  3600.55 us/run -  60.13 GFLOP/run -  16.70 TFLOPS
  MUL_MAT(type_a=f16,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):              1672 runs -   598.72 us/run -  60.13 GFLOP/run - 100.43 TFLOPS
  MUL_MAT(type_a=bf16,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     6190 runs -   161.55 us/run -  60.13 GFLOP/run - 372.19 TFLOPS
  MUL_MAT(type_a=q4_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     1394 runs -   717.78 us/run -  60.13 GFLOP/run -  83.77 TFLOPS
  MUL_MAT(type_a=q4_1,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     1394 runs -   717.75 us/run -  60.13 GFLOP/run -  83.77 TFLOPS
  MUL_MAT(type_a=q5_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     1470 runs -   680.61 us/run -  60.13 GFLOP/run -  88.35 TFLOPS
  MUL_MAT(type_a=q5_1,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     1474 runs -   678.46 us/run -  60.13 GFLOP/run -  88.63 TFLOPS
  MUL_MAT(type_a=q8_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     1564 runs -   639.65 us/run -  60.13 GFLOP/run -  94.00 TFLOPS
  MUL_MAT(type_a=q2_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     1392 runs -   719.33 us/run -  60.13 GFLOP/run -  83.59 TFLOPS
  MUL_MAT(type_a=q3_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     1394 runs -   717.84 us/run -  60.13 GFLOP/run -  83.76 TFLOPS
  MUL_MAT(type_a=q4_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     1390 runs -   719.74 us/run -  60.13 GFLOP/run -  83.54 TFLOPS
  MUL_MAT(type_a=q5_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     1394 runs -   717.56 us/run -  60.13 GFLOP/run -  83.80 TFLOPS
  MUL_MAT(type_a=q6_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     1392 runs -   718.58 us/run -  60.13 GFLOP/run -  83.68 TFLOPS
  MUL_MAT(type_a=iq2_xxs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                  1400 runs -   715.26 us/run -  60.13 GFLOP/run -  84.07 TFLOPS
  MUL_MAT(type_a=iq2_xs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                   1390 runs -   720.09 us/run -  60.13 GFLOP/run -  83.50 TFLOPS
  MUL_MAT(type_a=iq2_s,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                    1398 runs -   715.94 us/run -  60.13 GFLOP/run -  83.99 TFLOPS
  MUL_MAT(type_a=iq3_xxs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                  1398 runs -   715.32 us/run -  60.13 GFLOP/run -  84.06 TFLOPS
  MUL_MAT(type_a=iq1_s,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                    1398 runs -   715.35 us/run -  60.13 GFLOP/run -  84.06 TFLOPS
  MUL_MAT(type_a=iq1_m,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                    1398 runs -   715.48 us/run -  60.13 GFLOP/run -  84.04 TFLOPS
  MUL_MAT(type_a=iq4_nl,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                   1394 runs -   718.19 us/run -  60.13 GFLOP/run -  83.72 TFLOPS
  MUL_MAT(type_a=iq3_s,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                    1398 runs -   716.01 us/run -  60.13 GFLOP/run -  83.98 TFLOPS
  MUL_MAT(type_a=iq4_xs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                   1398 runs -   715.74 us/run -  60.13 GFLOP/run -  84.01 TFLOPS

test-backend-ops (main hipblaslt):
  MUL_MAT(type_a=f32,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):               826 runs -  1211.11 us/run -  60.13 GFLOP/run -  49.65 TFLOPS
  MUL_MAT(type_a=f16,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):              5376 runs -   186.04 us/run -  60.13 GFLOP/run - 323.21 TFLOPS
  MUL_MAT(type_a=bf16,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     6216 runs -   160.91 us/run -  60.13 GFLOP/run - 373.68 TFLOPS
  MUL_MAT(type_a=q4_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     3118 runs -   320.75 us/run -  60.13 GFLOP/run - 187.47 TFLOPS
  MUL_MAT(type_a=q4_1,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     3124 runs -   320.21 us/run -  60.13 GFLOP/run - 187.78 TFLOPS
  MUL_MAT(type_a=q5_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     3616 runs -   276.57 us/run -  60.13 GFLOP/run - 217.41 TFLOPS
  MUL_MAT(type_a=q5_1,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     3674 runs -   272.25 us/run -  60.13 GFLOP/run - 220.86 TFLOPS
  MUL_MAT(type_a=q8_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     4244 runs -   235.64 us/run -  60.13 GFLOP/run - 255.18 TFLOPS
  MUL_MAT(type_a=q2_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     3130 runs -   319.53 us/run -  60.13 GFLOP/run - 188.18 TFLOPS
  MUL_MAT(type_a=q3_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     3114 runs -   321.17 us/run -  60.13 GFLOP/run - 187.22 TFLOPS
  MUL_MAT(type_a=q4_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     3062 runs -   326.76 us/run -  60.13 GFLOP/run - 184.01 TFLOPS
  MUL_MAT(type_a=q5_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     3100 runs -   322.71 us/run -  60.13 GFLOP/run - 186.33 TFLOPS
  MUL_MAT(type_a=q6_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     3112 runs -   321.39 us/run -  60.13 GFLOP/run - 187.09 TFLOPS
  MUL_MAT(type_a=iq2_xxs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                  3126 runs -   319.95 us/run -  60.13 GFLOP/run - 187.93 TFLOPS
  MUL_MAT(type_a=iq2_xs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                   3140 runs -   318.61 us/run -  60.13 GFLOP/run - 188.72 TFLOPS
  MUL_MAT(type_a=iq2_s,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                    3140 runs -   318.60 us/run -  60.13 GFLOP/run - 188.73 TFLOPS
  MUL_MAT(type_a=iq3_xxs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                  3124 runs -   320.13 us/run -  60.13 GFLOP/run - 187.83 TFLOPS
  MUL_MAT(type_a=iq1_s,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                    3146 runs -   318.05 us/run -  60.13 GFLOP/run - 189.06 TFLOPS
  MUL_MAT(type_a=iq1_m,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                    3128 runs -   319.80 us/run -  60.13 GFLOP/run - 188.02 TFLOPS
  MUL_MAT(type_a=iq4_nl,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                   3116 runs -   321.04 us/run -  60.13 GFLOP/run - 187.30 TFLOPS
  MUL_MAT(type_a=iq3_s,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                    3116 runs -   321.05 us/run -  60.13 GFLOP/run - 187.29 TFLOPS
  MUL_MAT(type_a=iq4_xs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                   3112 runs -   321.45 us/run -  60.13 GFLOP/run - 187.06 TFLOPS

@IMbackK
Copy link
Collaborator

IMbackK commented Jul 23, 2025

Further i think it needs to be examined if this pr is necessary at all given that it would seam that the current code path taken (dequant->rocblas) could perform better than this pr if run with export ROCBLAS_USE_HIPBLASLT=1 which causes rocblas to cross dispatch to hipblaslt or if bf16 is used as a dequant target.

I'm leaning to agree with @JohannesGaessler on this. While the prompt size 512+ for q2_k is lower right now, that should be fixed soon. But we can lean to dp4a, if I'm not able to resolve it soon.

The low performance with q2_k is mainly vs rocblas/hipblaslt not dp4a as dp4a is only used for small batch sizes and the issue is mainly at large batch sizes, but yes we can just use the rocblas/hipblaslt code path for this datatype.

Your benchmark results suggest that like CDNA1/2 with rocblas, on CDNA3 with hipblaslt the performance for large batch sizes varies with datatype and sometimes hipblaslt wins and sometimes this pr wins. This is not a big problem, but ideally this pr would also change ggml_cuda_should_use_mmq to limit its own applicability to the fast datatypes. Since this is now the same across all CDNA devices the code path should ideally be the same across all CDNA devices. But since both hipblaslt and rocblas are currently also broken on CDNA3 we might want to skip the blas path on CDNA3 unconditionally, as this pr dose.

Please lean on your BLAS colleagues to have rocblas cross dispatch to hipblaslt automatically on GFX942 where this is faster, as it already dose on GFX12

So to recap these are the steps to take in my opinion:

  1. change ggml_cuda_should_use_mmq to return false for q2_k in the amd_mfma case
  2. hopefully fix the gfx906 regression by trying nwarps = 4
  3. merge this pr
  4. flow up pr enables this pr for gfx908/gfx90a on datatypes where it outperforms rocblas
    • i will do this if you like
  5. fix the hipblaslt bug
  6. limit the applicability of this pr to the same fast datatypes on gfx942 as on gfx908/90a from step 4 and use hipblaslt on gfx942

@JohannesGaessler
Copy link
Collaborator

change ggml_cuda_should_use_mmq to return false for q2_k in the amd_mfma case

To clarify, the function should return true for small batch sizes and false for large batch sizes with the boundary chosen in such a way that maximizes performance.

@deepsek
Copy link
Author

deepsek commented Jul 24, 2025

@JohannesGaessler @IMbackK

  • I've revised my design. All quants now use the same tile size, number of warps, MFMA instr, etc. With this revision, for ALL quants (including q2_K), PR code path with mfma will outperform the main branch on CDNA3. All checks are passed. Find numbers and results in drop down below:
  • I have added the code path for gfx906 with nwarps==4 with this commit too.
    • I have requested access to a Vega20 system, but it might get delayed. @0cc4m, would appreciate it if you can run this test on your setup. Also, gfx906 support is discontinued in rocm 6.4.0. Which rocm version are you using?
    • If the nwarps change doesn't fix this. I would have to look at the dp4a code path. I only enabled the 64 active threads but all optimization efforts for this PR is on CDNA3. Unless, we are okay with 2-3 quants drop slightly in perf while increasing all the other quants - For the short term (as this is not in the roadmap), I can only revert dp4a back to working on 32 threads, keeping all quants slower - which IMO, is not the right approach.
    • If nwarps == 4 fixes issue for dp4a, then this PR is ready to merge I believe.

change ggml_cuda_should_use_mmq to return false for q2_k in the amd_mfma case

To clarify, the function should return true for small batch sizes and false for large batch sizes with the boundary chosen in such a way that maximizes performance.

  • For AMD, with this PR and latest commit, the mfma code path outperforms the default path for all quants. So, ggml_cuda_should_use_mmq doesn't need to be changed. We can revisit again when the hipblaslt support is fixed.

  • More importantly, I tested the main branch on a H100, With the default code path (MMQ with mma instr) and with ggml_cuda_should_use_mmq returning false.

  • See below for results. In summary, the default path with MMQ is always underperforming compared to the dequant + cublas for 512 size. And for q2_K, it's is at ~60 TFLOPS vs ~221 TFLOPS with cublas.

    • I'm curious to understand why this code path is enabled by default. Shouldn't the behaviour be the same irrespective of underlying hardware to pick the code path with higher perf? Am I missing something here?
    • Also, the same observation that @IMbackK made for AMD, with the bf16 being far superior to the rest, can be seen in this case too. I haven't looked into why this is happening on MI300X yet. But if this is case, shouldn't dquant to bf16 + blas be the way to go over all other options?
MI300X TFLOPS - Main Branch vs PR
Main Branch:
MUL_MAT(type_a=f32,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):               274 runs -  3661.16 us/run -  60.13 GFLOP/run -  16.42 TFLOPS
  MUL_MAT(type_a=f16,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):              1664 runs -   601.67 us/run -  60.13 GFLOP/run -  99.94 TFLOPS
  MUL_MAT(type_a=bf16,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     6080 runs -   164.48 us/run -  60.13 GFLOP/run - 365.57 TFLOPS
  MUL_MAT(type_a=q4_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     1364 runs -   734.15 us/run -  60.13 GFLOP/run -  81.90 TFLOPS
  MUL_MAT(type_a=q4_1,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     1364 runs -   733.22 us/run -  60.13 GFLOP/run -  82.01 TFLOPS
  MUL_MAT(type_a=q5_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     1446 runs -   692.51 us/run -  60.13 GFLOP/run -  86.83 TFLOPS
  MUL_MAT(type_a=q5_1,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     1436 runs -   697.16 us/run -  60.13 GFLOP/run -  86.25 TFLOPS
  MUL_MAT(type_a=q8_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     1526 runs -   655.69 us/run -  60.13 GFLOP/run -  91.70 TFLOPS
  MUL_MAT(type_a=q2_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     1368 runs -   731.03 us/run -  60.13 GFLOP/run -  82.25 TFLOPS
  MUL_MAT(type_a=q3_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     1370 runs -   730.67 us/run -  60.13 GFLOP/run -  82.29 TFLOPS
  MUL_MAT(type_a=q4_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     1374 runs -   728.80 us/run -  60.13 GFLOP/run -  82.51 TFLOPS
  MUL_MAT(type_a=q5_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     1372 runs -   729.31 us/run -  60.13 GFLOP/run -  82.45 TFLOPS
  MUL_MAT(type_a=q6_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     1376 runs -   727.33 us/run -  60.13 GFLOP/run -  82.67 TFLOPS
  MUL_MAT(type_a=iq2_xxs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                  1266 runs -   790.53 us/run -  60.13 GFLOP/run -  76.06 TFLOPS
  MUL_MAT(type_a=iq2_xs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                   1374 runs -   728.63 us/run -  60.13 GFLOP/run -  82.52 TFLOPS
  MUL_MAT(type_a=iq2_s,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                    1372 runs -   729.26 us/run -  60.13 GFLOP/run -  82.45 TFLOPS
  MUL_MAT(type_a=iq3_xxs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                  1330 runs -   752.33 us/run -  60.13 GFLOP/run -  79.92 TFLOPS
  MUL_MAT(type_a=iq1_s,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                    1372 runs -   729.30 us/run -  60.13 GFLOP/run -  82.45 TFLOPS
  MUL_MAT(type_a=iq1_m,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                    1366 runs -   732.92 us/run -  60.13 GFLOP/run -  82.04 TFLOPS
  MUL_MAT(type_a=iq4_nl,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                   1374 runs -   728.28 us/run -  60.13 GFLOP/run -  82.56 TFLOPS
  MUL_MAT(type_a=iq3_s,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                    1374 runs -   728.63 us/run -  60.13 GFLOP/run -  82.52 TFLOPS
  MUL_MAT(type_a=iq4_xs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                   1370 runs -   730.72 us/run -  60.13 GFLOP/run -  82.29 TFLOPS

PR:
 MUL_MAT(type_a=f32,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):               280 runs -  3592.92 us/run -  60.13 GFLOP/run -  16.74 TFLOPS
  MUL_MAT(type_a=f16,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):              1646 runs -   608.25 us/run -  60.13 GFLOP/run -  98.86 TFLOPS
  MUL_MAT(type_a=bf16,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     6154 runs -   162.50 us/run -  60.13 GFLOP/run - 370.02 TFLOPS
  MUL_MAT(type_a=q4_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     3718 runs -   269.01 us/run -  60.13 GFLOP/run - 223.52 TFLOPS
  MUL_MAT(type_a=q4_1,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     3694 runs -   270.81 us/run -  60.13 GFLOP/run - 222.04 TFLOPS
  MUL_MAT(type_a=q5_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     3346 runs -   298.98 us/run -  60.13 GFLOP/run - 201.11 TFLOPS
  MUL_MAT(type_a=q5_1,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     3442 runs -   290.55 us/run -  60.13 GFLOP/run - 206.95 TFLOPS
  MUL_MAT(type_a=q8_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     3706 runs -   269.94 us/run -  60.13 GFLOP/run - 222.75 TFLOPS
  MUL_MAT(type_a=q2_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     1406 runs -   711.80 us/run -  60.13 GFLOP/run -  84.48 TFLOPS
  MUL_MAT(type_a=q3_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     2416 runs -   413.94 us/run -  60.13 GFLOP/run - 145.26 TFLOPS
  MUL_MAT(type_a=q4_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     3618 runs -   276.47 us/run -  60.13 GFLOP/run - 217.49 TFLOPS
  MUL_MAT(type_a=q5_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     3086 runs -   324.07 us/run -  60.13 GFLOP/run - 185.54 TFLOPS
  MUL_MAT(type_a=q6_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     2018 runs -   495.91 us/run -  60.13 GFLOP/run - 121.25 TFLOPS
  MUL_MAT(type_a=iq2_xxs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                  3250 runs -   307.82 us/run -  60.13 GFLOP/run - 195.34 TFLOPS
  MUL_MAT(type_a=iq2_xs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                   2110 runs -   474.03 us/run -  60.13 GFLOP/run - 126.85 TFLOPS
  MUL_MAT(type_a=iq2_s,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                    2010 runs -   497.58 us/run -  60.13 GFLOP/run - 120.84 TFLOPS
  MUL_MAT(type_a=iq3_xxs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                  3290 runs -   303.99 us/run -  60.13 GFLOP/run - 197.80 TFLOPS
  MUL_MAT(type_a=iq1_s,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                    3436 runs -   291.15 us/run -  60.13 GFLOP/run - 206.53 TFLOPS
  MUL_MAT(type_a=iq1_m,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                    1406 runs -   711.53 us/run -  60.13 GFLOP/run -  84.51 TFLOPS
  MUL_MAT(type_a=iq4_nl,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                   2826 runs -   354.01 us/run -  60.13 GFLOP/run - 169.85 TFLOPS
  MUL_MAT(type_a=iq3_s,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                    2956 runs -   338.30 us/run -  60.13 GFLOP/run - 177.74 TFLOPS
  MUL_MAT(type_a=iq4_xs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                   2738 runs -   365.25 us/run -  60.13 GFLOP/run - 164.62 TFLOPS
backed ops test - ALL PASSED
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 ROCm devices:
  Device 0: AMD Instinct MI300X, gfx942:sramecc+:xnack- (0x942), VMM: no, Wave Size: 64
load_backend: loaded ROCm backend from /app/new_rocm/build/bin/libggml-hip.so
load_backend: loaded CPU backend from /app/new_rocm/build/bin/libggml-cpu-icelake.so
Testing 2 devices

Backend 1/2: ROCm0
  Device description: AMD Instinct MI300X
  Device memory: 196592 MB (196100 MB free)

  MUL_MAT(type_a=f32,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=2,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=3,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=4,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=5,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=6,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=7,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=8,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=9,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=2,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=3,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=4,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=5,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=6,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=7,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=8,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=9,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=16,n=2,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=16,n=3,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=16,n=4,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=16,n=5,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=16,n=6,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=16,n=7,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=16,n=8,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=16,n=9,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=2,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=3,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=4,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=5,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=6,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=7,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=8,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=9,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=2,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=3,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=4,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=5,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=6,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=7,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=8,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=9,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q5_0,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q5_0,type_b=f32,m=16,n=2,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q5_0,type_b=f32,m=16,n=3,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q5_0,type_b=f32,m=16,n=4,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q5_0,type_b=f32,m=16,n=5,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q5_0,type_b=f32,m=16,n=6,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q5_0,type_b=f32,m=16,n=7,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q5_0,type_b=f32,m=16,n=8,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q5_0,type_b=f32,m=16,n=9,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q5_1,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q5_1,type_b=f32,m=16,n=2,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q5_1,type_b=f32,m=16,n=3,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q5_1,type_b=f32,m=16,n=4,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q5_1,type_b=f32,m=16,n=5,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q5_1,type_b=f32,m=16,n=6,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q5_1,type_b=f32,m=16,n=7,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q5_1,type_b=f32,m=16,n=8,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q5_1,type_b=f32,m=16,n=9,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q8_0,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q8_0,type_b=f32,m=16,n=2,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q8_0,type_b=f32,m=16,n=3,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q8_0,type_b=f32,m=16,n=4,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q8_0,type_b=f32,m=16,n=5,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q8_0,type_b=f32,m=16,n=6,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q8_0,type_b=f32,m=16,n=7,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q8_0,type_b=f32,m=16,n=8,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q8_0,type_b=f32,m=16,n=9,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q2_K,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q2_K,type_b=f32,m=16,n=2,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q2_K,type_b=f32,m=16,n=3,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q2_K,type_b=f32,m=16,n=4,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q2_K,type_b=f32,m=16,n=5,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q2_K,type_b=f32,m=16,n=6,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q2_K,type_b=f32,m=16,n=7,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q2_K,type_b=f32,m=16,n=8,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q2_K,type_b=f32,m=16,n=9,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q3_K,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q3_K,type_b=f32,m=16,n=2,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q3_K,type_b=f32,m=16,n=3,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q3_K,type_b=f32,m=16,n=4,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q3_K,type_b=f32,m=16,n=5,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q3_K,type_b=f32,m=16,n=6,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q3_K,type_b=f32,m=16,n=7,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q3_K,type_b=f32,m=16,n=8,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q3_K,type_b=f32,m=16,n=9,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_K,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_K,type_b=f32,m=16,n=2,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_K,type_b=f32,m=16,n=3,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_K,type_b=f32,m=16,n=4,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_K,type_b=f32,m=16,n=5,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_K,type_b=f32,m=16,n=6,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_K,type_b=f32,m=16,n=7,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_K,type_b=f32,m=16,n=8,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_K,type_b=f32,m=16,n=9,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q5_K,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q5_K,type_b=f32,m=16,n=2,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q5_K,type_b=f32,m=16,n=3,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q5_K,type_b=f32,m=16,n=4,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q5_K,type_b=f32,m=16,n=5,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q5_K,type_b=f32,m=16,n=6,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q5_K,type_b=f32,m=16,n=7,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q5_K,type_b=f32,m=16,n=8,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q5_K,type_b=f32,m=16,n=9,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q6_K,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q6_K,type_b=f32,m=16,n=2,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q6_K,type_b=f32,m=16,n=3,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q6_K,type_b=f32,m=16,n=4,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q6_K,type_b=f32,m=16,n=5,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q6_K,type_b=f32,m=16,n=6,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q6_K,type_b=f32,m=16,n=7,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q6_K,type_b=f32,m=16,n=8,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q6_K,type_b=f32,m=16,n=9,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq2_xxs,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq2_xxs,type_b=f32,m=16,n=2,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq2_xxs,type_b=f32,m=16,n=3,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq2_xxs,type_b=f32,m=16,n=4,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq2_xxs,type_b=f32,m=16,n=5,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq2_xxs,type_b=f32,m=16,n=6,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq2_xxs,type_b=f32,m=16,n=7,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq2_xxs,type_b=f32,m=16,n=8,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq2_xxs,type_b=f32,m=16,n=9,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq2_xs,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq2_xs,type_b=f32,m=16,n=2,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq2_xs,type_b=f32,m=16,n=3,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq2_xs,type_b=f32,m=16,n=4,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq2_xs,type_b=f32,m=16,n=5,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq2_xs,type_b=f32,m=16,n=6,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq2_xs,type_b=f32,m=16,n=7,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq2_xs,type_b=f32,m=16,n=8,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq2_xs,type_b=f32,m=16,n=9,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq2_s,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq2_s,type_b=f32,m=16,n=2,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq2_s,type_b=f32,m=16,n=3,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq2_s,type_b=f32,m=16,n=4,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq2_s,type_b=f32,m=16,n=5,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq2_s,type_b=f32,m=16,n=6,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq2_s,type_b=f32,m=16,n=7,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq2_s,type_b=f32,m=16,n=8,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq2_s,type_b=f32,m=16,n=9,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq3_xxs,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq3_xxs,type_b=f32,m=16,n=2,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq3_xxs,type_b=f32,m=16,n=3,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq3_xxs,type_b=f32,m=16,n=4,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq3_xxs,type_b=f32,m=16,n=5,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq3_xxs,type_b=f32,m=16,n=6,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq3_xxs,type_b=f32,m=16,n=7,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq3_xxs,type_b=f32,m=16,n=8,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq3_xxs,type_b=f32,m=16,n=9,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq1_s,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq1_s,type_b=f32,m=16,n=2,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq1_s,type_b=f32,m=16,n=3,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq1_s,type_b=f32,m=16,n=4,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq1_s,type_b=f32,m=16,n=5,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq1_s,type_b=f32,m=16,n=6,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq1_s,type_b=f32,m=16,n=7,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq1_s,type_b=f32,m=16,n=8,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq1_s,type_b=f32,m=16,n=9,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq1_m,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq1_m,type_b=f32,m=16,n=2,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq1_m,type_b=f32,m=16,n=3,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq1_m,type_b=f32,m=16,n=4,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq1_m,type_b=f32,m=16,n=5,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq1_m,type_b=f32,m=16,n=6,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq1_m,type_b=f32,m=16,n=7,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq1_m,type_b=f32,m=16,n=8,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq1_m,type_b=f32,m=16,n=9,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq4_nl,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq4_nl,type_b=f32,m=16,n=2,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq4_nl,type_b=f32,m=16,n=3,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq4_nl,type_b=f32,m=16,n=4,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq4_nl,type_b=f32,m=16,n=5,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq4_nl,type_b=f32,m=16,n=6,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq4_nl,type_b=f32,m=16,n=7,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq4_nl,type_b=f32,m=16,n=8,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq4_nl,type_b=f32,m=16,n=9,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq3_s,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq3_s,type_b=f32,m=16,n=2,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq3_s,type_b=f32,m=16,n=3,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq3_s,type_b=f32,m=16,n=4,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq3_s,type_b=f32,m=16,n=5,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq3_s,type_b=f32,m=16,n=6,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq3_s,type_b=f32,m=16,n=7,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq3_s,type_b=f32,m=16,n=8,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq3_s,type_b=f32,m=16,n=9,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq4_xs,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq4_xs,type_b=f32,m=16,n=2,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq4_xs,type_b=f32,m=16,n=3,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq4_xs,type_b=f32,m=16,n=4,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq4_xs,type_b=f32,m=16,n=5,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq4_xs,type_b=f32,m=16,n=6,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq4_xs,type_b=f32,m=16,n=7,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq4_xs,type_b=f32,m=16,n=8,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq4_xs,type_b=f32,m=16,n=9,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[2,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,2],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=1,k=256,bs=[3,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=1,k=256,bs=[3,1],nr=[2,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=1,k=256,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=1,k=256,bs=[3,2],nr=[2,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=1,k=256,bs=[3,2],nr=[1,2],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=1,k=256,bs=[3,2],nr=[2,2],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=16,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=16,k=256,bs=[1,1],nr=[2,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=16,k=256,bs=[1,1],nr=[1,2],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=16,k=256,bs=[3,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=16,k=256,bs=[3,1],nr=[2,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=16,k=256,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=16,k=256,bs=[3,2],nr=[2,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=16,k=256,bs=[3,2],nr=[1,2],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=16,k=256,bs=[3,2],nr=[2,2],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=1,k=256,bs=[2,3],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=1,k=256,bs=[2,3],nr=[1,1],per=[0,1,3,2],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=1,k=256,bs=[2,3],nr=[1,1],per=[0,3,2,1],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=8,k=256,bs=[2,3],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=8,k=256,bs=[2,3],nr=[1,1],per=[0,1,3,2],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=8,k=256,bs=[2,3],nr=[1,1],per=[0,3,2,1],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=16,k=256,bs=[2,3],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=16,k=256,bs=[2,3],nr=[1,1],per=[0,1,3,2],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=16,k=256,bs=[2,3],nr=[1,1],per=[0,3,2,1],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=1,k=4,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=1,k=4,bs=[1,1],nr=[2,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=1,k=4,bs=[1,1],nr=[1,2],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=1,k=4,bs=[3,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=1,k=4,bs=[3,1],nr=[2,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=1,k=4,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=1,k=4,bs=[3,2],nr=[2,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=1,k=4,bs=[3,2],nr=[1,2],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=1,k=4,bs=[3,2],nr=[2,2],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=16,k=4,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=16,k=4,bs=[1,1],nr=[2,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=16,k=4,bs=[1,1],nr=[1,2],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=16,k=4,bs=[3,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=16,k=4,bs=[3,1],nr=[2,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=16,k=4,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=16,k=4,bs=[3,2],nr=[2,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=16,k=4,bs=[3,2],nr=[1,2],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=16,k=4,bs=[3,2],nr=[2,2],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=1,k=4,bs=[2,3],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=1,k=4,bs=[2,3],nr=[1,1],per=[0,1,3,2],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=1,k=4,bs=[2,3],nr=[1,1],per=[0,3,2,1],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=8,k=4,bs=[2,3],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=8,k=4,bs=[2,3],nr=[1,1],per=[0,1,3,2],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=8,k=4,bs=[2,3],nr=[1,1],per=[0,3,2,1],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=16,k=4,bs=[2,3],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=16,k=4,bs=[2,3],nr=[1,1],per=[0,1,3,2],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=16,k=4,bs=[2,3],nr=[1,1],per=[0,3,2,1],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=1,k=1024,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=8,k=1024,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=16,k=1024,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f16,m=16,n=1,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=f32,type_b=f16,m=16,n=1,k=256,bs=[1,1],nr=[2,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=f32,type_b=f16,m=16,n=1,k=256,bs=[1,1],nr=[1,2],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=f32,type_b=f16,m=16,n=1,k=256,bs=[3,1],nr=[1,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=f32,type_b=f16,m=16,n=1,k=256,bs=[3,1],nr=[2,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=f32,type_b=f16,m=16,n=1,k=256,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=f32,type_b=f16,m=16,n=1,k=256,bs=[3,2],nr=[2,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=f32,type_b=f16,m=16,n=1,k=256,bs=[3,2],nr=[1,2],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=f32,type_b=f16,m=16,n=1,k=256,bs=[3,2],nr=[2,2],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=f32,type_b=f16,m=16,n=16,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=f32,type_b=f16,m=16,n=16,k=256,bs=[1,1],nr=[2,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=f32,type_b=f16,m=16,n=16,k=256,bs=[1,1],nr=[1,2],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=f32,type_b=f16,m=16,n=16,k=256,bs=[3,1],nr=[1,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=f32,type_b=f16,m=16,n=16,k=256,bs=[3,1],nr=[2,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=f32,type_b=f16,m=16,n=16,k=256,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=f32,type_b=f16,m=16,n=16,k=256,bs=[3,2],nr=[2,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=f32,type_b=f16,m=16,n=16,k=256,bs=[3,2],nr=[1,2],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=f32,type_b=f16,m=16,n=16,k=256,bs=[3,2],nr=[2,2],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=f32,type_b=f16,m=16,n=1,k=256,bs=[2,3],nr=[1,1],per=[0,2,1,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=f32,type_b=f16,m=16,n=1,k=256,bs=[2,3],nr=[1,1],per=[0,1,3,2],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=f32,type_b=f16,m=16,n=1,k=256,bs=[2,3],nr=[1,1],per=[0,3,2,1],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=f32,type_b=f16,m=16,n=8,k=256,bs=[2,3],nr=[1,1],per=[0,2,1,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=f32,type_b=f16,m=16,n=8,k=256,bs=[2,3],nr=[1,1],per=[0,1,3,2],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=f32,type_b=f16,m=16,n=8,k=256,bs=[2,3],nr=[1,1],per=[0,3,2,1],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=f32,type_b=f16,m=16,n=16,k=256,bs=[2,3],nr=[1,1],per=[0,2,1,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=f32,type_b=f16,m=16,n=16,k=256,bs=[2,3],nr=[1,1],per=[0,1,3,2],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=f32,type_b=f16,m=16,n=16,k=256,bs=[2,3],nr=[1,1],per=[0,3,2,1],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=f32,type_b=f16,m=16,n=1,k=4,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=f32,type_b=f16,m=16,n=1,k=4,bs=[1,1],nr=[2,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=f32,type_b=f16,m=16,n=1,k=4,bs=[1,1],nr=[1,2],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=f32,type_b=f16,m=16,n=1,k=4,bs=[3,1],nr=[1,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=f32,type_b=f16,m=16,n=1,k=4,bs=[3,1],nr=[2,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=f32,type_b=f16,m=16,n=1,k=4,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=f32,type_b=f16,m=16,n=1,k=4,bs=[3,2],nr=[2,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=f32,type_b=f16,m=16,n=1,k=4,bs=[3,2],nr=[1,2],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=f32,type_b=f16,m=16,n=1,k=4,bs=[3,2],nr=[2,2],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=f32,type_b=f16,m=16,n=16,k=4,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=f32,type_b=f16,m=16,n=16,k=4,bs=[1,1],nr=[2,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=f32,type_b=f16,m=16,n=16,k=4,bs=[1,1],nr=[1,2],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=f32,type_b=f16,m=16,n=16,k=4,bs=[3,1],nr=[1,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=f32,type_b=f16,m=16,n=16,k=4,bs=[3,1],nr=[2,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=f32,type_b=f16,m=16,n=16,k=4,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=f32,type_b=f16,m=16,n=16,k=4,bs=[3,2],nr=[2,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=f32,type_b=f16,m=16,n=16,k=4,bs=[3,2],nr=[1,2],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=f32,type_b=f16,m=16,n=16,k=4,bs=[3,2],nr=[2,2],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=f32,type_b=f16,m=16,n=1,k=4,bs=[2,3],nr=[1,1],per=[0,2,1,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=f32,type_b=f16,m=16,n=1,k=4,bs=[2,3],nr=[1,1],per=[0,1,3,2],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=f32,type_b=f16,m=16,n=1,k=4,bs=[2,3],nr=[1,1],per=[0,3,2,1],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=f32,type_b=f16,m=16,n=8,k=4,bs=[2,3],nr=[1,1],per=[0,2,1,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=f32,type_b=f16,m=16,n=8,k=4,bs=[2,3],nr=[1,1],per=[0,1,3,2],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=f32,type_b=f16,m=16,n=8,k=4,bs=[2,3],nr=[1,1],per=[0,3,2,1],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=f32,type_b=f16,m=16,n=16,k=4,bs=[2,3],nr=[1,1],per=[0,2,1,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=f32,type_b=f16,m=16,n=16,k=4,bs=[2,3],nr=[1,1],per=[0,1,3,2],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=f32,type_b=f16,m=16,n=16,k=4,bs=[2,3],nr=[1,1],per=[0,3,2,1],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=f32,type_b=f16,m=16,n=1,k=1024,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=f32,type_b=f16,m=16,n=8,k=1024,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=f32,type_b=f16,m=16,n=16,k=1024,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[2,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,2],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=1,k=256,bs=[3,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=1,k=256,bs=[3,1],nr=[2,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=1,k=256,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=1,k=256,bs=[3,2],nr=[2,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=1,k=256,bs=[3,2],nr=[1,2],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=1,k=256,bs=[3,2],nr=[2,2],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=16,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=16,k=256,bs=[1,1],nr=[2,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=16,k=256,bs=[1,1],nr=[1,2],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=16,k=256,bs=[3,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=16,k=256,bs=[3,1],nr=[2,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=16,k=256,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=16,k=256,bs=[3,2],nr=[2,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=16,k=256,bs=[3,2],nr=[1,2],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=16,k=256,bs=[3,2],nr=[2,2],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=1,k=256,bs=[2,3],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=1,k=256,bs=[2,3],nr=[1,1],per=[0,1,3,2],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=1,k=256,bs=[2,3],nr=[1,1],per=[0,3,2,1],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=8,k=256,bs=[2,3],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=8,k=256,bs=[2,3],nr=[1,1],per=[0,1,3,2],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=8,k=256,bs=[2,3],nr=[1,1],per=[0,3,2,1],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=16,k=256,bs=[2,3],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=16,k=256,bs=[2,3],nr=[1,1],per=[0,1,3,2],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=16,k=256,bs=[2,3],nr=[1,1],per=[0,3,2,1],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=1,k=4,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=1,k=4,bs=[1,1],nr=[2,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=1,k=4,bs=[1,1],nr=[1,2],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=1,k=4,bs=[3,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=1,k=4,bs=[3,1],nr=[2,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=1,k=4,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=1,k=4,bs=[3,2],nr=[2,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=1,k=4,bs=[3,2],nr=[1,2],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=1,k=4,bs=[3,2],nr=[2,2],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=16,k=4,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=16,k=4,bs=[1,1],nr=[2,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=16,k=4,bs=[1,1],nr=[1,2],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=16,k=4,bs=[3,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=16,k=4,bs=[3,1],nr=[2,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=16,k=4,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=16,k=4,bs=[3,2],nr=[2,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=16,k=4,bs=[3,2],nr=[1,2],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=16,k=4,bs=[3,2],nr=[2,2],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=1,k=4,bs=[2,3],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=1,k=4,bs=[2,3],nr=[1,1],per=[0,1,3,2],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=1,k=4,bs=[2,3],nr=[1,1],per=[0,3,2,1],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=8,k=4,bs=[2,3],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=8,k=4,bs=[2,3],nr=[1,1],per=[0,1,3,2],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=8,k=4,bs=[2,3],nr=[1,1],per=[0,3,2,1],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=16,k=4,bs=[2,3],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=16,k=4,bs=[2,3],nr=[1,1],per=[0,1,3,2],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=16,k=4,bs=[2,3],nr=[1,1],per=[0,3,2,1],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=1,k=1024,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=8,k=1024,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=16,k=1024,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f16,m=16,n=1,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f16,m=16,n=1,k=256,bs=[1,1],nr=[2,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f16,m=16,n=1,k=256,bs=[1,1],nr=[1,2],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f16,m=16,n=1,k=256,bs=[3,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f16,m=16,n=1,k=256,bs=[3,1],nr=[2,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f16,m=16,n=1,k=256,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f16,m=16,n=1,k=256,bs=[3,2],nr=[2,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f16,m=16,n=1,k=256,bs=[3,2],nr=[1,2],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f16,m=16,n=1,k=256,bs=[3,2],nr=[2,2],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f16,m=16,n=16,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f16,m=16,n=16,k=256,bs=[1,1],nr=[2,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f16,m=16,n=16,k=256,bs=[1,1],nr=[1,2],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f16,m=16,n=16,k=256,bs=[3,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f16,m=16,n=16,k=256,bs=[3,1],nr=[2,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f16,m=16,n=16,k=256,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f16,m=16,n=16,k=256,bs=[3,2],nr=[2,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f16,m=16,n=16,k=256,bs=[3,2],nr=[1,2],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f16,m=16,n=16,k=256,bs=[3,2],nr=[2,2],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f16,m=16,n=1,k=256,bs=[2,3],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f16,m=16,n=1,k=256,bs=[2,3],nr=[1,1],per=[0,1,3,2],v=0): OK
  MUL_MAT(type_a=f16,type_b=f16,m=16,n=1,k=256,bs=[2,3],nr=[1,1],per=[0,3,2,1],v=0): OK
  MUL_MAT(type_a=f16,type_b=f16,m=16,n=8,k=256,bs=[2,3],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f16,m=16,n=8,k=256,bs=[2,3],nr=[1,1],per=[0,1,3,2],v=0): OK
  MUL_MAT(type_a=f16,type_b=f16,m=16,n=8,k=256,bs=[2,3],nr=[1,1],per=[0,3,2,1],v=0): OK
  MUL_MAT(type_a=f16,type_b=f16,m=16,n=16,k=256,bs=[2,3],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f16,m=16,n=16,k=256,bs=[2,3],nr=[1,1],per=[0,1,3,2],v=0): OK
  MUL_MAT(type_a=f16,type_b=f16,m=16,n=16,k=256,bs=[2,3],nr=[1,1],per=[0,3,2,1],v=0): OK
  MUL_MAT(type_a=f16,type_b=f16,m=16,n=1,k=4,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f16,m=16,n=1,k=4,bs=[1,1],nr=[2,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f16,m=16,n=1,k=4,bs=[1,1],nr=[1,2],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f16,m=16,n=1,k=4,bs=[3,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f16,m=16,n=1,k=4,bs=[3,1],nr=[2,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f16,m=16,n=1,k=4,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f16,m=16,n=1,k=4,bs=[3,2],nr=[2,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f16,m=16,n=1,k=4,bs=[3,2],nr=[1,2],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f16,m=16,n=1,k=4,bs=[3,2],nr=[2,2],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f16,m=16,n=16,k=4,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f16,m=16,n=16,k=4,bs=[1,1],nr=[2,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f16,m=16,n=16,k=4,bs=[1,1],nr=[1,2],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f16,m=16,n=16,k=4,bs=[3,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f16,m=16,n=16,k=4,bs=[3,1],nr=[2,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f16,m=16,n=16,k=4,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f16,m=16,n=16,k=4,bs=[3,2],nr=[2,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f16,m=16,n=16,k=4,bs=[3,2],nr=[1,2],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f16,m=16,n=16,k=4,bs=[3,2],nr=[2,2],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f16,m=16,n=1,k=4,bs=[2,3],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f16,m=16,n=1,k=4,bs=[2,3],nr=[1,1],per=[0,1,3,2],v=0): OK
  MUL_MAT(type_a=f16,type_b=f16,m=16,n=1,k=4,bs=[2,3],nr=[1,1],per=[0,3,2,1],v=0): OK
  MUL_MAT(type_a=f16,type_b=f16,m=16,n=8,k=4,bs=[2,3],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f16,m=16,n=8,k=4,bs=[2,3],nr=[1,1],per=[0,1,3,2],v=0): OK
  MUL_MAT(type_a=f16,type_b=f16,m=16,n=8,k=4,bs=[2,3],nr=[1,1],per=[0,3,2,1],v=0): OK
  MUL_MAT(type_a=f16,type_b=f16,m=16,n=16,k=4,bs=[2,3],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f16,m=16,n=16,k=4,bs=[2,3],nr=[1,1],per=[0,1,3,2],v=0): OK
  MUL_MAT(type_a=f16,type_b=f16,m=16,n=16,k=4,bs=[2,3],nr=[1,1],per=[0,3,2,1],v=0): OK
  MUL_MAT(type_a=f16,type_b=f16,m=16,n=1,k=1024,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f16,m=16,n=8,k=1024,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f16,m=16,n=16,k=1024,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q8_0,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q8_0,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[2,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q8_0,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,2],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q8_0,type_b=f32,m=16,n=1,k=256,bs=[3,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q8_0,type_b=f32,m=16,n=1,k=256,bs=[3,1],nr=[2,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q8_0,type_b=f32,m=16,n=1,k=256,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q8_0,type_b=f32,m=16,n=1,k=256,bs=[3,2],nr=[2,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q8_0,type_b=f32,m=16,n=1,k=256,bs=[3,2],nr=[1,2],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q8_0,type_b=f32,m=16,n=1,k=256,bs=[3,2],nr=[2,2],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q8_0,type_b=f32,m=16,n=16,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q8_0,type_b=f32,m=16,n=16,k=256,bs=[1,1],nr=[2,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q8_0,type_b=f32,m=16,n=16,k=256,bs=[1,1],nr=[1,2],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q8_0,type_b=f32,m=16,n=16,k=256,bs=[3,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q8_0,type_b=f32,m=16,n=16,k=256,bs=[3,1],nr=[2,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q8_0,type_b=f32,m=16,n=16,k=256,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q8_0,type_b=f32,m=16,n=16,k=256,bs=[3,2],nr=[2,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q8_0,type_b=f32,m=16,n=16,k=256,bs=[3,2],nr=[1,2],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q8_0,type_b=f32,m=16,n=16,k=256,bs=[3,2],nr=[2,2],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q8_0,type_b=f32,m=16,n=1,k=256,bs=[2,3],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=q8_0,type_b=f32,m=16,n=1,k=256,bs=[2,3],nr=[1,1],per=[0,1,3,2],v=0): OK
  MUL_MAT(type_a=q8_0,type_b=f32,m=16,n=1,k=256,bs=[2,3],nr=[1,1],per=[0,3,2,1],v=0): OK
  MUL_MAT(type_a=q8_0,type_b=f32,m=16,n=8,k=256,bs=[2,3],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=q8_0,type_b=f32,m=16,n=8,k=256,bs=[2,3],nr=[1,1],per=[0,1,3,2],v=0): OK
  MUL_MAT(type_a=q8_0,type_b=f32,m=16,n=8,k=256,bs=[2,3],nr=[1,1],per=[0,3,2,1],v=0): OK
  MUL_MAT(type_a=q8_0,type_b=f32,m=16,n=16,k=256,bs=[2,3],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=q8_0,type_b=f32,m=16,n=16,k=256,bs=[2,3],nr=[1,1],per=[0,1,3,2],v=0): OK
  MUL_MAT(type_a=q8_0,type_b=f32,m=16,n=16,k=256,bs=[2,3],nr=[1,1],per=[0,3,2,1],v=0): OK
  MUL_MAT(type_a=q8_0,type_b=f32,m=16,n=1,k=1024,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q8_0,type_b=f32,m=16,n=8,k=1024,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q8_0,type_b=f32,m=16,n=16,k=1024,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q8_0,type_b=f16,m=16,n=1,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q8_0,type_b=f16,m=16,n=1,k=256,bs=[1,1],nr=[2,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q8_0,type_b=f16,m=16,n=1,k=256,bs=[1,1],nr=[1,2],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q8_0,type_b=f16,m=16,n=1,k=256,bs=[3,1],nr=[1,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q8_0,type_b=f16,m=16,n=1,k=256,bs=[3,1],nr=[2,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q8_0,type_b=f16,m=16,n=1,k=256,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q8_0,type_b=f16,m=16,n=1,k=256,bs=[3,2],nr=[2,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q8_0,type_b=f16,m=16,n=1,k=256,bs=[3,2],nr=[1,2],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q8_0,type_b=f16,m=16,n=1,k=256,bs=[3,2],nr=[2,2],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q8_0,type_b=f16,m=16,n=16,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q8_0,type_b=f16,m=16,n=16,k=256,bs=[1,1],nr=[2,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q8_0,type_b=f16,m=16,n=16,k=256,bs=[1,1],nr=[1,2],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q8_0,type_b=f16,m=16,n=16,k=256,bs=[3,1],nr=[1,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q8_0,type_b=f16,m=16,n=16,k=256,bs=[3,1],nr=[2,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q8_0,type_b=f16,m=16,n=16,k=256,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q8_0,type_b=f16,m=16,n=16,k=256,bs=[3,2],nr=[2,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q8_0,type_b=f16,m=16,n=16,k=256,bs=[3,2],nr=[1,2],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q8_0,type_b=f16,m=16,n=16,k=256,bs=[3,2],nr=[2,2],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q8_0,type_b=f16,m=16,n=1,k=256,bs=[2,3],nr=[1,1],per=[0,2,1,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q8_0,type_b=f16,m=16,n=1,k=256,bs=[2,3],nr=[1,1],per=[0,1,3,2],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q8_0,type_b=f16,m=16,n=1,k=256,bs=[2,3],nr=[1,1],per=[0,3,2,1],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q8_0,type_b=f16,m=16,n=8,k=256,bs=[2,3],nr=[1,1],per=[0,2,1,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q8_0,type_b=f16,m=16,n=8,k=256,bs=[2,3],nr=[1,1],per=[0,1,3,2],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q8_0,type_b=f16,m=16,n=8,k=256,bs=[2,3],nr=[1,1],per=[0,3,2,1],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q8_0,type_b=f16,m=16,n=16,k=256,bs=[2,3],nr=[1,1],per=[0,2,1,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q8_0,type_b=f16,m=16,n=16,k=256,bs=[2,3],nr=[1,1],per=[0,1,3,2],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q8_0,type_b=f16,m=16,n=16,k=256,bs=[2,3],nr=[1,1],per=[0,3,2,1],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q8_0,type_b=f16,m=16,n=1,k=1024,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q8_0,type_b=f16,m=16,n=8,k=1024,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q8_0,type_b=f16,m=16,n=16,k=1024,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[2,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,2],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=1,k=256,bs=[3,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=1,k=256,bs=[3,1],nr=[2,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=1,k=256,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=1,k=256,bs=[3,2],nr=[2,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=1,k=256,bs=[3,2],nr=[1,2],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=1,k=256,bs=[3,2],nr=[2,2],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=16,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=16,k=256,bs=[1,1],nr=[2,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=16,k=256,bs=[1,1],nr=[1,2],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=16,k=256,bs=[3,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=16,k=256,bs=[3,1],nr=[2,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=16,k=256,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=16,k=256,bs=[3,2],nr=[2,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=16,k=256,bs=[3,2],nr=[1,2],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=16,k=256,bs=[3,2],nr=[2,2],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=1,k=256,bs=[2,3],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=1,k=256,bs=[2,3],nr=[1,1],per=[0,1,3,2],v=0): OK
  MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=1,k=256,bs=[2,3],nr=[1,1],per=[0,3,2,1],v=0): OK
  MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=8,k=256,bs=[2,3],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=8,k=256,bs=[2,3],nr=[1,1],per=[0,1,3,2],v=0): OK
  MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=8,k=256,bs=[2,3],nr=[1,1],per=[0,3,2,1],v=0): OK
  MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=16,k=256,bs=[2,3],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=16,k=256,bs=[2,3],nr=[1,1],per=[0,1,3,2],v=0): OK
  MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=16,k=256,bs=[2,3],nr=[1,1],per=[0,3,2,1],v=0): OK
  MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=1,k=1024,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=8,k=1024,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=16,k=1024,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_0,type_b=f16,m=16,n=1,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_0,type_b=f16,m=16,n=1,k=256,bs=[1,1],nr=[2,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_0,type_b=f16,m=16,n=1,k=256,bs=[1,1],nr=[1,2],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_0,type_b=f16,m=16,n=1,k=256,bs=[3,1],nr=[1,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_0,type_b=f16,m=16,n=1,k=256,bs=[3,1],nr=[2,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_0,type_b=f16,m=16,n=1,k=256,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_0,type_b=f16,m=16,n=1,k=256,bs=[3,2],nr=[2,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_0,type_b=f16,m=16,n=1,k=256,bs=[3,2],nr=[1,2],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_0,type_b=f16,m=16,n=1,k=256,bs=[3,2],nr=[2,2],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_0,type_b=f16,m=16,n=16,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_0,type_b=f16,m=16,n=16,k=256,bs=[1,1],nr=[2,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_0,type_b=f16,m=16,n=16,k=256,bs=[1,1],nr=[1,2],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_0,type_b=f16,m=16,n=16,k=256,bs=[3,1],nr=[1,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_0,type_b=f16,m=16,n=16,k=256,bs=[3,1],nr=[2,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_0,type_b=f16,m=16,n=16,k=256,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_0,type_b=f16,m=16,n=16,k=256,bs=[3,2],nr=[2,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_0,type_b=f16,m=16,n=16,k=256,bs=[3,2],nr=[1,2],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_0,type_b=f16,m=16,n=16,k=256,bs=[3,2],nr=[2,2],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_0,type_b=f16,m=16,n=1,k=256,bs=[2,3],nr=[1,1],per=[0,2,1,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_0,type_b=f16,m=16,n=1,k=256,bs=[2,3],nr=[1,1],per=[0,1,3,2],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_0,type_b=f16,m=16,n=1,k=256,bs=[2,3],nr=[1,1],per=[0,3,2,1],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_0,type_b=f16,m=16,n=8,k=256,bs=[2,3],nr=[1,1],per=[0,2,1,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_0,type_b=f16,m=16,n=8,k=256,bs=[2,3],nr=[1,1],per=[0,1,3,2],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_0,type_b=f16,m=16,n=8,k=256,bs=[2,3],nr=[1,1],per=[0,3,2,1],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_0,type_b=f16,m=16,n=16,k=256,bs=[2,3],nr=[1,1],per=[0,2,1,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_0,type_b=f16,m=16,n=16,k=256,bs=[2,3],nr=[1,1],per=[0,1,3,2],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_0,type_b=f16,m=16,n=16,k=256,bs=[2,3],nr=[1,1],per=[0,3,2,1],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_0,type_b=f16,m=16,n=1,k=1024,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_0,type_b=f16,m=16,n=8,k=1024,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_0,type_b=f16,m=16,n=16,k=1024,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[2,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,2],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=1,k=256,bs=[3,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=1,k=256,bs=[3,1],nr=[2,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=1,k=256,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=1,k=256,bs=[3,2],nr=[2,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=1,k=256,bs=[3,2],nr=[1,2],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=1,k=256,bs=[3,2],nr=[2,2],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=16,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=16,k=256,bs=[1,1],nr=[2,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=16,k=256,bs=[1,1],nr=[1,2],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=16,k=256,bs=[3,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=16,k=256,bs=[3,1],nr=[2,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=16,k=256,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=16,k=256,bs=[3,2],nr=[2,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=16,k=256,bs=[3,2],nr=[1,2],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=16,k=256,bs=[3,2],nr=[2,2],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=1,k=256,bs=[2,3],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=1,k=256,bs=[2,3],nr=[1,1],per=[0,1,3,2],v=0): OK
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=1,k=256,bs=[2,3],nr=[1,1],per=[0,3,2,1],v=0): OK
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=8,k=256,bs=[2,3],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=8,k=256,bs=[2,3],nr=[1,1],per=[0,1,3,2],v=0): OK
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=8,k=256,bs=[2,3],nr=[1,1],per=[0,3,2,1],v=0): OK
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=16,k=256,bs=[2,3],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=16,k=256,bs=[2,3],nr=[1,1],per=[0,1,3,2],v=0): OK
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=16,k=256,bs=[2,3],nr=[1,1],per=[0,3,2,1],v=0): OK
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=1,k=1024,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=8,k=1024,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=16,k=1024,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_1,type_b=f16,m=16,n=1,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_1,type_b=f16,m=16,n=1,k=256,bs=[1,1],nr=[2,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_1,type_b=f16,m=16,n=1,k=256,bs=[1,1],nr=[1,2],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_1,type_b=f16,m=16,n=1,k=256,bs=[3,1],nr=[1,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_1,type_b=f16,m=16,n=1,k=256,bs=[3,1],nr=[2,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_1,type_b=f16,m=16,n=1,k=256,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_1,type_b=f16,m=16,n=1,k=256,bs=[3,2],nr=[2,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_1,type_b=f16,m=16,n=1,k=256,bs=[3,2],nr=[1,2],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_1,type_b=f16,m=16,n=1,k=256,bs=[3,2],nr=[2,2],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_1,type_b=f16,m=16,n=16,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_1,type_b=f16,m=16,n=16,k=256,bs=[1,1],nr=[2,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_1,type_b=f16,m=16,n=16,k=256,bs=[1,1],nr=[1,2],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_1,type_b=f16,m=16,n=16,k=256,bs=[3,1],nr=[1,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_1,type_b=f16,m=16,n=16,k=256,bs=[3,1],nr=[2,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_1,type_b=f16,m=16,n=16,k=256,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_1,type_b=f16,m=16,n=16,k=256,bs=[3,2],nr=[2,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_1,type_b=f16,m=16,n=16,k=256,bs=[3,2],nr=[1,2],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_1,type_b=f16,m=16,n=16,k=256,bs=[3,2],nr=[2,2],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_1,type_b=f16,m=16,n=1,k=256,bs=[2,3],nr=[1,1],per=[0,2,1,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_1,type_b=f16,m=16,n=1,k=256,bs=[2,3],nr=[1,1],per=[0,1,3,2],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_1,type_b=f16,m=16,n=1,k=256,bs=[2,3],nr=[1,1],per=[0,3,2,1],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_1,type_b=f16,m=16,n=8,k=256,bs=[2,3],nr=[1,1],per=[0,2,1,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_1,type_b=f16,m=16,n=8,k=256,bs=[2,3],nr=[1,1],per=[0,1,3,2],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_1,type_b=f16,m=16,n=8,k=256,bs=[2,3],nr=[1,1],per=[0,3,2,1],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_1,type_b=f16,m=16,n=16,k=256,bs=[2,3],nr=[1,1],per=[0,2,1,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_1,type_b=f16,m=16,n=16,k=256,bs=[2,3],nr=[1,1],per=[0,1,3,2],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_1,type_b=f16,m=16,n=16,k=256,bs=[2,3],nr=[1,1],per=[0,3,2,1],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_1,type_b=f16,m=16,n=1,k=1024,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_1,type_b=f16,m=16,n=8,k=1024,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_1,type_b=f16,m=16,n=16,k=1024,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_K,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_K,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[2,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_K,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,2],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_K,type_b=f32,m=16,n=1,k=256,bs=[3,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_K,type_b=f32,m=16,n=1,k=256,bs=[3,1],nr=[2,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_K,type_b=f32,m=16,n=1,k=256,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_K,type_b=f32,m=16,n=1,k=256,bs=[3,2],nr=[2,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_K,type_b=f32,m=16,n=1,k=256,bs=[3,2],nr=[1,2],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_K,type_b=f32,m=16,n=1,k=256,bs=[3,2],nr=[2,2],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_K,type_b=f32,m=16,n=16,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_K,type_b=f32,m=16,n=16,k=256,bs=[1,1],nr=[2,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_K,type_b=f32,m=16,n=16,k=256,bs=[1,1],nr=[1,2],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_K,type_b=f32,m=16,n=16,k=256,bs=[3,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_K,type_b=f32,m=16,n=16,k=256,bs=[3,1],nr=[2,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_K,type_b=f32,m=16,n=16,k=256,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_K,type_b=f32,m=16,n=16,k=256,bs=[3,2],nr=[2,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_K,type_b=f32,m=16,n=16,k=256,bs=[3,2],nr=[1,2],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_K,type_b=f32,m=16,n=16,k=256,bs=[3,2],nr=[2,2],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_K,type_b=f32,m=16,n=1,k=256,bs=[2,3],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=q4_K,type_b=f32,m=16,n=1,k=256,bs=[2,3],nr=[1,1],per=[0,1,3,2],v=0): OK
  MUL_MAT(type_a=q4_K,type_b=f32,m=16,n=1,k=256,bs=[2,3],nr=[1,1],per=[0,3,2,1],v=0): OK
  MUL_MAT(type_a=q4_K,type_b=f32,m=16,n=8,k=256,bs=[2,3],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=q4_K,type_b=f32,m=16,n=8,k=256,bs=[2,3],nr=[1,1],per=[0,1,3,2],v=0): OK
  MUL_MAT(type_a=q4_K,type_b=f32,m=16,n=8,k=256,bs=[2,3],nr=[1,1],per=[0,3,2,1],v=0): OK
  MUL_MAT(type_a=q4_K,type_b=f32,m=16,n=16,k=256,bs=[2,3],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=q4_K,type_b=f32,m=16,n=16,k=256,bs=[2,3],nr=[1,1],per=[0,1,3,2],v=0): OK
  MUL_MAT(type_a=q4_K,type_b=f32,m=16,n=16,k=256,bs=[2,3],nr=[1,1],per=[0,3,2,1],v=0): OK
  MUL_MAT(type_a=q4_K,type_b=f32,m=16,n=1,k=1024,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_K,type_b=f32,m=16,n=8,k=1024,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_K,type_b=f32,m=16,n=16,k=1024,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_K,type_b=f16,m=16,n=1,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_K,type_b=f16,m=16,n=1,k=256,bs=[1,1],nr=[2,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_K,type_b=f16,m=16,n=1,k=256,bs=[1,1],nr=[1,2],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_K,type_b=f16,m=16,n=1,k=256,bs=[3,1],nr=[1,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_K,type_b=f16,m=16,n=1,k=256,bs=[3,1],nr=[2,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_K,type_b=f16,m=16,n=1,k=256,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_K,type_b=f16,m=16,n=1,k=256,bs=[3,2],nr=[2,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_K,type_b=f16,m=16,n=1,k=256,bs=[3,2],nr=[1,2],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_K,type_b=f16,m=16,n=1,k=256,bs=[3,2],nr=[2,2],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_K,type_b=f16,m=16,n=16,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_K,type_b=f16,m=16,n=16,k=256,bs=[1,1],nr=[2,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_K,type_b=f16,m=16,n=16,k=256,bs=[1,1],nr=[1,2],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_K,type_b=f16,m=16,n=16,k=256,bs=[3,1],nr=[1,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_K,type_b=f16,m=16,n=16,k=256,bs=[3,1],nr=[2,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_K,type_b=f16,m=16,n=16,k=256,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_K,type_b=f16,m=16,n=16,k=256,bs=[3,2],nr=[2,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_K,type_b=f16,m=16,n=16,k=256,bs=[3,2],nr=[1,2],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_K,type_b=f16,m=16,n=16,k=256,bs=[3,2],nr=[2,2],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_K,type_b=f16,m=16,n=1,k=256,bs=[2,3],nr=[1,1],per=[0,2,1,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_K,type_b=f16,m=16,n=1,k=256,bs=[2,3],nr=[1,1],per=[0,1,3,2],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_K,type_b=f16,m=16,n=1,k=256,bs=[2,3],nr=[1,1],per=[0,3,2,1],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_K,type_b=f16,m=16,n=8,k=256,bs=[2,3],nr=[1,1],per=[0,2,1,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_K,type_b=f16,m=16,n=8,k=256,bs=[2,3],nr=[1,1],per=[0,1,3,2],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_K,type_b=f16,m=16,n=8,k=256,bs=[2,3],nr=[1,1],per=[0,3,2,1],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_K,type_b=f16,m=16,n=16,k=256,bs=[2,3],nr=[1,1],per=[0,2,1,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_K,type_b=f16,m=16,n=16,k=256,bs=[2,3],nr=[1,1],per=[0,1,3,2],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_K,type_b=f16,m=16,n=16,k=256,bs=[2,3],nr=[1,1],per=[0,3,2,1],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_K,type_b=f16,m=16,n=1,k=1024,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_K,type_b=f16,m=16,n=8,k=1024,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_K,type_b=f16,m=16,n=16,k=1024,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=iq2_xxs,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq2_xxs,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[2,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq2_xxs,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,2],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq2_xxs,type_b=f32,m=16,n=1,k=256,bs=[3,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq2_xxs,type_b=f32,m=16,n=1,k=256,bs=[3,1],nr=[2,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq2_xxs,type_b=f32,m=16,n=1,k=256,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq2_xxs,type_b=f32,m=16,n=1,k=256,bs=[3,2],nr=[2,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq2_xxs,type_b=f32,m=16,n=1,k=256,bs=[3,2],nr=[1,2],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq2_xxs,type_b=f32,m=16,n=1,k=256,bs=[3,2],nr=[2,2],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq2_xxs,type_b=f32,m=16,n=16,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq2_xxs,type_b=f32,m=16,n=16,k=256,bs=[1,1],nr=[2,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq2_xxs,type_b=f32,m=16,n=16,k=256,bs=[1,1],nr=[1,2],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq2_xxs,type_b=f32,m=16,n=16,k=256,bs=[3,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq2_xxs,type_b=f32,m=16,n=16,k=256,bs=[3,1],nr=[2,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq2_xxs,type_b=f32,m=16,n=16,k=256,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq2_xxs,type_b=f32,m=16,n=16,k=256,bs=[3,2],nr=[2,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq2_xxs,type_b=f32,m=16,n=16,k=256,bs=[3,2],nr=[1,2],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq2_xxs,type_b=f32,m=16,n=16,k=256,bs=[3,2],nr=[2,2],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq2_xxs,type_b=f32,m=16,n=1,k=256,bs=[2,3],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=iq2_xxs,type_b=f32,m=16,n=1,k=256,bs=[2,3],nr=[1,1],per=[0,1,3,2],v=0): OK
  MUL_MAT(type_a=iq2_xxs,type_b=f32,m=16,n=1,k=256,bs=[2,3],nr=[1,1],per=[0,3,2,1],v=0): OK
  MUL_MAT(type_a=iq2_xxs,type_b=f32,m=16,n=8,k=256,bs=[2,3],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=iq2_xxs,type_b=f32,m=16,n=8,k=256,bs=[2,3],nr=[1,1],per=[0,1,3,2],v=0): OK
  MUL_MAT(type_a=iq2_xxs,type_b=f32,m=16,n=8,k=256,bs=[2,3],nr=[1,1],per=[0,3,2,1],v=0): OK
  MUL_MAT(type_a=iq2_xxs,type_b=f32,m=16,n=16,k=256,bs=[2,3],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=iq2_xxs,type_b=f32,m=16,n=16,k=256,bs=[2,3],nr=[1,1],per=[0,1,3,2],v=0): OK
  MUL_MAT(type_a=iq2_xxs,type_b=f32,m=16,n=16,k=256,bs=[2,3],nr=[1,1],per=[0,3,2,1],v=0): OK
  MUL_MAT(type_a=iq2_xxs,type_b=f32,m=16,n=1,k=1024,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq2_xxs,type_b=f32,m=16,n=8,k=1024,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq2_xxs,type_b=f32,m=16,n=16,k=1024,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq2_xxs,type_b=f16,m=16,n=1,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=iq2_xxs,type_b=f16,m=16,n=1,k=256,bs=[1,1],nr=[2,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=iq2_xxs,type_b=f16,m=16,n=1,k=256,bs=[1,1],nr=[1,2],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=iq2_xxs,type_b=f16,m=16,n=1,k=256,bs=[3,1],nr=[1,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=iq2_xxs,type_b=f16,m=16,n=1,k=256,bs=[3,1],nr=[2,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=iq2_xxs,type_b=f16,m=16,n=1,k=256,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=iq2_xxs,type_b=f16,m=16,n=1,k=256,bs=[3,2],nr=[2,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=iq2_xxs,type_b=f16,m=16,n=1,k=256,bs=[3,2],nr=[1,2],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=iq2_xxs,type_b=f16,m=16,n=1,k=256,bs=[3,2],nr=[2,2],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=iq2_xxs,type_b=f16,m=16,n=16,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=iq2_xxs,type_b=f16,m=16,n=16,k=256,bs=[1,1],nr=[2,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=iq2_xxs,type_b=f16,m=16,n=16,k=256,bs=[1,1],nr=[1,2],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=iq2_xxs,type_b=f16,m=16,n=16,k=256,bs=[3,1],nr=[1,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=iq2_xxs,type_b=f16,m=16,n=16,k=256,bs=[3,1],nr=[2,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=iq2_xxs,type_b=f16,m=16,n=16,k=256,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=iq2_xxs,type_b=f16,m=16,n=16,k=256,bs=[3,2],nr=[2,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=iq2_xxs,type_b=f16,m=16,n=16,k=256,bs=[3,2],nr=[1,2],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=iq2_xxs,type_b=f16,m=16,n=16,k=256,bs=[3,2],nr=[2,2],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=iq2_xxs,type_b=f16,m=16,n=1,k=256,bs=[2,3],nr=[1,1],per=[0,2,1,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=iq2_xxs,type_b=f16,m=16,n=1,k=256,bs=[2,3],nr=[1,1],per=[0,1,3,2],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=iq2_xxs,type_b=f16,m=16,n=1,k=256,bs=[2,3],nr=[1,1],per=[0,3,2,1],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=iq2_xxs,type_b=f16,m=16,n=8,k=256,bs=[2,3],nr=[1,1],per=[0,2,1,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=iq2_xxs,type_b=f16,m=16,n=8,k=256,bs=[2,3],nr=[1,1],per=[0,1,3,2],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=iq2_xxs,type_b=f16,m=16,n=8,k=256,bs=[2,3],nr=[1,1],per=[0,3,2,1],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=iq2_xxs,type_b=f16,m=16,n=16,k=256,bs=[2,3],nr=[1,1],per=[0,2,1,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=iq2_xxs,type_b=f16,m=16,n=16,k=256,bs=[2,3],nr=[1,1],per=[0,1,3,2],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=iq2_xxs,type_b=f16,m=16,n=16,k=256,bs=[2,3],nr=[1,1],per=[0,3,2,1],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=iq2_xxs,type_b=f16,m=16,n=1,k=1024,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=iq2_xxs,type_b=f16,m=16,n=8,k=1024,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=iq2_xxs,type_b=f16,m=16,n=16,k=1024,bs=[3,2],nr=[1,1],per=[0,1,2,3],v=0): not supported [ROCm0] 
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=1,k=32,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q5_0,type_b=f32,m=16,n=1,k=32,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q5_0,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q5_1,type_b=f32,m=16,n=1,k=32,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q5_1,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q8_0,type_b=f32,m=16,n=1,k=32,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q8_0,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q2_K,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q3_K,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q5_K,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=q6_K,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq2_xs,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq2_s,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq3_xxs,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq1_s,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq1_m,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq4_nl,type_b=f32,m=16,n=1,k=32,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq4_nl,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq3_s,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=iq4_xs,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=16,n=1,k=1,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=64,n=2,k=128,bs=[8,1],nr=[1,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=83,n=2,k=128,bs=[8,1],nr=[4,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=64,n=2,k=64,bs=[8,1],nr=[4,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=83,n=2,k=64,bs=[8,1],nr=[4,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=64,n=45,k=128,bs=[8,1],nr=[4,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=128,n=45,k=64,bs=[8,1],nr=[4,1],per=[0,1,2,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=1056,n=1,k=193,bs=[1,1],nr=[4,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=1056,n=1,k=67,bs=[1,1],nr=[4,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=1056,n=1,k=128,bs=[1,1],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=128,n=1,k=1056,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=1056,n=1,k=128,bs=[1,1],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=128,n=1,k=1056,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=f32,type_b=f32,m=1056,n=1,k=128,bs=[1,1],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=128,n=1,k=1056,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=f16,type_b=f32,m=1056,n=1,k=129,bs=[1,1],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=128,n=1,k=1057,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=1056,n=1,k=129,bs=[1,1],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=128,n=1,k=1057,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=f32,type_b=f32,m=1056,n=1,k=129,bs=[1,1],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=128,n=1,k=1057,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=f16,type_b=f32,m=1057,n=1,k=128,bs=[1,1],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=129,n=1,k=1056,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=1057,n=1,k=128,bs=[1,1],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=129,n=1,k=1056,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=f32,type_b=f32,m=1057,n=1,k=128,bs=[1,1],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=129,n=1,k=1056,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=f16,type_b=f32,m=1057,n=1,k=129,bs=[1,1],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=129,n=1,k=1057,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=1057,n=1,k=129,bs=[1,1],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=129,n=1,k=1057,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=f32,type_b=f32,m=1057,n=1,k=129,bs=[1,1],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=129,n=1,k=1057,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=f16,type_b=f32,m=1056,n=1,k=128,bs=[1,1],nr=[4,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=128,n=1,k=1056,bs=[1,1],nr=[4,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=1056,n=1,k=128,bs=[1,1],nr=[4,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=128,n=1,k=1056,bs=[1,1],nr=[4,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=f32,type_b=f32,m=1056,n=1,k=128,bs=[1,1],nr=[4,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=128,n=1,k=1056,bs=[1,1],nr=[4,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=f16,type_b=f32,m=1056,n=1,k=129,bs=[1,1],nr=[4,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=128,n=1,k=1057,bs=[1,1],nr=[4,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=1056,n=1,k=129,bs=[1,1],nr=[4,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=128,n=1,k=1057,bs=[1,1],nr=[4,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=f32,type_b=f32,m=1056,n=1,k=129,bs=[1,1],nr=[4,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=128,n=1,k=1057,bs=[1,1],nr=[4,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=f16,type_b=f32,m=1057,n=1,k=128,bs=[1,1],nr=[4,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=129,n=1,k=1056,bs=[1,1],nr=[4,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=1057,n=1,k=128,bs=[1,1],nr=[4,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=129,n=1,k=1056,bs=[1,1],nr=[4,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=f32,type_b=f32,m=1057,n=1,k=128,bs=[1,1],nr=[4,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=129,n=1,k=1056,bs=[1,1],nr=[4,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=f16,type_b=f32,m=1057,n=1,k=129,bs=[1,1],nr=[4,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=129,n=1,k=1057,bs=[1,1],nr=[4,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=1057,n=1,k=129,bs=[1,1],nr=[4,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=129,n=1,k=1057,bs=[1,1],nr=[4,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=f32,type_b=f32,m=1057,n=1,k=129,bs=[1,1],nr=[4,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=129,n=1,k=1057,bs=[1,1],nr=[4,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=f16,type_b=f32,m=1056,n=1,k=128,bs=[2,1],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=128,n=1,k=1056,bs=[2,1],nr=[1,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=1056,n=1,k=128,bs=[2,1],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=128,n=1,k=1056,bs=[2,1],nr=[1,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=f32,type_b=f32,m=1056,n=1,k=128,bs=[2,1],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=128,n=1,k=1056,bs=[2,1],nr=[1,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=f16,type_b=f32,m=1056,n=1,k=129,bs=[2,1],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=128,n=1,k=1057,bs=[2,1],nr=[1,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=1056,n=1,k=129,bs=[2,1],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=128,n=1,k=1057,bs=[2,1],nr=[1,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=f32,type_b=f32,m=1056,n=1,k=129,bs=[2,1],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=128,n=1,k=1057,bs=[2,1],nr=[1,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=f16,type_b=f32,m=1057,n=1,k=128,bs=[2,1],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=129,n=1,k=1056,bs=[2,1],nr=[1,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=1057,n=1,k=128,bs=[2,1],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=129,n=1,k=1056,bs=[2,1],nr=[1,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=f32,type_b=f32,m=1057,n=1,k=128,bs=[2,1],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=129,n=1,k=1056,bs=[2,1],nr=[1,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=f16,type_b=f32,m=1057,n=1,k=129,bs=[2,1],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=129,n=1,k=1057,bs=[2,1],nr=[1,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=1057,n=1,k=129,bs=[2,1],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=129,n=1,k=1057,bs=[2,1],nr=[1,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=f32,type_b=f32,m=1057,n=1,k=129,bs=[2,1],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=129,n=1,k=1057,bs=[2,1],nr=[1,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=f16,type_b=f32,m=1056,n=1,k=128,bs=[2,1],nr=[4,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=128,n=1,k=1056,bs=[2,1],nr=[4,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=1056,n=1,k=128,bs=[2,1],nr=[4,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=128,n=1,k=1056,bs=[2,1],nr=[4,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=f32,type_b=f32,m=1056,n=1,k=128,bs=[2,1],nr=[4,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=128,n=1,k=1056,bs=[2,1],nr=[4,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=f16,type_b=f32,m=1056,n=1,k=129,bs=[2,1],nr=[4,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=128,n=1,k=1057,bs=[2,1],nr=[4,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=1056,n=1,k=129,bs=[2,1],nr=[4,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=128,n=1,k=1057,bs=[2,1],nr=[4,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=f32,type_b=f32,m=1056,n=1,k=129,bs=[2,1],nr=[4,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=128,n=1,k=1057,bs=[2,1],nr=[4,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=f16,type_b=f32,m=1057,n=1,k=128,bs=[2,1],nr=[4,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=129,n=1,k=1056,bs=[2,1],nr=[4,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=1057,n=1,k=128,bs=[2,1],nr=[4,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=129,n=1,k=1056,bs=[2,1],nr=[4,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=f32,type_b=f32,m=1057,n=1,k=128,bs=[2,1],nr=[4,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=129,n=1,k=1056,bs=[2,1],nr=[4,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=f16,type_b=f32,m=1057,n=1,k=129,bs=[2,1],nr=[4,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=129,n=1,k=1057,bs=[2,1],nr=[4,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=1057,n=1,k=129,bs=[2,1],nr=[4,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=129,n=1,k=1057,bs=[2,1],nr=[4,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=f32,type_b=f32,m=1057,n=1,k=129,bs=[2,1],nr=[4,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=129,n=1,k=1057,bs=[2,1],nr=[4,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=f16,type_b=f32,m=1056,n=1,k=128,bs=[4,1],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=128,n=1,k=1056,bs=[4,1],nr=[1,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=1056,n=1,k=128,bs=[4,1],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=128,n=1,k=1056,bs=[4,1],nr=[1,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=f32,type_b=f32,m=1056,n=1,k=128,bs=[4,1],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=128,n=1,k=1056,bs=[4,1],nr=[1,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=f16,type_b=f32,m=1056,n=1,k=129,bs=[4,1],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=128,n=1,k=1057,bs=[4,1],nr=[1,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=1056,n=1,k=129,bs=[4,1],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=128,n=1,k=1057,bs=[4,1],nr=[1,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=f32,type_b=f32,m=1056,n=1,k=129,bs=[4,1],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=128,n=1,k=1057,bs=[4,1],nr=[1,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=f16,type_b=f32,m=1057,n=1,k=128,bs=[4,1],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=129,n=1,k=1056,bs=[4,1],nr=[1,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=1057,n=1,k=128,bs=[4,1],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=129,n=1,k=1056,bs=[4,1],nr=[1,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=f32,type_b=f32,m=1057,n=1,k=128,bs=[4,1],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=129,n=1,k=1056,bs=[4,1],nr=[1,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=f16,type_b=f32,m=1057,n=1,k=129,bs=[4,1],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=129,n=1,k=1057,bs=[4,1],nr=[1,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=1057,n=1,k=129,bs=[4,1],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=129,n=1,k=1057,bs=[4,1],nr=[1,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=f32,type_b=f32,m=1057,n=1,k=129,bs=[4,1],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=129,n=1,k=1057,bs=[4,1],nr=[1,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=f16,type_b=f32,m=1056,n=1,k=128,bs=[4,1],nr=[4,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=128,n=1,k=1056,bs=[4,1],nr=[4,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=1056,n=1,k=128,bs=[4,1],nr=[4,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=128,n=1,k=1056,bs=[4,1],nr=[4,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=f32,type_b=f32,m=1056,n=1,k=128,bs=[4,1],nr=[4,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=128,n=1,k=1056,bs=[4,1],nr=[4,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=f16,type_b=f32,m=1056,n=1,k=129,bs=[4,1],nr=[4,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=128,n=1,k=1057,bs=[4,1],nr=[4,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=1056,n=1,k=129,bs=[4,1],nr=[4,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=128,n=1,k=1057,bs=[4,1],nr=[4,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=f32,type_b=f32,m=1056,n=1,k=129,bs=[4,1],nr=[4,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=128,n=1,k=1057,bs=[4,1],nr=[4,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=f16,type_b=f32,m=1057,n=1,k=128,bs=[4,1],nr=[4,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=129,n=1,k=1056,bs=[4,1],nr=[4,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=1057,n=1,k=128,bs=[4,1],nr=[4,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=129,n=1,k=1056,bs=[4,1],nr=[4,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=f32,type_b=f32,m=1057,n=1,k=128,bs=[4,1],nr=[4,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=129,n=1,k=1056,bs=[4,1],nr=[4,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=f16,type_b=f32,m=1057,n=1,k=129,bs=[4,1],nr=[4,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=129,n=1,k=1057,bs=[4,1],nr=[4,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=1057,n=1,k=129,bs=[4,1],nr=[4,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=129,n=1,k=1057,bs=[4,1],nr=[4,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=f32,type_b=f32,m=1057,n=1,k=129,bs=[4,1],nr=[4,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=129,n=1,k=1057,bs=[4,1],nr=[4,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=f16,type_b=f32,m=1056,n=1,k=128,bs=[8,1],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=128,n=1,k=1056,bs=[8,1],nr=[1,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=1056,n=1,k=128,bs=[8,1],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=128,n=1,k=1056,bs=[8,1],nr=[1,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=f32,type_b=f32,m=1056,n=1,k=128,bs=[8,1],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=128,n=1,k=1056,bs=[8,1],nr=[1,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=f16,type_b=f32,m=1056,n=1,k=129,bs=[8,1],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=128,n=1,k=1057,bs=[8,1],nr=[1,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=1056,n=1,k=129,bs=[8,1],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=128,n=1,k=1057,bs=[8,1],nr=[1,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=f32,type_b=f32,m=1056,n=1,k=129,bs=[8,1],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=128,n=1,k=1057,bs=[8,1],nr=[1,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=f16,type_b=f32,m=1057,n=1,k=128,bs=[8,1],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=129,n=1,k=1056,bs=[8,1],nr=[1,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=1057,n=1,k=128,bs=[8,1],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=129,n=1,k=1056,bs=[8,1],nr=[1,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=f32,type_b=f32,m=1057,n=1,k=128,bs=[8,1],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=129,n=1,k=1056,bs=[8,1],nr=[1,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=f16,type_b=f32,m=1057,n=1,k=129,bs=[8,1],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=129,n=1,k=1057,bs=[8,1],nr=[1,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=1057,n=1,k=129,bs=[8,1],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=129,n=1,k=1057,bs=[8,1],nr=[1,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=f32,type_b=f32,m=1057,n=1,k=129,bs=[8,1],nr=[1,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=129,n=1,k=1057,bs=[8,1],nr=[1,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=f16,type_b=f32,m=1056,n=1,k=128,bs=[8,1],nr=[4,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=128,n=1,k=1056,bs=[8,1],nr=[4,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=1056,n=1,k=128,bs=[8,1],nr=[4,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=128,n=1,k=1056,bs=[8,1],nr=[4,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=f32,type_b=f32,m=1056,n=1,k=128,bs=[8,1],nr=[4,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=128,n=1,k=1056,bs=[8,1],nr=[4,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=f16,type_b=f32,m=1056,n=1,k=129,bs=[8,1],nr=[4,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=128,n=1,k=1057,bs=[8,1],nr=[4,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=1056,n=1,k=129,bs=[8,1],nr=[4,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=128,n=1,k=1057,bs=[8,1],nr=[4,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=f32,type_b=f32,m=1056,n=1,k=129,bs=[8,1],nr=[4,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=128,n=1,k=1057,bs=[8,1],nr=[4,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=f16,type_b=f32,m=1057,n=1,k=128,bs=[8,1],nr=[4,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=129,n=1,k=1056,bs=[8,1],nr=[4,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=1057,n=1,k=128,bs=[8,1],nr=[4,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=129,n=1,k=1056,bs=[8,1],nr=[4,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=f32,type_b=f32,m=1057,n=1,k=128,bs=[8,1],nr=[4,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=129,n=1,k=1056,bs=[8,1],nr=[4,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=f16,type_b=f32,m=1057,n=1,k=129,bs=[8,1],nr=[4,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f16,type_b=f32,m=129,n=1,k=1057,bs=[8,1],nr=[4,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=1057,n=1,k=129,bs=[8,1],nr=[4,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=bf16,type_b=f32,m=129,n=1,k=1057,bs=[8,1],nr=[4,1],per=[0,1,2,3],v=1): OK
  MUL_MAT(type_a=f32,type_b=f32,m=1057,n=1,k=129,bs=[8,1],nr=[4,1],per=[0,2,1,3],v=0): OK
  MUL_MAT(type_a=f32,type_b=f32,m=129,n=1,k=1057,bs=[8,1],nr=[4,1],per=[0,1,2,3],v=1): OK
  6534/6534 tests passed
  Backend ROCm0: OK
Backend 2/2: CPU
  Skipping CPU backend
2/2 backends passed
OK
H100 TFLOPS - Main Branch vs PR
Default (MMA enabled):
  MUL_MAT(type_a=f32,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):              4920 runs -   203.28 us/run -  60.13 GFLOP/run - 295.79 TFLOPS
  MUL_MAT(type_a=f16,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):              7358 runs -   135.93 us/run -  60.13 GFLOP/run - 442.35 TFLOPS
  MUL_MAT(type_a=bf16,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     8006 runs -   124.93 us/run -  60.13 GFLOP/run - 481.29 TFLOPS
  MUL_MAT(type_a=q4_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     3250 runs -   307.85 us/run -  60.13 GFLOP/run - 195.32 TFLOPS
  MUL_MAT(type_a=q4_1,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     2932 runs -   341.10 us/run -  60.13 GFLOP/run - 176.28 TFLOPS
  MUL_MAT(type_a=q5_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     2908 runs -   343.98 us/run -  60.13 GFLOP/run - 174.80 TFLOPS
  MUL_MAT(type_a=q5_1,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     2658 runs -   376.49 us/run -  60.13 GFLOP/run - 159.71 TFLOPS
  MUL_MAT(type_a=q8_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     3280 runs -   305.02 us/run -  60.13 GFLOP/run - 197.13 TFLOPS
  MUL_MAT(type_a=q2_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     1046 runs -   957.45 us/run -  60.13 GFLOP/run -  62.80 TFLOPS
  MUL_MAT(type_a=q3_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     2248 runs -   445.12 us/run -  60.13 GFLOP/run - 135.09 TFLOPS
  MUL_MAT(type_a=q4_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     2890 runs -   346.03 us/run -  60.13 GFLOP/run - 173.77 TFLOPS
  MUL_MAT(type_a=q5_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     2740 runs -   365.10 us/run -  60.13 GFLOP/run - 164.69 TFLOPS
  MUL_MAT(type_a=q6_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     2292 runs -   436.53 us/run -  60.13 GFLOP/run - 137.75 TFLOPS
  MUL_MAT(type_a=iq2_xxs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                  3048 runs -   328.15 us/run -  60.13 GFLOP/run - 183.24 TFLOPS
  MUL_MAT(type_a=iq2_xs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                   2292 runs -   436.53 us/run -  60.13 GFLOP/run - 137.74 TFLOPS
  MUL_MAT(type_a=iq2_s,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                    2274 runs -   440.12 us/run -  60.13 GFLOP/run - 136.62 TFLOPS
  MUL_MAT(type_a=iq3_xxs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                  2912 runs -   343.60 us/run -  60.13 GFLOP/run - 175.00 TFLOPS
  MUL_MAT(type_a=iq1_s,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                    2734 runs -   365.82 us/run -  60.13 GFLOP/run - 164.37 TFLOPS
  MUL_MAT(type_a=iq1_m,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                    3680 runs -   271.79 us/run -  60.13 GFLOP/run - 221.24 TFLOPS
  MUL_MAT(type_a=iq4_nl,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                   2798 runs -   357.54 us/run -  60.13 GFLOP/run - 168.18 TFLOPS
  MUL_MAT(type_a=iq3_s,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                    2914 runs -   343.18 us/run -  60.13 GFLOP/run - 175.21 TFLOPS
  MUL_MAT(type_a=iq4_xs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                   2966 runs -   337.16 us/run -  60.13 GFLOP/run - 178.34 TFLOPS

MMA disabled:
  MUL_MAT(type_a=f32,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):              4918 runs -   203.35 us/run -  60.13 GFLOP/run - 295.69 TFLOPS
  MUL_MAT(type_a=f16,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):              7346 runs -   136.16 us/run -  60.13 GFLOP/run - 441.61 TFLOPS
  MUL_MAT(type_a=bf16,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     8002 runs -   125.00 us/run -  60.13 GFLOP/run - 481.04 TFLOPS
  MUL_MAT(type_a=q4_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     3684 runs -   271.47 us/run -  60.13 GFLOP/run - 221.50 TFLOPS
  MUL_MAT(type_a=q4_1,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     3690 runs -   271.10 us/run -  60.13 GFLOP/run - 221.80 TFLOPS
  MUL_MAT(type_a=q5_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     3730 runs -   268.12 us/run -  60.13 GFLOP/run - 224.26 TFLOPS
  MUL_MAT(type_a=q5_1,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     3744 runs -   267.14 us/run -  60.13 GFLOP/run - 225.08 TFLOPS
  MUL_MAT(type_a=q8_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     5148 runs -   194.25 us/run -  60.13 GFLOP/run - 309.54 TFLOPS
  MUL_MAT(type_a=q2_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     3690 runs -   271.08 us/run -  60.13 GFLOP/run - 221.82 TFLOPS
  MUL_MAT(type_a=q3_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     3692 runs -   270.99 us/run -  60.13 GFLOP/run - 221.89 TFLOPS
  MUL_MAT(type_a=q4_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     3680 runs -   271.75 us/run -  60.13 GFLOP/run - 221.27 TFLOPS
  MUL_MAT(type_a=q5_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     3686 runs -   271.42 us/run -  60.13 GFLOP/run - 221.53 TFLOPS
  MUL_MAT(type_a=q6_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     3696 runs -   270.70 us/run -  60.13 GFLOP/run - 222.13 TFLOPS
  MUL_MAT(type_a=iq2_xxs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                  3682 runs -   271.69 us/run -  60.13 GFLOP/run - 221.32 TFLOPS
  MUL_MAT(type_a=iq2_xs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                   3674 runs -   272.23 us/run -  60.13 GFLOP/run - 220.88 TFLOPS
  MUL_MAT(type_a=iq2_s,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                    3676 runs -   272.05 us/run -  60.13 GFLOP/run - 221.02 TFLOPS
  MUL_MAT(type_a=iq3_xxs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                  3678 runs -   271.96 us/run -  60.13 GFLOP/run - 221.10 TFLOPS
  MUL_MAT(type_a=iq1_s,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                    3682 runs -   271.62 us/run -  60.13 GFLOP/run - 221.37 TFLOPS
  MUL_MAT(type_a=iq1_m,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                    3678 runs -   271.96 us/run -  60.13 GFLOP/run - 221.10 TFLOPS
  MUL_MAT(type_a=iq4_nl,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                   3676 runs -   272.06 us/run -  60.13 GFLOP/run - 221.01 TFLOPS
  MUL_MAT(type_a=iq3_s,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                    3680 runs -   271.87 us/run -  60.13 GFLOP/run - 221.17 TFLOPS
  MUL_MAT(type_a=iq4_xs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                   3686 runs -   271.34 us/run -  60.13 GFLOP/run - 221.60 TFLOPS

@IMbackK
Copy link
Collaborator

IMbackK commented Jul 24, 2025

@JohannesGaessler @IMbackK
* I have requested access to a Vega20 system, but it might get delayed. @0cc4m, would appreciate it if you can run this test on your setup. Also, gfx906 support is discontinued in rocm 6.4.0. Which rocm version are you using?

Since CDNA is mostly just GCN renamed with mfma added and the gfx pipeline removed. A pretty good facsimile of gfx906 can be constructed by forcing a CDNA device to take the dp4a path. We can add a compile time option similar to GGML_CUDA_FORCE_MMQ to make this happen, to allow you to more easily check if your changes are negatively affecting GCN.

Generally in the llamacpp community distro packages, such as those provided by arch linux, are most popular. These provide a greatly superior expirance to amds official install method and still support gfx906 and older and will likely continue to do so for as long as possible.

* **_More importantly_**, I tested the main branch on a H100, With the default code path (MMQ with mma instr) and with `ggml_cuda_should_use_mmq` returning false.

The code is mainly targeted at consumer devices. I dont think anyone has ever tuned llamacpp/ggml for H100.
It also makes sense that cublas is more likely to win on H100 than consumer devices, as presumably NV spends more time tuning its blas libraries for these devices.

Id also like to note here that even in hipblaslt, it seams that that the fp32 result on mi300 is to way to low and is something someone from the blas libaries team should investigate. I get that since CDNA3 has 304 CUs it can be hard to fill all the cus due to GCN/CDNA's 4x smid16 nature requiring alot of threads to fill, but mi300 barely outperforming MI100 despite haveing over twice the CUs and a large clock speed advantage seams wrong.

Even for the 16 bit wide dataypes, I have to say that have to say that i find it a bit disappointing that mi300 manages to achieve barely 15% of its roofline here while MI100 is closer to 50% but this is likely due to the problem size being to small.

If your performance assertions with regards to this pr prove true and the performance on gfx906 has be de-regressed, i will be happy to merge this pr.

@0cc4m
Copy link
Collaborator

0cc4m commented Jul 24, 2025

I have requested access to a Vega20 system, but it might get delayed. @0cc4m, would appreciate it if you can run this test on your setup. Also, gfx906 support is discontinued in rocm 6.4.0. Which rocm version are you using?

I'm using ROCm 6.2.4 currently, on Ubuntu Server 24.04. I think it is not getting updated automatically.

Here are updated Vega20 results:

Device 0: AMD Radeon (TM) Pro VII, gfx906:sramecc+:xnack- (0x906), VMM: no, Wave Size: 64

Master:

model size params backend ngl n_ubatch test t/s
llama 7B Q4_0 3.56 GiB 6.74 B ROCm 99 256 pp4096 558.52 ± 1.02
llama 7B Q4_0 3.56 GiB 6.74 B ROCm 99 512 pp4096 636.51 ± 1.42
llama 7B Q4_0 3.56 GiB 6.74 B ROCm 99 1024 pp4096 591.55 ± 0.94
llama 7B Q4_0 3.56 GiB 6.74 B ROCm 99 2048 pp4096 479.18 ± 0.15

PR:

model size params backend ngl n_ubatch test t/s
llama 7B Q4_0 3.56 GiB 6.74 B ROCm 99 256 pp4096 535.23 ± 1.56
llama 7B Q4_0 3.56 GiB 6.74 B ROCm 99 512 pp4096 611.93 ± 0.84
llama 7B Q4_0 3.56 GiB 6.74 B ROCm 99 1024 pp4096 572.52 ± 1.39
llama 7B Q4_0 3.56 GiB 6.74 B ROCm 99 2048 pp4096 465.19 ± 0.56
Master:
  MUL_MAT(type_a=f16,type_b=f32,m=16416,n=1,k=128,bs=[8,1],nr=[4,1],per=[0,2,1,3],v=0):                 1488 runs -  1255.63 us/run - 134.48 MFLOP/run - 107.10 GFLOPS
  MUL_MAT(type_a=f16,type_b=f32,m=128,n=1,k=16416,bs=[8,1],nr=[4,1],per=[0,1,2,3],v=1):                 5208 runs -   213.64 us/run - 134.48 MFLOP/run - 629.47 GFLOPS
  MUL_MAT(type_a=f32,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):               118 runs -  8517.61 us/run -  60.13 GFLOP/run -   7.06 TFLOPS
  MUL_MAT(type_a=f16,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):               150 runs -  6710.00 us/run -  60.13 GFLOP/run -   8.96 TFLOPS
  MUL_MAT(type_a=bf16,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                       44 runs - 22873.86 us/run -  60.13 GFLOP/run -   2.63 TFLOPS
  MUL_MAT(type_a=q4_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                      254 runs -  3968.00 us/run -  60.13 GFLOP/run -  15.15 TFLOPS
  MUL_MAT(type_a=q4_1,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                      242 runs -  4157.12 us/run -  60.13 GFLOP/run -  14.46 TFLOPS
  MUL_MAT(type_a=q5_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                      126 runs -  8032.88 us/run -  60.13 GFLOP/run -   7.49 TFLOPS
  MUL_MAT(type_a=q5_1,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                      118 runs -  8619.28 us/run -  60.13 GFLOP/run -   6.98 TFLOPS
  MUL_MAT(type_a=q8_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                       90 runs - 11347.61 us/run -  60.13 GFLOP/run -   5.30 TFLOPS
  MUL_MAT(type_a=q2_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                       32 runs - 32505.75 us/run -  60.13 GFLOP/run -   1.85 TFLOPS
  MUL_MAT(type_a=q3_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                      126 runs -  8037.80 us/run -  60.13 GFLOP/run -   7.48 TFLOPS
  MUL_MAT(type_a=q4_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                      110 runs -  9248.24 us/run -  60.13 GFLOP/run -   6.50 TFLOPS
  MUL_MAT(type_a=q5_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                       96 runs - 10555.08 us/run -  60.13 GFLOP/run -   5.70 TFLOPS
  MUL_MAT(type_a=q6_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                       96 runs - 10558.56 us/run -  60.13 GFLOP/run -   5.69 TFLOPS
  MUL_MAT(type_a=iq2_xxs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                   150 runs -  6668.18 us/run -  60.13 GFLOP/run -   9.02 TFLOPS
  MUL_MAT(type_a=iq2_xs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                    156 runs -  6418.13 us/run -  60.13 GFLOP/run -   9.37 TFLOPS
  MUL_MAT(type_a=iq2_s,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     150 runs -  6719.71 us/run -  60.13 GFLOP/run -   8.95 TFLOPS
  MUL_MAT(type_a=iq3_xxs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                   152 runs -  6652.37 us/run -  60.13 GFLOP/run -   9.04 TFLOPS
  MUL_MAT(type_a=iq1_s,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     144 runs -  7038.55 us/run -  60.13 GFLOP/run -   8.54 TFLOPS
  MUL_MAT(type_a=iq1_m,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     138 runs -  7322.35 us/run -  60.13 GFLOP/run -   8.21 TFLOPS
  MUL_MAT(type_a=iq4_nl,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                    118 runs -  8577.08 us/run -  60.13 GFLOP/run -   7.01 TFLOPS
  MUL_MAT(type_a=iq3_s,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     156 runs -  6437.51 us/run -  60.13 GFLOP/run -   9.34 TFLOPS
  MUL_MAT(type_a=iq4_xs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     98 runs - 10413.84 us/run -  60.13 GFLOP/run -   5.77 TFLOPS

PR:
  MUL_MAT(type_a=f16,type_b=f32,m=16416,n=1,k=128,bs=[8,1],nr=[4,1],per=[0,2,1,3],v=0):                 1488 runs -  1255.64 us/run - 134.48 MFLOP/run - 107.10 GFLOPS
  MUL_MAT(type_a=f16,type_b=f32,m=128,n=1,k=16416,bs=[8,1],nr=[4,1],per=[0,1,2,3],v=1):                 5208 runs -   213.57 us/run - 134.48 MFLOP/run - 629.68 GFLOPS
  MUL_MAT(type_a=f32,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):               118 runs -  8530.25 us/run -  60.13 GFLOP/run -   7.05 TFLOPS
  MUL_MAT(type_a=f16,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):               150 runs -  6679.61 us/run -  60.13 GFLOP/run -   9.00 TFLOPS
  MUL_MAT(type_a=bf16,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                       44 runs - 22823.09 us/run -  60.13 GFLOP/run -   2.63 TFLOPS
  MUL_MAT(type_a=q4_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                      232 runs -  4323.67 us/run -  60.13 GFLOP/run -  13.91 TFLOPS
  MUL_MAT(type_a=q4_1,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                      236 runs -  4253.58 us/run -  60.13 GFLOP/run -  14.14 TFLOPS
  MUL_MAT(type_a=q5_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                      164 runs -  6137.27 us/run -  60.13 GFLOP/run -   9.80 TFLOPS
  MUL_MAT(type_a=q5_1,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                      162 runs -  6204.90 us/run -  60.13 GFLOP/run -   9.69 TFLOPS
  MUL_MAT(type_a=q8_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                      110 runs -  9100.52 us/run -  60.13 GFLOP/run -   6.61 TFLOPS
  MUL_MAT(type_a=q2_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                       46 runs - 22463.74 us/run -  60.13 GFLOP/run -   2.68 TFLOPS
  MUL_MAT(type_a=q3_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                      126 runs -  7962.72 us/run -  60.13 GFLOP/run -   7.55 TFLOPS
  MUL_MAT(type_a=q4_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                      224 runs -  4466.87 us/run -  60.13 GFLOP/run -  13.46 TFLOPS
  MUL_MAT(type_a=q5_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                      142 runs -  7130.94 us/run -  60.13 GFLOP/run -   8.43 TFLOPS
  MUL_MAT(type_a=q6_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                      158 runs -  6361.59 us/run -  60.13 GFLOP/run -   9.45 TFLOPS
  MUL_MAT(type_a=iq2_xxs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                   162 runs -  6220.20 us/run -  60.13 GFLOP/run -   9.67 TFLOPS
  MUL_MAT(type_a=iq2_xs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                    150 runs -  6732.95 us/run -  60.13 GFLOP/run -   8.93 TFLOPS
  MUL_MAT(type_a=iq2_s,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     142 runs -  7060.12 us/run -  60.13 GFLOP/run -   8.52 TFLOPS
  MUL_MAT(type_a=iq3_xxs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                   162 runs -  6203.98 us/run -  60.13 GFLOP/run -   9.69 TFLOPS
  MUL_MAT(type_a=iq1_s,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     140 runs -  7143.69 us/run -  60.13 GFLOP/run -   8.42 TFLOPS
  MUL_MAT(type_a=iq1_m,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     138 runs -  7334.96 us/run -  60.13 GFLOP/run -   8.20 TFLOPS
  MUL_MAT(type_a=iq4_nl,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                    112 runs -  9032.46 us/run -  60.13 GFLOP/run -   6.66 TFLOPS
  MUL_MAT(type_a=iq3_s,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     156 runs -  6448.32 us/run -  60.13 GFLOP/run -   9.32 TFLOPS
  MUL_MAT(type_a=iq4_xs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                    114 runs -  8820.08 us/run -  60.13 GFLOP/run -   6.82 TFLOPS

The regression is still there, but greatly reduced. I'd prefer if there was no regression at all, but I understand that might not be feasible if you can't think of another thing to try.

@IMbackK
Copy link
Collaborator

IMbackK commented Jul 24, 2025

I think this degree of slowdown is acceptable for the increased perf in other formats.

@IMbackK
Copy link
Collaborator

IMbackK commented Jul 24, 2025

@deepsek currently this dosent compile on ci due to Werror

@deepsek
Copy link
Author

deepsek commented Jul 24, 2025

Fixed the unused parameter Werror.

The code is mainly targeted at consumer devices. I dont think anyone has ever tuned llamacpp/ggml for H100.
It also makes sense that cublas is more likely to win on H100 than consumer devices, as presumably NV spends more time tuning its blas libraries for these devices.

I quickly ran the test on a RTX 2080 ti. I see the same results with cublas path outperforming the MMQ path, same as H100.
Again, while this is not a blocker for this PR.
I'm looking to see if there a reasoning why a different design choice is made here between AMD & NV in terms of code paths, or if I'm just doing something wrong with my testing. Or maybe I'm just overthinking here and it slipped through the cracks during testing and needs to be fixed?

MMA Disabled (cublas path)
  MUL_MAT(type_a=f32,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):               206 runs -  4890.33 us/run -  60.13 GFLOP/run -  12.30 TFLOPS
  MUL_MAT(type_a=f16,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):              1232 runs -   812.32 us/run -  60.13 GFLOP/run -  74.02 TFLOPS
  MUL_MAT(type_a=bf16,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                      130 runs -  7757.88 us/run -  60.13 GFLOP/run -   7.75 TFLOPS
  MUL_MAT(type_a=q4_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                      926 runs -  1081.20 us/run -  60.13 GFLOP/run -  55.61 TFLOPS
  MUL_MAT(type_a=q4_1,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                      930 runs -  1075.81 us/run -  60.13 GFLOP/run -  55.89 TFLOPS
  MUL_MAT(type_a=q5_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                      868 runs -  1154.16 us/run -  60.13 GFLOP/run -  52.10 TFLOPS
  MUL_MAT(type_a=q5_1,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                      870 runs -  1150.04 us/run -  60.13 GFLOP/run -  52.28 TFLOPS
  MUL_MAT(type_a=q8_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                      898 runs -  1113.99 us/run -  60.13 GFLOP/run -  53.98 TFLOPS
  MUL_MAT(type_a=q2_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                      950 runs -  1053.73 us/run -  60.13 GFLOP/run -  57.06 TFLOPS
  MUL_MAT(type_a=q3_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                      932 runs -  1075.19 us/run -  60.13 GFLOP/run -  55.92 TFLOPS
  MUL_MAT(type_a=q4_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                      936 runs -  1069.79 us/run -  60.13 GFLOP/run -  56.21 TFLOPS
  MUL_MAT(type_a=q5_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                      938 runs -  1067.14 us/run -  60.13 GFLOP/run -  56.35 TFLOPS
  MUL_MAT(type_a=q6_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                      932 runs -  1074.55 us/run -  60.13 GFLOP/run -  55.96 TFLOPS
  MUL_MAT(type_a=iq2_xxs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                   660 runs -  1517.49 us/run -  60.13 GFLOP/run -  39.62 TFLOPS
  MUL_MAT(type_a=iq2_xs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                    598 runs -  1674.09 us/run -  60.13 GFLOP/run -  35.92 TFLOPS
  MUL_MAT(type_a=iq2_s,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     650 runs -  1542.34 us/run -  60.13 GFLOP/run -  38.99 TFLOPS
  MUL_MAT(type_a=iq3_xxs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                   658 runs -  1520.42 us/run -  60.13 GFLOP/run -  39.55 TFLOPS
  MUL_MAT(type_a=iq1_s,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     780 runs -  1283.44 us/run -  60.13 GFLOP/run -  46.85 TFLOPS
  MUL_MAT(type_a=iq1_m,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     616 runs -  1626.38 us/run -  60.13 GFLOP/run -  36.97 TFLOPS
  MUL_MAT(type_a=iq4_nl,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                    758 runs -  1319.47 us/run -  60.13 GFLOP/run -  45.57 TFLOPS
  MUL_MAT(type_a=iq3_s,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     696 runs -  1439.55 us/run -  60.13 GFLOP/run -  41.77 TFLOPS
  MUL_MAT(type_a=iq4_xs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                    784 runs -  1277.54 us/run -  60.13 GFLOP/run -  47.07 TFLOPS
  
MMA Enabled (Default)
  MUL_MAT(type_a=f32,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):               206 runs -  4860.04 us/run -  60.13 GFLOP/run -  12.37 TFLOPS
  MUL_MAT(type_a=f16,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):              1262 runs -   793.19 us/run -  60.13 GFLOP/run -  75.81 TFLOPS
  MUL_MAT(type_a=bf16,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                      134 runs -  7536.22 us/run -  60.13 GFLOP/run -   7.98 TFLOPS
  MUL_MAT(type_a=q4_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                      840 runs -  1191.10 us/run -  60.13 GFLOP/run -  50.48 TFLOPS
  MUL_MAT(type_a=q4_1,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                      760 runs -  1317.36 us/run -  60.13 GFLOP/run -  45.64 TFLOPS
  MUL_MAT(type_a=q5_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                      774 runs -  1294.64 us/run -  60.13 GFLOP/run -  46.44 TFLOPS
  MUL_MAT(type_a=q5_1,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                      718 runs -  1395.47 us/run -  60.13 GFLOP/run -  43.09 TFLOPS
  MUL_MAT(type_a=q8_0,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                      702 runs -  1425.90 us/run -  60.13 GFLOP/run -  42.17 TFLOPS
  MUL_MAT(type_a=q2_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                      448 runs -  2237.44 us/run -  60.13 GFLOP/run -  26.87 TFLOPS
  MUL_MAT(type_a=q3_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                      652 runs -  1535.98 us/run -  60.13 GFLOP/run -  39.15 TFLOPS
  MUL_MAT(type_a=q4_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                      768 runs -  1303.45 us/run -  60.13 GFLOP/run -  46.13 TFLOPS
  MUL_MAT(type_a=q5_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                      722 runs -  1388.57 us/run -  60.13 GFLOP/run -  43.30 TFLOPS
  MUL_MAT(type_a=q6_K,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                      634 runs -  1577.74 us/run -  60.13 GFLOP/run -  38.11 TFLOPS
  MUL_MAT(type_a=iq2_xxs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                   622 runs -  1609.05 us/run -  60.13 GFLOP/run -  37.37 TFLOPS
  MUL_MAT(type_a=iq2_xs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                    412 runs -  2435.66 us/run -  60.13 GFLOP/run -  24.69 TFLOPS
  MUL_MAT(type_a=iq2_s,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     490 runs -  2047.03 us/run -  60.13 GFLOP/run -  29.37 TFLOPS
  MUL_MAT(type_a=iq3_xxs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                   570 runs -  1756.19 us/run -  60.13 GFLOP/run -  34.24 TFLOPS
  MUL_MAT(type_a=iq1_s,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     490 runs -  2043.27 us/run -  60.13 GFLOP/run -  29.43 TFLOPS
  MUL_MAT(type_a=iq1_m,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     644 runs -  1554.53 us/run -  60.13 GFLOP/run -  38.68 TFLOPS
  MUL_MAT(type_a=iq4_nl,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                    654 runs -  1532.06 us/run -  60.13 GFLOP/run -  39.25 TFLOPS
  MUL_MAT(type_a=iq3_s,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                     626 runs -  1597.93 us/run -  60.13 GFLOP/run -  37.63 TFLOPS
  MUL_MAT(type_a=iq4_xs,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0):                    648 runs -  1545.26 us/run -  60.13 GFLOP/run -  38.91 TFLOPS

@IMbackK
Copy link
Collaborator

IMbackK commented Jul 24, 2025

I quickly ran the test on a RTX 2080 ti. I see the same results with cublas path outperforming the MMQ path, same as H100. Again, while this is not a blocker for this PR. I'm looking to see if there a reasoning why a different design choice is made here between AMD & NV in terms of code paths, or if I'm just doing something wrong with my testing. Or maybe I'm just overthinking here and it slipped through the cracks during testing and needs to be fixed?

I only use amd devices so i cant comment too mutch on NV performance but one reason why the MMQ path is choosen on NV devices while amd deivces mostly choose the blas path is simply because @JohannesGaessler wrote most of the MMQ path while purely working on and optimizing for NV devices and I purely work on AMD devices and favored optimizing the blas path, as this was the path of less resistance.

If the blas path performs better also on NV devices it should be favored there in cases where this is true, but this needs more detailed examination as possibly this is just a spike at nb 512 or specfic to turing and hopper etc.

@slaren
Copy link
Member

slaren commented Jul 24, 2025

The lower memory usage of MMQ is also a significant factor on why it is the default, even if it is not always faster than cuBLAS. The size of the buffer used to store the dequantized weights can be significant on consumer GPUs.

@ggml-org ggml-org deleted a comment from deepsek Jul 24, 2025
@IMbackK
Copy link
Collaborator

IMbackK commented Jul 24, 2025

@deepsek upps, sorry i accidentally edited your post instead of quoting it, please repost.

From my side there is nothing further missing, id just like to give it another spin to test for regressions, i will approve after.

@deepsek
Copy link
Author

deepsek commented Jul 24, 2025

No worries haha. I was just saying. Based on all the comments so far, looks like ggml_cuda_should_use_mmq needs a lot more changes in it, to choose MMQ based on architecture and sizes across all relevant backends.

@JohannesGaessler

To clarify, the function should return true for small batch sizes and false for large batch sizes with the boundary chosen in such a way that maximizes performance.

@slaren

The lower memory usage of MMQ is also a significant factor on why it is the default, even if it is not always faster than cuBLAS. The size of the buffer used to store the dequantized weights can be significant on consumer GPUs.

Is there actual guidance as to when performance is preferred over memory usage? Looks like there are conflicting viewpoints. Would be great to have this information documentation for when we contribute and add other architectures, there is a common design principle.

P.S.,

  • I do have to say, some look needs to be taken further into documentation for llama.cpp. When I was trying to add support for AMD in MMQ kernels. There were so many implicit design choices that are made without explanation in comments/documentation. It makes it extremely hard to understand the code and thereby making it hard to contribute more.

Comment on lines +191 to +200
if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8>
#pragma unroll
for (int l = 0; l < t.ne; ++l) {
t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
}
} else {
int64_t * xi = (int64_t *) t.x;
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
xi[0] = xs[0];
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8>
#pragma unroll
for (int l = 0; l < t.ne; ++l) {
t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
}
} else {
int64_t * xi = (int64_t *) t.x;
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
xi[0] = xs[0];
}
if constexpr (I != 64 || J != 2) {
int64_t * xi = (int64_t *) t.x;
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
xi[0] = xs[0];
return;
}

I think this would be simpler.

Copy link
Author

@deepsek deepsek Jul 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean without the preprocessor directives?
This would affect the NV code path when we call load_generic though? I see some instances where load_generic is called

static __device__ __forceinline__ void load_generic(...) {
        if constexpr (I != 64 || J != 2) {
            int64_t * xi = (int64_t *) t.x;
            const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
            xi[0] = xs[0];
            return;
        }

#pragma unroll
        for (int l = 0; l < t.ne; ++l) {
            t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
        }
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I basically meant to have the instructions for loading data as 64 bit encapsulated in an ifdef AMD_MFMA_AVAILABLE ... #endif and to use the generic implementation if the preconditions aren't met. But if this is going to be refactored anyways it doesn't matter.

@JohannesGaessler
Copy link
Collaborator

I don't understand what you're doing with tile_load, please explain it.

@slaren
Copy link
Member

slaren commented Jul 24, 2025

Is there actual guidance as to when performance is preferred over memory usage? Looks like there are conflicting viewpoints. Would be great to have this information documentation for when we contribute and add other architectures, there is a common design principle.

It's very hard to have a single design principle for every situation. Historically, some models would need as much as 6 GB of VRAM for the dequantization buffer. This would cause a lot of problems for people using consumer GPUs who would not understand why they could not offload to the GPU as many layers of the model as they would expect. This was one of the reasons why MMQ was made the default, even thought it was not faster than cuBLAS in every situation. Ultimately, it doesn't matter if cuBLAS is 10% faster, if using it means that you need to keep a large portion of the model in a CPU that is 10 times slower. For a data center GPU where VRAM is not so limited, the calculus may be different.

@JohannesGaessler
Copy link
Collaborator

I ran the following tests on my RTX 3090/4090:

for q in q4_0 q4_1 q5_0 q5_1 q8_0 q2_k q3_k_s q4_k_s q5_k_s q6_k iq1_s iq2_xxs iq2_xs iq2_s iq3_xxs iq3_xs iq3_s iq3_m iq4_nl iq4_xs; do echo $q; ./bench --model models/opt/llama_3-${q}.gguf -r 1 -fa 1 -n 0 -p 2048 -ub 16,32,64,128,256,512,1024,2048 -o sql|sqlite3 llama-bench.sqlite; sleep 10; done

With compre-llama-bench.py I then get the following changes in performance when unconditionally using cuBLAS:

GPU Model Microbatch size Test t/s master t/s cuda-cublas-test Speedup
RTX 3090 llama 8B IQ1_S - 1.5625 bpw 16 pp2048 1269.16 363.33 0.29
RTX 3090 llama 8B IQ1_S - 1.5625 bpw 32 pp2048 2123.61 700.72 0.33
RTX 3090 llama 8B IQ1_S - 1.5625 bpw 64 pp2048 3035.47 1342.33 0.44
RTX 3090 llama 8B IQ1_S - 1.5625 bpw 128 pp2048 3725.88 2319.45 0.62
RTX 3090 llama 8B IQ1_S - 1.5625 bpw 256 pp2048 4410.33 3462.68 0.79
RTX 3090 llama 8B IQ1_S - 1.5625 bpw 512 pp2048 4721.20 4110.23 0.87
RTX 3090 llama 8B IQ1_S - 1.5625 bpw 1024 pp2048 4838.83 4874.03 1.01
RTX 3090 llama 8B IQ1_S - 1.5625 bpw 2048 pp2048 4776.52 4950.20 1.04
RTX 3090 llama 8B IQ2_S - 2.5 bpw 16 pp2048 1177.24 341.22 0.29
RTX 3090 llama 8B IQ2_S - 2.5 bpw 32 pp2048 1844.47 658.82 0.36
RTX 3090 llama 8B IQ2_S - 2.5 bpw 64 pp2048 2756.34 1262.28 0.46
RTX 3090 llama 8B IQ2_S - 2.5 bpw 128 pp2048 3397.38 2188.85 0.64
RTX 3090 llama 8B IQ2_S - 2.5 bpw 256 pp2048 3974.27 3305.33 0.83
RTX 3090 llama 8B IQ2_S - 2.5 bpw 512 pp2048 4188.95 3972.21 0.95
RTX 3090 llama 8B IQ2_S - 2.5 bpw 1024 pp2048 4218.73 4702.10 1.11
RTX 3090 llama 8B IQ2_S - 2.5 bpw 2048 pp2048 4133.44 4725.80 1.14
RTX 3090 llama 8B IQ2_XS - 2.3125 bpw 16 pp2048 1200.91 340.67 0.28
RTX 3090 llama 8B IQ2_XS - 2.3125 bpw 32 pp2048 1857.67 658.56 0.35
RTX 3090 llama 8B IQ2_XS - 2.3125 bpw 64 pp2048 2730.49 1263.58 0.46
RTX 3090 llama 8B IQ2_XS - 2.3125 bpw 128 pp2048 3300.26 2197.07 0.67
RTX 3090 llama 8B IQ2_XS - 2.3125 bpw 256 pp2048 3861.49 3315.19 0.86
RTX 3090 llama 8B IQ2_XS - 2.3125 bpw 512 pp2048 4100.01 3998.03 0.98
RTX 3090 llama 8B IQ2_XS - 2.3125 bpw 1024 pp2048 4180.07 4801.08 1.15
RTX 3090 llama 8B IQ2_XS - 2.3125 bpw 2048 pp2048 4168.49 4919.88 1.18
RTX 3090 llama 8B IQ2_XXS - 2.0625 bpw 16 pp2048 1195.23 355.18 0.30
RTX 3090 llama 8B IQ2_XXS - 2.0625 bpw 32 pp2048 2004.35 685.69 0.34
RTX 3090 llama 8B IQ2_XXS - 2.0625 bpw 64 pp2048 3042.87 1314.73 0.43
RTX 3090 llama 8B IQ2_XXS - 2.0625 bpw 128 pp2048 3890.35 2278.40 0.59
RTX 3090 llama 8B IQ2_XXS - 2.0625 bpw 256 pp2048 4561.47 3413.11 0.75
RTX 3090 llama 8B IQ2_XXS - 2.0625 bpw 512 pp2048 4825.37 4082.01 0.85
RTX 3090 llama 8B IQ2_XXS - 2.0625 bpw 1024 pp2048 4971.60 4832.21 0.97
RTX 3090 llama 8B IQ2_XXS - 2.0625 bpw 2048 pp2048 4898.16 4951.49 1.01
RTX 3090 llama 8B IQ3_S - 3.4375 bpw 16 pp2048 1063.41 347.95 0.33
RTX 3090 llama 8B IQ3_S - 3.4375 bpw 32 pp2048 1821.60 672.38 0.37
RTX 3090 llama 8B IQ3_S - 3.4375 bpw 64 pp2048 2907.41 1290.48 0.44
RTX 3090 llama 8B IQ3_S - 3.4375 bpw 128 pp2048 3822.82 2231.33 0.58
RTX 3090 llama 8B IQ3_S - 3.4375 bpw 256 pp2048 4488.48 3343.12 0.74
RTX 3090 llama 8B IQ3_S - 3.4375 bpw 512 pp2048 4762.17 4002.22 0.84
RTX 3090 llama 8B IQ3_S - 3.4375 bpw 1024 pp2048 4813.77 4718.51 0.98
RTX 3090 llama 8B IQ3_S - 3.4375 bpw 2048 pp2048 4667.00 4726.04 1.01
RTX 3090 llama 8B IQ3_S mix - 3.66 bpw 16 pp2048 1081.40 348.93 0.32
RTX 3090 llama 8B IQ3_S mix - 3.66 bpw 32 pp2048 1855.74 674.47 0.36
RTX 3090 llama 8B IQ3_S mix - 3.66 bpw 64 pp2048 2930.98 1292.52 0.44
RTX 3090 llama 8B IQ3_S mix - 3.66 bpw 128 pp2048 3809.37 2237.84 0.59
RTX 3090 llama 8B IQ3_S mix - 3.66 bpw 256 pp2048 4470.69 3346.66 0.75
RTX 3090 llama 8B IQ3_S mix - 3.66 bpw 512 pp2048 4774.10 4002.66 0.84
RTX 3090 llama 8B IQ3_S mix - 3.66 bpw 1024 pp2048 4804.29 4734.53 0.99
RTX 3090 llama 8B IQ3_S mix - 3.66 bpw 2048 pp2048 4646.86 4727.90 1.02
RTX 3090 llama 8B IQ3_XS - 3.3 bpw 16 pp2048 1116.47 351.51 0.31
RTX 3090 llama 8B IQ3_XS - 3.3 bpw 32 pp2048 1836.45 679.27 0.37
RTX 3090 llama 8B IQ3_XS - 3.3 bpw 64 pp2048 2949.60 1302.48 0.44
RTX 3090 llama 8B IQ3_XS - 3.3 bpw 128 pp2048 3820.79 2255.19 0.59
RTX 3090 llama 8B IQ3_XS - 3.3 bpw 256 pp2048 4496.96 3367.13 0.75
RTX 3090 llama 8B IQ3_XS - 3.3 bpw 512 pp2048 4756.66 4025.54 0.85
RTX 3090 llama 8B IQ3_XS - 3.3 bpw 1024 pp2048 4804.88 4746.55 0.99
RTX 3090 llama 8B IQ3_XS - 3.3 bpw 2048 pp2048 4654.84 4728.24 1.02
RTX 3090 llama 8B IQ3_XXS - 3.0625 bpw 16 pp2048 1152.30 351.00 0.30
RTX 3090 llama 8B IQ3_XXS - 3.0625 bpw 32 pp2048 1860.74 678.43 0.36
RTX 3090 llama 8B IQ3_XXS - 3.0625 bpw 64 pp2048 2937.99 1300.81 0.44
RTX 3090 llama 8B IQ3_XXS - 3.0625 bpw 128 pp2048 3766.46 2256.70 0.60
RTX 3090 llama 8B IQ3_XXS - 3.0625 bpw 256 pp2048 4426.53 3381.20 0.76
RTX 3090 llama 8B IQ3_XXS - 3.0625 bpw 512 pp2048 4699.49 4031.36 0.86
RTX 3090 llama 8B IQ3_XXS - 3.0625 bpw 1024 pp2048 4739.41 4743.37 1.00
RTX 3090 llama 8B IQ3_XXS - 3.0625 bpw 2048 pp2048 4588.29 4740.71 1.03
RTX 3090 llama 8B IQ4_NL - 4.5 bpw 16 pp2048 1104.76 355.22 0.32
RTX 3090 llama 8B IQ4_NL - 4.5 bpw 32 pp2048 1855.29 686.04 0.37
RTX 3090 llama 8B IQ4_NL - 4.5 bpw 64 pp2048 2835.30 1315.09 0.46
RTX 3090 llama 8B IQ4_NL - 4.5 bpw 128 pp2048 3739.72 2281.96 0.61
RTX 3090 llama 8B IQ4_NL - 4.5 bpw 256 pp2048 4384.26 3418.87 0.78
RTX 3090 llama 8B IQ4_NL - 4.5 bpw 512 pp2048 4578.77 4067.92 0.89
RTX 3090 llama 8B IQ4_NL - 4.5 bpw 1024 pp2048 4657.31 4813.85 1.03
RTX 3090 llama 8B IQ4_NL - 4.5 bpw 2048 pp2048 4539.71 4862.84 1.07
RTX 3090 llama 8B IQ4_XS - 4.25 bpw 16 pp2048 1113.46 353.79 0.32
RTX 3090 llama 8B IQ4_XS - 4.25 bpw 32 pp2048 1881.63 682.58 0.36
RTX 3090 llama 8B IQ4_XS - 4.25 bpw 64 pp2048 2837.24 1311.79 0.46
RTX 3090 llama 8B IQ4_XS - 4.25 bpw 128 pp2048 3743.22 2277.19 0.61
RTX 3090 llama 8B IQ4_XS - 4.25 bpw 256 pp2048 4403.62 3416.36 0.78
RTX 3090 llama 8B IQ4_XS - 4.25 bpw 512 pp2048 4670.17 4081.44 0.87
RTX 3090 llama 8B IQ4_XS - 4.25 bpw 1024 pp2048 4792.84 4832.52 1.01
RTX 3090 llama 8B IQ4_XS - 4.25 bpw 2048 pp2048 4695.20 4897.09 1.04
RTX 3090 llama 8B Q2_K_M 16 pp2048 1243.34 348.43 0.28
RTX 3090 llama 8B Q2_K_M 32 pp2048 1905.93 674.33 0.35
RTX 3090 llama 8B Q2_K_M 64 pp2048 2602.53 1294.78 0.50
RTX 3090 llama 8B Q2_K_M 128 pp2048 2744.97 2244.53 0.82
RTX 3090 llama 8B Q2_K_M 256 pp2048 3360.81 3388.03 1.01
RTX 3090 llama 8B Q2_K_M 512 pp2048 3648.98 4063.21 1.11
RTX 3090 llama 8B Q2_K_M 1024 pp2048 3807.19 4851.76 1.27
RTX 3090 llama 8B Q2_K_M 2048 pp2048 3806.04 4956.14 1.30
RTX 3090 llama 8B Q3_K_S 16 pp2048 1204.59 318.23 0.26
RTX 3090 llama 8B Q3_K_S 32 pp2048 1950.33 616.45 0.32
RTX 3090 llama 8B Q3_K_S 64 pp2048 2857.42 1182.46 0.41
RTX 3090 llama 8B Q3_K_S 128 pp2048 3399.77 2067.40 0.61
RTX 3090 llama 8B Q3_K_S 256 pp2048 4007.47 3183.37 0.79
RTX 3090 llama 8B Q3_K_S 512 pp2048 4261.17 3922.93 0.92
RTX 3090 llama 8B Q3_K_S 1024 pp2048 4377.78 4765.92 1.09
RTX 3090 llama 8B Q3_K_S 2048 pp2048 4331.94 4937.37 1.14
RTX 3090 llama 8B Q4_0 16 pp2048 1240.66 362.31 0.29
RTX 3090 llama 8B Q4_0 32 pp2048 2048.14 699.37 0.34
RTX 3090 llama 8B Q4_0 64 pp2048 3139.64 1344.21 0.43
RTX 3090 llama 8B Q4_0 128 pp2048 4045.92 2340.01 0.58
RTX 3090 llama 8B Q4_0 256 pp2048 4810.84 3532.32 0.73
RTX 3090 llama 8B Q4_0 512 pp2048 5139.21 4187.53 0.81
RTX 3090 llama 8B Q4_0 1024 pp2048 5207.74 4957.56 0.95
RTX 3090 llama 8B Q4_0 2048 pp2048 5104.67 5078.88 0.99
RTX 3090 llama 8B Q4_1 16 pp2048 1235.90 358.78 0.29
RTX 3090 llama 8B Q4_1 32 pp2048 2160.04 693.92 0.32
RTX 3090 llama 8B Q4_1 64 pp2048 3030.63 1332.42 0.44
RTX 3090 llama 8B Q4_1 128 pp2048 3759.64 2321.48 0.62
RTX 3090 llama 8B Q4_1 256 pp2048 4503.89 3485.79 0.77
RTX 3090 llama 8B Q4_1 512 pp2048 4817.82 4127.34 0.86
RTX 3090 llama 8B Q4_1 1024 pp2048 4877.41 4917.45 1.01
RTX 3090 llama 8B Q4_1 2048 pp2048 4807.40 5014.72 1.04
RTX 3090 llama 8B Q4_K_S 16 pp2048 1239.95 354.17 0.29
RTX 3090 llama 8B Q4_K_S 32 pp2048 2172.58 685.07 0.32
RTX 3090 llama 8B Q4_K_S 64 pp2048 3106.63 1316.57 0.42
RTX 3090 llama 8B Q4_K_S 128 pp2048 3820.59 2286.16 0.60
RTX 3090 llama 8B Q4_K_S 256 pp2048 4509.28 3433.22 0.76
RTX 3090 llama 8B Q4_K_S 512 pp2048 4794.53 4082.77 0.85
RTX 3090 llama 8B Q4_K_S 1024 pp2048 4902.25 4880.85 1.00
RTX 3090 llama 8B Q4_K_S 2048 pp2048 4814.05 4977.22 1.03
RTX 3090 llama 8B Q5_0 16 pp2048 1081.15 296.57 0.27
RTX 3090 llama 8B Q5_0 32 pp2048 1823.62 575.44 0.32
RTX 3090 llama 8B Q5_0 64 pp2048 2901.53 1106.99 0.38
RTX 3090 llama 8B Q5_0 128 pp2048 3821.11 1947.81 0.51
RTX 3090 llama 8B Q5_0 256 pp2048 4520.28 3031.59 0.67
RTX 3090 llama 8B Q5_0 512 pp2048 4817.86 3791.02 0.79
RTX 3090 llama 8B Q5_0 1024 pp2048 4896.94 4643.59 0.95
RTX 3090 llama 8B Q5_0 2048 pp2048 4796.51 4843.25 1.01
RTX 3090 llama 8B Q5_1 16 pp2048 1139.31 297.59 0.26
RTX 3090 llama 8B Q5_1 32 pp2048 2021.32 576.44 0.29
RTX 3090 llama 8B Q5_1 64 pp2048 2841.82 1108.70 0.39
RTX 3090 llama 8B Q5_1 128 pp2048 3570.52 1944.53 0.54
RTX 3090 llama 8B Q5_1 256 pp2048 4245.90 3022.05 0.71
RTX 3090 llama 8B Q5_1 512 pp2048 4575.99 3794.22 0.83
RTX 3090 llama 8B Q5_1 1024 pp2048 4649.67 4615.96 0.99
RTX 3090 llama 8B Q5_1 2048 pp2048 4571.11 4823.91 1.06
RTX 3090 llama 8B Q5_K_S 16 pp2048 1190.53 348.42 0.29
RTX 3090 llama 8B Q5_K_S 32 pp2048 2041.08 673.12 0.33
RTX 3090 llama 8B Q5_K_S 64 pp2048 2953.62 1293.24 0.44
RTX 3090 llama 8B Q5_K_S 128 pp2048 3673.18 2237.51 0.61
RTX 3090 llama 8B Q5_K_S 256 pp2048 4359.61 3368.89 0.77
RTX 3090 llama 8B Q5_K_S 512 pp2048 4657.81 4037.42 0.87
RTX 3090 llama 8B Q5_K_S 1024 pp2048 4754.54 4840.55 1.02
RTX 3090 llama 8B Q5_K_S 2048 pp2048 4676.00 4948.82 1.06
RTX 3090 llama 8B Q6_K 16 pp2048 1028.59 349.42 0.34
RTX 3090 llama 8B Q6_K 32 pp2048 1827.49 675.05 0.37
RTX 3090 llama 8B Q6_K 64 pp2048 2687.06 1297.17 0.48
RTX 3090 llama 8B Q6_K 128 pp2048 3377.85 2250.68 0.67
RTX 3090 llama 8B Q6_K 256 pp2048 3973.83 3382.65 0.85
RTX 3090 llama 8B Q6_K 512 pp2048 4207.99 4054.54 0.96
RTX 3090 llama 8B Q6_K 1024 pp2048 4288.43 4827.50 1.13
RTX 3090 llama 8B Q6_K 2048 pp2048 4227.47 4900.35 1.16
RTX 3090 llama 8B Q8_0 16 pp2048 973.31 341.18 0.35
RTX 3090 llama 8B Q8_0 32 pp2048 1683.00 658.35 0.39
RTX 3090 llama 8B Q8_0 64 pp2048 2928.68 1264.50 0.43
RTX 3090 llama 8B Q8_0 128 pp2048 3919.66 2220.24 0.57
RTX 3090 llama 8B Q8_0 256 pp2048 4636.38 3376.11 0.73
RTX 3090 llama 8B Q8_0 512 pp2048 4881.97 4052.21 0.83
RTX 3090 llama 8B Q8_0 1024 pp2048 5005.15 4837.42 0.97
RTX 3090 llama 8B Q8_0 2048 pp2048 4940.33 4961.92 1.00
RTX 4090 llama 8B IQ1_S - 1.5625 bpw 16 pp2048 2258.70 513.36 0.23
RTX 4090 llama 8B IQ1_S - 1.5625 bpw 32 pp2048 3999.30 1022.46 0.26
RTX 4090 llama 8B IQ1_S - 1.5625 bpw 64 pp2048 6196.90 1919.15 0.31
RTX 4090 llama 8B IQ1_S - 1.5625 bpw 128 pp2048 8209.90 3750.54 0.46
RTX 4090 llama 8B IQ1_S - 1.5625 bpw 256 pp2048 10764.07 6074.51 0.56
RTX 4090 llama 8B IQ1_S - 1.5625 bpw 512 pp2048 12283.45 8088.07 0.66
RTX 4090 llama 8B IQ1_S - 1.5625 bpw 1024 pp2048 12146.24 9638.25 0.79
RTX 4090 llama 8B IQ1_S - 1.5625 bpw 2048 pp2048 11193.68 9873.72 0.88
RTX 4090 llama 8B IQ2_S - 2.5 bpw 16 pp2048 2048.38 505.74 0.25
RTX 4090 llama 8B IQ2_S - 2.5 bpw 32 pp2048 3393.96 1011.03 0.30
RTX 4090 llama 8B IQ2_S - 2.5 bpw 64 pp2048 5600.50 1893.46 0.34
RTX 4090 llama 8B IQ2_S - 2.5 bpw 128 pp2048 7574.04 3677.54 0.49
RTX 4090 llama 8B IQ2_S - 2.5 bpw 256 pp2048 9738.47 5940.02 0.61
RTX 4090 llama 8B IQ2_S - 2.5 bpw 512 pp2048 10639.67 7841.15 0.74
RTX 4090 llama 8B IQ2_S - 2.5 bpw 1024 pp2048 10290.99 9033.11 0.88
RTX 4090 llama 8B IQ2_S - 2.5 bpw 2048 pp2048 9090.10 8752.14 0.96
RTX 4090 llama 8B IQ2_XS - 2.3125 bpw 16 pp2048 2120.06 506.95 0.24
RTX 4090 llama 8B IQ2_XS - 2.3125 bpw 32 pp2048 3500.25 1011.94 0.29
RTX 4090 llama 8B IQ2_XS - 2.3125 bpw 64 pp2048 5601.59 1900.01 0.34
RTX 4090 llama 8B IQ2_XS - 2.3125 bpw 128 pp2048 7557.40 3692.89 0.49
RTX 4090 llama 8B IQ2_XS - 2.3125 bpw 256 pp2048 9618.61 6010.88 0.62
RTX 4090 llama 8B IQ2_XS - 2.3125 bpw 512 pp2048 10853.05 8030.11 0.74
RTX 4090 llama 8B IQ2_XS - 2.3125 bpw 1024 pp2048 10760.68 9576.75 0.89
RTX 4090 llama 8B IQ2_XS - 2.3125 bpw 2048 pp2048 10043.13 9863.42 0.98
RTX 4090 llama 8B IQ2_XXS - 2.0625 bpw 16 pp2048 2136.25 509.27 0.24
RTX 4090 llama 8B IQ2_XXS - 2.0625 bpw 32 pp2048 3774.52 1016.24 0.27
RTX 4090 llama 8B IQ2_XXS - 2.0625 bpw 64 pp2048 6198.03 1906.25 0.31
RTX 4090 llama 8B IQ2_XXS - 2.0625 bpw 128 pp2048 8569.45 3729.15 0.44
RTX 4090 llama 8B IQ2_XXS - 2.0625 bpw 256 pp2048 11166.74 6033.03 0.54
RTX 4090 llama 8B IQ2_XXS - 2.0625 bpw 512 pp2048 12687.06 8063.58 0.64
RTX 4090 llama 8B IQ2_XXS - 2.0625 bpw 1024 pp2048 12497.25 9592.90 0.77
RTX 4090 llama 8B IQ2_XXS - 2.0625 bpw 2048 pp2048 11493.85 9872.21 0.86
RTX 4090 llama 8B IQ3_S - 3.4375 bpw 16 pp2048 1650.19 492.99 0.30
RTX 4090 llama 8B IQ3_S - 3.4375 bpw 32 pp2048 2883.47 984.68 0.34
RTX 4090 llama 8B IQ3_S - 3.4375 bpw 64 pp2048 5466.87 1846.74 0.34
RTX 4090 llama 8B IQ3_S - 3.4375 bpw 128 pp2048 8093.26 3578.02 0.44
RTX 4090 llama 8B IQ3_S - 3.4375 bpw 256 pp2048 10642.97 5835.55 0.55
RTX 4090 llama 8B IQ3_S - 3.4375 bpw 512 pp2048 11861.45 7760.24 0.65
RTX 4090 llama 8B IQ3_S - 3.4375 bpw 1024 pp2048 11412.11 8971.85 0.79
RTX 4090 llama 8B IQ3_S - 3.4375 bpw 2048 pp2048 9988.09 8732.89 0.87
RTX 4090 llama 8B IQ3_S mix - 3.66 bpw 16 pp2048 1673.63 492.89 0.29
RTX 4090 llama 8B IQ3_S mix - 3.66 bpw 32 pp2048 2935.60 979.58 0.33
RTX 4090 llama 8B IQ3_S mix - 3.66 bpw 64 pp2048 5487.45 1845.35 0.34
RTX 4090 llama 8B IQ3_S mix - 3.66 bpw 128 pp2048 8073.59 3579.77 0.44
RTX 4090 llama 8B IQ3_S mix - 3.66 bpw 256 pp2048 10682.04 5830.44 0.55
RTX 4090 llama 8B IQ3_S mix - 3.66 bpw 512 pp2048 11855.61 7739.69 0.65
RTX 4090 llama 8B IQ3_S mix - 3.66 bpw 1024 pp2048 11414.45 8955.80 0.78
RTX 4090 llama 8B IQ3_S mix - 3.66 bpw 2048 pp2048 9961.58 8733.25 0.88
RTX 4090 llama 8B IQ3_XS - 3.3 bpw 16 pp2048 1810.46 495.51 0.27
RTX 4090 llama 8B IQ3_XS - 3.3 bpw 32 pp2048 3075.65 989.76 0.32
RTX 4090 llama 8B IQ3_XS - 3.3 bpw 64 pp2048 5593.56 1855.39 0.33
RTX 4090 llama 8B IQ3_XS - 3.3 bpw 128 pp2048 7908.67 3605.42 0.46
RTX 4090 llama 8B IQ3_XS - 3.3 bpw 256 pp2048 10535.29 5860.60 0.56
RTX 4090 llama 8B IQ3_XS - 3.3 bpw 512 pp2048 11930.76 7784.26 0.65
RTX 4090 llama 8B IQ3_XS - 3.3 bpw 1024 pp2048 11487.46 8979.04 0.78
RTX 4090 llama 8B IQ3_XS - 3.3 bpw 2048 pp2048 9977.57 8752.45 0.88
RTX 4090 llama 8B IQ3_XXS - 3.0625 bpw 16 pp2048 1942.41 493.14 0.25
RTX 4090 llama 8B IQ3_XXS - 3.0625 bpw 32 pp2048 3246.69 984.55 0.30
RTX 4090 llama 8B IQ3_XXS - 3.0625 bpw 64 pp2048 5676.98 1852.24 0.33
RTX 4090 llama 8B IQ3_XXS - 3.0625 bpw 128 pp2048 7776.55 3608.04 0.46
RTX 4090 llama 8B IQ3_XXS - 3.0625 bpw 256 pp2048 10475.05 5855.00 0.56
RTX 4090 llama 8B IQ3_XXS - 3.0625 bpw 512 pp2048 11779.42 7790.70 0.66
RTX 4090 llama 8B IQ3_XXS - 3.0625 bpw 1024 pp2048 11286.48 8991.37 0.80
RTX 4090 llama 8B IQ3_XXS - 3.0625 bpw 2048 pp2048 9882.69 8754.41 0.89
RTX 4090 llama 8B IQ4_NL - 4.5 bpw 16 pp2048 1845.89 484.18 0.26
RTX 4090 llama 8B IQ4_NL - 4.5 bpw 32 pp2048 3363.92 967.28 0.29
RTX 4090 llama 8B IQ4_NL - 4.5 bpw 64 pp2048 5514.80 1815.48 0.33
RTX 4090 llama 8B IQ4_NL - 4.5 bpw 128 pp2048 8083.99 3536.87 0.44
RTX 4090 llama 8B IQ4_NL - 4.5 bpw 256 pp2048 10629.63 5787.63 0.54
RTX 4090 llama 8B IQ4_NL - 4.5 bpw 512 pp2048 11827.68 7834.73 0.66
RTX 4090 llama 8B IQ4_NL - 4.5 bpw 1024 pp2048 11715.52 9300.43 0.79
RTX 4090 llama 8B IQ4_NL - 4.5 bpw 2048 pp2048 10658.68 9474.71 0.89
RTX 4090 llama 8B IQ4_XS - 4.25 bpw 16 pp2048 1895.00 487.59 0.26
RTX 4090 llama 8B IQ4_XS - 4.25 bpw 32 pp2048 3456.82 971.04 0.28
RTX 4090 llama 8B IQ4_XS - 4.25 bpw 64 pp2048 5699.30 1826.83 0.32
RTX 4090 llama 8B IQ4_XS - 4.25 bpw 128 pp2048 8193.01 3540.37 0.43
RTX 4090 llama 8B IQ4_XS - 4.25 bpw 256 pp2048 10743.01 5822.09 0.54
RTX 4090 llama 8B IQ4_XS - 4.25 bpw 512 pp2048 12062.46 7858.77 0.65
RTX 4090 llama 8B IQ4_XS - 4.25 bpw 1024 pp2048 12010.94 9391.68 0.78
RTX 4090 llama 8B IQ4_XS - 4.25 bpw 2048 pp2048 11016.93 9632.84 0.87
RTX 4090 llama 8B Q2_K_M 16 pp2048 2038.29 501.50 0.25
RTX 4090 llama 8B Q2_K_M 32 pp2048 3468.57 1003.92 0.29
RTX 4090 llama 8B Q2_K_M 64 pp2048 5390.14 1881.31 0.35
RTX 4090 llama 8B Q2_K_M 128 pp2048 5553.53 3681.68 0.66
RTX 4090 llama 8B Q2_K_M 256 pp2048 7681.81 5948.97 0.77
RTX 4090 llama 8B Q2_K_M 512 pp2048 9513.04 8012.94 0.84
RTX 4090 llama 8B Q2_K_M 1024 pp2048 9820.82 9561.64 0.97
RTX 4090 llama 8B Q2_K_M 2048 pp2048 9281.51 9865.96 1.06
RTX 4090 llama 8B Q3_K_S 16 pp2048 1855.05 489.78 0.26
RTX 4090 llama 8B Q3_K_S 32 pp2048 3349.46 973.59 0.29
RTX 4090 llama 8B Q3_K_S 64 pp2048 5620.07 1825.17 0.32
RTX 4090 llama 8B Q3_K_S 128 pp2048 7756.72 3562.66 0.46
RTX 4090 llama 8B Q3_K_S 256 pp2048 9924.35 5828.74 0.59
RTX 4090 llama 8B Q3_K_S 512 pp2048 11032.31 7961.34 0.72
RTX 4090 llama 8B Q3_K_S 1024 pp2048 11147.46 9633.22 0.86
RTX 4090 llama 8B Q3_K_S 2048 pp2048 10458.19 10069.99 0.96
RTX 4090 llama 8B Q4_0 16 pp2048 1823.73 483.19 0.26
RTX 4090 llama 8B Q4_0 32 pp2048 3326.25 959.73 0.29
RTX 4090 llama 8B Q4_0 64 pp2048 5823.55 1805.26 0.31
RTX 4090 llama 8B Q4_0 128 pp2048 8776.15 3512.86 0.40
RTX 4090 llama 8B Q4_0 256 pp2048 11726.54 5807.80 0.50
RTX 4090 llama 8B Q4_0 512 pp2048 13241.36 8003.77 0.60
RTX 4090 llama 8B Q4_0 1024 pp2048 13302.06 9691.22 0.73
RTX 4090 llama 8B Q4_0 2048 pp2048 12395.89 10151.28 0.82
RTX 4090 llama 8B Q4_1 16 pp2048 1735.86 477.01 0.27
RTX 4090 llama 8B Q4_1 32 pp2048 3316.52 947.21 0.29
RTX 4090 llama 8B Q4_1 64 pp2048 5590.41 1786.01 0.32
RTX 4090 llama 8B Q4_1 128 pp2048 8266.54 3477.50 0.42
RTX 4090 llama 8B Q4_1 256 pp2048 11008.21 5749.13 0.52
RTX 4090 llama 8B Q4_1 512 pp2048 12390.34 7874.50 0.64
RTX 4090 llama 8B Q4_1 1024 pp2048 12545.96 9542.28 0.76
RTX 4090 llama 8B Q4_1 2048 pp2048 11768.92 10010.06 0.85
RTX 4090 llama 8B Q4_K_S 16 pp2048 1811.12 482.90 0.27
RTX 4090 llama 8B Q4_K_S 32 pp2048 3518.97 958.08 0.27
RTX 4090 llama 8B Q4_K_S 64 pp2048 5942.06 1803.60 0.30
RTX 4090 llama 8B Q4_K_S 128 pp2048 8363.02 3519.98 0.42
RTX 4090 llama 8B Q4_K_S 256 pp2048 11060.81 5800.22 0.52
RTX 4090 llama 8B Q4_K_S 512 pp2048 12470.52 7902.08 0.63
RTX 4090 llama 8B Q4_K_S 1024 pp2048 12600.90 9540.61 0.76
RTX 4090 llama 8B Q4_K_S 2048 pp2048 11763.92 10007.35 0.85
RTX 4090 llama 8B Q5_0 16 pp2048 1594.98 458.26 0.29
RTX 4090 llama 8B Q5_0 32 pp2048 2970.63 905.00 0.30
RTX 4090 llama 8B Q5_0 64 pp2048 5211.66 1698.45 0.33
RTX 4090 llama 8B Q5_0 128 pp2048 8122.97 3296.47 0.41
RTX 4090 llama 8B Q5_0 256 pp2048 10861.65 5482.80 0.50
RTX 4090 llama 8B Q5_0 512 pp2048 12236.20 7593.84 0.62
RTX 4090 llama 8B Q5_0 1024 pp2048 12297.34 9241.81 0.75
RTX 4090 llama 8B Q5_0 2048 pp2048 11317.52 9673.95 0.85
RTX 4090 llama 8B Q5_1 16 pp2048 1578.65 456.87 0.29
RTX 4090 llama 8B Q5_1 32 pp2048 3046.68 906.34 0.30
RTX 4090 llama 8B Q5_1 64 pp2048 5176.37 1694.54 0.33
RTX 4090 llama 8B Q5_1 128 pp2048 7753.29 3286.86 0.42
RTX 4090 llama 8B Q5_1 256 pp2048 10272.38 5461.69 0.53
RTX 4090 llama 8B Q5_1 512 pp2048 11674.35 7539.37 0.65
RTX 4090 llama 8B Q5_1 1024 pp2048 11823.78 9163.92 0.78
RTX 4090 llama 8B Q5_1 2048 pp2048 10994.51 9578.97 0.87
RTX 4090 llama 8B Q5_K_S 16 pp2048 1667.07 476.82 0.29
RTX 4090 llama 8B Q5_K_S 32 pp2048 3169.82 955.20 0.30
RTX 4090 llama 8B Q5_K_S 64 pp2048 5532.52 1793.41 0.32
RTX 4090 llama 8B Q5_K_S 128 pp2048 8040.64 3483.21 0.43
RTX 4090 llama 8B Q5_K_S 256 pp2048 10581.57 5724.62 0.54
RTX 4090 llama 8B Q5_K_S 512 pp2048 11963.16 7829.51 0.65
RTX 4090 llama 8B Q5_K_S 1024 pp2048 12252.54 9446.07 0.77
RTX 4090 llama 8B Q5_K_S 2048 pp2048 11462.50 9948.84 0.87
RTX 4090 llama 8B Q6_K 16 pp2048 1417.83 467.42 0.33
RTX 4090 llama 8B Q6_K 32 pp2048 2779.86 934.50 0.34
RTX 4090 llama 8B Q6_K 64 pp2048 4787.15 1758.05 0.37
RTX 4090 llama 8B Q6_K 128 pp2048 7103.90 3398.73 0.48
RTX 4090 llama 8B Q6_K 256 pp2048 9457.93 5629.29 0.60
RTX 4090 llama 8B Q6_K 512 pp2048 10727.17 7685.62 0.72
RTX 4090 llama 8B Q6_K 1024 pp2048 10923.56 9204.79 0.84
RTX 4090 llama 8B Q6_K 2048 pp2048 10093.00 9528.85 0.94
RTX 4090 llama 8B Q8_0 16 pp2048 1306.56 457.63 0.35
RTX 4090 llama 8B Q8_0 32 pp2048 2408.03 914.73 0.38
RTX 4090 llama 8B Q8_0 64 pp2048 4354.06 1711.24 0.39
RTX 4090 llama 8B Q8_0 128 pp2048 7153.27 3334.15 0.47
RTX 4090 llama 8B Q8_0 256 pp2048 10286.79 5520.72 0.54
RTX 4090 llama 8B Q8_0 512 pp2048 12371.58 7610.56 0.62
RTX 4090 llama 8B Q8_0 1024 pp2048 12631.21 9248.36 0.73
RTX 4090 llama 8B Q8_0 2048 pp2048 11870.43 9797.81 0.83

On my RTX 4090 for all quantization formats except q2_K MMQ is faster for batch sizes <= 2048, for batch sizes <= 1024 MMQ is always faster. On my RTX 3090 MMQ is faster for all quantization formats except q2_K at batch sizes <= 512.

@JohannesGaessler
Copy link
Collaborator

More importantly, I tested the main branch on a H100, With the default code path (MMQ with mma instr) and with ggml_cuda_should_use_mmq returning false.

The MMQ code is designed around tensor core instructions that were introduced with Ampere. Hopper can also make use of these instructions but they have additional tensor core instructions that are only found on Hopper and to my knowledge no earlier or later generation. Presumably the cuBLAS code makes use of these instructions, the MMQ code definitely does not. I have never tested the code on or tuned it for Hopper.

I quickly ran the test on a RTX 2080 ti. I see the same results with cublas path outperforming the MMQ path, same as H100.

The tensor core instructions that MMQ was written around are not available on Turing. However, there are similar tensor core instruction which work on tiles that are exactly half as large and as such the same numerical result can be obtained by executing 2 tensor core instructions instead. I do not own any Turing hardware and have not tuned the code specifically for this architecture.

I'm looking to see if there a reasoning why a different design choice is made here between AMD & NV in terms of code paths, or if I'm just doing something wrong with my testing. Or maybe I'm just overthinking here and it slipped through the cracks during testing and needs to be fixed?

Is there actual guidance as to when performance is preferred over memory usage? Looks like there are conflicting viewpoints. Would be great to have this information documentation for when we contribute and add other architectures, there is a common design principle.

No one wrote down the exact criteria for when to use cuBLAS vs. MMQ. My general opinion is that if you use anything below q8_0 you are already trading quality for lower memory usage. The hardware that I have been focusing on is RTX 3090/4090 because those are in my opinion the best "cheap" GPUs with 24 GB VRAM; on those MMQ performs well enough that I think using it unconditionally is the correct choice. On an RTX 2080 ti with only 11 GB VRAM it's even more important to keep memory usage low so I decided to enable MMQ unconditionally as well under the assumption that the tradeoffs would be similar to Ampere/Ada Lovelace.

The logic on master for AMD was written by @IMbackK . I don't remember whether he posted the performance numbers upon which his decisions were based in the relevant PRs. In any case I don't have a good overview of the AMD hardware stack and decided to go with his judgement.

For CDNA3 in particular my understanding is that all GPUs using that architecture have at least 128 GB of memory. For that particular hardware I therefore think that the correct choice is to simply maximize speed.

I do have to say, some look needs to be taken further into documentation for llama.cpp. When I was trying to add support for AMD in MMQ kernels. There were so many implicit design choices that are made without explanation in comments/documentation. It makes it extremely hard to understand the code and thereby making it hard to contribute more.

I have documented the design decisions that seemed unintuitive to me at the time. However, I think that it is generally difficult to judge which parts of your own work need to be documented in order to make it understandable to third parties. Since you have already gone through the trouble of understanding the code as-is I would be grateful if you could add comments in those places where they would have helped your understanding.

@deepsek
Copy link
Author

deepsek commented Jul 24, 2025

Sounds good. Great to hear the intuition process on these.
My primary goal is to enable all possible feature set. I'll leave the code path, design choices and subsequent modifications to the ggml_cuda_should_use_mmq code for the community.
I am here if you need to run perf/bench, etc on the hardware that is openly available (primarily CDNA3, my focus atm).

@deepsek
Copy link
Author

deepsek commented Jul 24, 2025

I don't understand what you're doing with tile_load, please explain it.

The purpose of this is to leverage the 16x16x32 mfma instr (16x8 tile) over 32x32x16 (32x4 tile). This gives some perf increase and also fixes nwarps to 8 for all quants.

I use this specific 'placeholder' tile to load the same <16x4> tile twice as a <16x8> tile since the current arch matrix core can't support 16x16x16 instr. With this compute, the result is basically double the value needed, hence in these cases the scale value is halved in the code.

load_tile is set to <64,2> only because the tile calculates number of elements and I need it to stay at 2 (== 64*2/64). Since, <16,8>, <32,4> is taken. Also, hence the tag as a special tile used to achieve this.

@JohannesGaessler
Copy link
Collaborator

If I understand you correctly you are solving the problem where some of the quantization types have 1 scale per 4 32 bit integers so the 16x8 tiles are too long and don't yield the correct results. Have you considered loading the data as 16x8 tiles as you do in e.g. vec_dot_q8_0_q8_1_mma and then simply running the matrix multiply accumulate twice with half of the values masked? To me that seems like a simpler and more intuitive solution. In terms of implementation detail we could maybe extend the interface in mma.cuh with functions like

    static __device__ __forceinline__ void mma_low(
            tile<16, 16, float> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) {

and a corresponding function mma_high that do the matrix multiply accumulate for only one half of the A/B tiles. It's not necessary to write an NVIDIA implementation for this in this PR, but it wouldn't be difficult.

@deepsek
Copy link
Author

deepsek commented Jul 24, 2025

Yea. I was initially going down that same road when I started remove the larger (32x4) tiles. I was going to use a bitmask to simply clear unused threads. But that would require quite a few more changes in the code and how the loops are written right now. In the interest of saving time and prevent this PR from dangling too long, I just choose the other route to achieve the same with minimal changes.

We can revisit this as part of a larger redesign at a later date if you'd like.

@JohannesGaessler
Copy link
Collaborator

I don't think the current solution is good but I'm willing to approve the PR as-is as long as you promise to refactor the code in the near future.

@JohannesGaessler JohannesGaessler requested a review from IMbackK July 24, 2025 20:27
@xbezdick
Copy link

Needed manual rebase but works great and on multiple AMD Radeon RX 7900 XT cards resolves "Page not present or supervisor privilege" gpu crash on k-shift.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
devops improvements to build systems and github actions ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants