Skip to content

Commit a8dcde2

Browse files
authored
Use ungated models for unit tests (#196)
* Fix unit test * Fix chat template tests * Remove deprecated test * up
1 parent 28bf902 commit a8dcde2

File tree

3 files changed

+11
-25
lines changed

3 files changed

+11
-25
lines changed

src/alignment/data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def maybe_insert_system_message(messages, tokenizer):
3232
# chat template can be one of two attributes, we check in order
3333
chat_template = tokenizer.chat_template
3434
if chat_template is None:
35-
chat_template = tokenizer.default_chat_template
35+
chat_template = tokenizer.get_chat_template()
3636

3737
# confirm the jinja template refers to a system message before inserting
3838
if "system" in chat_template or "<|im_start|>" in chat_template:

tests/test_data.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -122,21 +122,21 @@ def setUp(self):
122122
)
123123

124124
def test_maybe_insert_system_message(self):
125-
# does not accept system prompt
126-
mistral_tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
127-
# accepts system prompt. use codellama since it has no HF token requirement
128-
llama_tokenizer = AutoTokenizer.from_pretrained("codellama/CodeLlama-7b-hf")
125+
# Chat template that does not accept system prompt. Use community checkpoint since it has no HF token requirement
126+
tokenizer_sys_excl = AutoTokenizer.from_pretrained("mistral-community/Mistral-7B-Instruct-v0.3")
127+
# Chat template that accepts system prompt
128+
tokenizer_sys_incl = AutoTokenizer.from_pretrained("Qwen/Qwen2-7B-Instruct")
129129
messages_sys_excl = [{"role": "user", "content": "Tell me a joke."}]
130130
messages_sys_incl = [{"role": "system", "content": ""}, {"role": "user", "content": "Tell me a joke."}]
131131

132-
mistral_messages = deepcopy(messages_sys_excl)
133-
llama_messages = deepcopy(messages_sys_excl)
134-
maybe_insert_system_message(mistral_messages, mistral_tokenizer)
135-
maybe_insert_system_message(llama_messages, llama_tokenizer)
132+
messages_proc_excl = deepcopy(messages_sys_excl)
133+
message_proc_incl = deepcopy(messages_sys_excl)
134+
maybe_insert_system_message(messages_proc_excl, tokenizer_sys_excl)
135+
maybe_insert_system_message(message_proc_incl, tokenizer_sys_incl)
136136

137137
# output from mistral should not have a system message, output from llama should
138-
self.assertEqual(mistral_messages, messages_sys_excl)
139-
self.assertEqual(llama_messages, messages_sys_incl)
138+
self.assertEqual(messages_proc_excl, messages_sys_excl)
139+
self.assertEqual(message_proc_incl, messages_sys_incl)
140140

141141
def test_sft(self):
142142
dataset = self.dataset.map(

tests/test_model_utils.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import unittest
1616

1717
import torch
18-
from transformers import AutoTokenizer
1918

2019
from alignment import (
2120
DataArguments,
@@ -64,19 +63,6 @@ def test_default_chat_template(self):
6463
tokenizer = get_tokenizer(self.model_args, DataArguments())
6564
self.assertEqual(tokenizer.chat_template, DEFAULT_CHAT_TEMPLATE)
6665

67-
def test_default_chat_template_no_overwrite(self):
68-
"""
69-
If no chat template is passed explicitly in the config, then for models with a
70-
`default_chat_template` but no `chat_template` we do not set a `chat_template`,
71-
and that we do not change `default_chat_template`
72-
"""
73-
model_args = ModelArguments(model_name_or_path="m-a-p/OpenCodeInterpreter-SC2-7B")
74-
base_tokenizer = AutoTokenizer.from_pretrained("m-a-p/OpenCodeInterpreter-SC2-7B")
75-
processed_tokenizer = get_tokenizer(model_args, DataArguments())
76-
77-
assert getattr(processed_tokenizer, "chat_template") is None
78-
self.assertEqual(base_tokenizer.default_chat_template, processed_tokenizer.default_chat_template)
79-
8066
def test_chatml_chat_template(self):
8167
chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
8268
tokenizer = get_tokenizer(self.model_args, DataArguments(chat_template=chat_template))

0 commit comments

Comments
 (0)