diff --git a/vllm_gaudi/models/qwen2_5_vl.py b/vllm_gaudi/models/qwen2_5_vl.py index ef53aaa58..d8f607210 100644 --- a/vllm_gaudi/models/qwen2_5_vl.py +++ b/vllm_gaudi/models/qwen2_5_vl.py @@ -67,9 +67,40 @@ def forward(q, k, v, mask, q_block_size, softmax_mode): return attn_output -def create_block_diagonal_attention_mask_outerprod(indices): - maxsize = indices[-1] - range_to_max_for_each_img = torch.arange(maxsize, +class HPU_Attention: + + softmax_mode = 'fp32' if \ + os.environ.get('VLLM_FP32_SOFTMAX_VISION', 'false').lower() \ + in ['true', '1'] else 'None' + + @classmethod + def forward(cls, q, k, v, mask, q_block_size=64): + """ + Support long sequence at prompt phase + """ + q_len = q.size(-2) + if q_len <= 65536: # need to investigate this crosspoint + return FusedSDPA.apply(q, k, v, mask, 0.0, False, None, cls.softmax_mode) + + assert q_len % q_block_size == 0 + q_tiles = (q_len // q_block_size) if (q_len % q_block_size == 0) else math.ceil(q_len / q_block_size) + attn_output = torch.zeros_like(q) + + for i in range(q_tiles): + s, e = i * q_block_size, (i + 1) * q_block_size + row_q = q[:, :, s:e, :] + row_mask = mask[:, :, s:e, :] + attn_output[:, :, s:e, :] = FusedSDPA.apply(row_q, k, v, row_mask, 0.0, False, None, cls.softmax_mode) + # TODO: markstep after a couple of iterations + # need to experiment the optimal number. + if i % 75 == 0: + htcore.mark_step() + return attn_output + + +def create_block_diagonal_attention_mask(indices): + max_size = indices[-1] + range_to_max_for_each_img = torch.arange(max_size, device=indices.device).unsqueeze(0).repeat(indices.shape[0] - 1, 1) lesser = range_to_max_for_each_img < indices[1:].unsqueeze(1) greater_eq = range_to_max_for_each_img >= indices[:-1].unsqueeze(1) @@ -122,18 +153,17 @@ def __init__( attn_backend_override=attn_backend_override, ) - self.softmax_mode = 'fp32' if os.environ.get('VLLM_FP32_SOFTMAX_VISION', 'false').lower() in ['true', '1' - ] else 'None' assert_msg = ("Flash Attention backend is not supported on HPU for Vision Transformer " "in Qwen2_5_VL model. Please use TORCH_SDPA backend.") assert self.attn_backend != AttentionBackendEnum.FLASH_ATTN, assert_msg def forward( - self, - x: torch.Tensor, - cu_seqlens: torch.Tensor, - rotary_pos_emb_cos: torch.Tensor, - rotary_pos_emb_sin: torch.Tensor, + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb_cos: torch.Tensor, + rotary_pos_emb_sin: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, # Only used for HPU ) -> torch.Tensor: # [s, b, c] --> [s, b, head * 3 * head_dim] x, _ = self.qkv(x) @@ -162,46 +192,11 @@ def forward( else: q, k, v = qkv.unbind(dim=2) - fullattn_mask = cu_seqlens - - if fullattn_mask is None: # performs window attention - # we assume image is 112 aligned in both h/w dims - # in other words, x % 64 = 0 - # that simplifies the slicing of window attention - # in patches of 64 - outputs = [] - cu_seqlens = list(range(0, x.shape[0] + 1, 64)) - for i in range(1, len(cu_seqlens)): - # For large image, we add mark step here - # for every 100th step to make compile time shorter - if i % 100 == 0: - htcore.mark_step() - start_idx = cu_seqlens[i - 1] - end_idx = cu_seqlens[i] - q_i = q[:, start_idx:end_idx] - k_i = k[:, start_idx:end_idx] - v_i = v[:, start_idx:end_idx] - q_i, k_i, v_i = (rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]) - output_i = FusedSDPA.apply(q_i, k_i, v_i, None, 0.0, False, None, self.softmax_mode) - output_i = rearrange(output_i, "b h s d -> b s h d ") - outputs.append(output_i) - context_layer = torch.cat(outputs, dim=1) - else: - # performs full attention using the previous computed mask - fullatt_block_attn_mask = fullattn_mask - q1, k1, v1 = (rearrange(x, "b s h d -> b h s d") for x in [q, k, v]) - (batch_size, _, seq_len_N_t, _) = q1.shape - (batch_size, _, seq_len_N_s, _) = k1.shape - mask_shape = (batch_size, 1, seq_len_N_t, seq_len_N_s) - attn_mask = fullatt_block_attn_mask.reshape(batch_size, 1, seq_len_N_t, seq_len_N_s, - -1)[:, :, :, :, 0] # reshapes the mask to be Bx1xNxN - assert attn_mask.shape == mask_shape - - if q1.shape[2] <= 65536: # need to investigate this crosspoint - fused_out = FusedSDPA.apply(q1, k1, v1, attn_mask, 0.0, False, None, self.softmax_mode) - else: - fused_out = AttentionLongSequence.forward(q1, k1, v1, attn_mask, 64, self.softmax_mode) - context_layer = rearrange(fused_out, "b h s d -> b s h d ") + # performs full attention using the previous computed mask + q1, k1, v1 = (rearrange(x, "b s h d -> b h s d") for x in [q, k, v]) + output = HPU_Attention.forward(q1, k1, v1, attn_mask) + context_layer = rearrange(output, "b h s d -> b s h d ") + context_layer = rearrange(context_layer, "b s h d -> s b (h d)").contiguous() output, _ = self.proj(context_layer) @@ -249,16 +244,18 @@ def __init__( def forward( self, x: torch.Tensor, - cu_seqlens: torch.Tensor, rotary_pos_emb_cos: torch.Tensor, rotary_pos_emb_sin: torch.Tensor, + cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None, # Only used for Flash Attention seqlens: Optional[list[int]] = None, # Only used for xFormers + attn_mask: Optional[torch.Tensor] = None, # Only used for HPU ) -> torch.Tensor: x = x + self.attn(self.norm1(x), cu_seqlens=cu_seqlens, rotary_pos_emb_cos=rotary_pos_emb_cos, - rotary_pos_emb_sin=rotary_pos_emb_sin) + rotary_pos_emb_sin=rotary_pos_emb_sin, + attn_mask=attn_mask) x = x + self.mlp(self.norm2(x)) return x @@ -315,117 +312,63 @@ def __init__( ) for layer_idx in range(depth) ]) - def rot_pos_emb(self, grid_thw: torch.Tensor): # -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: - pos_ids = [] - for t, h, w in grid_thw: - hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) - wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) - hpos_ids = hpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ).permute(0, 2, 1, 3).flatten() - wpos_ids = wpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ).permute(0, 2, 1, 3).flatten() - pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) - pos_ids = torch.cat(pos_ids, dim=0) - max_grid_size = grid_thw[:, 1:].max() - - # Use pre-computed cos_sin_cache from RotaryEmbedding - cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size) - - cos_combined = cos[pos_ids].flatten(1) - sin_combined = sin[pos_ids].flatten(1) - - cos_combined = cos_combined.reshape( - cos_combined.shape[0] // self.spatial_merge_unit, - self.spatial_merge_unit, - -1, - ) - sin_combined = sin_combined.reshape( - sin_combined.shape[0] // self.spatial_merge_unit, - self.spatial_merge_unit, - -1, - ) - - return cos_combined, sin_combined - - def get_window_index(self, grid_thw): - window_index: list = [] - cu_window_seqlens: list = [0] - window_index_id = 0 - vit_merger_window_size = (self.window_size // self.spatial_merge_size // self.patch_size) - - for grid_t, grid_h, grid_w in grid_thw: - llm_grid_h = grid_h // self.spatial_merge_size - llm_grid_w = grid_w // self.spatial_merge_size - index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w) - pad_h = \ - vit_merger_window_size - llm_grid_h % vit_merger_window_size - pad_w = \ - vit_merger_window_size - llm_grid_w % vit_merger_window_size - num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size - num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size - index_padded = F.pad(index, (0, pad_w, 0, pad_h), 'constant', -100) - index_padded = index_padded.reshape(grid_t, num_windows_h, vit_merger_window_size, num_windows_w, - vit_merger_window_size) - index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(grid_t, num_windows_h * num_windows_w, - vit_merger_window_size, vit_merger_window_size) - seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) - index_padded = index_padded.reshape(-1) - index_new = index_padded[index_padded != -100] - - cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1] - - cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) - window_index.append(index_new + window_index_id) - window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() - window_index = torch.cat(window_index, dim=0) - - return window_index, cu_window_seqlens - def pre_attn(self, x: torch.Tensor, grid_thw: torch.Tensor): # patchify + seq_len, _ = x.size() + cos_list = [] + sin_list = [] + window_index: list = [] + cu_window_seqlens: list = [torch.tensor([0], dtype=torch.int32)] + cu_seqlens: list = [] + hidden_states = x.to(device=self.device, dtype=self.dtype) hidden_states = self.patch_embed(hidden_states) - # compute position embedding - cos_combined, sin_combined = self.rot_pos_emb(grid_thw) - - # windows attention - window_index, cu_window_seqlens = self.get_window_index(grid_thw) - - # NOTE: unique_consecutive is a dynamic operation - # we are using `remove_duplicates_cpu` instead - def remove_duplicates_cpu(a): - return [a[i] for i in range(len(a)) if i == 0 or a[i - 1] != a[i]] + window_index_id = 0 + cu_window_seqlens_last = 0 + for t, h, w in grid_thw: + t, h, w = int(t), int(h), int(w) + llm_h = h // self.spatial_merge_size + llm_w = w // self.spatial_merge_size + + ( + cos_thw, + sin_thw, + window_index_thw, + cu_seqlens_window_thw, + cu_seqlens_thw, + ) = self.get_rope_by_thw(t, h, w) + + window_index.append(window_index_thw + window_index_id) + window_index_id += (t * llm_h * llm_w) + + cu_seqlens_window_thw = (cu_seqlens_window_thw + cu_window_seqlens_last) + cu_window_seqlens_last = cu_seqlens_window_thw[-1] + cu_window_seqlens.append(cu_seqlens_window_thw) + + # accumulate RoPE and THW seqlens + cos_list.append(cos_thw) + sin_list.append(sin_thw) + cu_seqlens.append(cu_seqlens_thw) + + # concatenate + cos_combined = torch.cat(cos_list).to(self.device, non_blocking=True) + sin_combined = torch.cat(sin_list).to(self.device, non_blocking=True) + + window_index = torch.cat(window_index).to(self.device, non_blocking=True) + cu_window_seqlens = torch.cat(cu_window_seqlens) + cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) + cu_seqlens = torch.cat(cu_seqlens) + cu_seqlens = torch.cumsum(cu_seqlens, dim=0, dtype=torch.int32) + cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0) - cu_window_seqlens = remove_duplicates_cpu(cu_window_seqlens) - cu_window_seqlens = torch.tensor(cu_window_seqlens, - device=hidden_states.device, - dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32) + cu_seqlens = cu_seqlens.to(device=self.device, non_blocking=True) + cu_window_seqlens = cu_window_seqlens.to(device=self.device, non_blocking=True) - seq_len, _ = hidden_states.size() hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) hidden_states = hidden_states[window_index, :, :] hidden_states = hidden_states.reshape(seq_len, -1) - cos_combined = cos_combined[window_index, :, :] - cos_combined = cos_combined.flatten(start_dim=0, end_dim=1) - sin_combined = sin_combined[window_index, :, :] - sin_combined = sin_combined.flatten(start_dim=0, end_dim=1) - cos_combined = cos_combined.to(device=self.device, non_blocking=True) - sin_combined = sin_combined.to(device=self.device, non_blocking=True) - - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(dim=0, - dtype=torch.int32) - cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0) - return ( hidden_states, cos_combined, @@ -435,21 +378,27 @@ def remove_duplicates_cpu(a): window_index, ) - def forward(self, x: torch.Tensor, fullattn_mask: Optional[torch.Tensor], rotary_pos_emb_cos: torch.Tensor, - rotary_pos_emb_sin: torch.Tensor) -> torch.Tensor: - assert_msg = ("Expect inputs to be 112x112 aligned. " - "Please align before sending image and " - "check PR #1163 description for more details") - assert x.shape[0] % 64 == 0, assert_msg - hidden_states = x.unsqueeze(1) + def forward(self, hidden_states: torch.Tensor, rotary_pos_emb_cos: torch.Tensor, rotary_pos_emb_sin: torch.Tensor, + padding_attn_mask_window: torch.Tensor, padding_attn_mask_full: torch.Tensor) -> torch.Tensor: + hidden_states = hidden_states.unsqueeze(1) for layer_num, blk in enumerate(self.blocks): - htcore.mark_step() + if layer_num in self.fullatt_block_indexes: + padding_attn_mask_now = padding_attn_mask_full + else: + padding_attn_mask_now = padding_attn_mask_window + hidden_states = blk( hidden_states, - cu_seqlens=fullattn_mask if layer_num in self.fullatt_block_indexes else None, + cu_seqlens=padding_attn_mask_now, rotary_pos_emb_cos=rotary_pos_emb_cos, rotary_pos_emb_sin=rotary_pos_emb_sin, ) + + # For Qwen2.5-VL-3B, float16 will overflow at last block + # for long visual tokens sequences. + if hidden_states.dtype == torch.float16: + hidden_states = cast_overflow_tensors(hidden_states) + return hidden_states def post_attn(self, hidden_states: torch.Tensor, window_index: torch.Tensor): @@ -460,80 +409,13 @@ def post_attn(self, hidden_states: torch.Tensor, window_index: torch.Tensor): hidden_states = hidden_states[reverse_indices, :] return hidden_states - def pad_multimodal_data(self, pixel_values, image_grid_thw, vision_buckets): - assert pixel_values.shape[0] % 64 == 0, 'needs 64 aligned resolution' - - desired_number_of_pixels = vision_buckets.get_multimodal_bucket(pixel_values.shape[0]) - padding_len = desired_number_of_pixels - pixel_values.shape[0] - if padding_len <= 0: - return pixel_values, image_grid_thw - - logger_msg = "Padding current number pixel " \ - + str(pixel_values.shape[0]) \ - + " to " \ - + str(desired_number_of_pixels) - logger.info(logger_msg) - - assert padding_len % 64 == 0, 'padding needs to be multiple of 64' - - constant_value = -100 - pixel_values = torch.cat([ - pixel_values, - torch.ones((padding_len, pixel_values.shape[1]), device=pixel_values.device) * constant_value - ]) - - image_grid_thw = torch.cat( - [image_grid_thw, torch.tensor([[1, 8, padding_len // 8]], device=image_grid_thw.device)]) - - assert image_grid_thw.prod(-1).sum() == desired_number_of_pixels - return pixel_values, image_grid_thw - def get_image_embeds( self, pixel_values: torch.Tensor, grid_thw: torch.Tensor, vision_buckets, ) -> torch.Tensor: - - num_patches = pixel_values.shape[0] - if num_patches % 64 != 0: - assert num_patches > 64, "Image needs to be at least 112 x 112" - logger_msg = ("QWEN 2.5VL for HPU is under development. " - "Image height and width need to be multiples of 112 pixels. " - "We are prunning the last visual tokens to comply with this " - "requirement but this leads to accuracy degradation. " - "Please, reshape the images or use this custom transformer " - "that does the resizing/alignment automatically: " - "pip install " - "git+https://github.com/malkomes/transformers.git" - "@ac372cd18f836c41f57cdce46094db00019d4280" - "See PR #1163 description, for more details") - logger.warning_once(logger_msg) - - # reshape grid_thw with multiples of 8 - old_img_sizes = [] - new_img_sizes = [] - for img_idx in range(grid_thw.shape[0]): - img_shape = grid_thw[img_idx, :].tolist() - tt, hh, ww = img_shape - hh_new = (hh // 8) * 8 - ww_new = (ww // 8) * 8 - old_img_sizes.append(tt * hh * ww) - new_img_sizes.append(tt * hh_new * ww_new) - grid_thw[img_idx, 1] = hh_new - grid_thw[img_idx, 2] = ww_new - - # truncate pixel_values to new shapes - copy_pointer = 0 - paste_pointer = 0 - for old_img_size, new_img_size in zip(old_img_sizes, new_img_sizes): - pixel_values[paste_pointer:paste_pointer + new_img_size, :] = \ - pixel_values[copy_pointer:copy_pointer + new_img_size, :] - copy_pointer += old_img_size - paste_pointer += new_img_size - - pixel_values = pixel_values[:paste_pointer, :] - + seq_len, _ = pixel_values.size() offset = 0 results = [] # process each image one by one @@ -544,38 +426,48 @@ def get_image_embeds( pixel_values_curr_img = pixel_values[offset:offset + curr_img_size, :] offset += curr_img_size - pixel_values_curr_img_padded, img_shape_padded = \ - self.pad_multimodal_data( - pixel_values_curr_img, - img_shape, - vision_buckets=vision_buckets - ) - - pixel_values_curr_img_padded, rot_pos_emb_cos, rot_pos_emb_sin, \ + # pre-attention block + hidden_states, rot_pos_emb_cos, rot_pos_emb_sin, \ cu_seqlens, cu_window_seqlens, window_index = self.pre_attn( - pixel_values_curr_img_padded, img_shape_padded) - - # Create full attention block mask before VisionTransformer - # to save memory/time - fullatt_block_attn_mask = \ - create_block_diagonal_attention_mask_outerprod(cu_seqlens) - - assert pixel_values_curr_img_padded.shape[0] == cu_seqlens[-1] - assert pixel_values_curr_img_padded.shape[0] == rot_pos_emb_cos.shape[0] == rot_pos_emb_sin.shape[0] - + pixel_values_curr_img, img_shape) + + # add padding + bucket_size = vision_buckets.get_multimodal_bucket(curr_img_size) + num_pad_tokens = bucket_size - curr_img_size + if num_pad_tokens > 0: + logger_msg = "Padding current image size " \ + + str(curr_img_size.item()) \ + + " to " \ + + str(bucket_size) + logger.info(logger_msg) + cu_seqlens = F.pad(cu_seqlens, (0, 1), "constant", bucket_size) + cu_window_seqlens = F.pad(cu_window_seqlens, (0, 1), "constant", bucket_size) + hidden_states = F.pad(hidden_states, (0, 0, 0, num_pad_tokens), "constant", -100) + rot_pos_emb_cos = F.pad( + rot_pos_emb_cos, # [seq, dim] + (0, 0, 0, num_pad_tokens), + "constant", + 0.0) + rot_pos_emb_sin = F.pad(rot_pos_emb_sin, (0, 0, 0, num_pad_tokens), "constant", 0.0) + + padding_attn_mask_full = create_block_diagonal_attention_mask(cu_seqlens) + padding_attn_mask_window = create_block_diagonal_attention_mask(cu_window_seqlens) + + # static part htcore.mark_step() - hidden_states = self.forward( - pixel_values_curr_img_padded, - rotary_pos_emb_cos=rot_pos_emb_cos, - rotary_pos_emb_sin=rot_pos_emb_sin, - fullattn_mask=fullatt_block_attn_mask, - ) + hidden_states = self.forward(hidden_states, + rotary_pos_emb_cos=rot_pos_emb_cos, + rotary_pos_emb_sin=rot_pos_emb_sin, + padding_attn_mask_window=padding_attn_mask_window, + padding_attn_mask_full=padding_attn_mask_full) htcore.mark_step() + # remove padding + hidden_states = hidden_states[:curr_img_size, :, :] + + # after attention image_embeds = self.post_attn(hidden_states, window_index) - # slice image_embeds to remove the padded parts - pad_index = img_shape_padded[0].prod() // self.spatial_merge_unit - results += [image_embeds[:pad_index, :]] + results += [image_embeds] results_cat = torch.concat(results) image_embeds = results_cat return image_embeds