Skip to content

mtmd : add qwen2vl and qwen2.5vl #13141

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Apr 29, 2025
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others)
## Hot topics

- **GGML developer experience survey (organized and reviewed by NVIDIA):** [link](https://forms.gle/Gasw3cRgyhNEnrwK9)
- A new binary `llama-mtmd-cli` is introduced to replace `llava-cli`, `minicpmv-cli` and `gemma3-cli` https://github.com/ggml-org/llama.cpp/pull/13012, `libllava` will be deprecated
- A new binary `llama-mtmd-cli` is introduced to replace `llava-cli`, `minicpmv-cli`, `gemma3-cli` ([#13012](https://github.com/ggml-org/llama.cpp/pull/13012)) and `qwen2vl-cli` ([#13141]((https://github.com/ggml-org/llama.cpp/pull/13141))), `libllava` will be deprecated
- VS Code extension for FIM completions: https://github.com/ggml-org/llama.vscode
- Universal [tool call support](./docs/function-calling.md) in `llama-server` https://github.com/ggml-org/llama.cpp/pull/9639
- Vim/Neovim plugin for FIM completions: https://github.com/ggml-org/llama.vim
Expand Down
8 changes: 1 addition & 7 deletions examples/llava/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,7 @@ endif()
add_executable(llama-llava-cli deprecation-warning.cpp)
add_executable(llama-gemma3-cli deprecation-warning.cpp)
add_executable(llama-minicpmv-cli deprecation-warning.cpp)

set(TARGET llama-qwen2vl-cli)
add_executable(${TARGET} qwen2vl-cli.cpp)
set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama-qwen2vl-cli)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE common llava ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_17)
add_executable(llama-qwen2vl-cli deprecation-warning.cpp)

set(TARGET llama-mtmd-cli)
add_executable(${TARGET} mtmd-cli.cpp)
Expand Down
34 changes: 30 additions & 4 deletions examples/llava/clip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2825,15 +2825,18 @@ void clip_free(clip_ctx * ctx) {
delete ctx;
}

// deprecated
size_t clip_embd_nbytes(const struct clip_ctx * ctx) {
return clip_n_patches(ctx) * clip_n_mmproj_embd(ctx) * sizeof(float);
const int32_t nx = ctx->vision_model.hparams.image_size;
const int32_t ny = ctx->vision_model.hparams.image_size;
return clip_embd_nbytes_by_img(ctx, nx, ny);
}

size_t clip_embd_nbytes_by_img(const struct clip_ctx * ctx, int img_h, int img_w) {
size_t clip_embd_nbytes_by_img(const struct clip_ctx * ctx, int img_w, int img_h) {
clip_image_f32 img;
img.nx = img_w;
img.ny = img_h;
return clip_n_patches_by_img(ctx, &img) * clip_n_mmproj_embd(ctx) * sizeof(float);
return clip_n_output_tokens(ctx, &img) * clip_n_mmproj_embd(ctx) * sizeof(float);
}

int32_t clip_get_image_size(const struct clip_ctx * ctx) {
Expand Down Expand Up @@ -2863,14 +2866,37 @@ size_t get_clip_image_grid_size(const struct clip_ctx * ctx) {
return ctx->vision_model.hparams.image_grid_pinpoints.size();
}

// deprecated
int clip_n_patches(const struct clip_ctx * ctx) {
clip_image_f32 img;
img.nx = ctx->vision_model.hparams.image_size;
img.ny = ctx->vision_model.hparams.image_size;
return clip_n_patches_by_img(ctx, &img);
return clip_n_output_tokens(ctx, &img);
}

// deprecated
int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * img) {
return clip_n_output_tokens(ctx, img);
}

int clip_n_output_tokens_x(const struct clip_ctx * ctx, struct clip_image_f32 * img) {
const auto & params = ctx->vision_model.hparams;
const int n_total = clip_n_output_tokens(ctx, img);
if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type == PROJECTOR_TYPE_QWEN25VL) {
return img->nx / (params.patch_size * 2) + (int)(img->nx % params.patch_size > 0);
}
return n_total;
}

int clip_n_output_tokens_y(const struct clip_ctx * ctx, struct clip_image_f32 * img) {
const auto & params = ctx->vision_model.hparams;
if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type == PROJECTOR_TYPE_QWEN25VL) {
return img->ny / (params.patch_size * 2) + (int)(img->ny % params.patch_size > 0);
}
return 1;
}

int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * img) {
const auto & params = ctx->vision_model.hparams;

int n_patches = (params.image_size / params.patch_size) * (params.image_size / params.patch_size);
Expand Down
19 changes: 15 additions & 4 deletions examples/llava/clip.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ CLIP_API struct clip_ctx * clip_init(const char * fname, struct clip_context_par
CLIP_API void clip_free(struct clip_ctx * ctx);

CLIP_API size_t clip_embd_nbytes(const struct clip_ctx * ctx);
CLIP_API size_t clip_embd_nbytes_by_img(const struct clip_ctx * ctx, int img_h, int img_w);
CLIP_API size_t clip_embd_nbytes_by_img(const struct clip_ctx * ctx, int img_w, int img_h);

CLIP_API int32_t clip_get_image_size (const struct clip_ctx * ctx);
CLIP_API int32_t clip_get_patch_size (const struct clip_ctx * ctx);
Expand All @@ -59,9 +59,20 @@ CLIP_API const char * clip_patch_merge_type(const struct clip_ctx * ctx);
CLIP_API const int32_t * clip_image_grid(const struct clip_ctx * ctx);
CLIP_API size_t get_clip_image_grid_size(const struct clip_ctx * ctx);

CLIP_API int clip_n_patches (const struct clip_ctx * ctx);
CLIP_API int clip_n_patches_by_img (const struct clip_ctx * ctx, struct clip_image_f32 * img);
CLIP_API int clip_n_mmproj_embd (const struct clip_ctx * ctx);
GGML_DEPRECATED(CLIP_API int clip_n_patches(const struct clip_ctx * ctx),
"use clip_n_output_tokens instead");
GGML_DEPRECATED(CLIP_API int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * img),
"use clip_n_output_tokens instead");

CLIP_API int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * img);

// for M-RoPE, this will be the number of token positions in X and Y directions
// for other models, X will be the total number of tokens and Y will be 1
CLIP_API int clip_n_output_tokens_x(const struct clip_ctx * ctx, struct clip_image_f32 * img);
CLIP_API int clip_n_output_tokens_y(const struct clip_ctx * ctx, struct clip_image_f32 * img);

// this should be equal to the embedding dimension of the text model
CLIP_API int clip_n_mmproj_embd(const struct clip_ctx * ctx);

CLIP_API int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip);
CLIP_API void clip_add_load_image_size(struct clip_ctx * ctx_clip, struct clip_image_size * load_image_size);
Expand Down
15 changes: 8 additions & 7 deletions examples/llava/llava.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ static struct clip_image_grid_shape get_anyres_image_grid_shape(const std::pair<
}

// Take the image segments in a grid configuration and return the embeddings and the number of embeddings into preallocated memory (image_embd_out)
static bool clip_llava_handle_patches(clip_ctx * ctx_clip, std::vector<float *> & image_embd_v, struct clip_image_grid_shape grid_shape, float * image_embd_out, int * n_img_pos_out) {
static bool clip_llava_handle_patches(clip_ctx * ctx_clip, std::vector<float *> & image_embd_v, struct clip_image_grid_shape grid_shape, float * image_embd_out, int * n_img_pos_out, clip_image_f32 * img_input) {
struct {
struct ggml_context * ctx;
} model;
Expand Down Expand Up @@ -175,7 +175,7 @@ static bool clip_llava_handle_patches(clip_ctx * ctx_clip, std::vector<float *>

model.ctx = ggml_init(params);

struct ggml_tensor * image_features = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F32, clip_n_mmproj_embd(ctx_clip), clip_n_patches(ctx_clip), num_images - 1); // example: 4096 x 576 x 4
struct ggml_tensor * image_features = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F32, clip_n_mmproj_embd(ctx_clip), clip_n_output_tokens(ctx_clip, img_input), num_images - 1); // example: 4096 x 576 x 4
// ggml_tensor_printf(image_features,"image_features",__LINE__,false,false);
// fill it with the image embeddings, ignoring the base
for (size_t i = 1; i < num_images; i++) {
Expand Down Expand Up @@ -214,8 +214,8 @@ static bool clip_llava_handle_patches(clip_ctx * ctx_clip, std::vector<float *>

memcpy(image_embd_out, image_embd_v[0], clip_embd_nbytes(ctx_clip)); // main image as global context
// append without newline tokens (default behavior in llava_arch when not using unpad ):
memcpy(image_embd_out + clip_n_patches(ctx_clip) * clip_n_mmproj_embd(ctx_clip), (float*)result->data, clip_embd_nbytes(ctx_clip) * (num_images-1)); // grid patches
*n_img_pos_out = static_cast<int>(result->ne[1]+clip_n_patches(ctx_clip));
memcpy(image_embd_out + clip_n_output_tokens(ctx_clip, img_input) * clip_n_mmproj_embd(ctx_clip), (float*)result->data, clip_embd_nbytes(ctx_clip) * (num_images-1)); // grid patches
*n_img_pos_out = static_cast<int>(result->ne[1]+clip_n_output_tokens(ctx_clip, img_input));

// Debug: Test single segments
// Current findings: sending base image, sending a segment embedding all works similar to python
Expand Down Expand Up @@ -313,7 +313,7 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
image_embd + n_img_pos_out * clip_n_mmproj_embd(ctx_clip),
image_embd_v[i],
clip_embd_nbytes_by_img(ctx_clip, nx, ny));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the other call.

n_img_pos_out += clip_n_patches_by_img(ctx_clip, img_res);
n_img_pos_out += clip_n_output_tokens(ctx_clip, img_res);
}
*n_img_pos = n_img_pos_out;
for (size_t i = 0; i < image_embd_v.size(); i++) {
Expand Down Expand Up @@ -342,8 +342,8 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
}
else if (strcmp(mm_patch_merge_type, "spatial_unpad") != 0) {
// flat / default llava-1.5 type embedding
*n_img_pos = clip_n_patches(ctx_clip);
clip_image_f32 * img_res = clip_image_f32_get_img(img_res_v.get(), 0);
*n_img_pos = clip_n_output_tokens(ctx_clip, img_res);
bool encoded = clip_image_encode(ctx_clip, n_threads, img_res, image_embd); // image_embd shape is 576 x 4096
if (!encoded) {
LOG_ERR("Unable to encode image\n");
Expand Down Expand Up @@ -381,7 +381,8 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
struct clip_image_grid_shape grid_shape = get_anyres_image_grid_shape({img->nx,img->ny}, grid_pinpoints, image_size);

int n_img_pos_out;
clip_llava_handle_patches(ctx_clip, image_embd_v, grid_shape, image_embd, &n_img_pos_out);
clip_image_f32 * img_input = clip_image_f32_get_img(img_res_v.get(), 0);
clip_llava_handle_patches(ctx_clip, image_embd_v, grid_shape, image_embd, &n_img_pos_out, img_input);
*n_img_pos = n_img_pos_out;

for (size_t i = 0; i < image_embd_v.size(); i++) {
Expand Down
36 changes: 2 additions & 34 deletions examples/llava/mtmd-cli.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,39 +136,6 @@ struct mtmd_cli_context {
}
};

struct decode_embd_batch {
std::vector<llama_pos> pos;
std::vector<int32_t> n_seq_id;
std::vector<llama_seq_id> seq_id_0;
std::vector<llama_seq_id *> seq_ids;
std::vector<int8_t> logits;
llama_batch batch;
decode_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
pos .resize(n_tokens);
n_seq_id.resize(n_tokens);
seq_ids .resize(n_tokens + 1);
logits .resize(n_tokens);
seq_id_0.resize(1);
seq_id_0[0] = seq_id;
seq_ids [n_tokens] = nullptr;
batch = {
/*n_tokens =*/ n_tokens,
/*tokens =*/ nullptr,
/*embd =*/ embd,
/*pos =*/ pos.data(),
/*n_seq_id =*/ n_seq_id.data(),
/*seq_id =*/ seq_ids.data(),
/*logits =*/ logits.data(),
};
for (int i = 0; i < n_tokens; i++) {
batch.pos [i] = pos_0 + i;
batch.n_seq_id[i] = 1;
batch.seq_id [i] = seq_id_0.data();
batch.logits [i] = false;
}
}
};

static int generate_response(mtmd_cli_context & ctx, common_sampler * smpl, int n_predict) {
llama_tokens generated_tokens;
for (int i = 0; i < n_predict; i++) {
Expand Down Expand Up @@ -243,7 +210,7 @@ static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg, std::vect
return 1;
}

ctx.n_past += mtmd_helper_get_n_tokens(chunks);
ctx.n_past += mtmd_helper_get_n_pos(chunks);

return 0;
}
Expand Down Expand Up @@ -371,6 +338,7 @@ int main(int argc, char ** argv) {
}
}
if (g_is_interrupted) LOG("\nInterrupted by user\n");
LOG("\n\n");
llama_perf_context_print(ctx.lctx);
return g_is_interrupted ? 130 : 0;
}
Loading
Loading