Skip to content

Commit 156b0f9

Browse files
author
Thomas Rueckstiess
committed
install origami in editable mode, fix linting and format issues
1 parent 575bacb commit 156b0f9

File tree

14 files changed

+201
-228
lines changed

14 files changed

+201
-228
lines changed

.github/workflows/python-package.yml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,11 @@ jobs:
2020
with:
2121
python-version: "3.10"
2222
cache: "pip"
23-
- name: Install dependencies
24-
run: |
25-
pip install -r requirements.txt
23+
- name: Install dependencies
24+
run: |
25+
python -m pip install --upgrade pip
26+
pip install -r requirements.txt
27+
pip install -e .
2628
- name: Python Ruff Lint and Format
2729
uses: adityabhangle658/[email protected]
2830
- name: Run tests with pytest

notebooks/dungeon-results.ipynb

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -592,15 +592,15 @@
592592
" plt.text(\n",
593593
" i - width / 2,\n",
594594
" stats_df[\"train_acc_mean\"].iloc[i] + 0.02,\n",
595-
" f'{stats_df[\"train_acc_mean\"].iloc[i]:.2f}',\n",
595+
" f\"{stats_df['train_acc_mean'].iloc[i]:.2f}\",\n",
596596
" ha=\"center\",\n",
597597
" va=\"bottom\",\n",
598598
" fontsize=\"x-small\",\n",
599599
" )\n",
600600
" plt.text(\n",
601601
" i + width / 2,\n",
602602
" stats_df[\"test_acc_mean\"].iloc[i] + 0.02,\n",
603-
" f'{stats_df[\"test_acc_mean\"].iloc[i]:.2f}',\n",
603+
" f\"{stats_df['test_acc_mean'].iloc[i]:.2f}\",\n",
604604
" ha=\"center\",\n",
605605
" va=\"bottom\",\n",
606606
" fontsize=\"x-small\",\n",
@@ -1714,7 +1714,6 @@
17141714
"\n",
17151715
"from origami.utils.guild import plot_scalar_history\n",
17161716
"\n",
1717-
"\n",
17181717
"runs_gr = guild.runs(labels=[\"ablation-6-dungeons-easy\"], filter_expr=\"model.guardrails=STRUCTURE_AND_VALUES\")\n",
17191718
"runs_no_gr = guild.runs(labels=[\"ablation-6-dungeons-easy\"], filter_expr=\"model.guardrails=NONE\")\n",
17201719
"\n",
@@ -1924,7 +1923,7 @@
19241923
"name": "python",
19251924
"nbconvert_exporter": "python",
19261925
"pygments_lexer": "ipython3",
1927-
"version": "3.11.9"
1926+
"version": "3.10.14"
19281927
}
19291928
},
19301929
"nbformat": 4,

notebooks/example_dungeons.ipynb

Lines changed: 31 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -6,43 +6,44 @@
66
"source": [
77
"## Training an ORiGAMi model on the Dungeons dataset\n",
88
"\n",
9-
"The Dungeons dataset is a (dungeons-themed) challenging synthetic dataset for supervised classification on \n",
10-
"semi-structured data. \n",
9+
"The Dungeons dataset is a (dungeons-themed) challenging synthetic dataset for supervised classification on\n",
10+
"semi-structured data.\n",
1111
"\n",
12-
"Each instance constains a corridor array with several rooms. Each room has a door number and contains multiple \n",
12+
"Each instance constains a corridor array with several rooms. Each room has a door number and contains multiple\n",
1313
"treasure chests with different-colored keys. All but one of the treasures are fake though.\n",
1414
"\n",
15-
"The goal is to find the correct room number and key color in each dungeon based on some clues and return the \n",
16-
"only non-fake treasure. \n",
15+
"The goal is to find the correct room number and key color in each dungeon based on some clues and return the\n",
16+
"only non-fake treasure.\n",
1717
"\n",
18-
"The clues are given at the top-level of the object with keys `door`, `key_color`. \n",
18+
"The clues are given at the top-level of the object with keys `door`, `key_color`.\n",
1919
"\n",
20-
"To make it even harder, the `corridor` array may be shuffled, and room objects may have a number of monsters as \n",
21-
"their first field, shifting the token positions of the serialized object by a variable amount. \n",
20+
"To make it even harder, the `corridor` array may be shuffled, and room objects may have a number of monsters as\n",
21+
"their first field, shifting the token positions of the serialized object by a variable amount.\n",
2222
"\n",
2323
"The following dictionary represents one example JSON instance:\n",
2424
"\n",
2525
"```json\n",
2626
"{\n",
27-
" \"door\": 1, // clue which door is the correct one\n",
28-
" \"key_color\": \"blue\", // clue which key is the correct one\n",
29-
" \"corridor\": [\n",
30-
" {\n",
31-
" \"monsters\": [\"troll\", \"wolf\"], // optional monsters in front of the door\n",
32-
" \"door_no\": 1, // door number in the corridor\n",
33-
" \"red_key\": \"gemstones\", // different keys return different treasures,\n",
34-
" \"blue_key\": \"spellbooks\", // but only one is real, the others are fake\n",
35-
" \"green_key\": \"artifacts\"\n",
36-
" },\n",
37-
" { // another room\n",
38-
" \"door_no\": 0, // rooms can be shuffled, here room 0 comes after 1 \n",
39-
" \"red_key\": \"diamonds\", \n",
40-
" \"blue_key\": \"gold\", \n",
41-
" \"green_key\": \"gemstones\"\n",
42-
" },\n",
43-
" // ... more doors ...\n",
44-
" ],\n",
45-
" \"treasure\": \"spellbooks\" // correct treasure (target label)\n",
27+
" \"door\": 1, // clue which door is the correct one\n",
28+
" \"key_color\": \"blue\", // clue which key is the correct one\n",
29+
" \"corridor\": [\n",
30+
" {\n",
31+
" \"monsters\": [\"troll\", \"wolf\"], // optional monsters in front of the door\n",
32+
" \"door_no\": 1, // door number in the corridor\n",
33+
" \"red_key\": \"gemstones\", // different keys return different treasures,\n",
34+
" \"blue_key\": \"spellbooks\", // but only one is real, the others are fake\n",
35+
" \"green_key\": \"artifacts\"\n",
36+
" },\n",
37+
" {\n",
38+
" // another room\n",
39+
" \"door_no\": 0, // rooms can be shuffled, here room 0 comes after 1\n",
40+
" \"red_key\": \"diamonds\",\n",
41+
" \"blue_key\": \"gold\",\n",
42+
" \"green_key\": \"gemstones\"\n",
43+
" }\n",
44+
" // ... more doors ...\n",
45+
" ],\n",
46+
" \"treasure\": \"spellbooks\" // correct treasure (target label)\n",
4647
"}\n",
4748
"```\n",
4849
"\n",
@@ -133,8 +134,8 @@
133134
" num_doors_range=(5, 10),\n",
134135
" num_colors=3,\n",
135136
" num_treasures=5,\n",
136-
" with_monsters=True, # makes it harder as token positions get shifted by variable amount\n",
137-
" shuffle_rooms=True # makes it harder because rooms are in random order\n",
137+
" with_monsters=True, # makes it harder as token positions get shifted by variable amount\n",
138+
" shuffle_rooms=True, # makes it harder because rooms are in random order\n",
138139
")\n",
139140
"\n",
140141
"# print example dictionary\n",
@@ -463,7 +464,7 @@
463464
"name": "python",
464465
"nbconvert_exporter": "python",
465466
"pygments_lexer": "ipython3",
466-
"version": "3.11.9"
467+
"version": "3.10.14"
467468
}
468469
},
469470
"nbformat": 4,

notebooks/example_origami_dungeons.ipynb

Lines changed: 51 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -8,42 +8,44 @@
88
"\n",
99
"### The Dungeons Dataset\n",
1010
"\n",
11-
"The Dungeons dataset is a (dungeons-themed) challenging synthetic dataset for supervised classification on \n",
12-
"semi-structured data. \n",
11+
"The Dungeons dataset is a (dungeons-themed) challenging synthetic dataset for supervised classification on\n",
12+
"semi-structured data.\n",
1313
"\n",
14-
"Each instance constains a corridor array with several rooms. Each room has a door number and contains multiple \n",
14+
"Each instance constains a corridor array with several rooms. Each room has a door number and contains multiple\n",
1515
"treasure chests with different-colored keys. All but one of the treasures are fake though.\n",
1616
"\n",
17-
"The goal is to find the correct room number and key color in each dungeon based on some clues and return the \n",
18-
"only real treasure. The clues are given at the top-level of the object in the fields `door` and `key_color`. \n",
17+
"The goal is to find the correct room number and key color in each dungeon based on some clues and return the\n",
18+
"only real treasure. The clues are given at the top-level of the object in the fields `door` and `key_color`.\n",
1919
"\n",
20-
"To make it even harder, the `corridor` array may be shuffled (`shuffle_rooms=True`), and room objects may \n",
21-
"have a number of monsters as their first field (`with_monsters=True`), shifting the token positions of the \n",
22-
"serialized object by a variable amount. \n",
20+
"To make it even harder, the `corridor` array may be shuffled (`shuffle_rooms=True`), and room objects may\n",
21+
"have a number of monsters as their first field (`with_monsters=True`), shifting the token positions of the\n",
22+
"serialized object by a variable amount.\n",
2323
"\n",
2424
"The following dictionary represents one example JSON instance:\n",
2525
"\n",
2626
"```json\n",
2727
"{\n",
28-
" \"door\": 1, // clue which door is the correct one\n",
29-
" \"key_color\": \"blue\", // clue which key is the correct one\n",
30-
" \"corridor\": [ // a corridor with many doors\n",
31-
" {\n",
32-
" \"monsters\": [\"troll\", \"wolf\"], // optional monsters in front of the door\n",
33-
" \"door_no\": 1, // door number in the corridor\n",
34-
" \"red_key\": \"gemstones\", // different keys return different treasures,\n",
35-
" \"blue_key\": \"spellbooks\", // but only one is real, the others are fake\n",
36-
" \"green_key\": \"artifacts\"\n",
37-
" },\n",
38-
" { // another room, here without monsters\n",
39-
" \"door_no\": 0, // rooms can be shuffled, here room 0 comes after 1 \n",
40-
" \"red_key\": \"diamonds\", \n",
41-
" \"blue_key\": \"gold\", \n",
42-
" \"green_key\": \"gemstones\"\n",
43-
" },\n",
44-
" // ... more rooms ...\n",
45-
" ],\n",
46-
" \"treasure\": \"spellbooks\" // correct treasure (target label)\n",
28+
" \"door\": 1, // clue which door is the correct one\n",
29+
" \"key_color\": \"blue\", // clue which key is the correct one\n",
30+
" \"corridor\": [\n",
31+
" // a corridor with many doors\n",
32+
" {\n",
33+
" \"monsters\": [\"troll\", \"wolf\"], // optional monsters in front of the door\n",
34+
" \"door_no\": 1, // door number in the corridor\n",
35+
" \"red_key\": \"gemstones\", // different keys return different treasures,\n",
36+
" \"blue_key\": \"spellbooks\", // but only one is real, the others are fake\n",
37+
" \"green_key\": \"artifacts\"\n",
38+
" },\n",
39+
" {\n",
40+
" // another room, here without monsters\n",
41+
" \"door_no\": 0, // rooms can be shuffled, here room 0 comes after 1\n",
42+
" \"red_key\": \"diamonds\",\n",
43+
" \"blue_key\": \"gold\",\n",
44+
" \"green_key\": \"gemstones\"\n",
45+
" }\n",
46+
" // ... more rooms ...\n",
47+
" ],\n",
48+
" \"treasure\": \"spellbooks\" // correct treasure (target label)\n",
4749
"}\n",
4850
"```\n",
4951
"\n",
@@ -56,8 +58,8 @@
5658
"source": [
5759
"### Preprocessing\n",
5860
"\n",
59-
"The JSON objects are tokenized by recursively walking through them depth-first and extracting key and value tokens. \n",
60-
"Additionally, when encountering arrays or nested objects, special grammar tokens are included in the sequence. \n",
61+
"The JSON objects are tokenized by recursively walking through them depth-first and extracting key and value tokens.\n",
62+
"Additionally, when encountering arrays or nested objects, special grammar tokens are included in the sequence.\n",
6163
"This diagram illustrates tokenization.\n",
6264
"\n",
6365
"<img src=\"../assets/preprocessing-diagram.png\" width=\"600px\" />\n"
@@ -163,25 +165,24 @@
163165
"import json\n",
164166
"\n",
165167
"from sklearn.model_selection import train_test_split\n",
166-
"from sklearn.pipeline import Pipeline\n",
167168
"\n",
168-
"from origami.utils.config import PipelineConfig\n",
169-
"from origami.utils import set_seed\n",
170169
"from origami.datasets.dungeons import generate_data\n",
171-
"from origami.preprocessing import docs_to_df, build_prediction_pipelines\n",
170+
"from origami.preprocessing import build_prediction_pipelines, docs_to_df\n",
171+
"from origami.utils import set_seed\n",
172+
"from origami.utils.config import PipelineConfig\n",
172173
"\n",
173174
"# for reproducibility\n",
174-
"# set_seed(123)\n",
175+
"set_seed(123)\n",
175176
"\n",
176177
"# generate Dungeons dataset (see origami/datasets/dungeons.py)\n",
177178
"data = generate_data(\n",
178179
" num_instances=10_000,\n",
179180
" num_doors_range=(4, 8),\n",
180181
" num_colors=3,\n",
181182
" num_treasures=5,\n",
182-
" with_monsters=True, # makes it harder as token positions get shifted by variable amount\n",
183-
" shuffle_rooms=True, # makes it harder because rooms are in random order\n",
184-
" shuffle_keys=True # makes it harder because keys are in random order\n",
183+
" with_monsters=True, # makes it harder as token positions get shifted by variable amount\n",
184+
" shuffle_rooms=True, # makes it harder because rooms are in random order\n",
185+
" shuffle_keys=True, # makes it harder because keys are in random order\n",
185186
")\n",
186187
"\n",
187188
"# print example dictionary\n",
@@ -195,18 +196,17 @@
195196
"\n",
196197
"# create train and test pipelines\n",
197198
"pipelines = build_prediction_pipelines(\n",
198-
" pipeline_config=PipelineConfig(sequence_order=\"ORDERED\", upscale=1),\n",
199-
" target_field=TARGET_FIELD\n",
199+
" pipeline_config=PipelineConfig(sequence_order=\"ORDERED\", upscale=1), target_field=TARGET_FIELD\n",
200200
")\n",
201201
"\n",
202202
"# process train, eval and test data\n",
203-
"train_df = pipelines['train'].fit_transform(train_docs_df)\n",
204-
"test_df = pipelines['test'].transform(test_docs_df)\n",
203+
"train_df = pipelines[\"train\"].fit_transform(train_docs_df)\n",
204+
"test_df = pipelines[\"test\"].transform(test_docs_df)\n",
205205
"\n",
206206
"# get stateful objects\n",
207-
"schema = pipelines['train'][\"schema\"].schema\n",
208-
"encoder = pipelines['train'][\"encoder\"].encoder\n",
209-
"block_size = pipelines['train'][\"padding\"].length\n",
207+
"schema = pipelines[\"train\"][\"schema\"].schema\n",
208+
"encoder = pipelines[\"train\"][\"encoder\"].encoder\n",
209+
"block_size = pipelines[\"train\"][\"padding\"].length\n",
210210
"\n",
211211
"# print data stats\n",
212212
"print(f\"len train: {len(train_df)}, len test: {len(test_df)}\")\n",
@@ -231,8 +231,8 @@
231231
}
232232
],
233233
"source": [
234-
"# save dungeon dataset to MongoDB \n",
235-
"from pymongo import MongoClient\n",
234+
"# save dungeon dataset to MongoDB\n",
235+
"from pymongo import MongoClient\n",
236236
"\n",
237237
"client = MongoClient(\"mongodb://localhost:27017/\")\n",
238238
"collection = client.dungeons.dungeon_10k_4_8_3_5_mkr\n",
@@ -247,7 +247,7 @@
247247
"\n",
248248
"Here we instantiate an ORiGAMi model, a modified transformer trained on the token sequences created above.\n",
249249
"We use a standard \"medium\" configuration. ORiGAMi models are relatively robust to the choice of hyper-parameter\n",
250-
"and default configurations often work well for mid-sized datasets. "
250+
"and default configurations often work well for mid-sized datasets.\n"
251251
]
252252
},
253253
{
@@ -270,7 +270,7 @@
270270
"from origami.utils import ModelConfig, TrainConfig, count_parameters\n",
271271
"\n",
272272
"# model and train configs\n",
273-
"model_config = ModelConfig.from_preset(\"medium\") # see origami/utils/config.py for different presets\n",
273+
"model_config = ModelConfig.from_preset(\"medium\") # see origami/utils/config.py for different presets\n",
274274
"model_config.position_encoding = \"SINE_COSINE\"\n",
275275
"model_config.vocab_size = encoder.vocab_size\n",
276276
"model_config.block_size = block_size\n",
@@ -284,12 +284,12 @@
284284
"train_dataset = DFDataset(train_df)\n",
285285
"test_dataset = DFDataset(test_df)\n",
286286
"\n",
287-
"# create PDA and pass it to the model \n",
287+
"# create PDA and pass it to the model\n",
288288
"vpda = ObjectVPDA(encoder, schema)\n",
289289
"model = ORIGAMI(model_config, train_config, vpda=vpda)\n",
290290
"\n",
291291
"n_params = count_parameters(model)\n",
292-
"print(f\"Number of parameters: {n_params/1e6:.2f}M\")"
292+
"print(f\"Number of parameters: {n_params / 1e6:.2f}M\")"
293293
]
294294
},
295295
{
@@ -878,7 +878,7 @@
878878
"name": "python",
879879
"nbconvert_exporter": "python",
880880
"pygments_lexer": "ipython3",
881-
"version": "3.11.9"
881+
"version": "3.10.14"
882882
}
883883
},
884884
"nbformat": 4,

0 commit comments

Comments
 (0)