|
1 | | -import sys |
2 | | -from copy import deepcopy |
3 | | -from pathlib import Path |
4 | | - |
5 | | -import pytest |
6 | | -import torch |
7 | | -import torch.nn as nn |
8 | | -import torch.nn.functional as F |
9 | | -from torch.utils.data import DataLoader |
10 | | -from transformers import ( |
11 | | - AutoTokenizer, |
12 | | - BambaConfig, |
13 | | - BambaForCausalLM, |
14 | | - DataCollatorForSeq2Seq, |
15 | | - LlamaConfig, |
16 | | - LlamaForCausalLM, |
17 | | -) |
18 | | - |
19 | | -# HACK for being able to load the collator without needing to install open-instruct |
20 | | -open_instruct_dir = Path(__file__).parent.parent.absolute() |
21 | | -sys.path.append(open_instruct_dir) |
22 | | -from open_instruct.dataset_processor import CHAT_TEMPLATES |
23 | | -from open_instruct.dataset_transformation import sft_tulu_tokenize_and_truncate_v1 |
24 | | -from open_instruct.padding_free_collator import TensorDataCollatorWithFlattening |
25 | | - |
26 | | -try: |
27 | | - import mamba_ssm # noqa |
28 | | - import causal_conv1d # noqa |
29 | | - |
30 | | - mamba_and_causal_conv_available = True |
31 | | -except ImportError: |
32 | | - mamba_and_causal_conv_available = False |
33 | | - |
34 | | -try: |
35 | | - import flash_attn # noqa |
36 | | - |
37 | | - flash_attn_available = True |
38 | | -except ImportError: |
39 | | - flash_attn_available = False |
40 | | - |
41 | | -MODEL_CLASSES = {"bamba": BambaForCausalLM, "llama": LlamaForCausalLM} |
42 | | -MODEL_CFGS = {"bamba": BambaConfig, "llama": LlamaConfig} |
43 | | -MODEL_KWARGS = { |
44 | | - "bamba": dict( |
45 | | - attention_dropout=0.0, |
46 | | - attn_layer_indices=None, |
47 | | - attn_rotary_emb=8, |
48 | | - hidden_act="silu", |
49 | | - hidden_size=32, |
50 | | - initializer_range=0.02, |
51 | | - intermediate_size=64, |
52 | | - mamba_chunk_size=16, |
53 | | - mamba_d_conv=4, |
54 | | - mamba_d_state=16, |
55 | | - mamba_expand=2, |
56 | | - mamba_n_groups=1, |
57 | | - mamba_n_heads=16, |
58 | | - max_position_embeddings=512, |
59 | | - num_attention_heads=4, |
60 | | - num_hidden_layers=1, |
61 | | - num_key_value_heads=2, |
62 | | - pad_token_id=0, |
63 | | - ), |
64 | | - "llama": dict( |
65 | | - hidden_act="gelu", |
66 | | - hidden_size=32, |
67 | | - intermediate_size=64, |
68 | | - is_training=True, |
69 | | - max_position_embeddings=512, |
70 | | - mlp_bias=False, |
71 | | - num_attention_heads=2, |
72 | | - num_hidden_layers=1, |
73 | | - num_key_value_heads=2, |
74 | | - ), |
75 | | -} |
76 | | - |
77 | | - |
78 | | -class TestPaddingFree: |
79 | | - seqlen = 128 |
80 | | - batch_size = 2 |
81 | | - dtype = torch.bfloat16 |
82 | | - |
83 | | - def get_fa2_model_and_cfg(self, model_name: str, vocab_size: int) -> nn.Module: |
84 | | - model_cls = MODEL_CLASSES[model_name] |
85 | | - model_cfg = MODEL_CFGS[model_name] |
86 | | - model_kwargs = MODEL_KWARGS[model_name] |
87 | | - cfg = model_cfg( |
88 | | - **{ |
89 | | - **model_kwargs, |
90 | | - "torch_dtype": self.dtype, |
91 | | - "attn_implementation": "flash_attention_2", |
92 | | - "vocab_size": vocab_size, |
93 | | - } |
94 | | - ) |
95 | | - model = model_cls(cfg).to("cuda", dtype=self.dtype) |
96 | | - return model, cfg |
97 | | - |
98 | | - @pytest.mark.skipif(not torch.cuda.is_available(), reason="Padding free tests require CUDA") |
99 | | - @pytest.mark.skipif(not flash_attn_available, reason="Padding free requires flash_attn") |
100 | | - @pytest.mark.parametrize("model_name", ["bamba", "llama"]) |
101 | | - @pytest.mark.parametrize("loss_type", ["mean", "sum"]) |
102 | | - def test_padding_free(self, model_name: str, loss_type: str) -> None: |
103 | | - if model_name == "bamba" and not mamba_and_causal_conv_available: |
104 | | - pytest.skip("bamba padding-free tests require mamba_ssm and causal_conv1d") |
105 | | - torch.manual_seed(42) |
106 | | - |
107 | | - tokenizer = AutoTokenizer.from_pretrained("ibm-ai-platform/Bamba-9B-v2") |
108 | | - tokenizer.add_special_tokens({"pad_token": "<pad>"}) |
109 | | - tokenizer.chat_template = CHAT_TEMPLATES["tulu"] |
110 | | - vocab_size = len(tokenizer) |
111 | | - |
112 | | - model, cfg = self.get_fa2_model_and_cfg(model_name, vocab_size) |
113 | | - model.initialize_weights() |
114 | | - pf_model = deepcopy(model) |
115 | | - |
116 | | - inputs = torch.randint(cfg.vocab_size, size=(self.batch_size, self.seqlen), device="cpu") |
117 | | - |
118 | | - data = { |
119 | | - 0: { |
120 | | - "messages": [ |
121 | | - {"role": "user", "content": "Why did the chicken cross the road?"}, |
122 | | - {"role": "assistant", "content": "To get to the other side"}, |
123 | | - ] |
124 | | - }, |
125 | | - 1: { |
126 | | - "messages": [ |
127 | | - {"role": "user", "content": "What is one plus two?"}, |
128 | | - {"role": "assistant", "content": "The answer is 3"}, |
129 | | - ] |
130 | | - }, |
131 | | - } |
132 | | - |
133 | | - tok_data = {k: sft_tulu_tokenize_and_truncate_v1(v, tokenizer, max_seq_length=2**30) for k, v in data.items()} |
134 | | - for v in tok_data.values(): |
135 | | - del v["messages"] |
136 | | - |
137 | | - collate_fn = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model, padding="longest") |
138 | | - dataloader = DataLoader(tok_data, shuffle=False, collate_fn=collate_fn, batch_size=self.batch_size) |
139 | | - |
140 | | - pf_collate_fn = TensorDataCollatorWithFlattening() |
141 | | - pf_dataloader = DataLoader(tok_data, shuffle=False, collate_fn=pf_collate_fn, batch_size=self.batch_size) |
142 | | - |
143 | | - batch = next(iter(dataloader)) |
144 | | - pf_batch = next(iter(pf_dataloader)) |
145 | | - for b in (batch, pf_batch): |
146 | | - for k in b: |
147 | | - if torch.is_tensor(b[k]): |
148 | | - b[k] = b[k].cuda() |
149 | | - |
150 | | - assert batch["input_ids"].shape[0] == 2 |
151 | | - assert pf_batch["input_ids"].shape[0] == 1 |
152 | | - |
153 | | - # Also create a batch with the pf style concatenation, but without the pf seq markers as a |
154 | | - # control. Passing this through the model should give incorrect results. |
155 | | - |
156 | | - incorrect_pf_batch = { |
157 | | - "input_ids": pf_batch["input_ids"], |
158 | | - "labels": pf_batch["labels"], |
159 | | - "attention_mask": torch.ones_like(pf_batch["input_ids"]), |
160 | | - } |
161 | | - |
162 | | - outputs = model(**batch) |
163 | | - pf_outputs = pf_model(**pf_batch) |
164 | | - with torch.no_grad(): |
165 | | - incorrect_pf_outputs = model(**incorrect_pf_batch) |
166 | | - |
167 | | - # Compare logits (properly reshaped and masked) |
168 | | - logits = outputs.logits.reshape(1, -1, outputs.logits.shape[-1]) |
169 | | - non_masked_logits = logits[:, batch["attention_mask"].flatten().bool()] |
170 | | - pf_logits = pf_outputs.logits |
171 | | - incorrect_pf_logits = incorrect_pf_outputs.logits |
172 | | - torch.testing.assert_close(pf_logits, non_masked_logits) |
173 | | - with pytest.raises(AssertionError, match="Mismatched elements:"): |
174 | | - torch.testing.assert_close(pf_logits, incorrect_pf_logits) |
175 | | - |
176 | | - if loss_type == "mean": |
177 | | - loss = outputs.loss |
178 | | - pf_loss = pf_outputs.loss |
179 | | - else: |
180 | | - loss = F.cross_entropy(logits.view(-1, logits.size(-1)), batch["labels"].view(-1).long(), reduce="sum") |
181 | | - pf_loss = F.cross_entropy( |
182 | | - pf_logits.view(-1, pf_logits.size(-1)), pf_batch["labels"].view(-1).long(), reduce="sum" |
183 | | - ) |
184 | | - torch.testing.assert_close(loss, pf_loss) |
185 | | - |
186 | | - loss.backward() |
187 | | - pf_loss.backward() |
188 | | - |
189 | | - grads = {n: p.grad for n, p in model.named_parameters()} |
190 | | - pf_grads = {n: p.grad for n, p in pf_model.named_parameters()} |
191 | | - non_nan_grads = set() |
192 | | - nan_grads = set() |
193 | | - for k, g in grads.items(): |
194 | | - torch.testing.assert_close(g, pf_grads[k]) |
195 | | - non_nan_grads.add(k) |
196 | | - print(f"{non_nan_grads=}") |
197 | | - print(f"{nan_grads=}") |
| 1 | +# This file has been modified to remove GPU tests. |
| 2 | +# GPU tests have been moved to test_padding_free_gpu.py |
| 3 | +# The file is kept as a placeholder for any future non-GPU tests. |
0 commit comments