-
Notifications
You must be signed in to change notification settings - Fork 19
Vllm2 #27
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Vllm2 #27
Changes from all commits
224a9c0
bf30889
9587436
97ed852
c19328f
c9033a1
292836b
45a0bef
9467162
97aff25
d88a3a9
796e9fe
a31cb4d
a5530ce
3071f15
6f62461
bc390b4
6102146
d5e95bb
2771891
72b9706
d2a1da4
b01ab01
06a6f17
640771e
b74227a
3942946
40decd4
faec589
32cf1e4
a6f4ca2
075adfd
7cd53a7
83c1eda
a469b16
fc02c80
d869f73
cbaa0b1
53401e8
9125518
8842dee
bd8c67e
ea42e6e
438f96e
9d80745
2c75ce9
9b0ddbe
82e562d
265551c
f5d36d6
2905ee6
b128110
becf37b
c8d1b46
ff29200
071c0e0
fd5c347
48100e0
abf90b3
324b387
2b1608c
e952318
abebc2b
25a5244
83905c1
168b178
8f77911
90fdcec
c96a13b
c07385e
d7a5b79
36e0fed
0c33a04
b893192
6821917
48c82c0
ec4cf04
2691ed7
51263a8
6638b16
03e89cd
f1035fc
01510d1
cab9fe9
bd4ed48
f31fdb2
5ffd11e
e85239f
0986cab
96deb2a
ec95ad0
1ae7262
b7f2aa7
e06e3b5
e849c8a
f98a559
1462828
32f8e52
4afa904
465d2e7
ab1b50c
e3234ce
4e9bd81
d5096d5
f9f24ae
c5a3164
d46a454
1dca936
087dd9c
23bae8c
4c45559
868b62e
cbe5b50
319c883
2361216
7eb0612
973f6cb
cebd497
03e283b
9cc2ff2
68ff6e5
d0f2daa
4dc4fdd
592b072
579d59e
b2d917e
cacf49c
c88ee66
e4ccfc1
77d2907
88fa4cc
d3ec275
11adebd
92e7a52
c5ca4e6
53cebfe
4f85abd
bda2364
f716986
241db9d
55d6a79
ec36ea6
9b3cedb
9155413
4996c46
4599cf2
d62a2b5
2f79405
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -149,7 +149,10 @@ test/ | |
| # Logs | ||
| logs/* | ||
| !logs/.gitkeep | ||
| log/* | ||
|
|
||
| # databasegit status | ||
| db/* | ||
| !db/.gitkeep | ||
| !db/.gitkeep | ||
|
|
||
| biji.txt | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,16 +21,20 @@ | |
| from pathlib import Path | ||
|
|
||
| from lm_eval import evaluator | ||
|
|
||
| from vllm.model_executor.layers.logits_processor import _apply_logits_processors | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I assume there is no "safe" alternativ instead of using this internal method, right? If so, I think we can use it. |
||
| from vllm import LLM, SamplingParams | ||
| import warnings | ||
|
|
||
| warnings.filterwarnings('ignore') | ||
|
|
||
|
|
||
|
|
||
| # default value for arguments | ||
| DEFAULT_MODEL_PATH = "GreenBitAI/Qwen-1.5-1.8B-layer-mix-bpw-2.2" | ||
| DEFAULT_SEQLEN = 2048 | ||
| DEFAULT_RANDOM_SEED = 0 | ||
| DTYPE = torch.half | ||
| DEFAULT_MODEL_BCKEND = ["vllm", "greenbit-engine"] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The name seems it should be |
||
|
|
||
| replace_peft_lora_model_with_gba_lora_model() | ||
|
|
||
|
|
@@ -203,6 +207,18 @@ def setup_arg_parser(): | |
| help="Specify lora dir for lora merge" | ||
|
|
||
| ) | ||
| parser.add_argument( | ||
| "--backend", | ||
| type=str, | ||
| default="vllm", | ||
| help="Specify the model inference backend from [vllm, greenbit-engine]" | ||
| ) | ||
| parser.add_argument( | ||
| "--gpu-memory-utilization", | ||
| type=float, | ||
| default=0.8, | ||
| help="only useful when using vllm backend." | ||
| ) | ||
| return parser | ||
|
|
||
|
|
||
|
|
@@ -212,10 +228,10 @@ def create_device_map(cuda_device_id): | |
| device_map = {f"cuda:{id}" for id in ids} | ||
| return device_map | ||
|
|
||
| def main(args): | ||
| def evaluate_green_bit_engine(args): | ||
| if not os.path.exists(Path(args.save_dir)): | ||
| os.mkdir(Path(args.save_dir)) | ||
|
|
||
| # Building configs | ||
| tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None} | ||
| pretrain_model_config = { | ||
|
|
@@ -225,7 +241,7 @@ def main(args): | |
|
|
||
| if args.eos_token is not None: | ||
| tokenizer_config["eos_token"] = args.eos_token | ||
|
|
||
| model, tokenizer, config = load( | ||
| args.model, | ||
| tokenizer_config=tokenizer_config, | ||
|
|
@@ -235,7 +251,7 @@ def main(args): | |
| model_config=pretrain_model_config, | ||
| requires_grad=False | ||
| ) | ||
|
|
||
| if args.lora_dir is not None: | ||
| config = LoraConfig( | ||
| r=64, | ||
|
|
@@ -258,7 +274,97 @@ def main(args): | |
|
|
||
| eval_results = {"{}".format(args.model): eval_results} | ||
|
|
||
| add_dict_to_json_file(file_path="{}".format(os.path.join(args.save_dir, "eval_results.json")), new_data=eval_results) | ||
| add_dict_to_json_file(file_path="{}".format(os.path.join(args.save_dir, "eval_greenbit_engine_results.json")), new_data=eval_results) | ||
|
|
||
| def evaluate_vllm(args): | ||
| logits_list = [] | ||
| def forward_hook(module, input, output): | ||
| lm_head, hidden_states, sampling_metadata, *embedding_bias = input | ||
| embedding_bias = embedding_bias[0] if embedding_bias else None | ||
| logits = module._get_logits(hidden_states, lm_head, embedding_bias) | ||
| if logits is not None: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should use a guard clause here. |
||
| if module.soft_cap is not None: | ||
| logits = logits / module.soft_cap | ||
| logits = torch.tanh(logits) | ||
| logits = logits * module.soft_cap | ||
| if module.scale != 1.0: | ||
| logits *= module.scale | ||
| logits = _apply_logits_processors(logits, sampling_metadata) | ||
| logits_list.append(logits) | ||
| return output | ||
|
|
||
| @torch.no_grad() | ||
| def calculate_ppl(model, testenc, seqlen, device='cuda'): | ||
| nsamples = testenc.numel() // seqlen | ||
| nlls = [] | ||
|
|
||
| sampling_params = SamplingParams( | ||
| temperature=1.0, | ||
| max_tokens=1, | ||
| logprobs=None | ||
| ) | ||
|
|
||
| for i in tqdm(range(nsamples)): | ||
| logits_list.clear() | ||
| batch = testenc[:, (i * seqlen):((i + 1) * seqlen)] | ||
| outputs = model.generate(prompts=None, prompt_token_ids=batch.tolist(), sampling_params=sampling_params) | ||
| logits = logits_list[0].to(device) | ||
| logits = logits.unsqueeze(0) | ||
| shift_logits = logits[:, :-1, :] | ||
| shift_labels = testenc[:, (i * seqlen): ((i + 1) * seqlen)][ | ||
| :, 1: | ||
| ].to(device) | ||
|
Comment on lines
+314
to
+316
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please adjust the styling (see https://peps.python.org/pep-0008/#indentation or follow black style). |
||
| loss_fct = nn.CrossEntropyLoss() | ||
| loss = loss_fct( | ||
| shift_logits.view(-1, shift_logits.size(-1)), | ||
| shift_labels.view(-1), | ||
| ) | ||
| neg_log_likelihood = loss.float() * seqlen | ||
| nlls.append(neg_log_likelihood) | ||
| ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * seqlen)) | ||
| return ppl.item() | ||
|
|
||
| print(f"Loading model from {args.model}") | ||
| model = LLM( | ||
| model=args.model, | ||
| trust_remote_code=args.trust_remote_code, | ||
| gpu_memory_utilization=args.gpu_memory_utilization | ||
| ) | ||
| model.llm_engine.model_executor.driver_worker.model_runner.model.logits_processor.register_forward_hook(forward_hook) | ||
|
|
||
| results = {} | ||
| logger = create_logger(Path(args.save_dir)) | ||
| if args.eval_ppl: | ||
| for dataset in args.ppl_tasks.split(","): | ||
| # print(f"\nEvaluating {dataset}...") | ||
| dataloader, testloader = get_loaders( | ||
| dataset.strip(), | ||
| seed=args.seed, | ||
| model=args.model, | ||
| seqlen=args.seqlen, | ||
| ) | ||
|
|
||
| if "c4" in dataset: | ||
| testenc = testloader | ||
| else: | ||
| testenc = testloader.input_ids | ||
|
|
||
| ppl = calculate_ppl(model, testenc, args.seqlen) | ||
| logger.info(f'{dataset} : {ppl}') | ||
| results[dataset] = ppl | ||
|
|
||
| eval_results = {args.model: results} | ||
|
|
||
| add_dict_to_json_file(file_path="{}".format(os.path.join(args.save_dir, "eval_vllm_results.json")), new_data=eval_results) | ||
|
|
||
| def main(args): | ||
| if args.backend not in DEFAULT_MODEL_BCKEND: | ||
| print(f"Backend is error, please set the backend from {DEFAULT_MODEL_BCKEND}") | ||
| exit(-1) | ||
| if args.backend == "vllm": | ||
| evaluate_vllm(args) | ||
| elif args.backend == "greenbit-engine": | ||
| evaluate_green_bit_engine(args) | ||
|
|
||
| if __name__ == "__main__": | ||
| if not torch.cuda.is_available(): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| import abc | ||
|
|
||
|
|
||
| class BaseInferenceBackend: | ||
| @abc.abstractmethod | ||
| def generate(self, prompt, params): | ||
| pass |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,58 @@ | ||
| from green_bit_llm.inference.sim_gen import DTYPE | ||
| from .base import BaseInferenceBackend | ||
| import os | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
|
|
||
| import warnings | ||
| warnings.filterwarnings("ignore", category=UserWarning, module='torch.nn.modules.module') | ||
|
|
||
| from transformers import PreTrainedTokenizer | ||
|
|
||
| from green_bit_llm.common import generate, load | ||
| from green_bit_llm.args_parser import setup_shared_arg_parser | ||
|
|
||
| # default value for arguments | ||
| DEFAULT_PROMPT = None | ||
| DEFAULT_MAX_TOKENS = 100 | ||
| DEFAULT_TEMP = 0.8 | ||
| DEFAULT_TOP_P = 0.95 | ||
| DTYPE = torch.half | ||
|
|
||
| class GBLLMInferenceBackend(BaseInferenceBackend): | ||
| def __init__(self, model_path, **kwargs): | ||
| # Building configs | ||
| tokenizer_config = {"trust_remote_code": True if kwargs.get("trust_remote_code") else None} | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should use |
||
| pretrain_model_config = { | ||
| "trust_remote_code": True if kwargs.get("trust_remote_code") else None, | ||
| "attn_implementation": "flash_attention_2" if kwargs.get("use_flash_attention_2") else None | ||
| } | ||
| if kwargs.get("eos_token") is not None: | ||
| tokenizer_config["eos_token"] = kwargs.get("eos_token") | ||
|
|
||
| self.model, self.tokenizer, config = load( | ||
| model_path, | ||
| tokenizer_config=tokenizer_config, | ||
| dtype=kwargs.get("dtype", DTYPE), | ||
| device_map=kwargs.get("auto", "auto"), | ||
| seqlen=kwargs.get("seqlen", 2048), | ||
| model_config=pretrain_model_config, | ||
| requires_grad=False | ||
| ) | ||
|
|
||
| def generate(self, prompt, params=None): | ||
| if params == None: | ||
| params = {} | ||
| if isinstance(prompt, str): | ||
| prompt = [prompt] | ||
| for prom in prompt: | ||
| generate( | ||
| self.model, | ||
| self.tokenizer, | ||
| prom, | ||
| params.get("temperature", DEFAULT_TEMP), | ||
| params.get("max_tokens", DEFAULT_MAX_TOKENS), | ||
| True, | ||
| params.get("top_p", DEFAULT_TOP_P), | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,18 @@ | ||
| from vllm import LLM | ||
| from .base import BaseInferenceBackend | ||
|
|
||
| class VLLMInferenceBackend(BaseInferenceBackend): | ||
| def __init__(self, model_path, **kwargs): | ||
| self.model = LLM(model_path, **kwargs) | ||
|
|
||
| def do_generate(self, prompt, params): | ||
| outputs = self.model.generate(prompt, params) | ||
| return outputs | ||
|
|
||
| def generate(self, prompt, params=None): | ||
| if isinstance(prompt, str): | ||
| prompt = [prompt] | ||
| outputs = self.do_generate(prompt, params) | ||
| for i,output in enumerate(outputs): | ||
| print("Prompt:",prompt[i]) | ||
| print("Generated text:",output.outputs[0].text) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,34 @@ | ||
| psutil | ||
| sentencepiece # Required for LLaMA tokenizer. | ||
| numpy < 2.0.0 | ||
| requests >= 2.26.0 | ||
| tqdm | ||
| py-cpuinfo | ||
| transformers >= 4.45.2 # Required for Llama 3.2 and Qwen2-VL. | ||
| tokenizers >= 0.19.1 # Required for Llama 3. | ||
| protobuf # Required by LlamaTokenizer. | ||
| fastapi >= 0.107.0, < 0.113.0; python_version < '3.9' | ||
| fastapi >= 0.107.0, != 0.113.*, != 0.114.0; python_version >= '3.9' | ||
| aiohttp | ||
| openai >= 1.45.0 # Ensure modern openai package (ensure types module present and max_completion_tokens field support) | ||
| uvicorn[standard] | ||
| pydantic >= 2.9 # Required for fastapi >= 0.113.0 | ||
| pillow # Required for image processing | ||
| prometheus_client >= 0.18.0 | ||
| prometheus-fastapi-instrumentator >= 7.0.0 | ||
| tiktoken >= 0.6.0 # Required for DBRX tokenizer | ||
| lm-format-enforcer == 0.10.6 | ||
| outlines >= 0.0.43, < 0.1 | ||
| typing_extensions >= 4.10 | ||
| filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4 | ||
| partial-json-parser # used for parsing partial JSON outputs | ||
| pyzmq | ||
| msgspec | ||
| gguf == 0.10.0 | ||
| importlib_metadata | ||
| mistral_common[opencv] >= 1.4.4 | ||
| pyyaml | ||
| six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12 | ||
| setuptools>=74.1.1; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12 | ||
| einops # Required for Qwen2-VL. | ||
| compressed-tensors == 0.7.1 # required for compressed-tensors |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,10 @@ | ||
| # Common dependencies | ||
| -r requirements-common.txt | ||
|
|
||
| # Dependencies for NVIDIA GPUs | ||
| ray >= 2.9 | ||
| nvidia-ml-py >= 12.560.30 # for pynvml package | ||
| torch == 2.5.1 | ||
| # These must be updated alongside torch | ||
| torchvision == 0.20.1 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version | ||
| xformers == 0.0.28.post3; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch 2.5.1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think some explanation is needed. We could add some information here on the significance of this choice. I.e. what is VLLM, why/when to use it, or some link. Or all of those.