-
Notifications
You must be signed in to change notification settings - Fork 32
/
Copy pathmagma.py
249 lines (215 loc) · 10.3 KB
/
magma.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
import os
import uuid
import warnings
from typing import List, Optional, Tuple, Union
import torch
from accelerate import Accelerator, DistributedType
from tqdm import tqdm
import PIL
from torchvision.transforms.functional import to_pil_image
from decord import VideoReader, cpu
import numpy as np
from lmms_eval import utils
from lmms_eval.api.instance import Instance
from lmms_eval.api.model import lmms
from lmms_eval.api.registry import register_model
from lmms_eval.models.model_utils.qwen.qwen_generate_utils import make_context
warnings.simplefilter("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore")
from loguru import logger as eval_logger
from transformers import AutoModelForCausalLM, AutoProcessor
@register_model("magma")
class Magma(lmms):
"""
Magma Model
"""
def __init__(
self,
pretrained: str = "Magma/Magma-8b",
device: str = "cuda",
dtype: Optional[Union[str, torch.dtype]] = "auto",
batch_size: int = 1,
trust_remote_code: Optional[bool] = True,
attn_implementation: Optional[str] = None,
device_map: str = "",
max_frames_num: Optional[int] = 32,
**kwargs,
) -> None:
super().__init__()
# Do not use kwargs for now
assert kwargs == {}, f"Unexpected kwargs: {kwargs}"
accelerator = Accelerator()
if accelerator.num_processes >= 1 and device_map == "":
self._device = torch.device(f"cuda:{accelerator.local_process_index}")
self.device_map = f"cuda:{accelerator.local_process_index}"
else:
self._device = torch.device(device)
self.device_map = device_map
if isinstance(dtype, str) and dtype != "auto":
dtype = getattr(torch, dtype)
self.dtype = torch.bfloat16
self.max_frames_num = max_frames_num
self._model = AutoModelForCausalLM.from_pretrained(pretrained, torch_dtype=dtype, device_map=self.device_map, trust_remote_code=trust_remote_code, attn_implementation=attn_implementation)
self.model.eval()
self.processor = AutoProcessor.from_pretrained(pretrained, trust_remote_code=trust_remote_code)
if accelerator.num_processes > 1 and device_map == "":
assert accelerator.distributed_type in [DistributedType.FSDP, DistributedType.MULTI_GPU, DistributedType.DEEPSPEED], "Unsupported distributed type provided. Only DDP and FSDP are supported."
# If you want to use DistributedType.DEEPSPEED, you have to run accelerate config before using the model
# Also, you have to select zero stage 0 (equivalent to DDP) in order to make the prepare model works
# I tried to set different parameters in the kwargs to let default zero 2 stage works, but it didn't work.
if accelerator.distributed_type == DistributedType.DEEPSPEED:
kwargs = {
"train_micro_batch_size_per_gpu": self.batch_size_per_gpu,
"train_batch_size": self.batch_size_per_gpu * accelerator.num_processes,
}
AcceleratorState().deepspeed_plugin.deepspeed_config_process(must_match=True, **kwargs)
eval_logger.info("Detected that you are using DistributedType.DEEPSPEED. Make sure you run `accelerate config` and set zero stage to 0")
if accelerator.distributed_type == DistributedType.FSDP or accelerator.distributed_type == DistributedType.DEEPSPEED:
self._model = accelerator.prepare(self.model)
else:
self._model = accelerator.prepare_model(self.model, evaluation_mode=True)
self.accelerator = accelerator
if self.accelerator.is_local_main_process:
eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism")
self._rank = self.accelerator.local_process_index
self._world_size = self.accelerator.num_processes
elif accelerator.num_processes == 1 and device_map == "auto":
eval_logger.info(f"Using {accelerator.num_processes} devices with pipeline parallelism")
self._rank = 0
self._word_size = 1
else:
eval_logger.info(f"Using single device: {self._device}")
self.model.to(self._device)
self._rank = 0
self._word_size = 1
self.accelerator = accelerator
@property
def config(self):
# return the associated transformers.AutoConfig for the given pretrained model.
return self._config
@property
def tokenizer(self):
return self._tokenizer
@property
def model(self):
# returns the model, unwrapping it if using Accelerate
if hasattr(self, "accelerator"):
return self.accelerator.unwrap_model(self._model)
else:
return self._model
@property
def eot_token_id(self):
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
return self.tokenizer.eos_token_id
@property
def max_length(self):
return self._max_length
@property
def batch_size(self):
return self.batch_size_per_gpu
@property
def device(self):
return self._device
@property
def rank(self):
return self._rank
@property
def world_size(self):
return self._world_size
def tok_encode(self, string: str, left_truncate_len=None, add_special_tokens=None) -> List[int]:
""" """
add_special_tokens = False if add_special_tokens is None else add_special_tokens
encoding = self.tokenizer.encode(string, add_special_tokens=add_special_tokens)
# left-truncate the encoded context to be at most `left_truncate_len` tokens long
if left_truncate_len:
encoding = encoding[-left_truncate_len:]
return encoding
def tok_decode(self, tokens):
return self.tokenizer.decode(tokens)
def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
assert False, "Not implemented"
def flatten(self, input):
new_list = []
for i in input:
for j in i:
new_list.append(j)
return new_list
def load_video(self, video_path, max_frames_num):
if type(video_path) == str:
vr = VideoReader(video_path, ctx=cpu(0))
else:
vr = VideoReader(video_path[0], ctx=cpu(0))
total_frame_num = len(vr)
uniform_sampled_frames = np.linspace(0, total_frame_num - 1, max_frames_num, dtype=int)
frame_idx = uniform_sampled_frames.tolist()
spare_frames = vr.get_batch(frame_idx).asnumpy()
return spare_frames # (frames, height, width, channels)
def generate_until(self, requests: List[Instance]) -> List[str]:
res = []
pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding")
for contexts, gen_kwargs, doc_to_visual, doc_id, task, split in [reg.args for reg in requests]:
visuals = [doc_to_visual(self.task_dict[task][split][doc_id])]
visuals = self.flatten(visuals)
messages = [{"role": "user", "content": []}]
images = []
for visual in visuals:
if isinstance(visual, str):
frames = self.load_video(visual, self.max_frames_num)
frames = torch.from_numpy(frames).permute(0, 3, 1, 2)
images.extend([to_pil_image(frame) for frame in frames])
elif isinstance(visual, PIL.Image.Image):
images.append(visual)
for _ in range(len(images)):
messages[-1]["content"].append({"type": "image"})
messages[-1]["content"].append({"type": "text", "content": contexts})
convs = [
{"role": "user", "content": ''.join(["<image>\n"]*len(images)) + contexts},
# {"role": "user", "content": contexts},
]
convs = [
{
"role": "system",
"content": "You are agent that can see, talk and act.",
},
] + convs
prompt = self.processor.tokenizer.apply_chat_template(
convs,
tokenize=False,
add_generation_prompt=True
)
if self.model.config.mm_use_image_start_end:
prompt = prompt.replace("<image>", "<image_start><image><image_end>")
inputs = self.processor(images=images, texts=prompt, return_tensors="pt").to(self.model.device)
# convert inputs to the same data type
inputs['pixel_values'] = inputs['pixel_values'].unsqueeze(0)
inputs['image_sizes'] = inputs['image_sizes'].unsqueeze(0)
inputs = inputs.to(self.dtype)
if "max_new_tokens" not in gen_kwargs:
gen_kwargs["max_new_tokens"] = 1024
if "temperature" not in gen_kwargs:
gen_kwargs["temperature"] = 0
if "top_p" not in gen_kwargs:
gen_kwargs["top_p"] = None
if "num_beams" not in gen_kwargs:
gen_kwargs["num_beams"] = 1
if "do_sample" not in gen_kwargs:
gen_kwargs["do_sample"] = False
self.model.generation_config.pad_token_id = self.processor.tokenizer.pad_token_id
with torch.no_grad():
output = self.model.generate(
**inputs,
max_new_tokens=gen_kwargs["max_new_tokens"],
temperature=gen_kwargs["temperature"],
do_sample=gen_kwargs["do_sample"],
)
output = output[:, inputs["input_ids"].shape[-1] :]
if 'Phi-3-mini-128k-instruct' in self.processor.tokenizer.name_or_path:
decoded_text = self.processor.decode(output[0], skip_special_tokens=False).strip()
res.append(decoded_text.split('<|end|>')[0])
else:
res.append(self.processor.decode(output[0], skip_special_tokens=True).strip())
pbar.update(1)
pbar.close()
return res
def generate_until_multi_round(self, requests) -> List[str]:
raise NotImplementedError("TODO: Implement multi-round generation for LLaVAHF")