diff --git a/.dockerignore b/.dockerignore old mode 100644 new mode 100755 diff --git a/.editorconfig b/.editorconfig old mode 100644 new mode 100755 diff --git a/.gitattributes b/.gitattributes old mode 100644 new mode 100755 diff --git a/.gitignore b/.gitignore old mode 100644 new mode 100755 index 6ff6a3dc8..7d39e422a --- a/.gitignore +++ b/.gitignore @@ -7,25 +7,31 @@ dist # Log *.log *.log.* -*.json -*.jsonl +# *.json +# *.jsonl # Data !**/alpaca-data-conversation.json - # Editor .idea *.swp +.vscode # Other .DS_Store wandb output +llavavid checkpoints +project_checkpoints +debug_checkpoints +playground/data +playground/cc3m_llava34b_cap ckpts* .ipynb_checkpoints +chunyl_scripts *.ipynb # DevContainer @@ -33,3 +39,35 @@ ckpts* # Demo serve_images/ +notebooks/ +logs +scripts/dist_* +logs/ +submissions/ +cn_scripts/ +internal_project_checkpoints/ +work_dirs +scripts/i18n/* +playground/.nfs028b000000010add00000001 +HIP +playground/.nfs028b0000017bff2c00000012 +scripts/qwen +scripts/vicuna +scripts/mistral +scripts/baseline_rep +scripts/cn_boli01_hl +scripts/cn_boli01_lf +scripts/cn_lf +scripts/cn_lq +scripts/cn_yg +scripts/cn_yg_hao +scripts/eva_encoder +scripts/i18n +scripts/i18n_higher_res +scripts/multi-images +scratchpad +build/ +playground/*.json +mlx_configs/ +data_processing/ +# demo/ diff --git a/LICENSE b/LICENSE old mode 100644 new mode 100755 diff --git a/README.md b/README.md old mode 100644 new mode 100755 index 794ce1b27..053269c2e --- a/README.md +++ b/README.md @@ -1,427 +1,202 @@ -# 🌋 LLaVA: Large Language and Vision Assistant +

+ +

+ +# LLaVA-NeXT: Open Large Multimodal Models +[![Static Badge](https://img.shields.io/badge/llava_video-paper-green)](http://arxiv.org/abs/2410.02713) +[![Static Badge](https://img.shields.io/badge/llava_onevision-paper-green)](https://arxiv.org/abs/2408.03326) +[![llava_next-blog](https://img.shields.io/badge/llava_next-blog-green)](https://llava-vl.github.io/blog/) + +[![llava_onevision-demo](https://img.shields.io/badge/llava_onevision-demo-red)](https://llava-onevision.lmms-lab.com/) +[![llava_next-video_demo](https://img.shields.io/badge/llava_video-demo-red)](https://huggingface.co/spaces/WildVision/vision-arena) +[![llava_next-interleave_demo](https://img.shields.io/badge/llava_next-interleave_demo-red)](https://huggingface.co/spaces/lmms-lab/LLaVA-NeXT-Interleave-Demo) +[![Openbayes Demo](https://img.shields.io/static/v1?label=Demo&message=OpenBayes%E8%B4%9D%E5%BC%8F%E8%AE%A1%E7%AE%97&color=green)](https://openbayes.com/console/public/tutorials/gW0ng9jKXfO) + +[![llava_video-checkpoints](https://img.shields.io/badge/llava_video-checkpoints-blue)](https://huggingface.co/collections/lmms-lab/llava-next-video-661e86f5e8dabc3ff793c944) +[![llava_onevision-checkpoints](https://img.shields.io/badge/llava_onevision-checkpoints-blue)](https://huggingface.co/collections/lmms-lab/llava-onevision-66a259c3526e15166d6bba37) +[![llava_next-interleave_checkpoints](https://img.shields.io/badge/llava_next-interleave_checkpoints-blue)](https://huggingface.co/collections/lmms-lab/llava-next-interleave-66763c55c411b340b35873d1) +[![llava_next-image_checkpoints](https://img.shields.io/badge/llava_next-image_checkpoints-blue)](https://huggingface.co/lmms-lab) + +## Release Notes + +- **[2024/10/04] 🔥 LLaVA-Video** (formerly LLaVA-NeXT-Video) has undergone a major upgrade! We are excited to release **LLaVA-Video-178K**, a high-quality synthetic dataset for video instruction tuning. This dataset includes: + + - 178,510 caption entries + - 960,792 open-ended Q&A pairs + - 196,198 multiple-choice Q&A items -*Visual instruction tuning towards large language and vision models with GPT-4 level capabilities.* + Along with this, we’re also releasing the **LLaVA-Video 7B/72B models**, which deliver competitive performance on the latest video benchmarks, including [Video-MME](https://video-mme.github.io/home_page.html#leaderboard), [LongVideoBench](https://longvideobench.github.io/), and [Dream-1K](https://tarsier-vlm.github.io/). -[📢 [LLaVA-NeXT Blog](https://llava-vl.github.io/blog/2024-01-30-llava-next/)] [[Project Page](https://llava-vl.github.io/)] [[Demo](https://llava.hliu.cc/)] [[Data](https://github.com/haotian-liu/LLaVA/blob/main/docs/Data.md)] [[Model Zoo](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md)] + 📄 **Explore more**: + - [LLaVA-Video-178K Dataset](https://huggingface.co/datasets/lmms-lab/LLaVA-Video-178K): Download the dataset. + - [LLaVA-Video Models](https://huggingface.co/collections/lmms-lab/llava-video-661e86f5e8dabc3ff793c944): Access model checkpoints. + - [Paper](http://arxiv.org/abs/2410.02713): Detailed information about LLaVA-Video. + - [LLaVA-Video Documentation](https://github.com/LLaVA-VL/LLaVA-NeXT/blob/main/docs/LLaVA_Video_1003.md): Guidance on training, inference and evaluation. -🤝Community Contributions: [[llama.cpp](https://github.com/ggerganov/llama.cpp/pull/3436)] [[Colab](https://github.com/camenduru/LLaVA-colab)] [[🤗Space](https://huggingface.co/spaces/badayvedat/LLaVA)] [[Replicate](https://replicate.com/yorickvp/llava-13b)] [[AutoGen](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_lmm_llava.ipynb)] [[BakLLaVA](https://github.com/SkunkworksAI/BakLLaVA)] +- [2024/09/13] 🔥 **🚀 [LLaVA-OneVision-Chat](docs/LLaVA_OneVision_Chat.md)**. The new LLaVA-OV-Chat (7B/72B) significantly improves the chat experience of LLaVA-OV. 📄 + + ![](docs/ov_chat_images/chat_results.png) -**Improved Baselines with Visual Instruction Tuning** [[Paper](https://arxiv.org/abs/2310.03744)] [[HF](https://huggingface.co/papers/2310.03744)]
-[Haotian Liu](https://hliu.cc), [Chunyuan Li](https://chunyuan.li/), [Yuheng Li](https://yuheng-li.github.io/), [Yong Jae Lee](https://pages.cs.wisc.edu/~yongjaelee/) +- [2024/08/06] 🔥 **🚀 [LLaVA-OneVision (OV)](https://llava-vl.github.io/blog/2024-08-05-llava-onevision/)!** The new LLaVA-OV models (0.5B/7B/72B) achieve new state-of-the-art performance across single-image, multi-image, and video benchmarks, sometimes rivaling top commercial models on 47 diverse benchmarks. 📄 Explore More: + * [[Paper]](https://arxiv.org/abs/2408.03326): In-depth insights, new emegerging scenarios, ie, strong video understadning through task transfer from images. + * [[LLaVA-OV Doc]](https://github.com/LLaVA-VL/LLaVA-NeXT/blob/main/docs/LLaVA_OneVision.md): Model inference and evaluation guidance. + * [[Scripts]](https://github.com/LLaVA-VL/LLaVA-NeXT/blob/main/scripts/train): Start training models on your single-image/multi-image/video data. + +- [2024/07/16] 🔥 **LLaVA-NeXT-Video** has been upgraded. The new 32B model achieves the best open-source performance on several video benchmarks, including [Video-MME](https://video-mme.github.io/home_page.html#leaderboard). Please refer to [this page](docs/LLaVA-NeXT-Video_0716.md) for details, refer to [llava_next-video_demo](https://huggingface.co/spaces/WildVision/vision-arena) for demo. -**Visual Instruction Tuning** (NeurIPS 2023, **Oral**) [[Paper](https://arxiv.org/abs/2304.08485)] [[HF](https://huggingface.co/papers/2304.08485)]
-[Haotian Liu*](https://hliu.cc), [Chunyuan Li*](https://chunyuan.li/), [Qingyang Wu](https://scholar.google.ca/citations?user=HDiw-TsAAAAJ&hl=en/), [Yong Jae Lee](https://pages.cs.wisc.edu/~yongjaelee/) (*Equal Contribution) - +- [2024/06/23] 🔥 **LLaVA-NeXT-Interleave** is released. We utilize image-text interleaved format to unify multi-image, video, and 3D tasks in one LLM and achieve **SoTA** performance on a wide range of benchmarks. Check out [paper](https://arxiv.org/pdf/2407.07895), [blog](https://llava-vl.github.io/blog/2024-06-16-llava-next-interleave/), and [checkpoints](https://huggingface.co/collections/lmms-lab/llava-next-interleave-66763c55c411b340b35873d1) to see new capabilities and improved performance! We have released 0.5b, 7b, and 7b-dpo models. + * An all-round LLM for multi-image, video, and 3D with strong performance \[[demo](https://huggingface.co/spaces/lmms-lab/LLaVA-NeXT-Interleave-Demo)\] + * Construct interleave training data [**M4-Instruct**](https://huggingface.co/datasets/lmms-lab/M4-Instruct-Data) + * Construct multi-image benchmark [**LLaVA-Interleave Bench**](https://huggingface.co/datasets/lmms-lab/LLaVA-NeXT-Interleave-Bench) -## Release -- [03/10] Releasing **LMMs-Eval**, a highly efficient evaluation pipeline we used when developing LLaVA-NeXT. It supports the evaluation of LMMs on dozens of public datasets and allows new dataset onboarding, making the dev of new LMMs much faster. [[Blog](https://lmms-lab.github.io/lmms-eval-blog/lmms-eval-0.1/)] [[Codebase](https://github.com/EvolvingLMMs-Lab/lmms-eval)] -- [1/30] 🔥 LLaVA-NeXT (LLaVA-1.6) is out! With additional scaling to LLaVA-1.5, LLaVA-NeXT-34B outperforms Gemini Pro on some benchmarks. It can now process 4x more pixels and perform more tasks/applications than before. Check out the [blog post](https://llava-vl.github.io/blog/2024-01-30-llava-next/), and explore the [demo](https://llava.hliu.cc/)! Models are available in [Model Zoo](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md). Training/eval data and scripts coming soon. -- [11/10] [LLaVA-Plus](https://llava-vl.github.io/llava-plus/) is released: Learning to Use Tools for Creating Multimodal Agents, with LLaVA-Plus (LLaVA that Plug and Learn to Use Skills). [[Project Page](https://llava-vl.github.io/llava-plus/)] [[Demo](https://llavaplus.ngrok.io/)] [[Code](https://github.com/LLaVA-VL/LLaVA-Plus-Codebase)] [[Paper](https://arxiv.org/abs/2311.05437)] -- [11/2] [LLaVA-Interactive](https://llava-vl.github.io/llava-interactive/) is released: Experience the future of human-AI multimodal interaction with an all-in-one demo for Image Chat, Segmentation, Generation and Editing. [[Project Page](https://llava-vl.github.io/llava-interactive/)] [[Demo](https://llavainteractive.ngrok.io/)] [[Code](https://github.com/LLaVA-VL/LLaVA-Interactive-Demo)] [[Paper](https://arxiv.org/abs/2311.00571)] -- [10/26] 🔥 LLaVA-1.5 with LoRA achieves comparable performance as full-model finetuning, with a reduced GPU RAM requirement ([ckpts](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md#llava-v15), [script](https://github.com/haotian-liu/LLaVA#train)). We also provide a [doc](https://github.com/haotian-liu/LLaVA/blob/main/docs/Finetune_Custom_Data.md) on how to finetune LLaVA-1.5 on your own dataset with LoRA. -- [10/12] Check out the Korean LLaVA (Ko-LLaVA), created by ETRI, who has generously supported our research! [[🤗 Demo](https://huggingface.co/spaces/etri-vilab/Ko-LLaVA)] -- [10/5] 🔥 LLaVA-1.5 is out! Achieving SoTA on 11 benchmarks, with just simple modifications to the original LLaVA, utilizes all public data, completes training in ~1 day on a single 8-A100 node, and surpasses methods like Qwen-VL-Chat that use billion-scale data. Check out the [technical report](https://arxiv.org/abs/2310.03744), and explore the [demo](https://llava.hliu.cc/)! Models are available in [Model Zoo](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md). The training data and scripts of LLaVA-1.5 are released [here](https://github.com/haotian-liu/LLaVA#train), and evaluation scripts are released [here](https://github.com/haotian-liu/LLaVA/blob/main/docs/Evaluation.md)! -- [9/26] LLaVA is improved with reinforcement learning from human feedback (RLHF) to improve fact grounding and reduce hallucination. Check out the new SFT and RLHF checkpoints at project [[LLavA-RLHF]](https://llava-rlhf.github.io/) -- [9/22] [LLaVA](https://arxiv.org/abs/2304.08485) is accepted by NeurIPS 2023 as **oral presentation**, and [LLaVA-Med](https://arxiv.org/abs/2306.00890) is accepted by NeurIPS 2023 Datasets and Benchmarks Track as **spotlight presentation**. +- [2024/05/25] 🔥 Wondering "[What Else Influences Visual Instruction Tuning Beyond Data?](https://llava-vl.github.io/blog/2024-05-25-llava-next-ablations/)" Our new [blog](https://llava-vl.github.io/blog/2024-05-25-llava-next-ablations/) summarizes empirical explorations to ablate the various design choices in improving LMMs except instruct data itself. Meanwhile, open-source the recapioned high-quality data using LLaVA-NeXT-34B on [[COCO]](https://huggingface.co/datasets/lmms-lab/LLaVA-ReCap-118K) [[LCS]](https://huggingface.co/datasets/lmms-lab/LLaVA-ReCap-558K) [[CC3M]](https://huggingface.co/datasets/lmms-lab/LLaVA-ReCap-CC3M). + * Architectures (LMM & Vision Encoder) + * Visual Representations (Resolution & # Tokens) + * Training Strategies (High-quality data & Trainable modules) +- [2024/05/10] 🔥 **LLaVA-NeXT** (Stronger) models are released, with support of stronger LMM inlcuding LLama-3 (8B) and Qwen-1.5 (72B/110B) Check out [[blog](https://llava-vl.github.io/blog/2024-05-10-llava-next-stronger-llms/)] and [[checkpoints](https://huggingface.co/lmms-lab)] to see improved performance! +- [2024/05/10] 🔥 **LLaVA-NeXT** (Video) is released. The image-only-trained LLaVA-NeXT model is surprisingly strong on video tasks with zero-shot modality transfer. DPO training with AI feedback on videos can yield significant improvement. [[Blog](https://llava-vl.github.io/blog/2024-04-30-llava-next-video/)], [[checkpoints](https://huggingface.co/collections/lmms-lab/llava-next-video-661e86f5e8dabc3ff793c944)] and [[sglang](https://github.com/sgl-project/sglang)] +- [2024/01/30] 🔥 **LLaVA-NeXT** is out! With additional scaling to LLaVA-1.5, LLaVA-NeXT-34B outperforms Gemini Pro on some benchmarks. It can now process 4x more pixels and perform more tasks/applications than before. Check out the [blog post](https://llava-vl.github.io/blog/2024-01-30-llava-next/), and explore the [demo](https://llava.hliu.cc/)! Models are available in [Model Zoo](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md). Training/eval data and scripts coming soon.
More - -- [11/6] Support **Intel** dGPU and CPU platforms. [More details here.](https://github.com/haotian-liu/LLaVA/tree/intel/docs/intel) -- [10/12] LLaVA is now supported in [llama.cpp](https://github.com/ggerganov/llama.cpp/pull/3436) with 4-bit / 5-bit quantization support! -- [10/11] The training data and scripts of LLaVA-1.5 are released [here](https://github.com/haotian-liu/LLaVA#train), and evaluation scripts are released [here](https://github.com/haotian-liu/LLaVA/blob/main/docs/Evaluation.md)! -- [10/10] [Roboflow Deep Dive](https://blog.roboflow.com/first-impressions-with-llava-1-5/): First Impressions with LLaVA-1.5. -- [9/20] We summarize our empirical study of training 33B and 65B LLaVA models in a [note](https://arxiv.org/abs/2309.09958). Further, if you are interested in the comprehensive review, evolution and trend of multimodal foundation models, please check out our recent survey paper [``Multimodal Foundation Models: From Specialists to General-Purpose Assistants''.](https://arxiv.org/abs/2309.10020) + +- [2024/03/10] 🔥 Releasing **LMMs-Eval**, a highly efficient evaluation pipeline we used when developing LLaVA-NeXT. It supports the evaluation of LMMs on dozens of public datasets and allows new dataset onboarding, making the dev of new LMMs much faster. [[Blog](https://lmms-lab.github.io/lmms-eval-blog/lmms-eval-0.1/)] [[Codebase](https://github.com/EvolvingLMMs-Lab/lmms-eval)] + +- [2023/11/10] [LLaVA-Plus](https://llava-vl.github.io/llava-plus/) is released: Learning to Use Tools for Creating Multimodal Agents, with LLaVA-Plus (LLaVA that Plug and Learn to Use Skills). [[Project Page](https://llava-vl.github.io/llava-plus/)] [[Demo](https://llavaplus.ngrok.io/)] [[Code](https://github.com/LLaVA-VL/LLaVA-Plus-Codebase)] [[Paper](https://arxiv.org/abs/2311.05437)] +- [2023/11/02] [LLaVA-Interactive](https://llava-vl.github.io/llava-interactive/) is released: Experience the future of human-AI multimodal interaction with an all-in-one demo for Image Chat, Segmentation, Generation and Editing. [[Project Page](https://llava-vl.github.io/llava-interactive/)] [[Demo](https://llavainteractive.ngrok.io/)] [[Code](https://github.com/LLaVA-VL/LLaVA-Interactive-Demo)] [[Paper](https://arxiv.org/abs/2311.00571)] +- [2023/10/26] 🔥 LLaVA-1.5 with LoRA achieves comparable performance as full-model finetuning, with a reduced GPU RAM requirement ([ckpts](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md#llava-v15), [script](https://github.com/haotian-liu/LLaVA#train)). We also provide a [doc](https://github.com/haotian-liu/LLaVA/blob/main/docs/Finetune_Custom_Data.md) on how to finetune LLaVA-1.5 on your own dataset with LoRA. +- [2023/10/12] Check out the Korean LLaVA (Ko-LLaVA), created by ETRI, who has generously supported our research! [[🤗 Demo](https://huggingface.co/spaces/etri-vilab/Ko-LLaVA)] +- [2023/10/05] 🔥 LLaVA-1.5 is out! Achieving SoTA on 11 benchmarks, with just simple modifications to the original LLaVA, utilizes all public data, completes training in ~1 day on a single 8-A100 node, and surpasses methods like Qwen-VL-Chat that use billion-scale data. Check out the [technical report](https://arxiv.org/abs/2310.03744), and explore the [demo](https://llava.hliu.cc/)! Models are available in [Model Zoo](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md). The training data and scripts of LLaVA-1.5 are released [here](https://github.com/haotian-liu/LLaVA#train), and evaluation scripts are released [here](https://github.com/haotian-liu/LLaVA/blob/main/docs/Evaluation.md)! +- [2023/09/26] LLaVA is improved with reinforcement learning from human feedback (RLHF) to improve fact grounding and reduce hallucination. Check out the new SFT and RLHF checkpoints at project [[LLavA-RLHF]](https://llava-rlhf.github.io/) +- [2023/09/22] [LLaVA](https://arxiv.org/abs/2304.08485) is accepted by NeurIPS 2023 as **oral presentation**, and [LLaVA-Med](https://arxiv.org/abs/2306.00890) is accepted by NeurIPS 2023 Datasets and Benchmarks Track as **spotlight presentation**. +- [2023/11/06] Support **Intel** dGPU and CPU platforms. [More details here.](https://github.com/haotian-liu/LLaVA/tree/intel/docs/intel) +- [2023/10/12] LLaVA is now supported in [llama.cpp](https://github.com/ggerganov/llama.cpp/pull/3436) with 4-bit / 5-bit quantization support! +- [2023/10/11] The training data and scripts of LLaVA-1.5 are released [here](https://github.com/haotian-liu/LLaVA#train), and evaluation scripts are released [here](https://github.com/haotian-liu/LLaVA/blob/main/docs/Evaluation.md)! +- [2023/10/10] [Roboflow Deep Dive](https://blog.roboflow.com/first-impressions-with-llava-1-5/): First Impressions with LLaVA-1.5. +- [2023/09/20] We summarize our empirical study of training 33B and 65B LLaVA models in a [note](https://arxiv.org/abs/2309.09958). Further, if you are interested in the comprehensive review, evolution and trend of multimodal foundation models, please check out our recent survey paper [``Multimodal Foundation Models: From Specialists to General-Purpose Assistants''.](https://arxiv.org/abs/2309.10020)

-- [7/19] 🔥 We release a major upgrade, including support for LLaMA-2, LoRA training, 4-/8-bit inference, higher resolution (336x336), and a lot more. We release [LLaVA Bench](https://github.com/haotian-liu/LLaVA/blob/main/docs/LLaVA_Bench.md) for benchmarking open-ended visual chat with results from Bard and Bing-Chat. We also support and verify training with RTX 3090 and RTX A6000. Check out [LLaVA-from-LLaMA-2](https://github.com/haotian-liu/LLaVA/blob/main/docs/LLaVA_from_LLaMA2.md), and our [model zoo](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md)! -- [6/26] [CVPR 2023 Tutorial](https://vlp-tutorial.github.io/) on **Large Multimodal Models: Towards Building and Surpassing Multimodal GPT-4**! Please check out [[Slides](https://datarelease.blob.core.windows.net/tutorial/vision_foundation_models_2023/slides/Chunyuan_cvpr2023_tutorial_lmm.pdf)] [[Notes](https://arxiv.org/abs/2306.14895)] [[YouTube](https://youtu.be/mkI7EPD1vp8)] [[Bilibli](https://www.bilibili.com/video/BV1Ng4y1T7v3/)]. -- [6/11] We released the preview for the most requested feature: DeepSpeed and LoRA support! Please see documentations [here](./docs/LoRA.md). -- [6/1] We released **LLaVA-Med: Large Language and Vision Assistant for Biomedicine**, a step towards building biomedical domain large language and vision models with GPT-4 level capabilities. Checkout the [paper](https://arxiv.org/abs/2306.00890) and [page](https://github.com/microsoft/LLaVA-Med). -- [5/6] We are releasing [LLaVA-Lighting-MPT-7B-preview](https://huggingface.co/liuhaotian/LLaVA-Lightning-MPT-7B-preview), based on MPT-7B-Chat! See [here](#LLaVA-MPT-7b) for more details. -- [5/2] 🔥 We are releasing LLaVA-Lighting! Train a lite, multimodal GPT-4 with just $40 in 3 hours! See [here](#train-llava-lightning) for more details. -- [4/27] Thanks to the community effort, LLaVA-13B with 4-bit quantization allows you to run on a GPU with as few as 12GB VRAM! Try it out [here](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/llava). -- [4/17] 🔥 We released **LLaVA: Large Language and Vision Assistant**. We propose visual instruction tuning, towards building large language and vision models with GPT-4 level capabilities. Checkout the [paper](https://arxiv.org/abs/2304.08485) and [demo](https://llava.hliu.cc/). +- [2023/07/19] 🔥 We release a major upgrade, including support for LLaMA-2, LoRA training, 4-/8-bit inference, higher resolution (336x336), and a lot more. We release [LLaVA Bench](https://github.com/haotian-liu/LLaVA/blob/main/docs/LLaVA_Bench.md) for benchmarking open-ended visual chat with results from Bard and Bing-Chat. We also support and verify training with RTX 3090 and RTX A6000. Check out [LLaVA-from-LLaMA-2](https://github.com/haotian-liu/LLaVA/blob/main/docs/LLaVA_from_LLaMA2.md), and our [model zoo](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md)! +- [2023/06/26] [CVPR 2023 Tutorial](https://vlp-tutorial.github.io/) on **Large Multimodal Models: Towards Building and Surpassing Multimodal GPT-4**! Please check out [[Slides](https://datarelease.blob.core.windows.net/tutorial/vision_foundation_models_2023/slides/Chunyuan_cvpr2023_tutorial_lmm.pdf)] [[Notes](https://arxiv.org/abs/2306.14895)] [[YouTube](https://youtu.be/mkI7EPD1vp8)] [[Bilibli](https://www.bilibili.com/video/BV1Ng4y1T7v3/)]. +- [2023/06/11] We released the preview for the most requested feature: DeepSpeed and LoRA support! Please see documentations [here](./docs/LoRA.md). +- [2023/06/01] We released **LLaVA-Med: Large Language and Vision Assistant for Biomedicine**, a step towards building biomedical domain large language and vision models with GPT-4 level capabilities. Checkout the [paper](https://arxiv.org/abs/2306.00890) and [page](https://github.com/microsoft/LLaVA-Med). +- [2023/05/06] We are releasing [LLaVA-Lighting-MPT-7B-preview](https://huggingface.co/liuhaotian/LLaVA-Lightning-MPT-7B-preview), based on MPT-7B-Chat! See [here](#LLaVA-MPT-7b) for more details. +- [2023/05/02] 🔥 We are releasing LLaVA-Lighting! Train a lite, multimodal GPT-4 with just $40 in 3 hours! See [here](#train-llava-lightning) for more details. +- [2023/04/27] Thanks to the community effort, LLaVA-13B with 4-bit quantization allows you to run on a GPU with as few as 12GB VRAM! Try it out [here](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/llava). +- [2023/04/17] 🔥 We released **LLaVA: Large Language and Vision Assistant**. We propose visual instruction tuning, towards building large language and vision models with GPT-4 level capabilities. Checkout the [paper](https://arxiv.org/abs/2304.08485) and [demo](https://llava.hliu.cc/).
-[![Code License](https://img.shields.io/badge/Code%20License-Apache_2.0-green.svg)](https://github.com/tatsu-lab/stanford_alpaca/blob/main/LICENSE) -**Usage and License Notices**: This project utilizes certain datasets and checkpoints that are subject to their respective original licenses. Users must comply with all terms and conditions of these original licenses, including but not limited to the [OpenAI Terms of Use](https://openai.com/policies/terms-of-use) for the dataset and the specific licenses for base language models for checkpoints trained using the dataset (e.g. [Llama community license](https://ai.meta.com/llama/license/) for LLaMA-2 and Vicuna-v1.5). This project does not impose any additional constraints beyond those stipulated in the original licenses. Furthermore, users are reminded to ensure that their use of the dataset and checkpoints is in compliance with all applicable laws and regulations. - +**Usage and License Notices**: This project utilizes certain datasets and checkpoints that are subject to their respective original licenses. Users must comply with all terms and conditions of these original licenses, including but not limited to the [OpenAI Terms of Use](https://openai.com/policies/terms-of-use) for the dataset and the specific licenses for base language models for checkpoints trained using the dataset (e.g. [Llama-1/2 community license](https://ai.meta.com/llama/license/) for LLaMA-2 and Vicuna-v1.5, [Tongyi Qianwen RESEARCH LICENSE AGREEMENT](https://huggingface.co/Qwen/Qwen1.5-0.5B-Chat/blob/main/LICENSE) and [Llama-3 Research License](https://llama.meta.com/llama3/license/)). This project does not impose any additional constraints beyond those stipulated in the original licenses. Furthermore, users are reminded to ensure that their use of the dataset and checkpoints is in compliance with all applicable laws and regulations. -## Contents -- [Install](#install) -- [LLaVA Weights](#llava-weights) -- [Demo](#Demo) -- [Model Zoo](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md) -- [Dataset](https://github.com/haotian-liu/LLaVA/blob/main/docs/Data.md) -- [Train](#train) -- [Evaluation](#evaluation) +## Models & Scripts -## Install +### Installation -If you are not using Linux, do *NOT* proceed, see instructions for [macOS](https://github.com/haotian-liu/LLaVA/blob/main/docs/macOS.md) and [Windows](https://github.com/haotian-liu/LLaVA/blob/main/docs/Windows.md). - -1. Clone this repository and navigate to LLaVA folder +#### 1. **Clone this repository and navigate to the LLaVA folder:** ```bash -git clone https://github.com/haotian-liu/LLaVA.git -cd LLaVA +git clone https://github.com/LLaVA-VL/LLaVA-NeXT +cd LLaVA-NeXT ``` -2. Install Package -```Shell +#### 2. **Install the inference package:** +```bash conda create -n llava python=3.10 -y conda activate llava -pip install --upgrade pip # enable PEP 660 support -pip install -e . -``` - -3. Install additional packages for training cases -``` +pip install --upgrade pip # Enable PEP 660 support. pip install -e ".[train]" -pip install flash-attn --no-build-isolation -``` - -### Upgrade to latest code base - -```Shell -git pull -pip install -e . - -# if you see some import errors when you upgrade, -# please try running the command below (without #) -# pip install flash-attn --no-build-isolation --no-cache-dir -``` - -### Quick Start With HuggingFace - -
-Example Code - -```Python -from llava.model.builder import load_pretrained_model -from llava.mm_utils import get_model_name_from_path -from llava.eval.run_llava import eval_model - -model_path = "liuhaotian/llava-v1.5-7b" - -tokenizer, model, image_processor, context_len = load_pretrained_model( - model_path=model_path, - model_base=None, - model_name=get_model_name_from_path(model_path) -) -``` - -Check out the details wth the `load_pretrained_model` function in `llava/model/builder.py`. - -You can also use the `eval_model` function in `llava/eval/run_llava.py` to get the output easily. By doing so, you can use this code on Colab directly after downloading this repository. - -``` python -model_path = "liuhaotian/llava-v1.5-7b" -prompt = "What are the things I should be cautious about when I visit here?" -image_file = "https://llava-vl.github.io/static/images/view.jpg" - -args = type('Args', (), { - "model_path": model_path, - "model_base": None, - "model_name": get_model_name_from_path(model_path), - "query": prompt, - "conv_mode": None, - "image_file": image_file, - "sep": ",", - "temperature": 0, - "top_p": None, - "num_beams": 1, - "max_new_tokens": 512 -})() - -eval_model(args) -``` -
- -## LLaVA Weights -Please check out our [Model Zoo](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md) for all public LLaVA checkpoints, and the instructions of how to use the weights. - -## Demo - -### Gradio Web UI - -To launch a Gradio demo locally, please run the following commands one by one. If you plan to launch multiple model workers to compare between different checkpoints, you only need to launch the controller and the web server *ONCE*. - -```mermaid -flowchart BT - %% Declare Nodes - gws("Gradio (UI Server)") - c("Controller (API Server):
PORT: 10000") - mw7b("Model Worker:
llava-v1.5-7b
PORT: 40000") - mw13b("Model Worker:
llava-v1.5-13b
PORT: 40001") - sglw13b("SGLang Backend:
llava-v1.6-34b
http://localhost:30000") - lsglw13b("SGLang Worker:
llava-v1.6-34b
PORT: 40002") - - %% Declare Styles - classDef data fill:#3af,stroke:#48a,stroke-width:2px,color:#444 - classDef success fill:#8f8,stroke:#0a0,stroke-width:2px,color:#444 - classDef failure fill:#f88,stroke:#f00,stroke-width:2px,color:#444 - - %% Assign Styles - class id,od data; - class cimg,cs_s,scsim_s success; - class ncimg,cs_f,scsim_f failure; - - subgraph Demo Connections - direction BT - c<-->gws - - mw7b<-->c - mw13b<-->c - lsglw13b<-->c - sglw13b<-->lsglw13b - end -``` - -#### Launch a controller -```Shell -python -m llava.serve.controller --host 0.0.0.0 --port 10000 -``` - -#### Launch a gradio web server. -```Shell -python -m llava.serve.gradio_web_server --controller http://localhost:10000 --model-list-mode reload -``` -You just launched the Gradio web interface. Now, you can open the web interface with the URL printed on the screen. You may notice that there is no model in the model list. Do not worry, as we have not launched any model worker yet. It will be automatically updated when you launch a model worker. - -#### Launch a SGLang worker - -This is the recommended way to serve LLaVA model with high throughput, and you need to install SGLang first. Note that currently `4-bit` quantization is not supported yet on SGLang-LLaVA, and if you have limited GPU VRAM, please check out model worker with [quantization](https://github.com/haotian-liu/LLaVA?tab=readme-ov-file#launch-a-model-worker-4-bit-8-bit-inference-quantized). - -```Shell -pip install "sglang[all]" -``` - -You'll first launch a SGLang backend worker which will execute the models on GPUs. Remember the `--port` you've set and you'll use that later. - -```Shell -# Single GPU -CUDA_VISIBLE_DEVICES=0 python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --port 30000 - -# Multiple GPUs with tensor parallel -CUDA_VISIBLE_DEVICES=0,1 python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-13b --tokenizer-path llava-hf/llava-1.5-13b-hf --port 30000 --tp 2 -``` - -Tokenizers (temporary): `llava-hf/llava-1.5-7b-hf`, `llava-hf/llava-1.5-13b-hf`, `liuhaotian/llava-v1.6-34b-tokenizer`. - -You'll then launch a LLaVA-SGLang worker that will communicate between LLaVA controller and SGLang backend to route the requests. Set `--sgl-endpoint` to `http://127.0.0.1:port` where `port` is the one you just set (default: 30000). - -```Shell -python -m llava.serve.sglang_worker --host 0.0.0.0 --controller http://localhost:10000 --port 40000 --worker http://localhost:40000 --sgl-endpoint http://127.0.0.1:30000 -``` - -#### Launch a model worker - -This is the actual *worker* that performs the inference on the GPU. Each worker is responsible for a single model specified in `--model-path`. - -```Shell -python -m llava.serve.model_worker --host 0.0.0.0 --controller http://localhost:10000 --port 40000 --worker http://localhost:40000 --model-path liuhaotian/llava-v1.5-13b -``` -Wait until the process finishes loading the model and you see "Uvicorn running on ...". Now, refresh your Gradio web UI, and you will see the model you just launched in the model list. - -You can launch as many workers as you want, and compare between different model checkpoints in the same Gradio interface. Please keep the `--controller` the same, and modify the `--port` and `--worker` to a different port number for each worker. -```Shell -python -m llava.serve.model_worker --host 0.0.0.0 --controller http://localhost:10000 --port --worker http://localhost: --model-path -``` - -If you are using an Apple device with an M1 or M2 chip, you can specify the mps device by using the `--device` flag: `--device mps`. - -#### Launch a model worker (Multiple GPUs, when GPU VRAM <= 24GB) - -If the VRAM of your GPU is less than 24GB (e.g., RTX 3090, RTX 4090, etc.), you may try running it with multiple GPUs. Our latest code base will automatically try to use multiple GPUs if you have more than one GPU. You can specify which GPUs to use with `CUDA_VISIBLE_DEVICES`. Below is an example of running with the first two GPUs. - -```Shell -CUDA_VISIBLE_DEVICES=0,1 python -m llava.serve.model_worker --host 0.0.0.0 --controller http://localhost:10000 --port 40000 --worker http://localhost:40000 --model-path liuhaotian/llava-v1.5-13b -``` - -#### Launch a model worker (4-bit, 8-bit inference, quantized) - -You can launch the model worker with quantized bits (4-bit, 8-bit), which allows you to run the inference with reduced GPU memory footprint, potentially allowing you to run on a GPU with as few as 12GB VRAM. Note that inference with quantized bits may not be as accurate as the full-precision model. Simply append `--load-4bit` or `--load-8bit` to the **model worker** command that you are executing. Below is an example of running with 4-bit quantization. - -```Shell -python -m llava.serve.model_worker --host 0.0.0.0 --controller http://localhost:10000 --port 40000 --worker http://localhost:40000 --model-path liuhaotian/llava-v1.5-13b --load-4bit -``` - -#### Launch a model worker (LoRA weights, unmerged) - -You can launch the model worker with LoRA weights, without merging them with the base checkpoint, to save disk space. There will be additional loading time, while the inference speed is the same as the merged checkpoints. Unmerged LoRA checkpoints do not have `lora-merge` in the model name, and are usually much smaller (less than 1GB) than the merged checkpoints (13G for 7B, and 25G for 13B). - -To load unmerged LoRA weights, you simply need to pass an additional argument `--model-base`, which is the base LLM that is used to train the LoRA weights. You can check the base LLM of each LoRA weights in the [model zoo](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md). - -```Shell -python -m llava.serve.model_worker --host 0.0.0.0 --controller http://localhost:10000 --port 40000 --worker http://localhost:40000 --model-path liuhaotian/llava-v1-0719-336px-lora-vicuna-13b-v1.3 --model-base lmsys/vicuna-13b-v1.3 -``` - -### CLI Inference - -Chat about images using LLaVA without the need of Gradio interface. It also supports multiple GPUs, 4-bit and 8-bit quantized inference. With 4-bit quantization, for our LLaVA-1.5-7B, it uses less than 8GB VRAM on a single GPU. - -```Shell -python -m llava.serve.cli \ - --model-path liuhaotian/llava-v1.5-7b \ - --image-file "https://llava-vl.github.io/static/images/view.jpg" \ - --load-4bit -``` - - - -## Train - -*Below is the latest training configuration for LLaVA v1.5. For legacy models, please refer to README of [this](https://github.com/haotian-liu/LLaVA/tree/v1.0.1) version for now. We'll add them in a separate doc later.* - -LLaVA training consists of two stages: (1) feature alignment stage: use our 558K subset of the LAION-CC-SBU dataset to connect a *frozen pretrained* vision encoder to a *frozen LLM*; (2) visual instruction tuning stage: use 150K GPT-generated multimodal instruction-following data, plus around 515K VQA data from academic-oriented tasks, to teach the model to follow multimodal instructions. - -LLaVA is trained on 8 A100 GPUs with 80GB memory. To train on fewer GPUs, you can reduce the `per_device_train_batch_size` and increase the `gradient_accumulation_steps` accordingly. Always keep the global batch size the same: `per_device_train_batch_size` x `gradient_accumulation_steps` x `num_gpus`. - -### Hyperparameters -We use a similar set of hyperparameters as Vicuna in finetuning. Both hyperparameters used in pretraining and finetuning are provided below. - -1. Pretraining - -| Hyperparameter | Global Batch Size | Learning rate | Epochs | Max length | Weight decay | -| --- | ---: | ---: | ---: | ---: | ---: | -| LLaVA-v1.5-13B | 256 | 1e-3 | 1 | 2048 | 0 | - -2. Finetuning - -| Hyperparameter | Global Batch Size | Learning rate | Epochs | Max length | Weight decay | -| --- | ---: | ---: | ---: | ---: | ---: | -| LLaVA-v1.5-13B | 128 | 2e-5 | 1 | 2048 | 0 | - -### Download Vicuna checkpoints (automatically) - -Our base model Vicuna v1.5, which is an instruction-tuned chatbot, will be downloaded automatically when you run our provided training scripts. No action is needed. - -### Pretrain (feature alignment) - -Please download the 558K subset of the LAION-CC-SBU dataset with BLIP captions we use in the paper [here](https://huggingface.co/datasets/liuhaotian/LLaVA-Pretrain). - -Pretrain takes around 5.5 hours for LLaVA-v1.5-13B on 8x A100 (80G), due to the increased resolution to 336px. It takes around 3.5 hours for LLaVA-v1.5-7B. - -Training script with DeepSpeed ZeRO-2: [`pretrain.sh`](https://github.com/haotian-liu/LLaVA/blob/main/scripts/v1_5/pretrain.sh). - -- `--mm_projector_type mlp2x_gelu`: the two-layer MLP vision-language connector. -- `--vision_tower openai/clip-vit-large-patch14-336`: CLIP ViT-L/14 336px. - -
-Pretrain takes around 20 hours for LLaVA-7B on 8x V100 (32G) - - We provide training script with DeepSpeed [here](https://github.com/haotian-liu/LLaVA/blob/main/scripts/pretrain_xformers.sh). -Tips: -- If you are using V100 which is not supported by FlashAttention, you can use the [memory-efficient attention](https://arxiv.org/abs/2112.05682) implemented in [xFormers](https://github.com/facebookresearch/xformers). Install xformers and replace `llava/train/train_mem.py` above with [llava/train/train_xformers.py](llava/train/train_xformers.py). -
- -### Visual Instruction Tuning - -1. Prepare data - -Please download the annotation of the final mixture our instruction tuning data [llava_v1_5_mix665k.json](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K/blob/main/llava_v1_5_mix665k.json), and download the images from constituting datasets: - -- COCO: [train2017](http://images.cocodataset.org/zips/train2017.zip) -- GQA: [images](https://downloads.cs.stanford.edu/nlp/data/gqa/images.zip) -- OCR-VQA: [download script](https://drive.google.com/drive/folders/1_GYPY5UkUy7HIcR0zq3ZCFgeZN7BAfm_?usp=sharing), **we save all files as `.jpg`** -- TextVQA: [train_val_images](https://dl.fbaipublicfiles.com/textvqa/images/train_val_images.zip) -- VisualGenome: [part1](https://cs.stanford.edu/people/rak248/VG_100K_2/images.zip), [part2](https://cs.stanford.edu/people/rak248/VG_100K_2/images2.zip) - -After downloading all of them, organize the data as follows in `./playground/data`, - ``` -├── coco -│ └── train2017 -├── gqa -│ └── images -├── ocr_vqa -│ └── images -├── textvqa -│ └── train_images -└── vg - ├── VG_100K - └── VG_100K_2 -``` - -2. Start training! -You may download our pretrained projectors in [Model Zoo](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md). It is not recommended to use legacy projectors, as they may be trained with a different version of the codebase, and if any option is off, the model will not function/train as we expected. +### Project Navigation +Please checkout the following page for more inference & evaluation details. -Visual instruction tuning takes around 20 hours for LLaVA-v1.5-13B on 8x A100 (80G), due to the increased resolution to 336px. It takes around 10 hours for LLaVA-v1.5-7B on 8x A100 (40G). +#### - **LLaVA-OneVision: Easy Task Transfer** +- [LLaVA-OneVision]([./docs/LLaVA-NeXT.md](https://github.com/LLaVA-VL/LLaVA-NeXT/blob/main/docs/LLaVA_OneVision.md)): for demo inference. The evaluation code is in [lmms-eval](https://github.com/EvolvingLMMs-Lab/lmms-eval). -Training script with DeepSpeed ZeRO-3: [`finetune.sh`](https://github.com/haotian-liu/LLaVA/blob/main/scripts/v1_5/finetune.sh). +#### - **LLaVA-NeXT: Stronger LLMs Supercharge Multimodal Capabilities in the Wild** +- [LLaVA-NeXT-Image](./docs/LLaVA-NeXT.md): for image demo inference and evaluation of stronger LMMs using [lmms-eval](https://github.com/EvolvingLMMs-Lab/lmms-eval). -If you are do not have enough GPU memory: -- Use LoRA: [`finetune_lora.sh`](https://github.com/haotian-liu/LLaVA/blob/main/scripts/v1_5/finetune_lora.sh). We are able to fit 13B training in 8-A100-40G/8-A6000, and 7B training in 8-RTX3090. Make sure `per_device_train_batch_size*gradient_accumulation_steps` is the same as the provided script for best reproducibility. -- Replace `zero3.json` with `zero3_offload.json` which offloads some parameters to CPU RAM. This slows down the training speed. +#### - LLaVA-NeXT: A Strong Zero-shot Video Understanding Model +- [LLaVA-NeXT-Video](./docs/LLaVA-NeXT-Video.md): for video inference and evaluation scripts. We recommend to use [LMMs-video](https://lmms-lab.github.io/posts/lmms-eval-0.2/) for evaluation. -If you are interested in finetuning LLaVA model to your own task/data, please check out [`Finetune_Custom_Data.md`](https://github.com/haotian-liu/LLaVA/blob/main/docs/Finetune_Custom_Data.md)。 +#### - LLaVA-NeXT: Tackling Multi-image, Video, and 3D in Large Multimodal Models +- [LLaVA-NeXT-Interleave](./docs/LLaVA-NeXT-Interleave.md): for multi-image demo and evaluation scripts. -New options to note: +## SGLang for SpeedUp Inference and Deployment -- `--mm_projector_type mlp2x_gelu`: the two-layer MLP vision-language connector. -- `--vision_tower openai/clip-vit-large-patch14-336`: CLIP ViT-L/14 336px. -- `--image_aspect_ratio pad`: this pads the non-square images to square, instead of cropping them; it slightly reduces hallucination. -- `--group_by_modality_length True`: this should only be used when your instruction tuning dataset contains both language (e.g. ShareGPT) and multimodal (e.g. LLaVA-Instruct). It makes the training sampler only sample a single modality (either image or language) during training, which we observe to speed up training by ~25%, and does not affect the final outcome. +We use [SGLang](https://github.com/sgl-project/sglang) to speed up inference and deployment of LLaVA-NeXT. You could make LLaVA-NeXT as a backend API service with SGLang. -## Evaluation +**Prepare Environment**: + Following the instruction in the [sglang](https://github.com/sgl-project/sglang?tab=readme-ov-file#install) -In LLaVA-1.5, we evaluate models on a diverse set of 12 benchmarks. To ensure the reproducibility, we evaluate the models with greedy decoding. We do not evaluate using beam search to make the inference process consistent with the chat demo of real-time outputs. +### LLaVA-NeXT/OneVision -See [Evaluation.md](https://github.com/haotian-liu/LLaVA/blob/main/docs/Evaluation.md). +Checkout the HTTP Post/Get and SRT usage at [sglang/examples/runtime/llava_onevision](https://github.com/sgl-project/sglang/tree/main/examples/runtime/llava_onevision) -### GPT-assisted Evaluation +### LLaVA-NeXT (Video) -Our GPT-assisted evaluation pipeline for multimodal modeling is provided for a comprehensive understanding of the capabilities of vision-language models. Please see our paper for more details. +**Launch and Run on (K) Nodes**: +- Go to sglang project + ``` + cd PATH_TO/sglang + ``` +- First node: + ```sh + bash examples/usage/llava_video/srt_example_llava_v.sh K 0 YOUR_VIDEO_PATH YOUR_MODEL_PATH FRAMES_PER_VIDEO + (e.g. bash examples/usage/llava_video/srt_example_llava_v.sh K 0 examples/usage/llava_video/videos/Q98Z4OTh8RwmDonc.mp4 lmms-lab/LLaVA-NeXT-Video-7B-DPO 16) + ``` +- Second node: + ```sh + bash examples/usage/llava_video/srt_example_llava_v.sh K 1 YOUR_VIDEO_PATH YOUR_MODEL_PATH FRAMES_PER_VIDEO + ``` +- The K node: + ```sh + bash examples/usage/llava_video/srt_example_llava_v.sh K K-1 YOUR_VIDEO_PATH YOUR_MODEL_PATH FRAMES_PER_VIDEO + ``` -1. Generate LLaVA responses -```Shell -python model_vqa.py \ - --model-path ./checkpoints/LLaVA-13B-v0 \ - --question-file \ - playground/data/coco2014_val_qa_eval/qa90_questions.jsonl \ - --image-folder \ - /path/to/coco2014_val \ - --answers-file \ - /path/to/answer-file-our.jsonl -``` +## Citation -2. Evaluate the generated responses. In our case, [`answer-file-ref.jsonl`](./playground/data/coco2014_val_qa_eval/qa90_gpt4_answer.jsonl) is the response generated by text-only GPT-4 (0314), with the context captions/boxes provided. - -```Shell -OPENAI_API_KEY="sk-***********************************" python llava/eval/eval_gpt_review_visual.py \ - --question playground/data/coco2014_val_qa_eval/qa90_questions.jsonl \ - --context llava/eval/table/caps_boxes_coco2014_val_80.jsonl \ - --answer-list \ - /path/to/answer-file-ref.jsonl \ - /path/to/answer-file-our.jsonl \ - --rule llava/eval/table/rule.json \ - --output /path/to/review.json -``` +If you find it useful for your research and applications, please cite related papers/blogs using this BibTeX: +```bibtex +@article{li2024llava, + title={LLaVA-NeXT-Interleave: Tackling Multi-image, Video, and 3D in Large Multimodal Models}, + author={Li, Feng and Zhang, Renrui and Zhang, Hao and Zhang, Yuanhan and Li, Bo and Li, Wei and Ma, Zejun and Li, Chunyuan}, + journal={arXiv preprint arXiv:2407.07895}, + year={2024} +} -3. Summarize the evaluation results +@misc{li2024llavanext-ablations, + title={LLaVA-NeXT: What Else Influences Visual Instruction Tuning Beyond Data?}, + url={https://llava-vl.github.io/blog/2024-05-25-llava-next-ablations/}, + author={Li, Bo and Zhang, Hao and Zhang, Kaichen and Guo, Dong and Zhang, Yuanhan and Zhang, Renrui and Li, Feng and Liu, Ziwei and Li, Chunyuan}, + month={May}, + year={2024} +} -```Shell -python summarize_gpt_review.py -``` +@misc{li2024llavanext-strong, + title={LLaVA-NeXT: Stronger LLMs Supercharge Multimodal Capabilities in the Wild}, + url={https://llava-vl.github.io/blog/2024-05-10-llava-next-stronger-llms/}, + author={Li, Bo and Zhang, Kaichen and Zhang, Hao and Guo, Dong and Zhang, Renrui and Li, Feng and Zhang, Yuanhan and Liu, Ziwei and Li, Chunyuan}, + month={May}, + year={2024} +} -## Citation +@misc{zhang2024llavanext-video, + title={LLaVA-NeXT: A Strong Zero-shot Video Understanding Model}, + url={https://llava-vl.github.io/blog/2024-04-30-llava-next-video/}, + author={Zhang, Yuanhan and Li, Bo and Liu, haotian and Lee, Yong jae and Gui, Liangke and Fu, Di and Feng, Jiashi and Liu, Ziwei and Li, Chunyuan}, + month={April}, + year={2024} +} -If you find LLaVA useful for your research and applications, please cite using this BibTeX: -```bibtex @misc{liu2024llavanext, title={LLaVA-NeXT: Improved reasoning, OCR, and world knowledge}, url={https://llava-vl.github.io/blog/2024-01-30-llava-next/}, @@ -448,6 +223,8 @@ If you find LLaVA useful for your research and applications, please cite using t ## Acknowledgement - [Vicuna](https://github.com/lm-sys/FastChat): the codebase we built upon, and our base model Vicuna-13B that has the amazing language capabilities! +- The LLaVA-NeXT project is currently maintained by the team along with our contributors (listed alphabetically by the first names): [Bo Li](https://brianboli.com/), [Dong Guo](https://www.linkedin.com/in/dongguoset/), [Feng Li](https://scholar.google.com/citations?hl=zh-CN&user=ybRe9GcAAAAJ&view_op=list_works&sortby=pubdate), [Hao Zhang](https://scholar.google.com/citations?user=B8hPxMQAAAAJ&hl=en), [Kaichen Zhang](https://www.linkedin.com/in/kaichen-zhang-014b17219/?originalSubdomain=sg), [Renrui Zhang](https://zrrskywalker.github.io/), [Yuanhan Zhang](https://zhangyuanhan-ai.github.io/), led by [Chunyuan Li](https://chunyuan.li/) and with the guidance and help from [Haotian Liu](https://hliu.cc/). +- The `lmms-eval` framework and its core contributors, including Peiyuan Zhang, Fanyi Pu, Joshua Adrian Cahyono, and Kairui Hu, for their support on the evaluation side. ## Related Projects diff --git a/docs/LLaVA-NeXT-Interleave.md b/docs/LLaVA-NeXT-Interleave.md new file mode 100644 index 000000000..c05183726 --- /dev/null +++ b/docs/LLaVA-NeXT-Interleave.md @@ -0,0 +1,53 @@ + +# LLaVA-NeXT: Tackling Multi-image, Video, and 3D in Large Multimodal Models + +## Contents +- [Demo](#demo) +- [Evaluation](#evaluation) + +## Demo + +> make sure you installed the LLaVA-NeXT model files via outside REAME.md + +1. **Example model:** `lmms-lab/llava-next-interleave-7b` + + +To run a demo, execute: +```bash +# If you find any bug when running the demo, please make sure checkpoint path contains 'qwen'. +# You can try command like 'mv llava-next-interleave-7b llava-next-interleave-qwen-7b' +python playground/demo/interleave_demo.py --model_path path/to/ckpt +``` + +## Evaluation + +### Preparation + +Please download the evaluation data and its metadata from the following links: + +1. **llava-interleave-bench:** [here](https://huggingface.co/datasets/lmms-lab/llava-interleave-bench). + +Unzip eval_images.zip and there are Split1 and Split2 in it. +Organize the downloaded data into the following structure: +``` + +interleave_data +├── Split1 +│ ├── ... +│ └── ... +| +├── Split2 +| ├── ... +│ └── ... +├── multi_image_in_domain.json +├── multi_image_out_domain.json +└── multi_view_in_domain.json +``` + +### Inference and Evaluation +Example: +Please first edit /path/to/ckpt to the path of checkpoint, /path/to/images to the path of "interleave_data" in scripts/interleave/eval_all.sh and then run +```bash +bash scripts/interleave/eval_all.sh +``` + diff --git a/docs/LLaVA-NeXT-Video.md b/docs/LLaVA-NeXT-Video.md new file mode 100644 index 000000000..0f275f380 --- /dev/null +++ b/docs/LLaVA-NeXT-Video.md @@ -0,0 +1,81 @@ + +# LLaVA-NeXT: A Strong Zero-shot Video Understanding Model + +## Contents +- [Demo](#demo) +- [Evaluation](#evaluation) + +## Demo + +> make sure you installed the LLaVA-NeXT model files via outside REAME.md + +1. **Example model:** `lmms-lab/LLaVA-NeXT-Video-7B-DPO` + +2. **Prompt mode:** `vicuna_v1` (use `mistral_direct` for `lmms-lab/LLaVA-NeXT-Video-34B-DPO`) + +3. **Sampled frames:** `32` (Defines how many frames to sample from the video.) + +4. **Spatial pooling stride:** `2` (With original tokens for one frame at 24x24, if stride=2, then the tokens for one frame are 12x12.) + +5. **Spatial pooling mode:** `average` (Options: `average`, `max`.) + +6. **Local video path:** `./data/llava_video/video-chatgpt/evaluation/Test_Videos/v_Lf_7RurLgp0.mp4` + +To run a demo, execute: +```bash +bash scripts/video/demo/video_demo.sh ${Example model} ${Prompt mode} ${Sampled frames} ${Spatial pooling stride} ${Spatial pooling mode} grid True ${Video path at local} +``` +Example: +```bash +bash scripts/video/demo/video_demo.sh lmms-lab/LLaVA-NeXT-Video-7B-DPO vicuna_v1 32 2 average no_token True playground/demo/xU25MMA2N4aVtYay.mp4 +``` + +**IMPORTANT** Please refer to [Latest video model](https://github.com/LLaVA-VL/LLaVA-NeXT/blob/inference/docs/LLaVA-NeXT-Video_0716.md) for the runnning of the latest model. + +## Evaluation + +### Preparation + +Please download the evaluation data and its metadata from the following links: + +1. **video-chatgpt:** [here](https://github.com/mbzuai-oryx/Video-ChatGPT/blob/main/quantitative_evaluation/README.md#video-based-generative-performance-benchmarking). +2. **video_detail_description:** [here](https://mbzuaiac-my.sharepoint.com/personal/hanoona_bangalath_mbzuai_ac_ae/_layouts/15/onedrive.aspx?id=%2Fpersonal%2Fhanoona%5Fbangalath%5Fmbzuai%5Fac%5Fae%2FDocuments%2FVideo%2DChatGPT%2FData%5FCode%5FModel%5FRelease%2FQuantitative%5FEvaluation%2Fbenchamarking%2FTest%5FHuman%5FAnnotated%5FCaptions%2Ezip&parent=%2Fpersonal%2Fhanoona%5Fbangalath%5Fmbzuai%5Fac%5Fae%2FDocuments%2FVideo%2DChatGPT%2FData%5FCode%5FModel%5FRelease%2FQuantitative%5FEvaluation%2Fbenchamarking&ga=1). +3. **activity_qa:** [here](https://mbzuaiac-my.sharepoint.com/personal/hanoona_bangalath_mbzuai_ac_ae/_layouts/15/onedrive.aspx?id=%2Fpersonal%2Fhanoona%5Fbangalath%5Fmbzuai%5Fac%5Fae%2FDocuments%2FVideo%2DChatGPT%2FData%5FCode%5FModel%5FRelease%2FData%2FActivityNet%5FTest%2D1%2D3%5Fvideos%2Ezip&parent=%2Fpersonal%2Fhanoona%5Fbangalath%5Fmbzuai%5Fac%5Fae%2FDocuments%2FVideo%2DChatGPT%2FData%5FCode%5FModel%5FRelease%2FData&ga=1) and [here](https://github.com/MILVLG/activitynet-qa/tree/master/dataset). + +Organize the downloaded data into the following structure: +``` +LLaVA-NeXT +├── llava +├── scripts +└── data + └── llava_video + ├── video-chatgpt + │ ├── Test_Videos + │ ├── consistency_qa.json + │ ├── consistency_qa_test.json + │ ├── consistency_qa_train.json + ├── video_detail_description + │ └── Test_Human_Annotated_Captions + └── ActivityNet-QA + ├── all_test + ├── test_a.json + └── test_b.json +``` + +### Inference and Evaluation + +Example for video detail description evaluation (additional scripts are available in `scripts/eval`): +```bash +bash scripts/video/eval/video_detail_description_eval_shard.sh ${Example model} ${Prompt mode} ${Sampled frames} ${Spatial pooling stride} True 8 +``` +Example: +```bash +bash scripts/eval/video_detail_description_eval_shard.sh liuhaotian/llava-v1.6-vicuna-7b vicuna_v1 32 2 True 8 +``` + +### GPT Evaluation Example (Optional if the above step is completed) + +Assuming you have `pred.json` (model-generated predictions) for model `llava-v1.6-vicuna-7b` at `./work_dirs/eval_video_detail_description/llava-v1.6-vicuna-7b_vicuna_v1_frames_32_stride_2`: +```bash +bash scripts/video/eval/video_description_eval_only.sh llava-v1.6-vicuna-7b_vicuna_v1_frames_32_stride_2 +``` diff --git a/docs/LLaVA-NeXT-Video_0716.md b/docs/LLaVA-NeXT-Video_0716.md new file mode 100644 index 000000000..88460bdfa --- /dev/null +++ b/docs/LLaVA-NeXT-Video_0716.md @@ -0,0 +1,42 @@ +## LLaVA-NeXT-Video is upgraded 🚀 + +In our [LLaVA-Video blog](https://llava-vl.github.io/blog/2024-04-30-llava-next-video/) released this April, we shared two key observations: +- 🎬 AnyRes provides a shared and flexible representation between images and videos, and thus accommodates capability transfer between the two most common vision signals. Therefore, stronger image LMMs can naturally lead to stronger zero-shot video LMMs. +- 🗂️ There is a lack of high-quality language-video data, including video instruction-following data, and thus naive tuning on existing public data at that time results in performance degradation. Therefore, there is an urgent need to build high-quality video captions and QA datasets to train LMMs for improved video performance. + +Based on the insights, the new LLaVA-NeXT-Video in this release improves from two aspects: + +- 🎬 A stronger image LMMs ([LLaVA-NeXT-32B-Qwen](https://huggingface.co/lmms-lab/llava-next-qwen-32b)), which is built by initializing from Qwen-1.5 32B LLM. We further initialize our video training from this image checkpoint. +- 🗂️ A new high-quality video dataset with 830k samples. It is combined with LLaVA-1.6 image training data, and applying the same image-video mixed training procedure leads to the new video model. +The new model achieves the best open-source performance in several video benchmarks including [Video-MME](https://video-mme.github.io/home_page.html#leaderboard). + +### Resources +- **Model Card**: [LLaVA-NeXT-Video-32B-Qwen on Hugging Face](https://huggingface.co/lmms-lab/LLaVA-NeXT-Video-32B-Qwen) +- **Inference Script**: + ```bash + bash scripts/video/demo/video_demo.sh lmms-lab/LLaVA-NeXT-Video-32B-Qwen qwen_1_5 32 2 average grid True playground/demo/xU25MMA2N4aVtYay.mp4 + ``` + +### Evaluation Results +| Model | NextQA-MC | video-mme(overall) | | Egochema | Perception Test (val) | +|-----------------------------|-----------|--------------------|--------|----------|------------------------| +| | | w/o subs | w subs | | | +| **Proprietary** | | | | | | +| GPT-4o | - | 71.9 | 77.2 | 72.2 | - | +| Gemini 1.5 Pro | - | 75.0 | 81.3 | 72.2 | - | +| **Open-Source** | | | | | | +| VideoLLaMA 2 (8x7B) | 76.3* | 47.9 | 50.3 | 53.3 | 51.2* | +| VILA-1.5-34B | 67.89* | 60.1 | 61.1 | 58.04* | 54 | +| LLaVA-NeXT-Video (Qwen-32B) | 77.31 | 60.2 | 63.0 | 60.85 | 59.38 | + +_*Results are reproduced by [lmms-eval](https://github.com/EvolvingLMMs-Lab/lmms-eval). Please refer to the lmms-eval to reproduce the results._ + +### Citations +```bibtex +@misc{zhang2024llavanextvideo, + title={LLaVA-NeXT: A Strong Zero-shot Video Understanding Model}, + url={https://llava-vl.github.io/blog/2024-04-30-llava-next-video/}, + author={Zhang, Yuanhan and Li, Bo and Liu, haotian and Lee, Yong jae and Gui, Liangke and Fu, Di and Feng, Jiashi and Liu, Ziwei and Li, Chunyuan}, + month={April}, + year={2024} +} diff --git a/docs/LLaVA-NeXT.md b/docs/LLaVA-NeXT.md new file mode 100644 index 000000000..777a961f3 --- /dev/null +++ b/docs/LLaVA-NeXT.md @@ -0,0 +1,91 @@ +# LLaVA-NeXT: Stronger LLMs Supercharge Multimodal Capabilities in the Wild + +## Quick Start With HuggingFace +First please install our repo with code and environments: `pip install git+https://github.com/LLaVA-VL/LLaVA-NeXT.git` + +Here is a quick inference code using [`llavanext-llama3-8B`](https://huggingface.co/lmms-lab/llama3-llava-next-8b) as an example. You will need to install [`flash-attn`](https://github.com/Dao-AILab/flash-attention) to use this code snippet. If you don't want to install it, you can set `attn_implementation=None` when load_pretrained_model +```python +from llava.model.builder import load_pretrained_model +from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token +from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IGNORE_INDEX +from llava.conversation import conv_templates, SeparatorStyle + +from PIL import Image +import requests +import copy +import torch + +pretrained = "lmms-lab/llama3-llava-next-8b" +model_name = "llava_llama3" +device = "cuda" +device_map = "auto" +tokenizer, model, image_processor, max_length = load_pretrained_model(pretrained, None, model_name, device_map=device_map) # Add any other thing you want to pass in llava_model_args + +model.eval() +model.tie_weights() + +url = "https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true" +image = Image.open(requests.get(url, stream=True).raw) +image_tensor = process_images([image], image_processor, model.config) +image_tensor = [_image.to(dtype=torch.float16, device=device) for _image in image_tensor] + +conv_template = "llava_llama_3" # Make sure you use correct chat template for different models +question = DEFAULT_IMAGE_TOKEN + "\nWhat is shown in this image?" +conv = copy.deepcopy(conv_templates[conv_template]) +conv.append_message(conv.roles[0], question) +conv.append_message(conv.roles[1], None) +prompt_question = conv.get_prompt() + +input_ids = tokenizer_image_token(prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device) +image_sizes = [image.size] + + +cont = model.generate( + input_ids, + images=image_tensor, + image_sizes=image_sizes, + do_sample=False, + temperature=0, + max_new_tokens=256, +) +text_outputs = tokenizer.batch_decode(cont, skip_special_tokens=True) +print(text_outputs) +# The image shows a radar chart, also known as a spider chart or a web chart, which is a type of graph used to display multivariate data in the form of a two-dimensional chart of three or more quantitative variables represented on axes starting from the same point. Each axis represents a different variable, and the values are plotted along each axis and connected to form a polygon.\n\nIn this particular radar chart, there are several axes labeled with different variables, such as "MM-Vet," "LLaVA-Bench," "SEED-Bench," "MMBench-CN," "MMBench," "TextVQA," "VizWiz," "GQA," "BLIP-2," "InstructBLIP," "Owen-VL-Chat," and "LLaVA-1.5." These labels suggest that the chart is comparing the performance of different models or systems across various benchmarks or tasks, such as machine translation, visual question answering, and text-based question answering.\n\nThe chart is color-coded, with each color representing a different model or system. The points on the chart are connected to form a polygon, which shows the relative performance of each model across the different benchmarks. The closer the point is to the outer edge of the +``` + +## Evaluation + +**Install the evaluation package:** +```bash +# make sure you installed the LLaVA-NeXT model files via outside REAME.md +pip install git+https://github.com/EvolvingLMMs-Lab/lmms-eval.git +``` + +### Check the evaluation results with [lmms-eval](https://github.com/EvolvingLMMs-Lab/lmms-eval) +Our models' evaluation results can be fully reproduced by using the [`lmms-eval`](https://github.com/EvolvingLMMs-Lab/lmms-eval) toolkit. After you install lmms-eval and llava, you can run the evaluation using the following commands. To run following commands, you will have to install [`flash-attn`](https://github.com/Dao-AILab/flash-attention). If you do not want to install it, you can disable the flash-attn by specifying it in `--model_args pretrained=lmms-lab/llama3-llava-next-8b,conv_template=llava_llama_3,attn_implementation=None`. + +Please note that different torch versions might causing the results to vary. + +```shell +# Evaluating Llama-3-LLaVA-NeXT-8B on multiple datasets +accelerate launch --num_processes=8 \ + -m lmms_eval \ + --model llava \ + --model_args pretrained=lmms-lab/llama3-llava-next-8b,conv_template=llava_llama_3 \ + --tasks ai2d,chartqa,docvqa_val,mme,mmbench_en_dev \ + --batch_size 1 \ + --log_samples \ + --log_samples_suffix llava_next \ + --output_path ./logs/ + +# Evaluating LLaVA-NeXT-72B on multiple datasets +accelerate launch --num_processes=1 \ + -m lmms_eval \ + --model llava \ + --model_args pretrained=lmms-lab/llava-next-72b,conv_template=qwen_1_5,model_name=llava_qwen,device_map=auto \ + --tasks ai2d,chartqa,docvqa_val,mme,mmbench_en_dev \ + --batch_size 1 \ + --log_samples \ + --log_samples_suffix llava_next \ + --output_path ./logs/ +``` diff --git a/docs/LLaVA_OneVision.md b/docs/LLaVA_OneVision.md new file mode 100644 index 000000000..3aa9a8d5b --- /dev/null +++ b/docs/LLaVA_OneVision.md @@ -0,0 +1,94 @@ +# LLaVA OneVision + +## Model Details + +LLaVA OneVision is a multi-modal model capable of processing images, text, image-text interleaved inputs, and videos. The model is trained in multiple stages: + +1. Stage-1: Initial training on 558K samples from the LCS dataset. +2. Stage-1.5: Training on 4M high-quality samples with detailed captions, OCR and knowledge data. +3. Stage-2: + - Single-Image: Training on 3.2M instruction-following image samples. + - OneVision: Training on 1.6M single-image, multi-image and video samples with instructions. + +Key features: +- Supports various input resolutions up to 2304 * 2304 pixels. +- Single image input is represented by 729 * (9+1) tokens at most under `anyres_max_9` mode. +- Supports multi-image and video inputs. Multi-image input is represented by 729 token for each image, and video input is represented by 196 token for each frame. +- Available in three sizes: 0.5B, 7B and 72B parameter versions, fit for different memory and inference latency requirements. + +Some Implementation Details: +- Trained using a combination of vision-specific (AdamW, 2e-6) and language model (AdamW, 1e-5) learning rates. +- Each stage is trained for 1 epoch. + +The model uses [SO400M](https://huggingface.co/collections/google/siglip-659d5e62f0ae1a57ae0e83ba) as the vision encoder and [Qwen-2.0](https://huggingface.co/docs/transformers/model_doc/qwen2) as the language model, with trainable components including a projector and the full model in later stages. + +We recommend to use the scripts in [training](../scripts/) to get the details of the training process. + +## Inference Guidance + +We recommend to follow the [tutorial](./LLaVA_OneVision_Tutorials.ipynb) to get started on using our most basic 0.5B model for image, text, image-text interleaved, and video input. We use our 0.5B version as an example. This could be running on a GPU with 4GB memory. And with the following examples, you could see it's surprisingly have promising performance on understanding the image, interleaved image-text, and video. Tiny but mighty! + +## Evaluation Guidance + +We use the [lmms-eval](https://github.com/EvolvingLMMs-Lab/lmms-eval) toolkit to evaluate our models. Ensure you have installed the LLaVA-NeXT model files as per the instructions in the main README.md. + +Install lmms-eval: + +> pip install git+https://github.com/EvolvingLMMs-Lab/lmms-eval.git + +### Reproducing Evaluation Results + +Our models' evaluation results can be fully reproduced using the lmms-eval toolkit. After installing lmms-eval and llava, you can run the evaluation using the following commands. + +Note: These commands require flash-attn. If you prefer not to install it, disable flash-attn by adding `attn_implementation=None` to the `--model_args` parameter. + +Important: Different torch versions may cause slight variations in results. By default in `lmms-eval`, the requirement for torch version is set to the latest version. In `llava` repo, the torch version is set to `2.1.2`. Torch version `2.1.2` would be stable for both `llava` and `lmms-eval` + +### Evaluating LLaVA-OneVision on multiple datasets + +We recommend the developers and researchers to thoroughly evaluate the models on more datasets to get a comprehensive understanding of their performance in different scenarios. So we provide a comprehensive list of datasets for evaluation, and welcome to incoporate more evaluation tasks. Please refer to the [lmms-eval](https://github.com/EvolvingLMMs-Lab/lmms-eval) for more details. + + +Task: single-image tasks. + +```bash +# image tasks +accelerate launch --num_processes=8 \ +-m lmms_eval \ +--model llava_onevision \ +--model_args pretrained=lmms-lab/llava-onevision-qwen2-0.5b-si,conv_template=qwen_1_5,model_name=llava_qwen \ +--tasks ai2d,chartqa,docvqa_val,infovqa_val,mme,realworldqa,mathvista_testmini,llava_in_the_wild,mmvet,mmbench_en_dev,ocrbench,mmmu,mathverse_testmini_vision_intensive,mathverse_testmini_vision_only,seedbench,scienceqa_img,mmstar \ +--batch_size 1 \ +--log_samples \ +--log_samples_suffix llava_onevision \ +--output_path ./logs/ +``` + +Task: video tasks. The video tasks are more computationally expensive. We recommend running them on a machine with a GPU with at least 16GB memory. + +```bash +# video tasks +accelerate launch --num_processes=8 \ +-m lmms_eval \ +--model llava_onevision \ +--model_args pretrained=lmms-lab/llava-onevision-qwen2-0.5b-ov,conv_template=qwen_1_5,model_name=llava_qwen \ +--tasks activitynetqa,videochatgpt,nextqa_mc_test,egoschema,video_dc499,videmme,videomme_w_subtitle,perceptiontest_val_mc \ +--batch_size 1 \ +--log_samples \ +--log_samples_suffix llava_onevision \ +--output_path ./logs/ +``` + +Task: interleave tasks (`llava-interleave-bench` already contains most of existing image-text tasks). `mmmu_test` contains single image and multiple images as input, we run the model to obtain a submission file and you need to submit it to the [leaderboard](https://eval.ai/web/challenges/challenge-page/1700/overview) to get the accuracy for MMMU (multi-image) result. + +```bash +accelerate launch --num_processes=8 \ +-m lmms_eval \ +--model llava_onevision \ +--model_args pretrained=lmms-lab/llava-onevision-qwen2-0.5b-ov,conv_template=qwen_1_5,model_name=llava_qwen \ +--tasks llava-interleave-bench,muirbench,mmmu_test \ +--batch_size 1 \ +--log_samples \ +--log_samples_suffix llava_onevision \ +--output_path ./logs/ +``` diff --git a/docs/LLaVA_OneVision_Chat.md b/docs/LLaVA_OneVision_Chat.md new file mode 100644 index 000000000..9c31698ab --- /dev/null +++ b/docs/LLaVA_OneVision_Chat.md @@ -0,0 +1,157 @@ +# LLaVA-OneVision-Chat: Improving Chat with Preference Learning + +[LLaVA-OneVision](https://arxiv.org/abs/2408.03326) has demonstrated strong multimodal capabilities, showing excellent performance on various benchmarks in single-image, multi-image and video scenarios. However, we see potential for further improvement, particularly in its visual chat abilities. To achieve this, we've focused on enhancing the model through preference alignment, and our early experiments have produced some promising insights. + +### Key Observations: + +- **Impact of Preference Learning**: By incorporating alignment learning—whether through human feedback or AI-generated feedback—we've observed a notable improvement in LLaVA-OneVision's chat experience. This progress is reflected in the significant performance gains recorded on both the LLaVA-W and WildVision benchmarks. +- **Success of Self-Generated Feedback**: In LLaVA-OneVision's case, leveraging self-generated feedback data has proven to be a highly effective strategy for enhancing its visual chat capabilities. Specifically, [LLaVA-Critic](https://llava-vl.github.io/blog/2024-10-03-llava-critic/) is a utilized as a generalist evaluator to generate the scoring feedback for preference learning. This approach allows the model to refine its responses autonomously, leading to more natural and coherent conversations. + +---- + +### Release + +- 🤗 Model Checkpoints: [[OV-7b-Chat]](https://huggingface.co/lmms-lab/llava-onevision-qwen2-7b-ov-chat) | [[OV-72b-Chat]](https://huggingface.co/lmms-lab/llava-onevision-qwen2-72b-ov-chat) +- 💬 Demo: [https://llava-onevision.lmms-lab.com](https://llava-onevision.lmms-lab.com/) + +---- + + +### Results + +The figure below illustrates the performance gain of LLaVA-OV-Chat across 5 benchmarks. The delta numbers shown on top of the bars indicate the improvement of the chat model variant(7b/72b) over its base model LLaVA-OV. + +![](ov_chat_images/chat_results.png) + +| Model Name | WildVision | LLaVA-W | LLaVA-Wilder | LiveBench | Video Detailed Description | +|---------------------|------------|---------|--------------|-----------|----------------------------| +| LLaVA-OV-7b | 54.0 | 90.7 | 67.8 | 77.1 | 3.75 | +| LLaVA-OV-7b-Chat | 67.3 | 100.3 | 71.6 | 84.5 | 3.87 | +| LLaVA-OV-72b | 51.7 | 93.5 | 72.0 | 81.5 | 3.60 | +| LLaVA-OV-72b-Chat | 70.0 | 104.4 | 75.9 | 88.5 | 3.86 | +| GPT-4V | 79.8 | 98.0 | 81.0 | 73.7 | 4.00 | + +LLaVA-OV-Chat consistently showcases exceptional visual chat capabilities across five open-ended multimodal tasks. Notably, compared to its base version, LLaVA-OV, which was trained using next-token prediction, LLaVA-OV-Chat—enhanced through alignment learning—receives significantly higher preference on four image-based tasks. The *[WildVision Bench](https://huggingface.co/datasets/WildVision/wildvision-bench)* simulates real-world multimodal interactions based on human preferences, mirroring scenarios from the online platform [WildVision-Arena (WV-Arena)](https://huggingface.co/spaces/WildVision/vision-arena). Both *[LLaVA-W](https://huggingface.co/datasets/lmms-lab/llava-bench-in-the-wild)* and *[LLaVA-Wilder](https://huggingface.co/datasets/lmms-lab/LLaVA-Bench-Wilder)* tackle daily-life visual tasks, with the former focusing on simpler cases and the latter on more challenging ones. The *[Multimodal Live-Bench](https://huggingface.co/spaces/lmms-lab/LiveBench)* leverages continuously updating news and online forums to test models' ability to generalize unobserved knowledge. Despite the alignment learning being conducted solely with images, LLaVA-OV-Chat also excels in *[Video Detailed Captioning](https://llava-vl.github.io/blog/2024-04-30-llava-next-video/)*, demonstrating that its improved chat capabilities generalize effectively from image to video contexts. + +#### Examples + +| Example 1 | +| -- | +|
[Question]: How likely is it to snow after this picture was taken? What would change with this type of tree before it's likely to snow?
| +| **[LLaVA-OneVision]**: It is highly unlikely that it would snow immediately after this picture was taken, as the lush green leaves suggest it's either late spring or summer. Before snowfall, which typically occurs in late autumn or winter, this type of tree, likely an oak, would shed its leaves, and the branches would appear bare. | +| **[LLaVA-OneVision-Chat]**: Given the lush green leaves and the overall vibrancy of the tree in the image, it is highly unlikely that snow would follow immediately. This type of tree, which appears to be a deciduous species, typically sheds its leaves in autumn as a preparation for winter. Before it's likely to snow, the tree would undergo several changes:

1. **Color Change**: The leaves would change color from green to shades of yellow, orange, red, or brown as the chlorophyll breaks down, revealing other pigments.
2. **Leaf Fall**: As the days shorten and temperatures cool, the tree would gradually shed its leaves. This process is triggered by hormonal changes within the tree that cause the cells at the base of each leaf to weaken and eventually break, allowing the leaves to fall.
3. **Dormancy**: The tree would enter a state of dormancy, during which its metabolic processes slow down significantly. This helps the tree conserve energy and survive the cold winter months.
4. **Bark Changes**: The bark may become more rugged and textured as the tree ages and prepares for winter. This is a natural defense mechanism against cold temperatures and potential pests.
5. **Root Activity**: The roots would continue to absorb water and nutrients until the ground freezes, after which they would go dormant as well.
6. **Snow Preparation**: Once the tree has shed its leaves and entered dormancy, it would be better prepared to withstand snow and ice accumulation on its branches without the added weight of leaves.

In summary, the tree would need to go through the full autumn cycle of leaf color change and fall, followed by entering dormancy, before it would be likely to snow. This process can take several weeks to a few months, depending on the specific species of the tree and the local climate conditions.| + +| Example 2 | +| -- | +|
[Question]: Discuss how this creative twist on a classic work of art might be interpreted differently by various audiences.
| +| **[LLaVA-OneVision]**: Different audiences might interpret this creative twist on the Mona Lisa in various ways. Art enthusiasts might appreciate the skillful execution and the playful reinterpretation of a classic, seeing it as a commentary on the universality of beauty and expression across species. Dog lovers could find joy in seeing their favorite breed portrayed with such dignity and elegance. Some might view it as a humorous take on high art, making it more accessible and relatable. Others might see it as a statement on the anthropomorphism of pets, reflecting how we often attribute human-like qualities to animals. Overall, the varied interpretations highlight the versatility of art in engaging diverse perspectives and emotions. | +| **[LLaVA-OneVision-Chat]**: This creative twist on the Mona Lisa, featuring a dog's face, can be interpreted in various ways by different audiences. For art enthusiasts, it might be seen as a playful commentary on the nature of portraiture and the human tendency to anthropomorphize animals. It challenges traditional notions of beauty and subjectivity in art, inviting viewers to reconsider what makes a portrait compelling.

For dog lovers and pet owners, the image could evoke a sense of joy and familiarity, as they see their beloved pets reflected in one of the most iconic artworks in history. It might also highlight the deep bond between humans and their pets, suggesting that animals share in our cultural heritage.

In a broader social context, this image could be interpreted as a reflection of contemporary society's obsession with internet culture and memes. The substitution of the Mona Lisa's face with a dog's could be seen as a satirical take on how viral content often overshadows classical art in modern media.

Academically, this piece could spark discussions about copyright, artistic originality, and the public domain. It raises questions about the boundaries between homage and appropriation in art, encouraging critical thinking about intellectual property rights and creative freedom.

Overall, the diverse interpretations of this image underscore the power of art to engage, provoke, and connect people across different backgrounds and perspectives. | + +******** + +### Iterative DPO for Enhanced Performance + +#### Feedback Data + +To optimize LLaVA-OneVision’s in-the-wild conversational abilities, we've employed an iterative Direct Preference Optimization (DPO) process. Through this method, we found that feedback from two primary sources is particularly effective: + +1. **Human Feedback from [LLaVA-RLHF](https://llava-rlhf.github.io/)**: Real-world human input plays a crucial role in guiding the model toward more intuitive and user-friendly responses. + +2. **AI Feedback from LLaVA-OV’s Self-Generated Responses**: Additionally, the AI's own self-generated feedback allows it to continuously improve and adapt, making this a valuable source for iterative learning. [LLaVA-Critic](https://llava-vl.github.io/blog/2024-10-03-llava-critic/) is a utilized as a generalist evaluator to generate the scoring feedback for preference learning + +By experimenting with either of these two forms of feedback, we've been able to significantly enhance LLaVA-OneVision's conversation capabilities, bringing it closer to achieving seamless visual chat interactions in dynamic, real-world environments. + +#### Alignment Learning with Iterative DPO + +We provide a breakdown of the process for enhancing LLaVA-OneVision’s visual chat capabilities through iterative DPO. + +##### Requirements: + +1. **SFT Checkpoint**: We begin with a pretrained LLaVA-OneVision SFT (Supervised Fine-Tuning) model as the initial checkpoint for response generation. +2. **Preference Data**: The dataset used in our experiments consists of (language-image prompt, response, preference) pairs sourced from human feedback or AI feedback, which serves as the training data for the model to align with user preference to improve chat experience. + +##### Step 1: Response Generation + +For each langauge-image prompt in the dataset, we randomly generate `k = 5` candidate responses from the starting checkpoint. To ensure diversity in the generated responses, we employ random decoding with the following parameters: `temperature = 0.7`, `top-p (nucleus sampling) = 0.9`. These settings encourage the generation of varied responses by balancing randomness and precision, giving us a broad spectrum of potential answers for further evaluation. + +##### Step 2: Scoring and Acquiring Feedback Data + +Once the candidate responses are generated, we utilize a feedback source (e.g., the reward signals from LLaVA-RLHF or reward signals from LLaVA-Critic) to score each of them. The reward model is responsible for evaluating the quality of the responses based on relevance, coherence, and appropriateness in relation to the given image-question pair. From the scored responses, we then select: + +- The **best** response (highest score) +- The **worst** response (lowest score) + +These two responses serve as **pairwise feedback data** for the next phase of the training process. + +##### Step 3: Training with Iterative DPO + +Using the feedback data obtained in `Step 2`, we conduct DPO training in an iterative fashion. The process is as follows: + +1. In the $i^{th}$ round of training, we start with the pretrained model from the previous round $(i-1)^{th}$. +2. We generate new candidate responses by repeating the response generation process outlined in `Step 1`. +3. The reward source evaluates these new responses, and pairwise feedback data is acquired, as described in `Step 2`. +4. Finally, we apply DPO training to the model using the feedback data. Each round of DPO training lasts for **`1 epoch`**. + +This iterative process is repeated for `N=3` rounds in total, with each round refining the model’s ability to generate high-quality visual chat responses by progressively incorporating feedback from both human and AI assessments. + +**Training script and data format** + +- Example training script: [`/scripts/train/dpo_ov7b.sh`](../scripts/train/dpo_ov7b.sh) +- Format of training data: +~~~json +{ + "id": "", + "image": "", + "prompt": "", + "chosen": "", + "rejected": "" +} +~~~ + + +------ + +Check out on how we develop AI feedback for self-improvement LMMs, using [LLaVA-Critic](https://llava-vl.github.io/blog/2024-10-03-llava-critic/) as a generalist evaluator to generate the scoring feedback for preference learning! + +*Contributors to LLaVA-OneVision-Chat: [Tianyi Xiong](https://tyxiong23.github.io/), [Bo Li](https://brianboli.com/), [Dong Guo](https://www.linkedin.com/in/dongguoset/), [Huizhuo Yuan](https://scholar.google.com/citations?user=8foZzX4AAAAJ), [Quanquan Gu](https://web.cs.ucla.edu/~qgu/), [Chunyuan Li](https://scholar.google.com/citations?user=Zd7WmXUAAAAJ)* + + + +### Citation + +If you find it useful for your research and applications, please cite related papers/blogs using this BibTeX: +```bibtex +@misc{xiong2024llavaovchat, + title={LLaVA-OneVision-Chat: Improving Chat with Preference Learning}, + url={https://github.com/LLaVA-VL/LLaVA-NeXT/blob/main/docs/LLaVA_OneVision_Chat.md}, + author={Xiong, Tianyi and Li, Bo and Guo, Dong and Yuan, Huizhuo and Gu, Quanquan and Li, Chunyuan}, + month={September}, + year={2024} +} + +@article{xiong2024llavacritic, + title={LLaVA-Critic: Learning to Evaluate Multimodal Models}, + author={Xiong, Tianyi and Wang, Xiyao and Guo, Dong and Ye, Qinghao and Fan, Haoqi and Gu, Quanquan and Huang, Heng and Li, Chunyuan}, + year={2024}, + eprint={2410.02712}, + archivePrefix={arXiv}, + primaryClass={cs.CV}, + url={https://arxiv.org/abs/2410.02712}, +} + +@article{li2024llavaov, + title={Llava-onevision: Easy visual task transfer}, + author={Li, Bo and Zhang, Yuanhan and Guo, Dong and Zhang, Renrui and Li, Feng and Zhang, Hao and Zhang, Kaichen and Li, Yanwei and Liu, Ziwei and Li, Chunyuan}, + journal={arXiv preprint arXiv:2408.03326}, + year={2024} +} + +@article{sun2023aligning, + title={Aligning large multimodal models with factually augmented rlhf}, + author={Sun, Zhiqing and Shen, Sheng and Cao, Shengcao and Liu, Haotian and Li, Chunyuan and Shen, Yikang and Gan, Chuang and Gui, Liang-Yan and Wang, Yu-Xiong and Yang, Yiming and Keutzer, Kurt and Darrell, Trevor}, + journal={arXiv preprint arXiv:2309.14525}, + year={2023} +} +``` + + diff --git a/docs/LLaVA_OneVision_Tutorials.ipynb b/docs/LLaVA_OneVision_Tutorials.ipynb new file mode 100644 index 000000000..366868911 --- /dev/null +++ b/docs/LLaVA_OneVision_Tutorials.ipynb @@ -0,0 +1,397 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# (Frustratingly Easy) LLaVA OneVision Tutorial\n", + "\n", + "We know that it's always beneficial to have a unified interface for different tasks. So we are trying to unify the interface for image, text, image-text interleaved, and video input. And in this tutorial, we aim to provide the most straightforward way to use our model. \n", + "\n", + "We use our 0.5B version as an example. This could be running on a GPU with 4GB memory. And with the following examples, you could see it's surprisingly have promising performance on understanding the image, interleaved image-text, and video. Tiny but mighty!\n", + "\n", + "The same code could be used for 7B model as well.\n", + "\n", + "## Inference Guidance\n", + "\n", + "First please install our repo with code and environments: pip install git+https://github.com/LLaVA-VL/LLaVA-NeXT.git\n", + "\n", + "Here is a quick inference code using [lmms-lab/qwen2-0.5b-si](https://huggingface.co/lmms-lab/llava-onevision-qwen2-0.5b-si) as an example. You will need to install `flash-attn` to use this code snippet. If you don't want to install it, you can set `attn_implementation=None` when load_pretrained_model\n", + "\n", + "### Image Input\n", + "Tackling the single image input with LLaVA OneVision is pretty straightforward." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from llava.model.builder import load_pretrained_model\n", + "from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token\n", + "from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IGNORE_INDEX\n", + "from llava.conversation import conv_templates, SeparatorStyle\n", + "\n", + "from PIL import Image\n", + "import requests\n", + "import copy\n", + "import torch\n", + "\n", + "import sys\n", + "import warnings\n", + "\n", + "warnings.filterwarnings(\"ignore\")\n", + "pretrained = \"lmms-lab/llava-onevision-qwen2-0.5b-si\"\n", + "model_name = \"llava_qwen\"\n", + "device = \"cuda\"\n", + "device_map = \"auto\"\n", + "llava_model_args = {\n", + " \"multimodal\": True,\n", + " \"attn_implementation\": \"sdpa\",\n", + "}\n", + "tokenizer, model, image_processor, max_length = load_pretrained_model(pretrained, None, model_name, device_map=device_map, **llava_model_args) # Add any other thing you want to pass in llava_model_args\n", + "\n", + "model.eval()\n", + "\n", + "url = \"https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true\"\n", + "image = Image.open(requests.get(url, stream=True).raw)\n", + "image_tensor = process_images([image], image_processor, model.config)\n", + "image_tensor = [_image.to(dtype=torch.float16, device=device) for _image in image_tensor]\n", + "\n", + "conv_template = \"qwen_1_5\" # Make sure you use correct chat template for different models\n", + "question = DEFAULT_IMAGE_TOKEN + \"\\nWhat is shown in this image?\"\n", + "conv = copy.deepcopy(conv_templates[conv_template])\n", + "conv.append_message(conv.roles[0], question)\n", + "conv.append_message(conv.roles[1], None)\n", + "prompt_question = conv.get_prompt()\n", + "\n", + "input_ids = tokenizer_image_token(prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors=\"pt\").unsqueeze(0).to(device)\n", + "image_sizes = [image.size]\n", + "\n", + "\n", + "cont = model.generate(\n", + " input_ids,\n", + " images=image_tensor,\n", + " image_sizes=image_sizes,\n", + " do_sample=False,\n", + " temperature=0,\n", + " max_new_tokens=4096,\n", + ")\n", + "text_outputs = tokenizer.batch_decode(cont, skip_special_tokens=True)\n", + "print(text_outputs)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You could use the following code to make it streaming in terminal, this would be pretty useful when creating a chatbot." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from threading import Thread\n", + "from transformers import TextIteratorStreamer\n", + "import json\n", + "\n", + "url = \"https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true\"\n", + "image = Image.open(requests.get(url, stream=True).raw)\n", + "image_tensor = process_images([image], image_processor, model.config)\n", + "image_tensor = [_image.to(dtype=torch.float16, device=device) for _image in image_tensor]\n", + "\n", + "conv_template = \"qwen_1_5\"\n", + "question = DEFAULT_IMAGE_TOKEN + \"\\nWhat is shown in this image?\"\n", + "conv = copy.deepcopy(conv_templates[conv_template])\n", + "conv.append_message(conv.roles[0], question)\n", + "conv.append_message(conv.roles[1], None)\n", + "prompt_question = conv.get_prompt()\n", + "\n", + "input_ids = tokenizer_image_token(prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors=\"pt\").unsqueeze(0).to(device)\n", + "image_sizes = [image.size]\n", + "\n", + "max_context_length = getattr(model.config, \"max_position_embeddings\", 2048)\n", + "num_image_tokens = question.count(DEFAULT_IMAGE_TOKEN) * model.get_vision_tower().num_patches\n", + "\n", + "streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)\n", + "\n", + "max_new_tokens = min(4096, max_context_length - input_ids.shape[-1] - num_image_tokens)\n", + "\n", + "if max_new_tokens < 1:\n", + " print(\n", + " json.dumps(\n", + " {\n", + " \"text\": question + \"Exceeds max token length. Please start a new conversation, thanks.\",\n", + " \"error_code\": 0,\n", + " }\n", + " )\n", + " )\n", + "else:\n", + " gen_kwargs = {\n", + " \"do_sample\": False,\n", + " \"temperature\": 0,\n", + " \"max_new_tokens\": max_new_tokens,\n", + " \"images\": image_tensor,\n", + " \"image_sizes\": image_sizes,\n", + " }\n", + "\n", + " thread = Thread(\n", + " target=model.generate,\n", + " kwargs=dict(\n", + " inputs=input_ids,\n", + " streamer=streamer,\n", + " **gen_kwargs,\n", + " ),\n", + " )\n", + " thread.start()\n", + "\n", + " generated_text = \"\"\n", + " for new_text in streamer:\n", + " generated_text += new_text\n", + " print(generated_text, flush=True)\n", + " # print(json.dumps({\"text\": generated_text, \"error_code\": 0}), flush=True)\n", + "\n", + " print(\"Final output:\", generated_text)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Image-Text Interleaved Input\n", + "\n", + "Now switching to our onevision model for more complex tasks. You should start to use `llava-onevision-qwen2-0.5b-ov` for image-text interleaved input and video input.\n", + "\n", + "Processing image-text interleaved input is a bit more complicated. But following the code below should work." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Load model\n", + "pretrained = \"lmms-lab/llava-onevision-qwen2-0.5b-ov\"\n", + "model_name = \"llava_qwen\"\n", + "device = \"cuda\"\n", + "device_map = \"auto\"\n", + "llava_model_args = {\n", + " \"multimodal\": True,\n", + " }\n", + "overwrite_config = {}\n", + "overwrite_config[\"image_aspect_ratio\"] = \"pad\"\n", + "llava_model_args[\"overwrite_config\"] = overwrite_config\n", + "tokenizer, model, image_processor, max_length = load_pretrained_model(pretrained, None, model_name, device_map=device_map, **llava_model_args)\n", + "\n", + "model.eval()\n", + "\n", + "# Load two images\n", + "url1 = \"https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true\"\n", + "url2 = \"https://raw.githubusercontent.com/haotian-liu/LLaVA/main/images/llava_logo.png\"\n", + "\n", + "image1 = Image.open(requests.get(url1, stream=True).raw)\n", + "image2 = Image.open(requests.get(url2, stream=True).raw)\n", + "\n", + "images = [image1, image2]\n", + "image_tensors = process_images(images, image_processor, model.config)\n", + "image_tensors = [_image.to(dtype=torch.float16, device=device) for _image in image_tensors]\n", + "\n", + "# Prepare interleaved text-image input\n", + "conv_template = \"qwen_1_5\"\n", + "question = f\"{DEFAULT_IMAGE_TOKEN} This is the first image. Can you describe what you see?\\n\\nNow, let's look at another image: {DEFAULT_IMAGE_TOKEN}\\nWhat's the difference between these two images?\"\n", + "\n", + "conv = copy.deepcopy(conv_templates[conv_template])\n", + "conv.append_message(conv.roles[0], question)\n", + "conv.append_message(conv.roles[1], None)\n", + "prompt_question = conv.get_prompt()\n", + "\n", + "input_ids = tokenizer_image_token(prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors=\"pt\").unsqueeze(0).to(device)\n", + "image_sizes = [image.size for image in images]\n", + "\n", + "# Generate response\n", + "cont = model.generate(\n", + " input_ids,\n", + " images=image_tensors,\n", + " image_sizes=image_sizes,\n", + " do_sample=False,\n", + " temperature=0,\n", + " max_new_tokens=4096,\n", + ")\n", + "text_outputs = tokenizer.batch_decode(cont, skip_special_tokens=True)\n", + "print(text_outputs[0])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Video Input\n", + "\n", + "Now let's try video input. It's the same as image input, but you need to pass in a list of video frames. And remember to set the `` token only once in the prompt, e.g. \"\\nWhat is shown in this video?\", not \"\\n\\n\\nWhat is shown in this video?\". Since we trained on this format, it's important to keep the format consistent." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/tiger/miniconda3/envs/public_llava/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "/home/tiger/miniconda3/envs/public_llava/lib/python3.10/site-packages/huggingface_hub/file_download.py:1150: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", + " warnings.warn(\n", + "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loaded LLaVA model: lmms-lab/llava-onevision-qwen2-7b-ov\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n", + "You are using a model of type llava to instantiate a model of type llava_qwen. This is not supported for all configurations of models and can yield errors.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading vision tower: google/siglip-so400m-patch14-384\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading checkpoint shards: 100%|██████████| 4/4 [00:08<00:00, 2.07s/it]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model Class: LlavaQwenForCausalLM\n", + "(16, 1024, 576, 3)\n", + "The video features a person standing on a stage, dressed in a black shirt and dark pants. A large hand appears from the background, reaching towards the person's pocket. The text 'Source: Joshua AG' is displayed at the top left corner of the frames, and 'EVAN CARMICHAEL' is shown in the top right corner. The text 'Anyone know what this pocket is for?' appears as the hand continues to reach into the pocket. The person then looks down at their pocket, and the text 'I've always wondered that' appears. The hand finally pulls out a small white device labeled 'iPod Nano'. The person holds up the iPod Nano, and the text 'is the new iPod Nano' appears. The video concludes with a close-up of the person holding the iPod Nano, showing it from different angles.\n" + ] + } + ], + "source": [ + "from operator import attrgetter\n", + "from llava.model.builder import load_pretrained_model\n", + "from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token\n", + "from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IGNORE_INDEX\n", + "from llava.conversation import conv_templates, SeparatorStyle\n", + "\n", + "import torch\n", + "import cv2\n", + "import numpy as np\n", + "from PIL import Image\n", + "import requests\n", + "import copy\n", + "import warnings\n", + "from decord import VideoReader, cpu\n", + "\n", + "warnings.filterwarnings(\"ignore\")\n", + "# Load the OneVision model\n", + "pretrained = \"lmms-lab/llava-onevision-qwen2-7b-ov\"\n", + "model_name = \"llava_qwen\"\n", + "device = \"cuda\"\n", + "device_map = \"auto\"\n", + "llava_model_args = {\n", + " \"multimodal\": True,\n", + "}\n", + "tokenizer, model, image_processor, max_length = load_pretrained_model(pretrained, None, model_name, device_map=device_map, attn_implementation=\"sdpa\", **llava_model_args)\n", + "\n", + "model.eval()\n", + "\n", + "\n", + "# Function to extract frames from video\n", + "def load_video(video_path, max_frames_num):\n", + " if type(video_path) == str:\n", + " vr = VideoReader(video_path, ctx=cpu(0))\n", + " else:\n", + " vr = VideoReader(video_path[0], ctx=cpu(0))\n", + " total_frame_num = len(vr)\n", + " uniform_sampled_frames = np.linspace(0, total_frame_num - 1, max_frames_num, dtype=int)\n", + " frame_idx = uniform_sampled_frames.tolist()\n", + " spare_frames = vr.get_batch(frame_idx).asnumpy()\n", + " return spare_frames # (frames, height, width, channels)\n", + "\n", + "\n", + "# Load and process video\n", + "video_path = \"jobs.mp4\"\n", + "video_frames = load_video(video_path, 16)\n", + "print(video_frames.shape) # (16, 1024, 576, 3)\n", + "image_tensors = []\n", + "frames = image_processor.preprocess(video_frames, return_tensors=\"pt\")[\"pixel_values\"].half().cuda()\n", + "image_tensors.append(frames)\n", + "\n", + "# Prepare conversation input\n", + "conv_template = \"qwen_1_5\"\n", + "question = f\"{DEFAULT_IMAGE_TOKEN}\\nDescribe what's happening in this video.\"\n", + "\n", + "conv = copy.deepcopy(conv_templates[conv_template])\n", + "conv.append_message(conv.roles[0], question)\n", + "conv.append_message(conv.roles[1], None)\n", + "prompt_question = conv.get_prompt()\n", + "\n", + "input_ids = tokenizer_image_token(prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors=\"pt\").unsqueeze(0).to(device)\n", + "image_sizes = [frame.size for frame in video_frames]\n", + "\n", + "# Generate response\n", + "cont = model.generate(\n", + " input_ids,\n", + " images=image_tensors,\n", + " image_sizes=image_sizes,\n", + " do_sample=False,\n", + " temperature=0,\n", + " max_new_tokens=4096,\n", + " modalities=[\"video\"],\n", + ")\n", + "text_outputs = tokenizer.batch_decode(cont, skip_special_tokens=True)\n", + "print(text_outputs[0])" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.9.2 64-bit", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + }, + "vscode": { + "interpreter": { + "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/LLaVA_Video_1003.md b/docs/LLaVA_Video_1003.md new file mode 100644 index 000000000..5810e51ac --- /dev/null +++ b/docs/LLaVA_Video_1003.md @@ -0,0 +1,134 @@ +# LLaVA Video + +## Table of Contents + +1. [Model Summary](##model-summary) +2. [Inference](##inference) +3. [Training](##training) +4. [Evaluation](##evaluation-guidance) +6. [Citation](##citation) + +## Model Summary + +The LLaVA-Video models are 7/72B parameter models trained on [LLaVA-Video-178K](https://huggingface.co/datasets/lmms-lab/LLaVA-Video-178K) and [LLaVA-OneVision Dataset](https://huggingface.co/datasets/lmms-lab/LLaVA-OneVision-Data), based on Qwen2 language model with a context window of 32K tokens. + + +## Inference + +We provide the simple generation process for using our model. For more details, you could refer to [Github](https://github.com/LLaVA-VL/LLaVA-NeXT). + +```python +# pip install git+https://github.com/LLaVA-VL/LLaVA-NeXT.git +from llava.model.builder import load_pretrained_model +from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token +from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IGNORE_INDEX +from llava.conversation import conv_templates, SeparatorStyle +from PIL import Image +import requests +import copy +import torch +import sys +import warnings +from decord import VideoReader, cpu +import numpy as np +warnings.filterwarnings("ignore") +def load_video(self, video_path, max_frames_num,fps=1,force_sample=False): + if max_frames_num == 0: + return np.zeros((1, 336, 336, 3)) + vr = VideoReader(video_path, ctx=cpu(0),num_threads=1) + total_frame_num = len(vr) + video_time = total_frame_num / vr.get_avg_fps() + fps = round(vr.get_avg_fps()/fps) + frame_idx = [i for i in range(0, len(vr), fps)] + frame_time = [i/fps for i in frame_idx] + if len(frame_idx) > max_frames_num or force_sample: + sample_fps = max_frames_num + uniform_sampled_frames = np.linspace(0, total_frame_num - 1, sample_fps, dtype=int) + frame_idx = uniform_sampled_frames.tolist() + frame_time = [i/vr.get_avg_fps() for i in frame_idx] + frame_time = ",".join([f"{i:.2f}s" for i in frame_time]) + spare_frames = vr.get_batch(frame_idx).asnumpy() + # import pdb;pdb.set_trace() + return spare_frames,frame_time,video_time +pretrained = "lmms-lab/LLaVA-Video-7B-Qwen2" +model_name = "llava_qwen" +device = "cuda" +device_map = "auto" +tokenizer, model, image_processor, max_length = load_pretrained_model(pretrained, None, model_name, torch_dtype="bfloat16", device_map=device_map) # Add any other thing you want to pass in llava_model_args +model.eval() +video_path = "XXXX" +max_frames_num = "64" +video,frame_time,video_time = load_video(video_path, max_frames_num, 1, force_sample=True) +video = image_processor.preprocess(video, return_tensors="pt")["pixel_values"].cuda().bfloat16() +video = [video] +conv_template = "qwen_1_5" # Make sure you use correct chat template for different models +time_instruciton = f"The video lasts for {video_time:.2f} seconds, and {len(video[0])} frames are uniformly sampled from it. These frames are located at {frame_time}.Please answer the following questions related to this video." +question = DEFAULT_IMAGE_TOKEN + f"{time_instruciton}\nPlease describe this video in detail." +conv = copy.deepcopy(conv_templates[conv_template]) +conv.append_message(conv.roles[0], question) +conv.append_message(conv.roles[1], None) +prompt_question = conv.get_prompt() +input_ids = tokenizer_image_token(prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device) +cont = model.generate( + input_ids, + images=video, + modalities= ["video"], + do_sample=False, + temperature=0, + max_new_tokens=4096, +) +text_outputs = tokenizer.batch_decode(cont, skip_special_tokens=True)[0].strip() +print(text_outputs) +``` + +## Data Preparation + +1. **Download LLaVA-OneVision** + Refer to the official instructions here: [LLaVA-OneVision Data](https://github.com/LLaVA-VL/LLaVA-NeXT/tree/main/scripts/train#about-the-llava-onevision-data). Make sure to follow the guidelines provided to obtain and organize the data correctly. + +2. **Download LLaVA-Video-178K** + The dataset is available on Hugging Face: [LLaVA-Video-178K](https://huggingface.co/datasets/lmms-lab/LLaVA-Video-178K). After downloading, place it in your desired directory. + +3. **Update `exp.yaml`** + In the [`exp.yaml` file](https://github.com/LLaVA-VL/LLaVA-NeXT/blob/main/scripts/video/train/exp.yaml), update the file paths to point to the directories where you stored the datasets: + - **Line 186-Line 263**: Specify the path for the LLaVA-Video-178K dataset. + - For other data references, update them to point to your local LLaVA-OneVision data directory. + +## Training + +[[Scripts]](https://github.com/LLaVA-VL/LLaVA-NeXT/blob/yhzhang/video_dev/scripts/video/train/SO400M_Qwen2_72B_ov_to_video_am9_aug6.sh): Start training models on your single-image/multi-image/video data. + + +## Evaluation Guidance + +We use the [lmms-eval](https://github.com/EvolvingLMMs-Lab/lmms-eval) toolkit to evaluate our models. Ensure you have installed the LLaVA-NeXT model files as per the instructions in the main README.md. + +Install lmms-eval: + +> pip install git+https://github.com/EvolvingLMMs-Lab/lmms-eval.git + +### Reproducing Evaluation Results + +Our models' evaluation results can be fully reproduced using the lmms-eval toolkit. After installing lmms-eval and llava, you can run the evaluation using the following commands. + +Note: These commands require flash-attn. If you prefer not to install it, disable flash-attn by adding `attn_implementation=None` to the `--model_args` parameter. + +Important: Different torch versions may cause slight variations in results. By default in `lmms-eval`, the requirement for torch version is set to the latest version. In `llava` repo, the torch version is set to `2.1.2`. Torch version `2.1.2` would be stable for both `llava` and `lmms-eval` + +### Evaluating LLaVA-Video on multiple datasets + +We recommend the developers and researchers to thoroughly evaluate the models on more datasets to get a comprehensive understanding of their performance in different scenarios. So we provide a comprehensive list of datasets for evaluation, and welcome to incoporate more evaluation tasks. Please refer to the [lmms-eval](https://github.com/EvolvingLMMs-Lab/lmms-eval) for more details. + +```bash +# video tasks +accelerate launch --num_processes=8 \ +-m lmms_eval \ +--model llava_vid \ +--model_args pretrained=lmms-lab/LLaVA-Video-7B-Qwen2,conv_template=qwen_1_5,max_frames_num=64,mm_spatial_pool_mode=average \ +--tasks activitynetqa,videochatgpt,nextqa_mc_test,egoschema,video_dc499,videmme,videomme_w_subtitle,perceptiontest_val_mc \ +--batch_size 1 \ +--log_samples \ +--log_samples_suffix llava_vid \ +--output_path ./logs/ +``` + diff --git a/docs/README.md b/docs/README.md new file mode 100644 index 000000000..12907e809 --- /dev/null +++ b/docs/README.md @@ -0,0 +1,17 @@ +# LLaVA-NeXT Documentation + +Welcome to the LLaVA-NeXT documentation. This guide provides an overview of the different components and features of LLaVA-NeXT. Please refer to the following documents for detailed information on specific topics: + +1. [LLaVA OneVision](LLaVA_OneVision.md): Learn about the most advanced and unified version: LLaVA OneVision. + - [LLaVA OneVision: Inference Tutorials](LLaVA_OneVision_Tutorials.ipynb): Learn how to use LLaVA OneVision for inference. + - [LLaVA Onevision Chat](LLaVA_OneVision_Chat.md): Improving Chat with Preference Learning + +2. [LLaVA-NeXT Interleave](LLaVA-NeXT-Interleave.md): Explore the interleaved training approach used in LLaVA-NeXT. + +3. [LLaVA-NeXT Video (0716)](LLaVA-NeXT-Video_0716.md): Discover the video processing capabilities of LLaVA-NeXT (version 0716). + +4. [LLaVA-NeXT Video](LLaVA-NeXT-Video.md): Get information about the latest video processing features in LLaVA-NeXT. + +5. [LLaVA-NeXT Overview](LLaVA-NeXT.md): Read a comprehensive overview of the LLaVA-NeXT project, including its architecture, features, and capabilities. + +These documents provide in-depth information on various aspects of LLaVA-NeXT. Please refer to them for detailed explanations, implementation details, and usage instructions. \ No newline at end of file diff --git a/docs/jobs.mp4 b/docs/jobs.mp4 new file mode 100644 index 000000000..22b337268 Binary files /dev/null and b/docs/jobs.mp4 differ diff --git a/docs/onevision_trial.py b/docs/onevision_trial.py new file mode 100644 index 000000000..752345e57 --- /dev/null +++ b/docs/onevision_trial.py @@ -0,0 +1,111 @@ +from llava.model.builder import load_pretrained_model +from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token +from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IGNORE_INDEX +from llava.conversation import conv_templates, SeparatorStyle + +from PIL import Image +import requests +import copy +import torch + +import sys +import warnings + +warnings.filterwarnings("ignore") + +pretrained = "lmms-lab/llava-onevision-qwen2-0.5b-si" +model_name = "llava_qwen" +device = "cuda" +device_map = "auto" +tokenizer, model, image_processor, max_length = load_pretrained_model(pretrained, None, model_name, device_map=device_map) # Add any other thing you want to pass in llava_model_args + +model.eval() + +url = "https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true" +image = Image.open(requests.get(url, stream=True).raw) +image_tensor = process_images([image], image_processor, model.config) +image_tensor = [_image.to(dtype=torch.float16, device=device) for _image in image_tensor] + +conv_template = "qwen_1_5" # Make sure you use correct chat template for different models +question = DEFAULT_IMAGE_TOKEN + "\nWhat is shown in this image?" +conv = copy.deepcopy(conv_templates[conv_template]) +conv.append_message(conv.roles[0], question) +conv.append_message(conv.roles[1], None) +prompt_question = conv.get_prompt() + +input_ids = tokenizer_image_token(prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device) +image_sizes = [image.size] + + +cont = model.generate( + input_ids, + images=image_tensor, + image_sizes=image_sizes, + do_sample=False, + temperature=0, + max_new_tokens=4096, +) +text_outputs = tokenizer.batch_decode(cont, skip_special_tokens=True) +print(text_outputs) + +from threading import Thread +from transformers import TextIteratorStreamer +import json + +url = "https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true" +image = Image.open(requests.get(url, stream=True).raw) +image_tensor = process_images([image], image_processor, model.config) +image_tensor = [_image.to(dtype=torch.float16, device=device) for _image in image_tensor] + +conv_template = "qwen_1_5" +question = DEFAULT_IMAGE_TOKEN + "\nWhat is shown in this image?" +conv = copy.deepcopy(conv_templates[conv_template]) +conv.append_message(conv.roles[0], question) +conv.append_message(conv.roles[1], None) +prompt_question = conv.get_prompt() + +input_ids = tokenizer_image_token(prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device) +image_sizes = [image.size] + +max_context_length = getattr(model.config, "max_position_embeddings", 2048) +num_image_tokens = question.count(DEFAULT_IMAGE_TOKEN) * model.get_vision_tower().num_patches + +streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15) + +max_new_tokens = min(4096, max_context_length - input_ids.shape[-1] - num_image_tokens) + +if max_new_tokens < 1: + print( + json.dumps( + { + "text": question + "Exceeds max token length. Please start a new conversation, thanks.", + "error_code": 0, + } + ) + ) +else: + gen_kwargs = { + "do_sample": False, + "temperature": 0, + "max_new_tokens": max_new_tokens, + "images": image_tensor, + "image_sizes": image_sizes, + } + + thread = Thread( + target=model.generate, + kwargs=dict( + inputs=input_ids, + streamer=streamer, + **gen_kwargs, + ), + ) + thread.start() + + generated_text = "" + for new_text in streamer: + generated_text += new_text + sys.stdout.write(new_text) + sys.stdout.flush() + + print("\nFinal output:", generated_text) diff --git a/docs/ov_chat_images/chat_results.png b/docs/ov_chat_images/chat_results.png new file mode 100644 index 000000000..0416c7407 Binary files /dev/null and b/docs/ov_chat_images/chat_results.png differ diff --git a/docs/ov_chat_images/example1_tree.png b/docs/ov_chat_images/example1_tree.png new file mode 100644 index 000000000..0b1b015ce Binary files /dev/null and b/docs/ov_chat_images/example1_tree.png differ diff --git a/docs/ov_chat_images/example2_dog.jpg b/docs/ov_chat_images/example2_dog.jpg new file mode 100644 index 000000000..38b3fb719 Binary files /dev/null and b/docs/ov_chat_images/example2_dog.jpg differ diff --git a/llava/__init__.py b/llava/__init__.py old mode 100644 new mode 100755 diff --git a/llava/constants.py b/llava/constants.py old mode 100644 new mode 100755 diff --git a/llava/conversation.py b/llava/conversation.py old mode 100644 new mode 100755 index a1efd5bba..66ef69632 --- a/llava/conversation.py +++ b/llava/conversation.py @@ -87,7 +87,7 @@ def get_prompt(self): for role, message in messages: if message: if type(message) is tuple: - message, images = message + message, images, _ = message message = "" * len(images) + message ret += role + "\n" + message + self.sep + "\n" else: @@ -95,6 +95,8 @@ def get_prompt(self): return ret elif self.sep_style == SeparatorStyle.LLAMA_3: + if self.tokenizer is None: + raise ValueError("Llama 3 tokenizer is not available. Make sure you have the necessary permissions.") chat_template_messages = [{"role": "system", "content": self.system}] for role, message in messages: if message: @@ -103,8 +105,8 @@ def get_prompt(self): message = "" * len(images) + message chat_template_messages.append({"role": role, "content": message}) - print(chat_template_messages) - return self.tokenizer.apply_chat_template(chat_template_messages, tokenize=False) + # print(chat_template_messages) + return self.tokenizer.apply_chat_template(chat_template_messages, tokenize=False, add_generation_prompt=True) # ret = "" if self.system == "" else self.system + self.sep + "\n" # for role, message in messages: # if message: @@ -202,6 +204,9 @@ def expand2square(pil_img, background_color=(122, 116, 104)): else: raise ValueError(f"Invalid image_process_mode: {image_process_mode}") + if type(image) is not Image.Image: + image = Image.open(image).convert("RGB") + max_hw, min_hw = max(image.size), min(image.size) aspect_ratio = max_hw / min_hw max_len, min_len = 672, 448 @@ -221,7 +226,7 @@ def expand2square(pil_img, background_color=(122, 116, 104)): img_b64_str = base64.b64encode(buffered.getvalue()).decode() return img_b64_str - def get_images(self, return_pil=False): + def get_images(self, return_pil=False, return_path=False): images = [] for i, (role, msg) in enumerate(self.messages[self.offset :]): if i % 2 == 0: @@ -230,10 +235,20 @@ def get_images(self, return_pil=False): if type(image) != list: image = [image] for img in image: - img = self.process_image(img, image_process_mode, return_pil=return_pil) - images.append(img) + if not return_path and self.is_image_file(img): + img = self.process_image(img, image_process_mode, return_pil=return_pil) + else: + images.append(img) return images + def is_image_file(self, filename): + image_extensions = [".png", ".jpg", ".jpeg", ".gif", ".bmp", ".tiff", ".webp"] + return any(filename.lower().endswith(ext) for ext in image_extensions) + + def is_video_file(self, filename): + video_extensions = [".mp4", ".mov", ".avi", ".mkv", ".wmv", ".flv", ".mpeg", ".mpg"] + return any(filename.lower().endswith(ext) for ext in video_extensions) + def to_gradio_chatbot(self): ret = [] for i, (role, msg) in enumerate(self.messages[self.offset :]): @@ -246,10 +261,24 @@ def to_gradio_chatbot(self): msg = "\n" + msg.replace("", "").strip() else: msg = re.sub(r"()\n(?=)", r"\1 ", msg) + + img_str_list = [] for img in image: - img_b64_str = self.process_image(img, "Default", return_pil=False, image_format="JPEG") - img_str = f'' - msg = msg.replace("", img_str, 1).strip() + if self.is_image_file(img): + img_b64_str = self.process_image(img, "Default", return_pil=False, image_format="JPEG") + img_str = f'' + img_str_list.append(img_str) + elif self.is_video_file(img): + ret.append(((img,), None)) + + msg = msg.strip() + img_place_holder = "" + for img_str in img_str_list: + img_place_holder += f"{img_str}\n\n" + + if len(img_str_list) > 0: + msg = f"{img_place_holder}\n\n{msg}" + if len(msg) > 0: ret.append([msg, None]) else: @@ -284,9 +313,9 @@ def dict(self): conv_vicuna_v0 = Conversation( system="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.", roles=("Human", "Assistant"), - messages=( - ("Human", "What are the key differences between renewable and non-renewable energy sources?"), - ( + messages=[ + ["Human", "What are the key differences between renewable and non-renewable energy sources?"], + [ "Assistant", "Renewable energy sources are those that can be replenished naturally in a relatively " "short amount of time, such as solar, wind, hydro, geothermal, and biomass. " @@ -306,8 +335,8 @@ def dict(self): "situations and needs, while non-renewable sources are more rigid and inflexible.\n" "6. Sustainability: Renewable energy sources are more sustainable over the long term, while " "non-renewable sources are not, and their depletion can lead to economic and social instability.\n", - ), - ), + ], + ], offset=2, sep_style=SeparatorStyle.SINGLE, sep="###", @@ -317,7 +346,7 @@ def dict(self): system="A chat between a curious user and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the user's questions.", roles=("USER", "ASSISTANT"), version="v1", - messages=(), + messages=[], offset=0, sep_style=SeparatorStyle.TWO, sep=" ", @@ -330,7 +359,7 @@ def dict(self): If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""", roles=("USER", "ASSISTANT"), version="llama_v2", - messages=(), + messages=[], offset=0, sep_style=SeparatorStyle.LLAMA_2, sep="", @@ -341,21 +370,29 @@ def dict(self): system="You are a helpful language and vision assistant. " "You are able to understand the visual content that the user provides, " "and assist the user with a variety of tasks using natural language.", roles=("USER", "ASSISTANT"), version="llama_v2", - messages=(), + messages=[], offset=0, sep_style=SeparatorStyle.LLAMA_2, sep="", sep2="", ) +def safe_load_tokenizer(tokenizer_id): + try: + return AutoTokenizer.from_pretrained(tokenizer_id) + except Exception: + return None + conv_llava_llama_3 = Conversation( system="You are a helpful language and vision assistant. " "You are able to understand the visual content that the user provides, " "and assist the user with a variety of tasks using natural language.", - roles=("<|start_header_id|>user", "<|start_header_id|>assistant"), + roles=("user", "assistant"), version="llama_v3", - messages=(), + messages=[], offset=0, + sep="<|eot_id|>", sep_style=SeparatorStyle.LLAMA_3, tokenizer_id="meta-llama/Meta-Llama-3-8B-Instruct", + tokenizer=safe_load_tokenizer("meta-llama/Meta-Llama-3-8B-Instruct"), stop_token_ids=[128009], ) @@ -363,7 +400,7 @@ def dict(self): system="", roles=("USER", "ASSISTANT"), version="llama_v2", - messages=(), + messages=[], offset=0, sep_style=SeparatorStyle.LLAMA_2, sep="", @@ -374,7 +411,7 @@ def dict(self): system="Answer the questions about the visual content that the user provides.", roles=("USER", "ASSISTANT"), version="llama_v2", - messages=(), + messages=[], offset=0, sep_style=SeparatorStyle.LLAMA_2, sep="", @@ -385,7 +422,7 @@ def dict(self): system="Answer the questions about the visual content that the user provides." "The visual content will be provided with the following format: visual content.", roles=("USER", "ASSISTANT"), version="llama_v2_mmtag", - messages=(), + messages=[], offset=0, sep_style=SeparatorStyle.LLAMA_2, sep="", @@ -397,7 +434,7 @@ def dict(self): A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""", roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), version="mpt", - messages=(), + messages=[], offset=0, sep_style=SeparatorStyle.MPT, sep="<|im_end|>", @@ -408,18 +445,29 @@ def dict(self): You are a helpful assistant.""", roles=("<|im_start|>user", "<|im_start|>assistant"), version="qwen", - messages=(), + messages=[], + offset=0, + sep_style=SeparatorStyle.CHATML, + sep="<|im_end|>", +) + +conv_qwen_2_5 = Conversation( + system="""<|im_start|>system +You are Qwen, created by Alibaba Cloud. You are a helpful assistant.""", + roles=("<|im_start|>user", "<|im_start|>assistant"), + version="qwen", + messages=[], offset=0, sep_style=SeparatorStyle.CHATML, sep="<|im_end|>", ) -conv_gemma_instruct = Conversation(system="", roles=("user\n", "model\n"), version="gemma", messages=(), offset=0, sep_style=SeparatorStyle.GEMMA, sep="\n") +conv_gemma_instruct = Conversation(system="", roles=("user\n", "model\n"), version="gemma", messages=[], offset=0, sep_style=SeparatorStyle.GEMMA, sep="\n") conv_llava_plain = Conversation( system="", roles=("", ""), - messages=(), + messages=[], offset=0, sep_style=SeparatorStyle.PLAIN, sep="\n", @@ -428,7 +476,7 @@ def dict(self): conv_llava_v0 = Conversation( system="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.", roles=("Human", "Assistant"), - messages=(), + messages=[], offset=0, sep_style=SeparatorStyle.SINGLE, sep="###", @@ -439,7 +487,7 @@ def dict(self): "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language." "The visual content will be provided with the following format: visual content.", roles=("Human", "Assistant"), - messages=(), + messages=[], offset=0, sep_style=SeparatorStyle.SINGLE, sep="###", @@ -450,7 +498,7 @@ def dict(self): system="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.", roles=("USER", "ASSISTANT"), version="v1", - messages=(), + messages=[], offset=0, sep_style=SeparatorStyle.TWO, sep=" ", @@ -462,7 +510,7 @@ def dict(self): "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language." "The visual content will be provided with the following format: visual content.", roles=("USER", "ASSISTANT"), - messages=(), + messages=[], offset=0, sep_style=SeparatorStyle.TWO, sep=" ", @@ -475,7 +523,7 @@ def dict(self): You are MistralOrca, a large language model trained by Alignment Lab AI. Write out your reasoning step-by-step to be sure you get the right answers!""", roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), version="mpt", - messages=(), + messages=[], offset=0, sep_style=SeparatorStyle.MPT, sep="<|im_end|>", @@ -486,7 +534,7 @@ def dict(self): You are a helpful AI assistant.""", roles=("<|user|>\n", "<|assistant|>\n"), version="mpt", - messages=(), + messages=[], offset=0, sep_style=SeparatorStyle.MPT, sep="", @@ -497,7 +545,7 @@ def dict(self): Answer the questions.""", roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), version="mpt", - messages=(), + messages=[], offset=0, sep_style=SeparatorStyle.MPT, sep="<|im_end|>", @@ -508,7 +556,7 @@ def dict(self): Answer the questions.""", roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), version="mpt", - messages=(), + messages=[], offset=0, sep_style=SeparatorStyle.MPT, sep="<|im_end|>", @@ -539,6 +587,8 @@ def dict(self): "llava_mistral_instruct": conv_mistral_instruct, "mpt": conv_mpt, "qwen_1_5": conv_qwen, + "qwen_2": conv_qwen, + "qwen_2_5": conv_qwen_2_5, "gemma_instruct": conv_gemma_instruct, } diff --git a/llava/eval/evaluate_interleave.py b/llava/eval/evaluate_interleave.py new file mode 100644 index 000000000..b00d32f2d --- /dev/null +++ b/llava/eval/evaluate_interleave.py @@ -0,0 +1,339 @@ +import re +from rouge import Rouge +import argparse +import os +import json +import numpy as np +from sklearn.feature_extraction.text import TfidfVectorizer +from sklearn.metrics.pairwise import cosine_similarity + + +spot_the_diff = ["Spot-the-Diff", "Birds-to-Words", "CLEVR-Change"] +image_edit_instruct = ["IEdit", "HQ-Edit", "MagicBrush"] +visual_story_telling = ["AESOP", "FlintstonesSV", "PororoSV", "VIST"] +visual_cloze = ["COMICS_Dialogue", "RecipeQA_VisualCloze"] +text_rich_vqa = ["WebQA", "TQA", "OCR-VQA", "DocVQA"] +multi_image_vqa = ["MIT-States_StateCoherence", "MIT-States_PropertyCoherence", "VISION", "RecipeQA_ImageCoherence"] + +puzzle = ["RAVEN"] +nlrv2 = ["NLVR2_Mantis"] +qbench = ["QBench"] + +class Eval: + def __init__(self): + self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)") + self.commaStrip = re.compile("(\d)(\,)(\d)") + self.punct = [ + ";", + r"/", + "[", + "]", + '"', + "{", + "}", + "(", + ")", + "=", + "+", + "\\", + "_", + "-", + ">", + "<", + "@", + "`", + ",", + "?", + "!", + ] + + def processPunctuation(self, inText): + outText = inText + for p in self.punct: + if (p + " " in inText or " " + p in inText) or ( + re.search(self.commaStrip, inText) != None + ): + outText = outText.replace(p, "") + else: + outText = outText.replace(p, " ") + outText = self.periodStrip.sub("", outText, re.UNICODE) + return outText + + def process(self, answer): + answer = answer.replace("\n", " ") + answer = answer.replace("\t", " ") + answer = answer.strip() + answer = self.processPunctuation(answer) + answer = answer.strip('\'') + answer = answer.strip('\"') + answer = answer.strip(')') + answer = answer.strip('(') + answer = answer.strip().lower() + return answer + + def evaluate_rouge(self,preds): + rouge = Rouge() + acc = {'f': []} + eval_list = [] + for i, res in enumerate(preds): + sample_id = res['sample_id'] + # print(sample_id) + gt_ans = self.process(res["gt_response"]) + pred_ans = self.process(res["pred_response"]) + # assert gt_ans != '' + + if gt_ans == '': + continue + + if pred_ans == '': + s = 0 + else: + if len(pred_ans) > 512: + pred_ans = pred_ans[0: 512] + s = rouge.get_scores(pred_ans, gt_ans)[0]['rouge-l']['f'] + acc['f'].append(s) + eval_list.append({'id':str(sample_id),'score':str(round(s,3))}) + results = {'Rouge-L f': np.mean(acc['f'])} + return results,eval_list + + + def judge_multi_choice(self,sample): + sample_id = sample['sample_id'] + gt_ans = sample["gt_response"] + pred_ans = sample["pred_response"] + + if ":" in pred_ans: + a_list = pred_ans.split(":") + a_list = [a.strip() for a in a_list ] + for a in a_list: + if len(a) == 1 and a[-1] in ["a", "b", "c", "d", "e", "f", "g", "h"]: + pred_ans = a + + if pred_ans == gt_ans: + return 1 + else: + return 0 + + def process_sample(self,sample): + sample["gt_response"] = self.process(sample["gt_response"]) + sample["pred_response"] = self.process(sample["pred_response"]) + + def evaluate_multichoice(self, preditions): + correct = 0 + eval_list = [] + for i, sample in enumerate(preditions): + self.process_sample(sample) + score = self.judge_multi_choice(sample) + sample_id = sample['sample_id'] + sample['result'] = score + eval_list.append({'id':str(sample_id),'score':str(score)}) + correct+=score + return {'Accuracy':correct/len(preditions)},eval_list + + def evaluate_multi_choice_image(self,preditions): + correct = 0 + eval_list = [] + for i,sample in enumerate(preditions): + gt_ans = self.process(sample["gt_response"]) + pred_ans = self.process(sample["pred_response"]) + sample_id = sample['sample_id'] + + if ":" in pred_ans: + a_list = pred_ans.split(":") + a_list = [a.strip() for a in a_list ] + for a in a_list: + if len(a) == 1 and a[-1] in ["a", "b", "c", "d", "e", "f", "g", "h"]: + pred_ans = a + + if gt_ans == pred_ans: + score = 1 + else: + score = 0 + sample_id = sample['sample_id'] + sample['result'] = score + eval_list.append({'id':str(sample_id),'score':str(score)}) + correct+=score + return {'Accuracy':correct/len(preditions)},eval_list + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--result-dir', type=str, required=True) + + args = parser.parse_args() + + result_file = os.path.join(args.result_dir, "result.jsonl") + + if not os.path.exists(result_file): + print('No prediction file found') + exit(0) + with open(result_file, 'r') as f: + preds_all = [json.loads(line) for line in f] + + preds_all_dict = dict() + for pred in preds_all: + if pred["dataset"] not in preds_all_dict: + preds_all_dict[pred["dataset"]] = list() + preds_all_dict[pred["dataset"]].append(pred) + + image_choice_dataset_list = ["recipeqa-RecipeQA_VisualCloze", "RecipeQA_ImageCoherence", "COMICS_Panel"] + E = Eval() + + eval_result_list = dict() + eval_result_list_detail = dict() + + for dataset in preds_all_dict: + + preds = preds_all_dict[dataset] + question_type = preds[0]["question_type"] + + if question_type == 'open-ended': + eval_result, eval_list = E.evaluate_rouge(preds) + + elif question_type == 'multi-choice' or dataset == 'nlrv2': + if dataset in image_choice_dataset_list: + eval_result, eval_list = E.evaluate_multi_choice_image(preds) + else: + eval_result, eval_list = E.evaluate_multichoice(preds) + + else: + eval_result = 'Dataset not supported' + print('Dataset not supported') + exit(0) + + print(dataset, end = ': ') + print(eval_result) + + eval_result_list[dataset] = eval_result + eval_result_list_detail[dataset] = eval_list + + os.makedirs(args.result_dir, exist_ok=True) + with open(os.path.join(args.result_dir, 'eval_dataset.json'), 'w') as f: + json.dump(eval_result_list, f, indent=4) + + with open(os.path.join(args.result_dir,'eval_dataset_details.json'), 'w') as f: + json.dump(eval_result_list_detail, f, indent=4) + + + eval_cat_list = dict() + print() + + # spot_the_diff + score = 0 + count = 0 + for dataset in eval_result_list: + if dataset in spot_the_diff: + count += 1 + score += list(eval_result_list[dataset].values())[0] + if count > 0: + score /= count + eval_cat_list["spot_the_diff"] = score + print("spot_the_diff", end = ': ') + print('{:.2f}'.format(100 * score)) + + # image_edit_instruct + score = 0 + count = 0 + for dataset in eval_result_list: + if dataset in image_edit_instruct: + count += 1 + score += list(eval_result_list[dataset].values())[0] + if count > 0: + score /= count + eval_cat_list["image_edit_instruct"] = score + print("image_edit_instruct", end = ': ') + print('{:.2f}'.format(100 * score)) + + # visual_story_telling + score = 0 + count = 0 + for dataset in eval_result_list: + if dataset in visual_story_telling: + count += 1 + score += list(eval_result_list[dataset].values())[0] + if count > 0: + score /= count + eval_cat_list["visual_story_telling"] = score + print("visual_story_telling", end = ': ') + print('{:.2f}'.format(100 * score)) + + # visual_cloze + score = 0 + count = 0 + for dataset in eval_result_list: + if dataset in visual_cloze: + count += 1 + score += list(eval_result_list[dataset].values())[0] + if count > 0: + score /= count + eval_cat_list["visual_cloze"] = score + print("visual_cloze", end = ': ') + print('{:.2f}'.format(100 * score)) + + # text_rich_vqa + score = 0 + count = 0 + for dataset in eval_result_list: + if dataset in text_rich_vqa: + count += 1 + score += list(eval_result_list[dataset].values())[0] + if count > 0: + score /= count + eval_cat_list["text_rich_vqa"] = score + print("text_rich_vqa", end = ': ') + print('{:.2f}'.format(100 * score)) + + # multi_image_vqa + score = 0 + count = 0 + for dataset in eval_result_list: + if dataset in multi_image_vqa: + count += 1 + score += list(eval_result_list[dataset].values())[0] + if count > 0: + score /= count + eval_cat_list["multi_image_vqa"] = score + print("multi_image_vqa", end = ': ') + print('{:.2f}'.format(100 * score)) + + # puzzle + score = 0 + count = 0 + for dataset in eval_result_list: + if dataset in puzzle: + count += 1 + score += list(eval_result_list[dataset].values())[0] + if count > 0: + score /= count + eval_cat_list["puzzle"] = score + print("puzzle", end = ': ') + print('{:.2f}'.format(100 * score)) + + # nlrv2 + score = 0 + count = 0 + for dataset in eval_result_list: + if dataset in nlrv2: + count += 1 + score += list(eval_result_list[dataset].values())[0] + if count > 0: + score /= count + eval_cat_list["nlrv2"] = score + print("nlrv2", end = ': ') + print('{:.2f}'.format(100 * score)) + + # qbench + score = 0 + count = 0 + for dataset in eval_result_list: + if dataset in qbench: + count += 1 + score += list(eval_result_list[dataset].values())[0] + if count > 0: + score /= count + eval_cat_list["qbench"] = score + print("qbench", end = ': ') + print('{:.2f}'.format(100 * score)) + + with open(os.path.join(args.result_dir,'eval_cat.json'), 'w') as f: + json.dump(eval_cat_list, f, indent=4) \ No newline at end of file diff --git a/llava/eval/model_vqa.py b/llava/eval/model_vqa.py new file mode 100644 index 000000000..2ebceedaf --- /dev/null +++ b/llava/eval/model_vqa.py @@ -0,0 +1,240 @@ +import argparse +import torch +import os +import json +from tqdm import tqdm +import shortuuid + +from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN +from llava.conversation import conv_templates, SeparatorStyle +from llava.model.builder import load_pretrained_model +from llava.utils import disable_torch_init +from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria + +from llava.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IMAGE_TOKEN_INDEX +from typing import Dict, Optional, Sequence, List +import transformers +import re + +from PIL import Image +import math + + +def split_list(lst, n): + """Split a list into n (roughly) equal-sized chunks""" + chunk_size = math.ceil(len(lst) / n) # integer division + return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)] + + +def get_chunk(lst, n, k): + chunks = split_list(lst, n) + return chunks[k] + +def preprocess_qwen(sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False, max_len=2048, system_message: str = "You are a helpful assistant.") -> Dict: + roles = {"human": "<|im_start|>user", "gpt": "<|im_start|>assistant"} + + im_start, im_end = tokenizer.additional_special_tokens_ids + nl_tokens = tokenizer("\n").input_ids + _system = tokenizer("system").input_ids + nl_tokens + _user = tokenizer("user").input_ids + nl_tokens + _assistant = tokenizer("assistant").input_ids + nl_tokens + + # Apply prompt templates + input_ids, targets = [], [] + + source = sources + if roles[source[0]["from"]] != roles["human"]: + source = source[1:] + + input_id, target = [], [] + system = [im_start] + _system + tokenizer(system_message).input_ids + [im_end] + nl_tokens + input_id += system + target += [im_start] + [IGNORE_INDEX] * (len(system) - 3) + [im_end] + nl_tokens + assert len(input_id) == len(target) + for j, sentence in enumerate(source): + role = roles[sentence["from"]] + if has_image and sentence["value"] is not None and "" in sentence["value"]: + num_image = len(re.findall(DEFAULT_IMAGE_TOKEN, sentence["value"])) + texts = sentence["value"].split('') + _input_id = tokenizer(role).input_ids + nl_tokens + for i,text in enumerate(texts): + _input_id += tokenizer(text).input_ids + if iuser": + _target = [im_start] + [IGNORE_INDEX] * (len(_input_id) - 3) + [im_end] + nl_tokens + elif role == "<|im_start|>assistant": + _target = [im_start] + [IGNORE_INDEX] * len(tokenizer(role).input_ids) + _input_id[len(tokenizer(role).input_ids) + 1 : -2] + [im_end] + nl_tokens + else: + raise NotImplementedError + target += _target + + input_ids.append(input_id) + targets.append(target) + input_ids = torch.tensor(input_ids, dtype=torch.long) + targets = torch.tensor(targets, dtype=torch.long) + return input_ids + +def eval_model(args): + + # Model + disable_torch_init() + model_path = os.path.expanduser(args.model_path) + model_name = get_model_name_from_path(model_path) + tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name) + + # Data + with open(os.path.expanduser(args.question_file)) as f: + questions = json.load(f) + questions = get_chunk(questions, args.num_chunks, args.chunk_idx) + answers_file = os.path.expanduser(args.answers_file) + os.makedirs(os.path.dirname(answers_file), exist_ok=True) + ans_file = open(answers_file, "w") + + for line in tqdm(questions): + idx = line["sample_id"] + question_type = line["metadata"]["question_type"] + dataset_name = line["metadata"]["dataset"] + gt = line["conversations"][1]["value"] + + image_files = line["image"] + qs = line["conversations"][0]["value"] + cur_prompt = args.extra_prompt + qs + + args.conv_mode = "qwen_1_5" + + conv = conv_templates[args.conv_mode].copy() + conv.append_message(conv.roles[0], qs) + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + + input_ids = preprocess_qwen([line["conversations"][0],{'from': 'gpt','value': None}], tokenizer, has_image=True).cuda() + img_num = list(input_ids.squeeze()).count(IMAGE_TOKEN_INDEX) + + image_tensors = [] + for image_file in image_files: + image = Image.open(os.path.join(args.image_folder, image_file)) + image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'] + image_tensors.append(image_tensor.half().cuda()) + # image_tensors = torch.cat(image_tensors, dim=0) + + stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 + keywords = [stop_str] + stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) + + with torch.inference_mode(): + output_ids = model.generate( + input_ids, + images=image_tensors, + do_sample=True if args.temperature > 0 else False, + temperature=args.temperature, + top_p=args.top_p, + num_beams=args.num_beams, + # no_repeat_ngram_size=3, + max_new_tokens=1024, + use_cache=True) + + + outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0] + outputs = outputs.strip() + if outputs.endswith(stop_str): + outputs = outputs[:-len(stop_str)] + outputs = outputs.strip() + + ans_id = shortuuid.uuid() + ans_file.write(json.dumps({ + "dataset": dataset_name, + "sample_id": idx, + "prompt": cur_prompt, + "pred_response": outputs, + "gt_response": gt, + "shortuuid": ans_id, + "model_id": model_name, + "question_type": question_type, + }) + "\n") + ans_file.flush() + + if len(line["conversations"]) > 2: + + for i in range(2, len(line["conversations"]), 2): + input_ids = torch.cat((input_ids, output_ids), dim=1) + + gt = line["conversations"][i + 1]["value"] + qs = line["conversations"][i]["value"] + cur_prompt = args.extra_prompt + qs + + args.conv_mode = "qwen_1_5" + + conv = conv_templates[args.conv_mode].copy() + conv.append_message(conv.roles[0], qs) + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + + input_ids_new = preprocess_qwen([line["conversations"][i],{'from': 'gpt','value': None}], tokenizer, has_image=True).cuda() + input_ids = torch.cat((input_ids, input_ids_new), dim=1) + img_num = list(input_ids_new.squeeze()).count(IMAGE_TOKEN_INDEX) + + stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 + keywords = [stop_str] + stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) + + with torch.inference_mode(): + output_ids = model.generate( + input_ids, + images=image_tensors, + do_sample=True if args.temperature > 0 else False, + temperature=args.temperature, + top_p=args.top_p, + num_beams=args.num_beams, + # no_repeat_ngram_size=3, + max_new_tokens=1024, + use_cache=True) + + outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0] + outputs = outputs.strip() + if outputs.endswith(stop_str): + outputs = outputs[:-len(stop_str)] + outputs = outputs.strip() + + ans_id = shortuuid.uuid() + ans_file.write(json.dumps({ + "dataset": dataset_name, + "sample_id": idx, + "prompt": cur_prompt, + "pred_response": outputs, + "gt_response": gt, + "shortuuid": ans_id, + "model_id": model_name, + "question_type": question_type, + }) + "\n") + ans_file.flush() + + + ans_file.close() + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model-path", type=str, default="facebook/opt-350m") + parser.add_argument("--model-base", type=str, default=None) + parser.add_argument("--image-folder", type=str, default="") + parser.add_argument("--extra-prompt", type=str, default="") + parser.add_argument("--question-file", type=str, default="tables/question.jsonl") + parser.add_argument("--answers-file", type=str, default="answer.jsonl") + parser.add_argument("--conv-mode", type=str, default="llava_v1") + parser.add_argument("--num-chunks", type=int, default=1) + parser.add_argument("--chunk-idx", type=int, default=0) + parser.add_argument("--temperature", type=float, default=0.2) + parser.add_argument("--top_p", type=float, default=None) + parser.add_argument("--num_beams", type=int, default=1) + parser.add_argument("--test_size", type=int, default=10000000) + args = parser.parse_args() + + eval_model(args) \ No newline at end of file diff --git a/llava/mm_utils.py b/llava/mm_utils.py old mode 100644 new mode 100755 index 3e5c8b0c2..62a3e5090 --- a/llava/mm_utils.py +++ b/llava/mm_utils.py @@ -3,7 +3,7 @@ import base64 import math import ast - +import re import torch from transformers import StoppingCriteria from llava.constants import IMAGE_TOKEN_INDEX @@ -222,11 +222,16 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): Returns: tuple: The shape of the image patch grid in the format (width, height). """ - if isinstance(grid_pinpoints, str): + if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints: assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]" - grid_pinpoints = grid_pinpoints.replace(" ", "").replace("x", ",")[1:-1].split("),(") - grid_pinpoints = [[int(x) * patch_size for x in item.split(",")] for item in grid_pinpoints] - + # Use regex to extract the range from the input string + matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints) + range_start = tuple(map(int, matches[0])) + range_end = tuple(map(int, matches[-1])) + # Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1]) + grid_pinpoints = [(i, j) for i in range(range_start[0], range_end[0] + 1) for j in range(range_start[1], range_end[1] + 1)] + # Multiply all elements by patch_size + grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints] if type(grid_pinpoints) is list: possible_resolutions = grid_pinpoints else: @@ -248,11 +253,20 @@ def process_anyres_image(image, processor, grid_pinpoints): torch.Tensor: A tensor containing the processed image patches. """ # Convert grid_pinpoints from string to list - if isinstance(grid_pinpoints, str): - vis_encoder_size = processor.size[0] - assert vis_encoder_size in [224, 336, 384, 448, 512], "vis_encoder_size should be in [224, 336, 384, 448, 512]" - grid_pinpoints = grid_pinpoints.replace(" ", "").replace("x", ",")[1:-1].split("),(") - grid_pinpoints = [[int(x) * vis_encoder_size for x in item.split(",")] for item in grid_pinpoints] + if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints: + try: + patch_size = processor.size[0] + except Exception as e: + patch_size = processor.size["shortest_edge"] + assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]" + # Use regex to extract the range from the input string + matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints) + range_start = tuple(map(int, matches[0])) + range_end = tuple(map(int, matches[-1])) + # Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1]) + grid_pinpoints = [(i, j) for i in range(range_start[0], range_end[0] + 1) for j in range(range_start[1], range_end[1] + 1)] + # Multiply all elements by patch_size + grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints] if type(grid_pinpoints) is list: possible_resolutions = grid_pinpoints @@ -304,7 +318,7 @@ def process_images(images, image_processor, model_cfg): for image in images: image = process_highres_image(image, image_processor, model_cfg.image_grid_pinpoints) new_images.append(image) - elif image_aspect_ratio == "anyres": + elif image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio: for image in images: image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints) new_images.append(image) @@ -318,7 +332,7 @@ def process_images(images, image_processor, model_cfg): image = image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0] new_images.append(image) else: - return image_processor(images, return_tensors="pt")["pixel_values"] + return image_processor.preprocess(images, return_tensors="pt")["pixel_values"] if all(x.shape == new_images[0].shape for x in new_images): new_images = torch.stack(new_images, dim=0) return new_images diff --git a/llava/model/__init__.py b/llava/model/__init__.py old mode 100644 new mode 100755 index b6f5f70bd..1abf0b362 --- a/llava/model/__init__.py +++ b/llava/model/__init__.py @@ -2,20 +2,15 @@ AVAILABLE_MODELS = { "llava_llama": "LlavaLlamaForCausalLM, LlavaConfig", - "llava_gemma": "LlavaGemmaForCausalLM, LlavaGemmaConfig", "llava_qwen": "LlavaQwenForCausalLM, LlavaQwenConfig", - # "llava_qwen_moe": "LlavaQwenMoeForCausalLM, LlavaQwenMoeConfig", "llava_mistral": "LlavaMistralForCausalLM, LlavaMistralConfig", "llava_mixtral": "LlavaMixtralForCausalLM, LlavaMixtralConfig", + # "llava_qwen_moe": "LlavaQwenMoeForCausalLM, LlavaQwenMoeConfig", # Add other models as needed } for model_name, model_classes in AVAILABLE_MODELS.items(): try: exec(f"from .language_model.{model_name} import {model_classes}") - except ImportError: - import traceback - - traceback.print_exc() - print(f"Failed to import {model_name} from llava.language_model.{model_name}") - pass + except Exception as e: + print(f"Failed to import {model_name} from llava.language_model.{model_name}. Error: {e}") diff --git a/llava/model/apply_delta.py b/llava/model/apply_delta.py old mode 100644 new mode 100755 diff --git a/llava/model/builder.py b/llava/model/builder.py old mode 100644 new mode 100755 index 7191fc44b..704b960a6 --- a/llava/model/builder.py +++ b/llava/model/builder.py @@ -24,21 +24,32 @@ from llava.utils import rank0_print -def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", attn_implementation="flash_attention_2", customized_config=None): - kwargs = {"device_map": device_map} +def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", torch_dtype="float16",attn_implementation="flash_attention_2", customized_config=None, overwrite_config=None, **kwargs): + kwargs["device_map"] = device_map if load_8bit: kwargs["load_in_8bit"] = True elif load_4bit: kwargs["load_in_4bit"] = True kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4") - else: + elif torch_dtype == "float16": kwargs["torch_dtype"] = torch.float16 + elif torch_dtype == "bfloat16": + kwargs["torch_dtype"] = torch.bfloat16 + else: + import pdb;pdb.set_trace() if customized_config is not None: kwargs["config"] = customized_config - if "llava" in model_name.lower(): + if "multimodal" in kwargs: + if kwargs["multimodal"] is True: + is_multimodal = True + kwargs.pop("multimodal") + else: + is_multimodal = False + + if "llava" in model_name.lower() or is_multimodal: # Load LLaVA model if "lora" in model_name.lower() and model_base is None: warnings.warn( @@ -102,8 +113,7 @@ def load_from_hf(repo_id, filename, subfolder=None): rank0_print("Merging LoRA weights...") model = model.merge_and_unload() rank0_print("Model is loaded...") - elif model_base is not None: - # this may be mm projector only + elif model_base is not None: # this may be mm projector only, loading projector with preset language mdoel rank0_print(f"Loading LLaVA from base model {model_base}...") if "mixtral" in model_name.lower(): tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) @@ -148,8 +158,22 @@ def load_from_hf(repo_id, filename, subfolder=None): else: rank0_print(f"Loaded LLaVA model: {model_path}") if "mixtral" in model_name.lower(): + from llava.model.language_model.llava_mixtral import LlavaMixtralConfig + + tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) + if customized_config is None: + llava_cfg = LlavaMixtralConfig.from_pretrained(model_path) + else: + llava_cfg = customized_config + + if overwrite_config is not None: + rank0_print(f"Overwriting config with {overwrite_config}") + for k, v in overwrite_config.items(): + setattr(llava_cfg, k, v) + tokenizer = AutoTokenizer.from_pretrained(model_path) - model = LlavaMixtralForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, **kwargs) + model = LlavaMixtralForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, config=llava_cfg, **kwargs) + elif "mistral" in model_name.lower() or "zephyr" in model_name.lower(): tokenizer = AutoTokenizer.from_pretrained(model_path) model = LlavaMistralForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, **kwargs) @@ -172,16 +196,60 @@ def load_from_hf(repo_id, filename, subfolder=None): else: llava_cfg = customized_config + if overwrite_config is not None: + rank0_print(f"Overwriting config with {overwrite_config}") + for k, v in overwrite_config.items(): + setattr(llava_cfg, k, v) + model = LlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, config=llava_cfg, **kwargs) - elif "qwen" in model_name.lower(): - tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) - model = LlavaQwenForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, **kwargs) + + elif "qwen" in model_name.lower() or "quyen" in model_name.lower(): + tokenizer = AutoTokenizer.from_pretrained(model_path) + if "moe" in model_name.lower() or "A14B" in model_name.lower(): + from llava.model.language_model.llava_qwen_moe import LlavaQwenMoeConfig + if overwrite_config is not None: + llava_cfg = LlavaQwenMoeConfig.from_pretrained(model_path) + rank0_print(f"Overwriting config with {overwrite_config}") + for k, v in overwrite_config.items(): + setattr(llava_cfg, k, v) + model = LlavaQwenMoeForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, config=llava_cfg, **kwargs) + else: + model = LlavaQwenMoeForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, **kwargs) + + else: + from llava.model.language_model.llava_qwen import LlavaQwenConfig + if overwrite_config is not None: + llava_cfg = LlavaQwenConfig.from_pretrained(model_path) + rank0_print(f"Overwriting config with {overwrite_config}") + for k, v in overwrite_config.items(): + setattr(llava_cfg, k, v) + model = LlavaQwenForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, config=llava_cfg, **kwargs) + else: + model = LlavaQwenForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, **kwargs) + elif "gemma" in model_name.lower(): tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) cfg_pretrained = AutoConfig.from_pretrained(model_path) model = LlavaGemmaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, config=cfg_pretrained, attn_implementation=attn_implementation, **kwargs) else: - raise ValueError(f"Model {model_name} not supported") + try: + from llava.model.language_model.llava_llama import LlavaConfig + + tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) + if customized_config is None: + llava_cfg = LlavaConfig.from_pretrained(model_path) + if "v1.5" in model_path.lower(): + llava_cfg.delay_load = True # a workaround for correctly loading v1.5 models + else: + llava_cfg = customized_config + + if overwrite_config is not None: + rank0_print(f"Overwriting config with {overwrite_config}") + for k, v in overwrite_config.items(): + setattr(llava_cfg, k, v) + model = LlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, config=llava_cfg, **kwargs) + except: + raise ValueError(f"Model {model_name} not supported") else: # Load language model @@ -209,7 +277,7 @@ def load_from_hf(repo_id, filename, subfolder=None): rank0_print(f"Model Class: {model.__class__.__name__}") image_processor = None - if "llava" in model_name.lower(): + if "llava" in model_name.lower() or is_multimodal: mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True) if mm_use_im_patch_token: diff --git a/llava/model/consolidate.py b/llava/model/consolidate.py old mode 100644 new mode 100755 diff --git a/llava/model/language_model/llava_llama.py b/llava/model/language_model/llava_llama.py old mode 100644 new mode 100755 index b00dd29eb..1a930e97d --- a/llava/model/language_model/llava_llama.py +++ b/llava/model/language_model/llava_llama.py @@ -20,6 +20,9 @@ from transformers import AutoConfig, AutoModelForCausalLM, LlamaConfig +from torch.nn import CrossEntropyLoss + + # , LlamaModel, LlamaForCausalLM, GenerationConfig # from .modeling_llama import LlamaModel, LlamaForCausalLM from transformers import LlamaModel, LlamaForCausalLM @@ -35,7 +38,7 @@ class LlavaConfig(LlamaConfig): max_new_tokens: int = 1024 do_sample: bool = False top_p: Optional[float] = None - rope_scaling: Optional[dict] = {} + # rope_scaling: Optional[dict] = {} class LlavaLlamaModel(LlavaMetaModel, LlamaModel): @@ -53,7 +56,7 @@ def __init__(self, config): # configure default generation settings config.model_type = "llava_llama" - config.rope_scaling = None + # config.rope_scaling = None self.model = LlavaLlamaModel(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) @@ -77,24 +80,44 @@ def forward( images: Optional[torch.FloatTensor] = None, image_sizes: Optional[List[List[int]]] = None, return_dict: Optional[bool] = None, + modalities: Optional[List[str]] = ["image"], + dpo_forward: Optional[bool] = None, cache_position=None, ) -> Union[Tuple, CausalLMOutputWithPast]: if inputs_embeds is None: - (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes) - - return super().forward( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - labels=labels, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) + (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, modalities, image_sizes) + + if dpo_forward: + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + return logits, labels + + else: + return super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) @torch.no_grad() def generate( @@ -102,15 +125,17 @@ def generate( inputs: Optional[torch.Tensor] = None, images: Optional[torch.Tensor] = None, image_sizes: Optional[torch.Tensor] = None, + modalities: Optional[List[str]] = ["image"], **kwargs, ) -> Union[GenerateOutput, torch.LongTensor]: + modalities = kwargs.pop("modalities", None) if "modalities" in kwargs and modalities is None else modalities position_ids = kwargs.pop("position_ids", None) attention_mask = kwargs.pop("attention_mask", None) if "inputs_embeds" in kwargs: raise NotImplementedError("`inputs_embeds` is not supported") if images is not None: - (inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, image_sizes=image_sizes) + (inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, modalities, image_sizes=image_sizes) else: inputs_embeds = self.get_model().embed_tokens(inputs) diff --git a/llava/model/language_model/llava_mistral.py b/llava/model/language_model/llava_mistral.py old mode 100644 new mode 100755 diff --git a/llava/model/language_model/llava_mixtral.py b/llava/model/language_model/llava_mixtral.py index a9090abc2..ca6c25da4 100755 --- a/llava/model/language_model/llava_mixtral.py +++ b/llava/model/language_model/llava_mixtral.py @@ -68,24 +68,44 @@ def forward( images: Optional[torch.FloatTensor] = None, image_sizes: Optional[List[List[int]]] = None, return_dict: Optional[bool] = None, + modalities: Optional[List[str]] = ["image"], + dpo_forward: Optional[bool] = None, cache_position=None, ) -> Union[Tuple, CausalLMOutputWithPast]: if inputs_embeds is None: - (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes) - - return super().forward( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - labels=labels, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) + (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, modalities, image_sizes) + + if dpo_forward: + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + return logits, labels + + else: + return super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) @torch.no_grad() def generate( @@ -93,6 +113,7 @@ def generate( inputs: Optional[torch.Tensor] = None, images: Optional[torch.Tensor] = None, image_sizes: Optional[torch.Tensor] = None, + modalities: Optional[List[str]] = ["image"], **kwargs, ) -> Union[GenerateOutput, torch.LongTensor]: position_ids = kwargs.pop("position_ids", None) @@ -101,7 +122,7 @@ def generate( raise NotImplementedError("`inputs_embeds` is not supported") if images is not None: - (inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, image_sizes=image_sizes) + (inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, modalities, image_sizes=image_sizes) else: inputs_embeds = self.get_model().embed_tokens(inputs) diff --git a/llava/model/language_model/llava_mpt.py b/llava/model/language_model/llava_mpt.py old mode 100644 new mode 100755 diff --git a/llava/model/language_model/llava_qwen.py b/llava/model/language_model/llava_qwen.py index 1f681cef2..c19474761 100755 --- a/llava/model/language_model/llava_qwen.py +++ b/llava/model/language_model/llava_qwen.py @@ -74,24 +74,44 @@ def forward( images: Optional[torch.FloatTensor] = None, image_sizes: Optional[List[List[int]]] = None, return_dict: Optional[bool] = None, + modalities: Optional[List[str]] = ["image"], + dpo_forward: Optional[bool] = False, cache_position=None, ) -> Union[Tuple, CausalLMOutputWithPast]: if inputs_embeds is None: - (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes) - - return super().forward( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - labels=labels, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) + (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, modalities, image_sizes) + + if dpo_forward: + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + return logits, labels + + else: + return super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) @torch.no_grad() def generate( @@ -99,6 +119,7 @@ def generate( inputs: Optional[torch.Tensor] = None, images: Optional[torch.Tensor] = None, image_sizes: Optional[torch.Tensor] = None, + modalities: Optional[List[str]] = ["image"], **kwargs, ) -> Union[GenerateOutput, torch.LongTensor]: position_ids = kwargs.pop("position_ids", None) @@ -107,7 +128,7 @@ def generate( raise NotImplementedError("`inputs_embeds` is not supported") if images is not None: - (inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, image_sizes=image_sizes) + (inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, modalities, image_sizes=image_sizes) else: inputs_embeds = self.get_model().embed_tokens(inputs) diff --git a/llava/model/language_model/llava_qwen_moe.py b/llava/model/language_model/llava_qwen_moe.py index 08c39667e..618a48223 100755 --- a/llava/model/language_model/llava_qwen_moe.py +++ b/llava/model/language_model/llava_qwen_moe.py @@ -74,24 +74,44 @@ def forward( images: Optional[torch.FloatTensor] = None, image_sizes: Optional[List[List[int]]] = None, return_dict: Optional[bool] = None, + modalities: Optional[List[str]] = ["image"], + dpo_forward: Optional[bool] = False, cache_position=None, ) -> Union[Tuple, CausalLMOutputWithPast]: if inputs_embeds is None: - (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes) - - return super().forward( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - labels=labels, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) + (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, modalities, image_sizes) + + if dpo_forward: + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + return logits, labels + + else: + return super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) @torch.no_grad() def generate( @@ -99,6 +119,7 @@ def generate( inputs: Optional[torch.Tensor] = None, images: Optional[torch.Tensor] = None, image_sizes: Optional[torch.Tensor] = None, + modalities: Optional[List[str]] = ["image"], **kwargs, ) -> Union[GenerateOutput, torch.LongTensor]: position_ids = kwargs.pop("position_ids", None) @@ -107,7 +128,7 @@ def generate( raise NotImplementedError("`inputs_embeds` is not supported") if images is not None: - (inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, image_sizes=image_sizes) + (inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, modalities, image_sizes=image_sizes) else: inputs_embeds = self.get_model().embed_tokens(inputs) diff --git a/llava/model/llava_arch.py b/llava/model/llava_arch.py old mode 100644 new mode 100755 index 688c3bcb7..ea98ae850 --- a/llava/model/llava_arch.py +++ b/llava/model/llava_arch.py @@ -15,9 +15,11 @@ from abc import ABC, abstractmethod +import math +import re +import time import torch import torch.nn as nn - from .multimodal_encoder.builder import build_vision_tower from .multimodal_resampler.builder import build_vision_resampler from .multimodal_projector.builder import build_vision_projector @@ -25,7 +27,8 @@ from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN from llava.mm_utils import get_anyres_image_grid_shape -from llava.utils import rank0_print +from llava.utils import rank0_print, rank_print +import random class LlavaMetaModel: @@ -90,6 +93,14 @@ def initialize_vision_modules(self, model_args, fsdp=None): self.config.mm_vision_select_feature = mm_vision_select_feature self.config.mm_patch_merge_type = mm_patch_merge_type + + if not hasattr(self.config, 'add_faster_video'): + if model_args.add_faster_video: + embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype)) + self.faster_token = nn.Parameter( + torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std + ) + if getattr(self, "mm_projector", None) is None: self.mm_projector = build_vision_projector(self.config, vision_cfg=vision_tower.config) @@ -157,28 +168,136 @@ def get_model(self): def get_vision_tower(self): return self.get_model().get_vision_tower() + def get_2dPool(self, image_feature, stride=2): + height = width = self.get_vision_tower().num_patches_per_side + num_frames, num_tokens, num_dim = image_feature.shape + image_feature = image_feature.view(num_frames, height, width, -1) + image_feature = image_feature.permute(0, 3, 1, 2).contiguous() + # image_feature = nn.functional.max_pool2d(image_feature, self.config.mm_spatial_pool_stride) + if self.config.mm_spatial_pool_mode == "average": + image_feature = nn.functional.avg_pool2d(image_feature, stride) + elif self.config.mm_spatial_pool_mode == "max": + image_feature = nn.functional.max_pool2d(image_feature, stride) + elif self.config.mm_spatial_pool_mode == "bilinear": + height, width = image_feature.shape[2:] + scaled_shape = [math.ceil(height / stride), math.ceil(width / stride)] + image_feature = nn.functional.interpolate(image_feature, size=scaled_shape, mode='bilinear') + + else: + raise ValueError(f"Unexpected mm_spatial_pool_mode: {self.config.mm_spatial_pool_mode}") + image_feature = image_feature.permute(0, 2, 3, 1) + image_feature = image_feature.view(num_frames, -1, num_dim) + return image_feature + def encode_images(self, images): image_features = self.get_model().get_vision_tower()(images) - image_features = self.get_model().vision_resampler(image_features, images=images) + # image_features = self.get_model().vision_resampler(image_features, images=images) image_features = self.get_model().mm_projector(image_features) return image_features - - def prepare_inputs_labels_for_multimodal(self, input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes=None): + + def encode_multimodals(self, videos_or_images, video_idx_in_batch, split_sizes=None): + videos_or_images_features = self.get_model().get_vision_tower()(videos_or_images) + per_videos_or_images_features = torch.split(videos_or_images_features, split_sizes, dim=0) # tuple, (dim_1, 576, 4096) + all_videos_or_images_features = [] + all_faster_video_features = [] + cur_mm_spatial_pool_stride = self.config.mm_spatial_pool_stride + + for idx, feat in enumerate(per_videos_or_images_features): + + feat = self.get_model().mm_projector(feat) + faster_video_feature = 0 + slower_img_feat = 0 + if idx in video_idx_in_batch and cur_mm_spatial_pool_stride > 1: + slower_img_feat = self.get_2dPool(feat,cur_mm_spatial_pool_stride) + if self.config.add_faster_video: + cur_mm_spatial_pool_stride = cur_mm_spatial_pool_stride * 2 + faster_video_feature = self.get_2dPool(feat,cur_mm_spatial_pool_stride) + if slower_img_feat is not 0: + all_videos_or_images_features.append(slower_img_feat) + else: + all_videos_or_images_features.append(feat) + all_faster_video_features.append(faster_video_feature) + return all_videos_or_images_features,all_faster_video_features + + def add_token_per_grid(self, image_feature): + resize_h = int(math.sqrt(image_feature.shape[1])) + num_frames = image_feature.shape[0] + feature_dim = image_feature.shape[-1] + + image_feature = image_feature.view(num_frames, 1, resize_h, resize_h, -1) + image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() + image_feature = image_feature.flatten(1, 2).flatten(2, 3) + image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1) + if getattr(self.config, "add_faster_video", False): + # import pdb; pdb.set_trace() + # (3584, 832, 14) -> (3584, 64, 13, 14) + image_feature = image_feature.view(feature_dim, num_frames,resize_h, -1) + # (3584, 64, 13, 14) -> (64, 13, 14, 3584) + image_feature = image_feature.permute(1, 2, 3, 0).contiguous() + # (64, 13, 14, 3584) -> (64, 13*14, 3584) + image_feature = image_feature.flatten(1, 2) + # import pdb; pdb.set_trace() + return image_feature + # import pdb; pdb.set_trace() + image_feature = image_feature.flatten(1, 2).transpose(0, 1) + return image_feature + + def add_token_per_frame(self, image_feature): + image_feature = image_feature.permute(2, 0, 1).contiguous() + image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1) + image_feature = image_feature.permute(1, 2, 0).contiguous() + return image_feature + + def prepare_inputs_labels_for_multimodal(self, input_ids, position_ids, attention_mask, past_key_values, labels, images, modalities=["image"], image_sizes=None): vision_tower = self.get_vision_tower() + # rank_print(modalities) if vision_tower is None or images is None or input_ids.shape[1] == 1: return input_ids, position_ids, attention_mask, past_key_values, None, labels + if isinstance(modalities, str): + modalities = [modalities] + + # import pdb; pdb.set_trace() if type(images) is list or images.ndim == 5: if type(images) is list: images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images] - concat_images = torch.cat([image for image in images], dim=0) - image_features = self.encode_images(concat_images) - split_sizes = [image.shape[0] for image in images] - image_features = torch.split(image_features, split_sizes, dim=0) + + video_idx_in_batch = [] + for _ in range(len(modalities)): + if modalities[_] == "video": + video_idx_in_batch.append(_) + + images_list = [] + for image in images: + if image.ndim == 4: + images_list.append(image) + else: + images_list.append(image.unsqueeze(0)) + + concat_images = torch.cat([image for image in images_list], dim=0) + split_sizes = [image.shape[0] for image in images_list] + encoded_image_features = self.encode_images(concat_images) + # image_features,all_faster_video_features = self.encode_multimodals(concat_images, video_idx_in_batch, split_sizes) + + # This is a list, each element is [num_images, patch * patch, dim] + # rank_print(f"Concat images : {concat_images.shape}") + encoded_image_features = torch.split(encoded_image_features, split_sizes) + image_features = [] + for idx, image_feat in enumerate(encoded_image_features): + if idx in video_idx_in_batch: + image_features.append(self.get_2dPool(image_feat)) + else: + image_features.append(image_feat) + # image_features = self.encode_multimodals(concat_images, video_idx_in_batch, split_sizes) + # rank_print(f"Encoded image feats : {[x.shape for x in image_features]}") + # image_features = torch.split(image_features, split_sizes, dim=0) mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat") image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square") + mm_newline_position = getattr(self.config, "mm_newline_position", "one_token") + if mm_patch_merge_type == "flat": image_features = [x.flatten(0, 1) for x in image_features] + elif mm_patch_merge_type.startswith("spatial"): new_image_features = [] for image_idx, image_feature in enumerate(image_features): @@ -186,26 +305,91 @@ def prepare_inputs_labels_for_multimodal(self, input_ids, position_ids, attentio # num_patches = h * w, where h = w = sqrt(num_patches) # currently image_feature is a tensor of shape (4, num_patches, hidden_size) # we want to first unflatten it to (2, 2, h, w, hidden_size) - - if image_feature.shape[0] > 1: + # rank0_print("At least we are reaching here") + # import pdb; pdb.set_trace() + if image_idx in video_idx_in_batch: # video operations + # rank0_print("Video") + if mm_newline_position == "grid": + # Grid-wise + image_feature = self.add_token_per_grid(image_feature) + if getattr(self.config, "add_faster_video", False): + faster_video_feature = self.add_token_per_grid(all_faster_video_features[image_idx]) + # Add a token for each frame + concat_slow_fater_token = [] + # import pdb; pdb.set_trace() + for _ in range(image_feature.shape[0]): + if _ % self.config.faster_token_stride == 0: + concat_slow_fater_token.append(torch.cat((image_feature[_], self.model.faster_token[None].to(image_feature.device)), dim=0)) + else: + concat_slow_fater_token.append(torch.cat((faster_video_feature[_], self.model.faster_token[None].to(image_feature.device)), dim=0)) + # import pdb; pdb.set_trace() + image_feature = torch.cat(concat_slow_fater_token) + + # print("!!!!!!!!!!!!") + + new_image_features.append(image_feature) + elif mm_newline_position == "frame": + # Frame-wise + image_feature = self.add_token_per_frame(image_feature) + + new_image_features.append(image_feature.flatten(0, 1)) + + elif mm_newline_position == "one_token": + # one-token + image_feature = image_feature.flatten(0, 1) + if 'unpad' in mm_patch_merge_type: + image_feature = torch.cat(( + image_feature, + self.model.image_newline[None].to(image_feature.device) + ), dim=0) + new_image_features.append(image_feature) + elif mm_newline_position == "no_token": + new_image_features.append(image_feature.flatten(0, 1)) + else: + raise ValueError(f"Unexpected mm_newline_position: {mm_newline_position}") + elif image_feature.shape[0] > 1: # multi patches and multi images operations + # rank0_print("Single-images") base_image_feature = image_feature[0] image_feature = image_feature[1:] height = width = self.get_vision_tower().num_patches_per_side assert height * width == base_image_feature.shape[0] - if image_aspect_ratio == "anyres": + + if "anyres_max" in image_aspect_ratio: + matched_anyres_max_num_patches = re.match(r"anyres_max_(\d+)", image_aspect_ratio) + if matched_anyres_max_num_patches: + max_num_patches = int(matched_anyres_max_num_patches.group(1)) + + if image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio: if hasattr(self.get_vision_tower(), "image_size"): vision_tower_image_size = self.get_vision_tower().image_size else: raise ValueError("vision_tower_image_size is not found in the vision tower.") - num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_sizes[image_idx], self.config.image_grid_pinpoints, vision_tower_image_size) + try: + num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_sizes[image_idx], self.config.image_grid_pinpoints, vision_tower_image_size) + except Exception as e: + rank0_print(f"Error: {e}") + num_patch_width, num_patch_height = 2, 2 image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1) else: image_feature = image_feature.view(2, 2, height, width, -1) + if "maxpool2x2" in mm_patch_merge_type: image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() image_feature = image_feature.flatten(1, 2).flatten(2, 3) image_feature = nn.functional.max_pool2d(image_feature, 2) image_feature = image_feature.flatten(1, 2).transpose(0, 1) + elif "unpad" in mm_patch_merge_type and "anyres_max" in image_aspect_ratio and matched_anyres_max_num_patches: + unit = image_feature.shape[2] + image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() + image_feature = image_feature.flatten(1, 2).flatten(2, 3) + image_feature = unpad_image(image_feature, image_sizes[image_idx]) + c, h, w = image_feature.shape + times = math.sqrt(h * w / (max_num_patches * unit**2)) + if times > 1.1: + image_feature = image_feature[None] + image_feature = nn.functional.interpolate(image_feature, [int(h // times), int(w // times)], mode="bilinear")[0] + image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1) + image_feature = image_feature.flatten(1, 2).transpose(0, 1) elif "unpad" in mm_patch_merge_type: image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() image_feature = image_feature.flatten(1, 2).flatten(2, 3) @@ -219,11 +403,13 @@ def prepare_inputs_labels_for_multimodal(self, input_ids, position_ids, attentio pass else: image_feature = torch.cat((base_image_feature, image_feature), dim=0) - else: + new_image_features.append(image_feature) + else: # single image operations image_feature = image_feature[0] if "unpad" in mm_patch_merge_type: image_feature = torch.cat((image_feature, self.model.image_newline[None]), dim=0) - new_image_features.append(image_feature) + + new_image_features.append(image_feature) image_features = new_image_features else: raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}") @@ -233,6 +419,7 @@ def prepare_inputs_labels_for_multimodal(self, input_ids, position_ids, attentio # TODO: image start / end is not implemented here to support pretraining. if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr(self.config, "mm_use_im_start_end", False): raise NotImplementedError + # rank_print(f"Total images : {len(image_features)}") # Let's just add dummy tensors if they do not exist, # it is a headache to deal with None all the time. @@ -258,8 +445,10 @@ def prepare_inputs_labels_for_multimodal(self, input_ids, position_ids, attentio new_input_embeds = [] new_labels = [] cur_image_idx = 0 + # rank_print("Inserting Images embedding") for batch_idx, cur_input_ids in enumerate(input_ids): num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() + # rank0_print(num_images) if num_images == 0: cur_image_features = image_features[cur_image_idx] cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids) @@ -286,13 +475,17 @@ def prepare_inputs_labels_for_multimodal(self, input_ids, position_ids, attentio cur_new_input_embeds.append(cur_input_embeds_no_im[i]) cur_new_labels.append(cur_labels_noim[i]) if i < num_images: - cur_image_features = image_features[cur_image_idx] + try: + cur_image_features = image_features[cur_image_idx] + except IndexError: + cur_image_features = image_features[cur_image_idx - 1] cur_image_idx += 1 cur_new_input_embeds.append(cur_image_features) cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype)) cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds] + # import pdb; pdb.set_trace() cur_new_input_embeds = torch.cat(cur_new_input_embeds) cur_new_labels = torch.cat(cur_new_labels) @@ -301,9 +494,14 @@ def prepare_inputs_labels_for_multimodal(self, input_ids, position_ids, attentio # Truncate sequences to max length as image embeddings can make the sequence longer tokenizer_model_max_length = getattr(self.config, "tokenizer_model_max_length", None) - if tokenizer_model_max_length is not None: - new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds] - new_labels = [x[:tokenizer_model_max_length] for x in new_labels] + # rank_print("Finishing Inserting") + + new_input_embeds = [x[:tokenizer_model_max_length] for x, modality in zip(new_input_embeds, modalities)] + new_labels = [x[:tokenizer_model_max_length] for x, modality in zip(new_labels, modalities)] + # TODO: Hard code for control loss spike + # if tokenizer_model_max_length is not None: + # new_input_embeds = [x[:4096] if modality != "video" else x[:tokenizer_model_max_length] for x, modality in zip(new_input_embeds, modalities)] + # new_labels = [x[:4096] if modality != "video" else x[:tokenizer_model_max_length] for x, modality in zip(new_labels, modalities)] # Combine them max_len = max(x.shape[0] for x in new_input_embeds) @@ -313,6 +511,7 @@ def prepare_inputs_labels_for_multimodal(self, input_ids, position_ids, attentio new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device) attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device) position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device) + # rank0_print("Prepare pos id") for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)): cur_len = cur_new_embed.shape[0] @@ -330,6 +529,7 @@ def prepare_inputs_labels_for_multimodal(self, input_ids, position_ids, attentio position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) new_input_embeds = torch.stack(new_input_embeds_padded, dim=0) + # rank0_print("tokenizer padding") if _labels is None: new_labels = None @@ -343,7 +543,15 @@ def prepare_inputs_labels_for_multimodal(self, input_ids, position_ids, attentio if _position_ids is None: position_ids = None - + if getattr(self.config, "use_pos_skipping", False) and self.training: + position_ids = torch.arange(new_input_embeds.size(1), device=new_input_embeds.device).unsqueeze(0).to(new_input_embeds.device) + split_position = random.randint(0, new_input_embeds.size(1)) + left_add = random.randint(0, self.config.pos_skipping_range) + right_add = random.randint(left_add, self.config.pos_skipping_range) + position_ids[:, :split_position] += left_add + position_ids[:, split_position:] += right_add + # import pdb; pdb.set_trace() + # rank0_print("Finish preparing") return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels def initialize_vision_tokenizer(self, model_args, tokenizer): diff --git a/llava/model/make_delta.py b/llava/model/make_delta.py old mode 100644 new mode 100755 diff --git a/llava/model/multimodal_encoder/builder.py b/llava/model/multimodal_encoder/builder.py old mode 100644 new mode 100755 index 1236b359b..0a7b1600f --- a/llava/model/multimodal_encoder/builder.py +++ b/llava/model/multimodal_encoder/builder.py @@ -1,7 +1,13 @@ import os from .clip_encoder import CLIPVisionTower +from .imagebind import ImageBindWrapper from .open_clip_encoder import OpenCLIPVisionTower +from .hf_vision import HFVisionTower +from .siglip_encoder import SigLipVisionTower from .clip_encoder import CLIPVisionTower, CLIPVisionTowerS2 +from .mlcd_encoder import MLCDVisionTower, MLCDVisionTowerS2 +# from .eva_clip.eva_clip_encoder import EvaClipVisionTower +# from .dev_eva_clip.eva_vit import EvaViTWrapper def build_vision_tower(vision_tower_cfg, **kwargs): @@ -13,7 +19,23 @@ def build_vision_tower(vision_tower_cfg, **kwargs): return CLIPVisionTowerS2(vision_tower, args=vision_tower_cfg, **kwargs) else: return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) + elif "siglip" in vision_tower: + return SigLipVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, **kwargs) + elif vision_tower.startswith("hf:"): + return HFVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) + elif vision_tower in ["imagebind_huge"]: + return ImageBindWrapper(vision_tower, args=vision_tower_cfg, **kwargs) elif vision_tower.startswith("open_clip_hub"): return OpenCLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) + elif "mlcd-vit-bigG-patch14" in vision_tower: + if use_s2: + return MLCDVisionTowerS2(vision_tower, args=vision_tower_cfg, **kwargs) + else: + return MLCDVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) + + # elif "internal-eva" in vision_tower.lower() or "eva02" in vision_tower.lower(): + # return EvaClipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) + # elif vision_tower in ["EVA-CLIP-8B", "EVA-CLIP-8B-plus"]: + # return EvaViTWrapper(vision_tower, args=vision_tower_cfg, **kwargs) raise ValueError(f"Unknown vision tower: {vision_tower}") diff --git a/llava/model/multimodal_encoder/clip_encoder.py b/llava/model/multimodal_encoder/clip_encoder.py old mode 100644 new mode 100755 index 7462f9f1e..212b262b3 --- a/llava/model/multimodal_encoder/clip_encoder.py +++ b/llava/model/multimodal_encoder/clip_encoder.py @@ -152,21 +152,19 @@ def load_model(self, device_map=None): self.is_loaded = True - @torch.no_grad() def forward_feature(self, images): image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) image_features = self.feature_select(image_forward_outs).to(images.dtype) return image_features - @torch.no_grad() def forward(self, images): if type(images) is list: image_features = [] for image in images: - image_feature = multiscale_forward(self.forward_feature, image.unsqueeze(0), img_sizes=self.s2_scales, max_split_size=self.s2_split_size) + image_feature = multiscale_forward(self.forward_feature, image.unsqueeze(0), img_sizes=self.s2_scales, max_split_size=self.s2_split_size, split_forward=True) image_features.append(image_feature) else: - image_features = multiscale_forward(self.forward_feature, images, img_sizes=self.s2_scales, max_split_size=self.s2_split_size) + image_features = multiscale_forward(self.forward_feature, images, img_sizes=self.s2_scales, max_split_size=self.s2_split_size, split_forward=True) return image_features diff --git a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__init__.py b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__init__.py new file mode 100755 index 000000000..ede690054 --- /dev/null +++ b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/__init__.py @@ -0,0 +1,9 @@ +from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD +from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer +from .factory import list_models, add_model_config, get_model_config, load_checkpoint +from .loss import ClipLoss +from .model import CLIP, CustomCLIP, CLIPTextCfg, CLIPVisionCfg, convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype +from .openai import load_openai_model, list_openai_models +from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained +from .tokenizer import SimpleTokenizer, tokenize +from .transform import image_transform diff --git a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/bpe_simple_vocab_16e6.txt.gz b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/bpe_simple_vocab_16e6.txt.gz new file mode 100755 index 000000000..7b5088a52 Binary files /dev/null and b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/bpe_simple_vocab_16e6.txt.gz differ diff --git a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/constants.py b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/constants.py new file mode 100755 index 000000000..a670bb3fa --- /dev/null +++ b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/constants.py @@ -0,0 +1,2 @@ +OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) +OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) diff --git a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/eva_vit_model.py b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/eva_vit_model.py new file mode 100755 index 000000000..23cb38c92 --- /dev/null +++ b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/eva_vit_model.py @@ -0,0 +1,571 @@ +# -------------------------------------------------------- +# Adapted from https://github.com/microsoft/unilm/tree/master/beit +# -------------------------------------------------------- +import math +import os +import torch +import torch.nn as nn +import torch.nn.functional as F + +try: + from timm.models.layers import drop_path, to_2tuple, trunc_normal_ +except: + from timm.layers import drop_path, to_2tuple, trunc_normal_ + +from .transformer import PatchDropout +from .rope import VisionRotaryEmbedding, VisionRotaryEmbeddingFast + +if os.getenv("ENV_TYPE") == "deepspeed": + try: + from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint + except: + from torch.utils.checkpoint import checkpoint +else: + from torch.utils.checkpoint import checkpoint + +try: + import xformers.ops as xops +except ImportError: + xops = None + # print("Please 'pip install xformers'") + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return "p={}".format(self.drop_prob) + + +class Mlp(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + drop=0.0, + subln=False, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + + self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity() + + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + # x = self.drop(x) + # commit this for the orignal BERT implement + x = self.ffn_ln(x) + + x = self.fc2(x) + x = self.drop(x) + return x + + +class SwiGLU(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0.0, norm_layer=nn.LayerNorm, subln=False): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + + self.w1 = nn.Linear(in_features, hidden_features) + self.w2 = nn.Linear(in_features, hidden_features) + + self.act = act_layer() + self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity() + self.w3 = nn.Linear(hidden_features, out_features) + + self.drop = nn.Dropout(drop) + + def forward(self, x): + x1 = self.w1(x) + x2 = self.w2(x) + hidden = self.act(x1) * x2 + x = self.ffn_ln(hidden) + x = self.w3(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0, window_size=None, attn_head_dim=None, xattn=False, rope=None, subln=False, norm_layer=nn.LayerNorm): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + if attn_head_dim is not None: + head_dim = attn_head_dim + all_head_dim = head_dim * self.num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.subln = subln + if self.subln: + self.q_proj = nn.Linear(dim, all_head_dim, bias=False) + self.k_proj = nn.Linear(dim, all_head_dim, bias=False) + self.v_proj = nn.Linear(dim, all_head_dim, bias=False) + else: + self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) + + if qkv_bias: + self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) + self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) + else: + self.q_bias = None + self.v_bias = None + + if window_size: + self.window_size = window_size + self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 + self.relative_position_bias_table = nn.Parameter(torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH + # cls to token & token 2 cls & cls to cls + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(window_size[0]) + coords_w = torch.arange(window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * window_size[1] - 1 + relative_position_index = torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype) + relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + relative_position_index[0, 0:] = self.num_relative_distance - 3 + relative_position_index[0:, 0] = self.num_relative_distance - 2 + relative_position_index[0, 0] = self.num_relative_distance - 1 + + self.register_buffer("relative_position_index", relative_position_index) + else: + self.window_size = None + self.relative_position_bias_table = None + self.relative_position_index = None + + self.attn_drop = nn.Dropout(attn_drop) + self.inner_attn_ln = norm_layer(all_head_dim) if subln else nn.Identity() + # self.proj = nn.Linear(all_head_dim, all_head_dim) + self.proj = nn.Linear(all_head_dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.xattn = xattn + self.xattn_drop = attn_drop + + self.rope = rope + + def forward(self, x, rel_pos_bias=None, attn_mask=None): + B, N, C = x.shape + if self.subln: + q = F.linear(input=x, weight=self.q_proj.weight, bias=self.q_bias) + k = F.linear(input=x, weight=self.k_proj.weight, bias=None) + v = F.linear(input=x, weight=self.v_proj.weight, bias=self.v_bias) + + q = q.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) # B, num_heads, N, C + k = k.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) + v = v.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) + else: + + qkv_bias = None + if self.q_bias is not None: + qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) + + qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) + qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) # 3, B, num_heads, N, C + q, k, v = qkv[0], qkv[1], qkv[2] + + if self.rope: + # slightly fast impl + q_t = q[:, :, 1:, :] + ro_q_t = self.rope(q_t) + q = torch.cat((q[:, :, :1, :], ro_q_t), -2).type_as(v) + + k_t = k[:, :, 1:, :] + ro_k_t = self.rope(k_t) + k = torch.cat((k[:, :, :1, :], ro_k_t), -2).type_as(v) + + if self.xattn: + q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) + + x = xops.memory_efficient_attention( + q, + k, + v, + p=self.xattn_drop, + scale=self.scale, + ) + x = x.reshape(B, N, -1) + x = self.inner_attn_ln(x) + x = self.proj(x) + x = self.proj_drop(x) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + + if self.relative_position_bias_table is not None: + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0).type_as(attn) + + if rel_pos_bias is not None: + attn = attn + rel_pos_bias.type_as(attn) + + if attn_mask is not None: + attn_mask = attn_mask.bool() + attn = attn.masked_fill(~attn_mask[:, None, None, :], float("-inf")) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, -1) + x = self.inner_attn_ln(x) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + + def __init__( + self, + dim, + num_heads, + mlp_ratio=4.0, + qkv_bias=False, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + init_values=None, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + window_size=None, + attn_head_dim=None, + xattn=False, + rope=None, + postnorm=False, + subln=False, + naiveswiglu=False, + ): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim, xattn=xattn, rope=rope, subln=subln, norm_layer=norm_layer + ) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + + if naiveswiglu: + self.mlp = SwiGLU( + in_features=dim, + hidden_features=mlp_hidden_dim, + subln=subln, + norm_layer=norm_layer, + ) + else: + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, subln=subln, drop=drop) + + if init_values is not None and init_values > 0: + self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) + self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) + else: + self.gamma_1, self.gamma_2 = None, None + + self.postnorm = postnorm + + def forward(self, x, rel_pos_bias=None, attn_mask=None): + if self.gamma_1 is None: + if self.postnorm: + x = x + self.drop_path(self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))) + x = x + self.drop_path(self.norm2(self.mlp(x))) + else: + x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)) + x = x + self.drop_path(self.mlp(self.norm2(x))) + else: + if self.postnorm: + x = x + self.drop_path(self.gamma_1 * self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))) + x = x + self.drop_path(self.gamma_2 * self.norm2(self.mlp(x))) + else: + x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)) + x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) + return x + + +class PatchEmbed(nn.Module): + """Image to Patch Embedding""" + + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) + self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x, **kwargs): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + assert H == self.img_size[0] and W == self.img_size[1], f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) + return x + + +class RelativePositionBias(nn.Module): + + def __init__(self, window_size, num_heads): + super().__init__() + self.window_size = window_size + self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 + self.relative_position_bias_table = nn.Parameter(torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH + # cls to token & token 2 cls & cls to cls + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(window_size[0]) + coords_w = torch.arange(window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * window_size[1] - 1 + relative_position_index = torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype) + relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + relative_position_index[0, 0:] = self.num_relative_distance - 3 + relative_position_index[0:, 0] = self.num_relative_distance - 2 + relative_position_index[0, 0] = self.num_relative_distance - 1 + + self.register_buffer("relative_position_index", relative_position_index) + + def forward(self): + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH + return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + + +class EVAVisionTransformer(nn.Module): + """Vision Transformer with support for patch or hybrid CNN input stage""" + + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + num_classes=1000, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=False, + qk_scale=None, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.0, + norm_layer=nn.LayerNorm, + init_values=None, + patch_dropout=0.0, + use_abs_pos_emb=True, + use_rel_pos_bias=False, + use_shared_rel_pos_bias=False, + rope=False, + use_mean_pooling=True, + init_scale=0.001, + grad_checkpointing=False, + xattn=False, + postnorm=False, + pt_hw_seq_len=16, + intp_freq=False, + naiveswiglu=False, + subln=False, + ): + super().__init__() + self.image_size = img_size + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + + self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + if use_abs_pos_emb: + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) + else: + self.pos_embed = None + self.pos_drop = nn.Dropout(p=drop_rate) + + if use_shared_rel_pos_bias: + self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads) + else: + self.rel_pos_bias = None + + if rope: + half_head_dim = embed_dim // num_heads // 2 + hw_seq_len = img_size // patch_size + self.rope = VisionRotaryEmbeddingFast( + dim=half_head_dim, + pt_seq_len=pt_hw_seq_len, + ft_seq_len=hw_seq_len if intp_freq else None, + # patch_dropout=patch_dropout + ) + else: + self.rope = None + + self.naiveswiglu = naiveswiglu + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.use_rel_pos_bias = use_rel_pos_bias + self.blocks = nn.ModuleList( + [ + Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + init_values=init_values, + window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None, + xattn=xattn, + rope=self.rope, + postnorm=postnorm, + subln=subln, + naiveswiglu=naiveswiglu, + ) + for i in range(depth) + ] + ) + self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim) + self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None + self.head = nn.Linear(embed_dim, num_classes, bias=qkv_bias) if num_classes > 0 else nn.Identity() + + if self.pos_embed is not None: + trunc_normal_(self.pos_embed, std=0.02) + + trunc_normal_(self.cls_token, std=0.02) + + self.apply(self._init_weights) + self.fix_init_weight() + + if isinstance(self.head, nn.Linear): + trunc_normal_(self.head.weight, std=0.02) + self.head.weight.data.mul_(init_scale) + if self.head.bias is not None: + self.head.bias.data.mul_(init_scale) + + # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn + self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0.0 else nn.Identity() + + self.grad_checkpointing = grad_checkpointing + + def fix_init_weight(self): + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + + for layer_id, layer in enumerate(self.blocks): + rescale(layer.attn.proj.weight.data, layer_id + 1) + if self.naiveswiglu: + rescale(layer.mlp.w3.weight.data, layer_id + 1) + else: + rescale(layer.mlp.fc2.weight.data, layer_id + 1) + + def get_cast_dtype(self) -> torch.dtype: + return self.blocks[0].mlp.fc2.weight.dtype + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def get_num_layers(self): + return len(self.blocks) + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + assert unlocked_groups == 0, "partial locking not currently supported for this model" + for param in self.parameters(): + param.requires_grad = False + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + + @torch.jit.ignore + def no_weight_decay(self): + return {"pos_embed", "cls_token"} + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=""): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x, return_all_features=False): + + x = self.patch_embed(x) + batch_size, seq_len, _ = x.size() + + cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + if self.pos_embed is not None: + x = x + self.pos_embed + x = self.pos_drop(x) + + # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in + # if os.getenv("RoPE") == "1": + # if self.training and not isinstance(self.patch_dropout, nn.Identity): + # x, patch_indices_keep = self.patch_dropout(x) + # self.rope.forward = partial(self.rope.forward, patch_indices_keep=patch_indices_keep) + # else: + # self.rope.forward = partial(self.rope.forward, patch_indices_keep=None) + # x = self.patch_dropout(x) + # else: + x = self.patch_dropout(x) + + rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None + for blk in self.blocks: + if self.grad_checkpointing: + x = checkpoint(blk, x, (rel_pos_bias,)) + else: + x = blk(x, rel_pos_bias=rel_pos_bias) + + if not return_all_features: + x = self.norm(x) + if self.fc_norm is not None: + return self.fc_norm(x.mean(1)) + else: + return x[:, 0] + return x + + def forward(self, x, return_all_features=False): + if return_all_features: + return self.forward_features(x, return_all_features) + x = self.forward_features(x) + x = self.head(x) + return x diff --git a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/factory.py b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/factory.py new file mode 100755 index 000000000..535eeacb7 --- /dev/null +++ b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/factory.py @@ -0,0 +1,528 @@ +import json +import logging +import os +import pathlib +import re +from copy import deepcopy +from pathlib import Path +from typing import Optional, Tuple, Union, Dict, Any +import torch + +try: + import deepspeed +except ImportError: + deepspeed = None + +from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD +from .model import CLIP, CustomCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict, get_cast_dtype +from .openai import load_openai_model +from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model +from .transform import image_transform +from .tokenizer import HFTokenizer, tokenize +from .utils import resize_clip_pos_embed, resize_evaclip_pos_embed, resize_visual_pos_embed, resize_eva_pos_embed + + +_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"] +_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs + + +def _natural_key(string_): + return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())] + + +def _rescan_model_configs(): + global _MODEL_CONFIGS + + config_ext = (".json",) + config_files = [] + for config_path in _MODEL_CONFIG_PATHS: + if config_path.is_file() and config_path.suffix in config_ext: + config_files.append(config_path) + elif config_path.is_dir(): + for ext in config_ext: + config_files.extend(config_path.glob(f"*{ext}")) + + for cf in config_files: + with open(cf, "r", encoding="utf8") as f: + model_cfg = json.load(f) + if all(a in model_cfg for a in ("embed_dim", "vision_cfg", "text_cfg")): + _MODEL_CONFIGS[cf.stem] = model_cfg + + _MODEL_CONFIGS = dict(sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))) + + +_rescan_model_configs() # initial populate of model config registry + + +def list_models(): + """enumerate available model architectures based on config files""" + return list(_MODEL_CONFIGS.keys()) + + +def add_model_config(path): + """add model config path or file and update registry""" + if not isinstance(path, Path): + path = Path(path) + _MODEL_CONFIG_PATHS.append(path) + _rescan_model_configs() + + +def get_model_config(model_name): + if model_name in _MODEL_CONFIGS: + return deepcopy(_MODEL_CONFIGS[model_name]) + else: + return None + + +def get_tokenizer(model_name): + config = get_model_config(model_name) + tokenizer = HFTokenizer(config["text_cfg"]["hf_tokenizer_name"]) if "hf_tokenizer_name" in config["text_cfg"] else tokenize + return tokenizer + + +# loading openai CLIP weights when is_openai=True for training +def load_state_dict(checkpoint_path: str, map_location: str = "cpu", model_key: str = "model|module|state_dict", is_openai: bool = False, skip_list: list = []): + if is_openai: + model = torch.jit.load(checkpoint_path, map_location="cpu").eval() + state_dict = model.state_dict() + for key in ["input_resolution", "context_length", "vocab_size"]: + state_dict.pop(key, None) + else: + checkpoint = torch.load(checkpoint_path, map_location=map_location) + for mk in model_key.split("|"): + if isinstance(checkpoint, dict) and mk in checkpoint: + state_dict = checkpoint[mk] + break + else: + state_dict = checkpoint + if next(iter(state_dict.items()))[0].startswith("module"): + state_dict = {k[7:]: v for k, v in state_dict.items()} + + for k in skip_list: + if k in list(state_dict.keys()): + logging.info(f"Removing key {k} from pretrained checkpoint") + del state_dict[k] + + if os.getenv("RoPE") == "1": + for k in list(state_dict.keys()): + if "freqs_cos" in k or "freqs_sin" in k: + del state_dict[k] + return state_dict + + +def load_checkpoint(model, checkpoint_path, model_key="model|module|state_dict", strict=True): + state_dict = load_state_dict(checkpoint_path, model_key=model_key, is_openai=False) + # detect old format and make compatible with new format + if "positional_embedding" in state_dict and not hasattr(model, "positional_embedding"): + state_dict = convert_to_custom_text_state_dict(state_dict) + if "text.logit_scale" in state_dict and hasattr(model, "logit_scale"): + state_dict["logit_scale"] = state_dict["text.logit_scale"] + del state_dict["text.logit_scale"] + + # resize_clip_pos_embed for CLIP and open CLIP + if "visual.positional_embedding" in state_dict: + resize_clip_pos_embed(state_dict, model) + # specified to eva_vit_model + elif "visual.pos_embed" in state_dict: + resize_evaclip_pos_embed(state_dict, model) + + # resize_clip_pos_embed(state_dict, model) + incompatible_keys = model.load_state_dict(state_dict, strict=strict) + logging.info(f"incompatible_keys.missing_keys: {incompatible_keys.missing_keys}") + return incompatible_keys + + +def load_clip_visual_state_dict(checkpoint_path: str, map_location: str = "cpu", is_openai: bool = False, skip_list: list = []): + state_dict = load_state_dict(checkpoint_path, map_location=map_location, is_openai=is_openai, skip_list=skip_list) + + for k in list(state_dict.keys()): + if not k.startswith("visual."): + del state_dict[k] + for k in list(state_dict.keys()): + if k.startswith("visual."): + new_k = k[7:] + state_dict[new_k] = state_dict[k] + del state_dict[k] + return state_dict + + +def load_clip_text_state_dict(checkpoint_path: str, map_location: str = "cpu", is_openai: bool = False, skip_list: list = []): + state_dict = load_state_dict(checkpoint_path, map_location=map_location, is_openai=is_openai, skip_list=skip_list) + + for k in list(state_dict.keys()): + if k.startswith("visual."): + del state_dict[k] + return state_dict + + +def get_pretrained_tag(pretrained_model): + pretrained_model = pretrained_model.lower() + if "laion" in pretrained_model or "open_clip" in pretrained_model: + return "open_clip" + elif "openai" in pretrained_model: + return "clip" + elif "eva" in pretrained_model and "clip" in pretrained_model: + return "eva_clip" + else: + return "other" + + +def load_zero_partitions(model, state_dict, is_deepspeed_zero3_enabled, pretrained_model_path, ignore_mismatched_sizes=False): + """ + adept from pytorch lightning and transformers + with deepspeed.zero.Init(): + model = MyModel() + state_dict = torch.load(model_path, map_location="cpu") + load_zero_partitions(model, prefix="") + """ + + # because zero3 puts placeholders in model params, this context + # manager gathers (unpartitions) the params of the current layer, then loads from + # the state dict and then re-partitions them again + model_state_dict = model.state_dict() + expected_keys = list(model_state_dict.keys()) + loaded_keys = list(state_dict.keys()) + missing_keys = list(set(expected_keys) - set(loaded_keys)) + unexpected_keys = list(set(loaded_keys) - set(expected_keys)) + + # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not + # matching the weights in the model. + mismatched_keys = [] + if ignore_mismatched_sizes: + for checkpoint_key in loaded_keys: + model_key = checkpoint_key + + if model_key in model_state_dict and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape: + mismatched_keys.append((checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)) + del state_dict[checkpoint_key] + # copy state_dict so _load_from_state_dict can modify it + metadata = getattr(state_dict, "_metadata", None) + state_dict = state_dict.copy() + if metadata is not None: + state_dict._metadata = metadata + + error_msgs = [] + + # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants + # so we need to apply the function recursively. + def load(module, prefix=""): + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) + args = (state_dict, prefix, local_metadata, True, [], [], error_msgs) + if is_deepspeed_zero3_enabled: + # because zero3 puts placeholders in model params, this context + # manager gathers (unpartitions) the params of the current layer, then loads from + # the state dict and then re-partitions them again + with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0): + if torch.distributed.get_rank() == 0: + module._load_from_state_dict(*args) + else: + module._load_from_state_dict(*args) + + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + ".") + + # Make sure we are able to load base models as well as derived models (with heads) + start_prefix = "" + model_to_load = model + load(model_to_load, prefix=start_prefix) + del state_dict + if len(error_msgs) > 0: + error_msg = "\n\t".join(error_msgs) + if "size mismatch" in error_msg: + error_msg += "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method." + raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") + if len(unexpected_keys) > 0: + logging.warning( + f"Some weights of the model checkpoint at {pretrained_model_path} were not used when" + f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" + f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or" + " with another architecture (e.g. initializing a BertForSequenceClassification model from a" + " BertForPreTraining model).\n- This IS NOT expected if you are initializing" + f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical" + " (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)." + ) + else: + logging.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") + if len(missing_keys) > 0: + logging.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" + f" {pretrained_model_path} and are newly initialized: {missing_keys}\nYou should probably" + " TRAIN this model on a down-stream task to be able to use it for predictions and inference." + ) + elif len(mismatched_keys) == 0: + logging.info( + f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at" + f" {pretrained_model_path}.\nIf your task is similar to the task the model of the checkpoint" + f" was trained on, you can already use {model.__class__.__name__} for predictions without further" + " training." + ) + if len(mismatched_keys) > 0: + mismatched_warning = "\n".join([f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" for key, shape1, shape2 in mismatched_keys]) + logging.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" + f" {pretrained_model_path} and are newly initialized because the shapes did not" + f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able" + " to use it for predictions and inference." + ) + + +def load_pretrained_checkpoint(model, visual_checkpoint_path, text_checkpoint_path, strict=True, visual_model=None, text_model=None, model_key="model|module|state_dict", skip_list=[]): + visual_tag = get_pretrained_tag(visual_model) + text_tag = get_pretrained_tag(text_model) + + logging.info(f"num of model state_dict keys: {len(model.state_dict().keys())}") + visual_incompatible_keys, text_incompatible_keys = None, None + if visual_checkpoint_path: + if visual_tag == "eva_clip" or visual_tag == "open_clip": + visual_state_dict = load_clip_visual_state_dict(visual_checkpoint_path, is_openai=False, skip_list=skip_list) + elif visual_tag == "clip": + visual_state_dict = load_clip_visual_state_dict(visual_checkpoint_path, is_openai=True, skip_list=skip_list) + else: + visual_state_dict = load_state_dict(visual_checkpoint_path, model_key=model_key, is_openai=False, skip_list=skip_list) + + # resize_clip_pos_embed for CLIP and open CLIP + if "positional_embedding" in visual_state_dict: + resize_visual_pos_embed(visual_state_dict, model) + # specified to EVA model + elif "pos_embed" in visual_state_dict: + resize_eva_pos_embed(visual_state_dict, model) + + visual_incompatible_keys = model.visual.load_state_dict(visual_state_dict, strict=strict) + logging.info(f"num of loaded visual_state_dict keys: {len(visual_state_dict.keys())}") + logging.info(f"visual_incompatible_keys.missing_keys: {visual_incompatible_keys.missing_keys}") + + if text_checkpoint_path: + if text_tag == "eva_clip" or text_tag == "open_clip": + text_state_dict = load_clip_text_state_dict(text_checkpoint_path, is_openai=False, skip_list=skip_list) + elif text_tag == "clip": + text_state_dict = load_clip_text_state_dict(text_checkpoint_path, is_openai=True, skip_list=skip_list) + else: + text_state_dict = load_state_dict(visual_checkpoint_path, model_key=model_key, is_openai=False, skip_list=skip_list) + + text_incompatible_keys = model.text.load_state_dict(text_state_dict, strict=strict) + + logging.info(f"num of loaded text_state_dict keys: {len(text_state_dict.keys())}") + logging.info(f"text_incompatible_keys.missing_keys: {text_incompatible_keys.missing_keys}") + + return visual_incompatible_keys, text_incompatible_keys + + +def create_model( + model_name: str, + pretrained: Optional[str] = None, + precision: str = "fp32", + device: Union[str, torch.device] = "cpu", + jit: bool = False, + force_quick_gelu: bool = False, + force_custom_clip: bool = False, + force_patch_dropout: Optional[float] = None, + pretrained_image: str = "", + pretrained_text: str = "", + pretrained_hf: bool = True, + pretrained_visual_model: str = None, + pretrained_text_model: str = None, + cache_dir: Optional[str] = None, + skip_list: list = [], +): + model_name = model_name.replace("/", "-") # for callers using old naming with / in ViT names + if isinstance(device, str): + device = torch.device(device) + + if pretrained and pretrained.lower() == "openai": + logging.info(f"Loading pretrained {model_name} from OpenAI.") + model = load_openai_model( + model_name, + precision=precision, + device=device, + jit=jit, + cache_dir=cache_dir, + ) + else: + model_cfg = get_model_config(model_name) + if model_cfg is not None: + logging.info(f"Loaded {model_name} model config.") + else: + logging.error(f"Model config for {model_name} not found; available models {list_models()}.") + raise RuntimeError(f"Model config for {model_name} not found.") + + if "rope" in model_cfg.get("vision_cfg", {}): + if model_cfg["vision_cfg"]["rope"]: + os.environ["RoPE"] = "1" + else: + os.environ["RoPE"] = "0" + + if force_quick_gelu: + # override for use of QuickGELU on non-OpenAI transformer models + model_cfg["quick_gelu"] = True + + if force_patch_dropout is not None: + # override the default patch dropout value + model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout + + cast_dtype = get_cast_dtype(precision) + custom_clip = model_cfg.pop("custom_text", False) or force_custom_clip or ("hf_model_name" in model_cfg["text_cfg"]) + + if custom_clip: + if "hf_model_name" in model_cfg.get("text_cfg", {}): + model_cfg["text_cfg"]["hf_model_pretrained"] = pretrained_hf + model = CustomCLIP(**model_cfg, cast_dtype=cast_dtype) + else: + model = CLIP(**model_cfg, cast_dtype=cast_dtype) + + pretrained_cfg = {} + if pretrained: + checkpoint_path = "" + pretrained_cfg = get_pretrained_cfg(model_name, pretrained) + if pretrained_cfg: + checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir) + elif os.path.exists(pretrained): + checkpoint_path = pretrained + + if checkpoint_path: + logging.info(f"Loading pretrained {model_name} weights ({pretrained}).") + load_checkpoint(model, checkpoint_path, model_key="model|module|state_dict", strict=False) + else: + error_str = f"Pretrained weights ({pretrained}) not found for model {model_name}." f"Available pretrained tags ({list_pretrained_tags_by_model(model_name)}." + logging.warning(error_str) + raise RuntimeError(error_str) + else: + visual_checkpoint_path = "" + text_checkpoint_path = "" + + if pretrained_image: + pretrained_visual_model = pretrained_visual_model.replace("/", "-") # for callers using old naming with / in ViT names + pretrained_image_cfg = get_pretrained_cfg(pretrained_visual_model, pretrained_image) + if "timm_model_name" in model_cfg.get("vision_cfg", {}): + # pretrained weight loading for timm models set via vision_cfg + model_cfg["vision_cfg"]["timm_model_pretrained"] = True + elif pretrained_image_cfg: + visual_checkpoint_path = download_pretrained(pretrained_image_cfg, cache_dir=cache_dir) + elif os.path.exists(pretrained_image): + visual_checkpoint_path = pretrained_image + else: + logging.warning(f"Pretrained weights ({visual_checkpoint_path}) not found for model {model_name}.visual.") + raise RuntimeError(f"Pretrained weights ({visual_checkpoint_path}) not found for model {model_name}.visual.") + + if pretrained_text: + pretrained_text_model = pretrained_text_model.replace("/", "-") # for callers using old naming with / in ViT names + pretrained_text_cfg = get_pretrained_cfg(pretrained_text_model, pretrained_text) + if pretrained_image_cfg: + text_checkpoint_path = download_pretrained(pretrained_text_cfg, cache_dir=cache_dir) + elif os.path.exists(pretrained_text): + text_checkpoint_path = pretrained_text + else: + logging.warning(f"Pretrained weights ({text_checkpoint_path}) not found for model {model_name}.text.") + raise RuntimeError(f"Pretrained weights ({text_checkpoint_path}) not found for model {model_name}.text.") + + if visual_checkpoint_path: + logging.info(f"Loading pretrained {model_name}.visual weights ({visual_checkpoint_path}).") + if text_checkpoint_path: + logging.info(f"Loading pretrained {model_name}.text weights ({text_checkpoint_path}).") + + if visual_checkpoint_path or text_checkpoint_path: + load_pretrained_checkpoint(model, visual_checkpoint_path, text_checkpoint_path, strict=False, visual_model=pretrained_visual_model, text_model=pretrained_text_model, model_key="model|module|state_dict", skip_list=skip_list) + + if "fp16" in precision or "bf16" in precision: + logging.info(f"convert precision to {precision}") + model = model.to(torch.bfloat16) if "bf16" in precision else model.to(torch.float16) + + # model.to(device=device) + + # set image / mean metadata from pretrained_cfg if available, or use default + model.visual.image_mean = pretrained_cfg.get("mean", None) or OPENAI_DATASET_MEAN + model.visual.image_std = pretrained_cfg.get("std", None) or OPENAI_DATASET_STD + + if jit: + model = torch.jit.script(model) + + return model + + +def create_model_and_transforms( + model_name: str, + pretrained: Optional[str] = None, + precision: str = "fp32", + device: Union[str, torch.device] = "cpu", + jit: bool = False, + force_quick_gelu: bool = False, + force_custom_clip: bool = False, + force_patch_dropout: Optional[float] = None, + pretrained_image: str = "", + pretrained_text: str = "", + pretrained_hf: bool = True, + pretrained_visual_model: str = None, + pretrained_text_model: str = None, + image_mean: Optional[Tuple[float, ...]] = None, + image_std: Optional[Tuple[float, ...]] = None, + cache_dir: Optional[str] = None, + skip_list: list = [], +): + model = create_model( + model_name, + pretrained, + precision=precision, + device=device, + jit=jit, + force_quick_gelu=force_quick_gelu, + force_custom_clip=force_custom_clip, + force_patch_dropout=force_patch_dropout, + pretrained_image=pretrained_image, + pretrained_text=pretrained_text, + pretrained_hf=pretrained_hf, + pretrained_visual_model=pretrained_visual_model, + pretrained_text_model=pretrained_text_model, + cache_dir=cache_dir, + skip_list=skip_list, + ) + + image_mean = image_mean or getattr(model.visual, "image_mean", None) + image_std = image_std or getattr(model.visual, "image_std", None) + preprocess_train = image_transform(model.visual.image_size, is_train=True, mean=image_mean, std=image_std) + preprocess_val = image_transform(model.visual.image_size, is_train=False, mean=image_mean, std=image_std) + + return model, preprocess_train, preprocess_val + + +def create_model_from_pretrained( + model_name: str, + pretrained: str, + precision: str = "fp32", + device: Union[str, torch.device] = "cpu", + jit: bool = False, + force_quick_gelu: bool = False, + force_custom_clip: bool = False, + force_patch_dropout: Optional[float] = None, + return_transform: bool = True, + image_mean: Optional[Tuple[float, ...]] = None, + image_std: Optional[Tuple[float, ...]] = None, + cache_dir: Optional[str] = None, + is_frozen: bool = False, +): + if not is_pretrained_cfg(model_name, pretrained) and not os.path.exists(pretrained): + raise RuntimeError(f"{pretrained} is not a valid pretrained cfg or checkpoint for {model_name}." f" Use open_clip.list_pretrained() to find one.") + + model = create_model( + model_name, + pretrained, + precision=precision, + device=device, + jit=jit, + force_quick_gelu=force_quick_gelu, + force_custom_clip=force_custom_clip, + force_patch_dropout=force_patch_dropout, + cache_dir=cache_dir, + ) + + if is_frozen: + for param in model.parameters(): + param.requires_grad = False + + if not return_transform: + return model + + image_mean = image_mean or getattr(model.visual, "image_mean", None) + image_std = image_std or getattr(model.visual, "image_std", None) + preprocess = image_transform(model.visual.image_size, is_train=False, mean=image_mean, std=image_std) + + return model, preprocess diff --git a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/hf_configs.py b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/hf_configs.py new file mode 100755 index 000000000..ddd2c672f --- /dev/null +++ b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/hf_configs.py @@ -0,0 +1,57 @@ +# HF architecture dict: +arch_dict = { + # https://huggingface.co/docs/transformers/model_doc/roberta#roberta + "roberta": { + "config_names": { + "context_length": "max_position_embeddings", + "vocab_size": "vocab_size", + "width": "hidden_size", + "heads": "num_attention_heads", + "layers": "num_hidden_layers", + "layer_attr": "layer", + "token_embeddings_attr": "embeddings", + }, + "pooler": "mean_pooler", + }, + # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig + "xlm-roberta": { + "config_names": { + "context_length": "max_position_embeddings", + "vocab_size": "vocab_size", + "width": "hidden_size", + "heads": "num_attention_heads", + "layers": "num_hidden_layers", + "layer_attr": "layer", + "token_embeddings_attr": "embeddings", + }, + "pooler": "mean_pooler", + }, + # https://huggingface.co/docs/transformers/model_doc/mt5#mt5 + "mt5": { + "config_names": { + # unlimited seqlen + # https://github.com/google-research/text-to-text-transfer-transformer/issues/273 + # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374 + "context_length": "", + "vocab_size": "vocab_size", + "width": "d_model", + "heads": "num_heads", + "layers": "num_layers", + "layer_attr": "block", + "token_embeddings_attr": "embed_tokens", + }, + "pooler": "mean_pooler", + }, + "bert": { + "config_names": { + "context_length": "max_position_embeddings", + "vocab_size": "vocab_size", + "width": "hidden_size", + "heads": "num_attention_heads", + "layers": "num_hidden_layers", + "layer_attr": "layer", + "token_embeddings_attr": "embeddings", + }, + "pooler": "mean_pooler", + }, +} diff --git a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/hf_model.py b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/hf_model.py new file mode 100755 index 000000000..a156624ba --- /dev/null +++ b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/hf_model.py @@ -0,0 +1,240 @@ +""" huggingface model adapter + +Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model. +""" + +import re + +import torch +import torch.nn as nn +from torch.nn import functional as F +from torch import TensorType + +try: + import transformers + from transformers import AutoModel, AutoModelForMaskedLM, AutoTokenizer, AutoConfig, PretrainedConfig + from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions +except ImportError as e: + transformers = None + + class BaseModelOutput: + pass + + class PretrainedConfig: + pass + + +from .hf_configs import arch_dict + + +# utils +def _camel2snake(s): + return re.sub(r"(? TensorType: + # image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(x.device) + # attn_mask = (x != self.config.pad_token_id).long() + # out = self.transformer( + # input_ids=x, + # attention_mask=attn_mask, + # encoder_hidden_states = image_embeds, + # encoder_attention_mask = image_atts, + # ) + # pooled_out = self.pooler(out, attn_mask) + + # return self.itm_proj(pooled_out) + + def mask(self, input_ids, vocab_size, device, targets=None, masked_indices=None, probability_matrix=None): + if masked_indices is None: + masked_indices = torch.bernoulli(probability_matrix).bool() + + masked_indices[input_ids == self.tokenizer.pad_token_id] = False + masked_indices[input_ids == self.tokenizer.cls_token_id] = False + + if targets is not None: + targets[~masked_indices] = -100 # We only compute loss on masked tokens + + # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) + indices_replaced = torch.bernoulli(torch.full(input_ids.shape, 0.8)).bool() & masked_indices + input_ids[indices_replaced] = self.tokenizer.mask_token_id + + # 10% of the time, we replace masked input tokens with random word + indices_random = torch.bernoulli(torch.full(input_ids.shape, 0.5)).bool() & masked_indices & ~indices_replaced + random_words = torch.randint(vocab_size, input_ids.shape, dtype=torch.long).to(device) + input_ids[indices_random] = random_words[indices_random] + # The rest of the time (10% of the time) we keep the masked input tokens unchanged + + if targets is not None: + return input_ids, targets + else: + return input_ids + + def forward_mlm(self, input_ids, image_embeds, mlm_probability=0.25): + labels = input_ids.clone() + attn_mask = (input_ids != self.config.pad_token_id).long() + image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(input_ids.device) + vocab_size = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["vocab_size"]) + probability_matrix = torch.full(labels.shape, mlm_probability) + input_ids, labels = self.mask(input_ids, vocab_size, input_ids.device, targets=labels, probability_matrix=probability_matrix) + mlm_output = self.transformer( + input_ids, + attention_mask=attn_mask, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + labels=labels, + ) + return mlm_output.loss + # mlm_output = self.transformer(input_ids, + # attention_mask = attn_mask, + # encoder_hidden_states = image_embeds, + # encoder_attention_mask = image_atts, + # return_dict = True, + # ).last_hidden_state + # logits = self.mlm_proj(mlm_output) + + # # logits = logits[:, :-1, :].contiguous().view(-1, vocab_size) + # logits = logits[:, 1:, :].contiguous().view(-1, vocab_size) + # labels = labels[:, 1:].contiguous().view(-1) + + # mlm_loss = F.cross_entropy( + # logits, + # labels, + # # label_smoothing=0.1, + # ) + # return mlm_loss + + def forward(self, x: TensorType) -> TensorType: + attn_mask = (x != self.config.pad_token_id).long() + out = self.transformer(input_ids=x, attention_mask=attn_mask) + pooled_out = self.pooler(out, attn_mask) + + return self.proj(pooled_out) + + def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True): + if not unlocked_layers: # full freezing + for n, p in self.transformer.named_parameters(): + p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False + return + + encoder = self.transformer.encoder if hasattr(self.transformer, "encoder") else self.transformer + layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"]) + print(f"Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model") + embeddings = getattr(self.transformer, arch_dict[self.config.model_type]["config_names"]["token_embeddings_attr"]) + modules = [embeddings, *layer_list][:-unlocked_layers] + # freeze layers + for module in modules: + for n, p in module.named_parameters(): + p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.transformer.gradient_checkpointing_enable() + + def get_num_layers(self): + encoder = self.transformer.encoder if hasattr(self.transformer, "encoder") else self.transformer + layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"]) + return len(layer_list) + + def init_parameters(self): + pass diff --git a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/loss.py b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/loss.py new file mode 100755 index 000000000..ec4c99507 --- /dev/null +++ b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/loss.py @@ -0,0 +1,123 @@ +import math +import torch +import torch.nn as nn +from torch.nn import functional as F + +try: + import torch.distributed.nn + from torch import distributed as dist + + has_distributed = True +except ImportError: + has_distributed = False + +try: + import horovod.torch as hvd +except ImportError: + hvd = None + +from timm.loss import LabelSmoothingCrossEntropy + + +def gather_features(image_features, text_features, local_loss=False, gather_with_grad=False, rank=0, world_size=1, use_horovod=False): + assert has_distributed, "torch.distributed did not import correctly, please use a PyTorch version with support." + if use_horovod: + assert hvd is not None, "Please install horovod" + if gather_with_grad: + all_image_features = hvd.allgather(image_features) + all_text_features = hvd.allgather(text_features) + else: + with torch.no_grad(): + all_image_features = hvd.allgather(image_features) + all_text_features = hvd.allgather(text_features) + if not local_loss: + # ensure grads for local rank when all_* features don't have a gradient + gathered_image_features = list(all_image_features.chunk(world_size, dim=0)) + gathered_text_features = list(all_text_features.chunk(world_size, dim=0)) + gathered_image_features[rank] = image_features + gathered_text_features[rank] = text_features + all_image_features = torch.cat(gathered_image_features, dim=0) + all_text_features = torch.cat(gathered_text_features, dim=0) + else: + # We gather tensors from all gpus + if gather_with_grad: + all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0) + all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0) + # all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features, async_op=True), dim=0) + # all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features, async_op=True), dim=0) + else: + gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)] + gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)] + dist.all_gather(gathered_image_features, image_features) + dist.all_gather(gathered_text_features, text_features) + if not local_loss: + # ensure grads for local rank when all_* features don't have a gradient + gathered_image_features[rank] = image_features + gathered_text_features[rank] = text_features + all_image_features = torch.cat(gathered_image_features, dim=0) + all_text_features = torch.cat(gathered_text_features, dim=0) + + return all_image_features, all_text_features + + +class ClipLoss(nn.Module): + + def __init__( + self, + local_loss=False, + gather_with_grad=False, + cache_labels=False, + rank=0, + world_size=1, + use_horovod=False, + smoothing=0.0, + ): + super().__init__() + self.local_loss = local_loss + self.gather_with_grad = gather_with_grad + self.cache_labels = cache_labels + self.rank = rank + self.world_size = world_size + self.use_horovod = use_horovod + self.label_smoothing_cross_entropy = LabelSmoothingCrossEntropy(smoothing=smoothing) if smoothing > 0 else None + + # cache state + self.prev_num_logits = 0 + self.labels = {} + + def forward(self, image_features, text_features, logit_scale=1.0): + device = image_features.device + if self.world_size > 1: + all_image_features, all_text_features = gather_features(image_features, text_features, self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod) + + if self.local_loss: + logits_per_image = logit_scale * image_features @ all_text_features.T + logits_per_text = logit_scale * text_features @ all_image_features.T + else: + logits_per_image = logit_scale * all_image_features @ all_text_features.T + logits_per_text = logits_per_image.T + else: + logits_per_image = logit_scale * image_features @ text_features.T + logits_per_text = logit_scale * text_features @ image_features.T + # calculated ground-truth and cache if enabled + num_logits = logits_per_image.shape[0] + if self.prev_num_logits != num_logits or device not in self.labels: + labels = torch.arange(num_logits, device=device, dtype=torch.long) + if self.world_size > 1 and self.local_loss: + labels = labels + num_logits * self.rank + if self.cache_labels: + self.labels[device] = labels + self.prev_num_logits = num_logits + else: + labels = self.labels[device] + + if self.label_smoothing_cross_entropy: + total_loss = (self.label_smoothing_cross_entropy(logits_per_image, labels) + self.label_smoothing_cross_entropy(logits_per_text, labels)) / 2 + else: + total_loss = (F.cross_entropy(logits_per_image, labels) + F.cross_entropy(logits_per_text, labels)) / 2 + + acc = None + i2t_acc = (logits_per_image.argmax(-1) == labels).sum() / len(logits_per_image) + t2i_acc = (logits_per_text.argmax(-1) == labels).sum() / len(logits_per_text) + acc = {"i2t": i2t_acc, "t2i": t2i_acc} + return total_loss, acc diff --git a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model.py b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model.py new file mode 100755 index 000000000..b2f3a2317 --- /dev/null +++ b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model.py @@ -0,0 +1,429 @@ +""" CLIP Model + +Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. +""" + +import os +from dataclasses import dataclass +from typing import Optional, Tuple, Union +from functools import partial + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +try: + from .hf_model import HFTextEncoder +except: + HFTextEncoder = None +from .modified_resnet import ModifiedResNet +from .timm_model import TimmModel +from .eva_vit_model import EVAVisionTransformer +from .transformer import LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer + +try: + from apex.normalization import FusedLayerNorm +except: + FusedLayerNorm = LayerNorm + # print("Please 'pip install apex'") + +try: + import xformers.ops as xops +except ImportError: + xops = None + # print("Please 'pip install xformers'") + + +class RMSnorm(nn.Module): + """ + adepted from transformers T5LayerNorm + """ + + def __init__(self, hidden_size, eps=1e-6): + """ + Construct a layernorm module in the T5 style. No bias and no subtraction of mean. + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean + # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated + # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for + # half-precision inputs is done in fp32 + + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states + + +@dataclass +class CLIPVisionCfg: + layers: Union[Tuple[int, int, int, int], int] = 12 + width: int = 768 + head_width: int = 64 + mlp_ratio: float = 4.0 + patch_size: int = 16 + image_size: Union[Tuple[int, int], int] = 224 + ls_init_value: Optional[float] = None # layer scale initial value + patch_dropout: float = 0.0 # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results + global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580) + drop_path_rate: Optional[float] = None # drop path rate + timm_model_name: str = None # a valid model name overrides layers, width, patch_size + timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model + timm_pool: str = "avg" # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '') + timm_proj: str = "linear" # linear projection for timm model output ('linear', 'mlp', '') + timm_proj_bias: bool = False # enable bias final projection + eva_model_name: str = None # a valid eva model name overrides layers, width, patch_size + qkv_bias: bool = True + fusedLN: bool = False + xattn: bool = False + postnorm: bool = False + rope: bool = False + pt_hw_seq_len: int = 16 # 224/14 + intp_freq: bool = False + naiveswiglu: bool = False + subln: bool = False + use_rms_norm: bool = False + + +@dataclass +class CLIPTextCfg: + context_length: int = 77 + vocab_size: int = 49408 + width: int = 512 + heads: int = 8 + layers: int = 12 + ls_init_value: Optional[float] = None # layer scale initial value + hf_model_name: str = None + hf_tokenizer_name: str = None + hf_model_pretrained: bool = True + proj: str = "mlp" + pooler_type: str = "mean_pooler" + masked_language_modeling: bool = False + fusedLN: bool = False + xattn: bool = False + attn_mask: bool = True + + +def get_cast_dtype(precision: str): + cast_dtype = None + if precision == "bf16": + cast_dtype = torch.bfloat16 + elif precision == "fp16": + cast_dtype = torch.float16 + return cast_dtype + + +def _build_vision_tower(embed_dim: int, vision_cfg: CLIPVisionCfg, quick_gelu: bool = False, cast_dtype: Optional[torch.dtype] = None): + if isinstance(vision_cfg, dict): + vision_cfg = CLIPVisionCfg(**vision_cfg) + + # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more + # memory efficient in recent PyTorch releases (>= 1.10). + # NOTE: timm models always use native GELU regardless of quick_gelu flag. + act_layer = QuickGELU if quick_gelu else nn.GELU + + if vision_cfg.eva_model_name: + vision_heads = vision_cfg.width // vision_cfg.head_width + + norm_layer = RMSnorm if vision_cfg.use_rms_norm else LayerNorm + + visual = EVAVisionTransformer( + img_size=vision_cfg.image_size, + patch_size=vision_cfg.patch_size, + num_classes=embed_dim, + use_mean_pooling=vision_cfg.global_average_pool, # False + init_values=vision_cfg.ls_init_value, + patch_dropout=vision_cfg.patch_dropout, + embed_dim=vision_cfg.width, + depth=vision_cfg.layers, + num_heads=vision_heads, + mlp_ratio=vision_cfg.mlp_ratio, + qkv_bias=vision_cfg.qkv_bias, + drop_path_rate=vision_cfg.drop_path_rate, + norm_layer=partial(norm_layer, eps=1e-6), + xattn=vision_cfg.xattn, + rope=vision_cfg.rope, + postnorm=vision_cfg.postnorm, + pt_hw_seq_len=vision_cfg.pt_hw_seq_len, # 224/14 + intp_freq=vision_cfg.intp_freq, + naiveswiglu=vision_cfg.naiveswiglu, + subln=vision_cfg.subln, + ) + elif vision_cfg.timm_model_name: + visual = TimmModel( + vision_cfg.timm_model_name, pretrained=vision_cfg.timm_model_pretrained, pool=vision_cfg.timm_pool, proj=vision_cfg.timm_proj, proj_bias=vision_cfg.timm_proj_bias, embed_dim=embed_dim, image_size=vision_cfg.image_size + ) + act_layer = nn.GELU # so that text transformer doesn't use QuickGELU w/ timm models + elif isinstance(vision_cfg.layers, (tuple, list)): + vision_heads = vision_cfg.width * 32 // vision_cfg.head_width + visual = ModifiedResNet(layers=vision_cfg.layers, output_dim=embed_dim, heads=vision_heads, image_size=vision_cfg.image_size, width=vision_cfg.width) + else: + vision_heads = vision_cfg.width // vision_cfg.head_width + norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm + visual = VisionTransformer( + image_size=vision_cfg.image_size, + patch_size=vision_cfg.patch_size, + width=vision_cfg.width, + layers=vision_cfg.layers, + heads=vision_heads, + mlp_ratio=vision_cfg.mlp_ratio, + ls_init_value=vision_cfg.ls_init_value, + patch_dropout=vision_cfg.patch_dropout, + global_average_pool=vision_cfg.global_average_pool, + output_dim=embed_dim, + act_layer=act_layer, + norm_layer=norm_layer, + ) + + return visual + + +def _build_text_tower( + embed_dim: int, + text_cfg: CLIPTextCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, +): + if isinstance(text_cfg, dict): + text_cfg = CLIPTextCfg(**text_cfg) + + if text_cfg.hf_model_name: + text = HFTextEncoder(text_cfg.hf_model_name, output_dim=embed_dim, tokenizer_name=text_cfg.hf_tokenizer_name, proj=text_cfg.proj, pooler_type=text_cfg.pooler_type, masked_language_modeling=text_cfg.masked_language_modeling) + else: + act_layer = QuickGELU if quick_gelu else nn.GELU + norm_layer = LayerNorm + + text = TextTransformer( + context_length=text_cfg.context_length, + vocab_size=text_cfg.vocab_size, + width=text_cfg.width, + heads=text_cfg.heads, + layers=text_cfg.layers, + ls_init_value=text_cfg.ls_init_value, + output_dim=embed_dim, + act_layer=act_layer, + norm_layer=FusedLayerNorm if text_cfg.fusedLN else norm_layer, + xattn=text_cfg.xattn, + attn_mask=text_cfg.attn_mask, + ) + return text + + +class CLIP(nn.Module): + def __init__( + self, + embed_dim: int, + vision_cfg: CLIPVisionCfg, + text_cfg: CLIPTextCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, + ): + super().__init__() + self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) + + text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) + self.transformer = text.transformer + self.vocab_size = text.vocab_size + self.token_embedding = text.token_embedding + self.positional_embedding = text.positional_embedding + self.ln_final = text.ln_final + self.text_projection = text.text_projection + self.register_buffer("attn_mask", text.attn_mask, persistent=False) + + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): + # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 + self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.visual.set_grad_checkpointing(enable) + self.transformer.grad_checkpointing = enable + + @torch.jit.ignore + def no_weight_decay(self): + return {"logit_scale"} + + def encode_image(self, image, normalize: bool = False): + features = self.visual(image) + return F.normalize(features, dim=-1) if normalize else features + + def encode_text(self, text, normalize: bool = False): + cast_dtype = self.transformer.get_cast_dtype() + + x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding.to(cast_dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x, attn_mask=self.attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x) # [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + return F.normalize(x, dim=-1) if normalize else x + + def forward(self, image, text): + image_features = self.encode_image(image, normalize=True) + text_features = self.encode_text(text, normalize=True) + return image_features, text_features, self.logit_scale.exp() + + +class CustomCLIP(nn.Module): + def __init__( + self, + embed_dim: int, + vision_cfg: CLIPVisionCfg, + text_cfg: CLIPTextCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, + itm_task: bool = False, + ): + super().__init__() + self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) + self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): + # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 + self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats) + + def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True): + self.text.lock(unlocked_layers, freeze_layer_norm) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.visual.set_grad_checkpointing(enable) + self.text.set_grad_checkpointing(enable) + + @torch.jit.ignore + def no_weight_decay(self): + return {"logit_scale"} + + def encode_image(self, image, normalize: bool = False): + features = self.visual(image) + return F.normalize(features, dim=-1) if normalize else features + + def encode_text(self, text, normalize: bool = False): + features = self.text(text) + return F.normalize(features, dim=-1) if normalize else features + + def forward(self, image, text): + image_features = self.encode_image(image, normalize=True) + text_features = self.encode_text(text, normalize=True) + return image_features, text_features, self.logit_scale.exp() + + +def convert_weights_to_lp(model: nn.Module, dtype=torch.float16): + """Convert applicable model parameters to low-precision (bf16 or fp16)""" + + def _convert_weights(l): + + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): + l.weight.data = l.weight.data.to(dtype) + if l.bias is not None: + l.bias.data = l.bias.data.to(dtype) + + if isinstance(l, (nn.MultiheadAttention, Attention)): + for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: + tensor = getattr(l, attr, None) + if tensor is not None: + tensor.data = tensor.data.to(dtype) + + if isinstance(l, nn.Parameter): + l.data = l.data.to(dtype) + + for name in ["text_projection", "proj"]: + if hasattr(l, name) and isinstance(l, nn.Parameter): + attr = getattr(l, name, None) + if attr is not None: + attr.data = attr.data.to(dtype) + + model.apply(_convert_weights) + + +convert_weights_to_fp16 = convert_weights_to_lp # backwards compat + + +# used to maintain checkpoint compatibility +def convert_to_custom_text_state_dict(state_dict: dict): + if "text_projection" in state_dict: + # old format state_dict, move text tower -> .text + new_state_dict = {} + for k, v in state_dict.items(): + if any(k.startswith(p) for p in ("text_projection", "positional_embedding", "token_embedding", "transformer", "ln_final", "logit_scale")): + k = "text." + k + new_state_dict[k] = v + return new_state_dict + return state_dict + + +def build_model_from_openai_state_dict( + state_dict: dict, + quick_gelu=True, + cast_dtype=torch.float16, +): + vit = "visual.proj" in state_dict + + if vit: + vision_width = state_dict["visual.conv1.weight"].shape[0] + vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) + vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] + grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) + image_size = vision_patch_size * grid_size + else: + counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] + vision_layers = tuple(counts) + vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] + output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) + vision_patch_size = None + assert output_width**2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] + image_size = output_width * 32 + + embed_dim = state_dict["text_projection"].shape[1] + context_length = state_dict["positional_embedding"].shape[0] + vocab_size = state_dict["token_embedding.weight"].shape[0] + transformer_width = state_dict["ln_final.weight"].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) + + vision_cfg = CLIPVisionCfg( + layers=vision_layers, + width=vision_width, + patch_size=vision_patch_size, + image_size=image_size, + ) + text_cfg = CLIPTextCfg(context_length=context_length, vocab_size=vocab_size, width=transformer_width, heads=transformer_heads, layers=transformer_layers) + model = CLIP( + embed_dim, + vision_cfg=vision_cfg, + text_cfg=text_cfg, + quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU + cast_dtype=cast_dtype, + ) + + for key in ["input_resolution", "context_length", "vocab_size"]: + state_dict.pop(key, None) + + convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16 + model.load_state_dict(state_dict) + return model.eval() + + +def trace_model(model, batch_size=256, device=torch.device("cpu")): + model.eval() + image_size = model.visual.image_size + example_images = torch.ones((batch_size, 3, image_size, image_size), device=device) + example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device) + model = torch.jit.trace_module(model, inputs=dict(forward=(example_images, example_text), encode_text=(example_text,), encode_image=(example_images,))) + model.visual.image_size = image_size + return model diff --git a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA-CLIP-18B.json b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA-CLIP-18B.json new file mode 100755 index 000000000..491755669 --- /dev/null +++ b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA-CLIP-18B.json @@ -0,0 +1,27 @@ +{ + "embed_dim": 1536, + "vision_cfg": { + "image_size": 224, + "layers": 48, + "width": 5120, + "head_width": 128, + "mlp_ratio": 5, + "patch_size": 14, + "eva_model_name": "eva-clip-18b-14-x", + "drop_path_rate": 0, + "qkv_bias": false, + "xattn": true, + "postnorm": true, + "fusedLN": false, + "use_rms_norm": true + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1280, + "heads": 20, + "layers": 32, + "xattn": false, + "fusedLN": false + } +} \ No newline at end of file diff --git a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA-CLIP-8B-plus.json b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA-CLIP-8B-plus.json new file mode 100755 index 000000000..7d843daa3 --- /dev/null +++ b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA-CLIP-8B-plus.json @@ -0,0 +1,27 @@ +{ + "embed_dim": 1280, + "vision_cfg": { + "image_size": 448, + "layers": 32, + "width": 4096, + "head_width": 128, + "mlp_ratio": 5, + "patch_size": 14, + "eva_model_name": "eva-clip-8b-14-plus-x", + "drop_path_rate": 0, + "qkv_bias": false, + "xattn": true, + "postnorm": false, + "fusedLN": false, + "use_rms_norm": true + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1280, + "heads": 20, + "layers": 32, + "xattn": false, + "fusedLN": false + } +} \ No newline at end of file diff --git a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA-CLIP-8B.json b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA-CLIP-8B.json new file mode 100755 index 000000000..689492a25 --- /dev/null +++ b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA-CLIP-8B.json @@ -0,0 +1,27 @@ +{ + "embed_dim": 1280, + "vision_cfg": { + "image_size": 224, + "layers": 32, + "width": 4096, + "head_width": 128, + "mlp_ratio": 5, + "patch_size": 14, + "eva_model_name": "eva-clip-8b-14-x", + "drop_path_rate": 0, + "qkv_bias": false, + "xattn": true, + "postnorm": false, + "fusedLN": false, + "use_rms_norm": true + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1280, + "heads": 20, + "layers": 32, + "xattn": false, + "fusedLN": false + } +} \ No newline at end of file diff --git a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA01-CLIP-B-16.json b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA01-CLIP-B-16.json new file mode 100755 index 000000000..aad205800 --- /dev/null +++ b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA01-CLIP-B-16.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 16, + "eva_model_name": "eva-clip-b-16", + "ls_init_value": 0.1, + "drop_path_rate": 0.0 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA01-CLIP-g-14-plus.json b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA01-CLIP-g-14-plus.json new file mode 100755 index 000000000..100279572 --- /dev/null +++ b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA01-CLIP-g-14-plus.json @@ -0,0 +1,24 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 40, + "width": 1408, + "head_width": 88, + "mlp_ratio": 4.3637, + "patch_size": 14, + "eva_model_name": "eva-clip-g-14-x", + "drop_path_rate": 0, + "xattn": true, + "fusedLN": true + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24, + "xattn": false, + "fusedLN": true + } +} \ No newline at end of file diff --git a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA01-CLIP-g-14.json b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA01-CLIP-g-14.json new file mode 100755 index 000000000..5d338b4e6 --- /dev/null +++ b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA01-CLIP-g-14.json @@ -0,0 +1,24 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 40, + "width": 1408, + "head_width": 88, + "mlp_ratio": 4.3637, + "patch_size": 14, + "eva_model_name": "eva-clip-g-14-x", + "drop_path_rate": 0.4, + "xattn": true, + "fusedLN": true + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12, + "xattn": false, + "fusedLN": true + } +} \ No newline at end of file diff --git a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA02-CLIP-B-16.json b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA02-CLIP-B-16.json new file mode 100755 index 000000000..e4a6e723f --- /dev/null +++ b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA02-CLIP-B-16.json @@ -0,0 +1,29 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "head_width": 64, + "patch_size": 16, + "mlp_ratio": 2.6667, + "eva_model_name": "eva-clip-b-16-X", + "drop_path_rate": 0.0, + "xattn": true, + "fusedLN": true, + "rope": true, + "pt_hw_seq_len": 16, + "intp_freq": true, + "naiveswiglu": true, + "subln": true + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12, + "xattn": true, + "fusedLN": true + } +} \ No newline at end of file diff --git a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA02-CLIP-L-14-336.json b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA02-CLIP-L-14-336.json new file mode 100755 index 000000000..3e1d124e1 --- /dev/null +++ b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA02-CLIP-L-14-336.json @@ -0,0 +1,29 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 336, + "layers": 24, + "width": 1024, + "drop_path_rate": 0, + "head_width": 64, + "mlp_ratio": 2.6667, + "patch_size": 14, + "eva_model_name": "eva-clip-l-14-336", + "xattn": true, + "fusedLN": true, + "rope": true, + "pt_hw_seq_len": 16, + "intp_freq": true, + "naiveswiglu": true, + "subln": true + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12, + "xattn": false, + "fusedLN": true + } +} \ No newline at end of file diff --git a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA02-CLIP-L-14.json b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA02-CLIP-L-14.json new file mode 100755 index 000000000..03b22ad3c --- /dev/null +++ b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA02-CLIP-L-14.json @@ -0,0 +1,29 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 224, + "layers": 24, + "width": 1024, + "drop_path_rate": 0, + "head_width": 64, + "mlp_ratio": 2.6667, + "patch_size": 14, + "eva_model_name": "eva-clip-l-14", + "xattn": true, + "fusedLN": true, + "rope": true, + "pt_hw_seq_len": 16, + "intp_freq": true, + "naiveswiglu": true, + "subln": true + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12, + "xattn": false, + "fusedLN": true + } +} \ No newline at end of file diff --git a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA02-CLIP-bigE-14-plus.json b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA02-CLIP-bigE-14-plus.json new file mode 100755 index 000000000..aa04e2545 --- /dev/null +++ b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA02-CLIP-bigE-14-plus.json @@ -0,0 +1,25 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 64, + "width": 1792, + "head_width": 112, + "mlp_ratio": 8.571428571428571, + "patch_size": 14, + "eva_model_name": "eva-clip-4b-14-x", + "drop_path_rate": 0, + "xattn": true, + "postnorm": true, + "fusedLN": true + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1280, + "heads": 20, + "layers": 32, + "xattn": false, + "fusedLN": true + } +} diff --git a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA02-CLIP-bigE-14.json b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA02-CLIP-bigE-14.json new file mode 100755 index 000000000..747ffccc8 --- /dev/null +++ b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/EVA02-CLIP-bigE-14.json @@ -0,0 +1,25 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 64, + "width": 1792, + "head_width": 112, + "mlp_ratio": 8.571428571428571, + "patch_size": 14, + "eva_model_name": "eva-clip-4b-14-x", + "drop_path_rate": 0, + "xattn": true, + "postnorm": true, + "fusedLN": true + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24, + "xattn": false, + "fusedLN": true + } +} \ No newline at end of file diff --git a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/Internal-EVA02-CLIP-10B-14-448.json b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/Internal-EVA02-CLIP-10B-14-448.json new file mode 100755 index 000000000..ad71aff86 --- /dev/null +++ b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/Internal-EVA02-CLIP-10B-14-448.json @@ -0,0 +1,25 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 448, + "layers": 77, + "width": 2304, + "head_width": 144, + "mlp_ratio": 10.9722, + "patch_size": 14, + "eva_model_name": "eva-clip-10b-14-x", + "drop_path_rate": 0, + "xattn": true, + "postnorm": false, + "fusedLN": true + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1280, + "heads": 20, + "layers": 32, + "xattn": false, + "fusedLN": true + } +} diff --git a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/Internal-EVA02-CLIP-10B-14.json b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/Internal-EVA02-CLIP-10B-14.json new file mode 100755 index 000000000..21b206807 --- /dev/null +++ b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/model_configs/Internal-EVA02-CLIP-10B-14.json @@ -0,0 +1,25 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 77, + "width": 2304, + "head_width": 144, + "mlp_ratio": 10.9722, + "patch_size": 14, + "eva_model_name": "eva-clip-10b-14-x", + "drop_path_rate": 0, + "xattn": true, + "postnorm": false, + "fusedLN": true + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1280, + "heads": 20, + "layers": 32, + "xattn": false, + "fusedLN": true + } +} diff --git a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/modified_resnet.py b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/modified_resnet.py new file mode 100755 index 000000000..9f29f8454 --- /dev/null +++ b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/modified_resnet.py @@ -0,0 +1,179 @@ +from collections import OrderedDict + +import torch +from torch import nn +from torch.nn import functional as F + +from .utils import freeze_batch_norm_2d + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.act1 = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.act2 = nn.ReLU(inplace=True) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.act3 = nn.ReLU(inplace=True) + + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential(OrderedDict([("-1", nn.AvgPool2d(stride)), ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), ("1", nn.BatchNorm2d(planes * self.expansion))])) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.act1(self.bn1(self.conv1(x))) + out = self.act2(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.act3(out) + return out + + +class AttentionPool2d(nn.Module): + def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter(torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x, + key=x, + value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0.0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False, + ) + + return x[0] + + +class ModifiedResNet(nn.Module): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__(self, layers, output_dim, heads, image_size=224, width=64): + super().__init__() + self.output_dim = output_dim + self.image_size = image_size + + # the 3-layer stem + self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(width // 2) + self.act1 = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(width // 2) + self.act2 = nn.ReLU(inplace=True) + self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(width) + self.act3 = nn.ReLU(inplace=True) + self.avgpool = nn.AvgPool2d(2) + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=2) + + embed_dim = width * 32 # the ResNet feature dimension + self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim) + + self.init_parameters() + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def init_parameters(self): + if self.attnpool is not None: + std = self.attnpool.c_proj.in_features**-0.5 + nn.init.normal_(self.attnpool.q_proj.weight, std=std) + nn.init.normal_(self.attnpool.k_proj.weight, std=std) + nn.init.normal_(self.attnpool.v_proj.weight, std=std) + nn.init.normal_(self.attnpool.c_proj.weight, std=std) + + for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]: + for name, param in resnet_block.named_parameters(): + if name.endswith("bn3.weight"): + nn.init.zeros_(param) + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + assert unlocked_groups == 0, "partial locking not currently supported for this model" + for param in self.parameters(): + param.requires_grad = False + if freeze_bn_stats: + freeze_batch_norm_2d(self) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + # FIXME support for non-transformer + pass + + def stem(self, x): + x = self.act1(self.bn1(self.conv1(x))) + x = self.act2(self.bn2(self.conv2(x))) + x = self.act3(self.bn3(self.conv3(x))) + x = self.avgpool(x) + return x + + def forward(self, x): + x = self.stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.attnpool(x) + + return x diff --git a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/openai.py b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/openai.py new file mode 100755 index 000000000..9fbf6fc7e --- /dev/null +++ b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/openai.py @@ -0,0 +1,144 @@ +""" OpenAI pretrained model functions + +Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. +""" + +import os +import warnings +from typing import List, Optional, Union + +import torch + +from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype +from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url + +__all__ = ["list_openai_models", "load_openai_model"] + + +def list_openai_models() -> List[str]: + """Returns the names of available CLIP models""" + return list_pretrained_models_by_tag("openai") + + +def load_openai_model( + name: str, + precision: Optional[str] = None, + device: Optional[Union[str, torch.device]] = None, + jit: bool = True, + cache_dir: Optional[str] = None, +): + """Load a CLIP model + + Parameters + ---------- + name : str + A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict + precision: str + Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'. + device : Union[str, torch.device] + The device to put the loaded model + jit : bool + Whether to load the optimized JIT model (default) or more hackable non-JIT model. + cache_dir : Optional[str] + The directory to cache the downloaded model weights + + Returns + ------- + model : torch.nn.Module + The CLIP model + preprocess : Callable[[PIL.Image], torch.Tensor] + A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input + """ + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + if precision is None: + precision = "fp32" if device == "cpu" else "fp16" + + if get_pretrained_url(name, "openai"): + model_path = download_pretrained_from_url(get_pretrained_url(name, "openai"), cache_dir=cache_dir) + elif os.path.isfile(name): + model_path = name + else: + raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}") + + try: + # loading JIT archive + model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() + state_dict = None + except RuntimeError: + # loading saved state dict + if jit: + warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") + jit = False + state_dict = torch.load(model_path, map_location="cpu") + + if not jit: + # Build a non-jit model from the OpenAI jitted model state dict + cast_dtype = get_cast_dtype(precision) + try: + model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype) + except KeyError: + sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} + model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype) + + # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use + model = model.to(device) + if precision.startswith("amp") or precision == "fp32": + model.float() + elif precision == "bf16": + convert_weights_to_lp(model, dtype=torch.bfloat16) + + return model + + # patch the device names + device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) + device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] + + def patch_device(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("prim::Constant"): + if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): + node.copyAttributes(device_node) + + model.apply(patch_device) + patch_device(model.encode_image) + patch_device(model.encode_text) + + # patch dtype to float32 (typically for CPU) + if precision == "fp32": + float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) + float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] + float_node = float_input.node() + + def patch_float(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("aten::to"): + inputs = list(node.inputs()) + for i in [1, 2]: # dtype can be the second or third argument to aten::to() + if inputs[i].node()["value"] == 5: + inputs[i].node().copyAttributes(float_node) + + model.apply(patch_float) + patch_float(model.encode_image) + patch_float(model.encode_text) + model.float() + + # ensure image_size attr available at consistent location for both jit and non-jit + model.visual.image_size = model.input_resolution.item() + return model diff --git a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/pretrained.py b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/pretrained.py new file mode 100755 index 000000000..a603b6047 --- /dev/null +++ b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/pretrained.py @@ -0,0 +1,314 @@ +import hashlib +import os +import urllib +import warnings +from typing import Dict, Union + +from tqdm import tqdm + +try: + from huggingface_hub import hf_hub_download + + _has_hf_hub = True +except ImportError: + hf_hub_download = None + _has_hf_hub = False + + +def _pcfg(url="", hf_hub="", filename="", mean=None, std=None): + return dict( + url=url, + hf_hub=hf_hub, + mean=mean, + std=std, + ) + + +_VITB32 = dict( + openai=_pcfg("https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"), + laion400m_e31=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"), + laion400m_e32=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"), + laion2b_e16=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth"), + laion2b_s34b_b79k=_pcfg(hf_hub="laion/CLIP-ViT-B-32-laion2B-s34B-b79K/"), +) + +_VITB32_quickgelu = dict( + openai=_pcfg("https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"), + laion400m_e31=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"), + laion400m_e32=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"), +) + +_VITB16 = dict( + openai=_pcfg("https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"), + laion400m_e31=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt"), + laion400m_e32=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt"), + laion2b_s34b_b88k=_pcfg(hf_hub="laion/CLIP-ViT-B-16-laion2B-s34B-b88K/"), +) + +_EVAB16 = dict( + eva=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_B_psz14to16.pt"), + eva02=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_B_psz14to16.pt"), + eva_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_B_psz16_s8B.pt"), + eva02_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_B_psz16_s8B.pt"), +) + +_VITB16_PLUS_240 = dict( + laion400m_e31=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt"), + laion400m_e32=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"), +) + +_VITL14 = dict( + openai=_pcfg("https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"), + laion400m_e31=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt"), + laion400m_e32=_pcfg("https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt"), + laion2b_s32b_b82k=_pcfg(hf_hub="laion/CLIP-ViT-L-14-laion2B-s32B-b82K/", mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), +) + +_EVAL14 = dict( + eva=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_L_psz14.pt"), + eva02=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_L_psz14.pt"), + eva_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_s4B.pt"), + eva02_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_s4B.pt"), +) + +_VITL14_336 = dict( + openai=_pcfg("https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt"), +) + +_EVAL14_336 = dict( + eva_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt"), + eva02_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt"), + eva_clip_224to336=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_224to336.pt"), + eva02_clip_224to336=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_224to336.pt"), +) + +_VITH14 = dict( + laion2b_s32b_b79k=_pcfg(hf_hub="laion/CLIP-ViT-H-14-laion2B-s32B-b79K/"), +) + +_VITg14 = dict( + laion2b_s12b_b42k=_pcfg(hf_hub="laion/CLIP-ViT-g-14-laion2B-s12B-b42K/"), + laion2b_s34b_b88k=_pcfg(hf_hub="laion/CLIP-ViT-g-14-laion2B-s34B-b88K/"), +) + +_EVAg14 = dict( + eva=_pcfg(hf_hub="QuanSun/EVA-CLIP/"), + eva01=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA01_g_psz14.pt"), + eva_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA01_CLIP_g_14_psz14_s11B.pt"), + eva01_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA01_CLIP_g_14_psz14_s11B.pt"), +) + +_EVAg14_PLUS = dict( + eva=_pcfg(hf_hub="QuanSun/EVA-CLIP/"), + eva01=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA01_g_psz14.pt"), + eva_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA01_CLIP_g_14_plus_psz14_s11B.pt"), + eva01_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA01_CLIP_g_14_plus_psz14_s11B.pt"), +) + +_VITbigG14 = dict( + laion2b_s39b_b160k=_pcfg(hf_hub="laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/"), +) + +_EVAbigE14 = dict( + eva=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_E_psz14.pt"), + eva02=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_E_psz14.pt"), + eva_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_s4B.pt"), + eva02_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_s4B.pt"), +) + +_EVAbigE14_PLUS = dict( + eva=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_E_psz14.pt"), + eva02=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_E_psz14.pt"), + eva_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_plus_s9B.pt"), + eva02_clip=_pcfg(hf_hub="QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_plus_s9B.pt"), +) + +_EVA_8B = dict( + eva=_pcfg(hf_hub="BAAI/EVA-CLIP-8B/EVA_8B_psz14.bin"), + eva_clip=_pcfg(hf_hub="BAAI/EVA-CLIP-8B/EVA_CLIP_8B_psz14_s9B.pt"), +) + +_EVA_8B_PLUS = dict( + eva_clip=_pcfg(hf_hub="BAAI/EVA-CLIP-8B-448/EVA_CLIP_8B_psz14_plus_s0.6B.pt"), +) + + +_PRETRAINED = { + # "ViT-B-32": _VITB32, + "OpenaiCLIP-B-32": _VITB32, + "OpenCLIP-B-32": _VITB32, + # "ViT-B-32-quickgelu": _VITB32_quickgelu, + "OpenaiCLIP-B-32-quickgelu": _VITB32_quickgelu, + "OpenCLIP-B-32-quickgelu": _VITB32_quickgelu, + # "ViT-B-16": _VITB16, + "OpenaiCLIP-B-16": _VITB16, + "OpenCLIP-B-16": _VITB16, + "EVA02-B-16": _EVAB16, + "EVA02-CLIP-B-16": _EVAB16, + # "ViT-B-16-plus-240": _VITB16_PLUS_240, + "OpenCLIP-B-16-plus-240": _VITB16_PLUS_240, + # "ViT-L-14": _VITL14, + "OpenaiCLIP-L-14": _VITL14, + "OpenCLIP-L-14": _VITL14, + "EVA02-L-14": _EVAL14, + "EVA02-CLIP-L-14": _EVAL14, + # "ViT-L-14-336": _VITL14_336, + "OpenaiCLIP-L-14-336": _VITL14_336, + "EVA02-CLIP-L-14-336": _EVAL14_336, + # "ViT-H-14": _VITH14, + # "ViT-g-14": _VITg14, + "OpenCLIP-H-14": _VITH14, + "OpenCLIP-g-14": _VITg14, + "EVA01-CLIP-g-14": _EVAg14, + "EVA01-CLIP-g-14-plus": _EVAg14_PLUS, + # "ViT-bigG-14": _VITbigG14, + "OpenCLIP-bigG-14": _VITbigG14, + "EVA02-CLIP-bigE-14": _EVAbigE14, + "EVA02-CLIP-bigE-14-plus": _EVAbigE14_PLUS, + "EVA-CLIP-8B": _EVA_8B, + "EVA-CLIP-8B-448": _EVA_8B_PLUS, + "EVA-CLIP-8B-plus": _EVA_8B_PLUS, +} + + +def _clean_tag(tag: str): + # normalize pretrained tags + return tag.lower().replace("-", "_") + + +def list_pretrained(as_str: bool = False): + """returns list of pretrained models + Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True + """ + return [":".join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()] + + +def list_pretrained_models_by_tag(tag: str): + """return all models having the specified pretrain tag""" + models = [] + tag = _clean_tag(tag) + for k in _PRETRAINED.keys(): + if tag in _PRETRAINED[k]: + models.append(k) + return models + + +def list_pretrained_tags_by_model(model: str): + """return all pretrain tags for the specified model architecture""" + tags = [] + if model in _PRETRAINED: + tags.extend(_PRETRAINED[model].keys()) + return tags + + +def is_pretrained_cfg(model: str, tag: str): + if model not in _PRETRAINED: + return False + return _clean_tag(tag) in _PRETRAINED[model] + + +def get_pretrained_cfg(model: str, tag: str): + if model not in _PRETRAINED: + return {} + model_pretrained = _PRETRAINED[model] + return model_pretrained.get(_clean_tag(tag), {}) + + +def get_pretrained_url(model: str, tag: str): + cfg = get_pretrained_cfg(model, _clean_tag(tag)) + return cfg.get("url", "") + + +def download_pretrained_from_url( + url: str, + cache_dir: Union[str, None] = None, +): + if not cache_dir: + cache_dir = os.path.expanduser("~/.cache/clip") + os.makedirs(cache_dir, exist_ok=True) + filename = os.path.basename(url) + + if "openaipublic" in url: + expected_sha256 = url.split("/")[-2] + elif "mlfoundations" in url: + expected_sha256 = os.path.splitext(filename)[0].split("-")[-1] + else: + expected_sha256 = "" + + download_target = os.path.join(cache_dir, filename) + + if os.path.exists(download_target) and not os.path.isfile(download_target): + raise RuntimeError(f"{download_target} exists and is not a regular file") + + if os.path.isfile(download_target): + if expected_sha256: + if hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256): + return download_target + else: + warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") + else: + return download_target + + with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: + with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit="iB", unit_scale=True) as loop: + while True: + buffer = source.read(8192) + if not buffer: + break + + output.write(buffer) + loop.update(len(buffer)) + + if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256): + raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") + + return download_target + + +def has_hf_hub(necessary=False): + if not _has_hf_hub and necessary: + # if no HF Hub module installed, and it is necessary to continue, raise error + raise RuntimeError("Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.") + return _has_hf_hub + + +def download_pretrained_from_hf( + model_id: str, + filename: str = "open_clip_pytorch_model.bin", + revision=None, + cache_dir: Union[str, None] = None, +): + has_hf_hub(True) + cached_file = hf_hub_download(model_id, filename, revision=revision, cache_dir=cache_dir) + return cached_file + + +def download_pretrained( + cfg: Dict, + force_hf_hub: bool = False, + cache_dir: Union[str, None] = None, +): + target = "" + if not cfg: + return target + + download_url = cfg.get("url", "") + download_hf_hub = cfg.get("hf_hub", "") + if download_hf_hub and force_hf_hub: + # use HF hub even if url exists + download_url = "" + + if download_url: + target = download_pretrained_from_url(download_url, cache_dir=cache_dir) + elif download_hf_hub: + has_hf_hub(True) + # we assume the hf_hub entries in pretrained config combine model_id + filename in + # 'org/model_name/filename.pt' form. To specify just the model id w/o filename and + # use 'open_clip_pytorch_model.bin' default, there must be a trailing slash 'org/model_name/'. + model_id, filename = os.path.split(download_hf_hub) + if filename: + target = download_pretrained_from_hf(model_id, filename=filename, cache_dir=cache_dir) + else: + target = download_pretrained_from_hf(model_id, cache_dir=cache_dir) + + return target diff --git a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/rope.py b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/rope.py new file mode 100755 index 000000000..5fb3cce54 --- /dev/null +++ b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/rope.py @@ -0,0 +1,131 @@ +from math import pi +import torch +from torch import nn +from einops import rearrange, repeat +import logging + + +def broadcat(tensors, dim=-1): + num_tensors = len(tensors) + shape_lens = set(list(map(lambda t: len(t.shape), tensors))) + assert len(shape_lens) == 1, "tensors must all have the same number of dimensions" + shape_len = list(shape_lens)[0] + dim = (dim + shape_len) if dim < 0 else dim + dims = list(zip(*map(lambda t: list(t.shape), tensors))) + expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] + assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), "invalid dimensions for broadcastable concatentation" + max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims)) + expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims)) + expanded_dims.insert(dim, (dim, dims[dim])) + expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims))) + tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes))) + return torch.cat(tensors, dim=dim) + + +def rotate_half(x): + x = rearrange(x, "... (d r) -> ... d r", r=2) + x1, x2 = x.unbind(dim=-1) + x = torch.stack((-x2, x1), dim=-1) + return rearrange(x, "... d r -> ... (d r)") + + +class VisionRotaryEmbedding(nn.Module): + def __init__( + self, + dim, + pt_seq_len, + ft_seq_len=None, + custom_freqs=None, + freqs_for="lang", + theta=10000, + max_freq=10, + num_freqs=1, + ): + super().__init__() + if custom_freqs: + freqs = custom_freqs + elif freqs_for == "lang": + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + elif freqs_for == "pixel": + freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi + elif freqs_for == "constant": + freqs = torch.ones(num_freqs).float() + else: + raise ValueError(f"unknown modality {freqs_for}") + + if ft_seq_len is None: + ft_seq_len = pt_seq_len + t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len + + freqs_h = torch.einsum("..., f -> ... f", t, freqs) + freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2) + + freqs_w = torch.einsum("..., f -> ... f", t, freqs) + freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2) + + freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim=-1) + + self.register_buffer("freqs_cos", freqs.cos()) + self.register_buffer("freqs_sin", freqs.sin()) + + logging.info(f"Shape of rope freq: {self.freqs_cos.shape}") + + def forward(self, t, start_index=0): + rot_dim = self.freqs_cos.shape[-1] + end_index = start_index + rot_dim + assert rot_dim <= t.shape[-1], f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}" + t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:] + t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin) + + return torch.cat((t_left, t, t_right), dim=-1) + + +class VisionRotaryEmbeddingFast(nn.Module): + def __init__(self, dim, pt_seq_len, ft_seq_len=None, custom_freqs=None, freqs_for="lang", theta=10000, max_freq=10, num_freqs=1, patch_dropout=0.0): + super().__init__() + if custom_freqs: + freqs = custom_freqs + elif freqs_for == "lang": + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + elif freqs_for == "pixel": + freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi + elif freqs_for == "constant": + freqs = torch.ones(num_freqs).float() + else: + raise ValueError(f"unknown modality {freqs_for}") + + if ft_seq_len is None: + ft_seq_len = pt_seq_len + t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len + + freqs = torch.einsum("..., f -> ... f", t, freqs) + freqs = repeat(freqs, "... n -> ... (n r)", r=2) + freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1) + + freqs_cos = freqs.cos().view(-1, freqs.shape[-1]) + freqs_sin = freqs.sin().view(-1, freqs.shape[-1]) + + self.patch_dropout = patch_dropout + + self.register_buffer("freqs_cos", freqs_cos) + self.register_buffer("freqs_sin", freqs_sin) + + logging.info(f"Shape of rope freq: {self.freqs_cos.shape}") + + def forward(self, t, patch_indices_keep=None): + if patch_indices_keep is not None: + batch = t.size()[0] + batch_indices = torch.arange(batch) + batch_indices = batch_indices[..., None] + + freqs_cos = repeat(self.freqs_cos, "i j -> n i m j", n=t.shape[0], m=t.shape[1]) + freqs_sin = repeat(self.freqs_sin, "i j -> n i m j", n=t.shape[0], m=t.shape[1]) + + freqs_cos = freqs_cos[batch_indices, patch_indices_keep] + freqs_cos = rearrange(freqs_cos, "n i m j -> n m i j") + freqs_sin = freqs_sin[batch_indices, patch_indices_keep] + freqs_sin = rearrange(freqs_sin, "n i m j -> n m i j") + + return t * freqs_cos + rotate_half(t) * freqs_sin + + return t * self.freqs_cos + rotate_half(t) * self.freqs_sin diff --git a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/timm_model.py b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/timm_model.py new file mode 100755 index 000000000..65de78df2 --- /dev/null +++ b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/timm_model.py @@ -0,0 +1,114 @@ +""" timm model adapter + +Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model. +""" + +import logging +from collections import OrderedDict + +import torch +import torch.nn as nn + +try: + import timm + from timm.models.layers import Mlp, to_2tuple + + try: + # old timm imports < 0.8.1 + from timm.models.layers.attention_pool2d import RotAttentionPool2d + from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d + except ImportError: + # new timm imports >= 0.8.1 + from timm.layers import RotAttentionPool2d + from timm.layers import AttentionPool2d as AbsAttentionPool2d +except ImportError: + timm = None + +from .utils import freeze_batch_norm_2d + + +class TimmModel(nn.Module): + """timm model adapter + # FIXME this adapter is a work in progress, may change in ways that break weight compat + """ + + def __init__(self, model_name, embed_dim, image_size=224, pool="avg", proj="linear", proj_bias=False, drop=0.0, pretrained=False): + super().__init__() + if timm is None: + raise RuntimeError("Please `pip install timm` to use timm models.") + + self.image_size = to_2tuple(image_size) + self.trunk = timm.create_model(model_name, pretrained=pretrained) + feat_size = self.trunk.default_cfg.get("pool_size", None) + feature_ndim = 1 if not feat_size else 2 + if pool in ("abs_attn", "rot_attn"): + assert feature_ndim == 2 + # if attn pooling used, remove both classifier and default pool + self.trunk.reset_classifier(0, global_pool="") + else: + # reset global pool if pool config set, otherwise leave as network default + reset_kwargs = dict(global_pool=pool) if pool else {} + self.trunk.reset_classifier(0, **reset_kwargs) + prev_chs = self.trunk.num_features + + head_layers = OrderedDict() + if pool == "abs_attn": + head_layers["pool"] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim) + prev_chs = embed_dim + elif pool == "rot_attn": + head_layers["pool"] = RotAttentionPool2d(prev_chs, out_features=embed_dim) + prev_chs = embed_dim + else: + assert proj, "projection layer needed if non-attention pooling is used." + + # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used + if proj == "linear": + head_layers["drop"] = nn.Dropout(drop) + head_layers["proj"] = nn.Linear(prev_chs, embed_dim, bias=proj_bias) + elif proj == "mlp": + head_layers["mlp"] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop, bias=(True, proj_bias)) + + self.head = nn.Sequential(head_layers) + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + """lock modules + Args: + unlocked_groups (int): leave last n layer groups unlocked (default: 0) + """ + if not unlocked_groups: + # lock full model + for param in self.trunk.parameters(): + param.requires_grad = False + if freeze_bn_stats: + freeze_batch_norm_2d(self.trunk) + else: + # NOTE: partial freeze requires latest timm (master) branch and is subject to change + try: + # FIXME import here until API stable and in an official release + from timm.models.helpers import group_parameters, group_modules + except ImportError: + raise RuntimeError("Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`") + matcher = self.trunk.group_matcher() + gparams = group_parameters(self.trunk, matcher) + max_layer_id = max(gparams.keys()) + max_layer_id = max_layer_id - unlocked_groups + for group_idx in range(max_layer_id + 1): + group = gparams[group_idx] + for param in group: + self.trunk.get_parameter(param).requires_grad = False + if freeze_bn_stats: + gmodules = group_modules(self.trunk, matcher, reverse=True) + gmodules = {k for k, v in gmodules.items() if v <= max_layer_id} + freeze_batch_norm_2d(self.trunk, gmodules) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + try: + self.trunk.set_grad_checkpointing(enable) + except Exception as e: + logging.warning("grad checkpointing not supported for this timm image tower, continuing without...") + + def forward(self, x): + x = self.trunk(x) + x = self.head(x) + return x diff --git a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/tokenizer.py b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/tokenizer.py new file mode 100755 index 000000000..5f753e69b --- /dev/null +++ b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/tokenizer.py @@ -0,0 +1,205 @@ +""" CLIP tokenizer + +Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. +""" + +import gzip +import html +import os +from functools import lru_cache +from typing import Union, List + +import ftfy +import regex as re +import torch + +# https://stackoverflow.com/q/62691279 +import os + +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +@lru_cache() +def default_bpe(): + return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + def __init__(self, bpe_path: str = default_bpe(), special_tokens=None): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode("utf-8").split("\n") + merges = merges[1 : 49152 - 256 - 2 + 1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v + "" for v in vocab] + for merge in merges: + vocab.append("".join(merge)) + if not special_tokens: + special_tokens = ["", ""] + else: + special_tokens = ["", ""] + special_tokens + vocab.extend(special_tokens) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {t: t for t in special_tokens} + special = "|".join(special_tokens) + self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) + + self.vocab_size = len(self.encoder) + self.all_special_ids = [self.encoder[t] for t in special_tokens] + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + (token[-1] + "",) + pairs = get_pairs(word) + + if not pairs: + return token + "" + + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = " ".join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) + bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")) + return bpe_tokens + + def decode(self, tokens): + text = "".join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors="replace").replace("", " ") + return text + + +_tokenizer = SimpleTokenizer() + + +def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: + """ + Returns the tokenized representation of given input string(s) + + Parameters + ---------- + texts : Union[str, List[str]] + An input string or a list of input strings to tokenize + context_length : int + The context length to use; all CLIP models use 77 as the context length + + Returns + ------- + A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] + """ + if isinstance(texts, str): + texts = [texts] + + sot_token = _tokenizer.encoder[""] + eot_token = _tokenizer.encoder[""] + all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + tokens = tokens[:context_length] # Truncate + tokens[-1] = eot_token + result[i, : len(tokens)] = torch.tensor(tokens) + + return result + + +class HFTokenizer: + "HuggingFace tokenizer wrapper" + + def __init__(self, tokenizer_name: str): + from transformers import AutoTokenizer + + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + + def __call__(self, texts: Union[str, List[str]], context_length: int = 77) -> torch.Tensor: + # same cleaning as for default tokenizer, except lowercasing + # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance + if isinstance(texts, str): + texts = [texts] + texts = [whitespace_clean(basic_clean(text)) for text in texts] + input_ids = self.tokenizer(texts, return_tensors="pt", max_length=context_length, padding="max_length", truncation=True).input_ids + return input_ids diff --git a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/transform.py b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/transform.py new file mode 100755 index 000000000..8cad45a16 --- /dev/null +++ b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/transform.py @@ -0,0 +1,104 @@ +from typing import Optional, Sequence, Tuple + +import torch +import torch.nn as nn +import torchvision.transforms.functional as F + +from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, CenterCrop + +from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD + + +class ResizeMaxSize(nn.Module): + + def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn="max", fill=0): + super().__init__() + if not isinstance(max_size, int): + raise TypeError(f"Size should be int. Got {type(max_size)}") + self.max_size = max_size + self.interpolation = interpolation + self.fn = min if fn == "min" else min + self.fill = fill + + def forward(self, img): + if isinstance(img, torch.Tensor): + height, width = img.shape[:2] + else: + width, height = img.size + scale = self.max_size / float(max(height, width)) + if scale != 1.0: + new_size = tuple(round(dim * scale) for dim in (height, width)) + img = F.resize(img, new_size, self.interpolation) + pad_h = self.max_size - new_size[0] + pad_w = self.max_size - new_size[1] + img = F.pad(img, padding=[pad_w // 2, pad_h // 2, pad_w - pad_w // 2, pad_h - pad_h // 2], fill=self.fill) + return img + + +def _convert_to_rgb(image): + return image.convert("RGB") + + +# class CatGen(nn.Module): +# def __init__(self, num=4): +# self.num = num +# def mixgen_batch(image, text): +# batch_size = image.shape[0] +# index = np.random.permutation(batch_size) + +# cat_images = [] +# for i in range(batch_size): +# # image mixup +# image[i,:] = lam * image[i,:] + (1 - lam) * image[index[i],:] +# # text concat +# text[i] = tokenizer((str(text[i]) + " " + str(text[index[i]])))[0] +# text = torch.stack(text) +# return image, text + + +def image_transform( + image_size: int, + is_train: bool, + mean: Optional[Tuple[float, ...]] = None, + std: Optional[Tuple[float, ...]] = None, + resize_longest_max: bool = False, + fill_color: int = 0, +): + mean = mean or OPENAI_DATASET_MEAN + if not isinstance(mean, (list, tuple)): + mean = (mean,) * 3 + + std = std or OPENAI_DATASET_STD + if not isinstance(std, (list, tuple)): + std = (std,) * 3 + + if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]: + # for square size, pass size as int so that Resize() uses aspect preserving shortest edge + image_size = image_size[0] + + normalize = Normalize(mean=mean, std=std) + if is_train: + return Compose( + [ + RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC), + _convert_to_rgb, + ToTensor(), + normalize, + ] + ) + else: + if resize_longest_max: + transforms = [ResizeMaxSize(image_size, fill=fill_color)] + else: + transforms = [ + Resize(image_size, interpolation=InterpolationMode.BICUBIC), + CenterCrop(image_size), + ] + transforms.extend( + [ + _convert_to_rgb, + ToTensor(), + normalize, + ] + ) + return Compose(transforms) diff --git a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/transformer.py b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/transformer.py new file mode 100755 index 000000000..bd5ce4b6e --- /dev/null +++ b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/transformer.py @@ -0,0 +1,683 @@ +import os +import logging +from collections import OrderedDict +import math +from typing import Callable, Optional, Sequence +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F + +try: + from timm.models.layers import trunc_normal_ +except: + from timm.layers import trunc_normal_ + +from .rope import VisionRotaryEmbedding, VisionRotaryEmbeddingFast +from .utils import to_2tuple + +if os.getenv("ENV_TYPE") == "deepspeed": + try: + import deepspeed + from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint + except: + print("Please 'pip install deepspeed'") + deepspeed = None + from torch.utils.checkpoint import checkpoint +else: + from torch.utils.checkpoint import checkpoint + +try: + import xformers.ops as xops +except ImportError: + xops = None + # print("Please 'pip install xformers'") + + +class LayerNormFp32(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back).""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, x: torch.Tensor): + output = F.layer_norm( + x.float(), + self.normalized_shape, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, + ) + return output.type_as(x) + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm (with cast back to input dtype).""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + return x.to(orig_type) + + +class QuickGELU(nn.Module): + # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class LayerScale(nn.Module): + def __init__(self, dim, init_values=1e-5, inplace=False): + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x): + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class PatchDropout(nn.Module): + """ + https://arxiv.org/abs/2212.00794 + """ + + def __init__(self, prob, exclude_first_token=True): + super().__init__() + assert 0 <= prob < 1.0 + self.prob = prob + self.exclude_first_token = exclude_first_token # exclude CLS token + logging.info(f"os.getenv('RoPE')={os.getenv('RoPE')}") + + def forward(self, x): + if not self.training or self.prob == 0.0: + return x + + if self.exclude_first_token: + cls_tokens, x = x[:, :1], x[:, 1:] + else: + cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1]) + + batch = x.size()[0] + num_tokens = x.size()[1] + + batch_indices = torch.arange(batch) + batch_indices = batch_indices[..., None] + + keep_prob = 1 - self.prob + num_patches_keep = max(1, int(num_tokens * keep_prob)) + + rand = torch.randn(batch, num_tokens) + patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices + + x = x[batch_indices, patch_indices_keep] + + if self.exclude_first_token: + x = torch.cat((cls_tokens, x), dim=1) + + if self.training and os.getenv("RoPE") == "1": + return x, patch_indices_keep + + return x + + +def _in_projection_packed( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + w: torch.Tensor, + b: Optional[torch.Tensor] = None, +): + """ + https://github.com/pytorch/pytorch/blob/db2a237763eb8693a20788be94f8c192e762baa8/torch/nn/functional.py#L4726 + """ + E = q.size(-1) + if k is v: + if q is k: + # self-attention + return F.linear(q, w, b).chunk(3, dim=-1) + else: + # encoder-decoder attention + w_q, w_kv = w.split([E, E * 2]) + if b is None: + b_q = b_kv = None + else: + b_q, b_kv = b.split([E, E * 2]) + return (F.linear(q, w_q, b_q),) + F.linear(k, w_kv, b_kv).chunk(2, dim=-1) + else: + w_q, w_k, w_v = w.chunk(3) + if b is None: + b_q = b_k = b_v = None + else: + b_q, b_k, b_v = b.chunk(3) + return F.linear(q, w_q, b_q), F.linear(k, w_k, b_k), F.linear(v, w_v, b_v) + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=True, scaled_cosine=False, scale_heads=False, logit_scale_max=math.log(1.0 / 0.01), attn_drop=0.0, proj_drop=0.0, xattn=False, rope=False): + super().__init__() + self.scaled_cosine = scaled_cosine + self.scale_heads = scale_heads + assert dim % num_heads == 0, "dim should be divisible by num_heads" + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim**-0.5 + self.logit_scale_max = logit_scale_max + + # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original + self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale) + if qkv_bias: + self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3)) + else: + self.in_proj_bias = None + + if self.scaled_cosine: + self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1)))) + else: + self.logit_scale = None + self.attn_drop = nn.Dropout(attn_drop) + if self.scale_heads: + self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1))) + else: + self.head_scale = None + self.out_proj = nn.Linear(dim, dim) + self.out_drop = nn.Dropout(proj_drop) + self.xattn = xattn + self.xattn_drop = attn_drop + self.rope = rope + + def forward(self, x, attn_mask: Optional[torch.Tensor] = None): + L, N, C = x.shape + q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1) + if self.xattn: + q = q.contiguous().view(L, N, self.num_heads, -1).transpose(0, 1) + k = k.contiguous().view(L, N, self.num_heads, -1).transpose(0, 1) + v = v.contiguous().view(L, N, self.num_heads, -1).transpose(0, 1) + + x = xops.memory_efficient_attention( + q, + k, + v, + p=self.xattn_drop, + scale=self.scale if self.logit_scale is None else None, + attn_bias=xops.LowerTriangularMask() if attn_mask is not None else None, + ) + else: + q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) + k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) + v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) + + if self.logit_scale is not None: + attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2)) + logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp() + attn = attn.view(N, self.num_heads, L, L) * logit_scale + attn = attn.view(-1, L, L) + else: + q = q * self.scale + attn = torch.bmm(q, k.transpose(-1, -2)) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype) + new_attn_mask.masked_fill_(attn_mask, float("-inf")) + attn_mask = new_attn_mask + attn += attn_mask + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = torch.bmm(attn, v) + + if self.head_scale is not None: + x = x.view(N, self.num_heads, L, C) * self.head_scale + x = x.view(-1, L, C) + x = x.transpose(0, 1).reshape(L, N, C) + x = self.out_proj(x) + x = self.out_drop(x) + return x + + +class CustomAttention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=True, scaled_cosine=True, scale_heads=False, logit_scale_max=math.log(1.0 / 0.01), attn_drop=0.0, proj_drop=0.0, xattn=False): + super().__init__() + self.scaled_cosine = scaled_cosine + self.scale_heads = scale_heads + assert dim % num_heads == 0, "dim should be divisible by num_heads" + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim**-0.5 + self.logit_scale_max = logit_scale_max + + # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original + self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale) + if qkv_bias: + self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3)) + else: + self.in_proj_bias = None + + if self.scaled_cosine: + self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1)))) + else: + self.logit_scale = None + self.attn_drop = nn.Dropout(attn_drop) + if self.scale_heads: + self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1))) + else: + self.head_scale = None + self.out_proj = nn.Linear(dim, dim) + self.out_drop = nn.Dropout(proj_drop) + self.xattn = xattn + self.xattn_drop = attn_drop + + def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + q, k, v = _in_projection_packed(query, key, value, self.in_proj_weight, self.in_proj_bias) + N_q, B_q, C_q = q.shape + N_k, B_k, C_k = k.shape + N_v, B_v, C_v = v.shape + if self.xattn: + # B, N, C -> B, N, num_heads, C + q = q.permute(1, 0, 2).reshape(B_q, N_q, self.num_heads, -1) + k = k.permute(1, 0, 2).reshape(B_k, N_k, self.num_heads, -1) + v = v.permute(1, 0, 2).reshape(B_v, N_v, self.num_heads, -1) + + x = xops.memory_efficient_attention(q, k, v, p=self.xattn_drop, scale=self.scale if self.logit_scale is None else None, attn_bias=xops.LowerTriangularMask() if attn_mask is not None else None) + else: + # B*H, L, C + q = q.contiguous().view(N_q, B_q * self.num_heads, -1).transpose(0, 1) + k = k.contiguous().view(N_k, B_k * self.num_heads, -1).transpose(0, 1) + v = v.contiguous().view(N_v, B_v * self.num_heads, -1).transpose(0, 1) + + if self.logit_scale is not None: + # B*H, N_q, N_k + attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2)) + logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp() + attn = attn.view(B_q, self.num_heads, N_q, N_k) * logit_scale + attn = attn.view(-1, N_q, N_k) + else: + q = q * self.scale + attn = torch.bmm(q, k.transpose(-1, -2)) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype) + new_attn_mask.masked_fill_(attn_mask, float("-inf")) + attn_mask = new_attn_mask + attn += attn_mask + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = torch.bmm(attn, v) + + if self.head_scale is not None: + x = x.view(B_q, self.num_heads, N_q, C_q) * self.head_scale + x = x.view(-1, N_q, C_q) + x = x.transpose(0, 1).reshape(N_q, B_q, C_q) + x = self.out_proj(x) + x = self.out_drop(x) + return x + + +class CustomResidualAttentionBlock(nn.Module): + def __init__( + self, + d_model: int, + n_head: int, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + scale_cosine_attn: bool = False, + scale_heads: bool = False, + scale_attn: bool = False, + scale_fc: bool = False, + cross_attn: bool = False, + xattn: bool = False, + ): + super().__init__() + + self.ln_1 = norm_layer(d_model) + self.ln_1_k = norm_layer(d_model) if cross_attn else self.ln_1 + self.ln_1_v = norm_layer(d_model) if cross_attn else self.ln_1 + self.attn = CustomAttention(d_model, n_head, qkv_bias=True, attn_drop=0.0, proj_drop=0.0, scaled_cosine=scale_cosine_attn, scale_heads=scale_heads, xattn=xattn) + + self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity() + self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + + self.ln_2 = norm_layer(d_model) + mlp_width = int(d_model * mlp_ratio) + self.mlp = nn.Sequential(OrderedDict([("c_fc", nn.Linear(d_model, mlp_width)), ("ln", norm_layer(mlp_width) if scale_fc else nn.Identity()), ("gelu", act_layer()), ("c_proj", nn.Linear(mlp_width, d_model))])) + + self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + + def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + q = q + self.ls_1(self.ln_attn(self.attn(self.ln_1(q), self.ln_1_k(k), self.ln_1_v(v), attn_mask=attn_mask))) + q = q + self.ls_2(self.mlp(self.ln_2(q))) + return q + + +class CustomTransformer(nn.Module): + def __init__( + self, + width: int, + layers: int, + heads: int, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + scale_cosine_attn: bool = True, + scale_heads: bool = False, + scale_attn: bool = False, + scale_fc: bool = False, + cross_attn: bool = False, + xattn: bool = False, + ): + super().__init__() + self.width = width + self.layers = layers + self.grad_checkpointing = False + self.xattn = xattn + + self.resblocks = nn.ModuleList( + [ + CustomResidualAttentionBlock( + width, + heads, + mlp_ratio, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + scale_cosine_attn=scale_cosine_attn, + scale_heads=scale_heads, + scale_attn=scale_attn, + scale_fc=scale_fc, + cross_attn=cross_attn, + xattn=xattn, + ) + for _ in range(layers) + ] + ) + + def get_cast_dtype(self) -> torch.dtype: + return self.resblocks[0].mlp.c_fc.weight.dtype + + def forward(self, q: torch.Tensor, k: torch.Tensor = None, v: torch.Tensor = None, attn_mask: Optional[torch.Tensor] = None): + if k is None and v is None: + k = v = q + for r in self.resblocks: + if self.grad_checkpointing and not torch.jit.is_scripting(): + q = checkpoint(r, q, k, v, attn_mask) + else: + q = r(q, k, v, attn_mask=attn_mask) + return q + + +class ResidualAttentionBlock(nn.Module): + def __init__( + self, + d_model: int, + n_head: int, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + xattn: bool = False, + ): + super().__init__() + + self.ln_1 = norm_layer(d_model) + if xattn: + self.attn = Attention(d_model, n_head, xattn=True) + else: + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + + self.ln_2 = norm_layer(d_model) + mlp_width = int(d_model * mlp_ratio) + self.mlp = nn.Sequential(OrderedDict([("c_fc", nn.Linear(d_model, mlp_width)), ("gelu", act_layer()), ("c_proj", nn.Linear(mlp_width, d_model))])) + + self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + self.xattn = xattn + + def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + attn_mask = attn_mask.to(x.dtype) if attn_mask is not None else None + if self.xattn: + return self.attn(x, attn_mask=attn_mask) + return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0] + + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + x = x + self.ls_1(self.attention(self.ln_1(x), attn_mask=attn_mask)) + x = x + self.ls_2(self.mlp(self.ln_2(x))) + return x + + +class Transformer(nn.Module): + def __init__( + self, + width: int, + layers: int, + heads: int, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + xattn: bool = False, + ): + super().__init__() + self.width = width + self.layers = layers + self.grad_checkpointing = False + + self.resblocks = nn.ModuleList([ResidualAttentionBlock(width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, xattn=xattn) for _ in range(layers)]) + + def get_cast_dtype(self) -> torch.dtype: + return self.resblocks[0].mlp.c_fc.weight.dtype + + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + for r in self.resblocks: + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(r, x, attn_mask) + else: + x = r(x, attn_mask=attn_mask) + return x + + +class VisionTransformer(nn.Module): + def __init__( + self, + image_size: int, + patch_size: int, + width: int, + layers: int, + heads: int, + mlp_ratio: float, + ls_init_value: float = None, + patch_dropout: float = 0.0, + global_average_pool: bool = False, + output_dim: int = 512, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + xattn: bool = False, + ): + super().__init__() + self.image_size = to_2tuple(image_size) + self.patch_size = to_2tuple(patch_size) + self.grid_size = (self.image_size[0] // self.patch_size[0], self.image_size[1] // self.patch_size[1]) + self.output_dim = output_dim + self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) + + scale = width**-0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width)) + + # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn + self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0.0 else nn.Identity() + self.ln_pre = norm_layer(width) + + self.transformer = Transformer(width, layers, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, xattn=xattn) + + self.global_average_pool = global_average_pool + self.ln_post = norm_layer(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + for param in self.parameters(): + param.requires_grad = False + + if unlocked_groups != 0: + groups = [ + [ + self.conv1, + self.class_embedding, + self.positional_embedding, + self.ln_pre, + ], + *self.transformer.resblocks[:-1], + [ + self.transformer.resblocks[-1], + self.ln_post, + ], + self.proj, + ] + + def _unlock(x): + if isinstance(x, Sequence): + for g in x: + _unlock(g) + else: + if isinstance(x, torch.nn.Parameter): + x.requires_grad = True + else: + for p in x.parameters(): + p.requires_grad = True + + _unlock(groups[-unlocked_groups:]) + + def get_num_layers(self): + return self.transformer.layers + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.transformer.grad_checkpointing = enable + + @torch.jit.ignore + def no_weight_decay(self): + return {"positional_embedding", "class_embedding"} + + def forward(self, x: torch.Tensor, return_all_features: bool = False): + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + + # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in + x = self.patch_dropout(x) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + if not return_all_features: + if self.global_average_pool: + x = x.mean(dim=1) # x = x[:,1:,:].mean(dim=1) + else: + x = x[:, 0] + + x = self.ln_post(x) + + if self.proj is not None: + x = x @ self.proj + + return x + + +class TextTransformer(nn.Module): + def __init__( + self, + context_length: int = 77, + vocab_size: int = 49408, + width: int = 512, + heads: int = 8, + layers: int = 12, + ls_init_value: float = None, + output_dim: int = 512, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + xattn: bool = False, + attn_mask: bool = True, + ): + super().__init__() + self.context_length = context_length + self.vocab_size = vocab_size + self.width = width + self.output_dim = output_dim + + self.token_embedding = nn.Embedding(vocab_size, width) + self.positional_embedding = nn.Parameter(torch.empty(self.context_length, width)) + self.transformer = Transformer(width=width, layers=layers, heads=heads, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, xattn=xattn) + + self.xattn = xattn + self.ln_final = norm_layer(width) + self.text_projection = nn.Parameter(torch.empty(width, output_dim)) + + if attn_mask: + self.register_buffer("attn_mask", self.build_attention_mask(), persistent=False) + else: + self.attn_mask = None + + self.init_parameters() + + def init_parameters(self): + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + + proj_std = (self.transformer.width**-0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width**-0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_(self.text_projection, std=self.transformer.width**-0.5) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.transformer.grad_checkpointing = enable + + @torch.jit.ignore + def no_weight_decay(self): + # return {'positional_embedding', 'token_embedding'} + return {"positional_embedding"} + + def get_num_layers(self): + return self.transformer.layers + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + def forward(self, text, return_all_features: bool = False): + cast_dtype = self.transformer.get_cast_dtype() + x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding.to(cast_dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x, attn_mask=self.attn_mask) + # x = self.transformer(x) # no attention mask is applied + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x) + + if not return_all_features: + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + return x diff --git a/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/utils.py b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/utils.py new file mode 100755 index 000000000..73b6d6627 --- /dev/null +++ b/llava/model/multimodal_encoder/dev_eva_clip/eva_clip/utils.py @@ -0,0 +1,321 @@ +from itertools import repeat +import collections.abc +import logging +import math +import numpy as np + +import torch +from torch import nn as nn +from torchvision.ops.misc import FrozenBatchNorm2d +import torch.nn.functional as F + + +# open CLIP +def resize_clip_pos_embed(state_dict, model, interpolation: str = "bicubic", seq_dim=1): + # Rescale the grid of position embeddings when loading from state_dict + old_pos_embed = state_dict.get("visual.positional_embedding", None) + if old_pos_embed is None or not hasattr(model.visual, "grid_size"): + return + grid_size = to_2tuple(model.visual.grid_size) + extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more) + new_seq_len = grid_size[0] * grid_size[1] + extra_tokens + if new_seq_len == old_pos_embed.shape[0]: + return + + if extra_tokens: + pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:] + else: + pos_emb_tok, pos_emb_img = None, old_pos_embed + old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img)))) + + logging.info("Resizing position embedding grid-size from %s to %s", old_grid_size, grid_size) + pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2) + pos_emb_img = F.interpolate( + pos_emb_img, + size=grid_size, + mode=interpolation, + align_corners=True, + ) + pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0] + if pos_emb_tok is not None: + new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0) + else: + new_pos_embed = pos_emb_img + state_dict["visual.positional_embedding"] = new_pos_embed + + +def resize_visual_pos_embed(state_dict, model, interpolation: str = "bicubic", seq_dim=1): + # Rescale the grid of position embeddings when loading from state_dict + old_pos_embed = state_dict.get("positional_embedding", None) + if old_pos_embed is None or not hasattr(model.visual, "grid_size"): + return + grid_size = to_2tuple(model.visual.grid_size) + extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more) + new_seq_len = grid_size[0] * grid_size[1] + extra_tokens + if new_seq_len == old_pos_embed.shape[0]: + return + + if extra_tokens: + pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:] + else: + pos_emb_tok, pos_emb_img = None, old_pos_embed + old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img)))) + + logging.info("Resizing position embedding grid-size from %s to %s", old_grid_size, grid_size) + pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2) + pos_emb_img = F.interpolate( + pos_emb_img, + size=grid_size, + mode=interpolation, + align_corners=True, + ) + pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0] + if pos_emb_tok is not None: + new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0) + else: + new_pos_embed = pos_emb_img + state_dict["positional_embedding"] = new_pos_embed + + +def resize_evaclip_pos_embed(state_dict, model, interpolation: str = "bicubic", seq_dim=1): + all_keys = list(state_dict.keys()) + # interpolate position embedding + if "visual.pos_embed" in state_dict: + pos_embed_checkpoint = state_dict["visual.pos_embed"] + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = model.visual.patch_embed.num_patches + # num_extra_tokens = model.visual.pos_embed.shape[-2] - num_patches + num_extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more) + # height (== width) for the checkpoint position embedding + orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) + # height (== width) for the new position embedding + new_size = int(num_patches**0.5) + # class_token and dist_token are kept unchanged + if orig_size != new_size: + print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate(pos_tokens, size=(new_size, new_size), mode="bicubic", align_corners=False) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + state_dict["visual.pos_embed"] = new_pos_embed + + patch_embed_proj = state_dict["visual.patch_embed.proj.weight"] + patch_size = model.visual.patch_embed.patch_size + state_dict["visual.patch_embed.proj.weight"] = torch.nn.functional.interpolate(patch_embed_proj.float(), size=patch_size, mode="bicubic", align_corners=False) + + +def resize_eva_pos_embed(state_dict, model, interpolation: str = "bicubic", seq_dim=1): + all_keys = list(state_dict.keys()) + # interpolate position embedding + if "pos_embed" in state_dict: + pos_embed_checkpoint = state_dict["pos_embed"] + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = model.visual.patch_embed.num_patches + # num_extra_tokens = model.visual.pos_embed.shape[-2] - num_patches + num_extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more) + # height (== width) for the checkpoint position embedding + orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) + # height (== width) for the new position embedding + new_size = int(num_patches**0.5) + # class_token and dist_token are kept unchanged + if orig_size != new_size: + print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate(pos_tokens, size=(new_size, new_size), mode="bicubic", align_corners=False) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + state_dict["pos_embed"] = new_pos_embed + + patch_embed_proj = state_dict["patch_embed.proj.weight"] + patch_size = model.visual.patch_embed.patch_size + state_dict["patch_embed.proj.weight"] = torch.nn.functional.interpolate(patch_embed_proj.float(), size=patch_size, mode="bicubic", align_corners=False) + + +def resize_rel_pos_embed(state_dict, model, interpolation: str = "bicubic", seq_dim=1): + all_keys = list(state_dict.keys()) + for key in all_keys: + if "relative_position_index" in key: + state_dict.pop(key) + + if "relative_position_bias_table" in key: + rel_pos_bias = state_dict[key] + src_num_pos, num_attn_heads = rel_pos_bias.size() + dst_num_pos, _ = model.visual.state_dict()[key].size() + dst_patch_shape = model.visual.patch_embed.patch_shape + if dst_patch_shape[0] != dst_patch_shape[1]: + raise NotImplementedError() + num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (dst_patch_shape[1] * 2 - 1) + src_size = int((src_num_pos - num_extra_tokens) ** 0.5) + dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5) + if src_size != dst_size: + print("Position interpolate for %s from %dx%d to %dx%d" % (key, src_size, src_size, dst_size, dst_size)) + extra_tokens = rel_pos_bias[-num_extra_tokens:, :] + rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :] + + def geometric_progression(a, r, n): + return a * (1.0 - r**n) / (1.0 - r) + + left, right = 1.01, 1.5 + while right - left > 1e-6: + q = (left + right) / 2.0 + gp = geometric_progression(1, q, src_size // 2) + if gp > dst_size // 2: + right = q + else: + left = q + + # if q > 1.090307: + # q = 1.090307 + + dis = [] + cur = 1 + for i in range(src_size // 2): + dis.append(cur) + cur += q ** (i + 1) + + r_ids = [-_ for _ in reversed(dis)] + + x = r_ids + [0] + dis + y = r_ids + [0] + dis + + t = dst_size // 2.0 + dx = np.arange(-t, t + 0.1, 1.0) + dy = np.arange(-t, t + 0.1, 1.0) + + print("Original positions = %s" % str(x)) + print("Target positions = %s" % str(dx)) + + all_rel_pos_bias = [] + + for i in range(num_attn_heads): + z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy() + f = F.interpolate.interp2d(x, y, z, kind="cubic") + all_rel_pos_bias.append(torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(rel_pos_bias.device)) + + rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1) + + new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0) + state_dict[key] = new_rel_pos_bias + + # interpolate position embedding + if "pos_embed" in state_dict: + pos_embed_checkpoint = state_dict["pos_embed"] + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = model.visual.patch_embed.num_patches + num_extra_tokens = model.visual.pos_embed.shape[-2] - num_patches + # height (== width) for the checkpoint position embedding + orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) + # height (== width) for the new position embedding + new_size = int(num_patches**0.5) + # class_token and dist_token are kept unchanged + if orig_size != new_size: + print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate(pos_tokens, size=(new_size, new_size), mode="bicubic", align_corners=False) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + state_dict["pos_embed"] = new_pos_embed + + patch_embed_proj = state_dict["patch_embed.proj.weight"] + patch_size = model.visual.patch_embed.patch_size + state_dict["patch_embed.proj.weight"] = torch.nn.functional.interpolate(patch_embed_proj.float(), size=patch_size, mode="bicubic", align_corners=False) + + +def freeze_batch_norm_2d(module, module_match={}, name=""): + """ + Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is + itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and + returned. Otherwise, the module is walked recursively and submodules are converted in place. + + Args: + module (torch.nn.Module): Any PyTorch module. + module_match (dict): Dictionary of full module names to freeze (all if empty) + name (str): Full module name (prefix) + + Returns: + torch.nn.Module: Resulting module + + Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 + """ + res = module + is_match = True + if module_match: + is_match = name in module_match + if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)): + res = FrozenBatchNorm2d(module.num_features) + res.num_features = module.num_features + res.affine = module.affine + if module.affine: + res.weight.data = module.weight.data.clone().detach() + res.bias.data = module.bias.data.clone().detach() + res.running_mean.data = module.running_mean.data + res.running_var.data = module.running_var.data + res.eps = module.eps + else: + for child_name, child in module.named_children(): + full_child_name = ".".join([name, child_name]) if name else child_name + new_child = freeze_batch_norm_2d(child, module_match, full_child_name) + if new_child is not child: + res.add_module(child_name, new_child) + return res + + +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + + return parse + + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) +to_3tuple = _ntuple(3) +to_4tuple = _ntuple(4) +to_ntuple = lambda n, x: _ntuple(n)(x) + + +def is_logging(args): + def is_global_master(args): + return args.rank == 0 + + def is_local_master(args): + return args.local_rank == 0 + + def is_master(args, local=False): + return is_local_master(args) if local else is_global_master(args) + + return is_master + + +class AllGather(torch.autograd.Function): + """An autograd function that performs allgather on a tensor. + Performs all_gather operation on the provided tensors. + *** Warning ***: torch.distributed.all_gather has no gradient. + """ + + @staticmethod + def forward(ctx, tensor, rank, world_size): + tensors_gather = [torch.empty_like(tensor) for _ in range(world_size)] + torch.distributed.all_gather(tensors_gather, tensor) + ctx.rank = rank + ctx.batch_size = tensor.shape[0] + return torch.cat(tensors_gather, 0) + + @staticmethod + def backward(ctx, grad_output): + return (grad_output[ctx.batch_size * ctx.rank : ctx.batch_size * (ctx.rank + 1)], None, None) + + +allgather = AllGather.apply diff --git a/llava/model/multimodal_encoder/dev_eva_clip/eva_vit.py b/llava/model/multimodal_encoder/dev_eva_clip/eva_vit.py new file mode 100755 index 000000000..9bd746553 --- /dev/null +++ b/llava/model/multimodal_encoder/dev_eva_clip/eva_vit.py @@ -0,0 +1,141 @@ +# Based on EVA, BEIT, timm and DeiT code bases +# https://github.com/baaivision/EVA +# https://github.com/rwightman/pytorch-image-models/tree/master/timm +# https://github.com/microsoft/unilm/tree/master/beit +# https://github.com/facebookresearch/deit/ +# https://github.com/facebookresearch/dino +# --------------------------------------------------------' +# not tested yet +import math +from transformers import CLIPImageProcessor + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import drop_path, to_2tuple, trunc_normal_ +from .eva_clip import create_model_and_transforms, get_model_config +import torch +import torchvision +import time + +from llava.utils import rank0_print + + +class EvaViTWrapper(nn.Module): + def __init__(self, vision_tower, args, delay_load=False): + super().__init__() + + self.is_loaded = False + self.vision_tower_name = vision_tower + self.pretrained = args.vision_tower_pretrained + self.args = args + + self.select_layer = args.mm_vision_select_layer + if self.select_layer < -1: + self.select_layer += 1 + self.select_feature = getattr(args, "mm_vision_select_feature", "patch") + + self.model_config = get_model_config(self.vision_tower_name) + + if not delay_load: + rank0_print(f"Loading vision tower: {vision_tower}") + self.load_model() + elif getattr(args, "unfreeze_mm_vision_tower", False): + # TODO: better detector is needed. + rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.") + self.load_model() + elif hasattr(args, "mm_tunable_parts") and "mm_vision_tower" in args.mm_tunable_parts: + rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.") + self.load_model() + + def load_model(self): + rank0_print(f"Loading: {self.vision_tower_name}") + rank0_print(f"Pretrained: {self.pretrained}") + time_start = time.time() + model, _, image_processor = create_model_and_transforms(self.vision_tower_name, self.pretrained, force_custom_clip=True, precision="fp16") + time_end = time.time() + rank0_print(f"Loaded: {self.vision_tower_name} in {time_end - time_start:.2f}s") + self.device = next(model.parameters()).device + self.dtype = next(model.parameters()).dtype + if self.device.type != "meta": + model = model.to("cuda") + self.vision_tower = model.visual + resize_transform = [t for t in image_processor.transforms if isinstance(t, torchvision.transforms.Resize)][0] + normalize_transform = [t for t in image_processor.transforms if isinstance(t, torchvision.transforms.Normalize)][0] + self.resize_transform_size = resize_transform.size + self.image_processor = CLIPImageProcessor.from_pretrained( + "openai/clip-vit-large-patch14", + crop_size=resize_transform.size, + size={"shortest_edge": resize_transform.size}, + image_mean=list(normalize_transform.mean), + image_std=list(normalize_transform.std), + ) + rank0_print(f"Loaded image processor: {self.image_processor}") + self.vision_tower.requires_grad_(False) + self.is_loaded = True + + def feature_select(self, image_features): + select_feature_type = self.select_feature + + # if self.select_feature in ["slicefour_patch", "slicefour_cls_patch"]: + # select_every_k_layer = len(image_features) // 4 + # image_features = torch.cat([image_features[i] for i in range(select_every_k_layer + self.select_layer, len(image_features), select_every_k_layer)], dim=-1) + # select_feature_type = select_feature_type.replace("slicefour_", "") + # elif self.select_feature in ["slice_m25811_f6_patch", "slice_m25811_f6_cls_patch"]: + # select_layers = [-1, -4, -7, -10, 6] + # image_features = torch.cat([image_features[i] for i in select_layers], dim=-1) + # select_feature_type = select_feature_type.replace("slice_m25811_f6_", "") + # else: + # image_features = image_features[self.select_layer] + + if select_feature_type == "patch": + image_features = image_features[:, 1:] + elif select_feature_type == "cls_patch": + image_features = image_features + else: + raise ValueError(f"Unexpected select feature: {select_feature_type}") + return image_features + + def train(self, mode=True): + self.training = mode + + if self.is_loaded: + self.vision_tower.eval() + + def forward(self, images): + if type(images) is list: + image_features = [] + for image in images: + image_features = self.vision_tower.forward_features(image.to(self.dtype), return_all_features=True) + image_features = self.feature_select(image_features).to(self.dtype) + image_features.append(image_features) + else: + image_features = self.vision_tower.forward_features(images.to(self.dtype), return_all_features=True) + image_features = self.feature_select(image_features).to(self.dtype) + + return image_features + + @property + def dummy_feature(self): + return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) + + @property + def hidden_size(self): + return self.model_config["vision_cfg"]["width"] + + @property + def num_patches(self): + return (self.model_config["vision_cfg"]["image_size"] // self.model_config["vision_cfg"]["patch_size"]) ** 2 + + @property + def num_patches_per_side(self): + return self.model_config["vision_cfg"]["image_size"] // self.model_config["vision_cfg"]["patch_size"] + + @property + def config(self): + return self.model_config + + @property + def image_size(self): + return self.model_config["vision_cfg"]["image_size"] diff --git a/llava/model/multimodal_encoder/eva_clip/eva_clip_encoder.py b/llava/model/multimodal_encoder/eva_clip/eva_clip_encoder.py new file mode 100755 index 000000000..aaef516a9 --- /dev/null +++ b/llava/model/multimodal_encoder/eva_clip/eva_clip_encoder.py @@ -0,0 +1,74 @@ +import torch +import torch.nn as nn + +from .eva_clip_processors import EvaClipImageTrainProcessor +from .eva_vit import EVAEncoderWrapper +from .factory import list_models, add_model_config, get_model_config + +from llava.utils import rank0_print + + +class EvaClipVisionTower(nn.Module): + def __init__(self, vision_tower, args, delay_load=False): + super().__init__() + + self.is_loaded = False + self.vision_tower_name = vision_tower + self.vision_tower_pretrained = args.vision_tower_pretrained + self.config = get_model_config(vision_tower) + + if not delay_load: + rank0_print(f"Loading EVA ViT: {self.vision_tower_name}") + self.load_model() + elif getattr(args, "unfreeze_mm_vision_tower", False): + # TODO: better detector is needed. + rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.") + self.load_model() + elif hasattr(args, "mm_tunable_parts") and "mm_vision_tower" in args.mm_tunable_parts: + rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.") + self.load_model() + else: + self.cfg_only = self.config + + def load_model(self, device_map=None): + rank0_print(f"Pretrained: {self.vision_tower_pretrained}") + self.image_processor = EvaClipImageTrainProcessor(self.config["vision_cfg"]["image_size"]) + self.vision_tower = EVAEncoderWrapper(self.vision_tower_pretrained, self.config) + rank0_print(f"Loaded image processor: {self.image_processor}") + self.vision_tower.requires_grad_(False) + self.is_loaded = True + + def forward(self, images): + if type(images) is list: + image_features = [] + for image in images: + image_feature = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0)).to(image.dtype) + image_features.append(image_feature) + else: + image_features = self.vision_tower(images.to(device=self.device, dtype=self.dtype)).to(images.dtype) + + return image_features + + @property + def dtype(self): + return self.vision_tower.dtype + + @property + def device(self): + return self.vision_tower.device + + @property + def hidden_size(self): + return self.config["vision_cfg"]["width"] + + @property + def num_patches(self): + return (self.config["vision_cfg"]["image_size"] // self.config["vision_cfg"]["patch_size"]) ** 2 + + @property + def num_patches_per_side(self): + return self.config["vision_cfg"]["image_size"] // self.config["vision_cfg"]["patch_size"] + + @property + def image_size(self): + return self.config["vision_cfg"]["image_size"] diff --git a/llava/model/multimodal_encoder/eva_clip/eva_clip_processors.py b/llava/model/multimodal_encoder/eva_clip/eva_clip_processors.py new file mode 100755 index 000000000..7ee127315 --- /dev/null +++ b/llava/model/multimodal_encoder/eva_clip/eva_clip_processors.py @@ -0,0 +1,72 @@ +""" +# Adapted from https://github.com/baaivision/EVA/tree/master/EVA-CLIP +""" + +from torchvision import transforms +from torchvision.transforms.functional import InterpolationMode +from transformers.image_processing_utils import BatchFeature +from PIL import Image +from transformers.image_transforms import convert_to_rgb + + +class BaseProcessor: + def __init__(self): + self.transform = lambda x: x + return + + def __call__(self, item): + return self.transform(item) + + +class EvaClipImageBaseProcessor(BaseProcessor): + def __init__(self, mean=None, std=None): + self.mean = (0.48145466, 0.4578275, 0.40821073) if mean is None else mean + self.std = (0.26862954, 0.26130258, 0.27577711) if std is None else std + + self.normalize = transforms.Normalize(self.mean, self.std) + + @property + def image_mean(self): + return self.mean + + +class EvaClipImageTrainProcessor(EvaClipImageBaseProcessor): + def __init__(self, image_size=224, mean=None, std=None, min_scale=0.5, max_scale=1.0): + super().__init__(mean=mean, std=std) + + self.transform = transforms.Compose( + [ + convert_to_rgb, + transforms.Resize( + image_size, + interpolation=InterpolationMode.BICUBIC, + ), + transforms.CenterCrop(image_size), + transforms.ToTensor(), + self.normalize, + ] + ) + + self.image_size = image_size + + def preprocess(self, images, return_tensors): + if isinstance(images, Image.Image): + images = [images] + else: + assert isinstance(images, list) + + transformed_images = [self.transform(image).numpy() for image in images] + data = {"pixel_values": transformed_images} + + return BatchFeature(data=data, tensor_type=return_tensors) + + def __call__(self, item): + return self.transform(item) + + @property + def crop_size(self): + return {"height": self.image_size, "width": self.image_size} + + @property + def size(self): + return {"shortest_edge": self.image_size} diff --git a/llava/model/multimodal_encoder/eva_clip/eva_vit.py b/llava/model/multimodal_encoder/eva_clip/eva_vit.py new file mode 100755 index 000000000..d2330c322 --- /dev/null +++ b/llava/model/multimodal_encoder/eva_clip/eva_vit.py @@ -0,0 +1,856 @@ +""" +# Adapted from https://github.com/baaivision/EVA/tree/master/EVA-CLIP +""" + +from math import pi +import torch +from torch import nn +from einops import rearrange, repeat +import logging +from llava.utils import rank0_print + + +def broadcat(tensors, dim=-1): + num_tensors = len(tensors) + shape_lens = set(list(map(lambda t: len(t.shape), tensors))) + assert len(shape_lens) == 1, "tensors must all have the same number of dimensions" + shape_len = list(shape_lens)[0] + dim = (dim + shape_len) if dim < 0 else dim + dims = list(zip(*map(lambda t: list(t.shape), tensors))) + expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] + assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), "invalid dimensions for broadcastable concatentation" + max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims)) + expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims)) + expanded_dims.insert(dim, (dim, dims[dim])) + expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims))) + tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes))) + return torch.cat(tensors, dim=dim) + + +def rotate_half(x): + x = rearrange(x, "... (d r) -> ... d r", r=2) + x1, x2 = x.unbind(dim=-1) + x = torch.stack((-x2, x1), dim=-1) + return rearrange(x, "... d r -> ... (d r)") + + +class VisionRotaryEmbeddingFast(nn.Module): + def __init__(self, dim, pt_seq_len, ft_seq_len=None, custom_freqs=None, freqs_for="lang", theta=10000, max_freq=10, num_freqs=1, patch_dropout=0.0): + super().__init__() + if custom_freqs: + freqs = custom_freqs + elif freqs_for == "lang": + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + elif freqs_for == "pixel": + freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi + elif freqs_for == "constant": + freqs = torch.ones(num_freqs).float() + else: + raise ValueError(f"unknown modality {freqs_for}") + + if ft_seq_len is None: + ft_seq_len = pt_seq_len + t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len + + freqs = torch.einsum("..., f -> ... f", t, freqs) + freqs = repeat(freqs, "... n -> ... (n r)", r=2) + freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1) + + freqs_cos = freqs.cos().view(-1, freqs.shape[-1]) + freqs_sin = freqs.sin().view(-1, freqs.shape[-1]) + + self.patch_dropout = patch_dropout + + self.register_buffer("freqs_cos", freqs_cos) + self.register_buffer("freqs_sin", freqs_sin) + + logging.info(f"Shape of rope freq: {self.freqs_cos.shape}") + + def forward(self, t, patch_indices_keep=None): + if patch_indices_keep is not None: + batch = t.size()[0] + batch_indices = torch.arange(batch) + batch_indices = batch_indices[..., None] + + freqs_cos = repeat(self.freqs_cos, "i j -> n i m j", n=t.shape[0], m=t.shape[1]) + freqs_sin = repeat(self.freqs_sin, "i j -> n i m j", n=t.shape[0], m=t.shape[1]) + + freqs_cos = freqs_cos[batch_indices, patch_indices_keep] + freqs_cos = rearrange(freqs_cos, "n i m j -> n m i j") + freqs_sin = freqs_sin[batch_indices, patch_indices_keep] + freqs_sin = rearrange(freqs_sin, "n i m j -> n m i j") + + return t * freqs_cos + rotate_half(t) * freqs_sin + + return t * self.freqs_cos + rotate_half(t) * self.freqs_sin + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm (with cast back to input dtype).""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + return x.to(orig_type) + + +class PatchDropout(nn.Module): + """ + https://arxiv.org/abs/2212.00794 + """ + + def __init__(self, prob, exclude_first_token=True): + super().__init__() + assert 0 <= prob < 1.0 + self.prob = prob + self.exclude_first_token = exclude_first_token # exclude CLS token + logging.info(f"os.getenv('RoPE')={os.getenv('RoPE')}") + + def forward(self, x): + if not self.training or self.prob == 0.0: + return x + + if self.exclude_first_token: + cls_tokens, x = x[:, :1], x[:, 1:] + else: + cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1]) + + batch = x.size()[0] + num_tokens = x.size()[1] + + batch_indices = torch.arange(batch) + batch_indices = batch_indices[..., None] + + keep_prob = 1 - self.prob + num_patches_keep = max(1, int(num_tokens * keep_prob)) + + rand = torch.randn(batch, num_tokens) + patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices + + x = x[batch_indices, patch_indices_keep] + + if self.exclude_first_token: + x = torch.cat((cls_tokens, x), dim=1) + + if self.training and os.getenv("RoPE") == "1": + return x, patch_indices_keep + + return x + + +# -------------------------------------------------------- +# Adapted from https://github.com/microsoft/unilm/tree/master/beit +# -------------------------------------------------------- +import math +import os +import torch.nn as nn +import torch.nn.functional as F + +try: + from timm.models.layers import drop_path, to_2tuple, trunc_normal_ +except: + from timm.layers import drop_path, to_2tuple, trunc_normal_ + +if os.getenv("ENV_TYPE") == "deepspeed": + try: + from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint + except: + from torch.utils.checkpoint import checkpoint +else: + from torch.utils.checkpoint import checkpoint + +try: + import xformers.ops as xops +except ImportError: + xops = None + # print("Please 'pip install xformers'") + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return "p={}".format(self.drop_prob) + + +class Mlp(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + drop=0.0, + subln=False, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + + self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity() + + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + # x = self.drop(x) + # commit this for the orignal BERT implement + x = self.ffn_ln(x) + + x = self.fc2(x) + x = self.drop(x) + return x + + +class SwiGLU(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0.0, norm_layer=nn.LayerNorm, subln=False): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + + self.w1 = nn.Linear(in_features, hidden_features) + self.w2 = nn.Linear(in_features, hidden_features) + + self.act = act_layer() + self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity() + self.w3 = nn.Linear(hidden_features, out_features) + + self.drop = nn.Dropout(drop) + + def forward(self, x): + x1 = self.w1(x) + x2 = self.w2(x) + hidden = self.act(x1) * x2 + x = self.ffn_ln(hidden) + x = self.w3(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0, window_size=None, attn_head_dim=None, xattn=False, rope=None, subln=False, norm_layer=nn.LayerNorm): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + if attn_head_dim is not None: + head_dim = attn_head_dim + all_head_dim = head_dim * self.num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.subln = subln + if self.subln: + self.q_proj = nn.Linear(dim, all_head_dim, bias=False) + self.k_proj = nn.Linear(dim, all_head_dim, bias=False) + self.v_proj = nn.Linear(dim, all_head_dim, bias=False) + else: + self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) + + if qkv_bias: + self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) + self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) + else: + self.q_bias = None + self.v_bias = None + + if window_size: + self.window_size = window_size + self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 + self.relative_position_bias_table = nn.Parameter(torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH + # cls to token & token 2 cls & cls to cls + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(window_size[0]) + coords_w = torch.arange(window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * window_size[1] - 1 + relative_position_index = torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype) + relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + relative_position_index[0, 0:] = self.num_relative_distance - 3 + relative_position_index[0:, 0] = self.num_relative_distance - 2 + relative_position_index[0, 0] = self.num_relative_distance - 1 + + self.register_buffer("relative_position_index", relative_position_index) + else: + self.window_size = None + self.relative_position_bias_table = None + self.relative_position_index = None + + self.attn_drop = nn.Dropout(attn_drop) + self.inner_attn_ln = norm_layer(all_head_dim) if subln else nn.Identity() + # self.proj = nn.Linear(all_head_dim, all_head_dim) + self.proj = nn.Linear(all_head_dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.xattn = xattn + self.xattn_drop = attn_drop + + self.rope = rope + + def forward(self, x, rel_pos_bias=None, attn_mask=None): + B, N, C = x.shape + if self.subln: + q = F.linear(input=x, weight=self.q_proj.weight, bias=self.q_bias) + k = F.linear(input=x, weight=self.k_proj.weight, bias=None) + v = F.linear(input=x, weight=self.v_proj.weight, bias=self.v_bias) + + q = q.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) # B, num_heads, N, C + k = k.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) + v = v.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) + else: + + qkv_bias = None + if self.q_bias is not None: + qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) + + qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) + qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) # 3, B, num_heads, N, C + q, k, v = qkv[0], qkv[1], qkv[2] + + if self.rope: + # slightly fast impl + q_t = q[:, :, 1:, :] + ro_q_t = self.rope(q_t) + q = torch.cat((q[:, :, :1, :], ro_q_t), -2).type_as(v) + + k_t = k[:, :, 1:, :] + ro_k_t = self.rope(k_t) + k = torch.cat((k[:, :, :1, :], ro_k_t), -2).type_as(v) + + if self.xattn and xops is not None: + q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) + + x = xops.memory_efficient_attention( + q, + k, + v, + p=self.xattn_drop, + scale=self.scale, + ) + x = x.reshape(B, N, -1) + x = self.inner_attn_ln(x) + x = self.proj(x) + x = self.proj_drop(x) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + + if self.relative_position_bias_table is not None: + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0).type_as(attn) + + if rel_pos_bias is not None: + attn = attn + rel_pos_bias.type_as(attn) + + if attn_mask is not None: + attn_mask = attn_mask.bool() + attn = attn.masked_fill(~attn_mask[:, None, None, :], float("-inf")) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, -1) + x = self.inner_attn_ln(x) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + + def __init__( + self, + dim, + num_heads, + mlp_ratio=4.0, + qkv_bias=False, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + init_values=None, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + window_size=None, + attn_head_dim=None, + xattn=False, + rope=None, + postnorm=False, + subln=False, + naiveswiglu=False, + ): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim, xattn=xattn, rope=rope, subln=subln, norm_layer=norm_layer + ) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + + if naiveswiglu: + self.mlp = SwiGLU( + in_features=dim, + hidden_features=mlp_hidden_dim, + subln=subln, + norm_layer=norm_layer, + ) + else: + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, subln=subln, drop=drop) + + if init_values is not None and init_values > 0: + self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) + self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) + else: + self.gamma_1, self.gamma_2 = None, None + + self.postnorm = postnorm + + def forward(self, x, rel_pos_bias=None, attn_mask=None): + if self.gamma_1 is None: + if self.postnorm: + x = x + self.drop_path(self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))) + x = x + self.drop_path(self.norm2(self.mlp(x))) + else: + x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)) + x = x + self.drop_path(self.mlp(self.norm2(x))) + else: + if self.postnorm: + x = x + self.drop_path(self.gamma_1 * self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))) + x = x + self.drop_path(self.gamma_2 * self.norm2(self.mlp(x))) + else: + x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)) + x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) + return x + + +class PatchEmbed(nn.Module): + """Image to Patch Embedding""" + + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) + self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x, **kwargs): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + assert H == self.img_size[0] and W == self.img_size[1], f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) + return x + + +class RelativePositionBias(nn.Module): + + def __init__(self, window_size, num_heads): + super().__init__() + self.window_size = window_size + self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 + self.relative_position_bias_table = nn.Parameter(torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH + # cls to token & token 2 cls & cls to cls + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(window_size[0]) + coords_w = torch.arange(window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * window_size[1] - 1 + relative_position_index = torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype) + relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + relative_position_index[0, 0:] = self.num_relative_distance - 3 + relative_position_index[0:, 0] = self.num_relative_distance - 2 + relative_position_index[0, 0] = self.num_relative_distance - 1 + + self.register_buffer("relative_position_index", relative_position_index) + + def forward(self): + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH + return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + + +class EVAVisionTransformer(nn.Module): + """Vision Transformer with support for patch or hybrid CNN input stage""" + + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + num_classes=1000, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=False, + qk_scale=None, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.0, + norm_layer=nn.LayerNorm, + init_values=None, + patch_dropout=0.0, + use_abs_pos_emb=True, + use_rel_pos_bias=False, + use_shared_rel_pos_bias=False, + rope=False, + use_mean_pooling=True, + init_scale=0.001, + grad_checkpointing=False, + xattn=False, + postnorm=False, + pt_hw_seq_len=16, + intp_freq=False, + naiveswiglu=False, + subln=False, + ): + super().__init__() + self.image_size = img_size + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + + self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + if use_abs_pos_emb: + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) + else: + self.pos_embed = None + self.pos_drop = nn.Dropout(p=drop_rate) + + if use_shared_rel_pos_bias: + self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads) + else: + self.rel_pos_bias = None + + if rope: + half_head_dim = embed_dim // num_heads // 2 + hw_seq_len = img_size // patch_size + self.rope = VisionRotaryEmbeddingFast( + dim=half_head_dim, + pt_seq_len=pt_hw_seq_len, + ft_seq_len=hw_seq_len if intp_freq else None, + # patch_dropout=patch_dropout + ) + else: + self.rope = None + + self.naiveswiglu = naiveswiglu + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.use_rel_pos_bias = use_rel_pos_bias + self.blocks = nn.ModuleList( + [ + Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + init_values=init_values, + window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None, + xattn=xattn, + rope=self.rope, + postnorm=postnorm, + subln=subln, + naiveswiglu=naiveswiglu, + ) + for i in range(depth) + ] + ) + self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim) + self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None + self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + if self.pos_embed is not None: + trunc_normal_(self.pos_embed, std=0.02) + + trunc_normal_(self.cls_token, std=0.02) + # trunc_normal_(self.mask_token, std=.02) + + self.apply(self._init_weights) + self.fix_init_weight() + + if isinstance(self.head, nn.Linear): + trunc_normal_(self.head.weight, std=0.02) + self.head.weight.data.mul_(init_scale) + self.head.bias.data.mul_(init_scale) + + # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn + self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0.0 else nn.Identity() + + self.grad_checkpointing = grad_checkpointing + + def fix_init_weight(self): + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + + for layer_id, layer in enumerate(self.blocks): + rescale(layer.attn.proj.weight.data, layer_id + 1) + if self.naiveswiglu: + rescale(layer.mlp.w3.weight.data, layer_id + 1) + else: + rescale(layer.mlp.fc2.weight.data, layer_id + 1) + + def get_cast_dtype(self) -> torch.dtype: + return self.blocks[0].mlp.fc2.weight.dtype + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def get_num_layers(self): + return len(self.blocks) + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + assert unlocked_groups == 0, "partial locking not currently supported for this model" + for param in self.parameters(): + param.requires_grad = False + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + + @torch.jit.ignore + def no_weight_decay(self): + return {"pos_embed", "cls_token"} + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=""): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x, return_all_features=False): + + x = self.patch_embed(x) + batch_size, seq_len, _ = x.size() + + cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + if self.pos_embed is not None: + x = x + self.pos_embed + x = self.pos_drop(x) + + # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in + if os.getenv("RoPE") == "1": + if self.training and not isinstance(self.patch_dropout, nn.Identity): + x, patch_indices_keep = self.patch_dropout(x) + # Directly pass patch_indices_keep to self.rope.forward + x = self.rope.forward(x, patch_indices_keep=patch_indices_keep) + else: + # Pass None or omit the patch_indices_keep argument for default behavior + x = self.rope.forward(x, patch_indices_keep=None) + x = self.patch_dropout(x) + else: + x = self.patch_dropout(x) + + rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None + for i, blk in enumerate(self.blocks): + if i == len(self.blocks) - 1: + continue + if self.grad_checkpointing: + x = checkpoint(blk, x, (rel_pos_bias,)) + else: + x = blk(x, rel_pos_bias=rel_pos_bias) + + if not return_all_features: + x = self.norm(x) + if self.fc_norm is not None: + return self.fc_norm(x.mean(1)) + else: + return x[:, 0] + return x + + def forward(self, x, return_all_features=False): + if return_all_features: + return self.forward_features(x, return_all_features) + x = self.forward_features(x) + x = self.head(x) + return x + + +def load_state_dict(checkpoint_path: str, map_location: str = "cpu", model_key: str = "model|module|state_dict", is_openai: bool = False, skip_list: list = []): + if is_openai: + model = torch.jit.load(checkpoint_path, map_location="cpu").eval() + state_dict = model.state_dict() + for key in ["input_resolution", "context_length", "vocab_size"]: + state_dict.pop(key, None) + else: + checkpoint = torch.load(checkpoint_path, map_location=map_location) + for mk in model_key.split("|"): + if isinstance(checkpoint, dict) and mk in checkpoint: + state_dict = checkpoint[mk] + break + else: + state_dict = checkpoint + if next(iter(state_dict.items()))[0].startswith("module"): + state_dict = {k[7:]: v for k, v in state_dict.items()} + + for k in skip_list: + if k in list(state_dict.keys()): + logging.info(f"Removing key {k} from pretrained checkpoint") + del state_dict[k] + + if os.getenv("RoPE") == "1": + for k in list(state_dict.keys()): + if "freqs_cos" in k or "freqs_sin" in k: + del state_dict[k] + return state_dict + + +def load_clip_visual_state_dict(checkpoint_path: str, map_location: str = "cpu", is_openai: bool = False, skip_list: list = []): + state_dict = load_state_dict(checkpoint_path, map_location=map_location, is_openai=is_openai, skip_list=skip_list) + # for k in list(state_dict.keys()): + # if not k.startswith("visual."): + # del state_dict[k] + # for k in list(state_dict.keys()): + # if k.startswith("visual."): + # new_k = k[7:] + # state_dict[new_k] = state_dict[k] + # del state_dict[k] + return state_dict + + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +try: + from apex.normalization import FusedLayerNorm +except: + FusedLayerNorm = LayerNorm + # print("Please build and install Nvidia apex package with option '--cuda_ext' according to https://github.com/NVIDIA/apex#from-source .") + + +@dataclass +class CLIPVisionCfg: + layers: Union[Tuple[int, int, int, int], int] = 12 + width: int = 768 + head_width: int = 64 + mlp_ratio: float = 4.0 + patch_size: int = 16 + image_size: Union[Tuple[int, int], int] = 224 + ls_init_value: Optional[float] = None # layer scale initial value + patch_dropout: float = 0.0 # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results + global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580) + drop_path_rate: Optional[float] = None # drop path rate + timm_model_name: str = None # a valid model name overrides layers, width, patch_size + timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model + timm_pool: str = "avg" # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '') + timm_proj: str = "linear" # linear projection for timm model output ('linear', 'mlp', '') + timm_proj_bias: bool = False # enable bias final projection + eva_model_name: str = None # a valid eva model name overrides layers, width, patch_size + qkv_bias: bool = True + fusedLN: bool = False + xattn: bool = False + postnorm: bool = False + rope: bool = False + pt_hw_seq_len: int = 16 # 224/14 + intp_freq: bool = False + naiveswiglu: bool = False + subln: bool = False + + +def create_norm_layer_factory(use_fused_ln, eps=1e-6): + # Otherwise, use the standard LayerNorm + return lambda num_features: nn.LayerNorm(num_features, eps=eps) + + +def _build_vision_tower(vision_tower_path: str, embed_dim: int, vision_cfg: CLIPVisionCfg, **kwargs): + if isinstance(vision_cfg, dict): + vision_cfg = CLIPVisionCfg(**vision_cfg) + + if vision_cfg.eva_model_name: + vision_heads = vision_cfg.width // vision_cfg.head_width + # Determine the appropriate norm layer factory based on the configuration + norm_layer_factory = create_norm_layer_factory(vision_cfg.fusedLN, eps=1e-6) + + visual = EVAVisionTransformer( + img_size=vision_cfg.image_size, + patch_size=vision_cfg.patch_size, + num_classes=embed_dim, + use_mean_pooling=vision_cfg.global_average_pool, # False + init_values=vision_cfg.ls_init_value, + patch_dropout=vision_cfg.patch_dropout, + embed_dim=vision_cfg.width, + depth=vision_cfg.layers, + num_heads=vision_heads, + mlp_ratio=vision_cfg.mlp_ratio, + qkv_bias=vision_cfg.qkv_bias, + drop_path_rate=vision_cfg.drop_path_rate, + norm_layer=norm_layer_factory, + xattn=vision_cfg.xattn, + rope=vision_cfg.rope, + postnorm=vision_cfg.postnorm, + pt_hw_seq_len=vision_cfg.pt_hw_seq_len, # 224/14 + intp_freq=vision_cfg.intp_freq, + naiveswiglu=vision_cfg.naiveswiglu, + subln=vision_cfg.subln, + ) + + state_dict = load_clip_visual_state_dict(vision_tower_path) + incompatible_keys = visual.load_state_dict(state_dict, strict=False) + rank0_print("EVA-CLIP incompatible_keys:", incompatible_keys) + + return visual + + +class EVAEncoderWrapper(nn.Module): + def __init__(self, vision_tower_pretrained, config): + super(EVAEncoderWrapper, self).__init__() + self.config = config + self.config["vision_tower_path"] = vision_tower_pretrained + self.model = _build_vision_tower(**self.config) + + def forward(self, image, **kwargs): + encode = self.model(image, return_all_features=True)[:, 1:, :] # remove the CLS token + return encode + + @property + def dtype(self): + return list(self.parameters())[-1].dtype + + @property + def device(self): + return list(self.parameters())[-1].device diff --git a/llava/model/multimodal_encoder/eva_clip/factory.py b/llava/model/multimodal_encoder/eva_clip/factory.py new file mode 100755 index 000000000..6d3fafcfd --- /dev/null +++ b/llava/model/multimodal_encoder/eva_clip/factory.py @@ -0,0 +1,60 @@ +import json +import logging +import os +import pathlib +import re +from copy import deepcopy +from pathlib import Path +from typing import Optional, Tuple, Union, Dict, Any +import torch + +_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"] +_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs + + +def _natural_key(string_): + return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())] + + +def _rescan_model_configs(): + global _MODEL_CONFIGS + + config_ext = (".json",) + config_files = [] + for config_path in _MODEL_CONFIG_PATHS: + if config_path.is_file() and config_path.suffix in config_ext: + config_files.append(config_path) + elif config_path.is_dir(): + for ext in config_ext: + config_files.extend(config_path.glob(f"*{ext}")) + + for cf in config_files: + with open(cf, "r", encoding="utf8") as f: + model_cfg = json.load(f) + if all(a in model_cfg for a in ("embed_dim", "vision_cfg", "text_cfg")): + _MODEL_CONFIGS[cf.stem] = model_cfg + + _MODEL_CONFIGS = dict(sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))) + + +_rescan_model_configs() # initial populate of model config registry + + +def list_models(): + """enumerate available model architectures based on config files""" + return list(_MODEL_CONFIGS.keys()) + + +def add_model_config(path): + """add model config path or file and update registry""" + if not isinstance(path, Path): + path = Path(path) + _MODEL_CONFIG_PATHS.append(path) + _rescan_model_configs() + + +def get_model_config(model_name): + if model_name in _MODEL_CONFIGS: + return deepcopy(_MODEL_CONFIGS[model_name]) + else: + return None diff --git a/llava/model/multimodal_encoder/eva_clip/model_configs/EVA-CLIP-18B.json b/llava/model/multimodal_encoder/eva_clip/model_configs/EVA-CLIP-18B.json new file mode 100755 index 000000000..491755669 --- /dev/null +++ b/llava/model/multimodal_encoder/eva_clip/model_configs/EVA-CLIP-18B.json @@ -0,0 +1,27 @@ +{ + "embed_dim": 1536, + "vision_cfg": { + "image_size": 224, + "layers": 48, + "width": 5120, + "head_width": 128, + "mlp_ratio": 5, + "patch_size": 14, + "eva_model_name": "eva-clip-18b-14-x", + "drop_path_rate": 0, + "qkv_bias": false, + "xattn": true, + "postnorm": true, + "fusedLN": false, + "use_rms_norm": true + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1280, + "heads": 20, + "layers": 32, + "xattn": false, + "fusedLN": false + } +} \ No newline at end of file diff --git a/llava/model/multimodal_encoder/eva_clip/model_configs/EVA-CLIP-8B-plus.json b/llava/model/multimodal_encoder/eva_clip/model_configs/EVA-CLIP-8B-plus.json new file mode 100755 index 000000000..7d843daa3 --- /dev/null +++ b/llava/model/multimodal_encoder/eva_clip/model_configs/EVA-CLIP-8B-plus.json @@ -0,0 +1,27 @@ +{ + "embed_dim": 1280, + "vision_cfg": { + "image_size": 448, + "layers": 32, + "width": 4096, + "head_width": 128, + "mlp_ratio": 5, + "patch_size": 14, + "eva_model_name": "eva-clip-8b-14-plus-x", + "drop_path_rate": 0, + "qkv_bias": false, + "xattn": true, + "postnorm": false, + "fusedLN": false, + "use_rms_norm": true + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1280, + "heads": 20, + "layers": 32, + "xattn": false, + "fusedLN": false + } +} \ No newline at end of file diff --git a/llava/model/multimodal_encoder/eva_clip/model_configs/EVA-CLIP-8B.json b/llava/model/multimodal_encoder/eva_clip/model_configs/EVA-CLIP-8B.json new file mode 100755 index 000000000..689492a25 --- /dev/null +++ b/llava/model/multimodal_encoder/eva_clip/model_configs/EVA-CLIP-8B.json @@ -0,0 +1,27 @@ +{ + "embed_dim": 1280, + "vision_cfg": { + "image_size": 224, + "layers": 32, + "width": 4096, + "head_width": 128, + "mlp_ratio": 5, + "patch_size": 14, + "eva_model_name": "eva-clip-8b-14-x", + "drop_path_rate": 0, + "qkv_bias": false, + "xattn": true, + "postnorm": false, + "fusedLN": false, + "use_rms_norm": true + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1280, + "heads": 20, + "layers": 32, + "xattn": false, + "fusedLN": false + } +} \ No newline at end of file diff --git a/llava/model/multimodal_encoder/eva_clip/model_configs/EVA01-CLIP-B-16.json b/llava/model/multimodal_encoder/eva_clip/model_configs/EVA01-CLIP-B-16.json new file mode 100755 index 000000000..aad205800 --- /dev/null +++ b/llava/model/multimodal_encoder/eva_clip/model_configs/EVA01-CLIP-B-16.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 16, + "eva_model_name": "eva-clip-b-16", + "ls_init_value": 0.1, + "drop_path_rate": 0.0 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/llava/model/multimodal_encoder/eva_clip/model_configs/EVA01-CLIP-g-14-plus.json b/llava/model/multimodal_encoder/eva_clip/model_configs/EVA01-CLIP-g-14-plus.json new file mode 100755 index 000000000..100279572 --- /dev/null +++ b/llava/model/multimodal_encoder/eva_clip/model_configs/EVA01-CLIP-g-14-plus.json @@ -0,0 +1,24 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 40, + "width": 1408, + "head_width": 88, + "mlp_ratio": 4.3637, + "patch_size": 14, + "eva_model_name": "eva-clip-g-14-x", + "drop_path_rate": 0, + "xattn": true, + "fusedLN": true + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24, + "xattn": false, + "fusedLN": true + } +} \ No newline at end of file diff --git a/llava/model/multimodal_encoder/eva_clip/model_configs/EVA01-CLIP-g-14.json b/llava/model/multimodal_encoder/eva_clip/model_configs/EVA01-CLIP-g-14.json new file mode 100755 index 000000000..5d338b4e6 --- /dev/null +++ b/llava/model/multimodal_encoder/eva_clip/model_configs/EVA01-CLIP-g-14.json @@ -0,0 +1,24 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 40, + "width": 1408, + "head_width": 88, + "mlp_ratio": 4.3637, + "patch_size": 14, + "eva_model_name": "eva-clip-g-14-x", + "drop_path_rate": 0.4, + "xattn": true, + "fusedLN": true + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12, + "xattn": false, + "fusedLN": true + } +} \ No newline at end of file diff --git a/llava/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-B-16.json b/llava/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-B-16.json new file mode 100755 index 000000000..e4a6e723f --- /dev/null +++ b/llava/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-B-16.json @@ -0,0 +1,29 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "head_width": 64, + "patch_size": 16, + "mlp_ratio": 2.6667, + "eva_model_name": "eva-clip-b-16-X", + "drop_path_rate": 0.0, + "xattn": true, + "fusedLN": true, + "rope": true, + "pt_hw_seq_len": 16, + "intp_freq": true, + "naiveswiglu": true, + "subln": true + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12, + "xattn": true, + "fusedLN": true + } +} \ No newline at end of file diff --git a/llava/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-L-14-336.json b/llava/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-L-14-336.json new file mode 100755 index 000000000..3e1d124e1 --- /dev/null +++ b/llava/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-L-14-336.json @@ -0,0 +1,29 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 336, + "layers": 24, + "width": 1024, + "drop_path_rate": 0, + "head_width": 64, + "mlp_ratio": 2.6667, + "patch_size": 14, + "eva_model_name": "eva-clip-l-14-336", + "xattn": true, + "fusedLN": true, + "rope": true, + "pt_hw_seq_len": 16, + "intp_freq": true, + "naiveswiglu": true, + "subln": true + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12, + "xattn": false, + "fusedLN": true + } +} \ No newline at end of file diff --git a/llava/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-L-14.json b/llava/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-L-14.json new file mode 100755 index 000000000..03b22ad3c --- /dev/null +++ b/llava/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-L-14.json @@ -0,0 +1,29 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 224, + "layers": 24, + "width": 1024, + "drop_path_rate": 0, + "head_width": 64, + "mlp_ratio": 2.6667, + "patch_size": 14, + "eva_model_name": "eva-clip-l-14", + "xattn": true, + "fusedLN": true, + "rope": true, + "pt_hw_seq_len": 16, + "intp_freq": true, + "naiveswiglu": true, + "subln": true + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12, + "xattn": false, + "fusedLN": true + } +} \ No newline at end of file diff --git a/llava/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-bigE-14-plus.json b/llava/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-bigE-14-plus.json new file mode 100755 index 000000000..aa04e2545 --- /dev/null +++ b/llava/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-bigE-14-plus.json @@ -0,0 +1,25 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 64, + "width": 1792, + "head_width": 112, + "mlp_ratio": 8.571428571428571, + "patch_size": 14, + "eva_model_name": "eva-clip-4b-14-x", + "drop_path_rate": 0, + "xattn": true, + "postnorm": true, + "fusedLN": true + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1280, + "heads": 20, + "layers": 32, + "xattn": false, + "fusedLN": true + } +} diff --git a/llava/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-bigE-14.json b/llava/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-bigE-14.json new file mode 100755 index 000000000..747ffccc8 --- /dev/null +++ b/llava/model/multimodal_encoder/eva_clip/model_configs/EVA02-CLIP-bigE-14.json @@ -0,0 +1,25 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 64, + "width": 1792, + "head_width": 112, + "mlp_ratio": 8.571428571428571, + "patch_size": 14, + "eva_model_name": "eva-clip-4b-14-x", + "drop_path_rate": 0, + "xattn": true, + "postnorm": true, + "fusedLN": true + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24, + "xattn": false, + "fusedLN": true + } +} \ No newline at end of file diff --git a/llava/model/multimodal_encoder/eva_clip/model_configs/Internal-EVA02-CLIP-10B-14-448.json b/llava/model/multimodal_encoder/eva_clip/model_configs/Internal-EVA02-CLIP-10B-14-448.json new file mode 100755 index 000000000..ad71aff86 --- /dev/null +++ b/llava/model/multimodal_encoder/eva_clip/model_configs/Internal-EVA02-CLIP-10B-14-448.json @@ -0,0 +1,25 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 448, + "layers": 77, + "width": 2304, + "head_width": 144, + "mlp_ratio": 10.9722, + "patch_size": 14, + "eva_model_name": "eva-clip-10b-14-x", + "drop_path_rate": 0, + "xattn": true, + "postnorm": false, + "fusedLN": true + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1280, + "heads": 20, + "layers": 32, + "xattn": false, + "fusedLN": true + } +} diff --git a/llava/model/multimodal_encoder/eva_clip/model_configs/Internal-EVA02-CLIP-10B-14.json b/llava/model/multimodal_encoder/eva_clip/model_configs/Internal-EVA02-CLIP-10B-14.json new file mode 100755 index 000000000..21b206807 --- /dev/null +++ b/llava/model/multimodal_encoder/eva_clip/model_configs/Internal-EVA02-CLIP-10B-14.json @@ -0,0 +1,25 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 77, + "width": 2304, + "head_width": 144, + "mlp_ratio": 10.9722, + "patch_size": 14, + "eva_model_name": "eva-clip-10b-14-x", + "drop_path_rate": 0, + "xattn": true, + "postnorm": false, + "fusedLN": true + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1280, + "heads": 20, + "layers": 32, + "xattn": false, + "fusedLN": true + } +} diff --git a/llava/model/multimodal_encoder/hf_vision.py b/llava/model/multimodal_encoder/hf_vision.py new file mode 100755 index 000000000..a413208e4 --- /dev/null +++ b/llava/model/multimodal_encoder/hf_vision.py @@ -0,0 +1,111 @@ +import torch +import torch.nn as nn + +from transformers import AutoModel, AutoImageProcessor, AutoConfig, CLIPImageProcessor +from llava.utils import rank0_print + + +class HFVisionTower(nn.Module): + def __init__(self, vision_tower, args, delay_load=False): + super().__init__() + + self.is_loaded = False + + self.vision_tower_name = vision_tower.replace("hf:", "", 1) + self.select_layer = args.mm_vision_select_layer + self.select_feature = getattr(args, "mm_vision_select_feature", "patch") + + if not delay_load: + self.load_model() + else: + self.cfg_only = AutoConfig.from_pretrained(self.vision_tower_name) + + def load_model(self): + try: + self.image_processor = AutoImageProcessor.from_pretrained(self.vision_tower_name) + except Exception as e: + if "448" in self.vision_tower_name: + image_size = 448 + # use image processor with conig + self.image_processor = CLIPImageProcessor(size={"shortest_edge": image_size}, do_center_crop=True, crop_size=image_size) + else: + self.image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14") + rank0_print(f"Loaded image processor: {self.image_processor}") + self.vision_tower = AutoModel.from_pretrained(self.vision_tower_name, torch_dtype=torch.bfloat16, trust_remote_code=True).to("cuda") + self.device = self.vision_tower.device + self.dtype = self.vision_tower.dtype + self.config = self.vision_tower.config + + if hasattr(self.vision_tower, "vision_model"): + self.vision_tower = self.vision_tower.vision_model + self.vision_tower.requires_grad_(False) + # self.vision_tower.eval() + self.is_loaded = True + + def feature_select(self, image_forward_outs): + select_feature_type = self.select_feature + + if self.select_feature in ["slicefour_patch", "slicefour_cls_patch"]: + select_every_k_layer = len(image_forward_outs.hidden_states) // 4 + image_features = torch.cat([image_forward_outs.hidden_states[i] for i in range(select_every_k_layer + self.select_layer, len(image_forward_outs.hidden_states), select_every_k_layer)], dim=-1) + select_feature_type = select_feature_type.replace("slicefour_", "") + else: + image_features = image_forward_outs.hidden_states[self.select_layer] + + if select_feature_type == "patch": + image_features = image_features[:, 1:] + elif select_feature_type == "cls_patch": + image_features = image_features + else: + raise ValueError(f"Unexpected select feature: {select_feature_type}") + return image_features + + def forward(self, images): + if type(images) is list: + image_features = [] + for image in images: + image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True) + image_feature = self.feature_select(image_forward_out).to(image.dtype) + image_features.append(image_feature) + else: + image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) + image_features = self.feature_select(image_forward_outs).to(images.dtype) + + return image_features + + @property + def dummy_feature(self): + return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) + + # @property + # def dtype(self): + # return self.vision_tower.dtype + + # @property + # def device(self): + # return self.vision_tower.device + + @property + def hidden_size(self): + try: + _hidden_size = self.config.hidden_size + except: + _hidden_size = self.config.vision_config.hidden_size + if "slicefour" in self.select_feature: + _hidden_size *= 4 + return _hidden_size + + @property + def num_patches(self): + _num_patches = (self.config.image_size // self.config.patch_size) ** 2 + if "cls_patch" in self.select_feature: + _num_patches += 1 + return _num_patches + + @property + def num_patches_per_side(self): + return self.config.image_size // self.config.patch_size + + @property + def image_size(self): + return self.config.image_size diff --git a/llava/model/multimodal_encoder/imagebind.py b/llava/model/multimodal_encoder/imagebind.py new file mode 100755 index 000000000..8bbe71c7b --- /dev/null +++ b/llava/model/multimodal_encoder/imagebind.py @@ -0,0 +1,73 @@ +import torch +import torch.nn as nn + +from transformers import CLIPImageProcessor + +try: + from imagebind.models import imagebind_model + from imagebind.models.imagebind_model import ModalityType + from imagebind.data import load_and_transform_audio_data +except ImportError: + pass + + +class ImageBindWrapper(nn.Module): + def __init__(self, vision_tower, select_layer, select_feature="patch", delay_load=False): + super().__init__() + + self.is_loaded = False + + self.vision_tower_name = vision_tower + self.select_layer = select_layer + self.select_feature = select_feature + + if not delay_load: + self.load_model() + + def load_model(self): + self.image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14") + self.vision_tower = imagebind_model.imagebind_huge(pretrained=True) + for p in self.vision_tower.parameters(): + p.requires_grad = False + self.vision_tower.eval() + self.is_loaded = True + + def train(self, mode=True): + self.training = mode + + if self.is_loaded: + self.vision_tower.eval() + + @torch.no_grad() + def forward(self, x): + if type(x) == dict: + if x["audios"] is not None: + inputs = {ModalityType.AUDIO: load_and_transform_audio_data(x["audios"], device=self.device).half()} + embeddings = self.vision_tower(inputs) + audio_embedding = embeddings[ModalityType.AUDIO] + return audio_embedding.unsqueeze(1) + else: + inputs = {ModalityType.VISION: x.to(dtype=self.dtype)} + embeddings = self.vision_tower(inputs) + vision_embedding = embeddings[ModalityType.VISION] + if vision_embedding.ndim == 2: + return vision_embedding.unsqueeze(1) + if vision_embedding.shape[1] == 257: + return vision_embedding[:, 1:] + raise ValueError(f"Unexpected shape: {vision_embedding.shape}") + + @property + def dummy_feature(self): + return torch.zeros(1, 1024, device=self.device, dtype=self.dtype) + + @property + def dtype(self): + return self.vision_tower.modality_preprocessors.vision.cls_token.dtype + + @property + def device(self): + return self.vision_tower.modality_preprocessors.vision.cls_token.device + + @property + def hidden_size(self): + return 1024 diff --git a/llava/model/multimodal_encoder/mlcd/vit_rope2d_hf.py b/llava/model/multimodal_encoder/mlcd/vit_rope2d_hf.py new file mode 100644 index 000000000..5ddedae5b --- /dev/null +++ b/llava/model/multimodal_encoder/mlcd/vit_rope2d_hf.py @@ -0,0 +1,452 @@ +from typing import Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import functional as F +from transformers.models.clip.modeling_clip import (CLIPMLP, BaseModelOutput, + BaseModelOutputWithPooling, + CLIPVisionConfig, + PreTrainedModel) + + +def rotate_half(x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: + orig_dtype = tensor.dtype + tensor = tensor.float() + cos = freqs.cos() + sin = freqs.sin() + cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() + sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() + output = (tensor * cos) + (rotate_half(tensor) * sin) + output = output.to(orig_dtype) + return output + + +class VisionRotaryEmbedding(nn.Module): + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__() + inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward(self, seqlen: int) -> torch.Tensor: + seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + freqs = torch.outer(seq, self.inv_freq) + return freqs + + +class MLCDVisionConfig(CLIPVisionConfig): + + model_type = "mlcd_vision_model" + + def __init__(self,**kwargs): + super().__init__(**kwargs) + + +class MLCDMLP(CLIPMLP): + def __init__(self, config: MLCDVisionConfig): + super().__init__(config) + + +class MLCDVisionEmbeddings(torch.nn.Module): + def __init__(self, config: MLCDVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.class_embedding = nn.Parameter(torch.randn(self.embed_dim)) + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + bias=False, + ) + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + 1 + + + def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: + batch_size = pixel_values.shape[0] + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid] + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + + class_embeds = self.class_embedding.expand(batch_size, 1, -1) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + + return embeddings + + +class MLCDSdpaAttention(torch.nn.Module): + """Multi-headed attention from these papers + + - Attention is all you need: + https://arxiv.org/abs/1706.03762 + + - RoFormer: Enhanced Transformer with Rotary Position Embedding: + https://arxiv.org/abs/2104.09864 + """ + + def __init__(self, config: MLCDVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + + + def forward( + self, + hidden_states: torch.Tensor, + rotary_pos_emb: torch.Tensor, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Input shape: Batch x Seq x Hidden Size""" + batch_size, seq_length , hidden_size = hidden_states.size() + # Each of shape: [batch_size, seq_length, num_heads, head_dim] + q = self.q_proj(hidden_states).reshape((batch_size, seq_length, self.num_heads, self.head_dim)) + k = self.k_proj(hidden_states).reshape((batch_size, seq_length, self.num_heads, self.head_dim)) + v = self.v_proj(hidden_states).reshape((batch_size, seq_length, self.num_heads, self.head_dim)) + q = apply_rotary_pos_emb_vision(q, rotary_pos_emb) + k = apply_rotary_pos_emb_vision(k, rotary_pos_emb) + q = q.permute(0, 2, 1, 3).contiguous() + k = k.permute(0, 2, 1, 3).contiguous() + v = v.permute(0, 2, 1, 3).contiguous() + # q (batch_size, num_heads, seq_length, head_dim) + # k (batch_size, num_heads, seq_length, head_dim) + # v (batch_size, num_heads, seq_length, head_dim) + attn_output = F.scaled_dot_product_attention(q, k, v, None, dropout_p=0.0) + attn_output = attn_output.permute(2, 0, 1, 3).contiguous() # [seq_length, batch_size, num_heads, head_dim] + attn_output = attn_output.view(seq_length, batch_size, -1) # [seq_length, batch_size, embedding_dim] + attn_output = self.out_proj(attn_output) + attn_output = attn_output.permute(1, 0, 2).contiguous() # [batch_size, seq_length, embedding_dim] + return attn_output, None + + +class MLCDEncoderLayer(nn.Module): + def __init__(self, config: MLCDVisionConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = MLCDSdpaAttention(config) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = MLCDMLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + rotary_pos_emb: torch.Tensor, + ) -> Tuple[torch.FloatTensor]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + `(config.encoder_attention_heads,)`. + """ + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + + hidden_states = self.self_attn( + hidden_states=hidden_states, + rotary_pos_emb=rotary_pos_emb, + )[0] + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + + hidden_states = self.mlp(hidden_states) + + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + return outputs + + +class MLCDEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`MLCDEncoderLayer`]. + + Args: + config: MLCDVisionConfig + """ + + def __init__(self, config: MLCDVisionConfig): + super().__init__() + self.config = config + self.layers = nn.ModuleList([MLCDEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + inputs_embeds, + rotary_pos_emb, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Causal mask for the text model. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_states = () if output_hidden_states else None + + hidden_states = inputs_embeds + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + rotary_pos_emb + ) + else: + layer_outputs = encoder_layer( + hidden_states, + rotary_pos_emb + ) + + hidden_states = layer_outputs[0] + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, None] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=None, + ) + + +class MLCDVisionTransformer(nn.Module): + def __init__(self, config: MLCDVisionConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + + self.embeddings = MLCDVisionEmbeddings(config) + self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.encoder = MLCDEncoder(config) + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + self.vision_rotary_embedding = VisionRotaryEmbedding(config.hidden_size // config.num_attention_heads // 2) + self.class_pos_emb = nn.Parameter(torch.randn(1, config.hidden_size // config.num_attention_heads // 2)) + + + def rot_pos_emb(self, grid_thw): + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape(h, 1, w, 1) + hpos_ids = hpos_ids.permute(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape(h, 1, w, 1) + wpos_ids = wpos_ids.permute(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + rotary_pos_emb_full = self.vision_rotary_embedding(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb + + + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + # output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + """ + + twh = (1, pixel_values.size(3) // self.config.patch_size, pixel_values.size(2) // self.config.patch_size) + rotary_pos_emb = self.rot_pos_emb(torch.tensor([twh], device=pixel_values.device)) + rotary_pos_emb = torch.cat([self.class_pos_emb, rotary_pos_emb], dim=0) + + # output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + hidden_states = self.embeddings(pixel_values) + hidden_states = self.pre_layrnorm(hidden_states) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + rotary_pos_emb=rotary_pos_emb, + # output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + pooled_output = last_hidden_state[:, 0, :] + pooled_output = self.post_layernorm(pooled_output) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + # attentions=encoder_outputs.attentions, + ) + + +class MLCDPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + config_class = MLCDVisionConfig + base_model_prefix = "mlcd" + supports_gradient_checkpointing = True + _supports_sdpa = True + # _supports_flash_attn_2 = True + + def _init_weights(self, module): + """Initialize the weights""" + factor = self.config.initializer_factor + if isinstance(module, MLCDVisionEmbeddings): + factor = self.config.initializer_factor + nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor) + nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor) + elif isinstance(module, MLCDSdpaAttention): + factor = self.config.initializer_factor + in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor + out_proj_std = (module.embed_dim**-0.5) * factor + nn.init.normal_(module.q_proj.weight, std=in_proj_std) + nn.init.normal_(module.k_proj.weight, std=in_proj_std) + nn.init.normal_(module.v_proj.weight, std=in_proj_std) + nn.init.normal_(module.out_proj.weight, std=out_proj_std) + elif isinstance(module, MLCDMLP): + factor = self.config.initializer_factor + in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor + fc_std = (2 * module.config.hidden_size) ** -0.5 * factor + nn.init.normal_(module.fc1.weight, std=fc_std) + nn.init.normal_(module.fc2.weight, std=in_proj_std) + + + if isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +class MLCDVisionModel(MLCDPreTrainedModel): + config_class = MLCDVisionConfig + main_input_name = "pixel_values" + _no_split_modules = ["MLCDEncoderLayer"] + + def __init__(self, config: MLCDVisionConfig): + super().__init__(config) + self.vision_model = MLCDVisionTransformer(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.vision_model.embeddings.patch_embedding + + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, MLCDVisionModel + + >>> model = MLCDVisionModel.from_pretrained("DeepGlint-AI/mlcd-vit-bigG-patch14") + >>> processor = AutoProcessor.from_pretrained("DeepGlint-AI/mlcd-vit-bigG-patch14") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled CLS states + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + return self.vision_model( + pixel_values=pixel_values, + # output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) \ No newline at end of file diff --git a/llava/model/multimodal_encoder/mlcd_encoder.py b/llava/model/multimodal_encoder/mlcd_encoder.py new file mode 100644 index 000000000..7f1041f0c --- /dev/null +++ b/llava/model/multimodal_encoder/mlcd_encoder.py @@ -0,0 +1,179 @@ +import torch +import torch.nn as nn +from llava.utils import rank0_print +# from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig +from transformers import CLIPImageProcessor +from .mlcd.vit_rope2d_hf import MLCDVisionModel, MLCDVisionConfig + +try: + from s2wrapper import forward as multiscale_forward +except: + pass + + +# class CLIPVisionTower(nn.Module): +class MLCDVisionTower(nn.Module): + def __init__(self, vision_tower, args, delay_load=False): + super().__init__() + + self.is_loaded = False + + self.vision_tower_name = vision_tower + self.select_layer = args.mm_vision_select_layer + self.select_feature = getattr(args, "mm_vision_select_feature", "patch") + + if not delay_load: + rank0_print(f"Loading vision tower: {vision_tower}") + self.load_model() + elif getattr(args, "unfreeze_mm_vision_tower", False): + # TODO: better detector is needed. + rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.") + self.load_model() + elif hasattr(args, "mm_tunable_parts") and "mm_vision_tower" in args.mm_tunable_parts: + rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.") + self.load_model() + else: + # self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name) + self.cfg_only = MLCDVisionConfig.from_pretrained(self.vision_tower_name) + + def load_model(self, device_map=None): + if self.is_loaded: + rank0_print("{} is already loaded, `load_model` called again, skipping.".format(self.vision_tower_name)) + return + + self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name) + # self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map) + self.vision_tower = MLCDVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map) + self.vision_tower.requires_grad_(False) + + self.is_loaded = True + + def feature_select(self, image_forward_outs): + select_feature_type = self.select_feature + + if self.select_feature in ["slicefour_patch", "slicefour_cls_patch"]: + select_every_k_layer = len(image_forward_outs.hidden_states) // 4 + image_features = torch.cat([image_forward_outs.hidden_states[i] for i in range(select_every_k_layer + self.select_layer, len(image_forward_outs.hidden_states), select_every_k_layer)], dim=-1) + select_feature_type = select_feature_type.replace("slicefour_", "") + elif self.select_feature in ["slice_m25811_f6_patch", "slice_m25811_f6_cls_patch"]: + select_layers = [-2, -5, -8, -11, 6] + image_features = torch.cat([image_forward_outs.hidden_states[i] for i in select_layers], dim=-1) + select_feature_type = select_feature_type.replace("slice_m25811_f6_", "") + else: + image_features = image_forward_outs.hidden_states[self.select_layer] + + if select_feature_type == "patch": + image_features = image_features[:, 1:] + elif select_feature_type == "cls_patch": + image_features = image_features + else: + raise ValueError(f"Unexpected select feature: {select_feature_type}") + return image_features + + def forward(self, images): + if type(images) is list: + image_features = [] + for image in images: + image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True) + image_feature = self.feature_select(image_forward_out).to(image.dtype) + image_features.append(image_feature) + else: + image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) + image_features = self.feature_select(image_forward_outs).to(images.dtype) + + return image_features + + @property + def dummy_feature(self): + return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) + + @property + def dtype(self): + return self.vision_tower.dtype + + @property + def device(self): + return self.vision_tower.device + + @property + def config(self): + if self.is_loaded: + return self.vision_tower.config + else: + return self.cfg_only + + @property + def hidden_size(self): + _hidden_size = self.config.hidden_size + if "slicefour" in self.select_feature: + _hidden_size *= 4 + if "slice_m25811_f6" in self.select_feature: + _hidden_size *= 5 + return _hidden_size + + @property + def num_patches_per_side(self): + return self.config.image_size // self.config.patch_size + + @property + def num_patches(self): + _num_patches = (self.config.image_size // self.config.patch_size) ** 2 + if "cls_patch" in self.select_feature: + _num_patches += 1 + return _num_patches + + @property + def image_size(self): + return self.config.image_size + + +# class CLIPVisionTowerS2(CLIPVisionTower): +class MLCDVisionTowerS2(MLCDVisionTower): + def __init__(self, vision_tower, args, delay_load=False): + + self.s2_scales = getattr(args, "s2_scales", "336,672,1008") + self.s2_scales = list(map(int, self.s2_scales.split(","))) + self.s2_scales.sort() + self.s2_split_size = self.s2_scales[0] + self.s2_image_size = self.s2_scales[-1] + + super().__init__(vision_tower, args, delay_load) + + # change resize/crop size in preprocessing to the largest image size in s2_scale + if not delay_load or getattr(args, "unfreeze_mm_vision_tower", False): + self.image_processor.size["shortest_edge"] = self.s2_image_size + self.image_processor.crop_size["height"] = self.image_processor.crop_size["width"] = self.s2_image_size + + def load_model(self, device_map=None): + if self.is_loaded: + rank0_print("{} is already loaded, `load_model` called again, skipping.".format(self.vision_tower_name)) + return + + self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name) + self.vision_tower = MLCDVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map) + self.vision_tower.requires_grad_(False) + + self.image_processor.size["shortest_edge"] = self.s2_image_size + self.image_processor.crop_size["height"] = self.image_processor.crop_size["width"] = self.s2_image_size + + self.is_loaded = True + + def forward_feature(self, images): + image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) + image_features = self.feature_select(image_forward_outs).to(images.dtype) + return image_features + + def forward(self, images): + if type(images) is list: + image_features = [] + for image in images: + image_feature = multiscale_forward(self.forward_feature, image.unsqueeze(0), img_sizes=self.s2_scales, max_split_size=self.s2_split_size, split_forward=True) + image_features.append(image_feature) + else: + image_features = multiscale_forward(self.forward_feature, images, img_sizes=self.s2_scales, max_split_size=self.s2_split_size, split_forward=True) + + return image_features + + @property + def hidden_size(self): + return self.config.hidden_size * len(self.s2_scales) diff --git a/llava/model/multimodal_encoder/open_clip_encoder.py b/llava/model/multimodal_encoder/open_clip_encoder.py index cd6b0599a..17a3277f9 100755 --- a/llava/model/multimodal_encoder/open_clip_encoder.py +++ b/llava/model/multimodal_encoder/open_clip_encoder.py @@ -1,17 +1,20 @@ import torch import torch.nn as nn -import torchvision -import torch.utils.checkpoint as checkpoint from transformers import CLIPImageProcessor from llava.utils import rank0_print try: import open_clip + import torchvision from open_clip.transformer import _expand_token except ImportError: print("OpenCLIP not installed") open_clip = None +HIDDEN_SIZE_DICT = { + "ViT-H-14-378-quickgelu": 1280, +} + class OpenCLIPVisionTower(nn.Module): def __init__(self, vision_tower, args, delay_load=False): @@ -24,12 +27,20 @@ def __init__(self, vision_tower, args, delay_load=False): self.select_feature = getattr(args, "mm_vision_select_feature", "patch") if not delay_load: + rank0_print(f"Loading vision tower: {vision_tower}") + self.load_model() + elif getattr(args, "unfreeze_mm_vision_tower", False): + # TODO: better detector is needed. + rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.") + self.load_model() + elif hasattr(args, "mm_tunable_parts") and "mm_vision_tower" in args.mm_tunable_parts: + rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.") self.load_model() def load_model(self, device_map="auto"): rank0_print(f"Loading OpenCLIP model: {self.model_name}") rank0_print(f"Pretrained: {self.pretrained}") - vision_tower, _, image_processor = open_clip.create_model_and_transforms(model_name=self.model_name, pretrained=self.pretrained, precision="bf16", device="cuda") + vision_tower, _, image_processor = open_clip.create_model_and_transforms(model_name=self.model_name, pretrained=self.pretrained, precision="fp32", device="cuda") resize_transform = [t for t in image_processor.transforms if isinstance(t, torchvision.transforms.Resize)][0] normalize_transform = [t for t in image_processor.transforms if isinstance(t, torchvision.transforms.Normalize)][0] @@ -91,7 +102,6 @@ def forward_openclip(self, x: torch.Tensor): return forward_openclip(self.vision_tower, x) - @torch.no_grad() def forward(self, images): if type(images) is list: image_features = [] @@ -131,11 +141,10 @@ def config(self): @property def hidden_size(self): - if hasattr(self.vision_tower, "ln_post"): - return self.vision_tower.ln_post.weight.shape[0] - if hasattr(self.vision_tower, "trunk"): - return self.vision_tower.trunk.num_features - raise NotImplementedError + if self.model_name in HIDDEN_SIZE_DICT: + return HIDDEN_SIZE_DICT[self.model_name] + else: + raise NotImplementedError @property def num_patches(self): @@ -148,3 +157,7 @@ def num_patches(self): @property def image_size(self): return self.resize_transform_size + + @property + def num_patches_per_side(self): + return self.resize_transform_size // self.patch_size diff --git a/llava/model/multimodal_encoder/siglip_encoder.py b/llava/model/multimodal_encoder/siglip_encoder.py new file mode 100755 index 000000000..f1e101a2f --- /dev/null +++ b/llava/model/multimodal_encoder/siglip_encoder.py @@ -0,0 +1,620 @@ +""" +# Adapted from https://huggingface.co/MILVLG/imp-v1-3b/blob/main/vision_encoder.py +""" + +from typing import Optional, Tuple, Union, Dict +from dataclasses import dataclass +from functools import partial, reduce +from PIL import Image +import torch +import torch.utils.checkpoint +from torch import nn +import os +from transformers.image_processing_utils import BatchFeature, get_size_dict +from transformers.image_transforms import ( + convert_to_rgb, + normalize, + rescale, + resize, + to_channel_dimension_format, +) +from transformers.image_utils import ( + ChannelDimension, + PILImageResampling, + to_numpy_array, +) +from transformers.activations import ACT2FN +from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling +from transformers.modeling_utils import PreTrainedModel +from transformers import PretrainedConfig +from transformers.utils import ModelOutput +from llava.utils import rank0_print + + +class SigLipImageProcessor: + def __init__(self, image_mean=(0.5, 0.5, 0.5), image_std=(0.5, 0.5, 0.5), size=(384, 384), crop_size: Dict[str, int] = None, resample=PILImageResampling.BICUBIC, rescale_factor=1 / 255, data_format=ChannelDimension.FIRST): + crop_size = crop_size if crop_size is not None else {"height": 384, "width": 384} + crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size") + + self.image_mean = image_mean + self.image_std = image_std + self.size = size + self.resample = resample + self.rescale_factor = rescale_factor + self.data_format = data_format + self.crop_size = crop_size + + def preprocess(self, images, return_tensors): + if isinstance(images, Image.Image): + images = [images] + else: + # to adapt video data + images = [to_numpy_array(image) for image in images] + assert isinstance(images, list) + + transforms = [ + convert_to_rgb, + to_numpy_array, + partial(resize, size=self.size, resample=self.resample, data_format=self.data_format), + partial(rescale, scale=self.rescale_factor, data_format=self.data_format), + partial(normalize, mean=self.image_mean, std=self.image_std, data_format=self.data_format), + partial(to_channel_dimension_format, channel_dim=self.data_format, input_channel_dim=self.data_format), + ] + + images = reduce(lambda x, f: [*map(f, x)], transforms, images) + data = {"pixel_values": images} + + return BatchFeature(data=data, tensor_type=return_tensors) + + +class SigLipVisionConfig(PretrainedConfig): + model_type = "siglip_vision_model" + + def __init__( + self, + hidden_size=1152, + image_mean=(0.5, 0.5, 0.5), + intermediate_size=4304, + num_hidden_layers=27, + num_attention_heads=16, + num_channels=3, + image_size=384, + patch_size=14, + hidden_act="gelu_pytorch_tanh", + layer_norm_eps=1e-6, + attention_dropout=0.0, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.patch_size = patch_size + self.image_size = image_size + self.attention_dropout = attention_dropout + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + self.image_mean = image_mean + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": + cls._set_token_in_kwargs(kwargs) + + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + + # get the vision config dict if we are loading from SigLipConfig + if config_dict.get("model_type") == "siglip": + config_dict = config_dict["vision_config"] + + if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: + print(f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " f"{cls.model_type}. This is not supported for all configurations of models and can yield errors.") + + return cls.from_dict(config_dict, **kwargs) + + +@dataclass +# Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->SigLip +class SigLipVisionModelOutput(ModelOutput): + """ + Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states. + + Args: + image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The image embeddings obtained by applying the projection layer to the pooler_output. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + image_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +class SigLipVisionEmbeddings(nn.Module): + def __init__(self, config: SigLipVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + ) + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False) + + def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: + patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid] + embeddings = patch_embeds.flatten(2).transpose(1, 2) + + embeddings = embeddings + self.position_embedding(self.position_ids) + return embeddings + + +class SigLipAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__ + def __init__(self, config): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError(f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" f" {self.num_heads}).") + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + batch_size, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + k_v_seq_len = key_states.shape[-2] + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale + + if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len): + raise ValueError(f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" f" {attn_weights.size()}") + + if attention_mask is not None: + if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len): + raise ValueError(f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}") + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim): + raise ValueError(f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is" f" {attn_output.size()}") + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights + + +# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->SigLip +class SigLipMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->SigLip +class SigLipEncoderLayer(nn.Module): + def __init__(self, config: SigLipVisionConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = SigLipAttention(config) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = SigLipMLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + + # Ignore copy + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor]: + """ + Args: + hidden_states (`torch.FloatTensor`): + Input to the layer of shape `(batch, seq_len, embed_dim)`. + attention_mask (`torch.FloatTensor`): + Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class SigLipPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = SigLipVisionConfig + base_model_prefix = "siglip" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights""" + pass + + +# Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->SigLip +class SigLipEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`SigLipEncoderLayer`]. + + Args: + config: SigLipVisionConfig + """ + + def __init__(self, config: SigLipVisionConfig): + super().__init__() + self.config = config + self.layers = nn.ModuleList([SigLipEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + # Ignore copy + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_states = inputs_embeds + for encoder_layer in self.layers: + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + attention_mask, + output_attentions, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput(last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions) + + +class SigLipVisionTransformer(nn.Module): + def __init__(self, config: SigLipVisionConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + + self.embeddings = SigLipVisionEmbeddings(config) + self.encoder = SigLipEncoder(config) + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.head = SigLipMultiheadAttentionPoolingHead(config) + + def forward( + self, + pixel_values, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + hidden_states = self.embeddings(pixel_values) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + last_hidden_state = self.post_layernorm(last_hidden_state) + + pooled_output = self.head(last_hidden_state) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +class SigLipMultiheadAttentionPoolingHead(nn.Module): + """Multihead Attention Pooling.""" + + def __init__(self, config: SigLipVisionConfig): + super().__init__() + + self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size)) + self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True) + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.mlp = SigLipMLP(config) + + def forward(self, hidden_state): + batch_size = hidden_state.shape[0] + probe = self.probe.repeat(batch_size, 1, 1) + + hidden_state = self.attention(probe, hidden_state, hidden_state)[0] + + residual = hidden_state + hidden_state = self.layernorm(hidden_state) + hidden_state = residual + self.mlp(hidden_state) + + return hidden_state[:, 0] + + +class SigLipVisionModel(SigLipPreTrainedModel): + config_class = SigLipVisionConfig + main_input_name = "pixel_values" + _no_split_modules = ["SigLipEncoderLayer"] + + def __init__(self, config: SigLipVisionConfig): + super().__init__(config) + + self.vision_model = SigLipVisionTransformer(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.vision_model.embeddings.patch_embedding + + def forward( + self, + pixel_values, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, SigLipVisionModel + + >>> model = SigLipVisionModel.from_pretrained("google/siglip-base-patch16-224") + >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled features + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + return self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +class SigLipVisionTower(nn.Module): + def __init__(self, vision_tower, vision_tower_cfg, delay_load=False): + super().__init__() + + self.is_loaded = False + + self.config = SigLipVisionConfig() + + self.vision_tower_name = vision_tower + + self.image_processor = SigLipImageProcessor() + + if not delay_load: + rank0_print(f"Loading vision tower: {vision_tower}") + self.load_model() + elif getattr(vision_tower_cfg, "unfreeze_mm_vision_tower", False): + # TODO: better detector is needed. + rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.") + self.load_model() + elif hasattr(vision_tower_cfg, "mm_tunable_parts") and "mm_vision_tower" in vision_tower_cfg.mm_tunable_parts: + rank0_print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.") + self.load_model() + else: + self.cfg_only = self.config + + def load_model(self, device_map=None): + if self.is_loaded: + rank0_print("{} is already loaded, `load_model` called again, skipping.".format(self.vision_tower_name)) + return + + self.vision_tower = SigLipVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map) + + del self.vision_tower.vision_model.encoder.layers[-1:] + self.vision_tower.vision_model.head = nn.Identity() + self.vision_tower.requires_grad_(False) + + self.is_loaded = True + + def forward(self, images): + if type(images) is list: + image_features = [] + for image in images: + image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True) + image_feature = image_forward_out.hidden_states[-1].to(image.dtype) + assert image_features.shape[-2] == 729 + image_features.append(image_feature) + else: + image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) + image_features = image_forward_outs.hidden_states[-1].to(images.dtype) + assert image_features.shape[-2] == 729 + + return image_features + + @property + def dummy_feature(self): + return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) + + @property + def dtype(self): + for p in self.vision_tower.parameters(): + return p.dtype + + @property + def device(self): + for p in self.vision_tower.parameters(): + return p.device + + @property + def hidden_size(self): + return self.config.hidden_size + + @property + def num_patches(self): + return (self.config.image_size // self.config.patch_size) ** 2 + + @property + def num_patches_per_side(self): + return self.config.image_size // self.config.patch_size + # return self.model_config["vision_cfg"]["image_size"] // self.model_config["vision_cfg"]["patch_size"] + + @property + def image_size(self): + return self.config.image_size diff --git a/llava/model/multimodal_projector/builder.py b/llava/model/multimodal_projector/builder.py old mode 100644 new mode 100755 diff --git a/llava/model/utils.py b/llava/model/utils.py old mode 100644 new mode 100755 diff --git a/llava/serve/__init__.py b/llava/serve/__init__.py new file mode 100755 index 000000000..e69de29bb diff --git a/llava/serve/cli.py b/llava/serve/cli.py new file mode 100755 index 000000000..88fbfe85d --- /dev/null +++ b/llava/serve/cli.py @@ -0,0 +1,111 @@ +import argparse +import torch + +from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN +from llava.conversation import conv_templates, SeparatorStyle +from llava.model.builder import load_pretrained_model +from llava.utils import disable_torch_init +from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria + +from PIL import Image + +import requests +from PIL import Image +from io import BytesIO +from transformers import TextStreamer + + +def load_image(image_file): + if image_file.startswith("http") or image_file.startswith("https"): + response = requests.get(image_file) + image = Image.open(BytesIO(response.content)).convert("RGB") + else: + image = Image.open(image_file).convert("RGB") + return image + + +def main(args): + # Model + disable_torch_init() + + model_name = get_model_name_from_path(args.model_path) + tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit) + + if "llama-2" in model_name.lower(): + conv_mode = "llava_llama_2" + elif "v1" in model_name.lower(): + conv_mode = "llava_v1" + elif "mpt" in model_name.lower(): + conv_mode = "mpt" + else: + conv_mode = "llava_v0" + + if args.conv_mode is not None and conv_mode != args.conv_mode: + print("[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(conv_mode, args.conv_mode, args.conv_mode)) + else: + args.conv_mode = conv_mode + + conv = conv_templates[args.conv_mode].copy() + if "mpt" in model_name.lower(): + roles = ("user", "assistant") + else: + roles = conv.roles + + image = load_image(args.image_file) + image_tensor = image_processor.preprocess(image, return_tensors="pt")["pixel_values"].half().cuda() + + while True: + try: + inp = input(f"{roles[0]}: ") + except EOFError: + inp = "" + if not inp: + print("exit...") + break + + print(f"{roles[1]}: ", end="") + + if image is not None: + # first message + if model.config.mm_use_im_start_end: + inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + "\n" + inp + else: + inp = DEFAULT_IMAGE_TOKEN + "\n" + inp + conv.append_message(conv.roles[0], inp) + image = None + else: + # later messages + conv.append_message(conv.roles[0], inp) + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + + input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).cuda() + stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 + keywords = [stop_str] + stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) + streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) + + with torch.inference_mode(): + output_ids = model.generate(input_ids, images=image_tensor, do_sample=True, temperature=0.2, max_new_tokens=1024, streamer=streamer, use_cache=True, stopping_criteria=[stopping_criteria]) + + outputs = tokenizer.decode(output_ids[0, input_ids.shape[1] :]).strip() + conv.messages[-1][-1] = outputs + + if args.debug: + print("\n", {"prompt": prompt, "outputs": outputs}, "\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model-path", type=str, default="facebook/opt-350m") + parser.add_argument("--model-base", type=str, default=None) + parser.add_argument("--image-file", type=str, required=True) + parser.add_argument("--num-gpus", type=int, default=1) + parser.add_argument("--conv-mode", type=str, default=None) + parser.add_argument("--temperature", type=float, default=0.2) + parser.add_argument("--max-new-tokens", type=int, default=512) + parser.add_argument("--load-8bit", action="store_true") + parser.add_argument("--load-4bit", action="store_true") + parser.add_argument("--debug", action="store_true") + args = parser.parse_args() + main(args) diff --git a/llava/serve/controller.py b/llava/serve/controller.py new file mode 100755 index 000000000..261f8c6bd --- /dev/null +++ b/llava/serve/controller.py @@ -0,0 +1,287 @@ +""" +A controller manages distributed workers. +It sends worker addresses to clients. +""" + +import argparse +import asyncio +import dataclasses +from enum import Enum, auto +import json +import logging +import time +from typing import List, Union +import threading + +from fastapi import FastAPI, Request +from fastapi.responses import StreamingResponse +import numpy as np +import requests +import uvicorn + +from llava.constants import CONTROLLER_HEART_BEAT_EXPIRATION +from llava.utils import build_logger, server_error_msg + + +logger = build_logger("controller", "controller.log") + + +class DispatchMethod(Enum): + LOTTERY = auto() + SHORTEST_QUEUE = auto() + + @classmethod + def from_str(cls, name): + if name == "lottery": + return cls.LOTTERY + elif name == "shortest_queue": + return cls.SHORTEST_QUEUE + else: + raise ValueError(f"Invalid dispatch method") + + +@dataclasses.dataclass +class WorkerInfo: + model_names: List[str] + speed: int + queue_length: int + check_heart_beat: bool + last_heart_beat: str + + +def heart_beat_controller(controller): + while True: + time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION) + controller.remove_stable_workers_by_expiration() + + +class Controller: + def __init__(self, dispatch_method: str): + # Dict[str -> WorkerInfo] + self.worker_info = {} + self.dispatch_method = DispatchMethod.from_str(dispatch_method) + + self.heart_beat_thread = threading.Thread(target=heart_beat_controller, args=(self,)) + self.heart_beat_thread.start() + + logger.info("Init controller") + + def register_worker(self, worker_name: str, check_heart_beat: bool, worker_status: dict): + if worker_name not in self.worker_info: + logger.info(f"Register a new worker: {worker_name}") + else: + logger.info(f"Register an existing worker: {worker_name}") + + if not worker_status: + worker_status = self.get_worker_status(worker_name) + if not worker_status: + return False + + self.worker_info[worker_name] = WorkerInfo(worker_status["model_names"], worker_status["speed"], worker_status["queue_length"], check_heart_beat, time.time()) + + logger.info(f"Register done: {worker_name}, {worker_status}") + return True + + def get_worker_status(self, worker_name: str): + try: + r = requests.post(worker_name + "/worker_get_status", timeout=5) + except requests.exceptions.RequestException as e: + logger.error(f"Get status fails: {worker_name}, {e}") + return None + + if r.status_code != 200: + logger.error(f"Get status fails: {worker_name}, {r}") + return None + + return r.json() + + def remove_worker(self, worker_name: str): + del self.worker_info[worker_name] + + def refresh_all_workers(self): + old_info = dict(self.worker_info) + self.worker_info = {} + + for w_name, w_info in old_info.items(): + if not self.register_worker(w_name, w_info.check_heart_beat, None): + logger.info(f"Remove stale worker: {w_name}") + + def list_models(self): + model_names = set() + + for w_name, w_info in self.worker_info.items(): + model_names.update(w_info.model_names) + + return list(model_names) + + def get_worker_address(self, model_name: str): + if self.dispatch_method == DispatchMethod.LOTTERY: + worker_names = [] + worker_speeds = [] + for w_name, w_info in self.worker_info.items(): + if model_name in w_info.model_names: + worker_names.append(w_name) + worker_speeds.append(w_info.speed) + worker_speeds = np.array(worker_speeds, dtype=np.float32) + norm = np.sum(worker_speeds) + if norm < 1e-4: + return "" + worker_speeds = worker_speeds / norm + if True: # Directly return address + pt = np.random.choice(np.arange(len(worker_names)), p=worker_speeds) + worker_name = worker_names[pt] + return worker_name + + # Check status before returning + while True: + pt = np.random.choice(np.arange(len(worker_names)), p=worker_speeds) + worker_name = worker_names[pt] + + if self.get_worker_status(worker_name): + break + else: + self.remove_worker(worker_name) + worker_speeds[pt] = 0 + norm = np.sum(worker_speeds) + if norm < 1e-4: + return "" + worker_speeds = worker_speeds / norm + continue + return worker_name + elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE: + worker_names = [] + worker_qlen = [] + for w_name, w_info in self.worker_info.items(): + if model_name in w_info.model_names: + worker_names.append(w_name) + worker_qlen.append(w_info.queue_length / w_info.speed) + if len(worker_names) == 0: + return "" + min_index = np.argmin(worker_qlen) + w_name = worker_names[min_index] + self.worker_info[w_name].queue_length += 1 + logger.info(f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}") + return w_name + else: + raise ValueError(f"Invalid dispatch method: {self.dispatch_method}") + + def receive_heart_beat(self, worker_name: str, queue_length: int): + if worker_name not in self.worker_info: + logger.info(f"Receive unknown heart beat. {worker_name}") + return False + + self.worker_info[worker_name].queue_length = queue_length + self.worker_info[worker_name].last_heart_beat = time.time() + logger.info(f"Receive heart beat. {worker_name}") + return True + + def remove_stable_workers_by_expiration(self): + expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION + to_delete = [] + for worker_name, w_info in self.worker_info.items(): + if w_info.check_heart_beat and w_info.last_heart_beat < expire: + to_delete.append(worker_name) + + for worker_name in to_delete: + self.remove_worker(worker_name) + + def worker_api_generate_stream(self, params): + worker_addr = self.get_worker_address(params["model"]) + if not worker_addr: + logger.info(f"no worker: {params['model']}") + ret = { + "text": server_error_msg, + "error_code": 2, + } + yield json.dumps(ret).encode() + b"\0" + + try: + response = requests.post(worker_addr + "/worker_generate_stream", json=params, stream=True, timeout=5) + for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): + if chunk: + yield chunk + b"\0" + except requests.exceptions.RequestException as e: + logger.info(f"worker timeout: {worker_addr}") + ret = { + "text": server_error_msg, + "error_code": 3, + } + yield json.dumps(ret).encode() + b"\0" + + # Let the controller act as a worker to achieve hierarchical + # management. This can be used to connect isolated sub networks. + def worker_api_get_status(self): + model_names = set() + speed = 0 + queue_length = 0 + + for w_name in self.worker_info: + worker_status = self.get_worker_status(w_name) + if worker_status is not None: + model_names.update(worker_status["model_names"]) + speed += worker_status["speed"] + queue_length += worker_status["queue_length"] + + return { + "model_names": list(model_names), + "speed": speed, + "queue_length": queue_length, + } + + +app = FastAPI() + + +@app.post("/register_worker") +async def register_worker(request: Request): + data = await request.json() + controller.register_worker(data["worker_name"], data["check_heart_beat"], data.get("worker_status", None)) + + +@app.post("/refresh_all_workers") +async def refresh_all_workers(): + models = controller.refresh_all_workers() + + +@app.post("/list_models") +async def list_models(): + models = controller.list_models() + return {"models": models} + + +@app.post("/get_worker_address") +async def get_worker_address(request: Request): + data = await request.json() + addr = controller.get_worker_address(data["model"]) + return {"address": addr} + + +@app.post("/receive_heart_beat") +async def receive_heart_beat(request: Request): + data = await request.json() + exist = controller.receive_heart_beat(data["worker_name"], data["queue_length"]) + return {"exist": exist} + + +@app.post("/worker_generate_stream") +async def worker_api_generate_stream(request: Request): + params = await request.json() + generator = controller.worker_api_generate_stream(params) + return StreamingResponse(generator) + + +@app.post("/worker_get_status") +async def worker_api_get_status(request: Request): + return controller.worker_api_get_status() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=21001) + parser.add_argument("--dispatch-method", type=str, choices=["lottery", "shortest_queue"], default="shortest_queue") + args = parser.parse_args() + logger.info(f"args: {args}") + + controller = Controller(args.dispatch_method) + uvicorn.run(app, host=args.host, port=args.port, log_level="info") diff --git a/llava/serve/examples/extreme_ironing.jpg b/llava/serve/examples/extreme_ironing.jpg new file mode 100755 index 000000000..638b07883 Binary files /dev/null and b/llava/serve/examples/extreme_ironing.jpg differ diff --git a/llava/serve/examples/waterview.jpg b/llava/serve/examples/waterview.jpg new file mode 100755 index 000000000..6f44ebaba Binary files /dev/null and b/llava/serve/examples/waterview.jpg differ diff --git a/llava/serve/gradio_multi_image.py b/llava/serve/gradio_multi_image.py new file mode 100755 index 000000000..ca0e4206f --- /dev/null +++ b/llava/serve/gradio_multi_image.py @@ -0,0 +1,448 @@ +import argparse +import datetime +import json +import os +import time + +import gradio as gr +import requests + +from llava.conversation import default_conversation, conv_templates, SeparatorStyle +from llava.constants import LOGDIR +from llava.utils import build_logger, server_error_msg, violates_moderation, moderation_msg +import hashlib + + +logger = build_logger("gradio_web_server", "gradio_web_server.log") + +headers = {"User-Agent": "LLaVA Client"} + +no_change_btn = gr.Button.update() +enable_btn = gr.Button.update(interactive=True) +disable_btn = gr.Button.update(interactive=False) + +priority = { + "vicuna-13b": "aaaaaaa", + "koala-13b": "aaaaaab", +} + + +def get_conv_log_filename(): + t = datetime.datetime.now() + name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json") + return name + + +def get_model_list(): + ret = requests.post(args.controller_url + "/refresh_all_workers") + assert ret.status_code == 200 + ret = requests.post(args.controller_url + "/list_models") + models = ret.json()["models"] + models.sort(key=lambda x: priority.get(x, x)) + logger.info(f"Models: {models}") + return models + + +get_window_url_params = """ +function() { + const params = new URLSearchParams(window.location.search); + url_params = Object.fromEntries(params); + console.log(url_params); + return url_params; + } +""" + + +def load_demo(url_params, request: gr.Request): + logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}") + + dropdown_update = gr.Dropdown.update(visible=True) + if "model" in url_params: + model = url_params["model"] + if model in models: + dropdown_update = gr.Dropdown.update(value=model, visible=True) + + state = default_conversation.copy() + return (state, dropdown_update, gr.Chatbot.update(visible=True), gr.Textbox.update(visible=True), gr.Button.update(visible=True), gr.Row.update(visible=True), gr.Accordion.update(visible=True)) + + +def load_demo_refresh_model_list(request: gr.Request): + logger.info(f"load_demo. ip: {request.client.host}") + models = get_model_list() + state = default_conversation.copy() + return ( + state, + gr.Dropdown.update(choices=models, value=models[0] if len(models) > 0 else ""), + gr.Chatbot.update(visible=True), + gr.Textbox.update(visible=True), + gr.Button.update(visible=True), + gr.Row.update(visible=True), + gr.Accordion.update(visible=True), + ) + + +def vote_last_response(state, vote_type, model_selector, request: gr.Request): + with open(get_conv_log_filename(), "a") as fout: + data = { + "tstamp": round(time.time(), 4), + "type": vote_type, + "model": model_selector, + "state": state.dict(), + "ip": request.client.host, + } + fout.write(json.dumps(data) + "\n") + + +def upvote_last_response(state, model_selector, request: gr.Request): + logger.info(f"upvote. ip: {request.client.host}") + vote_last_response(state, "upvote", model_selector, request) + return ("",) + (disable_btn,) * 3 + + +def downvote_last_response(state, model_selector, request: gr.Request): + logger.info(f"downvote. ip: {request.client.host}") + vote_last_response(state, "downvote", model_selector, request) + return ("",) + (disable_btn,) * 3 + + +def flag_last_response(state, model_selector, request: gr.Request): + logger.info(f"flag. ip: {request.client.host}") + vote_last_response(state, "flag", model_selector, request) + return ("",) + (disable_btn,) * 3 + + +def regenerate(state, image_process_mode, request: gr.Request): + logger.info(f"regenerate. ip: {request.client.host}") + state.messages[-1][-1] = None + prev_human_msg = state.messages[-2] + if type(prev_human_msg[1]) in (tuple, list): + prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode) + state.skip_next = False + return (state, state.to_gradio_chatbot(), "", None, None) + (disable_btn,) * 5 + + +def clear_history(request: gr.Request): + logger.info(f"clear_history. ip: {request.client.host}") + state = default_conversation.copy() + return (state, state.to_gradio_chatbot(), "", None, None) + (disable_btn,) * 5 + + +def add_text(state, text, image, image2, image_process_mode, request: gr.Request): + logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}") + if len(text) <= 0 and image is None: + state.skip_next = True + return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5 + if args.moderate: + flagged = violates_moderation(text) + if flagged: + state.skip_next = True + return (state, state.to_gradio_chatbot(), moderation_msg, None) + (no_change_btn,) * 5 + + text = text[:3072] # Hard cut-off + images = [x for x in [image, image2] if x is not None] + num_images = len(images) + if num_images > 0: + text = text.replace("", "").strip() + text = text[: 3072 - 512 * num_images] + text = "\n" * num_images + text + text = (text, images, image_process_mode) + if len(state.get_images(return_pil=True)) > 0: + state = default_conversation.copy() + state.append_message(state.roles[0], text) + state.append_message(state.roles[1], None) + state.skip_next = False + return (state, state.to_gradio_chatbot(), "", None, None) + (disable_btn,) * 5 + + +def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request): + logger.info(f"http_bot. ip: {request.client.host}") + start_tstamp = time.time() + model_name = model_selector + + if state.skip_next: + # This generate call is skipped due to invalid inputs + yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 + return + + if len(state.messages) == state.offset + 2: + # First round of conversation + if "llava" in model_name.lower(): + if "llama-2" in model_name.lower(): + if "sharegpt" in model_name.lower(): + if "mmtag" in model_name.lower(): + template_name = "v1_mmtag" + elif "plain" in model_name.lower() and "finetune" not in model_name.lower(): + template_name = "v1_mmtag" + else: + template_name = "llava_v1" + else: + if "mmtag" in model_name.lower(): + template_name = "llava_llama_2_mmtag" + elif "simple" in model_name.lower(): + template_name = "llava_llama_2_simple" + elif "plain" in model_name.lower() and "finetune" not in model_name.lower(): + template_name = "llava_llama_2_mmtag" + elif "simple" in model_name.lower(): + template_name = "llava_llama_2_simple" + else: + template_name = "llava_llama_2" + elif "v1" in model_name.lower(): + if "mmtag" in model_name.lower(): + template_name = "v1_mmtag" + elif "plain" in model_name.lower() and "finetune" not in model_name.lower(): + template_name = "v1_mmtag" + else: + template_name = "llava_v1" + elif "mpt" in model_name.lower(): + template_name = "mpt" + else: + if "mmtag" in model_name.lower(): + template_name = "v0_mmtag" + elif "plain" in model_name.lower() and "finetune" not in model_name.lower(): + template_name = "v0_mmtag" + else: + template_name = "llava_v0" + elif "mpt" in model_name.lower(): + template_name = "mpt_text" + elif "llama-2" in model_name.lower(): + if "sharegpt" in model_name.lower(): + template_name = "vicuna_v1" + else: + template_name = "llama_2" + else: + template_name = "vicuna_v1" + new_state = conv_templates[template_name].copy() + new_state.append_message(new_state.roles[0], state.messages[-2][1]) + new_state.append_message(new_state.roles[1], None) + state = new_state + + # Query worker address + controller_url = args.controller_url + ret = requests.post(controller_url + "/get_worker_address", json={"model": model_name}) + worker_addr = ret.json()["address"] + logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}") + + # No available worker + if worker_addr == "": + state.messages[-1][-1] = server_error_msg + yield (state, state.to_gradio_chatbot(), disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) + return + + # Construct prompt + prompt = state.get_prompt() + + all_images = state.get_images(return_pil=True) + all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images] + for image, hash in zip(all_images, all_image_hash): + t = datetime.datetime.now() + filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg") + if not os.path.isfile(filename): + os.makedirs(os.path.dirname(filename), exist_ok=True) + image.save(filename) + + # Make requests + pload = { + "model": model_name, + "prompt": prompt, + "temperature": float(temperature), + "top_p": float(top_p), + "max_new_tokens": min(int(max_new_tokens), 1536), + "stop": state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2, + "images": f"List of {len(state.get_images())} images: {all_image_hash}", + } + logger.info(f"==== request ====\n{pload}") + + pload["images"] = state.get_images() + + state.messages[-1][-1] = "▌" + yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 + + try: + # Stream output + response = requests.post(worker_addr + "/worker_generate_stream", headers=headers, json=pload, stream=True, timeout=10) + for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): + if chunk: + data = json.loads(chunk.decode()) + if data["error_code"] == 0: + output = data["text"][len(prompt) :].strip() + state.messages[-1][-1] = output + "▌" + yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 + else: + output = data["text"] + f" (error_code: {data['error_code']})" + state.messages[-1][-1] = output + yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) + return + time.sleep(0.03) + except requests.exceptions.RequestException as e: + state.messages[-1][-1] = server_error_msg + yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) + return + + state.messages[-1][-1] = state.messages[-1][-1][:-1] + yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 + + finish_tstamp = time.time() + logger.info(f"{output}") + + with open(get_conv_log_filename(), "a") as fout: + data = { + "tstamp": round(finish_tstamp, 4), + "type": "chat", + "model": model_name, + "start": round(start_tstamp, 4), + "finish": round(start_tstamp, 4), + "state": state.dict(), + "images": all_image_hash, + "ip": request.client.host, + } + fout.write(json.dumps(data) + "\n") + + +title_markdown = """ +# 🌋 LLaVA: Large Language and Vision Assistant +[[Project Page](https://llava-vl.github.io)] [[Code](https://github.com/haotian-liu/LLaVA)] [[Model](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md)] | 📚 [[LLaVA](https://arxiv.org/abs/2304.08485)] [[LLaVA-v1.5](https://arxiv.org/abs/2310.03744)] +""" + +tos_markdown = """ +### Terms of use +By using this service, users are required to agree to the following terms: +The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research. +Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator. +For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality. +""" + + +learn_more_markdown = """ +### License +The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation. +""" + +block_css = """ + +#buttons button { + min-width: min(120px,100%); +} + +#chatbot img { + display: inline-block; +} + +""" + + +def build_demo(embed_mode): + textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False) + with gr.Blocks(title="LLaVA", theme=gr.themes.Default(), css=block_css) as demo: + state = gr.State() + + if not embed_mode: + gr.Markdown(title_markdown) + + with gr.Row(): + with gr.Column(scale=3): + with gr.Row(elem_id="model_selector_row"): + model_selector = gr.Dropdown(choices=models, value=models[0] if len(models) > 0 else "", interactive=True, show_label=False, container=False) + + with gr.Row(elem_id="images"): + imagebox = gr.Image(type="pil") + imagebox_2 = gr.Image(type="pil") + image_process_mode = gr.Radio(["Crop", "Resize", "Pad", "Default"], value="Default", label="Preprocess for non-square image", visible=False) + + cur_dir = os.path.dirname(os.path.abspath(__file__)) + gr.Examples( + examples=[ + [f"{cur_dir}/examples/extreme_ironing.jpg", "What is unusual about this image?"], + [f"{cur_dir}/examples/waterview.jpg", "What are the things I should be cautious about when I visit here?"], + ], + inputs=[imagebox, textbox], + ) + + with gr.Accordion("Parameters", open=False, visible=False) as parameter_row: + temperature = gr.Slider( + minimum=0.0, + maximum=1.0, + value=0.2, + step=0.1, + interactive=True, + label="Temperature", + ) + top_p = gr.Slider( + minimum=0.0, + maximum=1.0, + value=0.7, + step=0.1, + interactive=True, + label="Top P", + ) + max_output_tokens = gr.Slider( + minimum=0, + maximum=1024, + value=512, + step=64, + interactive=True, + label="Max output tokens", + ) + + with gr.Column(scale=8): + chatbot = gr.Chatbot(elem_id="chatbot", label="LLaVA Chatbot", visible=False, height=550) + with gr.Row(): + with gr.Column(scale=8): + textbox.render() + with gr.Column(scale=1, min_width=50): + submit_btn = gr.Button(value="Submit", visible=False) + with gr.Row(visible=False) as button_row: + upvote_btn = gr.Button(value="👍 Upvote", interactive=False) + downvote_btn = gr.Button(value="👎 Downvote", interactive=False) + flag_btn = gr.Button(value="⚠️ Flag", interactive=False) + # stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False) + regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False) + clear_btn = gr.Button(value="🗑️ Clear", interactive=False) + + if not embed_mode: + gr.Markdown(tos_markdown) + gr.Markdown(learn_more_markdown) + url_params = gr.JSON(visible=False) + + # Register listeners + btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn] + upvote_btn.click(upvote_last_response, [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn]) + downvote_btn.click(downvote_last_response, [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn]) + flag_btn.click(flag_last_response, [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn]) + regenerate_btn.click(regenerate, [state, image_process_mode], [state, chatbot, textbox, imagebox, imagebox_2] + btn_list).then(http_bot, [state, model_selector, temperature, top_p, max_output_tokens], [state, chatbot] + btn_list) + clear_btn.click(clear_history, None, [state, chatbot, textbox, imagebox, imagebox_2] + btn_list) + + textbox.submit(add_text, [state, textbox, imagebox, imagebox_2, image_process_mode], [state, chatbot, textbox, imagebox, imagebox_2] + btn_list).then( + http_bot, [state, model_selector, temperature, top_p, max_output_tokens], [state, chatbot] + btn_list + ) + submit_btn.click(add_text, [state, textbox, imagebox, imagebox_2, image_process_mode], [state, chatbot, textbox, imagebox, imagebox_2] + btn_list).then( + http_bot, [state, model_selector, temperature, top_p, max_output_tokens], [state, chatbot] + btn_list + ) + + if args.model_list_mode == "once": + demo.load(load_demo, [url_params], [state, model_selector, chatbot, textbox, submit_btn, button_row, parameter_row], _js=get_window_url_params) + elif args.model_list_mode == "reload": + demo.load(load_demo_refresh_model_list, None, [state, model_selector, chatbot, textbox, submit_btn, button_row, parameter_row]) + else: + raise ValueError(f"Unknown model list mode: {args.model_list_mode}") + + return demo + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="0.0.0.0") + parser.add_argument("--port", type=int) + parser.add_argument("--controller-url", type=str, default="http://localhost:21001") + parser.add_argument("--concurrency-count", type=int, default=8) + parser.add_argument("--model-list-mode", type=str, default="once", choices=["once", "reload"]) + parser.add_argument("--share", action="store_true") + parser.add_argument("--moderate", action="store_true") + parser.add_argument("--embed", action="store_true") + args = parser.parse_args() + logger.info(f"args: {args}") + + models = get_model_list() + + logger.info(args) + demo = build_demo(args.embed) + demo.queue(concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False).launch(server_name=args.host, server_port=args.port, share=args.share) diff --git a/llava/serve/gradio_web_server.py b/llava/serve/gradio_web_server.py new file mode 100755 index 000000000..4070dfaa3 --- /dev/null +++ b/llava/serve/gradio_web_server.py @@ -0,0 +1,442 @@ +import argparse +import datetime +import json +import os +import time + +import gradio as gr +import requests + +from llava.conversation import default_conversation, conv_templates, SeparatorStyle +from llava.constants import LOGDIR +from llava.utils import build_logger, server_error_msg, violates_moderation, moderation_msg +import hashlib + + +logger = build_logger("gradio_web_server", "gradio_web_server.log") + +headers = {"User-Agent": "LLaVA Client"} + +no_change_btn = gr.Button.update() +enable_btn = gr.Button.update(interactive=True) +disable_btn = gr.Button.update(interactive=False) + +priority = { + "vicuna-13b": "aaaaaaa", + "koala-13b": "aaaaaab", +} + + +def get_conv_log_filename(): + t = datetime.datetime.now() + name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json") + return name + + +def get_model_list(): + ret = requests.post(args.controller_url + "/refresh_all_workers") + assert ret.status_code == 200 + ret = requests.post(args.controller_url + "/list_models") + models = ret.json()["models"] + models.sort(key=lambda x: priority.get(x, x)) + logger.info(f"Models: {models}") + return models + + +get_window_url_params = """ +function() { + const params = new URLSearchParams(window.location.search); + url_params = Object.fromEntries(params); + console.log(url_params); + return url_params; + } +""" + + +def load_demo(url_params, request: gr.Request): + logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}") + + dropdown_update = gr.Dropdown.update(visible=True) + if "model" in url_params: + model = url_params["model"] + if model in models: + dropdown_update = gr.Dropdown.update(value=model, visible=True) + + state = default_conversation.copy() + return state, dropdown_update + + +def load_demo_refresh_model_list(request: gr.Request): + logger.info(f"load_demo. ip: {request.client.host}") + models = get_model_list() + state = default_conversation.copy() + dropdown_update = gr.Dropdown.update(choices=models, value=models[0] if len(models) > 0 else "") + return state, dropdown_update + + +def vote_last_response(state, vote_type, model_selector, request: gr.Request): + with open(get_conv_log_filename(), "a") as fout: + data = { + "tstamp": round(time.time(), 4), + "type": vote_type, + "model": model_selector, + "state": state.dict(), + "ip": request.client.host, + } + fout.write(json.dumps(data) + "\n") + + +def upvote_last_response(state, model_selector, request: gr.Request): + logger.info(f"upvote. ip: {request.client.host}") + vote_last_response(state, "upvote", model_selector, request) + return ("",) + (disable_btn,) * 3 + + +def downvote_last_response(state, model_selector, request: gr.Request): + logger.info(f"downvote. ip: {request.client.host}") + vote_last_response(state, "downvote", model_selector, request) + return ("",) + (disable_btn,) * 3 + + +def flag_last_response(state, model_selector, request: gr.Request): + logger.info(f"flag. ip: {request.client.host}") + vote_last_response(state, "flag", model_selector, request) + return ("",) + (disable_btn,) * 3 + + +def regenerate(state, image_process_mode, request: gr.Request): + logger.info(f"regenerate. ip: {request.client.host}") + state.messages[-1][-1] = None + prev_human_msg = state.messages[-2] + if type(prev_human_msg[1]) in (tuple, list): + prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode) + state.skip_next = False + return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5 + + +def clear_history(request: gr.Request): + logger.info(f"clear_history. ip: {request.client.host}") + state = default_conversation.copy() + return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5 + + +def add_text(state, text, image, image_process_mode, request: gr.Request): + logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}") + if len(text) <= 0 and image is None: + state.skip_next = True + return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5 + if args.moderate: + flagged = violates_moderation(text) + if flagged: + state.skip_next = True + return (state, state.to_gradio_chatbot(), moderation_msg, None) + (no_change_btn,) * 5 + + text = text[:1536] # Hard cut-off + if image is not None: + text = text[:1200] # Hard cut-off for images + if "" not in text: + # text = '' + text + text = text + "\n" + text = (text, image, image_process_mode) + if len(state.get_images(return_pil=True)) > 0: + state = default_conversation.copy() + state.append_message(state.roles[0], text) + state.append_message(state.roles[1], None) + state.skip_next = False + return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5 + + +def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request, template_name=None): + logger.info(f"http_bot. ip: {request.client.host}") + start_tstamp = time.time() + model_name = model_selector + + if state.skip_next: + # This generate call is skipped due to invalid inputs + yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 + return + + if len(state.messages) == state.offset + 2: + # First round of conversation + if "llava" in model_name.lower(): + if "llama-2" in model_name.lower(): + template_name = "llava_llama_2" + elif "mistral" in model_name.lower() or "mixtral" in model_name.lower(): + if "orca" in model_name.lower(): + template_name = "mistral_orca" + elif "hermes" in model_name.lower(): + template_name = "mistral_direct" + else: + template_name = "mistral_instruct" + elif "zephyr" in model_name.lower(): + template_name = "mistral_zephyr" + elif "hermes" in model_name.lower(): + template_name = "mistral_direct" + elif "v1" in model_name.lower(): + if "mmtag" in model_name.lower(): + template_name = "llava_v1_mmtag" + elif "plain" in model_name.lower() and "finetune" not in model_name.lower(): + template_name = "llava_v1_mmtag" + else: + template_name = "llava_v1" + elif "mpt" in model_name.lower(): + template_name = "mpt" + else: + if "mmtag" in model_name.lower(): + template_name = "v0_plain" + elif "plain" in model_name.lower() and "finetune" not in model_name.lower(): + template_name = "v0_plain" + else: + template_name = "llava_v0" + elif "mistral" in model_name.lower() or "mixtral" in model_name.lower(): + if "orca" in model_name.lower(): + template_name = "mistral_orca" + elif "hermes" in model_name.lower(): + template_name = "mistral_direct" + else: + template_name = "mistral_instruct" + elif "hermes" in model_name.lower(): + template_name = "mistral_direct" + elif "zephyr" in model_name.lower(): + template_name = "mistral_zephyr" + elif "mpt" in model_name: + template_name = "mpt_text" + elif "llama-2" in model_name: + template_name = "llama_2" + else: + template_name = "vicuna_v1" + new_state = conv_templates[template_name].copy() + new_state.append_message(new_state.roles[0], state.messages[-2][1]) + new_state.append_message(new_state.roles[1], None) + state = new_state + + # Query worker address + controller_url = args.controller_url + ret = requests.post(controller_url + "/get_worker_address", json={"model": model_name}) + worker_addr = ret.json()["address"] + logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}") + + # No available worker + if worker_addr == "": + state.messages[-1][-1] = server_error_msg + yield (state, state.to_gradio_chatbot(), disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) + return + + # Construct prompt + prompt = state.get_prompt() + + all_images = state.get_images(return_pil=True) + all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images] + for image, hash in zip(all_images, all_image_hash): + t = datetime.datetime.now() + filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg") + if not os.path.isfile(filename): + os.makedirs(os.path.dirname(filename), exist_ok=True) + image.save(filename) + + # Make requests + pload = { + "model": model_name, + "prompt": prompt, + "temperature": float(temperature), + "top_p": float(top_p), + "max_new_tokens": min(int(max_new_tokens), 1536), + "stop": state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2, + "images": f"List of {len(state.get_images())} images: {all_image_hash}", + } + logger.info(f"==== request ====\n{pload}") + + pload["images"] = state.get_images() + + state.messages[-1][-1] = "▌" + yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 + + try: + # Stream output + response = requests.post(worker_addr + "/worker_generate_stream", headers=headers, json=pload, stream=True, timeout=100) + last_print_time = time.time() + for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): + if chunk: + data = json.loads(chunk.decode()) + if data["error_code"] == 0: + output = data["text"][len(prompt) :].strip() + state.messages[-1][-1] = output + "▌" + if time.time() - last_print_time > 0.05: + last_print_time = time.time() + yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 + else: + output = data["text"] + f" (error_code: {data['error_code']})" + state.messages[-1][-1] = output + yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) + return + time.sleep(0.03) + except requests.exceptions.RequestException as e: + state.messages[-1][-1] = server_error_msg + yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) + return + + state.messages[-1][-1] = state.messages[-1][-1][:-1] + yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 + + finish_tstamp = time.time() + logger.info(f"{output}") + + with open(get_conv_log_filename(), "a") as fout: + data = { + "tstamp": round(finish_tstamp, 4), + "type": "chat", + "model": model_name, + "start": round(start_tstamp, 4), + "finish": round(start_tstamp, 4), + "state": state.dict(), + "images": all_image_hash, + "ip": request.client.host, + } + fout.write(json.dumps(data) + "\n") + + +title_markdown = """ +# 🌋 LLaVA: Large Language and Vision Assistant +[[Project Page](https://llava-vl.github.io)] [[Code](https://github.com/haotian-liu/LLaVA)] [[Model](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md)] | 📚 [[LLaVA](https://arxiv.org/abs/2304.08485)] [[LLaVA-v1.5](https://arxiv.org/abs/2310.03744)] +""" + +tos_markdown = """ +### Terms of use +By using this service, users are required to agree to the following terms: +The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research. +Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator. +For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality. +""" + + +learn_more_markdown = """ +### License +The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation. +""" + +block_css = """ + +#buttons button { + min-width: min(120px,100%); +} + +""" + + +def build_demo(embed_mode): + textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False) + with gr.Blocks(title="LLaVA", theme=gr.themes.Default(), css=block_css) as demo: + state = gr.State() + + if not embed_mode: + gr.Markdown(title_markdown) + + with gr.Row(): + with gr.Column(scale=3): + with gr.Row(elem_id="model_selector_row"): + model_selector = gr.Dropdown(choices=models, value=models[0] if len(models) > 0 else "", interactive=True, show_label=False, container=False) + + imagebox = gr.Image(type="pil") + image_process_mode = gr.Radio(["Crop", "Resize", "Pad", "Default"], value="Default", label="Preprocess for non-square image", visible=False) + + cur_dir = os.path.dirname(os.path.abspath(__file__)) + gr.Examples( + examples=[ + [f"{cur_dir}/examples/extreme_ironing.jpg", "What is unusual about this image?"], + [f"{cur_dir}/examples/waterview.jpg", "What are the things I should be cautious about when I visit here?"], + ], + inputs=[imagebox, textbox], + ) + + with gr.Accordion("Parameters", open=False) as parameter_row: + temperature = gr.Slider( + minimum=0.0, + maximum=1.0, + value=0.2, + step=0.1, + interactive=True, + label="Temperature", + ) + top_p = gr.Slider( + minimum=0.0, + maximum=1.0, + value=0.7, + step=0.1, + interactive=True, + label="Top P", + ) + max_output_tokens = gr.Slider( + minimum=0, + maximum=1024, + value=512, + step=64, + interactive=True, + label="Max output tokens", + ) + + with gr.Column(scale=8): + chatbot = gr.Chatbot(elem_id="chatbot", label="LLaVA Chatbot", height=550) + with gr.Row(): + with gr.Column(scale=8): + textbox.render() + with gr.Column(scale=1, min_width=50): + submit_btn = gr.Button(value="Send", variant="primary") + with gr.Row(elem_id="buttons") as button_row: + upvote_btn = gr.Button(value="👍 Upvote", interactive=False) + downvote_btn = gr.Button(value="👎 Downvote", interactive=False) + flag_btn = gr.Button(value="⚠️ Flag", interactive=False) + # stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False) + regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False) + clear_btn = gr.Button(value="🗑️ Clear", interactive=False) + + if not embed_mode: + gr.Markdown(tos_markdown) + gr.Markdown(learn_more_markdown) + url_params = gr.JSON(visible=False) + + # Register listeners + btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn] + upvote_btn.click(upvote_last_response, [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn], queue=False) + downvote_btn.click(downvote_last_response, [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn], queue=False) + flag_btn.click(flag_last_response, [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn], queue=False) + + regenerate_btn.click(regenerate, [state, image_process_mode], [state, chatbot, textbox, imagebox] + btn_list, queue=False).then(http_bot, [state, model_selector, temperature, top_p, max_output_tokens], [state, chatbot] + btn_list) + + clear_btn.click(clear_history, None, [state, chatbot, textbox, imagebox] + btn_list, queue=False) + + textbox.submit(add_text, [state, textbox, imagebox, image_process_mode], [state, chatbot, textbox, imagebox] + btn_list, queue=False).then( + http_bot, [state, model_selector, temperature, top_p, max_output_tokens], [state, chatbot] + btn_list + ) + + submit_btn.click(add_text, [state, textbox, imagebox, image_process_mode], [state, chatbot, textbox, imagebox] + btn_list, queue=False).then( + http_bot, [state, model_selector, temperature, top_p, max_output_tokens], [state, chatbot] + btn_list + ) + + if args.model_list_mode == "once": + demo.load(load_demo, [url_params], [state, model_selector], _js=get_window_url_params, queue=False) + elif args.model_list_mode == "reload": + demo.load(load_demo_refresh_model_list, None, [state, model_selector], queue=False) + else: + raise ValueError(f"Unknown model list mode: {args.model_list_mode}") + + return demo + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="0.0.0.0") + parser.add_argument("--port", type=int) + parser.add_argument("--controller-url", type=str, default="http://localhost:21001") + parser.add_argument("--concurrency-count", type=int, default=10) + parser.add_argument("--model-list-mode", type=str, default="once", choices=["once", "reload"]) + parser.add_argument("--share", action="store_true") + parser.add_argument("--moderate", action="store_true") + parser.add_argument("--embed", action="store_true") + args = parser.parse_args() + logger.info(f"args: {args}") + + models = get_model_list() + + logger.info(args) + demo = build_demo(args.embed) + demo.queue(concurrency_count=args.concurrency_count, api_open=False).launch(server_name=args.host, server_port=args.port, share=args.share) diff --git a/llava/serve/model_worker.py b/llava/serve/model_worker.py new file mode 100755 index 000000000..cd2256189 --- /dev/null +++ b/llava/serve/model_worker.py @@ -0,0 +1,271 @@ +""" +A model worker executes the model. +""" + +import argparse +import asyncio +import json +import time +import threading +import uuid + +from fastapi import FastAPI, Request, BackgroundTasks +from fastapi.responses import StreamingResponse +import requests +import torch +import uvicorn +from functools import partial + +from llava.constants import WORKER_HEART_BEAT_INTERVAL +from llava.utils import build_logger, server_error_msg, pretty_print_semaphore +from llava.model.builder import load_pretrained_model +from llava.mm_utils import process_images, load_image_from_base64, tokenizer_image_token, KeywordsStoppingCriteria +from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN +from transformers import TextIteratorStreamer +from threading import Thread + + +GB = 1 << 30 + +worker_id = str(uuid.uuid4())[:6] +logger = build_logger("model_worker", f"model_worker_{worker_id}.log") +global_counter = 0 + +model_semaphore = None + + +def heart_beat_worker(controller): + + while True: + time.sleep(WORKER_HEART_BEAT_INTERVAL) + controller.send_heart_beat() + + +class ModelWorker: + def __init__(self, controller_addr, worker_addr, worker_id, no_register, model_path, model_base, model_name, load_8bit, load_4bit): + self.controller_addr = controller_addr + self.worker_addr = worker_addr + self.worker_id = worker_id + if model_path.endswith("/"): + model_path = model_path[:-1] + if model_name is None: + model_paths = model_path.split("/") + if model_paths[-1].startswith("checkpoint-"): + self.model_name = model_paths[-2] + "_" + model_paths[-1] + else: + self.model_name = model_paths[-1] + else: + self.model_name = model_name + + logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...") + self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(model_path, model_base, self.model_name, load_8bit, load_4bit) + self.is_multimodal = "llava" in self.model_name.lower() + + if not no_register: + self.register_to_controller() + self.heart_beat_thread = threading.Thread(target=heart_beat_worker, args=(self,)) + self.heart_beat_thread.start() + + def register_to_controller(self): + logger.info("Register to controller") + + url = self.controller_addr + "/register_worker" + data = {"worker_name": self.worker_addr, "check_heart_beat": True, "worker_status": self.get_status()} + r = requests.post(url, json=data) + assert r.status_code == 200 + + def send_heart_beat(self): + logger.info(f"Send heart beat. Models: {[self.model_name]}. " f"Semaphore: {pretty_print_semaphore(model_semaphore)}. " f"global_counter: {global_counter}") + + url = self.controller_addr + "/receive_heart_beat" + + while True: + try: + ret = requests.post(url, json={"worker_name": self.worker_addr, "queue_length": self.get_queue_length()}, timeout=5) + exist = ret.json()["exist"] + break + except requests.exceptions.RequestException as e: + logger.error(f"heart beat error: {e}") + time.sleep(5) + + if not exist: + self.register_to_controller() + + def get_queue_length(self): + if model_semaphore is None: + return 0 + else: + return args.limit_model_concurrency - model_semaphore._value + (len(model_semaphore._waiters) if model_semaphore._waiters is not None else 0) + + def get_status(self): + return { + "model_names": [self.model_name], + "speed": 1, + "queue_length": self.get_queue_length(), + } + + @torch.inference_mode() + def generate_stream(self, params): + tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor + + prompt = params["prompt"] + ori_prompt = prompt + images = params.get("images", None) + num_image_tokens = 0 + if images is not None and len(images) > 0 and self.is_multimodal: + if len(images) > 0: + if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN): + raise ValueError("Number of images does not match number of tokens in prompt") + + images = [load_image_from_base64(image) for image in images] + image_sizes = [image.size for image in images] + images = process_images(images, image_processor, model.config) + + if type(images) is list: + images = [image.to(self.model.device, dtype=torch.float16) for image in images] + else: + images = images.to(self.model.device, dtype=torch.float16) + + replace_token = DEFAULT_IMAGE_TOKEN + if getattr(self.model.config, "mm_use_im_start_end", False): + replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN + prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token) + + num_image_tokens = prompt.count(replace_token) * model.get_vision_tower().num_patches + else: + images = None + image_sizes = None + image_args = {"images": images, "image_sizes": image_sizes} + else: + images = None + image_args = {} + + temperature = float(params.get("temperature", 1.0)) + top_p = float(params.get("top_p", 1.0)) + max_context_length = getattr(model.config, "max_position_embeddings", 2048) + max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024) + stop_str = params.get("stop", None) + do_sample = True if temperature > 0.001 else False + + input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).cuda() + keywords = [stop_str] + stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) + streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15) + + max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens) + + if max_new_tokens < 1: + yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0" + return + + thread = Thread( + target=model.generate, + kwargs=dict( + inputs=input_ids, + do_sample=do_sample, + temperature=temperature, + top_p=top_p, + max_new_tokens=max_new_tokens, + streamer=streamer, + # stopping_criteria=[stopping_criteria], + use_cache=True, + **image_args, + ), + ) + thread.start() + + start_time = time.time() + generated_text = ori_prompt + for new_text in streamer: + generated_text += new_text + if generated_text.endswith(stop_str): + generated_text = generated_text[: -len(stop_str)] + yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0" + + end_time = time.time() + + new_generated = generated_text[len(ori_prompt) :] + new_generated_tokens = tokenizer(new_generated).input_ids + token_per_second = len(new_generated_tokens) / (end_time - start_time) + print(f"token_per_second: {token_per_second}") + + def generate_stream_gate(self, params): + try: + for x in self.generate_stream(params): + yield x + except ValueError as e: + print("Caught ValueError:", e) + ret = { + "text": server_error_msg, + "error_code": 1, + } + yield json.dumps(ret).encode() + b"\0" + except torch.cuda.CudaError as e: + print("Caught torch.cuda.CudaError:", e) + ret = { + "text": server_error_msg, + "error_code": 1, + } + yield json.dumps(ret).encode() + b"\0" + except Exception as e: + print("Caught Unknown Error", e) + ret = { + "text": server_error_msg, + "error_code": 1, + } + yield json.dumps(ret).encode() + b"\0" + + +app = FastAPI() + + +def release_model_semaphore(fn=None): + model_semaphore.release() + if fn is not None: + fn() + + +@app.post("/worker_generate_stream") +async def generate_stream(request: Request): + global model_semaphore, global_counter + global_counter += 1 + params = await request.json() + + if model_semaphore is None: + model_semaphore = asyncio.Semaphore(args.limit_model_concurrency) + await model_semaphore.acquire() + worker.send_heart_beat() + generator = worker.generate_stream_gate(params) + background_tasks = BackgroundTasks() + background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat)) + return StreamingResponse(generator, background=background_tasks) + + +@app.post("/worker_get_status") +async def get_status(request: Request): + return worker.get_status() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=21002) + parser.add_argument("--worker-address", type=str, default="http://localhost:21002") + parser.add_argument("--controller-address", type=str, default="http://localhost:21001") + parser.add_argument("--model-path", type=str, default="facebook/opt-350m") + parser.add_argument("--model-base", type=str, default=None) + parser.add_argument("--model-name", type=str) + parser.add_argument("--multi-modal", action="store_true", help="Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.") + parser.add_argument("--limit-model-concurrency", type=int, default=5) + parser.add_argument("--stream-interval", type=int, default=1) + parser.add_argument("--no-register", action="store_true") + parser.add_argument("--load-8bit", action="store_true") + parser.add_argument("--load-4bit", action="store_true") + args = parser.parse_args() + logger.info(f"args: {args}") + + if args.multi_modal: + logger.warning("Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.") + + worker = ModelWorker(args.controller_address, args.worker_address, worker_id, args.no_register, args.model_path, args.model_base, args.model_name, args.load_8bit, args.load_4bit) + uvicorn.run(app, host=args.host, port=args.port, log_level="info") diff --git a/llava/serve/register_worker.py b/llava/serve/register_worker.py new file mode 100755 index 000000000..2c2c40295 --- /dev/null +++ b/llava/serve/register_worker.py @@ -0,0 +1,26 @@ +""" +Manually register workers. + +Usage: +python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002 +""" + +import argparse + +import requests + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--controller-address", type=str) + parser.add_argument("--worker-name", type=str) + parser.add_argument("--check-heart-beat", action="store_true") + args = parser.parse_args() + + url = args.controller_address + "/register_worker" + data = { + "worker_name": args.worker_name, + "check_heart_beat": args.check_heart_beat, + "worker_status": None, + } + r = requests.post(url, json=data) + assert r.status_code == 200 diff --git a/llava/serve/sglang_worker.py b/llava/serve/sglang_worker.py new file mode 100755 index 000000000..09c047f48 --- /dev/null +++ b/llava/serve/sglang_worker.py @@ -0,0 +1,237 @@ +""" +A model worker executes the model. +""" + +import argparse +import asyncio +from concurrent.futures import ThreadPoolExecutor +import json +import time +import threading +import uuid + +from fastapi import FastAPI, Request, BackgroundTasks +from fastapi.responses import StreamingResponse +import requests +import re +import uvicorn +from functools import partial + +from llava.constants import WORKER_HEART_BEAT_INTERVAL +from llava.utils import build_logger, server_error_msg, pretty_print_semaphore +from llava.model.builder import load_pretrained_model +from llava.mm_utils import process_images, load_image_from_base64, tokenizer_image_token, expand2square +from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN +from transformers import AutoTokenizer + +import sglang as sgl +from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend +from sglang.backend.runtime_endpoint import RuntimeEndpoint +from sglang.utils import read_jsonl, dump_state_text +from sglang.lang.interpreter import ProgramState + + +GB = 1 << 30 + +worker_id = str(uuid.uuid4())[:6] +logger = build_logger("model_worker", f"model_worker_{worker_id}.log") +global_counter = 0 + +model_semaphore = None + + +def heart_beat_worker(controller): + while True: + time.sleep(WORKER_HEART_BEAT_INTERVAL) + controller.send_heart_beat() + + +@sgl.function +def pipeline(s, prompt, max_tokens): + for p in prompt: + if type(p) is str: + s += p + else: + s += sgl.image(p) + s += sgl.gen("response", max_tokens=max_tokens) + + +class ModelWorker: + def __init__(self, controller_addr, worker_addr, sgl_endpoint, worker_id, no_register, model_name): + self.controller_addr = controller_addr + self.worker_addr = worker_addr + self.worker_id = worker_id + + # Select backend + backend = RuntimeEndpoint(sgl_endpoint) + sgl.set_default_backend(backend) + model_path = backend.model_info["model_path"] + + if model_path.endswith("/"): + model_path = model_path[:-1] + if model_name is None: + model_paths = model_path.split("/") + if model_paths[-1].startswith("checkpoint-"): + self.model_name = model_paths[-2] + "_" + model_paths[-1] + else: + self.model_name = model_paths[-1] + else: + self.model_name = model_name + + logger.info(f"Loading the SGLANG model {self.model_name} on worker {worker_id} ...") + + if not no_register: + self.register_to_controller() + self.heart_beat_thread = threading.Thread(target=heart_beat_worker, args=(self,)) + self.heart_beat_thread.start() + + def register_to_controller(self): + logger.info("Register to controller") + + url = self.controller_addr + "/register_worker" + data = {"worker_name": self.worker_addr, "check_heart_beat": True, "worker_status": self.get_status()} + r = requests.post(url, json=data) + assert r.status_code == 200 + + def send_heart_beat(self): + logger.info(f"Send heart beat. Models: {[self.model_name]}. " f"Semaphore: {pretty_print_semaphore(model_semaphore)}. " f"global_counter: {global_counter}") + + url = self.controller_addr + "/receive_heart_beat" + + while True: + try: + ret = requests.post(url, json={"worker_name": self.worker_addr, "queue_length": self.get_queue_length()}, timeout=5) + exist = ret.json()["exist"] + break + except requests.exceptions.RequestException as e: + logger.error(f"heart beat error: {e}") + time.sleep(5) + + if not exist: + self.register_to_controller() + + def get_queue_length(self): + if model_semaphore is None: + return 0 + else: + return args.limit_model_concurrency - model_semaphore._value + (len(model_semaphore._waiters) if model_semaphore._waiters is not None else 0) + + def get_status(self): + return { + "model_names": [self.model_name], + "speed": 1, + "queue_length": self.get_queue_length(), + } + + async def generate_stream(self, params): + ori_prompt = prompt = params["prompt"] + images = params.get("images", None) + if images is not None and len(images) > 0: + if len(images) > 0: + if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN): + raise ValueError("Number of images does not match number of tokens in prompt") + + images = [load_image_from_base64(image) for image in images] + # FIXME: hacky padding + images = [expand2square(image, tuple(int(x * 255) for x in [0.48145466, 0.4578275, 0.40821073])) for image in images] + + # FIXME: for image-start/end token + # replace_token = DEFAULT_IMAGE_TOKEN + # if getattr(self.model.config, 'mm_use_im_start_end', False): + # replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN + # prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token) + prompt = prompt.replace(" " + DEFAULT_IMAGE_TOKEN + "\n", DEFAULT_IMAGE_TOKEN) + prompt_split = prompt.split(DEFAULT_IMAGE_TOKEN) + prompt = [] + for i in range(len(prompt_split)): + prompt.append(prompt_split[i]) + if i < len(images): + prompt.append(images[i]) + else: + prompt = [prompt] + + temperature = float(params.get("temperature", 1.0)) + top_p = float(params.get("top_p", 1.0)) + # max_context_length = getattr(model.config, 'max_position_embeddings', 2048) + max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024) + stop_str = params.get("stop", None) + stop_str = [stop_str] if stop_str is not None else None + + if max_new_tokens < 1: + yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0" + return + + # print(prompt) + state = pipeline.run(prompt, max_new_tokens, temperature=temperature, top_p=top_p, stream=True) + + generated_text = ori_prompt + async for text_outputs in state.text_async_iter(var_name="response"): + generated_text += text_outputs + yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0" + + async def generate_stream_gate(self, params): + try: + async for x in self.generate_stream(params): + yield x + except ValueError as e: + print("Caught ValueError:", e) + ret = { + "text": server_error_msg, + "error_code": 1, + } + yield json.dumps(ret).encode() + b"\0" + except Exception as e: + print("Caught Unknown Error", e) + ret = { + "text": server_error_msg, + "error_code": 1, + } + yield json.dumps(ret).encode() + b"\0" + + +app = FastAPI() + + +def release_model_semaphore(fn=None): + model_semaphore.release() + if fn is not None: + fn() + + +@app.post("/worker_generate_stream") +async def generate_stream(request: Request): + global model_semaphore, global_counter + global_counter += 1 + params = await request.json() + + if model_semaphore is None: + model_semaphore = asyncio.Semaphore(args.limit_model_concurrency) + await model_semaphore.acquire() + worker.send_heart_beat() + generator = worker.generate_stream_gate(params) + background_tasks = BackgroundTasks() + background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat)) + return StreamingResponse(generator, background=background_tasks) + + +@app.post("/worker_get_status") +async def get_status(request: Request): + return worker.get_status() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=21002) + parser.add_argument("--worker-address", type=str, default="http://localhost:21002") + parser.add_argument("--controller-address", type=str, default="http://localhost:21001") + parser.add_argument("--model-name", type=str) + parser.add_argument("--sgl-endpoint", type=str) + parser.add_argument("--limit-model-concurrency", type=int, default=5) + parser.add_argument("--stream-interval", type=int, default=1) + parser.add_argument("--no-register", action="store_true") + args = parser.parse_args() + logger.info(f"args: {args}") + + worker = ModelWorker(args.controller_address, args.worker_address, args.sgl_endpoint, worker_id, args.no_register, args.model_name) + uvicorn.run(app, host=args.host, port=args.port, log_level="info") diff --git a/llava/serve/test_message.py b/llava/serve/test_message.py new file mode 100755 index 000000000..45acd534f --- /dev/null +++ b/llava/serve/test_message.py @@ -0,0 +1,59 @@ +import argparse +import json + +import requests + +from llava.conversation import default_conversation + + +def main(): + if args.worker_address: + worker_addr = args.worker_address + else: + controller_addr = args.controller_address + ret = requests.post(controller_addr + "/refresh_all_workers") + ret = requests.post(controller_addr + "/list_models") + models = ret.json()["models"] + models.sort() + print(f"Models: {models}") + + ret = requests.post(controller_addr + "/get_worker_address", json={"model": args.model_name}) + worker_addr = ret.json()["address"] + print(f"worker_addr: {worker_addr}") + + if worker_addr == "": + return + + conv = default_conversation.copy() + conv.append_message(conv.roles[0], args.message) + prompt = conv.get_prompt() + + headers = {"User-Agent": "LLaVA Client"} + pload = { + "model": args.model_name, + "prompt": prompt, + "max_new_tokens": args.max_new_tokens, + "temperature": 0.7, + "stop": conv.sep, + } + response = requests.post(worker_addr + "/worker_generate_stream", headers=headers, json=pload, stream=True) + + print(prompt.replace(conv.sep, "\n"), end="") + for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"): + if chunk: + data = json.loads(chunk.decode("utf-8")) + output = data["text"].split(conv.sep)[-1] + print(output, end="\r") + print("") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--controller-address", type=str, default="http://localhost:21001") + parser.add_argument("--worker-address", type=str) + parser.add_argument("--model-name", type=str, default="facebook/opt-350m") + parser.add_argument("--max-new-tokens", type=int, default=32) + parser.add_argument("--message", type=str, default="Tell me a story with more than 1000 words.") + args = parser.parse_args() + + main() diff --git a/llava/train/llama_flash_attn_monkey_patch.py b/llava/train/llama_flash_attn_monkey_patch.py new file mode 100755 index 000000000..c88fe3426 --- /dev/null +++ b/llava/train/llama_flash_attn_monkey_patch.py @@ -0,0 +1,87 @@ +from typing import Optional, Tuple +import warnings + +import torch + +import transformers +from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv + +try: + from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func +except ImportError: + from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func +from flash_attn.bert_padding import unpad_input, pad_input + + +def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + padding_mask: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + warnings.warn("Output attentions is not supported for patched `LlamaAttention`, returning `None` instead.") + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) # shape: (b, num_heads, s, head_dim) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # reuse k, v + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + # Transform the data into the format required by flash attention + qkv = torch.stack([query_states, key_states, value_states], dim=2) + qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim] + key_padding_mask = attention_mask + + if key_padding_mask is None: + qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim) + cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device) + max_s = q_len + output = flash_attn_unpadded_qkvpacked_func(qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True) + output = output.view(bsz, q_len, -1) + else: + qkv = qkv.reshape(bsz, q_len, -1) + qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask) + qkv = qkv.view(-1, 3, self.num_heads, self.head_dim) + output_unpad = flash_attn_unpadded_qkvpacked_func(qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True) + output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim) + output = pad_input(output_unpad, indices, bsz, q_len) + + return self.o_proj(output), None, past_key_value + + +# Disable the transformation of the attention mask in LlamaModel as the flash attention +# requires the attention mask to be the same as the key_padding_mask +def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): + # [bsz, seq_len] + return attention_mask + + +def replace_llama_attn_with_flash_attn(): + cuda_major, cuda_minor = torch.cuda.get_device_capability() + if cuda_major < 8: + warnings.warn("Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward." "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593") + transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask + transformers.models.llama.modeling_llama.LlamaAttention.forward = forward diff --git a/llava/train/llava_trainer.py b/llava/train/llava_trainer.py new file mode 100755 index 000000000..e65473cd0 --- /dev/null +++ b/llava/train/llava_trainer.py @@ -0,0 +1,527 @@ +import os +import torch +import torch.nn as nn +import datetime + +from accelerate import Accelerator +from accelerate.utils import InitProcessGroupKwargs, GradientAccumulationPlugin +from torch.utils.data import Dataset, Sampler, DataLoader + +from trl.trainer import DPOTrainer +from trl.trainer.utils import DPODataCollatorWithPadding + +from transformers import Trainer +from transformers.trainer import is_sagemaker_mp_enabled, get_parameter_names, has_length, ALL_LAYERNORM_LAYERS, logger, is_accelerate_available, is_datasets_available, GradientAccumulationPlugin +from transformers.trainer_utils import seed_worker +from transformers.trainer_pt_utils import get_length_grouped_indices as get_length_grouped_indices_hf +from transformers.trainer_pt_utils import AcceleratorConfig +from typing import List, Optional +from datetime import timedelta + +if is_accelerate_available(): + from accelerate import Accelerator, skip_first_batches, InitProcessGroupKwargs + +if is_datasets_available(): + import datasets + +from llava.utils import rank0_print + + +def maybe_zero_3(param, ignore_status=False, name=None): + from deepspeed import zero + from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus + + if hasattr(param, "ds_id"): + if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: + if not ignore_status: + print(name, "no ignore status") + with zero.GatheredParameters([param]): + param = param.data.detach().cpu().clone() + else: + param = param.detach().cpu().clone() + return param + + +def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): + to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)} + to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()} + return to_return + + +def split_to_even_chunks(indices, lengths, num_chunks): + """ + Split a list of indices into `chunks` chunks of roughly equal lengths. + """ + + if len(indices) % num_chunks != 0: + return [indices[i::num_chunks] for i in range(num_chunks)] + + num_indices_per_chunk = len(indices) // num_chunks + + chunks = [[] for _ in range(num_chunks)] + chunks_lengths = [0 for _ in range(num_chunks)] + for index in indices: + shortest_chunk = chunks_lengths.index(min(chunks_lengths)) + chunks[shortest_chunk].append(index) + chunks_lengths[shortest_chunk] += lengths[index] + if len(chunks[shortest_chunk]) == num_indices_per_chunk: + chunks_lengths[shortest_chunk] = float("inf") + + return chunks + + +def get_variable_length_grouped_indices(lengths, batch_size, world_size, megabatch_mult=8, generator=None): + # We need to use torch for the random part as a distributed sampler will set the random seed for torch. + indices = torch.randperm(len(lengths), generator=generator) + sorted_indices = sorted(range(len(lengths)), key=lambda i: lengths[i], reverse=True) + megabatch_size = world_size * batch_size * megabatch_mult + megabatches = [sorted_indices[i : i + megabatch_size] for i in range(0, len(lengths), megabatch_size)] + megabatches = [sorted(megabatch, key=lambda i: indices[i], reverse=True) for megabatch in megabatches] + shuffled_indices = [i for megabatch in megabatches for i in megabatch] + world_batch_size = world_size * batch_size + batches = [shuffled_indices[i : i + world_batch_size] for i in range(0, len(lengths), world_batch_size)] + batch_indices = torch.randperm(len(batches), generator=generator) + batches = [batches[i] for i in batch_indices] + + return [i for batch in batches for i in batch] + + +def get_modality_length_grouped_indices(lengths, batch_size, world_size, generator=None): + """ + Return a list of indices so that each slice of `batch_size` consecutive indices correspond to elements of similar + lengths. To do this, the indices are: + + - randomly permuted + - grouped in mega-batches of size `mega_batch_mult * batch_size` + - reorder by length in each mega-batch + + The result is the concatenation of all mega-batches, with the batch of `batch_size` containing the element of + maximum length placed first, so that an OOM happens sooner rather than later. + """ + + # We need to use torch for the random part as a distributed sampler will set the random seed for torch. + assert all(l != 0 for l in lengths), "Should not have zero length." + if all(l > 0 for l in lengths) or all(l < 0 for l in lengths): + # all samples are in the same modality + return get_length_grouped_indices(lengths, batch_size, world_size, generator=generator) + mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0]) + lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0]) + + mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices(mm_lengths, batch_size, world_size, generator=None)] + lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices(lang_lengths, batch_size, world_size, generator=None)] + megabatch_size = world_size * batch_size + mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)] + lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)] + + last_mm = mm_megabatches[-1] + last_lang = lang_megabatches[-1] + additional_batch = last_mm + last_lang + megabatches = mm_megabatches[:-1] + lang_megabatches[:-1] + megabatch_indices = torch.randperm(len(megabatches), generator=generator) + megabatches = [megabatches[i] for i in megabatch_indices] + + if len(additional_batch) > 0: + megabatches.append(sorted(additional_batch)) + + return [i for megabatch in megabatches for i in megabatch] + + +def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True): + """ + Return a list of indices so that each slice of `batch_size` consecutive indices correspond to elements of similar + lengths. To do this, the indices are: + + - randomly permuted + - grouped in mega-batches of size `mega_batch_mult * batch_size` + - reorder by length in each mega-batch + + The result is the concatenation of all mega-batches, with the batch of `batch_size` containing the element of + maximum length placed first, so that an OOM happens sooner rather than later. + """ + + # We need to use torch for the random part as a distributed sampler will set the random seed for torch. + indices = torch.randperm(len(lengths), generator=generator) + megabatch_size = world_size * batch_size + megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)] + megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches] + megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches] + + return [i for megabatch in megabatches for batch in megabatch for i in batch] + + +def get_length_grouped_indices_auto_single(lengths, batch_size, world_size, generator=None): + indices = get_length_grouped_indices_hf(lengths, batch_size * world_size, generator=generator) + + megabatch_size = world_size * batch_size + megabatches = [indices[i : i + megabatch_size] for i in range(0, len(lengths), megabatch_size)] + megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches] + megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches] + + # We need to use torch for the random part as a distributed sampler will set the random seed for torch. + batch_indices = torch.randperm(len(megabatches), generator=generator) + megabatches = [megabatches[i] for i in batch_indices] + + return [i for megabatch in megabatches for batch in megabatch for i in batch] + + +def get_modality_length_grouped_indices_auto(lengths, batch_size, world_size, generator=None): + # We need to use torch for the random part as a distributed sampler will set the random seed for torch. + assert all(l != 0 for l in lengths), "Should not have zero length." + if all(l > 0 for l in lengths) or all(l < 0 for l in lengths): + # all samples are in the same modality + return get_length_grouped_indices_auto_single(lengths, batch_size, world_size, generator=generator) + mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0]) + lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0]) + + mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices_auto_single(mm_lengths, batch_size, world_size, generator=None)] + lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices_auto_single(lang_lengths, batch_size, world_size, generator=None)] + megabatch_size = world_size * batch_size + mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)] + lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)] + + last_mm = mm_megabatches[-1] + last_lang = lang_megabatches[-1] + additional_batch = last_mm + last_lang + megabatches = mm_megabatches[:-1] + lang_megabatches[:-1] + megabatch_indices = torch.randperm(len(megabatches), generator=generator) + megabatches = [megabatches[i] for i in megabatch_indices] + + # FIXME: Hard code to avoid last batch mixed with different modalities + # if len(additional_batch) > 0: + # megabatches.append(sorted(additional_batch)) + + return [i for megabatch in megabatches for i in megabatch] + + +class LengthGroupedSampler(Sampler): + r""" + Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while + keeping a bit of randomness. + """ + + def __init__( + self, + batch_size: int, + world_size: int, + lengths: Optional[List[int]] = None, + generator=None, + variable_length: bool = False, + group_by_modality: bool = False, + group_by_modality_auto: bool = False, + ): + if lengths is None: + raise ValueError("Lengths must be provided.") + + self.batch_size = batch_size + self.world_size = world_size + self.lengths = lengths + self.generator = generator + self.variable_length = variable_length + self.group_by_modality = group_by_modality + self.group_by_modality_auto = group_by_modality_auto + + def __len__(self): + return len(self.lengths) + + def __iter__(self): + if self.variable_length: + assert not self.group_by_modality, "Variable length grouping is not supported with modality grouping." + indices = get_variable_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator) + else: + if self.group_by_modality: + indices = get_modality_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator) + elif self.group_by_modality_auto: + indices = get_modality_length_grouped_indices_auto(self.lengths, self.batch_size, self.world_size, generator=self.generator) + else: + indices = get_length_grouped_indices_auto_single(self.lengths, self.batch_size, self.world_size, generator=self.generator) + return iter(indices) + + +class LLaVATrainer(Trainer): + + def create_accelerator_and_postprocess(self): + grad_acc_kwargs = {"num_steps": self.args.gradient_accumulation_steps} + grad_acc_kwargs["sync_with_dataloader"] = False + gradient_accumulation_plugin = GradientAccumulationPlugin(**grad_acc_kwargs) + + accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52)) + rank0_print("Setting NCCL timeout to INF to avoid running errors.") + + # create accelerator object + self.accelerator = Accelerator( + dispatch_batches=self.args.dispatch_batches, split_batches=self.args.split_batches, deepspeed_plugin=self.args.deepspeed_plugin, gradient_accumulation_plugin=gradient_accumulation_plugin, kwargs_handlers=[accelerator_kwargs] + ) + # some Trainer classes need to use `gather` instead of `gather_for_metrics`, thus we store a flag + self.gather_function = self.accelerator.gather_for_metrics + + # deepspeed and accelerate flags covering both trainer args and accelerate launcher + self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None + self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None + + # post accelerator creation setup + if self.is_fsdp_enabled: + fsdp_plugin = self.accelerator.state.fsdp_plugin + fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get("limit_all_gathers", fsdp_plugin.limit_all_gathers) + if is_accelerate_available("0.23.0"): + fsdp_plugin.activation_checkpointing = self.args.fsdp_config.get("activation_checkpointing", fsdp_plugin.activation_checkpointing) + if fsdp_plugin.activation_checkpointing and self.args.gradient_checkpointing: + raise ValueError("The activation_checkpointing in FSDP config and the gradient_checkpointing in training arg " "can't be set to True simultaneously. Please use FSDP's activation_checkpointing logic " "when using FSDP.") + + if self.is_deepspeed_enabled and getattr(self.args, "hf_deepspeed_config", None) is None: + self.propagate_args_to_deepspeed() + + def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: + if self.train_dataset is None or not has_length(self.train_dataset): + return None + + if self.args.group_by_length: + lengths = self.train_dataset.lengths + return LengthGroupedSampler( + # self.args.train_batch_size * self.args.gradient_accumulation_steps, # TODO: seems that we should not have gradient_accumulation_steps + self.args.train_batch_size, + # world_size=self.args.world_size, + world_size=self.args.world_size * self.args.gradient_accumulation_steps, # TODO: seems that this may work? + lengths=lengths, + ) + elif self.args.group_by_modality_length: + lengths = self.train_dataset.modality_lengths + return LengthGroupedSampler( + # self.args.train_batch_size * self.args.gradient_accumulation_steps, # TODO: seems that we should not have gradient_accumulation_steps + self.args.train_batch_size, + # world_size=self.args.world_size, + world_size=self.args.world_size * self.args.gradient_accumulation_steps, # TODO: seems that this may work? + lengths=lengths, + group_by_modality=True, + ) + elif self.args.group_by_modality_length_auto: + lengths = self.train_dataset.modality_lengths + return LengthGroupedSampler( + # self.args.train_batch_size * self.args.gradient_accumulation_steps, # TODO: seems that we should not have gradient_accumulation_steps + self.args.train_batch_size, + # world_size=self.args.world_size, + world_size=self.args.world_size * self.args.gradient_accumulation_steps, # TODO: seems that this may work? + lengths=lengths, + group_by_modality_auto=True, + ) + elif self.args.group_by_varlen: + lengths = self.train_dataset.lengths + return LengthGroupedSampler( + self.args.train_batch_size * self.args.gradient_accumulation_steps, + # self.args.train_batch_size, # TODO: seems that we should have gradient_accumulation_steps + # world_size=self.args.world_size, + world_size=self.args.world_size * self.args.gradient_accumulation_steps, # TODO: seems that this may work? + lengths=lengths, + variable_length=True, + ) + else: + return super()._get_train_sampler() + + def get_train_dataloader(self) -> DataLoader: + """ + Returns the training [`~torch.utils.data.DataLoader`]. + + Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed + training if necessary) otherwise. + + Subclass and override this method if you want to inject some custom behavior. + """ + if self.train_dataset is None: + raise ValueError("Trainer: training requires a train_dataset.") + + train_dataset = self.train_dataset + data_collator = self.data_collator + if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): + train_dataset = self._remove_unused_columns(train_dataset, description="training") + else: + data_collator = self._get_collator_with_removed_columns(data_collator, description="training") + + dataloader_params = { + "batch_size": self._train_batch_size, + "collate_fn": data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + "persistent_workers": self.args.dataloader_persistent_workers, + } + + if not isinstance(train_dataset, torch.utils.data.IterableDataset): + dataloader_params["sampler"] = self._get_train_sampler() + dataloader_params["drop_last"] = self.args.dataloader_drop_last + dataloader_params["worker_init_fn"] = seed_worker + dataloader_params["prefetch_factor"] = self.args.dataloader_num_workers * 2 if self.args.dataloader_num_workers != 0 else None + + dataloader = self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) + + return dataloader + + def create_optimizer(self): + """ + Setup the optimizer. + + We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the + Trainer's init through `optimizers`, or subclass and override this method in a subclass. + """ + if is_sagemaker_mp_enabled(): + return super().create_optimizer() + + opt_model = self.model + + if self.optimizer is None: + decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS) + decay_parameters = [name for name in decay_parameters if "bias" not in name] + lr_mapper = {} + if self.args.mm_projector_lr is not None: + lr_mapper["mm_projector"] = self.args.mm_projector_lr + if self.args.mm_vision_tower_lr is not None: + lr_mapper["vision_tower"] = self.args.mm_vision_tower_lr + if len(lr_mapper) > 0: + special_lr_parameters = [name for name, _ in opt_model.named_parameters() if any(module_keyword in name for module_keyword in lr_mapper)] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in special_lr_parameters and p.requires_grad)], + "weight_decay": self.args.weight_decay, + }, + { + "params": [p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in special_lr_parameters and p.requires_grad)], + "weight_decay": 0.0, + }, + ] + for module_keyword, lr in lr_mapper.items(): + module_parameters = [name for name, _ in opt_model.named_parameters() if module_keyword in name] + optimizer_grouped_parameters.extend( + [ + { + "params": [p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in module_parameters and p.requires_grad)], + "weight_decay": self.args.weight_decay, + "lr": lr, + }, + { + "params": [p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in module_parameters and p.requires_grad)], + "weight_decay": 0.0, + "lr": lr, + }, + ] + ) + else: + optimizer_grouped_parameters = [ + { + "params": [p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)], + "weight_decay": self.args.weight_decay, + }, + { + "params": [p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)], + "weight_decay": 0.0, + }, + ] + + optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) + + self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) + if optimizer_cls.__name__ == "Adam8bit": + import bitsandbytes + + manager = bitsandbytes.optim.GlobalOptimManager.get_instance() + + skipped = 0 + for module in opt_model.modules(): + if isinstance(module, nn.Embedding): + skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values()) + logger.info(f"skipped {module}: {skipped/2**20}M params") + manager.register_module_override(module, "weight", {"optim_bits": 32}) + logger.debug(f"bitsandbytes: will optimize {module} in fp32") + logger.info(f"skipped: {skipped/2**20}M params") + + return self.optimizer + + def _save_checkpoint(self, model, trial, metrics=None): + if getattr(self.args, "tune_mm_mlp_adapter", False) or ( + hasattr(self.args, "mm_tunable_parts") and (len(self.args.mm_tunable_parts.split(",")) == 1 and ("mm_mlp_adapter" in self.args.mm_tunable_parts or "mm_vision_resampler" in self.args.mm_tunable_parts)) + ): + from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR + + checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" + + run_dir = self._get_output_dir(trial=trial) + output_dir = os.path.join(run_dir, checkpoint_folder) + + # Only save Adapter + keys_to_match = ["mm_projector", "vision_resampler"] + if getattr(self.args, "use_im_start_end", False): + keys_to_match.extend(["embed_tokens", "embed_in"]) + + weight_to_save = get_mm_adapter_state_maybe_zero_3(self.model.named_parameters(), keys_to_match) + + if self.args.local_rank == 0 or self.args.local_rank == -1: + self.model.config.save_pretrained(output_dir) + torch.save(weight_to_save, os.path.join(output_dir, f"mm_projector.bin")) + else: + super(LLaVATrainer, self)._save_checkpoint(model, trial, metrics) + + def _save(self, output_dir: Optional[str] = None, state_dict=None): + if getattr(self.args, "tune_mm_mlp_adapter", False): + pass + else: + super(LLaVATrainer, self)._save(output_dir, state_dict) + + +class LLaVADPOTrainer(DPOTrainer): + def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: + if self.train_dataset is None or not has_length(self.train_dataset): + return None + + if self.args.group_by_modality_length: + lengths = self.train_dataset.modality_lengths + return LengthGroupedSampler( + # self.args.train_batch_size * self.args.gradient_accumulation_steps, # TODO: seems that we should not have gradient_accumulation_steps + self.args.train_batch_size, + world_size=self.args.world_size, + lengths=lengths, + group_by_modality=True, + ) + else: + return super()._get_train_sampler() + + def _save_checkpoint(self, model, trial, metrics=None): + if getattr(self.args, "tune_mm_mlp_adapter", False) or ( + hasattr(self.args, "mm_tunable_parts") and (len(self.args.mm_tunable_parts.split(",")) == 1 and ("mm_mlp_adapter" in self.args.mm_tunable_parts or "mm_vision_resampler" in self.args.mm_tunable_parts)) + ): + from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR + + checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" + + run_dir = self._get_output_dir(trial=trial) + output_dir = os.path.join(run_dir, checkpoint_folder) + + # Only save Adapter + keys_to_match = ["mm_projector", "vision_resampler"] + if getattr(self.args, "use_im_start_end", False): + keys_to_match.extend(["embed_tokens", "embed_in"]) + + weight_to_save = get_mm_adapter_state_maybe_zero_3(self.model.named_parameters(), keys_to_match) + + if self.args.local_rank == 0 or self.args.local_rank == -1: + self.model.config.save_pretrained(output_dir) + torch.save(weight_to_save, os.path.join(output_dir, f"mm_projector.bin")) + else: + # super(LLaVADPOTrainer, self)._save_checkpoint(model, trial, metrics) + # print(type(model)) + # from transformers.modeling_utils import unwrap_model + # print(type(unwrap_model(model))) + # print(unwrap_model(model).config) + if self.args.lora_enable: + from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR + + checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" + run_dir = self._get_output_dir(trial=trial) + output_dir = os.path.join(run_dir, checkpoint_folder) + from transformers.modeling_utils import unwrap_model + + unwrapped_model = unwrap_model(model) + self.save_my_lora_ckpt(output_dir, self.args, unwrapped_model) + else: + super(LLaVADPOTrainer, self)._save_checkpoint(model, trial, metrics) + + def _save(self, output_dir: Optional[str] = None, state_dict=None): + if getattr(self.args, "tune_mm_mlp_adapter", False): + pass + else: + super(LLaVADPOTrainer, self)._save(output_dir, state_dict) diff --git a/llava/train/llava_trainer_eval.py b/llava/train/llava_trainer_eval.py new file mode 100755 index 000000000..e82225852 --- /dev/null +++ b/llava/train/llava_trainer_eval.py @@ -0,0 +1,76 @@ +import json +import subprocess + +from llava.train.llava_trainer import LLaVATrainer + + +class LLaVAEvalTrainer(LLaVATrainer): + def evaluate(self, evaluate_args): + cmd = f"accelerate launch --num_processes {evaluate_args.eval_num_processes} -m lmms_eval \ + --model {evaluate_args.model} \ + --model_args {evaluate_args.model_args} \ + --tasks {evaluate_args.task_names} \ + --batch_size {evaluate_args.batch_size} \ + --log_samples_suffix {evaluate_args.log_samples_suffix} \ + --output_path {evaluate_args.output_path}" + if evaluate_args.limit: + cmd += f" --limit {evaluate_args.limit}" + if evaluate_args.num_fewshot: + cmd += f" --num_fewshot {evaluate_args.num_fewshot}" + if evaluate_args.gen_kwargs != "": + cmd += f" --gen_kwargs {evaluate_args.gen_kwargs}" + if evaluate_args.log_samples: + cmd += f" --log_samples" + else: + assert False, "Please log samples so that the result can be parsed" + results = subprocess.run([cmd], shell=True, capture_output=True, text=True) + try: + result_file_index_start = results.stdout.index("Saved samples to ") + result_file_index_end = results.stdout.index(f".json") + result_file_index_start += len("Saved samples to ") + file = results.stdout[result_file_index_start:result_file_index_end] + except: + result_file_index_start = results.stderr.index("Saved samples to ") + result_file_index_end = results.stderr.index(f".json") + result_file_index_start += len("Saved samples to ") + file = results.stderr[result_file_index_start:result_file_index_end] + file = file.split("/")[:-1] + file = "/".join(file) + "/results.json" + with open(file, "r") as f: + lmms_eval_results = json.load(f) + result_dict = {} + tasks_list = evaluate_args.task_names.split(",") + for task in tasks_list: + task_results = lmms_eval_results["results"][task] + for k, v in task_results.items(): + if k != "alias" and "stderr" not in k: + metric = k.split(",")[0] + result_dict[f"{task}_{metric}"] = v + return result_dict + + """def evaluate(self, evaluate_args): + initialize_tasks() + tasks_list = evaluate_args.task_names.split(",") + result_dict = {} + results = evaluator.simple_evaluate( + model=evaluate_args.model, + model_args=evaluate_args.model_args, + tasks=tasks_list, + num_fewshot=evaluate_args.num_fewshot, + batch_size=evaluate_args.batch_size, + device=evaluate_args.device, + limit=evaluate_args.limit, + check_integrity=evaluate_args.check_integrity, + show_task_to_terminal=evaluate_args.show_task_to_terminal, + log_samples=evaluate_args.log_samples, + gen_kwargs=evaluate_args.gen_kwargs, + cli_args=evaluate_args, + ) + for task in tasks_list: + task_results = results["results"][task] + for k,v in task_results.items(): + if k != "alias" and "stderr" not in k: + metric = k.split(",")[0] + result_dict[f"{task}_{metric}"] = v + + return result_dict""" diff --git a/llava/train/train.py b/llava/train/train.py new file mode 100755 index 000000000..c342d88e9 --- /dev/null +++ b/llava/train/train.py @@ -0,0 +1,1721 @@ +# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: +# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: +# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import ast +import os +import copy +from dataclasses import dataclass, field +import json +import logging +import pathlib +from typing import Dict, Optional, Sequence, List +from PIL import Image, ImageFile +from packaging import version +import numpy as np + +import time +import random +import yaml +import math +import re +import torch + +import transformers +import tokenizers +import deepspeed + +from transformers import AutoConfig +from torch.utils.data import Dataset +from llava.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IMAGE_TOKEN_INDEX +from llava.train.llava_trainer import LLaVATrainer + +from llava import conversation as conversation_lib +from llava.model import * +from llava.mm_utils import process_highres_image, process_anyres_image, process_highres_image_crop_split, tokenizer_image_token +from llava.utils import rank0_print, process_video_with_pyav, process_video_with_decord + +torch.multiprocessing.set_sharing_strategy("file_system") + +ImageFile.LOAD_TRUNCATED_IMAGES = True +local_rank = None + +IS_TOKENIZER_GREATER_THAN_0_14 = version.parse(tokenizers.__version__) >= version.parse("0.14") + + +@dataclass +class ModelArguments: + model_name_or_path: Optional[str] = field(default="facebook/opt-125m") + model_class_name: Optional[str] = field(default=None, metadata={"help": "Used to init model class, format is XXXXForCausalLM. e.g. currently XXXX is chosen from LlavaLlama, LlavaMixtral, LlavaMistral, Llama"}) + + mm_tunable_parts: Optional[str] = field( + default=None, metadata={"help": 'Could be "mm_mlp_adapter", "mm_vision_resampler", "mm_vision_tower,mm_mlp_adapter,mm_language_model", "mm_vision_tower,mm_mlp_adapter,mm_language_model", "mm_mlp_adapter,mm_language_model"'} + ) + # deciding which part of the multimodal model to tune, will overwrite other previous settings + + version: Optional[str] = field(default="v0") + freeze_backbone: bool = field(default=False) + tune_mm_mlp_adapter: bool = field(default=False) + tune_mm_vision_resampler: bool = field(default=False) + vision_tower: Optional[str] = field(default=None) + vision_tower_pretrained: Optional[str] = field(default=None) # default to the last layer + + unfreeze_mm_vision_tower: bool = field(default=False) + unfreeze_language_model: bool = field(default=False) + mm_vision_select_layer: Optional[int] = field(default=-1) # default to the last layer + pretrain_mm_mlp_adapter: Optional[str] = field(default=None) + mm_projector_type: Optional[str] = field(default="linear") + mm_use_im_start_end: bool = field(default=False) + mm_use_im_patch_token: bool = field(default=True) + mm_patch_merge_type: Optional[str] = field(default="flat") + mm_vision_select_feature: Optional[str] = field(default="patch") + mm_resampler_type: Optional[str] = field(default=None) + mm_mask_drop_mode: str = field(default="fixed") + mm_mask_drop_skip_percentage: float = field(default=0.0) + mm_mask_drop_ratio: float = field(default=0.25) + mm_mask_drop_ratio_upper: Optional[float] = field(default=None) + mm_mask_drop_ratio_lower: Optional[float] = field(default=None) + mm_spatial_pool_stride: Optional[int] = field(default=None) + mm_spatial_pool_mode: str = field(default="bilinear") + mm_spatial_pool_out_channels: Optional[int] = field(default=None) + mm_perceiver_depth: Optional[int] = field(default=3) + mm_perceiver_latents: Optional[int] = field(default=32) + mm_perceiver_ff_mult: Optional[float] = field(default=4) + mm_perceiver_pretrained: Optional[str] = field(default=None) + mm_qformer_depth: Optional[int] = field(default=3) + mm_qformer_latents: Optional[int] = field(default=32) + mm_qformer_pretrained: Optional[str] = field(default=None) + + rope_scaling_factor: Optional[float] = field(default=None) + rope_scaling_type: Optional[str] = field(default=None) + + s2: Optional[bool] = field(default=False) + s2_scales: Optional[str] = field(default="336,672,1008") + + use_pos_skipping: Optional[bool] = field(default=False) + pos_skipping_range: Optional[int] = field(default=4096) + + + mm_newline_position: Optional[str] = field(default="grid") + delay_load: Optional[bool] = field(default=True) + add_faster_video: Optional[bool] = field(default=False) + faster_token_stride: Optional[int] = field(default=10) + + + +@dataclass +class DataArguments: + data_path: str = field(default=None, metadata={"help": "Path to the training data, in llava's instruction.json format. Supporting multiple json files via /path/to/{a,b,c}.json"}) + lazy_preprocess: bool = False + is_multimodal: bool = False + early_mix_text: bool = False + image_folder: Optional[str] = field(default=None) + image_aspect_ratio: str = "square" + image_grid_pinpoints: Optional[str] = field(default=None) + image_crop_resolution: Optional[int] = field(default=None) + image_split_resolution: Optional[int] = field(default=None) + + video_folder: Optional[str] = field(default=None) + video_fps: Optional[int] = field(default=1) + frames_upbound: Optional[int] = field(default=0) + add_time_instruction: Optional[bool] = field(default=False) + force_sample: Optional[bool] = field(default=False) + + +@dataclass +class TrainingArguments(transformers.TrainingArguments): + cache_dir: Optional[str] = field(default=None) + optim: str = field(default="adamw_torch") + remove_unused_columns: bool = field(default=False) + freeze_mm_mlp_adapter: bool = field(default=False) + freeze_mm_vision_resampler: bool = field(default=False) + mpt_attn_impl: Optional[str] = field(default="triton") + model_max_length: int = field( + default=4096, + metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."}, + ) + double_quant: bool = field(default=True, metadata={"help": "Compress the quantization statistics through double quantization."}) + quant_type: str = field(default="nf4", metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."}) + bits: int = field(default=16, metadata={"help": "How many bits to use."}) + lora_enable: bool = False + lora_r: int = 64 + lora_alpha: int = 16 + lora_dropout: float = 0.05 + lora_weight_path: str = "" + lora_bias: str = "none" + mm_projector_lr: Optional[float] = None + mm_vision_tower_lr: Optional[float] = None + group_by_varlen: bool = field(default=False) + group_by_modality_length: bool = field(default=False) + group_by_modality_length_auto: bool = field(default=False) + auto_find_batch_size: bool = field(default=False) + gradient_checkpointing: bool = field(default=True) + verbose_logging: bool = field(default=False) + attn_implementation: str = field(default="flash_attention_2", metadata={"help": "Use transformers attention implementation."}) + + +# @dataclass +# class EvaluationArguments: +# eval_num_processes: int = field(default=1) +# task_names: str = field(default=None) +# model: str = field(default="llava") +# model_args: Optional[str] = field(default=None) +# num_fewshot: Optional[int] = field(default=None) +# batch_size: int = field(default=1) +# device: Optional[str] = field(default=None) +# limit: Optional[int] = field(default=None) +# check_integrity: Optional[bool] = field(default=False) +# show_task_to_terminal: Optional[bool] = field(default=False) +# log_samples: Optional[bool] = field(default=True) +# gen_kwargs: Optional[str] = field(default="") +# log_samples_suffix: Optional[str] = field(default="") +# output_path: Optional[str] = field(default="./logs/") + + +def maybe_zero_3(param, ignore_status=False, name=None): + from deepspeed import zero + from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus + + if hasattr(param, "ds_id"): + if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: + if not ignore_status: + logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}") + with zero.GatheredParameters([param]): + param = param.data.detach().cpu().clone() + else: + param = param.detach().cpu().clone() + return param + + +# Borrowed from peft.utils.get_peft_model_state_dict +def get_peft_state_maybe_zero_3(named_params, bias): + if bias == "none": + to_return = {k: t for k, t in named_params if "lora_" in k} + elif bias == "all": + to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k} + elif bias == "lora_only": + to_return = {} + maybe_lora_bias = {} + lora_bias_names = set() + for k, t in named_params: + if "lora_" in k: + to_return[k] = t + bias_name = k.split("lora_")[0] + "bias" + lora_bias_names.add(bias_name) + elif "bias" in k: + maybe_lora_bias[k] = t + for k, t in maybe_lora_bias: + if bias_name in lora_bias_names: + to_return[bias_name] = t + else: + raise NotImplementedError + to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()} + return to_return + + +def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True): + to_return = {k: t for k, t in named_params if "lora_" not in k} + if require_grad_only: + to_return = {k: t for k, t in to_return.items() if t.requires_grad} + to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} + return to_return + + +def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): + to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)} + to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} + return to_return + + +def find_all_linear_names(model): + cls = torch.nn.Linear + lora_module_names = set() + multimodal_keywords = ["mm_projector", "vision_tower", "vision_resampler"] + for name, module in model.named_modules(): + if any(mm_keyword in name for mm_keyword in multimodal_keywords): + continue + if isinstance(module, cls): + names = name.split(".") + lora_module_names.add(names[0] if len(names) == 1 else names[-1]) + + if "lm_head" in lora_module_names: # needed for 16-bit + lora_module_names.remove("lm_head") + return list(lora_module_names) + + +def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str): + """Collects the state dict and dump to disk.""" + if hasattr(trainer.args, "tune_mm_mlp_adapter") and trainer.args.tune_mm_mlp_adapter: + check_only_save_mm_adapter_tunnable = True + # only has mm_mlp_adapter and mm_vision_resampler in the tuneable parts + elif hasattr(trainer.args, "mm_tunable_parts") and (len(trainer.args.mm_tunable_parts.split(",")) == 1 and ("mm_mlp_adapter" in trainer.args.mm_tunable_parts or "mm_vision_resampler" in trainer.args.mm_tunable_parts)): + check_only_save_mm_adapter_tunnable = True + else: + check_only_save_mm_adapter_tunnable = False + + trainer.accelerator.wait_for_everyone() + torch.cuda.synchronize() + rank0_print(f"Only save projectors: {check_only_save_mm_adapter_tunnable}") + if check_only_save_mm_adapter_tunnable: + # Only save Adapter + keys_to_match = ["mm_projector", "vision_resampler"] + if getattr(trainer.args, "use_im_start_end", False): + keys_to_match.extend(["embed_tokens", "embed_in"]) + + weight_to_save = get_mm_adapter_state_maybe_zero_3(trainer.model.named_parameters(), keys_to_match) + trainer.model.config.save_pretrained(output_dir) + + current_folder = output_dir.split("/")[-1] + parent_folder = os.path.dirname(output_dir) + if trainer.args.local_rank == 0 or trainer.args.local_rank == -1: + if current_folder.startswith("checkpoint-"): + mm_projector_folder = os.path.join(parent_folder, "mm_projector") + os.makedirs(mm_projector_folder, exist_ok=True) + torch.save(weight_to_save, os.path.join(mm_projector_folder, f"{current_folder}.bin")) + else: + torch.save(weight_to_save, os.path.join(output_dir, f"mm_projector.bin")) + return + + if trainer.deepspeed: + trainer.save_model(output_dir) + return + + state_dict = trainer.model.state_dict() + if trainer.args.should_save: + cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()} + del state_dict + trainer._save(output_dir, state_dict=cpu_state_dict) # noqa + + +def smart_tokenizer_and_embedding_resize( + special_tokens_dict: Dict, + tokenizer: transformers.PreTrainedTokenizer, + model: transformers.PreTrainedModel, +): + """Resize tokenizer and embedding. + + Note: This is the unoptimized version that may make your embedding size not be divisible by 64. + """ + num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) + model.resize_token_embeddings(len(tokenizer)) + + if num_new_tokens > 0: + input_embeddings = model.get_input_embeddings().weight.data + output_embeddings = model.get_output_embeddings().weight.data + + input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) + output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) + + input_embeddings[-num_new_tokens:] = input_embeddings_avg + output_embeddings[-num_new_tokens:] = output_embeddings_avg + + +def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict: + """Tokenize a list of strings.""" + tokenized_list = [ + tokenizer( + text, + return_tensors="pt", + padding="longest", + max_length=tokenizer.model_max_length, + truncation=True, + ) + for text in strings + ] + input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list] + input_ids_lens = labels_lens = [tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list] + return dict( + input_ids=input_ids, + labels=labels, + input_ids_lens=input_ids_lens, + labels_lens=labels_lens, + ) + + +def _mask_targets(target, tokenized_lens, speakers): + # cur_idx = 0 + cur_idx = tokenized_lens[0] + tokenized_lens = tokenized_lens[1:] + target[:cur_idx] = IGNORE_INDEX + for tokenized_len, speaker in zip(tokenized_lens, speakers): + if speaker == "human": + target[cur_idx + 2 : cur_idx + tokenized_len] = IGNORE_INDEX + cur_idx += tokenized_len + + +def _add_speaker_and_signal(header, source, get_conversation=True): + """Add speaker and start/end signal on each round.""" + BEGIN_SIGNAL = "### " + END_SIGNAL = "\n" + conversation = header + for sentence in source: + from_str = sentence["from"] + if from_str.lower() == "human": + from_str = conversation_lib.default_conversation.roles[0] + elif from_str.lower() == "gpt": + from_str = conversation_lib.default_conversation.roles[1] + else: + from_str = "unknown" + sentence["value"] = BEGIN_SIGNAL + from_str + ": " + sentence["value"] + END_SIGNAL + if get_conversation: + conversation += sentence["value"] + conversation += BEGIN_SIGNAL + return conversation + + +def preprocess_multimodal(sources: Sequence[str], data_args: DataArguments) -> Dict: + is_multimodal = data_args.is_multimodal + if not is_multimodal: + return sources + + for source in sources: + for sentence in source: + # TODO maybe this should be changed for interleaved data? + # if DEFAULT_IMAGE_TOKEN in sentence["value"] and not sentence["value"].startswith(DEFAULT_IMAGE_TOKEN): + # only check for num_im=1 + num_im = len(re.findall(DEFAULT_IMAGE_TOKEN, sentence["value"])) + if num_im == 1 and DEFAULT_IMAGE_TOKEN in sentence["value"] and not sentence["value"].startswith(DEFAULT_IMAGE_TOKEN): + sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, "").strip() + sentence["value"] = DEFAULT_IMAGE_TOKEN + "\n" + sentence["value"] + sentence["value"] = sentence["value"].strip() + if "mmtag" in conversation_lib.default_conversation.version: + sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, "" + DEFAULT_IMAGE_TOKEN + "") + replace_token = DEFAULT_IMAGE_TOKEN + if data_args.mm_use_im_start_end: + replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN + sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token) + + # For videoInstruct-100k noisy_data. TODO: Ask Yuanhan to clean the data instead of leaving the noise code here. + sentence["value"] = sentence["value"].replace("QA_GT_caption_based_noisy", "") + + return sources + + +def preprocess_llama_2(sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False) -> Dict: + conv = conversation_lib.default_conversation.copy() + roles = {"human": conv.roles[0], "gpt": conv.roles[1]} + + # Apply prompt templates + conversations = [] + for i, source in enumerate(sources): + if roles[source[0]["from"]] != conv.roles[0]: + # Skip the first one if it is not from human + source = source[1:] + + conv.messages = [] + for j, sentence in enumerate(source): + role = roles[sentence["from"]] + assert role == conv.roles[j % 2], f"{i}" + conv.append_message(role, sentence["value"]) + conversations.append(conv.get_prompt()) + + # Tokenize conversations + + if has_image: + input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations], dim=0) + else: + input_ids = tokenizer( + conversations, + return_tensors="pt", + padding="longest", + max_length=tokenizer.model_max_length, + truncation=True, + ).input_ids + + targets = input_ids.clone() + + assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2 + + # Mask targets + sep = "[/INST] " + for conversation, target in zip(conversations, targets): + total_len = int(target.ne(tokenizer.pad_token_id).sum()) + + rounds = conversation.split(conv.sep2) + cur_len = 1 + target[:cur_len] = IGNORE_INDEX + for i, rou in enumerate(rounds): + if rou == "": + break + + parts = rou.split(sep) + if len(parts) != 2: + break + parts[0] += sep + + if has_image: + round_len = len(tokenizer_image_token(rou, tokenizer)) + instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2 + else: + round_len = len(tokenizer(rou).input_ids) + instruction_len = len(tokenizer(parts[0]).input_ids) - 2 + + target[cur_len : cur_len + instruction_len] = IGNORE_INDEX + + cur_len += round_len + target[cur_len:] = IGNORE_INDEX + + if cur_len < tokenizer.model_max_length: + if cur_len != total_len: + target[:] = IGNORE_INDEX + print(f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)") + + return dict( + input_ids=input_ids, + labels=targets, + ) + + +def preprocess_gemma(sources: List[List[Dict[str, str]]], tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False) -> Dict: + conv: conversation_lib.Conversation = conversation_lib.default_conversation.copy() + roles: Dict[str, str] = {"human": conv.roles[0], "gpt": conv.roles[1]} + + # Apply prompt templates + conversations: List[str] = [] + for i, source in enumerate(sources): + if roles[source[0]["from"]] != conv.roles[0]: + # Skip the first one if it is not from human + source: List[Dict[str, str]] = source[1:] + + conv.messages = [] + for j, sentence in enumerate(source): + role: str = roles[sentence["from"]] + assert role == conv.roles[j % 2], f"{i}" + conv.append_message(role, sentence["value"]) + conversations.append(conv.get_prompt()) + + # Tokenize conversations + if has_image: + input_ids: torch.Tensor = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations], dim=0) + else: + input_ids: torch.Tensor = tokenizer( + conversations, + return_tensors="pt", + padding="longest", + max_length=tokenizer.model_max_length, + truncation=True, + ).input_ids + + targets: torch.Tensor = input_ids.clone() + assert conv.sep_style == conversation_lib.SeparatorStyle.GEMMA + + # Mask target + sep: str = conv.sep + conv.roles[1] + for conversation, target in zip(conversations, targets): + total_len: int = int(target.ne(tokenizer.pad_token_id).sum()) + + rounds: List[str] = conversation.split(conv.sep) + re_rounds = [] + for conv_idx in range(0, len(rounds), 2): + re_rounds.append(conv.sep.join(rounds[conv_idx : conv_idx + 2])) + + cur_len = 1 # Ignore + target[:cur_len] = IGNORE_INDEX + for i, rou in enumerate(re_rounds): + if rou == "": + break + + parts = rou.split(sep) + if len(parts) != 2: + break + parts[0] += sep # Re-append sep because split on this + # Now "".join(parts)==rou + + if has_image: + round_len = len(tokenizer_image_token(rou, tokenizer)) - 1 # Ignore + instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 1 # Ignore + else: + round_len = len(tokenizer(rou).input_ids) - 1 # Ignore + instruction_len = len(tokenizer(parts[0]).input_ids) - 1 # Ignore + + round_len += 2 # sep: \n takes 2 tokens + target[cur_len : cur_len + instruction_len] = IGNORE_INDEX + cur_len += round_len + + target[cur_len:] = IGNORE_INDEX + + if cur_len < tokenizer.model_max_length: + if cur_len != total_len: + target[:] = IGNORE_INDEX + print(f"warning: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)") + + return dict( + input_ids=input_ids, + labels=targets, + ) + + +def preprocess_qwen(sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False, max_len=2048, system_message: str = "You are a helpful assistant.") -> Dict: + # roles = {"human": "<|im_start|>user", "gpt": "<|im_start|>assistant"} + roles = {"human": "user", "gpt": "assistant"} + + # Add image tokens to tokenizer as a special tokens + # Use a deepcopy of tokenizer so that we don't modify on the tokenizer + tokenizer = copy.deepcopy(tokenizer) + # When there is actually an image, we add the image tokens as a special token + if has_image: + tokenizer.add_tokens([""], special_tokens=True) + + image_token_index = tokenizer.convert_tokens_to_ids("") + im_start, im_end = tokenizer.additional_special_tokens_ids + # unmask_tokens = ["<|im_start|>", "<|im_start|>", "\n"] + unmask_tokens_idx = [198, im_start, im_end] + nl_tokens = tokenizer("\n").input_ids + + # Reset Qwen chat templates so that it won't include system message every time we apply + chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" + tokenizer.chat_template = chat_template + + # _system = tokenizer("system").input_ids + nl_tokens + # _user = tokenizer("user").input_ids + nl_tokens + # _assistant = tokenizer("assistant").input_ids + nl_tokens + + # Apply prompt templates + input_ids, targets = [], [] + for i, source in enumerate(sources): + if roles[source[0]["from"]] != roles["human"]: + source = source[1:] + + input_id, target = [], [] + + # New version, use apply chat template + # Build system message for each sentence + input_id += tokenizer.apply_chat_template([{"role" : "system", "content" : system_message}]) + target += [IGNORE_INDEX] * len(input_id) + + for conv in source: + # Make sure llava data can load + try: + role = conv["role"] + content = conv["content"] + except: + role = conv["from"] + content = conv["value"] + + role = roles.get(role, role) + + conv = [{"role" : role, "content" : content}] + encode_id = tokenizer.apply_chat_template(conv) + input_id += encode_id + if role in ["user", "system"]: + target += [IGNORE_INDEX] * len(encode_id) + else: + target += encode_id + + + + assert len(input_id) == len(target), f"{len(input_id)} != {len(target)}" + for idx, encode_id in enumerate(input_id): + if encode_id in unmask_tokens_idx: + target[idx] = encode_id + if encode_id == image_token_index: + input_id[idx] = IMAGE_TOKEN_INDEX + input_ids.append(input_id) + targets.append(target) + input_ids = torch.tensor(input_ids, dtype=torch.long) + targets = torch.tensor(targets, dtype=torch.long) + + return dict( + input_ids=input_ids, # tensor(bs x seq_len) + labels=targets, # tensor(bs x seq_len) + ) + + +def preprocess_llama3( + sources, + tokenizer: transformers.PreTrainedTokenizer, + has_image: bool = False, + max_len=2048, + system_message: str = "You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.", +) -> Dict: + # roles = {"human": "<|start_header_id|>user<|end_header_id|>", "gpt": "<|start_header_id|>assistant<|end_header_id|>"} + roles = {"human": "user", "gpt": "assistant"} + + # Add image tokens to tokenizer as a special tokens + # Use a deepcopy of tokenizer so that we don't modify on the tokenizer + tokenizer = copy.deepcopy(tokenizer) + # When there is actually an image, we add the image tokens as a special token + if has_image: + tokenizer.add_tokens([""], special_tokens=True) + image_token_index = tokenizer.convert_tokens_to_ids("") + bos_token_id = tokenizer.convert_tokens_to_ids("<|begin_of_text|>") + start_header_id = tokenizer.convert_tokens_to_ids("<|start_header_id|>") + end_header_id = tokenizer.convert_tokens_to_ids("<|end_header_id|>") + eot_id = tokenizer.convert_tokens_to_ids("<|eot_id|>") + + unmask_tokens = ["<|begin_of_text|>", "<|start_header_id|>", "<|end_header_id|>", "<|eot_id|>", "\n\n"] + unmask_tokens_idx = [tokenizer.convert_tokens_to_ids(tok) for tok in unmask_tokens] + + # After update, calling tokenizer of llama3 will + # auto add bos id for the tokens. ヽ(`⌒´)ノ + def safe_tokenizer_llama3(text): + input_ids = tokenizer(text).input_ids + if input_ids[0] == bos_token_id: + input_ids = input_ids[1:] + return input_ids + + nl_tokens = tokenizer.convert_tokens_to_ids("\n\n") + # Apply prompt templates + input_ids, targets = [], [] + for i, source in enumerate(sources): + if roles[source[0]["from"]] != roles["human"]: + source = source[1:] + + input_id, target = [], [] + + # New version, use apply chat template + # Build system message for each sentence + input_id += tokenizer.apply_chat_template([{"role" : "system", "content" : system_message}]) + target += [IGNORE_INDEX] * len(input_id) + + for conv in source: + # Make sure llava data can load + try: + role = conv["role"] + content = conv["content"] + except: + role = conv["from"] + content = conv["value"] + + role = roles.get(role, role) + + conv = [{"role" : role, "content" : content}] + # First is bos token we don't need here + encode_id = tokenizer.apply_chat_template(conv)[1:] + input_id += encode_id + if role in ["user", "system"]: + target += [IGNORE_INDEX] * len(encode_id) + else: + target += encode_id + + + + assert len(input_id) == len(target), f"{len(input_id)} != {len(target)}" + for idx, encode_id in enumerate(input_id): + if encode_id in unmask_tokens_idx: + target[idx] = encode_id + if encode_id == image_token_index: + input_id[idx] = IMAGE_TOKEN_INDEX + input_ids.append(input_id) + targets.append(target) + input_ids = torch.tensor(input_ids, dtype=torch.long) + targets = torch.tensor(targets, dtype=torch.long) + + return dict( + input_ids=input_ids, # tensor(bs x seq_len) + labels=targets, # tensor(bs x seq_len) + ) + + +def preprocess_v1(sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False) -> Dict: + conv = conversation_lib.default_conversation.copy() + roles = {"human": conv.roles[0], "gpt": conv.roles[1]} + + # Apply prompt templates + conversations = [] + for i, source in enumerate(sources): + if roles[source[0]["from"]] != conv.roles[0]: + # Skip the first one if it is not from human + source = source[1:] + + conv.messages = [] + for j, sentence in enumerate(source): + role = roles[sentence["from"]] + assert role == conv.roles[j % 2], f"{i}" + conv.append_message(role, sentence["value"]) + conversations.append(conv.get_prompt()) + + # Tokenize conversations + + if has_image: + input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations], dim=0) + else: + input_ids = tokenizer( + conversations, + return_tensors="pt", + padding="longest", + max_length=tokenizer.model_max_length, + truncation=True, + ).input_ids + + targets = input_ids.clone() + + assert conv.sep_style == conversation_lib.SeparatorStyle.TWO + + # Mask targets + sep = conv.sep + conv.roles[1] + ": " + for conversation, target in zip(conversations, targets): + total_len = int(target.ne(tokenizer.pad_token_id).sum()) + + rounds = conversation.split(conv.sep2) + cur_len = 1 + target[:cur_len] = IGNORE_INDEX + for i, rou in enumerate(rounds): + if rou == "": + break + + parts = rou.split(sep) + if len(parts) != 2: + break + parts[0] += sep + + if has_image: + round_len = len(tokenizer_image_token(rou, tokenizer)) + instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2 + else: + round_len = len(tokenizer(rou).input_ids) + instruction_len = len(tokenizer(parts[0]).input_ids) - 2 + + if i != 0 and not tokenizer.legacy and IS_TOKENIZER_GREATER_THAN_0_14: + round_len -= 1 + instruction_len -= 1 + + target[cur_len : cur_len + instruction_len] = IGNORE_INDEX + + cur_len += round_len + target[cur_len:] = IGNORE_INDEX + + if cur_len < tokenizer.model_max_length: + if cur_len != total_len: + target[:] = IGNORE_INDEX + print(f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)") + + return dict( + input_ids=input_ids, + labels=targets, + ) + + +def preprocess_mpt(sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False) -> Dict: + conv = conversation_lib.default_conversation.copy() + roles = {"human": conv.roles[0], "gpt": conv.roles[1]} + + # Apply prompt templates + conversations = [] + for i, source in enumerate(sources): + if roles[source[0]["from"]] != conv.roles[0]: + # Skip the first one if it is not from human + source = source[1:] + + conv.messages = [] + for j, sentence in enumerate(source): + role = roles[sentence["from"]] + assert role == conv.roles[j % 2], f"{i}" + conv.append_message(role, sentence["value"]) + conversations.append(conv.get_prompt()) + + # Tokenize conversations + + if has_image: + input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations], dim=0) + else: + input_ids = tokenizer( + conversations, + return_tensors="pt", + padding="longest", + max_length=tokenizer.model_max_length, + truncation=True, + ).input_ids + + targets = input_ids.clone() + assert conv.sep_style == conversation_lib.SeparatorStyle.MPT + + # Mask targets + sep = conv.sep + conv.roles[1] + for conversation, target in zip(conversations, targets): + total_len = int(target.ne(tokenizer.pad_token_id).sum()) + + rounds = conversation.split(conv.sep) + re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt + for conv_idx in range(3, len(rounds), 2): + re_rounds.append(conv.sep.join(rounds[conv_idx : conv_idx + 2])) # user + gpt + cur_len = 1 + target[:cur_len] = IGNORE_INDEX + for i, rou in enumerate(re_rounds): + if rou == "": + break + + parts = rou.split(sep) + if len(parts) != 2: + break + parts[0] += sep + + if has_image: + round_len = len(tokenizer_image_token(rou, tokenizer)) + instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 1 + else: + round_len = len(tokenizer(rou).input_ids) + instruction_len = len(tokenizer(parts[0]).input_ids) - 1 + + if i != 0 and getattr(tokenizer, "legacy", False) and IS_TOKENIZER_GREATER_THAN_0_14: + round_len += 1 + instruction_len += 1 + + target[cur_len : cur_len + instruction_len] = IGNORE_INDEX + + cur_len += round_len + target[cur_len:] = IGNORE_INDEX + + if cur_len < tokenizer.model_max_length: + if cur_len != total_len: + target[:] = IGNORE_INDEX + print(f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f"(#turns={len(re_rounds)} ignored)") + + return dict( + input_ids=input_ids, + labels=targets, + ) + + +def preprocess_plain( + sources: Sequence[str], + tokenizer: transformers.PreTrainedTokenizer, +) -> Dict: + # add end signal and concatenate together + conversations = [] + for source in sources: + assert len(source) == 2 + assert DEFAULT_IMAGE_TOKEN in source[0]["value"] + source[0]["value"] = DEFAULT_IMAGE_TOKEN + conversation = source[0]["value"] + source[1]["value"] + conversation_lib.default_conversation.sep + conversations.append(conversation) + # tokenize conversations + input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations] + targets = copy.deepcopy(input_ids) + for target, source in zip(targets, sources): + tokenized_len = len(tokenizer_image_token(source[0]["value"], tokenizer)) + target[:tokenized_len] = IGNORE_INDEX + + return dict(input_ids=input_ids, labels=targets) + + +def preprocess(sources: Sequence[str], tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False) -> Dict: + """ + Given a list of sources, each is a conversation list. This transform: + 1. Add signal '### ' at the beginning each sentence, with end signal '\n'; + 2. Concatenate conversations together; + 3. Tokenize the concatenated conversation; + 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX. + """ + if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN: + return preprocess_plain(sources, tokenizer) + if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.LLAMA_2: + return preprocess_llama_2(sources, tokenizer, has_image=has_image) + if conversation_lib.default_conversation.version.startswith("v1"): + return preprocess_v1(sources, tokenizer, has_image=has_image) + if conversation_lib.default_conversation.version == "mpt": + return preprocess_mpt(sources, tokenizer, has_image=has_image) + if conversation_lib.default_conversation.version == "qwen": + return preprocess_qwen(sources, tokenizer, has_image=has_image) + if conversation_lib.default_conversation.version == "gemma": + return preprocess_gemma(sources, tokenizer, has_image=has_image) + if conversation_lib.default_conversation.version == "llama_v3": + return preprocess_llama3(sources, tokenizer, has_image=has_image) + # add end signal and concatenate together + conversations = [] + for source in sources: + header = f"{conversation_lib.default_conversation.system}\n\n" + conversation = _add_speaker_and_signal(header, source) + conversations.append(conversation) + + # tokenize conversations + def get_tokenize_len(prompts): + return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts] + + if has_image: + input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations] + else: + conversations_tokenized = _tokenize_fn(conversations, tokenizer) + input_ids = conversations_tokenized["input_ids"] + + targets = copy.deepcopy(input_ids) + for target, source in zip(targets, sources): + if has_image: + tokenized_lens = get_tokenize_len([header] + [s["value"] for s in source]) + else: + tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source], tokenizer)["input_ids_lens"] + speakers = [sentence["from"] for sentence in source] + _mask_targets(target, tokenized_lens, speakers) + + return dict(input_ids=input_ids, labels=targets) + + +class LazySupervisedDataset(Dataset): + def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer, data_args: DataArguments): + super(LazySupervisedDataset, self).__init__() + self.tokenizer = tokenizer + self.list_data_dict = [] + + # Handle multiple JSON files specified in the data_path + if "{" in data_path and "}" in data_path: + base_path, file_pattern = re.match(r"^(.*)\{(.*)\}\.json$", data_path).groups() + file_names = file_pattern.split(",") + rank0_print(f"Loading {file_names} from {base_path}") + data_args.dataset_paths = [] + for file_name in file_names: + data_args.dataset_paths.append(f"{base_path}{file_name}.json") + full_path = f"{base_path}{file_name}.json" + rank0_print(f"Loading {full_path}") + with open(full_path, "r") as file: + cur_data_dict = json.load(file) + rank0_print(f"Loaded {len(cur_data_dict)} samples from {full_path}") + self.list_data_dict.extend(cur_data_dict) + elif data_path.endswith(".yaml"): + with open(data_path, "r") as file: + yaml_data = yaml.safe_load(file) + datasets = yaml_data.get("datasets") + # file should be in the format of: + # datasets: + # - json_path: xxxx1.json + # sampling_strategy: first:1000 + # - json_path: xxxx2.json + # sampling_strategy: end:3000 + # - json_path: xxxx3.json + # sampling_strategy: random:999 + data_args.dataset_paths = [dataset.get("json_path") for dataset in datasets] + for dataset in datasets: + json_path = dataset.get("json_path") + sampling_strategy = dataset.get("sampling_strategy", "all") + sampling_number = None + + rank0_print(f"Loading {json_path} with {sampling_strategy} sampling strategy") + + if json_path.endswith(".jsonl"): + cur_data_dict = [] + with open(json_path, "r") as json_file: + for line in json_file: + cur_data_dict.append(json.loads(line.strip())) + elif json_path.endswith(".json"): + with open(json_path, "r") as json_file: + cur_data_dict = json.load(json_file) + else: + raise ValueError(f"Unsupported file type: {json_path}") + + if ":" in sampling_strategy: + sampling_strategy, sampling_number = sampling_strategy.split(":") + if "%" in sampling_number: + sampling_number = math.ceil(int(sampling_number.split("%")[0]) * len(cur_data_dict) / 100) + else: + sampling_number = int(sampling_number) + + # Apply the sampling strategy + if sampling_strategy == "first" and sampling_number is not None: + cur_data_dict = cur_data_dict[:sampling_number] + elif sampling_strategy == "end" and sampling_number is not None: + cur_data_dict = cur_data_dict[-sampling_number:] + elif sampling_strategy == "random" and sampling_number is not None: + random.shuffle(cur_data_dict) + cur_data_dict = cur_data_dict[:sampling_number] + + rank0_print(f"Loaded {len(cur_data_dict)} samples from {json_path}") + self.list_data_dict.extend(cur_data_dict) + else: + data_args.dataset_paths = [data_path] + rank0_print(f"Loading {data_path}") + with open(data_path, "r") as file: + cur_data_dict = json.load(file) + rank0_print(f"Loaded {len(cur_data_dict)} samples from {data_path}") + self.list_data_dict.extend(cur_data_dict) + + rank0_print(f"Loaded {len(self.list_data_dict)} samples from {data_path}") + rank0_print("Formatting inputs...Skip in lazy mode") + self.tokenizer = tokenizer + self.data_args = data_args + + def __len__(self): + return len(self.list_data_dict) + + @property + def lengths(self): + length_list = [] + for sample in self.list_data_dict: + img_tokens = 128 if "image" in sample else 0 + length_list.append(sum(len(conv["value"].split()) for conv in sample["conversations"]) + img_tokens) + return length_list + + @property + def modality_lengths(self): + length_list = [] + for sample in self.list_data_dict: + cur_len = sum(len(conv["value"].split()) for conv in sample["conversations"]) + assert cur_len > 0, f"Conversation length is 0 for {sample}" + if "image" in sample or "video" in sample or self.data_args.early_mix_text: + length_list.append(cur_len) + else: + length_list.append(-cur_len) + return length_list + + def process_image(self, image_file, overwrite_image_aspect_ratio=None): + image_folder = self.data_args.image_folder + processor = self.data_args.image_processor + # print(f"\n\nInspecting the image path, folder = {image_folder}, image={image_file}\n\n") + try: + image = Image.open(os.path.join(image_folder, image_file)).convert("RGB") + except Exception as exn: + print(f"Failed to open image {image_file}. Exception:", exn) + raise exn + + image_size = image.size + image_aspect_ratio = self.data_args.image_aspect_ratio + if overwrite_image_aspect_ratio is not None: + image_aspect_ratio = overwrite_image_aspect_ratio + if image_aspect_ratio == "highres": + image = process_highres_image(image, self.data_args.image_processor, self.data_args.image_grid_pinpoints) + elif image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio: + image = process_anyres_image(image, self.data_args.image_processor, self.data_args.image_grid_pinpoints) + elif image_aspect_ratio == "crop_split": + image = process_highres_image_crop_split(image, self.data_args) + elif image_aspect_ratio == "pad": + + def expand2square(pil_img, background_color): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + + image = expand2square(image, tuple(int(x * 255) for x in processor.image_mean)) + image = processor.preprocess(image, return_tensors="pt")["pixel_values"][0] + else: + image = processor.preprocess(image, return_tensors="pt")["pixel_values"][0] + return image, image_size, "image" + + def __getitem__(self, i) -> Dict[str, torch.Tensor]: + # TODO: define number of retries somewhere else + num_base_retries = 3 + num_final_retries = 300 + + # try the current sample first + for attempt_idx in range(num_base_retries): + try: + sample = self._get_item(i) + return sample + except Exception as e: + # sleep 1s in case it is a cloud disk issue + print(f"[Try #{attempt_idx}] Failed to fetch sample {i}. Exception:", e) + time.sleep(1) + + # try other samples, in case it is file corruption issue + for attempt_idx in range(num_base_retries): + try: + next_index = min(i + 1, len(self.list_data_dict) - 1) + # sample_idx = random.choice(range(len(self))) + sample = self._get_item(next_index) + return sample + except Exception as e: + # no need to sleep + print(f"[Try other #{attempt_idx}] Failed to fetch sample {next_index}. Exception:", e) + pass + + try: + sample = self._get_item(i) + return sample + except Exception as e: + raise e + + def _get_item(self, i) -> Dict[str, torch.Tensor]: + sources = self.list_data_dict[i] + if isinstance(i, int): + sources = [sources] + assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME + + if "image" in sources[0]: + image_file = self.list_data_dict[i]["image"] + if type(image_file) is list: + image = [self.process_image(f) for f in image_file] + # Handling multi images + # overwrite to process with simple pad + if len(image_file) > 1: + image = [self.process_image(f, "pad") for f in image_file] + image = [[im[0], im[1], "image"] for im in image] + else: + image = [self.process_image(image_file)] + sources = preprocess_multimodal(copy.deepcopy([e["conversations"] for e in sources]), self.data_args) + + elif "video" in sources[0]: + video_file = self.list_data_dict[i]["video"] + video_folder = self.data_args.video_folder + video_file = os.path.join(video_folder, video_file) + suffix = video_file.split(".")[-1] + if not os.path.exists(video_file): + print("File {} not exist!".format(video_file)) + + try: + if "shareVideoGPTV" in video_file: + frame_files = [os.path.join(video_file, f) for f in os.listdir(video_file) if os.path.isfile(os.path.join(video_file, f))] + frame_files.sort() # Ensure the frames are sorted if they are named sequentially + + # TODO: Hard CODE: Determine the indices for uniformly sampling 10 frames + if self.data_args.force_sample: + num_frames_to_sample = self.data_args.frames_upbound + else: + num_frames_to_sample = 10 + + avg_fps = 2 + + total_frames = len(frame_files) + sampled_indices = np.linspace(0, total_frames - 1, num_frames_to_sample, dtype=int) + + + frame_time = [i/2 for i in sampled_indices] + frame_time = ",".join([f"{i:.2f}s" for i in frame_time]) + + video_time = total_frames / avg_fps + + # Read and store the sampled frames + video = [] + for idx in sampled_indices: + frame_path = frame_files[idx] + try: + with Image.open(frame_path) as img: + frame = img.convert("RGB") + video.append(frame) + except IOError: + print(f"Failed to read frame at path: {frame_path}") + else: + video, video_time, frame_time, num_frames_to_sample = process_video_with_decord(video_file, self.data_args) + + processor = self.data_args.image_processor + image = processor.preprocess(video, return_tensors="pt")["pixel_values"] + if self.data_args.add_time_instruction: + time_instruciton = f"The video lasts for {video_time:.2f} seconds, and {num_frames_to_sample} frames are uniformly sampled from it. These frames are located at {frame_time}.Please answer the following questions related to this video." + sources[0]["conversations"][0]["value"] = f'{DEFAULT_IMAGE_TOKEN}\n{time_instruciton}\n{sources[0]["conversations"][0]["value"].replace(DEFAULT_IMAGE_TOKEN, "")}' + image = [(image, video[0].size, "video")] + sources = preprocess_multimodal(copy.deepcopy([e["conversations"] for e in sources]), self.data_args) + # print(sources) + except Exception as e: + print(f"Error: {e}") + print(f"Failed to read video file: {video_file}") + return self._get_item(i + 1) + else: + sources = copy.deepcopy([e["conversations"] for e in sources]) + + has_image = ("image" in self.list_data_dict[i]) or ("video" in self.list_data_dict[i]) + data_dict = preprocess(sources, self.tokenizer, has_image=has_image) + + if "prompt" in data_dict: + prompt = data_dict["prompt"] + else: + prompt = None + + if isinstance(i, int): + data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0]) + + # image exist in the data + if "image" in self.list_data_dict[i]: + data_dict["image"] = image + elif "video" in self.list_data_dict[i]: + data_dict["image"] = image + elif self.data_args.is_multimodal: + # image does not exist in the data, but the model is multimodal + crop_size = self.data_args.image_processor.crop_size + data_dict["image"] = [ + (torch.zeros(1, 3, crop_size["height"], crop_size["width"]), (crop_size["width"], crop_size["height"]), "text"), + ] + # prompt exist in the data + if prompt is not None: + data_dict["prompt"] = prompt + + data_dict["id"] = self.list_data_dict[i].get("id", i) + + return data_dict + + +@dataclass +class DataCollatorForSupervisedDataset(object): + """Collate examples for supervised fine-tuning.""" + + tokenizer: transformers.PreTrainedTokenizer + + def pad_sequence(self, input_ids, batch_first, padding_value): + if self.tokenizer.padding_side == "left": + input_ids = [torch.flip(_input_ids, [0]) for _input_ids in input_ids] + input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=batch_first, padding_value=padding_value) + if self.tokenizer.padding_side == "left": + input_ids = torch.flip(input_ids, [1]) + return input_ids + + def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: + input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels")) + # input_ids, labels, ids = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels", "id")) + input_ids = [_input_ids[: self.tokenizer.model_max_length] for _input_ids in input_ids] + labels = [_labels[: self.tokenizer.model_max_length] for _labels in labels] + if self.tokenizer.pad_token_id is None: + # self.tokenizer.pad_token_id = self.tokenizer.eos_token_id # FIXME: this could only be triggered for llama3 model. + self.tokenizer.pad_token_id = 0 # This gets the best result. Don't know why. + input_ids = self.pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id) + labels = self.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) + batch = dict(input_ids=input_ids, labels=labels.long() if labels.dtype == torch.int32 else labels, attention_mask=input_ids.ne(self.tokenizer.pad_token_id)) + # batch = dict(input_ids=input_ids, labels=labels, attention_mask=input_ids.ne(self.tokenizer.pad_token_id), ids=ids) + + if "image" in instances[0]: + images = [instance["image"] for instance in instances] + + batch["image_sizes"] = [im[1] for im_list in images for im in im_list] + batch["modalities"] = [im[2] for im_list in images for im in im_list] + images = [im[0] for im_list in images for im in im_list] + + # if all(x is not None and x.shape == images[0].shape for x in images): + # Image: (N, P, C, H, W) + # Video: (N, F, C, H, W) + # batch["images"] = torch.stack(images) + # else: + batch["images"] = images + + if "prompt" in instances[0]: + batch["prompts"] = [instance["prompt"] for instance in instances] + + return batch + + +def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict: + """Make dataset and collator for supervised fine-tuning.""" + train_dataset = LazySupervisedDataset(tokenizer=tokenizer, data_path=data_args.data_path, data_args=data_args) + data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) + return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator) + + +def get_model(model_args, training_args, bnb_model_from_pretrained_args): + assert training_args.attn_implementation + if training_args.attn_implementation == "sdpa" and torch.__version__ < "2.1.2": + raise ValueError("The 'sdpa' attention implementation requires torch version 2.1.2 or higher.") + + customized_kwargs = dict() + customized_kwargs.update(bnb_model_from_pretrained_args) + cfg_pretrained = None + + overwrite_config = {} + if any( + [ + model_args.rope_scaling_factor is not None, + model_args.rope_scaling_type is not None, + model_args.mm_spatial_pool_stride is not None, + model_args.mm_spatial_pool_out_channels is not None, + model_args.mm_spatial_pool_mode is not None, + model_args.mm_resampler_type is not None, + ] + ): + cfg_pretrained = AutoConfig.from_pretrained(model_args.model_name_or_path) + + if model_args.use_pos_skipping is not None and model_args.pos_skipping_range is not None: + overwrite_config["use_pos_skipping"] = model_args.use_pos_skipping + overwrite_config["pos_skipping_range"] = model_args.pos_skipping_range + + if model_args.rope_scaling_factor is not None and model_args.rope_scaling_type is not None: + overwrite_config["rope_scaling"] = { + "factor": model_args.rope_scaling_factor, + "type": model_args.rope_scaling_type, + } + if training_args.model_max_length is None: + training_args.model_max_length = cfg_pretrained.max_position_embeddings * model_args.rope_scaling_factor + overwrite_config["max_sequence_length"] = training_args.model_max_length + assert training_args.model_max_length == int(cfg_pretrained.max_position_embeddings * model_args.rope_scaling_factor), print( + f"model_max_length: {training_args.model_max_length}, max_position_embeddings: {cfg_pretrained.max_position_embeddings}, rope_scaling_factor: {model_args.rope_scaling_factor}" + ) + # overwrite_config["max_sequence_length"] = model_args.max_sequence_length + # overwrite_config["tokenizer_model_max_length"] = model_args.tokenizer_model_max_length + + if model_args.mm_spatial_pool_stride is not None and model_args.mm_spatial_pool_out_channels is not None and model_args.mm_spatial_pool_mode is not None and model_args.mm_resampler_type is not None: + overwrite_config["mm_resampler_type"] = model_args.mm_resampler_type + overwrite_config["mm_spatial_pool_stride"] = model_args.mm_spatial_pool_stride + overwrite_config["mm_spatial_pool_out_channels"] = model_args.mm_spatial_pool_out_channels + overwrite_config["mm_spatial_pool_mode"] = model_args.mm_spatial_pool_mode + + if model_args.mm_spatial_pool_mode is not None: + overwrite_config["mm_spatial_pool_mode"] = model_args.mm_spatial_pool_mode + + if overwrite_config: + assert cfg_pretrained is not None, "cfg_pretrained is None" + + rank0_print(f"Overwriting config with {overwrite_config}") + for k, v in overwrite_config.items(): + setattr(cfg_pretrained, k, v) + + customized_kwargs["config"] = cfg_pretrained + + if model_args.model_class_name is not None: + actual_model_class_name = f"{model_args.model_class_name}ForCausalLM" + model_class = getattr(transformers, actual_model_class_name) + rank0_print(f"Using model class {model_class} from {model_args.model_class_name}") + model = model_class.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + attn_implementation=training_args.attn_implementation, + torch_dtype=(torch.bfloat16 if training_args.bf16 else None), + low_cpu_mem_usage=False, + **customized_kwargs, + ) + elif model_args.vision_tower is not None: + if "mixtral" in model_args.model_name_or_path.lower(): + model = LlavaMixtralForCausalLM.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + attn_implementation=training_args.attn_implementation, + torch_dtype=(torch.bfloat16 if training_args.bf16 else None), + low_cpu_mem_usage=False, + **customized_kwargs, + ) + from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock + + deepspeed.utils.set_z3_leaf_modules(model, [MixtralSparseMoeBlock]) + elif "mistral" in model_args.model_name_or_path.lower() or "zephyr" in model_args.model_name_or_path.lower(): + model = LlavaMistralForCausalLM.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + attn_implementation=training_args.attn_implementation, + torch_dtype=(torch.bfloat16 if training_args.bf16 else None), + low_cpu_mem_usage=False, + **customized_kwargs, + ) + elif ( + "wizardlm-2" in model_args.model_name_or_path.lower() + or "vicuna" in model_args.model_name_or_path.lower() + or "llama" in model_args.model_name_or_path.lower() + or "yi" in model_args.model_name_or_path.lower() + or "nous-hermes" in model_args.model_name_or_path.lower() + and "wizard-2" in model_args.model_name_or_path.lower() + ): + model = LlavaLlamaForCausalLM.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + attn_implementation=training_args.attn_implementation, + torch_dtype=(torch.bfloat16 if training_args.bf16 else None), + low_cpu_mem_usage=False, + **customized_kwargs, + ) + elif "qwen" in model_args.model_name_or_path.lower(): + if "moe" in model_args.model_name_or_path.lower() or "A14B" in model_args.model_name_or_path: + model = LlavaQwenMoeForCausalLM.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + attn_implementation=training_args.attn_implementation, + torch_dtype=(torch.bfloat16 if training_args.bf16 else None), + low_cpu_mem_usage=False, + **customized_kwargs, + ) + from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock + + deepspeed.utils.set_z3_leaf_modules(model, [Qwen2MoeSparseMoeBlock]) + else: + model = LlavaQwenForCausalLM.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + attn_implementation=training_args.attn_implementation, + torch_dtype=(torch.bfloat16 if training_args.bf16 else None), + low_cpu_mem_usage=False, + **customized_kwargs, + ) + elif "gemma" in model_args.model_name_or_path.lower(): + model = LlavaGemmaForCausalLM.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + attn_implementation=training_args.attn_implementation, + torch_dtype=(torch.bfloat16 if training_args.bf16 else None), + low_cpu_mem_usage=False, + **customized_kwargs, + ) + else: + raise ValueError(f"Unknown model class {model_args}") + else: + model = transformers.LlamaForCausalLM.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + attn_implementation=training_args.attn_implementation, + torch_dtype=(torch.bfloat16 if training_args.bf16 else None), + low_cpu_mem_usage=False, + **customized_kwargs, + ) + return model + + +def train(attn_implementation=None): + global local_rank + + parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + if training_args.verbose_logging: + rank0_print(f"Inspecting experiment hyperparameters:\n") + rank0_print(f"model_args = {vars(model_args)}\n\n") + rank0_print(f"data_args = {vars(data_args)}\n\n") + rank0_print(f"training_args = {vars(training_args)}\n\n") + # rank0_print(f"evaluation_args = {vars(evaluation_args)}\n\n") + + local_rank = training_args.local_rank + compute_dtype = torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32) + + bnb_model_from_pretrained_args = {} + if training_args.bits in [4, 8]: + from transformers import BitsAndBytesConfig + + bnb_model_from_pretrained_args.update( + dict( + device_map={"": training_args.device}, + load_in_4bit=training_args.bits == 4, + load_in_8bit=training_args.bits == 8, + quantization_config=BitsAndBytesConfig( + load_in_4bit=training_args.bits == 4, + load_in_8bit=training_args.bits == 8, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=compute_dtype, + bnb_4bit_use_double_quant=training_args.double_quant, + bnb_4bit_quant_type=training_args.quant_type, # {'fp4', 'nf4'} + ), + ) + ) + + model = get_model(model_args, training_args, bnb_model_from_pretrained_args) + model.config.use_cache = False + if model_args.rope_scaling_factor is not None and model_args.rope_scaling_type is not None: + model.config.rope_scaling = { + "factor": model_args.rope_scaling_factor, + "type": model_args.rope_scaling_type, + } + + if model_args.freeze_backbone: + model.model.requires_grad_(False) + + if training_args.bits in [4, 8]: + from peft import prepare_model_for_kbit_training + + model.config.torch_dtype = torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32) + model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing) + + if training_args.gradient_checkpointing: + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + if training_args.lora_enable: + from peft import LoraConfig, get_peft_model + + lora_config = LoraConfig( + r=training_args.lora_r, + lora_alpha=training_args.lora_alpha, + target_modules=find_all_linear_names(model), + lora_dropout=training_args.lora_dropout, + bias=training_args.lora_bias, + task_type="CAUSAL_LM", + ) + if training_args.bits == 16: + if training_args.bf16: + model.to(torch.bfloat16) + if training_args.fp16: + model.to(torch.float16) + rank0_print("Adding LoRA adapters...") + model = get_peft_model(model, lora_config) + + if "mistral" in model_args.model_name_or_path.lower() or "mixtral" in model_args.model_name_or_path.lower() or "zephyr" in model_args.model_name_or_path.lower(): + tokenizer = transformers.AutoTokenizer.from_pretrained(model_args.model_name_or_path, cache_dir=training_args.cache_dir, model_max_length=training_args.model_max_length, padding_side="left") + elif "qwen" in model_args.model_name_or_path.lower(): + tokenizer = transformers.AutoTokenizer.from_pretrained(model_args.model_name_or_path, cache_dir=training_args.cache_dir, model_max_length=training_args.model_max_length, padding_side="right") + elif ( + "wizardlm-2" in model_args.model_name_or_path.lower() + or "vicuna" in model_args.model_name_or_path.lower() + or "llama" in model_args.model_name_or_path.lower() + or "yi" in model_args.model_name_or_path.lower() + or "nous-hermes" in model_args.model_name_or_path.lower() + and "wizard-2" in model_args.model_name_or_path.lower() + ): + tokenizer = transformers.AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + model_max_length=training_args.model_max_length, + padding_side="right", + use_fast=False, + ) + + rank0_print(f"Prompt version: {model_args.version}") + if model_args.version == "v0": + if tokenizer.pad_token is None: + smart_tokenizer_and_embedding_resize( + special_tokens_dict=dict(pad_token="[PAD]"), + tokenizer=tokenizer, + model=model, + ) + elif model_args.version == "v0.5": + tokenizer.pad_token = tokenizer.unk_token + else: + if tokenizer.unk_token is not None: + tokenizer.pad_token = tokenizer.unk_token + if model_args.version in conversation_lib.conv_templates: + conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version] + else: + conversation_lib.default_conversation = conversation_lib.conv_templates["vicuna_v1"] + + if model_args.vision_tower is not None: + model.get_model().initialize_vision_modules(model_args=model_args, fsdp=training_args.fsdp) + + vision_tower = model.get_vision_tower() + vision_tower.to(dtype=torch.bfloat16 if training_args.bf16 else torch.float16, device=training_args.device) + + data_args.image_processor = vision_tower.image_processor + data_args.is_multimodal = True + + model.config.image_aspect_ratio = data_args.image_aspect_ratio + if data_args.image_grid_pinpoints is not None: + if isinstance(data_args.image_grid_pinpoints, str) and "x" in data_args.image_grid_pinpoints: + try: + patch_size = data_args.image_processor.size[0] + except Exception as e: + patch_size = data_args.image_processor.size["shortest_edge"] + + assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]" + # Use regex to extract the range from the input string + matches = re.findall(r"\((\d+)x(\d+)\)", data_args.image_grid_pinpoints) + range_start = tuple(map(int, matches[0])) + range_end = tuple(map(int, matches[-1])) + # Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1]) + grid_pinpoints = [(i, j) for i in range(range_start[0], range_end[0] + 1) for j in range(range_start[1], range_end[1] + 1)] + # Multiply all elements by patch_size + data_args.image_grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints] + elif isinstance(data_args.image_grid_pinpoints, str): + data_args.image_grid_pinpoints = ast.literal_eval(data_args.image_grid_pinpoints) + + model.config.image_grid_pinpoints = data_args.image_grid_pinpoints + model.config.image_crop_resolution = data_args.image_crop_resolution + model.config.image_split_resolution = data_args.image_split_resolution + model.config.tokenizer_padding_side = tokenizer.padding_side + model.config.tokenizer_model_max_length = tokenizer.model_max_length + model.config.mm_newline_position = model_args.mm_newline_position + model.config.add_faster_video = model_args.add_faster_video + model.config.faster_token_stride = model_args.faster_token_stride + model.config.add_time_instruction = data_args.add_time_instruction + model.config.force_sample = data_args.force_sample + model.config.mm_spatial_pool_stride = model_args.mm_spatial_pool_stride + + ### Deciding train which part of the model + if model_args.mm_tunable_parts is None: # traditional way of deciding which part to train + model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter + model.config.tune_mm_vision_resampler = training_args.tune_mm_vision_resampler = model_args.tune_mm_vision_resampler + if model_args.tune_mm_mlp_adapter or model_args.tune_mm_vision_resampler: + model.requires_grad_(False) + if model_args.tune_mm_mlp_adapter: + for p in model.get_model().mm_projector.parameters(): + p.requires_grad = True + if model_args.tune_mm_vision_resampler: + for p in model.get_model().vision_resampler.parameters(): + p.requires_grad = True + + model.config.freeze_mm_mlp_adapter = training_args.freeze_mm_mlp_adapter + if training_args.freeze_mm_mlp_adapter: + for p in model.get_model().mm_projector.parameters(): + p.requires_grad = False + + model.config.freeze_mm_vision_resampler = training_args.freeze_mm_vision_resampler + if training_args.freeze_mm_vision_resampler: + for p in model.get_model().vision_resampler.parameters(): + p.requires_grad = False + + model.config.unfreeze_mm_vision_tower = model_args.unfreeze_mm_vision_tower + if model_args.unfreeze_mm_vision_tower: + vision_tower.requires_grad_(True) + else: + vision_tower.requires_grad_(False) + + else: + rank0_print(f"Using mm_tunable_parts: {model_args.mm_tunable_parts}") + model.config.mm_tunable_parts = training_args.mm_tunable_parts = model_args.mm_tunable_parts + # Set the entire model to not require gradients by default + model.requires_grad_(False) + vision_tower.requires_grad_(False) + model.get_model().mm_projector.requires_grad_(False) + model.get_model().vision_resampler.requires_grad_(False) + # Parse the mm_tunable_parts to decide which parts to unfreeze + tunable_parts = model_args.mm_tunable_parts.split(",") + if "mm_mlp_adapter" in tunable_parts: + for p in model.get_model().mm_projector.parameters(): + p.requires_grad = True + if "mm_vision_resampler" in tunable_parts: + for p in model.get_model().vision_resampler.parameters(): + p.requires_grad = True + if "mm_vision_tower" in tunable_parts: + for name, param in model.named_parameters(): + if "vision_tower" in name: + param.requires_grad_(True) + if "mm_language_model" in tunable_parts: + for name, param in model.named_parameters(): + if "vision_tower" not in name and "mm_projector" not in name and "vision_resampler" not in name: + param.requires_grad_(True) + + total_params = sum(p.ds_numel if hasattr(p, "ds_numel") else p.numel() for p in model.parameters()) + trainable_params = sum(p.ds_numel if hasattr(p, "ds_numel") else p.numel() for p in model.parameters() if p.requires_grad) + rank0_print(f"Total parameters: ~{total_params/1e6:.2f} MB)") + rank0_print(f"Trainable parameters: ~{trainable_params/1e6:.2f} MB)") + if training_args.bits in [4, 8]: + model.get_model().mm_projector.to(dtype=compute_dtype, device=training_args.device) + + model.config.mm_use_im_start_end = data_args.mm_use_im_start_end = model_args.mm_use_im_start_end + model.config.mm_projector_lr = training_args.mm_projector_lr + model.config.mm_vision_tower_lr = training_args.mm_vision_tower_lr + training_args.use_im_start_end = model_args.mm_use_im_start_end + model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token + model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer) + + if training_args.bits in [4, 8]: + from peft.tuners.lora import LoraLayer + + for name, module in model.named_modules(): + if isinstance(module, LoraLayer): + if training_args.bf16: + module = module.to(torch.bfloat16) + if "norm" in name: + module = module.to(torch.float32) + if "lm_head" in name or "embed_tokens" in name: + if hasattr(module, "weight"): + if training_args.bf16 and module.weight.dtype == torch.float32: + module = module.to(torch.bfloat16) + + data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args) + trainer = LLaVATrainer(model=model, tokenizer=tokenizer, args=training_args, **data_module) + + if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): + trainer.train(resume_from_checkpoint=True) + else: + trainer.train() + trainer.save_state() + + model.config.use_cache = True + + if training_args.lora_enable: + state_dict = get_peft_state_maybe_zero_3(model.named_parameters(), training_args.lora_bias) + non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(model.named_parameters()) + if training_args.local_rank == 0 or training_args.local_rank == -1: + if hasattr(model, "config"): + model.config.save_pretrained(training_args.output_dir) + if hasattr(model, "generation_config"): + model.generation_config.save_pretrained(training_args.output_dir) + model.save_pretrained(training_args.output_dir, state_dict=state_dict) + torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, "non_lora_trainables.bin")) + else: + safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir) + + rank0_print(f"Model saved to {training_args.output_dir}") + + +if __name__ == "__main__": + train() diff --git a/llava/train/train_dpo.py b/llava/train/train_dpo.py new file mode 100755 index 000000000..037eec42f --- /dev/null +++ b/llava/train/train_dpo.py @@ -0,0 +1,1782 @@ +# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: +# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: +# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import copy +import deepspeed +from dataclasses import dataclass, field +import json +import logging +import pathlib +from typing import Dict, Optional, Sequence, List +import ast + +import yaml +import time +import random +import yaml +import math +import re +import torch + +import transformers +import tokenizers + +from llava.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IMAGE_TOKEN_INDEX +from torch.utils.data import Dataset +from llava.train.llava_trainer import LLaVADPOTrainer +from data_processing.utils import load_jsonl, load_json +from llava import conversation as conversation_lib +from llava.model import * +from llava.model.language_model.llava_qwen import LlavaQwenConfig +from llava.model.language_model.llava_llama import LlavaConfig +from llava.model.language_model.llava_mistral import LlavaMistralConfig +from llava.mm_utils import process_highres_image, process_anyres_image, process_highres_image_crop_split, tokenizer_image_token +from llava.utils import rank0_print +from transformers import AutoConfig +import pickle + +from trl.trainer.utils import DPODataCollatorWithPadding +from PIL import Image, ImageFile +from decord import VideoReader, cpu + +ImageFile.LOAD_TRUNCATED_IMAGES = True +from packaging import version +from typing import Any + +local_rank = None +import numpy as np + +IS_TOKENIZER_GREATER_THAN_0_14 = version.parse(tokenizers.__version__) >= version.parse("0.14") + + +@dataclass +class ModelArguments: + model_name_or_path: Optional[str] = field(default="facebook/opt-125m") + model_class_name: Optional[str] = field(default=None, metadata={"help": "Used to init model class, format is XXXXForCausalLM. e.g. currently XXXX is chosen from LlavaLlama, LlavaMixtral, LlavaMistral, Llama"}) + + mm_tunable_parts: Optional[str] = field( + default=None, metadata={"help": 'Could be "mm_mlp_adapter", "mm_vision_resampler", "mm_vision_tower,mm_mlp_adapter,mm_language_model", "mm_vision_tower,mm_mlp_adapter,mm_language_model", "mm_mlp_adapter,mm_language_model"'} + ) + # deciding which part of the multimodal model to tune, will overwrite other previous settings + + version: Optional[str] = field(default="v0") + freeze_backbone: bool = field(default=False) + tune_mm_mlp_adapter: bool = field(default=False) + tune_mm_vision_resampler: bool = field(default=False) + vision_tower: Optional[str] = field(default=None) + vision_tower_pretrained: Optional[str] = field(default=None) # default to the last layer + + unfreeze_mm_vision_tower: bool = field(default=False) + unfreeze_language_model: bool = field(default=False) + mm_vision_select_layer: Optional[int] = field(default=-1) # default to the last layer + pretrain_mm_mlp_adapter: Optional[str] = field(default=None) + mm_projector_type: Optional[str] = field(default="linear") + mm_use_im_start_end: bool = field(default=False) + mm_use_im_patch_token: bool = field(default=True) + mm_patch_merge_type: Optional[str] = field(default="flat") + mm_vision_select_feature: Optional[str] = field(default="patch") + mm_resampler_type: Optional[str] = field(default=None) + mm_mask_drop_mode: str = field(default="fixed") + mm_mask_drop_skip_percentage: float = field(default=0.0) + mm_mask_drop_ratio: float = field(default=0.25) + mm_mask_drop_ratio_upper: Optional[float] = field(default=None) + mm_mask_drop_ratio_lower: Optional[float] = field(default=None) + mm_spatial_pool_stride: Optional[int] = field(default=None) + mm_spatial_pool_mode: str = field(default="average") + mm_spatial_pool_out_channels: Optional[int] = field(default=None) + mm_perceiver_depth: Optional[int] = field(default=3) + mm_perceiver_latents: Optional[int] = field(default=32) + mm_perceiver_ff_mult: Optional[float] = field(default=4) + mm_perceiver_pretrained: Optional[str] = field(default=None) + mm_qformer_depth: Optional[int] = field(default=3) + mm_qformer_latents: Optional[int] = field(default=32) + mm_qformer_pretrained: Optional[str] = field(default=None) + + rope_scaling_factor: Optional[float] = field(default=None) + rope_scaling_type: Optional[str] = field(default=None) + + s2: Optional[bool] = field(default=False) + s2_scales: Optional[str] = field(default="336,672,1008") + + +@dataclass +class DataArguments: + data_path: str = field(default=None, metadata={"help": "Path to the training data, in llava's instruction.json format. Supporting multiple json files via /path/to/{a,b,c}.json"}) + lazy_preprocess: bool = False + is_multimodal: bool = False + image_folder: Optional[str] = field(default=None) + video_folder: Optional[str] = field(default=None) + video_fps: Optional[int] = field(default=1) + image_aspect_ratio: str = "square" + image_grid_pinpoints: Optional[str] = field(default=None) + image_crop_resolution: int = 384 + image_split_resolution: int = 384 + input_prompt: Optional[str] = field(default=None) + refine_prompt: Optional[bool] = field(default=False) + frames_upbound: Optional[int] = field(default=0) + num_sample: Optional[int] = field(default=None) + + +@dataclass +class TrainingArguments(transformers.TrainingArguments): + cache_dir: Optional[str] = field(default=None) + optim: str = field(default="adamw_torch") + remove_unused_columns: bool = field(default=False) + freeze_mm_mlp_adapter: bool = field(default=False) + freeze_mm_vision_resampler: bool = field(default=False) + mpt_attn_impl: Optional[str] = field(default="triton") + model_max_length: int = field( + default=4096, + metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."}, + ) + double_quant: bool = field(default=True, metadata={"help": "Compress the quantization statistics through double quantization."}) + quant_type: str = field(default="nf4", metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."}) + bits: int = field(default=16, metadata={"help": "How many bits to use."}) + lora_enable: bool = False + lora_r: int = 64 + lora_alpha: int = 16 + lora_dropout: float = 0.05 + lora_weight_path: str = "" + lora_bias: str = "none" + mm_projector_lr: Optional[float] = None + mm_vision_tower_lr: Optional[float] = None + group_by_varlen: bool = field(default=False) + group_by_modality_length: bool = field(default=False) + group_by_modality_length_auto: bool = field(default=False) + auto_find_batch_size: bool = field(default=False) + gradient_checkpointing: bool = field(default=True) + verbose_logging: bool = field(default=False) + attn_implementation: str = field(default="flash_attention_2", metadata={"help": "Use transformers attention implementation."}) + dpo_alpha: float = field(default=1.0) + beta: float = field(default=0.1) + gamma: float = field(default=1.0) + generate_during_eval: bool = field(default=False) + precompute_ref_log_probs: bool = field(default=False) + + +def maybe_zero_3(param, ignore_status=False, name=None): + from deepspeed import zero + from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus + + if hasattr(param, "ds_id"): + if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: + if not ignore_status: + logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}") + with zero.GatheredParameters([param]): + param = param.data.detach().cpu().clone() + else: + param = param.detach().cpu().clone() + return param + + +# Borrowed from peft.utils.get_peft_model_state_dict +def get_peft_state_maybe_zero_3(named_params, bias): + if bias == "none": + to_return = {k: t for k, t in named_params if "lora_" in k} + elif bias == "all": + to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k} + elif bias == "lora_only": + to_return = {} + maybe_lora_bias = {} + lora_bias_names = set() + for k, t in named_params: + if "lora_" in k: + to_return[k] = t + bias_name = k.split("lora_")[0] + "bias" + lora_bias_names.add(bias_name) + elif "bias" in k: + maybe_lora_bias[k] = t + for k, t in maybe_lora_bias: + if bias_name in lora_bias_names: + to_return[bias_name] = t + else: + raise NotImplementedError + to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()} + return to_return + + +def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True): + to_return = {k: t for k, t in named_params if "lora_" not in k} + if require_grad_only: + to_return = {k: t for k, t in to_return.items() if t.requires_grad} + to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} + return to_return + + +def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): + to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)} + to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} + return to_return + + +def find_all_linear_names(model): + cls = torch.nn.Linear + lora_module_names = set() + multimodal_keywords = ["mm_projector", "vision_tower", "vision_resampler"] + for name, module in model.named_modules(): + if any(mm_keyword in name for mm_keyword in multimodal_keywords): + continue + if isinstance(module, cls): + names = name.split(".") + lora_module_names.add(names[0] if len(names) == 1 else names[-1]) + + if "lm_head" in lora_module_names: # needed for 16-bit + lora_module_names.remove("lm_head") + return list(lora_module_names) + + +def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str): + """Collects the state dict and dump to disk.""" + if hasattr(trainer.args, "tune_mm_mlp_adapter") and trainer.args.tune_mm_mlp_adapter: + check_only_save_mm_adapter_tunnable = True + # only has mm_mlp_adapter and mm_vision_resampler in the tuneable parts + elif hasattr(trainer.args, "mm_tunable_parts") and (len(trainer.args.mm_tunable_parts.split(",")) == 1 and ("mm_mlp_adapter" in trainer.args.mm_tunable_parts or "mm_vision_resampler" in trainer.args.mm_tunable_parts)): + check_only_save_mm_adapter_tunnable = True + else: + check_only_save_mm_adapter_tunnable = False + + trainer.accelerator.wait_for_everyone() + torch.cuda.synchronize() + rank0_print(f"Only save projectors: {check_only_save_mm_adapter_tunnable}") + if check_only_save_mm_adapter_tunnable: + # Only save Adapter + keys_to_match = ["mm_projector", "vision_resampler"] + if getattr(trainer.args, "use_im_start_end", False): + keys_to_match.extend(["embed_tokens", "embed_in"]) + + weight_to_save = get_mm_adapter_state_maybe_zero_3(trainer.model.named_parameters(), keys_to_match) + trainer.model.config.save_pretrained(output_dir) + + current_folder = output_dir.split("/")[-1] + parent_folder = os.path.dirname(output_dir) + if trainer.args.local_rank == 0 or trainer.args.local_rank == -1: + if current_folder.startswith("checkpoint-"): + mm_projector_folder = os.path.join(parent_folder, "mm_projector") + os.makedirs(mm_projector_folder, exist_ok=True) + torch.save(weight_to_save, os.path.join(mm_projector_folder, f"{current_folder}.bin")) + else: + torch.save(weight_to_save, os.path.join(output_dir, f"mm_projector.bin")) + return + + if trainer.deepspeed: + trainer.save_model(output_dir) + return + + state_dict = trainer.model.state_dict() + if trainer.args.should_save: + cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()} + del state_dict + trainer._save(output_dir, state_dict=cpu_state_dict) # noqa + + +def smart_tokenizer_and_embedding_resize( + special_tokens_dict: Dict, + tokenizer: transformers.PreTrainedTokenizer, + model: transformers.PreTrainedModel, +): + """Resize tokenizer and embedding. + + Note: This is the unoptimized version that may make your embedding size not be divisible by 64. + """ + num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) + model.resize_token_embeddings(len(tokenizer)) + + if num_new_tokens > 0: + input_embeddings = model.get_input_embeddings().weight.data + output_embeddings = model.get_output_embeddings().weight.data + + input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) + output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) + + input_embeddings[-num_new_tokens:] = input_embeddings_avg + output_embeddings[-num_new_tokens:] = output_embeddings_avg + + +def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict: + """Tokenize a list of strings.""" + tokenized_list = [ + tokenizer( + text, + return_tensors="pt", + padding="longest", + max_length=tokenizer.model_max_length, + truncation=True, + ) + for text in strings + ] + input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list] + input_ids_lens = labels_lens = [tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list] + return dict( + input_ids=input_ids, + labels=labels, + input_ids_lens=input_ids_lens, + labels_lens=labels_lens, + ) + + +def _mask_targets(target, tokenized_lens, speakers): + # cur_idx = 0 + cur_idx = tokenized_lens[0] + tokenized_lens = tokenized_lens[1:] + target[:cur_idx] = IGNORE_INDEX + for tokenized_len, speaker in zip(tokenized_lens, speakers): + if speaker == "human": + target[cur_idx + 2 : cur_idx + tokenized_len] = IGNORE_INDEX + cur_idx += tokenized_len + + +def _add_speaker_and_signal(header, source, get_conversation=True): + """Add speaker and start/end signal on each round.""" + BEGIN_SIGNAL = "### " + END_SIGNAL = "\n" + conversation = header + for sentence in source: + from_str = sentence["from"] + if from_str.lower() == "human": + from_str = conversation_lib.default_conversation.roles[0] + elif from_str.lower() == "gpt": + from_str = conversation_lib.default_conversation.roles[1] + else: + from_str = "unknown" + sentence["value"] = BEGIN_SIGNAL + from_str + ": " + sentence["value"] + END_SIGNAL + if get_conversation: + conversation += sentence["value"] + conversation += BEGIN_SIGNAL + return conversation + + +def preprocess_multimodal(sources: Sequence[str], data_args: DataArguments) -> Dict: + is_multimodal = data_args.is_multimodal + if not is_multimodal: + return sources + + for source in sources: + for sentence in source: + if DEFAULT_IMAGE_TOKEN in sentence["value"] and not sentence["value"].startswith(DEFAULT_IMAGE_TOKEN): + sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, "").strip() + sentence["value"] = DEFAULT_IMAGE_TOKEN + "\n" + sentence["value"] + sentence["value"] = sentence["value"].strip() + if "mmtag" in conversation_lib.default_conversation.version: + sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, "" + DEFAULT_IMAGE_TOKEN + "") + replace_token = DEFAULT_IMAGE_TOKEN + if data_args.mm_use_im_start_end: + replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN + sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token) + + return sources + + +def preprocess_multimodal_movie(sources: Sequence[str], data_args: DataArguments, video_inputs: str) -> Dict: + is_multimodal = data_args.is_multimodal + if not is_multimodal: + return sources + + for source in sources: + for sentence in source: + if DEFAULT_IMAGE_TOKEN in sentence["value"]: + prompt = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, "").strip() + replace_token = video_inputs + if data_args.mm_use_im_start_end: + replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN + sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token) + + return sources, prompt + + +def preprocess_llama_2(sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False) -> Dict: + conv = conversation_lib.default_conversation.copy() + roles = {"human": conv.roles[0], "gpt": conv.roles[1]} + + # Apply prompt templates + conversations = [] + for i, source in enumerate(sources): + if roles[source[0]["from"]] != conv.roles[0]: + # Skip the first one if it is not from human + source = source[1:] + + conv.messages = [] + for j, sentence in enumerate(source): + role = roles[sentence["from"]] + assert role == conv.roles[j % 2], f"{i}" + conv.append_message(role, sentence["value"]) + conversations.append(conv.get_prompt()) + + # Tokenize conversations + + if has_image: + input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations], dim=0) + else: + input_ids = tokenizer( + conversations, + return_tensors="pt", + padding="longest", + max_length=tokenizer.model_max_length, + truncation=True, + ).input_ids + + targets = input_ids.clone() + + assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2 + + # Mask targets + sep = "[/INST] " + for conversation, target in zip(conversations, targets): + total_len = int(target.ne(tokenizer.pad_token_id).sum()) + + rounds = conversation.split(conv.sep2) + cur_len = 1 + target[:cur_len] = IGNORE_INDEX + for i, rou in enumerate(rounds): + if rou == "": + break + + parts = rou.split(sep) + if len(parts) != 2: + break + parts[0] += sep + + if has_image: + round_len = len(tokenizer_image_token(rou, tokenizer)) + instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2 + else: + round_len = len(tokenizer(rou).input_ids) + instruction_len = len(tokenizer(parts[0]).input_ids) - 2 + + target[cur_len : cur_len + instruction_len] = IGNORE_INDEX + + cur_len += round_len + target[cur_len:] = IGNORE_INDEX + + if cur_len < tokenizer.model_max_length: + if cur_len != total_len: + target[:] = IGNORE_INDEX + rank0_print(f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)") + + return dict( + input_ids=input_ids, + labels=targets, + ) + + +def make_conv(prompt, answer): + return [ + { + "from": "human", + "value": prompt, + }, + { + "from": "gpt", + "value": answer, + }, + ] + + +def preprocess_gemma(sources: List[List[Dict[str, str]]], tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False) -> Dict: + conv: conversation_lib.Conversation = conversation_lib.default_conversation.copy() + roles: Dict[str, str] = {"human": conv.roles[0], "gpt": conv.roles[1]} + + # Apply prompt templates + conversations: List[str] = [] + for i, source in enumerate(sources): + if roles[source[0]["from"]] != conv.roles[0]: + # Skip the first one if it is not from human + source: List[Dict[str, str]] = source[1:] + + conv.messages = [] + for j, sentence in enumerate(source): + role: str = roles[sentence["from"]] + assert role == conv.roles[j % 2], f"{i}" + conv.append_message(role, sentence["value"]) + conversations.append(conv.get_prompt()) + + # Tokenize conversations + if has_image: + input_ids: torch.Tensor = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations], dim=0) + else: + input_ids: torch.Tensor = tokenizer( + conversations, + return_tensors="pt", + padding="longest", + max_length=tokenizer.model_max_length, + truncation=True, + ).input_ids + + targets: torch.Tensor = input_ids.clone() + assert conv.sep_style == conversation_lib.SeparatorStyle.GEMMA + + # Mask target + sep: str = conv.sep + conv.roles[1] + for conversation, target in zip(conversations, targets): + total_len: int = int(target.ne(tokenizer.pad_token_id).sum()) + + rounds: List[str] = conversation.split(conv.sep) + re_rounds = [] + for conv_idx in range(0, len(rounds), 2): + re_rounds.append(conv.sep.join(rounds[conv_idx : conv_idx + 2])) + + cur_len = 1 # Ignore + target[:cur_len] = IGNORE_INDEX + for i, rou in enumerate(re_rounds): + if rou == "": + break + + parts = rou.split(sep) + if len(parts) != 2: + break + parts[0] += sep # Re-append sep because split on this + # Now "".join(parts)==rou + + if has_image: + round_len = len(tokenizer_image_token(rou, tokenizer)) - 1 # Ignore + instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 1 # Ignore + else: + round_len = len(tokenizer(rou).input_ids) - 1 # Ignore + instruction_len = len(tokenizer(parts[0]).input_ids) - 1 # Ignore + + round_len += 2 # sep: \n takes 2 tokens + target[cur_len : cur_len + instruction_len] = IGNORE_INDEX + cur_len += round_len + + target[cur_len:] = IGNORE_INDEX + + if cur_len < tokenizer.model_max_length: + if cur_len != total_len: + target[:] = IGNORE_INDEX + rank0_print(f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)") + + return dict( + input_ids=input_ids, + labels=targets, + ) + + +def preprocess_qwen(sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False, max_len=2048, system_message: str = "You are a helpful assistant.") -> Dict: + roles = {"human": "<|im_start|>user", "gpt": "<|im_start|>assistant"} + + im_start, im_end = tokenizer.additional_special_tokens_ids + nl_tokens = tokenizer("\n").input_ids + _system = tokenizer("system").input_ids + nl_tokens + _user = tokenizer("user").input_ids + nl_tokens + _assistant = tokenizer("assistant").input_ids + nl_tokens + + # Apply prompt templates + input_ids, targets = [], [] + for i, source in enumerate(sources): + if roles[source[0]["from"]] != roles["human"]: + source = source[1:] + + input_id, target = [], [] + system = [im_start] + _system + tokenizer(system_message).input_ids + [im_end] + nl_tokens + input_id += system + target += [im_start] + [IGNORE_INDEX] * (len(system) - 3) + [im_end] + nl_tokens + assert len(input_id) == len(target) + for j, sentence in enumerate(source): + role = roles[sentence["from"]] + if has_image and "" in sentence["value"]: + assert sentence["value"].startswith(""), print(sentence["value"]) + + _input_id = tokenizer(role).input_ids + nl_tokens + [IMAGE_TOKEN_INDEX] + nl_tokens + tokenizer(sentence["value"][len("") :]).input_ids + [im_end] + nl_tokens + else: + _input_id = tokenizer(role).input_ids + nl_tokens + tokenizer(sentence["value"]).input_ids + [im_end] + nl_tokens + input_id += _input_id + if role == "<|im_start|>user": + _target = [im_start] + [IGNORE_INDEX] * (len(_input_id) - 3) + [im_end] + nl_tokens + elif role == "<|im_start|>assistant": + _target = [im_start] + [IGNORE_INDEX] * len(tokenizer(role).input_ids) + _input_id[len(tokenizer(role).input_ids) + 1 : -2] + [im_end] + nl_tokens + else: + raise NotImplementedError + target += _target + assert len(input_id) == len(target) + # input_id += [tokenizer.pad_token_id] * (max_len - len(input_id)) + # target += [IGNORE_INDEX] * (max_len - len(target)) + input_ids.append(input_id) + targets.append(target) + input_ids = torch.tensor(input_ids, dtype=torch.long) + targets = torch.tensor(targets, dtype=torch.long) + + return dict( + input_ids=input_ids, # tensor(bs x seq_len) + labels=targets, # tensor(bs x seq_len) + # attention_mask=input_ids.ne(tokenizer.pad_token_id), # tensor(bs x seq_len) + ) + + +def preprocess_llama3( + sources, + tokenizer: transformers.PreTrainedTokenizer, + has_image: bool = False, + max_len=2048, + system_message: str = "You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.", +) -> Dict: + roles = {"human": "<|start_header_id|>user<|end_header_id|>", "gpt": "<|start_header_id|>assistant<|end_header_id|>"} + + eot_id = tokenizer.convert_tokens_to_ids("<|eot_id|>") + nl_tokens = tokenizer("\n").input_ids + + # Apply prompt templates + input_ids, targets = [], [] + for i, source in enumerate(sources): + if roles[source[0]["from"]] != roles["human"]: + source = source[1:] + + input_id, target = [], [] + system = tokenizer("<|begin_of_text|>").input_ids + tokenizer("<|start_header_id|>system<|end_header_id|>").input_ids + nl_tokens * 2 + tokenizer(system_message).input_ids + [eot_id] + input_id += system + target += [IGNORE_INDEX] * len(system) + for j, sentence in enumerate(source): + role = roles[sentence["from"]] + if has_image and "" in sentence["value"]: + assert sentence["value"].startswith(""), print(sentence["value"]) + _input_id = tokenizer(role).input_ids + nl_tokens * 2 + [IMAGE_TOKEN_INDEX] + tokenizer(sentence["value"][len("") :]).input_ids + [eot_id] + else: + _input_id = tokenizer(role).input_ids + nl_tokens * 2 + tokenizer(sentence["value"]).input_ids + [eot_id] + input_id += _input_id + if role == "<|start_header_id|>user<|end_header_id|>": + _target = [IGNORE_INDEX] * len(_input_id) + elif role == "<|start_header_id|>assistant<|end_header_id|>": + _target = [IGNORE_INDEX] * (len(tokenizer(role).input_ids) + 2) + _input_id[len(tokenizer(role).input_ids) + 2 : -1] + [eot_id] + else: + raise NotImplementedError + target += _target + assert len(input_id) == len(target), f"{len(input_id)} != {len(target)}" + input_ids.append(input_id) + targets.append(target) + input_ids = torch.tensor(input_ids, dtype=torch.long) + targets = torch.tensor(targets, dtype=torch.long) + + return dict( + input_ids=input_ids, # tensor(bs x seq_len) + labels=targets, # tensor(bs x seq_len) + ) + + +def preprocess_v1(sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False) -> Dict: + conv = conversation_lib.default_conversation.copy() + roles = {"human": conv.roles[0], "gpt": conv.roles[1]} + + # Apply prompt templates + conversations = [] + for i, source in enumerate(sources): + if roles[source[0]["from"]] != conv.roles[0]: + # Skip the first one if it is not from human + source = source[1:] + + conv.messages = [] + for j, sentence in enumerate(source): + role = roles[sentence["from"]] + assert role == conv.roles[j % 2], f"{i}" + conv.append_message(role, sentence["value"]) + conversations.append(conv.get_prompt()) + + # Tokenize conversations + + if has_image: + input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations], dim=0) + else: + input_ids = tokenizer( + conversations, + return_tensors="pt", + padding="longest", + max_length=tokenizer.model_max_length, + truncation=True, + ).input_ids + + targets = input_ids.clone() + + assert conv.sep_style == conversation_lib.SeparatorStyle.TWO + + # Mask targets + sep = conv.sep + conv.roles[1] + ": " + for conversation, target in zip(conversations, targets): + total_len = int(target.ne(tokenizer.pad_token_id).sum()) + + rounds = conversation.split(conv.sep2) + cur_len = 1 + target[:cur_len] = IGNORE_INDEX + for i, rou in enumerate(rounds): + if rou == "": + break + + parts = rou.split(sep) + if len(parts) != 2: + break + parts[0] += sep + + if has_image: + round_len = len(tokenizer_image_token(rou, tokenizer)) + instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2 + else: + round_len = len(tokenizer(rou).input_ids) + instruction_len = len(tokenizer(parts[0]).input_ids) - 2 + + if i != 0 and not tokenizer.legacy and IS_TOKENIZER_GREATER_THAN_0_14: + round_len -= 1 + instruction_len -= 1 + + target[cur_len : cur_len + instruction_len] = IGNORE_INDEX + + cur_len += round_len + target[cur_len:] = IGNORE_INDEX + + if cur_len < tokenizer.model_max_length: + if cur_len != total_len: + target[:] = IGNORE_INDEX + print(f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)") + + return dict( + input_ids=input_ids, + labels=targets, + ) + + +def preprocess_mpt(sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False) -> Dict: + conv = conversation_lib.default_conversation.copy() + roles = {"human": conv.roles[0], "gpt": conv.roles[1]} + + # Apply prompt templates + conversations = [] + for i, source in enumerate(sources): + if roles[source[0]["from"]] != conv.roles[0]: + # Skip the first one if it is not from human + source = source[1:] + + conv.messages = [] + for j, sentence in enumerate(source): + role = roles[sentence["from"]] + assert role == conv.roles[j % 2], f"{i}" + conv.append_message(role, sentence["value"]) + conversations.append(conv.get_prompt()) + + # Tokenize conversations + + if has_image: + input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations], dim=0) + else: + input_ids = tokenizer( + conversations, + return_tensors="pt", + padding="longest", + max_length=tokenizer.model_max_length, + truncation=True, + ).input_ids + + targets = input_ids.clone() + assert conv.sep_style == conversation_lib.SeparatorStyle.MPT + + # Mask targets + sep = conv.sep + conv.roles[1] + for conversation, target in zip(conversations, targets): + total_len = int(target.ne(tokenizer.pad_token_id).sum()) + + rounds = conversation.split(conv.sep) + re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt + for conv_idx in range(3, len(rounds), 2): + re_rounds.append(conv.sep.join(rounds[conv_idx : conv_idx + 2])) # user + gpt + cur_len = 1 + target[:cur_len] = IGNORE_INDEX + for i, rou in enumerate(re_rounds): + if rou == "": + break + + parts = rou.split(sep) + if len(parts) != 2: + break + parts[0] += sep + + if has_image: + round_len = len(tokenizer_image_token(rou, tokenizer)) + instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 1 + else: + round_len = len(tokenizer(rou).input_ids) + instruction_len = len(tokenizer(parts[0]).input_ids) - 1 + + if i != 0 and getattr(tokenizer, "legacy", False) and IS_TOKENIZER_GREATER_THAN_0_14: + round_len += 1 + instruction_len += 1 + + target[cur_len : cur_len + instruction_len] = IGNORE_INDEX + + cur_len += round_len + target[cur_len:] = IGNORE_INDEX + + if cur_len < tokenizer.model_max_length: + if cur_len != total_len: + target[:] = IGNORE_INDEX + print(f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f"(#turns={len(re_rounds)} ignored)") + + return dict( + input_ids=input_ids, + labels=targets, + ) + + +def preprocess_plain( + sources: Sequence[str], + tokenizer: transformers.PreTrainedTokenizer, +) -> Dict: + # add end signal and concatenate together + conversations = [] + for source in sources: + assert len(source) == 2 + assert DEFAULT_IMAGE_TOKEN in source[0]["value"] + source[0]["value"] = DEFAULT_IMAGE_TOKEN + conversation = source[0]["value"] + source[1]["value"] + conversation_lib.default_conversation.sep + conversations.append(conversation) + # tokenize conversations + input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations] + targets = copy.deepcopy(input_ids) + for target, source in zip(targets, sources): + tokenized_len = len(tokenizer_image_token(source[0]["value"], tokenizer)) + target[:tokenized_len] = IGNORE_INDEX + + return dict(input_ids=input_ids, labels=targets) + + +def preprocess(sources: Sequence[str], tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False) -> Dict: + """ + Given a list of sources, each is a conversation list. This transform: + 1. Add signal '### ' at the beginning each sentence, with end signal '\n'; + 2. Concatenate conversations together; + 3. Tokenize the concatenated conversation; + 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX. + """ + if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN: + return preprocess_plain(sources, tokenizer) + if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.LLAMA_2: + return preprocess_llama_2(sources, tokenizer, has_image=has_image) + if conversation_lib.default_conversation.version.startswith("v1"): + return preprocess_v1(sources, tokenizer, has_image=has_image) + if conversation_lib.default_conversation.version == "mpt": + return preprocess_mpt(sources, tokenizer, has_image=has_image) + if conversation_lib.default_conversation.version == "qwen": + return preprocess_qwen(sources, tokenizer, has_image=has_image) + if conversation_lib.default_conversation.version == "gemma": + return preprocess_gemma(sources, tokenizer, has_image=has_image) + if conversation_lib.default_conversation.version == "llama_v3": + return preprocess_llama3(sources, tokenizer, has_image=has_image) + # add end signal and concatenate together + conversations = [] + for source in sources: + header = f"{conversation_lib.default_conversation.system}\n\n" + conversation = _add_speaker_and_signal(header, source) + conversations.append(conversation) + + # tokenize conversations + def get_tokenize_len(prompts): + return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts] + + if has_image: + input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors="pt") for prompt in conversations] + else: + conversations_tokenized = _tokenize_fn(conversations, tokenizer) + input_ids = conversations_tokenized["input_ids"] + + targets = copy.deepcopy(input_ids) + for target, source in zip(targets, sources): + if has_image: + tokenized_lens = get_tokenize_len([header] + [s["value"] for s in source]) + else: + tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source], tokenizer)["input_ids_lens"] + speakers = [sentence["from"] for sentence in source] + _mask_targets(target, tokenized_lens, speakers) + + return dict(input_ids=input_ids, labels=targets) + + +def load_data(data_path): + if "jsonl" in data_path: + data_list = load_jsonl(data_path) + else: + data_list = load_json(data_path) + return data_list + + +class DPODataset(Dataset): + """Dataset for DPODataset fine-tuning.""" + + def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer, data_args: DataArguments): + super(DPODataset, self).__init__() + # Handle multiple JSON files specified in the data_path + self.list_data_dict = [] + + if "{" in data_path and "}" in data_path: + base_path, file_pattern = re.match(r"^(.*)\{(.*)\}\.json$", data_path).groups() + file_names = file_pattern.split(",") + rank0_print(f"Loading {file_names} from {base_path}") + data_args.dataset_paths = [] + for file_name in file_names: + data_args.dataset_paths.append(f"{base_path}{file_name}.json") + full_path = f"{base_path}{file_name}.json" + rank0_print(f"Loading {full_path}") + cur_data_dict = load_data(full_path) + rank0_print(f"Loaded {len(cur_data_dict)} samples from {full_path}") + self.list_data_dict.extend(cur_data_dict) + elif data_path.endswith(".yaml"): + with open(data_path, "r") as file: + yaml_data = yaml.safe_load(file) + datasets = yaml_data.get("datasets") + # file should be in the format of: + # datasets: + # - json_path: xxxx1.json + # sampling_strategy: first:1000 + # - json_path: xxxx2.json + # sampling_strategy: end:3000 + # - json_path: xxxx3.json + # sampling_strategy: random:999 + data_args.dataset_paths = [dataset.get("json_path") for dataset in datasets] + for dataset in datasets: + json_path = dataset.get("json_path") + sampling_strategy = dataset.get("sampling_strategy", "all") + sampling_number = None + + rank0_print(f"Loading {json_path} with {sampling_strategy} sampling strategy") + cur_data_dict = load_data(json_path) + + if ":" in sampling_strategy: + sampling_strategy, sampling_number = sampling_strategy.split(":") + if "%" in sampling_number: + sampling_number = math.ceil(int(sampling_number.split("%")[0]) * len(cur_data_dict) / 100) + else: + sampling_number = int(sampling_number) + + # Apply the sampling strategy + if sampling_strategy == "first" and sampling_number is not None: + cur_data_dict = cur_data_dict[:sampling_number] + elif sampling_strategy == "end" and sampling_number is not None: + cur_data_dict = cur_data_dict[-sampling_number:] + elif sampling_strategy == "random" and sampling_number is not None: + random.shuffle(cur_data_dict) + cur_data_dict = cur_data_dict[:sampling_number] + + rank0_print(f"Loaded {len(cur_data_dict)} samples from {json_path}") + self.list_data_dict.extend(cur_data_dict) + else: + data_args.dataset_paths = [data_path] + rank0_print(f"Loading {data_path}") + cur_data_dict = load_data(data_path) + rank0_print(f"Loaded {len(cur_data_dict)} samples from {data_path}") + self.list_data_dict.extend(cur_data_dict) + + rank0_print("Formatting inputs...Skip in lazy mode") + self.tokenizer = tokenizer + self.data_args = data_args + + def __len__(self): + return len(self.list_data_dict) + + @property + def lengths(self): + length_list = [] + for sample in self.list_data_dict: + # Calculate the length of the prompt, answer, chosen, and rejected text + cur_len = len(sample["prompt"].split()) + len(sample["answer"].split()) + len(sample["chosen"].split()) + len(sample["rejected"].split()) + # Add additional tokens if an image is present + img_tokens = 128 if "image" in sample else 0 + length_list.append(cur_len + img_tokens) + return length_list + + @property + def modality_lengths(self): + length_list = [] + for sample in self.list_data_dict: + # Calculate the length of the prompt, answer, chosen, and rejected text + cur_len = len(sample["prompt"].split()) + len(sample["answer"].split()) + len(sample["chosen"].split()) + len(sample["rejected"].split()) + # If the sample includes a video, the length is positive; otherwise, it is negative + cur_len = cur_len if ("video" in sample or "image" in sample) else -cur_len + length_list.append(cur_len) + return length_list + + def process_image(self, image_file): + image_folder = self.data_args.image_folder + processor = self.data_args.image_processor + # print(f"\n\nInspecting the image path, folder = {image_folder}, image={image_file}\n\n") + try: + image = Image.open(os.path.join(image_folder, image_file)).convert("RGB") + except Exception as exn: + print(f"Failed to open image {image_file}. Exception:", exn) + raise exn + + image_size = image.size + if self.data_args.image_aspect_ratio == "highres": + image = process_highres_image(image, self.data_args.image_processor, self.data_args.image_grid_pinpoints) + elif self.data_args.image_aspect_ratio == "anyres" or "anyres" in self.data_args.image_aspect_ratio: + image = process_anyres_image(image, self.data_args.image_processor, self.data_args.image_grid_pinpoints) + elif self.data_args.image_aspect_ratio == "crop_split": + image = process_highres_image_crop_split(image, self.data_args) + elif self.data_args.image_aspect_ratio == "pad": + + def expand2square(pil_img, background_color): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + + image = expand2square(image, tuple(int(x * 255) for x in processor.image_mean)) + image = processor.preprocess(image, return_tensors="pt")["pixel_values"][0] + else: + image = processor.preprocess(image, return_tensors="pt")["pixel_values"][0] + return image, image_size, "image" + + def __getitem__(self, i) -> Dict[str, torch.Tensor]: + # TODO: define number of retries somewhere else + num_base_retries = 3 + num_final_retries = 300 + + # try the current sample first + for attempt_idx in range(num_base_retries): + try: + sample = self._get_item(i) + return sample + except Exception as e: + # sleep 1s in case it is a cloud disk issue + print(f"[Try #{attempt_idx}] Failed to fetch sample {i}. Exception:", e) + time.sleep(1) + + # try other samples, in case it is file corruption issue + for attempt_idx in range(num_base_retries): + try: + next_index = min(i + 1, len(self.list_data_dict) - 1) + # sample_idx = random.choice(range(len(self))) + sample = self._get_item(next_index) + return sample + except Exception as e: + # no need to sleep + print(f"[Try other #{attempt_idx}] Failed to fetch sample {next_index}. Exception:", e) + pass + + # still fail, most likely to be path issue or cloud disk issue, retry the same sample for longer + # for attempt_idx in range(num_final_retries): + # try: + # sample = self._get_item(i) + # return sample + # except Exception as e: + # # sleep 1s in case it is a cloud disk issue + # print(f"[Final try #{attempt_idx}] Failed to fetch sample {i}. Exception:", e) + # time.sleep(1) + + # Finally raise exception on failing. + assert False, "Failed to fetch sample." + + def _get_item(self, i) -> Dict[str, torch.Tensor]: + sources = self.list_data_dict[i] + if isinstance(i, int): + sources = [sources] + assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME + + suffix = None + if "image" in sources[0]: + image_file = self.list_data_dict[i]["image"] + if type(image_file) is list: + image = [self.process_image(f) for f in image_file] + else: + image = [self.process_image(image_file)] + # sources = preprocess_multimodal(copy.deepcopy([e["conversations"] for e in sources]), self.data_args) + + elif "video" in sources[0]: # FIXME: This logic should be largely improved by Yuanhan. It's too messy now. + video_file = self.list_data_dict[i]["video"] + video_folder = self.data_args.video_folder + video_file = os.path.join(video_folder, video_file) + suffix = video_file.split(".")[-1] + if not os.path.exists(video_file): + print("File {} not exist!".format(video_file)) + + if suffix == "pkl": + video_info = pickle.load(open(video_file, "rb")) + image = torch.from_numpy(video_info["feats"][:, 1:]) + input_prompt = video_info["inputs"].replace("...", "") + # replace the default image token with multiple tokens + input_prompt = input_prompt.replace(DEFAULT_IMAGE_TOKEN, DEFAULT_IMAGE_TOKEN * self.data_args.video_token) + sources, query_prompt = preprocess_multimodal_movie(copy.deepcopy([e["conversations"] for e in sources]), self.data_args, input_prompt) + else: # using videoreader + if "shareVideoGPTV" not in video_file and "liangke" not in video_file: + vr = VideoReader(video_file, ctx=cpu(0)) + total_frame_num = len(vr) + avg_fps = round(vr.get_avg_fps() / self.data_args.video_fps) + frame_idx = [i for i in range(0, total_frame_num, avg_fps)] + if self.data_args.frames_upbound > 0: + if len(frame_idx) > self.data_args.frames_upbound: + uniform_sampled_frames = np.linspace(0, total_frame_num - 1, self.data_args.frames_upbound, dtype=int) + frame_idx = uniform_sampled_frames.tolist() + video = vr.get_batch(frame_idx).asnumpy() + video = np.array(video) + else: + if "liangke" in video_file: + video_file = self.list_data_dict[i]["video"] + frame_files = [os.path.join(video_file, f) for f in os.listdir(video_file) if os.path.isfile(os.path.join(video_file, f))] + frame_files.sort() # Ensure the frames are sorted if they are named sequentially + + # TODO: Hard CODE: Determine the indices for uniformly sampling 10 frames + num_frames_to_sample = 10 + + total_frames = len(frame_files) + + sampled_indices = np.linspace(0, total_frames - 1, num_frames_to_sample, dtype=int) + + # Read and store the sampled frames + video = [] + for idx in sampled_indices: + frame_path = frame_files[idx] + try: + with Image.open(frame_path) as img: + frame = img.convert("RGB") + video.append(frame) + except IOError: + print(f"Failed to read frame at path: {frame_path}") + + processor = self.data_args.image_processor + image = processor.preprocess(video, return_tensors="pt")["pixel_values"] + image = [(image, video[0].size, "video")] + # sources = preprocess_multimodal(copy.deepcopy([e["conversations"] for e in sources]), self.data_args) + + else: + sources = copy.deepcopy([e["conversations"] for e in sources]) + + has_image = ("image" in self.list_data_dict[i]) or ("video" in self.list_data_dict[i]) + # data_dict = preprocess(sources, self.tokenizer, has_image=has_image) + data_dict = copy.deepcopy(self.list_data_dict[i]) # inplace modification following + + if "prompt" in data_dict: + prompt = data_dict["prompt"] + prompt = prompt.replace("", "").strip() + prompt = "\n" + prompt + data_dict["prompt"] = prompt + else: + prompt = None + + if suffix == "pkl": + prompt = [query_prompt] + + # image exist in the data + if "image" in self.list_data_dict[i]: + data_dict["image"] = image + elif "video" in self.list_data_dict[i]: + data_dict["image"] = image + elif self.data_args.is_multimodal: + # image does not exist in the data, but the model is multimodal + crop_size = self.data_args.image_processor.crop_size + data_dict["image"] = [ + (torch.zeros(1, 3, crop_size["height"], crop_size["width"]), (crop_size["width"], crop_size["height"]), "text"), + ] + # prompt exist in the data + data_dict["has_image"] = has_image + return data_dict + + +@dataclass +class DPODataCollator(DPODataCollatorWithPadding): + """Collate examples for DPO fine-tuning.""" + + # tokenizer: transformers.PreTrainedTokenizer + + def collate(self, batch): + # first, pad everything to the same length + # input_ids, labels = tuple([instance[key] for instance in instances] + # for key in ("input_ids", "labels")) + # input_ids = torch.nn.utils.rnn.pad_sequence( + # input_ids, + # batch_first=True, + # padding_value=self.tokenizer.pad_token_id) + # labels = torch.nn.utils.rnn.pad_sequence(labels, + # batch_first=True, + # padding_value=IGNORE_INDEX) + # input_ids = input_ids[:, :self.tokenizer.model_max_length] + # labels = labels[:, :self.tokenizer.model_max_length] + # batch = dict( + # input_ids=input_ids, + # labels=labels, + # attention_mask=input_ids.ne(self.tokenizer.pad_token_id), + # ) + padded_batch = {} + for k in batch[0].keys(): + if k.endswith("_input_ids") or k.endswith("_attention_mask") or k.endswith("_labels"): + # if "prompt" in k: + # to_pad = [torch.LongTensor(ex[k][::-1]) for ex in batch] + # else: + to_pad = [torch.LongTensor(ex[k]) for ex in batch] + if k.endswith("_input_ids"): + padding_value = self.tokenizer.pad_token_id + elif k.endswith("_labels"): + padding_value = self.label_pad_token_id + else: + continue + # elif k.endswith("_attention_mask"): + # padding_value = self.padding_value + # else: + # raise ValueError(f"Unexpected key in batch '{k}'") + + padded_batch[k] = torch.nn.utils.rnn.pad_sequence(to_pad, batch_first=True, padding_value=padding_value) + # for the prompt, flip back so padding is on left side + # if "prompt" in k: + # padded_batch[k] = padded_batch[k].flip(dims=[1]) + else: + padded_batch[k] = [ex[k] for ex in batch] + for k in ["chosen_input_ids", "rejected_input_ids"]: + attn_k = k.replace("input_ids", "attention_mask") + padded_batch[attn_k] = padded_batch[k].ne(self.tokenizer.pad_token_id) + return padded_batch + + def tokenize_batch_element(self, prompt: str, chosen: str, rejected: str, has_image: bool = True) -> Dict: + """Tokenize a single batch element. + + At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation + in case the prompt + chosen or prompt + rejected responses is/are too long. First + we truncate the prompt; if we're still too long, we truncate the chosen/rejected. + + We also create the labels for the chosen/rejected responses, which are of length equal to + the sum of the length of the prompt and the chosen/rejected response, with + label_pad_token_id for the prompt tokens. + """ + # import pdb; pdb.set_trace() + batch = {} + + chosen_sources = make_conv(prompt, chosen) + rejected_sources = make_conv(prompt, rejected) + chosen_data_dict = preprocess([chosen_sources], self.tokenizer, has_image=has_image) + # chosen_data_dict['attention_mask'] = chosen_data_dict["input_ids"].ne(self.tokenizer.pad_token_id) + + rejected_data_dict = preprocess([rejected_sources], self.tokenizer, has_image=has_image) + # rejected_data_dict['attention_mask'] = rejected_data_dict["input_ids"].ne(self.tokenizer.pad_token_id) + + chosen_data_dict = {k: v[0] for k, v in chosen_data_dict.items()} + rejected_data_dict = {k: v[0] for k, v in rejected_data_dict.items()} + + for k, toks in { + "chosen": chosen_data_dict, + "rejected": rejected_data_dict, + }.items(): + for type_key, tokens in toks.items(): + if type_key == "token_type_ids": + continue + batch[f"{k}_{type_key}"] = tokens + return batch + + def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: + + tokenized_batch = [] + Xs, keys = [], [] + for feature in features: + prompt = feature["prompt"] + chosen = feature["chosen"] + rejected = feature["rejected"] + has_image = feature["has_image"] + # Xs.append(feature[has_X]) + # keys.append(has_X) + + batch_element = self.tokenize_batch_element(prompt, chosen, rejected, has_image=has_image) + tokenized_batch.append(batch_element) + + # return collated batch + padded_batch = self.collate(tokenized_batch) + # import pdb;pdb.set_trace() + if "image" in features[0]: + # instances[1]['image'][0][0].shape + # torch.Size([5, 3, 224, 224]) + images = [instance["image"] for instance in features] + + padded_batch["image_sizes"] = [im[1] for im_list in images for im in im_list] + padded_batch["modalities"] = [im[2] for im_list in images for im in im_list] + images = [im[0] for im_list in images for im in im_list] + # import pdb;pdb.set_trace() + + padded_batch["images"] = images + # padded_batch["images"] =[padded_batch["modalities"], images] + + return padded_batch + + +def make_dpo_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict: + """Make dataset and collator for supervised fine-tuning.""" + train_dataset = DPODataset(tokenizer=tokenizer, data_path=data_args.data_path, data_args=data_args) + return train_dataset + + +def get_model(model_args, training_args, bnb_model_from_pretrained_args): + assert training_args.attn_implementation + if training_args.attn_implementation == "sdpa" and torch.__version__ < "2.1.2": + raise ValueError("The 'sdpa' attention implementation requires torch version 2.1.2 or higher.") + + ######################### Overwrite config ######################### + customized_kwargs = dict() + customized_kwargs.update(bnb_model_from_pretrained_args) + overwrite_config = {} + cfg_pretrained = None + if "qwen" in model_args.model_name_or_path.lower(): + cfg_pretrained = LlavaQwenConfig.from_pretrained(model_args.model_name_or_path) + elif "mistral" in model_args.model_name_or_path.lower() or "zephyr" in model_args.model_name_or_path.lower(): + cfg_pretrained = LlavaMistralConfig.from_pretrained(model_args.model_name_or_path) + elif ( + "wizardlm-2" in model_args.model_name_or_path.lower() + or "vicuna" in model_args.model_name_or_path.lower() + or "llama" in model_args.model_name_or_path.lower() + or "yi" in model_args.model_name_or_path.lower() + or "nous-hermes" in model_args.model_name_or_path.lower() + and "wizard-2" in model_args.model_name_or_path.lower() + ): + cfg_pretrained = LlavaConfig.from_pretrained(model_args.model_name_or_path) + else: + cfg_pretrained = AutoConfig.from_pretrained(model_args.model_name_or_path) + + if model_args.rope_scaling_factor is not None and model_args.rope_scaling_type is not None and cfg_pretrained is not None: + overwrite_config["rope_scaling"] = { + "factor": model_args.rope_scaling_factor, + "type": model_args.rope_scaling_type, + } + if training_args.model_max_length is None: + training_args.model_max_length = cfg_pretrained.max_position_embeddings * model_args.rope_scaling_factor + overwrite_config["max_sequence_length"] = training_args.model_max_length + assert training_args.model_max_length == int(cfg_pretrained.max_position_embeddings * model_args.rope_scaling_factor), print( + f"model_max_length: {training_args.model_max_length}, max_position_embeddings: {cfg_pretrained.max_position_embeddings}, rope_scaling_factor: {model_args.rope_scaling_factor}" + ) + # overwrite_config["max_sequence_length"] = model_args.max_sequence_length + # overwrite_config["tokenizer_model_max_length"] = model_args.tokenizer_model_max_length + + if model_args.mm_spatial_pool_stride is not None and model_args.mm_spatial_pool_out_channels is not None and model_args.mm_spatial_pool_mode is not None and model_args.mm_resampler_type is not None and cfg_pretrained is not None: + overwrite_config["mm_resampler_type"] = model_args.mm_resampler_type + overwrite_config["mm_spatial_pool_stride"] = model_args.mm_spatial_pool_stride + overwrite_config["mm_spatial_pool_out_channels"] = model_args.mm_spatial_pool_out_channels + overwrite_config["mm_spatial_pool_mode"] = model_args.mm_spatial_pool_mode + + if overwrite_config: + rank0_print(f"Overwriting config with {overwrite_config}") + for k, v in overwrite_config.items(): + setattr(cfg_pretrained, k, v) + + customized_kwargs["config"] = cfg_pretrained + + ######################### Finish Overwrite ########################### + + ref_model = None + if model_args.model_class_name is not None: + actual_model_class_name = f"{model_args.model_class_name}ForCausalLM" + model_class = getattr(transformers, actual_model_class_name) + rank0_print(f"Using model class {model_class} from {model_args.model_class_name}") + model = model_class.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + attn_implementation=training_args.attn_implementation, + torch_dtype=(torch.bfloat16 if training_args.bf16 else None), + low_cpu_mem_usage=False, + **customized_kwargs, + ) + elif model_args.vision_tower is not None: + if "mixtral" in model_args.model_name_or_path.lower(): + model = LlavaMixtralForCausalLM.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + attn_implementation=training_args.attn_implementation, + torch_dtype=(torch.bfloat16 if training_args.bf16 else None), + low_cpu_mem_usage=False, + **customized_kwargs, + ) + from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock + + deepspeed.utils.set_z3_leaf_modules(model, [MixtralSparseMoeBlock]) + elif "mistral" in model_args.model_name_or_path.lower() or "zephyr" in model_args.model_name_or_path.lower(): + model = LlavaMistralForCausalLM.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + attn_implementation=training_args.attn_implementation, + torch_dtype=(torch.bfloat16 if training_args.bf16 else None), + low_cpu_mem_usage=False, + **customized_kwargs, + ) + elif ( + "wizardlm-2" in model_args.model_name_or_path.lower() + or "vicuna" in model_args.model_name_or_path.lower() + or "llama" in model_args.model_name_or_path.lower() + or "yi" in model_args.model_name_or_path.lower() + or "nous-hermes" in model_args.model_name_or_path.lower() + and "wizard-2" in model_args.model_name_or_path.lower() + ): + model = LlavaLlamaForCausalLM.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + attn_implementation=training_args.attn_implementation, + torch_dtype=(torch.bfloat16 if training_args.bf16 else None), + low_cpu_mem_usage=False, + **customized_kwargs, + ) + + if "zero3" in training_args.deepspeed: + rank0_print("#### Initialize reference model #####") + ref_model = LlavaLlamaForCausalLM.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + attn_implementation=training_args.attn_implementation, + torch_dtype=(torch.bfloat16 if training_args.bf16 else None), + low_cpu_mem_usage=False, + **customized_kwargs, + ) + + elif "qwen" in model_args.model_name_or_path.lower() or "quyen" in model_args.model_name_or_path.lower(): + if "moe" in model_args.model_name_or_path.lower(): + model = LlavaQwenMoeForCausalLM.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + attn_implementation=training_args.attn_implementation, + torch_dtype=(torch.bfloat16 if training_args.bf16 else None), + low_cpu_mem_usage=False, + **customized_kwargs, + ) + from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock + + deepspeed.utils.set_z3_leaf_modules(model, [Qwen2MoeSparseMoeBlock]) + else: + model = LlavaQwenForCausalLM.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + attn_implementation=training_args.attn_implementation, + torch_dtype=(torch.bfloat16 if training_args.bf16 else None), + low_cpu_mem_usage=False, + **customized_kwargs, + ) + + if "zero3" in training_args.deepspeed: + rank0_print("#### Initialize reference model #####") + ref_model = LlavaQwenForCausalLM.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + attn_implementation=training_args.attn_implementation, + torch_dtype=(torch.bfloat16 if training_args.bf16 else None), + low_cpu_mem_usage=False, + **customized_kwargs, + ) + + elif "gemma" in model_args.model_name_or_path.lower(): + model = LlavaGemmaForCausalLM.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + attn_implementation=training_args.attn_implementation, + torch_dtype=(torch.bfloat16 if training_args.bf16 else None), + low_cpu_mem_usage=False, + **customized_kwargs, + ) + else: + raise ValueError(f"Unknown model class {model_args}") + else: + model = transformers.LlamaForCausalLM.from_pretrained( + model_args.model_name_or_path, cache_dir=training_args.cache_dir, attn_implementation=training_args.attn_implementation, torch_dtype=(torch.bfloat16 if training_args.bf16 else None), **customized_kwargs + ) + return model, ref_model + + +def train(attn_implementation=None): + global local_rank + + parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + if training_args.verbose_logging: + rank0_print(f"Inspecting experiment hyperparameters:\n") + rank0_print(f"model_args = {vars(model_args)}\n\n") + rank0_print(f"data_args = {vars(data_args)}\n\n") + rank0_print(f"training_args = {vars(training_args)}\n\n") + # rank0_print(f"evaluation_args = {vars(evaluation_args)}\n\n") + + local_rank = training_args.local_rank + compute_dtype = torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32) + + bnb_model_from_pretrained_args = {} + if training_args.bits in [4, 8]: + from transformers import BitsAndBytesConfig + + bnb_model_from_pretrained_args.update( + dict( + device_map={"": training_args.device}, + load_in_4bit=training_args.bits == 4, + load_in_8bit=training_args.bits == 8, + quantization_config=BitsAndBytesConfig( + load_in_4bit=training_args.bits == 4, + load_in_8bit=training_args.bits == 8, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=compute_dtype, + bnb_4bit_use_double_quant=training_args.double_quant, + bnb_4bit_quant_type=training_args.quant_type, # {'fp4', 'nf4'} + ), + ) + ) + + model, ref_model = get_model(model_args, training_args, bnb_model_from_pretrained_args) + model.config.use_cache = False + + if model_args.freeze_backbone: + model.model.requires_grad_(False) + + if training_args.bits in [4, 8]: + from peft import prepare_model_for_kbit_training + + model.config.torch_dtype = torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32) + model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing) + + if training_args.gradient_checkpointing: + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + if ref_model is not None: + ref_model.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + if ref_model is not None: + ref_model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + if training_args.lora_enable: + from peft import LoraConfig, get_peft_model + + lora_config = LoraConfig( + r=training_args.lora_r, + lora_alpha=training_args.lora_alpha, + target_modules=find_all_linear_names(model), + lora_dropout=training_args.lora_dropout, + bias=training_args.lora_bias, + task_type="CAUSAL_LM", + ) + if training_args.bits == 16: + if training_args.bf16: + model.to(torch.bfloat16) + if training_args.fp16: + model.to(torch.float16) + rank0_print("Adding LoRA adapters...") + model = get_peft_model(model, lora_config) + + if "mpt" in model_args.model_name_or_path: + tokenizer = transformers.AutoTokenizer.from_pretrained(model_args.model_name_or_path, cache_dir=training_args.cache_dir, model_max_length=training_args.model_max_length, padding_side="right") + elif "mistral" in model_args.model_name_or_path.lower() or "mixtral" in model_args.model_name_or_path.lower() or "zephyr" in model_args.model_name_or_path.lower(): + tokenizer = transformers.AutoTokenizer.from_pretrained(model_args.model_name_or_path, cache_dir=training_args.cache_dir, model_max_length=training_args.model_max_length, padding_side="left") + elif "qwen" in model_args.model_name_or_path.lower(): + tokenizer = transformers.AutoTokenizer.from_pretrained(model_args.model_name_or_path, cache_dir=training_args.cache_dir, model_max_length=training_args.model_max_length, padding_side="right") + else: # for all other models + tokenizer = transformers.AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + model_max_length=training_args.model_max_length, + padding_side="right", + use_fast=False, + ) + + rank0_print(f"Prompt version: {model_args.version}") + if model_args.version == "v0": + if tokenizer.pad_token is None: + smart_tokenizer_and_embedding_resize( + special_tokens_dict=dict(pad_token="[PAD]"), + tokenizer=tokenizer, + model=model, + ) + elif model_args.version == "v0.5": + tokenizer.pad_token = tokenizer.unk_token + else: + if tokenizer.unk_token is not None: + tokenizer.pad_token = tokenizer.unk_token + if model_args.version in conversation_lib.conv_templates: + conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version] + else: + conversation_lib.default_conversation = conversation_lib.conv_templates["vicuna_v1"] + + if model_args.vision_tower is not None: + model.get_model().initialize_vision_modules(model_args=model_args, fsdp=training_args.fsdp) + + vision_tower = model.get_vision_tower() + vision_tower.to(dtype=torch.bfloat16 if training_args.bf16 else torch.float16, device=training_args.device) + + data_args.image_processor = vision_tower.image_processor + data_args.is_multimodal = True + + model.config.image_aspect_ratio = data_args.image_aspect_ratio + if data_args.image_grid_pinpoints is not None: + # for input like "(1x1)...(3x3)", convert to [(1, 1), (2, 1), (3, 1), (1, 2), (2, 2), (3, 2), (1, 3), (2, 3), (3, 3)] + if "x" in data_args.image_grid_pinpoints and "..." in data_args.image_grid_pinpoints: + vis_encoder_size = data_args.image_processor.size[0] + matches = re.findall(r"\((\d+)x(\d+)\)", data_args.image_grid_pinpoints) + range_start = tuple(map(int, matches[0])) + range_end = tuple(map(int, matches[-1])) + grid_pinpoints = [(i, j) for i in range(range_start[0], range_end[0] + 1) for j in range(range_start[1], range_end[1] + 1)] + grid_pinpoints = [[dim * vis_encoder_size for dim in pair] for pair in grid_pinpoints] + data_args.image_grid_pinpoints = grid_pinpoints + elif "x" in data_args.image_grid_pinpoints: + vis_encoder_size = data_args.image_processor.size[0] + assert vis_encoder_size in [224, 336, 384, 448, 512], "vis_encoder_size should be in [224, 336, 384, 448, 512]" + grid_pinpoints = data_args.image_grid_pinpoints.replace(" ", "").replace("x", ",")[1:-1].split("),(") + data_args.image_grid_pinpoints = [[int(x) * vis_encoder_size for x in item.split(",")] for item in grid_pinpoints] + else: + data_args.image_grid_pinpoints = ast.literal_eval(data_args.image_grid_pinpoints) # for backward compatibility + model.config.image_grid_pinpoints = data_args.image_grid_pinpoints + model.config.image_crop_resolution = data_args.image_crop_resolution + model.config.image_split_resolution = data_args.image_split_resolution + model.config.tokenizer_padding_side = tokenizer.padding_side + model.config.tokenizer_model_max_length = tokenizer.model_max_length + + ### Deciding train which part of the model + if model_args.mm_tunable_parts is None: # traditional way of deciding which part to train + model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter + model.config.tune_mm_vision_resampler = training_args.tune_mm_vision_resampler = model_args.tune_mm_vision_resampler + if model_args.tune_mm_mlp_adapter or model_args.tune_mm_vision_resampler: + model.requires_grad_(False) + if model_args.tune_mm_mlp_adapter: + for p in model.get_model().mm_projector.parameters(): + p.requires_grad = True + if model_args.tune_mm_vision_resampler: + for p in model.get_model().vision_resampler.parameters(): + p.requires_grad = True + + model.config.freeze_mm_mlp_adapter = training_args.freeze_mm_mlp_adapter + if training_args.freeze_mm_mlp_adapter: + for p in model.get_model().mm_projector.parameters(): + p.requires_grad = False + + model.config.freeze_mm_vision_resampler = training_args.freeze_mm_vision_resampler + if training_args.freeze_mm_vision_resampler: + for p in model.get_model().vision_resampler.parameters(): + p.requires_grad = False + + model.config.unfreeze_mm_vision_tower = model_args.unfreeze_mm_vision_tower + if model_args.unfreeze_mm_vision_tower: + vision_tower.requires_grad_(True) + else: + vision_tower.requires_grad_(False) + + else: + rank0_print(f"Using mm_tunable_parts: {model_args.mm_tunable_parts}") + model.config.mm_tunable_parts = training_args.mm_tunable_parts = model_args.mm_tunable_parts + # Set the entire model to not require gradients by default + model.requires_grad_(False) + vision_tower.requires_grad_(False) + model.get_model().mm_projector.requires_grad_(False) + model.get_model().vision_resampler.requires_grad_(False) + # Parse the mm_tunable_parts to decide which parts to unfreeze + tunable_parts = model_args.mm_tunable_parts.split(",") + if "mm_mlp_adapter" in tunable_parts: + for p in model.get_model().mm_projector.parameters(): + p.requires_grad = True + if "mm_vision_resampler" in tunable_parts: + for p in model.get_model().vision_resampler.parameters(): + p.requires_grad = True + if "mm_vision_tower" in tunable_parts: + for name, param in model.named_parameters(): + if "vision_tower" in name: + param.requires_grad_(True) + if "mm_language_model" in tunable_parts: + for name, param in model.named_parameters(): + if "vision_tower" not in name and "mm_projector" not in name and "vision_resampler" not in name: + param.requires_grad_(True) + + total_params = sum(p.ds_numel if hasattr(p, "ds_numel") else p.numel() for p in model.parameters()) + trainable_params = sum(p.ds_numel if hasattr(p, "ds_numel") else p.numel() for p in model.parameters() if p.requires_grad) + rank0_print(f"Total parameters: ~{total_params/1e6:.2f} MB)") + rank0_print(f"Trainable parameters: ~{trainable_params/1e6:.2f} MB)") + if training_args.bits in [4, 8]: + model.get_model().mm_projector.to(dtype=compute_dtype, device=training_args.device) + + model.config.mm_use_im_start_end = data_args.mm_use_im_start_end = model_args.mm_use_im_start_end + model.config.mm_projector_lr = training_args.mm_projector_lr + model.config.mm_vision_tower_lr = training_args.mm_vision_tower_lr + training_args.use_im_start_end = model_args.mm_use_im_start_end + model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token + model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer) + + if ref_model is not None: + ref_model.get_model().initialize_vision_modules(model_args=model_args, fsdp=training_args.fsdp) + ref_vision_tower = ref_model.get_vision_tower() + ref_vision_tower.to(dtype=torch.bfloat16 if training_args.bf16 else torch.float16, device=training_args.device) + ref_model.config.image_aspect_ratio = data_args.image_aspect_ratio + ref_model.config.image_grid_pinpoints = data_args.image_grid_pinpoints + ref_model.config.image_crop_resolution = data_args.image_crop_resolution + ref_model.config.image_split_resolution = data_args.image_split_resolution + ref_model.config.tokenizer_padding_side = tokenizer.padding_side + ref_model.config.tokenizer_model_max_length = tokenizer.model_max_length + ref_model.config.mm_use_im_start_end = data_args.mm_use_im_start_end + ref_model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token + ref_model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer) + parameter_names = [n for n, _ in ref_model.named_parameters()] + for param_name in parameter_names: + param = ref_model.get_parameter(param_name) + param.requires_grad = False + ref_model.eval() + + if training_args.bits in [4, 8]: + from peft.tuners.lora import LoraLayer + + for name, module in model.named_modules(): + if isinstance(module, LoraLayer): + if training_args.bf16: + module = module.to(torch.bfloat16) + if "norm" in name: + module = module.to(torch.float32) + if "lm_head" in name or "embed_tokens" in name: + if hasattr(module, "weight"): + if training_args.bf16 and module.weight.dtype == torch.float32: + module = module.to(torch.bfloat16) + + train_dataset = make_dpo_data_module(tokenizer=tokenizer, data_args=data_args) + data_collator = DPODataCollator( + tokenizer, + label_pad_token_id=IGNORE_INDEX, + pad_token_id=tokenizer.pad_token_id, + ) + + trainer = LLaVADPOTrainer( + model, + ref_model, + args=training_args, + dpo_alpha=training_args.dpo_alpha, + beta=training_args.beta, + gamma=training_args.gamma, + train_dataset=train_dataset, + eval_dataset=None, + data_collator=data_collator, + tokenizer=tokenizer, + max_length=training_args.model_max_length, + generate_during_eval=False, # training_args.generate_during_eval, + precompute_ref_log_probs=training_args.precompute_ref_log_probs, + ) + + if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): + trainer.train(resume_from_checkpoint=True) + else: + trainer.train() + trainer.save_state() + + model.config.use_cache = True + + if training_args.lora_enable: + state_dict = get_peft_state_maybe_zero_3(model.named_parameters(), training_args.lora_bias) + non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(model.named_parameters()) + if training_args.local_rank == 0 or training_args.local_rank == -1: + if hasattr(model, "config"): + model.config.save_pretrained(training_args.output_dir) + if hasattr(model, "generation_config"): + model.generation_config.save_pretrained(training_args.output_dir) + model.save_pretrained(training_args.output_dir, state_dict=state_dict) + torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, "non_lora_trainables.bin")) + else: + safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir) + + rank0_print(f"Model saved to {training_args.output_dir}") + + +if __name__ == "__main__": + train() diff --git a/llava/train/train_mem.py b/llava/train/train_mem.py new file mode 100755 index 000000000..6135ca413 --- /dev/null +++ b/llava/train/train_mem.py @@ -0,0 +1,4 @@ +from llava.train.train import train + +if __name__ == "__main__": + train() diff --git a/llava/utils.py b/llava/utils.py old mode 100644 new mode 100755 index 2bf3bd19d..1e004c33b --- a/llava/utils.py +++ b/llava/utils.py @@ -3,24 +3,86 @@ import logging.handlers import os import sys +import numpy as np import requests from llava.constants import LOGDIR server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" -moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." +moderation_msg = "I am sorry. Your input may violate our content moderation guidelines. Please avoid using harmful or offensive content." handler = None import torch.distributed as dist +try: + import av + from decord import VideoReader, cpu +except ImportError: + print("Please install pyav to use video processing functions.") + +def process_video_with_decord(video_file, data_args): + vr = VideoReader(video_file, ctx=cpu(0), num_threads=1) + total_frame_num = len(vr) + video_time = total_frame_num / vr.get_avg_fps() + avg_fps = round(vr.get_avg_fps() / data_args.video_fps) + frame_idx = [i for i in range(0, total_frame_num, avg_fps)] + frame_time = [i/avg_fps for i in frame_idx] + + + if data_args.frames_upbound > 0: + if len(frame_idx) > data_args.frames_upbound or data_args.force_sample: + uniform_sampled_frames = np.linspace(0, total_frame_num - 1, data_args.frames_upbound, dtype=int) + frame_idx = uniform_sampled_frames.tolist() + frame_time = [i/vr.get_avg_fps() for i in frame_idx] + + video = vr.get_batch(frame_idx).asnumpy() + frame_time = ",".join([f"{i:.2f}s" for i in frame_time]) + + num_frames_to_sample = num_frames = len(frame_idx) + # https://github.com/dmlc/decord/issues/208 + vr.seek(0) + return video, video_time, frame_time, num_frames_to_sample + +def process_video_with_pyav(video_file, data_args): + container = av.open(video_file) + # !!! This is the only difference. Using auto threading + container.streams.video[0].thread_type = "AUTO" + + video_frames = [] + for packet in container.demux(): + if packet.stream.type == 'video': + for frame in packet.decode(): + video_frames.append(frame) + total_frame_num = len(video_frames) + video_time = video_frames[-1].time + avg_fps = round(total_frame_num / video_time / data_args.video_fps) + frame_idx = [i for i in range(0, total_frame_num, avg_fps)] + + if data_args.frames_upbound > 0: + if len(frame_idx) > data_args.frames_upbound: + uniform_sampled_frames = np.linspace(0, total_frame_num - 1, data_args.frames_upbound, dtype=int) + frame_idx = uniform_sampled_frames.tolist() + + + frames = [video_frames[i] for i in frame_idx] + return np.stack([x.to_ndarray(format="rgb24") for x in frames]) + def rank0_print(*args): if dist.is_initialized(): if dist.get_rank() == 0: print(f"Rank {dist.get_rank()}: ", *args) + else: + print(*args) + +def rank_print(*args): + if dist.is_initialized(): + print(f"Rank {dist.get_rank()}: ", *args) + else: + print(*args) def build_logger(logger_name, logger_filename): global handler @@ -121,8 +183,10 @@ def violates_moderation(text): ret = requests.post(url, headers=headers, data=data, timeout=5) flagged = ret.json()["results"][0]["flagged"] except requests.exceptions.RequestException as e: + print(f"######################### Moderation Error: {e} #########################") flagged = False except KeyError as e: + print(f"######################### Moderation Error: {e} #########################") flagged = False return flagged diff --git a/playground/2d_hist.py b/playground/2d_hist.py new file mode 100644 index 000000000..3592348e2 --- /dev/null +++ b/playground/2d_hist.py @@ -0,0 +1,132 @@ +import json +import os +from PIL import Image +from tqdm import tqdm +import matplotlib.pyplot as plt +import numpy as np +from multiprocessing import Pool +import functools +import argparse + + +def load_data(json_path): + with open(json_path, "r") as f: + return json.load(f) + + +def filter_data(data): + filtered_data = [item for item in data if "image" in item] + return filtered_data + + +def calculate_image_dimension(image_path, images_folder): + full_path = os.path.join(images_folder, image_path) + try: + with Image.open(full_path) as img: + width, height = img.size + return width, height + except Exception as e: + print(f"Error opening {full_path}: {e}") + return None, None + + +def calculate_image_dimensions_multiprocess(filtered_data, images_folder, num_processes=256): + image_paths = [] + for item in filtered_data: + if isinstance(item["image"], list): + image_paths.extend(item["image"]) + else: + image_paths.append(item["image"]) + + with Pool(num_processes) as p: + dimensions = list( + tqdm( + p.imap(functools.partial(calculate_image_dimension, images_folder=images_folder), image_paths), + total=len(image_paths), + desc="Calculating image dimensions", + ) + ) + widths, heights = zip(*[dim for dim in dimensions if dim[0] is not None]) + return list(widths), list(heights) + + +def tokenize(text): + return text.split() + + +def calculate_tokenized_lengths(data): + lengths = [] + for item in tqdm(data, desc="Tokenizing conversations"): + for conversation in item["conversations"]: + tokenized_value = tokenize(conversation["value"]) + lengths.append(len(tokenized_value)) + return lengths + + +def main(): + parser = argparse.ArgumentParser(description="Process data for LLaVA_Next project.") + parser.add_argument( + "--json_path", + type=str, + help="Path to the JSON file containing data.", + default="/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/llava_ofa_DEMON-FULL.json", + ) + parser.add_argument( + "--images_folder", + type=str, + default="/mnt/bn/vl-research/data/llava_data", + help="Path to the folder containing images.", + ) + args = parser.parse_args() + + llava_instruct_name = os.path.basename(args.json_path).replace(".json", "") + images_folder = args.images_folder + + data = load_data(args.json_path) + filtered_data = filter_data(data) + + print(f"Total data items: {len(data)}, Filtered data items: {len(filtered_data)}") + widths, heights = calculate_image_dimensions_multiprocess(filtered_data, images_folder) + max_width, max_height = max(widths), max(heights) + print(f"Max width: {max_width}, Max height: {max_height}") + + tokenized_lengths = calculate_tokenized_lengths(filtered_data) + + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(24, 12)) + + # Plot 2D histogram + widths_bins = [min(widths), max(widths) + 1] if min(widths) == max(widths) else np.arange(min(widths), max(widths) + 100, 100) + heights_bins = [min(heights), max(heights) + 1] if min(heights) == max(heights) else np.arange(min(heights), max(heights) + 100, 100) + + h, xedges, yedges, image = ax1.hist2d(widths, heights, bins=[widths_bins, heights_bins], cmap=plt.cm.jet, density=True) + fig.colorbar(image, ax=ax1) + ax1.set_xlabel("Width") + ax1.set_ylabel("Height") + ax1.set_title( + f"dist_{llava_instruct_name}_2d_w_h\nMax width: {max(widths)}, Max height: {max(heights)}", + fontsize=10, + ) + + # Plot histogram + hist, bin_edges = np.histogram(tokenized_lengths, bins=np.arange(0, max(tokenized_lengths) + 10, 10)) + bins = np.arange(0, max(tokenized_lengths) + 10, 10) + ax2.bar(bin_edges[:-1], hist, width=7, edgecolor="black", log=True) + + # Display every nth label on the x-axis + n = 8 # Adjust this value to control the number of labels displayed + ticks = bins[::n] + tick_labels = [int(tick) for tick in ticks] + ax2.set_xticks(ticks) + ax2.set_xticklabels(tick_labels, rotation=90, fontsize=8) + + ax2.set_xlim(min(bin_edges), max(bin_edges)) + ax2.set_xlabel("Tokenized Length") + ax2.set_ylabel("Count (log scale)") + ax2.set_title(f"dist_{llava_instruct_name}_tokenized_length", fontsize=8) + + plt.tight_layout() + plt.savefig(f"./dist_{llava_instruct_name}_combined.png") + + +if __name__ == "__main__": + main() diff --git a/playground/data_checker.py b/playground/data_checker.py new file mode 100644 index 000000000..9d145ca47 --- /dev/null +++ b/playground/data_checker.py @@ -0,0 +1,364 @@ +import json +import os +from tqdm import tqdm +from multiprocessing import Pool, cpu_count +import yaml + + +class DataProcessor: + def __init__(self, file_path, image_root, video_root): + self.file_path = file_path + self.image_root = image_root + self.data = None + self.video_root = video_root + self.load_data() + + def load_data(self): + if self.file_path.endswith(".json"): + with open(self.file_path, "r") as f: + self.data = json.load(f) + elif self.file_path.endswith(".yaml"): + with open(self.file_path, "r") as f: + self.data = yaml.safe_load(f) + elif self.file_path.endswith(".jsonl"): + with open(self.file_path, "r") as f: + self.data = [json.loads(line) for line in f.readlines()] + else: + raise ValueError("Unsupported file format") + + def load_json_data(self, json_path): + if json_path.endswith(".jsonl"): + cur_data_dict = [] + with open(json_path, "r") as json_file: + for line in json_file: + cur_data_dict.append(json.loads(line.strip())) + return cur_data_dict + elif json_path.endswith(".json"): + with open(json_path, "r") as f: + return json.load(f) + else: + raise ValueError("Unsupported file format") + + def check_image_existence(self, data): + if "image" in data: + if type(data["image"]) == list: + images = data["image"] + else: + images = [data["image"]] + + for image in images: + full_image_path = os.path.join(self.image_root, image) + if not os.path.exists(full_image_path): + print(f"WARNING!!! {full_image_path} not exists !!!") + + if "video" in data: + full_video_path = os.path.join(self.video_root, data["video"]) + if not os.path.exists(full_video_path): + print(f"WARNING!!! {full_video_path} not exists !!!") + + # if data["conversations"][0]["value"].count("") > 1: + # print(f"WARNING!!! {data['conversations'][0]['value']} has more than one !!!") + + def check_item_structure(self, item): + if not all(key in item for key in ["conversations"]): + print(f"WARNING!!! Item {item.get('id', 'unknown')} is missing required fields!") + return False + + conversations = item["conversations"] + if not isinstance(conversations, list) or len(conversations) < 2 or len(conversations) % 2 != 0: + print(f"WARNING!!! Item {item['id']} has invalid conversations structure!") + return False + + for i, conv in enumerate(conversations): + if not all(key in conv for key in ["from", "value"]): + print(f"WARNING!!! Item {item['id']} has invalid conversation format!") + return False + + expected_from = "human" if i % 2 == 0 else "gpt" + if conv["from"] != expected_from: + print(f"WARNING!!! Item {item['id']} has incorrect conversation order!") + return False + + return True + + def check_image_and_structure(self, item): + if not self.check_item_structure(item): + return + + # self.check_image_existence(item) + + def process_images(self): + if isinstance(self.data, list): + args = [d for d in self.data] + with Pool(processes=cpu_count()) as pool: + list(tqdm(pool.imap(self.check_image_and_structure, args), total=len(self.data))) + elif isinstance(self.data, dict): + for d in self.data["datasets"]: + dd_json_path = d["json_path"] + data = self.load_json_data(dd_json_path) + args = [d for d in data] + with Pool(processes=cpu_count()) as pool: + list(tqdm(pool.imap(self.check_image_and_structure, args), total=len(data), desc=f"Processing {dd_json_path}")) + + def count_items(self): + if isinstance(self.data, list): # Assuming JSON data loaded directly + return len(self.data) + elif isinstance(self.data, dict): # Assuming YAML data loaded + total_items_count = 0 + for d in self.data["datasets"]: + dd_json_path = d["json_path"] + data = self.load_json_data(dd_json_path) + current_items_count = len(data) + + sampling_strategy = d["sampling_strategy"] + try: + if sampling_strategy != "all": + percentage = float(sampling_strategy.split(":")[-1].replace("%", "")) / 100.0 + else: + percentage = 1.0 + except Exception as e: + print(f"Error: {e}") + percentage = 1.0 + + sampling_count = int(current_items_count * percentage) + total_items_count += sampling_count + print(f"{dd_json_path}: {sampling_count}") + return total_items_count + + def stat_data(self): + if isinstance(self.data, dict): + cur_lens_list = [] + single_image_count = 0 + multiple_image_count = 0 + video_count = 0 + total_count = 0 + text_count = 0 + max_tokens_item = None + max_tokens = 0 + + for d in self.data["datasets"]: + dd_json_path = d["json_path"] + data = self.load_json_data(dd_json_path) + sampling_strategy = d["sampling_strategy"] + + try: + if sampling_strategy != "all": + percentage = float(sampling_strategy.split(":")[-1].replace("%", "")) / 100.0 + else: + percentage = 1.0 + except Exception as e: + print(f"Error parsing sampling strategy: {e}") + percentage = 1.0 + + sampled_count = int(len(data) * percentage) + print(f"{dd_json_path}: {sampled_count} (sampled from {len(data)})") + + for item in data[:sampled_count]: + conversations = item["conversations"] + cur_len = sum([len(conv["value"].split()) for conv in conversations]) + cur_lens_list.append(cur_len) + + if cur_len > max_tokens: + max_tokens = cur_len + max_tokens_item = item + + total_count += 1 + if "image" in item: + if isinstance(item["image"], list): + if len(item["image"]) > 1: + multiple_image_count += 1 + else: + single_image_count += 1 + else: + single_image_count += 1 + elif "video" in item: + video_count += 1 + else: + text_count += 1 + + print(f"Max length: {max(cur_lens_list)}, Min length: {min(cur_lens_list)}, Average length: {sum(cur_lens_list) / len(cur_lens_list)}") + print(f"Total items: {total_count}") + print(f"Text items: {text_count} ({text_count/total_count*100:.2f}%)") + print(f"Single image items: {single_image_count} ({single_image_count/total_count*100:.2f}%)") + print(f"Multiple image items: {multiple_image_count} ({multiple_image_count/total_count*100:.2f}%)") + print(f"Video items: {video_count} ({video_count/total_count*100:.2f}%)") + + print("\nItem with the largest number of tokens:") + print(f"Token count: {max_tokens}") + print("Item content:") + print(json.dumps(max_tokens_item, indent=2)) + + def filter_data(self): + if isinstance(self.data, dict): + for d in self.data["datasets"]: + dd_json_path = d["json_path"] + print(f"Processing {dd_json_path}") + data = self.load_json_data(dd_json_path) + + filtered_data = [] + mismatch_data = [] + mismatch_flag = False + for item in data: + try: + if "image" in item: + num_image = len(item["image"]) if isinstance(item["image"], list) else 1 + else: + num_image = 0 + + if "video" in item: + num_video = len(item["video"]) if isinstance(item["video"], list) else 1 + else: + num_video = 0 + + num_visuals = num_image + num_video + conv_text = "" + for conv in item["conversations"]: + conv_text += conv["value"] + + num_img_token_appearance = conv_text.count("") + if len(conv_text) == 0: + print(f"Conversation text is empty for {item}") + + if num_img_token_appearance == num_visuals or num_img_token_appearance < num_visuals and len(conv_text) > 0: + filtered_data.append(item) + elif num_img_token_appearance > num_visuals: + item["num_img_token_appearance"] = num_img_token_appearance + item["num_visuals"] = num_visuals + mismatch_data.append(item) + + if not mismatch_flag: + print(f"Data mismatch for {item}") + + mismatch_flag = True + except Exception as e: + print(f"Error: {e}") + print() + + if mismatch_flag: + print(f"Data mismatch for {dd_json_path}") + + if len(filtered_data) < len(data): + saving_dd_json_path = dd_json_path.replace(".jsonl", f"fltd_{len(filtered_data)}.json").replace(".json", f"fltd_{len(filtered_data)}.json") + with open(saving_dd_json_path, "w") as f: + json.dump(filtered_data, f, indent=2) + print(f"Filtered data count: {len(filtered_data)}") + else: + pass + + def stat_and_filter_data(self, threshold): + if isinstance(self.data, dict): + cur_lens_list = [] + single_image_count = 0 + multiple_image_count = 0 + video_count = 0 + total_count = 0 + text_count = 0 + + for d in self.data["datasets"]: + dd_json_path = d["json_path"] + data = self.load_json_data(dd_json_path) + sampling_strategy = d["sampling_strategy"] + filtered_data = [] + + try: + if sampling_strategy != "all": + percentage = float(sampling_strategy.split(":")[-1].replace("%", "")) / 100.0 + else: + percentage = 1.0 + except Exception as e: + print(f"Error parsing sampling strategy: {e}") + percentage = 1.0 + + sampled_count = int(len(data) * percentage) + print(f"{dd_json_path}: {sampled_count} (sampled from {len(data)})") + + save_flag = False + for item in data: + total_count += 1 + conversations = item["conversations"] + filtered_conversations = [] + current_token_count = 0 + + for i in range(0, len(conversations), 2): + if i + 1 < len(conversations): + human_conv = conversations[i] + gpt_conv = conversations[i + 1] + pair_tokens = len(human_conv["value"].split()) + len(gpt_conv["value"].split()) + + if current_token_count + pair_tokens <= threshold: + filtered_conversations.extend([human_conv, gpt_conv]) + current_token_count += pair_tokens + else: + save_flag = True + break + + if filtered_conversations: + item["conversations"] = filtered_conversations + cur_len = sum([len(conv["value"].split()) for conv in filtered_conversations]) + cur_lens_list.append(cur_len) + filtered_data.append(item) + + if "image" in item: + if isinstance(item["image"], list): + if len(item["image"]) > 1: + multiple_image_count += 1 + else: + single_image_count += 1 + else: + single_image_count += 1 + elif "video" in item: + video_count += 1 + else: + text_count += 1 + + # Save filtered data for each dataset + if filtered_data and save_flag: + if dd_json_path.endswith(".jsonl"): + output_file = dd_json_path.replace(".jsonl", f"_filtered_{threshold}tokens_{len(filtered_data)}.jsonl") + with open(output_file, "w") as f: + for item in filtered_data: + f.write(json.dumps(item) + "\n") + else: + output_file = dd_json_path.replace(".json", f"_filtered_{threshold}tokens_{len(filtered_data)}.json") + with open(output_file, "w") as f: + json.dump(filtered_data, f, indent=2) + print(f"Filtered data for {dd_json_path} saved to: {output_file}") + + print(f"Max length: {max(cur_lens_list)}, Min length: {min(cur_lens_list)}, Average length: {sum(cur_lens_list) / len(cur_lens_list)}") + print(f"Total items: {total_count}") + print(f"Text items: {text_count} ({text_count/total_count*100:.2f}%)") + print(f"Single image items: {single_image_count} ({single_image_count/total_count*100:.2f}%)") + print(f"Multiple image items: {multiple_image_count} ({multiple_image_count/total_count*100:.2f}%)") + print(f"Video items: {video_count} ({video_count/total_count*100:.2f}%)") + + +def main(file_path, image_root, operation, video_root, threshold=None): + processor = DataProcessor(file_path, image_root, video_root) + if operation == "check": + processor.process_images() + elif operation == "count": + total_items = processor.count_items() + print(f"Total items: {total_items}") + elif operation == "filter": + processor.filter_data() + elif operation == "stat": + processor.stat_data() + elif operation == "stat_and_filter": + if threshold is None: + raise ValueError("Threshold must be provided for stat_and_filter operation") + processor.stat_and_filter_data(threshold) + else: + raise ValueError("Unsupported operation") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--file_path", type=str, default="/mnt/bn/vl-research/workspace/boli01/projects/LLaVA_Next/scripts/i18n/scale_llms/next_continual.yaml") + parser.add_argument("--image_root", type=str, default="/mnt/bn/vl-research/data/llava_data") + parser.add_argument("--video_root", type=str, default="/mnt/bn/vl-research/data/llava_video") + parser.add_argument("--operation", type=str, default="filter") + parser.add_argument("--threshold", type=int, default=None, help="Threshold for stat_and_filter operation") + args = parser.parse_args() + main(args.file_path, args.image_root, args.operation, args.video_root, args.threshold) diff --git a/playground/demo/video_demo.py b/playground/demo/video_demo.py new file mode 100644 index 000000000..d93f65826 --- /dev/null +++ b/playground/demo/video_demo.py @@ -0,0 +1,335 @@ +import argparse +import torch + +from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN +from llava.conversation import conv_templates, SeparatorStyle +from llava.model.builder import load_pretrained_model +from llava.utils import disable_torch_init +from llava.mm_utils import process_anyres_image,tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria + +import json +import os +import math +from tqdm import tqdm +from decord import VideoReader, cpu + +from transformers import AutoConfig + +import cv2 +import base64 +import openai + +from PIL import Image + + + +import numpy as np + +def split_list(lst, n): + """Split a list into n (roughly) equal-sized chunks""" + chunk_size = math.ceil(len(lst) / n) # integer division + return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)] + + +def get_chunk(lst, n, k): + chunks = split_list(lst, n) + return chunks[k] + + +def parse_args(): + """ + Parse command-line arguments. + """ + parser = argparse.ArgumentParser() + + # Define the command-line arguments + parser.add_argument("--video_path", help="Path to the video files.", required=True) + parser.add_argument("--output_dir", help="Directory to save the model results JSON.", required=True) + parser.add_argument("--output_name", help="Name of the file for storing results JSON.", required=True) + parser.add_argument("--model-path", type=str, default="facebook/opt-350m") + parser.add_argument("--model-base", type=str, default=None) + parser.add_argument("--conv-mode", type=str, default=None) + parser.add_argument("--chunk-idx", type=int, default=0) + parser.add_argument("--mm_resampler_type", type=str, default="spatial_pool") + parser.add_argument("--mm_spatial_pool_stride", type=int, default=4) + parser.add_argument("--mm_spatial_pool_out_channels", type=int, default=1024) + parser.add_argument("--mm_spatial_pool_mode", type=str, default="average") + parser.add_argument("--image_aspect_ratio", type=str, default="anyres") + parser.add_argument("--image_grid_pinpoints", type=str, default="[(224, 448), (224, 672), (224, 896), (448, 448), (448, 224), (672, 224), (896, 224)]") + parser.add_argument("--mm_patch_merge_type", type=str, default="spatial_unpad") + parser.add_argument("--overwrite", type=lambda x: (str(x).lower() == 'true'), default=True) + parser.add_argument("--for_get_frames_num", type=int, default=4) + parser.add_argument("--load_8bit", type=lambda x: (str(x).lower() == 'true'), default=False) + parser.add_argument("--prompt", type=str, default=None) + parser.add_argument("--api_key", type=str, help="OpenAI API key") + parser.add_argument("--mm_newline_position", type=str, default="no_token") + parser.add_argument("--force_sample", type=lambda x: (str(x).lower() == 'true'), default=False) + parser.add_argument("--add_time_instruction", type=str, default=False) + return parser.parse_args() + +def load_video(video_path,args): + if args.for_get_frames_num == 0: + return np.zeros((1, 336, 336, 3)) + vr = VideoReader(video_path, ctx=cpu(0),num_threads=1) + total_frame_num = len(vr) + video_time = total_frame_num / vr.get_avg_fps() + fps = round(vr.get_avg_fps()) + frame_idx = [i for i in range(0, len(vr), fps)] + frame_time = [i/fps for i in frame_idx] + if len(frame_idx) > args.for_get_frames_num or args.force_sample: + sample_fps = args.for_get_frames_num + uniform_sampled_frames = np.linspace(0, total_frame_num - 1, sample_fps, dtype=int) + frame_idx = uniform_sampled_frames.tolist() + frame_time = [i/vr.get_avg_fps() for i in frame_idx] + frame_time = ",".join([f"{i:.2f}s" for i in frame_time]) + spare_frames = vr.get_batch(frame_idx).asnumpy() + # import pdb;pdb.set_trace() + + return spare_frames,frame_time,video_time + + + + +def load_video_base64(path): + video = cv2.VideoCapture(path) + + base64Frames = [] + while video.isOpened(): + success, frame = video.read() + if not success: + break + _, buffer = cv2.imencode(".jpg", frame) + base64Frames.append(base64.b64encode(buffer).decode("utf-8")) + + video.release() + # print(len(base64Frames), "frames read.") + return base64Frames + + +def run_inference(args): + """ + Run inference on ActivityNet QA DataSet using the Video-ChatGPT model. + + Args: + args: Command-line arguments. + """ + # Initialize the model + if "gpt4v" != args.model_path: + model_name = get_model_name_from_path(args.model_path) + # Set model configuration parameters if they exist + if args.overwrite == True: + overwrite_config = {} + overwrite_config["mm_spatial_pool_mode"] = args.mm_spatial_pool_mode + overwrite_config["mm_spatial_pool_stride"] = args.mm_spatial_pool_stride + overwrite_config["mm_newline_position"] = args.mm_newline_position + + cfg_pretrained = AutoConfig.from_pretrained(args.model_path) + + # import pdb;pdb.set_trace() + if "qwen" not in args.model_path.lower(): + if "224" in cfg_pretrained.mm_vision_tower: + # suppose the length of text tokens is around 1000, from bo's report + least_token_number = args.for_get_frames_num*(16//args.mm_spatial_pool_stride)**2 + 1000 + else: + least_token_number = args.for_get_frames_num*(24//args.mm_spatial_pool_stride)**2 + 1000 + + scaling_factor = math.ceil(least_token_number/4096) + if scaling_factor >= 2: + if "vicuna" in cfg_pretrained._name_or_path.lower(): + print(float(scaling_factor)) + overwrite_config["rope_scaling"] = {"factor": float(scaling_factor), "type": "linear"} + overwrite_config["max_sequence_length"] = 4096 * scaling_factor + overwrite_config["tokenizer_model_max_length"] = 4096 * scaling_factor + + tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, load_8bit=args.load_8bit, overwrite_config=overwrite_config) + else: + tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name) + else: + pass + + # import pdb;pdb.set_trace() + if getattr(model.config, "force_sample", None) is not None: + args.force_sample = model.config.force_sample + else: + args.force_sample = False + + # import pdb;pdb.set_trace() + + if getattr(model.config, "add_time_instruction", None) is not None: + args.add_time_instruction = model.config.add_time_instruction + else: + args.add_time_instruction = False + + # Create the output directory if it doesn't exist + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + + output_name = args.output_name + answers_file = os.path.join(args.output_dir, f"{output_name}.json") + ans_file = open(answers_file, "w") + + video_path = args.video_path + + all_video_pathes = [] + + # Check if the video_path is a directory or a file + if os.path.isdir(video_path): + # If it's a directory, loop over all files in the directory + for filename in os.listdir(video_path): + # Load the video file + cur_video_path = os.path.join(video_path, f"{filename}") + all_video_pathes.append(os.path.join(video_path, cur_video_path)) + else: + # If it's a file, just process the video + all_video_pathes.append(video_path) + + # import pdb;pdb.set_trace() + for video_path in all_video_pathes: + + sample_set = {} + question = args.prompt + sample_set["Q"] = question + sample_set["video_name"] = video_path + + + # Check if the video exists + if os.path.exists(video_path): + if "gpt4v" != args.model_path: + video,frame_time,video_time = load_video(video_path, args) + video = image_processor.preprocess(video, return_tensors="pt")["pixel_values"].half().cuda() + video = [video] + else: + spare_frames,frame_time,video_time = load_video_base64(video_path) + interval = int(len(video) / args.for_get_frames_num) + + # try: + # Run inference on the video and add the output to the list + if "gpt4v" != args.model_path: + qs = question + if args.add_time_instruction: + time_instruciton = f"The video lasts for {video_time:.2f} seconds, and {len(video[0])} frames are uniformly sampled from it. These frames are located at {frame_time}.Please answer the following questions related to this video." + qs = f'{time_instruciton}\n{qs}' + if model.config.mm_use_im_start_end: + qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + "\n" + qs + else: + qs = DEFAULT_IMAGE_TOKEN + "\n" + qs + + conv = conv_templates[args.conv_mode].copy() + conv.append_message(conv.roles[0], qs) + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + + input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).cuda() + if tokenizer.pad_token_id is None: + if "qwen" in tokenizer.name_or_path.lower(): + print("Setting pad token to bos token for qwen model.") + tokenizer.pad_token_id = 151643 + + attention_masks = input_ids.ne(tokenizer.pad_token_id).long().cuda() + + stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 + keywords = [stop_str] + stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) + + cur_prompt = question + else: + prompt = question + + system_error = "" + + if "gpt4v" != args.model_path: + + + with torch.inference_mode(): + # model.update_prompt([[cur_prompt]]) + # import pdb;pdb.set_trace() + # output_ids = model.generate(inputs=input_ids, images=video, attention_mask=attention_masks, modalities="video", do_sample=True, temperature=0.2, max_new_tokens=1024, use_cache=True, stopping_criteria=[stopping_criteria]) + if "mistral" not in cfg_pretrained._name_or_path.lower(): + output_ids = model.generate(inputs=input_ids, images=video, attention_mask=attention_masks, modalities="video", do_sample=False, temperature=0.0, max_new_tokens=1024, top_p=0.1,num_beams=1,use_cache=True, stopping_criteria=[stopping_criteria]) + # output_ids = model.generate(inputs=input_ids, images=video, attention_mask=attention_masks, modalities="video", do_sample=True, temperature=0.2, max_new_tokens=1024, use_cache=True, stopping_criteria=[stopping_criteria]) + else: + output_ids = model.generate(inputs=input_ids, images=video, attention_mask=attention_masks, modalities="video", do_sample=False, temperature=0.0, max_new_tokens=1024, top_p=0.1, num_beams=1, use_cache=True) + # output_ids = model.generate(inputs=input_ids, images=video, attention_mask=attention_masks, modalities="video", do_sample=True, temperature=0.2, max_new_tokens=1024, use_cache=True) + else: + openai.api_key = args.api_key # Your API key here + + max_num_retries = 0 + retry = 5 + PROMPT_MESSAGES = [ + { + "role": "user", + "content": [ + f"These are frames from a video that I want to upload. Answer me one question of this video: {prompt}", + *map(lambda x: {"image": x, "resize": 336}, video[0::interval]), + ], + }, + ] + params = { + "model": "gpt-4-vision-preview", #gpt-4-1106-vision-preview + "messages": PROMPT_MESSAGES, + "max_tokens": 1024, + } + sucess_flag=False + while max_num_retries < retry: + try: + result = openai.ChatCompletion.create(**params) + outputs = result.choices[0].message.content + sucess_flag = True + break + except Exception as inst : + if 'error' in dir(inst): + # import pdb;pdb.set_trace() + if inst.error.code == 'rate_limit_exceeded': + if "TPM" in inst.error.message: + time.sleep(30) + continue + else: + import pdb;pdb.set_trace() + elif inst.error.code == 'insufficient_quota': + print(f'insufficient_quota key') + exit() + elif inst.error.code == 'content_policy_violation': + print(f'content_policy_violation') + system_error = "content_policy_violation" + + break + print('Find error message in response: ',str(inst.error.message), 'error code: ', str(inst.error.code)) + + continue + if not sucess_flag: + print(f'Calling OpenAI failed after retrying for {max_num_retries} times. Check the logs for details.') + exit() + + if "gpt4v" != args.model_path: + outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip() + else: + print(len(video[0::interval])) + + print(f"Question: {prompt}\n") + print(f"Response: {outputs}\n") + + if "gpt4v" == args.model_path: + if system_error == 'content_policy_violation': + continue + elif system_error == "": + continue + else: + import pdb;pdb.set_trace() + + # import pdb;pdb.set_trace() + if "mistral" not in cfg_pretrained._name_or_path.lower(): + if outputs.endswith(stop_str): + outputs = outputs[: -len(stop_str)] + + outputs = outputs.strip() + + sample_set["pred"] = outputs + ans_file.write(json.dumps(sample_set, ensure_ascii=False) + "\n") + ans_file.flush() + + ans_file.close() + + +if __name__ == "__main__": + args = parse_args() + run_inference(args) \ No newline at end of file diff --git a/playground/demo/xU25MMA2N4aVtYay.mp4 b/playground/demo/xU25MMA2N4aVtYay.mp4 new file mode 100644 index 000000000..5bc7f52ce Binary files /dev/null and b/playground/demo/xU25MMA2N4aVtYay.mp4 differ diff --git a/playground/equal_splitter.py b/playground/equal_splitter.py new file mode 100644 index 000000000..8f89f2d2d --- /dev/null +++ b/playground/equal_splitter.py @@ -0,0 +1,38 @@ +import json +from math import ceil + + +def split_json_file(input_file, n_splits): + # Read the JSON file + with open(input_file, "r") as file: + data = json.load(file) + + # Calculate the size of each split + total_items = len(data) + items_per_split = ceil(total_items / n_splits) + + # Split the data and save into separate files + for i in range(n_splits): + start_index = i * items_per_split + end_index = min((i + 1) * items_per_split, total_items) + split_data = data[start_index:end_index] + + # Write the split data to a new JSON file + with open(f"{input_file.split('.')[0]}_split_{i}.json", "w") as split_file: + json.dump(split_data, split_file, indent=4) + + +def main(): + import argparse + + parser = argparse.ArgumentParser(description="Split a JSON file into multiple parts.") + parser.add_argument("--input_file", type=str, help="The JSON file to split") + parser.add_argument("--n_splits", type=int, help="The number of splits") + + args = parser.parse_args() + + split_json_file(args.input_file, args.n_splits) + + +if __name__ == "__main__": + main() diff --git a/playground/remove_mid_ckpt.py b/playground/remove_mid_ckpt.py new file mode 100644 index 000000000..612b717d6 --- /dev/null +++ b/playground/remove_mid_ckpt.py @@ -0,0 +1,35 @@ +import os +import shutil +import glob + + +def remove_checkpoints(directory, pattern): + # Walk through the directory + for root, dirs, files in os.walk(directory): + # Use glob to find paths matching the pattern + for file_path in glob.glob(os.path.join(root, pattern)): + # Check if it is a directory + if "llava-1.6-mistral-7b" in file_path: + continue + if os.path.isdir(file_path): + # Remove the directory + print(f"Removing {file_path}") + input("Press Enter to continue...") + shutil.rmtree(file_path) + print(f"Removed directory: {file_path}") + else: + print(f"Removing {file_path}") + input("Press Enter to continue...") + # Remove the file + os.remove(file_path) + print(f"Removed file: {file_path}") + + +# Directory containing the checkpoints +directory = "/mnt/bn/vl-research/checkpoints/feng/" + +# Pattern to match in the file names +pattern = "global_step*" + +# Call the function +remove_checkpoints(directory, pattern) diff --git a/playground/sgl_llava_inference_multinode.py b/playground/sgl_llava_inference_multinode.py new file mode 100755 index 000000000..869099d73 --- /dev/null +++ b/playground/sgl_llava_inference_multinode.py @@ -0,0 +1,125 @@ +import argparse +import json +import time +import os +import tqdm +import sglang as sgl +from sglang.test.test_utils import select_sglang_backend +from sglang.utils import dump_state_text + + +@sgl.function +def image_description(s, image_file): + prompt = "Please generate detailed descriptions of the given image." + s += sgl.user(sgl.image(image_file) + prompt) + s += sgl.assistant(sgl.gen("answer", max_tokens=1024, temperature=0.0)) + + +def load_progress(progress_file): + print(f"Load progress from {progress_file}") + if os.path.exists(progress_file): + with open(progress_file, "r") as f: + return json.load(f) + return {"last_index": -1, "last_chunk": -1, "results": [], "annotations": []} + + +def save_progress(progress_file, progress_data): + with open(progress_file, "w") as f: + json.dump(progress_data, f, indent=2) + + +def find_images_in_subfolders(folder_path): + image_extensions = (".png", ".jpg", ".jpeg", ".gif", ".bmp") + image_files = [] + for root, dirs, files in os.walk(folder_path): + for file in files: + if file.endswith(image_extensions): + image_files.append(os.path.join(root, file)) + return image_files + + +def main(args): + dist_rank = args.dist + dist_size = args.total_dist + + base_dir = os.path.dirname(args.result_file) + os.makedirs(base_dir, exist_ok=True) # Ensure the base directory exists + progress_file = f"{base_dir}/progress_{dist_rank}_or_{dist_size}.json" + progress_data = load_progress(progress_file) + + with open(args.json_path, "r") as fp: + data = json.load(fp) + + image_files = [os.path.join(args.images_root, item["image"]) for item in data] + image_files = image_files[: args.limit] if args.limit > 0 else image_files + + # Shard the data + shard_size = len(image_files) // dist_size + start_index = shard_size * dist_rank + end_index = start_index + shard_size if dist_rank < dist_size - 1 else len(image_files) + shard_files = image_files[start_index:end_index] + + print(f"Querying {len(shard_files)} images from index {start_index} to {end_index - 1}") + + # Select backend + backend = select_sglang_backend(args) + sgl.set_default_backend(backend) + + tic = time.time() + batch_size = args.parallel + for batch_start in tqdm.tqdm(range(0, len(shard_files), batch_size)): + batch_end = min(batch_start + batch_size, len(shard_files)) + if batch_start <= progress_data.get("last_index", -1): + print(f"Skipping already processed batch starting at {batch_start}") + continue + batch_arguments = [{"image_file": image_file} for image_file in shard_files[batch_start:batch_end]] + try: + batch_states = image_description.run_batch(batch_arguments, temperature=0, num_threads=args.parallel, progress_bar=False) + for i, ret in enumerate(batch_states): + image_file = batch_arguments[i]["image_file"] + caption = ret.text().split("ASSISTANT:")[-1].strip() + progress_data["annotations"].append({"image_file": image_file, "caption": caption}) + progress_data["last_index"] = batch_start + i # Update last_index relative to this rank's shard + + save_progress(progress_file, progress_data) + except Exception as e: + print(f"Error during batch processing: {e}") + save_progress(progress_file, progress_data) + break + + latency = time.time() - tic + print(f"Latency: {latency:.3f}") + + value = { + "task": "image_captioning", + "backend": args.backend, + "num_gpus": 1, + "latency": round(latency, 3), + "num_requests": len(shard_files), + "parallel": args.parallel, + "results": progress_data["annotations"], + } + + result_file = args.result_file.replace(".json", f"_shard_{dist_rank}_or_{dist_size}.json") + print(f"Write output to {result_file}") + with open(result_file, "w") as fout: + json.dump(value, fout, indent=2) + + save_progress(progress_file, progress_data) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--images_root", type=str, default="/mnt/bn/vl-research/data/llava_data/cc3m") + parser.add_argument("--json_path", type=str, default="/mnt/bn/vl-research/data/llava_instruct/cc3m_recap_requery_363707.json") + parser.add_argument("--max_tokens", type=int, default=1024) + parser.add_argument("--parallel", type=int, default=32) + parser.add_argument("--backend", type=str, default="srt") + parser.add_argument("--host", type=str, default="http://127.0.0.1") + parser.add_argument("--port", type=int, default=30000) + parser.add_argument("--result_file", type=str, default="/mnt/bn/vl-research/workspace/boli01/projects/LLaVA_Next/playground/sgl_llava_inference.json") + parser.add_argument("--limit", type=int, default=-1) + parser.add_argument("--dist", type=int, default=0, help="The rank of the distributed machine") + parser.add_argument("--total_dist", type=int, default=6, help="Total number of distributed machines") + args = parser.parse_args() + main(args) diff --git a/playground/upload_data.py b/playground/upload_data.py new file mode 100644 index 000000000..d1acf2ba4 --- /dev/null +++ b/playground/upload_data.py @@ -0,0 +1,217 @@ +from datasets import Dataset, Features, Value, ClassLabel, Sequence, Image +import json +import PIL.Image as pil_image +from io import BytesIO +from tqdm import tqdm + +json_paths = [ + # "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/mavis_math_metagen_87358.json", + # "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/mavis_math_rule_geo_100000.json", + # "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/k12_printing_train_256646.json", + # "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/iiit5k_annotations_2000.json", + # "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/hme100k_train_clean_74502.json", + # "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/ai2d_azuregpt_detailed_understanding_4874.json", + # "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/infographic_vqa_4404.json", + # "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/infographic_azuregpt4v_1992.json", + # "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/lrv_chart_1787.json", + # "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/lrv_normal_gpt4v_filtered_10500.json", + # "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/scienceqa_nona_context_19218.json", + # "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/allava_instruct_vflan4v_20000.json", + # "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/allava_instruct_laion4v_50000.json", + # "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/textocr_gpt4v_train_converted_25114.json", + # "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/ai2d_train_internvl_single_12413.json", + # "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/textcaps_train_21952.json", + # "/mnt/bn/vl-research/data/llava_instruct/ureader_new/ureader_qa_sft.json", + # "/mnt/bn/vl-research/data/llava_instruct/ureader_new/ureader_cap_sft.json", + # "/mnt/bn/vl-research/data/llava_instruct/ureader_new/ureader_ie_sft.json", + # "/mnt/bn/vl-research/data/llava_instruct/ureader_new/ureader_kg_sft.json", + # "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/vision_flan_filtered_186070.json", + # "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/mathqa_29837.json", + # "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/geo3k_2101.json", + # "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/geo170k_qa_converted_67833.json", + # "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/geo170k_align_converted_60252.json", + # "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/sharegpt4v-coco-50k.json", + # "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/sharegpt4v-knowledge-2k.json", + # "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/sharegpt4v-llava-30k.json", + # "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/sharegpt4v-sam-20k.json", + # "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_CLEVR-Math_5290.json", + # "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_FigureQA_17597.json", + # "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_Geometry3K_9734.json", + # "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_GeoQA+_17172.json", + # "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_GEOS_508.json", + # "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_IconQA_22599.json", + # "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_MapQA_5235.json", + # "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_PMC-VQA_35958.json", + # "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_Super-CLEVR_8652.json", + # "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_TabMWP_22462.json", + # "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_UniGeo_11959.json", + # "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_VizWiz_6614.json", + # "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/magpie_pro_qwen2_72b_st_300000_sp_token_fltd_299992.json", + # "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/magpie_pro_l3_80b_st_300000.json", + # "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/magpie_pro_l3_80b_mt_300000_sp_token_fltd_299998.json", + # "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/image_textualization_dataset_filtered.json", + # "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/cambrian_filtered_gpt4vo_sp_token_fltd_max10k.json", + "/mnt/bn/vl-research/data/llava_instruct/real_vision_flan/sharegpt4o_dataset.jsonl", + "/mnt/bn/vl-research/data/llava_instruct/cauldron/ai2d_llava_format_2434.json", + "/mnt/bn/vl-research/data/llava_instruct/cauldron/aokvqa_16539_llava_format.json", + "/mnt/bn/vl-research/data/llava_instruct/cauldron/chart2text_26961.json", + "/mnt/bn/vl-research/data/llava_instruct/cauldron/chartqa_18265_llava_format.json", + "/mnt/bn/vl-research/data/llava_instruct/cauldron/clevr_70000_llava_format.json", + "/mnt/bn/vl-research/data/llava_instruct/cauldron/diagram_image_to_text_300.json", + "/mnt/bn/vl-research/data/llava_instruct/cauldron/dvqa_200000_llava_format.json", + "/mnt/bn/vl-research/data/llava_instruct/cauldron/figureqa_100000_llava_format.json", + "/mnt/bn/vl-research/data/llava_instruct/cauldron/geomverse_9303.json", + "/mnt/bn/vl-research/data/llava_instruct/cauldron/hateful_memes_8500_llava_format.json", + "/mnt/bn/vl-research/data/llava_instruct/cauldron/hitab_2500_llava_format.json", + "/mnt/bn/vl-research/data/llava_instruct/cauldron/iam_5663.json", + "/mnt/bn/vl-research/data/llava_instruct/cauldron/raven_42000.json", + "/mnt/bn/vl-research/data/llava_instruct/cauldron/iconqa_llava_format_27307.json", + "/mnt/bn/vl-research/data/llava_instruct/cauldron/infographic_vqa_2118_llava_format.json", + "/mnt/bn/vl-research/data/llava_instruct/cauldron/intergps_1280_llava_format.json", + "/mnt/bn/vl-research/data/llava_instruct/cauldron/mapqa_37417_llava_format.json", + "/mnt/bn/vl-research/data/llava_instruct/cauldron/multihiertt_7619.json", + "/mnt/bn/vl-research/data/llava_instruct/cauldron/rendered_text_10000.json", + "/mnt/bn/vl-research/data/llava_instruct/cauldron/robut_sqa_8514.json", + "/mnt/bn/vl-research/data/llava_instruct/cauldron/robut_wikisql_74989.json", + "/mnt/bn/vl-research/data/llava_instruct/cauldron/robut_wtq_38246_llava_format.json", + "/mnt/bn/vl-research/data/llava_instruct/cauldron/screen2words_15730.json", + "/mnt/bn/vl-research/data/llava_instruct/cauldron/scienceqa_llava_format_4976.json", + "/mnt/bn/vl-research/data/llava_instruct/cauldron/tabmwp_22722.json", + "/mnt/bn/vl-research/data/llava_instruct/cauldron/tallyqa_98680_llava_format.json", + "/mnt/bn/vl-research/data/llava_instruct/cauldron/st_vqa_17247_llava_format.json", + "/mnt/bn/vl-research/data/llava_instruct/cauldron/tqa_llava_format_27307.json", + "/mnt/bn/vl-research/data/llava_instruct/cauldron/visual7w_llava_format_14366.json", + "/mnt/bn/vl-research/data/llava_instruct/cauldron/visualmrc_3027.json", + "/mnt/bn/vl-research/data/llava_instruct/cauldron/vqarad_313_llava_format.json", + "/mnt/bn/vl-research/data/llava_instruct/cauldron/vsr_2157_llava_format.json", + "/mnt/bn/vl-research/data/llava_instruct/cauldron/vistext_9969.json", + "/mnt/bn/vl-research/data/llava_instruct/cauldron/websight_10000.json" +] + +short_names = [ + # "mavis_math_metagen", + # "mavis_math_rule_geo", + # "k12_printing", + # "iiit5k", + # "hme100k", + # "ai2d(gpt4v)", + # "infographic_vqa", + # "infographic(gpt4v)", + # "lrv_chart", + # "lrv_normal(filtered)", + # "scienceqa(nona_context)", + # "allava_instruct_vflan4v", + # "allava_instruct_laion4v", + # "textocr(gpt4v)", + # "ai2d(internvl)", + # "textcaps", + # "ureader_qa", # need to re-upload + # "ureader_cap", # need to re-upload + # "ureader_ie", # need to re-upload + # "ureader_kg", # need to re-upload + # "vision_flan(filtered)", + # "mathqa", + # "geo3k", + # "geo170k(qa)", + # "geo170k(align)", + # "sharegpt4v(coco)", + # "sharegpt4v(knowledge)", + # "sharegpt4v(llava)", + # "sharegpt4v(sam)", + # "CLEVR-Math(MathV360K)", + # "FigureQA(MathV360K)", + # "Geometry3K(MathV360K)", + # "GeoQA+(MathV360K)", + # "GEOS(MathV360K)", + # "IconQA(MathV360K)", + # "MapQA(MathV360K)", + # "PMC-VQA(MathV360K)", + # "Super-CLEVR(MathV360K)", + # "TabMWP(MathV360K)", + # "UniGeo(MathV360K)", + # "VizWiz(MathV360K)", + # "magpie_pro(qwen2_72b_st)", + # "magpie_pro(l3_80b_st)", + # "magpie_pro(l3_80b_mt)", + # "image_textualization(filtered)", + # "cambrian(filtered_gpt4vo)", # need to re-upload + "sharegpt4o", + "ai2d(cauldron,llava_format)", + "aokvqa(cauldron,llava_format)", + "chart2text(cauldron)", + "chartqa(cauldron,llava_format)", + "clevr(cauldron,llava_format)", + "diagram_image_to_text(cauldron)", + "dvqa(cauldron,llava_format)", + "figureqa(cauldron,llava_format)", + "geomverse(cauldron)", + "hateful_memes(cauldron,llava_format)", + "hitab(cauldron,llava_format)", + "iam(cauldron)", + "raven(cauldron)", + "iconqa(cauldron,llava_format)", + "infographic_vqa_llava_format", + "intergps(cauldron,llava_format)", + "mapqa(cauldron,llava_format)", + "multihiertt(cauldron)", + "rendered_text(cauldron)", + "robut_sqa(cauldron)", + "robut_wikisql(cauldron)", + "robut_wtq(cauldron,llava_format)", + "screen2words(cauldron)", + "scienceqa(cauldron,llava_format)", + "tabmwp(cauldron)", + "tallyqa(cauldron,llava_format)", + "st_vqa(cauldron,llava_format)", + "tqa(cauldron,llava_format)", + "visual7w(cauldron,llava_format)", + "visualmrc(cauldron)", + "vqarad(cauldron,llava_format)", + "vsr(cauldron,llava_format)", + "vistext(cauldron)", + "websight(cauldron)" +] + +def upload_data(json_path, short_name): + def gen(): + if json_path.endswith(".jsonl"): + with open(json_path, "r") as f: + data = [json.loads(line) for line in f] + else: + with open(json_path, "r") as f: + data = json.load(f) + + preview_index = 5 + idx = 0 + for item in tqdm(data): + if preview_index > 0: + preview_index -= 1 + print(item) + continue + + try: + if "image" in item: + image_path = f"/mnt/bn/vl-research/data/llava_data/{item['image']}" + try: + with open(image_path, "rb") as img_file: + image = pil_image.open(BytesIO(img_file.read())) + except: + print(f"Failed to load image {item['image']}") + continue + else: + image = None + + item_id = item["id"] if "id" in item else f"{idx:06d}" + yield {"id": item_id, "image": image, "conversations": item["conversations"], "data_source": short_name} + idx += 1 + + except Exception as e: + print(e) + continue + + + hf_dataset = Dataset.from_generator(generator=gen, num_proc=32) + hf_dataset.push_to_hub("lmms-lab/LLaVA-OneVision-Data", config_name=short_name, split="train") + +for json_path, short_name in zip(json_paths, short_names): + upload_data(json_path, short_name) \ No newline at end of file diff --git a/predict.py b/predict.py old mode 100644 new mode 100755 index 25367e95e..ae53e11ac --- a/predict.py +++ b/predict.py @@ -4,7 +4,7 @@ from llava.conversation import conv_templates, SeparatorStyle from llava.model.builder import load_pretrained_model from llava.utils import disable_torch_init -from llava.mm_utils import tokenizer_image_token +from llava.mm_utils import tokenizer_image_token, KeywordsStoppingCriteria from transformers.generation.streamers import TextIteratorStreamer from PIL import Image @@ -38,19 +38,16 @@ "special_tokens_map.json", "tokenizer.model", "tokenizer_config.json", - ] + ], }, { "dest": "openai/clip-vit-large-patch14-336", "src": "clip-vit-large-patch14-336/ce19dc912ca5cd21c8a653c79e251e808ccabcd1", - "files": [ - "config.json", - "preprocessor_config.json", - "pytorch_model.bin" - ], - } + "files": ["config.json", "preprocessor_config.json", "pytorch_model.bin"], + }, ] + def download_json(url: str, dest: Path): res = requests.get(url, allow_redirects=True) if res.status_code == 200 and res.content: @@ -81,7 +78,6 @@ def setup(self) -> None: for weight in weights: download_weights(weight["src"], weight["dest"], weight["files"]) disable_torch_init() - self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model("liuhaotian/llava-v1.5-13b", model_name="llava-v1.5-13b", model_base=None, load_8bit=False, load_4bit=False) def predict( @@ -93,37 +89,33 @@ def predict( max_tokens: int = Input(description="Maximum number of tokens to generate. A word is generally 2-3 tokens", default=1024, ge=0), ) -> ConcatenateIterator[str]: """Run a single prediction on the model""" - + conv_mode = "llava_v1" conv = conv_templates[conv_mode].copy() - + image_data = load_image(str(image)) - image_tensor = self.image_processor.preprocess(image_data, return_tensors='pt')['pixel_values'].half().cuda() - + image_tensor = self.image_processor.preprocess(image_data, return_tensors="pt")["pixel_values"].half().cuda() + # loop start - + # just one turn, always prepend image token - inp = DEFAULT_IMAGE_TOKEN + '\n' + prompt + inp = DEFAULT_IMAGE_TOKEN + "\n" + prompt conv.append_message(conv.roles[0], inp) conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() - - input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() + + input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).cuda() stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 keywords = [stop_str] + stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids) streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, timeout=20.0) - + with torch.inference_mode(): - thread = Thread(target=self.model.generate, kwargs=dict( - inputs=input_ids, - images=image_tensor, - do_sample=True, - temperature=temperature, - top_p=top_p, - max_new_tokens=max_tokens, - streamer=streamer, - use_cache=True)) + thread = Thread( + target=self.model.generate, + kwargs=dict(inputs=input_ids, images=image_tensor, do_sample=True, temperature=temperature, top_p=top_p, max_new_tokens=max_tokens, streamer=streamer, use_cache=True, stopping_criteria=[stopping_criteria]), + ) thread.start() # workaround: second-to-last token is always " " # but we want to keep it if it's not the second-to-last token @@ -133,7 +125,7 @@ def predict( prepend_space = True continue if new_text.endswith(stop_str): - new_text = new_text[:-len(stop_str)].strip() + new_text = new_text[: -len(stop_str)].strip() prepend_space = False elif prepend_space: new_text = " " + new_text @@ -143,13 +135,12 @@ def predict( if prepend_space: yield " " thread.join() - + def load_image(image_file): - if image_file.startswith('http') or image_file.startswith('https'): + if image_file.startswith("http") or image_file.startswith("https"): response = requests.get(image_file) - image = Image.open(BytesIO(response.content)).convert('RGB') + image = Image.open(BytesIO(response.content)).convert("RGB") else: - image = Image.open(image_file).convert('RGB') + image = Image.open(image_file).convert("RGB") return image - diff --git a/pyproject.toml b/pyproject.toml old mode 100644 new mode 100755 index 6a8ccb34d..23044a9df --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ build-backend = "setuptools.build_meta" [project] name = "llava" version = "1.7.0.dev0" -description = "Towards GPT-4 like large language and visual assistant." +description = "LLaVA OneVision: The Next Generation of LLaVA with Better Image and Video Understanding Capabilities" readme = "README.md" requires-python = ">=3.8" classifiers = [ @@ -20,17 +20,16 @@ classifiers = [ standalone = [ "shortuuid", "httpx==0.24.0", - "accelerate>=0.29.1", - "tokenizers~=0.15.2", + "einops", "ftfy", - "open_clip_torch" ] + train = [ "llava[standalone]", - "einops", + "numpy==1.26.1", + "open_clip_torch", "fastapi", - "gradio==3.35.2", "markdown2[all]", "numpy", "requests", @@ -38,26 +37,36 @@ train = [ "torch==2.1.2", "torchvision==0.16.2", "uvicorn", - "wandb==0.16.5", - "deepspeed==0.12.2", + "wandb", + "deepspeed==0.14.4", "peft==0.4.0", - "transformers>=4.39.0", + "accelerate>=0.29.1", + "tokenizers~=0.15.2", + "transformers@git+https://github.com/huggingface/transformers.git@1c39974a4c4036fd641bc1191cc32799f85715a4", "bitsandbytes==0.41.0", "scikit-learn==1.2.2", "sentencepiece~=0.1.99", "einops==0.6.1", "einops-exts==0.0.4", "gradio_client==0.2.9", - "pydantic", + "urllib3<=2.0.0", + "datasets==2.16.1", + "pydantic==1.10.8", "timm", "hf_transfer", + "opencv-python", + "av", + "decord", + "tyro", + "scipy", ] [project.urls] "Homepage" = "https://llava-vl.github.io" -"Bug Tracker" = "https://github.com/EvolvingLMMs-Lab/LLaVA-NEXT/issues" +"Bug Tracker" = "https://github.com/haotian-liu/LLaVA/issues" [tool.setuptools.packages.find] +include = ["llava*", "trl*"] exclude = [ "assets*", "benchmark*", diff --git a/requirements.txt b/requirements.txt new file mode 100755 index 000000000..66ea557ef --- /dev/null +++ b/requirements.txt @@ -0,0 +1,324 @@ +Babel==2.14.0 +DataProperty==1.0.1 +Deprecated==1.2.14 +GitPython==3.1.43 +Jinja2==3.1.3 +Levenshtein==0.25.1 +MarkupSafe==2.1.5 +PyJWT==2.8.0 +PyYAML==6.0.1 +Pygments==2.17.2 +QtPy==2.4.1 +Send2Trash==1.8.3 +absl-py==2.1.0 +accelerate==0.29.3 +aiofiles==22.1.0 +aiohttp==3.9.5 +aiosignal==1.3.1 +aiosqlite==0.20.0 +altair==5.3.0 +anyio==4.3.0 +appdirs==1.4.4 +argon2-cffi-bindings==21.2.0 +argon2-cffi==23.1.0 +arrow==1.3.0 +asttokens==2.4.1 +async-timeout==4.0.3 +attrs==23.1.0 +beautifulsoup4==4.12.3 +bidict==0.23.1 +bitsandbytes==0.41.0 +black==24.1.0 +bleach==6.1.0 +byted-remote-ikernel==0.4.8 +byted-torch-monitor==0.0.1 +byted-wandb==0.13.72 +bytedance-context==0.7.1 +bytedance-metrics==0.5.1 +bytedance.modelhub==0.0.64 +bytedance.servicediscovery==0.1.2 +bytedbackgrounds==0.0.6 +byteddatabus==1.0.6 +byteddps==0.1.2 +bytedenv==0.6.2 +bytedlogger==0.15.1 +bytedmemfd==0.2 +bytedmetrics==0.10.2 +bytedpymongo==2.0.5 +bytedrh2==1.18.7a2 +bytedservicediscovery==0.17.4 +bytedtcc==1.4.2 +bytedtos==1.1.16 +bytedtrace==0.3.0 +bytedztijwthelper==0.0.22 +bytedztispiffe==0.0.11 +certifi==2024.2.2 +cffi==1.16.0 +cfgv==3.4.0 +chardet==5.2.0 +charset-normalizer==3.3.2 +click==8.1.7 +colorama==0.4.6 +comm==0.2.2 +contourpy==1.2.1 +crcmod==1.7 +cryptography==38.0.4 +cycler==0.12.1 +datasets==2.16.1 +debugpy==1.8.1 +decorator==5.1.1 +decord==0.6.0 +deepspeed==0.12.2 +defusedxml==0.7.1 +dill==0.3.7 +distlib==0.3.8 +distro==1.9.0 +dnspython==2.6.1 +docker-pycreds==0.4.0 +docstring_parser==0.16 +einops-exts==0.0.4 +einops==0.6.1 +entrypoints==0.4 +et-xmlfile==1.1.0 +eval_type_backport==0.2.0 +evaluate==0.4.1 +exceptiongroup==1.2.1 +executing==2.0.1 +fastapi==0.110.2 +fastjsonschema==2.19.1 +ffmpy==0.3.2 +filelock==3.13.4 +flash-attn==2.5.7 +fonttools==4.51.0 +fqdn==1.5.1 +frozenlist==1.4.1 +fsspec==2023.10.0 +ftfy==6.2.0 +gitdb==4.0.11 +gradio==3.35.2 +gradio_client==0.2.9 +grpcio==1.62.2 +h11==0.14.0 +hf_transfer==0.1.6 +hjson==3.1.0 +httpcore==0.17.3 +httpx==0.24.0 +huggingface-hub==0.22.2 +identify==2.5.36 +idna==3.7 +importlib_metadata==7.1.0 +importlib_resources==6.4.0 +iniconfig==2.0.0 +ipaddress==1.0.23 +ipykernel==6.29.4 +ipython-genutils==0.2.0 +ipython==8.18.1 +ipywidgets==8.1.2 +isoduration==20.11.0 +jedi==0.19.1 +joblib==1.4.0 +json5==0.9.25 +jsonlines==4.0.0 +jsonpointer==2.4 +jsonschema-specifications==2023.12.1 +jsonschema==4.21.1 +jupyter-client==7.0.0 +jupyter-console==6.6.3 +jupyter-events==0.10.0 +jupyter-ydoc==0.2.5 +jupyter==1.0.0 +jupyter_core==5.7.2 +jupyter_server==2.14.0 +jupyter_server_fileid==0.9.2 +jupyter_server_terminals==0.5.3 +jupyter_server_ydoc==0.8.0 +jupyterlab==3.6.4 +jupyterlab_pygments==0.3.0 +jupyterlab_server==2.27.1 +jupyterlab_widgets==3.0.10 +kiwisolver==1.4.5 +linkify-it-py==2.0.3 +llava==1.7.0.dev0 +llava==1.7.0.dev0 +lmms_eval==0.1.1 +lxml==5.2.1 +markdown-it-py==2.2.0 +markdown2==2.4.13 +matplotlib-inline==0.1.7 +matplotlib==3.8.4 +mbstrdecoder==1.1.3 +mdit-py-plugins==0.3.3 +mdurl==0.1.2 +mistune==3.0.2 +mpmath==1.3.0 +msgpack==1.0.8 +multidict==6.0.5 +multiprocess==0.70.15 +mypy-extensions==1.0.0 +nbclassic==1.0.0 +nbclient==0.10.0 +nbconvert==7.16.3 +nbformat==5.10.4 +nest-asyncio==1.6.0 +networkx==3.2.1 +ninja==1.11.1.1 +nltk==3.8.1 +nodeenv==1.8.0 +notebook==6.5.6 +notebook_shim==0.2.4 +numexpr==2.10.0 +numpy==1.26.4 +nvidia-cublas-cu12==12.1.3.1 +nvidia-cuda-cupti-cu12==12.1.105 +nvidia-cuda-nvrtc-cu12==12.1.105 +nvidia-cuda-runtime-cu12==12.1.105 +nvidia-cudnn-cu12==8.9.2.26 +nvidia-cufft-cu12==11.0.2.54 +nvidia-curand-cu12==10.3.2.106 +nvidia-cusolver-cu12==11.4.5.107 +nvidia-cusparse-cu12==12.1.0.106 +nvidia-nccl-cu12==2.18.1 +nvidia-nvjitlink-cu12==12.4.127 +nvidia-nvtx-cu12==12.1.105 +open-clip-torch==2.24.0 +openai==1.23.6 +opencv-python-headless==4.9.0.80 +openpyxl==3.1.2 +orjson==3.10.1 +overrides==7.7.0 +packaging==24.0 +pandas==2.2.2 +pandocfilters==1.5.1 +parso==0.8.4 +pathlib2==2.3.7.post1 +pathspec==0.12.1 +pathtools==0.1.2 +pathvalidate==3.2.0 +peft==0.4.0 +pexpect==4.8.0 +pillow==10.3.0 +pip==23.3.1 +pip==24.0 +platformdirs==4.2.1 +pluggy==1.5.0 +ply==3.11 +portalocker==2.8.2 +pre-commit==3.7.0 +prometheus_client==0.20.0 +promise==2.3 +prompt-toolkit==3.0.43 +protobuf==3.20.3 +psutil==5.9.8 +ptyprocess==0.7.0 +pure-eval==0.2.2 +py-cpuinfo==9.0.0 +py-spy==0.3.14 +py==1.11.0 +pyOpenSSL==22.1.0 +pyarrow-hotfix==0.6 +pyarrow==16.0.0 +pybind11==2.12.0 +pycocoevalcap==1.2 +pycocotools==2.0.7 +pycparser==2.22 +pycryptodomex==3.20.0 +pydantic==1.10.8 +pydub==0.25.1 +pynvml==11.5.0 +pyparsing==3.1.2 +pytablewriter==1.2.0 +pytest==6.2.5 +python-consul==1.1.0 +python-dateutil==2.9.0.post0 +python-engineio==4.9.0 +python-etcd==0.4.5 +python-json-logger==2.0.7 +python-multipart==0.0.9 +python-socketio==5.11.2 +pytz==2024.1 +pyzmq==24.0.1 +qtconsole==5.5.1 +rapidfuzz==3.8.1 +referencing==0.35.0 +regex==2024.4.16 +requests==2.31.0 +responses==0.18.0 +rfc3339-validator==0.1.4 +rfc3986-validator==0.1.1 +rich==13.7.1 +rouge-score==0.1.2 +rpds-py==0.18.0 +sacrebleu==2.4.2 +safetensors==0.4.3 +schedule==1.2.1 +scikit-learn==1.2.2 +scipy==1.13.0 +semantic-version==2.10.0 +sentencepiece==0.1.99 +sentry-sdk==2.0.0 +setproctitle==1.3.3 +setuptools==68.2.2 +shortuuid==1.0.13 +shtab==1.7.1 +simple-websocket==1.0.0 +six==1.16.0 +smmap==5.0.1 +sniffio==1.3.1 +soupsieve==2.5 +sqlitedict==2.1.0 +stack-data==0.6.3 +starlette==0.37.2 +svgwrite==1.4.3 +sympy==1.12 +tabledata==1.3.3 +tabulate==0.9.0 +tcolorpy==0.1.4 +tenacity==8.2.3 +terminado==0.18.1 +threadpoolctl==3.4.0 +thriftpy2==0.4.20 +tiktoken==0.6.0 +timm==0.9.16 +tinycss2==1.3.0 +tokenizers==0.15.2 +toml==0.10.2 +tomli==2.0.1 +toolz==0.12.1 +torch==2.1.2 +torchvision==0.16.2 +tornado==6.4 +tox==3.28.0 +tqdm-multiprocess==0.0.11 +tqdm==4.66.2 +traitlets==5.14.3 +transformers-stream-generator==0.0.5 +transformers==4.40.0.dev0 +triton==2.1.0 +typepy==1.3.2 +types-python-dateutil==2.9.0.20240316 +typing_extensions==4.11.0 +tyro==0.8.3 +tzdata==2024.1 +uc-micro-py==1.0.3 +uri-template==1.3.0 +urllib3==2.2.1 +uvicorn==0.29.0 +virtualenv==20.26.0 +wandb==0.16.5 +watchdog==4.0.0 +wavedrom==2.0.3.post3 +wcwidth==0.2.13 +webcolors==1.13 +webencodings==0.5.1 +websocket-client==1.8.0 +websockets==12.0 +wheel==0.41.2 +widgetsnbextension==4.0.10 +wrapt==1.16.0 +wsproto==1.2.0 +xxhash==3.4.1 +y-py==0.6.2 +yarl==1.9.4 +ypy-websocket==0.8.4 +zipp==3.18.1 +zstandard==0.22.0 \ No newline at end of file diff --git a/scripts/archived/convert_gqa_for_eval.py b/scripts/archived/convert_gqa_for_eval.py new file mode 100755 index 000000000..2842975e7 --- /dev/null +++ b/scripts/archived/convert_gqa_for_eval.py @@ -0,0 +1,18 @@ +import os +import json +import argparse + +parser = argparse.ArgumentParser() +parser.add_argument("--src", type=str) +parser.add_argument("--dst", type=str) +args = parser.parse_args() + +all_answers = [] +for line_idx, line in enumerate(open(args.src)): + res = json.loads(line) + question_id = res["question_id"] + text = res["text"].rstrip(".").lower() + all_answers.append({"questionId": question_id, "prediction": text}) + +with open(args.dst, "w") as f: + json.dump(all_answers, f) diff --git a/scripts/archived/convert_mmvet_for_eval.py b/scripts/archived/convert_mmvet_for_eval.py new file mode 100755 index 000000000..80ff31520 --- /dev/null +++ b/scripts/archived/convert_mmvet_for_eval.py @@ -0,0 +1,18 @@ +import os +import json +import argparse + +parser = argparse.ArgumentParser() +parser.add_argument("--src", type=str) +parser.add_argument("--dst", type=str) +args = parser.parse_args() + +cur_result = {} + +for line in open(args.src): + data = json.loads(line) + qid = data["question_id"] + cur_result[f"v1_{qid}"] = data["text"] + +with open(args.dst, "w") as f: + json.dump(cur_result, f, indent=2) diff --git a/scripts/archived/convert_sqa_to_llava.py b/scripts/archived/convert_sqa_to_llava.py new file mode 100755 index 000000000..b0d242234 --- /dev/null +++ b/scripts/archived/convert_sqa_to_llava.py @@ -0,0 +1,88 @@ +import json +import os +import fire +import re +from convert_sqa_to_llava_base_prompt import build_prompt_chatbot + + +def convert_to_llava(base_dir, split, prompt_format="QCM-LEA"): + split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[split] + problems = json.load(open(os.path.join(base_dir, "problems.json"))) + + split_problems = build_prompt_chatbot(problems, split_indices, prompt_format, use_caption=False, is_test=False) + + target_format = [] + for prob_id, (input, output) in split_problems.items(): + if input.startswith("Question: "): + input = input.replace("Question: ", "") + if output.startswith("Answer: "): + output = output.replace("Answer: ", "") + + raw_prob_data = problems[prob_id] + if raw_prob_data["image"] is None: + target_format.append( + { + "id": prob_id, + "conversations": [ + {"from": "human", "value": f"{input}"}, + {"from": "gpt", "value": f"{output}"}, + ], + } + ) + + else: + target_format.append( + { + "id": prob_id, + "image": os.path.join(prob_id, raw_prob_data["image"]), + "conversations": [ + {"from": "human", "value": f"{input}\n"}, + {"from": "gpt", "value": f"{output}"}, + ], + } + ) + + print(f"Number of samples: {len(target_format)}") + + with open(os.path.join(base_dir, f"llava_{split}_{prompt_format}.json"), "w") as f: + json.dump(target_format, f, indent=2) + + +def convert_to_jsonl(base_dir, split, prompt_format="QCM-LEPA"): + split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[split] + problems = json.load(open(os.path.join(base_dir, "problems.json"))) + + split_problems = build_prompt_chatbot(problems, split_indices, prompt_format, use_caption=False, is_test=False) + + writer = open(os.path.join(base_dir, f"scienceqa_{split}_{prompt_format}.jsonl"), "w") + for prob_id, (input, output) in split_problems.items(): + if input.startswith("Question: "): + input = input.replace("Question: ", "") + if output.startswith("Answer: "): + output = output.replace("Answer: ", "") + + raw_prob_data = problems[prob_id] + if raw_prob_data["image"] is None: + data = { + "id": prob_id, + "instruction": f"{input}", + "output": f"{output}", + } + + else: + data = { + "id": prob_id, + "image": os.path.join(prob_id, raw_prob_data["image"]), + "instruction": f"{input}\n", + "output": f"{output}", + } + writer.write(json.dumps(data) + "\n") + writer.close() + + +def main(task, **kwargs): + globals()[task](**kwargs) + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/scripts/archived/convert_sqa_to_llava_base_prompt.py b/scripts/archived/convert_sqa_to_llava_base_prompt.py new file mode 100755 index 000000000..9a457dc4b --- /dev/null +++ b/scripts/archived/convert_sqa_to_llava_base_prompt.py @@ -0,0 +1,298 @@ +def get_question_text(problem): + question = problem["question"] + return question + + +def get_context_text(problem, use_caption): + txt_context = problem["hint"] + img_context = problem["caption"] if use_caption else "" + context = " ".join([txt_context, img_context]).strip() + if context == "": + context = "N/A" + return context + + +def get_choice_text(probelm, options): + choices = probelm["choices"] + choice_list = [] + for i, c in enumerate(choices): + choice_list.append("({}) {}".format(options[i], c)) + choice_txt = " ".join(choice_list) + # print(choice_txt) + return choice_txt + + +def get_answer(problem, options): + return options[problem["answer"]] + + +def get_lecture_text(problem): + # \\n: GPT-3 can generate the lecture with more tokens. + lecture = problem["lecture"].replace("\n", "\\n") + return lecture + + +def get_solution_text(problem): + # \\n: GPT-3 can generate the solution with more tokens + solution = problem["solution"].replace("\n", "\\n") + return solution + + +def create_one_example_chatbot(format, question, context, choice, answer, lecture, solution, test_example=True): + + input_format, output_format = format.split("-") + + ## Inputs + if input_format == "CQM": + input = f"Context: {context}\nQuestion: {question}\nOptions: {choice}\n" + elif input_format == "QCM": + input = f"Question: {question}\nContext: {context}\nOptions: {choice}\n" + # upper bound experiment + elif input_format == "QCML": + input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture}\n" + elif input_format == "QCME": + input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {solution}\n" + elif input_format == "QCMLE": + input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture} {solution}\n" + + elif input_format == "QCLM": + input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture}\nOptions: {choice}\n" + elif input_format == "QCEM": + input = f"Question: {question}\nContext: {context}\nBECAUSE: {solution}\nOptions: {choice}\n" + elif input_format == "QCLEM": + input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture} {solution}\nOptions: {choice}\n" + + # Outputs + if test_example: + output = "Answer:" + elif output_format == "A": + output = f"Answer: The answer is {answer}." + + elif output_format == "AL": + output = f"Answer: The answer is {answer}. BECAUSE: {solution}" + elif output_format == "AE": + output = f"Answer: The answer is {answer}. BECAUSE: {lecture}" + elif output_format == "ALE": + output = f"Answer: The answer is {answer}. BECAUSE: {lecture} {solution}" + elif output_format == "AEL": + output = f"Answer: The answer is {answer}. BECAUSE: {solution} {lecture}" + + elif output_format == "LA": + output = f"Answer: {lecture} The answer is {answer}." + elif output_format == "EA": + output = f"Answer: {solution} The answer is {answer}." + elif output_format == "LEA": + output = f"Answer: {lecture} {solution} The answer is {answer}." + elif output_format == "ELA": + output = f"Answer: {solution} {lecture} The answer is {answer}." + elif output_format == "LEPA": + output = "" + if len(lecture.strip()) > 0: + output += f"LECTURE: {lecture}\n" + if len(solution.strip()) > 0: + output += f"SOLUTION: {solution}\n" + output += "###\n" + output += f"ANSWER: {answer}." + + input = input.replace(" ", " ").strip() + output = output.replace(" ", " ").strip() + if input.endswith("BECAUSE:"): + input = input.replace("BECAUSE:", "").strip() + if output.endswith("BECAUSE:"): + output = output.replace("BECAUSE:", "").strip() + return input, output + + +def create_one_example(format, question, context, choice, answer, lecture, solution, test_example=True): + + input_format, output_format = format.split("-") + + ## Inputs + if input_format == "CQM": + input = f"Context: {context}\nQuestion: {question}\nOptions: {choice}\n" + elif input_format == "QCM": + input = f"Question: {question}\nContext: {context}\nOptions: {choice}\n" + # upper bound experiment + elif input_format == "QCML": + input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture}\n" + elif input_format == "QCME": + input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {solution}\n" + elif input_format == "QCMLE": + input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture} {solution}\n" + + elif input_format == "QCLM": + input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture}\nOptions: {choice}\n" + elif input_format == "QCEM": + input = f"Question: {question}\nContext: {context}\nBECAUSE: {solution}\nOptions: {choice}\n" + elif input_format == "QCLEM": + input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture} {solution}\nOptions: {choice}\n" + + # Outputs + if test_example: + output = "Answer:" + elif output_format == "A": + output = f"Answer: The answer is {answer}." + + elif output_format == "AL": + output = f"Answer: The answer is {answer}. BECAUSE: {solution}" + elif output_format == "AE": + output = f"Answer: The answer is {answer}. BECAUSE: {lecture}" + elif output_format == "ALE": + output = f"Answer: The answer is {answer}. BECAUSE: {lecture} {solution}" + elif output_format == "AEL": + output = f"Answer: The answer is {answer}. BECAUSE: {solution} {lecture}" + + elif output_format == "LA": + output = f"Answer: {lecture} The answer is {answer}." + elif output_format == "EA": + output = f"Answer: {solution} The answer is {answer}." + elif output_format == "LEA": + output = f"Answer: {lecture} {solution} The answer is {answer}." + elif output_format == "ELA": + output = f"Answer: {solution} {lecture} The answer is {answer}." + + text = input + output + text = text.replace(" ", " ").strip() + if text.endswith("BECAUSE:"): + text = text.replace("BECAUSE:", "").strip() + return text + + +def create_one_example_gpt4(format, question, context, choice, answer, lecture, solution, test_example=True): + + input_format, output_format = format.split("-") + + ## Inputs + if input_format == "CQM": + input = f"Context: {context}\nQuestion: {question}\nOptions: {choice}\n" + elif input_format == "QCM": + input = f"Question: {question}\nContext: {context}\nOptions: {choice}\n" + # upper bound experiment + elif input_format == "QCML": + input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture}\n" + elif input_format == "QCME": + input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {solution}\n" + elif input_format == "QCMLE": + input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture} {solution}\n" + + elif input_format == "QCLM": + input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture}\nOptions: {choice}\n" + elif input_format == "QCEM": + input = f"Question: {question}\nContext: {context}\nBECAUSE: {solution}\nOptions: {choice}\n" + elif input_format == "QCLEM": + input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture} {solution}\nOptions: {choice}\n" + + # Outputs + if test_example: + output = "Answer:" + elif output_format == "A": + output = f"Answer: The answer is {answer}." + + elif output_format == "AL": + output = f"Answer: The answer is {answer}. BECAUSE: {solution}" + elif output_format == "AE": + output = f"Answer: The answer is {answer}. BECAUSE: {lecture}" + elif output_format == "ALE": + output = f"Answer: The answer is {answer}. BECAUSE: {lecture} {solution}" + elif output_format == "AEL": + output = f"Answer: The answer is {answer}. BECAUSE: {solution} {lecture}" + + elif output_format == "LA": + output = f"Answer: {lecture} The answer is {answer}." + elif output_format == "EA": + output = f"Answer: {solution} The answer is {answer}." + elif output_format == "LEA": + output = f"Answer: {lecture} {solution} The answer is {answer}." + elif output_format == "ELA": + output = f"Answer: {solution} {lecture} The answer is {answer}." + + input = input.replace(" ", " ").strip() + output = output.replace(" ", " ").strip() + if output.endswith("BECAUSE:"): + output = output.replace("BECAUSE:", "").strip() + + user_prompt = {"role": "user", "content": f"Can you explain {input}?"} + assistant_prompt = {"role": "assistant", "content": f"{output}"} + + return user_prompt, assistant_prompt + + +def build_prompt_chatbot(problems, shot_qids, prompt_format, use_caption=False, options=["A", "B", "C", "D", "E"], is_test=False): + examples = {} + + for qid in shot_qids: + question = get_question_text(problems[qid]) + context = get_context_text(problems[qid], use_caption) + choice = get_choice_text(problems[qid], options) + answer = get_answer(problems[qid], options) + lecture = get_lecture_text(problems[qid]).replace("\\n", "\n") + solution = get_solution_text(problems[qid]).replace("\\n", "\n") + + train_example = create_one_example_chatbot(prompt_format, question, context, choice, answer, lecture, solution, test_example=is_test) + examples[qid] = train_example + return examples + + +def build_prompt(problems, shot_qids, test_qid, args): + + examples = [] + + # n-shot training examples + for qid in shot_qids: + question = get_question_text(problems[qid]) + context = get_context_text(problems[qid], args.use_caption) + choice = get_choice_text(problems[qid], args.options) + answer = get_answer(problems[qid], args.options) + lecture = get_lecture_text(problems[qid]) + solution = get_solution_text(problems[qid]) + + train_example = create_one_example(args.prompt_format, question, context, choice, answer, lecture, solution, test_example=False) + examples.append(train_example) + + # test example + question = get_question_text(problems[test_qid]) + context = get_context_text(problems[test_qid], args.use_caption) + choice = get_choice_text(problems[test_qid], args.options) + answer = get_answer(problems[test_qid], args.options) + lecture = get_lecture_text(problems[test_qid]) + solution = get_solution_text(problems[test_qid]) + + test_example = create_one_example(args.prompt_format, question, context, choice, answer, lecture, solution, test_example=True) + examples.append(test_example) + + # create the prompt input + prompt_input = "\n\n".join(examples) + + return prompt_input + + +def build_prompt_gpt4(problems, shot_qids, test_qid, args): + + prompt_array = [{"role": "system", "content": "You are a helpful assistant."}] + + # n-shot training examples + for qid in shot_qids: + question = get_question_text(problems[qid]) + context = get_context_text(problems[qid], args.use_caption) + choice = get_choice_text(problems[qid], args.options) + answer = get_answer(problems[qid], args.options) + lecture = get_lecture_text(problems[qid]) + solution = get_solution_text(problems[qid]) + + user_prompt, assistant_prompt = create_one_example_gpt4(args.prompt_format, question, context, choice, answer, lecture, solution, test_example=False) + prompt_array.append(user_prompt) + prompt_array.append(assistant_prompt) + + # test example + question = get_question_text(problems[test_qid]) + context = get_context_text(problems[test_qid], args.use_caption) + choice = get_choice_text(problems[test_qid], args.options) + answer = get_answer(problems[test_qid], args.options) + lecture = get_lecture_text(problems[test_qid]) + solution = get_solution_text(problems[test_qid]) + + user_prompt, assistant_prompt = create_one_example_gpt4(args.prompt_format, question, context, choice, answer, lecture, solution, test_example=True) + prompt_array.append(user_prompt) + prompt_array.append(assistant_prompt) + + return prompt_array diff --git a/scripts/archived/convert_vizwiz_for_submission.py b/scripts/archived/convert_vizwiz_for_submission.py new file mode 100755 index 000000000..80b825398 --- /dev/null +++ b/scripts/archived/convert_vizwiz_for_submission.py @@ -0,0 +1,45 @@ +import os +import argparse +import json + +from llava.eval.m4c_evaluator import EvalAIAnswerProcessor + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--annotation-file", type=str, required=True) + parser.add_argument("--result-file", type=str, required=True) + parser.add_argument("--result-upload-file", type=str, required=True) + return parser.parse_args() + + +if __name__ == "__main__": + + args = parse_args() + + os.makedirs(os.path.dirname(args.result_upload_file), exist_ok=True) + + results = [] + error_line = 0 + for line_idx, line in enumerate(open(args.result_file)): + try: + results.append(json.loads(line)) + except: + error_line += 1 + results = {x["question_id"]: x["text"] for x in results} + test_split = [json.loads(line) for line in open(args.annotation_file)] + split_ids = set([x["question_id"] for x in test_split]) + + print(f"total results: {len(results)}, total split: {len(test_split)}, error_line: {error_line}") + + all_answers = [] + + answer_processor = EvalAIAnswerProcessor() + + for x in test_split: + # import pdb; pdb.set_trace() + assert x["question_id"] in results, print(x) + all_answers.append({"image": x["image"], "answer": answer_processor(results[x["question_id"]])}) + + with open(args.result_upload_file, "w") as f: + json.dump(all_answers, f) diff --git a/scripts/archived/convert_vqav2_for_submission.py b/scripts/archived/convert_vqav2_for_submission.py new file mode 100755 index 000000000..f0a805795 --- /dev/null +++ b/scripts/archived/convert_vqav2_for_submission.py @@ -0,0 +1,50 @@ +import os +import argparse +import json + +from llava.eval.m4c_evaluator import EvalAIAnswerProcessor + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--dir", type=str, default="./playground/data/eval/vqav2") + parser.add_argument("--ckpt", type=str, required=True) + parser.add_argument("--split", type=str, required=True) + return parser.parse_args() + + +if __name__ == "__main__": + + args = parse_args() + + src = os.path.join(args.dir, "answers", args.split, args.ckpt, "merge.jsonl") + test_split = os.path.join(args.dir, "llava_vqav2_mscoco_test2015.jsonl") + dst = os.path.join(args.dir, "answers_upload", args.split, f"{args.ckpt}.json") + os.makedirs(os.path.dirname(dst), exist_ok=True) + + results = [] + error_line = 0 + for line_idx, line in enumerate(open(src)): + try: + results.append(json.loads(line)) + except: + error_line += 1 + + results = {x["question_id"]: x["text"] for x in results} + test_split = [json.loads(line) for line in open(test_split)] + split_ids = set([x["question_id"] for x in test_split]) + + print(f"total results: {len(results)}, total split: {len(test_split)}, error_line: {error_line}") + + all_answers = [] + + answer_processor = EvalAIAnswerProcessor() + + for x in test_split: + if x["question_id"] not in results: + all_answers.append({"question_id": x["question_id"], "answer": ""}) + else: + all_answers.append({"question_id": x["question_id"], "answer": answer_processor(results[x["question_id"]])}) + + with open(dst, "w") as f: + json.dump(all_answers, open(dst, "w")) diff --git a/scripts/archived/data_info.py b/scripts/archived/data_info.py new file mode 100755 index 000000000..d4aa7a5ba --- /dev/null +++ b/scripts/archived/data_info.py @@ -0,0 +1,124 @@ +import json +import os +from PIL import Image +from tqdm import tqdm +import matplotlib.pyplot as plt +import numpy as np + + +def load_data(json_path): + with open(json_path, "r") as f: + return json.load(f) + + +def filter_data(data): + # filtered_data = [item for item in data if "image" in item and "text" in item["image"]] + filtered_data = [item for item in data if "image" in item] + return filtered_data + + +from multiprocessing import Pool +import functools + + +def calculate_image_dimension(item, images_folder): + image_path = os.path.join(images_folder, item["image"]) + try: + with Image.open(image_path) as img: + width, height = img.size + return width, height + except Exception as e: + print(f"Error opening {image_path}: {e}") + return None, None + + +def calculate_image_dimensions_multiprocess(filtered_data, images_folder, num_processes=256): + with Pool(num_processes) as p: + dimensions = list(tqdm(p.imap(functools.partial(calculate_image_dimension, images_folder=images_folder), filtered_data), total=len(filtered_data), desc="Calculating image dimensions")) + widths, heights = zip(*[dim for dim in dimensions if dim[0] is not None]) + return list(widths), list(heights) + + +def tokenize(text): + return text.split() + + +def calculate_tokenized_lengths(data): + lengths = [] + for item in tqdm(data, desc="Tokenizing conversations"): + for conversation in item["conversations"]: + tokenized_value = tokenize(conversation["value"]) + lengths.append(len(tokenized_value)) + return lengths + + +import argparse + + +def main(): + parser = argparse.ArgumentParser(description="Process data for LLaVA_Next project.") + parser.add_argument("--json_path", type=str, help="Path to the JSON file containing data.") + parser.add_argument("--images_folder", type=str, default="/mnt/bn/vl-research/data/llava_data", help="Path to the folder containing images.") + args = parser.parse_args() + + llava_instruct_name = args.json_path.split("/")[-1].replace(".json", "") + json_path = args.json_path + llava_instruct_name = os.path.basename(json_path).replace(".json", "") + images_folder = args.images_folder + + data = load_data(json_path) + filtered_data = filter_data(data) + + if len(filtered_data) != 0: + print(f"Total data items: {len(data)}, Filtered data items: {len(filtered_data)}") + widths, heights = calculate_image_dimensions_multiprocess(filtered_data, images_folder) + max_width = max(widths) + max_height = max(heights) + print(f"Max width: {max_width}, Max height: {max_height}") + + tokenized_lengths = calculate_tokenized_lengths(data) + + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(24, 12)) + + if len(filtered_data) != 0: + # Plot 2D histogram + if min(widths) == max(widths): + widths_bins = [min(widths), max(widths) + 1] + else: + widths_bins = np.arange(min(widths), max(widths) + 100, 100) + + if min(heights) == max(heights): + heights_bins = [min(heights), max(heights) + 1] + else: + heights_bins = np.arange(min(heights), max(heights) + 100, 100) + + h, xedges, yedges, image = ax1.hist2d(widths, heights, bins=[widths_bins, heights_bins], cmap=plt.cm.jet, density=True) + fig.colorbar(image, ax=ax1) + ax1.set_xlabel("Width") + ax1.set_ylabel("Height") + ax1.set_title(f"dist_{llava_instruct_name}_2d_w_h\nMax width: {max(widths)}, Max height: {max(heights)}", fontsize=10) + + # Plot histogram + hist, bin_edges = np.histogram(tokenized_lengths, bins=np.arange(0, max(tokenized_lengths) + 10, 100)) + bins = np.arange(0, max(tokenized_lengths) + 10, 100) + ax2.bar(bin_edges[:-1], hist, width=7, edgecolor="black", log=True) + + # Display every nth label on the x-axis + n = 8 # Adjust this value to control the number of labels displayed + ticks = bins[::n] + tick_labels = [int(tick) for tick in ticks] + ax2.set_xticks(ticks) + ax2.set_xticklabels(tick_labels, rotation=90, fontsize=8) + + ax2.set_xlim(min(bin_edges), max(bin_edges)) + ax2.set_xlabel("Tokenized Length") + ax2.set_ylabel("Count (log scale)") + ax2.set_title(f"dist_{llava_instruct_name}_tokenized_length", fontsize=8) + + plt.tight_layout() + plt.savefig(f"/mnt/bn/vl-research/workspace/boli01/projects/LLaVA_Next/notebooks/sft_data/dist_{llava_instruct_name}_combined.png") + print(f"Plots saved to /mnt/bn/vl-research/workspace/boli01/projects/LLaVA_Next/notebooks/sft_data/dist_{llava_instruct_name}_combined.png") + + +if __name__ == "__main__": + main() diff --git a/scripts/archived/dpo_data_info.py b/scripts/archived/dpo_data_info.py new file mode 100755 index 000000000..ea60ec192 --- /dev/null +++ b/scripts/archived/dpo_data_info.py @@ -0,0 +1,67 @@ +import json +import re + +json_path = "/mnt/bn/vl-research/workspace/boli01/projects/sft_data_workspace/vlfeedback_80k.jsonl" + +with open(json_path, "r") as f: + data = f.readlines() + +data = [json.loads(d) for d in data] + + +def convert_format(original_data, dimension="Visual Faithfulness"): + converted_data = [] + for item in original_data: + # Assuming the best response is the one with the highest helpfulness rating + best_completion = max(item["completions"], key=lambda x: int(x["annotations"]["Helpfulness"]["Rating"])) + best_response = best_completion["response"] + best_model = best_completion["model"] + + if "†source" in best_response: + print(best_response) + # Regex pattern to match the pattern 【digit†source】 + pattern = r"【\d+†source】" + # Replace the matched patterns with an empty string + cleaned_text = re.sub(pattern, "", best_response) + best_response = cleaned_text + print(f"*****************************************") + print(best_response) + + # Assuming the worst response is the one with the lowest helpfulness rating + worst_completion = min(item["completions"], key=lambda x: int(x["annotations"]["Helpfulness"]["Rating"])) + worst_response = worst_completion["response"] + + if "†source" in worst_response: + print(worst_response) + # Regex pattern to match the pattern ��digit†source】 + pattern = r"【\d+†source】" + # Replace the matched patterns with an empty string + cleaned_text = re.sub(pattern, "", worst_response) + worst_response = cleaned_text + print(f"*****************************************") + print(worst_response) + + # Extract scores + best_score = int(best_completion["annotations"][dimension]["Rating"]) + worst_score = int(worst_completion["annotations"][dimension]["Rating"]) + + # Construct the new format + new_item = { + "id": item["id"], + "prompt": item["prompt"], + "answer": "", + "image": f"silkie_dpo/{item['id']}.jpg", # Assuming the video ID is the last part of the original ID + "chosen": best_response, + "rejected": worst_response, + "chosen_score": best_score, + "rejected_score": worst_score, + } + converted_data.append(new_item) + + return converted_data + + +for dimension in ["Visual Faithfulness", "Helpfulness", "Ethical Considerations"]: + converted_data = convert_format(data, dimension=dimension) + with open(f"/mnt/bn/vl-research/data/llava_instruct/dpo_data/silkie_dpo_data_{dimension.replace(' ', '_').lower()}_{len(converted_data)}.json", "w") as f: + json.dump(converted_data, f, indent=4) diff --git a/scripts/archived/entry_cmd.sh b/scripts/archived/entry_cmd.sh new file mode 100755 index 000000000..52a38170c --- /dev/null +++ b/scripts/archived/entry_cmd.sh @@ -0,0 +1,30 @@ +python3 -m pip install --upgrade pip; + +export http_proxy=http://sys-proxy-rd-relay.byted.org:8118; +export https_proxy=http://sys-proxy-rd-relay.byted.org:8118; + +export HF_HOME=/mnt/bn/vl-research-boli01-cn/.cache/huggingface; +export HF_TOKEN="HF_Token"; +export HF_HUB_ENABLE_HF_TRANSFER="1"; + +cd /mnt/bn/vl-research-boli01-cn/projects/zzz/lmms-eval; +pip install -e .; + +cd /mnt/bn/vl-research-boli01-cn/projects/zzz/LLaVA_Next; +pip install -e .; + +python3 -m pip install ninja; +python3 -m pip install flash-attn --no-build-isolation; + +bash /mnt/bn/vl-research-boli01-cn/projects/zzz/LLaVA_Next/cn_scripts/vicuna/internal0.6m_finetune_llava1.6mix_7b_v0.2_unfreeze.sh + + +accelerate launch --num_processes 8 --main_process_port 12345 -m lmms_eval \ + --model llava \ + --model_args pretrained="/mnt/bn/vl-research-boli01-cn/projects/zzz/LLaVA_Next/internal_project_checkpoints/llavanext-lmsys_vicuna-7b-v1.5-clip-vit-large-patch14-336-mlp2x_gelu-pretrain_internal0.6m_vicuna_v1_finetune_llava1.6_datamix_unfreezeVIS_1e" \ + --tasks ok_vqa,textcaps_val,mme_test,mmmu,cmmmu,coco2017_cap_val,vizwiz_vqa_val,ai2d,chartqa,pope \ + --batch_size 1 \ + --log_samples \ + --log_samples_suffix debug \ + --output_path ./logs/ \ + --wandb_args 'project=llava-next-lmms-eval,job_type=eval'; \ No newline at end of file diff --git a/scripts/archived/finetune.sh b/scripts/archived/finetune.sh new file mode 100755 index 000000000..9d1503682 --- /dev/null +++ b/scripts/archived/finetune.sh @@ -0,0 +1,95 @@ +#!/bin/bash + +cd /mnt/bn/vl-research/workspace/boli01/zzzprojects/LLaVA + +# Install yolk3k if not installed +if ! pip show yolk3k > /dev/null 2>&1; then + pip install yolk3k +fi + +# Get the installed version of transformers +installed_version=$(pip show transformers | grep Version | cut -d ' ' -f 2) + +# Get the latest version of transformers from PyPI +latest_version=$(yolk -V transformers | cut -d ' ' -f 2) + +# Check if the installed version is not the latest +if [ "$installed_version" != "$latest_version" ]; then + pip install -U transformers +fi + +# Get the installed version of deepspeed +installed_version=$(pip show deepspeed | grep Version | cut -d ' ' -f 2) + +# Get the latest version of deepspeed from PyPI +latest_version=$(yolk -V deepspeed | cut -d ' ' -f 2) + +# Check if the installed version is not the latest + # pip install deepspeed==0.12.2 +if [ "$installed_version" != "$latest_version" ]; then + pip install deepspeed==0.12.2 +fi + +# Install flash-attn if not installed +if ! pip show flash-attn > /dev/null 2>&1; then + pip install flash-attn --no-build-isolation +fi + +################## VICUNA ################## +PROMPT_VERSION=v1 +MODEL_VERSION="vicuna-7b-v1-5" +################## VICUNA ################## + + +################## project ################## +PROJECT_NAME="ds_llava-vicuna-7b-v1-5-mlp2x_gelu-pretrain_blip558k_plain" + +################## data ################## +DATA_NAME="mixtral_instruct_158K_V1" + +# wandb configure +export WANDB_API_KEY="03fc62d68025c9498cf6493432551badd7d4f953" +wandb login $WANDB_API_KEY + +export WANDB_NAME=$PROJECT_NAME--$MODEL_VERSION--$DATA_NAME + +export WANDB_PROJECT=LLaVA_Mixtral + +export WANDB_MODE=online + +# wandb online + +deepspeed --master_port 26000 \ + llava/train/train_mem.py \ + --deepspeed ./scripts/zero2.json \ + --model_name_or_path ./checkpoints/$MODEL_VERSION \ + --version $PROMPT_VERSION \ + --data_path ./playground/data/$DATA_NAME.json \ + --image_folder /mnt/bn/vl-research/workspace/boli01/data/playground/data/coco/train2017 \ + --vision_tower openai/clip-vit-large-patch14 \ + --pretrain_mm_mlp_adapter ./checkpoints/$PROJECT_NAME/mm_projector.bin \ + --mm_vision_select_layer -2 \ + --mm_projector_type mlp2x_gelu \ + --mm_use_im_start_end False \ + --mm_use_im_patch_token False \ + --bf16 True \ + --output_dir ./checkpoints/llava--$PROJECT_NAME--$MODEL_VERSION--$DATA_NAME--finetune \ + --num_train_epochs 1 \ + --per_device_train_batch_size 16 \ + --per_device_eval_batch_size 4 \ + --gradient_accumulation_steps 1 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 50000 \ + --save_total_limit 1 \ + --learning_rate 2e-5 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --tf32 True \ + --model_max_length 2048 \ + --gradient_checkpointing True \ + --dataloader_num_workers 16 \ + --lazy_preprocess True \ + --report_to wandb diff --git a/scripts/archived/finetune_1.5.sh b/scripts/archived/finetune_1.5.sh new file mode 100755 index 000000000..13cb818b7 --- /dev/null +++ b/scripts/archived/finetune_1.5.sh @@ -0,0 +1,99 @@ +#!/bin/bash +dataset_name=$1 + +# Uncomment and set the following variables correspondingly to run this script: + +cd /mnt/bn/vl-research/workspace/boli01/zzzprojects/LLaVA + +# Install yolk3k if not installed +if ! pip show yolk3k > /dev/null 2>&1; then + pip install yolk3k +fi + +# Get the installed version of transformers +installed_version=$(pip show transformers | grep Version | cut -d ' ' -f 2) + +# Get the latest version of transformers from PyPI +latest_version=$(yolk -V transformers | cut -d ' ' -f 2) + +# Check if the installed version is not the latest +if [ "$installed_version" != "$latest_version" ]; then + pip install -U transformers +fi + +# Get the installed version of deepspeed +installed_version=$(pip show deepspeed | grep Version | cut -d ' ' -f 2) + +# Get the latest version of deepspeed from PyPI +latest_version=$(yolk -V deepspeed | cut -d ' ' -f 2) + +# Check if the installed version is not the latest +if [ "$installed_version" != "$latest_version" ]; then + pip install deepspeed==0.12.2 +fi + +# Install yolk3k if not installed +if ! pip show flash-attn > /dev/null 2>&1; then + pip install flash-attn --no-build-isolation +fi + + +################## VICUNA ################## +PROMPT_VERSION=v1 +MODEL_VERSION="vicuna-7b-v1-5" +################## VICUNA ################## + +################## project ################## +PROJECT_NAME="ds_llava-vicuna-7b-v1-5-mlp2x_gelu-pretrain_blip558k_plain" + +################## data ################## +DATA_NAME=$dataset_name + + +# wandb configure +export WANDB_API_KEY="03fc62d68025c9498cf6493432551badd7d4f953" +wandb login $WANDB_API_KEY + +export WANDB_NAME=$PROJECT_NAME--$MODEL_VERSION--$DATA_NAME + +export WANDB_PROJECT=LLaVA_Mixtral + +export WANDB_MODE=online + +wandb online + + +deepspeed --master_port 26000 \ + llava/train/train_mem.py \ + --deepspeed ./scripts/zero2.json \ + --model_name_or_path ./checkpoints/$MODEL_VERSION \ + --version $PROMPT_VERSION \ + --data_path ./playground/data/$DATA_NAME.json \ + --image_folder /mnt/bn/vl-research/workspace/boli01/data/playground/data \ + --vision_tower openai/clip-vit-large-patch14 \ + --pretrain_mm_mlp_adapter ./checkpoints/$PROJECT_NAME/mm_projector.bin \ + --mm_vision_select_layer -2 \ + --mm_projector_type mlp2x_gelu \ + --mm_use_im_start_end False \ + --mm_use_im_patch_token False \ + --bf16 True \ + --output_dir ./checkpoints/llava--$PROJECT_NAME--$MODEL_VERSION--$DATA_NAME--finetune \ + --num_train_epochs 1 \ + --per_device_train_batch_size 16 \ + --per_device_eval_batch_size 4 \ + --gradient_accumulation_steps 1 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 50000 \ + --save_total_limit 1 \ + --learning_rate 2e-5 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --tf32 True \ + --model_max_length 2048 \ + --gradient_checkpointing True \ + --dataloader_num_workers 16 \ + --lazy_preprocess True \ + --report_to wandb diff --git a/scripts/archived/finetune_full_schedule.sh b/scripts/archived/finetune_full_schedule.sh new file mode 100755 index 000000000..9769666db --- /dev/null +++ b/scripts/archived/finetune_full_schedule.sh @@ -0,0 +1,46 @@ +#!/bin/bash + +# Uncomment and set the following variables correspondingly to run this script: + +################## VICUNA ################## +# PROMPT_VERSION=v1 +# MODEL_VERSION="vicuna-v1-3-7b" +################## VICUNA ################## + +################## LLaMA-2 ################## +# PROMPT_VERSION="llava_llama_2" +# MODEL_VERSION="llama-2-7b-chat" +################## LLaMA-2 ################## + +deepspeed llava/train/train_mem.py \ + --deepspeed ./scripts/zero2.json \ + --model_name_or_path ./checkpoints/$MODEL_VERSION \ + --version $PROMPT_VERSION \ + --data_path ./playground/data/llava_instruct_158k.json \ + --image_folder /path/to/coco/train2017 \ + --vision_tower openai/clip-vit-large-patch14 \ + --pretrain_mm_mlp_adapter ./checkpoints/llava-$MODEL_VERSION-pretrain/mm_projector.bin \ + --mm_vision_select_layer -2 \ + --mm_use_im_start_end False \ + --mm_use_im_patch_token False \ + --bf16 True \ + --output_dir ./checkpoints/llava-$MODEL_VERSION-finetune \ + --num_train_epochs 3 \ + --per_device_train_batch_size 16 \ + --per_device_eval_batch_size 4 \ + --gradient_accumulation_steps 1 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 50000 \ + --save_total_limit 1 \ + --learning_rate 2e-5 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --tf32 True \ + --model_max_length 2048 \ + --gradient_checkpointing True \ + --dataloader_num_workers 16 \ + --lazy_preprocess True \ + --report_to wandb diff --git a/scripts/archived/finetune_lora.sh b/scripts/archived/finetune_lora.sh new file mode 100755 index 000000000..5ff90c024 --- /dev/null +++ b/scripts/archived/finetune_lora.sh @@ -0,0 +1,47 @@ +#!/bin/bash + +# Uncomment and set the following variables correspondingly to run this script: + +################## VICUNA ################## +# PROMPT_VERSION=v1 +# MODEL_VERSION="vicuna-v1-3-7b" +################## VICUNA ################## + +################## LLaMA-2 ################## +# PROMPT_VERSION="llava_llama_2" +# MODEL_VERSION="llama-2-7b-chat" +################## LLaMA-2 ################## + +deepspeed llava/train/train_mem.py \ + --deepspeed ./scripts/zero2.json \ + --lora_enable True \ + --model_name_or_path ./checkpoints/$MODEL_VERSION \ + --version $PROMPT_VERSION \ + --data_path ./playground/data/llava_instruct_80k.json \ + --image_folder /path/to/coco/train2017 \ + --vision_tower openai/clip-vit-large-patch14 \ + --pretrain_mm_mlp_adapter ./checkpoints/llava-$MODEL_VERSION-pretrain/mm_projector.bin \ + --mm_vision_select_layer -2 \ + --mm_use_im_start_end False \ + --mm_use_im_patch_token False \ + --bf16 True \ + --output_dir ./checkpoints/llava-$MODEL_VERSION-finetune_lora \ + --num_train_epochs 1 \ + --per_device_train_batch_size 16 \ + --per_device_eval_batch_size 4 \ + --gradient_accumulation_steps 1 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 50000 \ + --save_total_limit 1 \ + --learning_rate 2e-5 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --tf32 True \ + --model_max_length 2048 \ + --gradient_checkpointing True \ + --lazy_preprocess True \ + --dataloader_num_workers 16 \ + --report_to wandb diff --git a/scripts/archived/finetune_mixtral.sh b/scripts/archived/finetune_mixtral.sh new file mode 100755 index 000000000..f7ca38361 --- /dev/null +++ b/scripts/archived/finetune_mixtral.sh @@ -0,0 +1,96 @@ +#!/bin/bash + +cd /mnt/bn/vl-research/workspace/boli01/zzzprojects/LLaVA + +# Install yolk3k if not installed +if ! pip show yolk3k > /dev/null 2>&1; then + pip install yolk3k +fi + +# Get the installed version of transformers +installed_version=$(pip show transformers | grep Version | cut -d ' ' -f 2) + +# Get the latest version of transformers from PyPI +latest_version=$(yolk -V transformers | cut -d ' ' -f 2) + +# Check if the installed version is not the latest +if [ "$installed_version" != "$latest_version" ]; then + pip install -U transformers +fi + +# Get the installed version of deepspeed +installed_version=$(pip show deepspeed | grep Version | cut -d ' ' -f 2) + +# Get the latest version of deepspeed from PyPI +latest_version=$(yolk -V deepspeed | cut -d ' ' -f 2) + +# Check if the installed version is not the latest +if [ "$installed_version" != "$latest_version" ]; then + pip install deepspeed==0.12.2 +fi + +# Install yolk3k if not installed +if ! pip show flash-attn > /dev/null 2>&1; then + pip install flash-attn --no-build-isolation +fi + + +################## MISTRAL ################## +PROMPT_VERSION=mistral_instruct +MODEL_VERSION="Mistral-7B-Instruct-v0.2" +################## VICUNA ################## + + +################## project ################## +PROJECT_NAME="ds_llava-Mistral-7B-Instruct-v0.2-mlp2x_gelu-pretrain_blip558k_plain" + +################## data ################## +DATA_NAME="mixtral_instruct_158K_V1" + +# wandb configure +export WANDB_API_KEY="03fc62d68025c9498cf6493432551badd7d4f953" +wandb login $WANDB_API_KEY + +export WANDB_NAME=$PROJECT_NAME--$MODEL_VERSION--$DATA_NAME + +export WANDB_PROJECT=LLaVA_Mixtral + +export WANDB_MODE=online + +wandb online + + +deepspeed --master_port 26000 \ + llava/train/train_mem.py \ + --deepspeed ./scripts/zero2.json \ + --model_name_or_path ./checkpoints/$MODEL_VERSION \ + --version $PROMPT_VERSION \ + --data_path ./playground/data/$DATA_NAME.json \ + --image_folder /mnt/bn/vl-research/workspace/boli01/data/playground/data/coco/train2017 \ + --vision_tower openai/clip-vit-large-patch14 \ + --pretrain_mm_mlp_adapter ./checkpoints/$PROJECT_NAME/mm_projector.bin \ + --mm_vision_select_layer -2 \ + --mm_projector_type mlp2x_gelu \ + --mm_use_im_start_end False \ + --mm_use_im_patch_token False \ + --bf16 True \ + --output_dir ./checkpoints/llava--$PROJECT_NAME--$MODEL_VERSION--$DATA_NAME--finetune \ + --num_train_epochs 1 \ + --per_device_train_batch_size 16 \ + --per_device_eval_batch_size 4 \ + --gradient_accumulation_steps 1 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 50000 \ + --save_total_limit 1 \ + --learning_rate 2e-5 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --tf32 True \ + --model_max_length 2048 \ + --gradient_checkpointing True \ + --dataloader_num_workers 16 \ + --lazy_preprocess True \ + --report_to wandb diff --git a/scripts/archived/finetune_mixtral_1.5.sh b/scripts/archived/finetune_mixtral_1.5.sh new file mode 100755 index 000000000..71f4ed8ab --- /dev/null +++ b/scripts/archived/finetune_mixtral_1.5.sh @@ -0,0 +1,96 @@ +#!/bin/bash +dataset_name=$1 + +cd /mnt/bn/vl-research/workspace/yhzhang/LLaVA + +# Install yolk3k if not installed +if ! pip show yolk3k > /dev/null 2>&1; then + pip install yolk3k +fi + +# Get the installed version of transformers +installed_version=$(pip show transformers | grep Version | cut -d ' ' -f 2) + +# Get the latest version of transformers from PyPI +latest_version=$(yolk -V transformers | cut -d ' ' -f 2) + +# Check if the installed version is not the latest +if [ "$installed_version" != "$latest_version" ]; then + pip install -U transformers +fi + +# Get the installed version of deepspeed +installed_version=$(pip show deepspeed | grep Version | cut -d ' ' -f 2) + +# Get the latest version of deepspeed from PyPI +# latest_version=$(yolk -V deepspeed | cut -d ' ' -f 2) + +# Check if the installed version is not the latest +if [ "$installed_version" != "0.12.2" ]; then + pip install deepspeed==0.12.2 +fi + +# Install yolk3k if not installed +if ! pip show flash-attn > /dev/null 2>&1; then + pip install flash-attn --no-build-isolation +fi + +################## MISTRAL ################## +PROMPT_VERSION=mistral_instruct +MODEL_VERSION="Mistral-7B-Instruct-v0.2" +################## MISTRAL ################## + + +################## project ################## +PROJECT_NAME="ds_llava-Mistral-7B-Instruct-v0.2-mlp2x_gelu-pretrain_blip558k_plain" + +################## data ################## +DATA_NAME=$dataset_name + + +# wandb configure +export WANDB_API_KEY="03fc62d68025c9498cf6493432551badd7d4f953" +wandb login $WANDB_API_KEY + +export WANDB_NAME=$PROJECT_NAME--$MODEL_VERSION--$DATA_NAME + +export WANDB_PROJECT=LLaVA_Mixtral + +export WANDB_MODE=online + +wandb online + +deepspeed --master_port 26000 \ + llava/train/train_mem.py \ + --deepspeed ./scripts/zero2.json \ + --model_name_or_path ./checkpoints/$MODEL_VERSION \ + --version $PROMPT_VERSION \ + --data_path ./playground/data/$DATA_NAME.json \ + --image_folder /mnt/bn/vl-research/workspace/boli01/data/playground/data \ + --vision_tower openai/clip-vit-large-patch14 \ + --pretrain_mm_mlp_adapter ./checkpoints/$PROJECT_NAME/mm_projector.bin \ + --mm_vision_select_layer -2 \ + --mm_projector_type mlp2x_gelu \ + --mm_use_im_start_end False \ + --mm_use_im_patch_token False \ + --bf16 True \ + --output_dir ./checkpoints/llava--$PROJECT_NAME--$MODEL_VERSION--$DATA_NAME--finetune \ + --num_train_epochs 1 \ + --per_device_train_batch_size 16 \ + --per_device_eval_batch_size 4 \ + --gradient_accumulation_steps 1 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 50000 \ + --save_total_limit 1 \ + --learning_rate 2e-5 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --tf32 True \ + --model_max_length 2048 \ + --gradient_checkpointing True \ + --dataloader_num_workers 16 \ + --lazy_preprocess True + # --report_to wandb diff --git a/scripts/archived/finetune_mixtral_1.6_336px_anyres.sh b/scripts/archived/finetune_mixtral_1.6_336px_anyres.sh new file mode 100755 index 000000000..28b2e0b25 --- /dev/null +++ b/scripts/archived/finetune_mixtral_1.6_336px_anyres.sh @@ -0,0 +1,100 @@ +#!/bin/bash +dataset_name=$1 + +cd /mnt/bn/vl-research/workspace/boli01/projects/LLaVA_Next + +# Install yolk3k if not installed +if ! pip show yolk3k > /dev/null 2>&1; then + pip install yolk3k +fi + +pip install pydantic + +# Get the installed version of transformers +installed_version=$(pip show transformers | grep Version | cut -d ' ' -f 2) + +# Get the latest version of transformers from PyPI +latest_version=$(yolk -V transformers | cut -d ' ' -f 2) + +# Check if the installed version is not the latest +if [ "$installed_version" != "4.36.2" ]; then + pip install transformers==4.36.2 +fi + +# Get the installed version of deepspeed +installed_version=$(pip show deepspeed | grep Version | cut -d ' ' -f 2) + + +# Check if the installed version is not the latest +if [ "$installed_version" != "0.12.2" ]; then + pip install deepspeed==0.12.2 +fi + +# Install flash-atten if not installed +if ! pip show flash-attn > /dev/null 2>&1; then + pip install flash-attn --no-build-isolation +fi + +################## MISTRAL ################## +PROMPT_VERSION=mistral_instruct +MODEL_VERSION="Mistral-7B-Instruct-v0.2" +################## MISTRAL ################## + + +################## project ################## +PROJECT_NAME="ds_llava-Mistral-7B-Instruct-v0.2-clip_large_336px-mlp2x_gelu-pretrain_blip558k_plain" + +################## data ################## +DATA_NAME=$dataset_name + + +# wandb configure +export WANDB_API_KEY=e464cc107357c7b38e87f239bc3eb2ce5fb73c7c +export WANDB_PROJECT=llava + +export WANDB_NAME=$PROJECT_NAME--$DATA_NAME--336px--anyres--sft + +export WANDB_MODE=online + +wandb online + +deepspeed --master_port 26000 \ + llava/train/train_mem.py \ + --deepspeed ./scripts/zero3.json \ + --model_name_or_path /mnt/bn/vl-research/workspace/project/2023/LLaVA/checkpoints/$MODEL_VERSION \ + --version $PROMPT_VERSION \ + --data_path ./playground/data/$DATA_NAME.json \ + --image_folder /mnt/bn/vl-research/workspace/boli01/data/playground/data \ + --vision_tower openai/clip-vit-large-patch14-336 \ + --pretrain_mm_mlp_adapter /mnt/bn/vl-research/workspace/project/2023/LLaVA/checkpoints/ds_llava-Mistral-7B-Instruct-v0.2-clip_large_336px-mlp2x_gelu-pretrain_blip558k_plain/mm_projector.bin \ + --mm_projector_type mlp2x_gelu \ + --mm_vision_select_layer -2 \ + --mm_use_im_start_end False \ + --mm_use_im_patch_token False \ + --group_by_modality_length True \ + --unfreeze_mm_vision_tower True \ + --mm_vision_tower_lr 2e-6 \ + --image_aspect_ratio anyres \ + --image_grid_pinpoints "[(336, 672), (672, 336), (672, 672), (1008, 336), (336, 1008)]" \ + --mm_patch_merge_type spatial_unpad \ + --bf16 True \ + --output_dir ./checkpoints/$PROJECT_NAME--$DATA_NAME--336px--anyres--sft \ + --num_train_epochs 9 \ + --per_device_train_batch_size 8 \ + --per_device_eval_batch_size 4 \ + --gradient_accumulation_steps 1 \ + --evaluation_strategy "no" \ + --save_strategy "epoch" \ + --save_steps 1500 \ + --learning_rate 5e-6 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --tf32 True \ + --model_max_length 4096 \ + --gradient_checkpointing True \ + --dataloader_num_workers 8 \ + --lazy_preprocess True \ + --report_to wandb + diff --git a/scripts/archived/finetune_mixtral_1.6_336px_anyres_freeze_vision.sh b/scripts/archived/finetune_mixtral_1.6_336px_anyres_freeze_vision.sh new file mode 100755 index 000000000..e04f4871f --- /dev/null +++ b/scripts/archived/finetune_mixtral_1.6_336px_anyres_freeze_vision.sh @@ -0,0 +1,97 @@ +#!/bin/bash +dataset_name=$1 + +cd /mnt/bn/vl-research/workspace/yhzhang/LLaVA + +# Install yolk3k if not installed +if ! pip show yolk3k > /dev/null 2>&1; then + pip install yolk3k +fi + +pip install pydantic + +# Get the installed version of transformers +installed_version=$(pip show transformers | grep Version | cut -d ' ' -f 2) + +# Get the latest version of transformers from PyPI +latest_version=$(yolk -V transformers | cut -d ' ' -f 2) + +# Check if the installed version is not the latest +if [ "$installed_version" != "4.36.2" ]; then + pip install transformers==4.36.2 +fi + +# Get the installed version of deepspeed +installed_version=$(pip show deepspeed | grep Version | cut -d ' ' -f 2) + + +# Check if the installed version is not the latest +if [ "$installed_version" != "0.12.2" ]; then + pip install deepspeed==0.12.2 +fi + +# Install flash-atten if not installed +if ! pip show flash-attn > /dev/null 2>&1; then + pip install flash-attn --no-build-isolation +fi + +################## MISTRAL ################## +PROMPT_VERSION=mistral_instruct +MODEL_VERSION="Mistral-7B-Instruct-v0.2" +################## MISTRAL ################## + + +################## project ################## +PROJECT_NAME="ds_llava-Mistral-7B-Instruct-v0.2-clip_large_336px-mlp2x_gelu-pretrain_blip558k_plain" + +################## data ################## +DATA_NAME=$dataset_name + + +# wandb configure +export WANDB_API_KEY=e464cc107357c7b38e87f239bc3eb2ce5fb73c7c +export WANDB_PROJECT=llava + +export WANDB_NAME=$PROJECT_NAME--$DATA_NAME--336px--unfreeze--anyres--sft + +export WANDB_MODE=online + +wandb online + +deepspeed --master_port 26000 \ + llava/train/train_mem.py \ + --deepspeed ./scripts/zero3.json \ + --model_name_or_path ./checkpoints/$MODEL_VERSION \ + --version $PROMPT_VERSION \ + --data_path ./playground/data/$DATA_NAME.json \ + --image_folder /mnt/bn/vl-research/workspace/boli01/data/playground/data \ + --vision_tower openai/clip-vit-large-patch14-336 \ + --pretrain_mm_mlp_adapter /mnt/bn/vl-research/workspace/project/2023/LLaVA/checkpoints/ds_llava-Mistral-7B-Instruct-v0.2-clip_large_336px-mlp2x_gelu-pretrain_blip558k_plain/mm_projector.bin \ + --mm_vision_select_layer -2 \ + --mm_projector_type mlp2x_gelu \ + --mm_use_im_start_end False \ + --mm_use_im_patch_token False \ + --group_by_modality_length True \ + --image_aspect_ratio anyres \ + --image_grid_pinpoints "[(336, 672), (672, 336), (672, 672), (1008, 336), (336, 1008)]" \ + --mm_patch_merge_type spatial_unpad \ + --bf16 True \ + --output_dir ./checkpoints/$PROJECT_NAME--$DATA_NAME--336px--anyres--unfreeze--sft \ + --num_train_epochs 1 \ + --per_device_train_batch_size 16 \ + --per_device_eval_batch_size 4 \ + --gradient_accumulation_steps 1 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 50000 \ + --save_total_limit 1 \ + --learning_rate 2e-5 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --tf32 True \ + --model_max_length 2048 \ + --gradient_checkpointing True \ + --dataloader_num_workers 16 \ + diff --git a/scripts/archived/finetune_mixtral_1.6_336px_anyres_lmms_eval.sh b/scripts/archived/finetune_mixtral_1.6_336px_anyres_lmms_eval.sh new file mode 100755 index 000000000..c31fdc53f --- /dev/null +++ b/scripts/archived/finetune_mixtral_1.6_336px_anyres_lmms_eval.sh @@ -0,0 +1,78 @@ +#!/bin/bash + +# set up wandb +export WANDB_API_KEY=a651c244635bc6f913ab654af3f0eebaecdc9381 +export WANDB_ENTITY=llava-vl +export WANDB_PROJECT=llava-next +export PYTHONWARNINGS="ignore" + +cd /mnt/bn/vl-research/workspace/boli01/projects/lmms-eval + +pip install -e . + +# set up llava dev env +cd /mnt/bn/vl-research/workspace/boli01/projects/LLaVA_Next + +################## MISTRAL ################## +PROMPT_VERSION=mistral_instruct +MODEL_VERSION="Mistral-7B-Instruct-v0.2" +################## MISTRAL ################## + +################## project ################## +PROJECT_NAME="ds_llava-Mistral-7B-Instruct-v0.2-clip_large_336px-mlp2x_gelu-pretrain_blip558k_plain" + +################## data ################## +DATA_NAME='llava_caps20k_chartqa19k' + +export WANDB_NAME=$PROJECT_NAME--$DATA_NAME--336px--anyres--sft +export WANDB_MODE=online + +wandb online + +CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" deepspeed --master_port 26000 --include localhost:0,1,2,3,4,5,6,7 llava/train/train_mem.py \ + --deepspeed ./scripts/zero3_offload.json \ + --model_name_or_path mistralai/$MODEL_VERSION \ + --version $PROMPT_VERSION \ + --data_path ./playground/data/llava_instruct/$DATA_NAME.json \ + --image_folder /mnt/bn/vl-research/data/llava \ + --vision_tower openai/clip-vit-large-patch14-336 \ + --mm_projector_type mlp2x_gelu \ + --mm_vision_select_layer -2 \ + --mm_use_im_start_end False \ + --mm_use_im_patch_token False \ + --group_by_modality_length True \ + --unfreeze_mm_vision_tower True \ + --mm_vision_tower_lr 2e-6 \ + --image_aspect_ratio anyres \ + --image_grid_pinpoints "[(336, 672), (672, 336), (672, 672), (1008, 336), (336, 1008)]" \ + --mm_patch_merge_type spatial_unpad \ + --bf16 True \ + --output_dir ./checkpoints/$PROJECT_NAME--llava1.6--336px--anyres--sft \ + --num_train_epochs 1 \ + --per_device_train_batch_size 8 \ + --per_device_eval_batch_size 4 \ + --gradient_accumulation_steps 1 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 1500 \ + --learning_rate 2e-5 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --tf32 True \ + --model_max_length 4096 \ + --gradient_checkpointing True \ + --dataloader_num_workers 32 \ + --lazy_preprocess True \ + --report_to wandb \ + --run_name $WANDB_NAME +# starting here is the args for evaluation + --eval_num_processes 4 \ + --task_names mme,docvqa_val \ + --model_args pretrained=./checkpoints/$PROJECT_NAME--$DATA_NAME--336px--anyres--sft \ + --limit 8 \ + --batch_size 1 \ + --log_samples \ + --log_samples_suffix debug \ + --output_path ./logs/ diff --git a/scripts/archived/finetune_mixtral_copy.sh b/scripts/archived/finetune_mixtral_copy.sh new file mode 100755 index 000000000..23cab89db --- /dev/null +++ b/scripts/archived/finetune_mixtral_copy.sh @@ -0,0 +1,96 @@ +#!/bin/bash + +cd /mnt/bn/vl-research/workspace/boli01/zzzprojects/LLaVA + +# Install yolk3k if not installed +if ! pip show yolk3k > /dev/null 2>&1; then + pip install yolk3k +fi + +# Get the installed version of transformers +installed_version=$(pip show transformers | grep Version | cut -d ' ' -f 2) + +# Get the latest version of transformers from PyPI +latest_version=$(yolk -V transformers | cut -d ' ' -f 2) + +# Check if the installed version is not the latest +if [ "$installed_version" != "$latest_version" ]; then + pip install -U transformers +fi + +# Get the installed version of deepspeed +installed_version=$(pip show deepspeed | grep Version | cut -d ' ' -f 2) + +# Get the latest version of deepspeed from PyPI +latest_version=$(yolk -V deepspeed | cut -d ' ' -f 2) + +# Check if the installed version is not the latest +if [ "$installed_version" != "$latest_version" ]; then + pip install deepspeed==0.12.2 +fi + +# Install yolk3k if not installed +if ! pip show flash-attn > /dev/null 2>&1; then + pip install flash-attn --no-build-isolation +fi + + +################## MISTRAL ################## +PROMPT_VERSION=mistral_instruct +MODEL_VERSION="Mistral-7B-Instruct-v0.2" +################## VICUNA ################## + + +################## project ################## +PROJECT_NAME="ds_llava-Mistral-7B-Instruct-v0.2-mlp2x_gelu-pretrain_blip558k_plain" + +################## data ################## +DATA_NAME="llava_instruct_150k" + +# wandb configure +export WANDB_API_KEY="03fc62d68025c9498cf6493432551badd7d4f953" +wandb login $WANDB_API_KEY + +export WANDB_NAME=$PROJECT_NAME--$MODEL_VERSION--$DATA_NAME + +export WANDB_PROJECT=LLaVA_Mixtral + +export WANDB_MODE=online + +wandb online + + +deepspeed --master_port 26000 \ + llava/train/train_mem.py \ + --deepspeed ./scripts/zero2.json \ + --model_name_or_path ./checkpoints/$MODEL_VERSION \ + --version $PROMPT_VERSION \ + --data_path ./playground/data/$DATA_NAME.json \ + --image_folder /mnt/bn/vl-research/workspace/boli01/data/playground/data/coco/train2017 \ + --vision_tower openai/clip-vit-large-patch14 \ + --pretrain_mm_mlp_adapter ./checkpoints/$PROJECT_NAME/mm_projector.bin \ + --mm_vision_select_layer -2 \ + --mm_projector_type mlp2x_gelu \ + --mm_use_im_start_end False \ + --mm_use_im_patch_token False \ + --bf16 True \ + --output_dir ./checkpoints/llava--$PROJECT_NAME--$MODEL_VERSION--$DATA_NAME--finetune \ + --num_train_epochs 1 \ + --per_device_train_batch_size 16 \ + --per_device_eval_batch_size 4 \ + --gradient_accumulation_steps 1 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 50000 \ + --save_total_limit 1 \ + --learning_rate 2e-5 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --tf32 True \ + --model_max_length 2048 \ + --gradient_checkpointing True \ + --dataloader_num_workers 16 \ + --lazy_preprocess True \ + --report_to wandb diff --git a/scripts/archived/finetune_qlora.sh b/scripts/archived/finetune_qlora.sh new file mode 100755 index 000000000..05be85628 --- /dev/null +++ b/scripts/archived/finetune_qlora.sh @@ -0,0 +1,48 @@ +#!/bin/bash + +# Uncomment and set the following variables correspondingly to run this script: + +################## VICUNA ################## +# PROMPT_VERSION=v1 +# MODEL_VERSION="vicuna-v1-3-7b" +################## VICUNA ################## + +################## LLaMA-2 ################## +# PROMPT_VERSION="llava_llama_2" +# MODEL_VERSION="llama-2-7b-chat" +################## LLaMA-2 ################## + +deepspeed llava/train/train_mem.py \ + --deepspeed ./scripts/zero2.json \ + --lora_enable True \ + --bits 4 \ + --model_name_or_path ./checkpoints/$MODEL_VERSION \ + --version $PROMPT_VERSION \ + --data_path ./playground/data/llava_instruct_80k.json \ + --image_folder /path/to/coco/train2017 \ + --vision_tower openai/clip-vit-large-patch14 \ + --pretrain_mm_mlp_adapter ./checkpoints/llava-$MODEL_VERSION-pretrain/mm_projector.bin \ + --mm_vision_select_layer -2 \ + --mm_use_im_start_end False \ + --mm_use_im_patch_token False \ + --bf16 True \ + --output_dir ./checkpoints/llava-$MODEL_VERSION-finetune_lora \ + --num_train_epochs 1 \ + --per_device_train_batch_size 16 \ + --per_device_eval_batch_size 4 \ + --gradient_accumulation_steps 1 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 50000 \ + --save_total_limit 1 \ + --learning_rate 2e-5 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --tf32 True \ + --model_max_length 2048 \ + --gradient_checkpointing True \ + --lazy_preprocess True \ + --dataloader_num_workers 16 \ + --report_to wandb diff --git a/scripts/archived/finetune_sqa.sh b/scripts/archived/finetune_sqa.sh new file mode 100755 index 000000000..ac1359ce9 --- /dev/null +++ b/scripts/archived/finetune_sqa.sh @@ -0,0 +1,34 @@ +#!/bin/bash + +deepspeed llava/train/train_mem.py \ + --deepspeed ./scripts/zero2.json \ + --model_name_or_path lmsys/vicuna-13b-v1.3 \ + --version $PROMPT_VERSION \ + --data_path /Data/ScienceQA/data/scienceqa/llava_train_QCM-LEA.json \ + --image_folder /Data/ScienceQA/data/scienceqa/images/train \ + --vision_tower openai/clip-vit-large-patch14 \ + --pretrain_mm_mlp_adapter ./checkpoints/huggingface/liuhaotian/llava-pretrain-vicuna-13b-v1.3/mm_projector.bin \ + --mm_vision_select_layer -2 \ + --mm_use_im_start_end False \ + --mm_use_im_patch_token False \ + --bf16 True \ + --output_dir ./checkpoints/llava-vicuna-13b-v1.3-pretrain_lcs558k_plain-ScienceQA_QCM_LEA-12e \ + --num_train_epochs 12 \ + --per_device_train_batch_size 16 \ + --per_device_eval_batch_size 4 \ + --gradient_accumulation_steps 1 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 50000 \ + --save_total_limit 1 \ + --learning_rate 2e-5 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --tf32 True \ + --model_max_length 2048 \ + --gradient_checkpointing True \ + --dataloader_num_workers 16 \ + --lazy_preprocess True \ + --report_to wandb diff --git a/scripts/archived/merge_lora_weights.py b/scripts/archived/merge_lora_weights.py new file mode 100755 index 000000000..90b4aa0f9 --- /dev/null +++ b/scripts/archived/merge_lora_weights.py @@ -0,0 +1,22 @@ +import argparse +from llava.model.builder import load_pretrained_model +from llava.mm_utils import get_model_name_from_path + + +def merge_lora(args): + model_name = get_model_name_from_path(args.model_path) + tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, device_map="cpu") + + model.save_pretrained(args.save_model_path) + tokenizer.save_pretrained(args.save_model_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model-path", type=str, required=True) + parser.add_argument("--model-base", type=str, required=True) + parser.add_argument("--save-model-path", type=str, required=True) + + args = parser.parse_args() + + merge_lora(args) diff --git a/scripts/archived/pretrain.sh b/scripts/archived/pretrain.sh new file mode 100755 index 000000000..f3cf7d94e --- /dev/null +++ b/scripts/archived/pretrain.sh @@ -0,0 +1,43 @@ +#!/bin/bash + +# Uncomment and set the following variables correspondingly to run this script: + +# MODEL_VERSION=vicuna-v1-3-7b +# MODEL_VERSION=llama-2-7b-chat + +########### DO NOT CHANGE ########### +########### USE THIS FOR BOTH ########### +PROMPT_VERSION=plain +########### DO NOT CHANGE ########### + +deepspeed llava/train/train_mem.py \ + --deepspeed ./scripts/zero2.json \ + --model_name_or_path ./checkpoints/$MODEL_VERSION \ + --version $PROMPT_VERSION \ + --data_path /path/to/pretrain_data.json \ + --image_folder /path/to/images \ + --vision_tower openai/clip-vit-large-patch14 \ + --tune_mm_mlp_adapter True \ + --mm_vision_select_layer -2 \ + --mm_use_im_start_end False \ + --mm_use_im_patch_token False \ + --bf16 True \ + --output_dir ./checkpoints/llava-$MODEL_VERSION-pretrain \ + --num_train_epochs 1 \ + --per_device_train_batch_size 16 \ + --per_device_eval_batch_size 4 \ + --gradient_accumulation_steps 1 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 24000 \ + --learning_rate 2e-3 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --tf32 True \ + --model_max_length 2048 \ + --gradient_checkpointing True \ + --dataloader_num_workers 16 \ + --lazy_preprocess True \ + --report_to wandb diff --git a/scripts/archived/quick_check.py b/scripts/archived/quick_check.py new file mode 100755 index 000000000..ca7cff9de --- /dev/null +++ b/scripts/archived/quick_check.py @@ -0,0 +1,57 @@ +import json +import os +import argparse +from tqdm import tqdm +import yaml + + +def check_missing_images(json_path, images_folder): + data = json.load(open(json_path, "r")) + missing_data = [] + + for i, d in enumerate(tqdm(data)): + image = d["image"] if "image" in d else "" + if image != "": + path = os.path.join(images_folder, image) + if not os.path.exists(path): + print(f"Missing image: {path}") + missing_data.append(d) + + return missing_data + + +def read_yaml_to_llava_data(yaml_path, images_folder): + print(f"Reading YAML file: {yaml_path}") + with open(yaml_path, "r") as f: + data = yaml.safe_load(f) + + llava_json_paths = data["datasets"] + for item in llava_json_paths: + json_path = item["json_path"] + missing_data = check_missing_images(json_path, images_folder) + if len(missing_data) > 0: + print(f"Missing images in {json_path}:") + for d in missing_data: + print(d) + + +def direct_check_llava_data(json_path, images_folder): + missing_data = check_missing_images(json_path, images_folder) + if len(missing_data) > 0: + print(f"Missing images in {json_path}:") + for d in missing_data: + print(d) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Check for missing images in dataset.") + parser.add_argument("--yaml_path", type=str, default="", help="Path to the YAML file containing the dataset.") + parser.add_argument("--json_path", type=str, default="", help="Path to the JSON file containing the dataset.") + parser.add_argument("--images_folder", type=str, default="/mnt/bn/vl-research/data/llava_data", help="Path to the folder containing the images.") + + args = parser.parse_args() + + if args.json_path != "": + direct_check_llava_data(args.json_path, args.images_folder) + elif args.yaml_path != "": + read_yaml_to_llava_data(args.yaml_path, args.images_folder) diff --git a/scripts/archived/sqa_eval_batch.sh b/scripts/archived/sqa_eval_batch.sh new file mode 100755 index 000000000..adbf46ef7 --- /dev/null +++ b/scripts/archived/sqa_eval_batch.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +CHUNKS=8 +for IDX in {0..7}; do + CUDA_VISIBLE_DEVICES=$IDX python -m llava.eval.model_vqa_science \ + --model-path liuhaotian/llava-lcs558k-scienceqa-vicuna-13b-v1.3 \ + --question-file ~/haotian/datasets/ScienceQA/data/scienceqa/llava_test_QCM-LEA.json \ + --image-folder ~/haotian/datasets/ScienceQA/data/scienceqa/images/test \ + --answers-file ./test_llava-13b-chunk$CHUNKS_$IDX.jsonl \ + --num-chunks $CHUNKS \ + --chunk-idx $IDX \ + --conv-mode llava_v1 & +done diff --git a/scripts/archived/sqa_eval_gather.sh b/scripts/archived/sqa_eval_gather.sh new file mode 100755 index 000000000..525bd43b8 --- /dev/null +++ b/scripts/archived/sqa_eval_gather.sh @@ -0,0 +1,18 @@ +#!/bin/bash + +CHUNKS=8 +output_file="test_llava-13b.jsonl" + +# Clear out the output file if it exists. +> "$output_file" + +# Loop through the indices and concatenate each file. +for idx in $(seq 0 $((CHUNKS-1))); do + cat "./test_llava-13b-chunk${idx}.jsonl" >> "$output_file" +done + +python llava/eval/eval_science_qa.py \ + --base-dir ~/haotian/datasets/ScienceQA/data/scienceqa \ + --result-file ./test_llava-13b.jsonl \ + --output-file ./test_llava-13b_output.json \ + --output-result ./test_llava-13b_result.json diff --git a/scripts/interleave/eval_all.sh b/scripts/interleave/eval_all.sh new file mode 100644 index 000000000..ee04c6ab8 --- /dev/null +++ b/scripts/interleave/eval_all.sh @@ -0,0 +1,5 @@ + +# evaluate +./scripts/interleave/eval_interleave_3d.sh /path/to/ckpt /path/to/images multi_image_in_domain +./scripts/interleave/eval_interleave_3d.sh /path/to/ckpt /path/to/images multi_image_out_domain +./scripts/interleave/eval_interleave_3d.sh /path/to/ckpt /path/to/images multi_view_in_domain \ No newline at end of file diff --git a/scripts/interleave/eval_interleave_3d.sh b/scripts/interleave/eval_interleave_3d.sh new file mode 100755 index 000000000..cea796cd7 --- /dev/null +++ b/scripts/interleave/eval_interleave_3d.sh @@ -0,0 +1,28 @@ +alias python=python3 +CKPT_PATH=$1 +NAME=$(echo "$CKPT_PATH" | awk -F'/' '{print $NF}') +echo $NAME +##### set images path +DATA_PATH=$2 +EVAL_TYPE=$3 +JSON_PATH=$2/$3.json +############################### eval multi-image +RESULT_NAME="logs/${NAME}/${EVAL_TYPE}" +echo $RESULT_NAME + +mkdir -p logs/${NAME} + +file_path=${RESULT_NAME}/result.jsonl + +bash scripts/interleave/eval_multiprocess.sh \ +${CKPT_PATH} \ +${JSON_PATH} \ +${RESULT_NAME} \ +${DATA_PATH} \ +"" \ +8 0 + +python3 llava/eval/evaluate_interleave.py --result-dir ${RESULT_NAME} + + + diff --git a/scripts/interleave/eval_multiprocess.sh b/scripts/interleave/eval_multiprocess.sh new file mode 100644 index 000000000..bab4f725e --- /dev/null +++ b/scripts/interleave/eval_multiprocess.sh @@ -0,0 +1,54 @@ +#!/bin/bash + +# Check if three arguments are passed +if [ "$#" -ne 7 ]; then + echo "Usage: $0 " + exit 1 +fi + +# Assign the command line arguments to variables +model_path=$1 +question_path=$2 +base_answer_path=$3 +image_folder=$4 +extra_prompt=$5 +N=$6 +temperature=$7 + +# Loop over each chunk/process +for (( chunk_id=0; chunk_id "${base_answer_path}.jsonl" +for ((i=0; i> "${base_answer_path}/result.jsonl" +done +# remove the unmerged files +for (( chunk_id=0; chunk_id They are basically the same as the basic training scripts, but with some modifications, such as the data yaml. + +- `finetune_clip.sh`: This could be seen as the first image version LLaVA-NeXT (2024-01) training script, with `anyres` strategy and maximum 2x2 image grids. +- `finetune_siglip.sh`: Same but with `siglip` encoder, each grid becomes 729 tokens. +- `finetune_onevision.sh`: This is our latest training script, with `anyres_max_9` strategy and image grids weaving from 1x1 to 6x6, at most to 2304x2304 resolution. Inside the script, we also incorporate the multi-image and video data into training loop. the detail token strategy could be found in our paper. + +# About the LLaVA-OneVision Data + +We need to address the fact that our data has been collected and used in different projects/people. LLaVA-OneVision is our first attempt to integrate these datasets. For the data that has already been uploaded, we will refer you to the corresponding locations. We kindly ask everyone to gather the "fragments" and piece them together into a "diamond" in your own environment. + +Here we explain the some technical details on our data. + +- **pretrain data** - BLIP558K (same as previous llava 1.5 series) +- **mid stage data mixture** + ```yaml + datasets: + - json_path: /mnt/bn/vl-research/data/llava_instruct/blip558k_stage1.5_finetune_w_prompt.json + sampling_strategy: all + - json_path: /mnt/bn/vl-research/data/llava_instruct/coco118k_stage1.5_finetune_w_prompt.json + sampling_strategy: all + - json_path: /mnt/bn/vl-research/data/llava_instruct/cc3m_recap_data_prompt_v2.json + sampling_strategy: all + - json_path: /mnt/bn/vl-research/data/llava_instruct/ureader_tr_sft.json + sampling_strategy: all + - json_path: /mnt/bn/vl-research/data/llava_instruct/instruct_azure_dc_zh_92K.json + sampling_strategy: all + - json_path: /mnt/bn/vl-research/data/llava_instruct/allava/Evol-Instruct-GPT4-Turbo-143K.json + sampling_strategy: all + - json_path: /mnt/bn/vl-research/data/llava_instruct/synthdog_zh/synthdog_zh_100k.json + sampling_strategy: all + - json_path: /mnt/bn/vl-research/data/llava_instruct/synthdog_en/synthdog_en_100k.json + sampling_strategy: all + ``` + The first three datasets can be collected via the [LLaVA-Recap](https://huggingface.co/collections/lmms-lab/llava-next-6623288e2d61edba3ddbf5ff) series data. We did a slightly modification to make the data more compatible to other data with an added prompt behind `` at each question. You may download our dataset and process it in local environment. Each data we added the **same prompt** "\nPlease generate detailed descriptions of the given image.". Here's an example: + ```json + { + "id": "000000000009", + "image": "coco/train2017/000000000009.jpg", + "conversations": [ + { + "from": "human", + "value": "\nPlease generate detailed descriptions of the given image." + }, + { + "from": "gpt", + "value": "The image shows a meal served in a blue tray with compartments. In the top left compartment, there is a slice of bread with a spread that appears to be butter, accompanied by a few almonds and a slice of what looks like a baked potato or sweet potato. The top right compartment contains a variety of fruits, including what seems to be pineapple, orange slices, and possibly a piece of melon.\n\nIn the bottom left compartment, there is a piece of bread with a spread that could be butter or margarine, and a small portion of what might be a meatball or a similar type of meat covered in a sauce. The bottom right compartment contains a serving of broccoli, which appears to be steamed or lightly cooked.\n\nThe meal is presented in a way that suggests it is a balanced meal, with a variety of food groups represented carbohydrates (bread), protein (meatball), healthy fats (almonds and butter), and fruits and vegetables (broccoli and the fruit assortment). The image is a close-up photograph with a focus on the food, and the colors are vibrant, indicating freshness." + } + ] + } + ``` + + The rest files are being uploaded to [LLaVA-OneVision Mid-Stage Data](https://huggingface.co/datasets/lmms-lab/LLaVA-OneVision-Mid-Data), except from `instruct_azure_dc_zh_92K` since it's internal data due and have policy issue (sad but I could preview one example). The `instruct_azure_dc_zh_92K` images are from COCO and QAs are generated by GPT-4V service provided by AzureAI. We uses this data in early stage experiments wishing to add Chinese knowledge to our model, and just kept using it (yes, just legacy, no other specific reason). + ```json + { + "id": "000000051920", + "image": "azure_dc_zh_92K/000000051920.jpg", + "conversations": [ + { + "from": "human", + "value": "\n这幅画中的关键信息有哪些?" + }, + { + "from": "gpt", + "value": "这张图片展示了一位穿着深色(可能是黑色或深蓝色)毛衣的人坐在一张桌子前,正在用右手拿着一片披萨。披萨看起来有多种配料,包括绿色的菠菜和橙红色的番茄,放在一个银色的圆形披萨盘上。披萨盘上还剩下大约六片披萨。在背景中,可以看到一个穿着浅色衣服的人站在柜台后面,柜台上方有一个黑板菜单,上面用白色粉笔写着菜单项,但是字迹不清晰,无法阅读。整个场景看起来像是在一个披萨店内部,环境显得比较休闲。" + } + ], + "data_source": "azure_dc_zh_92K" + } + ``` + +- **single-image stage data mixture** + +We have placed our single-image stage data in [single-image-yaml](https://github.com/LLaVA-VL/LLaVA-NeXT/blob/main/scripts/train/single_image.yaml) for users to review. You can download each subset from [onevision-data](https://huggingface.co/datasets/lmms-lab/LLaVA-OneVision-Data). + +Inside the data yaml, the first indicates the previous llava-1.6/next 790K data, you can download them in [llava-next-data](https://huggingface.co/datasets/lmms-lab/LLaVA-NeXT-Data). + +Inside the yaml, the naming would be different with our paper figure due to writing consideration. For users who need to explore our dataset, you can check the [upload script](https://github.com/LLaVA-VL/LLaVA-NeXT/blob/0070d0ae4931c9b19d9cc57c38e16a87c270a61c/playground/upload_data.py#L175) to find the mapping from our local dataset to HF's version. + +- **onevision stage data mixture** + +Our onevision stage data is available in [onevision-yaml](https://github.com/LLaVA-VL/LLaVA-NeXT/blob/main/scripts/train/onevision.yaml). The single-image portion can be downloaded from the above Huggingface link for onevision data. Here's a breakdown of each part: + + - Around 800K higher-quality data re-sampled from the previous stage (yes, it's data replay!). + - Multi-image data is released in [M4-Instruct Data](https://huggingface.co/datasets/lmms-lab/M4-Instruct-Data). We combine the different subsets into two jsons (as they are mainly from DEMON and Mantis) in our training yaml, the jsons are: + - /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/llava_ofa_DEMON-FULL_filtered_311085.json + - /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/llava_ofa_mantis-instruct_reformatted.json + + - Video Data: We have released the video part along with [llava-video-data](https://huggingface.co/datasets/lmms-lab/LLaVA-Video-178K). Users can download the data, and we utilize the subset used in LLaVA-OneVision: + - We have included captions and open-ended questions in the `0_30_s_academic_v0_1` split, along with 240,000 open-ended QA items and 15,000 caption entries, as part of the video data in LLaVA-Hound for LLaVA-OneVision. + - 0_30_s_academic_v0_1 captions + - 0_30_s_academic_v0_1 open-ended QA + - LLaVA-Hound: Same as above. diff --git a/scripts/train/direct_finetune_clip.sh b/scripts/train/direct_finetune_clip.sh new file mode 100644 index 000000000..ad246b413 --- /dev/null +++ b/scripts/train/direct_finetune_clip.sh @@ -0,0 +1,65 @@ +export OMP_NUM_THREADS=8 +export NCCL_IB_DISABLE=0 +export NCCL_IB_GID_INDEX=3 +export NCCL_SOCKET_IFNAME=eth0 +export NCCL_DEBUG=INFO + +LLM_VERSION="Qwen/Qwen2-7B-Instruct" +LLM_VERSION_CLEAN="${LLM_VERSION//\//_}" +VISION_MODEL_VERSION="openai/clip-vit-large-patch14-336" +VISION_MODEL_VERSION_CLEAN="${VISION_MODEL_VERSION//\//_}" + +############### Pretrain ################ + +PROMPT_VERSION="qwen_1_5" + +BASE_RUN_NAME="llavanext-${VISION_MODEL_VERSION_CLEAN}-${LLM_VERSION_CLEAN}-mlp2x_gelu-pretrain_blip558k_plain" +echo "BASE_RUN_NAME: ${BASE_RUN_NAME}" + +ACCELERATE_CPU_AFFINITY=1 torchrun --nproc_per_node="${NUM_GPUS}" --nnodes="${NNODES}" --node_rank="${RANK}" --master_addr="${ADDR}" --master_port="${PORT}" \ + llava/train/train_mem.py \ + --deepspeed scripts/zero3.json \ + --model_name_or_path ${LLM_VERSION} \ + --version ${PROMPT_VERSION} \ + --data_path=llava_1_6.json \ + --image_folder your_image_folder \ + --pretrain_mm_mlp_adapter="/checkpoints/projectors/${BASE_RUN_NAME}/mm_projector.bin" \ + --mm_tunable_parts="mm_vision_tower,mm_mlp_adapter,mm_language_model" \ + --mm_vision_tower_lr=2e-6 \ + --vision_tower ${VISION_MODEL_VERSION} \ + --mm_projector_type mlp2x_gelu \ + --mm_vision_select_layer -2 \ + --mm_use_im_start_end False \ + --mm_use_im_patch_token False \ + --group_by_modality_length True \ + --image_aspect_ratio anyres \ + --image_grid_pinpoints "[(336, 672), (672, 336), (672, 672), (1008, 336), (336, 1008)]" \ + --mm_patch_merge_type spatial_unpad \ + --bf16 True \ + --run_name $MID_RUN_NAME \ + --output_dir "/checkpoints/${MID_RUN_NAME}" \ + --num_train_epochs 1 \ + --per_device_train_batch_size 4 \ + --per_device_eval_batch_size 4 \ + --gradient_accumulation_steps 1 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 3000 \ + --save_total_limit 1 \ + --learning_rate 1e-5 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --tf32 True \ + --model_max_length 32768 \ + --gradient_checkpointing True \ + --dataloader_num_workers 16 \ + --lazy_preprocess True \ + --report_to wandb \ + --torch_compile True \ + --torch_compile_backend "inductor" \ + --dataloader_drop_last True \ + --attn_implementation sdpa + +# You can delete the sdpa attn_implementation if you want to use flash attn diff --git a/scripts/train/direct_finetune_siglip_a4.sh b/scripts/train/direct_finetune_siglip_a4.sh new file mode 100644 index 000000000..2a55cd914 --- /dev/null +++ b/scripts/train/direct_finetune_siglip_a4.sh @@ -0,0 +1,67 @@ +export OMP_NUM_THREADS=8 +export NCCL_IB_DISABLE=0 +export NCCL_IB_GID_INDEX=3 +export NCCL_SOCKET_IFNAME=eth0 +export NCCL_DEBUG=INFO + +LLM_VERSION="Qwen/Qwen2-7B-Instruct" +LLM_VERSION_CLEAN="${LLM_VERSION//\//_}" +VISION_MODEL_VERSION="google/siglip-so400m-patch14-384" +VISION_MODEL_VERSION_CLEAN="${VISION_MODEL_VERSION//\//_}" + +############### Pretrain ################ + +PROMPT_VERSION="qwen_1_5" + +BASE_RUN_NAME="llavanext-${VISION_MODEL_VERSION_CLEAN}-${LLM_VERSION_CLEAN}-mlp2x_gelu-pretrain_blip558k_plain" +echo "BASE_RUN_NAME: ${BASE_RUN_NAME}" + +CKPT_PATH=$LLM_VERSION # this could also be the previous stage checkpoint + +ACCELERATE_CPU_AFFINITY=1 torchrun --nproc_per_node="${NUM_GPUS}" --nnodes="${NNODES}" --node_rank="${RANK}" --master_addr="${ADDR}" --master_port="${PORT}" \ + llava/train/train_mem.py \ + --deepspeed scripts/zero3.json \ + --model_name_or_path ${CKPT_PATH} \ + --version ${PROMPT_VERSION} \ + --data_path=llava_1_6.json \ + --image_folder your_image_folder \ + --pretrain_mm_mlp_adapter="/checkpoints/projectors/${BASE_RUN_NAME}/mm_projector.bin" \ + --mm_tunable_parts="mm_vision_tower,mm_mlp_adapter,mm_language_model" \ + --mm_vision_tower_lr=2e-6 \ + --vision_tower ${VISION_MODEL_VERSION} \ + --mm_projector_type mlp2x_gelu \ + --mm_vision_select_layer -2 \ + --mm_use_im_start_end False \ + --mm_use_im_patch_token False \ + --group_by_modality_length True \ + --image_aspect_ratio anyres \ + --image_grid_pinpoints "[(384, 768), (768, 384), (768, 768), (1152, 384), (384, 1152)]" \ + --mm_patch_merge_type spatial_unpad \ + --bf16 True \ + --run_name $MID_RUN_NAME \ + --output_dir "/checkpoints/${MID_RUN_NAME}" \ + --num_train_epochs 1 \ + --per_device_train_batch_size 4 \ + --per_device_eval_batch_size 4 \ + --gradient_accumulation_steps 1 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 3000 \ + --save_total_limit 1 \ + --learning_rate 1e-5 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --tf32 True \ + --model_max_length 32768 \ + --gradient_checkpointing True \ + --dataloader_num_workers 16 \ + --lazy_preprocess True \ + --report_to wandb \ + --torch_compile True \ + --torch_compile_backend "inductor" \ + --dataloader_drop_last True \ + --attn_implementation sdpa + +# You can delete the sdpa attn_implementation if you want to use flash attn diff --git a/scripts/train/dpo.sh b/scripts/train/dpo.sh new file mode 100644 index 000000000..83bfe8412 --- /dev/null +++ b/scripts/train/dpo.sh @@ -0,0 +1,64 @@ +export OMP_NUM_THREADS=8 +export NCCL_IB_DISABLE=0 +export NCCL_IB_GID_INDEX=3 +# export NCCL_IB_HCA=${ARNOLD_RDMA_DEVICE} +export NCCL_SOCKET_IFNAME=eth0 +export NCCL_DEBUG=INFO + +VISION_MODEL_VERSION="openai/clip-vit-large-patch14-336" +VISION_MODEL_VERSION_CLEAN="${VISION_MODEL_VERSION//\//_}" + +############### Pretrain ################ + +# Stage 2 +PROMPT_VERSION="qwen_1_5" + +#torchrun --nproc_per_node="${ARNOLD_WORKER_GPU}" --nnodes="${ARNOLD_WORKER_NUM}" --node_rank="${ARNOLD_ID}" --master_addr="${METIS_WORKER_0_HOST}" --master_port="${port_in_cmd}" \ +ACCELERATE_CPU_AFFINITY=1 torchrun --nproc_per_node="${ARNOLD_WORKER_GPU}" --nnodes="${ARNOLD_WORKER_NUM}" --node_rank="${ARNOLD_ID}" --master_addr="${METIS_WORKER_0_HOST}" --master_port="${port_in_cmd}" \ + llava/train/train_dpo.py \ + --deepspeed scripts/zero3.json \ + --model_name_or_path lmms-lab/LongVA-7B \ + --version $PROMPT_VERSION \ + --dpo_alpha 1.0 --beta 0.1 --gamma 0 \ + --data_path="/data/llava_video/shareVideoGPTV/dpo/sft_dpo_17k.jsonl" \ + --image_folder /data/llava_data \ + --video_folder /llava_video/shareVideoGPTV/frames/all_frames/ \ + --mm_tunable_parts="mm_vision_tower,mm_mlp_adapter,mm_language_model" \ + --vision_tower ${VISION_MODEL_VERSION} \ + --mm_projector_type mlp2x_gelu \ + --mm_vision_select_layer -2 \ + --mm_use_im_start_end False \ + --mm_use_im_patch_token False \ + --mm_spatial_pool_stride 2 \ + --mm_resampler_type "spatial_pool" \ + --mm_spatial_pool_out_channels 1024 \ + --group_by_modality_length True \ + --image_aspect_ratio anyres \ + --image_grid_pinpoints "[(336, 672), (672, 336), (672, 672), (1008, 336), (336, 1008)]" \ + --mm_patch_merge_type unires \ + --bf16 True \ + --run_name $MID_RUN_NAME \ + --output_dir "/checkpoints/${MID_RUN_NAME}" \ + --num_train_epochs 3 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 4 \ + --gradient_accumulation_steps 16 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 3000 \ + --save_total_limit 1 \ + --learning_rate 5e-7 \ + --weight_decay 0. \ + --warmup_ratio 0.1 \ + --lr_scheduler_type "linear" \ + --logging_steps 1 \ + --tf32 True \ + --model_max_length 32768 \ + --gradient_checkpointing True \ + --dataloader_num_workers 16 \ + --lazy_preprocess True \ + --report_to wandb \ + --torch_compile True \ + --torch_compile_backend "inductor" \ + --dataloader_drop_last True \ + --attn_implementation sdpa \ No newline at end of file diff --git a/scripts/train/dpo_ov7b.sh b/scripts/train/dpo_ov7b.sh new file mode 100644 index 000000000..26748e4b1 --- /dev/null +++ b/scripts/train/dpo_ov7b.sh @@ -0,0 +1,69 @@ +export OMP_NUM_THREADS=8 +export NCCL_IB_DISABLE=0 +export NCCL_IB_GID_INDEX=3 +# export NCCL_IB_HCA=${ARNOLD_RDMA_DEVICE} +export NCCL_SOCKET_IFNAME=eth0 +export NCCL_DEBUG=INFO + +VISION_MODEL_VERSION="google/siglip-so400m-patch14-384" +VISION_MODEL_VERSION_CLEAN="${VISION_MODEL_VERSION//\//_}" + +# DPO Stage +PROMPT_VERSION="qwen_1_5" +SFT_MODEL="lmms-lab/llava-onevision-qwen2-7b-ov" +EPOCH=1 +beta=0.1 + +DPO_RUN_NAME="llava-onevision-qwen2-7b-ov_dpo-beta${beta}-epoch${EPOCH}" +DPO_CLEAN_NAME="${DPO_RUN_NAME##*/}" +OUTPUT_DIR="/${DPO_CLEAN_NAME}" +DATA_PATH="" + +echo $DPO_RUN_NAME + +ACCELERATE_CPU_AFFINITY=1 torchrun --nproc_per_node="${NUM_GPUS}" --nnodes="${NNODES}" --node_rank="${RANK}" --master_addr="${ADDR}" --master_port="${PORT}" \ + llava/train/train_dpo.py \ + --deepspeed scripts/zero3.json \ + --model_name_or_path=${SFT_MODEL} \ + --dpo_alpha=1.0 \ + --beta=${beta} \ + --gamma=0 \ + --version $PROMPT_VERSION \ + --data_path=$DATA_PATH \ + --image_folder "" \ + --mm_tunable_parts="mm_vision_tower,mm_mlp_adapter,mm_language_model" \ + --unfreeze_mm_vision_tower True \ + --vision_tower ${VISION_MODEL_VERSION} \ + --mm_projector_type mlp2x_gelu \ + --mm_vision_select_layer -2 \ + --mm_use_im_start_end False \ + --mm_use_im_patch_token False \ + --group_by_modality_length True \ + --image_aspect_ratio anyres_max_9 \ + --image_grid_pinpoints "(1x1),...,(6x6)" \ + --mm_patch_merge_type spatial_unpad \ + --bf16 True \ + --run_name $DPO_CLEAN_NAME \ + --output_dir $OUTPUT_DIR \ + --num_train_epochs $EPOCH \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --gradient_accumulation_steps 8 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 1000 \ + --save_total_limit 1 \ + --learning_rate 5e-7 \ + --weight_decay 0. \ + --warmup_ratio 0.1 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --tf32 True \ + --model_max_length 32768 \ + --gradient_checkpointing True \ + --dataloader_num_workers 4 \ + --lazy_preprocess True \ + --report_to wandb \ + --dataloader_drop_last True + + diff --git a/scripts/train/finetune_ov.sh b/scripts/train/finetune_ov.sh new file mode 100644 index 000000000..0a4526684 --- /dev/null +++ b/scripts/train/finetune_ov.sh @@ -0,0 +1,75 @@ +export OMP_NUM_THREADS=8 +export NCCL_IB_DISABLE=0 +export NCCL_IB_GID_INDEX=3 +export NCCL_SOCKET_IFNAME=eth0 +export NCCL_DEBUG=INFO + +LLM_VERSION="Qwen/Qwen2-7B-Instruct" +# for 7b model we recommend bs=1, accum=2, 16 nodes, 128 gpus, lr=1e-5, warmup=0.03 +# for 72b model we recommend bs=1, accum=1, 32 nodes, 256 gpus, lr=1e-5, warmup=0.03 +LLM_VERSION_CLEAN="${LLM_VERSION//\//_}" +VISION_MODEL_VERSION="google/siglip-so400m-patch14-384" +VISION_MODEL_VERSION_CLEAN="${VISION_MODEL_VERSION//\//_}" + +############### Pretrain ################ + +BASE_RUN_NAME="llavanext-google_siglip-so400m-patch14-384-Qwen_Qwen2-7B-Instruct-mlp2x_gelu-pretrain_blip558k_plain" +echo "BASE_RUN_NAME: ${BASE_RUN_NAME}" + +############### Finetune ################ + +# Stage 2 +PROMPT_VERSION="qwen_1_5" +RUN_NAME="llava-onevision-${VISION_MODEL_VERSION_CLEAN}-${LLM_VERSION_CLEAN}-ov_stage_am9" +PREV_STAGE_CHECKPOINT="/mnt/bn/vl-research/checkpoints/onevision/llavanext-google_siglip-so400m-patch14-384-Qwen_Qwen2-7B-Instruct-mid_to_final_next_3m_am9_july14" # replace it with your last checkpoint training from single image collection +echo "PREV_STAGE_CHECKPOINT: ${PREV_STAGE_CHECKPOINT}" +echo "MID_RUN_NAME: ${RUN_NAME}" + +ACCELERATE_CPU_AFFINITY=1 torchrun --nproc_per_node="${NUM_GPUS}" --nnodes="${NNODES}" --node_rank="${RANK}" --master_addr="${ADDR}" --master_port="${PORT}" \ + llava/train/train_mem.py \ + --deepspeed scripts/zero3.json \ + --model_name_or_path $PREV_STAGE_CHECKPOINT \ + --version $PROMPT_VERSION \ + --data_path /mnt/bn/vl-research/workspace/boli01/projects/LLaVA_Next/scripts/i18n/scale_llms/next_ov_stage_july21.yaml \ + --image_folder /mnt/bn/vl-research/data/llava_data \ + --video_folder /mnt/bn/vl-research/data/llava_video \ + --mm_tunable_parts="mm_vision_tower,mm_mlp_adapter,mm_language_model" \ + --mm_vision_tower_lr=2e-6 \ + --vision_tower ${VISION_MODEL_VERSION} \ + --mm_projector_type mlp2x_gelu \ + --mm_vision_select_layer -2 \ + --mm_use_im_start_end False \ + --mm_use_im_patch_token False \ + --group_by_modality_length True \ + --image_aspect_ratio anyres_max_9 \ + --image_grid_pinpoints "(1x1),...,(6x6)" \ + --mm_patch_merge_type spatial_unpad \ + --bf16 True \ + --run_name $RUN_NAME \ + --output_dir /mnt/bn/vl-research/checkpoints/onevision/$RUN_NAME \ + --num_train_epochs 1 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 4 \ + --gradient_accumulation_steps 2 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 1000 \ + --save_total_limit 1 \ + --learning_rate 1e-5 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --tf32 True \ + --model_max_length 32768 \ + --gradient_checkpointing True \ + --dataloader_num_workers 4 \ + --lazy_preprocess True \ + --report_to wandb \ + --torch_compile True \ + --torch_compile_backend "inductor" \ + --dataloader_drop_last True \ + --frames_upbound 32 +exit 0; + +# You can delete the sdpa attn_implementation if you want to use flash attn diff --git a/scripts/train/finetune_si.sh b/scripts/train/finetune_si.sh new file mode 100644 index 000000000..32768f306 --- /dev/null +++ b/scripts/train/finetune_si.sh @@ -0,0 +1,73 @@ +export OMP_NUM_THREADS=8 +export NCCL_IB_DISABLE=0 +export NCCL_IB_GID_INDEX=3 +export NCCL_SOCKET_IFNAME=eth0 +export NCCL_DEBUG=INFO + +LLM_VERSION="Qwen/Qwen2-7B-Instruct" +# for 7b model we recommend bs=1, accum=2, 16 nodes, 128 gpus, lr=1e-5, warmup=0.03 +# for 72b model we recommend bs=1, accum=1, 32 nodes, 256 gpus, lr=1e-5, warmup=0.03 +LLM_VERSION_CLEAN="${LLM_VERSION//\//_}" +VISION_MODEL_VERSION="google/siglip-so400m-patch14-384" +VISION_MODEL_VERSION_CLEAN="${VISION_MODEL_VERSION//\//_}" + +############### Pretrain ################ + +BASE_RUN_NAME="llavanext-google_siglip-so400m-patch14-384-Qwen_Qwen2-7B-Instruct-mlp2x_gelu-pretrain_blip558k_plain" +echo "BASE_RUN_NAME: ${BASE_RUN_NAME}" + +############### Finetune ################ + +# Stage 2 +PROMPT_VERSION="qwen_1_5" +RUN_NAME="llava-onevision-${VISION_MODEL_VERSION_CLEAN}-${LLM_VERSION_CLEAN}-si_stage_am9" +PREV_STAGE_CHECKPOINT="/mnt/bn/vl-research/checkpoints/onevision/xxxxxxxxxxxxxxxx" # replace it with your last checkpoint training from mid stage +echo "PREV_STAGE_CHECKPOINT: ${PREV_STAGE_CHECKPOINT}" +echo "MID_RUN_NAME: ${RUN_NAME}" + +ACCELERATE_CPU_AFFINITY=1 torchrun --nproc_per_node="${NUM_GPUS}" --nnodes="${NNODES}" --node_rank="${RANK}" --master_addr="${ADDR}" --master_port="${PORT}" \ + llava/train/train_mem.py \ + --deepspeed scripts/zero3.json \ + --model_name_or_path $PREV_STAGE_CHECKPOINT \ + --version $PROMPT_VERSION \ + --data_path /mnt/bn/vl-research/workspace/boli01/projects/LLaVA_Next/scripts/i18n/scale_llms/next_3p2m_single_image.yaml \ + --image_folder /mnt/bn/vl-research/data/llava_data \ + --video_folder /mnt/bn/vl-research/data/llava_video \ + --mm_tunable_parts="mm_vision_tower,mm_mlp_adapter,mm_language_model" \ + --mm_vision_tower_lr=2e-6 \ + --vision_tower ${VISION_MODEL_VERSION} \ + --mm_projector_type mlp2x_gelu \ + --mm_vision_select_layer -2 \ + --mm_use_im_start_end False \ + --mm_use_im_patch_token False \ + --group_by_modality_length True \ + --image_aspect_ratio anyres_max_9 \ + --image_grid_pinpoints "(1x1),...,(6x6)" \ + --mm_patch_merge_type spatial_unpad \ + --bf16 True \ + --run_name $RUN_NAME \ + --output_dir /mnt/bn/vl-research/checkpoints/onevision/$RUN_NAME \ + --num_train_epochs 1 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 4 \ + --gradient_accumulation_steps 2 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 1000 \ + --save_total_limit 1 \ + --learning_rate 1e-5 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --tf32 True \ + --model_max_length 32768 \ + --gradient_checkpointing True \ + --dataloader_num_workers 4 \ + --lazy_preprocess True \ + --report_to wandb \ + --torch_compile True \ + --torch_compile_backend "inductor" \ + --dataloader_drop_last True \ + --frames_upbound 32 +exit 0; diff --git a/scripts/train/mid_stage.yaml b/scripts/train/mid_stage.yaml new file mode 100644 index 000000000..566998719 --- /dev/null +++ b/scripts/train/mid_stage.yaml @@ -0,0 +1,17 @@ +datasets: + - json_path: /mnt/bn/vl-research/data/llava_instruct/blip558k_stage1.5_finetune_w_prompt.json # released in lmms-lab/LLaVA-ReCap-* + sampling_strategy: all + - json_path: /mnt/bn/vl-research/data/llava_instruct/coco118k_stage1.5_finetune_w_prompt.json # released in lmms-lab/LLaVA-ReCap-* + sampling_strategy: all + - json_path: /mnt/bn/vl-research/data/llava_instruct/cc3m_recap_data_prompt_v2.json # released in lmms-lab/LLaVA-ReCap-* + sampling_strategy: all + - json_path: /mnt/bn/vl-research/data/llava_instruct/ureader_tr_sft.json # released in lmms-lab/LLaVA-OneVision-Mid-Data + sampling_strategy: all + - json_path: /mnt/bn/vl-research/data/llava_instruct/instruct_azure_dc_zh_92K.json # not released, explained at https://github.com/LLaVA-VL/LLaVA-NeXT/tree/main/scripts/train + sampling_strategy: all + - json_path: /mnt/bn/vl-research/data/llava_instruct/Evol-Instruct-GPT4-Turbo-143K.json # released in lmms-lab/LLaVA-OneVision-Mid-Data + sampling_strategy: all + - json_path: /mnt/bn/vl-research/data/llava_instruct/synthdog_zh/synthdog_zh_100k.json # released in lmms-lab/LLaVA-OneVision-Mid-Data + sampling_strategy: all + - json_path: /mnt/bn/vl-research/data/llava_instruct/synthdog_en/synthdog_en_100k.json # released in lmms-lab/LLaVA-OneVision-Mid-Data + sampling_strategy: all \ No newline at end of file diff --git a/scripts/train/onevision.yaml b/scripts/train/onevision.yaml new file mode 100644 index 000000000..117a5e8f9 --- /dev/null +++ b/scripts/train/onevision.yaml @@ -0,0 +1,185 @@ +datasets: + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/llava_next_fit_mix_filtered_text_wild_738590.json + sampling_strategy: "first:50%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/llava_wild_4v_39k.json # not released + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/llava_wild_4v_12k.json # not released + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/mavis_math_metagen_87358.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/mavis_math_rule_geo_100000.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/cambrian_filtered_gpt4vo_sp_token_fltd_max10k_checked.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/VisualWebInstruct_filtered_263589.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/visual_chat_en_26048_gpt4o_coco_checked.json # not released + sampling_strategy: "all" + # - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/gpt4o_combinations_51316.json + # sampling_strategy: "all" + # - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/chrome_writting_train_8835.json + # sampling_strategy: "first:20%" + # - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/k12_printing_train_256646.json + # sampling_strategy: "first:1%" + # - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/iiit5k_annotations_2000.json + # sampling_strategy: "first:20%" + # - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/hme100k_train_clean_74502.json + # sampling_strategy: "first:10%" + # - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/sroie_data_33626.json + # sampling_strategy: "first:1%" + # - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/orand_car_a_train_2009.json + # sampling_strategy: "first:10%" + # - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/orand_car_b_train_3000.json + # sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/llavar_gpt4_20k.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/ai2d_azuregpt_detailed_understanding_4874.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/infographic_vqa_4404.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/infographic_azuregpt4v_1992.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/lrv_chart_1787.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/lrv_normal_gpt4v_filtered_10500.json + sampling_strategy: "first:10%" + # - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/scienceqa_nona_context_19218.json + # sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/allava_instruct_vflan4v_20000.json + sampling_strategy: "first:30%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/allava_instruct_laion4v_50000.json + sampling_strategy: "first:30%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/textocr_gpt4v_train_converted_25114.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/ai2d_train_internvl_single_12413.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/textcaps_train_21952.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/ureader_new/ureader_qa_sft.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/ureader_new/ureader_cap_sft.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/ureader_new/ureader_ie_sft.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/ureader_new/ureader_kg_sft.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/vision_flan_filtered_186070.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/mathqa_29837.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/geo3k_2101.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/geo170k_qa_converted_67833.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/geo170k_align_converted_60252.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/sharegpt4o_dataset.jsonl + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/sharegpt4v-coco-50k.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/sharegpt4v-knowledge-2k.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/sharegpt4v-llava-30k.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/sharegpt4v-sam-20k.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_CLEVR-Math_5290.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_FigureQA_17597.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_Geometry3K_9734.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_GeoQA+_17172.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_GEOS_508.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_IconQA_22599.json + sampling_strategy: "first:5%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_MapQA_5235.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_PlotQA_5485.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_PMC-VQA_35958.json + sampling_strategy: "first:1%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_Super-CLEVR_8652.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_TabMWP_22462.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_TQA_10181.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_UniGeo_11959.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_VizWiz_6614.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_VQA-AS_5907.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_VQA-RAD_2130.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/magpie_pro_qwen2_72b_st_300000_sp_token_fltd_299992.json + sampling_strategy: "end:20%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/magpie_pro_l3_80b_st_300000.json + sampling_strategy: "end:20%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/magpie_pro_l3_80b_mt_300000_sp_token_fltd_299998.json + sampling_strategy: "end:20%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/image_textualization_dataset_filtered.json + sampling_strategy: "first:20%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/ai2d_llava_format_2434.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/chart2text_26961.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/chartqa_18265_llava_format.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/diagram_image_to_text_300.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/hateful_memes_8500_llava_format.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/hitab_2500_llava_format.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/iam_5663.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/infographic_vqa_2118_llava_format.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/intergps_1280_llava_format.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/mapqa_37417_llava_format.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/rendered_text_10000.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/robut_sqa_8514.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/robut_wikisql_74989.json + sampling_strategy: "first:10%" + # - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/robut_wtq_38246_llava_format_filtered_4000tokens_38236.json + # sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/screen2words_15730.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/tabmwp_22722.json + sampling_strategy: "first:5%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/tallyqa_98680_llava_format.json + sampling_strategy: "first:5%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/st_vqa_17247_llava_format.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/tqa_llava_format_27307.json + sampling_strategy: "first:5%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/visual7w_llava_format_14366.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/visualmrc_3027.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/vqarad_313_llava_format.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/vsr_2157_llava_format.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/vistext_9969.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/websight_10000.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/llava_ofa_DEMON-FULL_filtered_311085.json # released in lmms-lab/M4-Instruct + sampling_strategy: all + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/llava_ofa_mantis-instruct_reformatted.json # released in lmms-lab/M4-Instruct + sampling_strategy: all + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/academic_source_30s_v1_all.json # will be released in next version of LLaVA-NeXT-Video + sampling_strategy: all + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/0718_0_30_s_academic_mc_v0_1_all.json # will be released in next version of LLaVA-NeXT-Video + sampling_strategy: all + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/sharegpt4video_255000.json # download from sharegpt4video + sampling_strategy: all diff --git a/scripts/train/pretrain_clip.sh b/scripts/train/pretrain_clip.sh new file mode 100644 index 000000000..2eb5c27e6 --- /dev/null +++ b/scripts/train/pretrain_clip.sh @@ -0,0 +1,55 @@ +export OMP_NUM_THREADS=8 +export NCCL_IB_DISABLE=0 +export NCCL_IB_GID_INDEX=3 +export NCCL_SOCKET_IFNAME=eth0 +export NCCL_DEBUG=INFO + +LLM_VERSION="Qwen/Qwen2-7B-Instruct" +LLM_VERSION_CLEAN="${LLM_VERSION//\//_}" +VISION_MODEL_VERSION="openai/clip-vit-large-patch14-336" +VISION_MODEL_VERSION_CLEAN="${VISION_MODEL_VERSION//\//_}" + +############### Pretrain ################ + +PROMPT_VERSION=plain + +BASE_RUN_NAME="llavanext-${VISION_MODEL_VERSION_CLEAN}-${LLM_VERSION_CLEAN}-mlp2x_gelu-pretrain_blip558k_plain" +echo "BASE_RUN_NAME: ${BASE_RUN_NAME}" + +ACCELERATE_CPU_AFFINITY=1 torchrun --nproc_per_node="${NUM_GPUS}" --nnodes="${NNODES}" --node_rank="${RANK}" --master_addr="${ADDR}" --master_port="${PORT}" \ + llava/train/train_mem.py \ + --deepspeed scripts/zero3.json \ + --model_name_or_path ${LLM_VERSION} \ + --version ${PROMPT_VERSION} \ + --data_path /blip_558k/blip_558k_plain.json \ + --image_folder /blip_558k/images \ + --vision_tower ${VISION_MODEL_VERSION} \ + --mm_tunable_parts="mm_mlp_adapter" \ + --mm_vision_select_layer -2 \ + --mm_projector_type mlp2x_gelu \ + --mm_use_im_start_end False \ + --mm_use_im_patch_token False \ + --bf16 True \ + --output_dir /checkpoints/projectors/${BASE_RUN_NAME} \ + --num_train_epochs 1 \ + --per_device_train_batch_size 16 \ + --per_device_eval_batch_size 4 \ + --gradient_accumulation_steps 1 \ + --evaluation_strategy "no" \ + --save_strategy "no" \ + --save_steps 50000 \ + --learning_rate 1e-3 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --tf32 True \ + --model_max_length 8192 \ + --gradient_checkpointing True \ + --dataloader_num_workers 16 \ + --lazy_preprocess True \ + --report_to wandb \ + --run_name $BASE_RUN_NAME \ + --attn_implementation sdpa + +# You can delete the sdpa attn_implementation if you want to use flash attn \ No newline at end of file diff --git a/scripts/train/pretrain_siglip.sh b/scripts/train/pretrain_siglip.sh new file mode 100644 index 000000000..ef13793d1 --- /dev/null +++ b/scripts/train/pretrain_siglip.sh @@ -0,0 +1,55 @@ +export OMP_NUM_THREADS=8 +export NCCL_IB_DISABLE=0 +export NCCL_IB_GID_INDEX=3 +export NCCL_SOCKET_IFNAME=eth0 +export NCCL_DEBUG=INFO + +LLM_VERSION="Qwen/Qwen2-7B-Instruct" +LLM_VERSION_CLEAN="${LLM_VERSION//\//_}" +VISION_MODEL_VERSION="google/siglip-so400m-patch14-384" +VISION_MODEL_VERSION_CLEAN="${VISION_MODEL_VERSION//\//_}" + +############### Pretrain ################ + +PROMPT_VERSION=plain + +BASE_RUN_NAME="llavanext-${VISION_MODEL_VERSION_CLEAN}-${LLM_VERSION_CLEAN}-mlp2x_gelu-pretrain_blip558k_plain" +echo "BASE_RUN_NAME: ${BASE_RUN_NAME}" + +ACCELERATE_CPU_AFFINITY=1 torchrun --nproc_per_node="${NUM_GPUS}" --nnodes="${NNODES}" --node_rank="${RANK}" --master_addr="${ADDR}" --master_port="${PORT}" \ + llava/train/train_mem.py \ + --deepspeed scripts/zero3.json \ + --model_name_or_path ${LLM_VERSION} \ + --version ${PROMPT_VERSION} \ + --data_path /blip_558k/blip_558k_plain.json \ + --image_folder /blip_558k/images \ + --vision_tower ${VISION_MODEL_VERSION} \ + --mm_tunable_parts="mm_mlp_adapter" \ + --mm_vision_select_layer -2 \ + --mm_projector_type mlp2x_gelu \ + --mm_use_im_start_end False \ + --mm_use_im_patch_token False \ + --bf16 True \ + --output_dir /checkpoints/projectors/${BASE_RUN_NAME} \ + --num_train_epochs 1 \ + --per_device_train_batch_size 16 \ + --per_device_eval_batch_size 4 \ + --gradient_accumulation_steps 1 \ + --evaluation_strategy "no" \ + --save_strategy "no" \ + --save_steps 50000 \ + --learning_rate 1e-3 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --tf32 True \ + --model_max_length 8192 \ + --gradient_checkpointing True \ + --dataloader_num_workers 16 \ + --lazy_preprocess True \ + --report_to wandb \ + --run_name $BASE_RUN_NAME \ + --attn_implementation sdpa + +# You can delete the sdpa attn_implementation if you want to use flash attn \ No newline at end of file diff --git a/scripts/train/single_image.yaml b/scripts/train/single_image.yaml new file mode 100644 index 000000000..a21e3a159 --- /dev/null +++ b/scripts/train/single_image.yaml @@ -0,0 +1,187 @@ +datasets: + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/llava_next_fit_mix_filtered_text_wild_738590.json # released in lmms-lab/LLaVA-NeXT-Data + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/llava_wild_4v_39k.json # not released + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/llava_wild_4v_12k.json # not released + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/llavar_gpt4_20k.json + sampling_strategy: "all" + # - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/sroie_data_33626.json + # sampling_strategy: "first:10%" + # - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/orand_car_a_train_2009.json + # sampling_strategy: "all" + # - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/orand_car_b_train_3000.json + # sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/mavis_math_metagen_87358.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/mavis_math_rule_geo_100000.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/chrome_writting_train_8835.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/k12_printing_train_256646.json + sampling_strategy: "first:1%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/iiit5k_annotations_2000.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/hme100k_train_clean_74502.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/ai2d_azuregpt_detailed_understanding_4874.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/infographic_vqa_4404.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/infographic_azuregpt4v_1992.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/lrv_chart_1787.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/lrv_normal_gpt4v_filtered_10500.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/scienceqa_nona_context_19218.json + sampling_strategy: "first:5%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/allava_instruct_vflan4v_20000.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/allava_instruct_laion4v_50000.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/textocr_gpt4v_train_converted_25114.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/ai2d_train_internvl_single_12413.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/textcaps_train_21952.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/ureader_new/ureader_qa_sft.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/ureader_new/ureader_cap_sft.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/ureader_new/ureader_ie_sft.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/ureader_new/ureader_kg_sft.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/vision_flan_filtered_186070.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/mathqa_29837.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/geo3k_2101.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/geo170k_qa_converted_67833.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/geo170k_align_converted_60252.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/sharegpt4v-coco-50k.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/sharegpt4v-knowledge-2k.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/sharegpt4v-llava-30k.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/sharegpt4v-sam-20k.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_CLEVR-Math_5290.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_FigureQA_17597.json + sampling_strategy: "first:5%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_Geometry3K_9734.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_GeoQA+_17172.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_GEOS_508.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_IconQA_22599.json + sampling_strategy: "first:5%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_MapQA_5235.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_PMC-VQA_35958.json + sampling_strategy: "first:5%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_Super-CLEVR_8652.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_TabMWP_22462.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_TQA_10181.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_UniGeo_11959.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_VizWiz_6614.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_VQA-AS_5907.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_VQA-RAD_2130.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/Evol-Instruct-GPT4-Turbo-143000.json + sampling_strategy: "first:30%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/magpie_pro_qwen2_72b_st_300000_sp_token_fltd_299992.json + sampling_strategy: "first:50%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/magpie_pro_l3_80b_st_300000.json + sampling_strategy: "first:50%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/magpie_pro_l3_80b_mt_300000_sp_token_fltd_299998.json + sampling_strategy: "first:50%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/image_textualization_dataset_filtered.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/cambrian_filtered_gpt4vo_sp_token_fltd_max10k.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/sharegpt4o_dataset.jsonl + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/ai2d_llava_format_2434.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/aokvqa_16539_llava_format.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/chart2text_26961.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/chartqa_18265_llava_format.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/clevr_70000_llava_format.json + sampling_strategy: "first:1%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/diagram_image_to_text_300.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/dvqa_200000_llava_format.json + sampling_strategy: "first:1%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/figureqa_100000_llava_format.json + sampling_strategy: "first:1%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/geomverse_9303.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/hateful_memes_8500_llava_format.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/hitab_2500_llava_format.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/iam_5663.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/raven_42000.json + sampling_strategy: "first:5%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/iconqa_llava_format_27307.json + sampling_strategy: "first:5%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/infographic_vqa_2118_llava_format.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/intergps_1280_llava_format.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/mapqa_37417_llava_format.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/multihiertt_7619.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/rendered_text_10000.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/robut_sqa_8514.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/robut_wikisql_74989.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/robut_wtq_38246_llava_format.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/screen2words_15730.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/scienceqa_llava_format_4976.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/tabmwp_22722.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/tallyqa_98680_llava_format.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/st_vqa_17247_llava_format.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/tqa_llava_format_27307.json + sampling_strategy: "first:5%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/visual7w_llava_format_14366.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/visualmrc_3027.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/vqarad_313_llava_format.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/vsr_2157_llava_format.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/vistext_9969.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/websight_10000.json + sampling_strategy: "all" diff --git a/scripts/video/demo/video_demo.sh b/scripts/video/demo/video_demo.sh new file mode 100644 index 000000000..17f24c957 --- /dev/null +++ b/scripts/video/demo/video_demo.sh @@ -0,0 +1,43 @@ +#!/bin/bash +ROOT_DIR="/mnt/bn/vl-research/workspace/yhzhang/LLaVA-NeXT" + +if [ ! -e $ROOT_DIR ]; then + echo "The root dir does not exist. Exiting the script." + exit 1 +fi + +cd $ROOT_DIR + +export PYTHONWARNINGS=ignore +export TOKENIZERS_PARALLELISM=false + +CKPT=$1 +CONV_MODE=$2 +FRAMES=$3 +POOL_STRIDE=$4 +POOL_MODE=$5 +NEWLINE_POSITION=$6 +OVERWRITE=$7 +VIDEO_PATH=$8 + + +if [ "$OVERWRITE" = False ]; then + SAVE_DIR=$(basename $CKPT)_${CONV_MODE}_frames_${FRAMES}_stride_${POOL_STRIDE}_overwrite_${OVERWRITE} + +else + SAVE_DIR=$(basename $CKPT)_${CONV_MODE}_frames_${FRAMES}_stride_${POOL_STRIDE} +fi + +python3 playground/demo/video_demo.py \ + --model-path $CKPT \ + --video_path ${VIDEO_PATH} \ + --output_dir ./work_dirs/video_demo/$SAVE_DIR \ + --output_name pred \ + --chunk-idx $(($IDX - 1)) \ + --overwrite ${OVERWRITE} \ + --mm_spatial_pool_stride ${POOL_STRIDE:-4} \ + --for_get_frames_num $FRAMES \ + --conv-mode $CONV_MODE \ + --mm_spatial_pool_mode ${POOL_MODE:-average} \ + --mm_newline_position ${NEWLINE_POSITION:-grid} \ + --prompt "Please provide a detailed description of the video, focusing on the main subjects, their actions, the background scenes." \ No newline at end of file diff --git a/scripts/video/eval/activitynet_eval.sh b/scripts/video/eval/activitynet_eval.sh new file mode 100644 index 000000000..6e16c5a5d --- /dev/null +++ b/scripts/video/eval/activitynet_eval.sh @@ -0,0 +1,96 @@ +#!/bin/bash +ROOT_DIR="root to LLaVA-NeXT-Video" + +if [ ! -e $ROOT_DIR ]; then + echo "The root dir does not exist. Exiting the script." + exit 1 +fi + +cd $ROOT_DIR + +export PYTHONWARNINGS=ignore +export TOKENIZERS_PARALLELISM=false +CUDA_VISIBLE_DEVICES='0,1,2,3,4,5,6,7' +gpu_list="${CUDA_VISIBLE_DEVICES}" +GPULIST=(${(s:,:)gpu_list}) + +CHUNKS=${#GPULIST[@]} +echo "Using $CHUNKS GPUs" + +CKPT=$1 +CONV_MODE=$2 +FRAMES=$3 +OVERWRITE=$4 +PREDEFINED_CONFIGURE=$5 +mm_spatial_pool_stride=$6 +MODEL_MAX_LENGTH=${7:-0} + +CKPT=$1 +CONV_MODE=$2 +FRAMES=$3 +POOL_STRIDE=$4 +OVERWRITE=$5 +CHUNKS=${6:-1} + +PATCHIFY=False + + +OPENAIKEY="INPUT YOUR OPENAI API" + + +if [ "$OVERWRITE" = False ]; then + SAVE_DIR=$(basename $CKPT)_${CONV_MODE}_frames_${FRAMES}_stride_${POOL_STRIDE}_overwrite_${OVERWRITE} + +else + SAVE_DIR=$(basename $CKPT)_${CONV_MODE}_frames_${FRAMES}_stride_${POOL_STRIDE} +fi + +echo $SAVE_DIR + +# for IDX in {1..$CHUNKS}; do +# GPU_ID=${GPULIST[$IDX]} # Note: Zsh arrays are 1-indexed by default + +# # GPU_FREE=0 +# # while [ $GPU_FREE -eq 0 ]; do +# # # Using nvidia-smi to get the memory usage of the GPU with ID $GPU_ID +# # # Parsing the output to extract the memory usage, and checking if it is "0" +# # MEM_USAGE=$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits -i $GPU_ID | tr -d '[:space:]') + +# # if [ "$MEM_USAGE" -eq 0 ]; then +# # GPU_FREE=1 +# # echo "GPU $GPU_ID is free." +# # else +# # echo "GPU $GPU_ID is in use. Memory used: ${MEM_USAGE}MiB. Checking again in 100 seconds..." +# # sleep 100 +# # fi +# # done + +# echo "Running on GPU $GPU_ID" +# CUDA_VISIBLE_DEVICES=$GPU_ID python3 llavavid/eval/model_activitynet_qa.py \ +# --model-path $CKPT \ +# --video_dir ./data/llava_video/ActivityNet-QA/all_test \ +# --gt_file_question ./data/llava_video/ActivityNet-QA/test_q.json \ +# --gt_file_answers ./data/llava_videoActivityNet-QA/test_a.json \ +# --output_dir ./work_dirs/eval_activitynet/$SAVE_DIR \ +# --output_name pred \ +# --num-chunks $CHUNKS \ +# --chunk-idx $(($IDX - 1)) \ +# --overwrite ${OVERWRITE} \ +# --patchify_video_feature ${PATCHIFY} \ +# --predefined_configure ${PREDEFINED_CONFIGURE} \ +# --mm_spatial_pool_stride ${mm_spatial_pool_stride:-4} \ +# --for_get_frames_num $FRAMES \ +# --model-max-length ${MODEL_MAX_LENGTH:-0} \ +# --conv-mode $CONV_MODE & + +# done + +# wait + +python3 llava/eval/eval_activitynet_qa.py \ + --pred_path ./work_dirs/eval_activitynet/$SAVE_DIR \ + --output_dir ./work_dirs/eval_activitynet/$SAVE_DIR/results \ + --output_json ./work_dirs/eval_activitynet/$SAVE_DIR/results.json \ + --num_chunks $CHUNKS \ + --api_key $OPENAIKEY \ + # --num_tasks 16 \ \ No newline at end of file diff --git a/scripts/video/eval/video_chatgpt_benchmark_eval_shard.sh b/scripts/video/eval/video_chatgpt_benchmark_eval_shard.sh new file mode 100644 index 000000000..8b6518a2d --- /dev/null +++ b/scripts/video/eval/video_chatgpt_benchmark_eval_shard.sh @@ -0,0 +1,242 @@ +#!/bin/bash +ROOT_DIR="root to LLaVA-NeXT-Video" + +if [ ! -e $ROOT_DIR ]; then + echo "The root dir does not exist. Exiting the script." + exit 1 +fi + +cd $ROOT_DIR + +export python3WARNINGS=ignore +export TOKENIZERS_PARALLELISM=false +# CUDA_VISIBLE_DEVICES='0,1,2,3,4,5,6,7' +gpu_list="${CUDA_VISIBLE_DEVICES}" +GPULIST=(${(s:,:)gpu_list}) + +# CHUNKS=${#GPULIST[@]} +# echo "Using $CHUNKS GPUs" + +CKPT=$1 +CONV_MODE=$2 +FRAMES=$3 +POOL_STRIDE=$4 +OVERWRITE=$5 +CHUNKS=${6:-1} + +OPENAIKEY="INPUT YOUR OPENAI API" + +if [ "$OVERWRITE" = False ]; then + SAVE_DIR=$(basename $CKPT)_${CONV_MODE}_frames_${FRAMES}_stride_${POOL_STRIDE}_overwrite_${OVERWRITE} + +else + SAVE_DIR=$(basename $CKPT)_${CONV_MODE}_frames_${FRAMES}_stride_${POOL_STRIDE} +fi + +echo $SAVE_DIR + +# Assuming GPULIST is a bash array containing your GPUs +GPULIST=(0 1 2 3 4 5 6 7) + +# Get the number of GPUs +NUM_GPUS=${#GPULIST[@]} + +# Calculate GPUs per chunk +GPUS_PER_CHUNK=$((NUM_GPUS / CHUNKS)) + + +for IDX in $(seq 1 $CHUNKS); do + START=$(((IDX-1) * GPUS_PER_CHUNK)) + LENGTH=$GPUS_PER_CHUNK # Length for slicing, not the end index + + CHUNK_GPUS=(${GPULIST[@]:$START:$LENGTH}) + + # Convert the chunk GPUs array to a comma-separated string + CHUNK_GPUS_STR=$(IFS=,; echo "${CHUNK_GPUS[*]}") + + # ALL_GPUS_FREE=0 + # while [ $ALL_GPUS_FREE -eq 0 ]; do + # ALL_GPUS_FREE=1 # Assume all GPUs are free initially + + # for GPU_ID in $CHUNK_GPUS; do + # MEM_USAGE=$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits -i $GPU_ID | tr -d '[:space:]') + + # # Assuming a GPU is considered free if its memory usage is less than 100 MiB + # if [ "$MEM_USAGE" -ge 100 ]; then + # ALL_GPUS_FREE=0 + # echo "GPU $GPU_ID is in use. Memory used: ${MEM_USAGE}MiB." + # break # Exit the loop early as we found a GPU that is not free + # fi + # done + + # if [ $ALL_GPUS_FREE -eq 0 ]; then + # echo "Not all GPUs in chunk are free. Checking again in 100 seconds..." + # sleep 100 + # fi + # done + + echo "CUDA_VISIBLE_DEVICES=$CHUNK_GPUS_STR" + CUDA_VISIBLE_DEVICES=$CHUNK_GPUS_STR python3 llava/eval/model_video_chatgpt_general.py \ + --model-path $CKPT \ + --video_dir ./data/llava_video/video-chatgpt/evaluation/Test_Videos/ \ + --gt_file ./data/llava_video/video-chatgpt/evaluation/generic_qa.json \ + --output_dir ./work_dirs/eval_video_chatgpt/$SAVE_DIR \ + --output_name pred \ + --num-chunks $CHUNKS \ + --chunk-idx $(($IDX - 1)) \ + --overwrite ${OVERWRITE:-true} \ + --mm_spatial_pool_stride ${POOL_STRIDE:-4} \ + --for_get_frames_num $FRAMES \ + --conv-mode $CONV_MODE & +done + +wait + +python3 llava/eval/evaluate_benchmark_1_correctness.py \ + --pred_path ./work_dirs/eval_video_chatgpt/$SAVE_DIR \ + --output_dir ./work_dirs/eval_video_chatgpt/$SAVE_DIR/correctness_results \ + --output_json ./work_dirs/eval_video_chatgpt/$SAVE_DIR/correctness_results.json \ + --num_chunks $CHUNKS \ + --output_name pred \ + --num_tasks 16 \ + --api_key $OPENAIKEY \ + + +python3 llava/eval/evaluate_benchmark_2_detailed_orientation.py \ + --pred_path ./work_dirs/eval_video_chatgpt/$SAVE_DIR \ + --output_dir ./work_dirs/eval_video_chatgpt/$SAVE_DIR/detail_results \ + --output_json ./work_dirs/eval_video_chatgpt/$SAVE_DIR/detail_results.json \ + --num_chunks $CHUNKS \ + --output_name pred \ + --num_tasks 16 \ + --api_key $OPENAIKEY \ + + +python3 llava/eval/evaluate_benchmark_3_context.py \ + --pred_path ./work_dirs/eval_video_chatgpt/$SAVE_DIR \ + --output_dir ./work_dirs/eval_video_chatgpt/$SAVE_DIR/context_results \ + --output_json ./work_dirs/eval_video_chatgpt/$SAVE_DIR/context_results.json \ + --num_chunks $CHUNKS \ + --output_name pred \ + --num_tasks 16 \ + --api_key $OPENAIKEY \ + + + +for IDX in $(seq 1 $CHUNKS); do + START=$(((IDX-1) * GPUS_PER_CHUNK)) + LENGTH=$GPUS_PER_CHUNK # Length for slicing, not the end index + + CHUNK_GPUS=(${GPULIST[@]:$START:$LENGTH}) + + # Convert the chunk GPUs array to a comma-separated string + CHUNK_GPUS_STR=$(IFS=,; echo "${CHUNK_GPUS[*]}") + + # ALL_GPUS_FREE=0 + # while [ $ALL_GPUS_FREE -eq 0 ]; do + # ALL_GPUS_FREE=1 # Assume all GPUs are free initially + + # for GPU_ID in $CHUNK_GPUS; do + # MEM_USAGE=$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits -i $GPU_ID | tr -d '[:space:]') + + # # Assuming a GPU is considered free if its memory usage is less than 100 MiB + # if [ "$MEM_USAGE" -ge 100 ]; then + # ALL_GPUS_FREE=0 + # echo "GPU $GPU_ID is in use. Memory used: ${MEM_USAGE}MiB." + # break # Exit the loop early as we found a GPU that is not free + # fi + # done + + # if [ $ALL_GPUS_FREE -eq 0 ]; then + # echo "Not all GPUs in chunk are free. Checking again in 100 seconds..." + # sleep 100 + # fi + # done + + echo "CUDA_VISIBLE_DEVICES=$CHUNK_GPUS_STR" + CUDA_VISIBLE_DEVICES=$CHUNK_GPUS_STR python3 llava/eval/model_video_chatgpt_general.py \ + --model-path $CKPT \ + --video_dir ./data/llava_video/video-chatgpt/evaluation/Test_Videos/ \ + --gt_file ./data/llava_video/video-chatgpt/evaluation/temporal_qa.json \ + --output_dir ./work_dirs/eval_video_chatgpt/$SAVE_DIR \ + --output_name pred_temporal \ + --num-chunks $CHUNKS \ + --chunk-idx $(($IDX - 1)) \ + --for_get_frames_num $FRAMES \ + --overwrite ${OVERWRITE} \ + --mm_spatial_pool_stride ${POOL_STRIDE:-4} \ + --conv-mode $CONV_MODE & + +done + +wait + + +python3 llava/eval/evaluate_benchmark_4_temporal.py \ + --pred_path ./work_dirs/eval_video_chatgpt/$SAVE_DIR \ + --output_dir ./work_dirs/eval_video_chatgpt/$SAVE_DIR/temporal_results \ + --output_json ./work_dirs/eval_video_chatgpt/$SAVE_DIR/temporal_results.json \ + --num_chunks $CHUNKS \ + --output_name pred_temporal \ + --num_tasks 16 \ + --api_key $OPENAIKEY \ + + + +for IDX in $(seq 1 $CHUNKS); do + START=$(((IDX-1) * GPUS_PER_CHUNK)) + LENGTH=$GPUS_PER_CHUNK # Length for slicing, not the end index + + CHUNK_GPUS=(${GPULIST[@]:$START:$LENGTH}) + + # Convert the chunk GPUs array to a comma-separated string + CHUNK_GPUS_STR=$(IFS=,; echo "${CHUNK_GPUS[*]}") + + # ALL_GPUS_FREE=0 + # while [ $ALL_GPUS_FREE -eq 0 ]; do + # ALL_GPUS_FREE=1 # Assume all GPUs are free initially + + # for GPU_ID in $CHUNK_GPUS; do + # MEM_USAGE=$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits -i $GPU_ID | tr -d '[:space:]') + + # # Assuming a GPU is considered free if its memory usage is less than 100 MiB + # if [ "$MEM_USAGE" -ge 100 ]; then + # ALL_GPUS_FREE=0 + # echo "GPU $GPU_ID is in use. Memory used: ${MEM_USAGE}MiB." + # break # Exit the loop early as we found a GPU that is not free + # fi + # done + + # if [ $ALL_GPUS_FREE -eq 0 ]; then + # echo "Not all GPUs in chunk are free. Checking again in 100 seconds..." + # sleep 100 + # fi + # done + + echo "CUDA_VISIBLE_DEVICES=$CHUNK_GPUS_STR" + CUDA_VISIBLE_DEVICES=$CHUNK_GPUS_STR python3 llava/eval/model_video_chatgpt_consistency.py \ + --model-path $CKPT \ + --video_dir ./data/llava_video/video-chatgpt/evaluation/Test_Videos/ \ + --gt_file ./data/llava_video/video-chatgpt/evaluation/consistency_qa.json \ + --output_dir ./work_dirs/eval_video_chatgpt/$SAVE_DIR \ + --output_name pred_consistency \ + --num-chunks $CHUNKS \ + --chunk-idx $(($IDX - 1)) \ + --mm_spatial_pool_stride ${POOL_STRIDE:-4} \ + --for_get_frames_num $FRAMES \ + --overwrite ${OVERWRITE} \ + --conv-mode $CONV_MODE & +done + +wait + + +python3 llava/eval/evaluate_benchmark_5_consistency.py \ + --pred_path ./work_dirs/eval_video_chatgpt/$SAVE_DIR \ + --output_dir ./work_dirs/eval_video_chatgpt/$SAVE_DIR/consistency_results \ + --output_json ./work_dirs/eval_video_chatgpt/$SAVE_DIR/consistency_results.json \ + --num_chunks $CHUNKS \ + --output_name pred_consistency \ + --num_tasks 16 \ + --api_key $OPENAIKEY \ + diff --git a/scripts/video/eval/video_description_from_t2v.sh b/scripts/video/eval/video_description_from_t2v.sh new file mode 100644 index 000000000..7b87c9408 --- /dev/null +++ b/scripts/video/eval/video_description_from_t2v.sh @@ -0,0 +1,98 @@ +#!/bin/bash +ROOT_DIR="/mnt/bn/vl-research/workspace/yhzhang/llava-next-video" + +if [ ! -e $ROOT_DIR ]; then + echo "The root dir does not exist. Exiting the script." + exit 1 +fi + +cd $ROOT_DIR + +export PYTHONWARNINGS=ignore +export TOKENIZERS_PARALLELISM=false + +CKPT=$1 +CONV_MODE=$2 +FRAMES=$3 +POOL_STRIDE=$4 +OVERWRITE=$5 +CHUNKS=${6:-1} +DO_CENTER_CROP=${7:-False} + +echo "Using $CHUNKS GPUs" + +LOAD_8BIT=False + + +if [ "$OVERWRITE" = False ]; then + if [ "$MODEL_MAX_LENGTH" = 0 ]; then + SAVE_DIR=$(basename $CKPT)_${CONV_MODE}_frames_${FRAMES}_overwrite_${OVERWRITE} + else + SAVE_DIR=$(basename $CKPT)_${CONV_MODE}_frames_${FRAMES}_overwrite_${OVERWRITE} + fi +else + SAVE_DIR=$(basename $CKPT)_${CONV_MODE}_frames_${FRAMES}_stride_${POOL_STRIDE} +fi + +SAVE_DIR=${SAVE_DIR}_do_center_crop_${DO_CENTER_CROP} +# Assuming GPULIST is a bash array containing your GPUs +GPULIST=(0 1 2 3 4 5 6 7) +# GPULIST=(0) + +# Get the number of GPUs +NUM_GPUS=${#GPULIST[@]} + +# Calculate GPUs per chunk +GPUS_PER_CHUNK=$((NUM_GPUS / CHUNKS)) + + +for IDX in $(seq 1 $CHUNKS); do + START=$(((IDX-1) * GPUS_PER_CHUNK)) + LENGTH=$GPUS_PER_CHUNK # Length for slicing, not the end index + + CHUNK_GPUS=(${GPULIST[@]:$START:$LENGTH}) + + # Convert the chunk GPUs array to a comma-separated string + CHUNK_GPUS_STR=$(IFS=,; echo "${CHUNK_GPUS[*]}") + + # ALL_GPUS_FREE=0 + # while [ $ALL_GPUS_FREE -eq 0 ]; do + # ALL_GPUS_FREE=1 # Assume all GPUs are free initially + + # for GPU_ID in $CHUNK_GPUS; do + # MEM_USAGE=$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits -i $GPU_ID | tr -d '[:space:]') + + # # Assuming a GPU is considered free if its memory usage is less than 100 MiB + # if [ "$MEM_USAGE" -ge 100 ]; then + # ALL_GPUS_FREE=0 + # echo "GPU $GPU_ID is in use. Memory used: ${MEM_USAGE}MiB." + # break # Exit the loop early as we found a GPU that is not free + # fi + # done + + # if [ $ALL_GPUS_FREE -eq 0 ]; then + # echo "Not all GPUs in chunk are free. Checking again in 100 seconds..." + # sleep 100 + # fi + # done + + echo "CUDA_VISIBLE_DEVICES=$CHUNK_GPUS_STR" + CUDA_VISIBLE_DEVICES=$CHUNK_GPUS_STR python3 llava/eval/model_video_description_from_t2v.py \ + --model-path $CKPT \ + --gt_file /mnt/bn/vl-research-1t/tuyen/webvid_hdvg_movie_pond5_for_captioning_evaluation/webvid_hdvg_movie_pond5_for_captioning_evaluation.processed.csv \ + --output_dir ./work_dirs/eval_video_description_from_t2v/$SAVE_DIR \ + --output_name pred \ + --num-chunks $CHUNKS \ + --chunk-idx $(($IDX - 1)) \ + --overwrite ${OVERWRITE} \ + --mm_spatial_pool_stride ${POOL_STRIDE:-4} \ + --for_get_frames_num $FRAMES \ + --load_8bit $LOAD_8BIT \ + --do_center_crop $DO_CENTER_CROP \ + --conv-mode $CONV_MODE & +done + +wait + +cat ${ROOT_DIR}/work_dirs/eval_video_description_from_t2v/$SAVE_DIR/${CHUNKS}* > ${ROOT_DIR}/work_dirs/eval_video_description_from_t2v/$SAVE_DIR/pred.json + diff --git a/scripts/video/eval/video_detail_description_eval_only.sh b/scripts/video/eval/video_detail_description_eval_only.sh new file mode 100644 index 000000000..b75021ffb --- /dev/null +++ b/scripts/video/eval/video_detail_description_eval_only.sh @@ -0,0 +1,24 @@ +#!/bin/bash +ROOT_DIR="root to LLaVA-NeXT-Video" + +if [ ! -e $ROOT_DIR ]; then + echo "The root dir does not exist. Exiting the script." + exit 1 +fi + +cd $ROOT_DIR + +export PYTHONWARNINGS=ignore +export TOKENIZERS_PARALLELISM=false + +OPENAIKEY="INPUT YOUR OPENAI API" + +SAVE_DIR=$1 + +python3 llava/eval/evaluate_benchmark_video_detail_description.py \ + --pred_path ./work_dirs/eval_video_detail_description/$SAVE_DIR/pred.json \ + --output_dir ./work_dirs/eval_video_detail_description/$SAVE_DIR/detail_results \ + --output_json ./work_dirs/eval_video_detail_description/$SAVE_DIR/detail_results.json \ + --num_chunks 1 \ + --num_tasks 16 \ + --api_key $OPENAIKEY \ \ No newline at end of file diff --git a/scripts/video/eval/video_detail_description_eval_shard.sh b/scripts/video/eval/video_detail_description_eval_shard.sh new file mode 100644 index 000000000..3b9de5cae --- /dev/null +++ b/scripts/video/eval/video_detail_description_eval_shard.sh @@ -0,0 +1,95 @@ +#!/bin/bash +ROOT_DIR="/mnt/bn/vl-research/workspace/yhzhang/llava-next-video" + +if [ ! -e $ROOT_DIR ]; then + echo "The root dir does not exist. Exiting the script." + exit 1 +fi + +cd $ROOT_DIR + +export PYTHONWARNINGS=ignore +export TOKENIZERS_PARALLELISM=false + +OPENAIKEY="INPUT YOUR OPENAI API" + +CKPT=$1 +CONV_MODE=$2 +FRAMES=$3 +POOL_STRIDE=$4 +OVERWRITE=$5 +CHUNKS=${6:-1} + +echo "Using $CHUNKS GPUs" + +if [ "$OVERWRITE" = False ]; then + SAVE_DIR=$(basename $CKPT)_${CONV_MODE}_frames_${FRAMES}_stride_${POOL_STRIDE}_overwrite_${OVERWRITE} + +else + SAVE_DIR=$(basename $CKPT)_${CONV_MODE}_frames_${FRAMES}_stride_${POOL_STRIDE} +fi + +# Assuming GPULIST is a bash array containing your GPUs +GPULIST=(0 1 2 3 4 5 6 7) + +# Get the number of GPUs +NUM_GPUS=${#GPULIST[@]} + +# Calculate GPUs per chunk +GPUS_PER_CHUNK=$((NUM_GPUS / CHUNKS)) + + +for IDX in $(seq 1 $CHUNKS); do + START=$(((IDX-1) * GPUS_PER_CHUNK)) + LENGTH=$GPUS_PER_CHUNK # Length for slicing, not the end index + + CHUNK_GPUS=(${GPULIST[@]:$START:$LENGTH}) + + # Convert the chunk GPUs array to a comma-separated string + CHUNK_GPUS_STR=$(IFS=,; echo "${CHUNK_GPUS[*]}") + + # ALL_GPUS_FREE=0 + # while [ $ALL_GPUS_FREE -eq 0 ]; do + # ALL_GPUS_FREE=1 # Assume all GPUs are free initially + + # for GPU_ID in $CHUNK_GPUS; do + # MEM_USAGE=$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits -i $GPU_ID | tr -d '[:space:]') + + # # Assuming a GPU is considered free if its memory usage is less than 100 MiB + # if [ "$MEM_USAGE" -ge 100 ]; then + # ALL_GPUS_FREE=0 + # echo "GPU $GPU_ID is in use. Memory used: ${MEM_USAGE}MiB." + # break # Exit the loop early as we found a GPU that is not free + # fi + # done + + # if [ $ALL_GPUS_FREE -eq 0 ]; then + # echo "Not all GPUs in chunk are free. Checking again in 100 seconds..." + # sleep 100 + # fi + # done + + echo "CUDA_VISIBLE_DEVICES=$CHUNK_GPUS_STR" + CUDA_VISIBLE_DEVICES=$CHUNK_GPUS_STR python3 llava/eval/model_video_detail_description.py \ + --model-path $CKPT \ + --video_dir ./data/llava_video/video-chatgpt/evaluation/Test_Videos/ \ + --output_dir ./work_dirs/eval_video_detail_description/$SAVE_DIR \ + --output_name pred \ + --num-chunks $CHUNKS \ + --chunk-idx $(($IDX - 1)) \ + --overwrite ${OVERWRITE} \ + --mm_spatial_pool_stride ${POOL_STRIDE:-4} \ + --for_get_frames_num $FRAMES \ + --conv-mode $CONV_MODE & +done + +wait + +python3 llava/eval/evaluate_benchmark_video_detail_description.py \ + --pred_path ./work_dirs/eval_video_detail_description/$SAVE_DIR \ + --output_dir ./work_dirs/eval_video_detail_description/$SAVE_DIR/detail_results \ + --output_json ./work_dirs/eval_video_detail_description/$SAVE_DIR/detail_results.json \ + --num_chunks $CHUNKS \ + --num_tasks 16 \ + --api_key $OPENAIKEY \ + diff --git a/scripts/video/train/SO400M_Qwen2_72B_ov_to_video_am9.sh b/scripts/video/train/SO400M_Qwen2_72B_ov_to_video_am9.sh new file mode 100755 index 000000000..5e053ffb5 --- /dev/null +++ b/scripts/video/train/SO400M_Qwen2_72B_ov_to_video_am9.sh @@ -0,0 +1,84 @@ +#!/bin/bash + +# Set up the data folder +IMAGE_FOLDER="XXX" +VIDEO_FOLDER="XXX" +DATA_YAML="XXX" # e.g exp.yaml + +############### Prepare Envs ################# +python3 -m pip install flash-attn --no-build-isolation +alias python=python3 +############### Show Envs #################### + +nvidia-smi + +################ Arnold Jobs ################ + +LLM_VERSION="Qwen/Qwen2-72B-Instruct" +LLM_VERSION_CLEAN="${LLM_VERSION//\//_}" +VISION_MODEL_VERSION="google/siglip-so400m-patch14-384" +VISION_MODEL_VERSION_CLEAN="${VISION_MODEL_VERSION//\//_}" + + +BASE_RUN_NAME="llavanext-google_siglip-so400m-patch14-384-Qwen_Qwen2-72B-Instruct-mlp2x_gelu-pretrain_blip558k_plain" +echo "BASE_RUN_NAME: ${BASE_RUN_NAME}" + +# Stage 2 +PROMPT_VERSION="qwen_1_5" +MID_RUN_NAME="llavanext-${VISION_MODEL_VERSION_CLEAN}-${LLM_VERSION_CLEAN}-ov_to_video_am9" +PREV_STAGE_CHECKPOINT="lmms-lab/llava-onevision-qwen2-72b-ov-si" +echo "PREV_STAGE_CHECKPOINT: ${PREV_STAGE_CHECKPOINT}" +echo "MID_RUN_NAME: ${MID_RUN_NAME}" + + +# ACCELERATE_CPU_AFFINITY=1 torchrun --nproc_per_node="${ARNOLD_WORKER_GPU}" --nnodes="${ARNOLD_WORKER_NUM}" --node_rank="${ARNOLD_ID}" --master_addr="${METIS_WORKER_0_HOST}" --master_port="${port_in_cmd}" \ +deepspeed --master_port 30000 \ + llava/train/train_mem.py \ + --deepspeed scripts/zero3.json \ + --model_name_or_path $PREV_STAGE_CHECKPOINT \ + --version $PROMPT_VERSION \ + --data_path $DATA_YAML \ + --image_folder $IMAGE_FOLDER \ + --video_folder $VIDEO_FOLDER \ + --mm_tunable_parts="mm_vision_tower,mm_mlp_adapter,mm_language_model" \ + --mm_vision_tower_lr=2e-6 \ + --vision_tower ${VISION_MODEL_VERSION} \ + --mm_projector_type mlp2x_gelu \ + --mm_vision_select_layer -2 \ + --mm_use_im_start_end False \ + --mm_use_im_patch_token False \ + --group_by_modality_length True \ + --image_aspect_ratio anyres_max_9 \ + --image_grid_pinpoints "(1x1),...,(6x6)" \ + --mm_patch_merge_type spatial_unpad \ + --bf16 True \ + --run_name $MID_RUN_NAME \ + --output_dir ./work_dirs/$MID_RUN_NAME \ + --num_train_epochs 1 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 4 \ + --gradient_accumulation_steps 2 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 500 \ + --save_total_limit 1 \ + --learning_rate 1e-5 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --tf32 True \ + --model_max_length 32768 \ + --gradient_checkpointing True \ + --dataloader_num_workers 2 \ + --lazy_preprocess True \ + --report_to wandb \ + --torch_compile True \ + --torch_compile_backend "inductor" \ + --dataloader_drop_last True \ + --frames_upbound 32 \ + --mm_newline_position grid \ + --add_time_instruction True \ + --force_sample True \ + --mm_spatial_pool_stride 2 +exit 0; \ No newline at end of file diff --git a/scripts/video/train/SO400M_Qwen2_7B_ov_to_video_am9.sh b/scripts/video/train/SO400M_Qwen2_7B_ov_to_video_am9.sh new file mode 100755 index 000000000..14b915a0c --- /dev/null +++ b/scripts/video/train/SO400M_Qwen2_7B_ov_to_video_am9.sh @@ -0,0 +1,84 @@ +#!/bin/bash + +# Set up the data folder +IMAGE_FOLDER="XXX" +VIDEO_FOLDER="XXX" +DATA_YAML="XXX" # e.g exp.yaml + +############### Prepare Envs ################# +python3 -m pip install flash-attn --no-build-isolation +alias python=python3 +############### Show Envs #################### + +nvidia-smi + +################ Arnold Jobs ################ + +LLM_VERSION="Qwen/Qwen2-7B-Instruct" +LLM_VERSION_CLEAN="${LLM_VERSION//\//_}" +VISION_MODEL_VERSION="google/siglip-so400m-patch14-384" +VISION_MODEL_VERSION_CLEAN="${VISION_MODEL_VERSION//\//_}" +# + +BASE_RUN_NAME="llavanext-google_siglip-so400m-patch14-384-Qwen_Qwen2-7B-Instruct-mlp2x_gelu-pretrain_blip558k_plain" +echo "BASE_RUN_NAME: ${BASE_RUN_NAME}" + +# Stage 2 +PROMPT_VERSION="qwen_1_5" +MID_RUN_NAME="llavanext-${VISION_MODEL_VERSION_CLEAN}-${LLM_VERSION_CLEAN}-ov_to_video_am9" +PREV_STAGE_CHECKPOINT="lmms-lab/llava-onevision-qwen2-7b-ov-si" +echo "PREV_STAGE_CHECKPOINT: ${PREV_STAGE_CHECKPOINT}" +echo "MID_RUN_NAME: ${MID_RUN_NAME}" + + +# ACCELERATE_CPU_AFFINITY=1 torchrun --nproc_per_node="${ARNOLD_WORKER_GPU}" --nnodes="${ARNOLD_WORKER_NUM}" --node_rank="${ARNOLD_ID}" --master_addr="${METIS_WORKER_0_HOST}" --master_port="${port_in_cmd}" \ +deepspeed --master_port 30000 \ + llava/train/train_mem.py \ + --deepspeed scripts/zero3.json \ + --model_name_or_path $PREV_STAGE_CHECKPOINT \ + --version $PROMPT_VERSION \ + --data_path $DATA_YAML \ + --image_folder $IMAGE_FOLDER \ + --video_folder $VIDEO_FOLDER \ + --mm_tunable_parts="mm_vision_tower,mm_mlp_adapter,mm_language_model" \ + --mm_vision_tower_lr=2e-6 \ + --vision_tower ${VISION_MODEL_VERSION} \ + --mm_projector_type mlp2x_gelu \ + --mm_vision_select_layer -2 \ + --mm_use_im_start_end False \ + --mm_use_im_patch_token False \ + --group_by_modality_length True \ + --image_aspect_ratio anyres_max_9 \ + --image_grid_pinpoints "(1x1),...,(6x6)" \ + --mm_patch_merge_type spatial_unpad \ + --bf16 True \ + --run_name $MID_RUN_NAME \ + --output_dir ./work_dirs/$MID_RUN_NAME \ + --num_train_epochs 1 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 4 \ + --gradient_accumulation_steps 2 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 500 \ + --save_total_limit 1 \ + --learning_rate 1e-5 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --tf32 True \ + --model_max_length 32768 \ + --gradient_checkpointing True \ + --dataloader_num_workers 2 \ + --lazy_preprocess True \ + --report_to wandb \ + --torch_compile True \ + --torch_compile_backend "inductor" \ + --dataloader_drop_last True \ + --frames_upbound 64 \ + --mm_newline_position grid \ + --add_time_instruction True \ + --force_sample True \ + --mm_spatial_pool_stride 2 +exit 0; \ No newline at end of file diff --git a/scripts/video/train/exp.yaml b/scripts/video/train/exp.yaml new file mode 100644 index 000000000..a0f5cb59c --- /dev/null +++ b/scripts/video/train/exp.yaml @@ -0,0 +1,263 @@ +datasets: + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/llava_next_fit_mix_filtered_text_wild_738590.json + sampling_strategy: "first:50%" + # - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/llava_wild_4v_39k.json + # sampling_strategy: "first:10%" + # - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/llava_wild_4v_12k.json + # sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/mavis_math_metagen_87358.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/mavis_math_rule_geo_100000.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/cambrian_filtered_gpt4vo_sp_token_fltd_max10k_checked.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/VisualWebInstruct_filtered_263589.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/visual_chat_en_26048_gpt4o_coco_checked.json + sampling_strategy: "all" + # - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/gpt4o_combinations_51316.json + # sampling_strategy: "all" + # - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/chrome_writting_train_8835.json + # sampling_strategy: "first:20%" + # - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/k12_printing_train_256646.json + # sampling_strategy: "first:1%" + # - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/iiit5k_annotations_2000.json + # sampling_strategy: "first:20%" + # - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/hme100k_train_clean_74502.json + # sampling_strategy: "first:10%" + # - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/sroie_data_33626.json + # sampling_strategy: "first:1%" + # - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/orand_car_a_train_2009.json + # sampling_strategy: "first:10%" + # - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/orand_car_b_train_3000.json + # sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/llavar_gpt4_20k.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/ai2d_azuregpt_detailed_understanding_4874.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/infographic_vqa_4404.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/infographic_azuregpt4v_1992.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/lrv_chart_1787.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/lrv_normal_gpt4v_filtered_10500.json + sampling_strategy: "first:10%" + # - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/scienceqa_nona_context_19218.json + # sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/allava_instruct_vflan4v_20000.json + sampling_strategy: "first:30%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/allava_instruct_laion4v_50000.json + sampling_strategy: "first:30%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/textocr_gpt4v_train_converted_25114.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/ai2d_train_internvl_single_12413.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/textcaps_train_21952.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/ureader_new/ureader_qa_sft.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/ureader_new/ureader_cap_sft.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/ureader_new/ureader_ie_sft.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/ureader_new/ureader_kg_sft.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/vision_flan_filtered_186070.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/mathqa_29837.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/geo3k_2101.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/geo170k_qa_converted_67833.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/geo170k_align_converted_60252.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/sharegpt4o_dataset.jsonl + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/sharegpt4v-coco-50k.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/sharegpt4v-knowledge-2k.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/sharegpt4v-llava-30k.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/sharegpt4v-sam-20k.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_CLEVR-Math_5290.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_FigureQA_17597.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_Geometry3K_9734.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_GeoQA+_17172.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_GEOS_508.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_IconQA_22599.json + sampling_strategy: "first:5%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_MapQA_5235.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_PlotQA_5485.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_PMC-VQA_35958.json + sampling_strategy: "first:1%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_Super-CLEVR_8652.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_TabMWP_22462.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_TQA_10181.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_UniGeo_11959.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_VizWiz_6614.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_VQA-AS_5907.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_VQA-RAD_2130.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/magpie_pro_qwen2_72b_st_300000_sp_token_fltd_299992.json + sampling_strategy: "end:20%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/magpie_pro_l3_80b_st_300000.json + sampling_strategy: "end:20%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/magpie_pro_l3_80b_mt_300000_sp_token_fltd_299998.json + sampling_strategy: "end:20%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/image_textualization_dataset_filtered.json + sampling_strategy: "first:20%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/ai2d_llava_format_2434.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/chart2text_26961.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/chartqa_18265_llava_format.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/diagram_image_to_text_300.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/hateful_memes_8500_llava_format.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/hitab_2500_llava_format.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/iam_5663.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/infographic_vqa_2118_llava_format.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/intergps_1280_llava_format.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/mapqa_37417_llava_format.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/rendered_text_10000.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/robut_sqa_8514.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/robut_wikisql_74989.json + sampling_strategy: "first:10%" + # - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/robut_wtq_38246_llava_format_filtered_4000tokens_38236.json + # sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/screen2words_15730.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/tabmwp_22722.json + sampling_strategy: "first:5%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/tallyqa_98680_llava_format.json + sampling_strategy: "first:5%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/st_vqa_17247_llava_format.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/tqa_llava_format_27307.json + sampling_strategy: "first:5%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/visual7w_llava_format_14366.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/visualmrc_3027.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/vqarad_313_llava_format.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/vsr_2157_llava_format.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/vistext_9969.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/websight_10000.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/llava_ofa_DEMON-FULL_filtered_311085.json + sampling_strategy: all + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/llava_ofa_mantis-instruct_reformatted.json + sampling_strategy: all + # - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/academic_source_30s_v1_all.json + # sampling_strategy: all + # - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/0718_0_30_s_academic_mc_v0_1_all.json + # sampling_strategy: all + # - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/sharegpt4video_255000.json + # sampling_strategy: all + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/0_30_s_academic_v0_1_cap.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/0_30_s_youtube_v0_1_cap.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/30_60_s_academic_v0_1_cap.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/30_60_s_youtube_v0_1_cap.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/1_2_m_academic_v0_1_cap.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/1_2_m_youtube_v0_1_cap.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/0_30_s_academic_oe_v0_1_qa.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/0_30_s_academic_mc_v0_1_qa.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/0_30_s_youtube_oe_v0_1_qa.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/0_30_s_youtube_mc_v0_1_qa.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/0_30_s_activitynetqa_oe_qa.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/0_30_s_nextqa_oe_qa.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/0_30_s_nextqa_mc_qa.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/0_30_s_perceptiontest_mc_qa.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/30_60_s_academic_oe_v0_1_qa.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/30_60_s_academic_mc_v0_1_qa.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/30_60_s_youtube_oe_v0_1_qa.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/30_60_s_youtube_mc_v0_1_qa.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/30_60_s_activitynetqa_oe_qa.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/30_60_s_nextqa_oe_qa.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/30_60_s_nextqa_mc_qa.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/30_60_s_perceptiontest_mc_qa.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/1_2_m_academic_oe_v0_1_qa.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/1_2_m_academic_mc_v0_1_qa.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/1_2_m_youtube_oe_v0_1_qa.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/1_2_m_youtube_mc_v0_1_qa.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/1_2_m_activitynetqa_oe_qa.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/1_2_m_nextqa_oe_qa.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/1_2_m_nextqa_mc_qa.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/sharegptvideo_qa_255k.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/2_3_m_academic_v0_1_cap.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/2_3_m_youtube_v0_1_cap.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/2_3_m_academic_oe_v0_1_qa.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/2_3_m_academic_mc_v0_1_qa.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/2_3_m_youtube_oe_v0_1_qa.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/2_3_m_youtube_mc_v0_1_qa.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/2_3_m_nextqa_oe_qa.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/2_3_m_nextqa_mc_qa.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/2_3_m_activitynetqa_oe_qa.json + sampling_strategy: "all" diff --git a/scripts/zero2.json b/scripts/zero2.json new file mode 100755 index 000000000..b5ba7ebea --- /dev/null +++ b/scripts/zero2.json @@ -0,0 +1,41 @@ +{ + "fp16": { + "enabled": "auto", + "loss_scale": 0, + "loss_scale_window": 1000, + "initial_scale_power": 16, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "bf16": { + "enabled": "auto" + }, + "optimizer": { + "type": "AdamW", + "params": { + "lr": "auto", + "betas": "auto", + "eps": "auto", + "weight_decay": "auto" + } + }, + "zero_optimization": { + "stage": 2, + "offload_optimizer": { + "device": "none", + "pin_memory": true + }, + "allgather_partitions": true, + "allgather_bucket_size": 2e8, + "overlap_comm": false, + "reduce_scatter": true, + "reduce_bucket_size": 2e8, + "contiguous_gradients": true + }, + "gradient_accumulation_steps": "auto", + "gradient_clipping": "auto", + "steps_per_print": 100, + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "wall_clock_breakdown": false +} \ No newline at end of file diff --git a/scripts/zero2_fused_adamw.json b/scripts/zero2_fused_adamw.json new file mode 100755 index 000000000..0c37172f6 --- /dev/null +++ b/scripts/zero2_fused_adamw.json @@ -0,0 +1,41 @@ +{ + "fp16": { + "enabled": "auto", + "loss_scale": 0, + "loss_scale_window": 1000, + "initial_scale_power": 16, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "bf16": { + "enabled": "auto" + }, + "optimizer": { + "type": "AdamW", + "params": { + "lr": "auto", + "betas": "auto", + "eps": "auto", + "weight_decay": "auto" + } + }, + "zero_optimization": { + "stage": 2, + "offload_optimizer": { + "device": "none", + "pin_memory": true + }, + "allgather_partitions": true, + "allgather_bucket_size": 2e8, + "overlap_comm": true, + "reduce_scatter": true, + "reduce_bucket_size": 2e8, + "contiguous_gradients": true + }, + "gradient_accumulation_steps": "auto", + "gradient_clipping": "auto", + "steps_per_print": 100, + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "wall_clock_breakdown": false +} \ No newline at end of file diff --git a/scripts/zero2_offload.json b/scripts/zero2_offload.json new file mode 100755 index 000000000..2d24e895b --- /dev/null +++ b/scripts/zero2_offload.json @@ -0,0 +1,31 @@ +{ + "fp16": { + "enabled": "auto", + "loss_scale": 0, + "loss_scale_window": 1000, + "initial_scale_power": 16, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "bf16": { + "enabled": "auto" + }, + "train_micro_batch_size_per_gpu": "auto", + "train_batch_size": "auto", + "gradient_accumulation_steps": "auto", + "zero_optimization": { + "stage": 2, + "offload_optimizer": { + "device": "cpu", + "pin_memory": true + }, + "offload_param": { + "device": "cpu", + "pin_memory": true + }, + "overlap_comm": true, + "contiguous_gradients": true, + "sub_group_size": 1e9, + "reduce_bucket_size": "auto" + } +} \ No newline at end of file diff --git a/scripts/zero3.json b/scripts/zero3.json new file mode 100755 index 000000000..02d343165 --- /dev/null +++ b/scripts/zero3.json @@ -0,0 +1,41 @@ +{ + "fp16": { + "enabled": "auto", + "loss_scale": 0, + "loss_scale_window": 1000, + "initial_scale_power": 16, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "bf16": { + "enabled": "auto" + }, + + "zero_optimization": { + "stage": 3, + "offload_optimizer": { + "device": "none", + "pin_memory": true + }, + "offload_param": { + "device": "none", + "pin_memory": true + }, + "overlap_comm": true, + "contiguous_gradients": true, + "sub_group_size": 1e9, + "reduce_bucket_size": "auto", + "stage3_prefetch_bucket_size": "auto", + "stage3_param_persistence_threshold": "auto", + "stage3_max_live_parameters": 1e9, + "stage3_max_reuse_distance": 1e9, + "stage3_gather_16bit_weights_on_model_save": true + }, + + "gradient_accumulation_steps": "auto", + "gradient_clipping": "auto", + "steps_per_print": 100, + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "wall_clock_breakdown": false +} \ No newline at end of file diff --git a/scripts/zero3_offload.json b/scripts/zero3_offload.json new file mode 100755 index 000000000..9da12de56 --- /dev/null +++ b/scripts/zero3_offload.json @@ -0,0 +1,48 @@ +{ + "fp16": { + "enabled": "auto", + "loss_scale": 0, + "loss_scale_window": 1000, + "initial_scale_power": 16, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "bf16": { + "enabled": "auto" + }, + "optimizer": { + "type": "AdamW", + "params": { + "lr": "auto", + "betas": "auto", + "eps": "auto", + "weight_decay": "auto" + } + }, + "zero_optimization": { + "stage": 3, + "offload_optimizer": { + "device": "cpu", + "pin_memory": true + }, + "offload_param": { + "device": "cpu", + "pin_memory": true + }, + "overlap_comm": true, + "contiguous_gradients": true, + "sub_group_size": 1e9, + "reduce_bucket_size": "auto", + "stage3_prefetch_bucket_size": "auto", + "stage3_param_persistence_threshold": "auto", + "stage3_max_live_parameters": 1e9, + "stage3_max_reuse_distance": 1e9, + "gather_16bit_weights_on_model_save": true + }, + "gradient_accumulation_steps": "auto", + "gradient_clipping": "auto", + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "steps_per_print": 1e5, + "wall_clock_breakdown": false +} \ No newline at end of file diff --git a/scripts/zero3pp.json b/scripts/zero3pp.json new file mode 100755 index 000000000..f7ca45f95 --- /dev/null +++ b/scripts/zero3pp.json @@ -0,0 +1,53 @@ +{ + "fp16": { + "enabled": "auto", + "loss_scale": 0, + "loss_scale_window": 1000, + "initial_scale_power": 16, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "bf16": { + "enabled": "auto" + }, + "optimizer": { + "type": "AdamW", + "params": { + "lr": "auto", + "betas": "auto", + "eps": "auto", + "weight_decay": "auto" + } + }, + + "zero_optimization": { + "stage": 3, + "offload_optimizer": { + "device": "none", + "pin_memory": true + }, + "offload_param": { + "device": "none", + "pin_memory": true + }, + "overlap_comm": true, + "contiguous_gradients": true, + "zero_quantized_weights": true, + "zero_hpz_partition_size": 16, + "zero_quantized_gradients": true, + "sub_group_size": 1e9, + "reduce_bucket_size": "auto", + "stage3_prefetch_bucket_size": "auto", + "stage3_param_persistence_threshold": "auto", + "stage3_max_live_parameters": 1e9, + "stage3_max_reuse_distance": 1e9, + "stage3_gather_16bit_weights_on_model_save": true + }, + + "gradient_accumulation_steps": "auto", + "gradient_clipping": "auto", + "steps_per_print": 100, + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "wall_clock_breakdown": false +} \ No newline at end of file diff --git a/trl/__init__.py b/trl/__init__.py new file mode 100755 index 000000000..c9a1a6b85 --- /dev/null +++ b/trl/__init__.py @@ -0,0 +1,44 @@ +# flake8: noqa + +__version__ = "0.7.11.dev0" + +from .core import set_seed +from .environment import TextEnvironment, TextHistory +from .extras import BestOfNSampler +from .import_utils import ( + is_bitsandbytes_available, + is_diffusers_available, + is_npu_available, + is_peft_available, + is_wandb_available, + is_xpu_available, +) +from .models import ( + AutoModelForCausalLMWithValueHead, + AutoModelForSeq2SeqLMWithValueHead, + PreTrainedModelWrapper, + create_reference_model, + setup_chat_format, +) +from .trainer import ( + DataCollatorForCompletionOnlyLM, + DPOTrainer, + IterativeSFTTrainer, + ModelConfig, + PPOConfig, + PPOTrainer, + RewardConfig, + RewardTrainer, + SFTTrainer, +) +from .trainer.utils import get_kbit_device_map, get_peft_config, get_quantization_config + + +if is_diffusers_available(): + from .models import ( + DDPOPipelineOutput, + DDPOSchedulerOutput, + DDPOStableDiffusionPipeline, + DefaultDDPOStableDiffusionPipeline, + ) + from .trainer import DDPOConfig, DDPOTrainer diff --git a/trl/core.py b/trl/core.py new file mode 100755 index 000000000..4cf481e0f --- /dev/null +++ b/trl/core.py @@ -0,0 +1,329 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import gc +import random +import warnings +from contextlib import contextmanager +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.utils.rnn import pad_sequence + +# from transformers import top_k_top_p_filtering + +from .import_utils import is_npu_available, is_xpu_available + + +try: + from collections.abc import Mapping +except ImportError: + from collections import Mapping + + +WANDB_PADDING = -1 + + +def top_k_top_p_filtering( + logits: torch.FloatTensor, + top_k: int = 0, + top_p: float = 1.0, + filter_value: float = -float("Inf"), + min_tokens_to_keep: int = 1, +) -> torch.FloatTensor: + """ + Filter a distribution of logits using top-k and/or nucleus (top-p) filtering. + + Args: + logits: logits distribution shape (batch size, vocabulary size) + top_k (`int`, *optional*, defaults to 0): + If > 0, only keep the top k tokens with highest probability (top-k filtering) + top_p (`float`, *optional*, defaults to 1.0): + If < 1.0, only keep the top tokens with cumulative probability >= top_p (nucleus filtering). Nucleus + filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) + min_tokens_to_keep (`int`, *optional*, defaults to 1): + Minimumber of tokens we keep per batch example in the output. + + From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 + """ + + if top_k > 0: + logits = TopKLogitsWarper(top_k=top_k, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)(None, logits) + + if 0 <= top_p <= 1.0: + logits = TopPLogitsWarper(top_p=top_p, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)(None, logits) + + return logits + + +def flatten_dict(nested: Dict, sep: str = "/") -> Dict: + """Flatten dictionary and concatenate nested keys with separator.""" + + def recurse(nest: Dict, prefix: str, into: Dict) -> None: + for k, v in nest.items(): + if sep in k: + raise ValueError(f"separator '{sep}' not allowed to be in key '{k}'") + if isinstance(v, Mapping): + recurse(v, prefix + k + sep, into) + else: + into[prefix + k] = v + + flat = {} + recurse(nested, "", flat) + return flat + + +def convert_to_scalar(stats: Dict) -> Dict: + """ + Converts the stats from a flattened dict to single scalar dicts + """ + tensorboard_stats = {} + for k, v in stats.items(): + # for tensorboard compatibility - arrays and tensors are ignored with tensorboard + # therefore we convert single element tensors to scalars + if (isinstance(v, torch.Tensor) or isinstance(v, np.ndarray)) and (len(v.shape) == 0 or (len(v.shape) == 1 and v.shape[0] == 1)): + v = v.item() + tensorboard_stats[k] = v + return tensorboard_stats + + +def stack_dicts(stats_dicts: List[Dict]) -> Dict: + """Stack the values of a dict.""" + results = dict() + for k in stats_dicts[0]: + stats_list = [torch.flatten(d[k]) for d in stats_dicts] + results[k] = pad_sequence(stats_list, batch_first=True, padding_value=WANDB_PADDING) + return results + + +def add_suffix(input_dict: Dict, suffix: str) -> Dict: + """Add suffix to dict keys.""" + return dict((k + suffix, v) for k, v in input_dict.items()) + + +def pad_to_size(tensor: torch.Tensor, size: int, dim: int = 1, padding: int = 50256) -> torch.Tensor: + """Pad tensor to size.""" + t_size = tensor.size()[dim] + if t_size == size: + return tensor + else: + return torch.nn.functional.pad(tensor, (0, size - t_size), "constant", padding) + + +def logprobs_from_logits(logits: torch.Tensor, labels: torch.Tensor, gather: bool = True) -> torch.Tensor: + """ + See: https://github.com/pytorch/pytorch/issues/563#issuecomment-330103591 + """ + logp = F.log_softmax(logits, dim=2) + + if not gather: + return logp + logpy = torch.gather(logp, 2, labels.unsqueeze(2)).squeeze(-1) + return logpy + + +def whiten(values: torch.Tensor, shift_mean: bool = True) -> torch.Tensor: + """Whiten values.""" + mean, var = torch.mean(values), torch.var(values) + whitened = (values - mean) * torch.rsqrt(var + 1e-8) + if not shift_mean: + whitened += mean + return whitened + + +def masked_mean(values: torch.Tensor, mask: torch.Tensor, axis: bool = None) -> torch.Tensor: + """Compute mean of tensor with a masked values.""" + if axis is not None: + return (values * mask).sum(axis=axis) / mask.sum(axis=axis) + else: + return (values * mask).sum() / mask.sum() + + +def masked_var(values: torch.Tensor, mask: torch.Tensor, unbiased: bool = True) -> torch.Tensor: + """Compute variance of tensor with masked values.""" + mean = masked_mean(values, mask) + centered_values = values - mean + variance = masked_mean(centered_values**2, mask) + if unbiased: + mask_sum = mask.sum() + if mask_sum == 0: + raise ValueError("The sum of the mask is zero, which can happen when `mini_batch_size=1`;" "try increase the `mini_batch_size` or `gradient_accumulation_steps`") + # note that if mask_sum == 1, then there is a division by zero issue + # to avoid it you just need to use a larger minibatch_size + bessel_correction = mask_sum / (mask_sum - 1) + variance = variance * bessel_correction + return variance + + +def masked_whiten(values: torch.Tensor, mask: torch.Tensor, shift_mean: bool = True) -> torch.Tensor: + """Whiten values with masked values.""" + mean, var = masked_mean(values, mask), masked_var(values, mask) + whitened = (values - mean) * torch.rsqrt(var + 1e-8) + if not shift_mean: + whitened += mean + return whitened + + +def clip_by_value(x: torch.Tensor, tensor_min: float, tensor_max: float) -> torch.Tensor: + """ + Tensor extension to torch.clamp + https://github.com/pytorch/pytorch/issues/2793#issuecomment-428784713 + """ + clipped = torch.max(torch.min(x, tensor_max), tensor_min) + return clipped + + +def entropy_from_logits(logits: torch.Tensor) -> torch.Tensor: + """Calculate entropy from logits.""" + pd = torch.nn.functional.softmax(logits, dim=-1) + entropy = torch.logsumexp(logits, axis=-1) - torch.sum(pd * logits, axis=-1) + return entropy + + +def average_torch_dicts(list_of_dicts: List[Dict]) -> Dict: + """Average values of a list of dicts with torch tensors.""" + average_dict = dict() + for key in list_of_dicts[0].keys(): + average_dict[key] = torch.mean(torch.stack([d[key] for d in list_of_dicts]), axis=0) + return average_dict + + +def stats_to_np(stats_dict: Dict) -> Dict: + """Cast all torch.tensors in dict to numpy arrays.""" + new_dict = dict() + for k, v in stats_dict.items(): + if isinstance(v, torch.Tensor): + new_dict[k] = v.detach().cpu() + if new_dict[k].dtype == torch.bfloat16: + new_dict[k] = new_dict[k].float() + new_dict[k] = new_dict[k].numpy() + else: + new_dict[k] = v + if np.isscalar(new_dict[k]): + new_dict[k] = float(new_dict[k]) + return new_dict + + +def respond_to_batch(model: nn.Module, queries: List[torch.LongTensor], txt_len: int = 20, top_k: int = 0, top_p: float = 1.0) -> torch.LongTensor: + """Sample text from language model.""" + input_ids = queries + for i in range(txt_len): + # Get Logits + outputs = model(input_ids) + next_token_logits = outputs[0][:, -1, :] + next_token_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p) + # Sample + probs = F.softmax(next_token_logits, dim=-1) + next_token = torch.multinomial(probs, num_samples=1).squeeze(1) + input_ids = torch.cat([input_ids, next_token.unsqueeze(-1)], dim=-1) + return input_ids[:, -txt_len:] + + +def set_seed(seed: int) -> None: + """ + Helper function for reproducible behavior to set the seed in `random`, `numpy`, and `torch`. + + Args: + seed (`int`): The seed to set. + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if is_xpu_available(): + torch.xpu.manual_seed_all(seed) + elif is_npu_available(): + torch.npu.manual_seed_all(seed) + else: + torch.cuda.manual_seed_all(seed) + + +class LengthSampler: + """ + Samples a length + """ + + def __init__(self, min_value: int, max_value: int): + self.values = list(range(min_value, max_value)) + + def __call__(self) -> int: + return np.random.choice(self.values) + + +class PPODecorators(object): + optimize_device_cache = False + + @classmethod + @contextmanager + def empty_device_cache(cls): + yield + if cls.optimize_device_cache: + if is_xpu_available(): + gc.collect() + torch.xpu.empty_cache() + gc.collect() + elif is_npu_available(): + gc.collect() + torch.npu.empty_cache() + gc.collect() + elif torch.cuda.is_available(): + gc.collect() + torch.cuda.empty_cache() + gc.collect() + + +def randn_tensor( + shape: Union[Tuple, List], + generator: Optional[Union[List[torch.Generator], torch.Generator]] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + layout: Optional[torch.layout] = None, +) -> torch.Tensor: + """A helper function to create random tensors on the desired `device` with the desired `dtype`. When + passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor + is always created on the CPU. + """ + # device on which tensor is created defaults to device + rand_device = device + batch_size = shape[0] + + layout = layout or torch.strided + device = device or torch.device("cpu") + + if generator is not None: + gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type + if gen_device_type != device.type and gen_device_type == "cpu": + rand_device = "cpu" + if device != "mps": + warnings.warn( + f"The passed generator was created on 'cpu' even though a tensor on {device} was expected." + f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably" + f" slighly speed up this function by passing a generator that was created on the {device} device." + ) + elif gen_device_type != device.type and gen_device_type == "cuda": + raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.") + + # make sure generator list of length 1 is treated like a non-list + if isinstance(generator, list) and len(generator) == 1: + generator = generator[0] + + if isinstance(generator, list): + shape = (1,) + shape[1:] + latents = [torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout) for i in range(batch_size)] + latents = torch.cat(latents, dim=0).to(device) + else: + latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device) + + return latents diff --git a/trl/environment/__init__.py b/trl/environment/__init__.py new file mode 100755 index 000000000..ae1cda4ec --- /dev/null +++ b/trl/environment/__init__.py @@ -0,0 +1,3 @@ +# flake8: noqa + +from .base_environment import TextEnvironment, TextHistory diff --git a/trl/environment/base_environment.py b/trl/environment/base_environment.py new file mode 100755 index 000000000..4c06b1c8d --- /dev/null +++ b/trl/environment/base_environment.py @@ -0,0 +1,463 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +import warnings + +import torch +from accelerate.utils import extract_model_from_parallel +from transformers import StoppingCriteria, StoppingCriteriaList + +from ..import_utils import is_rich_available + + +if is_rich_available(): + from rich import print + from rich.text import Text + + +class StringStoppingCriteria(StoppingCriteria): + """Custom `StoppingCriteria` which checks if all generations in the batch are completed.""" + + def __init__(self, stop_strings, tokenizer): + self.stop_strings = stop_strings + self.tokenizer = tokenizer + self.first_call = True + + def __call__(self, input_ids, scores, **kwargs): + """Returns true if all generated sequences contain any of the stop strings.""" + if self.first_call: + self.generated_tokens = [1 for _ in range(input_ids.shape[0])] + self.start_length = input_ids.shape[-1] - 1 + self.first_call = False + decoded_generations = self.tokenizer.batch_decode(input_ids[:, self.start_length :]) + done = [] + + for i, decoded_generation in enumerate(decoded_generations): + sequence_complete = any([stop_string in decoded_generation for stop_string in self.stop_strings]) + done.append(sequence_complete) + if not sequence_complete: + self.generated_tokens[i] += 1 + + if all(done): + self.first_call = True + + return all(done) + + +class TextHistory: + """The TextHistory class keeps track of the history of an interaction between the language model and the environment.""" + + def __init__(self, text, tokens, system=True): + """ + Initialize TextHistory. + + args: + text (`str`): The text of the first segment. + tokens (`torch.LongTensor`): The tokens of the first segment. + system (`bool`, *optional*): Whether the first segment is a system or user segment. + """ + self.system_spans = [] + self.text_spans = [] + self.token_spans = [] + self.token_masks = torch.tensor([], dtype=torch.long).to(tokens.device) + self.text = "" + self.tokens = torch.tensor([], dtype=torch.long).to(tokens.device) + self.completed = False + self.truncated = False + self.reward = 0.0 + + self.prompt_color = "black on grey85" + self.system_color = "black on cyan3" + self.model_color = "black on deep_sky_blue1" + self.reward_color = "black on plum1" + + self.append_segment(text, tokens, system=system) + + def append_segment(self, text, tokens, system=True): + """ + Append a new segment to the history. + + args: + text (`str`): The text of the new segment. + tokens (`torch.LongTensor`): The tokens of the new segment. + system (`bool`, *optional*): Whether the new segment is a system or user segment. + """ + + if len(text) == 0 or len(tokens) == 0: + raise ValueError("Can't append empty text or token list to history.") + + original_text_length = len(self.text) + + self.text += text + self.text_spans.append((original_text_length, len(self.text))) + self.system_spans.append(system) + + original_token_length = len(self.tokens) + + self.tokens = torch.cat((self.tokens, tokens)) + if system: + self.token_masks = torch.cat((self.token_masks, torch.zeros_like(tokens))) + else: + self.token_masks = torch.cat((self.token_masks, torch.ones_like(tokens))) + self.token_spans.append((original_token_length, len(self.tokens))) + + def complete(self, truncated=False): + """ + Mark the history as completed. + """ + self.completed = True + self.truncated = truncated + + @property + def last_text_segment(self): + """ + Get the last text segment. + """ + start, end = self.text_spans[-1] + return self.text[start:end] + + def split_query_response_tokens(self): + """ + Split the tokens into query and response tokens. + """ + split_index = self.token_spans[0][1] + query = self.tokens[:split_index] + response = self.tokens[split_index:] + mask = self.token_masks[split_index:] + + return query, response, mask + + def show_text(self, show_legend=False): + """ + Print the text history. + """ + if not is_rich_available(): + warnings.warn("install rich to display text") + return + + text = Text(self.text) + text.stylize(self.prompt_color, self.text_spans[0][0], self.text_spans[1][0]) + for i, (start, end) in enumerate(self.text_spans[1:]): + if self.system_spans[i + 1]: + text.stylize(self.system_color, start, end) + else: + text.stylize(self.model_color, start, end) + + text.append(f"\n\nReward: {self.reward}", style=self.reward_color) + print(text) + + if show_legend: + self.show_colour_legend() + + def show_tokens(self, tokenizer, show_legend=False): + """ + Print the history tokens. + """ + if not is_rich_available(): + warnings.warn("install rich to display tokens") + return + + text = Text() + prompt_end = self.token_spans[0][1] + for i, (token, mask) in enumerate(zip(self.tokens, self.token_masks)): + if i < prompt_end: + text.append(tokenizer.convert_ids_to_tokens(token.item()), style=self.prompt_color) + text.append(" ") + elif mask == 0: + text.append(tokenizer.convert_ids_to_tokens(token.item()), style=self.system_color) + text.append(" ") + else: + text.append(tokenizer.convert_ids_to_tokens(token.item()), style=self.model_color) + text.append(" ") + text.append(f"\n\nReward: {self.reward}", style=self.reward_color) + print(text) + if show_legend: + self.show_colour_legend() + + def show_colour_legend(self): + """ + Print the colour legend. + """ + if not is_rich_available(): + warnings.warn("install rich to display colour legend") + return + text = Text("\n\n(Colour Legend: ") + text.append("Prompt", style=self.prompt_color) + text.append("|") + text.append("System", style=self.system_color) + text.append("|") + text.append("Model", style=self.model_color) + text.append("|") + text.append("Reward", style=self.reward_color) + text.append(")") + print(text) + + +class TextEnvironment: + """ + The TextEnvironment enables interaction of a LLM with an environment using tools. + """ + + def __init__( + self, + model=None, + tokenizer=None, + tools=None, + reward_fn=None, + prompt=None, + max_turns=4, + max_tool_reponse=100, + max_length=None, + generation_kwargs=None, + ): + """ + Initialize TextEnvironment. + + Args: + model (`PreTrainedModelWrapper`): The model to use for generation. + tokenizer (`transformers.PreTrainedTokenizer`): The tokenizer to use for generation. + tools (list): A list of tools to use for interaction. + reward_fn (function): A function that takes a string and returns a reward. + prompt (str): The base prompt to use for generation. Is prepended to the tasks. + max_turns (Optional[int]): The maximum number of turns to allow. + max_tool_response (Optional[int]): The maximum number of characters to allow in a tool response. + max_length (Optional[int]): The maximum number of tokens to allow in an episode. + generation_kwargs (Optional[dict]): A dictionary of keyword arguments to pass to the model's generate method. + """ + self.model = model + self.tokenizer = tokenizer + self.prompt = prompt + if isinstance(tools, dict): + self.tools = tools + else: + self.tools = dict([(tool.__class__.__name__, tool) for tool in tools]) + self.reward_fn = reward_fn + self.max_length = max_length + self.request_token = "" + self.call_token = "" + self.response_token = "" + self.submit_token = "" + self.max_turns = max_turns + self.max_tool_response = max_tool_reponse + + if generation_kwargs is None: + self.generation_kwargs = dict() + else: + self.generation_kwargs = generation_kwargs + + self.is_encoder_decoder = hasattr(self.model, "is_encoder_decoder") + self.current_device = extract_model_from_parallel(self.model).pretrained_model.device + + def run(self, queries, **rewards_kwargs): + """ + Run the environment on a list of queries. + + Args: + queries (list[str]): A list of queries to run the model in the environment on. + """ + turns = 0 + + queries = [self.prompt + task for task in queries] + queries_tokens = [self.tokenizer(query, return_tensors="pt").input_ids[0].to(self.model.pretrained_model.device) for query in queries] + + histories = [TextHistory(q, qt, system=True) for q, qt in zip(queries, queries_tokens)] + + while any([not history.completed for history in histories]) and turns < self.max_turns: + histories = self.generate(histories) + histories = self.tasks_end_check(histories) + # TODO: make this parallel rather than for-loop + for i in range(len(histories)): + histories[i] = self.step(histories[i]) + histories = self.tasks_end_check(histories, model_turn=False) + turns += 1 + self.compute_reward(histories, **rewards_kwargs) + + # convert a list of (q, r, m) tuples to lists of all qs, rs, and ms respectively + queries, responses, masks = map(list, zip(*[history.split_query_response_tokens() for history in histories])) + + rewards = [history.reward for history in histories] + return queries, responses, masks, rewards, histories + + def step(self, history): + """ + Step the environment forward one turn. + + Args: + history (`TextHistory`): The history to step forward. + """ + truncated, ended = self.task_end_check(history) + if ended: + history.complete(truncated=truncated) + if history.completed: + return history + + tool, query = self.parse_tool_call(history.last_text_segment) + if tool is None or query is None: + response = f"Unknown tool call: {history.last_text_segment}" + else: + if tool not in self.tools: + response = f"Unknown tool {tool}." + try: + response = self.tools[tool](query) + except Exception as error: + response = f"Tool error: {str(error)}" + + if len(response) > self.max_tool_response: + response = response[: (self.max_tool_response - 3)] + "..." + + history.append_segment( + response + self.response_token, + self.tokenizer(response + self.response_token, return_tensors="pt").input_ids[0].to(self.model.pretrained_model.device), + system=True, + ) + + return history + + def parse_tool_call(self, text): + """ + Parse request string. Expected format: query + """ + result = re.search(f"(?<={self.request_token}).*?(?={self.call_token})", text, re.DOTALL) + + # if we can't find a / span we return none + if result is None: + return None, None + else: + extracted_text = result.group() + + result = re.search(r"<(.*?)>", extracted_text) + + # if we can't find a tool name we return none + if result is None: + return None, None + else: + tool = result.group(1) + + # split off the tool name + query = ">".join(extracted_text.split(">")[1:]) + + return tool, query + + def compute_reward(self, histories, **reward_kwargs): + """ + Compute the reward for a list of histories. + """ + rewards = self.reward_fn([history.last_text_segment for history in histories], **reward_kwargs) + for history, reward in zip(histories, rewards): + history.reward = reward + return histories + + def generate(self, histories): + """ + Generate responses for a list of histories. + """ + active_histories = [i for i, history in enumerate(histories) if not history.completed] + + query_tensors = [histories[i].tokens for i in active_histories] + response_tensors = self._generate_batched(query_tensors) + response_texts = self.tokenizer.batch_decode(response_tensors) + + for i, response_text, response_tensor in zip(active_histories, response_texts, response_tensors): + histories[i].append_segment(response_text, response_tensor, system=False) + + return histories + + def tasks_end_check(self, histories, model_turn=True): + """ + Check if the current generation sequences have finished. + """ + for history in histories: + if not history.completed: + truncated, ended = self.task_end_check(history, model_turn=model_turn) + if ended: + history.complete(truncated=truncated) + return histories + + def task_end_check(self, history, model_turn=True): + """ + Check if the current generation sequence has finished. + """ + truncated = False + ended = False + if history.completed: + return truncated, ended + if self.max_length is not None and len(self.tokenizer(history.text).input_ids[0]) > self.max_length: + truncated = True + ended = True + elif self.tokenizer.eos_token in history.text: + ended = True + elif model_turn and not ((self.request_token in history.last_text_segment and self.call_token in history.last_text_segment) or self.submit_token in history.last_text_segment): + ended = True + elif self.submit_token in history.last_text_segment: + ended = True + return truncated, ended + + def _generate_batched( + self, + query_tensors, + batch_size: int = 16, + pad_to_multiple_of: int = None, + ): + """ + Generate responses for a list of query tensors. + + args: + query_tensors (list[torch.Tensor]): A list of query tensors to generate responses for. + batch_size (int): The batch size to use for generation. + pad_to_multiple_of (int): The padding length to use for generation. + """ + outputs = [] + padding_side_default = self.tokenizer.padding_side + if not self.is_encoder_decoder: + self.tokenizer.padding_side = "left" + + # in case we have fewer examples than bs + batch_size = min(len(query_tensors), batch_size) + + for i in range(0, len(query_tensors), batch_size): + # prevent overflow if query tensors are not even multiple of bs + end_index = min(len(query_tensors), i + batch_size) + + batch = query_tensors[i:end_index] + batch_mask = [torch.ones_like(element) for element in batch] + inputs = {"input_ids": batch, "attention_mask": batch_mask} + + padded_inputs = self.tokenizer.pad( + inputs, + padding=True, + max_length=None, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors="pt", + ).to(self.current_device) + + stopping_criteria = StringStoppingCriteria([self.call_token, self.submit_token], self.tokenizer) + + self.generation_kwargs["stopping_criteria"] = StoppingCriteriaList([stopping_criteria]) + + generations = extract_model_from_parallel(self.model).generate(**padded_inputs, **self.generation_kwargs) + + for generation, mask, generated_tokens in zip(generations, padded_inputs["attention_mask"], stopping_criteria.generated_tokens): + if not self.is_encoder_decoder: + output = generation[(1 - mask).sum() :] # remove padding + else: + output = generation + + if not self.is_encoder_decoder: + output = output[(mask).sum() :] # remove prompt + + # remove chunk generated after stopping criteria in batch mode + outputs.append(output[:generated_tokens]) + self.tokenizer.padding_side = padding_side_default + return outputs diff --git a/trl/extras/__init__.py b/trl/extras/__init__.py new file mode 100755 index 000000000..6b3035db9 --- /dev/null +++ b/trl/extras/__init__.py @@ -0,0 +1,16 @@ +# flake8: noqa + +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .best_of_n_sampler import BestOfNSampler diff --git a/trl/extras/best_of_n_sampler.py b/trl/extras/best_of_n_sampler.py new file mode 100755 index 000000000..b64231f73 --- /dev/null +++ b/trl/extras/best_of_n_sampler.py @@ -0,0 +1,113 @@ +from typing import Any, Callable, List, Optional, Union + +import torch +from transformers import GenerationConfig, PreTrainedTokenizer, PreTrainedTokenizerFast + +from ..core import set_seed +from ..models import SUPPORTED_ARCHITECTURES, PreTrainedModelWrapper + + +class BestOfNSampler(object): + def __init__( + self, + model: PreTrainedModelWrapper, + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + queries_to_scores: Callable[[List[str]], List[float]], + length_sampler: Any, + sample_size: int = 4, + seed: Optional[int] = None, + n_candidates: int = 1, + generation_config: Optional[GenerationConfig] = None, + ) -> None: + r""" + Initialize the sampler for best-of-n generation + + Args: + model (`PreTrainedModelWrapper`): + The pretrained model to use for generation + tokenizer (`PreTrainedTokenizer` or `PreTrainedTokenizerFast`): + Tokenizer associated with the pretrained model + queries_to_scores (`Callable[[List[str]], List[float]]`): + Callable that takes a list of generated texts and returns the associated reward scores + length_sampler (`Any`): + Sampler used to sample the length of the generated text + sample_size (`int`): + Number of samples to generate for each query + seed (`int`, *optional*): + Random seed used to control generation + n_candidates (`int`): + Number of candidates to return for each query + generation_config (`GenerationConfig`, *optional*): + Generation config passed to the underlying model's `generate` method. + See `GenerationConfig` (https://huggingface.co/docs/transformers/v4.29.1/en/main_classes/text_generation#transformers.GenerationConfig) for more details + """ + if seed is not None: + set_seed(seed) + + if not isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)): + raise ValueError(f"tokenizer must be a PreTrainedTokenizer or PreTrainedTokenizerFast, got {type(tokenizer)}") + if not isinstance(model, (SUPPORTED_ARCHITECTURES)): + raise ValueError(f"model must be a PreTrainedModelWrapper, got {type(model)} - supported architectures are: {SUPPORTED_ARCHITECTURES}") + + self.model = model + self.tokenizer = tokenizer + + self.queries_to_scores = queries_to_scores + self.length_sampler = length_sampler + self.gen_config = generation_config + self.sample_size = sample_size + self.n_candidates = n_candidates + + def generate( + self, + tokenized_query: Union[List[int], torch.Tensor, List[torch.Tensor], List[List[int]]], + skip_special_tokens: bool = True, + device: Optional[Union[str, torch.device]] = None, + **generation_kwargs, + ) -> List[List[str]]: + r""" + Generate the best of n samples for input queries + + Args: + tokenized_query (`List[int]` or `torch.Tensor` or `List[torch.Tensor]` or `List[int]`): + represents either a single tokenized query (a single tensor or a list of integers) or a batch of tokenized queries (a list of tensors or a list of lists of integers) + skip_special_tokens (`bool`): + Whether to remove the special tokens from the output + device (`str` or `torch.device`, *optional*): + The device on which the model will be loaded + **generation_kwargs (`dict`, *optional*): + Additional keyword arguments passed along to the underlying model's `generate` method. + This is used to override generation config + + Returns: + List[List[str]]: A list of lists of generated texts + """ + queries = None + + if isinstance(tokenized_query, torch.Tensor) and tokenized_query.ndim == 1: + queries = tokenized_query.unsqueeze(0) + elif isinstance(tokenized_query, List): + element_type = type(tokenized_query[0]) + if element_type == int: + queries = torch.tensor(tokenized_query).unsqueeze(0) + elif element_type == torch.Tensor: + queries = [tensor.reshape((1, -1)) for tensor in tokenized_query] + else: + queries = [torch.tensor(query).reshape((1, -1)) for query in tokenized_query] + + result = [] + + for query in queries: + queries = query.repeat((self.sample_size, 1)) + output = self.model.generate( + queries.to(device), + max_new_tokens=self.length_sampler(), + generation_config=self.gen_config, + **generation_kwargs, + ).squeeze() + output = self.tokenizer.batch_decode(output, skip_special_tokens=skip_special_tokens) + scores = torch.tensor(self.queries_to_scores(output)) + output = [output[i] for i in scores.topk(self.n_candidates).indices] + result.append(output) + + return result diff --git a/trl/extras/dataset_formatting.py b/trl/extras/dataset_formatting.py new file mode 100755 index 000000000..b9a691bf5 --- /dev/null +++ b/trl/extras/dataset_formatting.py @@ -0,0 +1,86 @@ +import logging +from typing import Callable, Literal, Optional, Union + +from datasets import Dataset, Value +from transformers import AutoTokenizer + +from ..trainer.utils import ConstantLengthDataset + + +FORMAT_MAPPING = { + "chatml": [{"content": Value(dtype="string", id=None), "role": Value(dtype="string", id=None)}], + "instruction": {"completion": Value(dtype="string", id=None), "prompt": Value(dtype="string", id=None)}, +} + + +def conversations_formatting_function(tokenizer: AutoTokenizer, messages_field: Literal["messages", "conversations"]): + r""" + return a callable function that takes in a "messages" dataset and returns a formatted dataset, based on the tokenizer + apply chat template to the dataset + """ + + def format_dataset(examples): + if isinstance(examples[messages_field][0], list): + output_texts = [] + for i in range(len(examples[messages_field])): + output_texts.append(tokenizer.apply_chat_template(examples[messages_field][i], tokenize=False)) + return output_texts + else: + return tokenizer.apply_chat_template(examples[messages_field], tokenize=False) + + return format_dataset + + +def instructions_formatting_function(tokenizer: AutoTokenizer): + r""" + return a callable function that takes in an "instructions" dataset and returns a formatted dataset, based on the tokenizer + apply chat template to the dataset + """ + + def format_dataset(examples): + if isinstance(examples["prompt"], list): + output_texts = [] + for i in range(len(examples["prompt"])): + converted_sample = [ + {"role": "user", "content": examples["prompt"][i]}, + {"role": "assistant", "content": examples["completion"][i]}, + ] + output_texts.append(tokenizer.apply_chat_template(converted_sample, tokenize=False)) + return output_texts + else: + converted_sample = [ + {"role": "user", "content": examples["prompt"]}, + {"role": "assistant", "content": examples["completion"]}, + ] + return tokenizer.apply_chat_template(converted_sample, tokenize=False) + + return format_dataset + + +def get_formatting_func_from_dataset(dataset: Union[Dataset, ConstantLengthDataset], tokenizer: AutoTokenizer) -> Optional[Callable]: + r""" + Finds the correct formatting function based on the dataset structure. Currently supported datasets are: + - `ChatML` with [{"role": str, "content": str}] + - `instruction` with [{"prompt": str, "completion": str}] + + Args: + dataset (Dataset): User dataset + tokenizer (AutoTokenizer): Tokenizer used for formatting + + Returns: + Callable: Formatting function if the dataset format is supported else None + """ + if isinstance(dataset, Dataset): + if "messages" in dataset.features: + if dataset.features["messages"] == FORMAT_MAPPING["chatml"]: + logging.info("Formatting dataset with chatml format") + return conversations_formatting_function(tokenizer, "messages") + if "conversations" in dataset.features: + if dataset.features["conversations"] == FORMAT_MAPPING["chatml"]: + logging.info("Formatting dataset with chatml format") + return conversations_formatting_function(tokenizer, "conversations") + elif dataset.features == FORMAT_MAPPING["instruction"]: + logging.info("Formatting dataset with instruction format") + return instructions_formatting_function(tokenizer) + + return None diff --git a/trl/import_utils.py b/trl/import_utils.py new file mode 100755 index 000000000..88a04f7d1 --- /dev/null +++ b/trl/import_utils.py @@ -0,0 +1,108 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib +import sys + + +if sys.version_info < (3, 8): + _is_python_greater_3_8 = False +else: + _is_python_greater_3_8 = True + + +def is_peft_available() -> bool: + return importlib.util.find_spec("peft") is not None + + +def is_unsloth_available() -> bool: + return importlib.util.find_spec("unsloth") is not None + + +def is_accelerate_greater_20_0() -> bool: + if _is_python_greater_3_8: + from importlib.metadata import version + + accelerate_version = version("accelerate") + else: + import pkg_resources + + accelerate_version = pkg_resources.get_distribution("accelerate").version + return accelerate_version >= "0.20.0" + + +def is_transformers_greater_than(version: str) -> bool: + _transformers_version = importlib.metadata.version("transformers") + return _transformers_version > version + + +def is_torch_greater_2_0() -> bool: + if _is_python_greater_3_8: + from importlib.metadata import version + + torch_version = version("torch") + else: + import pkg_resources + + torch_version = pkg_resources.get_distribution("torch").version + return torch_version >= "2.0" + + +def is_diffusers_available() -> bool: + return importlib.util.find_spec("diffusers") is not None + + +def is_bitsandbytes_available() -> bool: + import torch + + # bnb can be imported without GPU but is not usable. + return importlib.util.find_spec("bitsandbytes") is not None and torch.cuda.is_available() + + +def is_torchvision_available() -> bool: + return importlib.util.find_spec("torchvision") is not None + + +def is_rich_available() -> bool: + return importlib.util.find_spec("rich") is not None + + +def is_wandb_available() -> bool: + return importlib.util.find_spec("wandb") is not None + + +def is_xpu_available() -> bool: + if is_accelerate_greater_20_0(): + import accelerate + + return accelerate.utils.is_xpu_available() + else: + if importlib.util.find_spec("intel_extension_for_pytorch") is None: + return False + try: + import torch + + return hasattr(torch, "xpu") and torch.xpu.is_available() + except RuntimeError: + return False + + +def is_npu_available() -> bool: + """Checks if `torch_npu` is installed and potentially if a NPU is in the environment""" + if importlib.util.find_spec("torch") is None or importlib.util.find_spec("torch_npu") is None: + return False + + import torch + import torch_npu # noqa: F401 + + return hasattr(torch, "npu") and torch.npu.is_available() diff --git a/trl/models/__init__.py b/trl/models/__init__.py new file mode 100755 index 000000000..ec2034553 --- /dev/null +++ b/trl/models/__init__.py @@ -0,0 +1,35 @@ +# flake8: noqa + +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .modeling_base import PreTrainedModelWrapper, create_reference_model +from .modeling_value_head import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead +from .utils import setup_chat_format + + +SUPPORTED_ARCHITECTURES = ( + AutoModelForCausalLMWithValueHead, + AutoModelForSeq2SeqLMWithValueHead, +) + +from ..import_utils import is_diffusers_available + + +if is_diffusers_available(): + from .modeling_sd_base import ( + DDPOPipelineOutput, + DDPOSchedulerOutput, + DDPOStableDiffusionPipeline, + DefaultDDPOStableDiffusionPipeline, + ) diff --git a/trl/models/modeling_base.py b/trl/models/modeling_base.py new file mode 100755 index 000000000..9e6c4fceb --- /dev/null +++ b/trl/models/modeling_base.py @@ -0,0 +1,640 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import logging +import os +from copy import deepcopy + +import torch +import torch.nn as nn +from accelerate import PartialState +from huggingface_hub import hf_hub_download +from huggingface_hub.utils import ( + EntryNotFoundError, + HFValidationError, + LocalEntryNotFoundError, + RepositoryNotFoundError, +) +from safetensors.torch import load_file as safe_load_file +from transformers import PreTrainedModel + +from ..import_utils import is_npu_available, is_peft_available, is_transformers_greater_than, is_xpu_available + + +if is_peft_available(): + from peft import ( + PeftConfig, + PeftModel, + PeftModelForCausalLM, + PeftModelForSeq2SeqLM, + PromptLearningConfig, + get_peft_model, + prepare_model_for_kbit_training, + ) + +if is_transformers_greater_than("4.33.0"): + from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled +else: + from transformers.deepspeed import is_deepspeed_zero3_enabled + +LAYER_PATTERNS = [ + "transformer.h.{layer}", + "model.decoder.layers.{layer}", + "gpt_neox.layers.{layer}", + "model.layers.{layer}", +] + + +class PreTrainedModelWrapper(nn.Module): + r""" + A wrapper class around a (`transformers.PreTrainedModel`) to be compatible with the + (`~transformers.PreTrained`) class in order to keep some attributes and methods of the + (`~transformers.PreTrainedModel`) class. + + Attributes: + pretrained_model: (`transformers.PreTrainedModel`) + The model to be wrapped. + parent_class: (`transformers.PreTrainedModel`) + The parent class of the model to be wrapped. + supported_args: (`list`) + The list of arguments that are supported by the wrapper class. + """ + + transformers_parent_class = None + supported_args = None + supported_modules = ("v_head",) + supported_rm_modules = ("score",) + supported_pretrained_model_architectures = (PreTrainedModel) if not is_peft_available() else (PreTrainedModel, PeftModelForCausalLM, PeftModelForSeq2SeqLM) + + def __init__(self, pretrained_model=None, score_module=None, supports_rm_adapter=False, rm_adapter_name=None, **kwargs): + super().__init__() + self.pretrained_model = pretrained_model + + self.config = pretrained_model.config + self.prepare_inputs_for_generation = pretrained_model.prepare_inputs_for_generation + self.is_loaded_in_8bit = getattr(pretrained_model, "is_loaded_in_8bit", False) + self.is_loaded_in_4bit = getattr(pretrained_model, "is_loaded_in_4bit", False) + self.is_sequential_parallel = False + + if hasattr(pretrained_model, "gradient_checkpointing_disable"): + self.gradient_checkpointing_disable = pretrained_model.gradient_checkpointing_disable + + if hasattr(pretrained_model, "gradient_checkpointing_enable"): + self.gradient_checkpointing_enable = pretrained_model.gradient_checkpointing_enable + + self.supports_rm_adapter = supports_rm_adapter + self.rm_adapter_name = rm_adapter_name + self.policy_adapter_name = "default" + if score_module is not None: + self.score = score_module + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + r""" + Instantiates a new model from a pretrained model from `transformers`. The + pretrained model is loaded using the `from_pretrained` method of the + `transformers.PreTrainedModel` class. The arguments that are specific to the + `transformers.PreTrainedModel` class are passed along this method and filtered + out from the `kwargs` argument. + + + Args: + pretrained_model_name_or_path (`str` or `transformers.PreTrainedModel`): + The path to the pretrained model or its name. + *model_args (`list`, *optional*)): + Additional positional arguments passed along to the underlying model's + `from_pretrained` method. + **kwargs (`dict`, *optional*): + Additional keyword arguments passed along to the underlying model's + `from_pretrained` method. We also pre-process the kwargs to extract + the arguments that are specific to the `transformers.PreTrainedModel` + class and the arguments that are specific to trl models. The kwargs + also support `prepare_model_for_kbit_training` arguments from + `peft` library. + """ + if kwargs is not None: + peft_config = kwargs.pop("peft_config", None) + reward_adapter = kwargs.pop("reward_adapter", None) + reward_adapter_name = kwargs.pop("reward_adapter_name", "reward_adapter") + is_trainable = kwargs.pop("is_trainable", False) + trl_model_args, pretrained_kwargs, peft_quantization_kwargs = cls._split_kwargs(kwargs) + token = pretrained_kwargs.get("token", None) + else: + peft_config = None + is_trainable = False + trl_model_args = {} + pretrained_kwargs = {} + peft_quantization_kwargs = {} + token = None + + if reward_adapter is not None and not isinstance(reward_adapter, str): + raise ValueError("The `reward_adapter` argument should be a string representing the name of local path or the Hub id to the Reward Modeling adapter.") + + is_peft_model = False + + current_device = cls._get_current_device() + if isinstance(pretrained_model_name_or_path, str): + is_loaded_in_8bit = pretrained_kwargs["load_in_8bit"] if "load_in_8bit" in pretrained_kwargs else False + is_loaded_in_4bit = pretrained_kwargs["load_in_4bit"] if "load_in_4bit" in pretrained_kwargs else False + else: + is_loaded_in_8bit = getattr(pretrained_model_name_or_path, "is_loaded_in_8bit", False) + is_loaded_in_4bit = getattr(pretrained_model_name_or_path, "is_loaded_in_4bit", False) + + if (is_loaded_in_8bit or is_loaded_in_4bit) and "device_map" not in pretrained_kwargs: + # warn users + logging.warning( + "The `device_map` argument is not provided. We will override the device_map argument." + " to set the entire" + " model on the current device. If you want to set the model on multiple devices, please provide" + " a custom `device_map` argument." + ) + pretrained_kwargs["device_map"] = {"": current_device} + + if is_peft_available() and peft_config is not None and not isinstance(peft_config, PeftConfig): + raise ValueError("The `peft_config` argument should be an instance of `peft.PeftConfig` class.") + + # First, load the pre-trained model using the parent-class + # either `AutoModelForCausalLM` or `AutoModelForSeq2SeqLM` + if isinstance(pretrained_model_name_or_path, str): + if is_peft_available(): + try: + # If there is a trained peft adapter in the hub, load its config. + remote_adapter_config = hf_hub_download( + pretrained_model_name_or_path, + "adapter_config.json", + token=token, + ) + except (EntryNotFoundError, LocalEntryNotFoundError, HFValidationError, RepositoryNotFoundError): + remote_adapter_config = None + else: + remote_adapter_config = None + + local_adapter_present = os.path.exists(os.path.join(pretrained_model_name_or_path, "adapter_config.json")) + + if (local_adapter_present or remote_adapter_config is not None) and is_peft_available(): + if peft_config is not None: + logging.warning("`peft_config` argument ignored since a peft config file was found in " f"{pretrained_model_name_or_path}") + + # Load the trained peft adapter config + if local_adapter_present: + trained_adapter_config = PeftConfig.from_pretrained(pretrained_model_name_or_path) + else: + remote_adapter_dir = os.path.dirname(remote_adapter_config) + trained_adapter_config = PeftConfig.from_pretrained(remote_adapter_dir) + + # Load the pretrained base model + pretrained_model = cls.transformers_parent_class.from_pretrained(trained_adapter_config.base_model_name_or_path, *model_args, **pretrained_kwargs) + + # Wrap the pretrained model with the trained peft adapter + pretrained_model = PeftModel.from_pretrained(pretrained_model, pretrained_model_name_or_path, is_trainable=is_trainable) + logging.info("Trained peft adapter loaded") + else: + pretrained_model = cls.transformers_parent_class.from_pretrained(pretrained_model_name_or_path, *model_args, **pretrained_kwargs) + + if peft_config is not None: + # Initialize a new peft adapter with the given config + if is_loaded_in_8bit or is_loaded_in_4bit: + pretrained_model = prepare_model_for_kbit_training( + pretrained_model, + **peft_quantization_kwargs, + ) + pretrained_model = get_peft_model(pretrained_model, peft_config) + logging.info("peft adapter initialised") + + elif isinstance(pretrained_model_name_or_path, cls.supported_pretrained_model_architectures): + pretrained_model = pretrained_model_name_or_path + + if peft_config is not None and isinstance(pretrained_model, PreTrainedModel): + # Initialize a new peft adapter with the given config + if is_loaded_in_8bit or is_loaded_in_4bit: + pretrained_model = prepare_model_for_kbit_training( + pretrained_model, + **peft_quantization_kwargs, + ) + pretrained_model = get_peft_model(pretrained_model, peft_config) + logging.info("peft adapter initialised") + else: + raise ValueError("pretrained_model_name_or_path should be a string or a PreTrainedModel, " f"but is {type(pretrained_model_name_or_path)}") + + if is_peft_available(): + if isinstance(pretrained_model, PeftModel): + is_peft_model = True + # for backward compatibility + if hasattr(pretrained_model, "active_peft_config") and isinstance(pretrained_model.active_peft_config, PromptLearningConfig): + raise ValueError("PromptLearningConfig is not supported for PPO training.") + + # Add reward modeling adapter if specified + if not is_peft_model and reward_adapter is not None: + raise ValueError("reward_adapter can only be used with a PeftModel. ") + elif is_peft_model and reward_adapter is not None: + score_module = cls.add_and_load_reward_modeling_adapter(pretrained_model, reward_adapter, reward_adapter_name, token=token) + multi_adapter_args = { + "score_module": score_module, + "supports_rm_adapter": True, + "rm_adapter_name": reward_adapter_name, + } + else: + multi_adapter_args = {"supports_rm_adapter": False} + + # Then, create the full model by instantiating the wrapper class + model = cls(pretrained_model, **multi_adapter_args, **trl_model_args) + + # if resume_training, load the state_dict again - this is ok since the + # state_dict is removed from the model after loading it. + is_resuming_training = True + if isinstance(pretrained_model_name_or_path, str): + safe_filename = os.path.join(pretrained_model_name_or_path, "model.safetensors") + filename = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin") + + sharded_index_filename = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin.index.json") + safe_sharded_index_filename = os.path.join(pretrained_model_name_or_path, "model.safetensors.index.json") + is_sharded = False + use_safe = os.path.exists(safe_filename) + + if not (os.path.exists(filename) or os.path.exists(safe_filename)): + # Try with `pytorch_model.bin` + filename, files_to_download, is_sharded, is_resuming_training = cls._get_checkpoint_from_hub( + pretrained_model, + pretrained_model_name_or_path, + sharded_index_filename, + token=token, + ) + # Try with safetensors + if filename is None and files_to_download is None: + safe_filename, files_to_download, is_sharded, is_resuming_training = cls._get_checkpoint_from_hub( + pretrained_model, + pretrained_model_name_or_path, + safe_sharded_index_filename, + token=token, + model_name="model.safetensors", + model_index_name="model.safetensors.index.json", + ) + use_safe = True + else: + use_safe = False + + loading_func = safe_load_file if use_safe else torch.load + load_kwargs = {} if use_safe else {"map_location": "cpu"} + + if is_resuming_training: + if is_sharded: + # download each file and add it to the state_dict + state_dict = {} + + for shard_file in files_to_download: + filename = hf_hub_download( + pretrained_model_name_or_path, + shard_file, + token=token, + ) + state_dict.update(loading_func(filename, **load_kwargs)) + else: + state_dict = loading_func(filename if not use_safe else safe_filename, **load_kwargs) + + else: + state_dict = pretrained_model_name_or_path.state_dict() + + model.is_peft_model = is_peft_model + model.current_device = current_device + + if is_resuming_training: + model.post_init(state_dict=state_dict) + + return model + + @classmethod + def _get_checkpoint_from_hub( + cls, + pretrained_model, + pretrained_model_name_or_path, + index_filename, + token=None, + model_name="pytorch_model.bin", + model_index_name="pytorch_model.bin.index.json", + ): + files_to_download = None + filename = None + is_resuming_training = True + is_sharded = False + + try: + filename = hf_hub_download( + pretrained_model_name_or_path, + model_name, + token=token, + ) + # sharded + except (EntryNotFoundError, LocalEntryNotFoundError, HFValidationError, RepositoryNotFoundError): + if os.path.exists(index_filename): + index_file_name = index_filename + else: + try: + index_file_name = hf_hub_download( + pretrained_model_name_or_path, + model_index_name, + token=token, + ) + except (EntryNotFoundError, LocalEntryNotFoundError, HFValidationError, RepositoryNotFoundError): + # not continue training, do not have v_head weight + is_resuming_training = False + logging.warning(f"A {type(pretrained_model)} model is loaded from '{pretrained_model_name_or_path}', " f"and no v_head weight is found. This IS expected if you are not resuming PPO training.") + # load json + if is_resuming_training: + with open(index_file_name, "r") as f: + index = json.load(f) + # check filename with `v_head` or any known extra module: + files_to_download = set() + for k, v in index["weight_map"].items(): + if any([module in k for module in cls.supported_modules]): + files_to_download.add(v) + is_sharded = True + + return filename, files_to_download, is_sharded, is_resuming_training + + @classmethod + def _get_current_device(cls): + r""" + Get the current device. For GPU, we return the local process index using the `accelerate.PartialState` + object to handle corner cases when running scripts in distributed environments. + + Returns: + current_device (`Union[int, str]`): + The current device. + """ + state = PartialState() + if is_xpu_available(): + return f"xpu:{state.local_process_index}" + elif is_npu_available(): + return f"npu:{state.local_process_index}" + else: + return state.local_process_index if torch.cuda.is_available() else "cpu" + + @classmethod + def _split_kwargs(cls, kwargs): + """ + Separate the kwargs from the arguments that we support inside + `supported_args` and the ones that we don't. + """ + check_peft_kwargs = False + + if is_peft_available(): + from peft import prepare_model_for_kbit_training + + check_peft_kwargs = True + + supported_kwargs = {} + unsupported_kwargs = {} + peft_kwargs = {} + + for key, value in kwargs.items(): + if key in cls.supported_args: + supported_kwargs[key] = value + else: + unsupported_kwargs[key] = value + + if check_peft_kwargs: + if key in prepare_model_for_kbit_training.__code__.co_varnames: + peft_kwargs[key] = value + if key in unsupported_kwargs: + unsupported_kwargs.pop(key) + + return supported_kwargs, unsupported_kwargs, peft_kwargs + + @classmethod + def add_and_load_reward_modeling_adapter(cls, pretrained_model, adapter_model_id, adapter_name="reward_model_adapter", token=None): + r""" + Add and load a reward modeling adapter. This method can only be used if the + model is a `PeftModel` and if you have initialized the model with the `reward_modeling_adapter_id` + argument, pointing to the id of the reward modeling adapter. The latest needs also to contain the + score head in order to produce the reward. + """ + pretrained_model.load_adapter(adapter_model_id, adapter_name, is_trainable=False) + pretrained_model.train() + + filename = os.path.join(adapter_model_id, "adapter_model.bin") + safe_loading = False + if not os.path.exists(filename): + try: + local_filename = hf_hub_download( + adapter_model_id, + "adapter_model.bin", + token=token, + ) + except: # noqa + filename = os.path.join(adapter_model_id, "adapter_model.safetensors") + safe_loading = True + if not os.path.exists(filename): + try: + local_filename = hf_hub_download( + adapter_model_id, + "adapter_model.safetensors", + token=token, + ) + except: # noqa + raise ValueError("Could not find adapter model in the Hub, make sure you have the correct adapter model id.") + else: + local_filename = filename + else: + local_filename = filename + + loading_func = safe_load_file if safe_loading else torch.load + load_kwargs = {} if safe_loading else {"map_location": "cpu"} + + adapter_state_dict = loading_func(local_filename, **load_kwargs) + + for score_name_candidate in cls.supported_rm_modules: + if any([score_name_candidate in name for name in adapter_state_dict.keys()]): + score_name = score_name_candidate + # we have found the correct head name and can break + break + + score_dict = {} + + for name, param in adapter_state_dict.items(): + if score_name in name: + key_name = ".".join(name.split(".")[-1:]) + score_dict[key_name] = param.to(cls._get_current_device()) + + num_labels, hidden_dim = score_dict["weight"].shape + has_bias = any(["bias" in name for name in adapter_state_dict.keys()]) + + score = nn.Linear(hidden_dim, num_labels, bias=has_bias).to( + device=cls._get_current_device(), + dtype=pretrained_model.dtype, + ) + score.load_state_dict(score_dict) + for param in score.parameters(): + param.requires_grad = False + + return score + + def push_to_hub(self, *args, **kwargs): + r""" + Push the pretrained model to the hub. This method is a wrapper around + `transformers.PreTrainedModel.push_to_hub`. Please refer to the documentation + of `transformers.PreTrainedModel.push_to_hub` for more information. + + Args: + *args (`list`, *optional*): + Positional arguments passed along to the underlying model's + `push_to_hub` method. + **kwargs (`dict`, *optional*): + Keyword arguments passed along to the underlying model's + `push_to_hub` method. + """ + raise NotImplementedError + + def save_pretrained(self, *args, **kwargs): + r""" + Save the pretrained model to a directory. This method is a wrapper around + `transformers.PreTrainedModel.save_pretrained`. Please refer to the documentation + of `transformers.PreTrainedModel.save_pretrained` for more information. + + Args: + *args (`list`, *optional*): + Positional arguments passed along to the underlying model's + `save_pretrained` method. + **kwargs (`dict`, *optional*): + Keyword arguments passed along to the underlying model's + `save_pretrained` method. + """ + state_dict = kwargs.get("state_dict") + if state_dict is None: + state_dict = self.state_dict() + kwargs["state_dict"] = state_dict + + # if it is a peft model only save the `v_head` state_dict and + # pop the `state_dict` from the kwargs to avoid slient bugs with `peft` + if self.is_peft_model: + save_path = args[0] + save_path = os.path.join(save_path, "pytorch_model.bin") + torch.save(state_dict, save_path) + _ = kwargs.pop("state_dict", None) + + return self.pretrained_model.save_pretrained(*args, **kwargs) + + def state_dict(self, *args, **kwargs): + r""" + Return the state_dict of the pretrained model. + """ + raise NotImplementedError + + def post_init(self, *args, **kwargs): + r""" + Post initialization method. This method is called after the model is + instantiated and loaded from a checkpoint. It can be used to perform + additional operations such as loading the state_dict. + """ + raise NotImplementedError + + def compute_reward_score(self, input_ids, attention_mask=None, **kwargs): + r""" + Computes the reward score for a given input. The method has first to enable the adapter + and then compute the reward score. After that the model disables the reward modeling + adapter and enables the default ppo adapter again. + """ + if not self.supports_rm_adapter: + raise ValueError("This model does not support reward modeling adapter.") + + # enable rm adapter + self.pretrained_model.set_adapter(self.rm_adapter_name) + self.pretrained_model.eval() + + with torch.no_grad(): + base_model_output = self.pretrained_model( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + return_dict=True, + **kwargs, + ) + + last_hidden_states = base_model_output.hidden_states[-1] + scores = self.score(last_hidden_states) + + self.pretrained_model.set_adapter(self.policy_adapter_name) + self.pretrained_model.eval() + + return scores + + +def create_reference_model(model: PreTrainedModelWrapper, num_shared_layers: int = None, pattern: str = None) -> PreTrainedModelWrapper: + """ + Creates a static reference copy of a model. Note that model will be in `.eval()` mode. + + Args: + model (`PreTrainedModelWrapper`): The model to be copied. + num_shared_layers (`int`, *optional*): The number of initial layers that are shared between both models and kept frozen. + pattern (`str`, *optional*): The shared layers are selected with a string pattern + (e.g. "transformer.h.{layer}" for GPT2) and if a custom pattern is necessary it can be passed here. + + Returns + `PreTrainedModelWrapper` + """ + if is_deepspeed_zero3_enabled(): + raise ValueError("DeepSpeed ZeRO-3 is enabled and is not compatible with `create_reference_model()`. Please instantiate your reference model directly with `AutoCausalLM.from_pretrained()`.") + + parameter_names = [n for n, _ in model.named_parameters()] + ref_model = deepcopy(model) + + # if no layers are shared, return copy of model + if num_shared_layers is None: + for param_name in parameter_names: + param = ref_model.get_parameter(param_name) + param.requires_grad = False + return ref_model.eval() + + # identify layer name pattern + if pattern is not None: + pattern = pattern.format(layer=num_shared_layers) + else: + for pattern_candidate in LAYER_PATTERNS: + pattern_candidate = pattern_candidate.format(layer=num_shared_layers) + if any([pattern_candidate in name for name in parameter_names]): + pattern = pattern_candidate + break + + if pattern is None: + raise ValueError("Layer pattern could not be matched.") + + # divide parameters in shared and unshared parameter lists + shared_param_list = [] + unshared_param_list = [] + + shared_parameter = True + for name, param in model.named_parameters(): + if pattern in name: + shared_parameter = False + if shared_parameter: + shared_param_list.append(name) + else: + unshared_param_list.append(name) + + # create reference of the original parameter if they are shared + for param_name in shared_param_list: + param = model.get_parameter(param_name) + param.requires_grad = False + + ref_param = ref_model.get_parameter(param_name) # noqa + ref_param = param # noqa + + # for all other parameters just make sure they don't use gradients + for param_name in unshared_param_list: + param = ref_model.get_parameter(param_name) + param.requires_grad = False + + if pattern is not None and len(unshared_param_list) == 0: + logging.warning("Pattern passed or found, but no layers matched in the model. Check for a typo.") + + return ref_model.eval() diff --git a/trl/models/modeling_sd_base.py b/trl/models/modeling_sd_base.py new file mode 100755 index 000000000..5cf6a1d68 --- /dev/null +++ b/trl/models/modeling_sd_base.py @@ -0,0 +1,624 @@ +# Copyright 2023 DDPO-pytorch authors (Kevin Black), The HuggingFace Team, metric-space. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import os +import warnings +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from diffusers import DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import rescale_noise_cfg +from diffusers.utils import convert_state_dict_to_diffusers + +from ..core import randn_tensor +from ..import_utils import is_peft_available + + +if is_peft_available(): + from peft import LoraConfig + from peft.utils import get_peft_model_state_dict + + +@dataclass +class DDPOPipelineOutput(object): + """ + Output class for the diffusers pipeline to be finetuned with the DDPO trainer + + Args: + images (`torch.Tensor`): + The generated images. + latents (`List[torch.Tensor]`): + The latents used to generate the images. + log_probs (`List[torch.Tensor]`): + The log probabilities of the latents. + + """ + + images: torch.Tensor + latents: torch.Tensor + log_probs: torch.Tensor + + +@dataclass +class DDPOSchedulerOutput(object): + """ + Output class for the diffusers scheduler to be finetuned with the DDPO trainer + + Args: + latents (`torch.Tensor`): + Predicted sample at the previous timestep. Shape: `(batch_size, num_channels, height, width)` + log_probs (`torch.Tensor`): + Log probability of the above mentioned sample. Shape: `(batch_size)` + """ + + latents: torch.Tensor + log_probs: torch.Tensor + + +class DDPOStableDiffusionPipeline(object): + """ + Main class for the diffusers pipeline to be finetuned with the DDPO trainer + """ + + def __call__(self, *args, **kwargs) -> DDPOPipelineOutput: + raise NotImplementedError + + def scheduler_step(self, *args, **kwargs) -> DDPOSchedulerOutput: + raise NotImplementedError + + @property + def unet(self): + """ + Returns the 2d U-Net model used for diffusion. + """ + raise NotImplementedError + + @property + def vae(self): + """ + Returns the Variational Autoencoder model used from mapping images to and from the latent space + """ + raise NotImplementedError + + @property + def tokenizer(self): + """ + Returns the tokenizer used for tokenizing text inputs + """ + raise NotImplementedError + + @property + def scheduler(self): + """ + Returns the scheduler associated with the pipeline used for the diffusion process + """ + raise NotImplementedError + + @property + def text_encoder(self): + """ + Returns the text encoder used for encoding text inputs + """ + raise NotImplementedError + + @property + def autocast(self): + """ + Returns the autocast context manager + """ + raise NotImplementedError + + def set_progress_bar_config(self, *args, **kwargs): + """ + Sets the progress bar config for the pipeline + """ + raise NotImplementedError + + def save_pretrained(self, *args, **kwargs): + """ + Saves all of the model weights + """ + raise NotImplementedError + + def get_trainable_layers(self, *args, **kwargs): + """ + Returns the trainable parameters of the pipeline + """ + raise NotImplementedError + + def save_checkpoint(self, *args, **kwargs): + """ + Light wrapper around accelerate's register_save_state_pre_hook which is run before saving state + """ + raise NotImplementedError + + def load_checkpoint(self, *args, **kwargs): + """ + Light wrapper around accelerate's register_lad_state_pre_hook which is run before loading state + """ + raise NotImplementedError + + +def _left_broadcast(input_tensor, shape): + """ + As opposed to the default direction of broadcasting (right to left), this function broadcasts + from left to right + Args: + input_tensor (`torch.FloatTensor`): is the tensor to broadcast + shape (`Tuple[int]`): is the shape to broadcast to + """ + input_ndim = input_tensor.ndim + if input_ndim > len(shape): + raise ValueError("The number of dimensions of the tensor to broadcast cannot be greater than the length of the shape to broadcast to") + return input_tensor.reshape(input_tensor.shape + (1,) * (len(shape) - input_ndim)).broadcast_to(shape) + + +def _get_variance(self, timestep, prev_timestep): + alpha_prod_t = torch.gather(self.alphas_cumprod, 0, timestep.cpu()).to(timestep.device) + alpha_prod_t_prev = torch.where( + prev_timestep.cpu() >= 0, + self.alphas_cumprod.gather(0, prev_timestep.cpu()), + self.final_alpha_cumprod, + ).to(timestep.device) + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) + + return variance + + +def scheduler_step( + self, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + eta: float = 0.0, + use_clipped_model_output: bool = False, + generator=None, + prev_sample: Optional[torch.FloatTensor] = None, +) -> DDPOSchedulerOutput: + """ + + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + eta (`float`): weight of noise for added noise in diffusion step. + use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped + predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when + `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would + coincide with the one provided as input and `use_clipped_model_output` will have not effect. + generator: random number generator. + variance_noise (`torch.FloatTensor`): instead of generating noise for the variance using `generator`, we + can directly provide the noise for the variance itself. This is useful for methods such as + CycleDiffusion. (https://arxiv.org/abs/2210.05559) + + Returns: + `DDPOSchedulerOutput`: the predicted sample at the previous timestep and the log probability of the sample + """ + + if self.num_inference_steps is None: + raise ValueError("Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler") + + # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf + # Ideally, read DDIM paper in-detail understanding + + # Notation ( -> + # - pred_noise_t -> e_theta(x_t, t) + # - pred_original_sample -> f_theta(x_t, t) or x_0 + # - std_dev_t -> sigma_t + # - eta -> η + # - pred_sample_direction -> "direction pointing to x_t" + # - pred_prev_sample -> "x_t-1" + + # 1. get previous step value (=t-1) + prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps + # to prevent OOB on gather + prev_timestep = torch.clamp(prev_timestep, 0, self.config.num_train_timesteps - 1) + + # 2. compute alphas, betas + alpha_prod_t = self.alphas_cumprod.gather(0, timestep.cpu()) + alpha_prod_t_prev = torch.where( + prev_timestep.cpu() >= 0, + self.alphas_cumprod.gather(0, prev_timestep.cpu()), + self.final_alpha_cumprod, + ) + alpha_prod_t = _left_broadcast(alpha_prod_t, sample.shape).to(sample.device) + alpha_prod_t_prev = _left_broadcast(alpha_prod_t_prev, sample.shape).to(sample.device) + + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + if self.config.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + pred_epsilon = model_output + elif self.config.prediction_type == "sample": + pred_original_sample = model_output + pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + elif self.config.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + else: + raise ValueError(f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction`") + + # 4. Clip or threshold "predicted x_0" + if self.config.thresholding: + pred_original_sample = self._threshold_sample(pred_original_sample) + elif self.config.clip_sample: + pred_original_sample = pred_original_sample.clamp(-self.config.clip_sample_range, self.config.clip_sample_range) + + # 5. compute variance: "sigma_t(η)" -> see formula (16) + # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) + variance = _get_variance(self, timestep, prev_timestep) + std_dev_t = eta * variance ** (0.5) + std_dev_t = _left_broadcast(std_dev_t, sample.shape).to(sample.device) + + if use_clipped_model_output: + # the pred_epsilon is always re-derived from the clipped x_0 in Glide + pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + + # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon + + # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + prev_sample_mean = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + + if prev_sample is not None and generator is not None: + raise ValueError("Cannot pass both generator and prev_sample. Please make sure that either `generator` or" " `prev_sample` stays `None`.") + + if prev_sample is None: + variance_noise = randn_tensor( + model_output.shape, + generator=generator, + device=model_output.device, + dtype=model_output.dtype, + ) + prev_sample = prev_sample_mean + std_dev_t * variance_noise + + # log prob of prev_sample given prev_sample_mean and std_dev_t + log_prob = -((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * (std_dev_t**2)) - torch.log(std_dev_t) - torch.log(torch.sqrt(2 * torch.as_tensor(np.pi))) + # mean along all but batch dimension + log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim))) + + return DDPOSchedulerOutput(prev_sample.type(sample.dtype), log_prob) + + +# 1. The output type for call is different as the logprobs are now returned +# 2. An extra method called `scheduler_step` is added which is used to constraint the scheduler output +@torch.no_grad() +def pipeline_step( + self, + prompt: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, +): + r""" + Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + guidance_rescale (`float`, *optional*, defaults to 0.7): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + + Examples: + + Returns: + `DDPOPipelineOutput`: The generated image, the predicted latents used to generate the image and the associated log probabilities + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + all_latents = [latents] + all_log_probs = [] + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + scheduler_output = scheduler_step(self.scheduler, noise_pred, t, latents, eta) + latents = scheduler_output.latents + log_prob = scheduler_output.log_probs + + all_latents.append(latents) + all_log_probs.append(log_prob) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + return DDPOPipelineOutput(image, all_latents, all_log_probs) + + +class DefaultDDPOStableDiffusionPipeline(DDPOStableDiffusionPipeline): + def __init__(self, pretrained_model_name: str, *, pretrained_model_revision: str = "main", use_lora: bool = True): + self.sd_pipeline = StableDiffusionPipeline.from_pretrained(pretrained_model_name, revision=pretrained_model_revision) + + self.use_lora = use_lora + self.pretrained_model = pretrained_model_name + self.pretrained_revision = pretrained_model_revision + + try: + self.sd_pipeline.load_lora_weights( + pretrained_model_name, + weight_name="pytorch_lora_weights.safetensors", + revision=pretrained_model_revision, + ) + self.use_lora = True + except OSError: + if use_lora: + warnings.warn("If you are aware that the pretrained model has no lora weights to it, ignore this message. " "Otherwise please check the if `pytorch_lora_weights.safetensors` exists in the model folder.") + + self.sd_pipeline.scheduler = DDIMScheduler.from_config(self.sd_pipeline.scheduler.config) + self.sd_pipeline.safety_checker = None + + # memory optimization + self.sd_pipeline.vae.requires_grad_(False) + self.sd_pipeline.text_encoder.requires_grad_(False) + self.sd_pipeline.unet.requires_grad_(not self.use_lora) + + def __call__(self, *args, **kwargs) -> DDPOPipelineOutput: + return pipeline_step(self.sd_pipeline, *args, **kwargs) + + def scheduler_step(self, *args, **kwargs) -> DDPOSchedulerOutput: + return scheduler_step(self.sd_pipeline.scheduler, *args, **kwargs) + + @property + def unet(self): + return self.sd_pipeline.unet + + @property + def vae(self): + return self.sd_pipeline.vae + + @property + def tokenizer(self): + return self.sd_pipeline.tokenizer + + @property + def scheduler(self): + return self.sd_pipeline.scheduler + + @property + def text_encoder(self): + return self.sd_pipeline.text_encoder + + @property + def autocast(self): + return contextlib.nullcontext if self.use_lora else None + + def save_pretrained(self, output_dir): + if self.use_lora: + state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(self.sd_pipeline.unet)) + self.sd_pipeline.save_lora_weights(save_directory=output_dir, unet_lora_layers=state_dict) + self.sd_pipeline.save_pretrained(output_dir) + + def set_progress_bar_config(self, *args, **kwargs): + self.sd_pipeline.set_progress_bar_config(*args, **kwargs) + + def get_trainable_layers(self): + if self.use_lora: + lora_config = LoraConfig( + r=4, + lora_alpha=4, + init_lora_weights="gaussian", + target_modules=["to_k", "to_q", "to_v", "to_out.0"], + ) + self.sd_pipeline.unet.add_adapter(lora_config) + + # To avoid accelerate unscaling problems in FP16. + for param in self.sd_pipeline.unet.parameters(): + # only upcast trainable parameters (LoRA) into fp32 + if param.requires_grad: + param.data = param.to(torch.float32) + return self.sd_pipeline.unet + else: + return self.sd_pipeline.unet + + def save_checkpoint(self, models, weights, output_dir): + if len(models) != 1: + raise ValueError("Given how the trainable params were set, this should be of length 1") + if self.use_lora and hasattr(models[0], "peft_config") and getattr(models[0], "peft_config", None) is not None: + state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(models[0])) + self.sd_pipeline.save_lora_weights(save_directory=output_dir, unet_lora_layers=state_dict) + elif not self.use_lora and isinstance(models[0], UNet2DConditionModel): + models[0].save_pretrained(os.path.join(output_dir, "unet")) + else: + raise ValueError(f"Unknown model type {type(models[0])}") + + def load_checkpoint(self, models, input_dir): + if len(models) != 1: + raise ValueError("Given how the trainable params were set, this should be of length 1") + if self.use_lora: + lora_state_dict, network_alphas = self.sd_pipeline.lora_state_dict(input_dir, weight_name="pytorch_lora_weights.safetensors") + self.sd_pipeline.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=models[0]) + + elif not self.use_lora and isinstance(models[0], UNet2DConditionModel): + load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet") + models[0].register_to_config(**load_model.config) + models[0].load_state_dict(load_model.state_dict()) + del load_model + else: + raise ValueError(f"Unknown model type {type(models[0])}") diff --git a/trl/models/modeling_value_head.py b/trl/models/modeling_value_head.py new file mode 100755 index 000000000..f2ca25ee3 --- /dev/null +++ b/trl/models/modeling_value_head.py @@ -0,0 +1,421 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch.nn as nn +from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM + +from .modeling_base import PreTrainedModelWrapper + + +class ValueHead(nn.Module): + r""" + The ValueHead class implements a head for GPT2 that returns a scalar for each output token. + """ + + def __init__(self, config, **kwargs): + super().__init__() + if not hasattr(config, "summary_dropout_prob"): + summary_dropout_prob = kwargs.pop("summary_dropout_prob", 0.1) + else: + summary_dropout_prob = config.summary_dropout_prob + + self.dropout = nn.Dropout(summary_dropout_prob) if summary_dropout_prob else nn.Identity() + + # some models such as OPT have a projection layer before the word embeddings - e.g. OPT-350m + if hasattr(config, "hidden_size"): + hidden_size = config.hidden_size + if hasattr(config, "word_embed_proj_dim"): + hidden_size = config.word_embed_proj_dim + elif hasattr(config, "is_encoder_decoder"): + if config.is_encoder_decoder and hasattr(config, "decoder"): + if hasattr(config.decoder, "hidden_size"): + hidden_size = config.decoder.hidden_size + + self.summary = nn.Linear(hidden_size, 1) + + self.flatten = nn.Flatten() + + def forward(self, hidden_states): + output = self.dropout(hidden_states) + + # For now force upcast in fp32 if needed. Let's keep the + # output in fp32 for numerical stability. + if output.dtype != self.summary.weight.dtype: + output = output.to(self.summary.weight.dtype) + + output = self.summary(output) + return output + + +class AutoModelForCausalLMWithValueHead(PreTrainedModelWrapper): + r""" + An autoregressive model with a value head in addition to the language model head. + This class inherits from `~trl.PreTrainedModelWrapper` and wraps a + `transformers.PreTrainedModel` class. The wrapper class supports classic functions + such as `from_pretrained`, `push_to_hub` and `generate`. To call a method of the wrapped + model, simply manipulate the `pretrained_model` attribute of this class. + + Class attributes: + - **transformers_parent_class** (`transformers.PreTrainedModel`) -- The parent class of the wrapped model. This + should be set to `transformers.AutoModelForCausalLM` for this class. + - **lm_head_namings** (`tuple`) -- A tuple of strings that are used to identify the language model head of the + wrapped model. This is set to `("lm_head", "embed_out")` for this class but can be changed for other models + in the future + - **supported_args** (`tuple`) -- A tuple of strings that are used to identify the arguments that are supported + by the `ValueHead` class. Currently, the supported args are: + - **summary_dropout_prob** (`float`, `optional`, defaults to `None`) -- The dropout probability for the + `ValueHead` class. + - **v_head_initializer_range** (`float`, `optional`, defaults to `0.2`) -- The initializer range for the + `ValueHead` if a specific initialization strategy is selected. + - **v_head_init_strategy** (`str`, `optional`, defaults to `None`) -- The initialization strategy for the + `ValueHead`. Currently, the supported strategies are: + - **`None`** -- Initializes the weights of the `ValueHead` with a random distribution. This is the default + strategy. + - **"normal"** -- Initializes the weights of the `ValueHead` with a normal distribution. + + """ + + transformers_parent_class = AutoModelForCausalLM + lm_head_namings = ["lm_head", "embed_out"] + supported_args = ( + "summary_dropout_prob", + "v_head_initializer_range", + "v_head_init_strategy", + ) + + def __init__(self, pretrained_model, **kwargs): + r""" + Initializes the model. + + Args: + pretrained_model (`transformers.PreTrainedModel`): + The model to wrap. It should be a causal language model such as GPT2. + or any model mapped inside the `AutoModelForCausalLM` class. + kwargs (`dict`, `optional`): + Additional keyword arguments, that are passed to the `ValueHead` class. + """ + super().__init__(pretrained_model, **kwargs) + v_head_kwargs, _, _ = self._split_kwargs(kwargs) + + if not any(hasattr(self.pretrained_model, attribute) for attribute in self.lm_head_namings): + raise ValueError("The model does not have a language model head, please use a model that has one.") + + self.v_head = ValueHead(self.pretrained_model.config, **v_head_kwargs) + + self._init_weights(**v_head_kwargs) + + def _init_weights(self, **kwargs): + r""" + Initializes the weights of the value head. The default initialization strategy is random. + Users can pass a different initialization strategy by passing the `v_head_init_strategy` argument + when calling `.from_pretrained`. Supported strategies are: + - `normal`: initializes the weights with a normal distribution. + + Args: + **kwargs (`dict`, `optional`): + Additional keyword arguments, that are passed to the `ValueHead` class. These arguments + can contain the `v_head_init_strategy` argument as well as the `v_head_initializer_range` + argument. + """ + initializer_range = kwargs.pop("v_head_initializer_range", 0.2) + # random init by default + init_strategy = kwargs.pop("v_head_init_strategy", None) + if init_strategy is None: + # do nothing + pass + elif init_strategy == "normal": + self.v_head.summary.weight.data.normal_(mean=0.0, std=initializer_range) + self.v_head.summary.bias.data.zero_() + + def forward( + self, + input_ids=None, + past_key_values=None, + attention_mask=None, + **kwargs, + ): + r""" + Applies a forward pass to the wrapped model and returns the logits of the value head. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + past_key_values (`tuple(tuple(torch.FloatTensor))`, `optional`): + Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model + (see `past_key_values` input) to speed up sequential decoding. + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + kwargs (`dict`, `optional`): + Additional keyword arguments, that are passed to the wrapped model. + """ + kwargs["output_hidden_states"] = True # this had already been set in the LORA / PEFT examples + kwargs["past_key_values"] = past_key_values + + if self.is_peft_model and self.pretrained_model.active_peft_config.peft_type == "PREFIX_TUNING": + kwargs.pop("past_key_values") + + base_model_output = self.pretrained_model( + input_ids=input_ids, + attention_mask=attention_mask, + **kwargs, + ) + + last_hidden_state = base_model_output.hidden_states[-1] + lm_logits = base_model_output.logits + loss = base_model_output.loss + + if last_hidden_state.device != self.v_head.summary.weight.device: + last_hidden_state = last_hidden_state.to(self.v_head.summary.weight.device) + + value = self.v_head(last_hidden_state).squeeze(-1) + + # force upcast in fp32 if logits are in half-precision + if lm_logits.dtype != torch.float32: + lm_logits = lm_logits.float() + + return (lm_logits, loss, value) + + def generate(self, *args, **kwargs): + r""" + A simple wrapper around the `generate` method of the wrapped model. + Please refer to the [`generate`](https://huggingface.co/docs/transformers/internal/generation_utils) + method of the wrapped model for more information about the supported arguments. + + Args: + *args (`list`, *optional*): + Positional arguments passed to the `generate` method of the wrapped model. + **kwargs (`dict`, *optional*): + Keyword arguments passed to the `generate` method of the wrapped model. + """ + return self.pretrained_model.generate(*args, **kwargs) + + def state_dict(self, *args, **kwargs): + r""" + Returns the state dictionary of the model. We add the state dictionary of the value head + to the state dictionary of the wrapped model by prepending the key with `v_head.`. + """ + if not self.is_peft_model: + pretrained_model_state_dict = self.pretrained_model.state_dict(*args, **kwargs) + else: + # if it is a peft model, only save the v_head + pretrained_model_state_dict = {} + + v_head_state_dict = self.v_head.state_dict(*args, **kwargs) + for k, v in v_head_state_dict.items(): + pretrained_model_state_dict[f"v_head.{k}"] = v + return pretrained_model_state_dict + + def push_to_hub(self, *args, **kwargs): + setattr(self.pretrained_model, "v_head", self.v_head) + + return self.pretrained_model.push_to_hub(*args, **kwargs) + + def post_init(self, state_dict): + r""" + We add the state dictionary of the value head to the state dictionary of the wrapped model + by prepending the key with `v_head.`. This function removes the `v_head.` prefix from the + keys of the value head state dictionary. + """ + for k in list(state_dict.keys()): + if "v_head." in k: + state_dict[k.replace("v_head.", "")] = state_dict.pop(k) + self.v_head.load_state_dict(state_dict, strict=False) + del state_dict + + if hasattr(self.pretrained_model, "hf_device_map"): + if "cpu" in self.pretrained_model.hf_device_map.values() or "disk" in self.pretrained_model.hf_device_map.values(): + raise ValueError("The model is offloaded on CPU or disk - CPU & disk offloading is not supported for ValueHead models.") + + first_device = list(set(self.pretrained_model.hf_device_map.values()))[0] + + self.v_head = self.v_head.to(first_device) + + def set_device_hook(module, input, outputs): + new_output = () + for output in outputs: + if isinstance(output, torch.Tensor): + new_output += (output.to(first_device),) + else: + new_output += (output,) + return new_output + + self.register_forward_hook(set_device_hook) + + self.is_sequential_parallel = True + + +class AutoModelForSeq2SeqLMWithValueHead(PreTrainedModelWrapper): + r""" + A seq2seq model with a value head in addition to the language model head. + This class inherits from `~trl.PreTrainedModelWrapper` and wraps a + `transformers.PreTrainedModel` class. The wrapper class supports classic functions + such as `from_pretrained` and `push_to_hub` and also provides some additional + functionalities such as `generate`. + + Args: + pretrained_model (`transformers.PreTrainedModel`): + The model to wrap. It should be a causal language model such as GPT2. + or any model mapped inside the `AutoModelForSeq2SeqLM` class. + kwargs: + Additional keyword arguments passed along to the `ValueHead` class. + """ + + transformers_parent_class = AutoModelForSeq2SeqLM + lm_head_namings = ["lm_head", "embed_out", "output_projection"] + supported_args = ( + "summary_dropout_prob", + "v_head_initializer_range", + "v_head_init_strategy", + ) + + def __init__(self, pretrained_model, **kwargs): + super().__init__(pretrained_model, **kwargs) + v_head_kwargs, _, _ = self._split_kwargs(kwargs) + self.is_encoder_decoder = True + + if not self._has_lm_head(): + raise ValueError("The model does not have a language model head, please use a model that has one.") + + self.v_head = ValueHead(self.pretrained_model.config, **v_head_kwargs) + + self._init_weights(**v_head_kwargs) + + def _has_lm_head(self): + # check module names of all modules inside `pretrained_model` to find the language model head + for name, module in self.pretrained_model.named_modules(): + if any(attribute in name for attribute in self.lm_head_namings): + return True + return False + + def post_init(self, state_dict): + r""" + We add the state dictionary of the value head to the state dictionary of the wrapped model + by prepending the key with `v_head.`. This function removes the `v_head.` prefix from the + keys of the value head state dictionary. + """ + for k in list(state_dict.keys()): + if "v_head." in k: + state_dict[k.replace("v_head.", "")] = state_dict.pop(k) + self.v_head.load_state_dict(state_dict, strict=False) + del state_dict + + if hasattr(self.pretrained_model, "hf_device_map"): + if "cpu" in self.pretrained_model.hf_device_map.values() or "disk" in self.pretrained_model.hf_device_map.values(): + raise ValueError("The model is offloaded on CPU or disk - CPU & disk offloading is not supported for ValueHead models.") + + # get the lm_head device + for name, module in self.pretrained_model.named_modules(): + if any(attribute in name for attribute in self.lm_head_namings): + lm_head_device = module.weight.device + break + + # put v_head on the same device as the lm_head to avoid issues + self.v_head = self.v_head.to(lm_head_device) + + def set_device_hook(module, input, outputs): + r""" + A hook that sets the device of the output of the model to the device of the first + parameter of the model. + + Args: + module (`nn.Module`): + The module to which the hook is attached. + input (`tuple`): + The input to the module. + outputs (`tuple`): + The output of the module. + """ + new_output = () + for output in outputs: + if isinstance(output, torch.Tensor): + new_output += (output.to(lm_head_device),) + else: + new_output += (output,) + return new_output + + self.register_forward_hook(set_device_hook) + self.is_sequential_parallel = True + + def state_dict(self, *args, **kwargs): + r""" + Returns the state dictionary of the model. We add the state dictionary of the value head + to the state dictionary of the wrapped model by prepending the key with `v_head.`. + """ + if not self.is_peft_model: + pretrained_model_state_dict = self.pretrained_model.state_dict(*args, **kwargs) + else: + # if it is a peft model, only save the v_head + pretrained_model_state_dict = {} + + v_head_state_dict = self.v_head.state_dict(*args, **kwargs) + for k, v in v_head_state_dict.items(): + pretrained_model_state_dict[f"v_head.{k}"] = v + return pretrained_model_state_dict + + def push_to_hub(self, *args, **kwargs): + setattr(self.pretrained_model, "v_head", self.v_head) + + return self.pretrained_model.push_to_hub(*args, **kwargs) + + def _init_weights(self, **kwargs): + r""" + We initialize the weights of the value head. + """ + initializer_range = kwargs.pop("v_head_initializer_range", 0.2) + # random init by default + init_strategy = kwargs.pop("v_head_init_strategy", None) + if init_strategy is None: + # do nothing + pass + elif init_strategy == "normal": + self.v_head.summary.weight.data.normal_(mean=0.0, std=initializer_range) + self.v_head.summary.bias.data.zero_() + + def forward( + self, + input_ids=None, + past_key_values=None, + attention_mask=None, + **kwargs, + ): + kwargs["past_key_values"] = past_key_values + if self.is_peft_model and self.pretrained_model.active_peft_config.peft_type == "PREFIX_TUNING": + kwargs.pop("past_key_values") + + base_model_output = self.pretrained_model( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, # We force the model to output hidden states + **kwargs, + ) + + last_hidden_state = base_model_output.decoder_hidden_states[-1] + lm_logits = base_model_output.logits + loss = base_model_output.loss + + value = self.v_head(last_hidden_state).squeeze(-1) + + # force upcast in fp32 if logits are in half-precision + if lm_logits.dtype != torch.float32: + lm_logits = lm_logits.float() + + return (lm_logits, loss, value) + + def generate(self, *args, **kwargs): + r""" + We call `generate` on the wrapped model. + """ + return self.pretrained_model.generate(*args, **kwargs) diff --git a/trl/models/utils.py b/trl/models/utils.py new file mode 100755 index 000000000..d9caf757f --- /dev/null +++ b/trl/models/utils.py @@ -0,0 +1,83 @@ +from dataclasses import dataclass +from typing import Literal, Optional, Tuple + +from transformers import PreTrainedModel, PreTrainedTokenizer + + +# TODO: Add Abstract Base Class if more formats are added +@dataclass +class ChatMlSpecialTokens: + """Dataclass for special tokens used in ChatML, including system, user, assistant, bos, eos, and pad tokens.""" + + bos_token: str = "<|im_start|>" + eos_token: str = "<|im_end|>" + pad_token: str = "<|im_end|>" + + @property + def system(self): + return f"{self.bos_token}system" + + @property + def user(self): + return f"{self.bos_token}user" + + @property + def assistant(self): + return f"{self.bos_token}assistant" + + @property + def chat_template(self): + return ( + "{% for message in messages %}" + f"{{{{'{self.bos_token}' + message['role'] + '\n' + message['content'] + '{self.eos_token}' + '\n'}}}}" + "{% endfor %}" + "{% if add_generation_prompt %}" + f"{{{{ '{self.assistant}\n' }}}}" + "{% endif %}" + ) + + +FORMAT_MAPPING = {"chatml": ChatMlSpecialTokens} + + +def setup_chat_format( + model: PreTrainedModel, + tokenizer: PreTrainedTokenizer, + format: Optional[Literal["chatml"]] = "chatml", + resize_to_multiple_of: Optional[int] = None, +) -> Tuple[PreTrainedModel, PreTrainedTokenizer]: + """ + Setup chat format by adding special tokens to the tokenizer, setting the correct format, and extending the embedding layer of the model based on the new special tokens. + + Args: + model (`~transformers.PreTrainedModel`): The model to be modified. + tokenizer (`~transformers.PreTrainedTokenizer`): The tokenizer to be modified. + format (`Optional[Literal["chatml"]]`): The format to be set. Defaults to "chatml". + resize_to_multiple_of (`Optional[int]`): Number to resize the embedding layer to. Defaults to None. + Returns: + model (`~transformers.PreTrainedModel`): The modified model. + tokenizer (`~transformers.PreTrainedTokenizer`): The modified tokenizer. + """ + # check if format available and retrieve + if format not in FORMAT_MAPPING: + raise ValueError(f"Format {format} not available. Please use one of {FORMAT_MAPPING.keys()}") + + chat_format = FORMAT_MAPPING[format]() + + # set special tokens and them + tokenizer.eos_token = chat_format.eos_token + tokenizer.pad_token = chat_format.pad_token + tokenizer.bos_token = chat_format.bos_token + tokenizer.add_special_tokens({"additional_special_tokens": [chat_format.bos_token, chat_format.eos_token]}) + # set chat format for tokenizer + tokenizer.chat_template = chat_format.chat_template + + # resize embedding layer to a multiple of 64, https://x.com/karpathy/status/1621578354024677377 + model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=resize_to_multiple_of if resize_to_multiple_of is not None else None) + # Make sure to update the generation config to use the new eos & bos token + if getattr(model, "generation_config", None) is not None: + model.generation_config.bos_token_id = tokenizer.bos_token_id + model.generation_config.eos_token_id = tokenizer.eos_token_id + model.generation_config.pad_token_id = tokenizer.pad_token_id + + return model, tokenizer diff --git a/trl/trainer/__init__.py b/trl/trainer/__init__.py new file mode 100755 index 000000000..d35460196 --- /dev/null +++ b/trl/trainer/__init__.py @@ -0,0 +1,46 @@ +# flake8: noqa + +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# There is a circular import in the PPOTrainer if we let isort sort these +# isort: off +from .utils import ( + AdaptiveKLController, + FixedKLController, + ConstantLengthDataset, + DataCollatorForCompletionOnlyLM, + RunningMoments, + disable_dropout_in_model, + peft_module_casting_to_bf16, +) + +# isort: on + +from ..import_utils import is_diffusers_available +from .base import BaseTrainer +from .ddpo_config import DDPOConfig + + +if is_diffusers_available(): + from .ddpo_trainer import DDPOTrainer + +from .dpo_trainer import DPOTrainer +from .iterative_sft_trainer import IterativeSFTTrainer +from .model_config import ModelConfig +from .ppo_config import PPOConfig +from .ppo_trainer import PPOTrainer +from .reward_config import RewardConfig +from .reward_trainer import RewardTrainer, compute_accuracy +from .sft_trainer import SFTTrainer diff --git a/trl/trainer/base.py b/trl/trainer/base.py new file mode 100755 index 000000000..f0314cb98 --- /dev/null +++ b/trl/trainer/base.py @@ -0,0 +1,46 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from huggingface_hub import PyTorchModelHubMixin + + +class BaseTrainer(PyTorchModelHubMixin): + r""" + Base class for all trainers - this base class implements the basic functions that we + need for a trainer. + + The trainer needs to have the following functions: + - step: takes in a batch of data and performs a step of training + - loss: takes in a batch of data and returns the loss + - compute_rewards: takes in a batch of data and returns the rewards + - _build_models_and_tokenizer: builds the models and tokenizer + - _build_dataset: builds the dataset + Each user is expected to implement their own trainer class that inherits from this base + if they want to use a new training algorithm. + """ + + def __init__(self, config): + self.config = config + + def step(self, *args): + raise NotImplementedError("Not implemented") + + def loss(self, *args): + raise NotImplementedError("Not implemented") + + def compute_rewards(self, *args): + raise NotImplementedError("Not implemented") + + def _save_pretrained(self, save_directory): + raise NotImplementedError("Not implemented") diff --git a/trl/trainer/ddpo_config.py b/trl/trainer/ddpo_config.py new file mode 100755 index 000000000..2c910950f --- /dev/null +++ b/trl/trainer/ddpo_config.py @@ -0,0 +1,115 @@ +import os +import sys +import warnings +from dataclasses import dataclass, field +from typing import Literal, Optional + +from ..core import flatten_dict +from ..import_utils import is_bitsandbytes_available, is_torchvision_available + + +@dataclass +class DDPOConfig: + """ + Configuration class for DDPOTrainer + """ + + # common parameters + exp_name: str = os.path.basename(sys.argv[0])[: -len(".py")] + """the name of this experiment (by default is the file name without the extension name)""" + run_name: Optional[str] = "" + """Run name for wandb logging and checkpoint saving.""" + seed: int = 0 + """Seed value for random generations""" + log_with: Optional[Literal["wandb", "tensorboard"]] = None + """Log with either 'wandb' or 'tensorboard', check https://huggingface.co/docs/accelerate/usage_guides/tracking for more details""" + tracker_kwargs: dict = field(default_factory=dict) + """Keyword arguments for the tracker (e.g. wandb_project)""" + accelerator_kwargs: dict = field(default_factory=dict) + """Keyword arguments for the accelerator""" + project_kwargs: dict = field(default_factory=dict) + """Keyword arguments for the accelerator project config (e.g. `logging_dir`)""" + tracker_project_name: str = "trl" + """Name of project to use for tracking""" + logdir: str = "logs" + """Top-level logging directory for checkpoint saving.""" + + # hyperparameters + num_epochs: int = 100 + """Number of epochs to train.""" + save_freq: int = 1 + """Number of epochs between saving model checkpoints.""" + num_checkpoint_limit: int = 5 + """Number of checkpoints to keep before overwriting old ones.""" + mixed_precision: str = "fp16" + """Mixed precision training.""" + allow_tf32: bool = True + """Allow tf32 on Ampere GPUs.""" + resume_from: Optional[str] = "" + """Resume training from a checkpoint.""" + sample_num_steps: int = 50 + """Number of sampler inference steps.""" + sample_eta: float = 1.0 + """Eta parameter for the DDIM sampler.""" + sample_guidance_scale: float = 5.0 + """Classifier-free guidance weight.""" + sample_batch_size: int = 1 + """Batch size (per GPU!) to use for sampling.""" + sample_num_batches_per_epoch: int = 2 + """Number of batches to sample per epoch.""" + train_batch_size: int = 1 + """Batch size (per GPU!) to use for training.""" + train_use_8bit_adam: bool = False + """Whether to use the 8bit Adam optimizer from bitsandbytes.""" + train_learning_rate: float = 3e-4 + """Learning rate.""" + train_adam_beta1: float = 0.9 + """Adam beta1.""" + train_adam_beta2: float = 0.999 + """Adam beta2.""" + train_adam_weight_decay: float = 1e-4 + """Adam weight decay.""" + train_adam_epsilon: float = 1e-8 + """Adam epsilon.""" + train_gradient_accumulation_steps: int = 1 + """Number of gradient accumulation steps.""" + train_max_grad_norm: float = 1.0 + """Maximum gradient norm for gradient clipping.""" + train_num_inner_epochs: int = 1 + """Number of inner epochs per outer epoch.""" + train_cfg: bool = True + """Whether or not to use classifier-free guidance during training.""" + train_adv_clip_max: float = 5 + """Clip advantages to the range.""" + train_clip_range: float = 1e-4 + """The PPO clip range.""" + train_timestep_fraction: float = 1.0 + """The fraction of timesteps to train on.""" + per_prompt_stat_tracking: bool = False + """Whether to track statistics for each prompt separately.""" + per_prompt_stat_tracking_buffer_size: int = 16 + """Number of reward values to store in the buffer for each prompt.""" + per_prompt_stat_tracking_min_count: int = 16 + """The minimum number of reward values to store in the buffer.""" + async_reward_computation: bool = False + """Whether to compute rewards asynchronously.""" + max_workers: int = 2 + """The maximum number of workers to use for async reward computation.""" + negative_prompts: Optional[str] = "" + """Comma-separated list of prompts to use as negative examples.""" + + def to_dict(self): + output_dict = {} + for key, value in self.__dict__.items(): + output_dict[key] = value + return flatten_dict(output_dict) + + def __post_init__(self): + if self.log_with not in ["wandb", "tensorboard"]: + warnings.warn(("Accelerator tracking only supports image logging if `log_with` is set to 'wandb' or 'tensorboard'.")) + + if self.log_with == "wandb" and not is_torchvision_available(): + warnings.warn("Wandb image logging requires torchvision to be installed") + + if self.train_use_8bit_adam and not is_bitsandbytes_available(): + raise ImportError("You need to install bitsandbytes to use 8bit Adam. " "You can install it with `pip install bitsandbytes`.") diff --git a/trl/trainer/ddpo_trainer.py b/trl/trainer/ddpo_trainer.py new file mode 100755 index 000000000..f1f897aae --- /dev/null +++ b/trl/trainer/ddpo_trainer.py @@ -0,0 +1,604 @@ +# Copyright 2023 DDPO-pytorch authors (Kevin Black), metric-space, The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import warnings +from collections import defaultdict +from concurrent import futures +from typing import Any, Callable, Optional, Tuple +from warnings import warn + +import torch +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from huggingface_hub import whoami + +from ..models import DDPOStableDiffusionPipeline +from . import BaseTrainer, DDPOConfig +from .utils import PerPromptStatTracker + + +logger = get_logger(__name__) + + +MODEL_CARD_TEMPLATE = """--- +license: apache-2.0 +tags: +- trl +- ddpo +- diffusers +- reinforcement-learning +- text-to-image +- stable-diffusion +--- + +# {model_name} + +This is a diffusion model that has been fine-tuned with reinforcement learning to + guide the model outputs according to a value, function, or human feedback. The model can be used for image generation conditioned with text. + +""" + + +class DDPOTrainer(BaseTrainer): + """ + The DDPOTrainer uses Deep Diffusion Policy Optimization to optimise diffusion models. + Note, this trainer is heavily inspired by the work here: https://github.com/kvablack/ddpo-pytorch + As of now only Stable Diffusion based pipelines are supported + + Attributes: + **config** (`DDPOConfig`) -- Configuration object for DDPOTrainer. Check the documentation of `PPOConfig` for more + details. + **reward_function** (Callable[[torch.Tensor, Tuple[str], Tuple[Any]], torch.Tensor]) -- Reward function to be used + **prompt_function** (Callable[[], Tuple[str, Any]]) -- Function to generate prompts to guide model + **sd_pipeline** (`DDPOStableDiffusionPipeline`) -- Stable Diffusion pipeline to be used for training. + **image_samples_hook** (Optional[Callable[[Any, Any, Any], Any]]) -- Hook to be called to log images + """ + + _tag_names = ["trl", "ddpo"] + + def __init__( + self, + config: DDPOConfig, + reward_function: Callable[[torch.Tensor, Tuple[str], Tuple[Any]], torch.Tensor], + prompt_function: Callable[[], Tuple[str, Any]], + sd_pipeline: DDPOStableDiffusionPipeline, + image_samples_hook: Optional[Callable[[Any, Any, Any], Any]] = None, + ): + if image_samples_hook is None: + warn("No image_samples_hook provided; no images will be logged") + + self.prompt_fn = prompt_function + self.reward_fn = reward_function + self.config = config + self.image_samples_callback = image_samples_hook + + accelerator_project_config = ProjectConfiguration(**self.config.project_kwargs) + + if self.config.resume_from: + self.config.resume_from = os.path.normpath(os.path.expanduser(self.config.resume_from)) + if "checkpoint_" not in os.path.basename(self.config.resume_from): + # get the most recent checkpoint in this directory + checkpoints = list( + filter( + lambda x: "checkpoint_" in x, + os.listdir(self.config.resume_from), + ) + ) + if len(checkpoints) == 0: + raise ValueError(f"No checkpoints found in {self.config.resume_from}") + checkpoint_numbers = sorted([int(x.split("_")[-1]) for x in checkpoints]) + self.config.resume_from = os.path.join( + self.config.resume_from, + f"checkpoint_{checkpoint_numbers[-1]}", + ) + + accelerator_project_config.iteration = checkpoint_numbers[-1] + 1 + + # number of timesteps within each trajectory to train on + self.num_train_timesteps = int(self.config.sample_num_steps * self.config.train_timestep_fraction) + + self.accelerator = Accelerator( + log_with=self.config.log_with, + mixed_precision=self.config.mixed_precision, + project_config=accelerator_project_config, + # we always accumulate gradients across timesteps; we want config.train.gradient_accumulation_steps to be the + # number of *samples* we accumulate across, so we need to multiply by the number of training timesteps to get + # the total number of optimizer steps to accumulate across. + gradient_accumulation_steps=self.config.train_gradient_accumulation_steps * self.num_train_timesteps, + **self.config.accelerator_kwargs, + ) + + is_okay, message = self._config_check() + if not is_okay: + raise ValueError(message) + + is_using_tensorboard = config.log_with is not None and config.log_with == "tensorboard" + + if self.accelerator.is_main_process: + self.accelerator.init_trackers( + self.config.tracker_project_name, + config=dict(ddpo_trainer_config=config.to_dict()) if not is_using_tensorboard else config.to_dict(), + init_kwargs=self.config.tracker_kwargs, + ) + + logger.info(f"\n{config}") + + set_seed(self.config.seed, device_specific=True) + + self.sd_pipeline = sd_pipeline + + self.sd_pipeline.set_progress_bar_config( + position=1, + disable=not self.accelerator.is_local_main_process, + leave=False, + desc="Timestep", + dynamic_ncols=True, + ) + + # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + if self.accelerator.mixed_precision == "fp16": + inference_dtype = torch.float16 + elif self.accelerator.mixed_precision == "bf16": + inference_dtype = torch.bfloat16 + else: + inference_dtype = torch.float32 + + self.sd_pipeline.vae.to(self.accelerator.device, dtype=inference_dtype) + self.sd_pipeline.text_encoder.to(self.accelerator.device, dtype=inference_dtype) + self.sd_pipeline.unet.to(self.accelerator.device, dtype=inference_dtype) + + trainable_layers = self.sd_pipeline.get_trainable_layers() + + self.accelerator.register_save_state_pre_hook(self._save_model_hook) + self.accelerator.register_load_state_pre_hook(self._load_model_hook) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if self.config.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + self.optimizer = self._setup_optimizer(trainable_layers.parameters() if not isinstance(trainable_layers, list) else trainable_layers) + + self.neg_prompt_embed = self.sd_pipeline.text_encoder( + self.sd_pipeline.tokenizer( + [""] if self.config.negative_prompts is None else self.config.negative_prompts, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=self.sd_pipeline.tokenizer.model_max_length, + ).input_ids.to(self.accelerator.device) + )[0] + + if config.per_prompt_stat_tracking: + self.stat_tracker = PerPromptStatTracker( + config.per_prompt_stat_tracking_buffer_size, + config.per_prompt_stat_tracking_min_count, + ) + + # NOTE: for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses + # more memory + self.autocast = self.sd_pipeline.autocast or self.accelerator.autocast + + if hasattr(self.sd_pipeline, "use_lora") and self.sd_pipeline.use_lora: + unet, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer) + self.trainable_layers = list(filter(lambda p: p.requires_grad, unet.parameters())) + else: + self.trainable_layers, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer) + + if self.config.async_reward_computation: + self.executor = futures.ThreadPoolExecutor(max_workers=config.max_workers) + + if config.resume_from: + logger.info(f"Resuming from {config.resume_from}") + self.accelerator.load_state(config.resume_from) + self.first_epoch = int(config.resume_from.split("_")[-1]) + 1 + else: + self.first_epoch = 0 + + def compute_rewards(self, prompt_image_pairs, is_async=False): + if not is_async: + rewards = [] + for images, prompts, prompt_metadata in prompt_image_pairs: + reward, reward_metadata = self.reward_fn(images, prompts, prompt_metadata) + rewards.append( + ( + torch.as_tensor(reward, device=self.accelerator.device), + reward_metadata, + ) + ) + else: + rewards = self.executor.map(lambda x: self.reward_fn(*x), prompt_image_pairs) + rewards = [(torch.as_tensor(reward.result(), device=self.accelerator.device), reward_metadata.result()) for reward, reward_metadata in rewards] + + return zip(*rewards) + + def step(self, epoch: int, global_step: int): + """ + Perform a single step of training. + + Args: + epoch (int): The current epoch. + global_step (int): The current global step. + + Side Effects: + - Model weights are updated + - Logs the statistics to the accelerator trackers. + - If `self.image_samples_callback` is not None, it will be called with the prompt_image_pairs, global_step, and the accelerator tracker. + + Returns: + global_step (int): The updated global step. + + """ + samples, prompt_image_data = self._generate_samples( + iterations=self.config.sample_num_batches_per_epoch, + batch_size=self.config.sample_batch_size, + ) + + # collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...) + samples = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys()} + rewards, rewards_metadata = self.compute_rewards(prompt_image_data, is_async=self.config.async_reward_computation) + + for i, image_data in enumerate(prompt_image_data): + image_data.extend([rewards[i], rewards_metadata[i]]) + + if self.image_samples_callback is not None: + self.image_samples_callback(prompt_image_data, global_step, self.accelerator.trackers[0]) + + rewards = torch.cat(rewards) + rewards = self.accelerator.gather(rewards).cpu().numpy() + + self.accelerator.log( + { + "reward": rewards, + "epoch": epoch, + "reward_mean": rewards.mean(), + "reward_std": rewards.std(), + }, + step=global_step, + ) + + if self.config.per_prompt_stat_tracking: + # gather the prompts across processes + prompt_ids = self.accelerator.gather(samples["prompt_ids"]).cpu().numpy() + prompts = self.sd_pipeline.tokenizer.batch_decode(prompt_ids, skip_special_tokens=True) + advantages = self.stat_tracker.update(prompts, rewards) + else: + advantages = (rewards - rewards.mean()) / (rewards.std() + 1e-8) + + # ungather advantages; keep the entries corresponding to the samples on this process + samples["advantages"] = torch.as_tensor(advantages).reshape(self.accelerator.num_processes, -1)[self.accelerator.process_index].to(self.accelerator.device) + + del samples["prompt_ids"] + + total_batch_size, num_timesteps = samples["timesteps"].shape + + for inner_epoch in range(self.config.train_num_inner_epochs): + # shuffle samples along batch dimension + perm = torch.randperm(total_batch_size, device=self.accelerator.device) + samples = {k: v[perm] for k, v in samples.items()} + + # shuffle along time dimension independently for each sample + # still trying to understand the code below + perms = torch.stack([torch.randperm(num_timesteps, device=self.accelerator.device) for _ in range(total_batch_size)]) + + for key in ["timesteps", "latents", "next_latents", "log_probs"]: + samples[key] = samples[key][ + torch.arange(total_batch_size, device=self.accelerator.device)[:, None], + perms, + ] + + original_keys = samples.keys() + original_values = samples.values() + # rebatch them as user defined train_batch_size is different from sample_batch_size + reshaped_values = [v.reshape(-1, self.config.train_batch_size, *v.shape[1:]) for v in original_values] + + # Transpose the list of original values + transposed_values = zip(*reshaped_values) + # Create new dictionaries for each row of transposed values + samples_batched = [dict(zip(original_keys, row_values)) for row_values in transposed_values] + + self.sd_pipeline.unet.train() + global_step = self._train_batched_samples(inner_epoch, epoch, global_step, samples_batched) + # ensure optimization step at the end of the inner epoch + if not self.accelerator.sync_gradients: + raise ValueError("Optimization step should have been performed by this point. Please check calculated gradient accumulation settings.") + + if epoch != 0 and epoch % self.config.save_freq == 0 and self.accelerator.is_main_process: + self.accelerator.save_state() + + return global_step + + def calculate_loss(self, latents, timesteps, next_latents, log_probs, advantages, embeds): + """ + Calculate the loss for a batch of an unpacked sample + + Args: + latents (torch.Tensor): + The latents sampled from the diffusion model, shape: [batch_size, num_channels_latents, height, width] + timesteps (torch.Tensor): + The timesteps sampled from the diffusion model, shape: [batch_size] + next_latents (torch.Tensor): + The next latents sampled from the diffusion model, shape: [batch_size, num_channels_latents, height, width] + log_probs (torch.Tensor): + The log probabilities of the latents, shape: [batch_size] + advantages (torch.Tensor): + The advantages of the latents, shape: [batch_size] + embeds (torch.Tensor): + The embeddings of the prompts, shape: [2*batch_size or batch_size, ...] + Note: the "or" is because if train_cfg is True, the expectation is that negative prompts are concatenated to the embeds + + Returns: + loss (torch.Tensor), approx_kl (torch.Tensor), clipfrac (torch.Tensor) + (all of these are of shape (1,)) + """ + with self.autocast(): + if self.config.train_cfg: + noise_pred = self.sd_pipeline.unet( + torch.cat([latents] * 2), + torch.cat([timesteps] * 2), + embeds, + ).sample + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.config.sample_guidance_scale * (noise_pred_text - noise_pred_uncond) + else: + noise_pred = self.sd_pipeline.unet( + latents, + timesteps, + embeds, + ).sample + # compute the log prob of next_latents given latents under the current model + + scheduler_step_output = self.sd_pipeline.scheduler_step( + noise_pred, + timesteps, + latents, + eta=self.config.sample_eta, + prev_sample=next_latents, + ) + + log_prob = scheduler_step_output.log_probs + + advantages = torch.clamp( + advantages, + -self.config.train_adv_clip_max, + self.config.train_adv_clip_max, + ) + + ratio = torch.exp(log_prob - log_probs) + + loss = self.loss(advantages, self.config.train_clip_range, ratio) + + approx_kl = 0.5 * torch.mean((log_prob - log_probs) ** 2) + + clipfrac = torch.mean((torch.abs(ratio - 1.0) > self.config.train_clip_range).float()) + + return loss, approx_kl, clipfrac + + def loss( + self, + advantages: torch.Tensor, + clip_range: float, + ratio: torch.Tensor, + ): + unclipped_loss = -advantages * ratio + clipped_loss = -advantages * torch.clamp( + ratio, + 1.0 - clip_range, + 1.0 + clip_range, + ) + return torch.mean(torch.maximum(unclipped_loss, clipped_loss)) + + def _setup_optimizer(self, trainable_layers_parameters): + if self.config.train_use_8bit_adam: + import bitsandbytes + + optimizer_cls = bitsandbytes.optim.AdamW8bit + else: + optimizer_cls = torch.optim.AdamW + + return optimizer_cls( + trainable_layers_parameters, + lr=self.config.train_learning_rate, + betas=(self.config.train_adam_beta1, self.config.train_adam_beta2), + weight_decay=self.config.train_adam_weight_decay, + eps=self.config.train_adam_epsilon, + ) + + def _save_model_hook(self, models, weights, output_dir): + self.sd_pipeline.save_checkpoint(models, weights, output_dir) + weights.pop() # ensures that accelerate doesn't try to handle saving of the model + + def _load_model_hook(self, models, input_dir): + self.sd_pipeline.load_checkpoint(models, input_dir) + models.pop() # ensures that accelerate doesn't try to handle loading of the model + + def _generate_samples(self, iterations, batch_size): + """ + Generate samples from the model + + Args: + iterations (int): Number of iterations to generate samples for + batch_size (int): Batch size to use for sampling + + Returns: + samples (List[Dict[str, torch.Tensor]]), prompt_image_pairs (List[List[Any]]) + """ + samples = [] + prompt_image_pairs = [] + self.sd_pipeline.unet.eval() + + sample_neg_prompt_embeds = self.neg_prompt_embed.repeat(batch_size, 1, 1) + + for _ in range(iterations): + prompts, prompt_metadata = zip(*[self.prompt_fn() for _ in range(batch_size)]) + + prompt_ids = self.sd_pipeline.tokenizer( + prompts, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=self.sd_pipeline.tokenizer.model_max_length, + ).input_ids.to(self.accelerator.device) + prompt_embeds = self.sd_pipeline.text_encoder(prompt_ids)[0] + + with self.autocast(): + sd_output = self.sd_pipeline( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=sample_neg_prompt_embeds, + num_inference_steps=self.config.sample_num_steps, + guidance_scale=self.config.sample_guidance_scale, + eta=self.config.sample_eta, + output_type="pt", + ) + + images = sd_output.images + latents = sd_output.latents + log_probs = sd_output.log_probs + + latents = torch.stack(latents, dim=1) # (batch_size, num_steps + 1, ...) + log_probs = torch.stack(log_probs, dim=1) # (batch_size, num_steps, 1) + timesteps = self.sd_pipeline.scheduler.timesteps.repeat(batch_size, 1) # (batch_size, num_steps) + + samples.append( + { + "prompt_ids": prompt_ids, + "prompt_embeds": prompt_embeds, + "timesteps": timesteps, + "latents": latents[:, :-1], # each entry is the latent before timestep t + "next_latents": latents[:, 1:], # each entry is the latent after timestep t + "log_probs": log_probs, + "negative_prompt_embeds": sample_neg_prompt_embeds, + } + ) + prompt_image_pairs.append([images, prompts, prompt_metadata]) + + return samples, prompt_image_pairs + + def _train_batched_samples(self, inner_epoch, epoch, global_step, batched_samples): + """ + Train on a batch of samples. Main training segment + + Args: + inner_epoch (int): The current inner epoch + epoch (int): The current epoch + global_step (int): The current global step + batched_samples (List[Dict[str, torch.Tensor]]): The batched samples to train on + + Side Effects: + - Model weights are updated + - Logs the statistics to the accelerator trackers. + + Returns: + global_step (int): The updated global step + """ + info = defaultdict(list) + for i, sample in enumerate(batched_samples): + if self.config.train_cfg: + # concat negative prompts to sample prompts to avoid two forward passes + embeds = torch.cat([sample["negative_prompt_embeds"], sample["prompt_embeds"]]) + else: + embeds = sample["prompt_embeds"] + + for j in range(self.num_train_timesteps): + with self.accelerator.accumulate(self.sd_pipeline.unet): + loss, approx_kl, clipfrac = self.calculate_loss( + sample["latents"][:, j], + sample["timesteps"][:, j], + sample["next_latents"][:, j], + sample["log_probs"][:, j], + sample["advantages"], + embeds, + ) + info["approx_kl"].append(approx_kl) + info["clipfrac"].append(clipfrac) + info["loss"].append(loss) + + self.accelerator.backward(loss) + if self.accelerator.sync_gradients: + self.accelerator.clip_grad_norm_( + self.trainable_layers.parameters() if not isinstance(self.trainable_layers, list) else self.trainable_layers, + self.config.train_max_grad_norm, + ) + self.optimizer.step() + self.optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if self.accelerator.sync_gradients: + # log training-related stuff + info = {k: torch.mean(torch.stack(v)) for k, v in info.items()} + info = self.accelerator.reduce(info, reduction="mean") + info.update({"epoch": epoch, "inner_epoch": inner_epoch}) + self.accelerator.log(info, step=global_step) + global_step += 1 + info = defaultdict(list) + return global_step + + def _config_check(self) -> Tuple[bool, str]: + samples_per_epoch = self.config.sample_batch_size * self.accelerator.num_processes * self.config.sample_num_batches_per_epoch + total_train_batch_size = self.config.train_batch_size * self.accelerator.num_processes * self.config.train_gradient_accumulation_steps + + if not self.config.sample_batch_size >= self.config.train_batch_size: + return ( + False, + f"Sample batch size ({self.config.sample_batch_size}) must be greater than or equal to the train batch size ({self.config.train_batch_size})", + ) + if not self.config.sample_batch_size % self.config.train_batch_size == 0: + return ( + False, + f"Sample batch size ({self.config.sample_batch_size}) must be divisible by the train batch size ({self.config.train_batch_size})", + ) + if not samples_per_epoch % total_train_batch_size == 0: + return ( + False, + f"Number of samples per epoch ({samples_per_epoch}) must be divisible by the total train batch size ({total_train_batch_size})", + ) + return True, "" + + def train(self, epochs: Optional[int] = None): + """ + Train the model for a given number of epochs + """ + global_step = 0 + if epochs is None: + epochs = self.config.num_epochs + for epoch in range(self.first_epoch, epochs): + global_step = self.step(epoch, global_step) + + def create_model_card(self, path: str, model_name: Optional[str] = "TRL DDPO Model") -> None: + """Creates and saves a model card for a TRL model. + + Args: + path (`str`): The path to save the model card to. + model_name (`str`, *optional*): The name of the model, defaults to `TRL DDPO Model`. + """ + try: + user = whoami()["name"] + # handle the offline case + except: # noqa + warnings.warn("Cannot retrieve user information assuming you are running in offline mode.") + return + + if not os.path.exists(path): + os.makedirs(path) + + model_card_content = MODEL_CARD_TEMPLATE.format(model_name=model_name, model_id=f"{user}/{path}") + with open(os.path.join(path, "README.md"), "w", encoding="utf-8") as f: + f.write(model_card_content) + + def _save_pretrained(self, save_directory): + self.sd_pipeline.save_pretrained(save_directory) + self.create_model_card(save_directory) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py new file mode 100755 index 000000000..ae2df8980 --- /dev/null +++ b/trl/trainer/dpo_trainer.py @@ -0,0 +1,1186 @@ +# DPO Authors: Rafael Rafailov, Archit Sharma, Eric Mitchell, Stefano Ermon, Christopher D. Manning, and Chelsea Finn 2023 +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +import random +import warnings +from collections import defaultdict +from contextlib import contextmanager, nullcontext +from copy import deepcopy +from functools import wraps +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from accelerate import PartialState +from accelerate.utils import is_deepspeed_available, tqdm +from datasets import Dataset +from torch.utils.data import DataLoader +from transformers import ( + AutoModelForCausalLM, + DataCollator, + PreTrainedModel, + PreTrainedTokenizerBase, + Trainer, + TrainingArguments, +) +from transformers.trainer_callback import TrainerCallback +from transformers.trainer_utils import EvalLoopOutput + +from ..import_utils import is_peft_available, is_wandb_available +from ..models import PreTrainedModelWrapper, create_reference_model +from .utils import ( + DPODataCollatorWithPadding, + disable_dropout_in_model, + pad_to_length, + peft_module_casting_to_bf16, + trl_sanitze_kwargs_for_tagging, +) + + +if is_peft_available(): + from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training + + +if is_wandb_available(): + import wandb + +if is_deepspeed_available(): + import deepspeed + +from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled + + +class DPOTrainer(Trainer): + r""" + Initialize DPOTrainer. + + Args: + model (`transformers.PreTrainedModel`): + The model to train, preferably an `AutoModelForSequenceClassification`. + ref_model (`PreTrainedModelWrapper`): + Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no + reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized. + beta (`float`, defaults to 0.1): + The beta factor in DPO loss. Higher beta means less divergence from the initial policy. For the IPO loss, beta is the regularization parameter denoted by tau in the paper. + label_smoothing (`float`, defaults to 0): + The robust DPO label smoothing parameter from the [cDPO](https://ericmitchell.ai/cdpo.pdf) report that should be between 0 and 0.5. + loss_type (`str`, defaults to `"sigmoid"`): + The type of DPO loss to use. Either `"sigmoid"` the default DPO loss,`"hinge"` loss from [SLiC](https://arxiv.org/abs/2305.10425) paper, `"ipo"` from [IPO](https://arxiv.org/abs/2310.12036) paper, or `"kto"` from the HALOs [report](https://github.com/ContextualAI/HALOs/blob/main/assets/report.pdf). + args (`transformers.TrainingArguments`): + The arguments to use for training. + data_collator (`transformers.DataCollator`): + The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used + which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences. + label_pad_token_id (`int`, defaults to `-100`): + The label pad token id. This argument is required if you want to use the default data collator. + padding_value (`int`, defaults to `0`): + The padding value if it is different to the tokenizer's pad_token_id. + truncation_mode (`str`, defaults to `keep_end`): + The truncation mode to use, either `keep_end` or `keep_start`. This argument is required if you want to use the default data collator. + train_dataset (`datasets.Dataset`): + The dataset to use for training. + eval_dataset (`datasets.Dataset`): + The dataset to use for evaluation. + tokenizer (`transformers.PreTrainedTokenizerBase`): + The tokenizer to use for training. This argument is required if you want to use the default data collator. + model_init (`Callable[[], transformers.PreTrainedModel]`): + The model initializer to use for training. If None is specified, the default model initializer will be used. + callbacks (`List[transformers.TrainerCallback]`): + The callbacks to use for training. + optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): + The optimizer and scheduler to use for training. + preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): + The function to use to preprocess the logits before computing the metrics. + max_length (`int`, defaults to `None`): + The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator. + max_prompt_length (`int`, defaults to `None`): + The maximum length of the prompt. This argument is required if you want to use the default data collator. + max_target_length (`int`, defaults to `None`): + The maximum length of the target. This argument is required if you want to use the default data collator and your model is an encoder-decoder. + peft_config (`Dict`, defaults to `None`): + The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model. + is_encoder_decoder (`Optional[bool]`, `optional`, defaults to `None`): + If no model is provided, we need to know if the model_init returns an encoder-decoder. + disable_dropout (`bool`, defaults to `True`): + Whether or not to disable dropouts in `model` and `ref_model`. + generate_during_eval (`bool`, defaults to `False`): + Whether to sample and log generations during evaluation step. + compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*): + The function to use to compute the metrics. Must take a `EvalPrediction` and return + a dictionary string to metric values. + precompute_ref_log_probs (`bool`, defaults to `False`): + Flag to precompute reference model log probabilities and evaluation datasets. This is useful if you want to train + without the reference model and reduce the total GPU memory needed. + dataset_num_proc (`Optional[int]`, *optional*): + The number of workers to use to tokenize the data. Defaults to None. + model_init_kwargs (`Optional[Dict]`, *optional*): + Dict of Optional kwargs to pass when instantiating the model from a string + ref_model_init_kwargs (`Optional[Dict]`, *optional*): + Dict of Optional kwargs to pass when instantiating the ref model from a string + model_adapter_name (`str`, defaults to `None`): + Name of the train target PEFT adapter, when using LoRA with multiple adapters. + ref_adapter_name (`str`, defaults to `None`): + Name of the reference PEFT adapter, when using LoRA with multiple adapters. + reference_free (`bool`): + If True, we ignore the _provided_ reference model and implicitly use a reference model that assigns equal probability to all responses. + """ + + _tag_names = ["trl", "dpo"] + + def __init__( + self, + model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, + ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, + dpo_alpha: float = 1.0, + beta: float = 0.1, + gamma: float = 0.1, + label_smoothing: float = 0, + loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair"] = "sigmoid", + args: Optional[TrainingArguments] = None, + data_collator: Optional[DataCollator] = None, + label_pad_token_id: int = -100, + padding_value: Optional[int] = None, + truncation_mode: str = "keep_end", + train_dataset: Optional[Dataset] = None, + eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, + tokenizer: Optional[PreTrainedTokenizerBase] = None, + model_init: Optional[Callable[[], PreTrainedModel]] = None, + callbacks: Optional[List[TrainerCallback]] = None, + optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), + preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, + max_length: Optional[int] = None, + max_prompt_length: Optional[int] = None, + max_target_length: Optional[int] = None, + peft_config: Optional[Dict] = None, + is_encoder_decoder: Optional[bool] = None, + disable_dropout: bool = True, + generate_during_eval: bool = False, + compute_metrics: Optional[Callable[[EvalLoopOutput], Dict]] = None, + precompute_ref_log_probs: bool = False, + dataset_num_proc: Optional[int] = None, + model_init_kwargs: Optional[Dict] = None, + ref_model_init_kwargs: Optional[Dict] = None, + model_adapter_name: Optional[str] = None, + ref_adapter_name: Optional[str] = None, + reference_free: bool = False, + ): + # import pdb;pdb.set_trace() + if model_init_kwargs is None: + model_init_kwargs = {} + elif not isinstance(model, str): + raise ValueError("You passed model_kwargs to the DPOTrainer. But your model is already instantiated.") + + if ref_model_init_kwargs is None: + ref_model_init_kwargs = {} + elif not isinstance(ref_model, str): + raise ValueError("You passed ref_model_kwargs to the DPOTrainer. But your ref_model is already instantiated.") + + if isinstance(model, str): + warnings.warn("You passed a model_id to the DPOTrainer. This will automatically create an " "`AutoModelForCausalLM` or a `PeftModel` (if you passed a `peft_config`) for you.") + model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs) + + if isinstance(ref_model, str): + warnings.warn("You passed a ref model_id to the DPOTrainer. This will automatically create an " "`AutoModelForCausalLM`") + ref_model = AutoModelForCausalLM.from_pretrained(ref_model, **ref_model_init_kwargs) + + # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16` + # has been called in order to properly call autocast if needed. + self._peft_has_been_casted_to_bf16 = False + + if generate_during_eval and not is_wandb_available(): + raise ValueError("`generate_during_eval=True` requires Weights and Biases to be installed." " Please install `wandb` to resolve.") + + if model is not None: + self.is_encoder_decoder = model.config.is_encoder_decoder + elif is_encoder_decoder is None: + raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.") + else: + self.is_encoder_decoder = is_encoder_decoder + + self.is_peft_model = is_peft_available() and isinstance(model, PeftModel) + self.model_adapter_name = model_adapter_name + self.ref_adapter_name = ref_adapter_name + self.reference_free = reference_free + + if ref_model: + self.ref_model = ref_model + elif self.is_peft_model or precompute_ref_log_probs: + # The `model` with adapters turned off will be used as the reference model + self.ref_model = None + else: + if is_deepspeed_zero3_enabled(): + self.ref_model = AutoModelForCausalLM.from_pretrained(model) + else: + self.ref_model = create_reference_model(model) + + if tokenizer is None: + raise ValueError("tokenizer must be specified to tokenize a DPO dataset.") + if max_length is None: + warnings.warn( + "`max_length` is not set in the DPOTrainer's init" " it will default to `512` by default, but you should do it yourself in the future.", + UserWarning, + ) + max_length = 512 + if max_prompt_length is None: + warnings.warn( + "`max_prompt_length` is not set in the DPOTrainer's init" " it will default to `128` by default, but you should do it yourself in the future.", + UserWarning, + ) + max_prompt_length = 128 + + if max_target_length is None and self.is_encoder_decoder: + warnings.warn( + "When using an encoder decoder architecture, you should set `max_target_length` in the DPOTrainer's init" " it will default to `128` by default, but you should do it yourself in the future.", + UserWarning, + ) + max_target_length = 128 + + if data_collator is None: + data_collator = DPODataCollatorWithPadding( + pad_token_id=tokenizer.pad_token_id, + label_pad_token_id=label_pad_token_id, + is_encoder_decoder=self.is_encoder_decoder, + ) + + if args.remove_unused_columns: + args.remove_unused_columns = False + # warn users + warnings.warn( + "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments" " we have set it for you, but you should do it yourself in the future.", + UserWarning, + ) + + self.use_dpo_data_collator = True + else: + self.use_dpo_data_collator = False + + if disable_dropout: + disable_dropout_in_model(model) + if self.ref_model is not None: + disable_dropout_in_model(self.ref_model) + + self.max_length = max_length + self.generate_during_eval = generate_during_eval + self.label_pad_token_id = label_pad_token_id + self.padding_value = padding_value if padding_value is not None else tokenizer.pad_token_id + self.max_prompt_length = max_prompt_length + self.truncation_mode = truncation_mode + self.max_target_length = max_target_length + self.tokenizer = tokenizer + self.precompute_ref_log_probs = precompute_ref_log_probs + + # Since ref_logs are precomputed on the first call to get_train/eval_dataloader + # keep track of first called to avoid computation of future calls + self._precomputed_train_ref_log_probs = False + self._precomputed_eval_ref_log_probs = False + + if loss_type in ["hinge", "ipo", "kto_pair"] and label_smoothing > 0: + warnings.warn("You are using a loss type that does not support label smoothing. Ignoring label_smoothing parameter.") + + self.dpo_alpha = dpo_alpha + self.beta = beta + self.gamma = gamma + self.label_smoothing = label_smoothing + self.loss_type = loss_type + + self._stored_metrics = defaultdict(lambda: defaultdict(list)) + + self.dataset_num_proc = dataset_num_proc + + # Compute that only on the main process for faster data processing. + # see: https://github.com/huggingface/trl/pull/1255 + # with PartialState().local_main_process_first(): + # # tokenize the dataset + # train_dataset = train_dataset.map(self.tokenize_row, num_proc=self.dataset_num_proc) + # if eval_dataset is not None: + # eval_dataset = eval_dataset.map(self.tokenize_row, num_proc=self.dataset_num_proc) + + super().__init__( + model=model, + args=args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + tokenizer=tokenizer, + model_init=model_init, + compute_metrics=compute_metrics, + callbacks=callbacks, + optimizers=optimizers, + preprocess_logits_for_metrics=preprocess_logits_for_metrics, + ) + + if not hasattr(self, "accelerator"): + raise AttributeError("Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`.") + + # Deepspeed Zero-3 does not support precompute_ref_log_probs + if self.is_deepspeed_enabled: + if self.accelerator.state.deepspeed_plugin.zero_stage == 3 and self.precompute_ref_log_probs: + raise ValueError("You cannot use `precompute_ref_log_probs=True` with Deepspeed ZeRO-3. Please set `precompute_ref_log_probs=False`.") + + if self.ref_model is None: + if not (self.is_peft_model or self.precompute_ref_log_probs): + raise ValueError("No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`") + else: + if self.is_deepspeed_enabled: + self.ref_model = self._prepare_deepspeed(self.ref_model) + else: + self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) + + def _prepare_deepspeed(self, model: PreTrainedModelWrapper): + # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473 + deepspeed_plugin = self.accelerator.state.deepspeed_plugin + config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config) + + if model is not None: + if hasattr(model, "config"): + hidden_size = max(model.config.hidden_sizes) if getattr(model.config, "hidden_sizes", None) else getattr(model.config, "hidden_size", None) + if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3: + # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0` + # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081 + config_kwargs.update( + { + "zero_optimization.reduce_bucket_size": hidden_size * hidden_size, + "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size, + "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size, + } + ) + + # If ZeRO-3 is used, we shard both the active and reference model. + # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0) + if config_kwargs["zero_optimization"]["stage"] != 3: + config_kwargs["zero_optimization"]["stage"] = 0 + model, *_ = deepspeed.initialize(model=model, config=config_kwargs) + model.eval() + return model + + def get_train_dataloader(self) -> DataLoader: + """ + Returns the training [`~torch.utils.data.DataLoader`]. + + Subclass of transformers.src.transformers.trainer.get_train_dataloader to precompute `ref_log_probs`. + """ + + if self.precompute_ref_log_probs and not self._precomputed_train_ref_log_probs: + dataloader_params = { + "batch_size": self.args.per_device_train_batch_size, + "collate_fn": self.data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + "shuffle": False, + } + + # prepare dataloader + data_loader = self.accelerator.prepare(DataLoader(self.train_dataset, **dataloader_params)) + + reference_chosen_logps = [] + reference_rejected_logps = [] + for padded_batch in tqdm(iterable=data_loader, desc="Train dataset reference log probs"): + reference_chosen_logp, reference_rejected_logp = self.compute_reference_log_probs(padded_batch) + reference_chosen_logp, reference_rejected_logp = self.accelerator.gather_for_metrics((reference_chosen_logp, reference_rejected_logp)) + reference_chosen_logps.append(reference_chosen_logp.cpu()) + reference_rejected_logps.append(reference_rejected_logp.cpu()) + + all_reference_chosen_logps = torch.cat(reference_chosen_logps).float().numpy() + all_reference_rejected_logps = torch.cat(reference_rejected_logps).float().numpy() + + self.train_dataset = self.train_dataset.add_column(name="reference_chosen_logps", column=all_reference_chosen_logps) + self.train_dataset = self.train_dataset.add_column(name="reference_rejected_logps", column=all_reference_rejected_logps) + + self._precomputed_train_ref_log_probs = True + + return super().get_train_dataloader() + + def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader: + """ + Returns the evaluation [`~torch.utils.data.DataLoader`]. + + Subclass of transformers.src.transformers.trainer.get_eval_dataloader to precompute `ref_log_probs`. + + Args: + eval_dataset (`torch.utils.data.Dataset`, *optional*): + If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted + by the `model.forward()` method are automatically removed. It must implement `__len__`. + """ + if eval_dataset is None and self.eval_dataset is None: + raise ValueError("Trainer: evaluation requires an eval_dataset.") + eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset + + if self.precompute_ref_log_probs and not self._precomputed_eval_ref_log_probs: + dataloader_params = { + "batch_size": self.args.per_device_eval_batch_size, + "collate_fn": self.data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + "shuffle": False, + } + + # prepare dataloader + data_loader = self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params)) + + reference_chosen_logps = [] + reference_rejected_logps = [] + for padded_batch in tqdm(iterable=data_loader, desc="Eval dataset reference log probs"): + reference_chosen_logp, reference_rejected_logp = self.compute_reference_log_probs(padded_batch) + reference_chosen_logp, reference_rejected_logp = self.accelerator.gather_for_metrics((reference_chosen_logp, reference_rejected_logp)) + reference_chosen_logps.append(reference_chosen_logp.cpu()) + reference_rejected_logps.append(reference_rejected_logp.cpu()) + + all_reference_chosen_logps = torch.cat(reference_chosen_logps).float().numpy() + all_reference_rejected_logps = torch.cat(reference_rejected_logps).float().numpy() + + eval_dataset = eval_dataset.add_column(name="reference_chosen_logps", column=all_reference_chosen_logps) + eval_dataset = eval_dataset.add_column(name="reference_rejected_logps", column=all_reference_rejected_logps) + + # Save calculated reference_chosen_logps and reference_rejected_logps to the eval_dataset for subsequent runs + if self.eval_dataset is not None: + self.eval_dataset = eval_dataset + self._precomputed_eval_ref_log_probs = True + + return super().get_eval_dataloader(eval_dataset=eval_dataset) + + def build_tokenized_answer(self, prompt, answer): + """ + Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`. + It does ensure `enc(a + b) = enc(a) + enc(a + b)[len(enc(a)):]`. + Reference: + https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257 + """ + + full_tokenized = self.tokenizer(prompt + answer, add_special_tokens=False) + prompt_input_ids = self.tokenizer(prompt, add_special_tokens=False)["input_ids"] + + answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids) :] + answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids) :] + + # Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]` + full_concat_input_ids = np.concatenate([prompt_input_ids, answer_input_ids]) + + # Prepare input tokens for token by token comparison + full_input_ids = np.array(full_tokenized["input_ids"]) + + if len(full_input_ids) != len(full_concat_input_ids): + raise ValueError("Prompt input ids and answer input ids should have the same length.") + + # On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens + # can be merged together when tokenizing prompt+answer. This could result + # on the last token from the prompt being different when tokenized on its own + # vs when done as prompt+answer. + response_token_ids_start_idx = len(prompt_input_ids) + + # If tokenized prompt is different than both prompt+answer, then it means the + # last token has changed due to merging. + if prompt_input_ids != full_tokenized["input_ids"][:response_token_ids_start_idx]: + response_token_ids_start_idx -= 1 + + prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx] + prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx] + + if len(prompt_input_ids) != len(prompt_attention_mask): + raise ValueError("Prompt input ids and attention mask should have the same length.") + + answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:] + answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:] + + return dict( + prompt_input_ids=prompt_input_ids, + prompt_attention_mask=prompt_attention_mask, + input_ids=answer_input_ids, + attention_mask=answer_attention_mask, + ) + + def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module]] = None) -> Dict: + """Tokenize a single row from a DPO specific dataset. + + At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation + in case the prompt + chosen or prompt + rejected responses is/are too long. First + we truncate the prompt; if we're still too long, we truncate the chosen/rejected. + + We also create the labels for the chosen/rejected responses, which are of length equal to + the sum of the length of the prompt and the chosen/rejected response, with + label_pad_token_id for the prompt tokens. + """ + batch = {} + prompt = feature["prompt"] + chosen = feature["chosen"] + rejected = feature["rejected"] + + if not self.is_encoder_decoder: + # Check issues below for more details + # 1. https://github.com/huggingface/trl/issues/907 + # 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257 + # 3. https://github.com/LianjiaTech/BELLE/issues/337 + + if not isinstance(prompt, str): + raise ValueError(f"prompt should be an str but got {type(prompt)}") + prompt_tokens = self.tokenizer(prompt, add_special_tokens=False) + prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()} + + if not isinstance(chosen, str): + raise ValueError(f"chosen should be an str but got {type(chosen)}") + chosen_tokens = self.build_tokenized_answer(prompt, chosen) + + if not isinstance(rejected, str): + raise ValueError(f"rejected should be an str but got {type(rejected)}") + rejected_tokens = self.build_tokenized_answer(prompt, rejected) + + # Last prompt token might get merged by tokenizer and + # it should not be included for generation if that happens + prompt_len_input_ids = len(prompt_tokens["prompt_input_ids"]) + + chosen_prompt_len_input_ids = len(chosen_tokens["prompt_input_ids"]) + rejected_prompt_len_input_ids = len(rejected_tokens["prompt_input_ids"]) + prompt_len_input_ids = min(chosen_prompt_len_input_ids, rejected_prompt_len_input_ids) + + for k, v in prompt_tokens.items(): + prompt_tokens[k] = v[:prompt_len_input_ids] + + # Make sure prompts only have one different token at most an + # and length only differs by 1 at most + num_diff_tokens = sum([a != b for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"])]) + num_diff_len = abs(chosen_prompt_len_input_ids - rejected_prompt_len_input_ids) + if num_diff_tokens > 1 or num_diff_len > 1: + raise ValueError("Chosen and rejected prompt_input_ids might only differ on the " "last token due to tokenizer merge ops.") + + # add BOS token to head of prompt + prompt_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + prompt_tokens["prompt_input_ids"] + chosen_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + chosen_tokens["prompt_input_ids"] + rejected_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + rejected_tokens["prompt_input_ids"] + + prompt_tokens["prompt_attention_mask"] = [1] + prompt_tokens["prompt_attention_mask"] + chosen_tokens["prompt_attention_mask"] = [1] + chosen_tokens["prompt_attention_mask"] + rejected_tokens["prompt_attention_mask"] = [1] + rejected_tokens["prompt_attention_mask"] + + # add EOS token to end of answer + chosen_tokens["input_ids"].append(self.tokenizer.eos_token_id) + chosen_tokens["attention_mask"].append(1) + + rejected_tokens["input_ids"].append(self.tokenizer.eos_token_id) + rejected_tokens["attention_mask"].append(1) + + longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"])) + + # if combined sequence is too long, truncate the prompt + for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]: + if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length: + if self.truncation_mode == "keep_start": + for k in ["prompt_input_ids", "prompt_attention_mask"]: + answer_tokens[k] = answer_tokens[k][: self.max_prompt_length] + elif self.truncation_mode == "keep_end": + for k in ["prompt_input_ids", "prompt_attention_mask"]: + answer_tokens[k] = answer_tokens[k][-self.max_prompt_length :] + else: + raise ValueError(f"Unknown truncation mode: {self.truncation_mode}") + + # if that's still too long, truncate the response + for answer_tokens in [chosen_tokens, rejected_tokens]: + if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length: + for k in ["input_ids", "attention_mask"]: + answer_tokens[k] = answer_tokens[k][: self.max_length - self.max_prompt_length] + + # Create labels + chosen_sequence_tokens = {k: chosen_tokens[f"prompt_{k}"] + chosen_tokens[k] for k in ["input_ids", "attention_mask"]} + rejected_sequence_tokens = {k: rejected_tokens[f"prompt_{k}"] + rejected_tokens[k] for k in ["input_ids", "attention_mask"]} + chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:] + chosen_sequence_tokens["labels"][: len(chosen_tokens["prompt_input_ids"])] = [self.label_pad_token_id] * len(chosen_tokens["prompt_input_ids"]) + rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:] + rejected_sequence_tokens["labels"][: len(rejected_tokens["prompt_input_ids"])] = [self.label_pad_token_id] * len(rejected_tokens["prompt_input_ids"]) + + for k, toks in { + "chosen_": chosen_sequence_tokens, + "rejected_": rejected_sequence_tokens, + "": prompt_tokens, + }.items(): + for type_key, tokens in toks.items(): + if type_key == "token_type_ids": + continue + batch[f"{k}{type_key}"] = tokens + + else: + chosen_tokens = self.tokenizer(chosen, truncation=True, max_length=self.max_target_length, add_special_tokens=True) + rejected_tokens = self.tokenizer(rejected, truncation=True, max_length=self.max_target_length, add_special_tokens=True) + prompt_tokens = self.tokenizer(prompt, truncation=True, max_length=self.max_prompt_length, add_special_tokens=True) + + batch["chosen_labels"] = chosen_tokens["input_ids"] + batch["rejected_labels"] = rejected_tokens["input_ids"] + batch["prompt_input_ids"] = prompt_tokens["input_ids"] + batch["prompt_attention_mask"] = prompt_tokens["attention_mask"] + + if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"): + batch["rejected_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(labels=batch["rejected_labels"]) + batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(labels=batch["chosen_labels"]) + + return batch + + @contextmanager + def null_ref_context(self): + """Context manager for handling null reference model (that is, peft adapter manipulation).""" + with self.accelerator.unwrap_model(self.model).disable_adapter() if self.is_peft_model and not self.ref_adapter_name else nullcontext(): + if self.ref_adapter_name: + self.model.set_adapter(self.ref_adapter_name) + yield + if self.ref_adapter_name: + self.model.set_adapter(self.model_adapter_name or "default") + + def compute_reference_log_probs(self, padded_batch: Dict) -> Dict: + """Computes log probabilities of the reference model for a single padded batch of a DPO specific dataset.""" + compte_ref_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext + + # compute reference logps + with torch.no_grad(), compte_ref_context_manager(): + if self.ref_model is None: + with self.null_ref_context(): + ( + reference_chosen_logps, + reference_rejected_logps, + _, + _, + ) = self.concatenated_forward(self.model, padded_batch) + else: + ( + reference_chosen_logps, + reference_rejected_logps, + _, + _, + ) = self.concatenated_forward(self.ref_model, padded_batch) + + return reference_chosen_logps, reference_rejected_logps + + @staticmethod + def concatenated_inputs( + batch: Dict[str, Union[List, torch.LongTensor]], + is_encoder_decoder: bool = False, + label_pad_token_id: int = -100, + padding_value: int = 0, + device: Optional[torch.device] = None, + ) -> Dict[str, torch.LongTensor]: + """Concatenate the chosen and rejected inputs into a single tensor. + + Args: + batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length). + is_encoder_decoder: Whether the model is an encoder-decoder model. + label_pad_token_id: The label pad token id. + padding_value: The padding value to use for the concatenated inputs_ids. + device: The device for the concatenated inputs. + + Returns: + A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'. + """ + concatenated_batch = {} + + if is_encoder_decoder: + max_length = max(batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1]) + else: + max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1]) + + for k in batch: + # import pdb; pdb.set_trace() + if k.startswith("chosen") and isinstance(batch[k], torch.Tensor): + if "labels" in k or is_encoder_decoder: + pad_value = label_pad_token_id + elif k.endswith("_input_ids"): + pad_value = padding_value + elif k.endswith("_attention_mask"): + pad_value = 0 + concatenated_key = k.replace("chosen", "concatenated") + concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value) + for k in batch: + if k.startswith("rejected") and isinstance(batch[k], torch.Tensor): + if "labels" in k or is_encoder_decoder: + pad_value = label_pad_token_id + elif k.endswith("_input_ids"): + pad_value = padding_value + elif k.endswith("_attention_mask"): + pad_value = 0 + concatenated_key = k.replace("rejected", "concatenated") + concatenated_batch[concatenated_key] = torch.cat( + ( + concatenated_batch[concatenated_key], + pad_to_length(batch[k], max_length, pad_value=pad_value), + ), + dim=0, + ).to(device=device) + + if is_encoder_decoder: + concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1).to(device=device) + concatenated_batch["concatenated_attention_mask"] = batch["prompt_attention_mask"].repeat(2, 1).to(device=device) + # import pdb; pdb.set_trace() + # repeated_list = [ + # batch['images'][0] * 2, + # batch['images'][1] * 2 + # ] + concatenated_batch["concatenated_images"] = batch["images"] * 2 + concatenated_batch["image_sizes"] = batch["image_sizes"] * 2 + concatenated_batch["modalities"] = batch["modalities"] * 2 + return concatenated_batch + + def dpo_loss( + self, + policy_chosen_logps: torch.FloatTensor, + policy_rejected_logps: torch.FloatTensor, + reference_chosen_logps: torch.FloatTensor, + reference_rejected_logps: torch.FloatTensor, + ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + """Compute the DPO loss for a batch of policy and reference model log probabilities. + + Args: + policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,) + policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,) + reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,) + reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,) + + Returns: + A tuple of three tensors: (losses, chosen_rewards, rejected_rewards). + The losses tensor contains the DPO loss for each example in the batch. + The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively. + """ + pi_logratios = policy_chosen_logps - policy_rejected_logps + if self.reference_free: + ref_logratios = torch.tensor([0], dtype=pi_logratios.dtype, device=pi_logratios.device) + else: + ref_logratios = reference_chosen_logps - reference_rejected_logps + + pi_logratios = pi_logratios.to(self.accelerator.device) + ref_logratios = ref_logratios.to(self.accelerator.device) + logits = pi_logratios - ref_logratios + # print(f"pi log ratios: {pi_logratios}") + # print(f"ref log ratios: {ref_logratios}") + # print(f"logits: {logits}") + # The beta is a temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5. + # We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the labels and + # calculates a conservative DPO loss. + if self.loss_type == "sigmoid": + losses = -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) - F.logsigmoid(-self.beta * logits) * self.label_smoothing + elif self.loss_type == "hinge": + losses = torch.relu(1 - self.beta * logits) + elif self.loss_type == "ipo": + # eqn (17) of the paper where beta is the regularization parameter for the IPO loss, denoted by tau in the paper. + losses = (logits - 1 / (2 * self.beta)) ** 2 + elif self.loss_type == "kto_pair": + # eqn (7) of the HALOs paper + chosen_KL = (policy_chosen_logps - reference_chosen_logps).mean().clamp(min=0) + rejected_KL = (policy_rejected_logps - reference_rejected_logps).mean().clamp(min=0) + + chosen_logratios = policy_chosen_logps - reference_chosen_logps + rejected_logratios = policy_rejected_logps - reference_rejected_logps + # As described in the KTO report, the KL term for chosen (rejected) is estimated using the rejected (chosen) half. + losses = torch.cat( + ( + 1 - F.sigmoid(self.beta * (chosen_logratios - rejected_KL)), + 1 - F.sigmoid(self.beta * (chosen_KL - rejected_logratios)), + ), + 0, + ) + else: + raise ValueError(f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'kto_pair']") + + chosen_rewards = self.beta * (policy_chosen_logps.to(self.accelerator.device) - reference_chosen_logps.to(self.accelerator.device)).detach() + rejected_rewards = self.beta * (policy_rejected_logps.to(self.accelerator.device) - reference_rejected_logps.to(self.accelerator.device)).detach() + + return losses, chosen_rewards, rejected_rewards + + @staticmethod + def get_batch_logps( + logits: torch.FloatTensor, + labels: torch.LongTensor, + average_log_prob: bool = False, + label_pad_token_id: int = -100, + is_encoder_decoder: bool = False, + ) -> torch.FloatTensor: + """Compute the log probabilities of the given labels under the given logits. + + Args: + logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size) + labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length) + average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens. + label_pad_token_id: The label pad token id. + is_encoder_decoder: Whether the model is an encoder-decoder model. + + Returns: + A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits. + """ + if logits.shape[:-1] != labels.shape: + raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.") + + if not is_encoder_decoder: + labels = labels[:, 1:].clone() + logits = logits[:, :-1, :] + loss_mask = labels != label_pad_token_id + + # dummy token; we'll ignore the losses on these tokens later + labels[labels == label_pad_token_id] = 0 + + per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) + + if average_log_prob: + return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) + else: + return (per_token_logps * loss_mask).sum(-1) + + def get_sft_loss(self, logits, labels): + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = nn.CrossEntropyLoss() + shift_logits = shift_logits.view(-1, shift_logits.size(-1)) + shift_labels = shift_labels.view(-1) + # Enable model/pipeline parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + return loss + + def concatenated_forward(self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together. + + We do this to avoid doing two forward passes, because it's faster for FSDP. + """ + # import pdb; pdb.set_trace() + concatenated_batch = self.concatenated_inputs( + batch, + is_encoder_decoder=self.is_encoder_decoder, + label_pad_token_id=self.label_pad_token_id, + padding_value=self.padding_value, + device=self.accelerator.device, + ) + len_chosen = batch["chosen_labels"].shape[0] + + # import pdb; pdb.set_trace() + all_logits, new_labels = model( + concatenated_batch["concatenated_input_ids"], + attention_mask=concatenated_batch["concatenated_attention_mask"], + labels=concatenated_batch["concatenated_labels"], + images=concatenated_batch["concatenated_images"], + image_sizes=concatenated_batch["image_sizes"], + modalities=concatenated_batch["modalities"], + use_cache=False, + dpo_forward=True, + ) + all_logits = all_logits.to(torch.float32) + all_logps = self.get_batch_logps( + all_logits, + new_labels, + average_log_prob=self.loss_type == "ipo", + is_encoder_decoder=self.is_encoder_decoder, + label_pad_token_id=self.label_pad_token_id, + ) + + chosen_logps = all_logps[:len_chosen] + rejected_logps = all_logps[len_chosen:] + + # don't count image embeds logits + # loss_mask = new_labels != -100 + # logits = [all_logits[i][loss_mask[i]] for i in range(loss_mask.shape[0])] + # chosen_logits = logits[:len_chosen] + # rejected_logits = logits[len_chosen:] + # chosen_logits = [l.detach().cpu().mean() for l in chosen_logits] + # rejected_logits = [l.detach().cpu().mean() for l in rejected_logits] + # chosen_logits = sum(chosen_logits)/len_chosen + # rejected_logits = sum(rejected_logits)/len_chosen + + chosen_logits = all_logits[:len_chosen] + rejected_logits = all_logits[len_chosen:] + + chosen_labels = new_labels[:len_chosen] + rejected_labels = new_labels[len_chosen:] + + return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_labels, rejected_labels) + + def get_batch_loss_metrics( + self, + model, + batch: Dict[str, Union[List, torch.LongTensor]], + train_eval: Literal["train", "eval"] = "train", + ): + """Compute the DPO loss and other metrics for the given batch of inputs for train or test. + CHANGE: 1. add sft loss + 2. all gather metrics + """ + metrics = {} + + ( + policy_chosen_logps, + policy_rejected_logps, + policy_chosen_logits, + policy_rejected_logits, + chosen_labels, + rejected_labels, + ) = self.concatenated_forward(model, batch) + + # if reference_chosen_logps and reference_rejected_logps in batch use them, otherwise use the reference model + if "reference_chosen_logps" in batch and "reference_rejected_logps" in batch: + reference_chosen_logps = batch["reference_chosen_logps"] + reference_rejected_logps = batch["reference_rejected_logps"] + else: + with torch.no_grad(): + if self.ref_model is None: + with self.null_ref_context(): + ( + reference_chosen_logps, + reference_rejected_logps, + ) = self.concatenated_forward( + self.model, batch + )[:2] + else: + ( + reference_chosen_logps, + reference_rejected_logps, + ) = self.concatenated_forward( + self.ref_model, batch + )[:2] + + unscaled_dpo_losses, chosen_rewards, rejected_rewards = self.dpo_loss( + policy_chosen_logps, + policy_rejected_logps, + reference_chosen_logps, + reference_rejected_logps, + ) + unscaled_dpo_losses = unscaled_dpo_losses.mean() + dpo_losses = unscaled_dpo_losses * self.dpo_alpha + unscaled_sft_loss = self.get_sft_loss(policy_chosen_logits, chosen_labels) + sft_loss = unscaled_sft_loss * self.gamma + + # print(sft_loss.shape, dpo_losses.shape) + losses = dpo_losses + sft_loss + # losses = sft_loss # sft only + # losses = dpo_losses # dpo only + reward_accuracies = (chosen_rewards > rejected_rewards).float() + + def all_gather_tensor(tensor): + if torch.distributed.is_available() and torch.distributed.is_initialized(): + tensor = tensor.detach() + gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())] + torch.distributed.all_gather(gathered_tensor, tensor) + tensor = torch.cat(gathered_tensor, dim=0) + # else: + # print('not distributed') + return tensor + + # gather chosen_rewards across devices + chosen_rewards = all_gather_tensor(chosen_rewards) + rejected_rewards = all_gather_tensor(rejected_rewards) + reward_accuracies = all_gather_tensor(reward_accuracies) + policy_chosen_logps = all_gather_tensor(policy_chosen_logps) + policy_rejected_logps = all_gather_tensor(policy_rejected_logps) + reference_chosen_logps = all_gather_tensor(reference_chosen_logps) + reference_rejected_logps = all_gather_tensor(reference_rejected_logps) + + prefix = "eval_" if train_eval == "eval" else "" + metrics[f"{prefix}losses/dpo"] = unscaled_dpo_losses.cpu() + metrics[f"{prefix}losses/sft"] = unscaled_sft_loss.cpu() + metrics[f"{prefix}losses/total"] = losses.cpu() + metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu() + metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().cpu() + metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean().cpu() + metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean().cpu() + # policy logps + metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean().cpu() + metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean().cpu() + # policy logits (exclude image tokens) + # metrics[f"{prefix}logits/rejected"] =policy_rejected_logits + # metrics[f"{prefix}logits/chosen"] = policy_chosen_logits + # reference logps + metrics[f"{prefix}ref_logps/rejected"] = reference_rejected_logps.mean().cpu() + metrics[f"{prefix}ref_logps/chosen"] = reference_chosen_logps.mean().cpu() + + # metrics all pick .4 digits + # for k in metrics: + # metrics[k] = round(metrics[k].item(), 4) + + return losses, metrics + + def compute_loss( + self, + model: Union[PreTrainedModel, nn.Module], + inputs: Dict[str, Union[torch.Tensor, Any]], + return_outputs=False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]: + if not self.use_dpo_data_collator: + warnings.warn( + "compute_loss is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than " + "DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator" + ) + + compute_loss_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext + + with compute_loss_context_manager(): + loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train") + + # force log the metrics + self.store_metrics(metrics, train_eval="train") + + if return_outputs: + return (loss, metrics) + return loss + + def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]: + """Generate samples from the model and reference model for the given batch of inputs.""" + + # If one uses `generate_during_eval` with peft + bf16, we need to explictly call generate with + # the torch cuda amp context manager as some hidden states are silently casted to full precision. + generate_context_manager = nullcontext if not self._peft_has_been_casted_to_bf16 else torch.cuda.amp.autocast + + with generate_context_manager(): + policy_output = model.generate( + input_ids=batch["prompt_input_ids"], + attention_mask=batch["prompt_attention_mask"], + max_length=self.max_length, + do_sample=True, + pad_token_id=self.tokenizer.pad_token_id, + ) + + # if reference_output in batch use that otherwise use the reference model + if "reference_output" in batch: + reference_output = batch["reference_output"] + else: + if self.ref_model is None: + with self.null_ref_context(): + reference_output = self.model.generate( + input_ids=batch["prompt_input_ids"], + attention_mask=batch["prompt_attention_mask"], + max_length=self.max_length, + do_sample=True, + pad_token_id=self.tokenizer.pad_token_id, + ) + else: + reference_output = self.ref_model.generate( + input_ids=batch["prompt_input_ids"], + attention_mask=batch["prompt_attention_mask"], + max_length=self.max_length, + do_sample=True, + pad_token_id=self.tokenizer.pad_token_id, + ) + + policy_output = pad_to_length(policy_output, self.max_length, self.tokenizer.pad_token_id) + policy_output_decoded = self.tokenizer.batch_decode(policy_output, skip_special_tokens=True) + + reference_output = pad_to_length(reference_output, self.max_length, self.tokenizer.pad_token_id) + reference_output_decoded = self.tokenizer.batch_decode(reference_output, skip_special_tokens=True) + + return policy_output_decoded, reference_output_decoded + + def prediction_step( + self, + model: Union[PreTrainedModel, nn.Module], + inputs: Dict[str, Union[torch.Tensor, Any]], + prediction_loss_only: bool, + ignore_keys: Optional[List[str]] = None, + ): + if not self.use_dpo_data_collator: + warnings.warn( + "prediction_step is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than " + "DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator" + ) + if ignore_keys is None: + if hasattr(model, "config"): + ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", []) + else: + ignore_keys = [] + + prediction_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext + + with torch.no_grad(), prediction_context_manager(): + loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval") + + # force log the metrics + self.store_metrics(metrics, train_eval="eval") + + if prediction_loss_only: + return (loss.detach(), None, None) + + # logits for the chosen and rejected samples from model + logits_dict = { + "eval_logits/chosen": metrics["eval_logits/chosen"], + "eval_logits/rejected": metrics["eval_logits/rejected"], + } + logits = tuple(v.unsqueeze(dim=0) for k, v in logits_dict.items() if k not in ignore_keys) + logits = torch.stack(logits).mean(axis=1).to(self.accelerator.device) + labels = torch.zeros(logits.shape[0], device=self.accelerator.device) + + return (loss.detach(), logits, labels) + + def store_metrics(self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None: + for key, value in metrics.items(): + self._stored_metrics[train_eval][key].append(value) + + def evaluation_loop( + self, + dataloader: DataLoader, + description: str, + prediction_loss_only: Optional[bool] = None, + ignore_keys: Optional[List[str]] = None, + metric_key_prefix: str = "eval", + ) -> EvalLoopOutput: + """ + Overriding built-in evaluation loop to store metrics for each batch. + Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`. + + Works both with or without labels. + """ + + # Sample and save to game log if requested (for one batch to save time) + if self.generate_during_eval: + # Generate random indices within the range of the total number of samples + num_samples = len(dataloader.dataset) + random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size) + + # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader + random_batch_dataset = dataloader.dataset.select(random_indices) + random_batch = self.data_collator(random_batch_dataset) + random_batch = self._prepare_inputs(random_batch) + + policy_output_decoded, ref_output_decoded = self.get_batch_samples(self.model, random_batch) + + self.log( + { + "game_log": wandb.Table( + columns=["Prompt", "Policy", "Ref Model"], + rows=[[prompt, pol[len(prompt) :], ref[len(prompt) :]] for prompt, pol, ref in zip(random_batch["prompt"], policy_output_decoded, ref_output_decoded)], + ) + } + ) + self.state.log_history.pop() + + # Base evaluation + initial_output = super().evaluation_loop(dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix) + + return initial_output + + def log(self, logs: Dict[str, float]) -> None: + """ + Log `logs` on the various objects watching training, including stored metrics. + + Args: + logs (`Dict[str, float]`): + The values to log. + """ + # logs either has 'loss' or 'eval_loss' + train_eval = "train" if "loss" in logs else "eval" + # Add averaged stored metrics to logs + for key, metrics in self._stored_metrics[train_eval].items(): + logs[key] = torch.tensor(metrics).mean().item() + del self._stored_metrics[train_eval] + return super().log(logs) + + @wraps(Trainer.push_to_hub) + def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str: + """ + Overwrite the `push_to_hub` method in order to force-add the tag "sft" when pushing the + model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details. + """ + kwargs = trl_sanitze_kwargs_for_tagging(model=self.model, tag_names=self._tag_names, kwargs=kwargs) + + return super().push_to_hub(commit_message=commit_message, blocking=blocking, **kwargs) diff --git a/trl/trainer/iterative_sft_trainer.py b/trl/trainer/iterative_sft_trainer.py new file mode 100755 index 000000000..1fd55442c --- /dev/null +++ b/trl/trainer/iterative_sft_trainer.py @@ -0,0 +1,334 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import warnings +from typing import Callable, Dict, List, Optional, Tuple, Union + +import torch +from datasets import Dataset +from torch.utils.data import DataLoader +from transformers import ( + DataCollator, + DataCollatorForLanguageModeling, + DataCollatorForSeq2Seq, + PreTrainedModel, + PreTrainedTokenizerBase, + Trainer, + TrainingArguments, +) +from transformers.trainer_utils import EvalLoopOutput + +from ..core import PPODecorators +from ..import_utils import is_peft_available + + +if is_peft_available(): + from peft import PeftModel + + +class IterativeSFTTrainer(Trainer): + """ + The IterativeSFTTrainer can be used to finetune models with methods that requires some steps between optimization. + + Attributes: + **model** (`PreTrainedModel`) -- Model to be optimized, either an 'AutoModelForCausalLM' or an 'AutoModelForSeq2SeqLM'. + Check the documentation of `PreTrainedModel` for more details. + **args** (`transformers.TrainingArguments`): -- The arguments to use for training. + **tokenizer** (`PreTrainedTokenizerBase`) -- Tokenizer to be used for encoding the + data. Check the documentation of `transformers.PreTrainedTokenizer` and + `transformers.PreTrainedTokenizerFast` for more details. + **optimizers** (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): -- The optimizer and scheduler to use for training. + **data_collator** (Union[DataCollatorForLanguageModeling, DataCollatorForSeq2Seq], *optional*) -- Data collator to be used for training and + passed along the dataloader. + **eval_dataset** (`datasets.Dataset`): The dataset to use for evaluation. + **max_length** (`int`, defaults to `None`): -- The maximum length of the input. + **truncation_mode** (`str`, defaults to `keep_end`): -- The truncation mode to use, either `keep_end` or `keep_start`. + **preprocess_logits_for_metrics** (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): -- The function to use to preprocess the logits before computing the metrics. + **compute_metrics** (`Callable[[EvalPrediction], Dict]`, *optional*): -- The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to metric values. + **optimize_device_cache ** (`bool`, *optional*, defaults to `False`) -- Optimize CUDA cache for slightly more memory-efficient training. + """ + + def __init__( + self, + model: PreTrainedModel = None, + args: TrainingArguments = None, + tokenizer: PreTrainedTokenizerBase = None, + optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = ( + None, + None, + ), + data_collator: Optional[DataCollator] = None, + eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, + max_length: Optional[int] = None, + truncation_mode: Optional[str] = "keep_end", + preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, + compute_metrics: Optional[Callable[[EvalLoopOutput], Dict]] = None, + optimize_device_cache: Optional[bool] = False, + ): + # Step 0: check positional arguments validity + if not isinstance(tokenizer, (PreTrainedTokenizerBase)): + raise ValueError(f"tokenizer must be a PreTrainedTokenizerBase like a PreTrainedTokenizer or a PreTrainedTokenizerFast, got {type(tokenizer)}") + if not isinstance(model, PreTrainedModel): + raise ValueError(f"model must be a PreTrainedModel, got {type(model)}") + if not model.can_generate(): + warnings.warn(f"The current model class {type(model)} is not compatible with `.generate()`" "Please make sure that this is intended.") + if optimizers[1] is None and args.max_steps == -1: + raise ValueError("When no scheduler is provided, you need to set the total number of training steps to perform `max_steps`") + + self.is_encoder_decoder = getattr(model.config, "is_encoder_decoder", False) + self.is_peft_model = is_peft_available() and isinstance(model, PeftModel) + + self.tokenizer = tokenizer + + if data_collator is None: + if self.is_encoder_decoder: + warnings.warn("No data collator is provided. Using 'DataCollatorForSeq2Seq' with" "'labels_pad_token_id' set to '-100' and 'pad_to_multiple_of' set to 8.") + self.data_collator = DataCollatorForSeq2Seq(tokenizer, label_pad_token_id=-100, pad_to_multiple_of=8) + else: + warnings.warn("No data collator is provided. Using 'DataCollatorForLanguageModeling'") + self.data_collator = DataCollatorForLanguageModeling(self.tokenizer, mlm=False) + else: + self.data_collator = data_collator + + self.max_length = max_length + self.truncation_mode = truncation_mode + self.optimize_device_cache = optimize_device_cache + + super().__init__( + model=model, + args=args, + data_collator=self.data_collator, + eval_dataset=eval_dataset, + tokenizer=tokenizer, + compute_metrics=compute_metrics, + optimizers=optimizers, + preprocess_logits_for_metrics=preprocess_logits_for_metrics, + ) + + self.create_optimizer_and_scheduler(self.args.max_steps) + + # prepare model, optimizer and lr_scheduler + self.model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(self.model, self.optimizer, self.lr_scheduler) + + self.tokenizer.truncation_side = "left" if self.truncation_mode == "keep_end" else "right" + + if not hasattr(self, "accelerator"): + raise AttributeError("Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`.") + + PPODecorators.optimize_device_cache = self.optimize_device_cache + + def prepare_model_inputs(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, labels: torch.Tensor): + if attention_mask is None: + attention_mask = [torch.ones_like(ids) for ids in input_ids] + + if self.is_encoder_decoder: + input_data = self.data_collator([{"input_ids": ids, "attention_mask": att, "labels": lab} for ids, att, lab in zip(input_ids, attention_mask, labels)]).to(self.model.device) + + input_data.pop("decoder_input_ids", None) # This is directly computed inside the model + + input_data["labels"][input_data["labels"] == self.tokenizer.pad_token_id] = -100 + + else: + input_data = self.data_collator([{"input_ids": ids, "attention_mask": att} for ids, att in zip(input_ids, attention_mask)]).to(self.model.device) + + # truncate in case the user has provided input_ids, attention_mask and labels + if self.max_length is not None: + if self.truncation_mode == "keep_start": + input_data = {k: v[: self.max_length] for k, v in input_data.items()} + elif self.truncation_mode == "keep_end": + input_data = {k: v[-self.max_length :] for k, v in input_data.items()} + else: + raise ValueError(f"Unknown truncation mode: {self.truncation_mode}") + + return input_data + + @staticmethod + def _step_safety_checker( + input_ids: List[torch.LongTensor], + attention_mask: List[torch.LongTensor], + labels: List[torch.LongTensor], + texts: List[str], + texts_labels: List[str], + ): + """ + Check if the input data is valid for training. + + Args: + input_ids (List[`torch.LongTensor`]): + List of tensors containing the input_ids + attention_mask (List[`torch.LongTensor`]): + List of tensors containing the attention_mask + labels (List[`torch.FloatTensor`]): + List of tensors containing the labels + texts (List[`str`]): + List of string containing the text input. + texts_labels (List[`str`]): + List of string containing the text labels. + Returns: + `tuple`: The input data. + """ + if texts is None: + if attention_mask is None: + for name, tensor_list in zip(["input_ids", "labels"], [input_ids, labels]): + if not isinstance(tensor_list, list): + raise ValueError(f"{name} must be a list of tensors - got {type(tensor_list)}") + if not isinstance(tensor_list[0], torch.Tensor): + raise ValueError(f"Elements in {name} must be tensors - got {type(tensor_list[0])}") + else: + for name, tensor_list in zip(["input_ids", "attention_mask", "labels"], [input_ids, attention_mask, labels]): + if not isinstance(tensor_list, list): + raise ValueError(f"{name} must be a list of tensors - got {type(tensor_list)}") + if not isinstance(tensor_list[0], torch.Tensor): + raise ValueError(f"Elements in {name} must be tensors - got {type(tensor_list[0])}") + else: + if not isinstance(texts, list): + raise ValueError(f"'text' must be a list of strings - got {type(texts)}") + if not isinstance(texts[0], str): + raise ValueError(f"Elements in 'text' must be strings - got {type(texts[0])}") + if texts_labels is not None: + if not isinstance(texts_labels, list): + raise ValueError(f"'text_labels' must be a list of strings - got {type(texts_labels)}") + if not isinstance(texts_labels[0], str): + raise ValueError(f"Elements in 'text_labels' must be strings - got {type(texts_labels[0])}") + + return input_ids, attention_mask, labels, texts, texts_labels + + @PPODecorators.empty_device_cache() + def step( + self, + input_ids: Optional[List[torch.LongTensor]] = None, + attention_mask: Optional[List[torch.LongTensor]] = None, + labels: Optional[List[torch.LongTensor]] = None, + texts: Optional[List[str]] = None, + texts_labels: Optional[List[str]] = None, + ): + """ + Run an optimisation step given a list of input_ids, attention_mask, and labels or a list of text and text_labels. + Args: + input_ids (List[`torch.LongTensor`]): + List of tensors containing the input_ids (if not provided, text will be used) + attention_mask (List[`torch.LongTensor`], , *optional*): + List of tensors containing the attention_mask + labels (List[`torch.FloatTensor`], *optional*): + List of tensors containing the labels (if set to None, will default to input_ids) + texts (List[`str`], *optional*): + List of strings containing the text input (if not provided, input_ids will directly be used) + texts_labels (List[`str`], *optional*): + List of strings containing the text labels (if set to None, will default to text) + Returns: + `dict[str, Any]`: A summary of the training statistics + """ + self.model.train() + + if self.state.global_step == 0: + self.tr_loss = torch.tensor(0.0).to(self.args.device) + self._globalstep_last_logged = self.state.global_step + + if input_ids is None and texts is None: + raise ValueError("Step should include `input_ids` or `texts` as keyword arguments.") + elif input_ids is not None and texts is not None: + warnings.warn("Both 'input_ids' and 'texts' are provided. 'input_ids' will be overwritten using inputs provided by the 'texts' keyword argument.") + + if labels is None and texts_labels is None and self.is_encoder_decoder: + raise ValueError("No 'labels' or 'text_labels' are provided. When using an encoder-decoder architecture, 'labels' or 'text_labels' must be passed.") + + input_ids, attention_mask, labels, texts, texts_labels = self._step_safety_checker(input_ids, attention_mask, labels, texts, texts_labels) + + if texts is not None: + model_inputs = self.tokenizer(texts, max_length=self.max_length, truncation=True, padding=True, return_tensors="pt") + + input_ids, attention_mask = model_inputs["input_ids"], model_inputs["attention_mask"] + + if texts_labels is not None: + labels = self.tokenizer(texts, max_length=self.max_length, truncation=True, padding=True, return_tensors="pt")["input_ids"] + + if labels is None: + warnings.warn("No labels are provided. Setting labels to input_ids") + labels = input_ids + + model_inputs = self.prepare_model_inputs(input_ids, attention_mask, labels) + + model_inputs_names = list(model_inputs.keys()) + + batch_dict = {} + batch_dict.update(model_inputs) + + def collator(data): + return_dict = dict() + for key in data[0]: + if key in ["input_ids", "attention_mask", "labels"]: + return_dict[key] = torch.stack([d[key] for d in data]).to(self.model.device) + return return_dict + + batch_data = Dataset.from_dict(batch_dict) + batch_data.set_format("torch") + + step_dataloader = DataLoader( + batch_data, + batch_size=self.args.per_device_train_batch_size, + shuffle=True, + collate_fn=collator, + ) + + for _, batch in enumerate(step_dataloader): + with self.accelerator.accumulate(self.model): + model_inputs = {k: batch[k] for k in model_inputs_names} + loss = self.compute_loss(self.model, model_inputs) + + if self.args.n_gpu > 1: + loss = loss.mean() + + tr_loss_step = loss.detach() + + self.accelerator.backward(loss) + + if self.accelerator.sync_gradients and self.args.max_grad_norm is not None: + self.accelerator.clip_grad_norm_( + self.model.parameters(), + self.args.max_grad_norm, + ) + + self.optimizer.step() + self.optimizer.zero_grad() + if self.lr_scheduler is not None: + self.lr_scheduler.step() + + self.state.global_step += 1 + + # update stats etc + self.tr_loss += tr_loss_step + + self._maybe_log_save_evaluate() + + def _maybe_log_save_evaluate(self): + # check if eval is required + if self.args.eval_steps is not None: + if self.state.global_step % self.args.eval_steps == 0 and self.state.global_step != 0: + self.evaluate(self.eval_dataset) + + # check if logging is required + if self.args.logging_steps is not None: + if self.state.global_step % self.args.logging_steps == 0 and self.state.global_step != 0: + logs: Dict[str, float] = {} + + tr_loss_scalar = self._nested_gather(self.tr_loss).mean().item() + + # reset tr_loss to zero + self.tr_loss -= self.tr_loss + + logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4) + logs["learning_rate"] = self._get_learning_rate() + + self._globalstep_last_logged = self.state.global_step + + self.log(logs) diff --git a/trl/trainer/model_config.py b/trl/trainer/model_config.py new file mode 100755 index 000000000..e6df85921 --- /dev/null +++ b/trl/trainer/model_config.py @@ -0,0 +1,71 @@ +from dataclasses import dataclass, field +from typing import List, Optional + +from ..core import flatten_dict + + +@dataclass +class ModelConfig: + """ + Arguments which define the model and tokenizer to load. + """ + + model_name_or_path: Optional[str] = field( + default=None, + metadata={"help": ("The model checkpoint for weights initialization.")}, + ) + model_revision: str = field( + default="main", + metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, + ) + torch_dtype: Optional[str] = field( + default=None, + metadata={ + "help": ("Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the " "dtype will be automatically derived from the model's weights."), + "choices": ["auto", "bfloat16", "float16", "float32"], + }, + ) + trust_remote_code: bool = field(default=False, metadata={"help": "Trust remote code when loading a model."}) + attn_implementation: Optional[str] = field( + default=None, + metadata={"help": ("Which attention implementation to use; you can run --attn_implementation=flash_attention_2, in which case you must install this manually by running `pip install flash-attn --no-build-isolation`")}, + ) + use_peft: bool = field( + default=False, + metadata={"help": ("Whether to use PEFT or not for training.")}, + ) + lora_r: Optional[int] = field( + default=16, + metadata={"help": ("LoRA R value.")}, + ) + lora_alpha: Optional[int] = field( + default=32, + metadata={"help": ("LoRA alpha.")}, + ) + lora_dropout: Optional[float] = field( + default=0.05, + metadata={"help": ("LoRA dropout.")}, + ) + lora_target_modules: Optional[List[str]] = field( + default=None, + metadata={"help": ("LoRA target modules.")}, + ) + lora_modules_to_save: Optional[List[str]] = field( + default=None, + metadata={"help": ("Model layers to unfreeze & train")}, + ) + load_in_8bit: bool = field(default=False, metadata={"help": "use 8 bit precision for the base model - works only with LoRA"}) + load_in_4bit: bool = field(default=False, metadata={"help": "use 4 bit precision for the base model - works only with LoRA"}) + + bnb_4bit_quant_type: Optional[str] = field(default="nf4", metadata={"help": "precise the quantization type (fp4 or nf4)"}) + use_bnb_nested_quant: bool = field(default=False, metadata={"help": "use nested quantization"}) + + def to_dict(self): + output_dict = {} + for key, value in self.__dict__.items(): + output_dict[key] = value + return flatten_dict(output_dict) + + def __post_init__(self): + if self.load_in_8bit and self.load_in_4bit: + raise ValueError("You can't use 8 bit and 4 bit precision at the same time") diff --git a/trl/trainer/ppo_config.py b/trl/trainer/ppo_config.py new file mode 100755 index 000000000..8642eab51 --- /dev/null +++ b/trl/trainer/ppo_config.py @@ -0,0 +1,175 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import os +import sys +import warnings +from dataclasses import dataclass, field +from typing import Literal, Optional + +import numpy as np +import tyro +from typing_extensions import Annotated + +from trl.trainer.utils import exact_div + +from ..core import flatten_dict +from ..import_utils import is_wandb_available + + +JSONDict = Annotated[Optional[dict], tyro.conf.arg(metavar="JSON", constructor=json.loads)] + + +@dataclass +class PPOConfig: + """ + Configuration class for PPOTrainer + """ + + # common parameters + exp_name: str = os.path.basename(sys.argv[0])[: -len(".py")] + """the name of this experiment (by default is the file name without the extension name)""" + seed: int = 0 + """Seed value for random generations""" + log_with: Optional[Literal["wandb", "tensorboard"]] = None + """Log with either 'wandb' or 'tensorboard', check https://huggingface.co/docs/accelerate/usage_guides/tracking for more details""" + task_name: Optional[str] = None + """Name of task to use - used only for tracking purposes""" + model_name: Optional[str] = "gpt2" + """Name of model to use - used only for tracking purposes""" + query_dataset: Optional[str] = "imdb" + """Name of dataset to query - used only for tracking purposes""" + reward_model: Optional[str] = "sentiment-analysis:lvwerra/distilbert-imdb" + """The reward model to use - used only for tracking purposes""" + remove_unused_columns: bool = True + """Remove unused columns from the dataset if `datasets.Dataset` is used""" + tracker_kwargs: JSONDict = field(default_factory=dict) + """Keyword arguments for the tracker (e.g. python ppo.py --tracker_kwargs='{"wandb": {"entity": "my_wandb_entity", "name": "my_exp_name"}}'""" + accelerator_kwargs: JSONDict = field(default_factory=dict) + """Keyword arguments for the accelerator""" + project_kwargs: JSONDict = field(default_factory=dict) + """Keyword arguments for the accelerator project config (e.g. `logging_dir`)""" + tracker_project_name: str = "trl" + """Name of project to use for tracking""" + push_to_hub_if_best_kwargs: JSONDict = field(default_factory=dict) + """Keyword arguments for pushing model to the hub during training (e.g. repo_id)""" + + # hyperparameters + steps: int = 20000 + """Number of training steps""" + learning_rate: float = 1.41e-5 + """Adam learning rate""" + adap_kl_ctrl: bool = True + """Use adaptive KL control, otherwise linear""" + init_kl_coef: Optional[float] = 0.2 + """Initial KL penalty coefficient (used for adaptive and linear control)""" + kl_penalty: Literal["kl", "abs", "mse", "full"] = "kl" + """kl penalty options: 'kl': model_logp - ref_logp, 'abs': abs(kl), 'mse': mean squared error mse(kl) and 'full': the actual kl for all tokens in the distribution""" + target: Optional[float] = 6 + """Target KL value for adaptive KL control""" + horizon: Optional[float] = 10000 + """Horizon for adaptive KL control""" + gamma: float = 1 + """Gamma parameter for advantage calculation""" + lam: float = 0.95 + """Lambda parameter for advantage calculation""" + cliprange: float = 0.2 + """Range for clipping in PPO policy gradient loss""" + cliprange_value: float = 0.2 + """Range for clipping values in loss calculation""" + vf_coef: float = 0.1 + """Scaling factor for value loss""" + batch_size: int = 128 + """Number of samples per optimisation step""" + forward_batch_size: Optional[int] = None + """DEPRECATED: use `mini_batch_size` instead, which does the same thing.""" + mini_batch_size: int = 128 + """Number of samples optimized in each mini batch""" + gradient_accumulation_steps: int = 1 + """The number of gradient accumulation steps""" + world_size: tyro.conf.Suppress[int] = None + """The world size for distributed training""" + ppo_epochs: int = 4 + """Number of optimisation epochs per batch of samples""" + max_grad_norm: Optional[float] = None + """Maximum gradient norm for gradient clipping""" + optimize_cuda_cache: Optional[bool] = None + """DEPRECATED: use `optimize_device_cache` instead, which does the same thing.""" + optimize_device_cache: Optional[bool] = False + """Optimize device cache for slightly more memory-efficient training""" + early_stopping: bool = False + """Whether to stop the PPO optimization loop early is the KL too high""" + target_kl: float = 1 + """Stop early if we exceed this value by over 50%""" + compare_steps: int = 1 + """Number of steps between comparison of the current reward with the best seen so far""" + ratio_threshold: float = 10.0 + """Skip mini-batches with high PPO ratios that can cause loss spikes""" + use_score_scaling: bool = False + """Use score scaling""" + use_score_norm: bool = False + """Use score normalization. Only applicable if use_score_scaling is True""" + score_clip: Optional[float] = None + """Score clipping""" + whiten_rewards: bool = False + """Whiten the rewards before compute advantages""" + + # computed hyperparameters at runtime; we use `tyro.conf.Suppress` to hide them from the help text + is_encoder_decoder: Optional[tyro.conf.Suppress[bool]] = None + """TO BE FILLED In RUNTIME: Whether the model is an encoder-decoder model""" + is_peft_model: Optional[tyro.conf.Suppress[bool]] = None + """TO BE FILLED In RUNTIME: Whether the model is a PEFT model""" + backward_batch_size: tyro.conf.Suppress[int] = None + """TO BE FILLED In RUNTIME: Number of samples optimized in an `optimizer.step()` call""" + global_backward_batch_size: tyro.conf.Suppress[int] = None + """TO BE FILLED In RUNTIME: the effective `backward_batch_size` across all processes""" + global_batch_size: tyro.conf.Suppress[int] = None + """TO BE FILLED In RUNTIME: the effective `batch_size` across all processes""" + + if optimize_cuda_cache is not None: + warnings.warn("The `optimize_cuda_cache` argument will be deprecated soon, please use `optimize_device_cache` instead.") + optimize_device_cache = optimize_cuda_cache + else: + optimize_device_cache = False + + def __post_init__(self): + if self.forward_batch_size is not None: + warnings.warn( + "Note that using `forward_batch_size` is deprecated, use `mini_batch_size` instead. By setting it you overwrite `mini_batch_size` which affects both the batch size during forward passes and also the mini batch size for PPO optimization." + ) + self.mini_batch_size = self.forward_batch_size + + self.backward_batch_size = self.mini_batch_size * self.gradient_accumulation_steps + exact_div( + self.batch_size, + self.backward_batch_size, + "`batch_size`", + "`mini_batch_size * gradient_accumulation_steps`", + "`batch_size` must be a multiple of `mini_batch_size * gradient_accumulation_steps`", + ) + + # check if wandb is installed + if self.log_with == "wandb": + # raise error if wandb is not installed + if not is_wandb_available(): + raise ImportError("Please install wandb to use wandb logging. You can do this by running `pip install wandb`.") + + self.total_ppo_epochs = int(np.ceil(self.steps / self.batch_size)) + assert self.kl_penalty in ["kl", "abs", "mse", "full"] + + def to_dict(self): + output_dict = {} + for key, value in self.__dict__.items(): + output_dict[key] = value + return flatten_dict(output_dict) diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py new file mode 100755 index 000000000..ce7e09397 --- /dev/null +++ b/trl/trainer/ppo_trainer.py @@ -0,0 +1,1397 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +import math +import os +import time +import typing +import warnings +from contextlib import nullcontext +from typing import Callable, List, Optional, Union + +import datasets +import numpy as np +import torch +import torch.nn.functional as F +from accelerate import Accelerator +from accelerate.utils import ProjectConfiguration, gather_object, is_deepspeed_available +from datasets import Dataset +from huggingface_hub import whoami +from packaging import version +from torch.optim import Adam +from transformers import ( + DataCollatorForLanguageModeling, + PreTrainedTokenizer, + PreTrainedTokenizerBase, + PreTrainedTokenizerFast, +) + +from ..core import ( + WANDB_PADDING, + PPODecorators, + clip_by_value, + convert_to_scalar, + entropy_from_logits, + flatten_dict, + logprobs_from_logits, + masked_mean, + masked_var, + masked_whiten, + set_seed, + stack_dicts, + stats_to_np, +) +from ..import_utils import is_npu_available, is_torch_greater_2_0, is_xpu_available +from ..models import SUPPORTED_ARCHITECTURES, PreTrainedModelWrapper, create_reference_model +from . import AdaptiveKLController, BaseTrainer, FixedKLController, PPOConfig, RunningMoments + + +if is_deepspeed_available(): + import deepspeed + +MODEL_CARD_TEMPLATE = """--- +license: apache-2.0 +tags: +- trl +- ppo +- transformers +- reinforcement-learning +--- + +# {model_name} + +This is a [TRL language model](https://github.com/huggingface/trl) that has been fine-tuned with reinforcement learning to + guide the model outputs according to a value, function, or human feedback. The model can be used for text generation. + +## Usage + +To use this model for inference, first install the TRL library: + +```bash +python -m pip install trl +``` + +You can then generate text as follows: + +```python +from transformers import pipeline + +generator = pipeline("text-generation", model="{model_id}") +outputs = generator("Hello, my llama is cute") +``` + +If you want to use the model for training or to obtain the outputs from the value head, load the model as follows: + +```python +from transformers import AutoTokenizer +from trl import AutoModelForCausalLMWithValueHead + +tokenizer = AutoTokenizer.from_pretrained("{model_id}") +model = AutoModelForCausalLMWithValueHead.from_pretrained("{model_id}") + +inputs = tokenizer("Hello, my llama is cute", return_tensors="pt") +outputs = model(**inputs, labels=inputs["input_ids"]) +``` +""" + + +class PPOTrainer(BaseTrainer): + """ + The PPOTrainer uses Proximal Policy Optimization to optimise language models. + Note, this trainer is heavily inspired by the original OpenAI learning to summarize work here: + https://github.com/openai/summarize-from-feedback + + Attributes: + **config** (`PPOConfig`) -- Configuration object for PPOTrainer. Check the documentation of `PPOConfig` for more + details. + **model** (`PreTrainedModelWrapper`) -- Model to be optimized, Hugging Face transformer model with a value head. + Check the documentation of `PreTrainedModelWrapper` for more details. + **ref_model** (`PreTrainedModelWrapper`, *optional*) -- Reference model to be used for KL penalty, Hugging Face + transformer model with a casual language modelling head. Check the documentation of `PreTrainedModelWrapper` + for more details. If no reference model is provided, the trainer will create a reference model with the same + architecture as the model to be optimized with shared layers. + **tokenizer** (`PreTrainedTokenizerBase`) -- Tokenizer to be used for encoding the + data. Check the documentation of `transformers.PreTrainedTokenizer` and + `transformers.PreTrainedTokenizerFast` for more details. + **dataset** (Union[`torch.utils.data.Dataset`, `datasets.Dataset`], *optional*) -- PyTorch dataset or Hugging + Face dataset. This is used to create a PyTorch dataloader. If no dataset is provided, the dataloader must be + created outside the trainer users needs to design their own dataloader and make sure the batch + size that is used is the same as the one specified in the configuration object. + **optimizer** (`torch.optim.Optimizer`, *optional*) -- Optimizer to be used for training. If no optimizer is + provided, the trainer will create an Adam optimizer with the learning rate specified in the configuration + object. + **data_collator** (DataCollatorForLanguageModeling, *optional*) -- Data collator to be used for training and + passed along the dataloader + **num_shared_layers** (int, *optional*) -- Number of layers to be shared between the model and the reference + model, if no reference model is passed. If no number is provided, all the layers will be shared. + **lr_scheduler** (`torch.optim.lr_scheduler`, *optional*) -- Learning rate scheduler to be used for training. + """ + + _tag_names = ["trl", "ppo"] + + def __init__( + self, + config: PPOConfig = None, + model: PreTrainedModelWrapper = None, + ref_model: Optional[PreTrainedModelWrapper] = None, + tokenizer: PreTrainedTokenizerBase = None, + dataset: Optional[Union[torch.utils.data.Dataset, Dataset]] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + data_collator: Optional[typing.Callable] = None, + num_shared_layers: Optional[int] = None, + lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + ): + """ + Initialize PPOTrainer. + + Args: + config (`PPOConfig`): + Configuration object for PPOTrainer. Check the documentation of `PPOConfig` for more details. + model (`PreTrainedModelWrapper`): + Hugging Face transformer model with a value head. + ref_model (`PreTrainedModelWrapper`): + Hugging Face transformer model with a casual language modelling head. Used for KL penalty + tokenizer (`transformers.PreTrainedTokenizerBase`): + Hugging Face tokenizer + dataset (Optional[Union[`torch.utils.data.Dataset`, `datasets.Dataset`]]): + PyTorch dataset or Hugging Face dataset. If a Hugging Face dataset is passed, the dataset + will be preprocessed by removing the columns that are not used by the model. If none is passed, + a warning will be raised in a multi-GPU setting. + optimizer (Optional[`torch.optim.Optimizer`]): + Optimizer used for training. If `None`, the `Adam` is used as default. + data_collator (Optional[function]): + Data collator function. + num_shared_layers (Optional[int]): + Number of shared layers between the model and the reference model. If `None`, all layers are shared. + used only if `ref_model` is `None`. + lr_scheduler (Optional[`torch.optim.lr_scheduler`]): + Learning rate scheduler used for training. + """ + super().__init__(config) + + # initial seed for reproducible experiments + set_seed(config.seed) + + # Step 0: check positional arguments validity + if not isinstance(config, PPOConfig): + raise ValueError(f"config must be a PPOConfig, got {type(config)}") + if not isinstance(tokenizer, (PreTrainedTokenizerBase)): + raise ValueError(f"tokenizer must be a PreTrainedTokenizerBase like a PreTrainedTokenizer or a PreTrainedTokenizerFast, got {type(tokenizer)}") + if not isinstance(model, (SUPPORTED_ARCHITECTURES)): + raise ValueError(f"model must be a PreTrainedModelWrapper, got {type(model)} - supported architectures are: {SUPPORTED_ARCHITECTURES}") + # Step 1: Initialize Accelerator + self.accelerator = Accelerator( + log_with=config.log_with, + gradient_accumulation_steps=config.gradient_accumulation_steps, + project_config=ProjectConfiguration(**config.project_kwargs), + **config.accelerator_kwargs, + ) + + # Step 1.1 Runtime variables filled by the accelerator + config.world_size = self.accelerator.num_processes + config.global_backward_batch_size = config.backward_batch_size * config.world_size + config.global_batch_size = config.batch_size * config.world_size + + self.model = model + self.model_params = filter(lambda p: p.requires_grad, self.model.parameters()) + self.is_encoder_decoder = hasattr(self.model, "is_encoder_decoder") + self.is_peft_model = getattr(self.model, "is_peft_model", False) + config.is_encoder_decoder = self.is_encoder_decoder + config.is_peft_model = self.is_peft_model + + is_using_tensorboard = config.log_with is not None and config.log_with == "tensorboard" + self.accelerator.init_trackers( + config.tracker_project_name, + config=dict(trl_ppo_trainer_config=config.to_dict()) if not is_using_tensorboard else config.to_dict(), + init_kwargs=config.tracker_kwargs, + ) + self.is_using_text_environment = getattr(config, "use_text_environment", False) + + if isinstance(ref_model, SUPPORTED_ARCHITECTURES): + self.ref_model = ref_model + if num_shared_layers is not None: + warnings.warn( + "num_shared_layers is ignored when ref_model is provided. Two different models are used for the " "model and the reference model and no layers are shared.", + UserWarning, + ) + elif ref_model is None and not self.is_peft_model: + self.ref_model = create_reference_model(self.model, num_shared_layers=num_shared_layers) + elif self.is_peft_model: + self.ref_model = None + else: + raise ValueError(f"ref_model must be a PreTrainedModelWrapper or `None`, got {type(ref_model)} - supported " f"architectures are: {SUPPORTED_ARCHITECTURES} ") + self.optional_peft_ctx = self.accelerator.unwrap_model(self.model).pretrained_model.disable_adapter if self.is_peft_model else nullcontext + + if not (isinstance(tokenizer, PreTrainedTokenizer) or isinstance(tokenizer, PreTrainedTokenizerFast)): + raise ValueError("tokenizer must be a transformers.PreTrainedTokenizer or transformers.PreTrainedTokenizerFast") + self.tokenizer = tokenizer + + if dataset is not None and not (isinstance(dataset, torch.utils.data.Dataset) or isinstance(dataset, Dataset)): + raise ValueError("dataset must be a torch.utils.data.Dataset or datasets.Dataset") + elif dataset is None: + warnings.warn( + "No dataset is provided. Make sure to set config.batch_size to the correct value before training.", + UserWarning, + ) + self.dataset = dataset + self._signature_columns = None + if self.dataset is not None: + self.dataloader = self.prepare_dataloader(self.dataset, data_collator) + elif self.dataset is None and self.accelerator.num_processes > 1: + warnings.warn( + "No dataset is provided. In a multi-GPU setting, this will lead to an error. You should" + " prepare your dataloader yourself with `dataloader = ppo_trainer.accelerator.prepare(dataloader)`" + " and using `torch.utils.data.DataLoader`, or pass a dataset to the `PPOTrainer`. Please " + " refer to the documentation for more details.", + UserWarning, + ) + self.dataloader = None + else: + self.dataloader = None + + # Step 3: Initialize optimizer and data collator + self.data_collator = DataCollatorForLanguageModeling(self.tokenizer, mlm=False) + if optimizer is None: + self.optimizer = Adam( + filter(lambda p: p.requires_grad, self.model.parameters()), + lr=self.config.learning_rate, + ) + else: + self.optimizer = optimizer + + self.lr_scheduler = lr_scheduler + if self.lr_scheduler is not None: + lr_scheduler_class = torch.optim.lr_scheduler._LRScheduler if not is_torch_greater_2_0() else torch.optim.lr_scheduler.LRScheduler + + if not isinstance(self.lr_scheduler, lr_scheduler_class): + raise ValueError("lr_scheduler must be a torch.optim.lr_scheduler._LRScheduler or torch.optim.lr_scheduler.LRScheduler (for torch >= 2.0)") + + if self.config.adap_kl_ctrl: + self.kl_ctl = AdaptiveKLController(self.config.init_kl_coef, self.config.target, self.config.horizon) + else: + self.kl_ctl = FixedKLController(self.config.init_kl_coef) + + # Safety checkers for DS integration + is_deepspeed_used = self.accelerator.distributed_type == "DEEPSPEED" and hasattr(self.accelerator.state, "deepspeed_plugin") + + ( + self.model, + self.optimizer, + self.data_collator, + self.dataloader, + self.lr_scheduler, + ) = self.accelerator.prepare( + self.model, + self.optimizer, + self.data_collator, + self.dataloader, + self.lr_scheduler, + ) + if is_deepspeed_used: + # Quantized models are already set on the correct device + if not self.is_peft_model and not (getattr(self.ref_model.pretrained_model, "is_loaded_in_8bit", False) or getattr(self.ref_model.pretrained_model, "is_loaded_in_4bit", False)): + self.ref_model = self._prepare_deepspeed(self.ref_model) + else: + self.ref_model = self.accelerator.prepare(self.ref_model) + + # In a distributed setup, only logging needs to be performed on the main process + # check: https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html + # or: https://discuss.pytorch.org/t/use-distributed-data-parallel-correctly/82500/11 + self.is_distributed = self.accelerator.num_processes > 1 + + # init the current step + self.current_step = 0 + + # init variables for pushing model to hub + if config.push_to_hub_if_best_kwargs: + if "repo_id" not in config.push_to_hub_if_best_kwargs: + raise ValueError("You have to specify repo_id in order to push the model to the hub!") + self.push_to_hub_kwargs = config.push_to_hub_if_best_kwargs + self.compare_step = 0 + self.highest_reward = torch.tensor(-float("inf")) + + # post process for PP + if not getattr(self.model, "is_sequential_parallel", False): + self.current_device = self.accelerator.device + else: + if is_xpu_available(): + self.current_device = torch.device("xpu:0") + elif is_npu_available(): + self.current_device = torch.device("npu:0") + else: + self.current_device = torch.device("cuda:0") + + PPODecorators.optimize_device_cache = self.config.optimize_device_cache + + self.running = RunningMoments(self.accelerator) + + def _filter_kwargs(self, kwargs, target_func): + """ + filter the keyword arguments that are supported by the target function. + + Args: + kwargs (dict): + Keyword arguments + target_func (function): + Target function + """ + return {k: v for k, v in kwargs.items() if k in inspect.signature(target_func).parameters.keys()} + + def prepare_dataloader(self, dataset: Union[torch.utils.data.Dataset, Dataset], data_collator=None): + """ + Prepare the dataloader for training. + + Args: + dataset (Union[`torch.utils.data.Dataset`, `datasets.Dataset`]): + PyTorch dataset or Hugging Face dataset. If a Hugging Face dataset is passed, the dataset + will be preprocessed by removing the columns that are not used by the model. + data_collator (Optional[function]): + Data collator function. + + Returns: + `torch.utils.data.DataLoader`: PyTorch dataloader + """ + if isinstance(dataset, Dataset): + dataset = self._remove_unused_columns(dataset) + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=self.config.batch_size, + collate_fn=data_collator, + shuffle=True, + drop_last=True, + ) + return dataloader + + # Adapted from transformers.Trainer._set_signature_columns_if_needed + def _set_signature_columns_if_needed(self): + if self._signature_columns is None: + # Inspect model forward signature to keep only the arguments it accepts. + signature = inspect.signature(self.model.forward) + self._signature_columns = list(signature.parameters.keys()) + # label => sentiment | we need query and response for logging purpose + self._signature_columns += ["label", "query", "response"] + + # Adapted from transformers.Trainer._remove_unused_columns + def _remove_unused_columns(self, dataset: "Dataset"): + if not self.config.remove_unused_columns: + return dataset + self._set_signature_columns_if_needed() + signature_columns = self._signature_columns + + ignored_columns = list(set(dataset.column_names) - set(signature_columns)) + + columns = [k for k in signature_columns if k in dataset.column_names] + + if version.parse(datasets.__version__) < version.parse("1.4.0"): + dataset.set_format( + type=dataset.format["type"], + columns=columns, + format_kwargs=dataset.format["format_kwargs"], + ) + return dataset + else: + return dataset.remove_columns(ignored_columns) + + def generate( + self, + query_tensor: Union[torch.Tensor, List[torch.Tensor]], + length_sampler: Callable = None, + batch_size: int = 4, + return_prompt: bool = True, + generate_ref_response: bool = False, + **generation_kwargs, + ): + """ + Generate response with the model given the query tensor. + call the `generate` method of the model. + + Args: + query_tensor (`torch.LongTensor`): + A tensor of shape (`seq_len`) containing query tokens or a list of tensors of shape (`seq_len`). + length_sampler (`Callable`, *optional*): + Callable that returns the number of newly generated tokens. + batch_size (`int`, *optional): + Batch size used for generation, defaults to `4`. + return_prompt (`bool`, *optional*): + If set to `False` the prompt is not returned but only the newly generated tokens, defaults to `True`. + generate_ref_response (`bool`, *optional*): + If set to `True` the reference response is also generated, defaults to `False`. + generation_kwargs (dict[str, Any]): + Keyword arguments for generation. + + Returns: + `torch.LongTensor`: A tensor of shape (`batch_size`, `gen_len`) containing response tokens. + """ + if generate_ref_response: + ref_model = self.model if self.is_peft_model else self.ref_model + if isinstance(query_tensor, List): + response = self._generate_batched( + self.model, + query_tensor, + length_sampler=length_sampler, + batch_size=batch_size, + return_prompt=return_prompt, + **generation_kwargs, + ) + if generate_ref_response: + with self.optional_peft_ctx(): + ref_response = self._generate_batched( + ref_model, + query_tensor, + length_sampler=length_sampler, + batch_size=batch_size, + return_prompt=return_prompt, + **generation_kwargs, + ) + + else: + if len(query_tensor.shape) == 2: + raise ValueError("query_tensor must be a tensor of shape (`seq_len`) or a list of tensors of shape (`seq_len`)") + + if length_sampler is not None: + generation_kwargs["max_new_tokens"] = length_sampler() + response = self.accelerator.unwrap_model(self.model).generate(input_ids=query_tensor.unsqueeze(dim=0), **generation_kwargs) + if generate_ref_response: + with self.optional_peft_ctx(): + ref_response = ref_model.generate(input_ids=query_tensor.unsqueeze(dim=0), **generation_kwargs) + + if not return_prompt and not self.is_encoder_decoder: + response = response[:, query_tensor.shape[0] :] + if generate_ref_response: + ref_response = ref_response[:, query_tensor.shape[0] :] + + if generate_ref_response: + return response, ref_response + return response + + def _generate_batched( + self, + model: PreTrainedModelWrapper, + query_tensors: List[torch.Tensor], + length_sampler: Callable = None, + batch_size: int = 4, + return_prompt: bool = True, + pad_to_multiple_of: int = None, + remove_padding: bool = True, + **generation_kwargs, + ): + outputs = [] + + padding_side_default = self.tokenizer.padding_side + if not self.is_encoder_decoder: + self.tokenizer.padding_side = "left" + + # in case we have fewer examples than bs + batch_size = min(len(query_tensors), batch_size) + + for i in range(0, len(query_tensors), batch_size): + if length_sampler is not None: + generation_kwargs["max_new_tokens"] = length_sampler() + + # prevent overflow if query tensors are not even multiple of bs + end_index = min(len(query_tensors), i + batch_size) + + batch = query_tensors[i:end_index] + batch_mask = [torch.ones_like(element) for element in batch] + inputs = {"input_ids": batch, "attention_mask": batch_mask} + + padded_inputs = self.tokenizer.pad( + inputs, + padding=True, + max_length=None, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors="pt", + ).to(self.current_device) + + generations = self.accelerator.unwrap_model(model).generate(**padded_inputs, **generation_kwargs) + + for generation, mask in zip(generations, padded_inputs["attention_mask"]): + if not self.is_encoder_decoder: + output = generation[(1 - mask).sum() :] # remove padding + else: + output = generation + + if not return_prompt and not self.is_encoder_decoder: + output = output[(mask).sum() :] # remove prompt + + if remove_padding and self.tokenizer.eos_token_id in output: + pad_mask = output == self.tokenizer.eos_token_id + pad_start = torch.nonzero(pad_mask, as_tuple=False)[0, 0].item() + output = output[: pad_start + 1] # keep the eos token at the end + + outputs.append(output) + + self.tokenizer.padding_side = padding_side_default + return outputs + + def _step_safety_checker( + self, + batch_size: int, + queries: List[torch.LongTensor], + responses: List[torch.LongTensor], + scores: List[torch.FloatTensor], + masks: Optional[List[torch.LongTensor]] = None, + ): + """ + Check if the input data is valid for training. + + Args: + batch_size (int): + Batch size from the config file. + queries (List[`torch.LongTensor`]): + List of tensors containing the encoded queries of shape (`query_length`) + responses (List[`torch.LongTensor`]): + List of tensors containing the encoded responses of shape (`response_length`) + scores (List[`torch.FloatTensor`]): + List of tensors containing the scores. + masks (List[`torch.LongTensor`], *optional*): + list of optional tensors containing the masks of shape (`query_length` + `response_length`) + Returns: + `tuple`: The input processed data. + """ + for name, tensor_list in zip(["queries", "responses", "scores"], [queries, responses, scores]): + if not isinstance(tensor_list, list): + raise ValueError(f"{name} must be a list of tensors - got {type(tensor_list)}") + if not isinstance(tensor_list[0], torch.Tensor): + raise ValueError(f"Elements in {name} must be tensors - got {type(tensor_list[0])}") + if batch_size is not None and len(tensor_list) != batch_size: + raise ValueError(f"Batch size ({batch_size}) does not match number of examples - but got {len(tensor_list)} for: {name}") + + # add queries, scores and responses on the correct device + queries = [tensor.to(self.current_device) for tensor in queries] + responses = [tensor.to(self.current_device) for tensor in responses] + scores = [tensor.to(self.current_device) for tensor in scores] + masks = [tensor.to(self.current_device) for tensor in masks] if masks is not None else None + + # squeeze scores if needed + for i, score in enumerate(scores): + if score.dim() > 1: + raise ValueError(f"Scores must be 1-dimensional - got {score.dim()} for {score}") + elif score.dim() == 1: + scores[i] = score.squeeze() + + return queries, responses, scores, masks + + @PPODecorators.empty_device_cache() + def step( + self, + queries: List[torch.LongTensor], + responses: List[torch.LongTensor], + scores: List[torch.FloatTensor], + response_masks: Optional[List[torch.LongTensor]] = None, + ): + """ + Run a PPO optimisation step given a list of queries, model responses, and rewards. + + Args: + queries (List[`torch.LongTensor`]): + List of tensors containing the encoded queries of shape (`query_length`) + responses (List[`torch.LongTensor`]): + List of tensors containing the encoded responses of shape (`response_length`) + scores (List[`torch.FloatTensor`]): + List of tensors containing the scores. + response_masks (List[`torch.FloatTensor`], *optional*)): + List of tensors containing masks of the response tokens. + + Returns: + `dict[str, Any]`: A summary of the training statistics + """ + bs = self.config.batch_size + + queries, responses, scores, response_masks = self._step_safety_checker(bs, queries, responses, scores, response_masks) + scores = torch.tensor(scores, device=self.current_device) + if self.config.use_score_scaling: + # Score scaling + scores_mean, scores_std = self.running.update(scores) + tensor_to_kwargs = dict(dtype=scores.dtype, device=scores.device) + score_scaling_factor = self.running.std.to(**tensor_to_kwargs) + torch.finfo(scores.dtype).eps + if self.config.use_score_norm: + scores = (scores - self.running.mean.to(**tensor_to_kwargs)) / score_scaling_factor + else: + scores /= score_scaling_factor + + if self.config.score_clip is not None: + # Score clipping + scores_dtype = scores.dtype + scores = torch.clip(scores.float(), -self.config.score_clip, self.config.score_clip).to(dtype=scores_dtype) + + # if we want to push best model to the hub + if hasattr(self, "highest_reward"): + if self.compare_step % self.config.compare_steps == 0: + curr_mean_reward = scores.mean() + # if the best reward ever seen + if curr_mean_reward > self.highest_reward: + self.highest_reward = curr_mean_reward + # push model to hub + self.push_to_hub(**self.push_to_hub_kwargs) + self.compare_step += 1 + + timing = dict() + t0 = time.time() + + t = time.time() + + model_inputs = self.prepare_model_inputs(queries, responses) + + if self.is_distributed: + pad_first = self.tokenizer.padding_side == "left" + + model_inputs["input_ids"] = self.accelerator.pad_across_processes( + model_inputs["input_ids"], + dim=1, + pad_index=self.tokenizer.pad_token_id, + pad_first=pad_first, + ) + model_inputs["attention_mask"] = self.accelerator.pad_across_processes(model_inputs["attention_mask"], dim=1, pad_index=0, pad_first=pad_first) + if self.is_encoder_decoder: + model_inputs["decoder_input_ids"] = self.accelerator.pad_across_processes( + model_inputs["decoder_input_ids"], + dim=1, + pad_index=self.tokenizer.pad_token_id, + pad_first=pad_first, + ) + model_inputs["decoder_attention_mask"] = self.accelerator.pad_across_processes( + model_inputs["decoder_attention_mask"], + dim=1, + pad_index=0, + pad_first=pad_first, + ) + + model_inputs_names = list(model_inputs.keys()) + + full_kl_penalty = self.config.kl_penalty == "full" + + with torch.no_grad(): + all_logprobs, logits_or_none, values, masks = self.batched_forward_pass( + self.model, + queries, + responses, + model_inputs, + response_masks=response_masks, + return_logits=full_kl_penalty, + ) + with self.optional_peft_ctx(): + ref_logprobs, ref_logits_or_none, _, _ = self.batched_forward_pass( + self.model if self.is_peft_model else self.ref_model, + queries, + responses, + model_inputs, + return_logits=full_kl_penalty, + ) + + timing["time/ppo/forward_pass"] = time.time() - t + + with torch.no_grad(): + t = time.time() + if full_kl_penalty: + active_full_logprobs = logprobs_from_logits(logits_or_none, None, gather=False) + ref_full_logprobs = logprobs_from_logits(ref_logits_or_none, None, gather=False) + + rewards, non_score_reward, kls = self.compute_rewards(scores, active_full_logprobs, ref_full_logprobs, masks) + else: + rewards, non_score_reward, kls = self.compute_rewards(scores, all_logprobs, ref_logprobs, masks) + timing["time/ppo/compute_rewards"] = time.time() - t + + t = time.time() + values, advantages, returns = self.compute_advantages(values, rewards, masks) + timing["time/ppo/compute_advantages"] = time.time() - t + + # upcast to float32 to avoid dataset issues + batch_dict = { + "queries": queries, + "responses": responses, + "logprobs": all_logprobs.to(torch.float32), + "values": values.to(torch.float32), + "masks": masks, + "advantages": advantages, + "returns": returns, + } + batch_dict.update(model_inputs) + + t = time.time() + all_stats = [] + early_stop = False + for _ in range(self.config.ppo_epochs): + if early_stop: + break + b_inds = np.random.permutation(bs) + for backward_batch_start in range(0, bs, self.config.backward_batch_size): + backward_batch_end = backward_batch_start + self.config.backward_batch_size + backward_batch_inds = b_inds[backward_batch_start:backward_batch_end] + + for mini_batch_start in range(0, self.config.backward_batch_size, self.config.mini_batch_size): + mini_batch_end = mini_batch_start + self.config.mini_batch_size + mini_batch_inds = backward_batch_inds[mini_batch_start:mini_batch_end] + mini_batch_dict = { + "logprobs": batch_dict["logprobs"][mini_batch_inds], + "values": batch_dict["values"][mini_batch_inds], + "masks": batch_dict["masks"][mini_batch_inds], + # hacks: the queries and responses are ragged. + "queries": [batch_dict["queries"][i] for i in mini_batch_inds], + "responses": [batch_dict["responses"][i] for i in mini_batch_inds], + "advantages": batch_dict["advantages"][mini_batch_inds], + "returns": batch_dict["returns"][mini_batch_inds], + } + for k in model_inputs_names: + mini_batch_dict[k] = batch_dict[k][mini_batch_inds] + with self.accelerator.accumulate(self.model): + model_inputs = {k: mini_batch_dict[k] for k in model_inputs_names} + + logprobs, logits, vpreds, _ = self.batched_forward_pass( + self.model, + mini_batch_dict["queries"], + mini_batch_dict["responses"], + model_inputs, + return_logits=True, + ) + train_stats = self.train_minibatch( + mini_batch_dict["logprobs"], + mini_batch_dict["values"], + logprobs, + logits, + vpreds, + mini_batch_dict["masks"], + mini_batch_dict["advantages"], + mini_batch_dict["returns"], + ) + all_stats.append(train_stats) + + # typically, early stopping is done at the epoch level + if self.config.early_stopping: + policykl = train_stats["policy/policykl"] + early_stop = self._early_stop(policykl) + if early_stop: + break + + timing["time/ppo/optimize_step"] = time.time() - t + + t = time.time() + train_stats = stack_dicts(all_stats) + + # reshape advantages/ratios such that they are not averaged. + train_stats["policy/advantages"] = torch.flatten(train_stats["policy/advantages"]).unsqueeze(0) + train_stats["policy/advantages"] = torch.nan_to_num(train_stats["policy/advantages"], WANDB_PADDING) + train_stats["policy/ratio"] = torch.flatten(train_stats["policy/ratio"]).unsqueeze(0) + + stats = self.record_step_stats( + scores=scores, + logprobs=all_logprobs, + ref_logprobs=ref_logprobs, + non_score_reward=non_score_reward, + train_stats=train_stats, + kl_coef=self.kl_ctl.value, + masks=masks, + queries=queries, + responses=responses, + kls=kls, + ) + # Gather/Reduce stats from all processes + if self.is_distributed: + stats = self.gather_stats(stats) + stats = stats_to_np(stats) + timing["time/ppo/calc_stats"] = time.time() - t + stats["ppo/learning_rate"] = self.optimizer.param_groups[0]["lr"] + + # Update the KL control - multiply the batch_size by the number of processes + self.kl_ctl.update( + stats["objective/kl"], + self.config.batch_size * self.accelerator.num_processes, + ) + + # Log the total ppo time + timing["time/ppo/total"] = time.time() - t0 + stats.update(timing) + + # post-process stats for tensorboard and other loggers + if self.config.log_with != "wandb": + stats = convert_to_scalar(stats) + + if self.lr_scheduler is not None: + self.lr_scheduler.step() + + return stats + + def _early_stop(self, policykl): + r""" + Handles the early stopping logic. If the policy KL is greater than the target KL, then the gradient is zeroed and + the optimization step is skipped. + This also handles the multi-gpu case where the policy KL is averaged across all processes. + + Args: + policy_kl (torch.Tensor): + the policy KL + + Returns: + `bool`: whether to early stop or not + """ + early_stop = False + if not self.config.early_stopping: + return early_stop + + if not self.is_distributed and policykl > 1.5 * self.config.target_kl: + self.optimizer.zero_grad() + early_stop = True + elif self.is_distributed: + import torch.distributed as dist + + # Wait for all processes to finish + dist.barrier() + + # all gather the policykl + dist.all_reduce(policykl, dist.ReduceOp.SUM) + policykl /= self.accelerator.num_processes + + if policykl > 1.5 * self.config.target_kl: + self.optimizer.zero_grad() + early_stop = True + return early_stop + + def gather_stats(self, stats): + """ + Gather stats from all processes. Useful in the context of distributed training. + + Args: + stats (dict[str, Any]): + a dictionary of stats to be gathered. The stats should contain torch tensors. + + Returns: + `dict[str, Any]`: A dictionary of stats with the tensors gathered. + """ + import torch.distributed as dist + + # Wait for all processes to finish + dist.barrier() + + for k, v in stats.items(): + if isinstance(v, torch.Tensor): + dist.all_reduce(v.to(self.accelerator.device), dist.ReduceOp.SUM) + v /= self.accelerator.num_processes + stats[k] = v + return stats + + def prepare_model_inputs(self, queries: torch.Tensor, responses: torch.Tensor): + if self.is_encoder_decoder: + input_data = self.data_collator([{"input_ids": q, "attention_mask": torch.ones_like(q)} for q in queries]).to(self.current_device) + + decoder_inputs = self.data_collator([{"input_ids": r, "attention_mask": torch.ones_like(r)} for r in responses]).to(self.current_device) + + input_data["decoder_input_ids"] = decoder_inputs["input_ids"] + input_data["decoder_attention_mask"] = decoder_inputs["attention_mask"] + else: + input_ids = [torch.cat([q, r]) for q, r in zip(queries, responses)] + input_data = self.data_collator([{"input_ids": ids, "attention_mask": torch.ones_like(ids)} for ids in input_ids]).to(self.current_device) + + input_data.pop("labels", None) # we don't want to compute LM losses + return input_data + + @PPODecorators.empty_device_cache() + def batched_forward_pass( + self, + model: PreTrainedModelWrapper, + queries: torch.Tensor, + responses: torch.Tensor, + model_inputs: dict, + return_logits: bool = False, + response_masks: Optional[torch.Tensor] = None, + ): + """ + Calculate model outputs in multiple batches. + + Args: + queries (`torch.LongTensor`): + List of tensors containing the encoded queries, shape (`batch_size`, `query_length`) + responses (`torch.LongTensor`): + List of tensors containing the encoded responses, shape (`batch_size`, `response_length`) + return_logits (`bool`, *optional*, defaults to `False`): + Whether to return all_logits. Set to `False` if logits are not needed to reduce memory consumption. + Returns: + (tuple): + - all_logprobs (`torch.FloatTensor`): Log probabilities of the responses, + shape (`batch_size`, `response_length`) + - all_ref_logprobs (`torch.FloatTensor`): Log probabilities of the responses, + shape (`batch_size`, `response_length`) + - all_values (`torch.FloatTensor`): Values of the responses, shape (`batch_size`, `response_length`) + """ + bs = len(queries) + fbs = self.config.mini_batch_size + all_logprobs = [] + all_logits = [] + all_masks = [] + all_values = [] + + model.eval() + + for i in range(math.ceil(bs / fbs)): + input_kwargs = {key: value[i * fbs : (i + 1) * fbs] for key, value in model_inputs.items()} + query_batch = queries[i * fbs : (i + 1) * fbs] + response_batch = responses[i * fbs : (i + 1) * fbs] + if response_masks is not None: + response_masks_batch = response_masks[i * fbs : (i + 1) * fbs] + logits, _, values = model(**input_kwargs) + + if self.is_encoder_decoder: + input_ids = input_kwargs["decoder_input_ids"] + attention_mask = input_kwargs["decoder_attention_mask"] + else: + input_ids = input_kwargs["input_ids"] + attention_mask = input_kwargs["attention_mask"] + + logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:]) + masks = torch.zeros_like(attention_mask) + masks[:, :-1] = attention_mask[:, 1:] + + for j in range(len(query_batch)): + if self.is_encoder_decoder: + # Decoder sentence starts always in the index 1 after padding in the Enc-Dec Models + start = 1 + end = attention_mask[j, :].sum() - 1 + else: + start = len(query_batch[j]) - 1 # logprobs starts from the second query token + if attention_mask[j, 0] == 0: # offset left padding + start += attention_mask[j, :].nonzero()[0] + end = start + len(response_batch[j]) + if response_masks is not None: + response_masks_batch[j] = torch.cat((torch.zeros_like(query_batch[j]), response_masks_batch[j]))[1:] + + masks[j, :start] = 0 + masks[j, end:] = 0 + if response_masks is not None: + masks[j, start:end] = masks[j, start:end] * response_masks_batch[j][start:end] + + if return_logits: + all_logits.append(logits) + else: + del logits + all_values.append(values) + all_logprobs.append(logprobs) + all_masks.append(masks) + + return ( + torch.cat(all_logprobs), + torch.cat(all_logits)[:, :-1] if return_logits else None, + torch.cat(all_values)[:, :-1], + torch.cat(all_masks)[:, :-1], + ) + + @PPODecorators.empty_device_cache() + def train_minibatch( + self, + old_logprobs: torch.FloatTensor, + values: torch.FloatTensor, + logprobs: torch.FloatTensor, + logits: torch.FloatTensor, + vpreds: torch.FloatTensor, + mask: torch.LongTensor, + advantages: torch.FloatTensor, + returns: torch.FloatTensor, + ): + """ + Train one PPO minibatch + + Args: + logprobs (`torch.FloatTensor`): + Log probabilities of the model, shape [mini_batch_size, response_length] + values (`torch.FloatTensor`): + Values of the value head, shape [mini_batch_size, response_length] + query (`torch.LongTensor`): + Encoded queries, shape [mini_batch_size, query_length] + response (`torch.LongTensor`): + Encoded responses, shape [mini_batch_size, response_length] + model_input (`torch.LongTensor`): + Concatenated queries and responses, shape [mini_batch_size, query_length+response_length] + + Returns: + train_stats (dict[str, `torch.Tensor`]): + Dictionary of training statistics + """ + self.model.train() + loss_p, loss_v, train_stats = self.loss(old_logprobs, values, logits, vpreds, logprobs, mask, advantages, returns) + loss = loss_p + loss_v + self.accelerator.backward(loss) + if self.config.max_grad_norm is not None: + if self.accelerator.sync_gradients: + self.accelerator.clip_grad_norm_(self.model_params, self.config.max_grad_norm) + self.optimizer.step() + # we call optimizer.zero_grad() every time and let `accelerator` handle accumulation + # see https://huggingface.co/docs/accelerate/usage_guides/gradient_accumulation#the-finished-code + self.optimizer.zero_grad() + return train_stats + + def compute_rewards( + self, + scores: torch.FloatTensor, + logprobs: torch.FloatTensor, + ref_logprobs: torch.FloatTensor, + masks: torch.LongTensor, + ): + """ + Compute per token rewards from scores and KL-penalty. + + Args: + scores (`torch.FloatTensor`): + Scores from the reward model, shape (`batch_size`) + logprobs (`torch.FloatTensor`): + Log probabilities of the model, shape (`batch_size`, `response_length`) + ref_logprobs (`torch.FloatTensor`): + Log probabilities of the reference model, shape (`batch_size`, `response_length`) + + Returns: + `torch.FloatTensor`: Per token rewards, shape (`batch_size`, `response_length`) + `torch.FloatTensor`: Non score rewards, shape (`batch_size`, `response_length`) + `torch.FloatTensor`: KL penalty, shape (`batch_size`, `response_length`) + """ + rewards, non_score_rewards, kls = [], [], [] + for score, logprob, ref_logprob, mask in zip(scores, logprobs, ref_logprobs, masks): + # compute KL penalty (from difference in logprobs) + kl = self._kl_penalty(logprob, ref_logprob) + kls.append(kl) + non_score_reward = -self.kl_ctl.value * kl + non_score_rewards.append(non_score_reward) + reward = non_score_reward.clone() + last_non_masked_index = mask.nonzero()[-1] + + # reward is preference model score + KL penalty + reward[last_non_masked_index] += score + rewards.append(reward) + return torch.stack(rewards), torch.stack(non_score_rewards), torch.stack(kls) + + def _kl_penalty(self, logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor) -> torch.FloatTensor: + if self.config.kl_penalty == "kl": + return logprob - ref_logprob + + if self.config.kl_penalty == "abs": + return (logprob - ref_logprob).abs() + + if self.config.kl_penalty == "mse": + return 0.5 * (logprob - ref_logprob).square() + + if self.config.kl_penalty == "full": + # Flip is required due to this issue? :https://github.com/pytorch/pytorch/issues/57459 + return F.kl_div(ref_logprob, logprob, log_target=True, reduction="none").sum(-1) + + raise NotImplementedError + + def compute_advantages( + self, + values: torch.FloatTensor, + rewards: torch.FloatTensor, + mask: torch.FloatTensor, + ): + lastgaelam = 0 + advantages_reversed = [] + gen_len = rewards.shape[-1] + + values = values * mask + rewards = rewards * mask + + if self.config.whiten_rewards: + rewards = masked_whiten(rewards, mask, shift_mean=False) + + for t in reversed(range(gen_len)): + nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0 + delta = rewards[:, t] + self.config.gamma * nextvalues - values[:, t] + lastgaelam = delta + self.config.gamma * self.config.lam * lastgaelam + advantages_reversed.append(lastgaelam) + advantages = torch.stack(advantages_reversed[::-1]).transpose(0, 1) + + returns = advantages + values + advantages = masked_whiten(advantages, mask) + advantages = advantages.detach() + return values, advantages, returns + + def loss( + self, + old_logprobs: torch.FloatTensor, + values: torch.FloatTensor, + logits: torch.FloatTensor, + vpreds: torch.FloatTensor, + logprobs: torch.FloatTensor, + mask: torch.LongTensor, + advantages: torch.FloatTensor, + returns: torch.FloatTensor, + ): + """ + Calculate policy and value losses. + + Args: + old_logprobs (`torch.FloatTensor`): + Log probabilities of the model, shape (`batch_size`, `response_length`) + values (`torch.FloatTensor`): + Values of the value head, shape (`batch_size`, `response_length`) + rewards (`torch.FloatTensor`): + Rewards from the reward model, shape (`batch_size`, `response_length`) + logits (`torch.FloatTensor`): + Logits of the model, shape (`batch_size`, `response_length`, `vocab_size`) + v_pred (`torch.FloatTensor`): + Values of the value head, shape (`batch_size`, `response_length`) + logprobs (`torch.FloatTensor`): + Log probabilities of the model, shape (`batch_size`, `response_length`) + """ + + vpredclipped = clip_by_value( + vpreds, + values - self.config.cliprange_value, + values + self.config.cliprange_value, + ) + + vf_losses1 = (vpreds - returns) ** 2 + vf_losses2 = (vpredclipped - returns) ** 2 + vf_loss = 0.5 * masked_mean(torch.max(vf_losses1, vf_losses2), mask) + vf_clipfrac = masked_mean(torch.gt(vf_losses2, vf_losses1).float(), mask) + + ratio = torch.exp(logprobs - old_logprobs) + + pg_losses = -advantages * ratio + pg_losses2 = -advantages * torch.clamp(ratio, 1.0 - self.config.cliprange, 1.0 + self.config.cliprange) + + pg_loss = masked_mean(torch.max(pg_losses, pg_losses2), mask) + pg_clipfrac = masked_mean(torch.gt(pg_losses2, pg_losses).float(), mask) + + loss = pg_loss + self.config.vf_coef * vf_loss + + avg_ratio = masked_mean(ratio, mask).item() + if avg_ratio > self.config.ratio_threshold: + warnings.warn(f"The average ratio of batch ({avg_ratio:.2f}) exceeds threshold {self.config.ratio_threshold:.2f}. Skipping batch.") + pg_loss = pg_loss * 0.0 + vf_loss = vf_loss * 0.0 + loss = loss * 0.0 + + entropy = masked_mean(entropy_from_logits(logits), mask) + + approxkl = 0.5 * masked_mean((logprobs - old_logprobs) ** 2, mask) + policykl = masked_mean(old_logprobs - logprobs, mask) + + return_mean, return_var = masked_mean(returns, mask), masked_var(returns, mask) + value_mean, value_var = masked_mean(values, mask), masked_var(values, mask) + + stats = dict( + loss=dict(policy=pg_loss.detach(), value=vf_loss.detach(), total=loss.detach()), + policy=dict( + entropy=entropy.detach(), + approxkl=approxkl.detach(), + policykl=policykl.detach(), + clipfrac=pg_clipfrac.detach(), + advantages=advantages.detach(), + advantages_mean=masked_mean(advantages, mask).detach(), + ratio=ratio.detach(), + ), + returns=dict(mean=return_mean.detach(), var=return_var.detach()), + val=dict( + vpred=masked_mean(vpreds, mask).detach(), + error=masked_mean((vpreds - returns) ** 2, mask).detach(), + clipfrac=vf_clipfrac.detach(), + mean=value_mean.detach(), + var=value_var.detach(), + ), + ) + return pg_loss, self.config.vf_coef * vf_loss, flatten_dict(stats) + + def record_step_stats(self, kl_coef: float, **data): + """ + Record training step statistics. + + + Args: + kl_coef (`float`): + KL coefficient + data (`dict`): + Dictionary of training step data + + Returns: + stats (`dict`): + Dictionary of training step statistics + """ + mask = data.pop("masks") + + kls = data.pop("kls") + kl_list = ((kls) * mask).sum(axis=-1) + mean_kl = kl_list.mean() + mean_entropy = (-data["logprobs"] * mask).sum(axis=-1).mean() + + mean_non_score_reward = masked_mean(data["non_score_reward"], mask) # non_score_reward is size `batch_size`, `response_length` + mean_scores = data["scores"].mean() # scores is size `batch_size` + std_scores = data["scores"].std() + + if mean_kl.item() < -1.0: + # warn users + warnings.warn( + f"KL divergence is starting to become negative: {mean_kl.item():.2f} - this might be a precursor for failed training." + " sometimes this happens because the generation kwargs are not correctly set. Please make sure" + " that the generation kwargs are set correctly, or review your training hyperparameters." + ) + + stats = { + "objective/kl": mean_kl, + "objective/kl_dist": kl_list, + "objective/logprobs": data["logprobs"], + "objective/ref_logprobs": data["ref_logprobs"], + "objective/kl_coef": kl_coef, + "objective/entropy": mean_entropy, + "ppo/mean_non_score_reward": mean_non_score_reward, + "ppo/mean_scores": mean_scores, + "ppo/std_scores": std_scores, + } + + # Log text properties + query_lens = torch.tensor([len(query) for query in data["queries"]], dtype=torch.float) + response_lens = torch.tensor([len(response) for response in data["responses"]], dtype=torch.float) + + stats["tokens/queries_len_mean"] = torch.mean(query_lens).cpu().numpy().item() + stats["tokens/queries_len_std"] = torch.std(query_lens).cpu().numpy().item() + stats["tokens/queries_dist"] = query_lens.cpu().numpy() + stats["tokens/responses_len_mean"] = torch.mean(response_lens).cpu().numpy().item() + stats["tokens/responses_len_std"] = torch.std(response_lens).cpu().numpy().item() + stats["tokens/responses_dist"] = response_lens.cpu().numpy() + + for k, v in data["train_stats"].items(): + stats[f"ppo/{k}"] = torch.mean(v, axis=0) + stats["ppo/val/var_explained"] = 1 - stats["ppo/val/error"] / stats["ppo/returns/var"] + return stats + + def log_stats( + self, + stats: dict, + batch: dict, + rewards: List[torch.FloatTensor], + columns_to_log: List[str] = ["query", "response"], + ): + """ + A function that logs all the training stats. Call it at the end of each epoch. + + Args: + stats (dict[str, Any]): + A dictionary of training stats. + batch (dict[str, Any]): + A dictionary of batch data, this contains the queries and responses. + rewards (`List[torch.FloatTensor]`): + A tensor of rewards. + """ + + # all gather stats + if not isinstance(rewards, torch.Tensor): + rewards = torch.tensor(rewards).to(self.current_device) + rewards = self.accelerator.gather(rewards).flatten() + + if self.config.log_with == "wandb": + import wandb + + if any([column_to_log not in batch.keys() for column_to_log in columns_to_log]): + raise ValueError(f"Columns to log {columns_to_log} are not present in the batch {batch.keys()}.") + + batch_list = [batch[column_to_log] for column_to_log in columns_to_log] + if self.is_distributed: + gathered_batch_list = [] + for b in batch_list: + flattened = gather_object(b) + gathered_batch_list.append(flattened) + batch_list = gathered_batch_list + + # Log only if we are in the main process + if self.accelerator.is_main_process: + logs = {} + + # Log stats + if "query" not in batch.keys() and "response" not in batch.keys(): + # warn the user that the game logs will not be logged + warnings.warn("The game logs will not be logged because the batch does not contain the keys 'query' and " "'response'. ") + elif self.config.log_with == "wandb": + table_rows = [list(r) for r in zip(*batch_list, rewards.cpu().tolist())] + logs.update({"game_log": wandb.Table(columns=[*columns_to_log, "reward"], rows=table_rows)}) + + logs.update(stats) + + # manually cast in fp32 for bf16 torch tensors + for k, v in logs.items(): + if isinstance(v, torch.Tensor) and v.dtype == torch.bfloat16: + logs[k] = v.float() + + logs["env/reward_mean"] = torch.mean(rewards).cpu().numpy().item() + logs["env/reward_std"] = torch.std(rewards).cpu().numpy().item() + logs["env/reward_dist"] = rewards.cpu().numpy() + + if self.config.log_with == "tensorboard": + # update the current step + self.current_step += 1 + + self.accelerator.log( + logs, + step=self.current_step if self.config.log_with == "tensorboard" else None, + ) + + def create_model_card(self, path: str, model_name: Optional[str] = "TRL Model") -> None: + """Creates and saves a model card for a TRL model. + + Args: + path (`str`): The path to save the model card to. + model_name (`str`, *optional*): The name of the model, defaults to `TRL Model`. + """ + try: + user = whoami()["name"] + # handle the offline case + except: # noqa + warnings.warn("Cannot retrieve user information assuming you are running in offline mode.") + return + + if not os.path.exists(path): + os.makedirs(path) + + model_card_content = MODEL_CARD_TEMPLATE.format(model_name=model_name, model_id=f"{user}/{path}") + with open(os.path.join(path, "README.md"), "w", encoding="utf-8") as f: + f.write(model_card_content) + + def _save_pretrained(self, save_directory: str) -> None: + self.accelerator.unwrap_model(self.model).save_pretrained(save_directory) + self.tokenizer.save_pretrained(save_directory) + self.create_model_card(save_directory) + + def _show_tokens(self, tokens, masks): + from rich import print + from rich.text import Text + + text = Text() + + for i, (token, mask) in enumerate(zip(tokens, masks)): + if mask == 1: + text.append(self.tokenizer.decode(token.item()), style="black on deep_sky_blue1") + text.append(" ") + else: + text.append(self.tokenizer.decode(token.item()), style="black on cyan3") + text.append(" ") + print(text) + + def _prepare_deepspeed(self, model: PreTrainedModelWrapper): + # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473 + deepspeed_plugin = self.accelerator.state.deepspeed_plugin + config_kwargs = deepspeed_plugin.deepspeed_config + if model is not None: + if hasattr(model, "config"): + hidden_size = max(model.config.hidden_sizes) if getattr(model.config, "hidden_sizes", None) else getattr(model.config, "hidden_size", None) + if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3: + # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0` + # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081 + config_kwargs.update( + { + "zero_optimization.reduce_bucket_size": hidden_size * hidden_size, + "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size, + "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size, + } + ) + + # If ZeRO-3 is used, we shard both the active and reference model. + # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0) + if config_kwargs["zero_optimization"]["stage"] != 3: + config_kwargs["zero_optimization"]["stage"] = 0 + model, *_ = deepspeed.initialize(model=model, config=config_kwargs) + model.eval() + return model diff --git a/trl/trainer/reward_config.py b/trl/trainer/reward_config.py new file mode 100755 index 000000000..32c7e264d --- /dev/null +++ b/trl/trainer/reward_config.py @@ -0,0 +1,38 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Optional + +from transformers import TrainingArguments + + +@dataclass +class RewardConfig(TrainingArguments): + """ + RewardConfig collects all training arguments related to the [`RewardTrainer`] class. + + Using [`HfArgumentParser`] we can turn this class into + [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the + command line. + + Parameters: + max_length (`int`, *optional*, defaults to `None`): + The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator. + gradient_checkpointing (`bool`, *optional*, defaults to `True`): + If True, use gradient checkpointing to save memory at the expense of slower backward pass. + """ + + max_length: Optional[int] = None + """The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator.""" diff --git a/trl/trainer/reward_trainer.py b/trl/trainer/reward_trainer.py new file mode 100755 index 000000000..eff4eb7ab --- /dev/null +++ b/trl/trainer/reward_trainer.py @@ -0,0 +1,257 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +import warnings +from dataclasses import FrozenInstanceError, replace +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from datasets import Dataset +from transformers import DataCollator, PreTrainedModel, PreTrainedTokenizerBase, Trainer, TrainingArguments +from transformers.trainer_callback import TrainerCallback +from transformers.trainer_pt_utils import nested_detach +from transformers.trainer_utils import EvalPrediction + +from ..import_utils import is_peft_available +from .reward_config import RewardConfig +from .utils import RewardDataCollatorWithPadding, compute_accuracy + + +if is_peft_available(): + from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training + + +class RewardTrainer(Trainer): + r""" + The RewardTrainer can be used to train your custom Reward Model. It is a subclass of the + `transformers.Trainer` class and inherits all of its attributes and methods. It is recommended to use + an `AutoModelForSequenceClassification` as the reward model. The reward model should be trained on a dataset + of paired examples, where each example is a tuple of two sequences. The reward model should be trained to + predict which example in the pair is more relevant to the task at hand. + + The reward trainer expects a very specific format for the dataset. The dataset should contain two 4 entries at least + if you don't use the default `RewardDataCollatorWithPadding` data collator. The entries should be named + - `input_ids_chosen` + - `attention_mask_chosen` + - `input_ids_rejected` + - `attention_mask_rejected` + + Optionally, you can also pass a `margin` entry to the dataset. This entry should contain the margin used to modulate the + loss of the reward model as outlined in https://ai.meta.com/research/publications/llama-2-open-foundation-and-fine-tuned-chat-models/. + If you don't pass a margin, no margin will be used. + """ + + def __init__( + self, + model: Union[PreTrainedModel, nn.Module] = None, + args: Optional[RewardConfig] = None, + data_collator: Optional[DataCollator] = None, + train_dataset: Optional[Dataset] = None, + eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, + tokenizer: Optional[PreTrainedTokenizerBase] = None, + model_init: Optional[Callable[[], PreTrainedModel]] = None, + compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, + callbacks: Optional[List[TrainerCallback]] = None, + optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = ( + None, + None, + ), + preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, + max_length: Optional[int] = None, + peft_config: Optional[Dict] = None, + ): + """ + Initialize RewardTrainer. + + Args: + model (`transformers.PreTrainedModel`): + The model to train, preferably an `AutoModelForSequenceClassification`. + args (`RewardConfig`): + The arguments to use for training. + data_collator (`transformers.DataCollator`): + The data collator to use for training. If None is specified, the default data collator (`RewardDataCollatorWithPadding`) will be used + which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences. + train_dataset (`datasets.Dataset`): + The dataset to use for training. + eval_dataset (`datasets.Dataset`): + The dataset to use for evaluation. + tokenizer (`transformers.PreTrainedTokenizerBase`): + The tokenizer to use for training. This argument is required if you want to use the default data collator. + model_init (`Callable[[], transformers.PreTrainedModel]`): + The model initializer to use for training. If None is specified, the default model initializer will be used. + compute_metrics (`Callable[[transformers.EvalPrediction], Dict]`, *optional* defaults to `compute_accuracy`): + The metrics to use for evaluation. If no metrics are specified, the default metric (`compute_accuracy`) will be used. + callbacks (`List[transformers.TrainerCallback]`): + The callbacks to use for training. + optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): + The optimizer and scheduler to use for training. + preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): + The function to use to preprocess the logits before computing the metrics. + max_length (`int`, defaults to `None`): + The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator. + peft_config (`Dict`, defaults to `None`): + The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model. + """ + if type(args) == TrainingArguments: + warnings.warn( + "Using `transformers.TrainingArguments` for `args` is deprecated and will be removed in a future version. Please use `RewardConfig` instead.", + FutureWarning, + ) + if max_length is not None: + warnings.warn( + "The `max_length` argument is deprecated and will be removed in a future version. Please use the `RewardConfig` to set `max_length` instead.", + FutureWarning, + ) + else: + if max_length is not None and args.max_length is not None: + raise ValueError("You cannot specify both `max_length` and `args.max_length`. Please use the `RewardConfig` to set `max_length` once.") + if max_length is not None and args.max_length is None: + warnings.warn( + "The `max_length` argument is deprecated and will be removed in a future version. Please use the `RewardConfig` to set `max_length` instead.", + FutureWarning, + ) + if not is_peft_available() and peft_config is not None: + raise ValueError("PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models") + elif is_peft_available() and peft_config is not None: + if not isinstance(model, PeftModel): + if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_quantized", False): + _supports_gc_kwargs = "gradient_checkpointing_kwargs" in list(inspect.signature(prepare_model_for_kbit_training).parameters) + + preprare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing} + + if not _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None: + warnings.warn("You passed `gradient_checkpointing_kwargs` in the trainer's kwargs, but your peft version does not support it. " "please update to the latest version of peft to use `gradient_checkpointing_kwargs`.") + elif _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None: + preprare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs + + model = prepare_model_for_kbit_training(model, **preprare_model_kwargs) + + model = get_peft_model(model, peft_config) + + if compute_metrics is None: + compute_metrics = compute_accuracy + + if data_collator is None: + if tokenizer is None: + raise ValueError("max_length or a tokenizer must be specified when using the default RewardDataCollatorWithPadding") + if type(args) == TrainingArguments: + if max_length is None: + warnings.warn( + "When using RewardDataCollatorWithPadding, you should set `max_length` in RewardConfig." " It will be set to `512` by default, but you should do it yourself in the future.", + UserWarning, + ) + max_length = 512 + else: + if max_length is None and args.max_length is None: + warnings.warn( + "When using RewardDataCollatorWithPadding, you should set `max_length` in RewardConfig." " It will be set to `512` by default, but you should do it yourself in the future.", + UserWarning, + ) + max_length = 512 + if max_length is None and args.max_length is not None: + max_length = args.max_length + + data_collator = RewardDataCollatorWithPadding(tokenizer, max_length=max_length) + + if args.remove_unused_columns: + try: # for bc before https://github.com/huggingface/transformers/pull/25435 + args.remove_unused_columns = False + except FrozenInstanceError: + args = replace(args, remove_unused_columns=False) + # warn users + warnings.warn( + "When using RewardDataCollatorWithPadding, you should set `remove_unused_columns=False` in your RewardConfig" " we have set it for you, but you should do it yourself in the future.", + UserWarning, + ) + + self.use_reward_data_collator = True + else: + self.use_reward_data_collator = False + super().__init__( + model, + args, + data_collator, + train_dataset, + eval_dataset, + tokenizer, + model_init, + compute_metrics, + callbacks, + optimizers, + preprocess_logits_for_metrics, + ) + + def compute_loss( + self, + model: Union[PreTrainedModel, nn.Module], + inputs: Dict[str, Union[torch.Tensor, Any]], + return_outputs=False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]: + if not self.use_reward_data_collator: + warnings.warn("The current compute_loss is implemented for RewardDataCollatorWithPadding," " if you are using a custom data collator make sure you know what you are doing or" " implement your own compute_loss method.") + rewards_chosen = model( + input_ids=inputs["input_ids_chosen"], + attention_mask=inputs["attention_mask_chosen"], + return_dict=True, + )["logits"] + rewards_rejected = model( + input_ids=inputs["input_ids_rejected"], + attention_mask=inputs["attention_mask_rejected"], + return_dict=True, + )["logits"] + # calculate loss, optionally modulate with margin + if "margin" in inputs: + loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected - inputs["margin"]).mean() + else: + loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected).mean() + + if return_outputs: + return loss, { + "rewards_chosen": rewards_chosen, + "rewards_rejected": rewards_rejected, + } + return loss + + def prediction_step( + self, + model: Union[PreTrainedModel, nn.Module], + inputs: Dict[str, Union[torch.Tensor, Any]], + prediction_loss_only: bool, + ignore_keys: Optional[List[str]] = None, + ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: + inputs = self._prepare_inputs(inputs) + if ignore_keys is None: + if hasattr(self.model, "config"): + ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", []) + else: + ignore_keys = [] + + with torch.no_grad(): + loss, logits_dict = self.compute_loss(model, inputs, return_outputs=True) + + if prediction_loss_only: + return (loss, None, None) + + loss = loss.detach() + logits = tuple(v for k, v in logits_dict.items() if k not in ignore_keys) + logits = nested_detach(logits) + # Stack accepted against rejected, mean over logits + # and softmax to get preferences between accepted and rejected to sum to 1 + logits = torch.stack(logits).mean(dim=2).softmax(dim=0).T + + labels = torch.zeros(logits.shape[0]) + labels = self._prepare_inputs(labels) + + return loss, logits, labels diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py new file mode 100755 index 000000000..69f32750b --- /dev/null +++ b/trl/trainer/sft_trainer.py @@ -0,0 +1,480 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import dataclasses +import inspect +import warnings +from functools import wraps +from typing import Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from accelerate.state import PartialState +from datasets import Dataset +from datasets.arrow_writer import SchemaInferenceError +from datasets.builder import DatasetGenerationError +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + DataCollator, + DataCollatorForLanguageModeling, + PreTrainedModel, + PreTrainedTokenizerBase, + Trainer, + TrainingArguments, +) +from transformers.modeling_utils import unwrap_model +from transformers.trainer_callback import TrainerCallback +from transformers.trainer_utils import EvalPrediction + +from ..extras.dataset_formatting import get_formatting_func_from_dataset +from ..import_utils import is_peft_available +from .utils import ( + ConstantLengthDataset, + DataCollatorForCompletionOnlyLM, + neftune_post_forward_hook, + peft_module_casting_to_bf16, + trl_sanitze_kwargs_for_tagging, +) + + +if is_peft_available(): + from peft import PeftConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training + + +class SFTTrainer(Trainer): + r""" + Class definition of the Supervised Finetuning Trainer (SFT Trainer). + This class is a wrapper around the `transformers.Trainer` class and inherits all of its attributes and methods. + The trainer takes care of properly initializing the PeftModel in case a user passes a `PeftConfig` object. + + Args: + model (Union[`transformers.PreTrainedModel`, `nn.Module`, `str`]): + The model to train, can be a `PreTrainedModel`, a `torch.nn.Module` or a string with the model name to + load from cache or download. The model can be also converted to a `PeftModel` if a `PeftConfig` object is + passed to the `peft_config` argument. + args (Optional[`transformers.TrainingArguments`]): + The arguments to tweak for training. Please refer to the official documentation of `transformers.TrainingArguments` + for more information. + data_collator (Optional[`transformers.DataCollator`]): + The data collator to use for training. + train_dataset (Optional[`datasets.Dataset`]): + The dataset to use for training. We recommend users to use `trl.trainer.ConstantLengthDataset` to create their dataset. + eval_dataset (Optional[Union[`datasets.Dataset`, Dict[`str`, `datasets.Dataset`]]]): + The dataset to use for evaluation. We recommend users to use `trl.trainer.ConstantLengthDataset` to create their dataset. + tokenizer (Optional[`transformers.PreTrainedTokenizer`]): + The tokenizer to use for training. If not specified, the tokenizer associated to the model will be used. + model_init (`Callable[[], transformers.PreTrainedModel]`): + The model initializer to use for training. If None is specified, the default model initializer will be used. + compute_metrics (`Callable[[transformers.EvalPrediction], Dict]`, *optional* defaults to None): + The function used to compute metrics during evaluation. It should return a dictionary mapping metric names to metric values. + If not specified, only the loss will be computed during evaluation. + callbacks (`List[transformers.TrainerCallback]`): + The callbacks to use for training. + optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): + The optimizer and scheduler to use for training. + preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): + The function to use to preprocess the logits before computing the metrics. + peft_config (`Optional[PeftConfig]`): + The PeftConfig object to use to initialize the PeftModel. + dataset_text_field (`Optional[str]`): + The name of the text field of the dataset, in case this is passed by a user, the trainer will automatically create a + `ConstantLengthDataset` based on the `dataset_text_field` argument. + formatting_func (`Optional[Callable]`): + The formatting function to be used for creating the `ConstantLengthDataset`. + max_seq_length (`Optional[int]`): + The maximum sequence length to use for the `ConstantLengthDataset` and for automatically creating the Dataset. Defaults to `512`. + infinite (`Optional[bool]`): + Whether to use an infinite dataset or not. Defaults to `False`. + num_of_sequences (`Optional[int]`): + The number of sequences to use for the `ConstantLengthDataset`. Defaults to `1024`. + chars_per_token (`Optional[float]`): + The number of characters per token to use for the `ConstantLengthDataset`. Defaults to `3.6`. You can check how this is computed in the + stack-llama example: https://github.com/huggingface/trl/blob/08f550674c553c36c51d1027613c29f14f3676a5/examples/stack_llama/scripts/supervised_finetuning.py#L53. + packing (`Optional[bool]`): + Used only in case `dataset_text_field` is passed. This argument is used by the `ConstantLengthDataset` to pack the sequences + of the dataset. + dataset_num_proc (`Optional[int]`): + The number of workers to use to tokenize the data. Only used when `packing=False`. Defaults to None. + dataset_batch_size (`int`): + The number of examples to tokenize per batch. If batch_size <= 0 or batch_size == None, + tokenize the full dataset as a single batch. Defaults to 1000. + neftune_noise_alpha (`Optional[float]`): + If not `None`, this will activate NEFTune noise embeddings. This has been proven to drastically improve model performances for instruction + fine-tuning. Check out the original paper here: https://arxiv.org/abs/2310.05914 and the original code here: https://github.com/neelsjain/NEFTune + model_init_kwargs: (`Optional[Dict]`, *optional*): + Dict of Optional kwargs to pass when instantiating the model from a string + dataset_kwargs: (`Optional[Dict]`, *optional*): + Dict of Optional kwargs to pass when creating packed or non-packed datasets + """ + + _tag_names = ["trl", "sft"] + + def __init__( + self, + model: Union[PreTrainedModel, nn.Module, str] = None, + args: TrainingArguments = None, + data_collator: Optional[DataCollator] = None, + train_dataset: Optional[Dataset] = None, + eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, + tokenizer: Optional[PreTrainedTokenizerBase] = None, + model_init: Optional[Callable[[], PreTrainedModel]] = None, + compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, + callbacks: Optional[List[TrainerCallback]] = None, + optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), + preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, + peft_config: Optional["PeftConfig"] = None, + dataset_text_field: Optional[str] = None, + packing: Optional[bool] = False, + formatting_func: Optional[Callable] = None, + max_seq_length: Optional[int] = None, + infinite: Optional[bool] = None, + num_of_sequences: Optional[int] = 1024, + chars_per_token: Optional[float] = 3.6, + dataset_num_proc: Optional[int] = None, + dataset_batch_size: int = 1000, + neftune_noise_alpha: Optional[float] = None, + model_init_kwargs: Optional[Dict] = None, + dataset_kwargs: Optional[Dict] = None, + ): + if model_init_kwargs is None: + model_init_kwargs = {} + elif not isinstance(model, str): + raise ValueError("You passed model_kwargs to the SFTTrainer. But your model is already instantiated.") + + if infinite is not None: + warnings.warn("The `infinite` argument is deprecated and will be removed in a future version of TRL. Use `TrainingArguments.max_steps` or `TrainingArguments.num_train_epochs` instead to control training length.") + + if isinstance(model, str): + warnings.warn("You passed a model_id to the SFTTrainer. This will automatically create an " "`AutoModelForCausalLM` or a `PeftModel` (if you passed a `peft_config`) for you.") + model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs) + + if packing and data_collator is not None and isinstance(data_collator, DataCollatorForCompletionOnlyLM): + raise ValueError("You passed a `DataCollatorForCompletionOnlyLM` to the SFTTrainer. This is not compatible with the `packing` argument.") + + if is_peft_available() and peft_config is not None: + if not isinstance(peft_config, PeftConfig): + raise ValueError("If you want to use the PeftModel, you need to pass a PeftConfig object to the SFTTrainer." f" and you passed a {type(peft_config)}.") + + if not isinstance(model, PeftModel): + _support_gc_kwargs = hasattr(args, "gradient_checkpointing_kwargs") and "gradient_checkpointing_kwargs" in list(inspect.signature(prepare_model_for_kbit_training).parameters) + gradient_checkpointing_kwargs = getattr(args, "gradient_checkpointing_kwargs", None) or {} + if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): + preprare_model_kwargs = {"use_gradient_checkpointing": getattr(args, "gradient_checkpointing", False)} + + if _support_gc_kwargs: + preprare_model_kwargs["gradient_checkpointing_kwargs"] = gradient_checkpointing_kwargs + + model = prepare_model_for_kbit_training(model, **preprare_model_kwargs) + + if args is not None: + args = dataclasses.replace(args, gradient_checkpointing=False) + elif getattr(args, "gradient_checkpointing", False) and ("use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"]): + # For backward compatibility with older versions of transformers + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + model = get_peft_model(model, peft_config) + if args.bf16 and getattr(model, "is_loaded_in_4bit", False): + peft_module_casting_to_bf16(model) + + if tokenizer is None: + tokenizer = AutoTokenizer.from_pretrained(model.config._name_or_path) + if getattr(tokenizer, "pad_token", None) is None: + tokenizer.pad_token = tokenizer.eos_token + + if max_seq_length is None: + # to overcome some issues with broken tokenizers + max_seq_length = min(tokenizer.model_max_length, 1024) + + warnings.warn(f"You didn't pass a `max_seq_length` argument to the SFTTrainer, this will default to {max_seq_length}") + + self.dataset_num_proc = dataset_num_proc + self.dataset_batch_size = dataset_batch_size + + self._trainer_supports_neftune = hasattr(args, "neftune_noise_alpha") + + if neftune_noise_alpha is not None and self._trainer_supports_neftune: + args.neftune_noise_alpha = neftune_noise_alpha + warnings.warn("You passed a `neftune_noise_alpha` argument to the SFTTrainer, the value you passed will override the one in the `TrainingArguments`.") + # self.neftune_noise_alpha is done at Trainer level + elif not self._trainer_supports_neftune: + self.neftune_noise_alpha = neftune_noise_alpha + + if formatting_func is None and dataset_text_field is None: + # check if dataset has ChatML format or instruction format and is supported + # if not stays #None + formatting_func = get_formatting_func_from_dataset(train_dataset, tokenizer) + + if not packing: + if dataset_text_field is None and formatting_func is None: + raise ValueError("You passed `packing=False` to the SFTTrainer, but you didn't pass a `dataset_text_field` or `formatting_func` argument.") + + if data_collator is None: + data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) + + # Pre-process the datasets only once per node. The remaining processes will use the cache. + with PartialState().local_main_process_first(): + if dataset_kwargs is None: + dataset_kwargs = {} + if train_dataset is not None: + train_dataset = self._prepare_dataset( + train_dataset, + tokenizer, + packing, + dataset_text_field, + max_seq_length, + formatting_func, + num_of_sequences, + chars_per_token, + remove_unused_columns=args.remove_unused_columns if args is not None else True, + **dataset_kwargs, + ) + if eval_dataset is not None: + _multiple = isinstance(eval_dataset, dict) + _eval_datasets = eval_dataset if _multiple else {"singleton": eval_dataset} + for _eval_dataset_name, _eval_dataset in _eval_datasets.items(): + _eval_datasets[_eval_dataset_name] = self._prepare_dataset( + _eval_dataset, + tokenizer, + packing, + dataset_text_field, + max_seq_length, + formatting_func, + num_of_sequences, + chars_per_token, + remove_unused_columns=args.remove_unused_columns if args is not None else True, + **dataset_kwargs, + ) + if not _multiple: + eval_dataset = _eval_datasets["singleton"] + + if tokenizer.padding_side is not None and tokenizer.padding_side != "right": + warnings.warn( + "You passed a tokenizer with `padding_side` not equal to `right` to the SFTTrainer. This might lead to some unexpected behaviour due to " + "overflow issues when training a model in half-precision. You might consider adding `tokenizer.padding_side = 'right'` to your code." + ) + + super().__init__( + model=model, + args=args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + tokenizer=tokenizer, + model_init=model_init, + compute_metrics=compute_metrics, + callbacks=callbacks, + optimizers=optimizers, + preprocess_logits_for_metrics=preprocess_logits_for_metrics, + ) + + if self.args.max_steps > 0 and packing: + warnings.warn("You passed `packing=True` to the SFTTrainer, and you are training your model with `max_steps` strategy. The dataset will be iterated until the `max_steps` are reached.") + self.train_dataset.infinite = True + elif self.args.max_steps == -1 and packing: + self.train_dataset.infinite = False + + @wraps(Trainer.train) + def train(self, *args, **kwargs): + # Activate neftune right before training. + if self.neftune_noise_alpha is not None and not self._trainer_supports_neftune: + self.model = self._trl_activate_neftune(self.model) + + output = super().train(*args, **kwargs) + + # After training we make sure to retrieve back the original forward pass method + # for the embedding layer by removing the forward post hook. + if self.neftune_noise_alpha is not None and not self._trainer_supports_neftune: + unwrapped_model = unwrap_model(self.model) + if is_peft_available() and isinstance(unwrapped_model, PeftModel): + embeddings = unwrapped_model.base_model.model.get_input_embeddings() + else: + embeddings = unwrapped_model.get_input_embeddings() + + self.neftune_hook_handle.remove() + del embeddings.neftune_noise_alpha + + return output + + @wraps(Trainer.push_to_hub) + def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str: + """ + Overwrite the `push_to_hub` method in order to force-add the tag "sft" when pushing the + model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details. + """ + kwargs = trl_sanitze_kwargs_for_tagging(model=self.model, tag_names=self._tag_names, kwargs=kwargs) + + return super().push_to_hub(commit_message=commit_message, blocking=blocking, **kwargs) + + def _prepare_dataset( + self, + dataset, + tokenizer, + packing, + dataset_text_field, + max_seq_length, + formatting_func, + num_of_sequences, + chars_per_token, + remove_unused_columns=True, + append_concat_token=True, + add_special_tokens=True, + ): + if dataset is None: + raise ValueError("The dataset should not be None") + + # check if torch dataset / dataloader and do nothing + if isinstance(dataset, (torch.utils.data.IterableDataset, torch.utils.data.Dataset, ConstantLengthDataset)): + return dataset + + if not packing: + return self._prepare_non_packed_dataloader( + tokenizer, + dataset, + dataset_text_field, + max_seq_length, + formatting_func, + add_special_tokens, + remove_unused_columns, + ) + + else: + return self._prepare_packed_dataloader( + tokenizer, + dataset, + dataset_text_field, + max_seq_length, + num_of_sequences, + chars_per_token, + formatting_func, + append_concat_token, + add_special_tokens, + ) + + def _prepare_non_packed_dataloader( + self, + tokenizer, + dataset, + dataset_text_field, + max_seq_length, + formatting_func=None, + add_special_tokens=True, + remove_unused_columns=True, + ): + use_formatting_func = formatting_func is not None and dataset_text_field is None + self._dataset_sanity_checked = False + + # Inspired from: https://huggingface.co/learn/nlp-course/chapter7/6?fw=pt + def tokenize(element): + outputs = tokenizer( + element[dataset_text_field] if not use_formatting_func else formatting_func(element), + add_special_tokens=add_special_tokens, + truncation=True, + padding=False, + max_length=max_seq_length, + return_overflowing_tokens=False, + return_length=False, + ) + + if use_formatting_func and not self._dataset_sanity_checked: + if not isinstance(formatting_func(element), list): + raise ValueError("The `formatting_func` should return a list of processed strings since it can lead to silent bugs.") + else: + self._dataset_sanity_checked = True + + return {"input_ids": outputs["input_ids"], "attention_mask": outputs["attention_mask"]} + + signature_columns = ["input_ids", "labels", "attention_mask"] + + extra_columns = list(set(dataset.column_names) - set(signature_columns)) + + if not remove_unused_columns and len(extra_columns) > 0: + warnings.warn( + "You passed `remove_unused_columns=False` on a non-packed dataset. This might create some issues with the default collator and yield to errors. If you want to " + f"inspect dataset other columns (in this case {extra_columns}), you can subclass `DataCollatorForLanguageModeling` in case you used the default collator and create your own data collator in order to inspect the unused dataset columns." + ) + + tokenized_dataset = dataset.map( + tokenize, + batched=True, + remove_columns=dataset.column_names if remove_unused_columns else None, + num_proc=self.dataset_num_proc, + batch_size=self.dataset_batch_size, + ) + + return tokenized_dataset + + def _prepare_packed_dataloader( + self, + tokenizer, + dataset, + dataset_text_field, + max_seq_length, + num_of_sequences, + chars_per_token, + formatting_func=None, + append_concat_token=True, + add_special_tokens=True, + ): + if dataset_text_field is not None or formatting_func is not None: + if tokenizer is None: + raise ValueError("You need to pass a tokenizer when using `dataset_text_field` with `SFTTrainer`.") + + constant_length_iterator = ConstantLengthDataset( + tokenizer, + dataset, + dataset_text_field=dataset_text_field, + formatting_func=formatting_func, + seq_length=max_seq_length, + infinite=False, + num_of_sequences=num_of_sequences, + chars_per_token=chars_per_token, + eos_token_id=tokenizer.eos_token_id, + append_concat_token=append_concat_token, + add_special_tokens=add_special_tokens, + ) + + def data_generator(constant_length_iterator): + for i in constant_length_iterator: + yield i + + try: + packed_dataset = Dataset.from_generator(data_generator, gen_kwargs={"constant_length_iterator": constant_length_iterator}) + except (DatasetGenerationError, SchemaInferenceError): + raise ValueError("Error occurred while packing the dataset. Make sure that your dataset has enough samples to at least yield one packed sequence.") + return packed_dataset + else: + raise ValueError("You need to pass a `dataset_text_field` or `formatting_func` argument to the SFTTrainer if you want to use the `ConstantLengthDataset`.") + + def _trl_activate_neftune(self, model): + r""" + Activates the neftune as presented in this code: https://github.com/neelsjain/NEFTune and paper: https://arxiv.org/abs/2310.05914 + Since in transformers Trainer we do have an `_activate_neftune` method, we need to rename this method to avoid conflicts. + """ + unwrapped_model = unwrap_model(model) + if is_peft_available() and isinstance(unwrapped_model, PeftModel): + embeddings = unwrapped_model.base_model.model.get_input_embeddings() + else: + embeddings = unwrapped_model.get_input_embeddings() + + embeddings.neftune_noise_alpha = self.neftune_noise_alpha + hook_handle = embeddings.register_forward_hook(neftune_post_forward_hook) + self.neftune_hook_handle = hook_handle + return model diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py new file mode 100755 index 000000000..1c566fa22 --- /dev/null +++ b/trl/trainer/utils.py @@ -0,0 +1,703 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import random +import warnings +from collections import deque +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from accelerate import PartialState +from torch.nn.utils.rnn import pad_sequence +from torch.utils.data import IterableDataset +from transformers import BitsAndBytesConfig, DataCollatorForLanguageModeling, PreTrainedTokenizerBase + +from ..import_utils import is_peft_available, is_unsloth_available, is_xpu_available +from ..trainer.model_config import ModelConfig + + +if is_peft_available(): + from peft import LoraConfig, PeftConfig + + +class AdaptiveKLController: + """ + Adaptive KL controller described in the paper: + https://arxiv.org/pdf/1909.08593.pdf + """ + + def __init__(self, init_kl_coef, target, horizon): + self.value = init_kl_coef + self.target = target + self.horizon = horizon + + def update(self, current, n_steps): + target = self.target + proportional_error = np.clip(current / target - 1, -0.2, 0.2) + mult = 1 + proportional_error * n_steps / self.horizon + self.value *= mult + + +class FixedKLController: + """Fixed KL controller.""" + + def __init__(self, kl_coef): + self.value = kl_coef + + def update(self, current, n_steps): + pass + + +class DataCollatorForCompletionOnlyLM(DataCollatorForLanguageModeling): + """ + Data collator used for completion tasks. It ensures that all the tokens of the labels are set to an 'ignore_index' + when they do not come from the assistant. This ensure that the loss is only + calculated on the completion made by the assistant. + + Args: + response_template (`Union[str, List[int]]`): the template form that indicates the start of the response, typically something like + '### Response:\n'. It can also be passed as tokenized ids, which can be useful when using a tokenizer that encodes the response + differently if it does not have proper context. + instruction_template (`Union[str, List[int]]`): the template form that indicates the start of the human instruction, typically something like + '### Human:\n'. Useful for assistant-style conversation datasets. It can also be passed as tokenized ids. + mlm (`bool`, *optional*, defaults to `False`): Whether or not to use masked language modeling in the underlying + `DataCollatorForLanguageModeling` class. Note that this option currently has no effect but is present + for flexibility and backwards-compatibility. + ignore_index (`int`, *optional*, defaults to `-100`): + The index to use to ignore the initial tokens with + """ + + def __init__( + self, + response_template: Union[str, List[int]], + instruction_template: Union[str, List[int]] = None, + *args, + mlm: bool = False, + ignore_index: int = -100, + **kwargs, + ): + super().__init__(*args, mlm=mlm, **kwargs) + + self.instruction_template = instruction_template + if isinstance(instruction_template, str): + # The user provides a string, must tokenize + self.instruction_token_ids = self.tokenizer.encode(self.instruction_template, add_special_tokens=False) + else: + # The user already provides the token ids + self.instruction_token_ids = instruction_template + + self.response_template = response_template + if isinstance(response_template, str): + # The user provides a string, must tokenize + self.response_token_ids = self.tokenizer.encode(self.response_template, add_special_tokens=False) + else: + # The user already provides the token ids + self.response_token_ids = response_template + + if not self.mlm and self.instruction_template and self.tokenizer.pad_token_id == self.tokenizer.eos_token_id: + warnings.warn( + "The pad_token_id and eos_token_id values of this tokenizer are identical. " + "If you are planning for multi-turn training, " + "it can result in the model continuously generating questions and answers without eos token. " + "To avoid this, set the pad_token_id to a different value." + ) + + self.ignore_index = ignore_index + + def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]: + batch = super().torch_call(examples) + + if self.instruction_template is None: + for i in range(len(examples)): + response_token_ids_start_idx = None + + for idx in np.where(batch["labels"][i] == self.response_token_ids[0])[0]: + # `response_token_ids` is `'### Response:\n'`, here we are just making sure that the token IDs match + if self.response_token_ids == batch["labels"][i][idx : idx + len(self.response_token_ids)].tolist(): + response_token_ids_start_idx = idx + + if response_token_ids_start_idx is None: + warnings.warn( + f"Could not find response key `{self.response_template}` in the " + f'following instance: {self.tokenizer.decode(batch["input_ids"][i])} ' + f"This instance will be ignored in loss calculation. " + f"Note, if this happens often, consider increasing the `max_seq_length`." + ) + batch["labels"][i, :] = self.ignore_index + else: + response_token_ids_end_idx = response_token_ids_start_idx + len(self.response_token_ids) + + # Make pytorch loss function ignore all tokens up through the end of the response key + batch["labels"][i, :response_token_ids_end_idx] = self.ignore_index + + else: + for i in range(len(examples)): + response_token_ids_idxs = [] + human_token_ids_idxs = [] + + for assistant_idx in np.where(batch["labels"][i] == self.response_token_ids[0])[0]: + # find the indexes of the start of a response. + if self.response_token_ids == batch["labels"][i][assistant_idx : assistant_idx + len(self.response_token_ids)].tolist(): + response_token_ids_idxs.append(assistant_idx + len(self.response_token_ids)) + + if len(response_token_ids_idxs) == 0: + warnings.warn( + f"Could not find response key `{self.response_template}` in the " + f'following instance: {self.tokenizer.decode(batch["input_ids"][i])} ' + f"This instance will be ignored in loss calculation. " + f"Note, if this happens often, consider increasing the `max_seq_length`." + ) + batch["labels"][i, :] = self.ignore_index + + human_token_ids = self.instruction_token_ids + for human_idx in np.where(batch["labels"][i] == human_token_ids[0])[0]: + # find the indexes of the start of a human answer. + if human_token_ids == batch["labels"][i][human_idx : human_idx + len(human_token_ids)].tolist(): + human_token_ids_idxs.append(human_idx) + + if len(human_token_ids_idxs) == 0: + warnings.warn( + f"Could not find instruction key `{self.instruction_template}` in the " + f'following instance: {self.tokenizer.decode(batch["input_ids"][i])} ' + f"This instance will be ignored in loss calculation. " + f"Note, if this happens often, consider increasing the `max_seq_length`." + ) + batch["labels"][i, :] = self.ignore_index + + if len(human_token_ids_idxs) > 0 and len(response_token_ids_idxs) > 0 and human_token_ids_idxs[0] > response_token_ids_idxs[0]: + human_token_ids_idxs = [0] + human_token_ids_idxs + + for idx, (start, end) in enumerate(zip(human_token_ids_idxs, response_token_ids_idxs)): + # Make pytorch loss function ignore all non response tokens + if idx != 0: + batch["labels"][i, start:end] = self.ignore_index + else: + batch["labels"][i, :end] = self.ignore_index + + if len(response_token_ids_idxs) < len(human_token_ids_idxs): + batch["labels"][i, human_token_ids_idxs[-1] :] = self.ignore_index + + return batch + + +@dataclass +class RewardDataCollatorWithPadding: + r""" + Reward DataCollator class that pads the inputs to the maximum length of the batch. + Args: + tokenizer (`PreTrainedTokenizerBase`): + The tokenizer used for encoding the data. + padding (`Union[bool, str, `PaddingStrategy`]`, `optional`, defaults to `True`): + padding_strategy to pass to the tokenizer. + max_length (`Optional[int]`, `optional`, defaults to `None`): + The maximum length of the sequence to be processed. + pad_to_multiple_of (`Optional[int]`, `optional`, defaults to `None`): + If set will pad the sequence to a multiple of the provided value. + return_tensors (`str`, `optional`, defaults to `"pt"`): + The tensor type to use. + """ + + tokenizer: PreTrainedTokenizerBase + padding: Union[bool, str] = True + max_length: Optional[int] = None + pad_to_multiple_of: Optional[int] = None + return_tensors: str = "pt" + + def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: + features_chosen = [] + features_rejected = [] + margin = [] + # check if we have a margin. If we do, we need to batch it as well + has_margin = "margin" in features[0] + for feature in features: + # check if the keys are named as expected + if "input_ids_chosen" not in feature or "input_ids_rejected" not in feature or "attention_mask_chosen" not in feature or "attention_mask_rejected" not in feature: + raise ValueError("The features should include `input_ids_chosen`, `attention_mask_chosen`, `input_ids_rejected` and `attention_mask_rejected`") + + features_chosen.append( + { + "input_ids": feature["input_ids_chosen"], + "attention_mask": feature["attention_mask_chosen"], + } + ) + features_rejected.append( + { + "input_ids": feature["input_ids_rejected"], + "attention_mask": feature["attention_mask_rejected"], + } + ) + if has_margin: + margin.append(feature["margin"]) + batch_chosen = self.tokenizer.pad( + features_chosen, + padding=self.padding, + max_length=self.max_length, + pad_to_multiple_of=self.pad_to_multiple_of, + return_tensors=self.return_tensors, + ) + batch_rejected = self.tokenizer.pad( + features_rejected, + padding=self.padding, + max_length=self.max_length, + pad_to_multiple_of=self.pad_to_multiple_of, + return_tensors=self.return_tensors, + ) + batch = { + "input_ids_chosen": batch_chosen["input_ids"], + "attention_mask_chosen": batch_chosen["attention_mask"], + "input_ids_rejected": batch_rejected["input_ids"], + "attention_mask_rejected": batch_rejected["attention_mask"], + "return_loss": True, + } + if has_margin: + margin = torch.tensor(margin, dtype=torch.float) + batch["margin"] = margin + return batch + + +@dataclass +class DPODataCollatorWithPadding: + r""" + DPO DataCollator class that pads the tokenized inputs to the maximum length of the batch. + Args: + pad_token_id (`int` defaults to 0): + The tokenizer's pad_token_id. + label_pad_token_id (`int`, defaults to -100): + The label used for masking. + is_encoder_decoder (`Optional[bool]`, `optional`, defaults to `None`): + Whether or not you model has an encoder_decoder architecture. + """ + + tokenizer: PreTrainedTokenizerBase + pad_token_id: int = 0 + label_pad_token_id: int = -100 + is_encoder_decoder: Optional[bool] = False + + def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: + # first, pad everything to the same length + padded_batch = {} + for k in features[0].keys(): + if k.endswith("_input_ids") or k.endswith("_attention_mask") or k.endswith("_labels"): + if self.is_encoder_decoder: + to_pad = [torch.LongTensor(ex[k]) for ex in features] + + if (k.startswith("prompt")) and (k.endswith("input_ids")): + if self.pad_token_id is None: + raise ValueError( + "Padding is enabled, but the tokenizer is not configured with a padding token." " Explicitly set `tokenizer.pad_token` (e.g. `tokenizer.pad_token = tokenizer.eos_token`)" " before calling the trainer." + ) + padding_value = self.pad_token_id + elif k.endswith("_attention_mask"): + padding_value = 0 + elif (k.startswith("chosen")) or (k.startswith("rejected")) or ("decoder" in k): + padding_value = self.label_pad_token_id + else: + raise ValueError(f"Unexpected key in batch '{k}'") + padded_batch[k] = pad_sequence(to_pad, batch_first=True, padding_value=padding_value) + else: + # adapted from https://stackoverflow.com/questions/73256206 + if "prompt" in k: + to_pad = [torch.LongTensor(ex[k][::-1]) for ex in features] + else: + to_pad = [torch.LongTensor(ex[k]) for ex in features] + if k.endswith("_input_ids"): + if self.pad_token_id is None: + raise ValueError( + "Padding is enabled, but the tokenizer is not configured with a padding token." " Explicitly set `tokenizer.pad_token` (e.g. `tokenizer.pad_token = tokenizer.eos_token`)" " before calling the trainer." + ) + padding_value = self.pad_token_id + elif k.endswith("_labels"): + padding_value = self.label_pad_token_id + elif k.endswith("_attention_mask"): + padding_value = 0 + else: + raise ValueError(f"Unexpected key in batch '{k}'") + + padded_batch[k] = pad_sequence(to_pad, batch_first=True, padding_value=padding_value) + # for the prompt, flip back so padding is on left side + if "prompt" in k: + padded_batch[k] = padded_batch[k].flip(dims=[1]) + elif k.endswith("_logps"): + # the cached reference model logprobs + padded_batch[k] = torch.tensor([ex[k] for ex in features]) + else: + padded_batch[k] = [ex[k] for ex in features] + + return padded_batch + + +class ConstantLengthDataset(IterableDataset): + """ + Iterable dataset that returns constant length chunks of tokens from stream of text files. + The dataset also formats the text before tokenization with a specific format that is provided + by the user. + + Args: + tokenizer (`transformers.PreTrainedTokenizer`): + The processor used for processing the data. + dataset (`dataset.Dataset`): + Dataset with text files. + dataset_text_field (`str`, **optional**): + Name of the field in the dataset that contains the text. Used only if `formatting_func` is `None`. + formatting_func (`Callable`, **optional**): + Function that formats the text before tokenization. Usually it is recommended to have follows a certain + pattern such as `"### Question: {question} ### Answer: {answer}"` + infinite (`bool`, *optional*, defaults to `False`): + If True the iterator is reset after dataset reaches end else stops. + seq_length (`int`, *optional*, defaults to `1024`): + Length of token sequences to return. + num_of_sequences (`int`, *optional*, defaults to `1024`): + Number of token sequences to keep in buffer. + chars_per_token (`int`, *optional*, defaults to `3.6`): + Number of characters per token used to estimate number of tokens in text buffer. + eos_token_id (`int`, *optional*, defaults to `0`): + Id of the end of sequence token if the passed tokenizer does not have an EOS token. + shuffle ('bool', *optional*, defaults to True) + Shuffle the examples before they are returned + append_concat_token ('bool', *optional*, defaults to True) + If true, appends `eos_token_id` at the end of each sample being packed. + add_special_tokens ('bool', *optional*, defaults to True) + If true, tokenizers adds special tokens to each sample being packed. + """ + + def __init__( + self, + tokenizer, + dataset, + dataset_text_field=None, + formatting_func=None, + infinite=False, + seq_length=1024, + num_of_sequences=1024, + chars_per_token=3.6, + eos_token_id=0, + shuffle=True, + append_concat_token=True, + add_special_tokens=True, + ): + self.tokenizer = tokenizer + + if tokenizer.eos_token_id is None: + warnings.warn( + "The passed tokenizer does not have an EOS token. We will use the passed eos_token_id instead which corresponds" f" to {eos_token_id}. If this is not the correct EOS token, make sure to pass the correct eos_token_id." + ) + + self.concat_token_id = tokenizer.eos_token_id if tokenizer.eos_token_id else eos_token_id + self.dataset = dataset + self.seq_length = seq_length + self.infinite = infinite + self.current_size = 0 + self.max_buffer_size = seq_length * chars_per_token * num_of_sequences + self.shuffle = shuffle + self.append_concat_token = append_concat_token + self.add_special_tokens = add_special_tokens + if formatting_func is None: + self.formatting_func = lambda x: x[dataset_text_field] + else: + self.formatting_func = formatting_func + + if formatting_func is not None: + if formatting_func.__code__.co_argcount > 1: + warnings.warn( + "The passed formatting_func has more than one argument. Usually that function should have a single argument `example`" + " which corresponds to the dictionary returned by each element of the dataset. Make sure you know what you are doing." + ) + + def __len__(self): + return len(self.dataset) + + def __iter__(self): + iterator = iter(self.dataset) + more_examples = True + while more_examples: + buffer, buffer_len = [], 0 + while True: + if buffer_len >= self.max_buffer_size: + break + try: + buffer.append(self.formatting_func(next(iterator))) + buffer_len += len(buffer[-1]) + except StopIteration: + if self.infinite: + iterator = iter(self.dataset) + warnings.warn("The dataset reached end and the iterator is reset to the start.") + else: + more_examples = False + break + tokenized_inputs = self.tokenizer(buffer, add_special_tokens=self.add_special_tokens, truncation=False)["input_ids"] + all_token_ids = [] + for tokenized_input in tokenized_inputs: + if self.append_concat_token: + tokenized_input = tokenized_input + [self.concat_token_id] + all_token_ids.extend(tokenized_input) + examples = [] + for i in range(0, len(all_token_ids), self.seq_length): + input_ids = all_token_ids[i : i + self.seq_length] + if len(input_ids) == self.seq_length: + examples.append(input_ids) + if self.shuffle: + random.shuffle(examples) + for example in examples: + self.current_size += 1 + yield { + "input_ids": torch.LongTensor(example), + "labels": torch.LongTensor(example), + } + + +class RunningMoments: + def __init__(self, accelerator): + """ + Calculates the running mean and standard deviation of a data stream. Reference: + https://github.com/OpenLMLab/MOSS-RLHF/blob/40b91eb2f2b71b16919addede0341d2bef70825d/utils.py#L75 + """ + self.mean = 0 + self.std = 1 + self.var = 1 + self.count = 1e-24 + self.accelerator = accelerator + + @torch.no_grad() + def update(self, xs: torch.Tensor) -> Tuple[float, float]: + """ + Updates running moments from batch's moments computed across ranks + """ + if self.accelerator.use_distributed: + xs_mean, xs_var, xs_count = get_global_statistics(self.accelerator, xs) + else: + xs_count = xs.numel() + xs_var, xs_mean = torch.var_mean(xs, unbiased=False) + xs_mean, xs_var = xs_mean.float(), xs_var.float() + + delta = xs_mean - self.mean + tot_count = self.count + xs_count + + new_sum = xs_var * xs_count + # correct old_sum deviation accounting for the new mean + old_sum = self.var * self.count + delta**2 * self.count * xs_count / tot_count + tot_sum = old_sum + new_sum + + self.mean += delta * xs_count / tot_count + self.var = tot_sum / tot_count + self.std = (self.var * tot_count / (tot_count - 1)).float().sqrt() + self.count = tot_count + + return xs_mean.item(), (xs_var * xs_count / (xs_count - 1)).float().sqrt().item() + + +@torch.no_grad() +def get_global_statistics(accelerator, xs: torch.Tensor, mask=None, device="cpu") -> Tuple[float, float, int]: + """ + Computes element-wise mean and variance of the tensor across processes. Reference: + https://github.com/OpenLMLab/MOSS-RLHF/blob/40b91eb2f2b71b16919addede0341d2bef70825d/utils.py#L57C1-L73C75 + """ + xs = xs.to(accelerator.device) + sum_and_count = torch.tensor([xs.sum(), (xs.numel() if mask is None else mask.sum())], device=xs.device) + sum_and_count = accelerator.reduce(sum_and_count) + global_sum, count = sum_and_count + global_mean = global_sum / count + + sum_var = torch.sum(((xs - global_mean) ** 2).mul(1 if mask is None else mask)) + sum_var = accelerator.reduce(sum_var) + global_var = sum_var / count + + return global_mean.to(device), global_var.to(device), count.to(device) + + +def compute_accuracy(eval_pred) -> Dict[str, float]: + predictions, labels = eval_pred + # Here, predictions is rewards_chosen and rewards_rejected. + # We want to see how much of the time rewards_chosen > rewards_rejected. + if np.array(predictions[:, 0] == predictions[:, 1], dtype=float).sum() > 0: + warnings.warn(f"There are {np.array(predictions[:, 0] == predictions[:, 1]).sum()} out of {len(predictions[:, 0])} instances where the predictions for both options are equal. As a consequence the accuracy can be misleading.") + predictions = np.argmax(predictions, axis=1) + + accuracy = np.array(predictions == labels, dtype=float).mean().item() + return {"accuracy": accuracy} + + +def pad_to_length(tensor: torch.Tensor, length: int, pad_value: Union[int, float], dim: int = -1) -> torch.Tensor: + if tensor.size(dim) >= length: + return tensor + else: + pad_size = list(tensor.shape) + pad_size[dim] = length - tensor.size(dim) + return torch.cat( + [ + tensor, + pad_value * torch.ones(*pad_size, dtype=tensor.dtype, device=tensor.device), + ], + dim=dim, + ) + + +def disable_dropout_in_model(model: torch.nn.Module) -> None: + for module in model.modules(): + if isinstance(module, torch.nn.Dropout): + module.p = 0 + + +def exact_div(a, b, a_str, b_str, custom_error_message=""): + q = a // b + if a != q * b: + raise ValueError(f"{custom_error_message}, {a_str}={a}, {b_str}={b}, inexact division: {a} / {b} = {a / b}") + return q + + +# copied from https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/stat_tracking.py#L5 +class PerPromptStatTracker: + r""" + Class for tracking statistics per prompt. Mainly used to calculate advantage for the DPPO algorithm + + Args: + buffer_size (`int`): + Size of the buffer to keep for each prompt. + min_count (`int`): + Minimum number of samples to keep in the buffer before calculating the mean and std. + """ + + def __init__(self, buffer_size, min_count): + self.buffer_size = buffer_size + self.min_count = min_count + self.stats = {} + + def update(self, prompts, rewards): + prompts = np.array(prompts) + rewards = np.array(rewards) + unique = np.unique(prompts) + advantages = np.empty_like(rewards) + for prompt in unique: + prompt_rewards = rewards[prompts == prompt] + if prompt not in self.stats: + self.stats[prompt] = deque(maxlen=self.buffer_size) + self.stats[prompt].extend(prompt_rewards) + + if len(self.stats[prompt]) < self.min_count: + mean = np.mean(rewards) + std = np.std(rewards) + 1e-6 + else: + mean = np.mean(self.stats[prompt]) + std = np.std(self.stats[prompt]) + 1e-6 + advantages[prompts == prompt] = (prompt_rewards - mean) / std + + return advantages + + def get_stats(self): + return {k: {"mean": np.mean(v), "std": np.std(v), "count": len(v)} for k, v in self.stats.items()} + + +def neftune_post_forward_hook(module, input, output): + """ + Implements the NEFTune forward pass for the model using forward hooks. Note this works only for + torch.nn.Embedding layers. This method is slightly adapted from the original source code + that can be found here: https://github.com/neelsjain/NEFTune + + Simply add it to your model as follows: + ```python + model = ... + model.embed_tokens.neftune_noise_alpha = 0.1 + model.embed_tokens.register_forward_hook(neftune_post_forward_hook) + ``` + + Args: + module (`torch.nn.Module`): + The embedding module where the hook is attached. Note that you need to set + `module.neftune_noise_alpha` to the desired noise alpha value. + input (`torch.Tensor`): + The input tensor to the model. + output (`torch.Tensor`): + The output tensor of the model (i.e. the embeddings). + """ + if module.training: + dims = torch.tensor(output.size(1) * output.size(2)) + mag_norm = module.neftune_noise_alpha / torch.sqrt(dims) + output = output + torch.zeros_like(output).uniform_(-mag_norm, mag_norm) + return output + + +def peft_module_casting_to_bf16(model): + from peft.tuners.tuners_utils import BaseTunerLayer + + for name, module in model.named_modules(): + if isinstance(module, BaseTunerLayer): + module = module.to(torch.bfloat16) + elif isinstance(module, torch.nn.LayerNorm) or "norm" in name: + module = module.to(torch.float32) + elif any(x in name for x in ["lm_head", "embed_tokens", "wte", "wpe"]): + if hasattr(module, "weight"): + if module.weight.dtype == torch.float32: + module = module.to(torch.bfloat16) + + +def trl_sanitze_kwargs_for_tagging(model, tag_names, kwargs=None): + if is_unsloth_available(): + # Unsloth adds a new attribute in the model config `unsloth_version` + # to keep track of models that have been patched with unsloth. + if hasattr(model, "config") and getattr(model.config, "unsloth_version", None) is not None: + tag_names.append("unsloth") + + if kwargs is not None: + if "tags" not in kwargs: + kwargs["tags"] = tag_names + elif "tags" in kwargs and isinstance(kwargs["tags"], list): + kwargs["tags"].extend(tag_names) + elif "tags" in kwargs and isinstance(kwargs["tags"], str): + tag_names.append(kwargs["tags"]) + kwargs["tags"] = tag_names + return kwargs + + +def get_quantization_config(model_config: ModelConfig) -> Optional[BitsAndBytesConfig]: + if model_config.load_in_4bit: + quantization_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=model_config.torch_dtype, # For consistency with model weights, we use the same value as `torch_dtype` + bnb_4bit_quant_type=model_config.bnb_4bit_quant_type, + bnb_4bit_use_double_quant=model_config.use_bnb_nested_quant, + ) + elif model_config.load_in_8bit: + quantization_config = BitsAndBytesConfig( + load_in_8bit=True, + ) + else: + quantization_config = None + + return quantization_config + + +def get_kbit_device_map() -> Optional[Dict[str, int]]: + if is_xpu_available(): + return {"": f"xpu:{PartialState().local_process_index}"} + elif torch.cuda.is_available(): + return {"": PartialState().local_process_index} + else: + return None + + +def get_peft_config(model_config: ModelConfig) -> "Optional[PeftConfig]": + if model_config.use_peft is False: + return None + + peft_config = LoraConfig( + r=model_config.lora_r, + lora_alpha=model_config.lora_alpha, + lora_dropout=model_config.lora_dropout, + bias="none", + task_type="CAUSAL_LM", + target_modules=model_config.lora_target_modules, + modules_to_save=model_config.lora_modules_to_save, + ) + + return peft_config