Skip to content

Commit 010a2ff

Browse files
rueckstiessThomas Rueckstiess
andauthored
Improve cli (#8)
* small copy change in notebook. * mapping learning rate, making --set-parameter a global option and lowercase, documentation of train command. * predict fixes, default STRUCTURE_AND_VALUES * CLI documentation, fixed issue where target field is ignored. Instead it returns Symbol.UNKNOWN in the TargetFieldPipe. --------- Co-authored-by: Thomas Rueckstiess <[email protected]>
1 parent eb9e05f commit 010a2ff

File tree

11 files changed

+503
-73
lines changed

11 files changed

+503
-73
lines changed

CLI.md

Lines changed: 374 additions & 0 deletions
Large diffs are not rendered by default.

README.md

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -50,33 +50,15 @@ ORiGAMi comes with a command line interface (CLI) and a Python SDK.
5050

5151
### Usage from the Command Line
5252

53-
The CLI allows to train a model and make predictions and generate synthetic data from a trained model. After installation, run `origami` from your shell to see an overview of available commands.
53+
The CLI allows to train a model and make predictions from a trained model. After installation, run `origami` from your shell to see an overview of available commands.
5454

55-
Help for specific commands is available with `origami <command> --help`, where `<command>` is one of `train`, `predict`, `generate`.
55+
Help for specific commands is available with `origami <command> --help`, where `<command>` is currently one of `train` or `predict`.
5656

57-
#### Model Training
58-
59-
To train a model, use the `origami train` command. ORiGAMi works well with MongoDB. For example, to train a model on the `shop.orders` collection on a locally running MongoDB instance on standard port 27017, use the following command:
60-
61-
```
62-
origami train "mongodb://localhost:27017" --source-db shop --source-coll orders
63-
```
64-
65-
#### Making Predictions
66-
67-
...TBD...
68-
69-
#### Generating Synthetic Data
70-
71-
...TBD...
57+
Detailed documentation for the CLI and available options can be found in [`CLI.md`](CLI.md).
7258

7359
### Usage with Python
7460

75-
...TBD...
76-
77-
```python
78-
from origami.model import ORIGAMI
79-
```
61+
To see an example on how to use ORiGAMi from Python, take a look at the provided [./notebooks](./notebooks/) folder, e.g. the [`example_origami_dungeons.ipynb`](./notebooks/example_origami_dungeons.ipynb) notebook.
8062

8163
## Experiment Reproduction
8264

notebooks/example_rf_dungeons.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@
325325
"\n",
326326
"We will attempt to learn the same Dungeons dataset as used in `example_origami_dungeons.ipynb` with a\n",
327327
"[RandomForestClassifier](https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html)\n",
328-
"from scikit-learn.\n",
328+
"from scikit-learn. However this will not generalize to the test set, as we discuss in the paper.\n",
329329
"\n",
330330
"We recursively flatten the dataset, creating a column for each field path (e.g. `corridor.2.blue_key`). The we\n",
331331
"transform all features through one-hot encoding, including the numeric fields (`door` and `door_no`) as these are\n",

origami/cli/main.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import click
22

3-
from .generate import generate
4-
from .predict import predict
5-
from .train import train
3+
from origami.cli.generate import generate
4+
from origami.cli.predict import predict
5+
from origami.cli.train import train
66

77
CONTEXT_SETTINGS = dict(max_content_width=120)
88

origami/cli/predict.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,14 @@
55
from click_option_group import optgroup
66
from omegaconf import OmegaConf
77

8+
from origami.cli.utils import create_projection, load_data
89
from origami.inference import Predictor
910
from origami.model import ORIGAMI
1011
from origami.model.vpda import ObjectVPDA
1112
from origami.preprocessing import DFDataset, TargetFieldPipe
1213
from origami.utils import Symbol, count_parameters, load_origami_model
1314
from origami.utils.config import GuardrailsMethod
1415

15-
from .utils import create_projection, load_data
16-
1716

1817
@click.command()
1918
@click.argument("source", type=str)
@@ -35,7 +34,7 @@
3534
@optgroup.option("--limit", "-l", type=int, default=0, help="limit the number of documents to load")
3635
@optgroup.group("Output Options")
3736
@optgroup.option("--json", "-j", is_flag=True, default=False, help="output full JSON objects including target field")
38-
@click.option("--verbose", "-v", is_flag=True, default=True)
37+
@click.option("--verbose", "-v", is_flag=True, default=False)
3938
def predict(source, **kwargs):
4039
"""Predict target fields with a trained ORIGAMI model."""
4140

@@ -58,8 +57,6 @@ def predict(source, **kwargs):
5857
case GuardrailsMethod.NONE:
5958
vpda = None
6059

61-
click.echo(f"config:\n {OmegaConf.to_yaml(config)}")
62-
6360
model = ORIGAMI(config.model, config.train, vpda=vpda)
6461
model.load_state_dict(state_dict)
6562

@@ -77,7 +74,8 @@ def predict(source, **kwargs):
7774
# update or create new target pipe with new target_field
7875
test_pipeline = pipelines["test"]
7976

80-
if "target" in test_pipeline:
77+
# update pipeline parameters and transform data
78+
if "target" in test_pipeline.named_steps:
8179
test_pipeline["target"].target_field = config.data.target_field
8280
else:
8381
test_pipeline.steps.insert(0, ["target", TargetFieldPipe(config.data.target_field)])
@@ -89,9 +87,9 @@ def predict(source, **kwargs):
8987
if kwargs["verbose"]:
9088
# report number of parameters (note we don't count the decoder parameters in lm_head)
9189
n_params = count_parameters(model)
92-
click.echo(f"running on device: {model.device}")
93-
click.echo(f"number of parameters: {n_params / 1e6:.2f}M")
94-
click.echo(f"config:\n {OmegaConf.to_yaml(config)}")
90+
click.echo(f"running on device: {model.device}", err=True)
91+
click.echo(f"number of parameters: {n_params / 1e6:.2f}M", err=True)
92+
click.echo(f"config:\n {OmegaConf.to_yaml(config)}", err=True)
9593

9694
# predict target field
9795
predictor = Predictor(model, encoder, config.data.target_field, max_batch_size=config.train.batch_size)

origami/cli/train.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,13 @@
2929
help="path to write trained model",
3030
)
3131
@click.option("--seed", type=int, default=1234, show_default=True, help="random seed")
32+
@click.option(
33+
"--set-parameter",
34+
"-p",
35+
type=str,
36+
multiple=True,
37+
help="set additional config parameters, format: key.subkey=value. Multiple parameters can be set.",
38+
)
3239
@click.option("--verbose", "-v", is_flag=True, default=False)
3340
@optgroup.group("Source Options")
3441
@optgroup.option("--source-db", "-d", type=str, help="database name, only used when SOURCE is a MongoDB URI.")
@@ -38,14 +45,7 @@
3845
@optgroup.option("--skip", "-s", type=int, default=0, help="number of documents to skip")
3946
@optgroup.option("--limit", "-l", type=int, default=0, help="limit the number of documents to load")
4047
@optgroup.group("Config Options")
41-
@optgroup.option("--config-file", "-C", type=click.File("r"), help="path to config file")
42-
@optgroup.option(
43-
"--set-parameter",
44-
"-P",
45-
type=str,
46-
multiple=True,
47-
help="set additional config parameters, format: key.subkey=value. Multiple parameters can be set.",
48-
)
48+
# @optgroup.option("--config-file", "-C", type=click.File("r"), help="path to config file")
4949
@optgroup.option(
5050
"--max-vocab-size",
5151
"-V",
@@ -56,7 +56,7 @@
5656
)
5757
@optgroup.option(
5858
"--num-layers",
59-
"-L",
59+
"-T",
6060
type=int,
6161
default=4,
6262
show_default=True,
@@ -78,6 +78,14 @@
7878
show_default=True,
7979
help="hidden dimensionality of transformer layers",
8080
)
81+
@optgroup.option(
82+
"--learning-rate",
83+
"-L",
84+
type=float,
85+
default=1e-3,
86+
show_default=True,
87+
help="max. learning rate of the model",
88+
)
8189
@optgroup.option(
8290
"--num-batches", "-N", type=int, default=10000, show_default=True, help="number of batches to train on"
8391
)
@@ -102,7 +110,7 @@
102110
"--guardrails",
103111
"-G",
104112
type=click.Choice(["NONE", "STRUCTURE_ONLY", "STRUCTURE_AND_VALUES"]),
105-
default="STRUCTURE_ONLY",
113+
default="STRUCTURE_AND_VALUES",
106114
help="guardrails settings",
107115
)
108116
@optgroup.option(
@@ -152,8 +160,7 @@ def train(source: str, **kwargs):
152160
# train configs
153161
config.train.n_batches = kwargs["num_batches"]
154162
config.train.batch_size = kwargs["batch_size"]
155-
config.train.learning_rate = 1e-3
156-
config.train.n_warmup_batches = 1000
163+
config.train.learning_rate = kwargs["learning_rate"]
157164
config.train.print_every = 10
158165
config.train.eval_every = 100
159166
config.train.test_split = kwargs["val_split_ratio"]

origami/preprocessing/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def consume_doc(gen):
9999

100100

101101
def target_collate_fn(target_token_id: int):
102-
def collate_fn(tokens: torch.tensor) -> torch.tensor:
102+
def collate_fn(tokens: list[torch.tensor]) -> torch.tensor:
103103
"""collate function that only returns sequences up to a target token (incl.). Assumes
104104
the target token is at the same position in each sequence. (use with TargetTokenBatchSampler)"""
105105
tokens = default_collate(tokens)

origami/utils/common.py

Lines changed: 75 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -282,14 +282,11 @@ def get_value_at_path(d: dict, path: List[str]) -> Tuple[Any, bool]:
282282
def reorder_with_target_last(d: dict, target_path: str) -> Tuple[OrderedDict, Any]:
283283
"""
284284
Reorder dictionary so target field appears last, maintaining nested structure.
285-
If target field doesn't exist, returns (OrderedDict(d), Symbol.UNKNOWN).
285+
Creates missing intermediate paths and sets target to Symbol.UNKNOWN if not found.
286286
"""
287287
path_components = parse_path(target_path)
288288
target_value, found = get_value_at_path(d, path_components)
289289

290-
if not found:
291-
return OrderedDict(d), Symbol.UNKNOWN
292-
293290
def reorder_level(current_dict: dict, remaining_path: List[str]) -> OrderedDict:
294291
if not remaining_path:
295292
return OrderedDict(current_dict)
@@ -302,18 +299,82 @@ def reorder_level(current_dict: dict, remaining_path: List[str]) -> OrderedDict:
302299
if k != current_target:
303300
result[k] = v if not isinstance(v, dict) else reorder_level(v, [])
304301

305-
# Add target field last
302+
# Handle the target path
306303
if current_target in current_dict:
307304
target_dict = current_dict[current_target]
308-
if len(remaining_path) > 1:
309-
# If we have more path components, recurse with remaining path
310-
result[current_target] = reorder_level(target_dict, remaining_path[1:])
311-
else:
312-
# If this is the final path component, add it last
313-
result[current_target] = (
314-
target_dict if not isinstance(target_dict, dict) else reorder_level(target_dict, [])
315-
)
305+
else:
306+
# Create empty dict for missing intermediate paths
307+
target_dict = {} if len(remaining_path) > 1 else Symbol.UNKNOWN
308+
309+
if len(remaining_path) > 1:
310+
# If we have more path components, recurse with remaining path
311+
result[current_target] = reorder_level(target_dict, remaining_path[1:])
312+
else:
313+
# If this is the final path component, add it last
314+
result[current_target] = (
315+
target_dict if not isinstance(target_dict, dict) else reorder_level(target_dict, [])
316+
)
316317

317318
return result
318319

319-
return reorder_level(d, path_components), target_value
320+
# Check if we're trying to traverse through a non-dict value
321+
current = d
322+
for i, component in enumerate(path_components[:-1]):
323+
if component in current and not isinstance(current[component], dict):
324+
# If we hit a non-dict value in the path, treat the entire remaining path
325+
# as a top-level field
326+
new_target = ".".join(path_components[i:])
327+
result = OrderedDict()
328+
for k, v in d.items():
329+
if k != new_target:
330+
result[k] = v if not isinstance(v, dict) else reorder_level(v, [])
331+
result[new_target] = Symbol.UNKNOWN
332+
return result, Symbol.UNKNOWN
333+
334+
if component not in current:
335+
break
336+
current = current[component]
337+
338+
return reorder_level(d, path_components), target_value if found else Symbol.UNKNOWN
339+
340+
341+
# def reorder_with_target_last(d: dict, target_path: str) -> Tuple[OrderedDict, Any]:
342+
# """
343+
# Reorder dictionary so target field appears last, maintaining nested structure.
344+
# If target field doesn't exist, returns (OrderedDict(d), Symbol.UNKNOWN).
345+
# """
346+
# path_components = parse_path(target_path)
347+
# target_value, found = get_value_at_path(d, path_components)
348+
349+
# if not found:
350+
# od = OrderedDict(d)
351+
# od[target_path] = Symbol.UNKNOWN
352+
# return od, Symbol.UNKNOWN
353+
354+
# def reorder_level(current_dict: dict, remaining_path: List[str]) -> OrderedDict:
355+
# if not remaining_path:
356+
# return OrderedDict(current_dict)
357+
358+
# current_target = remaining_path[0]
359+
# result = OrderedDict()
360+
361+
# # Add all non-target fields first
362+
# for k, v in current_dict.items():
363+
# if k != current_target:
364+
# result[k] = v if not isinstance(v, dict) else reorder_level(v, [])
365+
366+
# # Add target field last
367+
# if current_target in current_dict:
368+
# target_dict = current_dict[current_target]
369+
# if len(remaining_path) > 1:
370+
# # If we have more path components, recurse with remaining path
371+
# result[current_target] = reorder_level(target_dict, remaining_path[1:])
372+
# else:
373+
# # If this is the final path component, add it last
374+
# result[current_target] = (
375+
# target_dict if not isinstance(target_dict, dict) else reorder_level(target_dict, [])
376+
# )
377+
378+
# return result
379+
380+
# return reorder_level(d, path_components), target_value

origami/utils/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ class ModelConfig(BaseConfig):
6666
mask_field_token_losses: bool = False
6767

6868
# whether or not to use guardrails (requires a ObjectVPDA to be passed into model)
69-
guardrails: GuardrailsMethod = GuardrailsMethod.STRUCTURE_ONLY
69+
guardrails: GuardrailsMethod = GuardrailsMethod.STRUCTURE_AND_VALUES
7070

7171
@staticmethod
7272
def from_preset(size: str, **kwargs) -> "ModelConfig":

tests/preprocessing/test_pipes.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -153,12 +153,9 @@ def test_supervised_target_pipe(self):
153153

154154
self.assertIn("target", df.columns)
155155

156-
for i, (doc, target) in enumerate(zip(df["docs"], df["target"])):
157-
if target == Symbol.UNKNOWN:
158-
self.assertNotIn("b", doc)
159-
else:
160-
self.assertIn("b", doc)
161-
self.assertEqual(doc["b"], target)
156+
for doc, target in zip(df["docs"], df["target"]):
157+
self.assertIn("b", doc)
158+
self.assertEqual(doc["b"], target)
162159

163160

164161
class TestDocTokenizerPipe(unittest.TestCase):

0 commit comments

Comments
 (0)