Skip to content

convert : support non-mxfp4 HF model #15153

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 7, 2025

Conversation

ngxson
Copy link
Collaborator

@ngxson ngxson commented Aug 7, 2025

The goal is to fix conversion for https://huggingface.co/huihui-ai/Huihui-gpt-oss-20b-BF16-abliterated

This partially reverts e2c1beb

Closes #15146

@github-actions github-actions bot added the python python script changes label Aug 7, 2025
@gabriellarson
Copy link
Contributor

Got this error:

(main) [email protected]:/workspace$ python llama.cpp/convert_hf_to_gguf.py Huihui-gpt-oss-20b-BF16-abliterated
INFO:hf-to-gguf:Loading model: Huihui-gpt-oss-20b-BF16-abliterated
INFO:hf-to-gguf:Model architecture: GptOssForCausalLM
INFO:gguf.gguf_writer:gguf: This GGUF file is for Little Endian only
INFO:hf-to-gguf:Exporting model...
INFO:hf-to-gguf:gguf: loading model weight map from 'model.safetensors.index.json'
INFO:hf-to-gguf:gguf: loading model part 'model-00001-of-00009.safetensors'
INFO:hf-to-gguf:gguf: loading model part 'model-00002-of-00009.safetensors'
INFO:hf-to-gguf:gguf: loading model part 'model-00003-of-00009.safetensors'
INFO:hf-to-gguf:gguf: loading model part 'model-00004-of-00009.safetensors'
INFO:hf-to-gguf:gguf: loading model part 'model-00005-of-00009.safetensors'
INFO:hf-to-gguf:gguf: loading model part 'model-00006-of-00009.safetensors'
INFO:hf-to-gguf:gguf: loading model part 'model-00007-of-00009.safetensors'
INFO:hf-to-gguf:gguf: loading model part 'model-00008-of-00009.safetensors'
INFO:hf-to-gguf:gguf: loading model part 'model-00009-of-00009.safetensors'
Traceback (most recent call last):
  File "/workspace/llama.cpp/convert_hf_to_gguf.py", line 8495, in <module>
    main()
  File "/workspace/llama.cpp/convert_hf_to_gguf.py", line 8489, in main
    model_instance.write()
  File "/workspace/llama.cpp/convert_hf_to_gguf.py", line 410, in write
    self.prepare_tensors()
  File "/workspace/llama.cpp/convert_hf_to_gguf.py", line 259, in prepare_tensors
    for name, data_torch in chain(self.generate_extra_tensors(), self.get_tensors()):
  File "/workspace/llama.cpp/convert_hf_to_gguf.py", line 8019, in generate_extra_tensors
    raise ValueError("No MXFP4 tensors found in the model. Please make sure you are using MXFP4 model.")
ValueError: No MXFP4 tensors found in the model. Please make sure you are using MXFP4 model.

Tried with this script:

git clone https://github.com/ngxson/llama.cpp -b xsn/gptoss_non_mxfp4_conversion
pip install -r llama.cpp/requirements.txt
pip install -U "huggingface_hub[cli]"
pip install hf_transfer
export HF_HUB_ENABLE_HF_TRANSFER=1

huggingface-cli login --token $HF_TOKEN

huggingface-cli download $HF_REPO --local-dir $(basename "$HF_REPO")

python llama.cpp/convert_hf_to_gguf.py $(basename "$HF_REPO")

@gabriellarson
Copy link
Contributor

I deleted this section:

if not found_mxfp4_tensors:
    raise ValueError("No MXFP4 tensors found in the model. Please make sure you are using MXFP4 model.")

The script ran without it, I'm uploading here and testing it

@gabriellarson
Copy link
Contributor

looking good

llama.cpp/build/bin/llama-cli -c 4000 -fa --jinja --reasoning-format none -m Huihui-gpt-oss-20b-BF16-abliterated/Huihui-gpt-oss-20B
-abliterated-F16.gguf -p "hello, how are you doing"
build: 6112 (b1cbcdd9) with cc (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0 for x86_64-linux-gnu
main: llama backend init
main: load the model and apply lora adapter, if any
llama_model_loader: loaded meta data with 40 key-value pairs and 459 tensors from Huihui-gpt-oss-20b-BF16-abliterated/Huihui-gpt-oss-20B-abliterated-F16.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = gpt-oss
llama_model_loader: - kv   1:                               general.type str              = model
llama_model_loader: - kv   2:                               general.name str              = Huihui Gpt Oss 20b BF16 Abliterated
llama_model_loader: - kv   3:                           general.finetune str              = abliterated
llama_model_loader: - kv   4:                           general.basename str              = Huihui-gpt-oss
llama_model_loader: - kv   5:                         general.size_label str              = 20B
llama_model_loader: - kv   6:                            general.license str              = apache-2.0
llama_model_loader: - kv   7:                   general.base_model.count u32              = 1
llama_model_loader: - kv   8:                  general.base_model.0.name str              = Gpt Oss 20b BF16
llama_model_loader: - kv   9:          general.base_model.0.organization str              = Unsloth
llama_model_loader: - kv  10:              general.base_model.0.repo_url str              = https://huggingface.co/unsloth/gpt-os...
llama_model_loader: - kv  11:                               general.tags arr[str,5]       = ["vllm", "unsloth", "abliterated", "u...
llama_model_loader: - kv  12:                        gpt-oss.block_count u32              = 24
llama_model_loader: - kv  13:                     gpt-oss.context_length u32              = 131072
llama_model_loader: - kv  14:                   gpt-oss.embedding_length u32              = 2880
llama_model_loader: - kv  15:                gpt-oss.feed_forward_length u32              = 2880
llama_model_loader: - kv  16:               gpt-oss.attention.head_count u32              = 64
llama_model_loader: - kv  17:            gpt-oss.attention.head_count_kv u32              = 8
llama_model_loader: - kv  18:                     gpt-oss.rope.freq_base f32              = 150000.000000
llama_model_loader: - kv  19:   gpt-oss.attention.layer_norm_rms_epsilon f32              = 0.000010
llama_model_loader: - kv  20:                       gpt-oss.expert_count u32              = 32
llama_model_loader: - kv  21:                  gpt-oss.expert_used_count u32              = 4
llama_model_loader: - kv  22:               gpt-oss.attention.key_length u32              = 64
llama_model_loader: - kv  23:             gpt-oss.attention.value_length u32              = 64
llama_model_loader: - kv  24:                          general.file_type u32              = 1
llama_model_loader: - kv  25:           gpt-oss.attention.sliding_window u32              = 128
llama_model_loader: - kv  26:         gpt-oss.expert_feed_forward_length u32              = 2880
llama_model_loader: - kv  27:                  gpt-oss.rope.scaling.type str              = yarn
llama_model_loader: - kv  28:                gpt-oss.rope.scaling.factor f32              = 32.000000
llama_model_loader: - kv  29: gpt-oss.rope.scaling.original_context_length u32              = 4096
llama_model_loader: - kv  30:               general.quantization_version u32              = 2
llama_model_loader: - kv  31:                       tokenizer.ggml.model str              = gpt2
llama_model_loader: - kv  32:                         tokenizer.ggml.pre str              = gpt-4o
llama_model_loader: - kv  33:                      tokenizer.ggml.tokens arr[str,201088]  = ["!", "\"", "#", "$", "%", "&", "'", ...
llama_model_loader: - kv  34:                  tokenizer.ggml.token_type arr[i32,201088]  = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
llama_model_loader: - kv  35:                      tokenizer.ggml.merges arr[str,446189]  = ["Ġ Ġ", "Ġ ĠĠĠ", "ĠĠ ĠĠ", "...
llama_model_loader: - kv  36:                tokenizer.ggml.bos_token_id u32              = 199998
llama_model_loader: - kv  37:                tokenizer.ggml.eos_token_id u32              = 200002
llama_model_loader: - kv  38:            tokenizer.ggml.padding_token_id u32              = 199999
llama_model_loader: - kv  39:                    tokenizer.chat_template str              = {# Copyright 2025-present Unsloth. Ap...
llama_model_loader: - type  f32:  289 tensors
llama_model_loader: - type  f16:  170 tensors
print_info: file format = GGUF V3 (latest)
print_info: file type   = F16
print_info: file size   = 38.97 GiB (16.01 BPW) 
load: printing all EOG tokens:
load:   - 199999 ('<|endoftext|>')
load:   - 200002 ('<|return|>')
load:   - 200007 ('<|end|>')
load:   - 200012 ('<|call|>')
load: special_eog_ids contains both '<|return|>' and '<|call|>' tokens, removing '<|end|>' token from EOG list
load: special tokens cache size = 21
load: token to piece cache size = 1.3332 MB
print_info: arch             = gpt-oss
print_info: vocab_only       = 0
print_info: n_ctx_train      = 131072
print_info: n_embd           = 2880
print_info: n_layer          = 24
print_info: n_head           = 64
print_info: n_head_kv        = 8
print_info: n_rot            = 64
print_info: n_swa            = 128
print_info: is_swa_any       = 1
print_info: n_embd_head_k    = 64
print_info: n_embd_head_v    = 64
print_info: n_gqa            = 8
print_info: n_embd_k_gqa     = 512
print_info: n_embd_v_gqa     = 512
print_info: f_norm_eps       = 0.0e+00
print_info: f_norm_rms_eps   = 1.0e-05
print_info: f_clamp_kqv      = 0.0e+00
print_info: f_max_alibi_bias = 0.0e+00
print_info: f_logit_scale    = 0.0e+00
print_info: f_attn_scale     = 0.0e+00
print_info: n_ff             = 2880
print_info: n_expert         = 32
print_info: n_expert_used    = 4
print_info: causal attn      = 1
print_info: pooling type     = 0
print_info: rope type        = 2
print_info: rope scaling     = yarn
print_info: freq_base_train  = 150000.0
print_info: freq_scale_train = 0.03125
print_info: n_ctx_orig_yarn  = 4096
print_info: rope_finetuned   = unknown
print_info: model type       = ?B
print_info: model params     = 20.91 B
print_info: general.name     = Huihui Gpt Oss 20b BF16 Abliterated
print_info: n_ff_exp         = 2880
print_info: vocab type       = BPE
print_info: n_vocab          = 201088
print_info: n_merges         = 446189
print_info: BOS token        = 199998 '<|startoftext|>'
print_info: EOS token        = 200002 '<|return|>'
print_info: EOT token        = 199999 '<|endoftext|>'
print_info: PAD token        = 199999 '<|endoftext|>'
print_info: LF token         = 198 'Ċ'
print_info: EOG token        = 199999 '<|endoftext|>'
print_info: EOG token        = 200002 '<|return|>'
print_info: EOG token        = 200012 '<|call|>'
print_info: max token length = 256
load_tensors: loading model tensors, this can take a while... (mmap = true)
load_tensors:   CPU_Mapped model buffer size = 39909.25 MiB
.............................................................................
llama_context: constructing llama_context
llama_context: n_seq_max     = 1
llama_context: n_ctx         = 4000
llama_context: n_ctx_per_seq = 4000
llama_context: n_batch       = 2048
llama_context: n_ubatch      = 512
llama_context: causal_attn   = 1
llama_context: flash_attn    = 1
llama_context: kv_unified    = false
llama_context: freq_base     = 150000.0
llama_context: freq_scale    = 0.03125
llama_context: n_ctx_per_seq (4000) < n_ctx_train (131072) -- the full capacity of the model will not be utilized
llama_context:        CPU  output buffer size =     0.77 MiB
llama_kv_cache_unified_iswa: creating non-SWA KV cache, size = 4096 cells
llama_kv_cache_unified:        CPU KV buffer size =    96.00 MiB
llama_kv_cache_unified: size =   96.00 MiB (  4096 cells,  12 layers,  1/1 seqs), K (f16):   48.00 MiB, V (f16):   48.00 MiB
llama_kv_cache_unified_iswa: creating     SWA KV cache, size = 768 cells
llama_kv_cache_unified:        CPU KV buffer size =    18.00 MiB
llama_kv_cache_unified: size =   18.00 MiB (   768 cells,  12 layers,  1/1 seqs), K (f16):    9.00 MiB, V (f16):    9.00 MiB
llama_context:        CPU compute buffer size =   398.38 MiB
llama_context: graph nodes  = 1352
llama_context: graph splits = 1
common_init_from_params: KV cache shifting is not supported for this context, disabling KV cache shifting
common_init_from_params: added <|endoftext|> logit bias = -inf
common_init_from_params: added <|return|> logit bias = -inf
common_init_from_params: added <|call|> logit bias = -inf
common_init_from_params: setting dry_penalty_last_n to ctx_size = 4096
common_init_from_params: warming up the model with an empty run - please wait ... (--no-warmup to disable)
main: llama threadpool init, n_threads = 128
main: chat template is available, enabling conversation mode (disable it with -no-cnv)
*** User-specified prompt will pre-start conversation, did you mean to set --system-prompt (-sys) instead?
main: chat template example:
<|start|>system<|message|>You are ChatGPT, a large language model trained by OpenAI.
Knowledge cutoff: 2024-06
Current date: 2025-08-07

Reasoning: medium

# Valid channels: analysis, commentary, final. Channel must be included for every message.<|end|><|start|>developer<|message|># Instructions

You are a helpful assistant<|end|><|start|>user<|message|>Hello<|end|><|start|>assistant<|message|>Hi there<|end|><|start|>user<|message|>How are you?<|end|><|start|>assistant

system_info: n_threads = 128 (n_threads_batch = 128) / 255 | CPU : SSE3 = 1 | SSSE3 = 1 | AVX = 1 | AVX2 = 1 | F16C = 1 | FMA = 1 | BMI2 = 1 | LLAMAFILE = 1 | OPENMP = 1 | REPACK = 1 | 

main: interactive mode on.
sampler seed: 275774358
sampler params: 
        repeat_last_n = 64, repeat_penalty = 1.000, frequency_penalty = 0.000, presence_penalty = 0.000
        dry_multiplier = 0.000, dry_base = 1.750, dry_allowed_length = 2, dry_penalty_last_n = 4096
        top_k = 40, top_p = 0.950, min_p = 0.050, xtc_probability = 0.000, xtc_threshold = 0.100, typical_p = 1.000, top_n_sigma = -1.000, temp = 0.800
        mirostat = 0, mirostat_lr = 0.100, mirostat_ent = 5.000
sampler chain: logits -> logit-bias -> penalties -> dry -> top-n-sigma -> top-k -> typical -> top-p -> min-p -> xtc -> temp-ext -> dist 
generate: n_ctx = 4096, n_batch = 2048, n_predict = -1, n_keep = 0

== Running in interactive mode. ==
 - Press Ctrl+C to interject at any time.
 - Press Return to return control to the AI.
 - To return control without starting a new line, end your input with '/'.
 - If you want to submit another line, end your input with '\'.
 - Not using system message. To change it, set a different value via -sys PROMPT

<|start|>system<|message|>You are ChatGPT, a large language model trained by OpenAI.
Knowledge cutoff: 2024-06
Current date: 2025-08-07

Reasoning: medium

# Valid channels: analysis, commentary, final. Channel must be included for every message.<|start|>user<|message|>hello, how are you doing<|start|>assistant<|channel|>analysis<|message|>The user says "hello, how are you doing". They want a response. Likely a friendly reply. Should respond politely. Also ask about them? The conversation is short, so a brief friendly response. Could ask how they're doing, mention I'm an AI, but I'm doing well. Could ask about their day. Maybe incorporate something about the user. Probably respond in a friendly tone, mention I'm an AI. Let's produce.<|start|>assistant<|channel|>final<|message|>Hello! I’m doing great—thanks for asking. How about you? Anything interesting going on today that you’d like to chat about?

@gabriellarson
Copy link
Contributor

Trying to quantize down to MXFP4 prints out a ton of stuff and then fails

deq[8294350] = 0.003906, x[8294350] = 0.003998
deq[8294351] = -0.007812, x[8294351] = -0.007874
deq[8294353] = -0.003906, x[8294353] = -0.003937
deq[8294354] = 0.011719, x[8294354] = 0.011780
deq[8294357] = 0.000000, x[8294357] = -0.000008
deq[8294361] = 0.007812, x[8294361] = 0.007721
deq[8294362] = 0.011719, x[8294362] = 0.011780
deq[8294363] = 0.000000, x[8294363] = -0.000064
deq[8294365] = -0.003906, x[8294365] = -0.003876
deq[8294366] = 0.000000, x[8294366] = 0.000078
deq[8294367] = 0.023438, x[8294367] = 0.023315
deq[8294371] = 0.000000, x[8294371] = 0.000016
deq[8294372] = 0.000000, x[8294372] = 0.000031
deq[8294373] = 0.015625, x[8294373] = 0.015747
deq[8294376] = 0.007812, x[8294376] = 0.007721
deq[8294377] = 0.007812, x[8294377] = 0.007874
deq[8294379] = -0.015625, x[8294379] = -0.015747
deq[8294388] = -0.007812, x[8294388] = -0.007782
deq[8294389] = -0.007812, x[8294389] = -0.007874
deq[8294394] = 0.000000, x[8294394] = -0.000007
deq[8294395] = -0.015625, x[8294395] = -0.015564
deq[8294396] = -0.023438, x[8294396] = -0.023560
deq[8294398] = -0.007812, x[8294398] = -0.007782
/workspace/llama.cpp/src/llama-quant.cpp:1020: GGML_ASSERT(err == 0.00000) failed
[Thread debugging using libthread_db enabled]
Using host libthread_db library "/usr/lib/x86_64-linux-gnu/libthread_db.so.1".
0x00007f348ca4842f in wait4 () from /usr/lib/x86_64-linux-gnu/libc.so.6
#0  0x00007f348ca4842f in wait4 () from /usr/lib/x86_64-linux-gnu/libc.so.6
#1  0x00007f348cedd1eb in ggml_print_backtrace () from /workspace/llama.cpp/build/bin/libggml-base.so
#2  0x00007f348cedd382 in ggml_abort () from /workspace/llama.cpp/build/bin/libggml-base.so
#3  0x00007f348d0d05e3 in llama_model_quantize_impl(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, llama_model_quantize_params const*) () from /workspace/llama.cpp/build/bin/libllama.so
#4  0x00007f348d0d08e4 in llama_model_quantize () from /workspace/llama.cpp/build/bin/libllama.so
#5  0x0000557b4d1f630e in main ()
[Inferior 1 (process 1240365) detached]

@ShyloCliffe
Copy link

ShyloCliffe commented Aug 7, 2025

Tried your solution @gabriellarson but it only seems to produce a 2GB file using q8_0, so I think there's an issue somewhere.

@ngxson
Copy link
Collaborator Author

ngxson commented Aug 7, 2025

@gabriellarson thanks for testing, please retry to see if llama-quantize works

Comment on lines 1001 to +1002
// TODO: temporary sanity check that the F16 -> MXFP4 is lossless
#if 1
#if 0
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For vis @ggerganov , I disable this check because most users will now using this code branch to convert fine-tuned model to MXFP4, which will no longer be lossless.

Although, I'm a bit doubt if fine-tuned models like the abliterated version should be quantize to something other than MXFP4 or not

@gabriellarson Could you also try converting it to Q4_K_M to see if it impacts the quality?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah nevermind, it's not possible to quantize to Q4_K since the tensor shape is not divisible by 256

@gabriellarson
Copy link
Contributor

Quantizing works now

Q4_K_M and MXFP4 both create decent output, Q4_K_M has lower perplexity

llama.cpp/build/bin/llama-perplexity -m model.gguf -f wikitext-2-raw/wiki.test.raw -ngl 99

MXFP4:

Final estimate: PPL = 131.3544 +/- 1.10242

Q4_K_M

Final estimate: PPL = 93.1469 +/- 0.78081

@gabriellarson
Copy link
Contributor

Q4_K_M is also larger in size

image

@ngxson
Copy link
Collaborator Author

ngxson commented Aug 7, 2025

Q4_K_M is also larger in size

Yes that's expected, because big FFN tensors cannot be quantized to anything other than Q8_0 or MXFP4. For the Q4_K_M, these tensors are fallback to Q8_0

@ngxson ngxson marked this pull request as ready for review August 7, 2025 20:52
@ngxson ngxson requested review from CISC and ggerganov August 7, 2025 20:53
@ngxson ngxson merged commit 50aa938 into ggml-org:master Aug 7, 2025
51 checks passed
@pwilkin
Copy link
Contributor

pwilkin commented Aug 7, 2025

Q4_K_M is also larger in size

Yes that's expected, because big FFN tensors cannot be quantized to anything other than Q8_0 or MXFP4. For the Q4_K_M, these tensors are fallback to Q8_0

I guess we need a new quantization scheme of "Q4_K_FX" or something that uses MXFP4 as the fallback.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
python python script changes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Convert to Gguf - Errors with OpenAI finetunes to GGUF.
5 participants