|
8 | 8 | "\n", |
9 | 9 | "### The Dungeons Dataset\n", |
10 | 10 | "\n", |
11 | | - "The Dungeons dataset is a (dungeons-themed) challenging synthetic dataset for supervised classification on \n", |
12 | | - "semi-structured data. \n", |
| 11 | + "The Dungeons dataset is a (dungeons-themed) challenging synthetic dataset for supervised classification on\n", |
| 12 | + "semi-structured data.\n", |
13 | 13 | "\n", |
14 | | - "Each instance constains a corridor array with several rooms. Each room has a door number and contains multiple \n", |
| 14 | + "Each instance constains a corridor array with several rooms. Each room has a door number and contains multiple\n", |
15 | 15 | "treasure chests with different-colored keys. All but one of the treasures are fake though.\n", |
16 | 16 | "\n", |
17 | | - "The goal is to find the correct room number and key color in each dungeon based on some clues and return the \n", |
18 | | - "only real treasure. The clues are given at the top-level of the object in the fields `door` and `key_color`. \n", |
| 17 | + "The goal is to find the correct room number and key color in each dungeon based on some clues and return the\n", |
| 18 | + "only real treasure. The clues are given at the top-level of the object in the fields `door` and `key_color`.\n", |
19 | 19 | "\n", |
20 | | - "To make it even harder, the `corridor` array may be shuffled (`shuffle_rooms=True`), and room objects may \n", |
21 | | - "have a number of monsters as their first field (`with_monsters=True`), shifting the token positions of the \n", |
22 | | - "serialized object by a variable amount. \n", |
| 20 | + "To make it even harder, the `corridor` array may be shuffled (`shuffle_rooms=True`), and room objects may\n", |
| 21 | + "have a number of monsters as their first field (`with_monsters=True`), shifting the token positions of the\n", |
| 22 | + "serialized object by a variable amount.\n", |
23 | 23 | "\n", |
24 | 24 | "The following dictionary represents one example JSON instance:\n", |
25 | 25 | "\n", |
26 | 26 | "```json\n", |
27 | 27 | "{\n", |
28 | | - " \"door\": 1, // clue which door is the correct one\n", |
29 | | - " \"key_color\": \"blue\", // clue which key is the correct one\n", |
30 | | - " \"corridor\": [ // a corridor with many doors\n", |
31 | | - " {\n", |
32 | | - " \"monsters\": [\"troll\", \"wolf\"], // optional monsters in front of the door\n", |
33 | | - " \"door_no\": 1, // door number in the corridor\n", |
34 | | - " \"red_key\": \"gemstones\", // different keys return different treasures,\n", |
35 | | - " \"blue_key\": \"spellbooks\", // but only one is real, the others are fake\n", |
36 | | - " \"green_key\": \"artifacts\"\n", |
37 | | - " },\n", |
38 | | - " { // another room, here without monsters\n", |
39 | | - " \"door_no\": 0, // rooms can be shuffled, here room 0 comes after 1 \n", |
40 | | - " \"red_key\": \"diamonds\", \n", |
41 | | - " \"blue_key\": \"gold\", \n", |
42 | | - " \"green_key\": \"gemstones\"\n", |
43 | | - " },\n", |
44 | | - " // ... more rooms ...\n", |
45 | | - " ],\n", |
46 | | - " \"treasure\": \"spellbooks\" // correct treasure (target label)\n", |
| 28 | + " \"door\": 1, // clue which door is the correct one\n", |
| 29 | + " \"key_color\": \"blue\", // clue which key is the correct one\n", |
| 30 | + " \"corridor\": [\n", |
| 31 | + " // a corridor with many doors\n", |
| 32 | + " {\n", |
| 33 | + " \"monsters\": [\"troll\", \"wolf\"], // optional monsters in front of the door\n", |
| 34 | + " \"door_no\": 1, // door number in the corridor\n", |
| 35 | + " \"red_key\": \"gemstones\", // different keys return different treasures,\n", |
| 36 | + " \"blue_key\": \"spellbooks\", // but only one is real, the others are fake\n", |
| 37 | + " \"green_key\": \"artifacts\"\n", |
| 38 | + " },\n", |
| 39 | + " {\n", |
| 40 | + " // another room, here without monsters\n", |
| 41 | + " \"door_no\": 0, // rooms can be shuffled, here room 0 comes after 1\n", |
| 42 | + " \"red_key\": \"diamonds\",\n", |
| 43 | + " \"blue_key\": \"gold\",\n", |
| 44 | + " \"green_key\": \"gemstones\"\n", |
| 45 | + " }\n", |
| 46 | + " // ... more rooms ...\n", |
| 47 | + " ],\n", |
| 48 | + " \"treasure\": \"spellbooks\" // correct treasure (target label)\n", |
47 | 49 | "}\n", |
48 | 50 | "```\n", |
49 | 51 | "\n", |
|
56 | 58 | "source": [ |
57 | 59 | "### Preprocessing\n", |
58 | 60 | "\n", |
59 | | - "The JSON objects are tokenized by recursively walking through them depth-first and extracting key and value tokens. \n", |
60 | | - "Additionally, when encountering arrays or nested objects, special grammar tokens are included in the sequence. \n", |
| 61 | + "The JSON objects are tokenized by recursively walking through them depth-first and extracting key and value tokens.\n", |
| 62 | + "Additionally, when encountering arrays or nested objects, special grammar tokens are included in the sequence.\n", |
61 | 63 | "This diagram illustrates tokenization.\n", |
62 | 64 | "\n", |
63 | 65 | "<img src=\"../assets/preprocessing-diagram.png\" width=\"600px\" />\n" |
|
163 | 165 | "import json\n", |
164 | 166 | "\n", |
165 | 167 | "from sklearn.model_selection import train_test_split\n", |
166 | | - "from sklearn.pipeline import Pipeline\n", |
167 | 168 | "\n", |
168 | | - "from origami.utils.config import PipelineConfig\n", |
169 | | - "from origami.utils import set_seed\n", |
170 | 169 | "from origami.datasets.dungeons import generate_data\n", |
171 | | - "from origami.preprocessing import docs_to_df, build_prediction_pipelines\n", |
| 170 | + "from origami.preprocessing import build_prediction_pipelines, docs_to_df\n", |
| 171 | + "from origami.utils import set_seed\n", |
| 172 | + "from origami.utils.config import PipelineConfig\n", |
172 | 173 | "\n", |
173 | 174 | "# for reproducibility\n", |
174 | | - "# set_seed(123)\n", |
| 175 | + "set_seed(123)\n", |
175 | 176 | "\n", |
176 | 177 | "# generate Dungeons dataset (see origami/datasets/dungeons.py)\n", |
177 | 178 | "data = generate_data(\n", |
178 | 179 | " num_instances=10_000,\n", |
179 | 180 | " num_doors_range=(4, 8),\n", |
180 | 181 | " num_colors=3,\n", |
181 | 182 | " num_treasures=5,\n", |
182 | | - " with_monsters=True, # makes it harder as token positions get shifted by variable amount\n", |
183 | | - " shuffle_rooms=True, # makes it harder because rooms are in random order\n", |
184 | | - " shuffle_keys=True # makes it harder because keys are in random order\n", |
| 183 | + " with_monsters=True, # makes it harder as token positions get shifted by variable amount\n", |
| 184 | + " shuffle_rooms=True, # makes it harder because rooms are in random order\n", |
| 185 | + " shuffle_keys=True, # makes it harder because keys are in random order\n", |
185 | 186 | ")\n", |
186 | 187 | "\n", |
187 | 188 | "# print example dictionary\n", |
|
195 | 196 | "\n", |
196 | 197 | "# create train and test pipelines\n", |
197 | 198 | "pipelines = build_prediction_pipelines(\n", |
198 | | - " pipeline_config=PipelineConfig(sequence_order=\"ORDERED\", upscale=1),\n", |
199 | | - " target_field=TARGET_FIELD\n", |
| 199 | + " pipeline_config=PipelineConfig(sequence_order=\"ORDERED\", upscale=1), target_field=TARGET_FIELD\n", |
200 | 200 | ")\n", |
201 | 201 | "\n", |
202 | 202 | "# process train, eval and test data\n", |
203 | | - "train_df = pipelines['train'].fit_transform(train_docs_df)\n", |
204 | | - "test_df = pipelines['test'].transform(test_docs_df)\n", |
| 203 | + "train_df = pipelines[\"train\"].fit_transform(train_docs_df)\n", |
| 204 | + "test_df = pipelines[\"test\"].transform(test_docs_df)\n", |
205 | 205 | "\n", |
206 | 206 | "# get stateful objects\n", |
207 | | - "schema = pipelines['train'][\"schema\"].schema\n", |
208 | | - "encoder = pipelines['train'][\"encoder\"].encoder\n", |
209 | | - "block_size = pipelines['train'][\"padding\"].length\n", |
| 207 | + "schema = pipelines[\"train\"][\"schema\"].schema\n", |
| 208 | + "encoder = pipelines[\"train\"][\"encoder\"].encoder\n", |
| 209 | + "block_size = pipelines[\"train\"][\"padding\"].length\n", |
210 | 210 | "\n", |
211 | 211 | "# print data stats\n", |
212 | 212 | "print(f\"len train: {len(train_df)}, len test: {len(test_df)}\")\n", |
|
231 | 231 | } |
232 | 232 | ], |
233 | 233 | "source": [ |
234 | | - "# save dungeon dataset to MongoDB \n", |
235 | | - "from pymongo import MongoClient\n", |
| 234 | + "# save dungeon dataset to MongoDB\n", |
| 235 | + "from pymongo import MongoClient\n", |
236 | 236 | "\n", |
237 | 237 | "client = MongoClient(\"mongodb://localhost:27017/\")\n", |
238 | 238 | "collection = client.dungeons.dungeon_10k_4_8_3_5_mkr\n", |
|
247 | 247 | "\n", |
248 | 248 | "Here we instantiate an ORiGAMi model, a modified transformer trained on the token sequences created above.\n", |
249 | 249 | "We use a standard \"medium\" configuration. ORiGAMi models are relatively robust to the choice of hyper-parameter\n", |
250 | | - "and default configurations often work well for mid-sized datasets. " |
| 250 | + "and default configurations often work well for mid-sized datasets.\n" |
251 | 251 | ] |
252 | 252 | }, |
253 | 253 | { |
|
270 | 270 | "from origami.utils import ModelConfig, TrainConfig, count_parameters\n", |
271 | 271 | "\n", |
272 | 272 | "# model and train configs\n", |
273 | | - "model_config = ModelConfig.from_preset(\"medium\") # see origami/utils/config.py for different presets\n", |
| 273 | + "model_config = ModelConfig.from_preset(\"medium\") # see origami/utils/config.py for different presets\n", |
274 | 274 | "model_config.position_encoding = \"SINE_COSINE\"\n", |
275 | 275 | "model_config.vocab_size = encoder.vocab_size\n", |
276 | 276 | "model_config.block_size = block_size\n", |
|
284 | 284 | "train_dataset = DFDataset(train_df)\n", |
285 | 285 | "test_dataset = DFDataset(test_df)\n", |
286 | 286 | "\n", |
287 | | - "# create PDA and pass it to the model \n", |
| 287 | + "# create PDA and pass it to the model\n", |
288 | 288 | "vpda = ObjectVPDA(encoder, schema)\n", |
289 | 289 | "model = ORIGAMI(model_config, train_config, vpda=vpda)\n", |
290 | 290 | "\n", |
291 | 291 | "n_params = count_parameters(model)\n", |
292 | | - "print(f\"Number of parameters: {n_params/1e6:.2f}M\")" |
| 292 | + "print(f\"Number of parameters: {n_params / 1e6:.2f}M\")" |
293 | 293 | ] |
294 | 294 | }, |
295 | 295 | { |
|
878 | 878 | "name": "python", |
879 | 879 | "nbconvert_exporter": "python", |
880 | 880 | "pygments_lexer": "ipython3", |
881 | | - "version": "3.11.9" |
| 881 | + "version": "3.10.14" |
882 | 882 | } |
883 | 883 | }, |
884 | 884 | "nbformat": 4, |
|
0 commit comments