Skip to content

Commit 72af0cb

Browse files
authored
Merge pull request #43 from ruivieira/RHOAIENG-12606
RHOAIENG-12606: Add notes on using local models
2 parents 16ce8df + b35978a commit 72af0cb

File tree

1 file changed

+111
-0
lines changed

1 file changed

+111
-0
lines changed

examples/Detoxify.ipynb

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,14 @@
188188
"text"
189189
]
190190
},
191+
{
192+
"cell_type": "markdown",
193+
"id": "1eb7719e30054304",
194+
"metadata": {},
195+
"source": [
196+
"## Initializing TMaRCo"
197+
]
198+
},
191199
{
192200
"cell_type": "code",
193201
"execution_count": 5,
@@ -198,6 +206,64 @@
198206
"tmarco = TMaRCo()"
199207
]
200208
},
209+
{
210+
"cell_type": "markdown",
211+
"id": "3e16ee305f4983d9",
212+
"metadata": {},
213+
"source": [
214+
"This will initialize `TMaRCo` using the default models, taken from HuggingFace.\n",
215+
"<div class=\"alert alert-info\">\n",
216+
"To use local models with TMaRCo, we need to have the pre-initialized models in a local storage that is accessible to TMaRCo.\n",
217+
"</div>\n",
218+
"For instance, to use the default `facebook/bart-large` model, but locally. First, we would need to retrieve the model:"
219+
]
220+
},
221+
{
222+
"cell_type": "code",
223+
"execution_count": null,
224+
"id": "614c9ff6f46a0ea9",
225+
"metadata": {},
226+
"outputs": [],
227+
"source": [
228+
"from huggingface_hub import snapshot_download\n",
229+
"\n",
230+
"snapshot_download(repo_id=\"facebook/bart-large\", local_dir=\"models/bart\")"
231+
]
232+
},
233+
{
234+
"cell_type": "markdown",
235+
"id": "95bd792e757205d6",
236+
"metadata": {},
237+
"source": [
238+
"We now initialize the base model and tokenizer from local files and pass them to `TMaRCo`:"
239+
]
240+
},
241+
{
242+
"cell_type": "code",
243+
"execution_count": null,
244+
"id": "f0f24485822a7c3f",
245+
"metadata": {},
246+
"outputs": [],
247+
"source": [
248+
"from transformers import BartForConditionalGeneration, BartTokenizer\n",
249+
"\n",
250+
"tokenizer = BartTokenizer.from_pretrained(\n",
251+
" \"models/bart\", # Or directory where the local model is stored\n",
252+
" is_split_into_words=True, add_prefix_space=True\n",
253+
")\n",
254+
"\n",
255+
"tokenizer.pad_token_id = tokenizer.eos_token_id\n",
256+
"\n",
257+
"base = BartForConditionalGeneration.from_pretrained(\n",
258+
" \"models/bart\", # Or directory where the local model is stored\n",
259+
" max_length=150,\n",
260+
" forced_bos_token_id=tokenizer.bos_token_id,\n",
261+
")\n",
262+
"\n",
263+
"# Initialize TMaRCo with local models\n",
264+
"tmarco = TMaRCo(tokenizer=tokenizer, base_model=base)"
265+
]
266+
},
201267
{
202268
"cell_type": "code",
203269
"execution_count": 7,
@@ -223,6 +289,32 @@
223289
"tmarco.load_models([\"trustyai/gminus\", \"trustyai/gplus\"])"
224290
]
225291
},
292+
{
293+
"cell_type": "markdown",
294+
"id": "c113208c527c342e",
295+
"metadata": {},
296+
"source": [
297+
"<div class=\"alert alert-info\">\n",
298+
"To use local expert/anti-expert models with TMaRCo, we need to have them in a local storage that is accessible to TMaRCo, as previously.\n",
299+
"\n",
300+
"However, we don't need to initialize them separately, and can pass the directory directly.\n",
301+
"</div>\n",
302+
"If we want to use local models with `TMaRCo` (in this case the same default `gminus`/`gplus`):\n"
303+
]
304+
},
305+
{
306+
"cell_type": "code",
307+
"execution_count": null,
308+
"id": "dfa288dcb60102c",
309+
"metadata": {},
310+
"outputs": [],
311+
"source": [
312+
"snapshot_download(repo_id=\"trustyai/gminus\", local_dir=\"models/gminus\")\n",
313+
"snapshot_download(repo_id=\"trustyai/gplus\", local_dir=\"models/gplus\")\n",
314+
"\n",
315+
"tmarco.load_models([\"models/gminus\", \"models/gplus\"])"
316+
]
317+
},
226318
{
227319
"cell_type": "code",
228320
"execution_count": 13,
@@ -362,6 +454,25 @@
362454
"tmarco.load_models([\"trustyai/gminus\", \"trustyai/gplus\"])"
363455
]
364456
},
457+
{
458+
"cell_type": "markdown",
459+
"id": "b0738c324227f57",
460+
"metadata": {},
461+
"source": [
462+
"As noted previously, to use local models, simply pass the initialized tokenizer and base model to the constructor, and the local path as the expert/anti-expert:"
463+
]
464+
},
465+
{
466+
"cell_type": "code",
467+
"execution_count": null,
468+
"id": "b929e21a97ea914e",
469+
"metadata": {},
470+
"outputs": [],
471+
"source": [
472+
"tmarco = TMaRCo(tokenizer=tokenizer, base_model=base)\n",
473+
"tmarco.load_models([\"models/gminus\", \"models/gplus\"])"
474+
]
475+
},
365476
{
366477
"cell_type": "markdown",
367478
"id": "5303f56b-85ff-40da-99bf-6962cf2f3395",

0 commit comments

Comments
 (0)