|
| 1 | +import atexit |
| 2 | +import os |
| 3 | +from pathlib import Path |
| 4 | + |
| 5 | +import yaml |
| 6 | +from PIL import Image |
| 7 | +from transformers import AutoTokenizer |
| 8 | + |
| 9 | +from vllm import LLM, SamplingParams |
| 10 | + |
| 11 | +TEST_DATA_FILE = os.environ.get( |
| 12 | + "TEST_DATA_FILE", |
| 13 | + ".jenkins/vision/configs/Meta-Llama-3.2-11B-Vision-Instruct.yaml") |
| 14 | + |
| 15 | +TP_SIZE = int(os.environ.get("TP_SIZE", 1)) |
| 16 | + |
| 17 | + |
| 18 | +def fail_on_exit(): |
| 19 | + os._exit(1) |
| 20 | + |
| 21 | + |
| 22 | +def launch_enc_dec_model(config, question): |
| 23 | + model_name = config.get('model_name') |
| 24 | + dtype = config.get('dtype', 'bfloat16') |
| 25 | + max_num_seqs = config.get('max_num_seqs', 128) |
| 26 | + max_model_len = config.get('max_model_len', 4096) |
| 27 | + tensor_parallel_size = TP_SIZE |
| 28 | + num_scheduler_steps = config.get('num_scheduler_steps', 1) |
| 29 | + llm = LLM( |
| 30 | + model=model_name, |
| 31 | + dtype=dtype, |
| 32 | + tensor_parallel_size=tensor_parallel_size, |
| 33 | + num_scheduler_steps=num_scheduler_steps, |
| 34 | + max_model_len=max_model_len, |
| 35 | + max_num_seqs=max_num_seqs, |
| 36 | + ) |
| 37 | + tokenizer = AutoTokenizer.from_pretrained(model_name) |
| 38 | + messages = [{ |
| 39 | + "role": |
| 40 | + "user", |
| 41 | + "content": [{ |
| 42 | + "type": "image" |
| 43 | + }, { |
| 44 | + "type": "text", |
| 45 | + "text": f"{question}" |
| 46 | + }] |
| 47 | + }] |
| 48 | + prompt = tokenizer.apply_chat_template(messages, |
| 49 | + add_generation_prompt=True, |
| 50 | + tokenize=False) |
| 51 | + return llm, prompt |
| 52 | + |
| 53 | + |
| 54 | +def get_input(): |
| 55 | + image = Image.open("data/cherry_blossom.jpg").convert("RGB") |
| 56 | + img_question = "What is the content of this image?" |
| 57 | + |
| 58 | + return { |
| 59 | + "image": image, |
| 60 | + "question": img_question, |
| 61 | + } |
| 62 | + |
| 63 | + |
| 64 | +def get_current_gaudi_platform(): |
| 65 | + |
| 66 | + #Inspired by: https://github.com/HabanaAI/Model-References/blob/a87c21f14f13b70ffc77617b9e80d1ec989a3442/PyTorch/computer_vision/classification/torchvision/utils.py#L274 |
| 67 | + |
| 68 | + import habana_frameworks.torch.utils.experimental as htexp |
| 69 | + |
| 70 | + device_type = htexp._get_device_type() |
| 71 | + |
| 72 | + if device_type == htexp.synDeviceType.synDeviceGaudi: |
| 73 | + return "Gaudi1" |
| 74 | + elif device_type == htexp.synDeviceType.synDeviceGaudi2: |
| 75 | + return "Gaudi2" |
| 76 | + elif device_type == htexp.synDeviceType.synDeviceGaudi3: |
| 77 | + return "Gaudi3" |
| 78 | + else: |
| 79 | + raise ValueError( |
| 80 | + f"Unsupported device: the device type is {device_type}.") |
| 81 | + |
| 82 | + |
| 83 | +def test_enc_dec_model(record_xml_attribute, record_property): |
| 84 | + try: |
| 85 | + config = yaml.safe_load( |
| 86 | + Path(TEST_DATA_FILE).read_text(encoding="utf-8")) |
| 87 | + # Record JUnitXML test name |
| 88 | + platform = get_current_gaudi_platform() |
| 89 | + testname = (f'test_{Path(TEST_DATA_FILE).stem}_{platform}_' |
| 90 | + f'tp{TP_SIZE}') |
| 91 | + record_xml_attribute("name", testname) |
| 92 | + |
| 93 | + mm_input = get_input() |
| 94 | + image = mm_input["image"] |
| 95 | + question = mm_input["question"] |
| 96 | + llm, prompt = launch_enc_dec_model(config, question) |
| 97 | + |
| 98 | + sampling_params = SamplingParams(temperature=0.0, |
| 99 | + max_tokens=100, |
| 100 | + stop_token_ids=None) |
| 101 | + |
| 102 | + num_prompts = config.get('num_prompts', 1) |
| 103 | + inputs = [{ |
| 104 | + "prompt": prompt, |
| 105 | + "multi_modal_data": { |
| 106 | + "image": image |
| 107 | + }, |
| 108 | + } for _ in range(num_prompts)] |
| 109 | + |
| 110 | + outputs = llm.generate(inputs, sampling_params=sampling_params) |
| 111 | + |
| 112 | + for o in outputs: |
| 113 | + generated_text = o.outputs[0].text |
| 114 | + assert generated_text, "Generated text is empty" |
| 115 | + print(generated_text) |
| 116 | + os._exit(0) |
| 117 | + |
| 118 | + except Exception as exc: |
| 119 | + atexit.register(fail_on_exit) |
| 120 | + raise exc |
0 commit comments