-
Notifications
You must be signed in to change notification settings - Fork 19
Multimodality/sft extension #91
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: multimodality/images
Are you sure you want to change the base?
Changes from 1 commit
e8bda1c
c9b5701
5bfcacb
5a5b5de
6e458a8
4ff1170
df80bc2
912d683
c591b9a
812ea3e
bc021e1
49a4e8a
96ea32d
9f14f5b
d832f6e
74baf69
7b81b95
515977b
1ee7c2e
56fbc8c
e14e9cf
0fcbc7a
846ab71
3693f1e
572b92d
07b185c
4f1ee44
b89133c
217f3ea
c5eef18
ba0e5c6
53b4619
4192389
129d45a
8383953
1ea97b4
bf5d2b7
da8cb3b
ab43c71
bf74020
8c5188f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,6 +4,8 @@ | |
| import os | ||
| import logging | ||
| import numpy as np | ||
| import json | ||
| from pathlib import Path | ||
| import torch | ||
| import torch.nn.functional as F | ||
|
|
||
|
|
@@ -33,14 +35,16 @@ def __init__( | |
| # Call Megatron Dataset init instead of direct parent, as we initialize index differently | ||
| MegatronDataset.__init__(self, dataset, dataset_path, indexed_indices, num_samples, index_split, config) | ||
|
|
||
| self.debug_writer = DebugDataWriter(output_dir="/users/rkreft/debug_data") # Initialize once, outside the loop | ||
|
|
||
| self.tokenizer = config.tokenizer | ||
| # Set pad token | ||
| try: | ||
| self._pad_token_id = self.tokenizer.pad | ||
| except Exception: | ||
| self._pad_token_id = _PAD_TOKEN_ID | ||
|
|
||
| # End of Document token to add end to truncated samples | ||
| # End of Document token to add end to truncated samples TODO: currently works with HF tokenizers only | ||
| self._eod_token_id = self.tokenizer.eod | ||
| self._bos_token_id = self.tokenizer.bos | ||
| # TODO: Pass sequences dynamically | ||
|
|
@@ -49,13 +53,18 @@ def __init__( | |
| self._sft_turn_end_sequence = self.tokenizer.tokenize('<|eot_id|>', add_special_tokens=False) | ||
| self._sft_assistant_begin_sequence = self.tokenizer.tokenize('<|start_header_id|>assistant<|end_header_id|>', | ||
| add_special_tokens=False) | ||
| # TODO: Make multimodality optional - configurable or just leave this as long as there is only one option? | ||
| self._img_begin_sequence = self.tokenizer.tokenize('<|img_start|>', add_special_tokens=False) | ||
| self._img_end_sequence = self.tokenizer.tokenize('<|img_end|>', add_special_tokens=False) | ||
|
|
||
| # Configure token(sequences to remove from loss calculations) | ||
| self.tokens_to_mask = [] # a list of: token ids or sequences of token ids to mask | ||
| if self.config.sft_mask_special_tokens: | ||
| # add tokenizer special tokens like EOS, BOS to be masked | ||
| self.tokens_to_mask += list(self.tokenizer.special_tokens) | ||
| self.tokens_to_mask.append(self._sft_user_assistant_sequence) | ||
| logger.warning(f"Masking the following tokens/token-sequences: {self.tokens_to_mask}") | ||
| self.tokens_to_mask.append([self._eod_token_id]) | ||
| self.tokens_to_mask.append([self._bos_token_id]) | ||
| self.tokens_to_mask.append(self._sft_assistant_begin_sequence) | ||
| log_single_rank(logger, logging.WARNING, f"Masking the following tokens/token-sequences: {self.tokens_to_mask}", ) | ||
|
|
||
| # Build shuffle indices | ||
| self.document_index = self._build_single_document_indices() | ||
|
|
@@ -212,7 +221,8 @@ def __getitem__(self, idx: Optional[int]) -> Dict[str, torch.Tensor]: | |
| text = np.concatenate([trunc_doc, np.array([self._eod_token_id])]) | ||
| else: | ||
| padding_length = target_length - len(document) | ||
| text = np.concatenate([document, np.full(padding_length, self._pad_token_id, dtype=np.int64)]) | ||
| # Pad on left side with pad token | ||
| text = np.concatenate([np.full(padding_length, self._pad_token_id, dtype=np.int64), document]) | ||
|
|
||
| text = torch.from_numpy(text).long() | ||
|
|
||
|
|
@@ -250,17 +260,39 @@ def __getitem__(self, idx: Optional[int]) -> Dict[str, torch.Tensor]: | |
| loss_mask[labels == self._pad_token_id] = 0.0 | ||
|
|
||
| # DEBUG: Log how many tokens are actually being trained on | ||
| #num_unmasked = loss_mask.sum().item() | ||
| #total_tokens = loss_mask.numel() | ||
| #if idx is not None and idx % 100 == 0: # Log every 100 samples | ||
| # logger.warning(f"Sample {idx}: {num_unmasked}/{total_tokens} tokens unmasked " | ||
| # f"({100*num_unmasked/total_tokens:.1f}%), " | ||
| # f"doc_length={len(torch.from_numpy(document))}, " | ||
| # f"num_pad_tokens={(labels == self._pad_token_id).sum().item()}") | ||
| # user_begin_seq = torch.tensor(self._sft_user_begin_sequence, dtype=tokens.dtype, device=tokens.device) | ||
| # logger.warning(f"Sample {idx}: Looking for user_begin pattern: {user_begin_seq.tolist()}") | ||
| # logger.warning(f"Sample {idx}: Token sequence sample: {tokens[:100].tolist()}") | ||
| # logger.warning(f"Sample {idx}: Loss mask sample: {loss_mask[:100].tolist()}") | ||
| num_unmasked = loss_mask.sum().item() | ||
| total_tokens = loss_mask.numel() | ||
|
|
||
| if False and idx is not None and idx % 100 == 0: # Log every 100 samples | ||
| logger.warning(f"Sample {idx} - DOC {actual_doc_id}: {num_unmasked}/{total_tokens} tokens unmasked " | ||
| f"({100*num_unmasked/total_tokens:.1f}%), " | ||
| f"doc_length={len(torch.from_numpy(document))}, " | ||
| f"num_pad_tokens={(labels == self._pad_token_id).sum().item()}") | ||
|
|
||
| # Store to files | ||
| self.debug_writer.append_sample( | ||
| idx=idx, | ||
| actual_doc_id=actual_doc_id, | ||
| tokens=tokens, | ||
| loss_mask=loss_mask, | ||
| attention_mask=attention_mask, | ||
| position_ids=position_ids, | ||
| labels=labels, | ||
| document_length=len(torch.from_numpy(document)), | ||
| pad_token_id=self._pad_token_id | ||
| ) | ||
|
|
||
| # Continue with existing logging | ||
| user_begin_seq = torch.tensor(self._sft_user_begin_sequence, dtype=tokens.dtype, device=tokens.device) | ||
| logger.warning(f"Sample {idx}: Looking for user_begin pattern: {user_begin_seq.tolist()}") | ||
| logger.warning(f"Sample {idx}: Token sequence sample: {tokens[:100].tolist()}") | ||
| logger.warning(f"Sample {idx}: Loss mask sample: {loss_mask[:100].tolist()}") | ||
| logger.warning(f"Sample {idx}: Loss mask sample last 100: {loss_mask[-100:].tolist()}") | ||
| logger.warning(f"Sample {idx}: Position ids sample: {position_ids[:100].tolist()}") | ||
| if attention_mask is not None: | ||
| logger.warning(f"Sample {idx}: Attention mask sample (1=masked): {attention_mask[0,0,:100].long().tolist()}") | ||
| # END DEBUG | ||
|
|
||
|
|
||
| # Map pad tokens to valid embedding indices | ||
| tokens[tokens == self._pad_token_id] = 0 | ||
|
|
@@ -315,50 +347,8 @@ def _get_ltor_masks_and_position_ids(self, tokens, | |
| begin_seq = torch.tensor(self._sft_user_begin_sequence, dtype=tokens.dtype, device=tokens.device) | ||
| end_seq = torch.tensor(self._sft_turn_end_sequence, dtype=tokens.dtype, device=tokens.device) | ||
|
|
||
| begin_len = len(begin_seq) | ||
| end_len = len(end_seq) | ||
|
|
||
| if 0 < begin_len <= len(tokens): | ||
| matches_begin = get_matching_mask(tokens, begin_seq, only_begin=True) | ||
|
|
||
| if end_len > 0: | ||
| matches_end = get_matching_mask(tokens, end_seq, only_begin=True) | ||
|
|
||
| begin_indices = torch.where(matches_begin)[0] | ||
| end_indices = torch.where(matches_end)[0] | ||
|
|
||
| # Vectorized masking | ||
| if len(begin_indices) > 0 and len(end_indices) > 0: | ||
| # For each begin, find the next ends (vectorized) | ||
| end_matrix = end_indices.unsqueeze(0) > begin_indices.unsqueeze(1) | ||
| has_valid_end = end_matrix.any(dim=1) | ||
| first_end_idx = end_matrix.int().argmax(dim=1) | ||
|
|
||
| # Compute end positions for each begin | ||
| end_positions = torch.where( | ||
| has_valid_end, | ||
| end_indices[first_end_idx] + end_len, | ||
| len(loss_mask) | ||
| ) | ||
|
|
||
| # Create ranges and mask in one go, Shape: (num_begins, max_range_len) | ||
| max_len = (end_positions - begin_indices).max().item() | ||
| ranges = torch.arange(max_len, device=tokens.device).unsqueeze(0) | ||
| lengths = (end_positions - begin_indices).unsqueeze(1) | ||
|
|
||
| # Get all indices to mask | ||
| mask_positions = begin_indices.unsqueeze(1) + ranges | ||
| valid_mask = ranges < lengths | ||
| indices_to_mask = mask_positions[valid_mask] | ||
|
|
||
| loss_mask[indices_to_mask] = 0.0 | ||
| elif len(begin_indices) > 0: | ||
| # No end sequences, mask from each begin to the end | ||
| max_len = len(loss_mask) - begin_indices.min().item() | ||
| ranges = torch.arange(max_len, device=tokens.device).unsqueeze(0) | ||
| mask_positions = begin_indices.unsqueeze(1) + ranges | ||
| valid = mask_positions < len(loss_mask) | ||
| loss_mask[mask_positions[valid]] = 0.0 | ||
| user_seq_mask = get_matching_mask_by_start_end(tokens, begin_seq, end_seq) | ||
| loss_mask[user_seq_mask] = 0.0 | ||
|
|
||
| # 2) Mask other token(sequences) as configured in init (might contain BOS, EOS, assistant begin) | ||
| for t in self.tokens_to_mask: | ||
|
|
@@ -371,19 +361,27 @@ def _get_ltor_masks_and_position_ids(self, tokens, | |
| raise ValueError(f"Invalid token to mask: {t}") | ||
| loss_mask[mask] = 0.0 | ||
|
|
||
| # 3) Unmask Image tokens if configured | ||
| if self.config.sft_do_not_mask_image_tokens: | ||
| img_begin_tensor = torch.tensor(self._img_begin_sequence, dtype=tokens.dtype, device=tokens.device) | ||
| img_end_tensor = torch.tensor(self._img_end_sequence, dtype=tokens.dtype, device=tokens.device) | ||
| img_seq_mask = get_matching_mask_by_start_end(tokens, img_begin_tensor, img_end_tensor) | ||
| loss_mask[img_seq_mask] = 1.0 | ||
|
|
||
| # 4) Create attention mask, excluding padding tokens from attention | ||
| if create_attention_mask: | ||
| # Here me mask attention from all padding tokens to all other tokens and vice versa | ||
| attention_mask = torch.tril( | ||
| torch.ones((self.config.sequence_length, self.config.sequence_length), device=tokens.device) | ||
| ) | ||
| # Mask padding tokens in attention mask: | ||
| #no_padding_mask = (tokens != self._pad_token_id).float() # 1=real, 0=padding | ||
| no_padding_mask = (tokens != self._pad_token_id).float() # 1=real, 0=padding | ||
|
|
||
| # Mask both rows (queries from padding) and columns (keys to padding) | ||
| # Row masking: padding tokens shouldn't attend to anything | ||
| #attention_mask = attention_mask * no_padding_mask.unsqueeze(1) | ||
| attention_mask = attention_mask * no_padding_mask.unsqueeze(1) | ||
| # Column masking: nothing should attend to padding tokens | ||
| #attention_mask = attention_mask * no_padding_mask.unsqueeze(0) | ||
| attention_mask = attention_mask * no_padding_mask.unsqueeze(0) | ||
|
|
||
| # Convert attention mask to binary: | ||
| attention_mask = attention_mask.unsqueeze(0) | ||
|
|
@@ -419,3 +417,135 @@ def get_matching_mask(sequence, query: torch.Tensor, only_begin:bool=True): | |
| return matches | ||
|
|
||
|
|
||
| def get_matching_mask_by_start_end(tokens, begin_seq: torch.Tensor, end_seq: torch.Tensor): | ||
| """ | ||
| Given a sequence and a start and end query, return a mask indicating which positions in the sequence | ||
| are between the start and end queries (inclusive). | ||
| """ | ||
| mask = torch.zeros(len(tokens), dtype=torch.bool, device=tokens.device) | ||
| begin_len = len(begin_seq) | ||
| end_len = len(end_seq) | ||
|
|
||
| if 0 < begin_len <= len(tokens): | ||
| matches_begin = get_matching_mask(tokens, begin_seq, only_begin=True) | ||
|
||
|
|
||
| if end_len > 0: | ||
| matches_end = get_matching_mask(tokens, end_seq, only_begin=True) | ||
|
|
||
| begin_indices = torch.where(matches_begin)[0] | ||
| end_indices = torch.where(matches_end)[0] | ||
|
|
||
| # Vectorized masking | ||
| if len(begin_indices) > 0 and len(end_indices) > 0: | ||
| # For each begin, find the next ends (vectorized) | ||
| end_matrix = end_indices.unsqueeze(0) > begin_indices.unsqueeze(1) | ||
| has_valid_end = end_matrix.any(dim=1) | ||
| first_end_idx = end_matrix.int().argmax(dim=1) | ||
|
|
||
| # Compute end positions for each begin | ||
| end_positions = torch.where( | ||
| has_valid_end, | ||
| end_indices[first_end_idx] + end_len, | ||
| len(mask) | ||
| ) | ||
|
|
||
| # Create ranges and mask in one go, Shape: (num_begins, max_range_len) | ||
| max_len = (end_positions - begin_indices).max().item() | ||
| ranges = torch.arange(max_len, device=tokens.device).unsqueeze(0) | ||
| lengths = (end_positions - begin_indices).unsqueeze(1) | ||
|
|
||
| # Get all indices to mask | ||
| mask_positions = begin_indices.unsqueeze(1) + ranges | ||
| valid_mask = ranges < lengths | ||
| indices_to_mask = mask_positions[valid_mask] | ||
|
|
||
| mask[indices_to_mask] = True | ||
| elif len(begin_indices) > 0: | ||
| # No end sequences, mask from each begin to the end | ||
| max_len = len(mask) - begin_indices.min().item() | ||
| ranges = torch.arange(max_len, device=tokens.device).unsqueeze(0) | ||
| mask_positions = begin_indices.unsqueeze(1) + ranges | ||
| valid = mask_positions < len(mask) | ||
| mask[mask_positions[valid]] = True | ||
| return mask | ||
|
|
||
|
|
||
| class DebugDataWriter: | ||
| """Helper class to write debug data to files.""" | ||
|
|
||
| def __init__(self, output_dir="debug_data"): | ||
| self.output_dir = Path(output_dir) | ||
| self.output_dir.mkdir(exist_ok=True) | ||
| self.metadata_file = self.output_dir / "metadata.jsonl" | ||
| self.tokens_file = self.output_dir / "tokens.npy" | ||
| self.loss_mask_file = self.output_dir / "loss_mask.npy" | ||
| self.attention_mask_file = self.output_dir / "attention_mask.npy" | ||
|
|
||
| # Initialize files if they don't exist | ||
| if not self.tokens_file.exists(): | ||
| self._init_array_file(self.tokens_file) | ||
| if not self.loss_mask_file.exists(): | ||
| self._init_array_file(self.loss_mask_file) | ||
| if not self.attention_mask_file.exists(): | ||
| self._init_array_file(self.attention_mask_file) | ||
|
|
||
| def _init_array_file(self, filepath): | ||
| """Initialize an empty numpy array file.""" | ||
| np.save(filepath, np.array([])) | ||
|
|
||
| def append_sample(self, idx, actual_doc_id, tokens, loss_mask, | ||
| attention_mask=None, position_ids=None, | ||
| labels=None, document_length=None, pad_token_id=None): | ||
| """Append a sample to the debug files.""" | ||
|
|
||
| # Convert tensors to numpy | ||
| tokens_np = tokens.cpu().numpy() if torch.is_tensor(tokens) else tokens | ||
| loss_mask_np = loss_mask.cpu().numpy() if torch.is_tensor(loss_mask) else loss_mask | ||
|
|
||
| # Load existing data | ||
| tokens_data = np.load(self.tokens_file, allow_pickle=True) | ||
| loss_mask_data = np.load(self.loss_mask_file, allow_pickle=True) | ||
|
|
||
| # Append new data | ||
| if tokens_data.size == 0: | ||
| tokens_data = np.array([tokens_np], dtype=object) | ||
| loss_mask_data = np.array([loss_mask_np], dtype=object) | ||
| else: | ||
| tokens_data = np.append(tokens_data, [tokens_np]) | ||
| loss_mask_data = np.append(loss_mask_data, [loss_mask_np]) | ||
|
|
||
| # Save updated arrays | ||
| np.save(self.tokens_file, tokens_data) | ||
| np.save(self.loss_mask_file, loss_mask_data) | ||
|
|
||
| # Handle attention mask | ||
| if attention_mask is not None: | ||
| attention_mask_np = attention_mask.cpu().numpy() if torch.is_tensor(attention_mask) else attention_mask | ||
| attention_mask_data = np.load(self.attention_mask_file, allow_pickle=True) | ||
|
|
||
| if attention_mask_data.size == 0: | ||
| attention_mask_data = np.array([attention_mask_np], dtype=object) | ||
| else: | ||
| attention_mask_data = np.append(attention_mask_data, [attention_mask_np]) | ||
|
|
||
| np.save(self.attention_mask_file, attention_mask_data) | ||
|
|
||
| # Save metadata | ||
| num_unmasked = loss_mask_np.sum() | ||
| total_tokens = loss_mask_np.size | ||
|
|
||
| metadata = { | ||
| "sample_idx": int(idx) if idx is not None else None, | ||
| "doc_id": int(actual_doc_id) if actual_doc_id is not None else None, | ||
| "num_unmasked": int(num_unmasked), | ||
| "total_tokens": int(total_tokens), | ||
| "unmasked_percentage": float(100 * num_unmasked / total_tokens), | ||
| "document_length": int(document_length) if document_length is not None else None, | ||
| "num_pad_tokens": int((labels == pad_token_id).sum()) if labels is not None and pad_token_id is not None else None, | ||
| "sequence_length": int(len(tokens_np)), | ||
| "has_attention_mask": attention_mask is not None | ||
| } | ||
|
|
||
| # Append metadata as JSONL | ||
| with open(self.metadata_file, 'a') as f: | ||
| f.write(json.dumps(metadata) + '\n') | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CRUCIAL BUG: We're Searching in the Wrong Place
Let’s make this concrete with a simple example.
Each
loss_mask[i]controls whether we train on predictinglabels[i]fromtokens[i], i.e.:loss_mask[0]: a → bloss_mask[1]: b → closs_mask[2]: c → dloss_mask[3]: d → eThe Problem
Say we want to mask the prediction of
'c'(don’t train the model to predict'c').Current (incorrect) logic:
'c'in tokens → find index 2loss_mask[2] = 0Result:
This disables the c→d prediction instead of b→c.
Correct logic
We should search for
'c'in labels, not tokens:'c'is at label index 1loss_mask[1] = 0Result:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I now renamed the arguments to the sequence matching methods (to "data" and "sequence") and more importantly give them the labels to calculate the mask.